rand/seq/coin_flipper.rs
1// Copyright 2018-2023 Developers of the Rand project.
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9use crate::RngCore;
10
11pub(crate) struct CoinFlipper<R: RngCore> {
12 pub rng: R,
13 chunk: u32, // TODO(opt): this should depend on RNG word size
14 chunk_remaining: u32,
15}
16
17impl<R: RngCore> CoinFlipper<R> {
18 pub fn new(rng: R) -> Self {
19 Self {
20 rng,
21 chunk: 0,
22 chunk_remaining: 0,
23 }
24 }
25
26 #[inline]
27 /// Returns true with a probability of 1 / d
28 /// Uses an expected two bits of randomness
29 /// Panics if d == 0
30 pub fn random_ratio_one_over(&mut self, d: usize) -> bool {
31 debug_assert_ne!(d, 0);
32 // This uses the same logic as `random_ratio` but is optimized for the case that
33 // the starting numerator is one (which it always is for `Sequence::Choose()`)
34
35 // In this case (but not `random_ratio`), this way of calculating c is always accurate
36 let c = (usize::BITS - 1 - d.leading_zeros()).min(32);
37
38 if self.flip_c_heads(c) {
39 let numerator = 1 << c;
40 self.random_ratio(numerator, d)
41 } else {
42 false
43 }
44 }
45
46 #[inline]
47 /// Returns true with a probability of n / d
48 /// Uses an expected two bits of randomness
49 fn random_ratio(&mut self, mut n: usize, d: usize) -> bool {
50 // Explanation:
51 // We are trying to return true with a probability of n / d
52 // If n >= d, we can just return true
53 // Otherwise there are two possibilities 2n < d and 2n >= d
54 // In either case we flip a coin.
55 // If 2n < d
56 // If it comes up tails, return false
57 // If it comes up heads, double n and start again
58 // This is fair because (0.5 * 0) + (0.5 * 2n / d) = n / d and 2n is less than d
59 // (if 2n was greater than d we would effectively round it down to 1
60 // by returning true)
61 // If 2n >= d
62 // If it comes up tails, set n to 2n - d and start again
63 // If it comes up heads, return true
64 // This is fair because (0.5 * 1) + (0.5 * (2n - d) / d) = n / d
65 // Note that if 2n = d and the coin comes up tails, n will be set to 0
66 // before restarting which is equivalent to returning false.
67
68 // As a performance optimization we can flip multiple coins at once
69 // This is efficient because we can use the `lzcnt` intrinsic
70 // We can check up to 32 flips at once but we only receive one bit of information
71 // - all heads or at least one tail.
72
73 // Let c be the number of coins to flip. 1 <= c <= 32
74 // If 2n < d, n * 2^c < d
75 // If the result is all heads, then set n to n * 2^c
76 // If there was at least one tail, return false
77 // If 2n >= d, the order of results matters so we flip one coin at a time so c = 1
78 // Ideally, c will be as high as possible within these constraints
79
80 while n < d {
81 // Find a good value for c by counting leading zeros
82 // This will either give the highest possible c, or 1 less than that
83 let c = n
84 .leading_zeros()
85 .saturating_sub(d.leading_zeros() + 1)
86 .clamp(1, 32);
87
88 if self.flip_c_heads(c) {
89 // All heads
90 // Set n to n * 2^c
91 // If 2n >= d, the while loop will exit and we will return `true`
92 // If n * 2^c > `usize::MAX` we always return `true` anyway
93 n = n.saturating_mul(2_usize.pow(c));
94 } else {
95 // At least one tail
96 if c == 1 {
97 // Calculate 2n - d.
98 // We need to use wrapping as 2n might be greater than `usize::MAX`
99 let next_n = n.wrapping_add(n).wrapping_sub(d);
100 if next_n == 0 || next_n > n {
101 // This will happen if 2n < d
102 return false;
103 }
104 n = next_n;
105 } else {
106 // c > 1 so 2n < d so we can return false
107 return false;
108 }
109 }
110 }
111 true
112 }
113
114 /// If the next `c` bits of randomness all represent heads, consume them, return true
115 /// Otherwise return false and consume the number of heads plus one.
116 /// Generates new bits of randomness when necessary (in 32 bit chunks)
117 /// Has a 1 in 2 to the `c` chance of returning true
118 /// `c` must be less than or equal to 32
119 fn flip_c_heads(&mut self, mut c: u32) -> bool {
120 debug_assert!(c <= 32);
121 // Note that zeros on the left of the chunk represent heads.
122 // It needs to be this way round because zeros are filled in when left shifting
123 loop {
124 let zeros = self.chunk.leading_zeros();
125
126 if zeros < c {
127 // The happy path - we found a 1 and can return false
128 // Note that because a 1 bit was detected,
129 // We cannot have run out of random bits so we don't need to check
130
131 // First consume all of the bits read
132 // Using shl seems to give worse performance for size-hinted iterators
133 self.chunk = self.chunk.wrapping_shl(zeros + 1);
134
135 self.chunk_remaining = self.chunk_remaining.saturating_sub(zeros + 1);
136 return false;
137 } else {
138 // The number of zeros is larger than `c`
139 // There are two possibilities
140 if let Some(new_remaining) = self.chunk_remaining.checked_sub(c) {
141 // Those zeroes were all part of our random chunk,
142 // throw away `c` bits of randomness and return true
143 self.chunk_remaining = new_remaining;
144 self.chunk <<= c;
145 return true;
146 } else {
147 // Some of those zeroes were part of the random chunk
148 // and some were part of the space behind it
149 // We need to take into account only the zeroes that were random
150 c -= self.chunk_remaining;
151
152 // Generate a new chunk
153 self.chunk = self.rng.next_u32();
154 self.chunk_remaining = 32;
155 // Go back to start of loop
156 }
157 }
158 }
159 }
160}