rand/distr/
utils.rs

1// Copyright 2018 Developers of the Rand project.
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9//! Math helper functions
10
11#[cfg(feature = "simd_support")]
12use core::simd::prelude::*;
13#[cfg(feature = "simd_support")]
14use core::simd::{LaneCount, SimdElement, SupportedLaneCount};
15
16pub(crate) trait WideningMultiply<RHS = Self> {
17    type Output;
18
19    fn wmul(self, x: RHS) -> Self::Output;
20}
21
22macro_rules! wmul_impl {
23    ($ty:ty, $wide:ty, $shift:expr) => {
24        impl WideningMultiply for $ty {
25            type Output = ($ty, $ty);
26
27            #[inline(always)]
28            fn wmul(self, x: $ty) -> Self::Output {
29                let tmp = (self as $wide) * (x as $wide);
30                ((tmp >> $shift) as $ty, tmp as $ty)
31            }
32        }
33    };
34
35    // simd bulk implementation
36    ($(($ty:ident, $wide:ty),)+, $shift:expr) => {
37        $(
38            impl WideningMultiply for $ty {
39                type Output = ($ty, $ty);
40
41                #[inline(always)]
42                fn wmul(self, x: $ty) -> Self::Output {
43                    // For supported vectors, this should compile to a couple
44                    // supported multiply & swizzle instructions (no actual
45                    // casting).
46                    // TODO: optimize
47                    let y: $wide = self.cast();
48                    let x: $wide = x.cast();
49                    let tmp = y * x;
50                    let hi: $ty = (tmp >> Simd::splat($shift)).cast();
51                    let lo: $ty = tmp.cast();
52                    (hi, lo)
53                }
54            }
55        )+
56    };
57}
58wmul_impl! { u8, u16, 8 }
59wmul_impl! { u16, u32, 16 }
60wmul_impl! { u32, u64, 32 }
61wmul_impl! { u64, u128, 64 }
62
63// This code is a translation of the __mulddi3 function in LLVM's
64// compiler-rt. It is an optimised variant of the common method
65// `(a + b) * (c + d) = ac + ad + bc + bd`.
66//
67// For some reason LLVM can optimise the C version very well, but
68// keeps shuffling registers in this Rust translation.
69macro_rules! wmul_impl_large {
70    ($ty:ty, $half:expr) => {
71        impl WideningMultiply for $ty {
72            type Output = ($ty, $ty);
73
74            #[inline(always)]
75            fn wmul(self, b: $ty) -> Self::Output {
76                const LOWER_MASK: $ty = !0 >> $half;
77                let mut low = (self & LOWER_MASK).wrapping_mul(b & LOWER_MASK);
78                let mut t = low >> $half;
79                low &= LOWER_MASK;
80                t += (self >> $half).wrapping_mul(b & LOWER_MASK);
81                low += (t & LOWER_MASK) << $half;
82                let mut high = t >> $half;
83                t = low >> $half;
84                low &= LOWER_MASK;
85                t += (b >> $half).wrapping_mul(self & LOWER_MASK);
86                low += (t & LOWER_MASK) << $half;
87                high += t >> $half;
88                high += (self >> $half).wrapping_mul(b >> $half);
89
90                (high, low)
91            }
92        }
93    };
94
95    // simd bulk implementation
96    (($($ty:ty,)+) $scalar:ty, $half:expr) => {
97        $(
98            impl WideningMultiply for $ty {
99                type Output = ($ty, $ty);
100
101                #[inline(always)]
102                fn wmul(self, b: $ty) -> Self::Output {
103                    // needs wrapping multiplication
104                    let lower_mask = <$ty>::splat(!0 >> $half);
105                    let half = <$ty>::splat($half);
106                    let mut low = (self & lower_mask) * (b & lower_mask);
107                    let mut t = low >> half;
108                    low &= lower_mask;
109                    t += (self >> half) * (b & lower_mask);
110                    low += (t & lower_mask) << half;
111                    let mut high = t >> half;
112                    t = low >> half;
113                    low &= lower_mask;
114                    t += (b >> half) * (self & lower_mask);
115                    low += (t & lower_mask) << half;
116                    high += t >> half;
117                    high += (self >> half) * (b >> half);
118
119                    (high, low)
120                }
121            }
122        )+
123    };
124}
125wmul_impl_large! { u128, 64 }
126
127macro_rules! wmul_impl_usize {
128    ($ty:ty) => {
129        impl WideningMultiply for usize {
130            type Output = (usize, usize);
131
132            #[inline(always)]
133            fn wmul(self, x: usize) -> Self::Output {
134                let (high, low) = (self as $ty).wmul(x as $ty);
135                (high as usize, low as usize)
136            }
137        }
138    };
139}
140#[cfg(target_pointer_width = "16")]
141wmul_impl_usize! { u16 }
142#[cfg(target_pointer_width = "32")]
143wmul_impl_usize! { u32 }
144#[cfg(target_pointer_width = "64")]
145wmul_impl_usize! { u64 }
146
147#[cfg(feature = "simd_support")]
148mod simd_wmul {
149    use super::*;
150    #[cfg(target_arch = "x86")]
151    use core::arch::x86::*;
152    #[cfg(target_arch = "x86_64")]
153    use core::arch::x86_64::*;
154
155    wmul_impl! {
156        (u8x4, u16x4),
157        (u8x8, u16x8),
158        (u8x16, u16x16),
159        (u8x32, u16x32),
160        (u8x64, Simd<u16, 64>),,
161        8
162    }
163
164    wmul_impl! { (u16x2, u32x2),, 16 }
165    wmul_impl! { (u16x4, u32x4),, 16 }
166    #[cfg(not(target_feature = "sse2"))]
167    wmul_impl! { (u16x8, u32x8),, 16 }
168    #[cfg(not(target_feature = "avx2"))]
169    wmul_impl! { (u16x16, u32x16),, 16 }
170    #[cfg(not(target_feature = "avx512bw"))]
171    wmul_impl! { (u16x32, Simd<u32, 32>),, 16 }
172
173    // 16-bit lane widths allow use of the x86 `mulhi` instructions, which
174    // means `wmul` can be implemented with only two instructions.
175    #[allow(unused_macros)]
176    macro_rules! wmul_impl_16 {
177        ($ty:ident, $mulhi:ident, $mullo:ident) => {
178            impl WideningMultiply for $ty {
179                type Output = ($ty, $ty);
180
181                #[inline(always)]
182                fn wmul(self, x: $ty) -> Self::Output {
183                    let hi = unsafe { $mulhi(self.into(), x.into()) }.into();
184                    let lo = unsafe { $mullo(self.into(), x.into()) }.into();
185                    (hi, lo)
186                }
187            }
188        };
189    }
190
191    #[cfg(target_feature = "sse2")]
192    wmul_impl_16! { u16x8, _mm_mulhi_epu16, _mm_mullo_epi16 }
193    #[cfg(target_feature = "avx2")]
194    wmul_impl_16! { u16x16, _mm256_mulhi_epu16, _mm256_mullo_epi16 }
195    #[cfg(target_feature = "avx512bw")]
196    wmul_impl_16! { u16x32, _mm512_mulhi_epu16, _mm512_mullo_epi16 }
197
198    wmul_impl! {
199        (u32x2, u64x2),
200        (u32x4, u64x4),
201        (u32x8, u64x8),
202        (u32x16, Simd<u64, 16>),,
203        32
204    }
205
206    wmul_impl_large! { (u64x2, u64x4, u64x8,) u64, 32 }
207}
208
209/// Helper trait when dealing with scalar and SIMD floating point types.
210pub(crate) trait FloatSIMDUtils {
211    // `PartialOrd` for vectors compares lexicographically. We want to compare all
212    // the individual SIMD lanes instead, and get the combined result over all
213    // lanes. This is possible using something like `a.lt(b).all()`, but we
214    // implement it as a trait so we can write the same code for `f32` and `f64`.
215    // Only the comparison functions we need are implemented.
216    fn all_lt(self, other: Self) -> bool;
217    fn all_le(self, other: Self) -> bool;
218    fn all_finite(self) -> bool;
219
220    type Mask;
221    fn gt_mask(self, other: Self) -> Self::Mask;
222
223    // Decrease all lanes where the mask is `true` to the next lower value
224    // representable by the floating-point type. At least one of the lanes
225    // must be set.
226    fn decrease_masked(self, mask: Self::Mask) -> Self;
227
228    // Convert from int value. Conversion is done while retaining the numerical
229    // value, not by retaining the binary representation.
230    type UInt;
231    fn cast_from_int(i: Self::UInt) -> Self;
232}
233
234#[cfg(test)]
235pub(crate) trait FloatSIMDScalarUtils: FloatSIMDUtils {
236    type Scalar;
237
238    fn replace(self, index: usize, new_value: Self::Scalar) -> Self;
239    fn extract(self, index: usize) -> Self::Scalar;
240}
241
242/// Implement functions on f32/f64 to give them APIs similar to SIMD types
243pub(crate) trait FloatAsSIMD: Sized {
244    #[cfg(test)]
245    const LEN: usize = 1;
246
247    #[inline(always)]
248    fn splat(scalar: Self) -> Self {
249        scalar
250    }
251}
252
253pub(crate) trait IntAsSIMD: Sized {
254    #[inline(always)]
255    fn splat(scalar: Self) -> Self {
256        scalar
257    }
258}
259
260impl IntAsSIMD for u32 {}
261impl IntAsSIMD for u64 {}
262
263pub(crate) trait BoolAsSIMD: Sized {
264    fn any(self) -> bool;
265}
266
267impl BoolAsSIMD for bool {
268    #[inline(always)]
269    fn any(self) -> bool {
270        self
271    }
272}
273
274macro_rules! scalar_float_impl {
275    ($ty:ident, $uty:ident) => {
276        impl FloatSIMDUtils for $ty {
277            type Mask = bool;
278            type UInt = $uty;
279
280            #[inline(always)]
281            fn all_lt(self, other: Self) -> bool {
282                self < other
283            }
284
285            #[inline(always)]
286            fn all_le(self, other: Self) -> bool {
287                self <= other
288            }
289
290            #[inline(always)]
291            fn all_finite(self) -> bool {
292                self.is_finite()
293            }
294
295            #[inline(always)]
296            fn gt_mask(self, other: Self) -> Self::Mask {
297                self > other
298            }
299
300            #[inline(always)]
301            fn decrease_masked(self, mask: Self::Mask) -> Self {
302                debug_assert!(mask, "At least one lane must be set");
303                <$ty>::from_bits(self.to_bits() - 1)
304            }
305
306            #[inline]
307            fn cast_from_int(i: Self::UInt) -> Self {
308                i as $ty
309            }
310        }
311
312        #[cfg(test)]
313        impl FloatSIMDScalarUtils for $ty {
314            type Scalar = $ty;
315
316            #[inline]
317            fn replace(self, index: usize, new_value: Self::Scalar) -> Self {
318                debug_assert_eq!(index, 0);
319                new_value
320            }
321
322            #[inline]
323            fn extract(self, index: usize) -> Self::Scalar {
324                debug_assert_eq!(index, 0);
325                self
326            }
327        }
328
329        impl FloatAsSIMD for $ty {}
330    };
331}
332
333scalar_float_impl!(f32, u32);
334scalar_float_impl!(f64, u64);
335
336#[cfg(feature = "simd_support")]
337macro_rules! simd_impl {
338    ($fty:ident, $uty:ident) => {
339        impl<const LANES: usize> FloatSIMDUtils for Simd<$fty, LANES>
340        where
341            LaneCount<LANES>: SupportedLaneCount,
342        {
343            type Mask = Mask<<$fty as SimdElement>::Mask, LANES>;
344            type UInt = Simd<$uty, LANES>;
345
346            #[inline(always)]
347            fn all_lt(self, other: Self) -> bool {
348                self.simd_lt(other).all()
349            }
350
351            #[inline(always)]
352            fn all_le(self, other: Self) -> bool {
353                self.simd_le(other).all()
354            }
355
356            #[inline(always)]
357            fn all_finite(self) -> bool {
358                self.is_finite().all()
359            }
360
361            #[inline(always)]
362            fn gt_mask(self, other: Self) -> Self::Mask {
363                self.simd_gt(other)
364            }
365
366            #[inline(always)]
367            fn decrease_masked(self, mask: Self::Mask) -> Self {
368                // Casting a mask into ints will produce all bits set for
369                // true, and 0 for false. Adding that to the binary
370                // representation of a float means subtracting one from
371                // the binary representation, resulting in the next lower
372                // value representable by $fty. This works even when the
373                // current value is infinity.
374                debug_assert!(mask.any(), "At least one lane must be set");
375                Self::from_bits(self.to_bits() + mask.to_int().cast())
376            }
377
378            #[inline]
379            fn cast_from_int(i: Self::UInt) -> Self {
380                i.cast()
381            }
382        }
383
384        #[cfg(test)]
385        impl<const LANES: usize> FloatSIMDScalarUtils for Simd<$fty, LANES>
386        where
387            LaneCount<LANES>: SupportedLaneCount,
388        {
389            type Scalar = $fty;
390
391            #[inline]
392            fn replace(mut self, index: usize, new_value: Self::Scalar) -> Self {
393                self.as_mut_array()[index] = new_value;
394                self
395            }
396
397            #[inline]
398            fn extract(self, index: usize) -> Self::Scalar {
399                self.as_array()[index]
400            }
401        }
402    };
403}
404
405#[cfg(feature = "simd_support")]
406simd_impl!(f32, u32);
407#[cfg(feature = "simd_support")]
408simd_impl!(f64, u64);