module Network.Transport.Tests.Multicast where

import Network.Transport
import Control.Monad (replicateM, replicateM_, forM_, when)
import Control.Concurrent (forkIO)
import Control.Concurrent.MVar (MVar, newEmptyMVar, takeMVar, putMVar, readMVar)
import Data.ByteString (ByteString)
import Data.List (elemIndex)
import Network.Transport.Tests.Auxiliary (runTests)

-- | Node for the "No confusion" test
noConfusionNode :: Transport -- ^ Transport
                -> [MVar MulticastAddress] -- ^ my group : groups to subscribe to
                -> [MVar ()]               -- ^ I'm ready : others ready
                -> Int                     -- ^ number of pings
                -> [ByteString]            -- ^ my message : messages from subscribed groups (same order as 'groups to subscribe to')
                -> MVar ()                 -- ^ I'm done
                -> IO ()
noConfusionNode :: Transport
-> [MVar MulticastAddress]
-> [MVar ()]
-> Int
-> [ByteString]
-> MVar ()
-> IO ()
noConfusionNode Transport
transport [MVar MulticastAddress]
groups [MVar ()]
ready Int
numPings [ByteString]
msgs MVar ()
done = do
  -- Create a new endpoint
  Right EndPoint
endpoint <- Transport
-> IO (Either (TransportError NewEndPointErrorCode) EndPoint)
newEndPoint Transport
transport

  -- Create a new multicast group and broadcast its address
  Right MulticastGroup
myGroup <- EndPoint
-> IO
     (Either (TransportError NewMulticastGroupErrorCode) MulticastGroup)
newMulticastGroup EndPoint
endpoint
  MVar MulticastAddress -> MulticastAddress -> IO ()
forall a. MVar a -> a -> IO ()
putMVar ([MVar MulticastAddress] -> MVar MulticastAddress
forall a. HasCallStack => [a] -> a
head [MVar MulticastAddress]
groups) (MulticastGroup -> MulticastAddress
multicastAddress MulticastGroup
myGroup)

  -- Subscribe to the given multicast groups
  [MulticastAddress]
addrs <- (MVar MulticastAddress -> IO MulticastAddress)
-> [MVar MulticastAddress] -> IO [MulticastAddress]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM MVar MulticastAddress -> IO MulticastAddress
forall a. MVar a -> IO a
readMVar ([MVar MulticastAddress] -> [MVar MulticastAddress]
forall a. HasCallStack => [a] -> [a]
tail [MVar MulticastAddress]
groups)
  [MulticastAddress] -> (MulticastAddress -> IO ()) -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [MulticastAddress]
addrs ((MulticastAddress -> IO ()) -> IO ())
-> (MulticastAddress -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \MulticastAddress
addr -> do Right MulticastGroup
group <- EndPoint
-> MulticastAddress
-> IO
     (Either
        (TransportError ResolveMulticastGroupErrorCode) MulticastGroup)
resolveMulticastGroup EndPoint
endpoint MulticastAddress
addr
                            MulticastGroup -> IO ()
multicastSubscribe MulticastGroup
group

  -- Indicate that we're ready and wait for everybody else to be ready
  MVar () -> () -> IO ()
forall a. MVar a -> a -> IO ()
putMVar ([MVar ()] -> MVar ()
forall a. HasCallStack => [a] -> a
head [MVar ()]
ready) ()
  (MVar () -> IO ()) -> [MVar ()] -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ MVar () -> IO ()
forall a. MVar a -> IO a
readMVar ([MVar ()] -> [MVar ()]
forall a. HasCallStack => [a] -> [a]
tail [MVar ()]
ready)

  -- Send messages..
  IO () -> IO ThreadId
forkIO (IO () -> IO ThreadId) -> (IO () -> IO ()) -> IO () -> IO ThreadId
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> IO () -> IO ()
forall (m :: * -> *) a. Applicative m => Int -> m a -> m ()
replicateM_ Int
numPings (IO () -> IO ThreadId) -> IO () -> IO ThreadId
forall a b. (a -> b) -> a -> b
$ MulticastGroup -> [ByteString] -> IO ()
multicastSend MulticastGroup
myGroup [[ByteString] -> ByteString
forall a. HasCallStack => [a] -> a
head [ByteString]
msgs]

  -- ..while checking that the messages we receive are the right ones
  Int -> IO () -> IO ()
forall (m :: * -> *) a. Applicative m => Int -> m a -> m ()
replicateM_ (Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
numPings) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
    Event
event <- EndPoint -> IO Event
receive EndPoint
endpoint
    case Event
event of
      ReceivedMulticast MulticastAddress
addr [ByteString
msg] ->
        let mix :: Maybe Int
mix = MulticastAddress
addr MulticastAddress -> [MulticastAddress] -> Maybe Int
forall a. Eq a => a -> [a] -> Maybe Int
`elemIndex` [MulticastAddress]
addrs in
        case Maybe Int
mix of
          Maybe Int
Nothing -> [Char] -> IO ()
forall a. HasCallStack => [Char] -> a
error [Char]
"Message from unexpected source"
          Just Int
ix -> Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ([ByteString]
msgs [ByteString] -> Int -> ByteString
forall a. HasCallStack => [a] -> Int -> a
!! (Int
ix Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
/= ByteString
msg) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ [Char] -> IO ()
forall a. HasCallStack => [Char] -> a
error [Char]
"Unexpected message"
      Event
_ ->
        [Char] -> IO ()
forall a. HasCallStack => [Char] -> a
error [Char]
"Unexpected event"

  -- Success
  MVar () -> () -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar ()
done ()

-- | Test that distinct multicast groups are not confused
testNoConfusion :: Transport -> Int -> IO ()
testNoConfusion :: Transport -> Int -> IO ()
testNoConfusion Transport
transport Int
numPings = do
  [MVar MulticastAddress
group1, MVar MulticastAddress
group2, MVar MulticastAddress
group3] <- Int -> IO (MVar MulticastAddress) -> IO [MVar MulticastAddress]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
3 IO (MVar MulticastAddress)
forall a. IO (MVar a)
newEmptyMVar
  [MVar ()
readyA, MVar ()
readyB, MVar ()
readyC] <- Int -> IO (MVar ()) -> IO [MVar ()]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
3 IO (MVar ())
forall a. IO (MVar a)
newEmptyMVar
  [MVar ()
doneA, MVar ()
doneB, MVar ()
doneC]    <- Int -> IO (MVar ()) -> IO [MVar ()]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
3 IO (MVar ())
forall a. IO (MVar a)
newEmptyMVar
  let [ByteString
msgA, ByteString
msgB, ByteString
msgC]    = [ByteString
"A says hi", ByteString
"B says hi", ByteString
"C says hi"]

  IO () -> IO ThreadId
forkIO (IO () -> IO ThreadId) -> IO () -> IO ThreadId
forall a b. (a -> b) -> a -> b
$ Transport
-> [MVar MulticastAddress]
-> [MVar ()]
-> Int
-> [ByteString]
-> MVar ()
-> IO ()
noConfusionNode Transport
transport [MVar MulticastAddress
group1, MVar MulticastAddress
group1, MVar MulticastAddress
group2] [MVar ()
readyA, MVar ()
readyB, MVar ()
readyC] Int
numPings [ByteString
msgA, ByteString
msgA, ByteString
msgB] MVar ()
doneA
  IO () -> IO ThreadId
forkIO (IO () -> IO ThreadId) -> IO () -> IO ThreadId
forall a b. (a -> b) -> a -> b
$ Transport
-> [MVar MulticastAddress]
-> [MVar ()]
-> Int
-> [ByteString]
-> MVar ()
-> IO ()
noConfusionNode Transport
transport [MVar MulticastAddress
group2, MVar MulticastAddress
group1, MVar MulticastAddress
group3] [MVar ()
readyB, MVar ()
readyC, MVar ()
readyA] Int
numPings [ByteString
msgB, ByteString
msgA, ByteString
msgC] MVar ()
doneB
  IO () -> IO ThreadId
forkIO (IO () -> IO ThreadId) -> IO () -> IO ThreadId
forall a b. (a -> b) -> a -> b
$ Transport
-> [MVar MulticastAddress]
-> [MVar ()]
-> Int
-> [ByteString]
-> MVar ()
-> IO ()
noConfusionNode Transport
transport [MVar MulticastAddress
group3, MVar MulticastAddress
group2, MVar MulticastAddress
group3] [MVar ()
readyC, MVar ()
readyA, MVar ()
readyB] Int
numPings [ByteString
msgC, ByteString
msgB, ByteString
msgC] MVar ()
doneC

  (MVar () -> IO ()) -> [MVar ()] -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ MVar () -> IO ()
forall a. MVar a -> IO a
takeMVar [MVar ()
doneA, MVar ()
doneB, MVar ()
doneC]

-- | Test multicast
testMulticast :: Transport -> IO ()
testMulticast :: Transport -> IO ()
testMulticast Transport
transport =
  [([Char], IO ())] -> IO ()
runTests
    [ ([Char]
"NoConfusion", Transport -> Int -> IO ()
testNoConfusion Transport
transport Int
10000) ]