rand/distr/slice.rs
1// Copyright 2021 Developers of the Rand project.
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9//! Distributions over slices
10
11use core::num::NonZeroUsize;
12
13use crate::distr::uniform::{UniformSampler, UniformUsize};
14use crate::distr::Distribution;
15#[cfg(feature = "alloc")]
16use alloc::string::String;
17
18/// A distribution to uniformly sample elements of a slice
19///
20/// Like [`IndexedRandom::choose`], this uniformly samples elements of a slice
21/// without modification of the slice (so called "sampling with replacement").
22/// This distribution object may be a little faster for repeated sampling (but
23/// slower for small numbers of samples).
24///
25/// ## Examples
26///
27/// Since this is a distribution, [`Rng::sample_iter`] and
28/// [`Distribution::sample_iter`] may be used, for example:
29/// ```
30/// use rand::distr::{Distribution, slice::Choose};
31///
32/// let vowels = ['a', 'e', 'i', 'o', 'u'];
33/// let vowels_dist = Choose::new(&vowels).unwrap();
34///
35/// // build a string of 10 vowels
36/// let vowel_string: String = vowels_dist
37/// .sample_iter(&mut rand::rng())
38/// .take(10)
39/// .collect();
40///
41/// println!("{}", vowel_string);
42/// assert_eq!(vowel_string.len(), 10);
43/// assert!(vowel_string.chars().all(|c| vowels.contains(&c)));
44/// ```
45///
46/// For a single sample, [`IndexedRandom::choose`] may be preferred:
47/// ```
48/// use rand::seq::IndexedRandom;
49///
50/// let vowels = ['a', 'e', 'i', 'o', 'u'];
51/// let mut rng = rand::rng();
52///
53/// println!("{}", vowels.choose(&mut rng).unwrap());
54/// ```
55///
56/// [`IndexedRandom::choose`]: crate::seq::IndexedRandom::choose
57/// [`Rng::sample_iter`]: crate::Rng::sample_iter
58#[derive(Debug, Clone, Copy)]
59pub struct Choose<'a, T> {
60 slice: &'a [T],
61 range: UniformUsize,
62 num_choices: NonZeroUsize,
63}
64
65impl<'a, T> Choose<'a, T> {
66 /// Create a new `Choose` instance which samples uniformly from the slice.
67 ///
68 /// Returns error [`Empty`] if the slice is empty.
69 pub fn new(slice: &'a [T]) -> Result<Self, Empty> {
70 let num_choices = NonZeroUsize::new(slice.len()).ok_or(Empty)?;
71
72 Ok(Self {
73 slice,
74 range: UniformUsize::new(0, num_choices.get()).unwrap(),
75 num_choices,
76 })
77 }
78
79 /// Returns the count of choices in this distribution
80 pub fn num_choices(&self) -> NonZeroUsize {
81 self.num_choices
82 }
83}
84
85impl<'a, T> Distribution<&'a T> for Choose<'a, T> {
86 fn sample<R: crate::Rng + ?Sized>(&self, rng: &mut R) -> &'a T {
87 let idx = self.range.sample(rng);
88
89 debug_assert!(
90 idx < self.slice.len(),
91 "Uniform::new(0, {}) somehow returned {}",
92 self.slice.len(),
93 idx
94 );
95
96 // Safety: at construction time, it was ensured that the slice was
97 // non-empty, and that the `Uniform` range produces values in range
98 // for the slice
99 unsafe { self.slice.get_unchecked(idx) }
100 }
101}
102
103/// Error: empty slice
104///
105/// This error is returned when [`Choose::new`] is given an empty slice.
106#[derive(Debug, Clone, Copy)]
107pub struct Empty;
108
109impl core::fmt::Display for Empty {
110 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
111 write!(
112 f,
113 "Tried to create a `rand::distr::slice::Choose` with an empty slice"
114 )
115 }
116}
117
118#[cfg(feature = "std")]
119impl std::error::Error for Empty {}
120
121#[cfg(feature = "alloc")]
122impl super::SampleString for Choose<'_, char> {
123 fn append_string<R: crate::Rng + ?Sized>(&self, rng: &mut R, string: &mut String, len: usize) {
124 // Get the max char length to minimize extra space.
125 // Limit this check to avoid searching for long slice.
126 let max_char_len = if self.slice.len() < 200 {
127 self.slice
128 .iter()
129 .try_fold(1, |max_len, char| {
130 // When the current max_len is 4, the result max_char_len will be 4.
131 Some(max_len.max(char.len_utf8())).filter(|len| *len < 4)
132 })
133 .unwrap_or(4)
134 } else {
135 4
136 };
137
138 // Split the extension of string to reuse the unused capacities.
139 // Skip the split for small length or only ascii slice.
140 let mut extend_len = if max_char_len == 1 || len < 100 {
141 len
142 } else {
143 len / 4
144 };
145 let mut remain_len = len;
146 while extend_len > 0 {
147 string.reserve(max_char_len * extend_len);
148 string.extend(self.sample_iter(&mut *rng).take(extend_len));
149 remain_len -= extend_len;
150 extend_len = extend_len.min(remain_len);
151 }
152 }
153}
154
155#[cfg(test)]
156mod test {
157 use super::*;
158 use core::iter;
159
160 #[test]
161 fn value_stability() {
162 let rng = crate::test::rng(651);
163 let slice = Choose::new(b"escaped emus explore extensively").unwrap();
164 let expected = b"eaxee";
165 assert!(iter::zip(slice.sample_iter(rng), expected).all(|(a, b)| a == b));
166 }
167}