{-# LANGUAGE BangPatterns #-}
module Cryptography.WringTwistree.RotBitcount
  ( rotBitcount
  , rotBitcount'
  , rotFixed
  , rotFixed'
  )  where

{- This module is used in both Wring and Twistree.
 - It rotates an array of bytes by a multiple of its bitcount,
 - producing another array of the same size. As long as the multiplier
 - is relatively prime to the number of bits in the array, this
 - operation satisfies the strict avalanche criterion. Changing *two*
 - bits, however, has half a chance of changing only two bits in
 - the output.
 -
 - Bit 0 of byte 0 is bit 0 of the array. Bit 0 of byte 1 is bit 8 of the array.
 - e1 00 00 00 00 00 00 00, rotated by its bitcount (4), becomes
 - 10 0e 00 00 00 00 00 00.
 -}

import Data.Bits
import Data.Word
import qualified Data.Vector.Unboxed as V
import qualified Data.Vector.Unboxed.Mutable as MV
import Control.Monad.ST
import Control.Monad
import Debug.Trace

rotBitcount :: V.Vector Word8 -> Int -> V.Vector Word8
-- See Rust code for a timing leak which may be present in (.>>.).
rotBitcount :: Vector Word8 -> Int -> Vector Word8
rotBitcount Vector Word8
src Int
mult = Int -> [Word8] -> Vector Word8
forall a. Unbox a => Int -> [a] -> Vector a
V.fromListN Int
len
  [ (Vector Word8
src Vector Word8 -> Int -> Word8
forall a. Unbox a => Vector a -> Int -> a
V.! ((Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
lenInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
byte)   Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
len) Word8 -> Int -> Word8
forall a. Bits a => a -> Int -> a
.<<. Int
bit) Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.|.
    (Vector Word8
src Vector Word8 -> Int -> Word8
forall a. Unbox a => Vector a -> Int -> a
V.! ((Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
lenInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
byteInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
len) Word8 -> Int -> Word8
forall a. Bits a => a -> Int -> a
.>>. (Int
8Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
bit)) | Int
i <- [Int
0..(Int
lenInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1)]]
  where
    len :: Int
len = Vector Word8 -> Int
forall a. Unbox a => Vector a -> Int
V.length Vector Word8
src
    multmod :: Int
multmod = if Int
lenInt -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>Int
0 then Int
mult Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` (Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
8) else Int
mult
    bitcount :: Int
bitcount = [Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Int] -> Int) -> [Int] -> Int
forall a b. (a -> b) -> a -> b
$ (Word8 -> Int) -> [Word8] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map Word8 -> Int
forall a. Bits a => a -> Int
popCount ([Word8] -> [Int]) -> [Word8] -> [Int]
forall a b. (a -> b) -> a -> b
$ Vector Word8 -> [Word8]
forall a. Unbox a => Vector a -> [a]
V.toList Vector Word8
src
    rotcount :: Int
rotcount = if Int
lenInt -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>Int
0 then (Int
bitcount Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
multmod) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` (Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
8)
			else Int
bitcount Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
multmod
    !byte :: Int
byte = Int
rotcount Int -> Int -> Int
forall a. Bits a => a -> Int -> a
.>>. Int
3
    !bit :: Int
bit = Int
rotcount Int -> Int -> Int
forall a. Bits a => a -> a -> a
.&. Int
7

rotBitcount' :: MV.MVector s Word8 -> Int -> MV.MVector s Word8 -> ST s ()
rotBitcount' :: forall s. MVector s Word8 -> Int -> MVector s Word8 -> ST s ()
rotBitcount' MVector s Word8
src Int
mult MVector s Word8
dst = do
    Int
bitcount <- (Int -> Word8 -> Int)
-> Int -> MVector (PrimState (ST s)) Word8 -> ST s Int
forall (m :: * -> *) a b.
(PrimMonad m, Unbox a) =>
(b -> a -> b) -> b -> MVector (PrimState m) a -> m b
MV.foldl' (\Int
acc Word8
x -> Int
acc Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Word8 -> Int
forall a. Bits a => a -> Int
popCount Word8
x) Int
0 MVector s Word8
MVector (PrimState (ST s)) Word8
src
    let len :: Int
len = MVector s Word8 -> Int
forall a s. Unbox a => MVector s a -> Int
MV.length MVector s Word8
src
        !multmod :: Int
multmod = if Int
lenInt -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>Int
0 then Int
mult Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` (Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
8) else Int
mult
        rotcount :: Int
rotcount = if Int
lenInt -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>Int
0 then (Int
bitcount Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
multmod) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`rem` (Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
8)
			    else Int
bitcount Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
multmod
        !byte :: Int
byte = Int
rotcount Int -> Int -> Int
forall a. Bits a => a -> Int -> a
.>>. Int
3
        !bit :: Int
bit = Int
rotcount Int -> Int -> Int
forall a. Bits a => a -> a -> a
.&. Int
7
    [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0..MVector s Word8 -> Int
forall a s. Unbox a => MVector s a -> Int
MV.length MVector s Word8
src Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
i -> do
        Word8
l <- MVector (PrimState (ST s)) Word8 -> Int -> ST s Word8
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
MV.read MVector s Word8
MVector (PrimState (ST s)) Word8
src ((Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
lenInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
byte) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`rem` Int
len)
        Word8
r <- MVector (PrimState (ST s)) Word8 -> Int -> ST s Word8
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
MV.read MVector s Word8
MVector (PrimState (ST s)) Word8
src ((Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
lenInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
byteInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`rem` Int
len)
        MVector (PrimState (ST s)) Word8 -> Int -> Word8 -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
MV.write MVector s Word8
MVector (PrimState (ST s)) Word8
dst Int
i ((Word8
l Word8 -> Int -> Word8
forall a. Bits a => a -> Int -> a
.<<. Int
bit) Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.|. (Word8
r Word8 -> Int -> Word8
forall a. Bits a => a -> Int -> a
.>>. (Int
8Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
bit)))

-- For cryptanalyzing a weakened version which replaces the SACful rotBitcount
-- with a linear fixed rotation. It rotates by two more than half the number
-- of bits in the buffer, times mult.
rotFixed :: V.Vector Word8 -> Int -> V.Vector Word8
rotFixed :: Vector Word8 -> Int -> Vector Word8
rotFixed Vector Word8
src Int
mult = Int -> [Word8] -> Vector Word8
forall a. Unbox a => Int -> [a] -> Vector a
V.fromListN Int
len
  [ (Vector Word8
src Vector Word8 -> Int -> Word8
forall a. Unbox a => Vector a -> Int -> a
V.! ((Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
lenInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
byte)   Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
len) Word8 -> Int -> Word8
forall a. Bits a => a -> Int -> a
.<<. Int
bit) Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.|.
    (Vector Word8
src Vector Word8 -> Int -> Word8
forall a. Unbox a => Vector a -> Int -> a
V.! ((Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
lenInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
byteInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
len) Word8 -> Int -> Word8
forall a. Bits a => a -> Int -> a
.>>. (Int
8Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
bit)) | Int
i <- [Int
0..(Int
lenInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1)]]
  where
    len :: Int
len = Vector Word8 -> Int
forall a. Unbox a => Vector a -> Int
V.length Vector Word8
src
    multmod :: Int
multmod = Int
mult Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` (Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
8)
    bitcount :: Int
bitcount = Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
4 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
2
    rotcount :: Int
rotcount = (Int
bitcount Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
multmod) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` (Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
8)
    !byte :: Int
byte = Int
rotcount Int -> Int -> Int
forall a. Bits a => a -> Int -> a
.>>. Int
3
    !bit :: Int
bit = Int
rotcount Int -> Int -> Int
forall a. Bits a => a -> a -> a
.&. Int
7

rotFixed' :: MV.MVector s Word8 -> Int -> MV.MVector s Word8 -> ST s ()
rotFixed' :: forall s. MVector s Word8 -> Int -> MVector s Word8 -> ST s ()
rotFixed' MVector s Word8
src Int
mult MVector s Word8
dst = do
    let len :: Int
len = MVector s Word8 -> Int
forall a s. Unbox a => MVector s a -> Int
MV.length MVector s Word8
src
        bitcount :: Int
bitcount = Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
4 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
2
        !multmod :: Int
multmod = Int
mult Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` (Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
8)
        rotcount :: Int
rotcount = (Int
bitcount Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
multmod) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`rem` (Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
8)
        !byte :: Int
byte = Int
rotcount Int -> Int -> Int
forall a. Bits a => a -> Int -> a
.>>. Int
3
        !bit :: Int
bit = Int
rotcount Int -> Int -> Int
forall a. Bits a => a -> a -> a
.&. Int
7
    [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0..MVector s Word8 -> Int
forall a s. Unbox a => MVector s a -> Int
MV.length MVector s Word8
src Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
i -> do
        Word8
l <- MVector (PrimState (ST s)) Word8 -> Int -> ST s Word8
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
MV.read MVector s Word8
MVector (PrimState (ST s)) Word8
src ((Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
lenInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
byte) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`rem` Int
len)
        Word8
r <- MVector (PrimState (ST s)) Word8 -> Int -> ST s Word8
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
MV.read MVector s Word8
MVector (PrimState (ST s)) Word8
src ((Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
lenInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
byteInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`rem` Int
len)
        MVector (PrimState (ST s)) Word8 -> Int -> Word8 -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
MV.write MVector s Word8
MVector (PrimState (ST s)) Word8
dst Int
i ((Word8
l Word8 -> Int -> Word8
forall a. Bits a => a -> Int -> a
.<<. Int
bit) Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.|. (Word8
r Word8 -> Int -> Word8
forall a. Bits a => a -> Int -> a
.>>. (Int
8Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
bit)))