{-# LANGUAGE ExistentialQuantification #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE MultiWayIf #-} {-# LANGUAGE TupleSections #-} {-# LANGUAGE LambdaCase #-} module Network.SSH.Transport ( Transport() , TransportConfig (..) , Disconnected (..) , withTransport , plainEncryptionContext , plainDecryptionContext , newChaCha20Poly1305EncryptionContext , newChaCha20Poly1305DecryptionContext ) where import Control.Applicative import Control.Concurrent ( threadDelay ) import Control.Concurrent.Async import Control.Concurrent.MVar import Control.Exception ( throwIO, handle, catch, fromException) import Control.Monad ( when, void ) import Control.Monad.STM import Data.Default import Data.List import Data.Monoid ( (<>) ) import Data.Word import GHC.Clock import qualified Crypto.Hash as Hash import qualified Crypto.PubKey.Curve25519 as Curve25519 import qualified Crypto.PubKey.Ed25519 as Ed25519 import qualified Data.ByteArray as BA import qualified Data.ByteString as BS import qualified Data.ByteString.Short as SBS import qualified Data.List.NonEmpty as NEL import Network.SSH.Algorithms import qualified Network.SSH.Builder as B import Network.SSH.AuthAgent import Network.SSH.Constants import Network.SSH.Transport.Crypto import Network.SSH.Encoding import Network.SSH.Exception import Network.SSH.Message import Network.SSH.Name import Network.SSH.Stream data Transport = forall stream agent. (DuplexStream stream, AuthAgent agent) => TransportEnv { tStream :: stream , tConfig :: TransportConfig , tAuthAgent :: Maybe agent , tClientVersion :: Version , tServerVersion :: Version , tBytesSent :: MVar Word64 , tPacketsSent :: MVar Word64 , tBytesReceived :: MVar Word64 , tPacketsReceived :: MVar Word64 , tEncryptionCtx :: MVar EncryptionContext , tEncryptionCtxNext :: MVar EncryptionContext , tDecryptionCtx :: MVar DecryptionContext , tDecryptionCtxNext :: MVar DecryptionContext , tKexContinuation :: MVar KexContinuation , tSessionId :: MVar SessionId , tLastRekeyingTime :: MVar Word64 , tLastRekeyingDataSent :: MVar Word64 , tLastRekeyingDataReceived :: MVar Word64 } data TransportConfig = TransportConfig { serverHostKeyAlgorithms :: NEL.NonEmpty HostKeyAlgorithm , kexAlgorithms :: NEL.NonEmpty KeyExchangeAlgorithm , encryptionAlgorithms :: NEL.NonEmpty EncryptionAlgorithm , maxTimeBeforeRekey :: Word64 , maxDataBeforeRekey :: Word64 , onSend :: BS.ByteString -> IO () , onReceive :: BS.ByteString -> IO () } instance Default TransportConfig where def = TransportConfig { serverHostKeyAlgorithms = pure SshEd25519 , kexAlgorithms = pure Curve25519Sha256AtLibsshDotOrg , encryptionAlgorithms = pure Chacha20Poly1305AtOpensshDotCom , maxTimeBeforeRekey = 3600 , maxDataBeforeRekey = 1000 * 1000 * 1000 , onSend = const (pure ()) , onReceive = const (pure ()) } data KexStep = Init KexInit | EcdhInit KexEcdhInit | EcdhReply KexEcdhReply newtype KexContinuation = KexContinuation (Maybe KexStep -> IO KexContinuation) instance MessageStream Transport where sendMessage t msg = do kexIfNecessary t transportSendMessage t msg receiveMessage t = do kexIfNecessary t transportReceiveMessage t withTransport :: (DuplexStream stream, AuthAgent agent) => TransportConfig -> Maybe agent -> stream -> (Transport -> SessionId -> IO a) -> IO (Either Disconnect a) withTransport config magent stream runWith = withFinalExceptionHandler $ do (clientVersion, serverVersion) <- case magent of -- Receive the peer version and reject immediately if this -- is not an SSH connection attempt (before allocating -- any more resources); respond with the server version string. Just {} -> do cv <- receiveVersion stream sv <- sendVersion stream pure (cv, sv) -- Start with sending local version and then wait for response. Nothing -> do cv <- sendVersion stream sv <- receiveVersion stream pure (cv, sv) xBytesSent <- newMVar 0 xPacketsSent <- newMVar 0 xBytesReceived <- newMVar 0 xPacketsReceived <- newMVar 0 xEncryptionCtx <- newMVar (plainEncryptionContext stream) xEncryptionCtxNext <- newMVar (plainEncryptionContext stream) xDecryptionCtx <- newMVar (plainDecryptionContext stream) xDecryptionCtxNext <- newMVar (plainDecryptionContext stream) xKexContinuation <- newEmptyMVar xSessionId <- newEmptyMVar xRekeyTime <- newMVar =<< getEpochSeconds xRekeySent <- newMVar 0 xRekeyRcvd <- newMVar 0 let env = TransportEnv { tStream = stream , tConfig = config , tAuthAgent = magent , tClientVersion = clientVersion , tServerVersion = serverVersion , tBytesSent = xBytesSent , tPacketsSent = xPacketsSent , tBytesReceived = xBytesReceived , tPacketsReceived = xPacketsReceived , tEncryptionCtx = xEncryptionCtx , tEncryptionCtxNext = xEncryptionCtxNext , tDecryptionCtx = xDecryptionCtx , tDecryptionCtxNext = xDecryptionCtxNext , tKexContinuation = xKexContinuation , tSessionId = xSessionId , tLastRekeyingTime = xRekeyTime , tLastRekeyingDataSent = xRekeySent , tLastRekeyingDataReceived = xRekeyRcvd } withRespondingExceptionHandler env $ do sessionId <- kexInitialize env a <- runWith env sessionId sendMessage env (Disconnected DisconnectByApplication mempty mempty) pure a where withFinalExceptionHandler :: IO (Either Disconnect a) -> IO (Either Disconnect a) withFinalExceptionHandler = handle $ \e -> maybe (throwIO e) (pure . Left) (fromException e) withRespondingExceptionHandler :: Transport -> IO a -> IO (Either Disconnect a) withRespondingExceptionHandler env run = (Right <$> run) `catch` \e-> case e of Disconnect _ DisconnectConnectionLost _ -> pure (Left e) Disconnect Local r (DisconnectMessage m) -> withAsync (threadDelay (1000*1000)) $ \thread1 -> withAsync (transportSendMessage env $ Disconnected r (SBS.toShort m) mempty) $ \thread2 -> do atomically $ void (waitCatchSTM thread1) <|> void (waitCatchSTM thread2) pure (Left e) _ -> pure (Left e) transportSendMessage :: Encoding msg => Transport -> msg -> IO () transportSendMessage env msg = modifyMVar_ (tEncryptionCtx env) $ \sendEncrypted -> do onSend (tConfig env) (runPut payload) packets <- modifyMVar (tPacketsSent env) $ \p -> pure . (,p) $! p + 1 sent <- sendEncrypted packets payload modifyMVar_ (tBytesSent env) $ \bytes -> pure $! bytes + fromIntegral sent if B.babLength payload == 1 && runGet (runPut payload) == Just KexNewKeys then readMVar (tEncryptionCtxNext env) else pure sendEncrypted where payload = put msg transportReceiveMessage :: Encoding msg => Transport -> IO msg transportReceiveMessage env = do raw <- transportReceiveRawMessage env maybe (throwIO $ exceptionUnexpectedMessage raw) pure (runGet raw) transportReceiveRawMessage :: Transport -> IO BS.ByteString transportReceiveRawMessage env = maybe (transportReceiveRawMessage env) pure =<< transportReceiveRawMessageMaybe env transportReceiveRawMessageMaybe :: Transport -> IO (Maybe BS.ByteString) transportReceiveRawMessageMaybe env = modifyMVar (tDecryptionCtx env) $ \decrypt -> do packets <- readMVar (tPacketsReceived env) plainText <- decrypt packets onReceive (tConfig env) plainText modifyMVar_ (tPacketsReceived env) $ \pacs -> pure $! pacs + 1 case interpreter plainText of Just i -> i >> pure (decrypt, Nothing) Nothing -> case runGet plainText of Just KexNewKeys -> do (,Nothing) <$> readMVar (tDecryptionCtxNext env) Nothing -> pure (decrypt, Just plainText) where interpreter plainText = f i0 <|> f i1 <|> f i2 <|> f i3 <|> f i4 <|> f i5 <|> f i6 where f i = i <$> runGet plainText i0 (Disconnected r m _) = throwIO $ Disconnect Remote r (DisconnectMessage $ SBS.fromShort m) i1 Debug {} = pure () i2 Ignore {} = pure () i3 Unimplemented {} = pure () i4 x@KexInit {} = kexContinue env (Init x) i5 x@KexEcdhInit {} = kexContinue env (EcdhInit x) i6 x@KexEcdhReply {} = kexContinue env (EcdhReply x) ------------------------------------------------------------------------------- -- CRYPTO --------------------------------------------------------------------- ------------------------------------------------------------------------------- setChaCha20Poly1305Context :: Transport -> KeyStreams -> IO () setChaCha20Poly1305Context env@TransportEnv { tStream = stream, tAuthAgent = agent } (KeyStreams keys) = do modifyMVar_ (tEncryptionCtxNext env) $ const $ case agent of Just {} -> newChaCha20Poly1305EncryptionContext stream headerKeySC mainKeySC Nothing -> newChaCha20Poly1305EncryptionContext stream headerKeyCS mainKeyCS modifyMVar_ (tDecryptionCtxNext env) $ const $ case agent of Just {} -> newChaCha20Poly1305DecryptionContext stream headerKeyCS mainKeyCS Nothing -> newChaCha20Poly1305DecryptionContext stream headerKeySC mainKeySC where -- Derive the required encryption/decryption keys. -- The integrity keys etc. are not needed with chacha20. mainKeyCS : headerKeyCS : _ = keys "C" mainKeySC : headerKeySC : _ = keys "D" ------------------------------------------------------------------------------- -- KEY EXCHANGE --------------------------------------------------------------- ------------------------------------------------------------------------------- kexInitialize :: Transport -> IO SessionId kexInitialize env@TransportEnv { tAuthAgent = agent } = do cookie <- newCookie putMVar (tKexContinuation env) $ case agent of Just aa -> kexServerContinuation env cookie aa Nothing -> kexClientContinuation env cookie kexTrigger env dontAcceptMessageUntilKexComplete where dontAcceptMessageUntilKexComplete = do transportReceiveRawMessageMaybe env >>= \case Just _ -> throwIO exceptionKexInvalidTransition Nothing -> tryReadMVar (tSessionId env) >>= \case Nothing -> dontAcceptMessageUntilKexComplete Just sid -> pure sid kexTrigger :: Transport -> IO () kexTrigger env = do modifyMVar_ (tKexContinuation env) $ \(KexContinuation f) -> f Nothing kexIfNecessary :: Transport -> IO () kexIfNecessary env = do kexRekeyingRequired env >>= \case False -> pure () True -> do void $ swapMVar (tLastRekeyingTime env) =<< getEpochSeconds void $ swapMVar (tLastRekeyingDataSent env) =<< readMVar (tBytesSent env) void $ swapMVar (tLastRekeyingDataReceived env) =<< readMVar (tBytesReceived env) kexTrigger env kexContinue :: Transport -> KexStep -> IO () kexContinue env step = do modifyMVar_ (tKexContinuation env) $ \(KexContinuation f) -> f (Just step) -- NB: Uses transportSendMessage to avoid rekeying-loop kexClientContinuation :: Transport -> Cookie -> KexContinuation kexClientContinuation env cookie = clientKex0 where clientKex0 :: KexContinuation clientKex0 = KexContinuation $ \case Nothing -> do transportSendMessage env cki pure (clientKex1 cki) Just (Init ski) -> do cekSecret <- Curve25519.generateSecretKey let cek = Curve25519.toPublic cekSecret transportSendMessage env cki transportSendMessage env (KexEcdhInit cek) pure (clientKex2 cki ski cek cekSecret) _ -> throwIO exceptionKexInvalidTransition where cki = kexInit (tConfig env) cookie clientKex1 :: KexInit -> KexContinuation clientKex1 cki = KexContinuation $ \case Nothing -> pure (clientKex1 cki) Just (Init ski) -> do cekSecret <- Curve25519.generateSecretKey let cek = Curve25519.toPublic cekSecret transportSendMessage env (KexEcdhInit cek) pure (clientKex2 cki ski cek cekSecret) _ -> throwIO exceptionKexInvalidTransition clientKex2 :: KexInit -> KexInit -> Curve25519.PublicKey -> Curve25519.SecretKey -> KexContinuation clientKex2 cki ski cek cekSecret = KexContinuation $ \case Nothing -> pure (clientKex2 cki ski cek cekSecret) Just (EcdhReply ecdhReply) -> do consumeEcdhReply cki ski cek cekSecret ecdhReply pure clientKex0 _ -> throwIO exceptionKexInvalidTransition consumeEcdhReply :: KexInit -> KexInit -> Curve25519.PublicKey -> Curve25519.SecretKey -> KexEcdhReply -> IO () consumeEcdhReply cki ski cek cekSecret ecdhReply = do kexAlgorithm <- kexCommonKexAlgorithm ski cki encAlgorithmCS <- kexCommonEncAlgorithm ski cki kexEncryptionAlgorithmsClientToServer encAlgorithmSC <- kexCommonEncAlgorithm ski cki kexEncryptionAlgorithmsServerToClient case (kexAlgorithm, encAlgorithmCS, encAlgorithmSC) of (Curve25519Sha256AtLibsshDotOrg, Chacha20Poly1305AtOpensshDotCom, Chacha20Poly1305AtOpensshDotCom) -> kexWithVerifiedSignature shk hash sig $ do sid <- trySetSessionId env (SessionId $ SBS.toShort $ BA.convert hash) setChaCha20Poly1305Context env $ kexKeys sec hash sid transportSendMessage env KexNewKeys where cv = tClientVersion env sv = tServerVersion env shk = kexServerHostKey ecdhReply sek = kexServerEphemeralKey ecdhReply sec = Curve25519.dh sek cekSecret sig = kexHashSignature ecdhReply hash = kexHash cv sv cki ski shk cek sek sec -- NB: Uses transportSendMessage to avoid rekeying-loop kexServerContinuation :: AuthAgent agent => Transport -> Cookie -> agent -> KexContinuation kexServerContinuation env cookie authAgent = serverKex0 where serverKex0 :: KexContinuation serverKex0 = KexContinuation $ \case Nothing -> do transportSendMessage env ski pure (serverKex1 ski) Just (Init cki) -> do transportSendMessage env ski pure (serverKex2 cki ski) _ -> throwIO exceptionKexInvalidTransition where ski = kexInit (tConfig env) cookie serverKex1 :: KexInit -> KexContinuation serverKex1 ski = KexContinuation $ \case Nothing-> do pure (serverKex1 ski) Just (Init cki) -> pure (serverKex2 cki ski) _ -> throwIO exceptionKexInvalidTransition serverKex2 :: KexInit -> KexInit -> KexContinuation serverKex2 cki ski = KexContinuation $ \case Nothing -> do pure (serverKex2 cki ski) Just (EcdhInit (KexEcdhInit cek)) -> do emitEcdhReply cki ski cek pure serverKex0 _ -> throwIO exceptionKexInvalidTransition emitEcdhReply :: KexInit -> KexInit -> Curve25519.PublicKey -> IO () emitEcdhReply cki ski cek = do kexAlgorithm <- kexCommonKexAlgorithm ski cki encAlgorithmCS <- kexCommonEncAlgorithm ski cki kexEncryptionAlgorithmsClientToServer encAlgorithmSC <- kexCommonEncAlgorithm ski cki kexEncryptionAlgorithmsServerToClient getPublicKeys authAgent >>= \case [] -> throwIO exceptionKexNoSignature shk:_ -> case (kexAlgorithm, encAlgorithmCS, encAlgorithmSC) of (Curve25519Sha256AtLibsshDotOrg, Chacha20Poly1305AtOpensshDotCom, Chacha20Poly1305AtOpensshDotCom) -> do sekSecret <- Curve25519.generateSecretKey let cv = tClientVersion env sv = tServerVersion env sek = Curve25519.toPublic sekSecret sec = Curve25519.dh cek sekSecret hash = kexHash cv sv cki ski shk cek sek sec sig <- maybe (throwIO exceptionKexNoSignature) pure =<< getSignature authAgent shk hash sid <- trySetSessionId env (SessionId $ SBS.toShort $ BA.convert hash) setChaCha20Poly1305Context env $ kexKeys sec hash sid transportSendMessage env (KexEcdhReply shk sek sig) transportSendMessage env KexNewKeys kexCommonKexAlgorithm :: KexInit -> KexInit -> IO KeyExchangeAlgorithm kexCommonKexAlgorithm ski cki = case kexKexAlgorithms cki `intersect` kexKexAlgorithms ski of (x:_) | x == name Curve25519Sha256AtLibsshDotOrg -> pure Curve25519Sha256AtLibsshDotOrg _ -> throwIO exceptionKexNoCommonKexAlgorithm kexCommonEncAlgorithm :: KexInit -> KexInit -> (KexInit -> [Name]) -> IO EncryptionAlgorithm kexCommonEncAlgorithm ski cki f = case f cki `intersect` f ski of (x:_) | x == name Chacha20Poly1305AtOpensshDotCom -> pure Chacha20Poly1305AtOpensshDotCom _ -> throwIO exceptionKexNoCommonEncryptionAlgorithm kexInit :: TransportConfig -> Cookie -> KexInit kexInit config cookie = KexInit { kexCookie = cookie , kexServerHostKeyAlgorithms = NEL.toList $ fmap name (serverHostKeyAlgorithms config) , kexKexAlgorithms = NEL.toList $ fmap name (kexAlgorithms config) , kexEncryptionAlgorithmsClientToServer = NEL.toList $ fmap name (encryptionAlgorithms config) , kexEncryptionAlgorithmsServerToClient = NEL.toList $ fmap name (encryptionAlgorithms config) , kexMacAlgorithmsClientToServer = [] , kexMacAlgorithmsServerToClient = [] , kexCompressionAlgorithmsClientToServer = [name None] , kexCompressionAlgorithmsServerToClient = [name None] , kexLanguagesClientToServer = [] , kexLanguagesServerToClient = [] , kexFirstPacketFollows = False } kexRekeyingRequired :: Transport -> IO Bool kexRekeyingRequired env = do tNow <- getEpochSeconds t <- readMVar (tLastRekeyingTime env) sNow <- readMVar (tBytesSent env) s <- readMVar (tLastRekeyingDataSent env) rNow <- readMVar (tBytesReceived env) r <- readMVar (tLastRekeyingDataReceived env) pure $ t + interval < tNow || s + threshold < sNow || r + threshold < rNow where -- For reasons of fool-proofness the rekeying interval/threshold -- shall never be greater than 1 hour or 1GB. -- NB: This is security critical as some algorithms like ChaCha20 -- use the packet counter as nonce and an overflow will lead to -- nonce reuse! interval = min (maxTimeBeforeRekey $ tConfig env) 3600 threshold = min (maxDataBeforeRekey $ tConfig env) (1024 * 1024 * 1024) trySetSessionId :: Transport -> SessionId -> IO SessionId trySetSessionId env sidDef = tryReadMVar (tSessionId env) >>= \case Nothing -> putMVar (tSessionId env) sidDef >> pure sidDef Just sid -> pure sid kexHash :: Version -> -- client version string Version -> -- server version string KexInit -> -- client kex init msg KexInit -> -- server kex init msg PublicKey -> -- server host key Curve25519.PublicKey -> -- client ephemeral key Curve25519.PublicKey -> -- server ephemeral key Curve25519.DhSecret -> -- dh secret Hash.Digest Hash.SHA256 kexHash (Version vc) (Version vs) ic is ks qc qs k = Hash.hash $ runPut $ putShortString vc <> putShortString vs <> B.word32BE (len ic) <> put ic <> B.word32BE (len is) <> put is <> put ks <> put qc <> put qs <> putAsMPInt k where len = fromIntegral . B.length . put kexKeys :: Curve25519.DhSecret -> Hash.Digest Hash.SHA256 -> SessionId -> KeyStreams kexKeys secret hash (SessionId sess) = KeyStreams $ \i -> BA.convert <$> k1 i : f [k1 i] where k1 i = Hash.hashFinalize $ flip Hash.hashUpdate (SBS.fromShort sess) $ Hash.hashUpdate st i :: Hash.Digest Hash.SHA256 f ks = kx : f (ks ++ [kx]) where kx = Hash.hashFinalize (foldl Hash.hashUpdate st ks) st = flip Hash.hashUpdate hash $ Hash.hashUpdate Hash.hashInit (runPut $ putAsMPInt secret) kexWithVerifiedSignature :: BA.ByteArrayAccess hash => PublicKey -> hash -> Signature -> IO a -> IO a kexWithVerifiedSignature key hash sig action = case (key, sig) of (PublicKeyEd25519 k, SignatureEd25519 s) | Ed25519.verify k hash s -> action _ -> throwIO exceptionKexInvalidSignature ------------------------------------------------------------------------------- -- UTIL ----------------------------------------------------------------------- ------------------------------------------------------------------------------- sendVersion :: (OutputStream stream) => stream -> IO Version sendVersion stream = do sendAll stream $ runPut $ put version pure version -- The maximum length of the version string is 255 chars including CR+LF. -- The version string is usually short and transmitted within -- a single TCP segment. receiveVersion :: (InputStream stream) => stream -> IO Version receiveVersion stream = do bs <- peek stream 255 when (BS.null bs) e0 case BS.elemIndex 0x0a bs of Nothing -> e1 Just i -> maybe e1 pure . runGet =<< receive stream (i+1) where e0 = throwIO exceptionConnectionLost e1 = throwIO exceptionProtocolVersionNotSupported getEpochSeconds :: IO Word64 getEpochSeconds = (`div` 1000000000) <$> getMonotonicTimeNSec