rand/distr/
uniform_other.rs

1// Copyright 2018-2020 Developers of the Rand project.
2// Copyright 2017 The Rust Project Developers.
3//
4// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
5// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
6// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
7// option. This file may not be copied, modified, or distributed
8// except according to those terms.
9
10//! `UniformChar`, `UniformDuration` implementations
11
12use super::{Error, SampleBorrow, SampleUniform, Uniform, UniformInt, UniformSampler};
13use crate::distr::Distribution;
14use crate::Rng;
15use core::time::Duration;
16
17#[cfg(feature = "serde")]
18use serde::{Deserialize, Serialize};
19
20impl SampleUniform for char {
21    type Sampler = UniformChar;
22}
23
24/// The back-end implementing [`UniformSampler`] for `char`.
25///
26/// Unless you are implementing [`UniformSampler`] for your own type, this type
27/// should not be used directly, use [`Uniform`] instead.
28///
29/// This differs from integer range sampling since the range `0xD800..=0xDFFF`
30/// are used for surrogate pairs in UCS and UTF-16, and consequently are not
31/// valid Unicode code points. We must therefore avoid sampling values in this
32/// range.
33#[derive(Clone, Copy, Debug, PartialEq, Eq)]
34#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
35pub struct UniformChar {
36    sampler: UniformInt<u32>,
37}
38
39/// UTF-16 surrogate range start
40const CHAR_SURROGATE_START: u32 = 0xD800;
41/// UTF-16 surrogate range size
42const CHAR_SURROGATE_LEN: u32 = 0xE000 - CHAR_SURROGATE_START;
43
44/// Convert `char` to compressed `u32`
45fn char_to_comp_u32(c: char) -> u32 {
46    match c as u32 {
47        c if c >= CHAR_SURROGATE_START => c - CHAR_SURROGATE_LEN,
48        c => c,
49    }
50}
51
52impl UniformSampler for UniformChar {
53    type X = char;
54
55    #[inline] // if the range is constant, this helps LLVM to do the
56              // calculations at compile-time.
57    fn new<B1, B2>(low_b: B1, high_b: B2) -> Result<Self, Error>
58    where
59        B1: SampleBorrow<Self::X> + Sized,
60        B2: SampleBorrow<Self::X> + Sized,
61    {
62        let low = char_to_comp_u32(*low_b.borrow());
63        let high = char_to_comp_u32(*high_b.borrow());
64        let sampler = UniformInt::<u32>::new(low, high);
65        sampler.map(|sampler| UniformChar { sampler })
66    }
67
68    #[inline] // if the range is constant, this helps LLVM to do the
69              // calculations at compile-time.
70    fn new_inclusive<B1, B2>(low_b: B1, high_b: B2) -> Result<Self, Error>
71    where
72        B1: SampleBorrow<Self::X> + Sized,
73        B2: SampleBorrow<Self::X> + Sized,
74    {
75        let low = char_to_comp_u32(*low_b.borrow());
76        let high = char_to_comp_u32(*high_b.borrow());
77        let sampler = UniformInt::<u32>::new_inclusive(low, high);
78        sampler.map(|sampler| UniformChar { sampler })
79    }
80
81    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Self::X {
82        let mut x = self.sampler.sample(rng);
83        if x >= CHAR_SURROGATE_START {
84            x += CHAR_SURROGATE_LEN;
85        }
86        // SAFETY: x must not be in surrogate range or greater than char::MAX.
87        // This relies on range constructors which accept char arguments.
88        // Validity of input char values is assumed.
89        unsafe { core::char::from_u32_unchecked(x) }
90    }
91}
92
93#[cfg(feature = "alloc")]
94impl crate::distr::SampleString for Uniform<char> {
95    fn append_string<R: Rng + ?Sized>(
96        &self,
97        rng: &mut R,
98        string: &mut alloc::string::String,
99        len: usize,
100    ) {
101        // Getting the hi value to assume the required length to reserve in string.
102        let mut hi = self.0.sampler.low + self.0.sampler.range - 1;
103        if hi >= CHAR_SURROGATE_START {
104            hi += CHAR_SURROGATE_LEN;
105        }
106        // Get the utf8 length of hi to minimize extra space.
107        let max_char_len = char::from_u32(hi).map(char::len_utf8).unwrap_or(4);
108        string.reserve(max_char_len * len);
109        string.extend(self.sample_iter(rng).take(len))
110    }
111}
112
113/// The back-end implementing [`UniformSampler`] for `Duration`.
114///
115/// Unless you are implementing [`UniformSampler`] for your own types, this type
116/// should not be used directly, use [`Uniform`] instead.
117#[derive(Clone, Copy, Debug, PartialEq, Eq)]
118#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
119pub struct UniformDuration {
120    mode: UniformDurationMode,
121    offset: u32,
122}
123
124#[derive(Debug, Copy, Clone, PartialEq, Eq)]
125#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
126enum UniformDurationMode {
127    Small {
128        secs: u64,
129        nanos: Uniform<u32>,
130    },
131    Medium {
132        nanos: Uniform<u64>,
133    },
134    Large {
135        max_secs: u64,
136        max_nanos: u32,
137        secs: Uniform<u64>,
138    },
139}
140
141impl SampleUniform for Duration {
142    type Sampler = UniformDuration;
143}
144
145impl UniformSampler for UniformDuration {
146    type X = Duration;
147
148    #[inline]
149    fn new<B1, B2>(low_b: B1, high_b: B2) -> Result<Self, Error>
150    where
151        B1: SampleBorrow<Self::X> + Sized,
152        B2: SampleBorrow<Self::X> + Sized,
153    {
154        let low = *low_b.borrow();
155        let high = *high_b.borrow();
156        if !(low < high) {
157            return Err(Error::EmptyRange);
158        }
159        UniformDuration::new_inclusive(low, high - Duration::new(0, 1))
160    }
161
162    #[inline]
163    fn new_inclusive<B1, B2>(low_b: B1, high_b: B2) -> Result<Self, Error>
164    where
165        B1: SampleBorrow<Self::X> + Sized,
166        B2: SampleBorrow<Self::X> + Sized,
167    {
168        let low = *low_b.borrow();
169        let high = *high_b.borrow();
170        if !(low <= high) {
171            return Err(Error::EmptyRange);
172        }
173
174        let low_s = low.as_secs();
175        let low_n = low.subsec_nanos();
176        let mut high_s = high.as_secs();
177        let mut high_n = high.subsec_nanos();
178
179        if high_n < low_n {
180            high_s -= 1;
181            high_n += 1_000_000_000;
182        }
183
184        let mode = if low_s == high_s {
185            UniformDurationMode::Small {
186                secs: low_s,
187                nanos: Uniform::new_inclusive(low_n, high_n)?,
188            }
189        } else {
190            let max = high_s
191                .checked_mul(1_000_000_000)
192                .and_then(|n| n.checked_add(u64::from(high_n)));
193
194            if let Some(higher_bound) = max {
195                let lower_bound = low_s * 1_000_000_000 + u64::from(low_n);
196                UniformDurationMode::Medium {
197                    nanos: Uniform::new_inclusive(lower_bound, higher_bound)?,
198                }
199            } else {
200                // An offset is applied to simplify generation of nanoseconds
201                let max_nanos = high_n - low_n;
202                UniformDurationMode::Large {
203                    max_secs: high_s,
204                    max_nanos,
205                    secs: Uniform::new_inclusive(low_s, high_s)?,
206                }
207            }
208        };
209        Ok(UniformDuration {
210            mode,
211            offset: low_n,
212        })
213    }
214
215    #[inline]
216    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Duration {
217        match self.mode {
218            UniformDurationMode::Small { secs, nanos } => {
219                let n = nanos.sample(rng);
220                Duration::new(secs, n)
221            }
222            UniformDurationMode::Medium { nanos } => {
223                let nanos = nanos.sample(rng);
224                Duration::new(nanos / 1_000_000_000, (nanos % 1_000_000_000) as u32)
225            }
226            UniformDurationMode::Large {
227                max_secs,
228                max_nanos,
229                secs,
230            } => {
231                // constant folding means this is at least as fast as `Rng::sample(Range)`
232                let nano_range = Uniform::new(0, 1_000_000_000).unwrap();
233                loop {
234                    let s = secs.sample(rng);
235                    let n = nano_range.sample(rng);
236                    if !(s == max_secs && n > max_nanos) {
237                        let sum = n + self.offset;
238                        break Duration::new(s, sum);
239                    }
240                }
241            }
242        }
243    }
244}
245
246#[cfg(test)]
247mod tests {
248    use super::*;
249
250    #[test]
251    #[cfg(feature = "serde")]
252    fn test_serialization_uniform_duration() {
253        let distr = UniformDuration::new(Duration::from_secs(10), Duration::from_secs(60)).unwrap();
254        let de_distr: UniformDuration =
255            bincode::deserialize(&bincode::serialize(&distr).unwrap()).unwrap();
256        assert_eq!(distr, de_distr);
257    }
258
259    #[test]
260    #[cfg_attr(miri, ignore)] // Miri is too slow
261    fn test_char() {
262        let mut rng = crate::test::rng(891);
263        let mut max = core::char::from_u32(0).unwrap();
264        for _ in 0..100 {
265            let c = rng.random_range('A'..='Z');
266            assert!(c.is_ascii_uppercase());
267            max = max.max(c);
268        }
269        assert_eq!(max, 'Z');
270        let d = Uniform::new(
271            core::char::from_u32(0xD7F0).unwrap(),
272            core::char::from_u32(0xE010).unwrap(),
273        )
274        .unwrap();
275        for _ in 0..100 {
276            let c = d.sample(&mut rng);
277            assert!((c as u32) < 0xD800 || (c as u32) > 0xDFFF);
278        }
279        #[cfg(feature = "alloc")]
280        {
281            use crate::distr::SampleString;
282            let string1 = d.sample_string(&mut rng, 100);
283            assert_eq!(string1.capacity(), 300);
284            let string2 = Uniform::new(
285                core::char::from_u32(0x0000).unwrap(),
286                core::char::from_u32(0x0080).unwrap(),
287            )
288            .unwrap()
289            .sample_string(&mut rng, 100);
290            assert_eq!(string2.capacity(), 100);
291            let string3 = Uniform::new_inclusive(
292                core::char::from_u32(0x0000).unwrap(),
293                core::char::from_u32(0x0080).unwrap(),
294            )
295            .unwrap()
296            .sample_string(&mut rng, 100);
297            assert_eq!(string3.capacity(), 200);
298        }
299    }
300
301    #[test]
302    #[cfg_attr(miri, ignore)] // Miri is too slow
303    fn test_durations() {
304        let mut rng = crate::test::rng(253);
305
306        let v = &[
307            (Duration::new(10, 50000), Duration::new(100, 1234)),
308            (Duration::new(0, 100), Duration::new(1, 50)),
309            (Duration::new(0, 0), Duration::new(u64::MAX, 999_999_999)),
310        ];
311        for &(low, high) in v.iter() {
312            let my_uniform = Uniform::new(low, high).unwrap();
313            for _ in 0..1000 {
314                let v = rng.sample(my_uniform);
315                assert!(low <= v && v < high);
316            }
317        }
318    }
319}