bytecount/simd/
x86_avx2.rs

1use std::arch::x86_64::{
2    __m256i, _mm256_and_si256, _mm256_cmpeq_epi8, _mm256_extract_epi64, _mm256_loadu_si256,
3    _mm256_sad_epu8, _mm256_set1_epi8, _mm256_setzero_si256, _mm256_sub_epi8, _mm256_xor_si256,
4};
5
6#[target_feature(enable = "avx2")]
7pub unsafe fn _mm256_set1_epu8(a: u8) -> __m256i {
8    _mm256_set1_epi8(a as i8)
9}
10
11#[target_feature(enable = "avx2")]
12pub unsafe fn mm256_cmpneq_epi8(a: __m256i, b: __m256i) -> __m256i {
13    _mm256_xor_si256(_mm256_cmpeq_epi8(a, b), _mm256_set1_epi8(-1))
14}
15
16const MASK: [u8; 64] = [
17    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
18    255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
19    255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
20];
21
22#[target_feature(enable = "avx2")]
23unsafe fn mm256_from_offset(slice: &[u8], offset: usize) -> __m256i {
24    _mm256_loadu_si256(slice.as_ptr().add(offset) as *const _)
25}
26
27#[target_feature(enable = "avx2")]
28unsafe fn sum(u8s: &__m256i) -> usize {
29    let sums = _mm256_sad_epu8(*u8s, _mm256_setzero_si256());
30    (_mm256_extract_epi64(sums, 0)
31        + _mm256_extract_epi64(sums, 1)
32        + _mm256_extract_epi64(sums, 2)
33        + _mm256_extract_epi64(sums, 3)) as usize
34}
35
36#[target_feature(enable = "avx2")]
37pub unsafe fn chunk_count(haystack: &[u8], needle: u8) -> usize {
38    assert!(haystack.len() >= 32);
39
40    let mut offset = 0;
41    let mut count = 0;
42
43    let needles = _mm256_set1_epu8(needle);
44
45    // 8160
46    while haystack.len() >= offset + 32 * 255 {
47        let mut counts = _mm256_setzero_si256();
48        for _ in 0..255 {
49            counts = _mm256_sub_epi8(
50                counts,
51                _mm256_cmpeq_epi8(mm256_from_offset(haystack, offset), needles),
52            );
53            offset += 32;
54        }
55        count += sum(&counts);
56    }
57
58    // 4096
59    if haystack.len() >= offset + 32 * 128 {
60        let mut counts = _mm256_setzero_si256();
61        for _ in 0..128 {
62            counts = _mm256_sub_epi8(
63                counts,
64                _mm256_cmpeq_epi8(mm256_from_offset(haystack, offset), needles),
65            );
66            offset += 32;
67        }
68        count += sum(&counts);
69    }
70
71    // 32
72    let mut counts = _mm256_setzero_si256();
73    for i in 0..(haystack.len() - offset) / 32 {
74        counts = _mm256_sub_epi8(
75            counts,
76            _mm256_cmpeq_epi8(mm256_from_offset(haystack, offset + i * 32), needles),
77        );
78    }
79    if haystack.len() % 32 != 0 {
80        counts = _mm256_sub_epi8(
81            counts,
82            _mm256_and_si256(
83                _mm256_cmpeq_epi8(mm256_from_offset(haystack, haystack.len() - 32), needles),
84                mm256_from_offset(&MASK, haystack.len() % 32),
85            ),
86        );
87    }
88    count += sum(&counts);
89
90    count
91}
92
93#[target_feature(enable = "avx2")]
94unsafe fn is_leading_utf8_byte(u8s: __m256i) -> __m256i {
95    mm256_cmpneq_epi8(
96        _mm256_and_si256(u8s, _mm256_set1_epu8(0b1100_0000)),
97        _mm256_set1_epu8(0b1000_0000),
98    )
99}
100
101#[target_feature(enable = "avx2")]
102pub unsafe fn chunk_num_chars(utf8_chars: &[u8]) -> usize {
103    assert!(utf8_chars.len() >= 32);
104
105    let mut offset = 0;
106    let mut count = 0;
107
108    // 8160
109    while utf8_chars.len() >= offset + 32 * 255 {
110        let mut counts = _mm256_setzero_si256();
111
112        for _ in 0..255 {
113            counts = _mm256_sub_epi8(
114                counts,
115                is_leading_utf8_byte(mm256_from_offset(utf8_chars, offset)),
116            );
117            offset += 32;
118        }
119        count += sum(&counts);
120    }
121
122    // 4096
123    if utf8_chars.len() >= offset + 32 * 128 {
124        let mut counts = _mm256_setzero_si256();
125        for _ in 0..128 {
126            counts = _mm256_sub_epi8(
127                counts,
128                is_leading_utf8_byte(mm256_from_offset(utf8_chars, offset)),
129            );
130            offset += 32;
131        }
132        count += sum(&counts);
133    }
134
135    // 32
136    let mut counts = _mm256_setzero_si256();
137    for i in 0..(utf8_chars.len() - offset) / 32 {
138        counts = _mm256_sub_epi8(
139            counts,
140            is_leading_utf8_byte(mm256_from_offset(utf8_chars, offset + i * 32)),
141        );
142    }
143    if utf8_chars.len() % 32 != 0 {
144        counts = _mm256_sub_epi8(
145            counts,
146            _mm256_and_si256(
147                is_leading_utf8_byte(mm256_from_offset(utf8_chars, utf8_chars.len() - 32)),
148                mm256_from_offset(&MASK, utf8_chars.len() % 32),
149            ),
150        );
151    }
152    count += sum(&counts);
153
154    count
155}