module Crypto.Modes
	( ecb, unEcb
	, cbc, unCbc
	, cfb, unCfb
	, ofb, unOfb
	, ecb', unEcb'
	, cbc', unCbc'
	, cfb', unCfb'
	, ofb', unOfb'
	, IV
	, getIV, getIVIO
	
	
	
	
	) where
import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as L
import Data.Serialize
import qualified Data.Serialize.Put as SP
import qualified Data.Serialize.Get as SG
import Data.Bits (xor)
import Data.Tagged
import Crypto.Classes
import Crypto.Random
import System.Crypto.Random (getEntropy)
import Control.Monad (liftM)
data IV k = IV { initializationVector :: B.ByteString } deriving (Eq, Ord, Show)
collect :: Int -> [B.ByteString] -> [B.ByteString]
collect 0 _ = []
collect _ [] = []
collect i (b:bs)
	| len < i  = b : collect (i  len) bs
        | len >= i = [B.take i b]
  where
  len = B.length b
chunkFor :: (BlockCipher k) => k -> L.ByteString -> [B.ByteString]
chunkFor k = go
  where
  blkSz = (blockSize `for` k) `div` 8
  blkSzI = fromIntegral blkSz
  go bs | L.length bs < blkSzI = []
        | otherwise            = let (blk,rest) = L.splitAt blkSzI bs in B.concat (L.toChunks blk) : go rest
chunkFor' :: (BlockCipher k) => k -> B.ByteString -> [B.ByteString]
chunkFor' k = go
  where
  blkSz = (blockSize `for` k) `div` 8
  go bs | B.length bs < blkSz = []
        | otherwise           = let (blk,rest) = B.splitAt blkSz bs in blk : go rest
zwp  a b = 
	let as = L.toChunks a
	    bs = L.toChunks b
	in L.fromChunks (go as bs)
  where
  go [] _ = []
  go _ [] = []
  go (a:as) (b:bs) =
	let l = min (B.length a) (B.length b)
	    (a',ar) = B.splitAt l a
	    (b',br) = B.splitAt l b
	    as' = if B.length ar == 0 then as else ar : as
	    bs' = if B.length br == 0 then bs else br : bs
	in (zwp' a' b') : go as' bs'
zwp' a = B.pack . B.zipWith xor a
cbc' :: BlockCipher k => k -> IV k -> B.ByteString -> (B.ByteString, IV k)
cbc' k (IV v) plaintext =
	let blks = chunkFor' k plaintext
	    (cts, iv) = go blks v
	in (B.concat cts, IV iv)
  where
  go [] iv = ([], iv)
  go (b:bs) iv =
	let c = encryptBlock k (zwp' iv b)
	    (cs, ivFinal) = go bs c
	in (c:cs, ivFinal)
unCbc' :: BlockCipher k => k -> IV k -> B.ByteString -> (B.ByteString, IV k)
unCbc' k (IV v) ciphertext =
	let blks = chunkFor' k ciphertext
	    (pts, iv) = go blks v
	in (B.concat pts, IV iv)
  where
  go [] iv = ([], iv)
  go (c:cs) iv =
	let p = zwp' (decryptBlock k c) iv
	    (ps, ivFinal) = go cs c
	in (p:ps, ivFinal)
cbc :: BlockCipher k => k -> IV k -> L.ByteString -> (L.ByteString, IV k)
cbc k (IV v) plaintext =
	let blks = chunkFor k plaintext
	    (cts, iv) = go blks v
	in (L.fromChunks cts, IV iv)
  where
  go [] iv = ([], iv)
  go (b:bs) iv =
	let c = encryptBlock k (zwp' iv b)
	    (cs, ivFinal) = go bs c
	in (c:cs, ivFinal)
unCbc :: BlockCipher k => k -> IV k -> L.ByteString -> (L.ByteString, IV k)
unCbc k (IV v) ciphertext =
	let blks = chunkFor k ciphertext
	    (pts, iv) = go blks v
	in (L.fromChunks pts, IV iv)
  where
  go [] iv = ([], iv)
  go (c:cs) iv =
	let p = zwp' (decryptBlock k c) iv
	    (ps, ivFinal) = go cs c
	in (p:ps, ivFinal)
ecb :: BlockCipher k => k -> L.ByteString -> L.ByteString
ecb k msg =
	let chunks = chunkFor k msg
	in L.fromChunks $ map (encryptBlock k) chunks
unEcb :: BlockCipher k => k -> L.ByteString -> L.ByteString
unEcb k msg =
	let chunks = chunkFor k msg
	in L.fromChunks $ map (decryptBlock k) chunks
ecb' :: BlockCipher k => k -> B.ByteString -> B.ByteString
ecb' k msg =
	let chunks = chunkFor' k msg
	in B.concat $ map (encryptBlock k) chunks
unEcb' :: BlockCipher k => k -> B.ByteString -> B.ByteString
unEcb' k ct =
	let chunks = chunkFor' k ct
	in B.concat $ map (decryptBlock k) chunks
cfb :: BlockCipher k => k -> IV k -> L.ByteString -> (L.ByteString, IV k)
cfb k (IV v) msg =
	let blks = chunkFor k msg
	    (cs,ivF) = go v blks
	in (L.fromChunks cs, IV ivF)
  where
  go iv [] = ([],iv)
  go iv (b:bs) =
	let c = zwp' (encryptBlock k iv) b
	    (cs,ivFinal) = go c bs
	in (c:cs, ivFinal)
unCfb :: BlockCipher k => k -> IV k -> L.ByteString -> (L.ByteString, IV k)
unCfb k (IV v) msg = 
	let blks = chunkFor k msg
	    (ps, ivF) = go v blks
	in (L.fromChunks ps, IV ivF)
  where
  go iv [] = ([], iv)
  go iv (b:bs) =
	let p = zwp' (encryptBlock k iv) b
	    (ps, ivF) = go b bs
	in (p:ps, ivF)
cfb' :: BlockCipher k => k -> IV k -> B.ByteString -> (B.ByteString, IV k)
cfb' k (IV v) msg =
	let blks = chunkFor' k msg
	    (cs,ivF) = go v blks
	in (B.concat cs, IV ivF)
  where
  go iv [] = ([],iv)
  go iv (b:bs) =
	let c = zwp' (encryptBlock k iv) b
	    (cs,ivFinal) = go c bs
	in (c:cs, ivFinal)
unCfb' :: BlockCipher k => k -> IV k -> B.ByteString -> (B.ByteString, IV k)
unCfb' k (IV v) msg =
	let blks = chunkFor' k msg
	    (ps, ivF) = go v blks
	in (B.concat ps, IV ivF)
  where
  go iv [] = ([], iv)
  go iv (b:bs) =
	let p = zwp' (encryptBlock k iv) b
	    (ps, ivF) = go b bs
	in (p:ps, ivF)
ofb :: BlockCipher k => k -> IV k -> L.ByteString -> (L.ByteString, IV k)
ofb = unOfb
unOfb :: BlockCipher k => k -> IV k -> L.ByteString -> (L.ByteString, IV k)
unOfb k (IV iv) msg =
	let ivStr = drop 1 (iterate (encryptBlock k) iv)
	    ivLen = fromIntegral (B.length iv)
	    newIV = IV . B.concat . L.toChunks . L.take ivLen . L.drop (L.length msg) . L.fromChunks $ ivStr
	in (zwp (L.fromChunks ivStr) msg, newIV)
ofb' :: BlockCipher k => k -> IV k -> B.ByteString -> (B.ByteString, IV k)
ofb' = unOfb'
unOfb' :: BlockCipher k => k -> IV k -> B.ByteString -> (B.ByteString, IV k)
unOfb' k (IV iv) msg =
	let ivStr = collect (B.length msg + ivLen) (drop 1 (iterate (encryptBlock k) iv))
	    ivLen = B.length iv
	    mLen = fromIntegral (B.length msg)
	    newIV = IV . B.concat . L.toChunks . L.take (fromIntegral ivLen) . L.drop mLen . L.fromChunks $ ivStr
	in (zwp' (B.concat ivStr) msg, newIV)
unfoldK :: (b -> Maybe (a,b)) -> b -> ([a],b)
unfoldK f i = 
	case (f i) of
		Nothing -> ([], i)
		Just (a,i') ->
			let (as, iF) = unfoldK f i'
			in (a:as, iF)
getIV :: (BlockCipher k, CryptoRandomGen g) => g -> Either GenError (IV k, g)
getIV g =
	let bytes = ivBlockSizeBytes iv
	    gen = genBytes bytes g
	    fromRight (Right x) = x
	    iv  = IV (fst  . fromRight $ gen)
	in case gen of
		Left err -> Left err
		Right (bs,g')
			| B.length bs == bytes	-> Right (iv, g')
			| otherwise		-> Left (GenErrorOther "Generator failed to provide requested number of bytes")
getIVIO :: (BlockCipher k) => IO (IV k)
getIVIO = do
	let p = Proxy
	    getTypedIV :: BlockCipher k => Proxy k -> IO (IV k)
	    getTypedIV pr = liftM IV (getEntropy (proxy blockSize pr `div` 8))
	iv <- getTypedIV p
	return (iv `asProxyTypeOf` ivProxy p)
ivProxy :: Proxy k -> Proxy (IV k)
ivProxy = reproxy
deIVProxy :: Proxy (IV k) -> Proxy k
deIVProxy = reproxy
proxyOf :: a -> Proxy a
proxyOf = const Proxy
ivBlockSizeBytes :: BlockCipher k => IV k -> Int
ivBlockSizeBytes iv =
	let p = deIVProxy (proxyOf iv)
	in proxy blockSize p `div` 8
instance (BlockCipher k) => Serialize (IV k) where
	get = do
		let p = Proxy
		    doGet :: BlockCipher k => Proxy k -> Get (IV k)
	            doGet pr = liftM IV (SG.getByteString (proxy blockSize pr `div` 8))
		iv <- doGet p
		return (iv `asProxyTypeOf` ivProxy p)
	put (IV iv) = SP.putByteString iv