1use super::{Error, SampleBorrow, SampleUniform, UniformSampler};
13use crate::distr::float::IntoFloat;
14use crate::distr::utils::{BoolAsSIMD, FloatAsSIMD, FloatSIMDUtils, IntAsSIMD};
15use crate::Rng;
16
17#[cfg(feature = "simd_support")]
18use core::simd::prelude::*;
19#[cfg(feature = "serde")]
23use serde::{Deserialize, Serialize};
24
25#[derive(Clone, Copy, Debug, PartialEq)]
49#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
50pub struct UniformFloat<X> {
51 low: X,
52 scale: X,
53}
54
55macro_rules! uniform_float_impl {
56 ($($meta:meta)?, $ty:ty, $uty:ident, $f_scalar:ident, $u_scalar:ident, $bits_to_discard:expr) => {
57 $(#[cfg($meta)])?
58 impl UniformFloat<$ty> {
59 fn new_bounded(low: $ty, high: $ty, mut scale: $ty) -> Self {
68 let max_rand = <$ty>::splat(1.0 as $f_scalar - $f_scalar::EPSILON);
69
70 loop {
71 let mask = (scale * max_rand + low).gt_mask(high);
72 if !mask.any() {
73 break;
74 }
75 scale = scale.decrease_masked(mask);
76 }
77
78 debug_assert!(<$ty>::splat(0.0).all_le(scale));
79
80 UniformFloat { low, scale }
81 }
82 }
83
84 $(#[cfg($meta)])?
85 impl SampleUniform for $ty {
86 type Sampler = UniformFloat<$ty>;
87 }
88
89 $(#[cfg($meta)])?
90 impl UniformSampler for UniformFloat<$ty> {
91 type X = $ty;
92
93 fn new<B1, B2>(low_b: B1, high_b: B2) -> Result<Self, Error>
94 where
95 B1: SampleBorrow<Self::X> + Sized,
96 B2: SampleBorrow<Self::X> + Sized,
97 {
98 let low = *low_b.borrow();
99 let high = *high_b.borrow();
100 #[cfg(debug_assertions)]
101 if !(low.all_finite()) || !(high.all_finite()) {
102 return Err(Error::NonFinite);
103 }
104 if !(low.all_lt(high)) {
105 return Err(Error::EmptyRange);
106 }
107
108 let scale = high - low;
109 if !(scale.all_finite()) {
110 return Err(Error::NonFinite);
111 }
112
113 Ok(Self::new_bounded(low, high, scale))
114 }
115
116 fn new_inclusive<B1, B2>(low_b: B1, high_b: B2) -> Result<Self, Error>
117 where
118 B1: SampleBorrow<Self::X> + Sized,
119 B2: SampleBorrow<Self::X> + Sized,
120 {
121 let low = *low_b.borrow();
122 let high = *high_b.borrow();
123 #[cfg(debug_assertions)]
124 if !(low.all_finite()) || !(high.all_finite()) {
125 return Err(Error::NonFinite);
126 }
127 if !low.all_le(high) {
128 return Err(Error::EmptyRange);
129 }
130
131 let max_rand = <$ty>::splat(1.0 as $f_scalar - $f_scalar::EPSILON);
132 let scale = (high - low) / max_rand;
133 if !scale.all_finite() {
134 return Err(Error::NonFinite);
135 }
136
137 Ok(Self::new_bounded(low, high, scale))
138 }
139
140 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Self::X {
141 let value1_2 = (rng.random::<$uty>() >> $uty::splat($bits_to_discard)).into_float_with_exponent(0);
143
144 let value0_1 = value1_2 - <$ty>::splat(1.0);
146
147 value0_1 * self.scale + self.low
153 }
154
155 #[inline]
156 fn sample_single<R: Rng + ?Sized, B1, B2>(low_b: B1, high_b: B2, rng: &mut R) -> Result<Self::X, Error>
157 where
158 B1: SampleBorrow<Self::X> + Sized,
159 B2: SampleBorrow<Self::X> + Sized,
160 {
161 Self::sample_single_inclusive(low_b, high_b, rng)
162 }
163
164 #[inline]
165 fn sample_single_inclusive<R: Rng + ?Sized, B1, B2>(low_b: B1, high_b: B2, rng: &mut R) -> Result<Self::X, Error>
166 where
167 B1: SampleBorrow<Self::X> + Sized,
168 B2: SampleBorrow<Self::X> + Sized,
169 {
170 let low = *low_b.borrow();
171 let high = *high_b.borrow();
172 #[cfg(debug_assertions)]
173 if !low.all_finite() || !high.all_finite() {
174 return Err(Error::NonFinite);
175 }
176 if !low.all_le(high) {
177 return Err(Error::EmptyRange);
178 }
179 let scale = high - low;
180 if !scale.all_finite() {
181 return Err(Error::NonFinite);
182 }
183
184 let value1_2 =
186 (rng.random::<$uty>() >> $uty::splat($bits_to_discard)).into_float_with_exponent(0);
187
188 let value0_1 = value1_2 - <$ty>::splat(1.0);
190
191 Ok(value0_1 * scale + low)
194 }
195 }
196 };
197}
198
199uniform_float_impl! { , f32, u32, f32, u32, 32 - 23 }
200uniform_float_impl! { , f64, u64, f64, u64, 64 - 52 }
201
202#[cfg(feature = "simd_support")]
203uniform_float_impl! { feature = "simd_support", f32x2, u32x2, f32, u32, 32 - 23 }
204#[cfg(feature = "simd_support")]
205uniform_float_impl! { feature = "simd_support", f32x4, u32x4, f32, u32, 32 - 23 }
206#[cfg(feature = "simd_support")]
207uniform_float_impl! { feature = "simd_support", f32x8, u32x8, f32, u32, 32 - 23 }
208#[cfg(feature = "simd_support")]
209uniform_float_impl! { feature = "simd_support", f32x16, u32x16, f32, u32, 32 - 23 }
210
211#[cfg(feature = "simd_support")]
212uniform_float_impl! { feature = "simd_support", f64x2, u64x2, f64, u64, 64 - 52 }
213#[cfg(feature = "simd_support")]
214uniform_float_impl! { feature = "simd_support", f64x4, u64x4, f64, u64, 64 - 52 }
215#[cfg(feature = "simd_support")]
216uniform_float_impl! { feature = "simd_support", f64x8, u64x8, f64, u64, 64 - 52 }
217
218#[cfg(test)]
219mod tests {
220 use super::*;
221 use crate::distr::{utils::FloatSIMDScalarUtils, Uniform};
222 use crate::rngs::mock::StepRng;
223
224 #[test]
225 #[cfg_attr(miri, ignore)] fn test_floats() {
227 let mut rng = crate::test::rng(252);
228 let mut zero_rng = StepRng::new(0, 0);
229 let mut max_rng = StepRng::new(0xffff_ffff_ffff_ffff, 0);
230 macro_rules! t {
231 ($ty:ty, $f_scalar:ident, $bits_shifted:expr) => {{
232 let v: &[($f_scalar, $f_scalar)] = &[
233 (0.0, 100.0),
234 (-1e35, -1e25),
235 (1e-35, 1e-25),
236 (-1e35, 1e35),
237 (<$f_scalar>::from_bits(0), <$f_scalar>::from_bits(3)),
238 (-<$f_scalar>::from_bits(10), -<$f_scalar>::from_bits(1)),
239 (-<$f_scalar>::from_bits(5), 0.0),
240 (-<$f_scalar>::from_bits(7), -0.0),
241 (0.1 * $f_scalar::MAX, $f_scalar::MAX),
242 (-$f_scalar::MAX * 0.2, $f_scalar::MAX * 0.7),
243 ];
244 for &(low_scalar, high_scalar) in v.iter() {
245 for lane in 0..<$ty>::LEN {
246 let low = <$ty>::splat(0.0 as $f_scalar).replace(lane, low_scalar);
247 let high = <$ty>::splat(1.0 as $f_scalar).replace(lane, high_scalar);
248 let my_uniform = Uniform::new(low, high).unwrap();
249 let my_incl_uniform = Uniform::new_inclusive(low, high).unwrap();
250 for _ in 0..100 {
251 let v = rng.sample(my_uniform).extract(lane);
252 assert!(low_scalar <= v && v <= high_scalar);
253 let v = rng.sample(my_incl_uniform).extract(lane);
254 assert!(low_scalar <= v && v <= high_scalar);
255 let v =
256 <$ty as SampleUniform>::Sampler::sample_single(low, high, &mut rng)
257 .unwrap()
258 .extract(lane);
259 assert!(low_scalar <= v && v <= high_scalar);
260 let v = <$ty as SampleUniform>::Sampler::sample_single_inclusive(
261 low, high, &mut rng,
262 )
263 .unwrap()
264 .extract(lane);
265 assert!(low_scalar <= v && v <= high_scalar);
266 }
267
268 assert_eq!(
269 rng.sample(Uniform::new_inclusive(low, low).unwrap())
270 .extract(lane),
271 low_scalar
272 );
273
274 assert_eq!(zero_rng.sample(my_uniform).extract(lane), low_scalar);
275 assert_eq!(zero_rng.sample(my_incl_uniform).extract(lane), low_scalar);
276 assert_eq!(
277 <$ty as SampleUniform>::Sampler::sample_single(
278 low,
279 high,
280 &mut zero_rng
281 )
282 .unwrap()
283 .extract(lane),
284 low_scalar
285 );
286 assert_eq!(
287 <$ty as SampleUniform>::Sampler::sample_single_inclusive(
288 low,
289 high,
290 &mut zero_rng
291 )
292 .unwrap()
293 .extract(lane),
294 low_scalar
295 );
296
297 assert!(max_rng.sample(my_uniform).extract(lane) <= high_scalar);
298 assert!(max_rng.sample(my_incl_uniform).extract(lane) <= high_scalar);
299 assert!(
304 <$ty as SampleUniform>::Sampler::sample_single_inclusive(
305 low,
306 high,
307 &mut max_rng
308 )
309 .unwrap()
310 .extract(lane)
311 <= high_scalar
312 );
313
314 if (high_scalar - low_scalar) > 0.0001 {
318 let mut lowering_max_rng = StepRng::new(
319 0xffff_ffff_ffff_ffff,
320 (-1i64 << $bits_shifted) as u64,
321 );
322 assert!(
323 <$ty as SampleUniform>::Sampler::sample_single(
324 low,
325 high,
326 &mut lowering_max_rng
327 )
328 .unwrap()
329 .extract(lane)
330 <= high_scalar
331 );
332 }
333 }
334 }
335
336 assert_eq!(
337 rng.sample(Uniform::new_inclusive($f_scalar::MAX, $f_scalar::MAX).unwrap()),
338 $f_scalar::MAX
339 );
340 assert_eq!(
341 rng.sample(Uniform::new_inclusive(-$f_scalar::MAX, -$f_scalar::MAX).unwrap()),
342 -$f_scalar::MAX
343 );
344 }};
345 }
346
347 t!(f32, f32, 32 - 23);
348 t!(f64, f64, 64 - 52);
349 #[cfg(feature = "simd_support")]
350 {
351 t!(f32x2, f32, 32 - 23);
352 t!(f32x4, f32, 32 - 23);
353 t!(f32x8, f32, 32 - 23);
354 t!(f32x16, f32, 32 - 23);
355 t!(f64x2, f64, 64 - 52);
356 t!(f64x4, f64, 64 - 52);
357 t!(f64x8, f64, 64 - 52);
358 }
359 }
360
361 #[test]
362 fn test_float_overflow() {
363 assert_eq!(Uniform::try_from(f64::MIN..f64::MAX), Err(Error::NonFinite));
364 }
365
366 #[test]
367 #[should_panic]
368 fn test_float_overflow_single() {
369 let mut rng = crate::test::rng(252);
370 rng.random_range(f64::MIN..f64::MAX);
371 }
372
373 #[test]
374 #[cfg(all(feature = "std", panic = "unwind"))]
375 fn test_float_assertions() {
376 use super::SampleUniform;
377 fn range<T: SampleUniform>(low: T, high: T) -> Result<T, Error> {
378 let mut rng = crate::test::rng(253);
379 T::Sampler::sample_single(low, high, &mut rng)
380 }
381
382 macro_rules! t {
383 ($ty:ident, $f_scalar:ident) => {{
384 let v: &[($f_scalar, $f_scalar)] = &[
385 ($f_scalar::NAN, 0.0),
386 (1.0, $f_scalar::NAN),
387 ($f_scalar::NAN, $f_scalar::NAN),
388 (1.0, 0.5),
389 ($f_scalar::MAX, -$f_scalar::MAX),
390 ($f_scalar::INFINITY, $f_scalar::INFINITY),
391 ($f_scalar::NEG_INFINITY, $f_scalar::NEG_INFINITY),
392 ($f_scalar::NEG_INFINITY, 5.0),
393 (5.0, $f_scalar::INFINITY),
394 ($f_scalar::NAN, $f_scalar::INFINITY),
395 ($f_scalar::NEG_INFINITY, $f_scalar::NAN),
396 ($f_scalar::NEG_INFINITY, $f_scalar::INFINITY),
397 ];
398 for &(low_scalar, high_scalar) in v.iter() {
399 for lane in 0..<$ty>::LEN {
400 let low = <$ty>::splat(0.0 as $f_scalar).replace(lane, low_scalar);
401 let high = <$ty>::splat(1.0 as $f_scalar).replace(lane, high_scalar);
402 assert!(range(low, high).is_err());
403 assert!(Uniform::new(low, high).is_err());
404 assert!(Uniform::new_inclusive(low, high).is_err());
405 assert!(Uniform::new(low, low).is_err());
406 }
407 }
408 }};
409 }
410
411 t!(f32, f32);
412 t!(f64, f64);
413 #[cfg(feature = "simd_support")]
414 {
415 t!(f32x2, f32);
416 t!(f32x4, f32);
417 t!(f32x8, f32);
418 t!(f32x16, f32);
419 t!(f64x2, f64);
420 t!(f64x4, f64);
421 t!(f64x8, f64);
422 }
423 }
424
425 #[test]
426 fn test_uniform_from_std_range() {
427 let r = Uniform::try_from(2.0f64..7.0).unwrap();
428 assert_eq!(r.0.low, 2.0);
429 assert_eq!(r.0.scale, 5.0);
430 }
431
432 #[test]
433 fn test_uniform_from_std_range_bad_limits() {
434 #![allow(clippy::reversed_empty_ranges)]
435 assert!(Uniform::try_from(100.0..10.0).is_err());
436 assert!(Uniform::try_from(100.0..100.0).is_err());
437 }
438
439 #[test]
440 fn test_uniform_from_std_range_inclusive() {
441 let r = Uniform::try_from(2.0f64..=7.0).unwrap();
442 assert_eq!(r.0.low, 2.0);
443 assert!(r.0.scale > 5.0);
444 assert!(r.0.scale < 5.0 + 1e-14);
445 }
446
447 #[test]
448 fn test_uniform_from_std_range_inclusive_bad_limits() {
449 #![allow(clippy::reversed_empty_ranges)]
450 assert!(Uniform::try_from(100.0..=10.0).is_err());
451 assert!(Uniform::try_from(100.0..=99.0).is_err());
452 }
453}