{-# LANGUAGE CPP #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE BangPatterns  #-}
{-# LANGUAGE MagicHash, UnboxedTuples, PatternGuards, ScopedTypeVariables, RankNTypes #-}
-- | Concurrent queue for single reader, single writer
module Control.Distributed.Process.Internal.CQueue
  ( CQueue
  , BlockSpec(..)
  , MatchOn(..)
  , newCQueue
  , enqueue
  , enqueueSTM
  , dequeue
  , mkWeakCQueue
  , queueSize
  ) where

import Prelude hiding (length, reverse)
import Control.Concurrent.STM
  ( atomically
  , STM
  , TChan
  , TVar
  , modifyTVar'
  , tryReadTChan
  , newTChan
  , newTVarIO
  , writeTChan
  , readTChan
  , readTVarIO
  , orElse
  , retry
  )
import Control.Applicative ((<$>), (<*>))
import Control.Exception (mask_, onException)
import System.Timeout (timeout)
import Control.Distributed.Process.Internal.StrictMVar
  ( StrictMVar(StrictMVar)
  , newMVar
  , takeMVar
  , putMVar
  )
import Control.Distributed.Process.Internal.StrictList
  ( StrictList(..)
  , append
  )
import Data.Maybe (fromJust)
import Data.Traversable (traverse)
import GHC.MVar (MVar(MVar))
import GHC.IO (IO(IO), unIO)
import GHC.Exts (mkWeak#)
import GHC.Weak (Weak(Weak))

-- We use a TCHan rather than a Chan so that we have a non-blocking read
data CQueue a = CQueue (StrictMVar (StrictList a)) -- Arrived
                       (TChan a)                   -- Incoming
                       (TVar Int)                 -- Queue size

newCQueue :: IO (CQueue a)
newCQueue = CQueue <$> newMVar Nil <*> atomically newTChan <*> newTVarIO 0

-- | Enqueue an element
--
-- Enqueue is strict.
enqueue :: CQueue a -> a -> IO ()
enqueue c !a = atomically (enqueueSTM c a)

-- | Variant of enqueue for use in the STM monad.
enqueueSTM :: CQueue a -> a -> STM ()
enqueueSTM (CQueue _arrived incoming size) !a = do
   writeTChan incoming a
   modifyTVar' size succ

data BlockSpec =
    NonBlocking
  | Blocking
  | Timeout Int

-- Match operations
--
-- They can be either a message match or a channel match.
data MatchOn m a
 = MatchMsg  (m -> Maybe a)
 | MatchChan (STM a)
 deriving (Functor)

-- Lists of chunks of matches
--
-- Two consecutive chunks never have the same kind of matches. i.e. if one chunk
-- contains message matches then the next one must contain channel matches and
-- viceversa.
type MatchChunks m a = [Either [m -> Maybe a] [STM a]]

-- Splits a list of matches into chunks.
--
-- > concatMap (either (map MatchMsg) (map MatchChan)) . chunkMatches == id
--
chunkMatches :: [MatchOn m a] -> MatchChunks m a
chunkMatches [] = []
chunkMatches (MatchMsg m : ms) = Left (m : chk) : chunkMatches rest
   where (chk, rest) = spanMatchMsg ms
chunkMatches (MatchChan r : ms) = Right (r : chk) : chunkMatches rest
   where (chk, rest) = spanMatchChan ms

-- | @spanMatchMsg = first (map (\(MatchMsg x) -> x)) . span isMatchMsg@
spanMatchMsg :: [MatchOn m a] -> ([m -> Maybe a], [MatchOn m a])
spanMatchMsg [] = ([],[])
spanMatchMsg (m : ms)
    | MatchMsg msg <- m = (msg:msgs, rest)
    | otherwise         = ([], m:ms)
    where !(msgs,rest) = spanMatchMsg ms

-- | @spanMatchMsg = first (map (\(MatchChan x) -> x)) . span isMatchChan@
spanMatchChan :: [MatchOn m a] -> ([STM a], [MatchOn m a])
spanMatchChan [] = ([],[])
spanMatchChan (m : ms)
    | MatchChan stm <- m = (stm:stms, rest)
    | otherwise          = ([], m:ms)
    where !(stms,rest)  = spanMatchChan ms

-- | Dequeue an element
--
-- The timeout (if any) is applied only to waiting for incoming messages, not
-- to checking messages that have already arrived
dequeue :: forall m a.
           CQueue m          -- ^ Queue
        -> BlockSpec         -- ^ Blocking behaviour
        -> [MatchOn m a]     -- ^ List of matches
        -> IO (Maybe a)      -- ^ 'Nothing' only on timeout
dequeue (CQueue arrived incoming size) blockSpec matchons = mask_ $ decrementJust $
  case blockSpec of
    Timeout n -> timeout n $ fmap fromJust run
    _other    ->
       case chunks of
         [Right ports] -> -- channels only, this is easy:
           case blockSpec of
             NonBlocking -> atomically $ waitChans ports (return Nothing)
             _           -> atomically $ waitChans ports retry
                              -- no onException needed
         _other -> run
  where
    -- Decrement counter is smth is returned from the queue,
    -- this is safe to use as method is called under a mask
    -- and there is no 'unmasked' operation inside
    decrementJust :: IO (Maybe (Either a a)) -> IO (Maybe a)
    decrementJust f =
       traverse (either return (\x -> decrement >> return x)) =<< f
    decrement = atomically $ modifyTVar' size pred

    chunks = chunkMatches matchons

    run = do
           arr <- takeMVar arrived
           let grabNew xs = do
                 r <- atomically $ tryReadTChan incoming
                 case r of
                   Nothing -> return xs
                   Just x  -> grabNew (Snoc xs x)
           arr' <- grabNew arr
           goCheck chunks arr'

    -- Yields the value of the first succesful STM transaction as
    -- @Just (Left v)@. If all transactions fail, yields the value of the second
    -- argument.
    waitChans :: [STM a] -> STM (Maybe (Either a a)) -> STM (Maybe (Either a a))
    waitChans ports on_block =
        foldr orElse on_block (map (fmap (Just . Left)) ports)

    --
    -- First check the MatchChunks against the messages already in the
    -- mailbox.  For channel matches, we do a non-blocking check at
    -- this point.
    --
    -- Yields @Just (Left a)@ when a channel is matched, @Just (Right a)@
    -- when a message is matched and @Nothing@ when there are no messages and we
    -- aren't blocking.
    --
    goCheck :: MatchChunks m a
            -> StrictList m  -- messages to check, in this order
            -> IO (Maybe (Either a a))

    goCheck [] old = goWait old

    goCheck (Right ports : rest) old = do
      r <- atomically $ waitChans ports (return Nothing) -- does not block
      case r of
        Just _  -> returnOld old r
        Nothing -> goCheck rest old

    goCheck (Left matches : rest) old = do
           -- checkArrived might in principle take arbitrary time, so
           -- we ought to call restore and use an exception handler.  However,
           -- the check is usually fast (just a comparison), and the overhead
           -- of passing around restore and setting up exception handlers is
           -- high.  So just don't use expensive matchIfs!
      case checkArrived matches old of
        (old', Just r)  -> returnOld old' (Just (Right r))
        (old', Nothing) -> goCheck rest old'
          -- use the result list, which is now left-biased

    --
    -- Construct an STM transaction that looks at the relevant channels
    -- in the correct order.
    --
    mkSTM :: MatchChunks m a -> STM (Either m a)
    mkSTM [] = retry
    mkSTM (Left _ : rest)
      = fmap Left (readTChan incoming) `orElse` mkSTM rest
    mkSTM (Right ports : rest)
      = foldr orElse (mkSTM rest) (map (fmap Right) ports)

    waitIncoming :: IO (Maybe (Either m a))
    waitIncoming = case blockSpec of
      NonBlocking -> atomically $ fmap Just stm `orElse` return Nothing
      _           -> atomically $ fmap Just stm
     where
      stm = mkSTM chunks

    --
    -- The initial pass didn't find a message, so now we go into blocking
    -- mode.
    --
    -- Contents of 'arrived' from now on is (old ++ new), and
    -- messages that arrive are snocced onto new.
    --
    goWait :: StrictList m -> IO (Maybe (Either a a))
    goWait old = do
      r <- waitIncoming `onException` putMVar arrived old
      case r of
        --  Nothing => non-blocking and no message
        Nothing -> returnOld old Nothing
        Just e  -> case e of
          --
          -- Left => message arrived in the process mailbox.  We now have to
          -- run through the MatchChunks checking each one, because we might
          -- have a situation where the first chunk fails to match and the
          -- second chunk is a channel match and there *is* a message in the
          -- channel.  In that case the channel wins.
          --
          Left m -> goCheck1 chunks m old
          --
          -- Right => message arrived on a channel first
          --
          Right a -> returnOld old (Just (Left a))

    --
    -- A message arrived in the process inbox; check the MatchChunks for
    -- a valid match.
    --
    goCheck1 :: MatchChunks m a
             -> m               -- single message to check
             -> StrictList m    -- old messages we have already checked
             -> IO (Maybe (Either a a))

    goCheck1 [] m old = goWait (Snoc old m)

    goCheck1 (Right ports : rest) m old = do
      r <- atomically $ waitChans ports (return Nothing) -- does not block
      case r of
        Nothing -> goCheck1 rest m old
        Just _  -> returnOld (Snoc old m) r

    goCheck1 (Left matches : rest) m old = do
      case checkMatches matches m of
        Nothing -> goCheck1 rest m old
        Just p  -> returnOld old (Just (Right p))

    -- a common pattern for putting back the arrived queue at the end
    returnOld :: StrictList m -> Maybe (Either a a) -> IO (Maybe (Either a a))
    returnOld old r = do putMVar arrived old; return r

    -- as a side-effect, this left-biases the list
    checkArrived :: [m -> Maybe a] -> StrictList m -> (StrictList m, Maybe a)
    checkArrived matches list = go list Nil
      where
        -- @go xs ys@ searches for a message match in @append xs ys@
        go Nil Nil           = (Nil, Nothing)
        go Nil r             = go r Nil
        go (Append xs ys) tl = go xs (append ys tl)
        go (Snoc xs x)    tl = go xs (Cons x tl)
        go (Cons x xs)    tl
          | Just y <- checkMatches matches x = (append xs tl, Just y)
          | otherwise = let !(rest,r) = go xs tl in (Cons x rest, r)

    checkMatches :: [m -> Maybe a] -> m -> Maybe a
    checkMatches []     _ = Nothing
    checkMatches (m:ms) a = case m a of Nothing -> checkMatches ms a
                                        Just b  -> Just b

-- | Weak reference to a CQueue
mkWeakCQueue :: CQueue a -> IO () -> IO (Weak (CQueue a))
mkWeakCQueue m@(CQueue (StrictMVar (MVar m#)) _ _) f = IO $ \s ->
#if MIN_VERSION_base(4,9,0)
  case mkWeak# m# m (unIO f) s of (# s1, w #) -> (# s1, Weak w #)
#else
  case mkWeak# m# m f s of (# s1, w #) -> (# s1, Weak w #)
#endif

queueSize :: CQueue a -> IO Int
queueSize (CQueue _ _ size) = readTVarIO size