bytecount/simd/
x86_sse2.rs

1#[cfg(target_arch = "x86")]
2use std::arch::x86::{
3    __m128i,
4    _mm_and_si128,
5    _mm_cmpeq_epi8,
6    _mm_cvtsi128_si32,
7    _mm_loadu_si128,
8    _mm_sad_epu8,
9    _mm_set1_epi8,
10    _mm_setzero_si128,
11    _mm_shuffle_epi32,
12    _mm_sub_epi8,
13    _mm_xor_si128,
14};
15
16#[cfg(target_arch = "x86_64")]
17use std::arch::x86_64::{
18    __m128i,
19    _mm_and_si128,
20    _mm_cmpeq_epi8,
21    _mm_cvtsi128_si32,
22    _mm_loadu_si128,
23    _mm_sad_epu8,
24    _mm_set1_epi8,
25    _mm_setzero_si128,
26    _mm_shuffle_epi32,
27    _mm_sub_epi8,
28    _mm_xor_si128,
29};
30
31#[target_feature(enable = "sse2")]
32pub unsafe fn _mm_set1_epu8(a: u8) -> __m128i {
33    _mm_set1_epi8(a as i8)
34}
35
36#[target_feature(enable = "sse2")]
37pub unsafe fn mm_cmpneq_epi8(a: __m128i, b: __m128i) -> __m128i {
38    _mm_xor_si128(_mm_cmpeq_epi8(a, b), _mm_set1_epi8(-1))
39}
40
41const MASK: [u8; 32] = [
42    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
43    255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
44];
45
46#[target_feature(enable = "sse2")]
47unsafe fn mm_from_offset(slice: &[u8], offset: usize) -> __m128i {
48    _mm_loadu_si128(slice.as_ptr().offset(offset as isize) as *const _)
49}
50
51#[target_feature(enable = "sse2")]
52unsafe fn sum(u8s: &__m128i) -> usize {
53    let sums = _mm_sad_epu8(*u8s, _mm_setzero_si128());
54    (_mm_cvtsi128_si32(sums) + _mm_cvtsi128_si32(_mm_shuffle_epi32(sums, 0xaa))) as usize
55}
56
57#[target_feature(enable = "sse2")]
58pub unsafe fn chunk_count(haystack: &[u8], needle: u8) -> usize {
59    assert!(haystack.len() >= 16);
60
61    let mut offset = 0;
62    let mut count = 0;
63
64    let needles = _mm_set1_epu8(needle);
65
66    // 4080
67    while haystack.len() >= offset + 16 * 255 {
68        let mut counts = _mm_setzero_si128();
69        for _ in 0..255 {
70            counts = _mm_sub_epi8(
71                counts,
72                _mm_cmpeq_epi8(mm_from_offset(haystack, offset), needles)
73            );
74            offset += 16;
75        }
76        count += sum(&counts);
77    }
78
79    // 2048
80    if haystack.len() >= offset + 16 * 128 {
81        let mut counts = _mm_setzero_si128();
82        for _ in 0..128 {
83            counts = _mm_sub_epi8(
84                counts,
85                _mm_cmpeq_epi8(mm_from_offset(haystack, offset), needles)
86            );
87            offset += 16;
88        }
89        count += sum(&counts);
90    }
91
92    // 16
93    let mut counts = _mm_setzero_si128();
94    for i in 0..(haystack.len() - offset) / 16 {
95        counts = _mm_sub_epi8(
96            counts,
97            _mm_cmpeq_epi8(mm_from_offset(haystack, offset + i * 16), needles)
98        );
99    }
100    if haystack.len() % 16 != 0 {
101        counts = _mm_sub_epi8(
102            counts,
103            _mm_and_si128(
104                _mm_cmpeq_epi8(mm_from_offset(haystack, haystack.len() - 16), needles),
105                                  mm_from_offset(&MASK, haystack.len() % 16)
106            )
107        );
108    }
109    count += sum(&counts);
110
111    count
112}
113
114#[target_feature(enable = "sse2")]
115unsafe fn is_leading_utf8_byte(u8s: __m128i) -> __m128i {
116    mm_cmpneq_epi8(_mm_and_si128(u8s, _mm_set1_epu8(0b1100_0000)), _mm_set1_epu8(0b1000_0000))
117}
118
119#[target_feature(enable = "sse2")]
120pub unsafe fn chunk_num_chars(utf8_chars: &[u8]) -> usize {
121    assert!(utf8_chars.len() >= 16);
122
123    let mut offset = 0;
124    let mut count = 0;
125
126    // 4080
127    while utf8_chars.len() >= offset + 16 * 255 {
128        let mut counts = _mm_setzero_si128();
129
130        for _ in 0..255 {
131            counts = _mm_sub_epi8(
132                counts,
133                is_leading_utf8_byte(mm_from_offset(utf8_chars, offset))
134            );
135            offset += 16;
136        }
137        count += sum(&counts);
138    }
139
140    // 2048
141    if utf8_chars.len() >= offset + 16 * 128 {
142        let mut counts = _mm_setzero_si128();
143        for _ in 0..128 {
144            counts = _mm_sub_epi8(
145                counts,
146                is_leading_utf8_byte(mm_from_offset(utf8_chars, offset))
147            );
148            offset += 16;
149        }
150        count += sum(&counts);
151    }
152
153    // 16
154    let mut counts = _mm_setzero_si128();
155    for i in 0..(utf8_chars.len() - offset) / 16 {
156        counts = _mm_sub_epi8(
157            counts,
158            is_leading_utf8_byte(mm_from_offset(utf8_chars, offset + i * 16))
159        );
160    }
161    if utf8_chars.len() % 16 != 0 {
162        counts = _mm_sub_epi8(
163            counts,
164            _mm_and_si128(
165                is_leading_utf8_byte(mm_from_offset(utf8_chars, utf8_chars.len() - 16)),
166                                     mm_from_offset(&MASK,      utf8_chars.len() % 16)
167            )
168        );
169    }
170    count += sum(&counts);
171
172    count
173}