-- |
-- Module      : Crypto.Store.KeyWrap.AES
-- License     : BSD-style
-- Maintainer  : Olivier Chéron <olivier.cheron@gmail.com>
-- Stability   : experimental
-- Portability : unknown
--
-- AES Key Wrap (<https://tools.ietf.org/html/rfc3394 RFC 3394>) and Extended
-- Key Wrap (<https://tools.ietf.org/html/rfc5649 RFC 5649>)
--
-- Should be used with a cipher from module "Crypto.Cipher.AES".
{-# LANGUAGE BangPatterns #-}
module Crypto.Store.KeyWrap.AES
    ( wrap
    , unwrap
    , wrapPad
    , unwrapPad
    ) where

import           Data.Bits
import           Data.ByteArray (ByteArray, ByteArrayAccess, Bytes)
import qualified Data.ByteArray as B
import           Data.List
import           Data.Word

import Crypto.Cipher.Types

import Foreign.Storable

import Crypto.Store.Error
import Crypto.Store.Util

type Chunked ba = [ba]
type Pair ba = (ba, ba)

-- TODO: should use a low-level AES implementation to reduce allocations

aes' :: (BlockCipher aes, ByteArray ba) => aes -> Pair ba -> ba
aes' :: forall aes ba.
(BlockCipher aes, ByteArray ba) =>
aes -> Pair ba -> ba
aes' aes
cipher (ba
msb, ba
lsb) = forall cipher ba.
(BlockCipher cipher, ByteArray ba) =>
cipher -> ba -> ba
ecbEncrypt aes
cipher (forall bs. ByteArray bs => bs -> bs -> bs
B.append ba
msb ba
lsb)

aes :: (BlockCipher aes, ByteArray ba) => aes -> Pair ba -> Pair ba
aes :: forall aes ba.
(BlockCipher aes, ByteArray ba) =>
aes -> Pair ba -> Pair ba
aes aes
cipher = forall bs. ByteArray bs => Int -> bs -> (bs, bs)
B.splitAt Int
8 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall aes ba.
(BlockCipher aes, ByteArray ba) =>
aes -> Pair ba -> ba
aes' aes
cipher

aesrev' :: (BlockCipher aes, ByteArray ba) => aes -> ba -> Pair ba
aesrev' :: forall aes ba.
(BlockCipher aes, ByteArray ba) =>
aes -> ba -> Pair ba
aesrev' aes
cipher = forall bs. ByteArray bs => Int -> bs -> (bs, bs)
B.splitAt Int
8 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall cipher ba.
(BlockCipher cipher, ByteArray ba) =>
cipher -> ba -> ba
ecbDecrypt aes
cipher

aesrev :: (BlockCipher aes, ByteArray ba) => aes -> Pair ba -> Pair ba
aesrev :: forall aes ba.
(BlockCipher aes, ByteArray ba) =>
aes -> Pair ba -> Pair ba
aesrev aes
cipher (ba
msb, ba
lsb) = forall aes ba.
(BlockCipher aes, ByteArray ba) =>
aes -> ba -> Pair ba
aesrev' aes
cipher (forall bs. ByteArray bs => bs -> bs -> bs
B.append ba
msb ba
lsb)

wrapc :: (BlockCipher aes, ByteArray ba)
      => aes -> ba -> Chunked ba -> Chunked ba
wrapc :: forall aes ba.
(BlockCipher aes, ByteArray ba) =>
aes -> ba -> Chunked ba -> Chunked ba
wrapc aes
cipher ba
iiv Chunked ba
list = forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry (:) forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' forall {t}. ByteArray t => (t, [t]) -> Word64 -> (t, [t])
pass (ba
iiv, Chunked ba
list) [Word64
0 .. Word64
5]
  where
    !n :: Word64
n = forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall (t :: * -> *) a. Foldable t => t a -> Int
length Chunked ba
list)
    pass :: (t, [t]) -> Word64 -> (t, [t])
pass (t
a, [t]
l) Word64
j = forall {t}. ByteArray t => t -> Word64 -> [t] -> (t, [t])
go t
a (Word64
n forall a. Num a => a -> a -> a
* Word64
j forall a. Num a => a -> a -> a
+ Word64
1) [t]
l
    go :: t -> Word64 -> [t] -> (t, [t])
go t
a !Word64
_ [] = (t
a, [])
    go t
a !Word64
i (t
r : [t]
rs) =
        let (t
a', t
t) = forall aes ba.
(BlockCipher aes, ByteArray ba) =>
aes -> Pair ba -> Pair ba
aes aes
cipher (t
a, t
r)
         in (t
t forall a. a -> [a] -> [a]
:) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> t -> Word64 -> [t] -> (t, [t])
go (forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> Word64 -> bout
xorWith t
a' Word64
i) (forall a. Enum a => a -> a
succ Word64
i) [t]
rs

unwrapc :: (BlockCipher aes, ByteArray ba)
        => aes -> Chunked ba -> Either StoreError (ba, Chunked ba)
unwrapc :: forall aes ba.
(BlockCipher aes, ByteArray ba) =>
aes -> Chunked ba -> Either StoreError (ba, Chunked ba)
unwrapc aes
_      []         = forall a b. a -> Either a b
Left (String -> StoreError
InvalidInput String
"KeyWrap.AES: input too short")
unwrapc aes
cipher (ba
iv:[ba]
list)  = forall a b. b -> Either a b
Right (ba
iiv, forall a. [a] -> [a]
reverse [ba]
out)
  where
    (ba
iiv, [ba]
out) = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' forall {t}. ByteArray t => (t, [t]) -> Word64 -> (t, [t])
pass (ba
iv, forall a. [a] -> [a]
reverse [ba]
list) (forall a. [a] -> [a]
reverse [Word64
0 .. Word64
5])
    !n :: Word64
n = forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall (t :: * -> *) a. Foldable t => t a -> Int
length [ba]
list)
    pass :: (t, [t]) -> Word64 -> (t, [t])
pass (t
a, [t]
l) Word64
j = forall {t}. ByteArray t => t -> Word64 -> [t] -> (t, [t])
go t
a (Word64
n forall a. Num a => a -> a -> a
* Word64
j forall a. Num a => a -> a -> a
+ Word64
n) [t]
l
    go :: t -> Word64 -> [t] -> (t, [t])
go t
a !Word64
_ [] = (t
a, [])
    go t
a !Word64
i (t
r : [t]
rs) =
        let (t
a', t
t) = forall aes ba.
(BlockCipher aes, ByteArray ba) =>
aes -> Pair ba -> Pair ba
aesrev aes
cipher (forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> Word64 -> bout
xorWith t
a Word64
i, t
r)
         in (t
t forall a. a -> [a] -> [a]
:) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> t -> Word64 -> [t] -> (t, [t])
go t
a' (forall a. Enum a => a -> a
pred Word64
i) [t]
rs

-- | Wrap a key with the specified AES cipher.
wrap :: (BlockCipher aes, ByteArray ba) => aes -> ba -> Either StoreError ba
wrap :: forall aes ba.
(BlockCipher aes, ByteArray ba) =>
aes -> ba -> Either StoreError ba
wrap aes
cipher ba
bs = forall ba. ByteArray ba => Chunked ba -> ba
unchunks forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall aes ba.
(BlockCipher aes, ByteArray ba) =>
aes -> ba -> Chunked ba -> Chunked ba
wrapc aes
cipher ba
iiv forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall ba. ByteArray ba => ba -> Either StoreError (Chunked ba)
chunks ba
bs
  where iiv :: ba
iiv = forall ba. ByteArray ba => Int -> Word8 -> ba
B.replicate Int
8 Word8
0xA6

-- | Unwrap an encrypted key with the specified AES cipher.
unwrap :: (BlockCipher aes, ByteArray ba) => aes -> ba -> Either StoreError ba
unwrap :: forall aes ba.
(BlockCipher aes, ByteArray ba) =>
aes -> ba -> Either StoreError ba
unwrap aes
cipher ba
bs = forall ba. ByteArray ba => Chunked ba -> ba
unchunks forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall {ba} {b}.
ByteArrayAccess ba =>
(ba, b) -> Either StoreError b
check forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall aes ba.
(BlockCipher aes, ByteArray ba) =>
aes -> Chunked ba -> Either StoreError (ba, Chunked ba)
unwrapc aes
cipher forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall ba. ByteArray ba => ba -> Either StoreError (Chunked ba)
chunks ba
bs)
  where
    check :: (ba, b) -> Either StoreError b
check (ba
iiv, b
out)
        | forall ba. ByteArrayAccess ba => Word8 -> ba -> Bool
constAllEq Word8
0xA6 ba
iiv = forall a b. b -> Either a b
Right b
out
        | Bool
otherwise           = forall a b. a -> Either a b
Left StoreError
BadChecksum

chunks :: ByteArray ba => ba -> Either StoreError (Chunked ba)
chunks :: forall ba. ByteArray ba => ba -> Either StoreError (Chunked ba)
chunks ba
bs | forall a. ByteArrayAccess a => a -> Bool
B.null ba
bs       = forall a b. b -> Either a b
Right []
          | forall ba. ByteArrayAccess ba => ba -> Int
B.length ba
bs forall a. Ord a => a -> a -> Bool
< Int
8 = forall a b. a -> Either a b
Left (String -> StoreError
InvalidInput String
"KeyWrap.AES: input is not multiple of 8 bytes")
          | Bool
otherwise       = let (ba
a, ba
b) = forall bs. ByteArray bs => Int -> bs -> (bs, bs)
B.splitAt Int
8 ba
bs in (ba
a forall a. a -> [a] -> [a]
:) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall ba. ByteArray ba => ba -> Either StoreError (Chunked ba)
chunks ba
b

unchunks :: ByteArray ba => Chunked ba -> ba
unchunks :: forall ba. ByteArray ba => Chunked ba -> ba
unchunks = forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
[bin] -> bout
B.concat

padMask :: Bytes
padMask :: Bytes
padMask = forall a. ByteArray a => [Word8] -> a
B.pack [Word8
0xA6, Word8
0x59, Word8
0x59, Word8
0xA6, Word8
0x00, Word8
0x00, Word8
0x00, Word8
0x00]

pad :: ByteArray ba => Int -> ba -> Either StoreError (Pair ba)
pad :: forall ba. ByteArray ba => Int -> ba -> Either StoreError (Pair ba)
pad Int
inlen ba
bs | Int
inlen  forall a. Eq a => a -> a -> Bool
== Int
0 = forall a b. a -> Either a b
Left (String -> StoreError
InvalidInput String
"KeyWrap.AES: input is empty")
             | Int
padlen forall a. Eq a => a -> a -> Bool
== Int
8 = forall a b. b -> Either a b
Right (ba
aiv, ba
bs)
             | Bool
otherwise   = forall a b. b -> Either a b
Right (ba
aiv, ba
bs forall bs. ByteArray bs => bs -> bs -> bs
`B.append` forall ba. ByteArray ba => Int -> ba
B.zero Int
padlen)
  where padlen :: Int
padlen = Int
8 forall a. Num a => a -> a -> a
- forall a. Integral a => a -> a -> a
mod Int
inlen Int
8
        aiv :: ba
aiv    = forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> Word64 -> bout
xorWith Bytes
padMask (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
inlen)

unpad :: ByteArray ba => Int -> Pair ba -> Either StoreError ba
unpad :: forall ba. ByteArray ba => Int -> Pair ba -> Either StoreError ba
unpad Int
inlen (ba
aiv, ba
b)
    | Bool
badlen         = forall a b. a -> Either a b
Left StoreError
BadChecksum
    | forall ba. ByteArrayAccess ba => Word8 -> ba -> Bool
constAllEq Word8
0 ba
p = forall a b. b -> Either a b
Right ba
bs
    | Bool
otherwise      = forall a b. a -> Either a b
Left StoreError
BadChecksum
  where aivlen :: Int
aivlen = forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall bx by.
(ByteArrayAccess bx, ByteArrayAccess by) =>
bx -> by -> Word64
unxor Bytes
padMask ba
aiv)
        badlen :: Bool
badlen = Int
inlen forall a. Ord a => a -> a -> Bool
< Int
aivlen forall a. Num a => a -> a -> a
+ Int
8 Bool -> Bool -> Bool
|| Int
inlen forall a. Ord a => a -> a -> Bool
>= Int
aivlen forall a. Num a => a -> a -> a
+ Int
16
        (ba
bs, ba
p) = forall bs. ByteArray bs => Int -> bs -> (bs, bs)
B.splitAt Int
aivlen ba
b

-- | Pad and wrap a key with the specified AES cipher.
wrapPad :: (BlockCipher aes, ByteArray ba) => aes -> ba -> Either StoreError ba
wrapPad :: forall aes ba.
(BlockCipher aes, ByteArray ba) =>
aes -> ba -> Either StoreError ba
wrapPad aes
cipher ba
bs = forall {b}. ByteArray b => (b, b) -> Either StoreError b
doWrap forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall ba. ByteArray ba => Int -> ba -> Either StoreError (Pair ba)
pad Int
inlen ba
bs
  where
    inlen :: Int
inlen = forall ba. ByteArrayAccess ba => ba -> Int
B.length ba
bs
    doWrap :: (b, b) -> Either StoreError b
doWrap (b
aiv, b
b)
        | Int
inlen forall a. Ord a => a -> a -> Bool
<= Int
8 = forall a b. b -> Either a b
Right forall a b. (a -> b) -> a -> b
$ forall aes ba.
(BlockCipher aes, ByteArray ba) =>
aes -> Pair ba -> ba
aes' aes
cipher (b
aiv, b
b)
        | Bool
otherwise  = forall ba. ByteArray ba => Chunked ba -> ba
unchunks forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall aes ba.
(BlockCipher aes, ByteArray ba) =>
aes -> ba -> Chunked ba -> Chunked ba
wrapc aes
cipher b
aiv forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall ba. ByteArray ba => ba -> Either StoreError (Chunked ba)
chunks b
b

-- | Unwrap and unpad an encrypted key with the specified AES cipher.
unwrapPad :: (BlockCipher aes, ByteArray ba) => aes -> ba -> Either StoreError ba
unwrapPad :: forall aes ba.
(BlockCipher aes, ByteArray ba) =>
aes -> ba -> Either StoreError ba
unwrapPad aes
cipher ba
bs = forall ba. ByteArray ba => Int -> Pair ba -> Either StoreError ba
unpad Int
inlen forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Either StoreError (ba, ba)
doUnwrap
  where
    inlen :: Int
inlen = forall ba. ByteArrayAccess ba => ba -> Int
B.length ba
bs
    doUnwrap :: Either StoreError (ba, ba)
doUnwrap
        | Int
inlen forall a. Eq a => a -> a -> Bool
== Int
16 = let (ba
aiv, ba
b) = forall aes ba.
(BlockCipher aes, ByteArray ba) =>
aes -> ba -> Pair ba
aesrev' aes
cipher ba
bs in forall a b. b -> Either a b
Right (ba
aiv, ba
b)
        | Bool
otherwise   = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall ba. ByteArray ba => Chunked ba -> ba
unchunks forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall aes ba.
(BlockCipher aes, ByteArray ba) =>
aes -> Chunked ba -> Either StoreError (ba, Chunked ba)
unwrapc aes
cipher forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall ba. ByteArray ba => ba -> Either StoreError (Chunked ba)
chunks ba
bs)

xorWith :: (ByteArrayAccess bin, ByteArray bout) => bin -> Word64 -> bout
xorWith :: forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> Word64 -> bout
xorWith bin
bs !Word64
i = forall bs1 bs2 p.
(ByteArrayAccess bs1, ByteArray bs2) =>
bs1 -> (Ptr p -> IO ()) -> bs2
B.copyAndFreeze bin
bs forall a b. (a -> b) -> a -> b
$ \Ptr Any
dst -> forall {t} {b}. (Bits t, Integral t) => Ptr b -> Int -> t -> IO ()
loop Ptr Any
dst Int
len Word64
i
  where !len :: Int
len = forall ba. ByteArrayAccess ba => ba -> Int
B.length bin
bs
        loop :: Ptr b -> Int -> t -> IO ()
loop Ptr b
_ Int
0 !t
_ = forall (m :: * -> *) a. Monad m => a -> m a
return ()
        loop Ptr b
_ Int
_ t
0  = forall (m :: * -> *) a. Monad m => a -> m a
return () -- return early (constant-time not needed)
        loop Ptr b
p Int
n t
j  = do
            Word8
b <- forall a b. Storable a => Ptr b -> Int -> IO a
peekByteOff Ptr b
p (Int
n forall a. Num a => a -> a -> a
- Int
1)
            let mask :: Word8
mask = forall a b. (Integral a, Num b) => a -> b
fromIntegral t
j :: Word8
            forall a b. Storable a => Ptr b -> Int -> a -> IO ()
pokeByteOff Ptr b
p (Int
n forall a. Num a => a -> a -> a
- Int
1) (forall a. Bits a => a -> a -> a
xor Word8
b Word8
mask)
            Ptr b -> Int -> t -> IO ()
loop Ptr b
p (Int
n forall a. Num a => a -> a -> a
- Int
1) (forall a. Bits a => a -> Int -> a
shiftR t
j Int
8)

unxor :: (ByteArrayAccess bx, ByteArrayAccess by) => bx -> by -> Word64
unxor :: forall bx by.
(ByteArrayAccess bx, ByteArrayAccess by) =>
bx -> by -> Word64
unxor bx
x by
y = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' forall {a} {a}. (Bits a, Integral a, Num a) => a -> a -> a
f Word64
0 forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall a. Bits a => a -> a -> a
xor (forall a. ByteArrayAccess a => a -> [Word8]
B.unpack bx
x) (forall a. ByteArrayAccess a => a -> [Word8]
B.unpack by
y)
  where f :: a -> a -> a
f a
acc a
z = forall a. Bits a => a -> Int -> a
shiftL a
acc Int
8 forall a. Num a => a -> a -> a
+ forall a b. (Integral a, Num b) => a -> b
fromIntegral a
z