{-# LANGUAGE FlexibleInstances, TypeSynonymInstances #-}
module SSH.Session where

import Control.Concurrent.Chan (Chan)
import Control.Monad.IO.Class (liftIO)
import Control.Monad.Trans.State (StateT, get, gets, modify, runState)
import Data.Binary (decode, encode)
import Data.Word (Word8, Word32)
import System.IO (Handle)
import qualified Codec.Crypto.SimpleAES as A
import qualified Data.ByteString as BS
import qualified Data.ByteString.Lazy as LBS
import qualified Data.Map as M

import SSH.Channel (ChannelConfig, ChannelMessage)
import SSH.Crypto (Cipher(..), HMAC(..), KeyPair(..), PublicKey(..), CipherType(..), CipherMode(..), toBlocks)
import SSH.Debug (dump)
import SSH.NetReader (NetReader)
import SSH.Packet (io)
import SSH.Sender (Sender(..), SenderMessage)
import SSH.Util (strictLBS)


type Session = StateT SessionState IO

data SessionState
    = Initial
        { ssConfig :: SessionConfig
        , ssChannelConfig :: ChannelConfig
        , ssThem :: Handle
        , ssSend :: SenderMessage -> IO ()
        , ssPayload :: LBS.ByteString
        , ssTheirVersion :: String
        , ssOurKEXInit :: LBS.ByteString
        , ssInSeq :: Word32
        }
    | GotKEXInit
        { ssConfig :: SessionConfig
        , ssChannelConfig :: ChannelConfig
        , ssThem :: Handle
        , ssSend :: SenderMessage -> IO ()
        , ssPayload :: LBS.ByteString
        , ssTheirVersion :: String
        , ssOurKEXInit :: LBS.ByteString
        , ssInSeq :: Word32
        , ssTheirKEXInit :: LBS.ByteString
        , ssOutCipher :: Cipher
        , ssInCipher :: Cipher
        , ssOutHMACPrep :: LBS.ByteString -> HMAC
        , ssInHMACPrep :: LBS.ByteString -> HMAC
        }
    | Final
        { ssConfig :: SessionConfig
        , ssChannelConfig :: ChannelConfig
        , ssChannels :: M.Map Word32 (Chan ChannelMessage)
        , ssID :: LBS.ByteString
        , ssThem :: Handle
        , ssSend :: SenderMessage -> IO ()
        , ssPayload :: LBS.ByteString
        , ssGotNEWKEYS :: Bool
        , ssInSeq :: Word32
        , ssInCipher :: Cipher
        , ssInHMAC :: HMAC
        , ssInKey :: BS.ByteString
        , ssInVector :: BS.ByteString
        , ssUser :: Maybe String
        }

data SessionConfig =
    SessionConfig
        { scAuthMethods :: [String]
        , scAuthorize :: Authorize -> Session Bool
        , scKeyPair :: KeyPair
        }

data Authorize
    = Password String String
    | PublicKey String PublicKey

instance Sender Session where
    send m = gets ssSend >>= io . ($ m)


defaultSessionConfig :: SessionConfig
defaultSessionConfig =
    SessionConfig
        { scAuthMethods = ["publickey"]
        , scAuthorize = const (return True)
        , scKeyPair = RSAKeyPair (RSAPublicKey 0 0) 0 0 0 0 0 0
        {-\(Password u p) ->-}
            {-return $ u == "test" && p == "test"-}
        }

net :: NetReader a -> Session a
net r = do
    pl <- gets ssPayload

    let (res, new) = runState r pl

    modify (\s -> s { ssPayload = new })
    return res

newChannelID :: Session Word32
newChannelID = gets ssChannels >>= return . findNext . M.keys
  where
    findNext :: [Word32] -> Word32
    findNext ks = head . filter (not . (`elem` ks)) $ [0..]

getChannel :: Word32 -> Session (Chan ChannelMessage)
getChannel i = do
    mc <- gets (M.lookup i . ssChannels)
    case mc of
        Just c -> return c
        Nothing -> error $ "unknown channel: " ++ show i

decrypt :: LBS.ByteString -> Session LBS.ByteString
decrypt m
    | m == LBS.empty = return m
    | otherwise = do
    s <- get
    case s of
        Final
            { ssInCipher = Cipher AES CBC bs@16 _
            , ssInKey = key
            , ssInVector = vector
            } -> do
                let blocks = toBlocks bs m
                    decrypted =
                      A.crypt A.CBC key vector A.Decrypt m

                modify (\ss -> ss { ssInVector = strictLBS $ last blocks })
                return decrypted
        _ -> error "no decrypt for current state"

getPacket :: Session ()
getPacket = do
    s <- get
    h <- gets ssThem
    case s of
        Final
            { ssGotNEWKEYS = True
            , ssInCipher = Cipher _ _ bs _
            , ssInHMAC = HMAC ms f
            , ssInSeq = is
            } -> do
                let firstChunk = max 8 bs

                firstEnc <- liftIO $ LBS.hGet h firstChunk
                first <- decrypt firstEnc

                let packetLen = decode (LBS.take 4 first) :: Word32
                    paddingLen = decode (LBS.drop 4 first) :: Word8

                dump ("got packet", is, first, packetLen, paddingLen)

                restEnc <- liftIO $ LBS.hGet h (fromIntegral packetLen - firstChunk + 4)

                dump ("got rest", restEnc)

                rest <- decrypt restEnc

                dump ("decrypted", rest)
                let decrypted = first `LBS.append` rest
                    payload = extract packetLen paddingLen decrypted

                dump ("getting hmac", ms)

                mac <- liftIO $ LBS.hGet h ms

                dump ("got mac", mac, decrypted, is)
                dump ("hmac'd", f decrypted)
                dump ("got mac, valid?", verify mac is decrypted f)

                modify (\ss -> ss { ssPayload = payload })
        _ -> do
            first <- liftIO $ LBS.hGet h 5

            let packetLen = decode (LBS.take 4 first) :: Word32
                paddingLen = decode (LBS.drop 4 first) :: Word8

            rest <- liftIO $ LBS.hGet h (fromIntegral packetLen - 5 + 4)
            let payload = LBS.take (fromIntegral packetLen - fromIntegral paddingLen - 1) rest
            modify (\ss -> ss { ssPayload = payload })
  where
    extract pkl pdl d = LBS.take (fromIntegral pkl - fromIntegral pdl - 1) (LBS.drop 5 d)
    verify m is d f = m == f (encode (fromIntegral is :: Word32) `LBS.append` d)