{-# LANGUAGE CPP, BangPatterns, RecordWildCards, DeriveDataTypeable #-} -- | CommSec is a package that provides communication security for -- use with Haskell sockets. Using an ephemeral shared -- secret you can build contexts for sending or receiving data between one -- or more peers. -- -- Do not reuse the shared secret! Key agreement mechanisms that leverage -- PKI might be added later. module Network.CommSec.Package ( -- * Types OutContext(..) , InContext(..) , CommSecError(..) , SequenceMode(..) -- , Secret, Socket -- * Build contexts for use sending and receiving , newInContext, newOutContext -- , newSecret -- * Pure / ByteString based encryption and decryption routines , decode , encode -- * IO / Pointer based encryption and decryption routines , decodePtr , encodePtr -- * Utility functions , encBytes, decBytes -- * Wrappers for network sending and receiving -- , send, recv -- , sendPtr, recvPtr -- , connect, unsafeConnect -- , listen, unsafeListen -- * Utilities , peekBE32 , pokeBE32 , peekBE , pokeBE ) where import Prelude hiding (seq) import qualified Crypto.Cipher.AES128.Internal as AES import Crypto.Cipher.AES128.Internal (AESKey) import Crypto.Cipher.AES128 () import qualified Data.ByteString.Internal as B import qualified Data.ByteString as B import qualified Data.ByteString.Unsafe as B import Crypto.Classes (buildKey) import Data.ByteString (ByteString) import Data.Bits import Data.Maybe (fromMaybe) import Data.Word import Data.List import Foreign.Ptr import Foreign.ForeignPtr import Foreign.Storable import Foreign.Marshal.Alloc (allocaBytes) import Foreign.Marshal.Utils (copyBytes) import System.IO.Unsafe import Data.Data import Data.Typeable import Control.Exception import Network.CommSec.Types import Network.CommSec.BitWindow gPadMax,gBlockLen,gTagLen,gCtrSize :: Int gPadMax = 16 gBlockLen = 16 gTagLen = 16 gCtrSize = 8 -- IPSec inspired packet format: -- -- [CNT (used for both the IV and seq) | CT of Payload + Pad | ICV] -- | A context useful for sending data. data OutContext = Out { aesCtr :: {-# UNPACK #-} !Word64 , saltOut :: {-# UNPACK #-} !Word32 , outKey :: AESKey } -- | A context useful for receiving data. data InContext = In { bitWindow :: {-# UNPACK #-} !BitWindow , saltIn :: {-# UNPACK #-} !Word32 , inKey :: AESKey } | InStrict { seqVal :: {-# UNPACK #-} !Word64 , saltIn :: {-# UNPACK #-} !Word32 , inKey :: AESKey } | InSequential { seqVal :: {-# UNPACK #-} !Word64 , saltIn :: {-# UNPACK #-} !Word32 , inKey :: AESKey } -- | Given at least 24 bytes of entropy, produce an out context that can -- communicate with an identically initialized in context. newOutContext :: ByteString -> OutContext newOutContext bs | B.length bs < 24 = error $ "Not enough entropy: " ++ show (B.length bs) | otherwise = let aesCtr = 1 saltOut = unsafePerformIO $ B.unsafeUseAsCString bs $ peekBE32 . castPtr outKey = fromMaybe (error "Could not build a key") $ buildKey $ B.drop (sizeOf saltOut) bs in Out {..} -- | Given at least 24 bytes of entropy, produce an in context that can -- communicate with an identically initialized out context. newInContext :: ByteString -> SequenceMode -> InContext newInContext bs md | B.length bs < 24 = error $ "Not enough entropy: " ++ show (B.length bs) | otherwise = let bitWindow = zeroWindow seqVal = 0 saltIn = unsafePerformIO $ B.unsafeUseAsCString bs $ peekBE32 . castPtr inKey = fromMaybe (error "Could not build a key") $ buildKey $ B.drop (sizeOf saltIn) bs in case md of AllowOutOfOrder -> In {..} StrictOrdering -> InStrict {..} Sequential -> InSequential {..} -- Encrypts multiple-of-block-sized input, returing a bytestring of the -- [ctr, ct, tag]. encryptGCM :: AESKey -> Word64 -- ^ AES GCM Counter (IV) -> Word32 -- ^ Salt -> ByteString -- ^ Plaintext -> ByteString encryptGCM key ctr salt pt = unsafePerformIO $ do let ivLen = sizeOf ctr + sizeOf salt tagLen = gTagLen paddedLen = B.length pt allocaBytes ivLen $ \ptrIV -> do -- Build the IV pokeBE32 ptrIV salt pokeBE (ptrIV `plusPtr` sizeOf salt) ctr B.unsafeUseAsCString pt $ \ptrPT -> do B.create (paddedLen + sizeOf ctr + tagLen) $ \ctPtr -> do pokeBE ctPtr ctr let tagPtr = ctPtr' `plusPtr` paddedLen ctPtr' = ctPtr `plusPtr` sizeOf ctr AES.encryptGCM key ptrIV ivLen nullPtr 0 (castPtr ptrPT) (B.length pt) (castPtr ctPtr') tagPtr -- Encrypts multiple-of-block-sized input, filling a pointer with the -- result of [ctr, ct, tag]. encryptGCMPtr :: AESKey -> Word64 -- ^ AES GCM Counter (IV) -> Word32 -- ^ Salt -> Ptr Word8 -- ^ Plaintext buffer -> Int -- ^ Plaintext length -> Ptr Word8 -- ^ ciphertext buffer (at least encBytes large) -> IO () encryptGCMPtr key ctr salt ptPtr ptLen ctPtr = do let ivLen = sizeOf ctr + sizeOf salt tagLen = gTagLen paddedLen = ptLen allocaBytes ivLen $ \ptrIV -> do -- Build the IV pokeBE32 ptrIV salt pokeBE (ptrIV `plusPtr` sizeOf salt) ctr pokeBE ctPtr ctr let tagPtr = ctPtr' `plusPtr` paddedLen ctPtr' = ctPtr `plusPtr` sizeOf ctr AES.encryptGCM key ptrIV ivLen nullPtr 0 (castPtr ptPtr) ptLen (castPtr ctPtr') tagPtr -- | GCM decrypt and verify ICV. decryptGCMPtr :: AESKey -> Word64 -- ^ AES GCM Counter (IV) -> Word32 -- ^ Salt -> Ptr Word8 -- ^ Ciphertext -> Int -- ^ Ciphertext length -> Ptr Word8 -- ^ Tag -> Int -- ^ Tag length -> Ptr Word8 -- ^ Plaintext result ptr (at least 'decBytes' large) -> IO (Either CommSecError ()) decryptGCMPtr key ctr salt ctPtr ctLen tagPtr tagLen ptPtr | tagLen /= gTagLen = return $ Left InvalidICV | otherwise = do let ivLen = sizeOf ctr + sizeOf salt paddedLen = ctLen allocaBytes ivLen $ \ptrIV -> allocaBytes tagLen $ \ctagPtr -> do -- Build the IV pokeBE32 ptrIV salt pokeBE (ptrIV `plusPtr` sizeOf salt) ctr AES.decryptGCM key ptrIV ivLen nullPtr 0 (castPtr ctPtr) paddedLen (castPtr ptPtr) ctagPtr w1 <- peekBE ctagPtr w2 <- peekBE (ctagPtr `plusPtr` sizeOf w1) y1 <- peekBE (castPtr tagPtr) y2 <- peekBE (castPtr tagPtr `plusPtr` sizeOf y1) if (w1 /= y1 || w2 /= y2) then return (Left InvalidICV) else return (Right ()) -- Decrypts multiple-of-block-sized input, returing a bytestring of the -- [ctr, ct, tag]. decryptGCM :: AESKey -> Word64 -- ^ AES GCM Counter (IV) -> Word32 -- ^ Salt -> ByteString -- ^ Ciphertext -> ByteString -- ^ Tag -> Either CommSecError ByteString -- Plaintext (or an exception due to bad tag) decryptGCM key ctr salt ct tag | B.length tag < gTagLen = Left InvalidICV | otherwise = unsafePerformIO $ do let ivLen = sizeOf ctr + sizeOf salt tagLen = gTagLen paddedLen = B.length ct allocaBytes ivLen $ \ptrIV -> allocaBytes tagLen $ \ctagPtr -> do -- Build the IV pokeBE32 ptrIV salt pokeBE (ptrIV `plusPtr` sizeOf salt) ctr B.unsafeUseAsCString tag $ \tagPtr -> do B.unsafeUseAsCString ct $ \ptrCT -> do pt <- B.create paddedLen $ \ptrPT -> do AES.decryptGCM key ptrIV ivLen nullPtr 0 (castPtr ptrCT) (B.length ct) (castPtr ptrPT) ctagPtr w1 <- peekBE ctagPtr w2 <- peekBE (ctagPtr `plusPtr` sizeOf w1) y1 <- peekBE (castPtr tagPtr) y2 <- peekBE (castPtr tagPtr `plusPtr` sizeOf y1) if (w1 /= y1 || w2 /= y2) then return (Left InvalidICV) else return (Right pt) -- |Use an 'OutContext' to protect a message for transport. -- Message format: [ctr, ct, padding, tag]. -- -- This routine can throw an exception of 'OldContext' if the context being -- used has expired. encode :: OutContext -> ByteString -> (ByteString, OutContext) encode ctx@(Out {..}) pt | aesCtr == maxBound = throw OldContext | otherwise = let !iv_ct_tag = encryptGCM outKey aesCtr saltOut (pad pt) in (iv_ct_tag, ctx { aesCtr = ((fromIntegral $ B.length pt + 31) `rem` 16) + 1 + aesCtr }) -- |Given a message length, returns the number of bytes an encoded message -- will consume. encBytes :: Int -> Int encBytes lenMsg = let lenBlock = gBlockLen tagLen = lenBlock ctrLen = gCtrSize r = lenBlock - (lenMsg `rem` lenBlock) pdLen = if r == 0 then lenBlock else r in ctrLen + lenMsg + pdLen + tagLen -- |Given a package length, returns the maximum number of bytes the -- underlying message could be (including padding). decBytes :: Int -> Int decBytes lenPkg = let tagLen = gTagLen ctrLen = gCtrSize in lenPkg - tagLen - ctrLen -- |@encodePtr outCtx msg result msgLen@ will encode @msgLen@ bytes at -- location @msg@, placing the result at location @result@. The buffer -- pointed to by @result@ must be at least @encBytes msgLen@ bytes large, -- the actual package will be exactly @encBytes msgLen@ in size. encodePtr :: OutContext -> Ptr Word8 -> Ptr Word8 -> Int -> IO OutContext encodePtr ctx@(Out {..}) ptPtr pkgPtr ptLen | aesCtr == maxBound = throw OldContext | otherwise = do let !totalLen = padding + ptLen !padding = padLen ptLen allocaBytes totalLen $ \ptPaddedPtr -> do copyBytes ptPaddedPtr ptPtr ptLen memset (ptPaddedPtr `plusPtr` ptLen) padding (fromIntegral padding) encryptGCMPtr outKey aesCtr saltOut ptPaddedPtr totalLen pkgPtr return (ctx { aesCtr = (fromIntegral totalLen `rem` 16) + 1 + aesCtr }) where memset :: Ptr Word8 -> Int -> Word8 -> IO () memset ptr1 len val = mapM_ (\o -> pokeElemOff ptr1 o val) [0..len-1] -- |@decodePtr inCtx pkg msg pkgLen@ decrypts and verifies a package at -- location @pkg@ of size @pkgLen@. The resulting message is placed at -- location @msg@ and its size is returned along with a new context (or -- error). decodePtr :: InContext -> Ptr Word8 -> Ptr Word8 -> Int -> IO (Either CommSecError (Int,InContext)) decodePtr ctx pkgPtr msgPtr pkgLen = do cnt <- peekBE pkgPtr let !ctPtr = pkgPtr `plusPtr` sizeOf cnt !ctLen = pkgLen - tagLen - sizeOf cnt !tagPtr = pkgPtr `plusPtr` (pkgLen - tagLen) tagLen = gTagLen paddedLen = ctLen r <- decryptGCMPtr (inKey ctx) cnt (saltIn ctx) ctPtr ctLen tagPtr tagLen msgPtr case r of Left err -> return (Left err) Right () -> helper ctx cnt paddedLen where {-# INLINE helper #-} helper :: InContext -> Word64 -> Int -> IO (Either CommSecError (Int,InContext)) helper (InStrict {..}) cnt paddedLen | cnt > seqVal = do pdLen <- padLenPtr msgPtr paddedLen case pdLen of Nothing -> return $ Left BadPadding Just l -> return $ Right (paddedLen - l, InStrict cnt saltIn inKey) | otherwise = return (Left DuplicateSeq) helper (InSequential {..}) cnt paddedLen | cnt == seqVal + 1 = do pdLen <- padLenPtr msgPtr paddedLen case pdLen of Nothing -> return $ Left BadPadding Just l -> return $ Right (paddedLen - l, InSequential cnt saltIn inKey) | otherwise = return (Left DuplicateSeq) helper (In {..}) cnt paddedLen = do case updateBitWindow bitWindow cnt of Left e -> return (Left e) Right newMask -> do pdLen <- padLenPtr msgPtr paddedLen case pdLen of Nothing -> return $ Left BadPadding Just l -> return $ Right (paddedLen - l, In newMask saltIn inKey) -- |Use an 'InContext' to decrypt a message, verifying the ICV and sequence -- number. Unlike sending, receiving is more likely to result in an -- exceptional condition and thus it returns an 'Either' value. -- -- Message format: [ctr, ct, padding, tag]. decode :: InContext -> ByteString -> Either CommSecError (ByteString, InContext) decode ctx pkg = let cnt = unsafePerformIO $ B.unsafeUseAsCString pkg (peekBE . castPtr) cntLen = sizeOf cnt tagLen = gTagLen tag = B.drop (B.length pkg - tagLen) pkg ct = let st = (B.drop cntLen pkg) in B.take (B.length st - tagLen) st ptpd = decryptGCM (inKey ctx) cnt (saltIn ctx) ct tag in helper ctx cnt ptpd where {-# INLINE helper #-} helper (In {..}) cnt ptpd = case updateBitWindow bitWindow cnt of Left e -> Left e Right bw -> case ptpd of Left err -> Left err Right ptPad -> case unpad ptPad of Nothing -> Left BadPadding Just pt -> Right (pt, In bw saltIn inKey) helper (InStrict {..}) cnt ptpd | cnt > seqVal = case ptpd of Left err -> Left err Right ptPad -> case unpad ptPad of Nothing -> Left BadPadding Just pt -> Right (pt, InStrict cnt saltIn inKey) | otherwise = Left DuplicateSeq helper (InSequential {..}) cnt ptpd | cnt == seqVal + 1 = case ptpd of Left err -> Left err Right ptPad -> case unpad ptPad of Nothing -> Left BadPadding Just pt -> Right (pt, InSequential cnt saltIn inKey) | otherwise = Left DuplicateSeq -- |Pad a bytestring to block size pad :: ByteString -> ByteString pad bs = let pd = B.replicate pdLen pdValue pdLen = padLen (B.length bs) pdValue = fromIntegral pdLen in B.concat [bs,pd] -- |Given length of a plaintext message, return the length of the padding -- needed. padLen :: Int -> Int padLen ptLen = let blkLen = gBlockLen r = blkLen - (ptLen `rem` blkLen) in if r == 0 then blkLen else r {-# INLINE padLen #-} -- |Remove padding from a padded bytestring. This is a varient of PCKS5 padding that does not check the pad values. unpad :: ByteString -> Maybe ByteString unpad bs | len > 0 = Just $ B.take (len - fromIntegral (B.last bs)) bs | otherwise = Nothing where len = B.length bs -- |Given a pointer to padded data and the length of the data, determine the length of the padding. -- Perhaps this should be called 'unpadPtr' padLenPtr :: Ptr Word8 -> Int -> IO (Maybe Int) padLenPtr ptr len | len < gPadMax = return Nothing | otherwise = do r <- fromIntegral `fmap` (peekElemOff ptr (len-1) :: IO Word8) if r <= gPadMax then return (Just r) else return Nothing peekBE :: Ptr Word8 -> IO Word64 peekBE p = do let op n = fromIntegral `fmap` peekElemOff p n as <- mapM op [0..7] return (foldl1' (\r a -> (r `shiftL` 8) .|. a) as) {-# INLINE peekBE #-} pokeBE :: Ptr Word8 -> Word64 -> IO () pokeBE p w = do let op n = pokeElemOff p n (fromIntegral (w `shiftR` (56-(8*n) :: Int))) mapM_ op [0..7] {-# INLINE pokeBE #-} pokeBE32 :: Ptr Word8 -> Word32 -> IO () pokeBE32 p w = do let op n = pokeElemOff p n (fromIntegral (w `shiftR` (24 - (8*n) :: Int))) mapM_ op [0..3] {-# INLINE pokeBE32 #-} peekBE32 :: Ptr Word8 -> IO Word32 peekBE32 p = do let op n = fromIntegral `fmap` peekElemOff p n as <- mapM op [0..3] return (foldl1' (\r a -> (r `shiftL` 8) .|. a) as) {-# INLINE peekBE32 #-}