module Cryptography.Wring
  ( Wring
  , mix3Parts -- reexported for use in Cryptanalysis.hs
  , permut8 -- reexported for testing
  , mul65537 -- "
  , wringName -- For events in traces when profiling
  , xorn -- Used for big hash test
  , linearWring -- Only for cryptanalysis and testing
  , keyedWring
  , encrypt
  , decrypt
  , encryptFixed -- Only for cryptanalysis and testing
  ) where

{- This cipher is intended to be used with short random keys (32 bytes or less,
 - no hard limit) or long human-readable keys (up to 96 bytes). keyedWring
 - takes arbitrarily long keys, but do not use keys longer than 96 bytes as they
 - make the cipher more vulnerable to related-key attacks.
 -
 - The Haskell implementation needs four times as much RAM as the message size,
 - plus a constant overhead.
 -}

import Cryptography.WringTwistree.Mix3
import Cryptography.WringTwistree.RotBitcount
import Cryptography.WringTwistree.Sboxes
import Text.Printf
import Data.Word
import Data.Bits
import Data.Foldable (foldl')
import qualified Data.ByteString as B
import qualified Data.Vector.Unboxed as V
import qualified Data.Vector.Unboxed.Mutable as MV
import Control.Monad.ST
import Control.Monad

data Wring = Wring
  { Wring -> SBox
sbox    :: SBox
  , Wring -> SBox
invSbox :: SBox
  } deriving Int -> Wring -> ShowS
[Wring] -> ShowS
Wring -> String
(Int -> Wring -> ShowS)
-> (Wring -> String) -> ([Wring] -> ShowS) -> Show Wring
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Wring -> ShowS
showsPrec :: Int -> Wring -> ShowS
$cshow :: Wring -> String
show :: Wring -> String
$cshowList :: [Wring] -> ShowS
showList :: [Wring] -> ShowS
Show

-- | Generates a name from the first four bytes of the S-box.
-- Used to tag events in a profiling log.
wringName :: Wring -> String
wringName :: Wring -> String
wringName Wring
wring = String -> Word8 -> Word8 -> Word8 -> Word8 -> String
forall r. PrintfType r => String -> r
printf String
"%02x-%02x-%02x-%02x"
  ((Wring -> SBox
sbox Wring
wring) SBox -> Int -> Word8
forall a. Unbox a => Vector a -> Int -> a
V.! Int
0)
  ((Wring -> SBox
sbox Wring
wring) SBox -> Int -> Word8
forall a. Unbox a => Vector a -> Int -> a
V.! Int
1)
  ((Wring -> SBox
sbox Wring
wring) SBox -> Int -> Word8
forall a. Unbox a => Vector a -> Int -> a
V.! Int
2)
  ((Wring -> SBox
sbox Wring
wring) SBox -> Int -> Word8
forall a. Unbox a => Vector a -> Int -> a
V.! Int
3)

nRounds :: Integral a => a -> a
nRounds :: forall a. Integral a => a -> a
nRounds a
len
  | a
len a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
3 = a
3
  | Bool
otherwise = (a -> a
forall a. Integral a => a -> a
nRounds (a
len a -> a -> a
forall a. Integral a => a -> a -> a
`div` a
3)) a -> a -> a
forall a. Num a => a -> a -> a
+a
1

-- | Exclusive-ors all bytes in a nonnegative number. The only reason this
-- function is public is that it's used to generate a long `ByteString`
-- for a test.
xorn :: (Integral a,Bits a) => a -> Word8
xorn :: forall a. (Integral a, Bits a) => a -> Word8
xorn a
0 = Word8
0
xorn (-1) = String -> Word8
forall a. HasCallStack => String -> a
error String
"xorn: negative"
xorn a
a = (a -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
a) Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
`xor` (a -> Word8
forall a. (Integral a, Bits a) => a -> Word8
xorn (a
a a -> Int -> a
forall a. Bits a => a -> Int -> a
.>>. Int
8))

xornArray :: Int -> V.Vector Word8
xornArray :: Int -> SBox
xornArray Int
n = Int -> [Word8] -> SBox
forall a. Unbox a => Int -> [a] -> Vector a
V.fromListN Int
n ((Int -> Word8) -> [Int] -> [Word8]
forall a b. (a -> b) -> [a] -> [b]
map Int -> Word8
forall a. (Integral a, Bits a) => a -> Word8
xorn [Int
0..(Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1)])

-- | A `Wring` with linear S-boxes. Used only for testing and cryptanalysis.
linearWring :: Wring
linearWring = SBox -> SBox -> Wring
Wring SBox
linearSbox SBox
linearInvSbox

-- | Creates a `Wring` with the given key.
-- To convert a `String` to a `ByteString`, put @- utf8-string@ in your
-- package.yaml dependencies, @import Data.ByteString.UTF8@, and use
-- `fromString`.
keyedWring :: B.ByteString -> Wring
keyedWring :: ByteString -> Wring
keyedWring ByteString
key = SBox -> SBox -> Wring
Wring SBox
sbox (SBox -> SBox
invert SBox
sbox)
 where
  sbox :: SBox
sbox = ByteString -> SBox
sboxes ByteString
key

{- A round of encryption consists of four steps:
 - mix3Parts
 - sboxes
 - rotBitcount
 - add byte index xor round number
 -}

-- Original purely functional version, modified to use vectors

roundEncryptFun ::
  Int ->
  SBox ->
  V.Vector Word8 ->
  V.Vector Word8 ->
  Int ->
  V.Vector Word8
roundEncryptFun :: Int -> SBox -> SBox -> SBox -> Int -> SBox
roundEncryptFun Int
rprime SBox
sbox SBox
xornary SBox
buf Int
rond = SBox
i4 where
  len :: Int
len = SBox -> Int
forall a. Unbox a => Vector a -> Int
V.length SBox
buf
  xornrond :: Word8
xornrond = Int -> Word8
forall a. (Integral a, Bits a) => a -> Word8
xorn Int
rond
  i1 :: SBox
i1 = SBox -> Int -> SBox
mix3Parts SBox
buf Int
rprime
  i2 :: SBox
i2 = Int -> [Word8] -> SBox
forall a. Unbox a => Int -> [a] -> Vector a
V.fromListN Int
len ([Word8] -> SBox) -> [Word8] -> SBox
forall a b. (a -> b) -> a -> b
$ (Int -> Word8) -> [Int] -> [Word8]
forall a b. (a -> b) -> [a] -> [b]
map (SBox
sbox SBox -> Int -> Word8
forall a. Unbox a => Vector a -> Int -> a
V.!) ([Int] -> [Word8]) -> [Int] -> [Word8]
forall a b. (a -> b) -> a -> b
$
    (Word8 -> Word8 -> Int) -> [Word8] -> [Word8] -> [Int]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Word8 -> Word8 -> Int
forall a b. (Integral a, Integral b) => a -> b -> Int
sboxInx (Int -> [Word8] -> [Word8]
forall a. Int -> [a] -> [a]
drop Int
rond [Word8]
cycle3) (SBox -> [Word8]
forall a. Unbox a => Vector a -> [a]
V.toList SBox
i1)
  i3 :: SBox
i3 = SBox -> Int -> SBox
rotBitcount SBox
i2 Int
1
  i4 :: SBox
i4 = Int -> [Word8] -> SBox
forall a. Unbox a => Int -> [a] -> Vector a
V.fromListN Int
len ([Word8] -> SBox) -> [Word8] -> SBox
forall a b. (a -> b) -> a -> b
$ (Word8 -> Word8 -> Word8) -> [Word8] -> [Word8] -> [Word8]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Word8 -> Word8 -> Word8
forall a. Num a => a -> a -> a
(+) (SBox -> [Word8]
forall a. Unbox a => Vector a -> [a]
V.toList SBox
i3)
    ((Word8 -> Word8) -> [Word8] -> [Word8]
forall a b. (a -> b) -> [a] -> [b]
map (Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
xor Word8
xornrond) (SBox -> [Word8]
forall a. Unbox a => Vector a -> [a]
V.toList SBox
xornary))

roundDecryptFun ::
  Int ->
  SBox ->
  V.Vector Word8 ->
  V.Vector Word8 ->
  Int ->
  V.Vector Word8
roundDecryptFun :: Int -> SBox -> SBox -> SBox -> Int -> SBox
roundDecryptFun Int
rprime SBox
sbox SBox
xornary SBox
buf Int
rond = SBox
i4 where
  len :: Int
len = SBox -> Int
forall a. Unbox a => Vector a -> Int
V.length SBox
buf
  xornrond :: Word8
xornrond = Int -> Word8
forall a. (Integral a, Bits a) => a -> Word8
xorn Int
rond
  i1 :: SBox
i1 = Int -> [Word8] -> SBox
forall a. Unbox a => Int -> [a] -> Vector a
V.fromListN Int
len ([Word8] -> SBox) -> [Word8] -> SBox
forall a b. (a -> b) -> a -> b
$ (Word8 -> Word8 -> Word8) -> [Word8] -> [Word8] -> [Word8]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (-) (SBox -> [Word8]
forall a. Unbox a => Vector a -> [a]
V.toList SBox
buf)
    ((Word8 -> Word8) -> [Word8] -> [Word8]
forall a b. (a -> b) -> [a] -> [b]
map (Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
xor Word8
xornrond) (SBox -> [Word8]
forall a. Unbox a => Vector a -> [a]
V.toList SBox
xornary))
  i2 :: SBox
i2 = SBox -> Int -> SBox
rotBitcount SBox
i1 (-Int
1)
  i3 :: SBox
i3 = Int -> [Word8] -> SBox
forall a. Unbox a => Int -> [a] -> Vector a
V.fromListN Int
len ([Word8] -> SBox) -> [Word8] -> SBox
forall a b. (a -> b) -> a -> b
$ (Int -> Word8) -> [Int] -> [Word8]
forall a b. (a -> b) -> [a] -> [b]
map (SBox
sbox SBox -> Int -> Word8
forall a. Unbox a => Vector a -> Int -> a
V.!) ([Int] -> [Word8]) -> [Int] -> [Word8]
forall a b. (a -> b) -> a -> b
$
    (Word8 -> Word8 -> Int) -> [Word8] -> [Word8] -> [Int]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Word8 -> Word8 -> Int
forall a b. (Integral a, Integral b) => a -> b -> Int
sboxInx (Int -> [Word8] -> [Word8]
forall a. Int -> [a] -> [a]
drop Int
rond [Word8]
cycle3) (SBox -> [Word8]
forall a. Unbox a => Vector a -> [a]
V.toList SBox
i2)
  i4 :: SBox
i4 = SBox -> Int -> SBox
mix3Parts SBox
i3 Int
rprime

encryptFun :: Wring -> V.Vector Word8 -> V.Vector Word8
encryptFun :: Wring -> SBox -> SBox
encryptFun Wring
wring SBox
buf = (SBox -> Int -> SBox) -> SBox -> [Int] -> SBox
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (Int -> SBox -> SBox -> SBox -> Int -> SBox
roundEncryptFun Int
rprime (Wring -> SBox
sbox Wring
wring) SBox
xornary)
  SBox
buf [Int]
rounds
  where
    len :: Int
len = SBox -> Int
forall a. Unbox a => Vector a -> Int
V.length SBox
buf
    xornary :: SBox
xornary = Int -> SBox
xornArray Int
len
    rprime :: Int
rprime = Integer -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Integer -> Int) -> Integer -> Int
forall a b. (a -> b) -> a -> b
$ Integer -> Integer
findMaxOrder (Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Integer) -> Int -> Integer
forall a b. (a -> b) -> a -> b
$ Int
len Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
3)
    rounds :: [Int]
rounds = [Int
0 .. (Int -> Int
forall a. Integral a => a -> a
nRounds Int
len) Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]

decryptFun :: Wring -> V.Vector Word8 -> V.Vector Word8
decryptFun :: Wring -> SBox -> SBox
decryptFun Wring
wring SBox
buf = (SBox -> Int -> SBox) -> SBox -> [Int] -> SBox
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (Int -> SBox -> SBox -> SBox -> Int -> SBox
roundDecryptFun Int
rprime (Wring -> SBox
invSbox Wring
wring) SBox
xornary)
  SBox
buf [Int]
rounds
  where
    len :: Int
len = SBox -> Int
forall a. Unbox a => Vector a -> Int
V.length SBox
buf
    xornary :: SBox
xornary = Int -> SBox
xornArray Int
len
    rprime :: Int
rprime = Integer -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Integer -> Int) -> Integer -> Int
forall a b. (a -> b) -> a -> b
$ Integer -> Integer
findMaxOrder (Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Integer) -> Int -> Integer
forall a b. (a -> b) -> a -> b
$ Int
len Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
3)
    rounds :: [Int]
rounds = [Int] -> [Int]
forall a. [a] -> [a]
reverse [Int
0 .. (Int -> Int
forall a. Integral a => a -> a
nRounds Int
len) Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]

-- ST monad version modifies memory in place
-- by int-e

{-# NOINLINE roundEncryptST #-}
roundEncryptST ::
  Int ->
  SBox ->
  V.Vector Word8 ->
  MV.MVector s Word8 ->
  MV.MVector s Word8 ->
  Int ->
  ST s ()
roundEncryptST :: forall s.
Int
-> SBox
-> SBox
-> MVector s Word8
-> MVector s Word8
-> Int
-> ST s ()
roundEncryptST Int
rprime SBox
sbox SBox
xornary MVector s Word8
buf MVector s Word8
tmp Int
rond = do
  let len :: Int
len = MVector s Word8 -> Int
forall a s. Unbox a => MVector s a -> Int
MV.length MVector s Word8
buf
      xornrond :: Word8
xornrond = Int -> Word8
forall a. (Integral a, Bits a) => a -> Word8
xorn Int
rond
  MVector s Word8 -> Int -> ST s ()
forall s. MVector s Word8 -> Int -> ST s ()
mix3Parts' MVector s Word8
buf Int
rprime
  [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0..Int
lenInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
i -> do
      Word8
a <- MVector (PrimState (ST s)) Word8 -> Int -> ST s Word8
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
MV.read MVector s Word8
MVector (PrimState (ST s)) Word8
buf Int
i
      MVector (PrimState (ST s)) Word8 -> Int -> Word8 -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
MV.write MVector s Word8
MVector (PrimState (ST s)) Word8
tmp Int
i (SBox
sbox SBox -> Int -> Word8
forall a. Unbox a => Vector a -> Int -> a
V.! (Int -> Word8 -> Int
forall a b. (Integral a, Integral b) => a -> b -> Int
sboxInx ((Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
rond) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`rem` Int
3) Word8
a))
  MVector s Word8 -> Int -> MVector s Word8 -> ST s ()
forall s. MVector s Word8 -> Int -> MVector s Word8 -> ST s ()
rotBitcount' MVector s Word8
tmp Int
1 MVector s Word8
buf
  [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0..Int
lenInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
i -> do
      Word8
a <- MVector (PrimState (ST s)) Word8 -> Int -> ST s Word8
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
MV.read MVector s Word8
MVector (PrimState (ST s)) Word8
buf Int
i
      let a' :: Word8
a' = Word8
a Word8 -> Word8 -> Word8
forall a. Num a => a -> a -> a
+ (Word8
xornrond Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
`xor` (SBox
xornary SBox -> Int -> Word8
forall a. Unbox a => Vector a -> Int -> a
V.! Int
i))
      MVector (PrimState (ST s)) Word8 -> Int -> Word8 -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
MV.write MVector s Word8
MVector (PrimState (ST s)) Word8
buf Int
i Word8
a'

{-# NOINLINE roundEncryptFixedST #-}
roundEncryptFixedST ::
  Int ->
  SBox ->
  V.Vector Word8 ->
  MV.MVector s Word8 ->
  MV.MVector s Word8 ->
  Int ->
  ST s ()
roundEncryptFixedST :: forall s.
Int
-> SBox
-> SBox
-> MVector s Word8
-> MVector s Word8
-> Int
-> ST s ()
roundEncryptFixedST Int
rprime SBox
sbox SBox
xornary MVector s Word8
buf MVector s Word8
tmp Int
rond = do
  let len :: Int
len = MVector s Word8 -> Int
forall a s. Unbox a => MVector s a -> Int
MV.length MVector s Word8
buf
      xornrond :: Word8
xornrond = Int -> Word8
forall a. (Integral a, Bits a) => a -> Word8
xorn Int
rond
  MVector s Word8 -> Int -> ST s ()
forall s. MVector s Word8 -> Int -> ST s ()
mix3Parts' MVector s Word8
buf Int
rprime
  [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0..Int
lenInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
i -> do
      Word8
a <- MVector (PrimState (ST s)) Word8 -> Int -> ST s Word8
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
MV.read MVector s Word8
MVector (PrimState (ST s)) Word8
buf Int
i
      MVector (PrimState (ST s)) Word8 -> Int -> Word8 -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
MV.write MVector s Word8
MVector (PrimState (ST s)) Word8
tmp Int
i (SBox
sbox SBox -> Int -> Word8
forall a. Unbox a => Vector a -> Int -> a
V.! (Int -> Word8 -> Int
forall a b. (Integral a, Integral b) => a -> b -> Int
sboxInx ((Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
rond) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`rem` Int
3) Word8
a))
  MVector s Word8 -> Int -> MVector s Word8 -> ST s ()
forall s. MVector s Word8 -> Int -> MVector s Word8 -> ST s ()
rotFixed' MVector s Word8
tmp Int
1 MVector s Word8
buf
  [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0..Int
lenInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
i -> do
      Word8
a <- MVector (PrimState (ST s)) Word8 -> Int -> ST s Word8
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
MV.read MVector s Word8
MVector (PrimState (ST s)) Word8
buf Int
i
      let a' :: Word8
a' = Word8
a Word8 -> Word8 -> Word8
forall a. Num a => a -> a -> a
+ (Word8
xornrond Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
`xor` (SBox
xornary SBox -> Int -> Word8
forall a. Unbox a => Vector a -> Int -> a
V.! Int
i))
      MVector (PrimState (ST s)) Word8 -> Int -> Word8 -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
MV.write MVector s Word8
MVector (PrimState (ST s)) Word8
buf Int
i Word8
a'

{-# NOINLINE roundDecryptST #-}
roundDecryptST ::
  Int ->
  SBox ->
  V.Vector Word8 ->
  MV.MVector s Word8 ->
  MV.MVector s Word8 ->
  Int ->
  ST s ()
roundDecryptST :: forall s.
Int
-> SBox
-> SBox
-> MVector s Word8
-> MVector s Word8
-> Int
-> ST s ()
roundDecryptST Int
rprime SBox
sbox SBox
xornary MVector s Word8
buf MVector s Word8
tmp Int
rond = do
  let len :: Int
len = MVector s Word8 -> Int
forall a s. Unbox a => MVector s a -> Int
MV.length MVector s Word8
buf
      xornrond :: Word8
xornrond = Int -> Word8
forall a. (Integral a, Bits a) => a -> Word8
xorn Int
rond
  [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0..Int
lenInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
i -> do
      Word8
a <- MVector (PrimState (ST s)) Word8 -> Int -> ST s Word8
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
MV.read MVector s Word8
MVector (PrimState (ST s)) Word8
buf Int
i
      let a' :: Word8
a' = Word8
a Word8 -> Word8 -> Word8
forall a. Num a => a -> a -> a
- (Word8
xornrond Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
`xor` (SBox
xornary SBox -> Int -> Word8
forall a. Unbox a => Vector a -> Int -> a
V.! Int
i))
      MVector (PrimState (ST s)) Word8 -> Int -> Word8 -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
MV.write MVector s Word8
MVector (PrimState (ST s)) Word8
tmp Int
i Word8
a'
  MVector s Word8 -> Int -> MVector s Word8 -> ST s ()
forall s. MVector s Word8 -> Int -> MVector s Word8 -> ST s ()
rotBitcount' MVector s Word8
tmp (-Int
1) MVector s Word8
buf
  [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0..Int
lenInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
i -> do
      Word8
a <- MVector (PrimState (ST s)) Word8 -> Int -> ST s Word8
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
MV.read MVector s Word8
MVector (PrimState (ST s)) Word8
buf Int
i
      MVector (PrimState (ST s)) Word8 -> Int -> Word8 -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
MV.write MVector s Word8
MVector (PrimState (ST s)) Word8
buf Int
i (SBox
sbox SBox -> Int -> Word8
forall a. Unbox a => Vector a -> Int -> a
V.! (Int -> Word8 -> Int
forall a b. (Integral a, Integral b) => a -> b -> Int
sboxInx ((Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
rond) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`rem` Int
3) Word8
a))
  MVector s Word8 -> Int -> ST s ()
forall s. MVector s Word8 -> Int -> ST s ()
mix3Parts' MVector s Word8
buf Int
rprime

encryptST :: Wring -> V.Vector Word8 -> V.Vector Word8
encryptST :: Wring -> SBox -> SBox
encryptST Wring
wring SBox
buf = (forall s. ST s (MVector s Word8)) -> SBox
forall a. Unbox a => (forall s. ST s (MVector s a)) -> Vector a
V.create ((forall s. ST s (MVector s Word8)) -> SBox)
-> (forall s. ST s (MVector s Word8)) -> SBox
forall a b. (a -> b) -> a -> b
$ do
  let len :: Int
len = SBox -> Int
forall a. Unbox a => Vector a -> Int
V.length SBox
buf
      xornary :: SBox
xornary = Int -> SBox
xornArray Int
len
      rprime :: Int
rprime = Integer -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Integer -> Int) -> Integer -> Int
forall a b. (a -> b) -> a -> b
$ Integer -> Integer
findMaxOrder (Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Integer) -> Int -> Integer
forall a b. (a -> b) -> a -> b
$ Int
len Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
3)
      rounds :: [Int]
rounds = [Int
0 .. Int -> Int
forall a. Integral a => a -> a
nRounds Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]
  MVector s Word8
buf <- SBox -> ST s (MVector (PrimState (ST s)) Word8)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
Vector a -> m (MVector (PrimState m) a)
V.thaw SBox
buf
  MVector s Word8
tmp <- Int -> ST s (MVector (PrimState (ST s)) Word8)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> m (MVector (PrimState m) a)
MV.new Int
len
  [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0..Int -> Int
forall a. Integral a => a -> a
nRounds Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
rond -> do
    Int
-> SBox
-> SBox
-> MVector s Word8
-> MVector s Word8
-> Int
-> ST s ()
forall s.
Int
-> SBox
-> SBox
-> MVector s Word8
-> MVector s Word8
-> Int
-> ST s ()
roundEncryptST Int
rprime (Wring -> SBox
sbox Wring
wring) SBox
xornary MVector s Word8
buf MVector s Word8
tmp Int
rond
  MVector s Word8 -> ST s (MVector s Word8)
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure MVector s Word8
buf

-- Encrypts using a fixed rotation of the buffer, rather than rotating
-- by its population count. This removes a source of nonlinearity.
encryptFixedST :: Wring -> V.Vector Word8 -> V.Vector Word8
encryptFixedST :: Wring -> SBox -> SBox
encryptFixedST Wring
wring SBox
buf = (forall s. ST s (MVector s Word8)) -> SBox
forall a. Unbox a => (forall s. ST s (MVector s a)) -> Vector a
V.create ((forall s. ST s (MVector s Word8)) -> SBox)
-> (forall s. ST s (MVector s Word8)) -> SBox
forall a b. (a -> b) -> a -> b
$ do
  let len :: Int
len = SBox -> Int
forall a. Unbox a => Vector a -> Int
V.length SBox
buf
      xornary :: SBox
xornary = Int -> SBox
xornArray Int
len
      rprime :: Int
rprime = Integer -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Integer -> Int) -> Integer -> Int
forall a b. (a -> b) -> a -> b
$ Integer -> Integer
findMaxOrder (Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Integer) -> Int -> Integer
forall a b. (a -> b) -> a -> b
$ Int
len Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
3)
      rounds :: [Int]
rounds = [Int
0 .. Int -> Int
forall a. Integral a => a -> a
nRounds Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]
  MVector s Word8
buf <- SBox -> ST s (MVector (PrimState (ST s)) Word8)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
Vector a -> m (MVector (PrimState m) a)
V.thaw SBox
buf
  MVector s Word8
tmp <- Int -> ST s (MVector (PrimState (ST s)) Word8)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> m (MVector (PrimState m) a)
MV.new Int
len
  [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0..Int -> Int
forall a. Integral a => a -> a
nRounds Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
rond -> do
    Int
-> SBox
-> SBox
-> MVector s Word8
-> MVector s Word8
-> Int
-> ST s ()
forall s.
Int
-> SBox
-> SBox
-> MVector s Word8
-> MVector s Word8
-> Int
-> ST s ()
roundEncryptFixedST Int
rprime (Wring -> SBox
sbox Wring
wring) SBox
xornary MVector s Word8
buf MVector s Word8
tmp Int
rond
  MVector s Word8 -> ST s (MVector s Word8)
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure MVector s Word8
buf

decryptST :: Wring -> V.Vector Word8 -> V.Vector Word8
decryptST :: Wring -> SBox -> SBox
decryptST Wring
wring SBox
buf = (forall s. ST s (MVector s Word8)) -> SBox
forall a. Unbox a => (forall s. ST s (MVector s a)) -> Vector a
V.create ((forall s. ST s (MVector s Word8)) -> SBox)
-> (forall s. ST s (MVector s Word8)) -> SBox
forall a b. (a -> b) -> a -> b
$ do
  let len :: Int
len = SBox -> Int
forall a. Unbox a => Vector a -> Int
V.length SBox
buf
      xornary :: SBox
xornary = Int -> SBox
xornArray Int
len
      rprime :: Int
rprime = Integer -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Integer -> Int) -> Integer -> Int
forall a b. (a -> b) -> a -> b
$ Integer -> Integer
findMaxOrder (Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Integer) -> Int -> Integer
forall a b. (a -> b) -> a -> b
$ Int
len Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
3)
      nr :: Int
nr = Int -> Int
forall a. Integral a => a -> a
nRounds Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1
      rounds :: [Int]
rounds = [Int
0 .. Int
nr]
  MVector s Word8
buf <- SBox -> ST s (MVector (PrimState (ST s)) Word8)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
Vector a -> m (MVector (PrimState m) a)
V.thaw SBox
buf
  MVector s Word8
tmp <- Int -> ST s (MVector (PrimState (ST s)) Word8)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> m (MVector (PrimState m) a)
MV.new Int
len
  [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0..Int -> Int
forall a. Integral a => a -> a
nRounds Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
rond -> do
    Int
-> SBox
-> SBox
-> MVector s Word8
-> MVector s Word8
-> Int
-> ST s ()
forall s.
Int
-> SBox
-> SBox
-> MVector s Word8
-> MVector s Word8
-> Int
-> ST s ()
roundDecryptST Int
rprime (Wring -> SBox
invSbox Wring
wring) SBox
xornary MVector s Word8
buf MVector s Word8
tmp (Int
nr Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
rond)
  MVector s Word8 -> ST s (MVector s Word8)
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure MVector s Word8
buf

-- Use either the ST version or the Fun version.
-- Fun takes 4.2 times as long as ST when encrypting or decrypting a file,
-- but doing six threads in parallel for cryptanalysis, it takes only
-- 1.58 times as long. Turning off threading makes encrypting 5.2 times
-- as fast.

-- | Encrypts a message.
encrypt
  :: Wring -- ^ The `Wring` made with the key to encrypt with
  -> V.Vector Word8 -- ^ The plaintext
  -> V.Vector Word8 -- ^ The returned ciphertext
encrypt :: Wring -> SBox -> SBox
encrypt = Wring -> SBox -> SBox
encryptST

-- | Used only for cryptanalysis
encryptFixed :: Wring -> SBox -> SBox
encryptFixed = Wring -> SBox -> SBox
encryptFixedST

-- | Decrypts a message.
decrypt
  :: Wring -- ^ The `Wring` made with the key to decrypt with
  -> V.Vector Word8 -- ^ The ciphertext
  -> V.Vector Word8 -- ^ The returned plaintext
decrypt :: Wring -> SBox -> SBox
decrypt = Wring -> SBox -> SBox
decryptST