module Codec.Crypto.SimpleAES
( Mode(..)
, Direction(..)
, Key, IV
, newIV, randomKey, crypt
, encryptMsg, encryptMsg', decryptMsg, decryptMsg'
) where
import qualified Codec.Crypto.AES.IO as AES
import Codec.Crypto.AES.IO(Mode(..), Direction(..))
import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as BL
import Control.Monad
import System.IO.Unsafe
import Control.Monad.ST
import Data.Binary
import System.Random.MWC
newSeed :: IO Seed
newSeed = withSystemRandom (\x -> save x :: IO Seed)
type Key = B.ByteString
type IV = B.ByteString
ivLength :: Num a => a
ivLength = 16
newIV :: IO IV
newIV = do seed <- newSeed
let (iv, seed') = randomByteString seed ivLength
return iv
randomKey :: IO Key
randomKey = do seed <- newSeed
let (key, seed') = randomByteString seed 32
return key
encryptMsg :: Mode -> Key -> BL.ByteString -> IO BL.ByteString
encryptMsg mode key bs
= do iv <- newIV
return $ encryptMsg' mode key iv bs
encryptMsg' :: Mode -> Key -> IV -> BL.ByteString -> BL.ByteString
encryptMsg' mode key iv bs
= unsafePerformIO $
do ctx <- AES.newCtx mode key iv Encrypt
chunks <- unsafeInterleaveIO $ lazyCrypt ctx (repack bs)
let encrypted = BL.fromChunks chunks
return $ BL.concat [ BL.fromChunks [iv]
, encode (BL.length bs)
, encrypted ]
decryptMsg :: Mode -> Key -> BL.ByteString -> BL.ByteString
decryptMsg mode key bs = either error id $ decryptMsg' mode key bs
decryptMsg' :: Mode -> Key -> BL.ByteString -> Either String BL.ByteString
decryptMsg' mode key bs
= check (BL.length iv == ivLength) "Codec.Crypto.SimpleAES.decryptMsg: Invalid IV length. Message garbled."
$ check (BL.length lenStr == 8) "Codec.Crypto.SimpleAES.decryptMsg: Invalid encoding. Message garbled."
$ check (BL.length encrypted >= len) "Codec.Crypto.SimpleAES.decryptMsg: Invalid size. Message garbled."
$ check (BL.length encrypted `mod` 16 == 0) "Codec.Crypto.SimpleAES.decryptMsg: Invalid padding. Message garbled."
$ Right $ unsafePerformIO $ do
ctx <- AES.newCtx mode key (B.concat $ BL.toChunks iv) Decrypt
chunks <- unsafeInterleaveIO $ lazyCrypt ctx (repack encrypted)
return $ BL.take len (BL.fromChunks chunks)
where (iv, bs') = BL.splitAt ivLength bs
(lenStr,encrypted) = BL.splitAt 8 bs'
len = decode lenStr
check :: Bool -> String -> Either String a -> Either String a
check False s _ = Left s
check True _ f = f
crypt :: Mode
-> Key
-> IV
-> Direction
-> BL.ByteString
-> BL.ByteString
crypt mode key iv dir bs
= unsafePerformIO $
do ctx <- AES.newCtx mode key iv dir
chunks <- unsafeInterleaveIO $ lazyCrypt ctx (repack bs)
return $ BL.fromChunks chunks
lazyCrypt :: AES.AESCtx -> [B.ByteString] -> IO [B.ByteString]
lazyCrypt ctx [] = return []
lazyCrypt ctx (x:xs)
= do x' <- AES.crypt ctx x
xs' <- unsafeInterleaveIO $ lazyCrypt ctx xs
return (x' : xs')
repack :: BL.ByteString -> [B.ByteString]
repack = filter (not . B.null) . worker . BL.toChunks
where worker [] = []
worker [x] = [pad x]
worker (x:y:xs) = let d = B.length x `mod` 16
(xa,xb) = B.splitAt (B.length x d) x
in if d == 0 then x:worker (y:xs) else xa:worker (xb `B.append` y:xs)
pad x = let d = B.length x `mod` 16
in if d == 0 then x else x `B.append` B.replicate (16d) 0
randomByteString :: Seed -> Int -> (B.ByteString, Seed)
randomByteString seed len
= runST (do gen <- restore seed
words <- replicateM len (uniform gen)
seed' <- save gen
return (B.pack words, seed'))