1#[cfg(feature = "simd_support")]
12use core::simd::prelude::*;
13#[cfg(feature = "simd_support")]
14use core::simd::{LaneCount, SimdElement, SupportedLaneCount};
15
16pub(crate) trait WideningMultiply<RHS = Self> {
17 type Output;
18
19 fn wmul(self, x: RHS) -> Self::Output;
20}
21
22macro_rules! wmul_impl {
23 ($ty:ty, $wide:ty, $shift:expr) => {
24 impl WideningMultiply for $ty {
25 type Output = ($ty, $ty);
26
27 #[inline(always)]
28 fn wmul(self, x: $ty) -> Self::Output {
29 let tmp = (self as $wide) * (x as $wide);
30 ((tmp >> $shift) as $ty, tmp as $ty)
31 }
32 }
33 };
34
35 ($(($ty:ident, $wide:ty),)+, $shift:expr) => {
37 $(
38 impl WideningMultiply for $ty {
39 type Output = ($ty, $ty);
40
41 #[inline(always)]
42 fn wmul(self, x: $ty) -> Self::Output {
43 let y: $wide = self.cast();
48 let x: $wide = x.cast();
49 let tmp = y * x;
50 let hi: $ty = (tmp >> Simd::splat($shift)).cast();
51 let lo: $ty = tmp.cast();
52 (hi, lo)
53 }
54 }
55 )+
56 };
57}
58wmul_impl! { u8, u16, 8 }
59wmul_impl! { u16, u32, 16 }
60wmul_impl! { u32, u64, 32 }
61wmul_impl! { u64, u128, 64 }
62
63macro_rules! wmul_impl_large {
70 ($ty:ty, $half:expr) => {
71 impl WideningMultiply for $ty {
72 type Output = ($ty, $ty);
73
74 #[inline(always)]
75 fn wmul(self, b: $ty) -> Self::Output {
76 const LOWER_MASK: $ty = !0 >> $half;
77 let mut low = (self & LOWER_MASK).wrapping_mul(b & LOWER_MASK);
78 let mut t = low >> $half;
79 low &= LOWER_MASK;
80 t += (self >> $half).wrapping_mul(b & LOWER_MASK);
81 low += (t & LOWER_MASK) << $half;
82 let mut high = t >> $half;
83 t = low >> $half;
84 low &= LOWER_MASK;
85 t += (b >> $half).wrapping_mul(self & LOWER_MASK);
86 low += (t & LOWER_MASK) << $half;
87 high += t >> $half;
88 high += (self >> $half).wrapping_mul(b >> $half);
89
90 (high, low)
91 }
92 }
93 };
94
95 (($($ty:ty,)+) $scalar:ty, $half:expr) => {
97 $(
98 impl WideningMultiply for $ty {
99 type Output = ($ty, $ty);
100
101 #[inline(always)]
102 fn wmul(self, b: $ty) -> Self::Output {
103 let lower_mask = <$ty>::splat(!0 >> $half);
105 let half = <$ty>::splat($half);
106 let mut low = (self & lower_mask) * (b & lower_mask);
107 let mut t = low >> half;
108 low &= lower_mask;
109 t += (self >> half) * (b & lower_mask);
110 low += (t & lower_mask) << half;
111 let mut high = t >> half;
112 t = low >> half;
113 low &= lower_mask;
114 t += (b >> half) * (self & lower_mask);
115 low += (t & lower_mask) << half;
116 high += t >> half;
117 high += (self >> half) * (b >> half);
118
119 (high, low)
120 }
121 }
122 )+
123 };
124}
125wmul_impl_large! { u128, 64 }
126
127macro_rules! wmul_impl_usize {
128 ($ty:ty) => {
129 impl WideningMultiply for usize {
130 type Output = (usize, usize);
131
132 #[inline(always)]
133 fn wmul(self, x: usize) -> Self::Output {
134 let (high, low) = (self as $ty).wmul(x as $ty);
135 (high as usize, low as usize)
136 }
137 }
138 };
139}
140#[cfg(target_pointer_width = "16")]
141wmul_impl_usize! { u16 }
142#[cfg(target_pointer_width = "32")]
143wmul_impl_usize! { u32 }
144#[cfg(target_pointer_width = "64")]
145wmul_impl_usize! { u64 }
146
147#[cfg(feature = "simd_support")]
148mod simd_wmul {
149 use super::*;
150 #[cfg(target_arch = "x86")]
151 use core::arch::x86::*;
152 #[cfg(target_arch = "x86_64")]
153 use core::arch::x86_64::*;
154
155 wmul_impl! {
156 (u8x4, u16x4),
157 (u8x8, u16x8),
158 (u8x16, u16x16),
159 (u8x32, u16x32),
160 (u8x64, Simd<u16, 64>),,
161 8
162 }
163
164 wmul_impl! { (u16x2, u32x2),, 16 }
165 wmul_impl! { (u16x4, u32x4),, 16 }
166 #[cfg(not(target_feature = "sse2"))]
167 wmul_impl! { (u16x8, u32x8),, 16 }
168 #[cfg(not(target_feature = "avx2"))]
169 wmul_impl! { (u16x16, u32x16),, 16 }
170 #[cfg(not(target_feature = "avx512bw"))]
171 wmul_impl! { (u16x32, Simd<u32, 32>),, 16 }
172
173 #[allow(unused_macros)]
176 macro_rules! wmul_impl_16 {
177 ($ty:ident, $mulhi:ident, $mullo:ident) => {
178 impl WideningMultiply for $ty {
179 type Output = ($ty, $ty);
180
181 #[inline(always)]
182 fn wmul(self, x: $ty) -> Self::Output {
183 let hi = unsafe { $mulhi(self.into(), x.into()) }.into();
184 let lo = unsafe { $mullo(self.into(), x.into()) }.into();
185 (hi, lo)
186 }
187 }
188 };
189 }
190
191 #[cfg(target_feature = "sse2")]
192 wmul_impl_16! { u16x8, _mm_mulhi_epu16, _mm_mullo_epi16 }
193 #[cfg(target_feature = "avx2")]
194 wmul_impl_16! { u16x16, _mm256_mulhi_epu16, _mm256_mullo_epi16 }
195 #[cfg(target_feature = "avx512bw")]
196 wmul_impl_16! { u16x32, _mm512_mulhi_epu16, _mm512_mullo_epi16 }
197
198 wmul_impl! {
199 (u32x2, u64x2),
200 (u32x4, u64x4),
201 (u32x8, u64x8),
202 (u32x16, Simd<u64, 16>),,
203 32
204 }
205
206 wmul_impl_large! { (u64x2, u64x4, u64x8,) u64, 32 }
207}
208
209pub(crate) trait FloatSIMDUtils {
211 fn all_lt(self, other: Self) -> bool;
217 fn all_le(self, other: Self) -> bool;
218 fn all_finite(self) -> bool;
219
220 type Mask;
221 fn gt_mask(self, other: Self) -> Self::Mask;
222
223 fn decrease_masked(self, mask: Self::Mask) -> Self;
227
228 type UInt;
231 fn cast_from_int(i: Self::UInt) -> Self;
232}
233
234#[cfg(test)]
235pub(crate) trait FloatSIMDScalarUtils: FloatSIMDUtils {
236 type Scalar;
237
238 fn replace(self, index: usize, new_value: Self::Scalar) -> Self;
239 fn extract(self, index: usize) -> Self::Scalar;
240}
241
242pub(crate) trait FloatAsSIMD: Sized {
244 #[cfg(test)]
245 const LEN: usize = 1;
246
247 #[inline(always)]
248 fn splat(scalar: Self) -> Self {
249 scalar
250 }
251}
252
253pub(crate) trait IntAsSIMD: Sized {
254 #[inline(always)]
255 fn splat(scalar: Self) -> Self {
256 scalar
257 }
258}
259
260impl IntAsSIMD for u32 {}
261impl IntAsSIMD for u64 {}
262
263pub(crate) trait BoolAsSIMD: Sized {
264 fn any(self) -> bool;
265}
266
267impl BoolAsSIMD for bool {
268 #[inline(always)]
269 fn any(self) -> bool {
270 self
271 }
272}
273
274macro_rules! scalar_float_impl {
275 ($ty:ident, $uty:ident) => {
276 impl FloatSIMDUtils for $ty {
277 type Mask = bool;
278 type UInt = $uty;
279
280 #[inline(always)]
281 fn all_lt(self, other: Self) -> bool {
282 self < other
283 }
284
285 #[inline(always)]
286 fn all_le(self, other: Self) -> bool {
287 self <= other
288 }
289
290 #[inline(always)]
291 fn all_finite(self) -> bool {
292 self.is_finite()
293 }
294
295 #[inline(always)]
296 fn gt_mask(self, other: Self) -> Self::Mask {
297 self > other
298 }
299
300 #[inline(always)]
301 fn decrease_masked(self, mask: Self::Mask) -> Self {
302 debug_assert!(mask, "At least one lane must be set");
303 <$ty>::from_bits(self.to_bits() - 1)
304 }
305
306 #[inline]
307 fn cast_from_int(i: Self::UInt) -> Self {
308 i as $ty
309 }
310 }
311
312 #[cfg(test)]
313 impl FloatSIMDScalarUtils for $ty {
314 type Scalar = $ty;
315
316 #[inline]
317 fn replace(self, index: usize, new_value: Self::Scalar) -> Self {
318 debug_assert_eq!(index, 0);
319 new_value
320 }
321
322 #[inline]
323 fn extract(self, index: usize) -> Self::Scalar {
324 debug_assert_eq!(index, 0);
325 self
326 }
327 }
328
329 impl FloatAsSIMD for $ty {}
330 };
331}
332
333scalar_float_impl!(f32, u32);
334scalar_float_impl!(f64, u64);
335
336#[cfg(feature = "simd_support")]
337macro_rules! simd_impl {
338 ($fty:ident, $uty:ident) => {
339 impl<const LANES: usize> FloatSIMDUtils for Simd<$fty, LANES>
340 where
341 LaneCount<LANES>: SupportedLaneCount,
342 {
343 type Mask = Mask<<$fty as SimdElement>::Mask, LANES>;
344 type UInt = Simd<$uty, LANES>;
345
346 #[inline(always)]
347 fn all_lt(self, other: Self) -> bool {
348 self.simd_lt(other).all()
349 }
350
351 #[inline(always)]
352 fn all_le(self, other: Self) -> bool {
353 self.simd_le(other).all()
354 }
355
356 #[inline(always)]
357 fn all_finite(self) -> bool {
358 self.is_finite().all()
359 }
360
361 #[inline(always)]
362 fn gt_mask(self, other: Self) -> Self::Mask {
363 self.simd_gt(other)
364 }
365
366 #[inline(always)]
367 fn decrease_masked(self, mask: Self::Mask) -> Self {
368 debug_assert!(mask.any(), "At least one lane must be set");
375 Self::from_bits(self.to_bits() + mask.to_int().cast())
376 }
377
378 #[inline]
379 fn cast_from_int(i: Self::UInt) -> Self {
380 i.cast()
381 }
382 }
383
384 #[cfg(test)]
385 impl<const LANES: usize> FloatSIMDScalarUtils for Simd<$fty, LANES>
386 where
387 LaneCount<LANES>: SupportedLaneCount,
388 {
389 type Scalar = $fty;
390
391 #[inline]
392 fn replace(mut self, index: usize, new_value: Self::Scalar) -> Self {
393 self.as_mut_array()[index] = new_value;
394 self
395 }
396
397 #[inline]
398 fn extract(self, index: usize) -> Self::Scalar {
399 self.as_array()[index]
400 }
401 }
402 };
403}
404
405#[cfg(feature = "simd_support")]
406simd_impl!(f32, u32);
407#[cfg(feature = "simd_support")]
408simd_impl!(f64, u64);