-- |
-- Module      : Crypto.ConstructHash.MiyaguchiPreneel
-- License     : BSD-style
-- Maintainer  : Kei Hibino <ex8k.hibino@gmail.com>
-- Stability   : experimental
-- Portability : unknown
--
-- Provide the hash function construction method from block cipher
-- <https://en.wikipedia.org/wiki/One-way_compression_function>
--
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
module Crypto.ConstructHash.MiyaguchiPreneel
       ( compute, compute'
       , MiyaguchiPreneel
       ) where

import           Data.List (foldl')

import           Crypto.Data.Padding (pad, Format (ZERO))
import           Crypto.Cipher.Types
import           Crypto.Error (throwCryptoError)
import           Crypto.Internal.ByteArray (ByteArrayAccess, ByteArray, Bytes)
import qualified Crypto.Internal.ByteArray as B


newtype MiyaguchiPreneel a = MP Bytes
    deriving (MiyaguchiPreneel a -> Int
MiyaguchiPreneel a -> Ptr p -> IO ()
MiyaguchiPreneel a -> (Ptr p -> IO a) -> IO a
(MiyaguchiPreneel a -> Int)
-> (forall p a. MiyaguchiPreneel a -> (Ptr p -> IO a) -> IO a)
-> (forall p. MiyaguchiPreneel a -> Ptr p -> IO ())
-> ByteArrayAccess (MiyaguchiPreneel a)
forall a. MiyaguchiPreneel a -> Int
forall p. MiyaguchiPreneel a -> Ptr p -> IO ()
forall ba.
(ba -> Int)
-> (forall p a. ba -> (Ptr p -> IO a) -> IO a)
-> (forall p. ba -> Ptr p -> IO ())
-> ByteArrayAccess ba
forall a p. MiyaguchiPreneel a -> Ptr p -> IO ()
forall p a. MiyaguchiPreneel a -> (Ptr p -> IO a) -> IO a
forall a p a. MiyaguchiPreneel a -> (Ptr p -> IO a) -> IO a
copyByteArrayToPtr :: MiyaguchiPreneel a -> Ptr p -> IO ()
$ccopyByteArrayToPtr :: forall a p. MiyaguchiPreneel a -> Ptr p -> IO ()
withByteArray :: MiyaguchiPreneel a -> (Ptr p -> IO a) -> IO a
$cwithByteArray :: forall a p a. MiyaguchiPreneel a -> (Ptr p -> IO a) -> IO a
length :: MiyaguchiPreneel a -> Int
$clength :: forall a. MiyaguchiPreneel a -> Int
ByteArrayAccess)

instance Eq (MiyaguchiPreneel a) where
    MP Bytes
b1 == :: MiyaguchiPreneel a -> MiyaguchiPreneel a -> Bool
== MP Bytes
b2  =  Bytes -> Bytes -> Bool
forall bs1 bs2.
(ByteArrayAccess bs1, ByteArrayAccess bs2) =>
bs1 -> bs2 -> Bool
B.constEq Bytes
b1 Bytes
b2


-- | Compute Miyaguchi-Preneel one way compress using the supplied block cipher.
compute' :: (ByteArrayAccess bin, BlockCipher cipher)
         => (Bytes -> cipher)       -- ^ key build function to compute Miyaguchi-Preneel. care about block-size and key-size
         -> bin                     -- ^ input message
         -> MiyaguchiPreneel cipher -- ^ output tag
compute' :: (Bytes -> cipher) -> bin -> MiyaguchiPreneel cipher
compute' Bytes -> cipher
g = Bytes -> MiyaguchiPreneel cipher
forall a. Bytes -> MiyaguchiPreneel a
MP (Bytes -> MiyaguchiPreneel cipher)
-> (bin -> Bytes) -> bin -> MiyaguchiPreneel cipher
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Bytes -> Bytes -> Bytes) -> Bytes -> [Bytes] -> Bytes
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' ((Bytes -> cipher) -> Bytes -> Bytes -> Bytes
forall ba k.
(ByteArray ba, BlockCipher k) =>
(ba -> k) -> ba -> ba -> ba
step ((Bytes -> cipher) -> Bytes -> Bytes -> Bytes)
-> (Bytes -> cipher) -> Bytes -> Bytes -> Bytes
forall a b. (a -> b) -> a -> b
$ Bytes -> cipher
g) (Int -> Word8 -> Bytes
forall ba. ByteArray ba => Int -> Word8 -> ba
B.replicate Int
bsz Word8
0) ([Bytes] -> Bytes) -> (bin -> [Bytes]) -> bin -> Bytes
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Bytes -> [Bytes]
chunks (Bytes -> [Bytes]) -> (bin -> Bytes) -> bin -> [Bytes]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Format -> Bytes -> Bytes
forall byteArray.
ByteArray byteArray =>
Format -> byteArray -> byteArray
pad (Int -> Format
ZERO Int
bsz) (Bytes -> Bytes) -> (bin -> Bytes) -> bin -> Bytes
forall b c a. (b -> c) -> (a -> b) -> a -> c
. bin -> Bytes
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
B.convert
  where
    bsz :: Int
bsz = cipher -> Int
forall cipher. BlockCipher cipher => cipher -> Int
blockSize ( Bytes -> cipher
g Bytes
forall a. ByteArray a => a
B.empty {- dummy to get block size -} )
    chunks :: Bytes -> [Bytes]
chunks Bytes
msg
      | Bytes -> Bool
forall a. ByteArrayAccess a => a -> Bool
B.null Bytes
msg  =  []
      | Bool
otherwise  =   (Bytes
hd :: Bytes) Bytes -> [Bytes] -> [Bytes]
forall a. a -> [a] -> [a]
: Bytes -> [Bytes]
chunks Bytes
tl
      where
        (Bytes
hd, Bytes
tl) = Int -> Bytes -> (Bytes, Bytes)
forall bs. ByteArray bs => Int -> bs -> (bs, bs)
B.splitAt Int
bsz Bytes
msg

-- | Compute Miyaguchi-Preneel one way compress using the inferred block cipher.
--   Only safe when KEY-SIZE equals to BLOCK-SIZE.
--
--   Simple usage /mp' msg :: MiyaguchiPreneel AES128/
compute :: (ByteArrayAccess bin, BlockCipher cipher)
        => bin                     -- ^ input message
        -> MiyaguchiPreneel cipher -- ^ output tag
compute :: bin -> MiyaguchiPreneel cipher
compute = (Bytes -> cipher) -> bin -> MiyaguchiPreneel cipher
forall bin cipher.
(ByteArrayAccess bin, BlockCipher cipher) =>
(Bytes -> cipher) -> bin -> MiyaguchiPreneel cipher
compute' ((Bytes -> cipher) -> bin -> MiyaguchiPreneel cipher)
-> (Bytes -> cipher) -> bin -> MiyaguchiPreneel cipher
forall a b. (a -> b) -> a -> b
$ CryptoFailable cipher -> cipher
forall a. CryptoFailable a -> a
throwCryptoError (CryptoFailable cipher -> cipher)
-> (Bytes -> CryptoFailable cipher) -> Bytes -> cipher
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Bytes -> CryptoFailable cipher
forall cipher key.
(Cipher cipher, ByteArray key) =>
key -> CryptoFailable cipher
cipherInit

-- | computation step of Miyaguchi-Preneel
step :: (ByteArray ba, BlockCipher k)
     => (ba -> k)
     -> ba
     -> ba
     -> ba
step :: (ba -> k) -> ba -> ba -> ba
step ba -> k
g ba
iv ba
msg =
    k -> ba -> ba
forall cipher ba.
(BlockCipher cipher, ByteArray ba) =>
cipher -> ba -> ba
ecbEncrypt k
k ba
msg ba -> ba -> ba
forall ba. ByteArray ba => ba -> ba -> ba
`bxor` ba
iv ba -> ba -> ba
forall ba. ByteArray ba => ba -> ba -> ba
`bxor` ba
msg
  where
    k :: k
k = ba -> k
g ba
iv

bxor :: ByteArray ba => ba -> ba -> ba
bxor :: ba -> ba -> ba
bxor = ba -> ba -> ba
forall a b c.
(ByteArrayAccess a, ByteArrayAccess b, ByteArray c) =>
a -> b -> c
B.xor