1use super::coin_flipper::CoinFlipper;
12#[allow(unused)]
13use super::IndexedRandom;
14use crate::Rng;
15#[cfg(feature = "alloc")]
16use alloc::vec::Vec;
17
18pub trait IteratorRandom: Iterator + Sized {
35 fn choose<R>(mut self, rng: &mut R) -> Option<Self::Item>
67 where
68 R: Rng + ?Sized,
69 {
70 let (mut lower, mut upper) = self.size_hint();
71 let mut result = None;
72
73 if upper == Some(lower) {
77 return match lower {
78 0 => None,
79 1 => self.next(),
80 _ => self.nth(rng.random_range(..lower)),
81 };
82 }
83
84 let mut coin_flipper = CoinFlipper::new(rng);
85 let mut consumed = 0;
86
87 loop {
89 if lower > 1 {
90 let ix = coin_flipper.rng.random_range(..lower + consumed);
91 let skip = if ix < lower {
92 result = self.nth(ix);
93 lower - (ix + 1)
94 } else {
95 lower
96 };
97 if upper == Some(lower) {
98 return result;
99 }
100 consumed += lower;
101 if skip > 0 {
102 self.nth(skip - 1);
103 }
104 } else {
105 let elem = self.next();
106 if elem.is_none() {
107 return result;
108 }
109 consumed += 1;
110 if coin_flipper.random_ratio_one_over(consumed) {
111 result = elem;
112 }
113 }
114
115 let hint = self.size_hint();
116 lower = hint.0;
117 upper = hint.1;
118 }
119 }
120
121 fn choose_stable<R>(mut self, rng: &mut R) -> Option<Self::Item>
138 where
139 R: Rng + ?Sized,
140 {
141 let mut consumed = 0;
142 let mut result = None;
143 let mut coin_flipper = CoinFlipper::new(rng);
144
145 loop {
146 let mut next = 0;
151
152 let (lower, _) = self.size_hint();
153 if lower >= 2 {
154 let highest_selected = (0..lower)
155 .filter(|ix| coin_flipper.random_ratio_one_over(consumed + ix + 1))
156 .last();
157
158 consumed += lower;
159 next = lower;
160
161 if let Some(ix) = highest_selected {
162 result = self.nth(ix);
163 next -= ix + 1;
164 debug_assert!(result.is_some(), "iterator shorter than size_hint().0");
165 }
166 }
167
168 let elem = self.nth(next);
169 if elem.is_none() {
170 return result;
171 }
172
173 if coin_flipper.random_ratio_one_over(consumed + 1) {
174 result = elem;
175 }
176 consumed += 1;
177 }
178 }
179
180 fn choose_multiple_fill<R>(mut self, rng: &mut R, buf: &mut [Self::Item]) -> usize
196 where
197 R: Rng + ?Sized,
198 {
199 let amount = buf.len();
200 let mut len = 0;
201 while len < amount {
202 if let Some(elem) = self.next() {
203 buf[len] = elem;
204 len += 1;
205 } else {
206 return len;
208 }
209 }
210
211 for (i, elem) in self.enumerate() {
213 let k = rng.random_range(..i + 1 + amount);
214 if let Some(slot) = buf.get_mut(k) {
215 *slot = elem;
216 }
217 }
218 len
219 }
220
221 #[cfg(feature = "alloc")]
236 fn choose_multiple<R>(mut self, rng: &mut R, amount: usize) -> Vec<Self::Item>
237 where
238 R: Rng + ?Sized,
239 {
240 let mut reservoir = Vec::with_capacity(amount);
241 reservoir.extend(self.by_ref().take(amount));
242
243 if reservoir.len() == amount {
248 for (i, elem) in self.enumerate() {
249 let k = rng.random_range(..i + 1 + amount);
250 if let Some(slot) = reservoir.get_mut(k) {
251 *slot = elem;
252 }
253 }
254 } else {
255 reservoir.shrink_to_fit();
258 }
259 reservoir
260 }
261}
262
263impl<I> IteratorRandom for I where I: Iterator + Sized {}
264
265#[cfg(test)]
266mod test {
267 use super::*;
268 #[cfg(all(feature = "alloc", not(feature = "std")))]
269 use alloc::vec::Vec;
270
271 #[derive(Clone)]
272 struct UnhintedIterator<I: Iterator + Clone> {
273 iter: I,
274 }
275 impl<I: Iterator + Clone> Iterator for UnhintedIterator<I> {
276 type Item = I::Item;
277
278 fn next(&mut self) -> Option<Self::Item> {
279 self.iter.next()
280 }
281 }
282
283 #[derive(Clone)]
284 struct ChunkHintedIterator<I: ExactSizeIterator + Iterator + Clone> {
285 iter: I,
286 chunk_remaining: usize,
287 chunk_size: usize,
288 hint_total_size: bool,
289 }
290 impl<I: ExactSizeIterator + Iterator + Clone> Iterator for ChunkHintedIterator<I> {
291 type Item = I::Item;
292
293 fn next(&mut self) -> Option<Self::Item> {
294 if self.chunk_remaining == 0 {
295 self.chunk_remaining = core::cmp::min(self.chunk_size, self.iter.len());
296 }
297 self.chunk_remaining = self.chunk_remaining.saturating_sub(1);
298
299 self.iter.next()
300 }
301
302 fn size_hint(&self) -> (usize, Option<usize>) {
303 (
304 self.chunk_remaining,
305 if self.hint_total_size {
306 Some(self.iter.len())
307 } else {
308 None
309 },
310 )
311 }
312 }
313
314 #[derive(Clone)]
315 struct WindowHintedIterator<I: ExactSizeIterator + Iterator + Clone> {
316 iter: I,
317 window_size: usize,
318 hint_total_size: bool,
319 }
320 impl<I: ExactSizeIterator + Iterator + Clone> Iterator for WindowHintedIterator<I> {
321 type Item = I::Item;
322
323 fn next(&mut self) -> Option<Self::Item> {
324 self.iter.next()
325 }
326
327 fn size_hint(&self) -> (usize, Option<usize>) {
328 (
329 core::cmp::min(self.iter.len(), self.window_size),
330 if self.hint_total_size {
331 Some(self.iter.len())
332 } else {
333 None
334 },
335 )
336 }
337 }
338
339 #[test]
340 #[cfg_attr(miri, ignore)] fn test_iterator_choose() {
342 let r = &mut crate::test::rng(109);
343 fn test_iter<R: Rng + ?Sized, Iter: Iterator<Item = usize> + Clone>(r: &mut R, iter: Iter) {
344 let mut chosen = [0i32; 9];
345 for _ in 0..1000 {
346 let picked = iter.clone().choose(r).unwrap();
347 chosen[picked] += 1;
348 }
349 for count in chosen.iter() {
350 assert!(
354 72 < *count && *count < 154,
355 "count not close to 1000/9: {}",
356 count
357 );
358 }
359 }
360
361 test_iter(r, 0..9);
362 test_iter(r, [0, 1, 2, 3, 4, 5, 6, 7, 8].iter().cloned());
363 #[cfg(feature = "alloc")]
364 test_iter(r, (0..9).collect::<Vec<_>>().into_iter());
365 test_iter(r, UnhintedIterator { iter: 0..9 });
366 test_iter(
367 r,
368 ChunkHintedIterator {
369 iter: 0..9,
370 chunk_size: 4,
371 chunk_remaining: 4,
372 hint_total_size: false,
373 },
374 );
375 test_iter(
376 r,
377 ChunkHintedIterator {
378 iter: 0..9,
379 chunk_size: 4,
380 chunk_remaining: 4,
381 hint_total_size: true,
382 },
383 );
384 test_iter(
385 r,
386 WindowHintedIterator {
387 iter: 0..9,
388 window_size: 2,
389 hint_total_size: false,
390 },
391 );
392 test_iter(
393 r,
394 WindowHintedIterator {
395 iter: 0..9,
396 window_size: 2,
397 hint_total_size: true,
398 },
399 );
400
401 assert_eq!((0..0).choose(r), None);
402 assert_eq!(UnhintedIterator { iter: 0..0 }.choose(r), None);
403 }
404
405 #[test]
406 #[cfg_attr(miri, ignore)] fn test_iterator_choose_stable() {
408 let r = &mut crate::test::rng(109);
409 fn test_iter<R: Rng + ?Sized, Iter: Iterator<Item = usize> + Clone>(r: &mut R, iter: Iter) {
410 let mut chosen = [0i32; 9];
411 for _ in 0..1000 {
412 let picked = iter.clone().choose_stable(r).unwrap();
413 chosen[picked] += 1;
414 }
415 for count in chosen.iter() {
416 assert!(
420 72 < *count && *count < 154,
421 "count not close to 1000/9: {}",
422 count
423 );
424 }
425 }
426
427 test_iter(r, 0..9);
428 test_iter(r, [0, 1, 2, 3, 4, 5, 6, 7, 8].iter().cloned());
429 #[cfg(feature = "alloc")]
430 test_iter(r, (0..9).collect::<Vec<_>>().into_iter());
431 test_iter(r, UnhintedIterator { iter: 0..9 });
432 test_iter(
433 r,
434 ChunkHintedIterator {
435 iter: 0..9,
436 chunk_size: 4,
437 chunk_remaining: 4,
438 hint_total_size: false,
439 },
440 );
441 test_iter(
442 r,
443 ChunkHintedIterator {
444 iter: 0..9,
445 chunk_size: 4,
446 chunk_remaining: 4,
447 hint_total_size: true,
448 },
449 );
450 test_iter(
451 r,
452 WindowHintedIterator {
453 iter: 0..9,
454 window_size: 2,
455 hint_total_size: false,
456 },
457 );
458 test_iter(
459 r,
460 WindowHintedIterator {
461 iter: 0..9,
462 window_size: 2,
463 hint_total_size: true,
464 },
465 );
466
467 assert_eq!((0..0).choose(r), None);
468 assert_eq!(UnhintedIterator { iter: 0..0 }.choose(r), None);
469 }
470
471 #[test]
472 #[cfg_attr(miri, ignore)] fn test_iterator_choose_stable_stability() {
474 fn test_iter(iter: impl Iterator<Item = usize> + Clone) -> [i32; 9] {
475 let r = &mut crate::test::rng(109);
476 let mut chosen = [0i32; 9];
477 for _ in 0..1000 {
478 let picked = iter.clone().choose_stable(r).unwrap();
479 chosen[picked] += 1;
480 }
481 chosen
482 }
483
484 let reference = test_iter(0..9);
485 assert_eq!(
486 test_iter([0, 1, 2, 3, 4, 5, 6, 7, 8].iter().cloned()),
487 reference
488 );
489
490 #[cfg(feature = "alloc")]
491 assert_eq!(test_iter((0..9).collect::<Vec<_>>().into_iter()), reference);
492 assert_eq!(test_iter(UnhintedIterator { iter: 0..9 }), reference);
493 assert_eq!(
494 test_iter(ChunkHintedIterator {
495 iter: 0..9,
496 chunk_size: 4,
497 chunk_remaining: 4,
498 hint_total_size: false,
499 }),
500 reference
501 );
502 assert_eq!(
503 test_iter(ChunkHintedIterator {
504 iter: 0..9,
505 chunk_size: 4,
506 chunk_remaining: 4,
507 hint_total_size: true,
508 }),
509 reference
510 );
511 assert_eq!(
512 test_iter(WindowHintedIterator {
513 iter: 0..9,
514 window_size: 2,
515 hint_total_size: false,
516 }),
517 reference
518 );
519 assert_eq!(
520 test_iter(WindowHintedIterator {
521 iter: 0..9,
522 window_size: 2,
523 hint_total_size: true,
524 }),
525 reference
526 );
527 }
528
529 #[test]
530 #[cfg(feature = "alloc")]
531 fn test_sample_iter() {
532 let min_val = 1;
533 let max_val = 100;
534
535 let mut r = crate::test::rng(401);
536 let vals = (min_val..max_val).collect::<Vec<i32>>();
537 let small_sample = vals.iter().choose_multiple(&mut r, 5);
538 let large_sample = vals.iter().choose_multiple(&mut r, vals.len() + 5);
539
540 assert_eq!(small_sample.len(), 5);
541 assert_eq!(large_sample.len(), vals.len());
542 assert_eq!(large_sample, vals.iter().collect::<Vec<_>>());
544
545 assert!(small_sample
546 .iter()
547 .all(|e| { **e >= min_val && **e <= max_val }));
548 }
549
550 #[test]
551 fn value_stability_choose() {
552 fn choose<I: Iterator<Item = u32>>(iter: I) -> Option<u32> {
553 let mut rng = crate::test::rng(411);
554 iter.choose(&mut rng)
555 }
556
557 assert_eq!(choose([].iter().cloned()), None);
558 assert_eq!(choose(0..100), Some(33));
559 assert_eq!(choose(UnhintedIterator { iter: 0..100 }), Some(27));
560 assert_eq!(
561 choose(ChunkHintedIterator {
562 iter: 0..100,
563 chunk_size: 32,
564 chunk_remaining: 32,
565 hint_total_size: false,
566 }),
567 Some(91)
568 );
569 assert_eq!(
570 choose(ChunkHintedIterator {
571 iter: 0..100,
572 chunk_size: 32,
573 chunk_remaining: 32,
574 hint_total_size: true,
575 }),
576 Some(91)
577 );
578 assert_eq!(
579 choose(WindowHintedIterator {
580 iter: 0..100,
581 window_size: 32,
582 hint_total_size: false,
583 }),
584 Some(34)
585 );
586 assert_eq!(
587 choose(WindowHintedIterator {
588 iter: 0..100,
589 window_size: 32,
590 hint_total_size: true,
591 }),
592 Some(34)
593 );
594 }
595
596 #[test]
597 fn value_stability_choose_stable() {
598 fn choose<I: Iterator<Item = u32>>(iter: I) -> Option<u32> {
599 let mut rng = crate::test::rng(411);
600 iter.choose_stable(&mut rng)
601 }
602
603 assert_eq!(choose([].iter().cloned()), None);
604 assert_eq!(choose(0..100), Some(27));
605 assert_eq!(choose(UnhintedIterator { iter: 0..100 }), Some(27));
606 assert_eq!(
607 choose(ChunkHintedIterator {
608 iter: 0..100,
609 chunk_size: 32,
610 chunk_remaining: 32,
611 hint_total_size: false,
612 }),
613 Some(27)
614 );
615 assert_eq!(
616 choose(ChunkHintedIterator {
617 iter: 0..100,
618 chunk_size: 32,
619 chunk_remaining: 32,
620 hint_total_size: true,
621 }),
622 Some(27)
623 );
624 assert_eq!(
625 choose(WindowHintedIterator {
626 iter: 0..100,
627 window_size: 32,
628 hint_total_size: false,
629 }),
630 Some(27)
631 );
632 assert_eq!(
633 choose(WindowHintedIterator {
634 iter: 0..100,
635 window_size: 32,
636 hint_total_size: true,
637 }),
638 Some(27)
639 );
640 }
641
642 #[test]
643 fn value_stability_choose_multiple() {
644 fn do_test<I: Clone + Iterator<Item = u32>>(iter: I, v: &[u32]) {
645 let mut rng = crate::test::rng(412);
646 let mut buf = [0u32; 8];
647 assert_eq!(
648 iter.clone().choose_multiple_fill(&mut rng, &mut buf),
649 v.len()
650 );
651 assert_eq!(&buf[0..v.len()], v);
652
653 #[cfg(feature = "alloc")]
654 {
655 let mut rng = crate::test::rng(412);
656 assert_eq!(iter.choose_multiple(&mut rng, v.len()), v);
657 }
658 }
659
660 do_test(0..4, &[0, 1, 2, 3]);
661 do_test(0..8, &[0, 1, 2, 3, 4, 5, 6, 7]);
662 do_test(0..100, &[77, 95, 38, 23, 25, 8, 58, 40]);
663 }
664}