1use super::{Error, SampleBorrow, SampleUniform, Uniform, UniformInt, UniformSampler};
13use crate::distr::Distribution;
14use crate::Rng;
15use core::time::Duration;
16
17#[cfg(feature = "serde")]
18use serde::{Deserialize, Serialize};
19
20impl SampleUniform for char {
21 type Sampler = UniformChar;
22}
23
24#[derive(Clone, Copy, Debug, PartialEq, Eq)]
34#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
35pub struct UniformChar {
36 sampler: UniformInt<u32>,
37}
38
39const CHAR_SURROGATE_START: u32 = 0xD800;
41const CHAR_SURROGATE_LEN: u32 = 0xE000 - CHAR_SURROGATE_START;
43
44fn char_to_comp_u32(c: char) -> u32 {
46 match c as u32 {
47 c if c >= CHAR_SURROGATE_START => c - CHAR_SURROGATE_LEN,
48 c => c,
49 }
50}
51
52impl UniformSampler for UniformChar {
53 type X = char;
54
55 #[inline] fn new<B1, B2>(low_b: B1, high_b: B2) -> Result<Self, Error>
58 where
59 B1: SampleBorrow<Self::X> + Sized,
60 B2: SampleBorrow<Self::X> + Sized,
61 {
62 let low = char_to_comp_u32(*low_b.borrow());
63 let high = char_to_comp_u32(*high_b.borrow());
64 let sampler = UniformInt::<u32>::new(low, high);
65 sampler.map(|sampler| UniformChar { sampler })
66 }
67
68 #[inline] fn new_inclusive<B1, B2>(low_b: B1, high_b: B2) -> Result<Self, Error>
71 where
72 B1: SampleBorrow<Self::X> + Sized,
73 B2: SampleBorrow<Self::X> + Sized,
74 {
75 let low = char_to_comp_u32(*low_b.borrow());
76 let high = char_to_comp_u32(*high_b.borrow());
77 let sampler = UniformInt::<u32>::new_inclusive(low, high);
78 sampler.map(|sampler| UniformChar { sampler })
79 }
80
81 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Self::X {
82 let mut x = self.sampler.sample(rng);
83 if x >= CHAR_SURROGATE_START {
84 x += CHAR_SURROGATE_LEN;
85 }
86 unsafe { core::char::from_u32_unchecked(x) }
90 }
91}
92
93#[cfg(feature = "alloc")]
94impl crate::distr::SampleString for Uniform<char> {
95 fn append_string<R: Rng + ?Sized>(
96 &self,
97 rng: &mut R,
98 string: &mut alloc::string::String,
99 len: usize,
100 ) {
101 let mut hi = self.0.sampler.low + self.0.sampler.range - 1;
103 if hi >= CHAR_SURROGATE_START {
104 hi += CHAR_SURROGATE_LEN;
105 }
106 let max_char_len = char::from_u32(hi).map(char::len_utf8).unwrap_or(4);
108 string.reserve(max_char_len * len);
109 string.extend(self.sample_iter(rng).take(len))
110 }
111}
112
113#[derive(Clone, Copy, Debug, PartialEq, Eq)]
118#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
119pub struct UniformDuration {
120 mode: UniformDurationMode,
121 offset: u32,
122}
123
124#[derive(Debug, Copy, Clone, PartialEq, Eq)]
125#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
126enum UniformDurationMode {
127 Small {
128 secs: u64,
129 nanos: Uniform<u32>,
130 },
131 Medium {
132 nanos: Uniform<u64>,
133 },
134 Large {
135 max_secs: u64,
136 max_nanos: u32,
137 secs: Uniform<u64>,
138 },
139}
140
141impl SampleUniform for Duration {
142 type Sampler = UniformDuration;
143}
144
145impl UniformSampler for UniformDuration {
146 type X = Duration;
147
148 #[inline]
149 fn new<B1, B2>(low_b: B1, high_b: B2) -> Result<Self, Error>
150 where
151 B1: SampleBorrow<Self::X> + Sized,
152 B2: SampleBorrow<Self::X> + Sized,
153 {
154 let low = *low_b.borrow();
155 let high = *high_b.borrow();
156 if !(low < high) {
157 return Err(Error::EmptyRange);
158 }
159 UniformDuration::new_inclusive(low, high - Duration::new(0, 1))
160 }
161
162 #[inline]
163 fn new_inclusive<B1, B2>(low_b: B1, high_b: B2) -> Result<Self, Error>
164 where
165 B1: SampleBorrow<Self::X> + Sized,
166 B2: SampleBorrow<Self::X> + Sized,
167 {
168 let low = *low_b.borrow();
169 let high = *high_b.borrow();
170 if !(low <= high) {
171 return Err(Error::EmptyRange);
172 }
173
174 let low_s = low.as_secs();
175 let low_n = low.subsec_nanos();
176 let mut high_s = high.as_secs();
177 let mut high_n = high.subsec_nanos();
178
179 if high_n < low_n {
180 high_s -= 1;
181 high_n += 1_000_000_000;
182 }
183
184 let mode = if low_s == high_s {
185 UniformDurationMode::Small {
186 secs: low_s,
187 nanos: Uniform::new_inclusive(low_n, high_n)?,
188 }
189 } else {
190 let max = high_s
191 .checked_mul(1_000_000_000)
192 .and_then(|n| n.checked_add(u64::from(high_n)));
193
194 if let Some(higher_bound) = max {
195 let lower_bound = low_s * 1_000_000_000 + u64::from(low_n);
196 UniformDurationMode::Medium {
197 nanos: Uniform::new_inclusive(lower_bound, higher_bound)?,
198 }
199 } else {
200 let max_nanos = high_n - low_n;
202 UniformDurationMode::Large {
203 max_secs: high_s,
204 max_nanos,
205 secs: Uniform::new_inclusive(low_s, high_s)?,
206 }
207 }
208 };
209 Ok(UniformDuration {
210 mode,
211 offset: low_n,
212 })
213 }
214
215 #[inline]
216 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Duration {
217 match self.mode {
218 UniformDurationMode::Small { secs, nanos } => {
219 let n = nanos.sample(rng);
220 Duration::new(secs, n)
221 }
222 UniformDurationMode::Medium { nanos } => {
223 let nanos = nanos.sample(rng);
224 Duration::new(nanos / 1_000_000_000, (nanos % 1_000_000_000) as u32)
225 }
226 UniformDurationMode::Large {
227 max_secs,
228 max_nanos,
229 secs,
230 } => {
231 let nano_range = Uniform::new(0, 1_000_000_000).unwrap();
233 loop {
234 let s = secs.sample(rng);
235 let n = nano_range.sample(rng);
236 if !(s == max_secs && n > max_nanos) {
237 let sum = n + self.offset;
238 break Duration::new(s, sum);
239 }
240 }
241 }
242 }
243 }
244}
245
246#[cfg(test)]
247mod tests {
248 use super::*;
249
250 #[test]
251 #[cfg(feature = "serde")]
252 fn test_serialization_uniform_duration() {
253 let distr = UniformDuration::new(Duration::from_secs(10), Duration::from_secs(60)).unwrap();
254 let de_distr: UniformDuration =
255 bincode::deserialize(&bincode::serialize(&distr).unwrap()).unwrap();
256 assert_eq!(distr, de_distr);
257 }
258
259 #[test]
260 #[cfg_attr(miri, ignore)] fn test_char() {
262 let mut rng = crate::test::rng(891);
263 let mut max = core::char::from_u32(0).unwrap();
264 for _ in 0..100 {
265 let c = rng.random_range('A'..='Z');
266 assert!(c.is_ascii_uppercase());
267 max = max.max(c);
268 }
269 assert_eq!(max, 'Z');
270 let d = Uniform::new(
271 core::char::from_u32(0xD7F0).unwrap(),
272 core::char::from_u32(0xE010).unwrap(),
273 )
274 .unwrap();
275 for _ in 0..100 {
276 let c = d.sample(&mut rng);
277 assert!((c as u32) < 0xD800 || (c as u32) > 0xDFFF);
278 }
279 #[cfg(feature = "alloc")]
280 {
281 use crate::distr::SampleString;
282 let string1 = d.sample_string(&mut rng, 100);
283 assert_eq!(string1.capacity(), 300);
284 let string2 = Uniform::new(
285 core::char::from_u32(0x0000).unwrap(),
286 core::char::from_u32(0x0080).unwrap(),
287 )
288 .unwrap()
289 .sample_string(&mut rng, 100);
290 assert_eq!(string2.capacity(), 100);
291 let string3 = Uniform::new_inclusive(
292 core::char::from_u32(0x0000).unwrap(),
293 core::char::from_u32(0x0080).unwrap(),
294 )
295 .unwrap()
296 .sample_string(&mut rng, 100);
297 assert_eq!(string3.capacity(), 200);
298 }
299 }
300
301 #[test]
302 #[cfg_attr(miri, ignore)] fn test_durations() {
304 let mut rng = crate::test::rng(253);
305
306 let v = &[
307 (Duration::new(10, 50000), Duration::new(100, 1234)),
308 (Duration::new(0, 100), Duration::new(1, 50)),
309 (Duration::new(0, 0), Duration::new(u64::MAX, 999_999_999)),
310 ];
311 for &(low, high) in v.iter() {
312 let my_uniform = Uniform::new(low, high).unwrap();
313 for _ in 0..1000 {
314 let v = rng.sample(my_uniform);
315 assert!(low <= v && v < high);
316 }
317 }
318 }
319}