{-# LANGUAGE BangPatterns, RecordWildCards #-} -- | CommSec, for communications security. module Network.CommSec ( -- * Types OutContext(..) , InContext(..) -- * Build contexts for use sending and receiving , newInContext, newOutContext -- Encryption and decryption routines , recv , send ) where import Prelude hiding (seq) import qualified Crypto.Cipher.AES128.Internal as AES import Crypto.Cipher.AES128.Internal (AESKey) import Crypto.Cipher.AES128 () import qualified Data.ByteString.Internal as B import qualified Data.ByteString as B import qualified Data.ByteString.Unsafe as B import Crypto.Classes (buildKey) import Data.ByteString (ByteString) import Data.Bits import Data.Maybe (fromMaybe) import Data.Word import Data.List import Foreign.Ptr import Foreign.ForeignPtr import Foreign.Storable import Foreign.Marshal.Alloc (allocaBytes) import System.IO.Unsafe -- IPSec inspired packet format: -- -- [CNT (IV and seq) | Payload | Pad | ICV] -- | A context useful for sending data. data OutContext = Out { aesCtr :: {-# UNPACK #-} !Word64 , saltOut :: {-# UNPACK #-} !Word32 , outKey :: AESKey } -- | A context useful for receiving data. data InContext = In { base :: {-# UNPACK #-} !Word64 , mask :: {-# UNPACK #-} !Word64 , saltIn :: {-# UNPACK #-} !Word32 , inKey :: AESKey } -- | Given at least 24 bytes of entropy, produce an out context that can -- communicate with an identically initialized in context. newOutContext :: ByteString -> OutContext newOutContext bs | B.length bs < 24 = error "Not enough entropy" | otherwise = let aesCtr = 0 saltOut = unsafePerformIO $ B.unsafeUseAsCString bs $ peekBE32 . castPtr outKey = fromMaybe (error "Could not build a key") $ buildKey $ B.drop (sizeOf saltOut) bs in Out {..} -- | Given at least 24 bytes of entropy, produce an in context that can -- communicate with an identically initialized out context. newInContext :: ByteString -> InContext newInContext bs | B.length bs < 24 = error "Not enough entropy" | otherwise = let base = 0 mask = 0 saltIn = unsafePerformIO $ B.unsafeUseAsCString bs $ peekBE32 . castPtr inKey = fromMaybe (error "Could not build a key") $ buildKey $ B.drop (sizeOf saltIn) bs in In {..} -- Encrypts multiple-of-block-sized input, returing a bytestring of the -- [ctr, ct, tag]. encryptGCM :: AESKey -> Word64 -- ^ AES GCM Counter (IV) -> Word32 -- ^ Salt -> ByteString -- ^ Plaintext -> ByteString encryptGCM key ctr salt pt = unsafePerformIO $ do let ivLen = sizeOf ctr + sizeOf salt tagLen = 16 paddedLen = B.length pt allocaBytes ivLen $ \ptrIV -> do -- Build the IV pokeBE32 ptrIV salt pokeBE (ptrIV `plusPtr` sizeOf salt) ctr B.unsafeUseAsCString pt $ \ptrPT -> do B.create (paddedLen + sizeOf ctr + tagLen) $ \ctPtr -> do pokeBE ctPtr ctr let tagPtr = ctPtr' `plusPtr` paddedLen ctPtr' = ctPtr `plusPtr` sizeOf ctr AES.encryptGCM key ptrIV ivLen nullPtr 0 (castPtr ptrPT) (B.length pt) (castPtr ctPtr') tagPtr -- Decrypts multiple-of-block-sized input, returing a bytestring of the -- [ctr, ct, tag]. decryptGCM :: AESKey -> Word64 -- ^ AES GCM Counter (IV) -> Word32 -- ^ Salt -> ByteString -- ^ Ciphertext -> ByteString -- ^ Tag -> Either String ByteString -- Plaintext (or an exception due to bad tag) decryptGCM key ctr salt ct tag | B.length tag < 16 = Left "Tag too small" | otherwise = unsafePerformIO $ do let ivLen = sizeOf ctr + sizeOf salt tagLen = 16 paddedLen = B.length ct allocaBytes ivLen $ \ptrIV -> allocaBytes tagLen $ \ctagPtr -> do -- Build the IV pokeBE32 ptrIV salt pokeBE (ptrIV `plusPtr` sizeOf salt) ctr B.unsafeUseAsCString tag $ \tagPtr -> do B.unsafeUseAsCString ct $ \ptrCT -> do pt <- B.create paddedLen $ \ptrPT -> do AES.decryptGCM key ptrIV ivLen nullPtr 0 (castPtr ptrCT) (B.length ct) (castPtr ptrPT) ctagPtr w1 <- peekBE ctagPtr w2 <- peekBE (ctagPtr `plusPtr` sizeOf w1) y1 <- peekBE (castPtr tagPtr) y2 <- peekBE (castPtr tagPtr `plusPtr` sizeOf y1) if (w1 /= y1 || w2 /= y2) then return (Left $ "Tags do not match: " ++ show (w1,w2,y1,y2)) else return (Right pt) -- |Use an 'OutContext' to protect a message for transport. -- Message format: [ctr, ct, padding, tag]. send :: OutContext -> ByteString -> (ByteString, OutContext) send ctx@(Out {..}) pt = let !iv_ct_tag = encryptGCM outKey aesCtr saltOut (pad pt) in (iv_ct_tag, ctx { aesCtr = 1 + aesCtr }) -- |Use an 'InContext' to decrypt a message, verifying the ICV and sequence -- number. -- Message format: [ctr, ct, padding, tag]. recv :: InContext -> ByteString -> Either String (ByteString, InContext) recv (In {..}) pkg | base >= maxBound - 64 = Left "This cipher context has been used too long." | otherwise = let cnt = unsafePerformIO $ B.unsafeUseAsCString pkg (peekBE . castPtr) cntLen = sizeOf cnt tagLen = 16 tag = B.drop (B.length pkg - tagLen) pkg ct = let st = (B.drop cntLen pkg) in B.take (B.length st - tagLen) st ptpd = decryptGCM inKey cnt saltIn ct tag in case updateBaseMask base mask cnt of Nothing -> Left "Dup!" Just (base',mask') -> case ptpd of Left err -> Left err Right ptPad -> case unpad ptPad of Nothing -> Left "Bad padding" Just pt -> Right (pt, In base' mask' saltIn inKey) pad :: ByteString -> ByteString pad bs = let pd = B.replicate pdLen pdValue len = 16 r = len - (B.length bs `rem` len) pdLen = if r == 0 then len else r pdValue = fromIntegral pdLen in B.concat [bs,pd] -- Unsafe varient of PCKS5 padding unpad :: ByteString -> Maybe ByteString unpad bs | len > 0 = Just $ B.take (len - fromIntegral (B.last bs)) bs | otherwise = Nothing where len = B.length bs updateBaseMask :: Word64 -> Word64 -> Word64 -> Maybe (Word64,Word64) updateBaseMask !base !mask !seq | base <= seq && base >= seqBase = let pos = fromIntegral $ seq - base in if testBit mask pos then Nothing else Just (base, setBit mask pos) | base < seqBase = updateBaseMask seq (mask `shiftR` fromIntegral (seq - base)) seq | base > seq = Nothing where !seqBase | seq < 64 = 0 | otherwise = seq-63 peekBE :: Ptr Word8 -> IO Word64 peekBE p = do let op n = fromIntegral `fmap` peekElemOff p n as <- mapM op [0..7] return (foldl1' (\r a -> (r `shiftL` 8) .|. a) as) pokeBE :: Ptr Word8 -> Word64 -> IO () pokeBE p w = do let op n = pokeElemOff p n (fromIntegral (w `shiftR` (56-(8*n) :: Int))) mapM_ op [0..7] pokeBE32 :: Ptr Word8 -> Word32 -> IO () pokeBE32 p w = do let op n = pokeElemOff p n (fromIntegral (w `shiftR` (24 - (8*n) :: Int))) mapM_ op [0..3] peekBE32 :: Ptr Word8 -> IO Word32 peekBE32 p = do let op n = fromIntegral `fmap` peekElemOff p n as <- mapM op [0..3] return (foldl1' (\r a -> (r `shiftL` 8) .|. a) as)