module SSH.Sender where import Control.Concurrent.Chan import Control.Monad (replicateM) import Data.Word import System.IO import System.Random import qualified Codec.Crypto.SimpleAES as A import qualified Data.ByteString as BS import qualified Data.ByteString.Lazy as LBS import SSH.Debug import SSH.Crypto import SSH.Packet import SSH.Util data SenderState = NoKeys { senderThem :: Handle , senderOutSeq :: Word32 } | GotKeys { senderThem :: Handle , senderOutSeq :: Word32 , senderEncrypting :: Bool , senderCipher :: Cipher , senderKey :: BS.ByteString , senderVector :: BS.ByteString , senderHMAC :: HMAC } data SenderMessage = Prepare Cipher BS.ByteString BS.ByteString HMAC | StartEncrypting | Send LBS.ByteString | Stop class Sender a where send :: SenderMessage -> a () sendPacket :: Packet () -> a () sendPacket = send . Send . doPacket sender :: Chan SenderMessage -> SenderState -> IO () sender ms ss = do m <- readChan ms case m of Stop -> return () Prepare cipher key iv hmac -> do dump ("initiating encryption", key, iv) sender ms (GotKeys (senderThem ss) (senderOutSeq ss) False cipher key iv hmac) StartEncrypting -> do dump ("starting encryption") sender ms (ss { senderEncrypting = True }) Send msg -> do pad <- fmap (LBS.pack . map fromIntegral) $ replicateM (fromIntegral $ paddingLen msg) (randomRIO (0, 255 :: Int)) let f = full msg pad case ss of GotKeys h os True cipher key iv (HMAC _ mac) -> do dump ("sending encrypted", os, f) let (encrypted, newVector) = encrypt cipher key iv f LBS.hPut h . LBS.concat $ [ encrypted , mac . doPacket $ long os >> raw f ] hFlush h sender ms $ ss { senderOutSeq = senderOutSeq ss + 1 , senderVector = newVector } _ -> do dump ("sending unencrypted", senderOutSeq ss, f) LBS.hPut (senderThem ss) f hFlush (senderThem ss) sender ms (ss { senderOutSeq = senderOutSeq ss + 1 }) where blockSize = case ss of GotKeys { senderCipher = Cipher _ _ bs _ } | bs > 8 -> bs _ -> 8 full msg pad = doPacket $ do long (len msg) byte (paddingLen msg) raw msg raw pad len :: LBS.ByteString -> Word32 len msg = 1 + fromIntegral (LBS.length msg) + fromIntegral (paddingLen msg) paddingNeeded :: LBS.ByteString -> Word8 paddingNeeded msg = fromIntegral blockSize - (fromIntegral $ (5 + LBS.length msg) `mod` fromIntegral blockSize) paddingLen :: LBS.ByteString -> Word8 paddingLen msg = if paddingNeeded msg < 4 then paddingNeeded msg + fromIntegral blockSize else paddingNeeded msg encrypt :: Cipher -> BS.ByteString -> BS.ByteString -> LBS.ByteString -> (LBS.ByteString, BS.ByteString) encrypt (Cipher AES CBC bs _) key vector m = ( fromBlocks encrypted , case encrypted of (_:_) -> strictLBS (last encrypted) [] -> error ("encrypted data empty for `" ++ show m ++ "' in encrypt") vector ) where encrypted = toBlocks bs $ A.crypt A.CBC key vector A.Encrypt m