-- | A pure interface to AES
module Codec.Crypto.SimpleAES
    ( Mode(..)
    , Direction(..)
    , Key, IV
    , newIV, randomKey, crypt
    , encryptMsg, encryptMsg', decryptMsg, decryptMsg'
    ) where

import qualified Codec.Crypto.AES.IO as AES
import Codec.Crypto.AES.IO(Mode(..), Direction(..))
import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as BL
import Control.Monad
import System.IO.Unsafe
import Control.Monad.ST
import Data.Binary

import System.Random.MWC

newSeed :: IO Seed
newSeed = withSystemRandom (\x -> save x :: IO Seed)

type Key = B.ByteString
type IV = B.ByteString

ivLength :: Num a => a
ivLength = 16

newIV :: IO IV
newIV = do seed <- newSeed
           let (iv, seed') = randomByteString seed ivLength
           return iv

randomKey :: IO Key
randomKey = do seed <- newSeed
               let (key, seed') = randomByteString seed 32
               return key


-- Properties:
--   decryptMsg mode key . encryptMsg mode key == id
--   x == y => encryptMsg mode key x == encryptMsg mode key y
-- | Encrypt a bytestring using a random seed (IV). Since the
--   seed is random, this function is tainted by IO.
encryptMsg :: Mode -> Key -> BL.ByteString -> IO BL.ByteString
encryptMsg mode key bs
    = do iv <- newIV
         return $ encryptMsg' mode key iv bs

-- | Encrypt a ByteString using a given seed (IV).
--   The resulting ByteString contains both the seed and the original
--   length of the input (before padding).
encryptMsg' :: Mode -> Key -> IV -> BL.ByteString -> BL.ByteString
encryptMsg' mode key iv bs
    = unsafePerformIO $
      do ctx <- AES.newCtx mode key iv Encrypt
         chunks <- unsafeInterleaveIO $ lazyCrypt ctx (repack bs)
         let encrypted = BL.fromChunks chunks
         return $ BL.concat [ BL.fromChunks [iv]
                            , encode (BL.length bs)
                            , encrypted ]

decryptMsg :: Mode -> Key -> BL.ByteString -> BL.ByteString
decryptMsg mode key bs = either error id $ decryptMsg' mode key bs

decryptMsg' :: Mode -> Key -> BL.ByteString -> Either String BL.ByteString
decryptMsg' mode key bs
    = check (BL.length iv == ivLength) "Codec.Crypto.SimpleAES.decryptMsg: Invalid IV length. Message garbled."
    $ check (BL.length lenStr == 8) "Codec.Crypto.SimpleAES.decryptMsg: Invalid encoding. Message garbled."
    $ check (BL.length encrypted >= len) "Codec.Crypto.SimpleAES.decryptMsg: Invalid size. Message garbled."
    $ check (BL.length encrypted `mod` 16 == 0) "Codec.Crypto.SimpleAES.decryptMsg: Invalid padding. Message garbled."
    $ Right $ unsafePerformIO $ do
         ctx <- AES.newCtx mode key (B.concat $ BL.toChunks iv) Decrypt
         chunks <- unsafeInterleaveIO $ lazyCrypt ctx (repack encrypted)
         return $ BL.take len (BL.fromChunks chunks)
    where (iv, bs') = BL.splitAt ivLength bs
          (lenStr,encrypted) = BL.splitAt 8 bs'
          len = decode lenStr
          check :: Bool -> String -> Either String a -> Either String a
          check False s _ = Left s
          check True _ f = f



-- | Encryption/decryption for lazy bytestrings. The input string is zero-padded to
--   a multiple of 16. It is your obligation to separate encode the length of the string.
--
--   Properties:
--     x == y => crypt mode key iv dir x == crypt mode key iv dir y
--     take (length x) (crypt mode key iv Decrypt (crypt mode key iv Encrypt x)) == x
crypt :: Mode
        -> Key -- ^ The AES key - 16, 24 or 32 bytes
        -> IV
        -> Direction 
        -> BL.ByteString -- ^ Bytestring to encrypt/decrypt
        -> BL.ByteString
crypt mode key iv dir bs
    = unsafePerformIO $
      do ctx <- AES.newCtx mode key iv dir
         chunks <- unsafeInterleaveIO $ lazyCrypt ctx (repack bs)
         return $ BL.fromChunks chunks

lazyCrypt :: AES.AESCtx -> [B.ByteString] -> IO [B.ByteString]
lazyCrypt ctx [] = return []
lazyCrypt ctx (x:xs)
    = do x' <- AES.crypt ctx x
         xs' <- unsafeInterleaveIO $ lazyCrypt ctx xs
         return (x' : xs')

repack :: BL.ByteString -> [B.ByteString]
repack = filter (not . B.null) . worker . BL.toChunks
    where worker [] = []
          worker [x] = [pad x]
          worker (x:y:xs) = let d = B.length x `mod` 16
                                (xa,xb) = B.splitAt (B.length x - d) x
                            in if d == 0 then x:worker (y:xs) else xa:worker (xb `B.append` y:xs)
          pad x = let d = B.length x `mod` 16
                  in if d == 0 then x else x `B.append` B.replicate (16-d) 0

randomByteString :: Seed -> Int -> (B.ByteString, Seed)
randomByteString seed len
    = runST (do gen <- restore seed
                words <- replicateM len (uniform gen)
                seed' <- save gen
                return (B.pack words, seed'))