{-# LANGUAGE PackageImports #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE NamedFieldPuns #-} module Language.Erlang.Handshake ( connectNodes , Name (..) , Status (..) , Challenge (..) , ChallengeReply (..) , ChallengeAck (..) ) where import Control.Monad (unless) import qualified Data.ByteString as BS import Data.Binary import Data.Binary.Get import Data.Binary.Put import Util.Socket import Util.BufferedSocket import Util.Binary import Util.Util import Util.IOx import Language.Erlang.Digest import Language.Erlang.NodeState import Language.Erlang.NodeData import Language.Erlang.Epmd import Language.Erlang.Term import Language.Erlang.Mailbox import Language.Erlang.Connection -------------------------------------------------------------------------------- nodeTypeR6, challengeStatus, challengeReply, challengeAck :: Char nodeTypeR6 = 'n' challengeStatus = 's' challengeReply = 'r' challengeAck = 'a' -------------------------------------------------------------------------------- data Name = Name { n_distVer :: DistributionVersion , n_distFlags :: DistributionFlags , n_nodeName :: BS.ByteString } deriving (Eq, Show) instance Binary Name where put Name {..} = putWithLength16be $ do putChar8 nodeTypeR6 put n_distVer putDistributionFlags n_distFlags putByteString n_nodeName get = do len <- getWord16be (((), version, flags), l) <- getWithLength16be $ (,,) <$> matchChar8 nodeTypeR6 <*> get <*> getDistributionFlags name <- getByteString (fromIntegral (len - l)) return $ Name version flags name -------------------------------------------------------------------------------- data Status = Ok | OkSimultaneous | Nok | NotAllowed | Alive deriving (Eq, Show, Bounded, Enum) instance Binary Status where put status = putWithLength16be $ do putChar8 challengeStatus case status of Ok -> putByteString "ok" OkSimultaneous -> putByteString "ok_simultaneous" Nok -> putByteString "nok" NotAllowed -> putByteString "not_allowed" Alive -> putByteString "alive" get = do len <- getWord16be ((), l) <- getWithLength16be $ matchChar8 challengeStatus status <- getByteString (fromIntegral (len - l)) case status of "ok" -> return Ok "ok_simultaneous" -> return OkSimultaneous "nok" -> return Nok "not_allowed" -> return NotAllowed "alive" -> return Alive _ -> fail $ "Bad status: " ++ show status -------------------------------------------------------------------------------- data Challenge = Challenge { c_distVer :: DistributionVersion , c_distFlags :: DistributionFlags , c_challenge :: Word32 , c_nodeName :: BS.ByteString } deriving (Eq, Show) instance Binary Challenge where put Challenge {..} = putWithLength16be $ do putChar8 nodeTypeR6 put c_distVer putDistributionFlags c_distFlags putWord32be c_challenge putByteString c_nodeName get = do len <- getWord16be (((), version, flags, challenge), l) <- getWithLength16be $ (,,,) <$> matchChar8 nodeTypeR6 <*> get <*> getDistributionFlags <*> getWord32be name <- getByteString (fromIntegral (len - l)) return $ Challenge version flags challenge name -------------------------------------------------------------------------------- data ChallengeReply = ChallengeReply { cr_challenge :: Word32 , cr_digest :: BS.ByteString } deriving (Eq, Show) instance Binary ChallengeReply where put ChallengeReply {..} = putWithLength16be $ do putChar8 challengeReply putWord32be cr_challenge putByteString cr_digest get = do len <- getWord16be (((), challenge), l) <- getWithLength16be $ (,) <$> matchChar8 challengeReply <*> getWord32be digest <- getByteString (fromIntegral (len - l)) return $ ChallengeReply challenge digest -------------------------------------------------------------------------------- data ChallengeAck = ChallengeAck { ca_digest :: BS.ByteString } deriving (Eq, Show) instance Binary ChallengeAck where put ChallengeAck {..} = putWithLength16be $ do putChar8 challengeAck putByteString ca_digest get = do len <- getWord16be ((), l) <- getWithLength16be $ matchChar8 challengeAck digest <- getByteString (fromIntegral (len - l)) return $ ChallengeAck digest -------------------------------------------------------------------------------- connectNodes :: BS.ByteString -> NodeData -> DistributionFlags -> Term -> BS.ByteString -> NodeState Term Term Mailbox Connection -> IOx Connection connectNodes localName localNode localFlags remoteName cookie nodeState = do let (remoteAlive, remoteHost) = splitNodeName remoteName remoteNode@NodeData {portNo = remotePort} <- lookupNode remoteAlive remoteHost sock <- connectSocket remoteHost remotePort >>= makeBuffered localVersion <- maybeErrorX illegalOperationErrorType "version mismatch" (matchDistributionVersion localNode remoteNode) Challenge {c_distFlags, c_nodeName} <- handshake sock (Name localVersion localFlags localName) cookie unless (c_nodeName == remoteAlive `BS.append` "@" `BS.append`remoteHost) $ errorX userErrorType "Remote node name mismatch" newConnection sock nodeState remoteName handshake :: BufferedSocket -> Name -> BS.ByteString -> IOx Challenge handshake sock n cookie = do send n s <- recv case s of Ok -> return () _ -> errorX userErrorType $ "Bad status: " ++ show s c <- recv checkVersion n c r <- reply c send r a <- recv checkCookie r a return c where send :: (Binary a) => a -> IOx () send = runPutSocket sock . put recv :: (Binary a) => IOx a recv = runGetSocket sock get checkVersion :: Name -> Challenge -> IOx () checkVersion Name {n_distVer} Challenge {c_distVer} = do unless (n_distVer == c_distVer) $ errorX userErrorType "Version mismatch" reply :: Challenge -> IOx ChallengeReply reply Challenge {c_challenge} = do localChallenge <- genChallenge return $ ChallengeReply localChallenge (genDigest c_challenge cookie) checkCookie :: ChallengeReply -> ChallengeAck -> IOx () checkCookie ChallengeReply {cr_challenge} ChallengeAck {ca_digest} = do unless (ca_digest == genDigest cr_challenge cookie) $ errorX userErrorType "Cookie mismatch" --------------------------------------------------------------------------------