{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE Strict #-}
module Foreign.Erlang.Handshake
( HandshakeData(..)
, doConnect
, doAccept
, Name(..)
, Status(..)
, Challenge(..)
, ChallengeReply(..)
, ChallengeAck(..)
) where
import Control.Monad ( unless, when )
import Util.IOExtra
import Data.Ix ( inRange )
import qualified Data.ByteString as BS
import Data.Binary
import Data.Binary.Get
import Data.Binary.Put
import Util.Binary
import Foreign.Erlang.Digest
import Foreign.Erlang.NodeData
data HandshakeData = HandshakeData { name :: Name
, nodeData :: NodeData
, cookie :: BS.ByteString
}
nodeTypeR6, challengeStatus, challengeReply, challengeAck :: Char
nodeTypeR6 = 'n'
challengeStatus = 's'
challengeReply = 'r'
challengeAck = 'a'
data Name = Name { n_distVer :: DistributionVersion
, n_distFlags :: DistributionFlags
, n_nodeName :: BS.ByteString
}
deriving (Eq, Show)
instance Binary Name where
put Name{n_distVer,n_distFlags,n_nodeName} =
putWithLength16be $ do
putChar8 nodeTypeR6
put n_distVer
put n_distFlags
putByteString n_nodeName
get = do
len <- getWord16be
(((), n_distVer, n_distFlags), l) <- getWithLength16be $
(,,) <$> matchChar8 nodeTypeR6
<*> get
<*> get
n_nodeName <- getByteString (fromIntegral (len - l))
return Name { n_distVer, n_distFlags, n_nodeName }
data Status = Ok | OkSimultaneous | Nok | NotAllowed | Alive
deriving (Eq, Show, Bounded, Enum)
instance Binary Status where
put status = putWithLength16be $ do
putChar8 challengeStatus
case status of
Ok -> putByteString "ok"
OkSimultaneous -> putByteString "ok_simultaneous"
Nok -> putByteString "nok"
NotAllowed -> putByteString "not_allowed"
Alive -> putByteString "alive"
get = do
len <- getWord16be
((), l) <- getWithLength16be $ matchChar8 challengeStatus
status <- getByteString (fromIntegral (len - l))
case status of
"ok" -> return Ok
"ok_simultaneous" -> return OkSimultaneous
"nok" -> return Nok
"not_allowed" -> return NotAllowed
"alive" -> return Alive
_ -> fail $ "Bad status: " ++ show status
data Challenge = Challenge { c_distVer :: DistributionVersion
, c_distFlags :: DistributionFlags
, c_challenge :: Word32
, c_nodeName :: BS.ByteString
}
deriving (Eq, Show)
instance Binary Challenge where
put Challenge{c_distVer,c_distFlags,c_challenge,c_nodeName} =
putWithLength16be $ do
putChar8 nodeTypeR6
put c_distVer
put c_distFlags
putWord32be c_challenge
putByteString c_nodeName
get = do
len <- getWord16be
(((), c_distVer, c_distFlags, c_challenge), l) <- getWithLength16be $
(,,,) <$> matchChar8 nodeTypeR6
<*> get
<*> get
<*> getWord32be
c_nodeName <- getByteString (fromIntegral (len - l))
return Challenge { c_distVer, c_distFlags, c_challenge, c_nodeName }
data ChallengeReply = ChallengeReply { cr_challenge :: Word32
, cr_digest :: BS.ByteString
}
deriving (Eq, Show)
instance Binary ChallengeReply where
put ChallengeReply{cr_challenge,cr_digest} =
putWithLength16be $ do
putChar8 challengeReply
putWord32be cr_challenge
putByteString cr_digest
get = do
len <- getWord16be
(((), cr_challenge), l) <- getWithLength16be $
(,) <$> matchChar8 challengeReply
<*> getWord32be
cr_digest <- getByteString (fromIntegral (len - l))
return ChallengeReply { cr_challenge, cr_digest }
data ChallengeAck = ChallengeAck { ca_digest :: BS.ByteString }
deriving (Eq, Show)
instance Binary ChallengeAck where
put ChallengeAck{ca_digest} =
putWithLength16be $ do
putChar8 challengeAck
putByteString ca_digest
get = do
len <- getWord16be
((), l) <- getWithLength16be $ matchChar8 challengeAck
ca_digest <- getByteString (fromIntegral (len - l))
return ChallengeAck { ca_digest }
doConnect :: (MonadCatch m, MonadIO m)
=> (forall o. Binary o => o -> m ())
-> (forall i. (Binary i) => m i)
-> HandshakeData
-> m ()
doConnect send recv HandshakeData{name,nodeData = NodeData{loVer,hiVer},cookie} = do
send name
do
her_status <- recv
when (her_status /= Ok) (throwM (BadHandshakeStatus her_status))
Challenge{c_distVer = her_distVer,c_challenge = her_challenge} <- recv
checkVersionRange her_distVer loVer hiVer
our_challenge <- liftIO genChallenge
send ChallengeReply { cr_challenge = our_challenge
, cr_digest = genDigest her_challenge cookie
}
ChallengeAck{ca_digest = her_digest} <- recv
checkCookie her_digest our_challenge cookie
newtype BadHandshakeStatus = BadHandshakeStatus Status
deriving Show
instance Exception BadHandshakeStatus
doAccept :: (MonadCatch m, MonadIO m)
=> (forall o. Binary o => o -> m ())
-> (forall i. (Binary i) => m i)
-> HandshakeData
-> m BS.ByteString
doAccept send recv HandshakeData{name = Name{n_distFlags,n_nodeName},nodeData = NodeData{loVer,hiVer},cookie} = do
Name{n_distVer = her_distVer,n_nodeName = her_nodeName} <- recv
checkVersionRange her_distVer loVer hiVer
send Ok
our_challenge <- liftIO genChallenge
send Challenge { c_distVer = R6B
, c_distFlags = n_distFlags
, c_challenge = our_challenge
, c_nodeName = n_nodeName
}
ChallengeReply{cr_challenge = her_challenge,cr_digest = her_digest} <- recv
checkCookie her_digest our_challenge cookie
send ChallengeAck { ca_digest = genDigest her_challenge cookie }
return her_nodeName
checkVersionRange :: MonadThrow m
=> DistributionVersion
-> DistributionVersion
-> DistributionVersion
-> m ()
checkVersionRange herVersion lowVersion highVersion =
unless (inRange (lowVersion, highVersion) herVersion)
(throwM DistributionVersionMismatch { herVersion
, lowVersion
, highVersion
})
checkCookie :: MonadThrow m => BS.ByteString -> Word32 -> BS.ByteString -> m ()
checkCookie her_digest our_challenge cookie =
unless (her_digest == genDigest our_challenge cookie)
(throwM CookieMismatch)
data DistributionVersionMismatch =
DistributionVersionMismatch { herVersion :: DistributionVersion
, lowVersion :: DistributionVersion
, highVersion :: DistributionVersion
}
deriving Show
instance Exception DistributionVersionMismatch
data CookieMismatch = CookieMismatch
deriving Show
instance Exception CookieMismatch