-- | A pure interface to AES
module Codec.Crypto.SimpleAES(
  Mode(..), Direction(..), Key, IV, newIV, randomKey, crypt
  ) where

import qualified Codec.Crypto.AES.IO as AES
import Codec.Crypto.AES.IO(Mode(..), Direction(..))
import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as BL
import Control.Monad
import System.IO.Unsafe
import Control.Monad.ST

import System.Random.MWC

newSeed :: IO Seed
newSeed = withSystemRandom save

type Key = B.ByteString
type IV = B.ByteString

newIV :: IO IV
newIV = do seed <- newSeed
           let (iv, seed') = randomByteString seed 16
           return iv

randomKey :: IO Key
randomKey = do seed <- newSeed
               let (key, seed') = randomByteString seed 32
               return key

{-
-- Properties:
--   decryptMsg mode key . encryptMsg mode key == id
--   x == y => encryptMsg mode key x == encryptMsg mode key y
encryptMsg :: Mode -> Key -> BL.ByteString -> IO BL.ByteString
encryptMsg mode key bs
    = do iv <- newIV
         cxt <- AES.newCtx mode key iv Encrypt
         chunks <- unsafeInterleaveIO $ lazyCrypt ctx (repack bs)
-}


-- | Encryption/decryption for lazy bytestrings
--   Properties:
--     x == y => crypt mode key iv dir x == crypt mode key iv dir y
--     take (length x) (crypt mode key iv Decrypt (crypt mode key iv Encrypt x)) == x
crypt :: Mode
        -> Key -- ^ The AES key - 16, 24 or 32 bytes
        -> IV
        -> Direction 
        -> BL.ByteString -- ^ Bytestring to encrypt/decrypt
        -> BL.ByteString
crypt mode key iv dir bs
    = unsafePerformIO $
      do ctx <- AES.newCtx mode key iv dir
         chunks <- unsafeInterleaveIO $ lazyCrypt ctx (repack bs)
         return $ BL.fromChunks chunks

lazyCrypt :: AES.AESCtx -> [B.ByteString] -> IO [B.ByteString]
lazyCrypt ctx [] = return []
lazyCrypt ctx (x:xs)
    = do x' <- AES.crypt ctx x
         xs' <- unsafeInterleaveIO $ lazyCrypt ctx xs
         return (x' : xs')

repack :: BL.ByteString -> [B.ByteString]
repack = filter (not . B.null) . worker . BL.toChunks
    where worker [] = []
          worker [x] = [pad x]
          worker (x:y:xs) = let d = B.length x `mod` 16
                                (xa,xb) = B.splitAt (B.length x - d) x
                            in if d == 0 then x:worker (y:xs) else xa:worker (xb `B.append` y:xs)
          pad x = let d = B.length x `mod` 16
                  in if d == 0 then x else x `B.append` B.replicate (16-d) 0

randomByteString :: Seed -> Int -> (B.ByteString, Seed)
randomByteString seed len
    = runST (do gen <- restore seed
                words <- replicateM len (uniform gen)
                seed' <- save gen
                return (B.pack words, seed'))