-- | -- Module : Crypto.Store.KeyWrap.AES -- License : BSD-style -- Maintainer : Olivier Chéron -- Stability : experimental -- Portability : unknown -- -- AES Key Wrap () and Extended -- Key Wrap () -- -- Should be used with a cipher from module "Crypto.Cipher.AES". module Crypto.Store.KeyWrap.AES ( wrap , unwrap , wrapPad , unwrapPad ) where import Data.Bits import Data.ByteArray (ByteArray, ByteArrayAccess, Bytes) import qualified Data.ByteArray as B import Data.List import Data.Word import Crypto.Cipher.Types import Foreign.Storable import Crypto.Store.Error import Crypto.Store.Util type Chunked ba = [ba] type Pair ba = (ba, ba) -- TODO: should use a low-level AES implementation to reduce allocations aes' :: (BlockCipher aes, ByteArray ba) => aes -> Pair ba -> ba aes' cipher (msb, lsb) = ecbEncrypt cipher (B.append msb lsb) aes :: (BlockCipher aes, ByteArray ba) => aes -> Pair ba -> Pair ba aes cipher = B.splitAt 8 . aes' cipher aesrev' :: (BlockCipher aes, ByteArray ba) => aes -> ba -> Pair ba aesrev' cipher = B.splitAt 8 . ecbDecrypt cipher aesrev :: (BlockCipher aes, ByteArray ba) => aes -> Pair ba -> Pair ba aesrev cipher (msb, lsb) = aesrev' cipher (B.append msb lsb) wrapc :: (BlockCipher aes, ByteArray ba) => aes -> ba -> Chunked ba -> Chunked ba wrapc cipher iiv list = uncurry (:) $ foldl' pass (iiv, list) [0 .. 5] where n = fromIntegral (length list) pass (a, l) j = mapAccumL f a $ zip [n * j + 1 ..] l f a (i, r) = let (msb, lsb) = aes cipher (a, r) in (xorWith msb i, lsb) unwrapc :: (BlockCipher aes, ByteArray ba) => aes -> Chunked ba -> Either StoreError (ba, Chunked ba) unwrapc _ [] = Left (InvalidInput "KeyWrap.AES: input too short") unwrapc cipher (iv:list) = Right (iiv, reverse out) where (iiv, out) = foldl' pass (iv, reverse list) (reverse [0 .. 5]) n = fromIntegral (length list) pass (a, l) j = mapAccumL f a $ zip (reverse [n * j + 1 .. n * j + n]) l f a (i, r) = aesrev cipher (xorWith a i, r) -- | Wrap a key with the specified AES cipher. wrap :: (BlockCipher aes, ByteArray ba) => aes -> ba -> Either StoreError ba wrap cipher bs = unchunks . wrapc cipher iiv <$> chunks bs where iiv = B.replicate 8 0xA6 -- | Unwrap an encrypted key with the specified AES cipher. unwrap :: (BlockCipher aes, ByteArray ba) => aes -> ba -> Either StoreError ba unwrap cipher bs = unchunks <$> (check =<< unwrapc cipher =<< chunks bs) where check (iiv, out) | constAllEq 0xA6 iiv = Right out | otherwise = Left BadChecksum chunks :: ByteArray ba => ba -> Either StoreError (Chunked ba) chunks bs | B.null bs = Right [] | B.length bs < 8 = Left (InvalidInput "KeyWrap.AES: input is not multiple of 8 bytes") | otherwise = let (a, b) = B.splitAt 8 bs in (a :) <$> chunks b unchunks :: ByteArray ba => Chunked ba -> ba unchunks = B.concat padMask :: Bytes padMask = B.pack [0xA6, 0x59, 0x59, 0xA6, 0x00, 0x00, 0x00, 0x00] pad :: ByteArray ba => Int -> ba -> Either StoreError (Pair ba) pad inlen bs | inlen == 0 = Left (InvalidInput "KeyWrap.AES: input is empty") | padlen == 8 = Right (aiv, bs) | otherwise = Right (aiv, bs `B.append` B.zero padlen) where padlen = 8 - mod inlen 8 aiv = xorWith padMask (fromIntegral inlen) unpad :: ByteArray ba => Int -> Pair ba -> Either StoreError ba unpad inlen (aiv, b) | badlen = Left BadChecksum | constAllEq 0 p = Right bs | otherwise = Left BadChecksum where aivlen = fromIntegral (unxor padMask aiv) badlen = inlen < aivlen + 8 || inlen >= aivlen + 16 (bs, p) = B.splitAt aivlen b -- | Pad and wrap a key with the specified AES cipher. wrapPad :: (BlockCipher aes, ByteArray ba) => aes -> ba -> Either StoreError ba wrapPad cipher bs = doWrap =<< pad inlen bs where inlen = B.length bs doWrap (aiv, b) | inlen <= 8 = Right $ aes' cipher (aiv, b) | otherwise = unchunks . wrapc cipher aiv <$> chunks b -- | Unwrap and unpad an encrypted key with the specified AES cipher. unwrapPad :: (BlockCipher aes, ByteArray ba) => aes -> ba -> Either StoreError ba unwrapPad cipher bs = unpad inlen =<< doUnwrap where inlen = B.length bs doUnwrap | inlen == 16 = let (aiv, b) = aesrev' cipher bs in Right (aiv, b) | otherwise = fmap unchunks <$> (unwrapc cipher =<< chunks bs) xorWith :: (ByteArrayAccess bin, ByteArray bout) => bin -> Word64 -> bout xorWith bs i = B.copyAndFreeze bs $ \dst -> loop dst len i where len = B.length bs loop _ 0 _ = return () loop _ _ 0 = return () -- return early (constant-time not needed) loop p n j = do b <- peekByteOff p (n - 1) let mask = fromIntegral j :: Word8 pokeByteOff p (n - 1) (xor b mask) loop p (n - 1) (shiftR j 8) unxor :: (ByteArrayAccess bx, ByteArrayAccess by) => bx -> by -> Word64 unxor x y = foldl' f 0 $ zipWith xor (B.unpack x) (B.unpack y) where f acc z = shiftL acc 8 + fromIntegral z