module SSH.Session where
import Control.Concurrent.Chan (Chan)
import Control.Monad.IO.Class (liftIO)
import Control.Monad.Trans.State (StateT, get, gets, modify, runState)
import Data.Binary (decode, encode)
import Data.Word (Word8, Word32)
import System.IO (Handle)
import qualified Codec.Crypto.SimpleAES as A
import qualified Data.ByteString as BS
import qualified Data.ByteString.Lazy as LBS
import qualified Data.Map as M
import SSH.Channel (ChannelConfig, ChannelMessage)
import SSH.Crypto (Cipher(..), HMAC(..), KeyPair(..), PublicKey(..), CipherType(..), CipherMode(..), toBlocks)
import SSH.Debug (dump)
import SSH.NetReader (NetReader)
import SSH.Packet (io)
import SSH.Sender (Sender(..), SenderMessage)
import SSH.Util (strictLBS)
type Session = StateT SessionState IO
data SessionState
= Initial
{ ssConfig :: SessionConfig
, ssChannelConfig :: ChannelConfig
, ssThem :: Handle
, ssSend :: SenderMessage -> IO ()
, ssPayload :: LBS.ByteString
, ssTheirVersion :: String
, ssOurKEXInit :: LBS.ByteString
, ssInSeq :: Word32
}
| GotKEXInit
{ ssConfig :: SessionConfig
, ssChannelConfig :: ChannelConfig
, ssThem :: Handle
, ssSend :: SenderMessage -> IO ()
, ssPayload :: LBS.ByteString
, ssTheirVersion :: String
, ssOurKEXInit :: LBS.ByteString
, ssInSeq :: Word32
, ssTheirKEXInit :: LBS.ByteString
, ssOutCipher :: Cipher
, ssInCipher :: Cipher
, ssOutHMACPrep :: LBS.ByteString -> HMAC
, ssInHMACPrep :: LBS.ByteString -> HMAC
}
| Final
{ ssConfig :: SessionConfig
, ssChannelConfig :: ChannelConfig
, ssChannels :: M.Map Word32 (Chan ChannelMessage)
, ssID :: LBS.ByteString
, ssThem :: Handle
, ssSend :: SenderMessage -> IO ()
, ssPayload :: LBS.ByteString
, ssGotNEWKEYS :: Bool
, ssInSeq :: Word32
, ssInCipher :: Cipher
, ssInHMAC :: HMAC
, ssInKey :: BS.ByteString
, ssInVector :: BS.ByteString
, ssUser :: Maybe String
}
data SessionConfig =
SessionConfig
{ scAuthMethods :: [String]
, scAuthorize :: Authorize -> Session Bool
, scKeyPair :: KeyPair
}
data Authorize
= Password String String
| PublicKey String PublicKey
instance Sender Session where
send m = gets ssSend >>= io . ($ m)
defaultSessionConfig :: SessionConfig
defaultSessionConfig =
SessionConfig
{ scAuthMethods = ["publickey"]
, scAuthorize = const (return True)
, scKeyPair = RSAKeyPair (RSAPublicKey 0 0) 0 0 0 0 0 0
}
net :: NetReader a -> Session a
net r = do
pl <- gets ssPayload
let (res, new) = runState r pl
modify (\s -> s { ssPayload = new })
return res
newChannelID :: Session Word32
newChannelID = gets ssChannels >>= return . findNext . M.keys
where
findNext :: [Word32] -> Word32
findNext ks = head . filter (not . (`elem` ks)) $ [0..]
getChannel :: Word32 -> Session (Chan ChannelMessage)
getChannel i = do
mc <- gets (M.lookup i . ssChannels)
case mc of
Just c -> return c
Nothing -> error $ "unknown channel: " ++ show i
decrypt :: LBS.ByteString -> Session LBS.ByteString
decrypt m
| m == LBS.empty = return m
| otherwise = do
s <- get
case s of
Final
{ ssInCipher = Cipher AES CBC bs@16 _
, ssInKey = key
, ssInVector = vector
} -> do
let blocks = toBlocks bs m
decrypted =
A.crypt A.CBC key vector A.Decrypt m
modify (\ss -> ss { ssInVector = strictLBS $ last blocks })
return decrypted
_ -> error "no decrypt for current state"
getPacket :: Session ()
getPacket = do
s <- get
h <- gets ssThem
case s of
Final
{ ssGotNEWKEYS = True
, ssInCipher = Cipher _ _ bs _
, ssInHMAC = HMAC ms f
, ssInSeq = is
} -> do
let firstChunk = max 8 bs
firstEnc <- liftIO $ LBS.hGet h firstChunk
first <- decrypt firstEnc
let packetLen = decode (LBS.take 4 first) :: Word32
paddingLen = decode (LBS.drop 4 first) :: Word8
dump ("got packet", is, first, packetLen, paddingLen)
restEnc <- liftIO $ LBS.hGet h (fromIntegral packetLen firstChunk + 4)
dump ("got rest", restEnc)
rest <- decrypt restEnc
dump ("decrypted", rest)
let decrypted = first `LBS.append` rest
payload = extract packetLen paddingLen decrypted
dump ("getting hmac", ms)
mac <- liftIO $ LBS.hGet h ms
dump ("got mac", mac, decrypted, is)
dump ("hmac'd", f decrypted)
dump ("got mac, valid?", verify mac is decrypted f)
modify (\ss -> ss { ssPayload = payload })
_ -> do
first <- liftIO $ LBS.hGet h 5
let packetLen = decode (LBS.take 4 first) :: Word32
paddingLen = decode (LBS.drop 4 first) :: Word8
rest <- liftIO $ LBS.hGet h (fromIntegral packetLen 5 + 4)
let payload = LBS.take (fromIntegral packetLen fromIntegral paddingLen 1) rest
modify (\ss -> ss { ssPayload = payload })
where
extract pkl pdl d = LBS.take (fromIntegral pkl fromIntegral pdl 1) (LBS.drop 5 d)
verify m is d f = m == f (encode (fromIntegral is :: Word32) `LBS.append` d)