module MT19937.Internal where

import Data.Bits
import Data.Word ( Word32 )
import Data.Vector.Unboxed.Mutable qualified as VUM

-- | MT19937 tempering function.
temper :: (Num a, Bits a) => a -> a
temper :: forall a. (Num a, Bits a) => a -> a
temper a
x = a
z
  where
    y1 :: a
y1 = a
x  a -> a -> a
forall a. Bits a => a -> a -> a
`xor` ((a
x  a -> Int -> a
forall a. Bits a => a -> Int -> a
`shiftR` Int
u) a -> a -> a
forall a. Bits a => a -> a -> a
.&. a
d)
    y2 :: a
y2 = a
y1 a -> a -> a
forall a. Bits a => a -> a -> a
`xor` ((a
y1 a -> Int -> a
forall a. Bits a => a -> Int -> a
`shiftL` Int
s) a -> a -> a
forall a. Bits a => a -> a -> a
.&. a
b)
    y3 :: a
y3 = a
y2 a -> a -> a
forall a. Bits a => a -> a -> a
`xor` ((a
y2 a -> Int -> a
forall a. Bits a => a -> Int -> a
`shiftL` Int
t) a -> a -> a
forall a. Bits a => a -> a -> a
.&. a
c)
    z :: a
z  = a
y3 a -> a -> a
forall a. Bits a => a -> a -> a
`xor`  (a
y3 a -> Int -> a
forall a. Bits a => a -> Int -> a
`shiftR` Int
l)
    u :: Int
u = Int
11
    d :: a
d = a
0xFFFFFFFF
    s :: Int
s = Int
7
    b :: a
b = a
0x9D2C5680
    t :: Int
t = Int
15
    c :: a
c = a
0xEFC60000
    l :: Int
l = Int
18

-- | Twist an MT19937 state vector.
twist :: VUM.PrimMonad m => VUM.MVector (VUM.PrimState m) Word32 -> m ()
twist :: forall (m :: Type -> Type).
PrimMonad m =>
MVector (PrimState m) Word32 -> m ()
twist MVector (PrimState m) Word32
mt = Int -> m ()
go Int
0
  where
    m :: Int
m = Int
397
    a :: Word32
a = Word32
0x9908B0DF
    go :: Int -> m ()
go = \case
      Int
624 -> () -> m ()
forall a. a -> m a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure ()
      Int
i   -> do
        Word32
mti  <- MVector (PrimState m) Word32 -> Int -> m Word32
forall (m :: Type -> Type) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
VUM.unsafeRead MVector (PrimState m) Word32
mt Int
i
        Word32
mti1 <- MVector (PrimState m) Word32 -> Int -> m Word32
forall (m :: Type -> Type) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
VUM.unsafeRead MVector (PrimState m) Word32
mt ((Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
624)
        Word32
mtim <- MVector (PrimState m) Word32 -> Int -> m Word32
forall (m :: Type -> Type) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
VUM.unsafeRead MVector (PrimState m) Word32
mt ((Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
m) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
624)
        let x :: Word32
x    = (Word32
mti Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.&. Word32
0x80000000) Word32 -> Word32 -> Word32
forall a. Num a => a -> a -> a
+ (Word32
mti1 Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.&. Word32
0x7FFFFFFF)
            mti' :: Word32
mti' = Word32
mtim Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
`xor` (Word32
x Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`shiftR` Int
1)
        if   Word32
x Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.&. Word32
1 Word32 -> Word32 -> Bool
forall a. Eq a => a -> a -> Bool
== Word32
0
        then MVector (PrimState m) Word32 -> Int -> Word32 -> m ()
forall (m :: Type -> Type) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
VUM.unsafeWrite MVector (PrimState m) Word32
mt Int
i Word32
mti'
        else MVector (PrimState m) Word32 -> Int -> Word32 -> m ()
forall (m :: Type -> Type) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
VUM.unsafeWrite MVector (PrimState m) Word32
mt Int
i (Word32
mti' Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
`xor` Word32
a)
        Int -> m ()
go (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)