{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE Strict     #-}

module Foreign.Erlang.Handshake
    ( HandshakeData(..)
    , doConnect
    , doAccept
    , Name(..)
    , Status(..)
    , Challenge(..)
    , ChallengeReply(..)
    , ChallengeAck(..)
    ) where

import           Control.Monad           ( unless, when )
import           Util.IOExtra
import           Data.Ix                 ( inRange )

import qualified Data.ByteString         as BS
import           Data.Binary
import           Data.Binary.Get
import           Data.Binary.Put
import           Util.Binary
import           Foreign.Erlang.Digest
import           Foreign.Erlang.NodeData

data HandshakeData = HandshakeData { name     :: Name
                                   , nodeData :: NodeData
                                   , cookie   :: BS.ByteString
                                   }

nodeTypeR6, challengeStatus, challengeReply, challengeAck :: Char
nodeTypeR6 = 'n'

challengeStatus = 's'

challengeReply = 'r'

challengeAck = 'a'

data Name = Name { n_distVer   :: DistributionVersion
                 , n_distFlags :: DistributionFlags
                 , n_nodeName  :: BS.ByteString
                 }
    deriving (Eq, Show)

instance Binary Name where
    put Name{n_distVer,n_distFlags,n_nodeName} =
        putWithLength16be $ do
            putChar8 nodeTypeR6
            put n_distVer
            put n_distFlags
            putByteString n_nodeName
    get = do
        len <- getWord16be
        (((), n_distVer, n_distFlags), l) <- getWithLength16be $
                                                 (,,) <$> matchChar8 nodeTypeR6
                                                      <*> get
                                                      <*> get
        n_nodeName <- getByteString (fromIntegral (len - l))
        return Name { n_distVer, n_distFlags, n_nodeName }

data Status = Ok | OkSimultaneous | Nok | NotAllowed | Alive
    deriving (Eq, Show, Bounded, Enum)

instance Binary Status where
    put status = putWithLength16be $ do
        putChar8 challengeStatus
        case status of
            Ok -> putByteString "ok"
            OkSimultaneous -> putByteString "ok_simultaneous"
            Nok -> putByteString "nok"
            NotAllowed -> putByteString "not_allowed"
            Alive -> putByteString "alive"
    get = do
        len <- getWord16be
        ((), l) <- getWithLength16be $ matchChar8 challengeStatus
        status <- getByteString (fromIntegral (len - l))
        case status of
            "ok" -> return Ok
            "ok_simultaneous" -> return OkSimultaneous
            "nok" -> return Nok
            "not_allowed" -> return NotAllowed
            "alive" -> return Alive
            _ -> fail $ "Bad status: " ++ show status

data Challenge = Challenge { c_distVer   :: DistributionVersion
                           , c_distFlags :: DistributionFlags
                           , c_challenge :: Word32
                           , c_nodeName  :: BS.ByteString
                           }
    deriving (Eq, Show)

instance Binary Challenge where
    put Challenge{c_distVer,c_distFlags,c_challenge,c_nodeName} =
        putWithLength16be $ do
            putChar8 nodeTypeR6
            put c_distVer
            put c_distFlags
            putWord32be c_challenge
            putByteString c_nodeName
    get = do
        len <- getWord16be
        (((), c_distVer, c_distFlags, c_challenge), l) <- getWithLength16be $
                                                              (,,,) <$> matchChar8 nodeTypeR6
                                                                    <*> get
                                                                    <*> get
                                                                    <*> getWord32be
        c_nodeName <- getByteString (fromIntegral (len - l))
        return Challenge { c_distVer, c_distFlags, c_challenge, c_nodeName }

data ChallengeReply = ChallengeReply { cr_challenge :: Word32
                                     , cr_digest    :: BS.ByteString
                                     }
    deriving (Eq, Show)

instance Binary ChallengeReply where
    put ChallengeReply{cr_challenge,cr_digest} =
        putWithLength16be $ do
            putChar8 challengeReply
            putWord32be cr_challenge
            putByteString cr_digest
    get = do
        len <- getWord16be
        (((), cr_challenge), l) <- getWithLength16be $
                                       (,) <$> matchChar8 challengeReply
                                           <*> getWord32be
        cr_digest <- getByteString (fromIntegral (len - l))
        return ChallengeReply { cr_challenge, cr_digest }

data ChallengeAck = ChallengeAck { ca_digest :: BS.ByteString }
    deriving (Eq, Show)

instance Binary ChallengeAck where
    put ChallengeAck{ca_digest} =
        putWithLength16be $ do
            putChar8 challengeAck
            putByteString ca_digest
    get = do
        len <- getWord16be
        ((), l) <- getWithLength16be $ matchChar8 challengeAck
        ca_digest <- getByteString (fromIntegral (len - l))
        return ChallengeAck { ca_digest }

doConnect :: (MonadCatch m, MonadIO m)
          => (forall o. Binary o => o -> m ())
          -> (forall i. (Binary i) => m i)
          -> HandshakeData
          -> m ()
doConnect send recv HandshakeData{name,nodeData = NodeData{loVer,hiVer},cookie} = do
    send name
    do
        her_status <- recv
        when (her_status /= Ok) (throwM (BadHandshakeStatus her_status))

    Challenge{c_distVer = her_distVer,c_challenge = her_challenge} <- recv
    checkVersionRange her_distVer loVer hiVer

    our_challenge <- liftIO genChallenge
    send ChallengeReply { cr_challenge = our_challenge
                        , cr_digest = genDigest her_challenge cookie
                        }
    ChallengeAck{ca_digest = her_digest} <- recv
    checkCookie her_digest our_challenge cookie

newtype BadHandshakeStatus = BadHandshakeStatus Status
    deriving Show

instance Exception BadHandshakeStatus

doAccept :: (MonadCatch m, MonadIO m)
         => (forall o. Binary o => o -> m ()) -- TODO
         -> (forall i. (Binary i) => m i)
         -> HandshakeData
         -> m BS.ByteString
doAccept send recv HandshakeData{name = Name{n_distFlags,n_nodeName},nodeData = NodeData{loVer,hiVer},cookie} = do
    Name{n_distVer = her_distVer,n_nodeName = her_nodeName} <- recv
    checkVersionRange her_distVer loVer hiVer

    send Ok

    our_challenge <- liftIO genChallenge
    send Challenge { c_distVer = R6B
                   , c_distFlags = n_distFlags
                   , c_challenge = our_challenge
                   , c_nodeName = n_nodeName
                   }

    ChallengeReply{cr_challenge = her_challenge,cr_digest = her_digest} <- recv
    checkCookie her_digest our_challenge cookie

    send ChallengeAck { ca_digest = genDigest her_challenge cookie }
    return her_nodeName

checkVersionRange :: MonadThrow m
                  => DistributionVersion
                  -> DistributionVersion
                  -> DistributionVersion
                  -> m ()
checkVersionRange herVersion lowVersion highVersion =
    unless (inRange (lowVersion, highVersion) herVersion)
           (throwM DistributionVersionMismatch { herVersion
                                               , lowVersion
                                               , highVersion
                                               })

checkCookie :: MonadThrow m => BS.ByteString -> Word32 -> BS.ByteString -> m ()
checkCookie her_digest our_challenge cookie =
    unless (her_digest == genDigest our_challenge cookie)
           (throwM CookieMismatch)

data DistributionVersionMismatch =
      DistributionVersionMismatch { herVersion  :: DistributionVersion
                                  , lowVersion  :: DistributionVersion
                                  , highVersion :: DistributionVersion
                                  }
    deriving Show

instance Exception DistributionVersionMismatch

data CookieMismatch = CookieMismatch
    deriving Show

instance Exception CookieMismatch