rand_mt/mt.rs
1// src/mt.rs
2//
3// Copyright (c) 2015,2017 rust-mersenne-twister developers
4// Copyright (c) 2020 Ryan Lopopolo <rjl@hyperbo.la>
5//
6// Licensed under the Apache License, Version 2.0
7// <LICENSE-APACHE> or <http://www.apache.org/licenses/LICENSE-2.0> or the MIT
8// license <LICENSE-MIT> or <http://opensource.org/licenses/MIT>, at your
9// option. All files in the project carrying such notice may not be copied,
10// modified, or distributed except according to those terms.
11
12use core::convert::TryFrom;
13use core::fmt;
14use core::num::Wrapping;
15
16use crate::RecoverRngError;
17
18#[cfg(feature = "rand-traits")]
19mod rand;
20
21const N: usize = 624;
22const M: usize = 397;
23const ONE: Wrapping<u32> = Wrapping(1);
24const MATRIX_A: Wrapping<u32> = Wrapping(0x9908_b0df);
25const UPPER_MASK: Wrapping<u32> = Wrapping(0x8000_0000);
26const LOWER_MASK: Wrapping<u32> = Wrapping(0x7fff_ffff);
27
28/// The 32-bit flavor of the Mersenne Twister pseudorandom number
29/// generator.
30///
31/// The official name of this RNG is `MT19937`. It natively outputs `u32`.
32///
33/// # Size
34///
35/// `Mt` requires approximately 2.5 kilobytes of internal state.
36///
37/// You may wish to store an `Mt` on the heap in a [`Box`] to make it easier to
38/// embed in another struct.
39///
40/// `Mt` is also the same size as [`Mt64`](crate::Mt64).
41///
42/// ```
43/// # use core::mem;
44/// # use rand_mt::{Mt, Mt64};
45/// assert_eq!(2504, size_of::<Mt>());
46/// assert_eq!(size_of::<Mt64>(), size_of::<Mt>());
47/// ```
48///
49/// [`Box`]: https://doc.rust-lang.org/std/boxed/struct.Box.html
50#[derive(Clone, Hash, PartialEq, Eq, PartialOrd, Ord)]
51pub struct Mt {
52 idx: usize,
53 state: [Wrapping<u32>; N],
54}
55
56impl fmt::Debug for Mt {
57 #[inline]
58 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
59 f.write_str("Mt {}")
60 }
61}
62
63impl Default for Mt {
64 /// Return a new `Mt` with the default seed.
65 ///
66 /// Equivalent to calling [`Mt::new_unseeded`].
67 #[inline]
68 fn default() -> Self {
69 Self::new_unseeded()
70 }
71}
72
73impl From<[u8; 4]> for Mt {
74 /// Construct a Mersenne Twister RNG from 4 bytes.
75 ///
76 /// The given bytes are treated as a little endian encoded `u32`.
77 ///
78 /// # Examples
79 ///
80 /// ```
81 /// # use rand_mt::Mt;
82 /// // Default MT seed
83 /// let seed = 5489_u32.to_le_bytes();
84 /// let mut mt = Mt::from(seed);
85 /// assert_ne!(mt.next_u32(), mt.next_u32());
86 /// ```
87 ///
88 /// This constructor is equivalent to passing a little endian encoded `u32`.
89 ///
90 /// ```
91 /// # use rand_mt::Mt;
92 /// // Default MT seed
93 /// let seed = 5489_u32.to_le_bytes();
94 /// let mt1 = Mt::from(seed);
95 /// let mt2 = Mt::new(5489_u32);
96 /// assert_eq!(mt1, mt2);
97 /// ```
98 #[inline]
99 fn from(seed: [u8; 4]) -> Self {
100 Self::new(u32::from_le_bytes(seed))
101 }
102}
103
104impl From<u32> for Mt {
105 /// Construct a Mersenne Twister RNG from a `u32` seed.
106 ///
107 /// This function is equivalent to [`new`].
108 ///
109 /// # Examples
110 ///
111 /// ```
112 /// # use rand_mt::Mt;
113 /// // Default MT seed
114 /// let seed = 5489_u32;
115 /// let mt1 = Mt::from(seed);
116 /// let mt2 = Mt::new(seed);
117 /// assert_eq!(mt1, mt2);
118 ///
119 /// // Non-default MT seed
120 /// let seed = 9927_u32;
121 /// let mt1 = Mt::from(seed);
122 /// let mt2 = Mt::new(seed);
123 /// assert_eq!(mt1, mt2);
124 /// ```
125 ///
126 /// [`new`]: Self::new
127 #[inline]
128 fn from(seed: u32) -> Self {
129 Self::new(seed)
130 }
131}
132
133impl From<[u32; N]> for Mt {
134 /// Recover the internal state of a Mersenne Twister using the past 624
135 /// samples.
136 ///
137 /// This conversion takes a history of samples from a RNG and returns a
138 /// RNG that will produce identical output to the RNG that supplied the
139 /// samples.
140 #[inline]
141 fn from(key: [u32; N]) -> Self {
142 let mut mt = Self {
143 idx: N,
144 state: [Wrapping(0); N],
145 };
146 for (sample, out) in key.iter().copied().zip(mt.state.iter_mut()) {
147 *out = Wrapping(untemper(sample));
148 }
149 mt
150 }
151}
152
153impl TryFrom<&[u32]> for Mt {
154 type Error = RecoverRngError;
155
156 /// Attempt to recover the internal state of a Mersenne Twister using the
157 /// past 624 samples.
158 ///
159 /// This conversion takes a history of samples from a RNG and returns a RNG
160 /// that will produce identical output to the RNG that supplied the samples.
161 ///
162 /// This conversion is implemented with [`Mt::recover`].
163 ///
164 /// # Errors
165 ///
166 /// If `key` has less than 624 elements, an error is returned because there
167 /// is not enough data to fully initialize the RNG.
168 ///
169 /// If `key` has more than 624 elements, an error is returned because the
170 /// recovered RNG will not produce identical output to the RNG that supplied
171 /// the samples.
172 #[inline]
173 fn try_from(key: &[u32]) -> Result<Self, Self::Error> {
174 Self::recover(key.iter().copied())
175 }
176}
177
178impl Mt {
179 /// Default seed used by [`Mt::new_unseeded`].
180 pub const DEFAULT_SEED: u32 = 5489_u32;
181
182 /// Create a new Mersenne Twister random number generator using the given
183 /// seed.
184 ///
185 /// # Examples
186 ///
187 /// ## Constructing with a `u32` seed
188 ///
189 /// ```
190 /// # use rand_mt::Mt;
191 /// let seed = 123_456_789_u32;
192 /// let mt1 = Mt::new(seed);
193 /// let mt2 = Mt::from(seed.to_le_bytes());
194 /// assert_eq!(mt1, mt2);
195 /// ```
196 ///
197 /// ## Constructing with default seed
198 ///
199 /// ```
200 /// # use rand_mt::Mt;
201 /// let mt1 = Mt::new(Mt::DEFAULT_SEED);
202 /// let mt2 = Mt::new_unseeded();
203 /// assert_eq!(mt1, mt2);
204 /// ```
205 #[inline]
206 #[must_use]
207 pub fn new(seed: u32) -> Self {
208 let mut mt = Self {
209 idx: 0,
210 state: [Wrapping(0); N],
211 };
212 mt.reseed(seed);
213 mt
214 }
215
216 /// Create a new Mersenne Twister random number generator using the given
217 /// key.
218 ///
219 /// Key can have any length.
220 #[inline]
221 #[must_use]
222 pub fn new_with_key<I>(key: I) -> Self
223 where
224 I: IntoIterator<Item = u32>,
225 I::IntoIter: Clone,
226 {
227 let mut mt = Self {
228 idx: 0,
229 state: [Wrapping(0); N],
230 };
231 mt.reseed_with_key(key);
232 mt
233 }
234
235 /// Create a new Mersenne Twister random number generator using the default
236 /// fixed seed.
237 ///
238 /// # Examples
239 ///
240 /// ```
241 /// # use rand_mt::Mt;
242 /// // Default MT seed
243 /// let seed = 5489_u32;
244 /// let mt = Mt::new(seed);
245 /// let unseeded = Mt::new_unseeded();
246 /// assert_eq!(mt, unseeded);
247 /// ```
248 #[inline]
249 #[must_use]
250 pub fn new_unseeded() -> Self {
251 Self::new(Self::DEFAULT_SEED)
252 }
253
254 /// Generate next `u64` output.
255 ///
256 /// This function is implemented by generating two `u32`s from the RNG and
257 /// performing shifting and masking to turn them into a `u64` output.
258 ///
259 /// # Examples
260 ///
261 /// ```
262 /// # use rand_mt::Mt;
263 /// let mut mt = Mt::new_unseeded();
264 /// assert_ne!(mt.next_u64(), mt.next_u64());
265 /// ```
266 #[inline]
267 pub fn next_u64(&mut self) -> u64 {
268 let out = u64::from(self.next_u32());
269 let out = out << 32;
270 out | u64::from(self.next_u32())
271 }
272
273 /// Generate next `u32` output.
274 ///
275 /// `u32` is the native output of the generator. This function advances the
276 /// RNG step counter by one.
277 ///
278 /// # Examples
279 ///
280 /// ```
281 /// # use rand_mt::Mt;
282 /// let mut mt = Mt::new_unseeded();
283 /// assert_ne!(mt.next_u32(), mt.next_u32());
284 /// ```
285 #[inline]
286 pub fn next_u32(&mut self) -> u32 {
287 // Failing this check indicates that, somehow, the structure
288 // was not initialized.
289 debug_assert!(self.idx != 0);
290 if self.idx >= N {
291 fill_next_state(self);
292 }
293 let Wrapping(x) = self.state[self.idx];
294 self.idx += 1;
295 temper(x)
296 }
297
298 /// Fill a buffer with bytes generated from the RNG.
299 ///
300 /// This method generates random `u32`s (the native output unit of the RNG)
301 /// until `dest` is filled.
302 ///
303 /// This method may discard some output bits if `dest.len()` is not a
304 /// multiple of 4.
305 ///
306 /// # Examples
307 ///
308 /// ```
309 /// # use rand_mt::Mt;
310 /// let mut mt = Mt::new_unseeded();
311 /// let mut buf = [0; 32];
312 /// mt.fill_bytes(&mut buf);
313 /// assert_ne!([0; 32], buf);
314 /// let mut buf = [0; 31];
315 /// mt.fill_bytes(&mut buf);
316 /// assert_ne!([0; 31], buf);
317 /// ```
318 #[inline]
319 pub fn fill_bytes(&mut self, dest: &mut [u8]) {
320 const CHUNK: usize = size_of::<u32>();
321 let mut dest_chunks = dest.chunks_exact_mut(CHUNK);
322
323 for next in &mut dest_chunks {
324 let chunk: [u8; CHUNK] = self.next_u32().to_le_bytes();
325 next.copy_from_slice(&chunk);
326 }
327
328 let remainder = dest_chunks.into_remainder();
329 if remainder.is_empty() {
330 return;
331 }
332 remainder
333 .iter_mut()
334 .zip(self.next_u32().to_le_bytes().iter())
335 .for_each(|(cell, &byte)| {
336 *cell = byte;
337 });
338 }
339
340 /// Attempt to recover the internal state of a Mersenne Twister using the
341 /// past 624 samples.
342 ///
343 /// This conversion takes a history of samples from a RNG and returns a
344 /// RNG that will produce identical output to the RNG that supplied the
345 /// samples.
346 ///
347 /// This constructor is also available as a [`TryFrom`] implementation for
348 /// `&[u32]`.
349 ///
350 /// # Errors
351 ///
352 /// If `key` has less than 624 elements, an error is returned because there
353 /// is not enough data to fully initialize the RNG.
354 ///
355 /// If `key` has more than 624 elements, an error is returned because the
356 /// recovered RNG will not produce identical output to the RNG that supplied
357 /// the samples.
358 #[inline]
359 pub fn recover<I>(key: I) -> Result<Self, RecoverRngError>
360 where
361 I: IntoIterator<Item = u32>,
362 {
363 let mut mt = Self {
364 idx: N,
365 state: [Wrapping(0); N],
366 };
367 let mut state = mt.state.iter_mut();
368 for sample in key {
369 let out = state.next().ok_or(RecoverRngError::TooManySamples(N))?;
370 *out = Wrapping(untemper(sample));
371 }
372 // If the state iterator still has unfilled cells, the given iterator
373 // was too short. If there are no additional cells, return the
374 // initialized RNG.
375 if state.next().is_none() {
376 Ok(mt)
377 } else {
378 Err(RecoverRngError::TooFewSamples(N))
379 }
380 }
381
382 /// Reseed a Mersenne Twister from a single `u32`.
383 ///
384 /// # Examples
385 ///
386 /// ```
387 /// # use rand_mt::Mt;
388 /// // Default MT seed
389 /// let mut mt = Mt::new_unseeded();
390 /// let first = mt.next_u32();
391 /// mt.fill_bytes(&mut [0; 512]);
392 /// // Default MT seed
393 /// mt.reseed(5489_u32);
394 /// assert_eq!(first, mt.next_u32());
395 /// ```
396 #[inline]
397 #[expect(
398 clippy::cast_possible_truncation,
399 reason = "const N is always less than u32::MAX"
400 )]
401 pub fn reseed(&mut self, seed: u32) {
402 self.idx = N;
403 self.state[0] = Wrapping(seed);
404 for i in 1..N {
405 self.state[i] = Wrapping(1_812_433_253)
406 * (self.state[i - 1] ^ (self.state[i - 1] >> 30))
407 + Wrapping(i as u32);
408 }
409 }
410
411 /// Reseed a Mersenne Twister from am iterator of `u32`s.
412 ///
413 /// Key can have any length.
414 #[inline]
415 #[allow(clippy::cast_possible_truncation)]
416 pub fn reseed_with_key<I>(&mut self, key: I)
417 where
418 I: IntoIterator<Item = u32>,
419 I::IntoIter: Clone,
420 {
421 self.reseed(19_650_218_u32);
422 let mut i = 1_usize;
423 for (j, piece) in key.into_iter().enumerate().cycle().take(N) {
424 self.state[i] = (self.state[i]
425 ^ ((self.state[i - 1] ^ (self.state[i - 1] >> 30)) * Wrapping(1_664_525)))
426 + Wrapping(piece)
427 + Wrapping(j as u32);
428 i += 1;
429 if i >= N {
430 self.state[0] = self.state[N - 1];
431 i = 1;
432 }
433 }
434 for _ in 0..N - 1 {
435 self.state[i] = (self.state[i]
436 ^ ((self.state[i - 1] ^ (self.state[i - 1] >> 30)) * Wrapping(1_566_083_941)))
437 - Wrapping(i as u32);
438 i += 1;
439 if i >= N {
440 self.state[0] = self.state[N - 1];
441 i = 1;
442 }
443 }
444 self.state[0] = Wrapping(1 << 31);
445 }
446}
447
448#[inline]
449fn temper(mut x: u32) -> u32 {
450 x ^= x >> 11;
451 x ^= (x << 7) & 0x9d2c_5680;
452 x ^= (x << 15) & 0xefc6_0000;
453 x ^= x >> 18;
454 x
455}
456
457#[inline]
458fn untemper(mut x: u32) -> u32 {
459 // reverse `x ^= x>>18;`
460 x ^= x >> 18;
461
462 // reverse `x ^= (x<<15) & 0xefc6_0000;`
463 x ^= (x << 15) & 0x2fc6_0000;
464 x ^= (x << 15) & 0xc000_0000;
465
466 // reverse `x ^= (x<< 7) & 0x9d2c_5680;`
467 x ^= (x << 7) & 0x0000_1680;
468 x ^= (x << 7) & 0x000c_4000;
469 x ^= (x << 7) & 0x0d20_0000;
470 x ^= (x << 7) & 0x9000_0000;
471
472 // reverse `x ^= x>>11;`
473 x ^= x >> 11;
474 x ^= x >> 22;
475
476 x
477}
478
479#[inline]
480fn fill_next_state(rng: &mut Mt) {
481 for i in 0..N - M {
482 let x = (rng.state[i] & UPPER_MASK) | (rng.state[i + 1] & LOWER_MASK);
483 rng.state[i] = rng.state[i + M] ^ (x >> 1) ^ ((x & ONE) * MATRIX_A);
484 }
485 for i in N - M..N - 1 {
486 let x = (rng.state[i] & UPPER_MASK) | (rng.state[i + 1] & LOWER_MASK);
487 rng.state[i] = rng.state[i + M - N] ^ (x >> 1) ^ ((x & ONE) * MATRIX_A);
488 }
489 let x = (rng.state[N - 1] & UPPER_MASK) | (rng.state[0] & LOWER_MASK);
490 rng.state[N - 1] = rng.state[M - 1] ^ (x >> 1) ^ ((x & ONE) * MATRIX_A);
491 rng.idx = 0;
492}
493
494#[cfg(test)]
495mod tests {
496 use core::convert::TryFrom;
497 use core::iter;
498 use core::num::Wrapping;
499
500 use super::{Mt, N};
501 use crate::vectors::mt::{STATE_SEEDED_BY_SLICE, STATE_SEEDED_BY_U32, TEST_OUTPUT};
502 use crate::RecoverRngError;
503
504 #[test]
505 fn seeded_state_from_u32_seed() {
506 let mt = Mt::new(0x1234_5678_u32);
507 let mt_from_seed = Mt::from(0x1234_5678_u32.to_le_bytes());
508 assert_eq!(mt.state, mt_from_seed.state);
509 for (&Wrapping(x), &y) in mt.state.iter().zip(STATE_SEEDED_BY_U32.iter()) {
510 assert_eq!(x, y);
511 }
512 for (&Wrapping(x), &y) in mt_from_seed.state.iter().zip(STATE_SEEDED_BY_U32.iter()) {
513 assert_eq!(x, y);
514 }
515 }
516
517 #[test]
518 fn seeded_state_from_u32_slice_key() {
519 let key = [0x123_u32, 0x234_u32, 0x345_u32, 0x456_u32];
520 let mt = Mt::new_with_key(key.iter().copied());
521 for (&Wrapping(x), &y) in mt.state.iter().zip(STATE_SEEDED_BY_SLICE.iter()) {
522 assert_eq!(x, y);
523 }
524 }
525
526 #[test]
527 fn seed_with_empty_iter_returns() {
528 let _rng = Mt::new_with_key(iter::empty());
529 }
530
531 #[test]
532 fn output_from_u32_slice_key() {
533 let key = [0x123_u32, 0x234_u32, 0x345_u32, 0x456_u32];
534 let mut mt = Mt::new_with_key(key.iter().copied());
535 for &x in &TEST_OUTPUT {
536 assert_eq!(x, mt.next_u32());
537 }
538 }
539
540 #[test]
541 fn temper_untemper_is_identity() {
542 for _ in 0..10_000 {
543 let x = getrandom::u32().unwrap();
544 assert_eq!(x, super::untemper(super::temper(x)));
545 }
546 }
547
548 #[test]
549 fn untemper_temper_is_identity() {
550 for _ in 0..10_000 {
551 let x = getrandom::u32().unwrap();
552 assert_eq!(x, super::temper(super::untemper(x)));
553 }
554 }
555
556 #[test]
557 fn recovery_via_from() {
558 for _ in 0..100 {
559 let seed = getrandom::u32().unwrap();
560 for skip in 0..256 {
561 let mut orig_mt = Mt::new(seed);
562 // skip some samples so the RNG is in an intermediate state
563 for _ in 0..skip {
564 orig_mt.next_u32();
565 }
566 let mut samples = [0; 624];
567 for sample in &mut samples {
568 *sample = orig_mt.next_u32();
569 }
570 let mut recovered_mt = Mt::from(samples);
571 for _ in 0..624 * 2 {
572 assert_eq!(orig_mt.next_u32(), recovered_mt.next_u32());
573 }
574 }
575 }
576 }
577
578 #[test]
579 fn recovery_via_recover() {
580 for _ in 0..100 {
581 let seed = getrandom::u32().unwrap();
582 for skip in 0..256 {
583 let mut orig_mt = Mt::new(seed);
584 // skip some samples so the RNG is in an intermediate state
585 for _ in 0..skip {
586 orig_mt.next_u32();
587 }
588 let mut samples = [0; 624];
589 for sample in &mut samples {
590 *sample = orig_mt.next_u32();
591 }
592 let mut recovered_mt = Mt::recover(samples.iter().copied()).unwrap();
593 for _ in 0..624 * 2 {
594 assert_eq!(orig_mt.next_u32(), recovered_mt.next_u32());
595 }
596 }
597 }
598 }
599
600 #[test]
601 fn recover_required_exact_sample_length_via_from() {
602 assert_eq!(
603 Mt::try_from(&[0; 0][..]),
604 Err(RecoverRngError::TooFewSamples(N))
605 );
606 assert_eq!(
607 Mt::try_from(&[0; 1][..]),
608 Err(RecoverRngError::TooFewSamples(N))
609 );
610 assert_eq!(
611 Mt::try_from(&[0; 623][..]),
612 Err(RecoverRngError::TooFewSamples(N))
613 );
614 Mt::try_from(&[0; 624][..]).unwrap();
615 assert_eq!(
616 Mt::try_from(&[0; 625][..]),
617 Err(RecoverRngError::TooManySamples(N))
618 );
619 assert_eq!(
620 Mt::try_from(&[0; 1000][..]),
621 Err(RecoverRngError::TooManySamples(N))
622 );
623 }
624
625 #[test]
626 fn recover_required_exact_sample_length_via_recover() {
627 assert_eq!(
628 Mt::recover([0; 0].iter().copied()),
629 Err(RecoverRngError::TooFewSamples(N))
630 );
631 assert_eq!(
632 Mt::recover([0; 1].iter().copied()),
633 Err(RecoverRngError::TooFewSamples(N))
634 );
635 assert_eq!(
636 Mt::recover([0; 623].iter().copied()),
637 Err(RecoverRngError::TooFewSamples(N))
638 );
639 Mt::recover([0; 624].iter().copied()).unwrap();
640 assert_eq!(
641 Mt::recover([0; 625].iter().copied()),
642 Err(RecoverRngError::TooManySamples(N))
643 );
644 assert_eq!(
645 Mt::recover([0; 1000].iter().copied()),
646 Err(RecoverRngError::TooManySamples(N))
647 );
648 }
649
650 #[test]
651 fn fmt_debug_does_not_leak_seed() {
652 use core::fmt::Write as _;
653 use std::string::String;
654
655 let random = Mt::new(874);
656
657 let mut buf = String::new();
658 write!(&mut buf, "{random:?}").unwrap();
659 assert!(!buf.contains("874"));
660 assert_eq!(buf, "Mt {}");
661
662 let random = Mt::new(123_456);
663
664 let mut buf = String::new();
665 write!(&mut buf, "{random:?}").unwrap();
666 assert!(!buf.contains("123456"));
667 assert_eq!(buf, "Mt {}");
668 }
669
670 #[test]
671 fn default_is_new_unseeded() {
672 let mut default = Mt::default();
673 let mut unseeded = Mt::new_unseeded();
674
675 assert_eq!(default, unseeded);
676 for _ in 0..1024 {
677 assert_eq!(default.next_u32(), unseeded.next_u32());
678 }
679 }
680}