{-# LANGUAGE ExistentialQuantification #-} {-# LANGUAGE MultiWayIf #-} {-# LANGUAGE LambdaCase #-} module Network.SSH.Transport.Crypto ( KeyStreams (..) , EncryptionContext , DecryptionContext , plainEncryptionContext , plainDecryptionContext , newChaCha20Poly1305EncryptionContext , newChaCha20Poly1305DecryptionContext ) where import Control.Exception ( throwIO ) import Control.Monad ( when ) import Data.Bits ( unsafeShiftL, (.|.) ) import Data.Memory.PtrMethods ( memCopy, memConstEqual ) import Data.Monoid ( (<>) ) import Data.Word import Foreign.Marshal.Alloc ( allocaBytes ) import Foreign.Ptr import Foreign.Storable ( peekByteOff ) import qualified Data.ByteArray as BA import qualified Data.ByteString as BS import Network.SSH.Constants import Network.SSH.Encoding import Network.SSH.Exception import Network.SSH.Stream import qualified Network.SSH.Builder as B import qualified Network.SSH.Transport.Crypto.ChaCha as ChaChaM import qualified Network.SSH.Transport.Crypto.Poly1305 as Poly1305M newtype KeyStreams = KeyStreams (BS.ByteString -> [BA.Bytes]) type DecryptionContext = Word64 -> IO BS.ByteString type EncryptionContext = Word64 -> B.ByteArrayBuilder -> IO Int plainEncryptionContext :: OutputStream stream => stream -> EncryptionContext plainEncryptionContext stream _ payload = allocaBytes messageLen $ \ptr -> do B.copyToPtr messageBuilder ptr sendAllUnsafe stream (BA.MemView ptr messageLen) pure packetLen where payloadLen = B.babLength payload paddingLen = 16 - (headerLen + 1 + payloadLen) `mod` 8 packetLen = 1 + payloadLen + paddingLen messageLen = headerLen + packetLen messageBuilder = B.word32BE (fromIntegral packetLen) <> putWord8 (fromIntegral paddingLen) <> payload <> B.zeroes (fromIntegral paddingLen) plainDecryptionContext :: InputStream stream => stream -> DecryptionContext plainDecryptionContext stream = const $ allocaBytes headerLen $ \headerPtr -> do receiveAllUnsafe stream (BA.MemView headerPtr headerLen) packetLen <- peekPacketLen headerPtr (bsLen, bs) <- BA.allocRet packetLen $ \bsPtr -> do receiveAllUnsafe stream (BA.MemView bsPtr packetLen) -- the first byte of the packet announces the number of padding bytes paddingLen <- fromIntegral <$> (peekByteOff bsPtr 0 :: IO Word8) -- RFC: the padding must be >=4 && <= 255 when (paddingLen < minPaddingLen) (throwIO exceptionInvalidPacket) -- the padding must not exceed the packet length when (paddingLen + 1 >= packetLen) (throwIO exceptionInvalidPacket) -- return the length of the actual message without padding pure (packetLen - 1 - paddingLen) pure $! BS.take bsLen (BS.drop 1 bs) newChaCha20Poly1305EncryptionContext :: (OutputStream stream, BA.ByteArrayAccess key) => stream -> key -> key -> IO EncryptionContext newChaCha20Poly1305EncryptionContext stream headerKey mainKey = do chaChaState <- ChaChaM.new polyState <- Poly1305M.new poly64 <- BA.alloc (2 * polyKeyLen) (const $ pure ()) :: IO BA.Bytes pure $ \packetsSent plainBuilder -> do let plainLen = B.babLength plainBuilder :: Int packetLen = 1 + plainLen + paddingLen paddingLen = paddingLenFor plainLen messageLen = headerLen + packetLen + macLen allocaBytes messageLen $ \messagePtr -> do let headerPtr = messagePtr macPtr = plusPtr packetPtr packetLen noncePtr = macPtr nonceView = BA.MemView noncePtr nonceLen packetPtr = plusPtr headerPtr headerLen packetBuilder = B.word8 (fromIntegral paddingLen) <> plainBuilder <> B.zeroes paddingLen -- Use the MAC area to store the nonce temporarily and -- safe an allocation (made up 8% of all allocations in benchmark) B.copyToPtr (B.word64BE packetsSent) noncePtr -- Header ChaChaM.initialize chaChaState chaChaRounds headerKey nonceView B.copyToPtr (B.word32BE $ fromIntegral packetLen) headerPtr ChaChaM.combineUnsafe chaChaState headerPtr headerPtr headerLen -- Packet B.copyToPtr packetBuilder packetPtr BA.withByteArray poly64 $ \poly64Ptr -> do ChaChaM.initialize chaChaState chaChaRounds mainKey nonceView ChaChaM.generateUnsafe chaChaState poly64Ptr (2 * polyKeyLen) ChaChaM.combineUnsafe chaChaState packetPtr packetPtr packetLen -- MAC Poly1305M.authUnsafe polyState (BA.MemView poly64Ptr polyKeyLen) (BA.MemView headerPtr $ headerLen + packetLen) macPtr sendAllUnsafe stream (BA.MemView messagePtr messageLen) pure messageLen newChaCha20Poly1305DecryptionContext :: InputStream stream => BA.ByteArrayAccess key => stream -> key -> key -> IO DecryptionContext newChaCha20Poly1305DecryptionContext stream headerKey mainKey = do -- The mutable states for ChaCha and Poly1305 are allocated once -- per new decryption context. They are re-used for the decryption of -- subsequent messages. This is safe as long as the context is used by -- only one thread at a time. -- Both states get scrubbed on connection loss or after rekeying. -- The states do contain secret data while they are alive, but -- the ephemeral keys are stored in memory anyway. chaChaState <- ChaChaM.new polyState <- Poly1305M.new -- A piece of memory is allocated once for the lifetime of this -- decryption context. It does not contain confidential data and -- does not need to be scrubbed. temp <- BA.alloc (nonceLen + 2 * headerLen + macLen + 2 * polyKeyLen) (const $ pure ()) :: IO BA.Bytes pure $ \packetsReceived -> BA.withByteArray temp $ \ tempPtr -> do let noncePtr = tempPtr nonceView = BA.MemView noncePtr nonceLen headerCryptPtr = noncePtr `plusPtr` nonceLen headerPlainPtr = headerCryptPtr `plusPtr` headerLen macTrustedPtr = headerPlainPtr `plusPtr` headerLen polyKeyPtr = macTrustedPtr `plusPtr` macLen -- Poke the current nonce to the pre-allocated memory location (big-endian). -- It is the caller's responsibility to avoid nonce-reuse by timely rekeying. B.copyToPtr (B.word64BE packetsReceived) noncePtr -- Receive and decrypt the header (packet length). -- The encrypted packet header is also needed for integrity check (below). receiveAllUnsafe stream (BA.MemView headerCryptPtr headerLen) ChaChaM.initialize chaChaState chaChaRounds headerKey (BA.MemView noncePtr nonceLen) ChaChaM.combineUnsafe chaChaState headerPlainPtr headerCryptPtr headerLen packetLen <- peekPacketLen headerPlainPtr -- 64 (2*polyKeyLen) bytes shall be taken from the main key stream of which -- the first 32 are used for Poly1305. The other 32 bytes are -- not needed, but generated in order to get the correct ChaCha state. ChaChaM.initialize chaChaState chaChaRounds mainKey nonceView ChaChaM.generateUnsafe chaChaState polyKeyPtr (2 * polyKeyLen) -- Receive and authenticate the remaining packet. (bsLen, bs) <- BA.allocRet (headerLen + packetLen + macLen) $ \bsPtr -> do let packetPtr = bsPtr `plusPtr` headerLen macUntrustedPtr = packetPtr `plusPtr` packetLen -- Copy the ciphered header for inclusion in integrity check. memCopy bsPtr headerCryptPtr headerLen -- Receive the announced packet len + mac. receiveAllUnsafe stream (BA.MemView packetPtr (packetLen + macLen)) Poly1305M.authUnsafe polyState (BA.MemView polyKeyPtr polyKeyLen) -- authentication key (BA.MemView bsPtr (headerLen + packetLen)) -- authenticated data macTrustedPtr -- mac destination -- CRITICAL: check the message integrity! memConstEqual macTrustedPtr macUntrustedPtr macLen >>= \case False -> throwIO exceptionMacError True -> do -- decrypt message in-place ChaChaM.combineUnsafe chaChaState packetPtr packetPtr packetLen -- the first byte of the packet announces the number of padding bytes paddingLen <- fromIntegral <$> (peekByteOff packetPtr 0 :: IO Word8) -- RFC: the padding must be >=4 && <= 255 when (paddingLen < minPaddingLen) (throwIO exceptionInvalidPacket) -- the padding must not exceed the packet length when (paddingLen + 1 >= packetLen) (throwIO exceptionInvalidPacket) -- return the length of the actual message without padding pure (packetLen - 1 - paddingLen) -- The resulting message is a slice of the `BS.ByteString` (without padding and mac). -- The header, padding and mac are not confidential and remain in memory until the -- whole `BS.ByteString` gets collected. This saves allocations. pure $! BS.take bsLen (BS.drop (headerLen + 1) bs) ------------------------------------------------------------------------------- -- UTIL ------------------------------------------------------------------------------- headerLen, macLen, nonceLen, polyKeyLen, chaChaRounds, minPaddingLen :: Int headerLen = 4 macLen = 16 nonceLen = 8 polyKeyLen = 32 chaChaRounds = 20 minPaddingLen = 4 paddingLenFor :: Int -> Int paddingLenFor plainLen = if p < minPaddingLen then p + minBlockSize else p where minBlockSize = 8 p = minBlockSize - ((1 + plainLen) `mod` minBlockSize) receiveAllUnsafe :: InputStream stream => stream -> BA.MemView -> IO () receiveAllUnsafe stream v@(BA.MemView ptr n) | n <= 0 = pure () | otherwise = do m <- receiveUnsafe stream v when (m <= 0) (throwIO exceptionConnectionLost) receiveAllUnsafe stream (BA.MemView (plusPtr ptr m) (n - m)) sendAllUnsafe :: OutputStream stream => stream -> BA.MemView -> IO () sendAllUnsafe stream v@(BA.MemView ptr n) | n <= 0 = pure () | otherwise = do m <- sendUnsafe stream v when (m <= 0) (throwIO exceptionConnectionLost) sendAllUnsafe stream (BA.MemView (plusPtr ptr m) (n - m)) peekPacketLen :: Ptr Word8 -> IO Int peekPacketLen ptr = do packetLen <- f <$> (peekByteOff ptr 0 :: IO Word8) <*> (peekByteOff ptr 1 :: IO Word8) <*> (peekByteOff ptr 2 :: IO Word8) <*> (peekByteOff ptr 3 :: IO Word8) -- Any manipulation of the ciphered packet header will -- (with extreme likelyhood) result in a huge designated packet size -- after decryption. In this case, do not try to receive this packet -- and allocate memory for it but throw an exception and disconnect -- before even trying to authenticate the packet. when (packetLen > fromIntegral maxPacketLength) (throwIO exceptionMacError) -- Packet always consists of at least padding size byte, 1 byte payload -- and 4 bytes padding. when (packetLen < 1 + 1 + 4) (throwIO exceptionMacError) pure packetLen where f h0 h1 h2 h3 = g h0 24 .|. g h1 16 .|. g h2 8 .|. g h3 0 g w8 = unsafeShiftL (fromIntegral w8)