module Crypto.Cipher.AES
(
Key
, IV(..)
, initKey
, genCTR
, encryptECB
, encryptCBC
, encryptCTR
, encryptXTS
, encryptGCM
, decryptECB
, decryptCBC
, decryptCTR
, decryptXTS
, decryptGCM
) where
import Data.Word
import Foreign.Ptr
import Foreign.ForeignPtr
import Foreign.Storable
import Foreign.C.Types
import Foreign.C.String
import Foreign.Marshal.Alloc
import Data.ByteString.Internal
import Data.ByteString.Unsafe
import qualified Data.ByteString as B
import System.IO.Unsafe (unsafePerformIO)
newtype Key = Key ByteString
newtype IV = IV ByteString
newtype GCM = GCM ByteString
sizeGCM :: Int
sizeGCM = 540
instance Storable GCM where
sizeOf _ = sizeGCM
alignment _ = 16
poke ptr (GCM b) = unsafeUseAsCString b (\cs -> memcpy (castPtr ptr) (castPtr cs) (fromIntegral sizeGCM))
peek ptr = create sizeGCM (\bptr -> memcpy bptr (castPtr ptr) (fromIntegral sizeGCM)) >>= return . GCM
keyToPtr :: Key -> (Ptr Key -> IO a) -> IO a
keyToPtr (Key b) f = unsafeUseAsCString b (f . castPtr)
ivToPtr :: IV -> (Ptr IV -> IO a) -> IO a
ivToPtr (IV b) f = unsafeUseAsCString b (f . castPtr)
withKeyAndIV :: Key -> IV -> (Ptr Key -> Ptr IV -> IO a) -> IO a
withKeyAndIV key iv f = keyToPtr key $ \kptr -> ivToPtr iv $ \ivp -> f kptr ivp
withKey2AndIV :: Key -> Key -> IV -> (Ptr Key -> Ptr Key -> Ptr IV -> IO a) -> IO a
withKey2AndIV key1 key2 iv f =
keyToPtr key1 $ \kptr1 -> keyToPtr key2 $ \kptr2 -> ivToPtr iv $ \ivp -> f kptr1 kptr2 ivp
initKey :: ByteString -> Key
initKey b@(B.length -> len)
| len == 16 = doInit 10
| len == 24 = doInit 12
| len == 32 = doInit 14
| otherwise = error "wrong key size: need to be 16, 24 or 32 bytes."
where doInit nbR = unsafePerformIO $ unsafeUseAsCString b (allocAndFill nbR)
allocAndFill nbR ikey = do
ptr <- mallocBytes (16+2*2*16*nbR)
c_aes_init ptr (castPtr ikey) (fromIntegral len)
fptr <- newForeignPtr c_free_finalizer (castPtr ptr)
return $ Key $ fromForeignPtr fptr 0 (16+2*2*16*nbR)
encryptECB :: Key -> ByteString -> ByteString
encryptECB = doECB c_aes_encrypt_ecb
encryptCBC :: Key -> IV -> ByteString -> ByteString
encryptCBC = doCBC c_aes_encrypt_cbc
genCTR :: Key
-> IV
-> Int
-> ByteString
genCTR key iv len = unsafeCreate (nbBlocks * 16) generate
where
generate o = withKeyAndIV key iv $ \k i -> c_aes_gen_ctr (castPtr o) k i (fromIntegral nbBlocks)
(nbBlocks',r) = len `divMod` 16
nbBlocks = if r == 0 then nbBlocks' else nbBlocks' + 1
encryptCTR :: Key -> IV -> ByteString -> ByteString
encryptCTR key iv input = unsafeCreate len doEncrypt
where doEncrypt o = withKeyAndIV key iv $ \k v -> unsafeUseAsCString input $ \i ->
c_aes_encrypt_ctr (castPtr o) k v i (fromIntegral len)
len = B.length input
encryptGCM :: Key
-> IV
-> ByteString
-> ByteString
-> (ByteString, ByteString)
encryptGCM = doGCM gcmAppendEncrypt
encryptXTS :: (Key,Key) -> IV -> Word32 -> ByteString -> ByteString
encryptXTS = doXTS c_aes_encrypt_xts
decryptECB :: Key -> ByteString -> ByteString
decryptECB = doECB c_aes_decrypt_ecb
decryptCBC :: Key -> IV -> ByteString -> ByteString
decryptCBC = doCBC c_aes_decrypt_cbc
decryptCTR :: Key -> IV -> ByteString -> ByteString
decryptCTR = encryptCTR
decryptXTS :: (Key,Key) -> IV -> Word32 -> ByteString -> ByteString
decryptXTS = doXTS c_aes_decrypt_xts
decryptGCM :: Key -> IV -> ByteString -> ByteString -> (ByteString, ByteString)
decryptGCM = doGCM gcmAppendDecrypt
doECB :: (Ptr b -> Ptr Key -> CString -> CUInt -> IO ())
-> Key -> ByteString -> ByteString
doECB f key input
| r /= 0 = error "cannot use with non multiple of block size"
| otherwise = unsafeCreate len $ \o -> keyToPtr key $ \k -> unsafeUseAsCString input $ \i ->
f (castPtr o) k i (fromIntegral nbBlocks)
where (nbBlocks, r) = len `divMod` 16
len = (B.length input)
doCBC :: (Ptr b -> Ptr Key -> Ptr IV -> CString -> CUInt -> IO ())
-> Key -> IV -> ByteString -> ByteString
doCBC f key iv input
| r /= 0 = error "cannot use with non multiple of block size"
| otherwise = unsafeCreate len $ \o -> withKeyAndIV key iv $ \k v -> unsafeUseAsCString input $ \i ->
f (castPtr o) k v i (fromIntegral nbBlocks)
where (nbBlocks, r) = len `divMod` 16
len = (B.length input)
doXTS :: (Ptr b -> Ptr Key -> Ptr Key -> Ptr IV -> CUInt -> CString -> CUInt -> IO ())
-> (Key, Key) -> IV -> Word32 -> ByteString -> ByteString
doXTS f (key1,key2) iv spoint input
| r /= 0 = error "cannot use with non multiple of block size (yet)"
| otherwise = unsafeCreate len $ \o -> withKey2AndIV key1 key2 iv $ \k1 k2 v -> unsafeUseAsCString input $ \i ->
f (castPtr o) k1 k2 v (fromIntegral spoint) i (fromIntegral nbBlocks)
where (nbBlocks, r) = len `divMod` 16
len = (B.length input)
doGCM :: (GCM -> ByteString -> (ByteString, GCM)) -> Key -> IV -> ByteString -> ByteString -> (ByteString, ByteString)
doGCM f key iv aad input = (cipher, tag)
where
tag = gcmFinish after 16
(cipher, after) = f afterAAD input
afterAAD = gcmAppendAAD ini aad
ini = gcmInit key iv
allocaFrom :: Storable a => a -> (Ptr a -> IO b) -> IO b
allocaFrom z f = alloca $ \ptr -> poke ptr z >> f ptr
gcmInit :: Key -> IV -> GCM
gcmInit key iv@(IV b) = unsafePerformIO $ alloca doInit
where doInit gcm = withKeyAndIV key iv (\k v -> c_aes_gcm_init gcm k v (fromIntegral $ B.length b)) >> peek gcm
gcmAppendAAD :: GCM -> ByteString -> GCM
gcmAppendAAD gcm input = unsafePerformIO $ allocaFrom gcm doAppend
where doAppend p = do
unsafeUseAsCString input $ \i -> c_aes_gcm_aad p i (fromIntegral $ B.length input)
peek p
gcmAppendEncrypt :: GCM -> ByteString -> (ByteString, GCM)
gcmAppendEncrypt gcm input = unsafePerformIO $ allocaFrom gcm doEnc
where len = B.length input
doEnc p = do
output <- create len $ \o -> unsafeUseAsCString input $ \i -> c_aes_gcm_encrypt (castPtr o) p i (fromIntegral len)
ngcm <- peek p
return (output, ngcm)
gcmAppendDecrypt :: GCM -> ByteString -> (ByteString, GCM)
gcmAppendDecrypt gcm input = unsafePerformIO $ allocaFrom gcm doDec
where len = B.length input
doDec p = do
output <- create len $ \o -> unsafeUseAsCString input $ \i -> c_aes_gcm_decrypt (castPtr o) p i (fromIntegral len)
ngcm <- peek p
return (output, ngcm)
gcmFinish :: GCM -> Int -> ByteString
gcmFinish gcm taglen = B.take taglen (unsafeCreate 16 $ \t -> allocaFrom gcm (finish t))
where finish t p = c_aes_gcm_finish (castPtr t) p
foreign import ccall "aes.h aes_initkey"
c_aes_init :: Ptr Key -> CString -> CUInt -> IO ()
foreign import ccall "aes.h aes_encrypt_ecb"
c_aes_encrypt_ecb :: CString -> Ptr Key -> CString -> CUInt -> IO ()
foreign import ccall "aes.h aes_decrypt_ecb"
c_aes_decrypt_ecb :: CString -> Ptr Key -> CString -> CUInt -> IO ()
foreign import ccall "aes.h aes_encrypt_cbc"
c_aes_encrypt_cbc :: CString -> Ptr Key -> Ptr IV -> CString -> CUInt -> IO ()
foreign import ccall "aes.h aes_decrypt_cbc"
c_aes_decrypt_cbc :: CString -> Ptr Key -> Ptr IV -> CString -> CUInt -> IO ()
foreign import ccall "aes.h aes_encrypt_xts"
c_aes_encrypt_xts :: CString -> Ptr Key -> Ptr Key -> Ptr IV -> CUInt -> CString -> CUInt -> IO ()
foreign import ccall "aes.h aes_decrypt_xts"
c_aes_decrypt_xts :: CString -> Ptr Key -> Ptr Key -> Ptr IV -> CUInt -> CString -> CUInt -> IO ()
foreign import ccall "aes.h aes_gen_ctr"
c_aes_gen_ctr :: CString -> Ptr Key -> Ptr IV -> CUInt -> IO ()
foreign import ccall "aes.h aes_encrypt_ctr"
c_aes_encrypt_ctr :: CString -> Ptr Key -> Ptr IV -> CString -> CUInt -> IO ()
foreign import ccall "aes.h aes_gcm_init"
c_aes_gcm_init :: Ptr GCM -> Ptr Key -> Ptr IV -> CUInt -> IO ()
foreign import ccall "aes.h aes_gcm_aad"
c_aes_gcm_aad :: Ptr GCM -> CString -> CUInt -> IO ()
foreign import ccall "aes.h aes_gcm_encrypt"
c_aes_gcm_encrypt :: CString -> Ptr GCM -> CString -> CUInt -> IO ()
foreign import ccall "aes.h aes_gcm_decrypt"
c_aes_gcm_decrypt :: CString -> Ptr GCM -> CString -> CUInt -> IO ()
foreign import ccall "aes.h aes_gcm_finish"
c_aes_gcm_finish :: CString -> Ptr GCM -> IO ()