rand/seq/
iterator.rs

1// Copyright 2018-2024 Developers of the Rand project.
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9//! `IteratorRandom`
10
11use super::coin_flipper::CoinFlipper;
12#[allow(unused)]
13use super::IndexedRandom;
14use crate::Rng;
15#[cfg(feature = "alloc")]
16use alloc::vec::Vec;
17
18/// Extension trait on iterators, providing random sampling methods.
19///
20/// This trait is implemented on all iterators `I` where `I: Iterator + Sized`
21/// and provides methods for
22/// choosing one or more elements. You must `use` this trait:
23///
24/// ```
25/// use rand::seq::IteratorRandom;
26///
27/// let faces = "😀😎😐😕😠😢";
28/// println!("I am {}!", faces.chars().choose(&mut rand::rng()).unwrap());
29/// ```
30/// Example output (non-deterministic):
31/// ```none
32/// I am 😀!
33/// ```
34pub trait IteratorRandom: Iterator + Sized {
35    /// Uniformly sample one element
36    ///
37    /// Assuming that the [`Iterator::size_hint`] is correct, this method
38    /// returns one uniformly-sampled random element of the slice, or `None`
39    /// only if the slice is empty. Incorrect bounds on the `size_hint` may
40    /// cause this method to incorrectly return `None` if fewer elements than
41    /// the advertised `lower` bound are present and may prevent sampling of
42    /// elements beyond an advertised `upper` bound (i.e. incorrect `size_hint`
43    /// is memory-safe, but may result in unexpected `None` result and
44    /// non-uniform distribution).
45    ///
46    /// With an accurate [`Iterator::size_hint`] and where [`Iterator::nth`] is
47    /// a constant-time operation, this method can offer `O(1)` performance.
48    /// Where no size hint is
49    /// available, complexity is `O(n)` where `n` is the iterator length.
50    /// Partial hints (where `lower > 0`) also improve performance.
51    ///
52    /// Note further that [`Iterator::size_hint`] may affect the number of RNG
53    /// samples used as well as the result (while remaining uniform sampling).
54    /// Consider instead using [`IteratorRandom::choose_stable`] to avoid
55    /// [`Iterator`] combinators which only change size hints from affecting the
56    /// results.
57    ///
58    /// # Example
59    ///
60    /// ```
61    /// use rand::seq::IteratorRandom;
62    ///
63    /// let words = "Mary had a little lamb".split(' ');
64    /// println!("{}", words.choose(&mut rand::rng()).unwrap());
65    /// ```
66    fn choose<R>(mut self, rng: &mut R) -> Option<Self::Item>
67    where
68        R: Rng + ?Sized,
69    {
70        let (mut lower, mut upper) = self.size_hint();
71        let mut result = None;
72
73        // Handling for this condition outside the loop allows the optimizer to eliminate the loop
74        // when the Iterator is an ExactSizeIterator. This has a large performance impact on e.g.
75        // seq_iter_choose_from_1000.
76        if upper == Some(lower) {
77            return match lower {
78                0 => None,
79                1 => self.next(),
80                _ => self.nth(rng.random_range(..lower)),
81            };
82        }
83
84        let mut coin_flipper = CoinFlipper::new(rng);
85        let mut consumed = 0;
86
87        // Continue until the iterator is exhausted
88        loop {
89            if lower > 1 {
90                let ix = coin_flipper.rng.random_range(..lower + consumed);
91                let skip = if ix < lower {
92                    result = self.nth(ix);
93                    lower - (ix + 1)
94                } else {
95                    lower
96                };
97                if upper == Some(lower) {
98                    return result;
99                }
100                consumed += lower;
101                if skip > 0 {
102                    self.nth(skip - 1);
103                }
104            } else {
105                let elem = self.next();
106                if elem.is_none() {
107                    return result;
108                }
109                consumed += 1;
110                if coin_flipper.random_ratio_one_over(consumed) {
111                    result = elem;
112                }
113            }
114
115            let hint = self.size_hint();
116            lower = hint.0;
117            upper = hint.1;
118        }
119    }
120
121    /// Uniformly sample one element (stable)
122    ///
123    /// This method is very similar to [`choose`] except that the result
124    /// only depends on the length of the iterator and the values produced by
125    /// `rng`. Notably for any iterator of a given length this will make the
126    /// same requests to `rng` and if the same sequence of values are produced
127    /// the same index will be selected from `self`. This may be useful if you
128    /// need consistent results no matter what type of iterator you are working
129    /// with. If you do not need this stability prefer [`choose`].
130    ///
131    /// Note that this method still uses [`Iterator::size_hint`] to skip
132    /// constructing elements where possible, however the selection and `rng`
133    /// calls are the same in the face of this optimization. If you want to
134    /// force every element to be created regardless call `.inspect(|e| ())`.
135    ///
136    /// [`choose`]: IteratorRandom::choose
137    fn choose_stable<R>(mut self, rng: &mut R) -> Option<Self::Item>
138    where
139        R: Rng + ?Sized,
140    {
141        let mut consumed = 0;
142        let mut result = None;
143        let mut coin_flipper = CoinFlipper::new(rng);
144
145        loop {
146            // Currently the only way to skip elements is `nth()`. So we need to
147            // store what index to access next here.
148            // This should be replaced by `advance_by()` once it is stable:
149            // https://github.com/rust-lang/rust/issues/77404
150            let mut next = 0;
151
152            let (lower, _) = self.size_hint();
153            if lower >= 2 {
154                let highest_selected = (0..lower)
155                    .filter(|ix| coin_flipper.random_ratio_one_over(consumed + ix + 1))
156                    .last();
157
158                consumed += lower;
159                next = lower;
160
161                if let Some(ix) = highest_selected {
162                    result = self.nth(ix);
163                    next -= ix + 1;
164                    debug_assert!(result.is_some(), "iterator shorter than size_hint().0");
165                }
166            }
167
168            let elem = self.nth(next);
169            if elem.is_none() {
170                return result;
171            }
172
173            if coin_flipper.random_ratio_one_over(consumed + 1) {
174                result = elem;
175            }
176            consumed += 1;
177        }
178    }
179
180    /// Uniformly sample `amount` distinct elements into a buffer
181    ///
182    /// Collects values at random from the iterator into a supplied buffer
183    /// until that buffer is filled.
184    ///
185    /// Although the elements are selected randomly, the order of elements in
186    /// the buffer is neither stable nor fully random. If random ordering is
187    /// desired, shuffle the result.
188    ///
189    /// Returns the number of elements added to the buffer. This equals the length
190    /// of the buffer unless the iterator contains insufficient elements, in which
191    /// case this equals the number of elements available.
192    ///
193    /// Complexity is `O(n)` where `n` is the length of the iterator.
194    /// For slices, prefer [`IndexedRandom::choose_multiple`].
195    fn choose_multiple_fill<R>(mut self, rng: &mut R, buf: &mut [Self::Item]) -> usize
196    where
197        R: Rng + ?Sized,
198    {
199        let amount = buf.len();
200        let mut len = 0;
201        while len < amount {
202            if let Some(elem) = self.next() {
203                buf[len] = elem;
204                len += 1;
205            } else {
206                // Iterator exhausted; stop early
207                return len;
208            }
209        }
210
211        // Continue, since the iterator was not exhausted
212        for (i, elem) in self.enumerate() {
213            let k = rng.random_range(..i + 1 + amount);
214            if let Some(slot) = buf.get_mut(k) {
215                *slot = elem;
216            }
217        }
218        len
219    }
220
221    /// Uniformly sample `amount` distinct elements into a [`Vec`]
222    ///
223    /// This is equivalent to `choose_multiple_fill` except for the result type.
224    ///
225    /// Although the elements are selected randomly, the order of elements in
226    /// the buffer is neither stable nor fully random. If random ordering is
227    /// desired, shuffle the result.
228    ///
229    /// The length of the returned vector equals `amount` unless the iterator
230    /// contains insufficient elements, in which case it equals the number of
231    /// elements available.
232    ///
233    /// Complexity is `O(n)` where `n` is the length of the iterator.
234    /// For slices, prefer [`IndexedRandom::choose_multiple`].
235    #[cfg(feature = "alloc")]
236    fn choose_multiple<R>(mut self, rng: &mut R, amount: usize) -> Vec<Self::Item>
237    where
238        R: Rng + ?Sized,
239    {
240        let mut reservoir = Vec::with_capacity(amount);
241        reservoir.extend(self.by_ref().take(amount));
242
243        // Continue unless the iterator was exhausted
244        //
245        // note: this prevents iterators that "restart" from causing problems.
246        // If the iterator stops once, then so do we.
247        if reservoir.len() == amount {
248            for (i, elem) in self.enumerate() {
249                let k = rng.random_range(..i + 1 + amount);
250                if let Some(slot) = reservoir.get_mut(k) {
251                    *slot = elem;
252                }
253            }
254        } else {
255            // Don't hang onto extra memory. There is a corner case where
256            // `amount` was much less than `self.len()`.
257            reservoir.shrink_to_fit();
258        }
259        reservoir
260    }
261}
262
263impl<I> IteratorRandom for I where I: Iterator + Sized {}
264
265#[cfg(test)]
266mod test {
267    use super::*;
268    #[cfg(all(feature = "alloc", not(feature = "std")))]
269    use alloc::vec::Vec;
270
271    #[derive(Clone)]
272    struct UnhintedIterator<I: Iterator + Clone> {
273        iter: I,
274    }
275    impl<I: Iterator + Clone> Iterator for UnhintedIterator<I> {
276        type Item = I::Item;
277
278        fn next(&mut self) -> Option<Self::Item> {
279            self.iter.next()
280        }
281    }
282
283    #[derive(Clone)]
284    struct ChunkHintedIterator<I: ExactSizeIterator + Iterator + Clone> {
285        iter: I,
286        chunk_remaining: usize,
287        chunk_size: usize,
288        hint_total_size: bool,
289    }
290    impl<I: ExactSizeIterator + Iterator + Clone> Iterator for ChunkHintedIterator<I> {
291        type Item = I::Item;
292
293        fn next(&mut self) -> Option<Self::Item> {
294            if self.chunk_remaining == 0 {
295                self.chunk_remaining = core::cmp::min(self.chunk_size, self.iter.len());
296            }
297            self.chunk_remaining = self.chunk_remaining.saturating_sub(1);
298
299            self.iter.next()
300        }
301
302        fn size_hint(&self) -> (usize, Option<usize>) {
303            (
304                self.chunk_remaining,
305                if self.hint_total_size {
306                    Some(self.iter.len())
307                } else {
308                    None
309                },
310            )
311        }
312    }
313
314    #[derive(Clone)]
315    struct WindowHintedIterator<I: ExactSizeIterator + Iterator + Clone> {
316        iter: I,
317        window_size: usize,
318        hint_total_size: bool,
319    }
320    impl<I: ExactSizeIterator + Iterator + Clone> Iterator for WindowHintedIterator<I> {
321        type Item = I::Item;
322
323        fn next(&mut self) -> Option<Self::Item> {
324            self.iter.next()
325        }
326
327        fn size_hint(&self) -> (usize, Option<usize>) {
328            (
329                core::cmp::min(self.iter.len(), self.window_size),
330                if self.hint_total_size {
331                    Some(self.iter.len())
332                } else {
333                    None
334                },
335            )
336        }
337    }
338
339    #[test]
340    #[cfg_attr(miri, ignore)] // Miri is too slow
341    fn test_iterator_choose() {
342        let r = &mut crate::test::rng(109);
343        fn test_iter<R: Rng + ?Sized, Iter: Iterator<Item = usize> + Clone>(r: &mut R, iter: Iter) {
344            let mut chosen = [0i32; 9];
345            for _ in 0..1000 {
346                let picked = iter.clone().choose(r).unwrap();
347                chosen[picked] += 1;
348            }
349            for count in chosen.iter() {
350                // Samples should follow Binomial(1000, 1/9)
351                // Octave: binopdf(x, 1000, 1/9) gives the prob of *count == x
352                // Note: have seen 153, which is unlikely but not impossible.
353                assert!(
354                    72 < *count && *count < 154,
355                    "count not close to 1000/9: {}",
356                    count
357                );
358            }
359        }
360
361        test_iter(r, 0..9);
362        test_iter(r, [0, 1, 2, 3, 4, 5, 6, 7, 8].iter().cloned());
363        #[cfg(feature = "alloc")]
364        test_iter(r, (0..9).collect::<Vec<_>>().into_iter());
365        test_iter(r, UnhintedIterator { iter: 0..9 });
366        test_iter(
367            r,
368            ChunkHintedIterator {
369                iter: 0..9,
370                chunk_size: 4,
371                chunk_remaining: 4,
372                hint_total_size: false,
373            },
374        );
375        test_iter(
376            r,
377            ChunkHintedIterator {
378                iter: 0..9,
379                chunk_size: 4,
380                chunk_remaining: 4,
381                hint_total_size: true,
382            },
383        );
384        test_iter(
385            r,
386            WindowHintedIterator {
387                iter: 0..9,
388                window_size: 2,
389                hint_total_size: false,
390            },
391        );
392        test_iter(
393            r,
394            WindowHintedIterator {
395                iter: 0..9,
396                window_size: 2,
397                hint_total_size: true,
398            },
399        );
400
401        assert_eq!((0..0).choose(r), None);
402        assert_eq!(UnhintedIterator { iter: 0..0 }.choose(r), None);
403    }
404
405    #[test]
406    #[cfg_attr(miri, ignore)] // Miri is too slow
407    fn test_iterator_choose_stable() {
408        let r = &mut crate::test::rng(109);
409        fn test_iter<R: Rng + ?Sized, Iter: Iterator<Item = usize> + Clone>(r: &mut R, iter: Iter) {
410            let mut chosen = [0i32; 9];
411            for _ in 0..1000 {
412                let picked = iter.clone().choose_stable(r).unwrap();
413                chosen[picked] += 1;
414            }
415            for count in chosen.iter() {
416                // Samples should follow Binomial(1000, 1/9)
417                // Octave: binopdf(x, 1000, 1/9) gives the prob of *count == x
418                // Note: have seen 153, which is unlikely but not impossible.
419                assert!(
420                    72 < *count && *count < 154,
421                    "count not close to 1000/9: {}",
422                    count
423                );
424            }
425        }
426
427        test_iter(r, 0..9);
428        test_iter(r, [0, 1, 2, 3, 4, 5, 6, 7, 8].iter().cloned());
429        #[cfg(feature = "alloc")]
430        test_iter(r, (0..9).collect::<Vec<_>>().into_iter());
431        test_iter(r, UnhintedIterator { iter: 0..9 });
432        test_iter(
433            r,
434            ChunkHintedIterator {
435                iter: 0..9,
436                chunk_size: 4,
437                chunk_remaining: 4,
438                hint_total_size: false,
439            },
440        );
441        test_iter(
442            r,
443            ChunkHintedIterator {
444                iter: 0..9,
445                chunk_size: 4,
446                chunk_remaining: 4,
447                hint_total_size: true,
448            },
449        );
450        test_iter(
451            r,
452            WindowHintedIterator {
453                iter: 0..9,
454                window_size: 2,
455                hint_total_size: false,
456            },
457        );
458        test_iter(
459            r,
460            WindowHintedIterator {
461                iter: 0..9,
462                window_size: 2,
463                hint_total_size: true,
464            },
465        );
466
467        assert_eq!((0..0).choose(r), None);
468        assert_eq!(UnhintedIterator { iter: 0..0 }.choose(r), None);
469    }
470
471    #[test]
472    #[cfg_attr(miri, ignore)] // Miri is too slow
473    fn test_iterator_choose_stable_stability() {
474        fn test_iter(iter: impl Iterator<Item = usize> + Clone) -> [i32; 9] {
475            let r = &mut crate::test::rng(109);
476            let mut chosen = [0i32; 9];
477            for _ in 0..1000 {
478                let picked = iter.clone().choose_stable(r).unwrap();
479                chosen[picked] += 1;
480            }
481            chosen
482        }
483
484        let reference = test_iter(0..9);
485        assert_eq!(
486            test_iter([0, 1, 2, 3, 4, 5, 6, 7, 8].iter().cloned()),
487            reference
488        );
489
490        #[cfg(feature = "alloc")]
491        assert_eq!(test_iter((0..9).collect::<Vec<_>>().into_iter()), reference);
492        assert_eq!(test_iter(UnhintedIterator { iter: 0..9 }), reference);
493        assert_eq!(
494            test_iter(ChunkHintedIterator {
495                iter: 0..9,
496                chunk_size: 4,
497                chunk_remaining: 4,
498                hint_total_size: false,
499            }),
500            reference
501        );
502        assert_eq!(
503            test_iter(ChunkHintedIterator {
504                iter: 0..9,
505                chunk_size: 4,
506                chunk_remaining: 4,
507                hint_total_size: true,
508            }),
509            reference
510        );
511        assert_eq!(
512            test_iter(WindowHintedIterator {
513                iter: 0..9,
514                window_size: 2,
515                hint_total_size: false,
516            }),
517            reference
518        );
519        assert_eq!(
520            test_iter(WindowHintedIterator {
521                iter: 0..9,
522                window_size: 2,
523                hint_total_size: true,
524            }),
525            reference
526        );
527    }
528
529    #[test]
530    #[cfg(feature = "alloc")]
531    fn test_sample_iter() {
532        let min_val = 1;
533        let max_val = 100;
534
535        let mut r = crate::test::rng(401);
536        let vals = (min_val..max_val).collect::<Vec<i32>>();
537        let small_sample = vals.iter().choose_multiple(&mut r, 5);
538        let large_sample = vals.iter().choose_multiple(&mut r, vals.len() + 5);
539
540        assert_eq!(small_sample.len(), 5);
541        assert_eq!(large_sample.len(), vals.len());
542        // no randomization happens when amount >= len
543        assert_eq!(large_sample, vals.iter().collect::<Vec<_>>());
544
545        assert!(small_sample
546            .iter()
547            .all(|e| { **e >= min_val && **e <= max_val }));
548    }
549
550    #[test]
551    fn value_stability_choose() {
552        fn choose<I: Iterator<Item = u32>>(iter: I) -> Option<u32> {
553            let mut rng = crate::test::rng(411);
554            iter.choose(&mut rng)
555        }
556
557        assert_eq!(choose([].iter().cloned()), None);
558        assert_eq!(choose(0..100), Some(33));
559        assert_eq!(choose(UnhintedIterator { iter: 0..100 }), Some(27));
560        assert_eq!(
561            choose(ChunkHintedIterator {
562                iter: 0..100,
563                chunk_size: 32,
564                chunk_remaining: 32,
565                hint_total_size: false,
566            }),
567            Some(91)
568        );
569        assert_eq!(
570            choose(ChunkHintedIterator {
571                iter: 0..100,
572                chunk_size: 32,
573                chunk_remaining: 32,
574                hint_total_size: true,
575            }),
576            Some(91)
577        );
578        assert_eq!(
579            choose(WindowHintedIterator {
580                iter: 0..100,
581                window_size: 32,
582                hint_total_size: false,
583            }),
584            Some(34)
585        );
586        assert_eq!(
587            choose(WindowHintedIterator {
588                iter: 0..100,
589                window_size: 32,
590                hint_total_size: true,
591            }),
592            Some(34)
593        );
594    }
595
596    #[test]
597    fn value_stability_choose_stable() {
598        fn choose<I: Iterator<Item = u32>>(iter: I) -> Option<u32> {
599            let mut rng = crate::test::rng(411);
600            iter.choose_stable(&mut rng)
601        }
602
603        assert_eq!(choose([].iter().cloned()), None);
604        assert_eq!(choose(0..100), Some(27));
605        assert_eq!(choose(UnhintedIterator { iter: 0..100 }), Some(27));
606        assert_eq!(
607            choose(ChunkHintedIterator {
608                iter: 0..100,
609                chunk_size: 32,
610                chunk_remaining: 32,
611                hint_total_size: false,
612            }),
613            Some(27)
614        );
615        assert_eq!(
616            choose(ChunkHintedIterator {
617                iter: 0..100,
618                chunk_size: 32,
619                chunk_remaining: 32,
620                hint_total_size: true,
621            }),
622            Some(27)
623        );
624        assert_eq!(
625            choose(WindowHintedIterator {
626                iter: 0..100,
627                window_size: 32,
628                hint_total_size: false,
629            }),
630            Some(27)
631        );
632        assert_eq!(
633            choose(WindowHintedIterator {
634                iter: 0..100,
635                window_size: 32,
636                hint_total_size: true,
637            }),
638            Some(27)
639        );
640    }
641
642    #[test]
643    fn value_stability_choose_multiple() {
644        fn do_test<I: Clone + Iterator<Item = u32>>(iter: I, v: &[u32]) {
645            let mut rng = crate::test::rng(412);
646            let mut buf = [0u32; 8];
647            assert_eq!(
648                iter.clone().choose_multiple_fill(&mut rng, &mut buf),
649                v.len()
650            );
651            assert_eq!(&buf[0..v.len()], v);
652
653            #[cfg(feature = "alloc")]
654            {
655                let mut rng = crate::test::rng(412);
656                assert_eq!(iter.choose_multiple(&mut rng, v.len()), v);
657            }
658        }
659
660        do_test(0..4, &[0, 1, 2, 3]);
661        do_test(0..8, &[0, 1, 2, 3, 4, 5, 6, 7]);
662        do_test(0..100, &[77, 95, 38, 23, 25, 8, 58, 40]);
663    }
664}