module Network.Transport.Tests.Multicast where

import Network.Transport
import Control.Monad (replicateM, replicateM_, forM_, when)
import Control.Concurrent (forkIO)
import Control.Concurrent.MVar (MVar, newEmptyMVar, takeMVar, putMVar, readMVar)
import Data.ByteString (ByteString)
import Data.List (elemIndex)
import Network.Transport.Tests.Auxiliary (runTests)

-- | Node for the "No confusion" test
noConfusionNode :: Transport -- ^ Transport
                -> [MVar MulticastAddress] -- ^ my group : groups to subscribe to
                -> [MVar ()]               -- ^ I'm ready : others ready
                -> Int                     -- ^ number of pings
                -> [ByteString]            -- ^ my message : messages from subscribed groups (same order as 'groups to subscribe to')
                -> MVar ()                 -- ^ I'm done
                -> IO ()
noConfusionNode :: Transport
-> [MVar MulticastAddress]
-> [MVar ()]
-> Int
-> [ByteString]
-> MVar ()
-> IO ()
noConfusionNode Transport
transport [MVar MulticastAddress]
groups [MVar ()]
ready Int
numPings [ByteString]
msgs MVar ()
done = do
  -- Create a new endpoint
  Right EndPoint
endpoint <- Transport
-> IO (Either (TransportError NewEndPointErrorCode) EndPoint)
newEndPoint Transport
transport

  -- Create a new multicast group and broadcast its address
  Right MulticastGroup
myGroup <- EndPoint
-> IO
     (Either (TransportError NewMulticastGroupErrorCode) MulticastGroup)
newMulticastGroup EndPoint
endpoint
  forall a. MVar a -> a -> IO ()
putMVar (forall a. [a] -> a
head [MVar MulticastAddress]
groups) (MulticastGroup -> MulticastAddress
multicastAddress MulticastGroup
myGroup)

  -- Subscribe to the given multicast groups
  [MulticastAddress]
addrs <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall a. MVar a -> IO a
readMVar (forall a. [a] -> [a]
tail [MVar MulticastAddress]
groups)
  forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [MulticastAddress]
addrs forall a b. (a -> b) -> a -> b
$ \MulticastAddress
addr -> do Right MulticastGroup
group <- EndPoint
-> MulticastAddress
-> IO
     (Either
        (TransportError ResolveMulticastGroupErrorCode) MulticastGroup)
resolveMulticastGroup EndPoint
endpoint MulticastAddress
addr
                            MulticastGroup -> IO ()
multicastSubscribe MulticastGroup
group

  -- Indicate that we're ready and wait for everybody else to be ready
  forall a. MVar a -> a -> IO ()
putMVar (forall a. [a] -> a
head [MVar ()]
ready) ()
  forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ forall a. MVar a -> IO a
readMVar (forall a. [a] -> [a]
tail [MVar ()]
ready)

  -- Send messages..
  IO () -> IO ThreadId
forkIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. Applicative m => Int -> m a -> m ()
replicateM_ Int
numPings forall a b. (a -> b) -> a -> b
$ MulticastGroup -> [ByteString] -> IO ()
multicastSend MulticastGroup
myGroup [forall a. [a] -> a
head [ByteString]
msgs]

  -- ..while checking that the messages we receive are the right ones
  forall (m :: * -> *) a. Applicative m => Int -> m a -> m ()
replicateM_ (Int
2 forall a. Num a => a -> a -> a
* Int
numPings) forall a b. (a -> b) -> a -> b
$ do
    Event
event <- EndPoint -> IO Event
receive EndPoint
endpoint
    case Event
event of
      ReceivedMulticast MulticastAddress
addr [ByteString
msg] ->
        let mix :: Maybe Int
mix = MulticastAddress
addr forall a. Eq a => a -> [a] -> Maybe Int
`elemIndex` [MulticastAddress]
addrs in
        case Maybe Int
mix of
          Maybe Int
Nothing -> forall a. HasCallStack => [Char] -> a
error [Char]
"Message from unexpected source"
          Just Int
ix -> forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ([ByteString]
msgs forall a. [a] -> Int -> a
!! (Int
ix forall a. Num a => a -> a -> a
+ Int
1) forall a. Eq a => a -> a -> Bool
/= ByteString
msg) forall a b. (a -> b) -> a -> b
$ forall a. HasCallStack => [Char] -> a
error [Char]
"Unexpected message"
      Event
_ ->
        forall a. HasCallStack => [Char] -> a
error [Char]
"Unexpected event"

  -- Success
  forall a. MVar a -> a -> IO ()
putMVar MVar ()
done ()

-- | Test that distinct multicast groups are not confused
testNoConfusion :: Transport -> Int -> IO ()
testNoConfusion :: Transport -> Int -> IO ()
testNoConfusion Transport
transport Int
numPings = do
  [MVar MulticastAddress
group1, MVar MulticastAddress
group2, MVar MulticastAddress
group3] <- forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
3 forall a. IO (MVar a)
newEmptyMVar
  [MVar ()
readyA, MVar ()
readyB, MVar ()
readyC] <- forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
3 forall a. IO (MVar a)
newEmptyMVar
  [MVar ()
doneA, MVar ()
doneB, MVar ()
doneC]    <- forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
3 forall a. IO (MVar a)
newEmptyMVar
  let [ByteString
msgA, ByteString
msgB, ByteString
msgC]    = [ByteString
"A says hi", ByteString
"B says hi", ByteString
"C says hi"]

  IO () -> IO ThreadId
forkIO forall a b. (a -> b) -> a -> b
$ Transport
-> [MVar MulticastAddress]
-> [MVar ()]
-> Int
-> [ByteString]
-> MVar ()
-> IO ()
noConfusionNode Transport
transport [MVar MulticastAddress
group1, MVar MulticastAddress
group1, MVar MulticastAddress
group2] [MVar ()
readyA, MVar ()
readyB, MVar ()
readyC] Int
numPings [ByteString
msgA, ByteString
msgA, ByteString
msgB] MVar ()
doneA
  IO () -> IO ThreadId
forkIO forall a b. (a -> b) -> a -> b
$ Transport
-> [MVar MulticastAddress]
-> [MVar ()]
-> Int
-> [ByteString]
-> MVar ()
-> IO ()
noConfusionNode Transport
transport [MVar MulticastAddress
group2, MVar MulticastAddress
group1, MVar MulticastAddress
group3] [MVar ()
readyB, MVar ()
readyC, MVar ()
readyA] Int
numPings [ByteString
msgB, ByteString
msgA, ByteString
msgC] MVar ()
doneB
  IO () -> IO ThreadId
forkIO forall a b. (a -> b) -> a -> b
$ Transport
-> [MVar MulticastAddress]
-> [MVar ()]
-> Int
-> [ByteString]
-> MVar ()
-> IO ()
noConfusionNode Transport
transport [MVar MulticastAddress
group3, MVar MulticastAddress
group2, MVar MulticastAddress
group3] [MVar ()
readyC, MVar ()
readyA, MVar ()
readyB] Int
numPings [ByteString
msgC, ByteString
msgB, ByteString
msgC] MVar ()
doneC

  forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ forall a. MVar a -> IO a
takeMVar [MVar ()
doneA, MVar ()
doneB, MVar ()
doneC]

-- | Test multicast
testMulticast :: Transport -> IO ()
testMulticast :: Transport -> IO ()
testMulticast Transport
transport =
  [([Char], IO ())] -> IO ()
runTests
    [ ([Char]
"NoConfusion", Transport -> Int -> IO ()
testNoConfusion Transport
transport Int
10000) ]