module Crypto.Classes
        (
        
          Hash(..)
        , hashFunc'
        , hashFunc
        
        , BlockCipher(..)
        , blockSizeBytes
        , keyLengthBytes
        , buildKeyIO
        , buildKeyGen
        , StreamCipher(..)
        , buildStreamKeyIO
        , buildStreamKeyGen
        , AsymCipher(..)
        , buildKeyPairIO
        , buildKeyPairGen
        , Signing(..)
        , buildSigningKeyPairIO
        , buildSigningKeyPairGen
        
        , encode
        , zeroIV
        , incIV
        , getIV, getIVIO
        , chunkFor, chunkFor'
        , module Crypto.Util
        , module Crypto.Types
        ) where
import Data.Data
import Data.Typeable
import Data.Serialize
import qualified Data.Serialize.Get as SG
import qualified Data.Serialize.Put as SP
import qualified Data.ByteString.Lazy as L
import qualified Data.ByteString as B
import qualified Data.ByteString.Internal as I
import Data.ByteString.Unsafe (unsafeUseAsCStringLen)
import Control.Monad.Trans.Class (lift)
import Control.Monad.Trans.State (StateT(..), runStateT)
import Control.Monad (liftM)
import Data.Bits
import Data.List (foldl', genericDrop)
import Data.Word (Word8, Word16, Word64)
import Data.Tagged
import Data.Proxy
import Crypto.Types
import Crypto.Random
import Crypto.Util
import System.IO.Unsafe (unsafePerformIO)
import Foreign (Ptr)
import Foreign.C (CChar(..), CInt(..))
import System.Entropy
import  Crypto.Modes
class (Serialize d, Eq d, Ord d)
    => Hash ctx d | d -> ctx, ctx -> d where
  outputLength  :: Tagged d BitLength         
  blockLength   :: Tagged d BitLength         
  initialCtx    :: ctx                        
  updateCtx     :: ctx -> B.ByteString -> ctx 
                                              
  finalize      :: ctx -> B.ByteString -> d   
  
  hash :: (Hash ctx d) => L.ByteString -> d
  hash msg = res
    where
    res = finalize ctx end
    ctx = foldl' updateCtx initialCtx blks
    (blks,end) = makeBlocks msg blockLen
    blockLen = (blockLength .::. res) `div` 8
  
  hash' :: (Hash ctx d) => B.ByteString -> d
  hash' msg = res
    where
    res = finalize (updateCtx initialCtx top) end
    (top, end) = B.splitAt remlen msg
    remlen = B.length msg  (B.length msg `rem` bLen)
    bLen = blockLength `for` res `div` 8
hashFunc :: Hash c d => d -> (L.ByteString -> d)
hashFunc d = f
  where
  f = hash
  a = f undefined `asTypeOf` d
hashFunc' :: Hash c d => d -> (B.ByteString -> d)
hashFunc' d = f
  where
  f = hash'
  a = f undefined `asTypeOf` d
makeBlocks :: L.ByteString -> ByteLength -> ([B.ByteString], B.ByteString)
makeBlocks msg len = go (L.toChunks msg)
  where
  go [] = ([],B.empty)
  go (x:xs)
    | B.length x >= len =
        let l = B.length x  B.length x `rem` len
            (top,end) = B.splitAt l x
            (rest,trueEnd) = go (end:xs)
        in (top:rest, trueEnd)
    | otherwise =
        case xs of
                [] -> ([], x)
                (a:as) -> go (B.append x a : as)
class ( Serialize k) => BlockCipher k where
  blockSize     :: Tagged k BitLength                   
  encryptBlock  :: k -> B.ByteString -> B.ByteString    
  decryptBlock  :: k -> B.ByteString -> B.ByteString    
  buildKey      :: B.ByteString -> Maybe k              
  keyLength     :: Tagged k BitLength                   
  
  
  ecb           :: k -> B.ByteString -> B.ByteString
  ecb = modeEcb'
  
  unEcb         :: k -> B.ByteString -> B.ByteString
  unEcb = modeUnEcb'
  
  cbc           :: k -> IV k -> B.ByteString -> (B.ByteString, IV k)
  cbc = modeCbc'
  
  unCbc         :: k -> IV k -> B.ByteString -> (B.ByteString, IV k)
  unCbc = modeUnCbc'
  
  ctr           :: k -> IV k -> B.ByteString -> (B.ByteString, IV k)
  ctr = modeCtr' incIV
  
  unCtr         :: k -> IV k -> B.ByteString -> (B.ByteString, IV k)
  unCtr = modeUnCtr' incIV
  
  ctrLazy           :: k -> IV k -> L.ByteString -> (L.ByteString, IV k)
  ctrLazy = modeCtr incIV
  
  unCtrLazy         :: k -> IV k -> L.ByteString -> (L.ByteString, IV k)
  unCtrLazy = modeUnCtr incIV
  
  cfb           :: k -> IV k -> B.ByteString -> (B.ByteString, IV k)
  cfb = modeCfb'
  
  unCfb         :: k -> IV k -> B.ByteString -> (B.ByteString, IV k)
  unCfb = modeUnCfb'
  
  ofb           :: k -> IV k -> B.ByteString -> (B.ByteString, IV k)
  ofb = modeOfb'
  
  unOfb         :: k -> IV k -> B.ByteString -> (B.ByteString, IV k)
  unOfb = modeUnOfb'
  
  cbcLazy       :: k -> IV k -> L.ByteString -> (L.ByteString, IV k)
  cbcLazy = modeCbc
  
  unCbcLazy     :: k -> IV k -> L.ByteString -> (L.ByteString, IV k)
  unCbcLazy = modeUnCbc
  
  
  
  
  sivLazy :: k -> k -> [L.ByteString] -> L.ByteString -> Maybe L.ByteString
  sivLazy = modeSiv
  
  
  
  
  unSivLazy :: k -> k -> [L.ByteString] -> L.ByteString -> Maybe L.ByteString
  unSivLazy = modeUnSiv
  
  
  
  
  siv :: k -> k -> [B.ByteString] -> B.ByteString -> Maybe B.ByteString
  siv = modeSiv'
  
  
  
  
  unSiv :: k -> k -> [B.ByteString] -> B.ByteString -> Maybe B.ByteString
  unSiv = modeUnSiv'
  
  ecbLazy :: k -> L.ByteString -> L.ByteString
  ecbLazy = modeEcb
  
  unEcbLazy :: k -> L.ByteString -> L.ByteString
  unEcbLazy = modeUnEcb
  
  
  cfbLazy :: k -> IV k -> L.ByteString -> (L.ByteString, IV k)
  cfbLazy = modeCfb
  
  
  unCfbLazy :: k -> IV k -> L.ByteString -> (L.ByteString, IV k)
  unCfbLazy = modeUnCfb
  
  ofbLazy  :: k -> IV k -> L.ByteString -> (L.ByteString, IV k)
  ofbLazy = modeOfb
  
  unOfbLazy :: k -> IV k -> L.ByteString -> (L.ByteString, IV k)
  unOfbLazy = modeUnOfb
modeOfb :: BlockCipher k => k -> IV k -> L.ByteString -> (L.ByteString, IV k)
modeOfb = modeUnOfb
modeUnOfb :: BlockCipher k => k -> IV k -> L.ByteString -> (L.ByteString, IV k)
modeUnOfb k (IV iv) msg =
        let ivStr = drop 1 (iterate (encryptBlock k) iv)
            ivLen = fromIntegral (B.length iv)
            newIV = IV . B.concat . L.toChunks . L.take ivLen . L.drop (L.length msg) . L.fromChunks $ ivStr
        in (zwp (L.fromChunks ivStr) msg, newIV)
modeCfb :: BlockCipher k => k -> IV k -> L.ByteString -> (L.ByteString, IV k)
modeCfb k (IV v) msg =
        let blks = chunkFor k msg
            (cs,ivF) = go v blks
        in (L.fromChunks cs, IV ivF)
  where
  go iv [] = ([],iv)
  go iv (b:bs) =
        let c = zwp' (encryptBlock k iv) b
            (cs,ivFinal) = go c bs
        in (c:cs, ivFinal)
modeUnCfb :: BlockCipher k => k -> IV k -> L.ByteString -> (L.ByteString, IV k)
modeUnCfb k (IV v) msg = 
        let blks = chunkFor k msg
            (ps, ivF) = go v blks
        in (L.fromChunks ps, IV ivF)
  where
  go iv [] = ([], iv)
  go iv (b:bs) =
        let p = zwp' (encryptBlock k iv) b
            (ps, ivF) = go b bs
        in (p:ps, ivF)
getIV :: (BlockCipher k, CryptoRandomGen g) => g -> Either GenError (IV k, g)
getIV g =
        let bytes = ivBlockSizeBytes iv
            gen = genBytes bytes g
            fromRight (Right x) = x
            iv  = IV (fst  . fromRight $ gen)
        in case gen of
                Left err -> Left err
                Right (bs,g')
                        | B.length bs == bytes  -> Right (iv, g')
                        | otherwise             -> Left (GenErrorOther "Generator failed to provide requested number of bytes")
getIVIO :: (BlockCipher k) => IO (IV k)
getIVIO = do
        let p = Proxy
            getTypedIV :: BlockCipher k => Proxy k -> IO (IV k)
            getTypedIV pr = liftM IV (getEntropy (proxy blockSize pr `div` 8))
        iv <- getTypedIV p
        return (iv `asProxyTypeOf` ivProxy p)
ivProxy :: Proxy k -> Proxy (IV k)
ivProxy = const Proxy
deIVProxy :: Proxy (IV k) -> Proxy k
deIVProxy = const Proxy
modeEcb :: BlockCipher k => k -> L.ByteString -> L.ByteString
modeEcb k msg =
        let chunks = chunkFor k msg
        in L.fromChunks $ map (encryptBlock k) chunks
modeUnEcb :: BlockCipher k => k -> L.ByteString -> L.ByteString
modeUnEcb k msg =
        let chunks = chunkFor k msg
        in L.fromChunks $ map (decryptBlock k) chunks
modeSiv :: BlockCipher k => k -> k -> [L.ByteString] -> L.ByteString -> Maybe L.ByteString
modeSiv k1 k2 xs m
    | length xs > bSizeb  1 = Nothing
    | otherwise = Just
                . L.append iv
                . fst
                . ctrLazy k2 (IV . sivMask . B.concat . L.toChunks $ iv)
                $ m
  where
       bSize = fromIntegral $ blockSizeBytes `for` k1
       bSizeb = fromIntegral $ blockSize `for` k1
       iv = cMacStar k1 $ xs ++ [m]
modeUnSiv :: BlockCipher k => k -> k -> [L.ByteString] -> L.ByteString -> Maybe L.ByteString
modeUnSiv k1 k2 xs c | length xs > bSizeb  1 = Nothing
                 | L.length c < fromIntegral bSize = Nothing
                 | iv /= (cMacStar k1 $ xs ++ [dm]) = Nothing
                 | otherwise = Just dm
  where
       bSize = fromIntegral $ blockSizeBytes `for` k1
       bSizeb = fromIntegral $ blockSize `for` k1
       (iv,m) = L.splitAt (fromIntegral bSize) c
       dm = fst $ modeUnCtr incIV k2 (IV $ sivMask $ B.concat $ L.toChunks iv) m
modeSiv' :: BlockCipher k => k -> k -> [B.ByteString] -> B.ByteString -> Maybe B.ByteString
modeSiv' k1 k2 xs m | length xs > bSizeb  1 = Nothing
                | otherwise = Just $ B.append iv $ fst $ Crypto.Classes.ctr k2 (IV $ sivMask iv) m
  where
       bSize = fromIntegral $ blockSizeBytes `for` k1
       bSizeb = fromIntegral $ blockSize `for` k1
       iv = cMacStar' k1 $ xs ++ [m]
modeUnSiv' :: BlockCipher k => k -> k -> [B.ByteString] -> B.ByteString -> Maybe B.ByteString
modeUnSiv' k1 k2 xs c | length xs > bSizeb  1 = Nothing
                  | B.length c < bSize = Nothing
                  | iv /= (cMacStar' k1 $ xs ++ [dm]) = Nothing
                  | otherwise = Just dm
  where
       bSize = fromIntegral $ blockSizeBytes `for` k1
       bSizeb = fromIntegral $ blockSize `for` k1
       (iv,m) = B.splitAt bSize c
       dm = fst $ Crypto.Classes.unCtr k2 (IV $ sivMask iv) m
modeCbc :: BlockCipher k => k -> IV k -> L.ByteString -> (L.ByteString, IV k)
modeCbc k (IV v) plaintext =
        let blks = chunkFor k plaintext
            (cts, iv) = go blks v
        in (L.fromChunks cts, IV iv)
  where
  go [] iv = ([], iv)
  go (b:bs) iv =
        let c = encryptBlock k (zwp' iv b)
            (cs, ivFinal) = go bs c
        in (c:cs, ivFinal)
modeUnCbc :: BlockCipher k => k -> IV k -> L.ByteString -> (L.ByteString, IV k)
modeUnCbc k (IV v) ciphertext =
        let blks = chunkFor k ciphertext
            (pts, iv) = go blks v
        in (L.fromChunks pts, IV iv)
  where
  go [] iv = ([], iv)
  go (c:cs) iv =
        let p = zwp' (decryptBlock k c) iv
            (ps, ivFinal) = go cs c
        in (p:ps, ivFinal)
modeCtr :: BlockCipher k => (IV k -> IV k) -> k -> IV k -> L.ByteString -> (L.ByteString, IV k)
modeCtr = modeUnCtr
modeUnCtr :: BlockCipher k => (IV k -> IV k) -> k -> IV k -> L.ByteString -> (L.ByteString, IV k)
modeUnCtr f k (IV iv) msg =
       let ivStr = iterate f $ IV iv
           ivLen = fromIntegral $ B.length iv
           newIV = head $ genericDrop ((ivLen  1 + L.length msg) `div` ivLen) ivStr
       in (zwp (L.fromChunks $ map (encryptBlock k) $ map initializationVector ivStr) msg, newIV)
blockSizeBytes :: (BlockCipher k) => Tagged k ByteLength
blockSizeBytes = fmap (`div` 8) blockSize
keyLengthBytes :: (BlockCipher k) => Tagged k ByteLength
keyLengthBytes = fmap (`div` 8) keyLength
buildKeyIO :: (BlockCipher k) => IO k
buildKeyIO = buildKeyM getEntropy fail
buildKeyGen :: (BlockCipher k, CryptoRandomGen g) => g -> Either GenError (k, g)
buildKeyGen = runStateT (buildKeyM (StateT . genBytes) (lift . Left . GenErrorOther))
buildKeyM :: (BlockCipher k, Monad m) => (Int -> m B.ByteString) -> (String -> m k) -> m k
buildKeyM getMore err = go (0::Int)
  where
  go 1000 = err "Tried 1000 times to generate a key from the system entropy.\
                \  No keys were returned! Perhaps the system entropy is broken\
                \ or perhaps the BlockCipher instance being used has a non-flat\
                \ keyspace."
  go i = do
    let bs = keyLength
    kd <- getMore ((7 + untag bs) `div` 8)
    case buildKey kd of
        Nothing -> go (i+1)
        Just k  -> return $ k `asTaggedTypeOf` bs
class AsymCipher p v | p -> v, v -> p where
  buildKeyPair :: CryptoRandomGen g => g -> BitLength -> Either GenError ((p,v),g) 
  encryptAsym      :: (CryptoRandomGen g) => g -> p -> B.ByteString -> Either GenError (B.ByteString, g) 
  decryptAsym      :: (CryptoRandomGen g) => g -> v -> B.ByteString -> Either GenError (B.ByteString, g) 
  publicKeyLength  :: p -> BitLength
  privateKeyLength :: v -> BitLength
buildKeyPairIO :: AsymCipher p v => BitLength -> IO (Either GenError (p,v))
buildKeyPairIO bl = do
        g <- newGenIO :: IO SystemRandom
        case buildKeyPair g bl of
                Left err -> return (Left err)
                Right (k,_) -> return (Right k)
buildKeyPairGen :: (CryptoRandomGen g, AsymCipher p v) => BitLength -> g -> Either GenError ((p,v),g)
buildKeyPairGen = flip buildKeyPair
class (Serialize k) => StreamCipher k iv | k -> iv where
  buildStreamKey        :: B.ByteString -> Maybe k
  encryptStream         :: k -> iv -> B.ByteString -> (B.ByteString, iv)
  decryptStream         :: k -> iv -> B.ByteString -> (B.ByteString, iv)
  streamKeyLength       :: Tagged k BitLength
buildStreamKeyIO :: StreamCipher k iv => IO k
buildStreamKeyIO = buildStreamKeyM getEntropy fail
buildStreamKeyGen :: (StreamCipher k iv, CryptoRandomGen g) => g -> Either GenError (k, g)
buildStreamKeyGen = runStateT (buildStreamKeyM (StateT . genBytes) (lift . Left . GenErrorOther))
buildStreamKeyM :: (Monad m, StreamCipher k iv) => (Int -> m B.ByteString) -> (String -> m k) -> m k
buildStreamKeyM getMore err = go (0::Int)
  where
  go 1000 = err "Tried 1000 times to generate a stream key from the system entropy.\
                \  No keys were returned! Perhaps the system entropy is broken\
                \ or perhaps the BlockCipher instance being used has a non-flat\
                \ keyspace."
  go i = do
    let k = streamKeyLength
    kd <- getMore ((untag k + 7) `div` 8)
    case buildStreamKey kd of
        Nothing -> go (i+1)
        Just k' -> return $ k' `asTaggedTypeOf` k
class (Serialize p, Serialize v) => Signing p v | p -> v, v -> p  where
  sign   :: CryptoRandomGen g => g -> v -> L.ByteString -> Either GenError (B.ByteString, g)
  verify :: p -> L.ByteString -> B.ByteString -> Bool
  buildSigningPair :: CryptoRandomGen g => g -> BitLength -> Either GenError ((p, v), g)
  signingKeyLength :: v -> BitLength
  verifyingKeyLength :: p -> BitLength
buildSigningKeyPairIO :: (Signing p v) => BitLength -> IO (Either GenError (p,v))
buildSigningKeyPairIO bl = do
        g <- newGenIO :: IO SystemRandom
        case buildSigningPair g bl of
                Left err -> return $ Left err
                Right (k,_) -> return $ Right k
buildSigningKeyPairGen :: (Signing p v, CryptoRandomGen g) => BitLength -> g -> Either GenError ((p, v), g)
buildSigningKeyPairGen = flip buildSigningPair
modeEcb' :: BlockCipher k => k -> B.ByteString -> B.ByteString
modeEcb' k msg =
        let chunks = chunkFor' k msg
        in B.concat $ map (encryptBlock k) chunks
modeUnEcb' :: BlockCipher k => k -> B.ByteString -> B.ByteString
modeUnEcb' k ct =
        let chunks = chunkFor' k ct
        in B.concat $ map (decryptBlock k) chunks
modeCbc' :: BlockCipher k => k -> IV k -> B.ByteString -> (B.ByteString, IV k)
modeCbc' k (IV v) plaintext =
        let blks = chunkFor' k plaintext
            (cts, iv) = go blks v
        in (B.concat cts, IV iv)
  where
  go [] iv = ([], iv)
  go (b:bs) iv =
        let c = encryptBlock k (zwp' iv b)
            (cs, ivFinal) = go bs c
        in (c:cs, ivFinal)
modeUnCbc' :: BlockCipher k => k -> IV k -> B.ByteString -> (B.ByteString, IV k)
modeUnCbc' k (IV v) ciphertext =
        let blks = chunkFor' k ciphertext
            (pts, iv) = go blks v
        in (B.concat pts, IV iv)
  where
  go [] iv = ([], iv)
  go (c:cs) iv =
        let p = zwp' (decryptBlock k c) iv
            (ps, ivFinal) = go cs c
        in (p:ps, ivFinal)
modeOfb' :: BlockCipher k => k -> IV k -> B.ByteString -> (B.ByteString, IV k)
modeOfb' = modeUnOfb'
modeUnOfb' :: BlockCipher k => k -> IV k -> B.ByteString -> (B.ByteString, IV k)
modeUnOfb' k (IV iv) msg =
        let ivStr = collect (B.length msg + ivLen) (drop 1 (iterate (encryptBlock k) iv))
            ivLen = B.length iv
            mLen = fromIntegral (B.length msg)
            newIV = IV . B.concat . L.toChunks . L.take (fromIntegral ivLen) . L.drop mLen . L.fromChunks $ ivStr
        in (zwp' (B.concat ivStr) msg, newIV)
modeCtr' :: BlockCipher k => (IV k -> IV k) -> k -> IV k -> B.ByteString -> (B.ByteString, IV k)
modeCtr' = modeUnCtr'
modeUnCtr' :: BlockCipher k => (IV k -> IV k) -> k -> IV k -> B.ByteString -> (B.ByteString, IV k)
modeUnCtr' f k iv msg =
       let fa (st,IV iv) c 
              | B.null st = fa (encryptBlock k iv, f (IV iv)) c
              | otherwise = let Just (s,nst) = B.uncons st in ((nst,IV iv),xor c s)
           ((_,newIV),res) = B.mapAccumL fa (B.empty,iv) msg 
       in (res,newIV)
modeCfb' :: BlockCipher k => k -> IV k -> B.ByteString -> (B.ByteString, IV k)
modeCfb' k (IV v) msg =
        let blks = chunkFor' k msg
            (cs,ivF) = go v blks
        in (B.concat cs, IV ivF)
  where
  go iv [] = ([],iv)
  go iv (b:bs) =
        let c = zwp' (encryptBlock k iv) b
            (cs,ivFinal) = go c bs
        in (c:cs, ivFinal)
modeUnCfb' :: BlockCipher k => k -> IV k -> B.ByteString -> (B.ByteString, IV k)
modeUnCfb' k (IV v) msg =
        let blks = chunkFor' k msg
            (ps, ivF) = go v blks
        in (B.concat ps, IV ivF)
  where
  go iv [] = ([], iv)
  go iv (b:bs) =
        let p = zwp' (encryptBlock k iv) b
            (ps, ivF) = go b bs
        in (p:ps, ivF)
toChunks :: Int -> B.ByteString -> [B.ByteString]
toChunks n val = go val
  where
  go b
    | B.length b == 0 = []
    | otherwise       = let (h,t) = B.splitAt n b
                        in h : go t
incIV :: BlockCipher k => IV k -> IV k
incIV (IV b) = IV $ snd $ B.mapAccumR (incw) 1 b
  where
       incw :: Word16 -> Word8 -> (Word16, Word8)
       incw i w = let nw=i+(fromIntegral w) in (shiftR nw 8, fromIntegral nw)
zeroIV :: (BlockCipher k) => IV k
zeroIV = iv
  where bytes = ivBlockSizeBytes iv
        iv  = IV $ B.replicate  bytes 0
zeroIVcwc :: BlockCipher k => IV k
zeroIVcwc = iv
  where bytes = ivBlockSizeBytes iv  5  
        iv    = IV $ B.replicate bytes 0
chunkFor :: (BlockCipher k) => k -> L.ByteString -> [B.ByteString]
chunkFor k = go
  where
  blkSz = (blockSize `for` k) `div` 8
  blkSzI = fromIntegral blkSz
  go bs | L.length bs < blkSzI = []
        | otherwise            = let (blk,rest) = L.splitAt blkSzI bs in B.concat (L.toChunks blk) : go rest
chunkFor' :: (BlockCipher k) => k -> B.ByteString -> [B.ByteString]
chunkFor' k = go
  where
  blkSz = (blockSize `for` k) `div` 8
  go bs | B.length bs < blkSz = []
        | otherwise           = let (blk,rest) = B.splitAt blkSz bs in blk : go rest
sivMask :: B.ByteString -> B.ByteString
sivMask b = snd $ B.mapAccumR (go) 0 b
  where
       go :: Int -> Word8 -> (Int,Word8)
       go 24 w = (32,clearBit w 7)
       go 56 w = (64,clearBit w 7)
       go n w = (n+8,w)
ivBlockSizeBytes :: BlockCipher k => IV k -> Int
ivBlockSizeBytes iv =
        let p = deIVProxy (proxyOf iv)
        in proxy blockSize p `div` 8
 where
  proxyOf :: a -> Proxy a
  proxyOf = const Proxy
instance (BlockCipher k) => Serialize (IV k) where
        get = do
                let p = Proxy
                    doGet :: BlockCipher k => Proxy k -> Get (IV k)
                    doGet pr = liftM IV (SG.getByteString (proxy blockSizeBytes pr))
                iv <- doGet p
                return (iv `asProxyTypeOf` ivProxy p)
        put (IV iv) = SP.putByteString iv