1use std::arch::x86_64::{
2 __m256i, _mm256_and_si256, _mm256_cmpeq_epi8, _mm256_extract_epi64, _mm256_loadu_si256,
3 _mm256_sad_epu8, _mm256_set1_epi8, _mm256_setzero_si256, _mm256_sub_epi8, _mm256_xor_si256,
4};
5
6#[target_feature(enable = "avx2")]
7pub unsafe fn _mm256_set1_epu8(a: u8) -> __m256i {
8 _mm256_set1_epi8(a as i8)
9}
10
11#[target_feature(enable = "avx2")]
12pub unsafe fn mm256_cmpneq_epi8(a: __m256i, b: __m256i) -> __m256i {
13 _mm256_xor_si256(_mm256_cmpeq_epi8(a, b), _mm256_set1_epi8(-1))
14}
15
16const MASK: [u8; 64] = [
17 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
18 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
19 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
20];
21
22#[target_feature(enable = "avx2")]
23unsafe fn mm256_from_offset(slice: &[u8], offset: usize) -> __m256i {
24 _mm256_loadu_si256(slice.as_ptr().add(offset) as *const _)
25}
26
27#[target_feature(enable = "avx2")]
28unsafe fn sum(u8s: &__m256i) -> usize {
29 let sums = _mm256_sad_epu8(*u8s, _mm256_setzero_si256());
30 (_mm256_extract_epi64(sums, 0)
31 + _mm256_extract_epi64(sums, 1)
32 + _mm256_extract_epi64(sums, 2)
33 + _mm256_extract_epi64(sums, 3)) as usize
34}
35
36#[target_feature(enable = "avx2")]
37pub unsafe fn chunk_count(haystack: &[u8], needle: u8) -> usize {
38 assert!(haystack.len() >= 32);
39
40 let mut offset = 0;
41 let mut count = 0;
42
43 let needles = _mm256_set1_epu8(needle);
44
45 while haystack.len() >= offset + 32 * 255 {
47 let mut counts = _mm256_setzero_si256();
48 for _ in 0..255 {
49 counts = _mm256_sub_epi8(
50 counts,
51 _mm256_cmpeq_epi8(mm256_from_offset(haystack, offset), needles),
52 );
53 offset += 32;
54 }
55 count += sum(&counts);
56 }
57
58 if haystack.len() >= offset + 32 * 128 {
60 let mut counts = _mm256_setzero_si256();
61 for _ in 0..128 {
62 counts = _mm256_sub_epi8(
63 counts,
64 _mm256_cmpeq_epi8(mm256_from_offset(haystack, offset), needles),
65 );
66 offset += 32;
67 }
68 count += sum(&counts);
69 }
70
71 let mut counts = _mm256_setzero_si256();
73 for i in 0..(haystack.len() - offset) / 32 {
74 counts = _mm256_sub_epi8(
75 counts,
76 _mm256_cmpeq_epi8(mm256_from_offset(haystack, offset + i * 32), needles),
77 );
78 }
79 if haystack.len() % 32 != 0 {
80 counts = _mm256_sub_epi8(
81 counts,
82 _mm256_and_si256(
83 _mm256_cmpeq_epi8(mm256_from_offset(haystack, haystack.len() - 32), needles),
84 mm256_from_offset(&MASK, haystack.len() % 32),
85 ),
86 );
87 }
88 count += sum(&counts);
89
90 count
91}
92
93#[target_feature(enable = "avx2")]
94unsafe fn is_leading_utf8_byte(u8s: __m256i) -> __m256i {
95 mm256_cmpneq_epi8(
96 _mm256_and_si256(u8s, _mm256_set1_epu8(0b1100_0000)),
97 _mm256_set1_epu8(0b1000_0000),
98 )
99}
100
101#[target_feature(enable = "avx2")]
102pub unsafe fn chunk_num_chars(utf8_chars: &[u8]) -> usize {
103 assert!(utf8_chars.len() >= 32);
104
105 let mut offset = 0;
106 let mut count = 0;
107
108 while utf8_chars.len() >= offset + 32 * 255 {
110 let mut counts = _mm256_setzero_si256();
111
112 for _ in 0..255 {
113 counts = _mm256_sub_epi8(
114 counts,
115 is_leading_utf8_byte(mm256_from_offset(utf8_chars, offset)),
116 );
117 offset += 32;
118 }
119 count += sum(&counts);
120 }
121
122 if utf8_chars.len() >= offset + 32 * 128 {
124 let mut counts = _mm256_setzero_si256();
125 for _ in 0..128 {
126 counts = _mm256_sub_epi8(
127 counts,
128 is_leading_utf8_byte(mm256_from_offset(utf8_chars, offset)),
129 );
130 offset += 32;
131 }
132 count += sum(&counts);
133 }
134
135 let mut counts = _mm256_setzero_si256();
137 for i in 0..(utf8_chars.len() - offset) / 32 {
138 counts = _mm256_sub_epi8(
139 counts,
140 is_leading_utf8_byte(mm256_from_offset(utf8_chars, offset + i * 32)),
141 );
142 }
143 if utf8_chars.len() % 32 != 0 {
144 counts = _mm256_sub_epi8(
145 counts,
146 _mm256_and_si256(
147 is_leading_utf8_byte(mm256_from_offset(utf8_chars, utf8_chars.len() - 32)),
148 mm256_from_offset(&MASK, utf8_chars.len() % 32),
149 ),
150 );
151 }
152 count += sum(&counts);
153
154 count
155}