{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE DeriveLift #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TypeFamilies #-}
module Foreign.Erlang.LocalNode
( LocalNode()
, NodeT()
, LocalNodeConfig(..)
, askCreation
, askNodeName
, askNodeState
, askNodeRegistration
, askLocalNode
, runNodeT
, make_pid
, make_ref
, make_port
, make_mailbox
, register_pid
, send
, sendReg
) where
import Prelude hiding ( id )
import Control.Monad
import Control.Monad.Reader
import Control.Concurrent.STM
import Control.Monad.Base
import qualified Data.ByteString.Char8 as CS
import Data.Word
import Util.IOExtra
import Util.BufferedIOx
import Util.Socket
import Network.BufferedSocket
import Foreign.Erlang.ControlMessage ( ControlMessage(..) )
import Foreign.Erlang.NodeState
import Foreign.Erlang.NodeData
import Foreign.Erlang.Epmd
import Foreign.Erlang.Handshake
import Foreign.Erlang.Term
import Foreign.Erlang.Connection
import Foreign.Erlang.Mailbox
data LocalNodeConfig = LocalNodeConfig { aliveName :: String
, hostName :: String
, cookie :: String
}
deriving Show
newtype NodeT m a = NodeT { unNodeT :: ReaderT RegisteredNode m a }
deriving (Functor, Applicative, Monad, MonadCatch, MonadThrow, MonadMask, MonadLogger, MonadIO)
deriving instance (MonadBase IO (NodeT m), MonadResource m) =>
MonadResource (NodeT m)
deriving instance MonadLoggerIO m => MonadLoggerIO (NodeT m)
deriving instance MonadBase b m => MonadBase b (NodeT m)
instance (MonadBaseControl b m) =>
MonadBaseControl b (NodeT m) where
type StM (NodeT m) a = StM m a
liftBaseWith k = NodeT (liftBaseWith (\run -> k (run . unNodeT)))
restoreM = NodeT . restoreM
data LocalNode = LocalNode { handshakeData :: HandshakeData
, nodeState :: NodeState Pid Term Mailbox Connection
, acceptorSocket :: Socket
}
data RegisteredNode = RegisteredNode { localNode :: LocalNode
, nodeRegistration :: NodeRegistration
}
askLocalNode :: Monad m => NodeT m LocalNode
askLocalNode = NodeT (asks localNode)
askNodeRegistration :: Monad m => NodeT m NodeRegistration
askNodeRegistration = NodeT (asks nodeRegistration)
askCreation :: Monad m => NodeT m Word8
askCreation = fromIntegral . nr_creation <$> askNodeRegistration
askNodeState :: Monad m => NodeT m (NodeState Pid Term Mailbox Connection)
askNodeState = nodeState <$> askLocalNode
askNodeName :: Monad m => NodeT m CS.ByteString
askNodeName = n_nodeName . name . handshakeData <$> askLocalNode
make_pid :: MonadIO m => NodeT m Pid
make_pid = do
name <- askNodeName
state <- askNodeState
(id, serial) <- liftIO (new_pid state)
cr <- askCreation
return (pid name id serial cr)
register_pid :: (MonadIO m) => Term -> Pid -> NodeT m Bool
register_pid name pid' = do
state <- askNodeState
liftIO (do
mbox <- getMailboxForPid state pid'
mapM_ (putMailboxForName state name) mbox
return (isJust mbox))
make_ref :: (MonadIO m) => NodeT m Term
make_ref = do
state <- askNodeState
name <- askNodeName
(refId0, refId1, refId2) <- liftIO (new_ref state)
cr <- askCreation
return (ref name cr [ refId0, refId1, refId2 ])
make_port :: (MonadIO m) => NodeT m Term
make_port = do
name <- askNodeName
state <- askNodeState
id <- liftIO (new_port state)
cr <- askCreation
return $ port name id cr
runNodeT :: forall m a.
(MonadResource m, MonadThrow m, MonadMask m, MonadLogger m, MonadLoggerIO m, MonadBaseControl IO m)
=> LocalNodeConfig
-> NodeT m a
-> m a
runNodeT LocalNodeConfig{aliveName,hostName,cookie} NodeT{unNodeT} = do
requireM "(aliveName /= \"\")" (aliveName /= "")
requireM "(hostName /= \"\")" (hostName /= "")
bracket setupAcceptorSock stopAllConnections acceptRegisterAndRun
where
setupAcceptorSock = do
let nodeNameBS = CS.pack (aliveName ++ "@" ++ hostName)
(_, (acceptorSocket, portNo)) <- allocate (serverSocket (CS.pack hostName))
(closeSock . fst)
let dFlags = DistributionFlags [ EXTENDED_REFERENCES
, FUN_TAGS
, NEW_FUN_TAGS
, EXTENDED_PIDS_PORTS
, BIT_BINARIES
, NEW_FLOATS
]
name = Name { n_distVer = R6B
, n_distFlags = dFlags
, n_nodeName = nodeNameBS
}
nodeData = NodeData { portNo = portNo
, nodeType = HiddenNode
, protocol = TcpIpV4
, hiVer = R6B
, loVer = R6B
, aliveName = CS.pack aliveName
, extra = ""
}
handshakeData = HandshakeData { name
, nodeData
, cookie = CS.pack cookie
}
nodeState <- liftIO newNodeState
return LocalNode { acceptorSocket, handshakeData, nodeState }
acceptRegisterAndRun localNode@LocalNode{acceptorSocket,handshakeData = hsn@HandshakeData{nodeData},nodeState} =
withAsync accept (\accepted -> link accepted >> registerAndRun)
where
accept = forever (bracketOnErrorLog (liftIO (acceptSocket acceptorSocket >>=
makeBuffered))
(liftIO . closeBuffered)
onConnect)
where
onConnect sock = tryAndLogAll (doAccept (runPutBuffered sock)
(runGetBuffered sock)
hsn)
>>= maybe (return ())
(void . newConnection sock nodeState . atom)
registerAndRun = registerNode nodeData (CS.pack hostName) go
where
go nodeRegistration = do
let env = RegisteredNode { localNode, nodeRegistration }
result <- runReaderT unNodeT env
return result
stopAllConnections LocalNode{nodeState} = do
cs <- liftIO $ getConnectedNodes nodeState
mapM_ (liftIO . closeConnection . snd) cs
make_mailbox :: (MonadResource m) => NodeT m Mailbox
make_mailbox = do
self <- make_pid
msgQueue <- liftIO newTQueueIO
let mailbox = MkMailbox { self, msgQueue }
nodeState <- askNodeState
liftIO (putMailboxForPid nodeState self mailbox)
return mailbox
send :: (MonadMask m, MonadBaseControl IO m, MonadResource m, MonadLoggerIO m)
=> Pid
-> Term
-> NodeT m ()
send toPid message = getOrCreateConnection (atom_name (node (toTerm toPid)))
>>= maybe (return ()) (sendControlMessage (SEND toPid message))
sendReg :: (MonadMask m, MonadBaseControl IO m, MonadResource m, MonadLoggerIO m)
=> Mailbox
-> Term
-> Term
-> Term
-> NodeT m ()
sendReg MkMailbox{self} regName nodeName message =
getOrCreateConnection (atom_name nodeName) >>=
maybe (return ()) (sendControlMessage (REG_SEND self regName message))
splitNodeName :: CS.ByteString -> (CS.ByteString, CS.ByteString)
splitNodeName a = case CS.split '@' a of
[ alive, host ] -> (alive, host)
_ -> error $ "Illegal node name: " ++ show a
getOrCreateConnection :: (MonadMask m, MonadBaseControl IO m, MonadResource m, MonadLoggerIO m)
=> CS.ByteString
-> NodeT m (Maybe Connection)
getOrCreateConnection remoteName =
getExistingConnection >>= maybe lookupAndConnect (return . Just)
where
getExistingConnection = do
let nodeName = atom remoteName
logInfoStr (printf "getExistingConnection %s" (show nodeName))
nodeState <- askNodeState
logNodeState nodeState
getConnectionForNode nodeState nodeName
lookupAndConnect = lookupNode remoteAlive remoteHost >>=
maybe warnNotFound connect
where
(remoteAlive, remoteHost) =
splitNodeName remoteName
warnNotFound = do
logWarnStr (printf "Connection failed: Node '%s' not found on '%s'."
(CS.unpack remoteAlive)
(CS.unpack remoteHost))
return Nothing
connect NodeData{portNo = remotePort} =
bracketOnErrorLog (liftIO (connectSocket remoteHost remotePort >>=
makeBuffered))
cleanup
go
where
cleanup sock = do
liftIO (closeBuffered sock)
return Nothing
go sock = Just <$> do
nodeState <- askNodeState
LocalNode{handshakeData} <- askLocalNode
doConnect (runPutBuffered sock)
(runGetBuffered sock)
handshakeData
newConnection sock nodeState (atom remoteName)