module SSH.Sender where import Control.Concurrent.Chan (Chan, readChan) import Control.Monad (replicateM) import Data.Word (Word8, Word32) import System.IO (Handle, hFlush) import System.Random (randomRIO) import qualified Codec.Crypto.SimpleAES as A import qualified Data.ByteString as BS import qualified Data.ByteString.Lazy as LBS import SSH.Debug (dump) import SSH.Crypto (Cipher(..), HMAC(..), CipherType(..), CipherMode(..), fromBlocks, toBlocks) import SSH.Packet (Packet, doPacket, raw, byte, long) import SSH.Util (strictLBS) data SenderState = NoKeys { senderThem :: Handle , senderOutSeq :: Word32 } | GotKeys { senderThem :: Handle , senderOutSeq :: Word32 , senderEncrypting :: Bool , senderCipher :: Cipher , senderKey :: BS.ByteString , senderVector :: BS.ByteString , senderHMAC :: HMAC } data SenderMessage = Prepare Cipher BS.ByteString BS.ByteString HMAC | StartEncrypting | Send LBS.ByteString | Stop class Sender a where send :: SenderMessage -> a () sendPacket :: Packet () -> a () sendPacket = send . Send . doPacket sender :: Chan SenderMessage -> SenderState -> IO () sender ms ss = do m <- readChan ms case m of Stop -> return () Prepare cipher key iv hmac -> do dump ("initiating encryption", key, iv) sender ms (GotKeys (senderThem ss) (senderOutSeq ss) False cipher key iv hmac) StartEncrypting -> do dump ("starting encryption") sender ms (ss { senderEncrypting = True }) Send msg -> do pad <- fmap (LBS.pack . map fromIntegral) $ replicateM (fromIntegral $ paddingLen msg) (randomRIO (0, 255 :: Int)) let f = full msg pad case ss of GotKeys h os True cipher key iv (HMAC _ mac) -> do dump ("sending encrypted", os, f) let (encrypted, newVector) = encrypt cipher key iv f LBS.hPut h . LBS.concat $ [ encrypted , mac . doPacket $ long os >> raw f ] hFlush h sender ms $ ss { senderOutSeq = senderOutSeq ss + 1 , senderVector = newVector } _ -> do dump ("sending unencrypted", senderOutSeq ss, f) LBS.hPut (senderThem ss) f hFlush (senderThem ss) sender ms (ss { senderOutSeq = senderOutSeq ss + 1 }) where blockSize = case ss of GotKeys { senderCipher = Cipher _ _ bs _ } | bs > 8 -> bs _ -> 8 full msg pad = doPacket $ do long (len msg) byte (paddingLen msg) raw msg raw pad len :: LBS.ByteString -> Word32 len msg = 1 + fromIntegral (LBS.length msg) + fromIntegral (paddingLen msg) paddingNeeded :: LBS.ByteString -> Word8 paddingNeeded msg = fromIntegral blockSize - (fromIntegral $ (5 + LBS.length msg) `mod` fromIntegral blockSize) paddingLen :: LBS.ByteString -> Word8 paddingLen msg = if paddingNeeded msg < 4 then paddingNeeded msg + fromIntegral blockSize else paddingNeeded msg encrypt :: Cipher -> BS.ByteString -> BS.ByteString -> LBS.ByteString -> (LBS.ByteString, BS.ByteString) encrypt (Cipher AES CBC bs _) key vector m = ( fromBlocks encrypted , case encrypted of (_:_) -> strictLBS (last encrypted) [] -> error ("encrypted data empty for `" ++ show m ++ "' in encrypt") vector ) where encrypted = toBlocks bs $ A.crypt A.CBC key vector A.Encrypt m