module Cryptography.Twistree
  ( Twistree
  , SBox -- reexported for use in Cryptanalysis.hs
  , sboxes -- "
  , sameBitcount -- "
  , compress -- "
  , linearSbox -- "
  , linearTwistree -- Only for cryptanalysis and testing
  , parListDeal
  , keyedTwistree
  , hash
  ) where

{-
This hash function uses a double-tree construction, as shown in this drawing:

                                                  2
                               -------------------+-------------------
               ----------------+---------------                      |
       --------+--------               -------+---------             |
   ----+----       ----+----       ----+----       ----+----       --+---
 --+--   --+--   --+--   --+--   --+--   --+--   --+--   --+--   --+--  |
-+- -+- -+- -+- -+- -+- -+- -+- -+- -+- -+- -+- -+- -+- -+- -+- -+- -+- |
4 * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * *
--+-- --+-- --+-- --+-- --+-- --+-- --+-- --+-- --+-- --+-- --+-- --+-- |
  ------+------     ------+------     ------+------     ------+------   |
        ------------------+------------------                 -----+-----
                          ---------------------+--------------------
                                               3
2 3
-+-
 H

*   A block of the message to be hashed, including padding at the end.
4   Binary representation of exp(4). One is used in the binary tree and the
    other in the ternary tree.
3   Output of the ternary tree
2   Output of the binary tree
H   Final hash output
-}

import Cryptography.WringTwistree.Compress
import Cryptography.WringTwistree.Blockize
import Cryptography.WringTwistree.Sboxes
import Control.Parallel
import Control.Parallel.Strategies
import Data.List (transpose)
import Data.List.Split
import Data.Word
import Data.Bits
import Data.Array.Unboxed
import Data.Foldable (foldl')
import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as BL
import qualified Data.Vector.Unboxed as V

data Twistree = Twistree
  { Twistree -> SBox
sbox    :: SBox
  } deriving Int -> Twistree -> ShowS
[Twistree] -> ShowS
Twistree -> String
(Int -> Twistree -> ShowS)
-> (Twistree -> String) -> ([Twistree] -> ShowS) -> Show Twistree
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Twistree -> ShowS
showsPrec :: Int -> Twistree -> ShowS
$cshow :: Twistree -> String
show :: Twistree -> String
$cshowList :: [Twistree] -> ShowS
showList :: [Twistree] -> ShowS
Show

deal :: Int -> [a] -> [[a]]
deal Int
n = [[a]] -> [[a]]
forall a. [[a]] -> [[a]]
transpose ([[a]] -> [[a]]) -> ([a] -> [[a]]) -> [a] -> [[a]]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> [a] -> [[a]]
forall e. Int -> [e] -> [[e]]
chunksOf Int
n -- to be used as a parallel strategy

parListDeal :: Int -> Strategy a -> Strategy [a]
parListDeal :: forall a. Int -> Strategy a -> Strategy [a]
parListDeal Int
n Strategy a
strat [a]
xs
  | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
1    = Strategy a -> Strategy [a]
forall a. Strategy a -> Strategy [a]
evalList Strategy a
strat [a]
xs
  | Bool
otherwise = [[a]] -> [a]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[a]] -> [a]) -> Eval [[a]] -> Eval [a]
forall a b. (a -> b) -> Eval a -> Eval b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
`fmap` Strategy [a] -> Strategy [[a]]
forall a. Strategy a -> Strategy [a]
parList (Strategy a -> Strategy [a]
forall a. Strategy a -> Strategy [a]
evalList Strategy a
strat) (Int -> [a] -> [[a]]
forall e. Int -> [e] -> [[e]]
deal Int
n [a]
xs)

compressPairs :: SBox -> [V.Vector Word8] -> [V.Vector Word8]
compressPairs :: SBox -> [SBox] -> [SBox]
compressPairs SBox
_ [] = []
compressPairs SBox
_ [SBox
x] = [SBox
x]
compressPairs SBox
sbox (SBox
x:SBox
y:[SBox]
xs) = SBox -> [SBox] -> [SBox]
forall a b. a -> b -> b
pseq (SBox -> SBox -> SBox -> Int -> SBox
compress2 SBox
sbox SBox
x SBox
y Int
0) ([SBox] -> [SBox]) -> [SBox] -> [SBox]
forall a b. (a -> b) -> a -> b
$
  ((SBox -> SBox -> SBox -> Int -> SBox
compress2 SBox
sbox SBox
x SBox
y Int
0) SBox -> [SBox] -> [SBox]
forall a. a -> [a] -> [a]
: SBox -> [SBox] -> [SBox]
compressPairs SBox
sbox [SBox]
xs)

hashPairs :: SBox -> [V.Vector Word8] -> V.Vector Word8
hashPairs :: SBox -> [SBox] -> SBox
hashPairs SBox
_ [] = SBox
forall a. HasCallStack => a
undefined -- can't happen, there's always at least exp(4)
hashPairs SBox
_ [SBox
x] = SBox
x
hashPairs SBox
sbox [SBox]
x = [SBox] -> SBox -> SBox
forall a b. a -> b -> b
par (SBox -> [SBox] -> [SBox]
compressPairs SBox
sbox [SBox]
x) (SBox -> SBox) -> SBox -> SBox
forall a b. (a -> b) -> a -> b
$
  SBox -> [SBox] -> SBox
hashPairs SBox
sbox (SBox -> [SBox] -> [SBox]
compressPairs SBox
sbox [SBox]
x)

compressTriples :: SBox -> [V.Vector Word8] -> [V.Vector Word8]
compressTriples :: SBox -> [SBox] -> [SBox]
compressTriples SBox
_ [] = []
compressTriples SBox
_ [SBox
x] = [SBox
x]
compressTriples SBox
sbox [SBox
x,SBox
y] = [SBox -> SBox -> SBox -> Int -> SBox
compress2 SBox
sbox SBox
x SBox
y Int
1]
compressTriples SBox
sbox (SBox
x:SBox
y:SBox
z:[SBox]
xs) = SBox -> [SBox] -> [SBox]
forall a b. a -> b -> b
pseq (SBox -> SBox -> SBox -> SBox -> Int -> SBox
compress3 SBox
sbox SBox
x SBox
y SBox
z Int
1) ([SBox] -> [SBox]) -> [SBox] -> [SBox]
forall a b. (a -> b) -> a -> b
$
  ((SBox -> SBox -> SBox -> SBox -> Int -> SBox
compress3 SBox
sbox SBox
x SBox
y SBox
z Int
1) SBox -> [SBox] -> [SBox]
forall a. a -> [a] -> [a]
: SBox -> [SBox] -> [SBox]
compressTriples SBox
sbox [SBox]
xs)

hashTriples :: SBox -> [V.Vector Word8] -> V.Vector Word8
hashTriples :: SBox -> [SBox] -> SBox
hashTriples SBox
_ [] = SBox
forall a. HasCallStack => a
undefined -- can't happen, there's always at least exp(4)
hashTriples SBox
_ [SBox
x] = SBox
x
hashTriples SBox
sbox [SBox]
x = [SBox] -> SBox -> SBox
forall a b. a -> b -> b
par (SBox -> [SBox] -> [SBox]
compressTriples SBox
sbox [SBox]
x) (SBox -> SBox) -> SBox -> SBox
forall a b. (a -> b) -> a -> b
$
  SBox -> [SBox] -> SBox
hashTriples SBox
sbox (SBox -> [SBox] -> [SBox]
compressTriples SBox
sbox [SBox]
x)

-- | A `Twistree` with linear `SBox`. Used only for testing and cryptanalysis.
linearTwistree :: Twistree
linearTwistree = SBox -> Twistree
Twistree SBox
linearSbox

-- | Creates a `Twistree` with the given key.
-- To convert a `String` to a `ByteString`, put @- utf8-string@ in your
-- package.yaml dependencies, @import Data.ByteString.UTF8@, and use
-- `fromString`.
keyedTwistree :: B.ByteString -> Twistree
keyedTwistree :: ByteString -> Twistree
keyedTwistree ByteString
key = SBox -> Twistree
Twistree SBox
sbox where
  sbox :: SBox
sbox = ByteString -> SBox
sboxes ByteString
key

hash
  :: Twistree -- ^ The `Twistree` made with the key to hash with
  -> BL.ByteString -- ^ The text to be hashed. It's a lazy `ByteString`,
    -- so you can hash a file bigger than RAM.
  -> V.Vector Word8 -- ^ The returned hash, 32 bytes.
hash :: Twistree -> ByteString -> SBox
hash Twistree
twistree ByteString
stream = [SBox] -> SBox -> SBox
forall a b. a -> b -> b
par [SBox]
blocks (SBox -> SBox) -> SBox -> SBox
forall a b. (a -> b) -> a -> b
$ SBox -> SBox -> SBox
forall a b. a -> b -> b
par SBox
h2 (SBox -> SBox) -> SBox -> SBox
forall a b. (a -> b) -> a -> b
$ SBox -> SBox -> SBox
forall a b. a -> b -> b
par SBox
h3 (SBox -> SBox) -> SBox -> SBox
forall a b. (a -> b) -> a -> b
$
  SBox -> SBox -> SBox -> Int -> SBox
compress2 (Twistree -> SBox
sbox Twistree
twistree) SBox
h2 SBox
h3 Int
2 where
    blocks :: [SBox]
blocks = ByteString -> [SBox]
blockize ByteString
stream
    h2 :: SBox
h2 = SBox -> [SBox] -> SBox
hashPairs (Twistree -> SBox
sbox Twistree
twistree) (SBox
exp4_2adic SBox -> [SBox] -> [SBox]
forall a. a -> [a] -> [a]
: [SBox]
blocks)
    h3 :: SBox
h3 = SBox -> [SBox] -> SBox
hashTriples (Twistree -> SBox
sbox Twistree
twistree) (SBox
exp4_base2 SBox -> [SBox] -> [SBox]
forall a. a -> [a] -> [a]
: [SBox]
blocks)