{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiWayIf #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
module Network.SSH.Server.Service.Connection
( Connection ()
, ConnectionConfig (..)
, SessionRequest (..)
, SessionHandler (..)
, Environment (..)
, TermInfo (..)
, Command (..)
, DirectTcpIpRequest (..)
, DirectTcpIpHandler (..)
, ConnectionMsg (..)
, serveConnection
) where
import Control.Applicative
import qualified Control.Concurrent.Async as Async
import Control.Concurrent.STM.TVar
import Control.Concurrent.STM.TMVar
import Control.Monad (join, when, forever, unless)
import Control.Monad.STM (STM, atomically, check, throwSTM)
import Control.Exception (bracket, bracketOnError)
import qualified Data.ByteString as BS
import qualified Data.ByteString.Short as SBS
import Data.Default
import qualified Data.Map.Strict as M
import Data.Word
import Data.String
import System.Exit
import Network.SSH.Encoding
import Network.SSH.Exception
import Network.SSH.Constants
import Network.SSH.Message
import qualified Network.SSH.Stream as S
import qualified Network.SSH.TStreamingQueue as Q
data ConnectionConfig identity
= ConnectionConfig
{ onSessionRequest :: identity -> SessionRequest -> IO (Maybe SessionHandler)
, onDirectTcpIpRequest :: identity -> DirectTcpIpRequest -> IO (Maybe DirectTcpIpHandler)
, channelMaxCount :: Word16
, channelMaxQueueSize :: Word32
, channelMaxPacketSize :: Word32
}
instance Default (ConnectionConfig identity) where
def = ConnectionConfig
{ onSessionRequest = \_ _ -> pure Nothing
, onDirectTcpIpRequest = \_ _ -> pure Nothing
, channelMaxCount = 256
, channelMaxQueueSize = 32 * 1024
, channelMaxPacketSize = 32 * 1024
}
data SessionRequest
= SessionRequest
deriving (Eq, Ord, Show)
newtype SessionHandler =
SessionHandler (forall stdin stdout stderr. (S.InputStream stdin, S.OutputStream stdout, S.OutputStream stderr)
=> Environment -> Maybe TermInfo -> Maybe Command -> stdin -> stdout -> stderr -> IO ExitCode)
newtype Environment = Environment [(BS.ByteString, BS.ByteString)]
deriving (Eq, Ord, Show)
data TermInfo = TermInfo PtySettings
newtype Command = Command BS.ByteString
deriving (Eq, Ord, Show, IsString)
data DirectTcpIpRequest
= DirectTcpIpRequest
{ dstAddress :: BS.ByteString
, dstPort :: Word32
, srcAddress :: BS.ByteString
, srcPort :: Word32
} deriving (Eq, Ord, Show)
newtype DirectTcpIpHandler =
DirectTcpIpHandler (forall stream. S.DuplexStream stream => stream -> IO ())
data Connection identity
= Connection
{ connConfig :: ConnectionConfig identity
, connIdentity :: identity
, connChannels :: TVar (M.Map ChannelId Channel)
}
data Channel
= Channel
{ chanApplication :: ChannelApplication
, chanIdLocal :: ChannelId
, chanIdRemote :: ChannelId
, chanMaxPacketSizeRemote :: Word32
, chanClosed :: TVar Bool
, chanThread :: TMVar (Async.Async ())
}
data ChannelApplication
= ChannelApplicationSession SessionState
| ChannelApplicationDirectTcpIp DirectTcpIpState
data SessionState
= SessionState
{ sessHandler :: SessionHandler
, sessEnvironment :: TVar Environment
, sessPtySettings :: TVar (Maybe PtySettings)
, sessStdin :: Q.TStreamingQueue
, sessStdout :: Q.TStreamingQueue
, sessStderr :: Q.TStreamingQueue
}
data DirectTcpIpState
= DirectTcpIpState
{ dtiStreamIn :: Q.TStreamingQueue
, dtiStreamOut :: Q.TStreamingQueue
}
instance S.InputStream DirectTcpIpState where
peek x = S.peek (dtiStreamIn x)
receive x = S.receive (dtiStreamIn x)
instance S.OutputStream DirectTcpIpState where
send x = S.send (dtiStreamOut x)
instance S.DuplexStream DirectTcpIpState where
data ConnectionMsg
= ConnectionChannelOpen ChannelOpen
| ConnectionChannelClose ChannelClose
| ConnectionChannelEof ChannelEof
| ConnectionChannelData ChannelData
| ConnectionChannelRequest ChannelRequest
| ConnectionChannelWindowAdjust ChannelWindowAdjust
deriving (Eq, Show)
instance Encoding ConnectionMsg where
put (ConnectionChannelOpen x) = put x
put (ConnectionChannelClose x) = put x
put (ConnectionChannelEof x) = put x
put (ConnectionChannelData x) = put x
put (ConnectionChannelRequest x) = put x
put (ConnectionChannelWindowAdjust x) = put x
get = ConnectionChannelOpen <$> get
<|> ConnectionChannelClose <$> get
<|> ConnectionChannelEof <$> get
<|> ConnectionChannelData <$> get
<|> ConnectionChannelRequest <$> get
<|> ConnectionChannelWindowAdjust <$> get
serveConnection :: forall stream identity. MessageStream stream =>
ConnectionConfig identity -> stream -> identity -> IO ()
serveConnection config stream idnt = bracket open close $ \connection ->
forever $ receiveMessage stream >>= \case
ConnectionChannelOpen req -> connectionChannelOpen connection stream req
ConnectionChannelClose req -> connectionChannelClose connection stream req
ConnectionChannelEof req -> connectionChannelEof connection req
ConnectionChannelData req -> connectionChannelData connection req
ConnectionChannelRequest req -> connectionChannelRequest connection stream req
ConnectionChannelWindowAdjust req -> connectionChannelWindowAdjust connection req
where
open :: IO (Connection identity)
open = Connection
<$> pure config
<*> pure idnt
<*> newTVarIO mempty
close :: Connection identity -> IO ()
close connection = do
channels <- readTVarIO (connChannels connection)
mapM_ terminate (M.elems channels)
where
terminate channel =
maybe (pure ()) Async.cancel =<< atomically (tryReadTMVar $ chanThread channel)
connectionChannelOpen :: forall stream identity. MessageStream stream =>
Connection identity -> stream -> ChannelOpen -> IO ()
connectionChannelOpen connection stream (ChannelOpen remoteChannelId remoteWindowSize remotePacketSize channelType) =
case channelType of
ChannelOpenSession ->
onSessionRequest (connConfig connection) (connIdentity connection) SessionRequest >>= \case
Nothing ->
sendMessage stream $ openFailure ChannelOpenAdministrativelyProhibited
Just handler -> do
env <- newTVarIO (Environment [])
pty <- newTVarIO Nothing
wsLocal <- newTVarIO maxQueueSize
wsRemote <- newTVarIO remoteWindowSize
stdIn <- atomically $ Q.newTStreamingQueue maxQueueSize wsLocal
stdOut <- atomically $ Q.newTStreamingQueue maxQueueSize wsRemote
stdErr <- atomically $ Q.newTStreamingQueue maxQueueSize wsRemote
let app = ChannelApplicationSession SessionState
{ sessHandler = handler
, sessEnvironment = env
, sessPtySettings = pty
, sessStdin = stdIn
, sessStdout = stdOut
, sessStderr = stdErr
}
atomically (openApplicationChannel app) >>= \case
Left failure -> sendMessage stream failure
Right (_,confirmation) -> sendMessage stream confirmation
ChannelOpenDirectTcpIp da dp oa op -> do
let req = DirectTcpIpRequest (SBS.fromShort da) dp (SBS.fromShort oa) op
onDirectTcpIpRequest (connConfig connection) (connIdentity connection) req >>= \case
Nothing ->
sendMessage stream $ openFailure ChannelOpenAdministrativelyProhibited
Just (DirectTcpIpHandler handler) -> do
wsLocal <- newTVarIO maxQueueSize
wsRemote <- newTVarIO remoteWindowSize
streamIn <- atomically $ Q.newTStreamingQueue maxQueueSize wsLocal
streamOut <- atomically $ Q.newTStreamingQueue maxQueueSize wsRemote
let st = DirectTcpIpState
{ dtiStreamIn = streamIn
, dtiStreamOut = streamOut
}
let app = ChannelApplicationDirectTcpIp st
atomically (openApplicationChannel app) >>= \case
Left failure -> sendMessage stream failure
Right (c,confirmation) -> do
forkDirectTcpIpHandler stream c st (handler st)
sendMessage stream confirmation
ChannelOpenOther {} ->
sendMessage stream $ openFailure ChannelOpenUnknownChannelType
where
openFailure :: ChannelOpenFailureReason -> ChannelOpenFailure
openFailure reason = ChannelOpenFailure remoteChannelId reason mempty mempty
openApplicationChannel :: ChannelApplication -> STM (Either ChannelOpenFailure (Channel, ChannelOpenConfirmation))
openApplicationChannel application = tryRegisterChannel $ \localChannelId -> do
closed <- newTVar False
thread <- newEmptyTMVar
pure Channel
{ chanApplication = application
, chanIdLocal = localChannelId
, chanIdRemote = remoteChannelId
, chanMaxPacketSizeRemote = remotePacketSize
, chanClosed = closed
, chanThread = thread
}
tryRegisterChannel :: (ChannelId -> STM Channel) -> STM (Either ChannelOpenFailure (Channel, ChannelOpenConfirmation))
tryRegisterChannel createChannel = do
channels <- readTVar (connChannels connection)
case selectFreeLocalChannelId channels of
Nothing -> pure $ Left $ openFailure ChannelOpenResourceShortage
Just localChannelId -> do
channel <- createChannel localChannelId
writeTVar (connChannels connection) $! M.insert localChannelId channel channels
pure $ Right $ (channel,) $ ChannelOpenConfirmation
remoteChannelId
localChannelId
maxQueueSize
maxPacketSize
maxQueueSize :: Word32
maxQueueSize = max 1 $ fromIntegral $ min maxBoundIntWord32
(channelMaxQueueSize $ connConfig connection)
maxPacketSize :: Word32
maxPacketSize = max 1 $ fromIntegral $ min maxBoundIntWord32
(channelMaxPacketSize $ connConfig connection)
selectFreeLocalChannelId :: M.Map ChannelId a -> Maybe ChannelId
selectFreeLocalChannelId m
| M.size m >= fromIntegral maxCount = Nothing
| otherwise = f (ChannelId 0) $ M.keys m
where
f i [] = Just i
f (ChannelId i) (ChannelId k:ks)
| i == k = f (ChannelId $ i+1) ks
| otherwise = Just (ChannelId i)
maxCount = channelMaxCount (connConfig connection)
connectionChannelEof ::
Connection identity -> ChannelEof -> IO ()
connectionChannelEof connection (ChannelEof localChannelId) = atomically $ do
channel <- getChannelSTM connection localChannelId
let queue = case chanApplication channel of
ChannelApplicationSession st -> sessStdin st
ChannelApplicationDirectTcpIp st -> dtiStreamIn st
Q.terminate queue
connectionChannelClose :: forall stream identity. MessageStream stream =>
Connection identity -> stream -> ChannelClose -> IO ()
connectionChannelClose connection stream (ChannelClose localChannelId) = do
channel <- atomically $ getChannelSTM connection localChannelId
maybe (pure ()) Async.cancel =<< atomically (tryReadTMVar $ chanThread channel)
atomically $ do
channels <- readTVar (connChannels connection)
writeTVar (connChannels connection) $! M.delete localChannelId channels
closeAlreadySent <- readTVarIO (chanClosed channel)
unless closeAlreadySent $
sendMessage stream $ ChannelClose $ chanIdRemote channel
connectionChannelData ::
Connection identity -> ChannelData -> IO ()
connectionChannelData connection (ChannelData localChannelId packet) = atomically $ do
when (packetSize > maxPacketSize) (throwSTM exceptionPacketSizeExceeded)
channel <- getChannelSTM connection localChannelId
let queue = case chanApplication channel of
ChannelApplicationSession st -> sessStdin st
ChannelApplicationDirectTcpIp st -> dtiStreamIn st
i <- Q.enqueue queue (SBS.fromShort packet) <|> throwSTM exceptionWindowSizeUnderrun
when (i == 0) (throwSTM exceptionDataAfterEof)
when (i /= packetSize) (throwSTM exceptionWindowSizeUnderrun)
where
packetSize :: Word32
packetSize = fromIntegral $ SBS.length packet
maxPacketSize :: Word32
maxPacketSize = max 1 $ fromIntegral $ min maxBoundIntWord32
(channelMaxPacketSize $ connConfig connection)
connectionChannelWindowAdjust ::
Connection identity -> ChannelWindowAdjust -> IO ()
connectionChannelWindowAdjust connection (ChannelWindowAdjust channelId increment) = atomically $ do
channel <- getChannelSTM connection channelId
let queue = case chanApplication channel of
ChannelApplicationSession st -> sessStdout st
ChannelApplicationDirectTcpIp st -> dtiStreamOut st
Q.addWindowSpace queue increment <|> throwSTM exceptionWindowSizeOverflow
connectionChannelRequest :: forall identity stream. MessageStream stream =>
Connection identity -> stream -> ChannelRequest -> IO ()
connectionChannelRequest connection stream (ChannelRequest channelId typ wantReply dat) = join $ atomically $ do
channel <- getChannelSTM connection channelId
case chanApplication channel of
ChannelApplicationSession sessionState -> case typ of
"env" -> interpret $ \(ChannelRequestEnv name value) -> do
Environment env <- readTVar (sessEnvironment sessionState)
writeTVar (sessEnvironment sessionState) $! Environment $ (SBS.fromShort name, SBS.fromShort value):env
pure $ success channel
"pty-req" -> interpret $ \(ChannelRequestPty settings) -> do
writeTVar (sessPtySettings sessionState) (Just settings)
pure $ success channel
"shell" -> interpret $ \ChannelRequestShell -> do
env <- readTVar (sessEnvironment sessionState)
pty <- readTVar (sessPtySettings sessionState)
stdin <- pure (sessStdin sessionState)
stdout <- pure (sessStdout sessionState)
stderr <- pure (sessStderr sessionState)
let SessionHandler handler = sessHandler sessionState
pure $ do
forkSessionHandler stream channel stdin stdout stderr $
handler env (TermInfo <$> pty) Nothing stdin stdout stderr
success channel
"exec" -> interpret $ \(ChannelRequestExec command) -> do
env <- readTVar (sessEnvironment sessionState)
pty <- readTVar (sessPtySettings sessionState)
stdin <- pure (sessStdin sessionState)
stdout <- pure (sessStdout sessionState)
stderr <- pure (sessStderr sessionState)
let SessionHandler handler = sessHandler sessionState
pure $ do
forkSessionHandler stream channel stdin stdout stderr $
handler env (TermInfo <$> pty) (Just (Command $ SBS.fromShort command)) stdin stdout stderr
success channel
_ -> pure $ failure channel
ChannelApplicationDirectTcpIp {} -> pure $ failure channel
where
interpret f = maybe (throwSTM exceptionInvalidChannelRequest) f (runGet dat)
success channel
| wantReply = sendMessage stream $ ChannelSuccess (chanIdRemote channel)
| otherwise = pure ()
failure channel
| wantReply = sendMessage stream $ ChannelFailure (chanIdRemote channel)
| otherwise = pure ()
forkDirectTcpIpHandler :: forall stream. MessageStream stream =>
stream -> Channel -> DirectTcpIpState -> IO () -> IO ()
forkDirectTcpIpHandler stream channel st handle = do
registerThread channel handle supervise
where
supervise :: Async.Async () -> IO ()
supervise thread = do
continue <- join $ atomically
$ waitOutput
<|> waitExit thread
<|> waitLocalWindowAdjust
when continue $ supervise thread
waitExit :: Async.Async () -> STM (IO Bool)
waitExit thread = do
eof <- Async.waitCatchSTM thread >>= \case
Right _ -> pure True
Left _ -> pure False
writeTVar (chanClosed channel) True
pure $ do
when eof $ sendMessage stream $ ChannelEof (chanIdRemote channel)
sendMessage stream $ ChannelClose (chanIdRemote channel)
pure False
waitOutput :: STM (IO Bool)
waitOutput = do
bs <- Q.dequeueShort (dtiStreamOut st) (chanMaxPacketSizeRemote channel)
pure $ do
sendMessage stream $ ChannelData (chanIdRemote channel) bs
pure True
waitLocalWindowAdjust :: STM (IO Bool)
waitLocalWindowAdjust = do
check =<< Q.askWindowSpaceAdjustRecommended (dtiStreamIn st)
increaseBy <- Q.fillWindowSpace (dtiStreamIn st)
pure $ do
sendMessage stream $ ChannelWindowAdjust (chanIdRemote channel) increaseBy
pure True
forkSessionHandler :: forall stream. MessageStream stream =>
stream -> Channel -> Q.TStreamingQueue -> Q.TStreamingQueue -> Q.TStreamingQueue -> IO ExitCode -> IO ()
forkSessionHandler stream channel stdin stdout stderr run = do
registerThread channel run supervise
where
supervise :: Async.Async ExitCode -> IO ()
supervise thread = do
continue <- join $ atomically $
waitStdout
<|> waitStderr
<|> waitExit thread
<|> waitLocalWindowAdjust
when continue $ supervise thread
waitExit :: Async.Async ExitCode -> STM (IO Bool)
waitExit thread = do
exitMessage <- Async.waitCatchSTM thread >>= \case
Right c -> pure $ req "exit-status" $ runPut $ put $ ChannelRequestExitStatus c
Left _ -> pure $ req "exit-signal" $ runPut $ put $ ChannelRequestExitSignal "ILL" False "" ""
writeTVar (chanClosed channel) True
pure $ do
sendMessage stream eofMessage
sendMessage stream exitMessage
sendMessage stream closeMessage
pure False
where
req t = ChannelRequest (chanIdRemote channel) t False
eofMessage = ChannelEof (chanIdRemote channel)
closeMessage = ChannelClose (chanIdRemote channel)
waitStdout :: STM (IO Bool)
waitStdout = do
bs <- Q.dequeueShort stdout (chanMaxPacketSizeRemote channel)
pure $ do
sendMessage stream $ ChannelData (chanIdRemote channel) bs
pure True
waitStderr :: STM (IO Bool)
waitStderr = do
bs <- Q.dequeueShort stderr (chanMaxPacketSizeRemote channel)
pure $ do
sendMessage stream $ ChannelExtendedData (chanIdRemote channel) 1 bs
pure True
waitLocalWindowAdjust :: STM (IO Bool)
waitLocalWindowAdjust = do
check =<< Q.askWindowSpaceAdjustRecommended stdin
increaseBy <- Q.fillWindowSpace stdin
pure $ do
sendMessage stream $ ChannelWindowAdjust (chanIdRemote channel) increaseBy
pure True
getChannelSTM :: Connection identity -> ChannelId -> STM Channel
getChannelSTM connection channelId = do
channels <- readTVar (connChannels connection)
case M.lookup channelId channels of
Just channel -> pure channel
Nothing -> throwSTM exceptionInvalidChannelId
registerThread :: Channel -> IO a -> (Async.Async a -> IO ()) -> IO ()
registerThread channel run supervise = do
barrier <- newTVarIO False
let prepare = Async.async $ do
atomically $ readTVar barrier >>= check
Async.withAsync run supervise
let abort = Async.cancel
let register thread =
putTMVar (chanThread channel) thread
<|> throwSTM exceptionAlreadyExecuting
bracketOnError prepare abort $ \thread -> atomically $
register thread >> writeTVar barrier True