-- | Multicast utilities {-# OPTIONS_GHC -fno-warn-orphans #-} module Control.Distributed.Process.Backend.SimpleLocalnet.Internal.Multicast (initMulticast) where import Data.Function (on) import Data.Map (Map) import qualified Data.Map as Map (empty) import Data.Binary (Binary, decode, encode) import Data.IORef (IORef, newIORef, readIORef, modifyIORef) import qualified Data.ByteString as BSS (ByteString, concat) import qualified Data.ByteString.Lazy as BSL ( ByteString , empty , append , fromChunks , toChunks , length , splitAt ) import Data.Accessor (Accessor, (^:), (^.), (^=)) import qualified Data.Accessor.Container as DAC (mapDefault) import Control.Applicative ((<$>)) import Network.Socket (HostName, PortNumber, Socket, SockAddr) import qualified Network.Socket.ByteString as NBS (recvFrom, sendManyTo) import Network.Transport.Internal (decodeInt32, encodeInt32) import Network.Multicast (multicastSender, multicastReceiver) -------------------------------------------------------------------------------- -- Top-level API -- -------------------------------------------------------------------------------- -- | Given a hostname and a port number, initialize the multicast system. -- -- Note: it is important that you never send messages larger than the maximum -- message size; if you do, all subsequent communication will probably fail. -- -- Returns a reader and a writer. -- -- NOTE: By rights the two functions should be "locally" polymorphic in 'a', -- but this requires impredicative types. initMulticast :: forall a. Binary a => HostName -- ^ Multicast IP -> PortNumber -- ^ Port number -> Int -- ^ Maximum message size -> IO (IO (a, SockAddr), a -> IO ()) initMulticast host port bufferSize = do (sendSock, sendAddr) <- multicastSender host port readSock <- multicastReceiver host port st <- newIORef Map.empty return (recvBinary readSock st bufferSize, writer sendSock sendAddr) where writer :: forall a. Binary a => Socket -> SockAddr -> a -> IO () writer sock addr val = do let bytes = encode val len = encodeInt32 (BSL.length bytes) NBS.sendManyTo sock (len : BSL.toChunks bytes) addr -------------------------------------------------------------------------------- -- UDP multicast read, dealing with multiple senders -- -------------------------------------------------------------------------------- type UDPState = Map SockAddr BSL.ByteString #if MIN_VERSION_network(2,4,0) -- network-2.4.0 provides the Ord instance for us #else instance Ord SockAddr where compare = compare `on` show #endif bufferFor :: SockAddr -> Accessor UDPState BSL.ByteString bufferFor = DAC.mapDefault BSL.empty bufferAppend :: SockAddr -> BSS.ByteString -> UDPState -> UDPState bufferAppend addr bytes = bufferFor addr ^: flip BSL.append (BSL.fromChunks [bytes]) recvBinary :: Binary a => Socket -> IORef UDPState -> Int -> IO (a, SockAddr) recvBinary sock st bufferSize = do (bytes, addr) <- recvWithLength sock st bufferSize return (decode bytes, addr) recvWithLength :: Socket -> IORef UDPState -> Int -> IO (BSL.ByteString, SockAddr) recvWithLength sock st bufferSize = do (len, addr) <- recvExact sock 4 st bufferSize let n = decodeInt32 . BSS.concat . BSL.toChunks $ len bytes <- recvExactFrom addr sock n st bufferSize return (bytes, addr) -- Receive all bytes currently in the buffer recvAll :: Socket -> IORef UDPState -> Int -> IO SockAddr recvAll sock st bufferSize = do (bytes, addr) <- NBS.recvFrom sock bufferSize modifyIORef st $ bufferAppend addr bytes return addr recvExact :: Socket -> Int -> IORef UDPState -> Int -> IO (BSL.ByteString, SockAddr) recvExact sock n st bufferSize = do addr <- recvAll sock st bufferSize bytes <- recvExactFrom addr sock n st bufferSize return (bytes, addr) recvExactFrom :: SockAddr -> Socket -> Int -> IORef UDPState -> Int -> IO BSL.ByteString recvExactFrom addr sock n st bufferSize = go where go :: IO BSL.ByteString go = do accAddr <- (^. bufferFor addr) <$> readIORef st if BSL.length accAddr >= fromIntegral n then do let (bytes, accAddr') = BSL.splitAt (fromIntegral n) accAddr modifyIORef st $ bufferFor addr ^= accAddr' return bytes else do _ <- recvAll sock st bufferSize go