module Control.Parallel.HdpH.Internal.Comm
(
CommM,
run_,
liftIO,
nodes,
allNodes,
myNode,
isMain,
Message,
send,
receive,
shutdown,
waitShutdown
) where
import Prelude hiding (error)
import qualified Prelude (error)
import Control.DeepSeq (NFData(rnf),force)
import Control.Exception (throw)
import Control.Monad (unless,void,when,forever)
import Control.Monad.Reader (ReaderT, runReaderT, ask)
import Control.Monad.Trans (lift)
import Data.Functor ((<$>))
import Data.IORef (writeIORef,atomicModifyIORef)
import qualified Data.Serialize (Serialize, put, get)
import Data.Word (Word8)
import Control.Parallel.HdpH.Internal.Misc (encodeLazy, decodeLazy)
import Control.Parallel.HdpH.Conf
(RTSConf(debugLvl),numProcs, networkInterface)
import Control.Parallel.HdpH.Internal.Location
(NodeId, MyNodeException(NodeIdUnset), error, dbgNone)
import Control.Parallel.HdpH.Internal.State.Location (myNodeRef, debugRef)
import Control.Concurrent
import System.IO (hPutStrLn,stderr)
import System.Exit (ExitCode(..),exitWith)
import System.Timeout (timeout)
import qualified Data.ByteString.Lazy as Lazy (ByteString,toChunks,fromChunks)
import qualified Network.Transport as NT
import qualified Network.Transport.TCP as TCP
import Data.IORef (IORef, newIORef, readIORef)
import System.IO.Unsafe (unsafePerformIO)
import Data.Maybe
import qualified Data.Map as Map
import Data.List ((\\))
import Control.Exception (SomeException,try)
import Network.Multicast
import Network.Socket.ByteString (sendTo, recvFrom)
import Network.Socket (Socket)
import Network.Info
import Data.List (sort)
import System.Random (randomRs, newStdGen)
data State =
State { s_conf :: RTSConf,
s_nodes :: Int,
s_allNodes :: [NodeId],
s_myNode :: NodeId,
s_isMain :: Bool,
s_msgQ :: MessageQ,
s_shutdown :: MVar () }
type MessageQ = Chan Message
type Message = Lazy.ByteString
type CommM = ReaderT State IO
liftIO :: IO a -> CommM a
liftIO = lift
nodes :: CommM Int
nodes = s_nodes <$> ask
allNodes :: CommM [NodeId]
allNodes = s_allNodes <$> ask
myNode :: CommM NodeId
myNode = s_myNode <$> ask
isMain :: CommM Bool
isMain = s_isMain <$> ask
msgQ :: CommM MessageQ
msgQ = s_msgQ <$> ask
debug :: CommM Int
debug = debugLvl <$> s_conf <$> ask
waitShutdown :: CommM ()
waitShutdown = do
mvar <- s_shutdown <$> ask
liftIO $ takeMVar mvar
shutdownTransport
shutdownTransport :: CommM ()
shutdownTransport = do
liftIO shutdownTransportIO
shutdownTransportIO :: IO ()
shutdownTransportIO = do
killConnections =<< connectionLookup
NT.closeEndPoint =<< myEndPoint
NT.closeTransport =<< lclTransport
killConnections :: Map.Map NodeId NT.Connection -> IO ()
killConnections remoteConnections = do
let nodes = Map.keys remoteConnections
mapM_ killConn nodes
where
killConn node = do
let remoteConnection = fromJust $ Map.lookup node remoteConnections
NT.close remoteConnection
shutdown :: CommM ()
shutdown = do
targets <- allNodes
liftIO $ broadcastMsg targets Shutdown
uncleanShutdown :: IO ()
uncleanShutdown = do
#ifdef HDPH_DEBUG
dbg "Shutting down as main process died"
#endif
shutdownTransportIO
exitWith (ExitFailure 9)
run_ :: RTSConf -> CommM () -> IO ()
run_ conf action = do
#ifdef HDPH_DEBUG
dbg "run_.1"
#endif
let debugLevel = debugLvl conf
unless (debugLevel >= dbgNone) $
Prelude.error "HdpH.Internal.Comm_MPI.run_: debug level < none"
writeIORef debugRef debugLevel
#ifdef HDPH_DEBUG
dbg "run_.2"
#endif
#ifdef HDPH_DEBUG
dbg "run_.3"
#endif
myIP <- discoverMyIP conf
transport <- tryCreateTransport myIP conf
atomicModifyIORef lclTransportRef (\r -> (transport,r))
Right myEP <- NT.newEndPoint transport
let me = NT.address myEP
atomicModifyIORef myEndPointRef (\r -> (myEP,r))
(allNodes,main) <- nodeInfo conf
let iAmMain = me == main
atomicModifyIORef myNodeRef (\r -> (me,r))
#ifdef HDPH_DEBUG
dbg "run_.4"
#endif
#ifdef HDPH_DEBUG
dbg "run_.5"
#endif
q <- newChan
startBarrier <- newEmptyMVar
stopBarrier <- newEmptyMVar
#ifdef HDPH_DEBUG
dbg "run_.6"
#endif
#ifdef HDPH_DEBUG
dbg $ "run_.7 receiveServerTid = "
#endif
if iAmMain
then do
#ifdef HDPH_DEBUG
dbg "run_.7.root"
#endif
nodeConnections <- remoteEndPointAddrMap allNodes
atomicModifyIORef connectionLookupRef (\r -> (nodeConnections,r))
recvAllReady (length allNodes 1 )
broadcastMsg (allNodes \\ [me]) Booted
atomicModifyIORef mainEndpointAddrRef (const (myEP, ()))
else do
#ifdef HDPH_DEBUG
dbg "run_.7.other"
#endif
let mainEP = main
nodeConnections <- remoteEndPointAddrMap allNodes
atomicModifyIORef connectionLookupRef (const (nodeConnections, ()))
atomicModifyIORef mainEndpointAddrRef (const (mainEP, ()))
broadcastMsg [main] Ready
waitForBootstrapConfirmation
let s0 = State { s_conf = conf,
s_nodes = length allNodes,
s_allNodes = allNodes,
s_myNode = me,
s_isMain = iAmMain,
s_msgQ = q,
s_shutdown = stopBarrier }
#ifdef HDPH_DEBUG
dbg "run_.8"
#endif
forkIO $ receiveServer q startBarrier stopBarrier
#ifdef HDPH_DEBUG
dbg "run_.8b"
#endif
runReaderT action s0
#ifdef HDPH_DEBUG
dbg "run_.9"
#endif
atomicModifyIORef myNodeRef (\r -> (throw NodeIdUnset,r))
writeIORef debugRef dbgNone
#ifdef HDPH_DEBUG
dbg "run_.10"
#endif
data Msg = Startup
| Shutdown
| Booted
| Ready
| Payload Message
deriving (Eq, Ord, Show)
instance NFData Msg where
rnf Startup = ()
rnf (Booted) = ()
rnf (Ready) = ()
rnf (Shutdown) = ()
rnf (Payload work) = rnf work
instance Data.Serialize.Serialize Msg where
put Startup = Data.Serialize.put (0 :: Word8)
put (Booted) = Data.Serialize.put (1 :: Word8)
put (Ready) = Data.Serialize.put (2 :: Word8)
put (Shutdown) = Data.Serialize.put (3 :: Word8)
put (Payload work) = Data.Serialize.put (4 :: Word8) >>
Data.Serialize.put work
get = do tag <- Data.Serialize.get
case tag :: Word8 of
0 -> do return $ Startup
1 -> do return $ Booted
2 -> do return $ Ready
3 -> do return $ Shutdown
4 -> do work <- Data.Serialize.get
return $ Payload work
send :: NodeId -> Message -> CommM ()
send dest message = lift $ send_ dest message
send_ :: NodeId -> Message -> IO ()
send_ dest message = do
result <- try $ do
remoteConnections <- connectionLookup
let conn = fromJust $ Map.lookup dest remoteConnections
NT.send conn (Lazy.toChunks (encodeLazy (Payload message)))
case result of
Left (e::SomeException) -> void (print e)
Right _ -> return ()
broadcastMsg :: [NodeId] -> Msg -> IO ()
broadcastMsg dests msg =
mapM_ broadcastMsg' dests
where
serialized_msg = encodeLazy msg
broadcastMsg' dest = do
result <- try $ do
remoteConnections <- connectionLookup
let conn = fromJust $ Map.lookup dest remoteConnections
_ <- NT.send conn (Lazy.toChunks serialized_msg)
return ()
case result of
Left (e::SomeException) -> void (print e)
Right _ -> return ()
receive :: CommM Message
receive = do q <- msgQ
liftIO $ readChan q
recv :: IO Msg
recv = do
ep <- myEndPoint
event <- NT.receive ep
case event of
NT.Received _ msg -> return ((force . decodeLazy . Lazy.fromChunks) msg)
NT.ErrorEvent (NT.TransportError e _) ->
case e of
NT.EventConnectionLost ep -> do
mainEP <- mainEndpointAddr
if mainEP == ep then do
uncleanShutdown
return Shutdown
else do
recv
_ -> recv
_ -> do
recv
receiveServer :: MessageQ -> MVar () -> MVar () -> IO ()
receiveServer q startBarrier stopBarrier = do
hdl <- recv
handleMsg hdl
where
handleMsg hdl =
case hdl of
Startup -> receiveServer q startBarrier stopBarrier
Shutdown ->
putMVar stopBarrier ()
Payload message ->
do writeChan q message
receiveServer q startBarrier stopBarrier
_ -> error $ "Unexpected message in `receiveServer' " ++ show hdl
#ifdef HDPH_DEBUG
dbg :: String -> IO ()
dbg s = do
hPutStrLn stderr $ ": HdpH.Internal.Comm_TCP." ++ s
#endif
remoteEndPointAddrMap :: [NodeId] -> IO (Map.Map NodeId NT.Connection)
remoteEndPointAddrMap nodes = do
mvar <- newMVar Map.empty
mapM_ (connectToAllNodes mvar) nodes
takeMVar mvar
waitForBootstrapConfirmation :: IO ()
waitForBootstrapConfirmation = do
msg <- recv
case msg of
Booted -> return ()
_ -> waitForBootstrapConfirmation
connectToAllNodes :: MVar (Map.Map NodeId NT.Connection) -> NodeId -> IO ()
connectToAllNodes mvar remoteNode = do
myEP <- myEndPoint
x <- NT.connect myEP remoteNode NT.ReliableOrdered NT.defaultConnectHints
case x of
(Right newConnection) ->
modifyMVar_ mvar $ \m ->
return $ Map.insert remoteNode newConnection m
(Left _) -> connectToAllNodes mvar remoteNode
recvAllReady :: Int -> IO ()
recvAllReady i =
when (i > 0) $ do
msg <- recv
case msg of
Ready -> recvAllReady (i1)
_ -> putStrLn $ "unexpected msg in recvAllReady: " ++ show msg
tryCreateTransport :: IPv4 -> RTSConf -> IO NT.Transport
tryCreateTransport myIP conf =
createTrans myIP (numProcs conf) 0
createTrans :: IPv4 -> Int -> Int -> IO NT.Transport
createTrans myIP tasks attempts = do
rndsock <- genRandomSocket
t <- TCP.createTransport (show myIP) (show rndsock) TCP.defaultTCPParameters
case t of
Right transport -> return transport
Left e -> do
let attempts' = attempts+1
if attempts' == tasks then error ("Error creating transport: " ++ show e)
else do
createTrans myIP tasks attempts'
nodeInfo :: RTSConf -> IO (SlaveNodes, MainNode)
nodeInfo conf = do
_ <- forkIO $ broadcastTimeout 10000000
all <- findSlaves (numProcs conf)
let mainNode = head (sort all)
return (all,mainNode)
where
broadcastTimeout i = do
_ <- timeout i broadcastMyNode
return ()
discoverMyIP :: RTSConf -> IO IPv4
discoverMyIP conf = do
ns <- getNetworkInterfaces
return $ myIP ns (networkInterface conf)
myIP :: [NetworkInterface] -> String -> IPv4
myIP interfaces interfaceName =
let eth = filter (\x -> name x == interfaceName) interfaces
in ipv4 $ head eth
type MainNode = NodeId
type SlaveNodes = [NodeId]
broadcastMyNode :: IO ()
broadcastMyNode = do
myEP <- myEndPoint
forever $ do
(sock, addr) <- multicastSender "224.0.0.99" 9999
sendTo sock (NT.endPointAddressToByteString (NT.address myEP)) addr
threadDelay 100000
findSlaves :: Int -> IO SlaveNodes
findSlaves numNodesExpected = do
sock <- multicastReceiver "224.0.0.99" 9999
listenForNodes sock [] numNodesExpected
listenForNodes :: Socket -> SlaveNodes -> Int -> IO SlaveNodes
listenForNodes sock ns expected = do
(msg, _) <- recvFrom sock 1024
let remoteEndPointAddr = NT.EndPointAddress msg
let n = if remoteEndPointAddr `elem` ns then [] else [remoteEndPointAddr]
ns' = n ++ ns
if length ns' == expected then return ns'
else listenForNodes sock ns' expected
genRandomSocket :: IO Int
genRandomSocket = do
gen <- newStdGen
return $ head (randomRs (8000,40000) gen)
myEndPoint :: IO NT.EndPoint
myEndPoint = readIORef myEndPointRef
myEndPointRef :: IORef NT.EndPoint
myEndPointRef = unsafePerformIO $ newIORef $ throw NodeIdUnset
connectionLookup :: IO (Map.Map NodeId NT.Connection)
connectionLookup = readIORef connectionLookupRef
connectionLookupRef :: forall k a. IORef (Map.Map k a)
connectionLookupRef = unsafePerformIO $ newIORef Map.empty
mainEndpointAddr :: forall a. IO a
mainEndpointAddr = readIORef mainEndpointAddrRef
mainEndpointAddrRef :: forall a. IORef a
mainEndpointAddrRef = unsafePerformIO $ newIORef $ throw NodeIdUnset
lclTransport :: IO NT.Transport
lclTransport = readIORef lclTransportRef
lclTransportRef :: IORef NT.Transport
lclTransportRef = unsafePerformIO $ newIORef $ throw NodeIdUnset