{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE MultiWayIf #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TupleSections #-} {-# LANGUAGE ExistentialQuantification #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} module Network.SSH.Server.Service.Connection ( Connection () , ConnectionConfig (..) , SessionRequest (..) , SessionHandler (..) , Environment (..) , TermInfo (..) , Command (..) , DirectTcpIpRequest (..) , DirectTcpIpHandler (..) , ConnectionMsg (..) , serveConnection ) where import Control.Applicative import qualified Control.Concurrent.Async as Async import Control.Concurrent.STM.TVar import Control.Concurrent.STM.TMVar import Control.Monad (join, when, forever, unless) import Control.Monad.STM (STM, atomically, check, throwSTM) import Control.Exception (bracket, bracketOnError) import qualified Data.ByteString as BS import qualified Data.ByteString.Short as SBS import Data.Default import qualified Data.Map.Strict as M import Data.Word import Data.String import System.Exit import Network.SSH.Encoding import Network.SSH.Exception import Network.SSH.Constants import Network.SSH.Message import qualified Network.SSH.Stream as S import qualified Network.SSH.TStreamingQueue as Q data ConnectionConfig identity = ConnectionConfig { onSessionRequest :: identity -> SessionRequest -> IO (Maybe SessionHandler) -- ^ This callback will be executed for every session request. -- -- Return a `SessionHandler` or `Nothing` to reject the request (default). , onDirectTcpIpRequest :: identity -> DirectTcpIpRequest -> IO (Maybe DirectTcpIpHandler) -- ^ This callback will be executed for every direct-tcpip request. -- -- Return a `DirectTcpIpHandler` or `Nothing` to reject the request (default). , channelMaxCount :: Word16 -- ^ The maximum number of channels that may be active simultaneously (default: 256). -- -- Any requests that would exceed the limit will be rejected. -- Setting the limit to high values might expose the server to denial -- of service issues! , channelMaxQueueSize :: Word32 -- ^ The maximum size of the internal buffers in bytes (also -- limits the maximum window size, default: 32 kB) -- -- Increasing this value might help with performance issues -- (if connection delay is in a bad ration with the available bandwidth the window -- resizing might cause unncessary throttling). , channelMaxPacketSize :: Word32 -- ^ The maximum size of inbound channel data payload (default: 32 kB) -- -- Values that are larger than `channelMaxQueueSize` or the -- maximum message size (35000 bytes) will be automatically adjusted -- to the maximum possible value. } instance Default (ConnectionConfig identity) where def = ConnectionConfig { onSessionRequest = \_ _ -> pure Nothing , onDirectTcpIpRequest = \_ _ -> pure Nothing , channelMaxCount = 256 , channelMaxQueueSize = 32 * 1024 , channelMaxPacketSize = 32 * 1024 } -- | Information associated with the session request. -- -- Might be exteded in the future. data SessionRequest = SessionRequest deriving (Eq, Ord, Show) -- | The session handler contains the application logic that serves a client's -- shell or exec request. -- -- * The `Command` parameter will be present if this is an exec request and absent -- for shell requests. -- * The `TermInfo` parameter will be present if the client requested a pty. -- * The `Environment` parameter contains the set of all env requests -- the client issued before the actual shell or exec request. -- * @stdin@, @stdout@ and @stderr@ are streams. The former can only be read -- from while the latter can only be written to. -- After the handler has gracefully terminated, the implementation assures -- that all bytes will be sent before sending an eof and actually closing the -- channel. -- has gracefully terminated. The client will then receive an eof and close. -- * A @SIGILL@ exit signal will be sent if the handler terminates with an exception. -- Otherwise the client will receive the returned exit code. -- -- @ -- handler :: SessionHandler -- handler = SessionHandler $ \\env mterm mcmd stdin stdout stderr -> case mcmd of -- Just "echo" -> do -- bs <- `receive` stdin 1024 -- `sendAll` stdout bs -- pure `ExitSuccess` -- Nothing -> -- pure (`ExitFailure` 1) -- @ newtype SessionHandler = SessionHandler (forall stdin stdout stderr. (S.InputStream stdin, S.OutputStream stdout, S.OutputStream stderr) => Environment -> Maybe TermInfo -> Maybe Command -> stdin -> stdout -> stderr -> IO ExitCode) -- | The `Environment` is list of key-value pairs. -- -- > Environment [ ("LC_ALL", "en_US.UTF-8") ] newtype Environment = Environment [(BS.ByteString, BS.ByteString)] deriving (Eq, Ord, Show) -- | The `TermInfo` describes the client's terminal settings if it requested a pty. -- -- NOTE: This will follow in a future release. You may access the constructor -- through the `Network.SSH.Internal` module, but should not rely on it yet. data TermInfo = TermInfo PtySettings -- | The `Command` is what the client wants to execute when making an exec request -- (shell requests don't have a command). newtype Command = Command BS.ByteString deriving (Eq, Ord, Show, IsString) -- | When the client makes a `DirectTcpIpRequest` it requests a TCP port forwarding. data DirectTcpIpRequest = DirectTcpIpRequest { dstAddress :: BS.ByteString -- ^ The destination address. , dstPort :: Word32 -- ^ The destination port. , srcAddress :: BS.ByteString -- ^ The source address (usually the IP the client will bind the local listening socket to). , srcPort :: Word32 -- ^ The source port (usually the port the client will bind the local listening socket). } deriving (Eq, Ord, Show) -- | The `DirectTcpIpHandler` contains the application logic -- that handles port forwarding requests. -- -- There is of course no need to actually do a real forwarding - this -- mechanism might also be used to give access to process internal services -- like integrated web servers etc. -- -- * When the handler exits gracefully, the implementation assures that -- all bytes will be sent to the client before terminating the stream -- with an eof and actually closing the channel. newtype DirectTcpIpHandler = DirectTcpIpHandler (forall stream. S.DuplexStream stream => stream -> IO ()) data Connection identity = Connection { connConfig :: ConnectionConfig identity , connIdentity :: identity , connChannels :: TVar (M.Map ChannelId Channel) } data Channel = Channel { chanApplication :: ChannelApplication , chanIdLocal :: ChannelId , chanIdRemote :: ChannelId , chanMaxPacketSizeRemote :: Word32 , chanClosed :: TVar Bool , chanThread :: TMVar (Async.Async ()) } data ChannelApplication = ChannelApplicationSession SessionState | ChannelApplicationDirectTcpIp DirectTcpIpState data SessionState = SessionState { sessHandler :: SessionHandler , sessEnvironment :: TVar Environment , sessPtySettings :: TVar (Maybe PtySettings) , sessStdin :: Q.TStreamingQueue , sessStdout :: Q.TStreamingQueue , sessStderr :: Q.TStreamingQueue } data DirectTcpIpState = DirectTcpIpState { dtiStreamIn :: Q.TStreamingQueue , dtiStreamOut :: Q.TStreamingQueue } instance S.InputStream DirectTcpIpState where peek x = S.peek (dtiStreamIn x) receive x = S.receive (dtiStreamIn x) instance S.OutputStream DirectTcpIpState where send x = S.send (dtiStreamOut x) instance S.DuplexStream DirectTcpIpState where data ConnectionMsg = ConnectionChannelOpen ChannelOpen | ConnectionChannelClose ChannelClose | ConnectionChannelEof ChannelEof | ConnectionChannelData ChannelData | ConnectionChannelRequest ChannelRequest | ConnectionChannelWindowAdjust ChannelWindowAdjust deriving (Eq, Show) instance Encoding ConnectionMsg where put (ConnectionChannelOpen x) = put x put (ConnectionChannelClose x) = put x put (ConnectionChannelEof x) = put x put (ConnectionChannelData x) = put x put (ConnectionChannelRequest x) = put x put (ConnectionChannelWindowAdjust x) = put x get = ConnectionChannelOpen <$> get <|> ConnectionChannelClose <$> get <|> ConnectionChannelEof <$> get <|> ConnectionChannelData <$> get <|> ConnectionChannelRequest <$> get <|> ConnectionChannelWindowAdjust <$> get serveConnection :: forall stream identity. MessageStream stream => ConnectionConfig identity -> stream -> identity -> IO () serveConnection config stream idnt = bracket open close $ \connection -> forever $ receiveMessage stream >>= \case ConnectionChannelOpen req -> connectionChannelOpen connection stream req ConnectionChannelClose req -> connectionChannelClose connection stream req ConnectionChannelEof req -> connectionChannelEof connection req ConnectionChannelData req -> connectionChannelData connection req ConnectionChannelRequest req -> connectionChannelRequest connection stream req ConnectionChannelWindowAdjust req -> connectionChannelWindowAdjust connection req where open :: IO (Connection identity) open = Connection <$> pure config <*> pure idnt <*> newTVarIO mempty close :: Connection identity -> IO () close connection = do channels <- readTVarIO (connChannels connection) mapM_ terminate (M.elems channels) where terminate channel = maybe (pure ()) Async.cancel =<< atomically (tryReadTMVar $ chanThread channel) connectionChannelOpen :: forall stream identity. MessageStream stream => Connection identity -> stream -> ChannelOpen -> IO () connectionChannelOpen connection stream (ChannelOpen remoteChannelId remoteWindowSize remotePacketSize channelType) = case channelType of ChannelOpenSession -> onSessionRequest (connConfig connection) (connIdentity connection) SessionRequest >>= \case Nothing -> sendMessage stream $ openFailure ChannelOpenAdministrativelyProhibited Just handler -> do env <- newTVarIO (Environment []) pty <- newTVarIO Nothing wsLocal <- newTVarIO maxQueueSize wsRemote <- newTVarIO remoteWindowSize stdIn <- atomically $ Q.newTStreamingQueue maxQueueSize wsLocal stdOut <- atomically $ Q.newTStreamingQueue maxQueueSize wsRemote stdErr <- atomically $ Q.newTStreamingQueue maxQueueSize wsRemote let app = ChannelApplicationSession SessionState { sessHandler = handler , sessEnvironment = env , sessPtySettings = pty , sessStdin = stdIn , sessStdout = stdOut , sessStderr = stdErr } atomically (openApplicationChannel app) >>= \case Left failure -> sendMessage stream failure Right (_,confirmation) -> sendMessage stream confirmation ChannelOpenDirectTcpIp da dp oa op -> do let req = DirectTcpIpRequest (SBS.fromShort da) dp (SBS.fromShort oa) op onDirectTcpIpRequest (connConfig connection) (connIdentity connection) req >>= \case Nothing -> sendMessage stream $ openFailure ChannelOpenAdministrativelyProhibited Just (DirectTcpIpHandler handler) -> do wsLocal <- newTVarIO maxQueueSize wsRemote <- newTVarIO remoteWindowSize streamIn <- atomically $ Q.newTStreamingQueue maxQueueSize wsLocal streamOut <- atomically $ Q.newTStreamingQueue maxQueueSize wsRemote let st = DirectTcpIpState { dtiStreamIn = streamIn , dtiStreamOut = streamOut } let app = ChannelApplicationDirectTcpIp st atomically (openApplicationChannel app) >>= \case Left failure -> sendMessage stream failure Right (c,confirmation) -> do forkDirectTcpIpHandler stream c st (handler st) sendMessage stream confirmation ChannelOpenOther {} -> sendMessage stream $ openFailure ChannelOpenUnknownChannelType where openFailure :: ChannelOpenFailureReason -> ChannelOpenFailure openFailure reason = ChannelOpenFailure remoteChannelId reason mempty mempty openApplicationChannel :: ChannelApplication -> STM (Either ChannelOpenFailure (Channel, ChannelOpenConfirmation)) openApplicationChannel application = tryRegisterChannel $ \localChannelId -> do closed <- newTVar False thread <- newEmptyTMVar pure Channel { chanApplication = application , chanIdLocal = localChannelId , chanIdRemote = remoteChannelId , chanMaxPacketSizeRemote = remotePacketSize , chanClosed = closed , chanThread = thread } tryRegisterChannel :: (ChannelId -> STM Channel) -> STM (Either ChannelOpenFailure (Channel, ChannelOpenConfirmation)) tryRegisterChannel createChannel = do channels <- readTVar (connChannels connection) case selectFreeLocalChannelId channels of Nothing -> pure $ Left $ openFailure ChannelOpenResourceShortage Just localChannelId -> do channel <- createChannel localChannelId writeTVar (connChannels connection) $! M.insert localChannelId channel channels pure $ Right $ (channel,) $ ChannelOpenConfirmation remoteChannelId localChannelId maxQueueSize maxPacketSize -- The maxQueueSize must at least be one (even if 0 in the config) -- and must not exceed the range of Int (might happen on 32bit systems -- as Int's guaranteed upper bound is only 2^29 -1). -- The value is adjusted silently as this won't be a problem -- for real use cases and is just the safest thing to do. maxQueueSize :: Word32 maxQueueSize = max 1 $ fromIntegral $ min maxBoundIntWord32 (channelMaxQueueSize $ connConfig connection) maxPacketSize :: Word32 maxPacketSize = max 1 $ fromIntegral $ min maxBoundIntWord32 (channelMaxPacketSize $ connConfig connection) selectFreeLocalChannelId :: M.Map ChannelId a -> Maybe ChannelId selectFreeLocalChannelId m | M.size m >= fromIntegral maxCount = Nothing | otherwise = f (ChannelId 0) $ M.keys m where f i [] = Just i f (ChannelId i) (ChannelId k:ks) | i == k = f (ChannelId $ i+1) ks | otherwise = Just (ChannelId i) maxCount = channelMaxCount (connConfig connection) connectionChannelEof :: Connection identity -> ChannelEof -> IO () connectionChannelEof connection (ChannelEof localChannelId) = atomically $ do channel <- getChannelSTM connection localChannelId let queue = case chanApplication channel of ChannelApplicationSession st -> sessStdin st ChannelApplicationDirectTcpIp st -> dtiStreamIn st Q.terminate queue connectionChannelClose :: forall stream identity. MessageStream stream => Connection identity -> stream -> ChannelClose -> IO () connectionChannelClose connection stream (ChannelClose localChannelId) = do channel <- atomically $ getChannelSTM connection localChannelId maybe (pure ()) Async.cancel =<< atomically (tryReadTMVar $ chanThread channel) atomically $ do channels <- readTVar (connChannels connection) writeTVar (connChannels connection) $! M.delete localChannelId channels -- When the channel is not marked as already closed then the close -- must have been initiated by the client and the server needs to send -- a confirmation (both sides may issue close messages simultaneously -- and receive them afterwards). closeAlreadySent <- readTVarIO (chanClosed channel) unless closeAlreadySent $ sendMessage stream $ ChannelClose $ chanIdRemote channel connectionChannelData :: Connection identity -> ChannelData -> IO () connectionChannelData connection (ChannelData localChannelId packet) = atomically $ do when (packetSize > maxPacketSize) (throwSTM exceptionPacketSizeExceeded) channel <- getChannelSTM connection localChannelId let queue = case chanApplication channel of ChannelApplicationSession st -> sessStdin st ChannelApplicationDirectTcpIp st -> dtiStreamIn st i <- Q.enqueue queue (SBS.fromShort packet) <|> throwSTM exceptionWindowSizeUnderrun when (i == 0) (throwSTM exceptionDataAfterEof) when (i /= packetSize) (throwSTM exceptionWindowSizeUnderrun) where packetSize :: Word32 packetSize = fromIntegral $ SBS.length packet maxPacketSize :: Word32 maxPacketSize = max 1 $ fromIntegral $ min maxBoundIntWord32 (channelMaxPacketSize $ connConfig connection) connectionChannelWindowAdjust :: Connection identity -> ChannelWindowAdjust -> IO () connectionChannelWindowAdjust connection (ChannelWindowAdjust channelId increment) = atomically $ do channel <- getChannelSTM connection channelId let queue = case chanApplication channel of ChannelApplicationSession st -> sessStdout st ChannelApplicationDirectTcpIp st -> dtiStreamOut st Q.addWindowSpace queue increment <|> throwSTM exceptionWindowSizeOverflow connectionChannelRequest :: forall identity stream. MessageStream stream => Connection identity -> stream -> ChannelRequest -> IO () connectionChannelRequest connection stream (ChannelRequest channelId typ wantReply dat) = join $ atomically $ do channel <- getChannelSTM connection channelId case chanApplication channel of ChannelApplicationSession sessionState -> case typ of "env" -> interpret $ \(ChannelRequestEnv name value) -> do Environment env <- readTVar (sessEnvironment sessionState) writeTVar (sessEnvironment sessionState) $! Environment $ (SBS.fromShort name, SBS.fromShort value):env pure $ success channel "pty-req" -> interpret $ \(ChannelRequestPty settings) -> do writeTVar (sessPtySettings sessionState) (Just settings) pure $ success channel "shell" -> interpret $ \ChannelRequestShell -> do env <- readTVar (sessEnvironment sessionState) pty <- readTVar (sessPtySettings sessionState) stdin <- pure (sessStdin sessionState) stdout <- pure (sessStdout sessionState) stderr <- pure (sessStderr sessionState) let SessionHandler handler = sessHandler sessionState pure $ do forkSessionHandler stream channel stdin stdout stderr $ handler env (TermInfo <$> pty) Nothing stdin stdout stderr success channel "exec" -> interpret $ \(ChannelRequestExec command) -> do env <- readTVar (sessEnvironment sessionState) pty <- readTVar (sessPtySettings sessionState) stdin <- pure (sessStdin sessionState) stdout <- pure (sessStdout sessionState) stderr <- pure (sessStderr sessionState) let SessionHandler handler = sessHandler sessionState pure $ do forkSessionHandler stream channel stdin stdout stderr $ handler env (TermInfo <$> pty) (Just (Command $ SBS.fromShort command)) stdin stdout stderr success channel -- "signal" -> -- "exit-status" -> -- "exit-signal" -> -- "window-change" -> _ -> pure $ failure channel ChannelApplicationDirectTcpIp {} -> pure $ failure channel where interpret f = maybe (throwSTM exceptionInvalidChannelRequest) f (runGet dat) success channel | wantReply = sendMessage stream $ ChannelSuccess (chanIdRemote channel) | otherwise = pure () failure channel | wantReply = sendMessage stream $ ChannelFailure (chanIdRemote channel) | otherwise = pure () forkDirectTcpIpHandler :: forall stream. MessageStream stream => stream -> Channel -> DirectTcpIpState -> IO () -> IO () forkDirectTcpIpHandler stream channel st handle = do registerThread channel handle supervise where supervise :: Async.Async () -> IO () supervise thread = do continue <- join $ atomically $ waitOutput <|> waitExit thread <|> waitLocalWindowAdjust when continue $ supervise thread waitExit :: Async.Async () -> STM (IO Bool) waitExit thread = do eof <- Async.waitCatchSTM thread >>= \case Right _ -> pure True Left _ -> pure False writeTVar (chanClosed channel) True pure $ do when eof $ sendMessage stream $ ChannelEof (chanIdRemote channel) sendMessage stream $ ChannelClose (chanIdRemote channel) pure False waitOutput :: STM (IO Bool) waitOutput = do bs <- Q.dequeueShort (dtiStreamOut st) (chanMaxPacketSizeRemote channel) pure $ do sendMessage stream $ ChannelData (chanIdRemote channel) bs pure True waitLocalWindowAdjust :: STM (IO Bool) waitLocalWindowAdjust = do check =<< Q.askWindowSpaceAdjustRecommended (dtiStreamIn st) increaseBy <- Q.fillWindowSpace (dtiStreamIn st) pure $ do sendMessage stream $ ChannelWindowAdjust (chanIdRemote channel) increaseBy pure True forkSessionHandler :: forall stream. MessageStream stream => stream -> Channel -> Q.TStreamingQueue -> Q.TStreamingQueue -> Q.TStreamingQueue -> IO ExitCode -> IO () forkSessionHandler stream channel stdin stdout stderr run = do registerThread channel run supervise where -- The supervisor thread waits for several event sources simultaneously, -- handles them and loops until the session thread has terminated and exit -- has been signaled or the channel/connection got closed. supervise :: Async.Async ExitCode -> IO () supervise thread = do continue <- join $ atomically $ -- NB: The order is critical: Another order would cause a close -- or eof to be sent before all data has been flushed. waitStdout <|> waitStderr <|> waitExit thread <|> waitLocalWindowAdjust when continue $ supervise thread waitExit :: Async.Async ExitCode -> STM (IO Bool) waitExit thread = do exitMessage <- Async.waitCatchSTM thread >>= \case Right c -> pure $ req "exit-status" $ runPut $ put $ ChannelRequestExitStatus c Left _ -> pure $ req "exit-signal" $ runPut $ put $ ChannelRequestExitSignal "ILL" False "" "" writeTVar (chanClosed channel) True pure $ do sendMessage stream eofMessage sendMessage stream exitMessage sendMessage stream closeMessage pure False where req t = ChannelRequest (chanIdRemote channel) t False eofMessage = ChannelEof (chanIdRemote channel) closeMessage = ChannelClose (chanIdRemote channel) waitStdout :: STM (IO Bool) waitStdout = do bs <- Q.dequeueShort stdout (chanMaxPacketSizeRemote channel) pure $ do sendMessage stream $ ChannelData (chanIdRemote channel) bs pure True waitStderr :: STM (IO Bool) waitStderr = do bs <- Q.dequeueShort stderr (chanMaxPacketSizeRemote channel) pure $ do sendMessage stream $ ChannelExtendedData (chanIdRemote channel) 1 bs pure True waitLocalWindowAdjust :: STM (IO Bool) waitLocalWindowAdjust = do check =<< Q.askWindowSpaceAdjustRecommended stdin increaseBy <- Q.fillWindowSpace stdin pure $ do sendMessage stream $ ChannelWindowAdjust (chanIdRemote channel) increaseBy pure True getChannelSTM :: Connection identity -> ChannelId -> STM Channel getChannelSTM connection channelId = do channels <- readTVar (connChannels connection) case M.lookup channelId channels of Just channel -> pure channel Nothing -> throwSTM exceptionInvalidChannelId -- Two threads are forked: a worker thread running as Async and a -- supervisor thread which is registered with the channel. -- -> The worker thread does never outlive the supervisor thread (`withAsync`). -- -> The supervisor thread terminates itself when either the worker thread -- has terminated (`waitExit`) or gets cancelled when the channel/connection -- gets closed. -- -> The supervisor thread is started even if a thread is already running. -- It is blocked until it is notified that it is the only one -- running and its Async has been registered with the channel (meaning -- it will be reliably cancelled on main thread termination). registerThread :: Channel -> IO a -> (Async.Async a -> IO ()) -> IO () registerThread channel run supervise = do barrier <- newTVarIO False let prepare = Async.async $ do atomically $ readTVar barrier >>= check Async.withAsync run supervise let abort = Async.cancel let register thread = putTMVar (chanThread channel) thread <|> throwSTM exceptionAlreadyExecuting bracketOnError prepare abort $ \thread -> atomically $ register thread >> writeTVar barrier True