module SSH where import Control.Concurrent (forkIO) import Control.Concurrent.Chan import Control.Monad (replicateM, when) import Control.Monad.Trans.State import Data.Digest.Pure.SHA (bytestringDigest, sha1) import Crypto.HMAC import Crypto.Hash.MD5 import Crypto.Hash.SHA1 import Data.List (intercalate) import Data.List.Split (splitOn) import Network import OpenSSL.BN (randIntegerOneToNMinusOne) import System.IO import System.Random import qualified Data.ByteString.Lazy as LBS import qualified Data.Map as M import qualified Data.Serialize as S import SSH.Channel import SSH.Crypto import SSH.Debug import SSH.NetReader import SSH.Packet import SSH.Sender import SSH.Session import SSH.Util version :: String version = "SSH-2.0-DarcsDen" supportedKeyExchanges :: [String] supportedKeyExchanges = {-"diffie-hellman-group-exchange-sha1," ++-} ["diffie-hellman-group1-sha1"] supportedKeyAlgorithms :: [String] supportedKeyAlgorithms = ["ssh-rsa", "ssh-dss"] supportedCiphers :: [(String, Cipher)] supportedCiphers = [ ("aes256-cbc", aesCipher CBC 32) , ("aes192-cbc", aesCipher CBC 24) , ("aes128-cbc", aesCipher CBC 16) ] where aesCipher m s = Cipher AES m 16 s supportedMACs :: [(String, LBS.ByteString -> HMAC)] supportedMACs = [ ("hmac-sha1", makeHMAC True) , ("hmac-md5", makeHMAC False) ] where makeHMAC True k = HMAC 20 $ \b -> bsToLBS . S.runPut $ S.put (hmac (MacKey (strictLBS (LBS.take 20 k))) b :: SHA1) makeHMAC False k = HMAC 16 $ \b -> bsToLBS . S.runPut $ S.put (hmac (MacKey (strictLBS (LBS.take 16 k))) b :: MD5) bsToLBS = LBS.fromChunks . (: []) supportedCompression :: String supportedCompression = "none" supportedLanguages :: String supportedLanguages = "" start :: SessionConfig -> ChannelConfig -> PortNumber -> IO () start sc cc p = withSocketsDo $ do sock <- listenOn (PortNumber p) putStrLn $ "ssh server listening on port " ++ show p waitLoop sc cc sock waitLoop :: SessionConfig -> ChannelConfig -> Socket -> IO () waitLoop sc cc s = do (handle, hostName, port) <- accept s io $ hSetBinaryMode handle True dump ("got connection from", hostName, port) forkIO $ do -- send SSH server version hPutStr handle (version ++ "\r\n") hFlush handle done <- hIsEOF handle if done then return () else do -- get the version response theirVersion <- hGetLine handle >>= return . takeWhile (/= '\r') cookie <- fmap (LBS.pack . map fromIntegral) $ replicateM 16 (randomRIO (0, 255 :: Int)) let ourKEXInit = doPacket $ pKEXInit cookie out <- newChan forkIO (sender out (NoKeys handle 0)) evalStateT (send (Send ourKEXInit) >> readLoop) (Initial { ssConfig = sc , ssChannelConfig = cc , ssThem = handle , ssSend = writeChan out , ssPayload = LBS.empty , ssTheirVersion = theirVersion , ssOurKEXInit = ourKEXInit , ssInSeq = 0 }) waitLoop sc cc s where pKEXInit :: LBS.ByteString -> Packet () pKEXInit cookie = do byte 20 raw cookie mapM_ string [ intercalate "," $ supportedKeyExchanges , intercalate "," $ supportedKeyAlgorithms , intercalate "," $ map fst supportedCiphers , intercalate "," $ map fst supportedCiphers , intercalate "," $ map fst supportedMACs , intercalate "," $ map fst supportedMACs , supportedCompression , supportedCompression , supportedLanguages , supportedLanguages ] byte 0 -- first_kex_packet_follows (boolean) long 0 readLoop :: Session () readLoop = do done <- gets ssThem >>= io . hIsEOF if done then shutdownChannels else do getPacket msg <- net readByte if msg == 1 || msg == 97 -- disconnect || close then shutdownChannels else do case msg of 5 -> serviceRequest 20 -> kexInit 21 -> newKeys 30 -> kexDHInit 50 -> userAuthRequest 90 -> channelOpen 94 -> dataReceived 96 -> eofReceived 98 -> channelRequest u -> dump $ "unknown message: " ++ show u modify (\s -> s { ssInSeq = ssInSeq s + 1 }) readLoop where shutdownChannels = do s <- get case s of Final { ssChannels = cs } -> mapM_ (io . flip writeChan Interrupt) (M.elems cs) _ -> return () io $ ssSend s Stop kexInit :: Session () kexInit = do cookie <- net (readBytes 16) nameLists <- fmap (map (splitOn "," . fromLBS)) (replicateM 10 (net readLBS)) kpf <- net readByte dummy <- net readULong let theirKEXInit = reconstruct cookie nameLists kpf dummy ocn = match (nameLists !! 3) (map fst supportedCiphers) icn = match (nameLists !! 2) (map fst supportedCiphers) omn = match (nameLists !! 5) (map fst supportedMACs) imn = match (nameLists !! 4) (map fst supportedMACs) dump ("KEXINIT", theirKEXInit, ocn, icn, omn, imn) modify $ \st -> case st of Initial c cc h s p cv sk is -> case ( lookup ocn supportedCiphers , lookup icn supportedCiphers , lookup omn supportedMACs , lookup imn supportedMACs ) of (Just oc, Just ic, Just om, Just im) -> GotKEXInit { ssConfig = c , ssChannelConfig = cc , ssThem = h , ssSend = s , ssPayload = p , ssTheirVersion = cv , ssOurKEXInit = sk , ssTheirKEXInit = theirKEXInit , ssOutCipher = oc , ssInCipher = ic , ssOutHMACPrep = om , ssInHMACPrep = im , ssInSeq = is } _ -> error . concat $ [ "impossible: lookup failed for ciphers/macs: " , show (ocn, icn, omn, imn) ] _ -> error "impossible state transition; expected Initial" where match n h = head . filter (`elem` h) $ n reconstruct c nls kpf dummy = doPacket $ do byte 20 raw c mapM_ (string . intercalate ",") nls byte kpf long dummy kexDHInit :: Session () kexDHInit = do e <- net readInteger dump ("KEXDH_INIT", e) y <- io $ randIntegerOneToNMinusOne ((safePrime - 1) `div` 2) -- q? let f = modexp generator y safePrime k = modexp e y safePrime keyPair <- gets (scKeyPair . ssConfig) let pub = case keyPair of RSAKeyPair { rprivPub = p } -> p DSAKeyPair { dprivPub = p } -> p d <- digest e f k pub let [civ, siv, ckey, skey, cinteg, sinteg] = map (makeKey k d) ['A'..'F'] dump ("DECRYPT KEY/IV", LBS.take 16 ckey, LBS.take 16 civ) oc <- gets ssOutCipher om <- gets ssOutHMACPrep send $ Prepare oc (strictLBS $ LBS.take (fromIntegral $ cKeySize oc) $ skey) (strictLBS $ LBS.take (fromIntegral $ cBlockSize oc) $ siv) (om sinteg) modify $ \st -> case st of GotKEXInit c cc h s p _ _ is _ _ ic _ im -> Final { ssConfig = c , ssChannelConfig = cc , ssChannels = M.empty , ssID = d , ssThem = h , ssSend = s , ssPayload = p , ssGotNEWKEYS = False , ssInSeq = is , ssInCipher = ic , ssInHMAC = im cinteg , ssInKey = strictLBS $ LBS.take (fromIntegral $ cKeySize ic) $ ckey , ssInVector = strictLBS $ LBS.take (fromIntegral $ cBlockSize ic) $ civ , ssUser = Nothing } _ -> error "impossible state transition; expected GotKEXInit" signed <- io $ sign keyPair d let reply = doPacket (kexDHReply f signed pub) dump ("KEXDH_REPLY", reply) send (Send reply) where kexDHReply f s p = do byte 31 byteString (blob p) integer f byteString s digest e f k p = do cv <- gets ssTheirVersion ck <- gets ssTheirKEXInit sk <- gets ssOurKEXInit return . bytestringDigest . sha1 . doPacket $ do string cv string version byteString ck byteString sk byteString (blob p) integer e integer f integer k newKeys :: Session () newKeys = do sendPacket (byte 21) send StartEncrypting modify (\ss -> ss { ssGotNEWKEYS = True }) serviceRequest :: Session () serviceRequest = do name <- net readLBS sendPacket $ do byte 6 byteString name userAuthRequest :: Session () userAuthRequest = do user <- net readLBS service <- net readLBS method <- net readLBS auth <- gets (scAuthorize . ssConfig) authMethods <- gets (scAuthMethods . ssConfig) dump ("userauth attempt", user, service, method) check <- case fromLBS method of x | not (x `elem` authMethods) -> return False "publickey" -> do b <- net readByte name <- net readLBS key <- net readLBS ch <- auth (PublicKey (fromLBS user) (blobToKey key)) -- if it's signed, assume it's the second one after auth if ch && b == 1 then sendPacket userAuthOK else when ch (sendPacket $ userAuthPKOK name key) return ch "password" -> do 0 <- net readByte password <- net readLBS ch <- auth (Password (fromLBS user) (fromLBS password)) when ch (sendPacket userAuthOK) return ch u -> error $ "unhandled authorization type: " ++ u if check then modify (\s -> s { ssUser = Just (fromLBS user) }) else sendPacket (userAuthFail authMethods) where userAuthFail ms = do byte 51 string (intercalate "," ms) byte 0 userAuthPKOK name key = do byte 60 byteString name byteString key userAuthOK = byte 52 channelOpen :: Session () channelOpen = do name <- net readLBS them <- net readULong windowSize <- net readULong maxPacketLength <- net readULong dump ("channel open", name, them, windowSize, maxPacketLength) us <- newChannelID chan <- do c <- gets ssChannelConfig s <- gets ssSend Just u <- gets ssUser io $ newChannel c s us them windowSize maxPacketLength u modify (\s -> s { ssChannels = M.insert us chan (ssChannels s) }) channelRequest :: Session () channelRequest = do chan <- net readULong >>= getChannel typ <- net readLBS wantReply <- net readBool let sendRequest = io . writeChan chan . Request wantReply case fromLBS typ of "pty-req" -> do term <- net readString cols <- net readULong rows <- net readULong width <- net readULong height <- net readULong modes <- net readString sendRequest (PseudoTerminal term cols rows width height modes) "x11-req" -> sendRequest X11Forwarding "shell" -> sendRequest Shell "exec" -> do command <- net readString dump ("execute command", command) sendRequest (Execute command) "subsystem" -> do name <- net readString dump ("subsystem request", name) sendRequest (Subsystem name) "env" -> do name <- net readString value <- net readString dump ("environment request", name, value) sendRequest (Environment name value) "window-change" -> do cols <- net readULong rows <- net readULong width <- net readULong height <- net readULong sendRequest (WindowChange cols rows width height) "xon-xoff" -> do b <- net readBool sendRequest (FlowControl b) "signal" -> do name <- net readString sendRequest (Signal name) "exit-status" -> do status <- net readULong sendRequest (ExitStatus status) "exit-signal" -> do name <- net readString dumped <- net readBool msg <- net readString lang <- net readString sendRequest (ExitSignal name dumped msg lang) u -> sendRequest (Unknown u) dump ("request processed") dataReceived :: Session () dataReceived = do dump "got data" chan <- net readULong >>= getChannel msg <- net readLBS io $ writeChan chan (Data msg) dump "data processed" eofReceived :: Session () eofReceived = do chan <- net readULong >>= getChannel io $ writeChan chan EOF