module Network.CommSec.Package
(
OutContext(..)
, InContext(..)
, CommSecError(..)
, SequenceMode(..)
, newInContext, newOutContext
, decode
, encode
, decodePtr
, encodePtr
, encBytes, decBytes
, peekBE32
, pokeBE32
, peekBE
, pokeBE
) where
import Prelude hiding (seq)
import qualified Crypto.Cipher.AES128.Internal as AES
import Crypto.Cipher.AES128.Internal (AESKey)
import Crypto.Cipher.AES128 ()
import qualified Data.ByteString.Internal as B
import qualified Data.ByteString as B
import qualified Data.ByteString.Unsafe as B
import Crypto.Classes (buildKey)
import Data.ByteString (ByteString)
import Data.Bits
import Data.Maybe (fromMaybe)
import Data.Word
import Data.List
import Foreign.Ptr
import Foreign.ForeignPtr
import Foreign.Storable
import Foreign.Marshal.Alloc (allocaBytes)
import Foreign.Marshal.Utils (copyBytes)
import System.IO.Unsafe
import Data.Data
import Data.Typeable
import Control.Exception
import Network.CommSec.Types
import Network.CommSec.BitWindow
gPadMax,gBlockLen,gTagLen,gCtrSize :: Int
gPadMax = 16
gBlockLen = 16
gTagLen = 16
gCtrSize = 8
data OutContext =
Out { aesCtr :: !Word64
, saltOut :: !Word32
, outKey :: AESKey
}
data InContext
= In { bitWindow :: !BitWindow
, saltIn :: !Word32
, inKey :: AESKey
}
| InStrict
{ seqVal :: !Word64
, saltIn :: !Word32
, inKey :: AESKey
}
| InSequential
{ seqVal :: !Word64
, saltIn :: !Word32
, inKey :: AESKey
}
newOutContext :: ByteString -> OutContext
newOutContext bs
| B.length bs < 20 = error $ "Not enough entropy: " ++ show (B.length bs)
| otherwise =
let aesCtr = 1
saltOut = unsafePerformIO $ B.unsafeUseAsCString bs $ peekBE32 . castPtr
outKey = fromMaybe (error "Could not build a key") $ buildKey $ B.drop (sizeOf saltOut) bs
in Out {..}
newInContext :: ByteString -> SequenceMode -> InContext
newInContext bs md
| B.length bs < 20 = error $ "Not enough entropy: " ++ show (B.length bs)
| otherwise =
let bitWindow = zeroWindow
seqVal = 0
saltIn = unsafePerformIO $ B.unsafeUseAsCString bs $ peekBE32 . castPtr
inKey = fromMaybe (error "Could not build a key") $ buildKey $ B.drop (sizeOf saltIn) bs
in case md of
AllowOutOfOrder -> In {..}
StrictOrdering -> InStrict {..}
Sequential -> InSequential {..}
encryptGCM :: AESKey
-> Word64
-> Word32
-> ByteString
-> ByteString
encryptGCM key ctr salt pt = unsafePerformIO $ do
let ivLen = sizeOf ctr + sizeOf salt
tagLen = gTagLen
paddedLen = B.length pt
allocaBytes ivLen $ \ptrIV -> do
pokeBE32 ptrIV salt
pokeBE (ptrIV `plusPtr` sizeOf salt) ctr
B.unsafeUseAsCString pt $ \ptrPT -> do
B.create (paddedLen + sizeOf ctr + tagLen) $ \ctPtr -> do
pokeBE ctPtr ctr
let tagPtr = ctPtr' `plusPtr` paddedLen
ctPtr' = ctPtr `plusPtr` sizeOf ctr
AES.encryptGCM key ptrIV ivLen nullPtr 0 (castPtr ptrPT) (B.length pt) (castPtr ctPtr') tagPtr
encryptGCMPtr :: AESKey
-> Word64
-> Word32
-> Ptr Word8
-> Int
-> Ptr Word8
-> IO ()
encryptGCMPtr key ctr salt ptPtr ptLen ctPtr = do
let ivLen = sizeOf ctr + sizeOf salt
tagLen = gTagLen
paddedLen = ptLen
allocaBytes ivLen $ \ptrIV -> do
pokeBE32 ptrIV salt
pokeBE (ptrIV `plusPtr` sizeOf salt) ctr
pokeBE ctPtr ctr
let tagPtr = ctPtr' `plusPtr` paddedLen
ctPtr' = ctPtr `plusPtr` sizeOf ctr
AES.encryptGCM key ptrIV ivLen nullPtr 0 (castPtr ptPtr) ptLen (castPtr ctPtr') tagPtr
decryptGCMPtr :: AESKey
-> Word64
-> Word32
-> Ptr Word8
-> Int
-> Ptr Word8
-> Int
-> Ptr Word8
-> IO (Either CommSecError ())
decryptGCMPtr key ctr salt ctPtr ctLen tagPtr tagLen ptPtr
| tagLen /= gTagLen = return $ Left InvalidICV
| otherwise = do
let ivLen = sizeOf ctr + sizeOf salt
paddedLen = ctLen
allocaBytes ivLen $ \ptrIV -> allocaBytes tagLen $ \ctagPtr -> do
pokeBE32 ptrIV salt
pokeBE (ptrIV `plusPtr` sizeOf salt) ctr
AES.decryptGCM key ptrIV ivLen nullPtr 0 (castPtr ctPtr) paddedLen (castPtr ptPtr) ctagPtr
w1 <- peekBE ctagPtr
w2 <- peekBE (ctagPtr `plusPtr` sizeOf w1)
y1 <- peekBE (castPtr tagPtr)
y2 <- peekBE (castPtr tagPtr `plusPtr` sizeOf y1)
if (w1 /= y1 || w2 /= y2)
then return (Left InvalidICV)
else return (Right ())
decryptGCM :: AESKey
-> Word64
-> Word32
-> ByteString
-> ByteString
-> Either CommSecError ByteString
decryptGCM key ctr salt ct tag
| B.length tag < gTagLen = Left InvalidICV
| otherwise = unsafePerformIO $ do
let ivLen = sizeOf ctr + sizeOf salt
tagLen = gTagLen
paddedLen = B.length ct
allocaBytes ivLen $ \ptrIV -> allocaBytes tagLen $ \ctagPtr -> do
pokeBE32 ptrIV salt
pokeBE (ptrIV `plusPtr` sizeOf salt) ctr
B.unsafeUseAsCString tag $ \tagPtr -> do
B.unsafeUseAsCString ct $ \ptrCT -> do
pt <- B.create paddedLen $ \ptrPT -> do
AES.decryptGCM key ptrIV ivLen nullPtr 0 (castPtr ptrCT) (B.length ct) (castPtr ptrPT) ctagPtr
w1 <- peekBE ctagPtr
w2 <- peekBE (ctagPtr `plusPtr` sizeOf w1)
y1 <- peekBE (castPtr tagPtr)
y2 <- peekBE (castPtr tagPtr `plusPtr` sizeOf y1)
if (w1 /= y1 || w2 /= y2)
then return (Left InvalidICV)
else return (Right pt)
encode :: OutContext -> ByteString -> (ByteString, OutContext)
encode ctx@(Out {..}) pt
| aesCtr == maxBound = throw OldContext
| otherwise =
let !iv_ct_tag = encryptGCM outKey aesCtr saltOut (pad pt)
in (iv_ct_tag, ctx { aesCtr = 1 + aesCtr })
encBytes :: Int -> Int
encBytes lenMsg =
let lenBlock = gBlockLen
tagLen = lenBlock
ctrLen = gCtrSize
r = lenBlock (lenMsg `rem` lenBlock)
pdLen = if r == 0 then lenBlock else r
in ctrLen + lenMsg + pdLen + tagLen
decBytes :: Int -> Int
decBytes lenPkg =
let tagLen = gTagLen
ctrLen = gCtrSize
in lenPkg tagLen ctrLen
encodePtr :: OutContext -> Ptr Word8 -> Ptr Word8 -> Int -> IO OutContext
encodePtr ctx@(Out {..}) ptPtr pkgPtr ptLen
| aesCtr == maxBound = throw OldContext
| otherwise = do
let !totalLen = padding + ptLen
!padding = padLen ptLen
allocaBytes totalLen $ \ptPaddedPtr -> do
copyBytes ptPaddedPtr ptPtr ptLen
memset (ptPaddedPtr `plusPtr` ptLen) padding (fromIntegral padding)
encryptGCMPtr outKey aesCtr saltOut ptPaddedPtr totalLen pkgPtr
return (ctx { aesCtr = 1 + aesCtr })
where
memset :: Ptr Word8 -> Int -> Word8 -> IO ()
memset ptr1 len val = mapM_ (\o -> pokeElemOff ptr1 o val) [0..len1]
decodePtr :: InContext -> Ptr Word8 -> Ptr Word8 -> Int -> IO (Either CommSecError (Int,InContext))
decodePtr ctx pkgPtr msgPtr pkgLen = do
cnt <- peekBE pkgPtr
let !ctPtr = pkgPtr `plusPtr` sizeOf cnt
!ctLen = pkgLen tagLen sizeOf cnt
!tagPtr = pkgPtr `plusPtr` (pkgLen tagLen)
tagLen = gTagLen
paddedLen = ctLen
r <- decryptGCMPtr (inKey ctx) cnt (saltIn ctx) ctPtr ctLen tagPtr tagLen msgPtr
case r of
Left err -> return (Left err)
Right () -> helper ctx cnt paddedLen
where
helper :: InContext -> Word64 -> Int
-> IO (Either CommSecError (Int,InContext))
helper (InStrict {..}) cnt paddedLen
| cnt > seqVal = do
pdLen <- padLenPtr msgPtr paddedLen
case pdLen of
Nothing -> return $ Left BadPadding
Just l -> return $ Right (paddedLen l, InStrict cnt saltIn inKey)
| otherwise = return (Left DuplicateSeq)
helper (InSequential {..}) cnt paddedLen
| cnt == seqVal + 1 = do
pdLen <- padLenPtr msgPtr paddedLen
case pdLen of
Nothing -> return $ Left BadPadding
Just l -> return $ Right (paddedLen l, InSequential cnt saltIn inKey)
| otherwise = return (Left DuplicateSeq)
helper (In {..}) cnt paddedLen = do
case updateBitWindow bitWindow cnt of
Left e -> return (Left e)
Right newMask -> do
pdLen <- padLenPtr msgPtr paddedLen
case pdLen of
Nothing -> return $ Left BadPadding
Just l -> return $ Right (paddedLen l, In newMask saltIn inKey)
decode :: InContext -> ByteString -> Either CommSecError (ByteString, InContext)
decode ctx pkg =
let cnt = unsafePerformIO $ B.unsafeUseAsCString pkg (peekBE . castPtr)
cntLen = sizeOf cnt
tagLen = gTagLen
tag = B.drop (B.length pkg tagLen) pkg
ct = let st = (B.drop cntLen pkg) in B.take (B.length st tagLen) st
ptpd = decryptGCM (inKey ctx) cnt (saltIn ctx) ct tag
in helper ctx cnt ptpd
where
helper (In {..}) cnt ptpd =
case updateBitWindow bitWindow cnt of
Left e -> Left e
Right bw ->
case ptpd of
Left err -> Left err
Right ptPad ->
case unpad ptPad of
Nothing -> Left BadPadding
Just pt -> Right (pt, In bw saltIn inKey)
helper (InStrict {..}) cnt ptpd
| cnt > seqVal =
case ptpd of
Left err -> Left err
Right ptPad ->
case unpad ptPad of
Nothing -> Left BadPadding
Just pt -> Right (pt, InStrict cnt saltIn inKey)
| otherwise = Left DuplicateSeq
helper (InSequential {..}) cnt ptpd
| cnt == seqVal + 1 =
case ptpd of
Left err -> Left err
Right ptPad ->
case unpad ptPad of
Nothing -> Left BadPadding
Just pt -> Right (pt, InSequential cnt saltIn inKey)
| otherwise = Left DuplicateSeq
pad :: ByteString -> ByteString
pad bs =
let pd = B.replicate pdLen pdValue
pdLen = padLen (B.length bs)
pdValue = fromIntegral pdLen
in B.concat [bs,pd]
padLen :: Int -> Int
padLen ptLen =
let blkLen = gBlockLen
r = blkLen (ptLen `rem` blkLen)
in if r == 0 then blkLen else r
unpad :: ByteString -> Maybe ByteString
unpad bs
| len > 0 = Just $ B.take (len fromIntegral (B.last bs)) bs
| otherwise = Nothing
where
len = B.length bs
padLenPtr :: Ptr Word8 -> Int -> IO (Maybe Int)
padLenPtr ptr len
| len < gPadMax = return Nothing
| otherwise = do
r <- fromIntegral `fmap` (peekElemOff ptr (len1) :: IO Word8)
if r <= gPadMax then return (Just r) else return Nothing
peekBE :: Ptr Word8 -> IO Word64
peekBE p = do
let op n = fromIntegral `fmap` peekElemOff p n
as <- mapM op [0..7]
return (foldl1' (\r a -> (r `shiftL` 8) .|. a) as)
pokeBE :: Ptr Word8 -> Word64 -> IO ()
pokeBE p w = do
let op n = pokeElemOff p n (fromIntegral (w `shiftR` (56(8*n) :: Int)))
mapM_ op [0..7]
pokeBE32 :: Ptr Word8 -> Word32 -> IO ()
pokeBE32 p w = do
let op n = pokeElemOff p n (fromIntegral (w `shiftR` (24 (8*n) :: Int)))
mapM_ op [0..3]
peekBE32 :: Ptr Word8 -> IO Word32
peekBE32 p = do
let op n = fromIntegral `fmap` peekElemOff p n
as <- mapM op [0..3]
return (foldl1' (\r a -> (r `shiftL` 8) .|. a) as)