{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE MultiWayIf #-}
{-# LANGUAGE LambdaCase #-}
module Network.SSH.Transport.Crypto
( KeyStreams (..)
, EncryptionContext
, DecryptionContext
, plainEncryptionContext
, plainDecryptionContext
, newChaCha20Poly1305EncryptionContext
, newChaCha20Poly1305DecryptionContext
)
where
import Control.Exception ( throwIO )
import Control.Monad ( when )
import Data.Bits ( unsafeShiftL, (.|.) )
import Data.Memory.PtrMethods ( memCopy, memConstEqual )
import Data.Monoid ( (<>) )
import Data.Word
import Foreign.Marshal.Alloc ( allocaBytes )
import Foreign.Ptr
import Foreign.Storable ( peekByteOff )
import qualified Data.ByteArray as BA
import qualified Data.ByteString as BS
import Network.SSH.Constants
import Network.SSH.Encoding
import Network.SSH.Exception
import Network.SSH.Stream
import qualified Network.SSH.Builder as B
import qualified Network.SSH.Transport.Crypto.ChaCha as ChaChaM
import qualified Network.SSH.Transport.Crypto.Poly1305 as Poly1305M
newtype KeyStreams = KeyStreams (BS.ByteString -> [BA.Bytes])
type DecryptionContext = Word64 -> IO BS.ByteString
type EncryptionContext = Word64 -> B.ByteArrayBuilder -> IO Int
plainEncryptionContext :: OutputStream stream => stream -> EncryptionContext
plainEncryptionContext stream _ payload = allocaBytes messageLen $ \ptr -> do
B.copyToPtr messageBuilder ptr
sendAllUnsafe stream (BA.MemView ptr messageLen)
pure packetLen
where
payloadLen = B.babLength payload
paddingLen = 16 - (headerLen + 1 + payloadLen) `mod` 8
packetLen = 1 + payloadLen + paddingLen
messageLen = headerLen + packetLen
messageBuilder =
B.word32BE (fromIntegral packetLen) <>
putWord8 (fromIntegral paddingLen) <>
payload <>
B.zeroes (fromIntegral paddingLen)
plainDecryptionContext :: InputStream stream => stream -> DecryptionContext
plainDecryptionContext stream = const $ allocaBytes headerLen $ \headerPtr -> do
receiveAllUnsafe stream (BA.MemView headerPtr headerLen)
packetLen <- peekPacketLen headerPtr
(bsLen, bs) <- BA.allocRet packetLen $ \bsPtr -> do
receiveAllUnsafe stream (BA.MemView bsPtr packetLen)
paddingLen <- fromIntegral <$> (peekByteOff bsPtr 0 :: IO Word8)
when (paddingLen < minPaddingLen) (throwIO exceptionInvalidPacket)
when (paddingLen + 1 >= packetLen) (throwIO exceptionInvalidPacket)
pure (packetLen - 1 - paddingLen)
pure $! BS.take bsLen (BS.drop 1 bs)
newChaCha20Poly1305EncryptionContext ::
(OutputStream stream, BA.ByteArrayAccess key) =>
stream -> key -> key -> IO EncryptionContext
newChaCha20Poly1305EncryptionContext stream headerKey mainKey = do
chaChaState <- ChaChaM.new
polyState <- Poly1305M.new
poly64 <- BA.alloc (2 * polyKeyLen) (const $ pure ()) :: IO BA.Bytes
pure $ \packetsSent plainBuilder -> do
let plainLen = B.babLength plainBuilder :: Int
packetLen = 1 + plainLen + paddingLen
paddingLen = paddingLenFor plainLen
messageLen = headerLen + packetLen + macLen
allocaBytes messageLen $ \messagePtr -> do
let headerPtr = messagePtr
macPtr = plusPtr packetPtr packetLen
noncePtr = macPtr
nonceView = BA.MemView noncePtr nonceLen
packetPtr = plusPtr headerPtr headerLen
packetBuilder = B.word8 (fromIntegral paddingLen) <> plainBuilder <> B.zeroes paddingLen
B.copyToPtr (B.word64BE packetsSent) noncePtr
ChaChaM.initialize chaChaState chaChaRounds headerKey nonceView
B.copyToPtr (B.word32BE $ fromIntegral packetLen) headerPtr
ChaChaM.combineUnsafe chaChaState headerPtr headerPtr headerLen
B.copyToPtr packetBuilder packetPtr
BA.withByteArray poly64 $ \poly64Ptr -> do
ChaChaM.initialize chaChaState chaChaRounds mainKey nonceView
ChaChaM.generateUnsafe chaChaState poly64Ptr (2 * polyKeyLen)
ChaChaM.combineUnsafe chaChaState packetPtr packetPtr packetLen
Poly1305M.authUnsafe polyState
(BA.MemView poly64Ptr polyKeyLen)
(BA.MemView headerPtr $ headerLen + packetLen) macPtr
sendAllUnsafe stream (BA.MemView messagePtr messageLen)
pure messageLen
newChaCha20Poly1305DecryptionContext ::
InputStream stream => BA.ByteArrayAccess key =>
stream -> key -> key -> IO DecryptionContext
newChaCha20Poly1305DecryptionContext stream headerKey mainKey = do
chaChaState <- ChaChaM.new
polyState <- Poly1305M.new
temp <- BA.alloc
(nonceLen + 2 * headerLen + macLen + 2 * polyKeyLen)
(const $ pure ()) :: IO BA.Bytes
pure $ \packetsReceived -> BA.withByteArray temp $ \ tempPtr -> do
let noncePtr = tempPtr
nonceView = BA.MemView noncePtr nonceLen
headerCryptPtr = noncePtr `plusPtr` nonceLen
headerPlainPtr = headerCryptPtr `plusPtr` headerLen
macTrustedPtr = headerPlainPtr `plusPtr` headerLen
polyKeyPtr = macTrustedPtr `plusPtr` macLen
B.copyToPtr (B.word64BE packetsReceived) noncePtr
receiveAllUnsafe stream (BA.MemView headerCryptPtr headerLen)
ChaChaM.initialize chaChaState chaChaRounds headerKey (BA.MemView noncePtr nonceLen)
ChaChaM.combineUnsafe chaChaState headerPlainPtr headerCryptPtr headerLen
packetLen <- peekPacketLen headerPlainPtr
ChaChaM.initialize chaChaState chaChaRounds mainKey nonceView
ChaChaM.generateUnsafe chaChaState polyKeyPtr (2 * polyKeyLen)
(bsLen, bs) <- BA.allocRet (headerLen + packetLen + macLen) $ \bsPtr -> do
let packetPtr = bsPtr `plusPtr` headerLen
macUntrustedPtr = packetPtr `plusPtr` packetLen
memCopy bsPtr headerCryptPtr headerLen
receiveAllUnsafe stream
(BA.MemView packetPtr (packetLen + macLen))
Poly1305M.authUnsafe polyState
(BA.MemView polyKeyPtr polyKeyLen)
(BA.MemView bsPtr (headerLen + packetLen))
macTrustedPtr
memConstEqual macTrustedPtr macUntrustedPtr macLen >>= \case
False -> throwIO exceptionMacError
True -> do
ChaChaM.combineUnsafe chaChaState packetPtr packetPtr packetLen
paddingLen <- fromIntegral <$> (peekByteOff packetPtr 0 :: IO Word8)
when (paddingLen < minPaddingLen) (throwIO exceptionInvalidPacket)
when (paddingLen + 1 >= packetLen) (throwIO exceptionInvalidPacket)
pure (packetLen - 1 - paddingLen)
pure $! BS.take bsLen (BS.drop (headerLen + 1) bs)
headerLen, macLen, nonceLen, polyKeyLen, chaChaRounds, minPaddingLen :: Int
headerLen = 4
macLen = 16
nonceLen = 8
polyKeyLen = 32
chaChaRounds = 20
minPaddingLen = 4
paddingLenFor :: Int -> Int
paddingLenFor plainLen =
if p < minPaddingLen then p + minBlockSize else p
where
minBlockSize = 8
p = minBlockSize - ((1 + plainLen) `mod` minBlockSize)
receiveAllUnsafe :: InputStream stream => stream -> BA.MemView -> IO ()
receiveAllUnsafe stream v@(BA.MemView ptr n)
| n <= 0 = pure ()
| otherwise = do
m <- receiveUnsafe stream v
when (m <= 0) (throwIO exceptionConnectionLost)
receiveAllUnsafe stream (BA.MemView (plusPtr ptr m) (n - m))
sendAllUnsafe :: OutputStream stream => stream -> BA.MemView -> IO ()
sendAllUnsafe stream v@(BA.MemView ptr n)
| n <= 0 = pure ()
| otherwise = do
m <- sendUnsafe stream v
when (m <= 0) (throwIO exceptionConnectionLost)
sendAllUnsafe stream (BA.MemView (plusPtr ptr m) (n - m))
peekPacketLen :: Ptr Word8 -> IO Int
peekPacketLen ptr = do
packetLen <- f
<$> (peekByteOff ptr 0 :: IO Word8)
<*> (peekByteOff ptr 1 :: IO Word8)
<*> (peekByteOff ptr 2 :: IO Word8)
<*> (peekByteOff ptr 3 :: IO Word8)
when (packetLen > fromIntegral maxPacketLength) (throwIO exceptionMacError)
when (packetLen < 1 + 1 + 4) (throwIO exceptionMacError)
pure packetLen
where
f h0 h1 h2 h3 = g h0 24 .|. g h1 16 .|. g h2 8 .|. g h3 0
g w8 = unsafeShiftL (fromIntegral w8)