rand_mt/
mt.rs

1// src/mt.rs
2//
3// Copyright (c) 2015,2017 rust-mersenne-twister developers
4// Copyright (c) 2020 Ryan Lopopolo <rjl@hyperbo.la>
5//
6// Licensed under the Apache License, Version 2.0
7// <LICENSE-APACHE> or <http://www.apache.org/licenses/LICENSE-2.0> or the MIT
8// license <LICENSE-MIT> or <http://opensource.org/licenses/MIT>, at your
9// option. All files in the project carrying such notice may not be copied,
10// modified, or distributed except according to those terms.
11
12use core::convert::TryFrom;
13use core::fmt;
14use core::num::Wrapping;
15
16use crate::RecoverRngError;
17
18#[cfg(feature = "rand-traits")]
19mod rand;
20
21const N: usize = 624;
22const M: usize = 397;
23const ONE: Wrapping<u32> = Wrapping(1);
24const MATRIX_A: Wrapping<u32> = Wrapping(0x9908_b0df);
25const UPPER_MASK: Wrapping<u32> = Wrapping(0x8000_0000);
26const LOWER_MASK: Wrapping<u32> = Wrapping(0x7fff_ffff);
27
28/// The 32-bit flavor of the Mersenne Twister pseudorandom number
29/// generator.
30///
31/// The official name of this RNG is `MT19937`. It natively outputs `u32`.
32///
33/// # Size
34///
35/// `Mt` requires approximately 2.5 kilobytes of internal state.
36///
37/// You may wish to store an `Mt` on the heap in a [`Box`] to make it easier to
38/// embed in another struct.
39///
40/// `Mt` is also the same size as [`Mt64`](crate::Mt64).
41///
42/// ```
43/// # use core::mem;
44/// # use rand_mt::{Mt, Mt64};
45/// assert_eq!(2504, size_of::<Mt>());
46/// assert_eq!(size_of::<Mt64>(), size_of::<Mt>());
47/// ```
48///
49/// [`Box`]: https://doc.rust-lang.org/std/boxed/struct.Box.html
50#[derive(Clone, Hash, PartialEq, Eq, PartialOrd, Ord)]
51pub struct Mt {
52    idx: usize,
53    state: [Wrapping<u32>; N],
54}
55
56impl fmt::Debug for Mt {
57    #[inline]
58    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
59        f.write_str("Mt {}")
60    }
61}
62
63impl Default for Mt {
64    /// Return a new `Mt` with the default seed.
65    ///
66    /// Equivalent to calling [`Mt::new_unseeded`].
67    #[inline]
68    fn default() -> Self {
69        Self::new_unseeded()
70    }
71}
72
73impl From<[u8; 4]> for Mt {
74    /// Construct a Mersenne Twister RNG from 4 bytes.
75    ///
76    /// The given bytes are treated as a little endian encoded `u32`.
77    ///
78    /// # Examples
79    ///
80    /// ```
81    /// # use rand_mt::Mt;
82    /// // Default MT seed
83    /// let seed = 5489_u32.to_le_bytes();
84    /// let mut mt = Mt::from(seed);
85    /// assert_ne!(mt.next_u32(), mt.next_u32());
86    /// ```
87    ///
88    /// This constructor is equivalent to passing a little endian encoded `u32`.
89    ///
90    /// ```
91    /// # use rand_mt::Mt;
92    /// // Default MT seed
93    /// let seed = 5489_u32.to_le_bytes();
94    /// let mt1 = Mt::from(seed);
95    /// let mt2 = Mt::new(5489_u32);
96    /// assert_eq!(mt1, mt2);
97    /// ```
98    #[inline]
99    fn from(seed: [u8; 4]) -> Self {
100        Self::new(u32::from_le_bytes(seed))
101    }
102}
103
104impl From<u32> for Mt {
105    /// Construct a Mersenne Twister RNG from a `u32` seed.
106    ///
107    /// This function is equivalent to [`new`].
108    ///
109    /// # Examples
110    ///
111    /// ```
112    /// # use rand_mt::Mt;
113    /// // Default MT seed
114    /// let seed = 5489_u32;
115    /// let mt1 = Mt::from(seed);
116    /// let mt2 = Mt::new(seed);
117    /// assert_eq!(mt1, mt2);
118    ///
119    /// // Non-default MT seed
120    /// let seed = 9927_u32;
121    /// let mt1 = Mt::from(seed);
122    /// let mt2 = Mt::new(seed);
123    /// assert_eq!(mt1, mt2);
124    /// ```
125    ///
126    /// [`new`]: Self::new
127    #[inline]
128    fn from(seed: u32) -> Self {
129        Self::new(seed)
130    }
131}
132
133impl From<[u32; N]> for Mt {
134    /// Recover the internal state of a Mersenne Twister using the past 624
135    /// samples.
136    ///
137    /// This conversion takes a history of samples from a RNG and returns a
138    /// RNG that will produce identical output to the RNG that supplied the
139    /// samples.
140    #[inline]
141    fn from(key: [u32; N]) -> Self {
142        let mut mt = Self {
143            idx: N,
144            state: [Wrapping(0); N],
145        };
146        for (sample, out) in key.iter().copied().zip(mt.state.iter_mut()) {
147            *out = Wrapping(untemper(sample));
148        }
149        mt
150    }
151}
152
153impl TryFrom<&[u32]> for Mt {
154    type Error = RecoverRngError;
155
156    /// Attempt to recover the internal state of a Mersenne Twister using the
157    /// past 624 samples.
158    ///
159    /// This conversion takes a history of samples from a RNG and returns a RNG
160    /// that will produce identical output to the RNG that supplied the samples.
161    ///
162    /// This conversion is implemented with [`Mt::recover`].
163    ///
164    /// # Errors
165    ///
166    /// If `key` has less than 624 elements, an error is returned because there
167    /// is not enough data to fully initialize the RNG.
168    ///
169    /// If `key` has more than 624 elements, an error is returned because the
170    /// recovered RNG will not produce identical output to the RNG that supplied
171    /// the samples.
172    #[inline]
173    fn try_from(key: &[u32]) -> Result<Self, Self::Error> {
174        Self::recover(key.iter().copied())
175    }
176}
177
178impl Mt {
179    /// Default seed used by [`Mt::new_unseeded`].
180    pub const DEFAULT_SEED: u32 = 5489_u32;
181
182    /// Create a new Mersenne Twister random number generator using the given
183    /// seed.
184    ///
185    /// # Examples
186    ///
187    /// ## Constructing with a `u32` seed
188    ///
189    /// ```
190    /// # use rand_mt::Mt;
191    /// let seed = 123_456_789_u32;
192    /// let mt1 = Mt::new(seed);
193    /// let mt2 = Mt::from(seed.to_le_bytes());
194    /// assert_eq!(mt1, mt2);
195    /// ```
196    ///
197    /// ## Constructing with default seed
198    ///
199    /// ```
200    /// # use rand_mt::Mt;
201    /// let mt1 = Mt::new(Mt::DEFAULT_SEED);
202    /// let mt2 = Mt::new_unseeded();
203    /// assert_eq!(mt1, mt2);
204    /// ```
205    #[inline]
206    #[must_use]
207    pub fn new(seed: u32) -> Self {
208        let mut mt = Self {
209            idx: 0,
210            state: [Wrapping(0); N],
211        };
212        mt.reseed(seed);
213        mt
214    }
215
216    /// Create a new Mersenne Twister random number generator using the given
217    /// key.
218    ///
219    /// Key can have any length.
220    #[inline]
221    #[must_use]
222    pub fn new_with_key<I>(key: I) -> Self
223    where
224        I: IntoIterator<Item = u32>,
225        I::IntoIter: Clone,
226    {
227        let mut mt = Self {
228            idx: 0,
229            state: [Wrapping(0); N],
230        };
231        mt.reseed_with_key(key);
232        mt
233    }
234
235    /// Create a new Mersenne Twister random number generator using the default
236    /// fixed seed.
237    ///
238    /// # Examples
239    ///
240    /// ```
241    /// # use rand_mt::Mt;
242    /// // Default MT seed
243    /// let seed = 5489_u32;
244    /// let mt = Mt::new(seed);
245    /// let unseeded = Mt::new_unseeded();
246    /// assert_eq!(mt, unseeded);
247    /// ```
248    #[inline]
249    #[must_use]
250    pub fn new_unseeded() -> Self {
251        Self::new(Self::DEFAULT_SEED)
252    }
253
254    /// Generate next `u64` output.
255    ///
256    /// This function is implemented by generating two `u32`s from the RNG and
257    /// performing shifting and masking to turn them into a `u64` output.
258    ///
259    /// # Examples
260    ///
261    /// ```
262    /// # use rand_mt::Mt;
263    /// let mut mt = Mt::new_unseeded();
264    /// assert_ne!(mt.next_u64(), mt.next_u64());
265    /// ```
266    #[inline]
267    pub fn next_u64(&mut self) -> u64 {
268        let out = u64::from(self.next_u32());
269        let out = out << 32;
270        out | u64::from(self.next_u32())
271    }
272
273    /// Generate next `u32` output.
274    ///
275    /// `u32` is the native output of the generator. This function advances the
276    /// RNG step counter by one.
277    ///
278    /// # Examples
279    ///
280    /// ```
281    /// # use rand_mt::Mt;
282    /// let mut mt = Mt::new_unseeded();
283    /// assert_ne!(mt.next_u32(), mt.next_u32());
284    /// ```
285    #[inline]
286    pub fn next_u32(&mut self) -> u32 {
287        // Failing this check indicates that, somehow, the structure
288        // was not initialized.
289        debug_assert!(self.idx != 0);
290        if self.idx >= N {
291            fill_next_state(self);
292        }
293        let Wrapping(x) = self.state[self.idx];
294        self.idx += 1;
295        temper(x)
296    }
297
298    /// Fill a buffer with bytes generated from the RNG.
299    ///
300    /// This method generates random `u32`s (the native output unit of the RNG)
301    /// until `dest` is filled.
302    ///
303    /// This method may discard some output bits if `dest.len()` is not a
304    /// multiple of 4.
305    ///
306    /// # Examples
307    ///
308    /// ```
309    /// # use rand_mt::Mt;
310    /// let mut mt = Mt::new_unseeded();
311    /// let mut buf = [0; 32];
312    /// mt.fill_bytes(&mut buf);
313    /// assert_ne!([0; 32], buf);
314    /// let mut buf = [0; 31];
315    /// mt.fill_bytes(&mut buf);
316    /// assert_ne!([0; 31], buf);
317    /// ```
318    #[inline]
319    pub fn fill_bytes(&mut self, dest: &mut [u8]) {
320        const CHUNK: usize = size_of::<u32>();
321        let mut dest_chunks = dest.chunks_exact_mut(CHUNK);
322
323        for next in &mut dest_chunks {
324            let chunk: [u8; CHUNK] = self.next_u32().to_le_bytes();
325            next.copy_from_slice(&chunk);
326        }
327
328        let remainder = dest_chunks.into_remainder();
329        if remainder.is_empty() {
330            return;
331        }
332        remainder
333            .iter_mut()
334            .zip(self.next_u32().to_le_bytes().iter())
335            .for_each(|(cell, &byte)| {
336                *cell = byte;
337            });
338    }
339
340    /// Attempt to recover the internal state of a Mersenne Twister using the
341    /// past 624 samples.
342    ///
343    /// This conversion takes a history of samples from a RNG and returns a
344    /// RNG that will produce identical output to the RNG that supplied the
345    /// samples.
346    ///
347    /// This constructor is also available as a [`TryFrom`] implementation for
348    /// `&[u32]`.
349    ///
350    /// # Errors
351    ///
352    /// If `key` has less than 624 elements, an error is returned because there
353    /// is not enough data to fully initialize the RNG.
354    ///
355    /// If `key` has more than 624 elements, an error is returned because the
356    /// recovered RNG will not produce identical output to the RNG that supplied
357    /// the samples.
358    #[inline]
359    pub fn recover<I>(key: I) -> Result<Self, RecoverRngError>
360    where
361        I: IntoIterator<Item = u32>,
362    {
363        let mut mt = Self {
364            idx: N,
365            state: [Wrapping(0); N],
366        };
367        let mut state = mt.state.iter_mut();
368        for sample in key {
369            let out = state.next().ok_or(RecoverRngError::TooManySamples(N))?;
370            *out = Wrapping(untemper(sample));
371        }
372        // If the state iterator still has unfilled cells, the given iterator
373        // was too short. If there are no additional cells, return the
374        // initialized RNG.
375        if state.next().is_none() {
376            Ok(mt)
377        } else {
378            Err(RecoverRngError::TooFewSamples(N))
379        }
380    }
381
382    /// Reseed a Mersenne Twister from a single `u32`.
383    ///
384    /// # Examples
385    ///
386    /// ```
387    /// # use rand_mt::Mt;
388    /// // Default MT seed
389    /// let mut mt = Mt::new_unseeded();
390    /// let first = mt.next_u32();
391    /// mt.fill_bytes(&mut [0; 512]);
392    /// // Default MT seed
393    /// mt.reseed(5489_u32);
394    /// assert_eq!(first, mt.next_u32());
395    /// ```
396    #[inline]
397    #[expect(
398        clippy::cast_possible_truncation,
399        reason = "const N is always less than u32::MAX"
400    )]
401    pub fn reseed(&mut self, seed: u32) {
402        self.idx = N;
403        self.state[0] = Wrapping(seed);
404        for i in 1..N {
405            self.state[i] = Wrapping(1_812_433_253)
406                * (self.state[i - 1] ^ (self.state[i - 1] >> 30))
407                + Wrapping(i as u32);
408        }
409    }
410
411    /// Reseed a Mersenne Twister from am iterator of `u32`s.
412    ///
413    /// Key can have any length.
414    #[inline]
415    #[allow(clippy::cast_possible_truncation)]
416    pub fn reseed_with_key<I>(&mut self, key: I)
417    where
418        I: IntoIterator<Item = u32>,
419        I::IntoIter: Clone,
420    {
421        self.reseed(19_650_218_u32);
422        let mut i = 1_usize;
423        for (j, piece) in key.into_iter().enumerate().cycle().take(N) {
424            self.state[i] = (self.state[i]
425                ^ ((self.state[i - 1] ^ (self.state[i - 1] >> 30)) * Wrapping(1_664_525)))
426                + Wrapping(piece)
427                + Wrapping(j as u32);
428            i += 1;
429            if i >= N {
430                self.state[0] = self.state[N - 1];
431                i = 1;
432            }
433        }
434        for _ in 0..N - 1 {
435            self.state[i] = (self.state[i]
436                ^ ((self.state[i - 1] ^ (self.state[i - 1] >> 30)) * Wrapping(1_566_083_941)))
437                - Wrapping(i as u32);
438            i += 1;
439            if i >= N {
440                self.state[0] = self.state[N - 1];
441                i = 1;
442            }
443        }
444        self.state[0] = Wrapping(1 << 31);
445    }
446}
447
448#[inline]
449fn temper(mut x: u32) -> u32 {
450    x ^= x >> 11;
451    x ^= (x << 7) & 0x9d2c_5680;
452    x ^= (x << 15) & 0xefc6_0000;
453    x ^= x >> 18;
454    x
455}
456
457#[inline]
458fn untemper(mut x: u32) -> u32 {
459    // reverse `x ^=  x>>18;`
460    x ^= x >> 18;
461
462    // reverse `x ^= (x<<15) & 0xefc6_0000;`
463    x ^= (x << 15) & 0x2fc6_0000;
464    x ^= (x << 15) & 0xc000_0000;
465
466    // reverse `x ^= (x<< 7) & 0x9d2c_5680;`
467    x ^= (x << 7) & 0x0000_1680;
468    x ^= (x << 7) & 0x000c_4000;
469    x ^= (x << 7) & 0x0d20_0000;
470    x ^= (x << 7) & 0x9000_0000;
471
472    // reverse `x ^=  x>>11;`
473    x ^= x >> 11;
474    x ^= x >> 22;
475
476    x
477}
478
479#[inline]
480fn fill_next_state(rng: &mut Mt) {
481    for i in 0..N - M {
482        let x = (rng.state[i] & UPPER_MASK) | (rng.state[i + 1] & LOWER_MASK);
483        rng.state[i] = rng.state[i + M] ^ (x >> 1) ^ ((x & ONE) * MATRIX_A);
484    }
485    for i in N - M..N - 1 {
486        let x = (rng.state[i] & UPPER_MASK) | (rng.state[i + 1] & LOWER_MASK);
487        rng.state[i] = rng.state[i + M - N] ^ (x >> 1) ^ ((x & ONE) * MATRIX_A);
488    }
489    let x = (rng.state[N - 1] & UPPER_MASK) | (rng.state[0] & LOWER_MASK);
490    rng.state[N - 1] = rng.state[M - 1] ^ (x >> 1) ^ ((x & ONE) * MATRIX_A);
491    rng.idx = 0;
492}
493
494#[cfg(test)]
495mod tests {
496    use core::convert::TryFrom;
497    use core::iter;
498    use core::num::Wrapping;
499
500    use super::{Mt, N};
501    use crate::vectors::mt::{STATE_SEEDED_BY_SLICE, STATE_SEEDED_BY_U32, TEST_OUTPUT};
502    use crate::RecoverRngError;
503
504    #[test]
505    fn seeded_state_from_u32_seed() {
506        let mt = Mt::new(0x1234_5678_u32);
507        let mt_from_seed = Mt::from(0x1234_5678_u32.to_le_bytes());
508        assert_eq!(mt.state, mt_from_seed.state);
509        for (&Wrapping(x), &y) in mt.state.iter().zip(STATE_SEEDED_BY_U32.iter()) {
510            assert_eq!(x, y);
511        }
512        for (&Wrapping(x), &y) in mt_from_seed.state.iter().zip(STATE_SEEDED_BY_U32.iter()) {
513            assert_eq!(x, y);
514        }
515    }
516
517    #[test]
518    fn seeded_state_from_u32_slice_key() {
519        let key = [0x123_u32, 0x234_u32, 0x345_u32, 0x456_u32];
520        let mt = Mt::new_with_key(key.iter().copied());
521        for (&Wrapping(x), &y) in mt.state.iter().zip(STATE_SEEDED_BY_SLICE.iter()) {
522            assert_eq!(x, y);
523        }
524    }
525
526    #[test]
527    fn seed_with_empty_iter_returns() {
528        let _rng = Mt::new_with_key(iter::empty());
529    }
530
531    #[test]
532    fn output_from_u32_slice_key() {
533        let key = [0x123_u32, 0x234_u32, 0x345_u32, 0x456_u32];
534        let mut mt = Mt::new_with_key(key.iter().copied());
535        for &x in &TEST_OUTPUT {
536            assert_eq!(x, mt.next_u32());
537        }
538    }
539
540    #[test]
541    fn temper_untemper_is_identity() {
542        for _ in 0..10_000 {
543            let x = getrandom::u32().unwrap();
544            assert_eq!(x, super::untemper(super::temper(x)));
545        }
546    }
547
548    #[test]
549    fn untemper_temper_is_identity() {
550        for _ in 0..10_000 {
551            let x = getrandom::u32().unwrap();
552            assert_eq!(x, super::temper(super::untemper(x)));
553        }
554    }
555
556    #[test]
557    fn recovery_via_from() {
558        for _ in 0..100 {
559            let seed = getrandom::u32().unwrap();
560            for skip in 0..256 {
561                let mut orig_mt = Mt::new(seed);
562                // skip some samples so the RNG is in an intermediate state
563                for _ in 0..skip {
564                    orig_mt.next_u32();
565                }
566                let mut samples = [0; 624];
567                for sample in &mut samples {
568                    *sample = orig_mt.next_u32();
569                }
570                let mut recovered_mt = Mt::from(samples);
571                for _ in 0..624 * 2 {
572                    assert_eq!(orig_mt.next_u32(), recovered_mt.next_u32());
573                }
574            }
575        }
576    }
577
578    #[test]
579    fn recovery_via_recover() {
580        for _ in 0..100 {
581            let seed = getrandom::u32().unwrap();
582            for skip in 0..256 {
583                let mut orig_mt = Mt::new(seed);
584                // skip some samples so the RNG is in an intermediate state
585                for _ in 0..skip {
586                    orig_mt.next_u32();
587                }
588                let mut samples = [0; 624];
589                for sample in &mut samples {
590                    *sample = orig_mt.next_u32();
591                }
592                let mut recovered_mt = Mt::recover(samples.iter().copied()).unwrap();
593                for _ in 0..624 * 2 {
594                    assert_eq!(orig_mt.next_u32(), recovered_mt.next_u32());
595                }
596            }
597        }
598    }
599
600    #[test]
601    fn recover_required_exact_sample_length_via_from() {
602        assert_eq!(
603            Mt::try_from(&[0; 0][..]),
604            Err(RecoverRngError::TooFewSamples(N))
605        );
606        assert_eq!(
607            Mt::try_from(&[0; 1][..]),
608            Err(RecoverRngError::TooFewSamples(N))
609        );
610        assert_eq!(
611            Mt::try_from(&[0; 623][..]),
612            Err(RecoverRngError::TooFewSamples(N))
613        );
614        Mt::try_from(&[0; 624][..]).unwrap();
615        assert_eq!(
616            Mt::try_from(&[0; 625][..]),
617            Err(RecoverRngError::TooManySamples(N))
618        );
619        assert_eq!(
620            Mt::try_from(&[0; 1000][..]),
621            Err(RecoverRngError::TooManySamples(N))
622        );
623    }
624
625    #[test]
626    fn recover_required_exact_sample_length_via_recover() {
627        assert_eq!(
628            Mt::recover([0; 0].iter().copied()),
629            Err(RecoverRngError::TooFewSamples(N))
630        );
631        assert_eq!(
632            Mt::recover([0; 1].iter().copied()),
633            Err(RecoverRngError::TooFewSamples(N))
634        );
635        assert_eq!(
636            Mt::recover([0; 623].iter().copied()),
637            Err(RecoverRngError::TooFewSamples(N))
638        );
639        Mt::recover([0; 624].iter().copied()).unwrap();
640        assert_eq!(
641            Mt::recover([0; 625].iter().copied()),
642            Err(RecoverRngError::TooManySamples(N))
643        );
644        assert_eq!(
645            Mt::recover([0; 1000].iter().copied()),
646            Err(RecoverRngError::TooManySamples(N))
647        );
648    }
649
650    #[test]
651    fn fmt_debug_does_not_leak_seed() {
652        use core::fmt::Write as _;
653        use std::string::String;
654
655        let random = Mt::new(874);
656
657        let mut buf = String::new();
658        write!(&mut buf, "{random:?}").unwrap();
659        assert!(!buf.contains("874"));
660        assert_eq!(buf, "Mt {}");
661
662        let random = Mt::new(123_456);
663
664        let mut buf = String::new();
665        write!(&mut buf, "{random:?}").unwrap();
666        assert!(!buf.contains("123456"));
667        assert_eq!(buf, "Mt {}");
668    }
669
670    #[test]
671    fn default_is_new_unseeded() {
672        let mut default = Mt::default();
673        let mut unseeded = Mt::new_unseeded();
674
675        assert_eq!(default, unseeded);
676        for _ in 0..1024 {
677            assert_eq!(default.next_u32(), unseeded.next_u32());
678        }
679    }
680}