module Network.Transport.UDP (
newUDPTransport
) where
import Network.Transport
import Network.Transport.Internal
import Network.Transport.Sockets
import Control.Concurrent.Async
import Control.Concurrent.STM
import Control.Exception
import qualified Data.ByteString as B
import qualified Data.Map as M
import qualified Data.Set as S
import qualified Network.Socket as NS
import Network.Socket.ByteString(sendAllTo,recvFrom)
import System.Log.Logger
_log :: String
_log = "transport.udp"
udpScheme :: Scheme
udpScheme = "udp"
lookupUDPAddress :: Address -> IO NS.SockAddr
lookupUDPAddress address = lookupAddress NS.AF_INET NS.Datagram address
lookupWildcardUDPAddress :: Address -> IO NS.SockAddr
lookupWildcardUDPAddress address = lookupWildcardAddress NS.AF_INET NS.Datagram address
newUDPTransport :: Resolver -> IO Transport
newUDPTransport resolver = do
messengers <- atomically $ newTVar M.empty
bindings <- atomically $ newTVar M.empty
sockets <- newSocketBindings
inbound <- atomically $ newMailbox
dispatch <- async $ dispatcher bindings inbound
let transport = SocketTransport {
socketMessengers = messengers,
socketBindings = bindings,
socketConnection = newUDPConnection,
socketMessenger = newUDPMessenger,
socketInbound = inbound,
socketDispatchers = S.fromList [dispatch],
socketResolver = resolver
}
return Transport {
scheme = udpScheme,
handles = udpHandles transport,
bind = udpBind transport sockets,
sendTo = socketSendTo transport,
shutdown = udpShutdown transport sockets
}
udpHandles :: SocketTransport -> Name -> IO Bool
udpHandles transport name = do
resolved <- resolve (socketResolver transport) name
return $ isJust resolved
where
isJust (Just _) = True
isJust _ = False
udpBind :: SocketTransport -> SocketBindings -> Mailbox Message -> Name -> IO (Either String Binding)
udpBind transport sockets inc name = do
atomically $ modifyTVar (socketBindings transport) $ \bindings ->
M.insert name inc bindings
Just address <- resolve (socketResolver transport) name
bindAddress sockets address $ do
sockaddr <- lookupWildcardUDPAddress address
sock <- NS.socket NS.AF_INET NS.Datagram NS.defaultProtocol
infoM _log $ "Binding to " ++ (show address) ++ " over UDP"
NS.setSocketOption sock NS.ReuseAddr 1
NS.bindSocket sock sockaddr
infoM _log $ "Bound to " ++ (show address) ++ " over UDP"
rcvr <- async $ udpReceiveSocketMessages sock address (socketInbound transport)
return (sock,rcvr)
return $ Right Binding {
bindingName = name,
unbind = do
infoM _log $ "Unbinding from UDP port " ++ (show address)
unbindAddress sockets address
infoM _log $ "Unbound from UDP port " ++ (show address)
}
newUDPConnection :: Address -> IO Connection
newUDPConnection address = do
sock <- atomically newEmptyTMVar
return Connection {
connAddress = address,
connSocket = sock,
connConnect = NS.socket NS.AF_INET NS.Datagram NS.defaultProtocol,
connSend = (\s bs -> do
addr <- lookupUDPAddress address
infoM _log $ "Sending via UDP to " ++ (show addr)
sendAllTo s bs addr
infoM _log $ "Sent via UDP to " ++ (show addr)),
connReceive = udpRecvFrom,
connClose = do
maybeSocket <- atomically $ tryTakeTMVar sock
case maybeSocket of
Just s -> NS.sClose s
Nothing -> return ()
return ()
}
newUDPMessenger :: Connection -> Mailbox Message -> IO Messenger
newUDPMessenger conn mailbox = do
msngr <- newMessenger conn mailbox
return msngr
udpReceiveSocketMessages :: NS.Socket -> Address -> Mailbox Message -> IO ()
udpReceiveSocketMessages sock addr mailbox = catchExceptions
(do
infoM _log $ "Waiting to receive via UDP on " ++ (show addr)
maybeMsg <- udpReceiveSocketMessage
infoM _log $ "Received message via UDP on " ++ (show addr)
case maybeMsg of
Nothing -> do
NS.sClose sock
return ()
Just msg -> do
atomically $ writeMailbox mailbox msg
udpReceiveSocketMessages sock addr mailbox)
(\e -> do
warningM _log $ "Receive error: " ++ (show (e :: SomeException)))
where
udpReceiveSocketMessage = do
maybeMsg <- udpRecvFrom sock 512
infoM _log $ "Received message"
return maybeMsg
udpRecvFrom :: NS.Socket -> Int -> IO (Maybe B.ByteString)
udpRecvFrom sock count = do
(bs,addr) <- recvFrom sock count
infoM _log $ "Received UDP message from " ++ (show addr) ++ ": " ++ (show bs)
if B.null bs
then return Nothing
else return $ Just bs
udpShutdown :: SocketTransport -> SocketBindings -> IO ()
udpShutdown transport sockets = do
infoM _log $ "Unbinding transport"
closeBindings sockets
infoM _log $ "Closing messengers"
msngrs <- atomically $ readTVar $ socketMessengers transport
mapM_ closeMessenger $ M.elems msngrs
infoM _log $ "Closing dispatcher"
mapM_ cancel $ S.toList $ socketDispatchers transport
mapM_ wait $ S.toList $ socketDispatchers transport