rand_mt/
mt64.rs

1// src/mt64.rs
2//
3// Copyright (c) 2015,2017 rust-mersenne-twister developers
4// Copyright (c) 2020 Ryan Lopopolo <rjl@hyperbo.la>
5//
6// Licensed under the Apache License, Version 2.0
7// <LICENSE-APACHE> or <http://www.apache.org/licenses/LICENSE-2.0> or the MIT
8// license <LICENSE-MIT> or <http://opensource.org/licenses/MIT>, at your
9// option. All files in the project carrying such notice may not be copied,
10// modified, or distributed except according to those terms.
11
12use core::convert::TryFrom;
13use core::fmt;
14use core::num::Wrapping;
15
16use crate::RecoverRngError;
17
18#[cfg(feature = "rand-traits")]
19mod rand;
20
21const NN: usize = 312;
22const MM: usize = 156;
23const ONE: Wrapping<u64> = Wrapping(1);
24const MATRIX_A: Wrapping<u64> = Wrapping(0xb502_6f5a_a966_19e9);
25const UM: Wrapping<u64> = Wrapping(0xffff_ffff_8000_0000); // Most significant 33 bits
26const LM: Wrapping<u64> = Wrapping(0x7fff_ffff); // Least significant 31 bits
27
28/// The 64-bit flavor of the Mersenne Twister pseudorandom number generator.
29///
30/// # Size
31///
32/// `Mt64` requires approximately 2.5 kilobytes of internal state.
33///
34/// You may wish to store an `Mt64` on the heap in a [`Box`] to make it easier
35/// to embed in another struct.
36///
37/// `Mt64` is also the same size as [`Mt`](crate::Mt).
38///
39/// ```
40/// # use rand_mt::{Mt, Mt64};
41/// assert_eq!(2504, size_of::<Mt64>());
42/// assert_eq!(size_of::<Mt>(), size_of::<Mt64>());
43/// ```
44///
45/// [`Box`]: https://doc.rust-lang.org/std/boxed/struct.Box.html
46#[derive(Clone, Hash, PartialEq, Eq, PartialOrd, Ord)]
47pub struct Mt64 {
48    idx: usize,
49    state: [Wrapping<u64>; NN],
50}
51
52impl fmt::Debug for Mt64 {
53    #[inline]
54    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
55        f.write_str("Mt64 {}")
56    }
57}
58
59impl Default for Mt64 {
60    /// Return a new `Mt64` with the default seed.
61    ///
62    /// Equivalent to calling [`Mt64::new_unseeded`].
63    #[inline]
64    fn default() -> Self {
65        Self::new_unseeded()
66    }
67}
68
69impl From<[u8; 8]> for Mt64 {
70    /// Construct a Mersenne Twister RNG from 8 bytes.
71    ///
72    /// # Examples
73    ///
74    /// ```
75    /// # use rand_mt::Mt64;
76    /// // Default MT seed
77    /// let seed = 5489_u64.to_le_bytes();
78    /// let mut mt = Mt64::from(seed);
79    /// assert_ne!(mt.next_u64(), mt.next_u64());
80    /// ```
81    #[inline]
82    fn from(seed: [u8; 8]) -> Self {
83        Self::new(u64::from_le_bytes(seed))
84    }
85}
86
87impl From<u64> for Mt64 {
88    /// Construct a Mersenne Twister RNG from a `u64` seed.
89    ///
90    /// This function is equivalent to [`new`].
91    ///
92    /// # Examples
93    ///
94    /// ```
95    /// # use rand_mt::Mt64;
96    /// // Default MT seed
97    /// let seed = 5489_u64;
98    /// let mt1 = Mt64::from(seed);
99    /// let mt2 = Mt64::new(seed);
100    /// assert_eq!(mt1, mt2);
101    ///
102    /// // Non-default MT seed
103    /// let seed = 9927_u64;
104    /// let mt1 = Mt64::from(seed);
105    /// let mt2 = Mt64::new(seed);
106    /// assert_eq!(mt1, mt2);
107    /// ```
108    ///
109    /// [`new`]: Self::new
110    #[inline]
111    fn from(seed: u64) -> Self {
112        Self::new(seed)
113    }
114}
115
116impl From<[u64; NN]> for Mt64 {
117    /// Recover the internal state of a Mersenne Twister using the past 312
118    /// samples.
119    ///
120    /// This conversion takes a history of samples from a RNG and returns a
121    /// RNG that will produce identical output to the RNG that supplied the
122    /// samples.
123    #[inline]
124    fn from(key: [u64; NN]) -> Self {
125        let mut mt = Self {
126            idx: NN,
127            state: [Wrapping(0); NN],
128        };
129        for (sample, out) in key.iter().copied().zip(mt.state.iter_mut()) {
130            *out = Wrapping(untemper(sample));
131        }
132        mt
133    }
134}
135
136impl TryFrom<&[u64]> for Mt64 {
137    type Error = RecoverRngError;
138
139    /// Attempt to recover the internal state of a Mersenne Twister using the
140    /// past 312 samples.
141    ///
142    /// This conversion takes a history of samples from a RNG and returns a
143    /// RNG that will produce identical output to the RNG that supplied the
144    /// samples.
145    ///
146    /// This conversion is implemented with [`Mt64::recover`].
147    ///
148    /// # Errors
149    ///
150    /// If `key` has less than 312 elements, an error is returned because there
151    /// is not enough data to fully initialize the RNG.
152    ///
153    /// If `key` has more than 312 elements, an error is returned because the
154    /// recovered RNG will not produce identical output to the RNG that supplied
155    /// the samples.
156    #[inline]
157    fn try_from(key: &[u64]) -> Result<Self, Self::Error> {
158        Self::recover(key.iter().copied())
159    }
160}
161
162impl Mt64 {
163    /// Default seed used by [`Mt64::new_unseeded`].
164    pub const DEFAULT_SEED: u64 = 5489_u64;
165
166    /// Create a new Mersenne Twister random number generator using the given
167    /// seed.
168    ///
169    /// # Examples
170    ///
171    /// ## Constructing with a `u64` seed
172    ///
173    /// ```
174    /// # use rand_mt::Mt64;
175    /// let seed = 123_456_789_u64;
176    /// let mt1 = Mt64::new(seed);
177    /// let mt2 = Mt64::from(seed.to_le_bytes());
178    /// assert_eq!(mt1, mt2);
179    /// ```
180    ///
181    /// ## Constructing with default seed
182    ///
183    /// ```
184    /// # use rand_mt::Mt64;
185    /// let mt1 = Mt64::new(Mt64::DEFAULT_SEED);
186    /// let mt2 = Mt64::new_unseeded();
187    /// assert_eq!(mt1, mt2);
188    /// ```
189    #[inline]
190    #[must_use]
191    pub fn new(seed: u64) -> Self {
192        let mut mt = Self {
193            idx: 0,
194            state: [Wrapping(0); NN],
195        };
196        mt.reseed(seed);
197        mt
198    }
199
200    /// Create a new Mersenne Twister random number generator using the given
201    /// key.
202    ///
203    /// Key can have any length.
204    #[inline]
205    #[must_use]
206    pub fn new_with_key<I>(key: I) -> Self
207    where
208        I: IntoIterator<Item = u64>,
209        I::IntoIter: Clone,
210    {
211        let mut mt = Self {
212            idx: 0,
213            state: [Wrapping(0); NN],
214        };
215        mt.reseed_with_key(key);
216        mt
217    }
218
219    /// Create a new Mersenne Twister random number generator using the default
220    /// fixed seed.
221    ///
222    /// # Examples
223    ///
224    /// ```
225    /// # use rand_mt::Mt64;
226    /// // Default MT seed
227    /// let seed = 5489_u64;
228    /// let mt = Mt64::new(seed);
229    /// let unseeded = Mt64::new_unseeded();
230    /// assert_eq!(mt, unseeded);
231    /// ```
232    #[inline]
233    #[must_use]
234    pub fn new_unseeded() -> Self {
235        Self::new(Self::DEFAULT_SEED)
236    }
237
238    /// Generate next `u64` output.
239    ///
240    /// `u64` is the native output of the generator. This function advances the
241    /// RNG step counter by one.
242    ///
243    /// # Examples
244    ///
245    /// ```
246    /// # use rand_mt::Mt64;
247    /// let mut mt = Mt64::new_unseeded();
248    /// assert_ne!(mt.next_u64(), mt.next_u64());
249    /// ```
250    #[inline]
251    pub fn next_u64(&mut self) -> u64 {
252        // Failing this check indicates that, somehow, the structure
253        // was not initialized.
254        debug_assert!(self.idx != 0);
255        if self.idx >= NN {
256            fill_next_state(self);
257        }
258        let Wrapping(x) = self.state[self.idx];
259        self.idx += 1;
260        temper(x)
261    }
262
263    /// Generate next `u32` output.
264    ///
265    /// This function is implemented by generating one `u64` from the RNG and
266    /// performing shifting and masking to turn it into a `u32` output.
267    ///
268    /// # Examples
269    ///
270    /// ```
271    /// # use rand_mt::Mt64;
272    /// let mut mt = Mt64::new_unseeded();
273    /// assert_ne!(mt.next_u32(), mt.next_u32());
274    /// ```
275    #[inline]
276    #[expect(
277        clippy::cast_possible_truncation,
278        reason = "Mt64 natively generates 64-bit ints, so truncate to yield a u32"
279    )]
280    pub fn next_u32(&mut self) -> u32 {
281        self.next_u64() as u32
282    }
283
284    /// Fill a buffer with bytes generated from the RNG.
285    ///
286    /// This method generates random `u64`s (the native output unit of the RNG)
287    /// until `dest` is filled.
288    ///
289    /// This method may discard some output bits if `dest.len()` is not a
290    /// multiple of 8.
291    ///
292    /// # Examples
293    ///
294    /// ```
295    /// # use rand_mt::Mt64;
296    /// let mut mt = Mt64::new_unseeded();
297    /// let mut buf = [0; 32];
298    /// mt.fill_bytes(&mut buf);
299    /// assert_ne!([0; 32], buf);
300    /// let mut buf = [0; 31];
301    /// mt.fill_bytes(&mut buf);
302    /// assert_ne!([0; 31], buf);
303    /// ```
304    #[inline]
305    pub fn fill_bytes(&mut self, dest: &mut [u8]) {
306        const CHUNK: usize = size_of::<u64>();
307        let mut dest_chunks = dest.chunks_exact_mut(CHUNK);
308
309        for next in &mut dest_chunks {
310            let chunk: [u8; CHUNK] = self.next_u64().to_le_bytes();
311            next.copy_from_slice(&chunk);
312        }
313
314        let remainder = dest_chunks.into_remainder();
315        if remainder.is_empty() {
316            return;
317        }
318        remainder
319            .iter_mut()
320            .zip(self.next_u64().to_le_bytes().iter())
321            .for_each(|(cell, &byte)| {
322                *cell = byte;
323            });
324    }
325
326    /// Attempt to recover the internal state of a Mersenne Twister using the
327    /// past 312 samples.
328    ///
329    /// This conversion takes a history of samples from a RNG and returns a
330    /// RNG that will produce identical output to the RNG that supplied the
331    /// samples.
332    ///
333    /// This constructor is also available as a [`TryFrom`] implementation for
334    /// `&[u32]`.
335    ///
336    /// # Errors
337    ///
338    /// If `key` has less than 312 elements, an error is returned because there
339    /// is not enough data to fully initialize the RNG.
340    ///
341    /// If `key` has more than 312 elements, an error is returned because the
342    /// recovered RNG will not produce identical output to the RNG that supplied
343    /// the samples.
344    #[inline]
345    pub fn recover<I>(key: I) -> Result<Self, RecoverRngError>
346    where
347        I: IntoIterator<Item = u64>,
348    {
349        let mut mt = Self {
350            idx: NN,
351            state: [Wrapping(0); NN],
352        };
353        let mut state = mt.state.iter_mut();
354        for sample in key {
355            let out = state.next().ok_or(RecoverRngError::TooManySamples(NN))?;
356            *out = Wrapping(untemper(sample));
357        }
358        // If the state iterator still has unfilled cells, the given iterator
359        // was too short. If there are no additional cells, return the
360        // initialized RNG.
361        if state.next().is_none() {
362            Ok(mt)
363        } else {
364            Err(RecoverRngError::TooFewSamples(NN))
365        }
366    }
367
368    /// Reseed a Mersenne Twister from a single `u64`.
369    ///
370    /// # Examples
371    ///
372    /// ```
373    /// # use rand_mt::Mt64;
374    /// // Default MT seed
375    /// let mut mt = Mt64::new_unseeded();
376    /// let first = mt.next_u64();
377    /// mt.fill_bytes(&mut [0; 512]);
378    /// // Default MT seed
379    /// mt.reseed(5489_u64);
380    /// assert_eq!(first, mt.next_u64());
381    /// ```
382    #[inline]
383    pub fn reseed(&mut self, seed: u64) {
384        self.idx = NN;
385        self.state[0] = Wrapping(seed);
386        for i in 1..NN {
387            self.state[i] = Wrapping(6_364_136_223_846_793_005)
388                * (self.state[i - 1] ^ (self.state[i - 1] >> 62))
389                + Wrapping(i as u64);
390        }
391    }
392
393    /// Reseed a Mersenne Twister from am iterator of `u64`s.
394    ///
395    /// Key can have any length.
396    #[inline]
397    pub fn reseed_with_key<I>(&mut self, key: I)
398    where
399        I: IntoIterator<Item = u64>,
400        I::IntoIter: Clone,
401    {
402        self.reseed(19_650_218_u64);
403        let mut i = 1_usize;
404        for (j, piece) in key.into_iter().enumerate().cycle().take(NN) {
405            self.state[i] = (self.state[i]
406                ^ ((self.state[i - 1] ^ (self.state[i - 1] >> 62))
407                    * Wrapping(3_935_559_000_370_003_845)))
408                + Wrapping(piece)
409                + Wrapping(j as u64);
410            i += 1;
411            if i >= NN {
412                self.state[0] = self.state[NN - 1];
413                i = 1;
414            }
415        }
416        for _ in 0..NN - 1 {
417            self.state[i] = (self.state[i]
418                ^ ((self.state[i - 1] ^ (self.state[i - 1] >> 62))
419                    * Wrapping(2_862_933_555_777_941_757)))
420                - Wrapping(i as u64);
421            i += 1;
422            if i >= NN {
423                self.state[0] = self.state[NN - 1];
424                i = 1;
425            }
426        }
427        self.state[0] = Wrapping(1 << 63);
428    }
429}
430
431#[inline]
432fn temper(mut x: u64) -> u64 {
433    x ^= (x >> 29) & 0x5555_5555_5555_5555;
434    x ^= (x << 17) & 0x71d6_7fff_eda6_0000;
435    x ^= (x << 37) & 0xfff7_eee0_0000_0000;
436    x ^= x >> 43;
437    x
438}
439
440#[inline]
441fn untemper(mut x: u64) -> u64 {
442    // reverse `x ^=  x >> 43;`
443    x ^= x >> 43;
444
445    // reverse `x ^= (x << 37) & 0xfff7_eee0_0000_0000;`
446    x ^= (x << 37) & 0xfff7_eee0_0000_0000;
447
448    // reverse `x ^= (x << 17) & 0x71d6_7fff_eda6_0000;`
449    x ^= (x << 17) & 0x0000_0003_eda6_0000;
450    x ^= (x << 17) & 0x0006_7ffc_0000_0000;
451    x ^= (x << 17) & 0x71d0_0000_0000_0000;
452
453    // reverse `x ^= (x >> 29) & 0x5555_5555_5555_5555;`
454    x ^= (x >> 29) & 0x0000_0005_5555_5540;
455    x ^= (x >> 29) & 0x0000_0000_0000_0015;
456
457    x
458}
459
460#[inline]
461fn fill_next_state(rng: &mut Mt64) {
462    for i in 0..NN - MM {
463        let x = (rng.state[i] & UM) | (rng.state[i + 1] & LM);
464        rng.state[i] = rng.state[i + MM] ^ (x >> 1) ^ ((x & ONE) * MATRIX_A);
465    }
466    for i in NN - MM..NN - 1 {
467        let x = (rng.state[i] & UM) | (rng.state[i + 1] & LM);
468        rng.state[i] = rng.state[i + MM - NN] ^ (x >> 1) ^ ((x & ONE) * MATRIX_A);
469    }
470    let x = (rng.state[NN - 1] & UM) | (rng.state[0] & LM);
471    rng.state[NN - 1] = rng.state[MM - 1] ^ (x >> 1) ^ ((x & ONE) * MATRIX_A);
472    rng.idx = 0;
473}
474
475#[cfg(test)]
476mod tests {
477    use core::convert::TryFrom;
478    use core::num::Wrapping;
479
480    use super::{Mt64, NN};
481    use crate::vectors::mt64::{STATE_SEEDED_BY_SLICE, STATE_SEEDED_BY_U64, TEST_OUTPUT};
482    use crate::RecoverRngError;
483
484    #[test]
485    fn seeded_state_from_u64_seed() {
486        let mt = Mt64::new(0x0123_4567_89ab_cdef_u64);
487        let mt_from_seed = Mt64::from(0x0123_4567_89ab_cdef_u64.to_le_bytes());
488        assert_eq!(mt.state, mt_from_seed.state);
489        for (&Wrapping(x), &y) in mt.state.iter().zip(STATE_SEEDED_BY_U64.iter()) {
490            assert_eq!(x, y);
491        }
492        for (&Wrapping(x), &y) in mt_from_seed.state.iter().zip(STATE_SEEDED_BY_U64.iter()) {
493            assert_eq!(x, y);
494        }
495    }
496
497    #[test]
498    fn seeded_state_from_u64_slice_key() {
499        let key = [0x12345_u64, 0x23456_u64, 0x34567_u64, 0x45678_u64];
500        let mt = Mt64::new_with_key(key.iter().copied());
501        for (&Wrapping(x), &y) in mt.state.iter().zip(STATE_SEEDED_BY_SLICE.iter()) {
502            assert_eq!(x, y);
503        }
504    }
505
506    #[test]
507    fn seed_with_empty_iter_returns() {
508        let _rng = Mt64::new_with_key(core::iter::empty());
509    }
510
511    #[test]
512    fn output_from_u64_slice_key() {
513        let key = [0x12345_u64, 0x23456_u64, 0x34567_u64, 0x45678_u64];
514        let mut mt = Mt64::new_with_key(key.iter().copied());
515        for &x in &TEST_OUTPUT {
516            assert_eq!(x, mt.next_u64());
517        }
518    }
519
520    #[test]
521    fn temper_untemper_is_identity() {
522        for _ in 0..10_000 {
523            let x = getrandom::u64().unwrap();
524            assert_eq!(x, super::untemper(super::temper(x)));
525        }
526    }
527
528    #[test]
529    fn untemper_temper_is_identity() {
530        for _ in 0..10_000 {
531            let x = getrandom::u64().unwrap();
532            assert_eq!(x, super::temper(super::untemper(x)));
533        }
534    }
535
536    #[test]
537    fn recovery_via_from() {
538        for _ in 0..100 {
539            let seed = getrandom::u64().unwrap();
540            for skip in 0..256 {
541                let mut orig_mt = Mt64::new(seed);
542                // skip some samples so the RNG is in an intermediate state
543                for _ in 0..skip {
544                    orig_mt.next_u64();
545                }
546                let mut samples = [0; 312];
547                for sample in &mut samples {
548                    *sample = orig_mt.next_u64();
549                }
550                let mut recovered_mt = Mt64::from(samples);
551                for _ in 0..312 * 2 {
552                    assert_eq!(orig_mt.next_u64(), recovered_mt.next_u64());
553                }
554            }
555        }
556    }
557
558    #[test]
559    fn recovery_via_recover() {
560        for _ in 0..100 {
561            let seed = getrandom::u64().unwrap();
562            for skip in 0..256 {
563                let mut orig_mt = Mt64::new(seed);
564                // skip some samples so the RNG is in an intermediate state
565                for _ in 0..skip {
566                    orig_mt.next_u64();
567                }
568                let mut samples = [0; 312];
569                for sample in &mut samples {
570                    *sample = orig_mt.next_u64();
571                }
572                let mut recovered_mt = Mt64::recover(samples.iter().copied()).unwrap();
573                for _ in 0..312 * 2 {
574                    assert_eq!(orig_mt.next_u64(), recovered_mt.next_u64());
575                }
576            }
577        }
578    }
579
580    #[test]
581    fn recover_required_exact_sample_length_via_from() {
582        assert_eq!(
583            Mt64::try_from(&[0; 0][..]),
584            Err(RecoverRngError::TooFewSamples(NN))
585        );
586        assert_eq!(
587            Mt64::try_from(&[0; 1][..]),
588            Err(RecoverRngError::TooFewSamples(NN))
589        );
590        assert_eq!(
591            Mt64::try_from(&[0; 311][..]),
592            Err(RecoverRngError::TooFewSamples(NN))
593        );
594        Mt64::try_from(&[0; 312][..]).unwrap();
595        assert_eq!(
596            Mt64::try_from(&[0; 313][..]),
597            Err(RecoverRngError::TooManySamples(NN))
598        );
599        assert_eq!(
600            Mt64::try_from(&[0; 1000][..]),
601            Err(RecoverRngError::TooManySamples(NN))
602        );
603    }
604
605    #[test]
606    fn recover_required_exact_sample_length_via_recover() {
607        assert_eq!(
608            Mt64::recover([0; 0].iter().copied()),
609            Err(RecoverRngError::TooFewSamples(NN))
610        );
611        assert_eq!(
612            Mt64::recover([0; 1].iter().copied()),
613            Err(RecoverRngError::TooFewSamples(NN))
614        );
615        assert_eq!(
616            Mt64::recover([0; 311].iter().copied()),
617            Err(RecoverRngError::TooFewSamples(NN))
618        );
619        Mt64::recover([0; 312].iter().copied()).unwrap();
620        assert_eq!(
621            Mt64::recover([0; 313].iter().copied()),
622            Err(RecoverRngError::TooManySamples(NN))
623        );
624        assert_eq!(
625            Mt64::recover([0; 1000].iter().copied()),
626            Err(RecoverRngError::TooManySamples(NN))
627        );
628    }
629
630    #[test]
631    fn fmt_debug_does_not_leak_seed() {
632        use core::fmt::Write as _;
633        use std::string::String;
634
635        let random = Mt64::new(874);
636
637        let mut buf = String::new();
638        write!(&mut buf, "{random:?}").unwrap();
639        assert!(!buf.contains("874"));
640        assert_eq!(buf, "Mt64 {}");
641
642        let random = Mt64::new(123_456);
643
644        let mut buf = String::new();
645        write!(&mut buf, "{random:?}").unwrap();
646        assert!(!buf.contains("123456"));
647        assert_eq!(buf, "Mt64 {}");
648    }
649
650    #[test]
651    fn default_is_new_unseeded() {
652        let mut default = Mt64::default();
653        let mut unseeded = Mt64::new_unseeded();
654
655        assert_eq!(default, unseeded);
656        for _ in 0..1024 {
657            assert_eq!(default.next_u64(), unseeded.next_u64());
658        }
659    }
660}