-- Node to node communication (via TCP) -- -- Uses the transport layer abstraction for distributed Haskell communication -- Hackage: http://hackage.haskell.org/package/distributed-process -- GitHub: https://github.com/haskell-distributed/distributed-process -- -- Author: Rob Stewart, Patrick Maier ----------------------------------------------------------------------------- {-# LANGUAGE CPP #-} {-# LANGUAGE ScopedTypeVariables #-} -- for 'IOException' module Control.Parallel.HdpH.Internal.Comm ( -- * CommM monad CommM, -- synonym: Control.Monad.Reader.ReaderT IO run_, -- :: RTSConf -> CommM () -> IO () liftIO, -- :: IO a -> CommM a -- * information about the virtual machine nodes, -- :: CommM Int allNodes, -- :: CommM [NodeId] myNode, -- :: CommM NodeId isMain, -- :: CommM Bool -- * sending and receiving messages Message, -- synomyn: MPI.Msg (= Data.ByteString.Lazy.ByteString) send, -- :: NodeId -> Message -> CommM () receive, -- :: CommM Message shutdown, -- :: CommM () waitShutdown -- :: CommM () ) where import Prelude hiding (error) import qualified Prelude (error) import Control.DeepSeq (NFData(rnf),force) import Control.Exception (throw) import Control.Monad (unless,void,when,forever) import Control.Monad.Reader (ReaderT, runReaderT, ask) import Control.Monad.Trans (lift) import Data.Functor ((<$>)) import Data.IORef (writeIORef,atomicModifyIORef) import qualified Data.Serialize (Serialize, put, get) import Data.Word (Word8) import Control.Parallel.HdpH.Internal.Misc (encodeLazy, decodeLazy) import Control.Parallel.HdpH.Conf (RTSConf(debugLvl),numProcs, networkInterface) import Control.Parallel.HdpH.Internal.Location (NodeId, MyNodeException(NodeIdUnset), error, dbgNone) import Control.Parallel.HdpH.Internal.State.Location (myNodeRef, debugRef) import Control.Concurrent import System.IO (hPutStrLn,stderr) -- Used by HDPH_DEBUG import System.Exit (ExitCode(..),exitWith) import System.Timeout (timeout) import qualified Data.ByteString.Lazy as Lazy (ByteString,toChunks,fromChunks) import qualified Network.Transport as NT import qualified Network.Transport.TCP as TCP import Data.IORef (IORef, newIORef, readIORef) import System.IO.Unsafe (unsafePerformIO) import Data.Maybe import qualified Data.Map as Map import Data.List ((\\)) import Control.Exception (SomeException,try) import Network.Multicast import Network.Socket.ByteString (sendTo, recvFrom) import Network.Socket (Socket) import Network.Info import Data.List (sort) import System.Random (randomRs, newStdGen) ----------------------------------------------------------------------------- -- state representation data State = State { s_conf :: RTSConf, -- config data s_nodes :: Int, -- # nodes of virtual machine s_allNodes :: [NodeId], -- all nodes of virtual machine s_myNode :: NodeId, -- currently executing node s_isMain :: Bool, -- True iff currently executing is main node s_msgQ :: MessageQ, -- queue holding received payload messages s_shutdown :: MVar () } -- shutdown signal -- concurrent message queue (storing payload messages) type MessageQ = Chan Message -- a message is an MPI message (here: a lazy byte string) type Message = Lazy.ByteString ----------------------------------------------------------------------------- -- CommM monad -- CommM is a reader monad on top of the IO monad; mutable parts of the state -- (namely the message queue) are implemented via mutable references. type CommM = ReaderT State IO -- lifting lower layers liftIO :: IO a -> CommM a liftIO = lift ----------------------------------------------------------------------------- -- access to individual bits of state -- Number of nodes in the virtual machine. nodes :: CommM Int nodes = s_nodes <$> ask -- List of all nodes in the virtual machine; head should be main node. allNodes :: CommM [NodeId] allNodes = s_allNodes <$> ask -- The currently executing node. myNode :: CommM NodeId myNode = s_myNode <$> ask -- True iff the currently executing node is the main node. isMain :: CommM Bool isMain = s_isMain <$> ask -- internal use only: queue of received payload messages msgQ :: CommM MessageQ msgQ = s_msgQ <$> ask -- |Internal use only: debug level debug :: CommM Int debug = debugLvl <$> s_conf <$> ask -- |Block until receiving a shutdown signal. waitShutdown :: CommM () waitShutdown = do mvar <- s_shutdown <$> ask liftIO $ takeMVar mvar shutdownTransport -- |Called from 'waitShutdown', this closes -- all connections, the local endpoint, and -- then the transport layer, in that order shutdownTransport :: CommM () shutdownTransport = do {- -- Kill all connections liftIO $ killConnections <$> connectionLookup -- Kill my endpoint liftIO $ NT.closeEndPoint <$> myEndPoint -- Kill the transport layer trans <- liftIO lclTransport liftIO $ NT.closeTransport trans -} liftIO shutdownTransportIO -- |Used in an unclean shutdown when -- the connection with the master node -- has been unexpectedly closed. shutdownTransportIO :: IO () shutdownTransportIO = do -- Kill all connections -- remoteConnections <- connectionLookup -- killConnections remoteConnections killConnections =<< connectionLookup -- Kill my endpoint -- myEP <- myEndPoint -- NT.closeEndPoint myEP NT.closeEndPoint =<< myEndPoint -- Kill the transport layer -- trans <- lclTransport -- NT.closeTransport trans NT.closeTransport =<< lclTransport -- |Used during shutdown, to close -- connections with all processes safely killConnections :: Map.Map NodeId NT.Connection -> IO () killConnections remoteConnections = do let nodes = Map.keys remoteConnections mapM_ killConn nodes where killConn node = do let remoteConnection = fromJust $ Map.lookup node remoteConnections NT.close remoteConnection -- |Initiate a shutdown from master process shutdown :: CommM () shutdown = do targets <- allNodes -- broadcast Shutdown message to all nodes but sender liftIO $ broadcastMsg targets Shutdown -- |Connection with the master process has been closed -- unexpectedly, close transport layer uncleanShutdown :: IO () uncleanShutdown = do #ifdef HDPH_DEBUG dbg "Shutting down as main process died" #endif -- The main process has terminated. Let's clean up. shutdownTransportIO exitWith (ExitFailure 9) -- force process termination ----------------------------------------------------------------------------- -- running the CommM monad -- Run the given 'CommM' action in the IO monad with the given config data -- (which determines the debug level, see module HdpH.Internal.Location). -- The 'action' must call 'waitShutdown' before it terminates on all nodes, -- and at least one node must call 'shutdown'. run_ :: RTSConf -> CommM () -> IO () run_ conf action = do #ifdef HDPH_DEBUG dbg "run_.1" #endif -- check debug level let debugLevel = debugLvl conf unless (debugLevel >= dbgNone) $ Prelude.error "HdpH.Internal.Comm_MPI.run_: debug level < none" -- set debug level in HdpH.Internal.State.Location writeIORef debugRef debugLevel #ifdef HDPH_DEBUG dbg "run_.2" #endif #ifdef HDPH_DEBUG dbg "run_.3" #endif myIP <- discoverMyIP conf -- uses network-info for local IP address identification -- Networking step 1: Create a transport transport <- tryCreateTransport myIP conf atomicModifyIORef lclTransportRef (\r -> (transport,r)) -- Networking step 2: Create a local endpoint; write to myEndPointRef IORef Right myEP <- NT.newEndPoint transport let me = NT.address myEP -- Sets 'myEndPointRef'. Receive events from endpoint in transport layer atomicModifyIORef myEndPointRef (\r -> (myEP,r)) (allNodes,main) <- nodeInfo conf let iAmMain = me == main -- set node ID in HdpH.Internal.State.Location atomicModifyIORef myNodeRef (\r -> (me,r)) #ifdef HDPH_DEBUG dbg "run_.4" #endif #ifdef HDPH_DEBUG dbg "run_.5" #endif -- create initial state q <- newChan startBarrier <- newEmptyMVar -- (not used in TCP backend) stopBarrier <- newEmptyMVar #ifdef HDPH_DEBUG dbg "run_.6" #endif #ifdef HDPH_DEBUG dbg $ "run_.7 receiveServerTid = " #endif if iAmMain then do #ifdef HDPH_DEBUG dbg "run_.7.root" #endif -- Networking step 4: Create NodeId's for all nodes -- including endpoint address info. Also, create -- connections between local endpoint and endpoint addresses -- of all nodes. Write connections map to connectionlookupRef -- and all NodeId's to allNodesRef nodeConnections <- remoteEndPointAddrMap allNodes atomicModifyIORef connectionLookupRef (\r -> (nodeConnections,r)) recvAllReady (length allNodes - 1 ) -- blocking broadcastMsg (allNodes \\ [me]) Booted atomicModifyIORef mainEndpointAddrRef (const (myEP, ())) else do #ifdef HDPH_DEBUG dbg "run_.7.other" #endif -- See networking step 4, above. let mainEP = main nodeConnections <- remoteEndPointAddrMap allNodes atomicModifyIORef connectionLookupRef (const (nodeConnections, ())) atomicModifyIORef mainEndpointAddrRef (const (mainEP, ())) -- Tells master that this node is ready broadcastMsg [main] Ready -- waits for `Booted', means that main is connected to all nodes waitForBootstrapConfirmation let s0 = State { s_conf = conf, s_nodes = length allNodes, s_allNodes = allNodes, s_myNode = me, s_isMain = iAmMain, s_msgQ = q, s_shutdown = stopBarrier } #ifdef HDPH_DEBUG dbg "run_.8" #endif forkIO $ receiveServer q startBarrier stopBarrier #ifdef HDPH_DEBUG dbg "run_.8b" #endif -- run monad runReaderT action s0 #ifdef HDPH_DEBUG dbg "run_.9" #endif -- reset HdpH.Internal.State.Location atomicModifyIORef myNodeRef (\r -> (throw NodeIdUnset,r)) writeIORef debugRef dbgNone #ifdef HDPH_DEBUG dbg "run_.10" #endif ----------------------------------------------------------------------------- -- internal messages data Msg = Startup -- startup completed message (main -> other) | Shutdown -- shutdown system message (main -> other) | Booted | Ready | Payload Message -- non-system message; arg (payload) to be queued deriving (Eq, Ord, Show) -- Show inst only for debugging instance NFData Msg where rnf Startup = () rnf (Booted) = () rnf (Ready) = () rnf (Shutdown) = () rnf (Payload work) = rnf work instance Data.Serialize.Serialize Msg where put Startup = Data.Serialize.put (0 :: Word8) put (Booted) = Data.Serialize.put (1 :: Word8) put (Ready) = Data.Serialize.put (2 :: Word8) put (Shutdown) = Data.Serialize.put (3 :: Word8) put (Payload work) = Data.Serialize.put (4 :: Word8) >> Data.Serialize.put work get = do tag <- Data.Serialize.get case tag :: Word8 of 0 -> do return $ Startup 1 -> do return $ Booted 2 -> do return $ Ready 3 -> do return $ Shutdown 4 -> do work <- Data.Serialize.get return $ Payload work ----------------------------------------------------------------------------- -- sending messages (incl system messages) -- Send a payload message. send :: NodeId -> Message -> CommM () send dest message = lift $ send_ dest message send_ :: NodeId -> Message -> IO () send_ dest message = do result <- try $ do remoteConnections <- connectionLookup let conn = fromJust $ Map.lookup dest remoteConnections -- Actual type of 'send' is Either (NT.FailedWith NT.SendErrorCode) () NT.send conn (Lazy.toChunks (encodeLazy (Payload message))) case result of Left (e::SomeException) -> void (print e) Right _ -> return () -- This needs to be separate because in `send_' (Payload _) -- constructor is automatically added. broadcastMsg :: [NodeId] -> Msg -> IO () broadcastMsg dests msg = mapM_ broadcastMsg' dests where serialized_msg = encodeLazy msg broadcastMsg' dest = do result <- try $ do remoteConnections <- connectionLookup let conn = fromJust $ Map.lookup dest remoteConnections -- Actual type of 'send' is Either (NT.FailedWith NT.SendErrorCode) () _ <- NT.send conn (Lazy.toChunks serialized_msg) return () case result of Left (e::SomeException) -> void (print e) Right _ -> return () ----------------------------------------------------------------------------- -- receiving messages (incl system messages and message queue server) -- Block to receive a message; -- the sender must be encoded into the message, otherwise it is unknown. receive :: CommM Message receive = do q <- msgQ liftIO $ readChan q recv :: IO Msg recv = do ep <- myEndPoint event <- NT.receive ep case event of NT.Received _ msg -> return ((force . decodeLazy . Lazy.fromChunks) msg) NT.ErrorEvent (NT.TransportError e _) -> case e of NT.EventConnectionLost ep -> do mainEP <- mainEndpointAddr -- Let's check if the main node has died. -- If it has, we should give up. if mainEP == ep then do -- Main process has terminated prematurely. Fatal. uncleanShutdown return Shutdown else do {- Enabled in ft-scheduler branch -- Send a message to scheduler remoteConnections <- connectionLookup let x = Map.filterWithKey (\node _ -> ep == node) remoteConnections deadNode = head $ Map.keys x -- should only be one msg = Payload $ encodeLazy (Payload.DEADNODE deadNode) return msg -} recv _ -> recv -- links to "case e of" _ -> do -- links to "case event of" -- ignore remaining NT.Event constructors for now -- i.e. [ConnectionClosed,ConnectionOpened,ReceivedMulticast,EndPointClosed] recv -- Non-terminating computation, to be run in a separate thread. -- Continually receives message, which it puts into the given -- message queue or handles immediately (in the case of system messages). -- * 'Shutdown' unblocks the shutdown barrier (thus terminating all actions). receiveServer :: MessageQ -> MVar () -> MVar () -> IO () receiveServer q startBarrier stopBarrier = do hdl <- recv handleMsg hdl -- NOTE: Changed from previous 'forkIO $ handleMsg hdl' -- Rationale: 'handleMsg' should be sufficiently lazy to run sequentially', -- plus it poses less danger of corruption on shutdown where handleMsg hdl = -- receive message and dispatch on constructor case hdl of -- 'Startup' not used in TCP backend Startup -> receiveServer q startBarrier stopBarrier Shutdown -> -- lift shutdown barrier putMVar stopBarrier () Payload message -> -- queue the payload do writeChan q message receiveServer q startBarrier stopBarrier _ -> error $ "Unexpected message in `receiveServer' " ++ show hdl ----------------------------------------------------------------------------- -- debugging #ifdef HDPH_DEBUG dbg :: String -> IO () dbg s = do hPutStrLn stderr $ ": HdpH.Internal.Comm_TCP." ++ s #endif -- |Used to setup and store a Map of NodeId -> NT.Connection -- And also, creates a list of [NodeId] that is written -- to the allNodesRef IORef remoteEndPointAddrMap :: [NodeId] -> IO (Map.Map NodeId NT.Connection) remoteEndPointAddrMap nodes = do mvar <- newMVar Map.empty -- to store connections mapM_ (connectToAllNodes mvar) nodes takeMVar mvar ------ -- Start up utilities -- |Before the 'receiveServer' function is forked, -- each node must receive a Booted payload from the master process. -- This indicates that all nodes have sent an ALIVE payload to the master process. waitForBootstrapConfirmation :: IO () waitForBootstrapConfirmation = do msg <- recv case msg of Booted -> return () _ -> waitForBootstrapConfirmation -- Hangover from UDP broadcasts, still wait for `Booted' -- |Writing into an MVar the connection that has been -- made with the remote node, to be written into -- the connectionLookupRef IORef connectToAllNodes :: MVar (Map.Map NodeId NT.Connection) -> NodeId -> IO () connectToAllNodes mvar remoteNode = do myEP <- myEndPoint x <- NT.connect myEP remoteNode NT.ReliableOrdered NT.defaultConnectHints case x of (Right newConnection) -> modifyMVar_ mvar $ \m -> return $ Map.insert remoteNode newConnection m (Left _) -> connectToAllNodes mvar remoteNode -- keep retrying -- |'action' is only executed once the master node -- has receve NT.ConnectionOpened from all other nodes recvAllReady :: Int -> IO () recvAllReady i = when (i > 0) $ do msg <- recv case msg of Ready -> recvAllReady (i-1) _ -> putStrLn $ "unexpected msg in recvAllReady: " ++ show msg --------------------------- -- Transport layer creation tryCreateTransport :: IPv4 -> RTSConf -> IO NT.Transport tryCreateTransport myIP conf = createTrans myIP (numProcs conf) 0 createTrans :: IPv4 -> Int -> Int -> IO NT.Transport createTrans myIP tasks attempts = do rndsock <- genRandomSocket t <- TCP.createTransport (show myIP) (show rndsock) TCP.defaultTCPParameters case t of Right transport -> return transport Left e -> do let attempts' = attempts+1 if attempts' == tasks then error ("Error creating transport: " ++ show e) else do createTrans myIP tasks attempts' ----------------------- -- UDP based Node discovery nodeInfo :: RTSConf -> IO (SlaveNodes, MainNode) nodeInfo conf = do _ <- forkIO $ broadcastTimeout 10000000 all <- findSlaves (numProcs conf) let mainNode = head (sort all) -- election protocol return (all,mainNode) where broadcastTimeout i = do -- broadcast endpoint address via UDP for 10 seconds _ <- timeout i broadcastMyNode return () discoverMyIP :: RTSConf -> IO IPv4 discoverMyIP conf = do ns <- getNetworkInterfaces return $ myIP ns (networkInterface conf) myIP :: [NetworkInterface] -> String -> IPv4 myIP interfaces interfaceName = let eth = filter (\x -> name x == interfaceName) interfaces in ipv4 $ head eth type MainNode = NodeId type SlaveNodes = [NodeId] broadcastMyNode :: IO () broadcastMyNode = do myEP <- myEndPoint forever $ do (sock, addr) <- multicastSender "224.0.0.99" 9999 sendTo sock (NT.endPointAddressToByteString (NT.address myEP)) addr threadDelay 100000 findSlaves :: Int -> IO SlaveNodes findSlaves numNodesExpected = do sock <- multicastReceiver "224.0.0.99" 9999 listenForNodes sock [] numNodesExpected listenForNodes :: Socket -> SlaveNodes -> Int -> IO SlaveNodes listenForNodes sock ns expected = do (msg, _) <- recvFrom sock 1024 let remoteEndPointAddr = NT.EndPointAddress msg let n = if remoteEndPointAddr `elem` ns then [] else [remoteEndPointAddr] ns' = n ++ ns if length ns' == expected then return ns' else listenForNodes sock ns' expected genRandomSocket :: IO Int genRandomSocket = do gen <- newStdGen return $ head (randomRs (8000,40000) gen) ------ -- Enpoint and connection lookup IORefs myEndPoint :: IO NT.EndPoint myEndPoint = readIORef myEndPointRef myEndPointRef :: IORef NT.EndPoint myEndPointRef = unsafePerformIO $ newIORef $ throw NodeIdUnset {-# NOINLINE myEndPointRef #-} -- required to protect unsafePerformIO hack connectionLookup :: IO (Map.Map NodeId NT.Connection) connectionLookup = readIORef connectionLookupRef connectionLookupRef :: forall k a. IORef (Map.Map k a) connectionLookupRef = unsafePerformIO $ newIORef Map.empty {-# NOINLINE connectionLookupRef #-} -- required to protect unsafePerformIO hack -- Used to watch when main node has failed. -- If main node fails, shutdown transport layer and terminate. mainEndpointAddr :: forall a. IO a mainEndpointAddr = readIORef mainEndpointAddrRef mainEndpointAddrRef :: forall a. IORef a mainEndpointAddrRef = unsafePerformIO $ newIORef $ throw NodeIdUnset {-# NOINLINE mainEndpointAddrRef #-} -- required to protect unsafePerformIO hack lclTransport :: IO NT.Transport lclTransport = readIORef lclTransportRef lclTransportRef :: IORef NT.Transport lclTransportRef = unsafePerformIO $ newIORef $ throw NodeIdUnset {-# NOINLINE lclTransportRef #-} -- required to protect unsafePerformIO hack