{-# LANGUAGE TypeSynonymInstances #-}
module SSH.Session where

import Control.Concurrent.Chan
import Control.Monad.State
import Data.Binary (decode, encode)
import Data.Word
import System.IO
import qualified Codec.Crypto.SimpleAES as A
import qualified Data.ByteString as BS
import qualified Data.ByteString.Lazy as LBS
import qualified Data.Map as M

import SSH.Channel
import SSH.Crypto
import SSH.Debug
import SSH.NetReader
import SSH.Packet
import SSH.Sender
import SSH.Util


type Session = StateT SessionState IO

data SessionState
    = Initial
        { ssConfig :: SessionConfig
        , ssChannelConfig :: ChannelConfig
        , ssThem :: Handle
        , ssSend :: SenderMessage -> IO ()
        , ssPayload :: LBS.ByteString
        , ssTheirVersion :: String
        , ssOurKEXInit :: LBS.ByteString
        , ssInSeq :: Word32
        }
    | GotKEXInit
        { ssConfig :: SessionConfig
        , ssChannelConfig :: ChannelConfig
        , ssThem :: Handle
        , ssSend :: SenderMessage -> IO ()
        , ssPayload :: LBS.ByteString
        , ssTheirVersion :: String
        , ssOurKEXInit :: LBS.ByteString
        , ssInSeq :: Word32
        , ssTheirKEXInit :: LBS.ByteString
        , ssOutCipher :: Cipher
        , ssInCipher :: Cipher
        , ssOutHMACPrep :: LBS.ByteString -> HMAC
        , ssInHMACPrep :: LBS.ByteString -> HMAC
        }
    | Final
        { ssConfig :: SessionConfig
        , ssChannelConfig :: ChannelConfig
        , ssChannels :: M.Map Word32 (Chan ChannelMessage)
        , ssID :: LBS.ByteString
        , ssThem :: Handle
        , ssSend :: SenderMessage -> IO ()
        , ssPayload :: LBS.ByteString
        , ssGotNEWKEYS :: Bool
        , ssInSeq :: Word32
        , ssInCipher :: Cipher
        , ssInHMAC :: HMAC
        , ssInKey :: BS.ByteString
        , ssInVector :: BS.ByteString
        , ssUser :: Maybe String
        }

data SessionConfig =
    SessionConfig
        { scAuthMethods :: [String]
        , scAuthorize :: Authorize -> Session Bool
        , scKeyPair :: KeyPair
        }

data Authorize
    = Password String String
    | PublicKey String PublicKey

instance Sender Session where
    send m = gets ssSend >>= io . ($ m)


defaultSessionConfig :: SessionConfig
defaultSessionConfig =
    SessionConfig
        { scAuthMethods = ["publickey"]
        , scAuthorize = const (return True)
        , scKeyPair = RSAKeyPair (RSAPublicKey 0 0) 0
        {-\(Password u p) ->-}
            {-return $ u == "test" && p == "test"-}
        }

net :: NetReader a -> Session a
net r = do
    pl <- gets ssPayload

    let (res, new) = runState r pl

    modify (\s -> s { ssPayload = new })
    return res

newChannelID :: Session Word32
newChannelID = gets ssChannels >>= return . findNext . M.keys
  where
    findNext :: [Word32] -> Word32
    findNext ks = head . filter (not . (`elem` ks)) $ [0..]

getChannel :: Word32 -> Session (Chan ChannelMessage)
getChannel i = do
    mc <- gets (M.lookup i . ssChannels)
    case mc of
        Just c -> return c
        Nothing -> error $ "unknown channel: " ++ show i

decrypt :: LBS.ByteString -> Session LBS.ByteString
decrypt m
    | m == LBS.empty = return m
    | otherwise = do
    s <- get
    case s of
        Final
            { ssInCipher = Cipher AES CBC bs@16 _
            , ssInKey = key
            , ssInVector = vector
            } -> do
                let blocks = toBlocks bs m
                    decrypted =
                      A.crypt A.CBC key vector A.Decrypt m

                modify (\ss -> ss { ssInVector = strictLBS $ last blocks })
                return decrypted
        _ -> error "no decrypt for current state"

getPacket :: Session ()
getPacket = do
    s <- get
    h <- gets ssThem
    case s of
        Final
            { ssGotNEWKEYS = True
            , ssInCipher = Cipher _ _ bs _
            , ssInHMAC = HMAC ms f
            , ssInSeq = is
            } -> do
                let firstChunk = max 8 bs

                firstEnc <- liftIO $ LBS.hGet h firstChunk
                first <- decrypt firstEnc

                let packetLen = decode (LBS.take 4 first) :: Word32
                    paddingLen = decode (LBS.drop 4 first) :: Word8

                dump ("got packet", is, first, packetLen, paddingLen)

                restEnc <- liftIO $ LBS.hGet h (fromIntegral packetLen - firstChunk + 4)

                dump ("got rest", restEnc)

                rest <- decrypt restEnc

                dump ("decrypted", rest)
                let decrypted = first `LBS.append` rest
                    payload = extract packetLen paddingLen decrypted

                dump ("getting hmac", ms)

                mac <- liftIO $ LBS.hGet h ms

                dump ("got mac", mac, decrypted, is)
                dump ("hmac'd", f decrypted)
                dump ("got mac, valid?", verify mac is decrypted f)

                modify (\ss -> ss { ssPayload = payload })
        _ -> do
            first <- liftIO $ LBS.hGet h 5

            let packetLen = decode (LBS.take 4 first) :: Word32
                paddingLen = decode (LBS.drop 4 first) :: Word8

            rest <- liftIO $ LBS.hGet h (fromIntegral packetLen - 5 + 4)
            let payload = LBS.take (fromIntegral packetLen - fromIntegral paddingLen - 1) rest
            modify (\ss -> ss { ssPayload = payload })
  where
    extract pkl pdl d = LBS.take (fromIntegral pkl - fromIntegral pdl - 1) (LBS.drop 5 d)
    verify m is d f = m == f (encode (fromIntegral is :: Word32) `LBS.append` d)