module System.Random.Random123.Threefry (
threefry2, threefry4, threefry2R, threefry4R, ThreefryWord) where
import Data.Word
import Data.Bits
import Data.Array.Base
import System.Random.Random123.Types
import System.Random.Random123.Misc
type RotationConstants = UArray Int Int
getRotationConstant :: RotationConstants -> Int -> Int
getRotationConstant = unsafeAt
class ThreefryWord a where
parityConstant :: a
rotationConstant2 :: a -> RotationConstants
rotationConstant4_0 :: a -> RotationConstants
rotationConstant4_1 :: a -> RotationConstants
instance ThreefryWord Word32 where
parityConstant = 0x1BD11BDA
rotationConstant2 _ = listArray (0, 7) [13, 15, 26, 6, 17, 29, 16, 24]
rotationConstant4_0 _ = listArray (0, 7) [10, 11, 13, 23, 6, 17, 25, 18]
rotationConstant4_1 _ = listArray (0, 7) [26, 21, 27, 5, 20, 11, 10, 20]
instance ThreefryWord Word64 where
parityConstant = 0x1BD11BDAA9FC1A22
rotationConstant2 _ = listArray (0, 7) [16, 42, 12, 31, 16, 32, 24, 21]
rotationConstant4_0 _ = listArray (0, 7) [14, 52, 23, 5, 25, 46, 58, 32]
rotationConstant4_1 _ = listArray (0, 7) [16, 57, 40, 37, 33, 12, 22, 32]
sbox' :: (Num a, Bits a) => Int -> (a -> RotationConstants) -> Array2 a -> Array2 a
sbox' r r_constant (!x0, !x1) = (x0', x1') where
rot = getRotationConstant (r_constant (undefined :: a)) (r `mod` 8)
x0' = x0 + x1
x1' = x0' `xor` (x1 `rotate` rot)
sbox2 :: (ThreefryWord a, Bits a, Num a) => Int -> Array2 a -> Array2 a
sbox2 r = sbox' r rotationConstant2
sbox4 :: (ThreefryWord a, Bits a, Num a) => Int -> Array4 a -> Array4 a
sbox4 r (!x0, !x1, !x2, !x3) = (x0', x1', x2', x3') where
(xa, xb) = if r `mod` 2 == 0 then (x1, x3) else (x3, x1)
(x0', xa') = sbox' r rotationConstant4_0 (x0, xa)
(x2', xb') = sbox' r rotationConstant4_1 (x2, xb)
(x1', x3') = if r `mod` 2 == 0 then (xa', xb') else (xb', xa')
shiftTuple2 :: Int -> (a, a, a) -> Array2 a
shiftTuple2 i (k0, k1, k2)
| remainder == 0 = (k0, k1)
| remainder == 1 = (k1, k2)
| otherwise = (k2, k0)
where
remainder = i `mod` 3
shiftTuple4 :: Int -> (a, a, a, a, a) -> Array4 a
shiftTuple4 i (k0, k1, k2, k3, k4)
| remainder == 0 = (k0, k1, k2, k3)
| remainder == 1 = (k1, k2, k3, k4)
| remainder == 2 = (k2, k3, k4, k0)
| remainder == 3 = (k3, k4, k0, k1)
| otherwise = (k4, k0, k1, k2)
where
remainder = i `mod` 5
addTuple2 :: Num a => Array2 a -> Array2 a -> Array2 a
addTuple2 (k0, k1) (x0, x1) = (k0 + x0, k1 + x1)
addTuple4 :: Num a => Array4 a -> Array4 a -> Array4 a
addTuple4 (k0, k1, k2, k3) (x0, x1, x2, x3) = (k0 + x0, k1 + x1, k2 + x2, k3 + x3)
pbox2 :: (Num a, Bits a) => (a, a, a) -> Int -> Array2 a -> Array2 a
pbox2 extended_key r x = (x0', x1' + fromIntegral tshift) where
tshift = r `div` 4 + 1
(x0', x1') = addTuple2 x (shiftTuple2 tshift extended_key)
pbox4 :: (Num a, Bits a) => (a, a, a, a, a) -> Int -> Array4 a -> Array4 a
pbox4 extended_key r x = (x0', x1', x2', x3' + fromIntegral tshift) where
tshift = r `div` 4 + 1
(x0', x1', x2', x3') = addTuple4 x (shiftTuple4 tshift extended_key)
threefryRound :: (Int -> c -> c) -> (Int -> c -> c) -> Int -> c -> c
threefryRound pbox sbox r x = if r `mod` 4 == 3
then pbox r (sbox r x)
else sbox r x
extendKey2 :: (ThreefryWord a, Bits a) => Array2 a -> (a, a, a)
extendKey2 (k0, k1) = (k0, k1, k0 `xor` k1 `xor` parityConstant)
extendKey4 :: (ThreefryWord a, Bits a) => Array4 a -> (a, a, a, a, a)
extendKey4 (k0, k1, k2, k3) = (k0, k1, k2, k3, k0 `xor` k1 `xor` k2 `xor` k3 `xor` parityConstant)
threefry2R :: (ThreefryWord a, Bits a, Num a)
=> Int
-> Array2 a
-> Array2 a
-> Array2 a
threefry2R rounds key ctr
| (rounds >= 1) && (rounds <= 32) = apply (threefryRound pbox sbox2) rounds starting_x
| otherwise = error "The number of rounds in Threefry-2 must be between 1 and 32"
where
starting_x = addTuple2 key ctr
pbox = pbox2 (extendKey2 key)
threefry4R :: (ThreefryWord a, Bits a, Num a)
=> Int
-> Array4 a
-> Array4 a
-> Array4 a
threefry4R rounds key ctr
| (rounds >= 1) && (rounds <= 72) = apply (threefryRound pbox sbox4) rounds starting_x
| otherwise = error "The number of rounds in Threefry-4 must be between 1 and 72"
where
starting_x = addTuple4 key ctr
pbox = pbox4 (extendKey4 key)
threefry2 :: (ThreefryWord a, Bits a, Num a)
=> Array2 a
-> Array2 a
-> Array2 a
threefry2 = threefry2R 20
threefry4 :: (ThreefryWord a, Bits a, Num a)
=> Array4 a
-> Array4 a
-> Array4 a
threefry4 = threefry4R 20