{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeFamilies #-}
module Main (main) where

import           Control.Arrow ((&&&))
import           Control.Concurrent
import           Control.Concurrent.EQueue
import           Control.Concurrent.EQueue.Class
import           Control.Concurrent.EQueue.Simple
import           Control.Concurrent.EQueue.STMEQueue
import           Control.Concurrent.STM
import           Control.Monad
import           Control.Time
import           Data.Functor.Contravariant
import qualified Data.Map as Map
import           Data.Semigroup
import           Data.Time
import           Test.Tasty
import           Test.Tasty.HUnit

main :: IO ()
main = defaultMain tests

tests :: TestTree
tests = testGroup "EQueue"
  [ stmeqTests
  , anyTests
  , forceEdgeTests
  , mappedTests
  , simpleTests "ChanEQueue Tests" (CEQ <$> newChan)
  , simpleTests "TChanEQueue Tests" (TCEQ <$> newTChanIO)
  , simpleTests "TQueueEQueue Tests" (TQEQ <$> newTQueueIO)
  , ioTests
  ]

numSources :: STMEQueue a -> IO Int
numSources = fmap Map.size . atomically . readTVar . _eqActiveSources

trivialSource :: STM (Maybe ())
trivialSource = return (Just ())

waitEQBlocks :: (Eq a, Show a) => STMEQueue a -> IO ()
waitEQBlocks eq = do
  r <- timeout (0.05::Double) $ waitEQ eq RequireEvent
  r @=? Nothing

waitEQDelaysForAdd :: NominalDiffTime -> TestTree
waitEQDelaysForAdd d = testCase ("waitEQ delay for register ("++show d++")") $ do
  eq <- newSTMEQueue
  st <- getCurrentTime
  void $ forkIO (delay ((fromRational.toRational $ d)::Double) >> (void $ register eq trivialSource))
  void $ waitEQ eq RequireEvent
  t <- (`diffUTCTime` st) <$> getCurrentTime
  assertBool "In time window" (t >= d && t < 2*d)

stmeqTests :: TestTree
stmeqTests = testGroup "STMEQueue"
  [ testCase "No initial Sources" $ do
      (0 @?=) =<< numSources =<< newSTMEQueue
  , testCase "register adds" $ do
        eq <- newSTMEQueue
        (0 @?=) =<< numSources eq
        void $ register eq trivialSource
        (1 @?=) =<< numSources eq
  , testCase "Empty waitEQ blocks" $ do
      eq <- newSTMEQueue::IO (STMEQueue ())
      waitEQBlocks eq
  , testCase "Empty waitEQ returns immediate" $ do
      eq <- newSTMEQueue::IO (STMEQueue ())
      ([] @=?) =<< waitEQ eq ReturnImmediate
  , testCase "waitEQ retrieves" $ do
      eq <- newSTMEQueue
      void $ register eq trivialSource
      ([()] @=?) =<< waitEQ eq RequireEvent
      ([()] @=?) =<< waitEQ eq RequireEvent
  , waitEQDelaysForAdd 0.01
  , waitEQDelaysForAdd 0.05
  , testCase "register killer removes" $ do
      eq <- newSTMEQueue
      k <- register eq trivialSource
      ([()] @=?) =<< waitEQ eq RequireEvent
      k
      waitEQBlocks eq
  , testGroup "registerSemi"
    [ testCase "adds" $ do
        eq <- newSTMEQueue
        (add, _) <- registerSemi eq id
        (1 @?=) =<< numSources eq
        add ()
        ([()] @=?) =<< waitEQ eq RequireEvent
    , testCase "appends" $ do
        eq::STMEQueue [Int] <- newSTMEQueue
        (add, _) <- registerSemi eq id
        add [0]
        add [1]
        ([[0,1]] @=?) =<< waitEQ eq RequireEvent
        add [2]
        ([[2]] @=?) =<< waitEQ eq RequireEvent
    , testCase "killer removes" $ do
        eq <- newSTMEQueue
        (add, k) <- registerSemi eq id
        add ()
        ([()] @=?) =<< waitEQ eq RequireEvent
        add ()
        k
        waitEQBlocks eq
    , testCase "maps" $ do
        eq::STMEQueue (Char, Int) <- newSTMEQueue
        (add, _) <- registerSemi eq (('a',) . getMax::Max Int -> (Char, Int))
        add (Max 0)
        add (Max 1)
        ([('a', 1)] @=?) =<< waitEQ eq RequireEvent
    ]
  , testGroup "registerQueued"
    [ testCase "adds" $ do
        eq::STMEQueue () <- newSTMEQueue
        (add, _) <- registerQueued eq
        (1 @?=) =<< numSources eq
        add ()
        ([()] @=?) =<< waitEQ eq RequireEvent
    , testCase "appends" $ do
        eq::STMEQueue Int <- newSTMEQueue
        (add, _) <- registerQueued eq
        add 0
        add 1
        ([0] @=?) =<< waitEQ eq RequireEvent
        ([1] @=?) =<< waitEQ eq RequireEvent
    , testCase "killer removes" $ do
        eq::STMEQueue () <- newSTMEQueue
        (add, k) <- registerQueued eq
        add ()
        ([()] @=?) =<< waitEQ eq RequireEvent
        add ()
        k
        waitEQBlocks eq
    ]
  , testCase "Gets all latest" $ do
      eq::STMEQueue (Either String Int) <- newSTMEQueue
      (addQ, _) <- registerQueued eq
      (addS, _) <- registerSemi eq id
      addQ (Left "a")
      addQ (Left "b")
      addQ (Left "c")
      ([Left "a"] @=?) =<< waitEQ eq RequireEvent
      addS (Right 2)
      addS (Right 1)
      ([Left "b", Right 2] @=?) =<< waitEQ eq RequireEvent
      ([Left "c"] @=?) =<< waitEQ eq RequireEvent
  ]

anyTests :: TestTree
anyTests = testGroup "AnyEQueue"
  [ testGroup "registerSemi"
    [ testCase "adds" $ do
        eq <- newSTMEQueue
        (add, _) <- registerSemi (AEQ eq) id
        (1 @?=) =<< numSources eq
        add ()
        ([()] @=?) =<< waitEQ eq RequireEvent
    , testCase "appends" $ do
        eq::STMEQueue [Int] <- newSTMEQueue
        (add, _) <- registerSemi (AEQ eq) id
        add [0]
        add [1]
        ([[0,1]] @=?) =<< waitEQ eq RequireEvent
        add [2]
        ([[2]] @=?) =<< waitEQ eq RequireEvent
    , testCase "killer removes" $ do
        eq <- newSTMEQueue
        (add, k) <- registerSemi (AEQ eq) id
        add ()
        ([()] @=?) =<< waitEQ eq RequireEvent
        add ()
        k
        waitEQBlocks eq
    , testCase "maps" $ do
        eq::STMEQueue (Char, Int) <- newSTMEQueue
        (add, _) <- registerSemi (AEQ eq) (('a',) . getMax::Max Int -> (Char, Int))
        add (Max 0)
        add (Max 1)
        ([('a', 1)] @=?) =<< waitEQ eq RequireEvent
    ]
  , testGroup "registerQueued"
    [ testCase "adds" $ do
        eq::STMEQueue () <- newSTMEQueue
        (add, _) <- registerQueued (AEQ eq)
        (1 @?=) =<< numSources eq
        add ()
        ([()] @=?) =<< waitEQ eq RequireEvent
    , testCase "appends" $ do
        eq::STMEQueue Int <- newSTMEQueue
        (add, _) <- registerQueued (AEQ eq)
        add 0
        add 1
        ([0] @=?) =<< waitEQ eq RequireEvent
        ([1] @=?) =<< waitEQ eq RequireEvent
    , testCase "killer removes" $ do
        eq::STMEQueue () <- newSTMEQueue
        (add, k) <- registerQueued (AEQ eq)
        add ()
        ([()] @=?) =<< waitEQ eq RequireEvent
        add ()
        k
        waitEQBlocks eq
    ]
  , testCase "Gets all latest" $ do
      eq::STMEQueue (Either String Int) <- newSTMEQueue
      (addQ, _) <- registerQueued (AEQ eq)
      (addS, _) <- registerSemi (AEQ eq) id
      addQ (Left "a")
      addQ (Left "b")
      addQ (Left "c")
      ([Left "a"] @=?) =<< waitEQ eq RequireEvent
      addS (Right 2)
      addS (Right 1)
      ([Left "b", Right 2] @=?) =<< waitEQ eq RequireEvent
      ([Left "c"] @=?) =<< waitEQ eq RequireEvent
  ]

forceEdgeTests :: TestTree
forceEdgeTests = testGroup "AnyEQueue"
  [ testGroup "registerSemi"
    [ testCase "adds" $ do
        eq <- newSTMEQueue
        (add, _) <- registerSemi (EEQ eq) id
        (1 @?=) =<< numSources eq
        add ()
        ([()] @=?) =<< waitEQ eq RequireEvent
    , testCase "appends" $ do
        eq::STMEQueue [Int] <- newSTMEQueue
        (add, _) <- registerSemi (EEQ eq) id
        add [0]
        add [1]
        ([[0]] @=?) =<< waitEQ eq RequireEvent
        ([[1]] @=?) =<< waitEQ eq RequireEvent
        add [2]
        ([[2]] @=?) =<< waitEQ eq RequireEvent
    , testCase "killer removes" $ do
        eq <- newSTMEQueue
        (add, k) <- registerSemi (EEQ eq) id
        add ()
        ([()] @=?) =<< waitEQ eq RequireEvent
        add ()
        k
        waitEQBlocks eq
    , testCase "maps" $ do
        eq::STMEQueue (Char, Int) <- newSTMEQueue
        (add, _) <- registerSemi (EEQ eq) (('a',) . getMax::Max Int -> (Char, Int))
        add (Max 0)
        add (Max 1)
        ([('a', 0)] @=?) =<< waitEQ eq RequireEvent
        ([('a', 1)] @=?) =<< waitEQ eq RequireEvent
    ]
  , testGroup "registerQueued"
    [ testCase "adds" $ do
        eq::STMEQueue () <- newSTMEQueue
        (add, _) <- registerQueued (EEQ eq)
        (1 @?=) =<< numSources eq
        add ()
        ([()] @=?) =<< waitEQ eq RequireEvent
    , testCase "appends" $ do
        eq::STMEQueue Int <- newSTMEQueue
        (add, _) <- registerQueued (EEQ eq)
        add 0
        add 1
        ([0] @=?) =<< waitEQ eq RequireEvent
        ([1] @=?) =<< waitEQ eq RequireEvent
    , testCase "killer removes" $ do
        eq::STMEQueue () <- newSTMEQueue
        (add, k) <- registerQueued (EEQ eq)
        add ()
        ([()] @=?) =<< waitEQ eq RequireEvent
        add ()
        k
        waitEQBlocks eq
    ]
  , testCase "Gets all latest" $ do
      eq::STMEQueue (Either String Int) <- newSTMEQueue
      (addQ, _) <- registerQueued (EEQ eq)
      (addS, _) <- registerSemi (EEQ eq) id
      addQ (Left "a")
      addQ (Left "b")
      addQ (Left "c")
      ([Left "a"] @=?) =<< waitEQ eq RequireEvent
      addS (Right 2)
      addS (Right 1)
      ([Left "b", Right 2] @=?) =<< waitEQ eq RequireEvent
      ([Left "c", Right 1] @=?) =<< waitEQ eq RequireEvent
  ]

mappedTests :: TestTree
mappedTests = testGroup "MappedEQueue Tests"
  [ testGroup "non-mapped"
    [ testCase "Single Edge" $ do
        eq <- (MEQ id) <$> newSTMEQueue
        (add, k) <- registerQueued eq
        add 'a'
        (['a'] @=?) =<< waitEQ (meqEQ eq) ReturnImmediate
        (() @=?) =<< k
    , testCase "Double Edge" $ do
        eq <- (MEQ id) <$> newSTMEQueue
        (add, k) <- registerQueued eq
        add 'a'
        add 'b'
        (['a'] @=?) =<< waitEQ (meqEQ eq) ReturnImmediate
        (['b'] @=?) =<< waitEQ (meqEQ eq) ReturnImmediate
        (() @=?) =<< k
    , testCase "Single Level" $ do
        eq <- (MEQ id) <$> newSTMEQueue
        (add, k) <- registerSemi eq (fmap fromEnum)
        add "a"
        ([[fromEnum 'a']] @=?) =<< waitEQ (meqEQ eq) ReturnImmediate
        (() @=?) =<< k
    , testCase "Double Level" $ do
        eq <- (MEQ id) <$> newSTMEQueue
        (add, k) <- registerSemi eq (fmap fromEnum)
        add "a"
        add "b"
        ([[fromEnum 'a', fromEnum 'b']] @=?) =<< waitEQ (meqEQ eq) ReturnImmediate
        (() @=?) =<< k
    ]
  , testGroup "mapped"
    [ testCase "Single Edge" $ do
        eq <- (contramap fromEnum . MEQ id) <$> newSTMEQueue
        (add, k) <- registerQueued eq
        add 'a'
        ([fromEnum 'a'] @=?) =<< waitEQ (meqEQ eq) ReturnImmediate
        (() @=?) =<< k
    , testCase "Double Edge" $ do
        eq <- (contramap fromEnum  . MEQ id) <$> newSTMEQueue
        (add, k) <- registerQueued eq
        add 'a'
        add 'b'
        ([fromEnum 'a'] @=?) =<< waitEQ (meqEQ eq) ReturnImmediate
        ([fromEnum 'b'] @=?) =<< waitEQ (meqEQ eq) ReturnImmediate
        (() @=?) =<< k
    , testCase "Single Level" $ do
        eq <- (contramap (fmap fromEnum) . MEQ id) <$> newSTMEQueue
        (add, k) <- registerSemi eq ("+"++)
        add "a"
        ([fromEnum <$> "+a"] @=?) =<< waitEQ (meqEQ eq) ReturnImmediate
        (() @=?) =<< k
    , testCase "Double Level" $ do
        eq <- (contramap (fmap fromEnum)  . MEQ id) <$> newSTMEQueue
        (add, k) <- registerSemi eq ("+"++)
        add "a"
        add "b"
        ([fromEnum <$> "+ab"] @=?) =<< waitEQ (meqEQ eq) ReturnImmediate
        (() @=?) =<< k
    ]
  ]

simpleTests :: (EQueue eq, EQueueW eq, JustOneEventually ~ WaitPolicy eq)
            => TestName -> (forall a . IO (eq a)) -> TestTree
simpleTests nm new = testGroup nm
  [ testCase "Single Edge" $ do
      eq <- new
      (add, k) <- registerQueued eq
      add 'a'
      (['a'] @=?) =<< waitEQ eq JustOneEventually
      (() @=?) =<< k
  , testCase "Double Edge" $ do
      eq <- new
      (add, k) <- registerQueued eq
      add 'a'
      add 'b'
      (['a'] @=?) =<< waitEQ eq JustOneEventually
      (['b'] @=?) =<< waitEQ eq JustOneEventually
      (() @=?) =<< k
  , testCase "Single Level" $ do
      eq <- new
      (add, k) <- registerSemi eq (fmap fromEnum)
      add "a"
      ([[fromEnum 'a']] @=?) =<< waitEQ eq JustOneEventually
      (() @=?) =<< k
  , testCase "Double Level" $ do
      eq <- new
      (add, k) <- registerSemi eq (fmap fromEnum)
      add "a"
      add "b"
      ([[fromEnum 'a']] @=?) =<< waitEQ eq JustOneEventually
      ([[fromEnum 'b']] @=?) =<< waitEQ eq JustOneEventually
      (() @=?) =<< k
  ]

ioTests :: TestTree
ioTests = testGroup "IOEQueue Tests"
  [ testCase "Single Edge" $ do
      (eq, w) <- newIOEQ
      (add, k) <- registerQueued eq
      add 'a'
      (['a'] @=?) =<< w
      (() @=?) =<< k
  , testCase "Double Edge" $ do
      (eq, w) <- newIOEQ
      (add, k) <- registerQueued eq
      add 'a'
      add 'b'
      (['a'] @=?) =<< w
      (['b'] @=?) =<< w
      (() @=?) =<< k
  , testCase "Single Level" $ do
      (eq, w) <- newIOEQ
      (add, k) <- registerSemi eq (fmap fromEnum)
      add "a"
      ([[fromEnum 'a']] @=?) =<< w
      (() @=?) =<< k
  , testCase "Double Level" $ do
      (eq, w) <- newIOEQ
      (add, k) <- registerSemi eq (fmap fromEnum)
      add "a"
      add "b"
      ([[fromEnum 'a']] @=?) =<< w
      ([[fromEnum 'b']] @=?) =<< w
      (() @=?) =<< k
  ]
  where
    newIOEQ :: IO (IOEQueue a, IO [a])
    newIOEQ = ((IOEQ . writeChan) &&& (fmap pure . readChan)) <$> newChan
