module Network.Transport.Sockets (
Bindings,
SocketTransport(..),
Connection(..),
IdentifyMessage(..),
Messenger(..),
newMessenger,
addMessenger,
deliver,
closeMessenger,
dispatcher,
sender,
socketSendTo,
receiver,
receiveSocketMessage,
receiveSocketMessages,
SocketSend,
parseSocketAddress,
lookupAddresses,
lookupAddress
) where
import Network.Transport
import Network.Transport.Internal
import Control.Concurrent.Async
import Control.Concurrent.STM
import Control.Exception
import qualified Data.ByteString as B
import qualified Data.Map as M
import Data.Serialize
import qualified Data.Set as S
import qualified Data.Text as T
import GHC.Generics
import Network.Simple.TCP (recv)
import Network.Socket hiding (recv,socket)
import System.Log.Logger
_log :: String
_log = "transport.sockets"
type Bindings = TVar (M.Map Name Mailbox)
data SocketTransport = SocketTransport {
socketMessengers :: TVar (M.Map Address Messenger),
socketBindings :: Bindings,
socketConnection :: Address -> IO Connection,
socketMessenger :: Connection -> Mailbox -> IO Messenger,
socketInbound :: Mailbox,
socketDispatchers :: S.Set (Async ()),
socketResolver :: Resolver
}
data Connection = Connection {
connAddress :: Address,
connSocket :: TMVar Socket,
connConnect :: IO Socket,
connSend :: Socket -> B.ByteString -> IO (),
connReceive :: Socket -> Int -> IO (Maybe B.ByteString),
connClose :: IO ()
}
data Messenger = Messenger {
messengerOut :: Mailbox,
messengerAddress :: Address,
messengerSender :: Async (),
messengerReceiver :: Async (),
messengerConnection :: Connection
}
data IdentifyMessage = IdentifyMessage Address deriving (Generic)
instance Serialize IdentifyMessage
parseSocketAddress :: Address -> (HostName,ServiceName)
parseSocketAddress address =
let identifer = T.pack $ address
parts = T.splitOn ":" identifer
in if (length parts) > 1 then
(host $ T.unpack $ parts !! 0, port $ T.unpack $ parts !! 1)
else (host $ T.unpack $ parts !! 0, "0")
where
host h = if h == "" then
"localhost"
else h
port p = p
lookupAddresses :: (HostName,ServiceName) -> IO [SockAddr]
lookupAddresses hostAndPort =
let (host,port) = hostAndPort
hints = defaultHints { addrFlags = [AI_ADDRCONFIG, AI_CANONNAME, AI_NUMERICSERV] }
in do
addresses <- getAddrInfo (Just hints) (Just host) (Just port)
return $ map addrAddress addresses
lookupAddress :: (HostName,ServiceName) -> IO SockAddr
lookupAddress hostAndPort = do
addresses <- lookupAddresses hostAndPort
return $ addresses !! 0
type SocketSend = Socket -> B.ByteString -> IO ()
instance Show Messenger where
show msngr = "Messenger(" ++ (show $ messengerAddress msngr) ++ ")"
newMessenger :: Connection -> Mailbox -> IO Messenger
newMessenger conn inc = do
out <- newMailbox
sndr <- async $ sender conn out
rcvr <- async $ receiver conn inc
return Messenger {
messengerOut = out,
messengerAddress = connAddress conn,
messengerSender = sndr,
messengerReceiver = rcvr,
messengerConnection = conn
}
addMessenger :: SocketTransport -> Address -> Messenger -> IO ()
addMessenger transport address msngr = do
msngrs <- atomically $ do
modifyTVar (socketMessengers transport) $ \msngrs -> M.insert address msngr msngrs
msngrs <- readTVar (socketMessengers transport)
return msngrs
infoM _log $ "Added messenger to " ++ (show address) ++ "; messengers are " ++ (show msngrs)
deliver :: Messenger -> Message -> IO ()
deliver msngr message = atomically $ writeTQueue (messengerOut msngr) message
dispatcher :: TVar (M.Map Name Mailbox) -> Mailbox -> IO ()
dispatcher bindings mbox = dispatchMessages
where
dispatchMessages = catchExceptions (do
infoM _log $ "Dispatching messages"
env <- atomically $ readTQueue mbox
dispatchMessage env
dispatchMessages)
(\e -> do
warningM _log $ "Dispatch error: " ++ (show (e :: SomeException)))
dispatchMessage env = do
infoM _log $ "Dispatching message"
let envelopeOrErr = decode env
case envelopeOrErr of
Left err -> do
errorM _log $ "Error decoding message for dispatch: " ++ err
return ()
Right (Envelope destination msg) -> do
atomically $ do
dests <- readTVar bindings
let maybeDest = M.lookup destination dests
case maybeDest of
Nothing -> return ()
Just dest -> do
writeTQueue dest msg
return ()
sender :: Connection -> Mailbox -> IO ()
sender conn mailbox = sendMessages
where
sendMessages = do
reconnect
catchExceptions (do
infoM _log $ "Waiting to send to " ++ (show $ connAddress conn)
msg <- atomically $ readTQueue mailbox
infoM _log $ "Sending message to " ++ (show $ connAddress conn)
connected <- atomically $ tryReadTMVar $ connSocket conn
case connected of
Just socket -> do
(connSend conn) socket msg
Nothing -> return ()
)
(\e -> do
warningM _log $ "Send error: " ++ (show (e :: SomeException))
disconnect)
sendMessages
reconnect = do
connected <- atomically $ tryReadTMVar $ connSocket conn
case connected of
Just _ -> return ()
Nothing -> do
let (host,port) = parseSocketAddress $ connAddress conn
infoM _log $ "Connecting to " ++ (show host) ++ ":" ++ (show port)
socket <- connConnect conn
infoM _log $ "Connected to " ++ (show $ connAddress conn)
atomically $ putTMVar (connSocket conn) socket
disconnect = do
connected <- atomically $ tryTakeTMVar $ connSocket conn
case connected of
Just socket -> sClose socket
Nothing -> return ()
socketSendTo :: SocketTransport -> Name -> Message -> IO ()
socketSendTo transport name msg = do
isLocal <- local
if isLocal
then return ()
else remote
where
local = do
found <- atomically $ do
bindings <- readTVar $ socketBindings transport
return $ M.lookup name bindings
case found of
Nothing -> return False
Just mbox -> do
atomically $ writeTQueue mbox msg
return True
remote = do
Just address <- resolve (socketResolver transport) name
let env = encode $ Envelope {
envelopeDestination = name,
envelopeContents = msg
}
amsngr <- atomically $ do
msngrs <- readTVar $ socketMessengers transport
return $ M.lookup address msngrs
case amsngr of
Nothing -> do
msngrs <- atomically $ readTVar $ socketMessengers transport
infoM _log $ "No messenger for " ++ (show address) ++ " in " ++ (show msngrs)
socketVar <- atomically $ newEmptyTMVar
newConn <- (socketConnection transport) address
let conn = newConn {connSocket = socketVar}
msngr <- (socketMessenger transport) conn (socketInbound transport)
addMessenger transport address msngr
deliver msngr env
return ()
Just msngr -> deliver msngr env
receiver :: Connection -> Mailbox -> IO ()
receiver conn mailbox = do
socket <- atomically $ readTMVar $ connSocket conn
receiveSocketMessages socket (connAddress conn) mailbox
receiveSocketMessages :: Socket -> Address -> Mailbox -> IO ()
receiveSocketMessages sock addr mailbox = catchExceptions (do
infoM _log $ "Waiting to receive on " ++ (show addr)
maybeMsg <- receiveSocketMessage sock
infoM _log $ "Received message on " ++ (show addr)
case maybeMsg of
Nothing -> do
sClose sock
return ()
Just msg -> do
atomically $ writeTQueue mailbox msg
receiveSocketMessages sock addr mailbox) (\e -> do
warningM _log $ "Receive error: " ++ (show (e :: SomeException)))
receiveSocketMessage :: Socket -> IO (Maybe Message)
receiveSocketMessage socket = do
maybeLen <- recv socket 8
case maybeLen of
Nothing -> do
infoM _log $ "No length received"
return Nothing
Just len -> do
maybeMsg <- recv socket $ msgLength (decode len)
infoM _log $ "Received message"
return maybeMsg
where
msgLength (Right size) = size
msgLength (Left err) = error err
closeMessenger :: Messenger -> IO ()
closeMessenger msngr = do
cancel $ messengerSender msngr
cancel $ messengerReceiver msngr
connClose $ messengerConnection msngr