-- | Concurrent queue for single reader, single writer
{-# LANGUAGE MagicHash, UnboxedTuples #-}
module Control.Distributed.Process.Internal.CQueue 
  ( CQueue
  , BlockSpec(..)
  , newCQueue
  , enqueue
  , dequeue
  , mkWeakCQueue
  ) where

import Prelude hiding (length, reverse)
import Control.Concurrent.STM 
  ( atomically
  , TChan
  , newTChan
  , writeTChan
  , readTChan
  , tryReadTChan
  )
import Control.Applicative ((<$>), (<*>))
import Control.Exception (mask, onException)
import System.Timeout (timeout)
import Control.Distributed.Process.Internal.StrictMVar 
  ( StrictMVar(StrictMVar)
  , newMVar
  , takeMVar
  , putMVar
  )
import Control.Distributed.Process.Internal.StrictList
  ( StrictList(..)
  , reverse
  , reverse'
  )
import GHC.MVar (MVar(MVar))
import GHC.IO (IO(IO)) 
import GHC.Prim (mkWeak#)
import GHC.Weak (Weak(Weak))

-- We use a TCHan rather than a Chan so that we have a non-blocking read
data CQueue a = CQueue (StrictMVar (StrictList a)) -- Arrived
                       (TChan a)                   -- Incoming

newCQueue :: IO (CQueue a)
newCQueue = CQueue <$> newMVar Nil <*> atomically newTChan

-- | Enqueue an element
--
-- Enqueue is strict.
enqueue :: CQueue a -> a -> IO ()
enqueue (CQueue _arrived incoming) !a = atomically $ writeTChan incoming a 

data BlockSpec = 
    NonBlocking
  | Blocking
  | Timeout Int

-- | Dequeue an element
--
-- The timeout (if any) is applied only to waiting for incoming messages, not
-- to checking messages that have already arrived
dequeue :: forall a b. 
           CQueue a          -- ^ Queue
        -> BlockSpec         -- ^ Blocking behaviour 
        -> [a -> Maybe b]    -- ^ List of matches
        -> IO (Maybe b)      -- ^ 'Nothing' only on timeout
dequeue (CQueue arrived incoming) blockSpec matches = go 
  where    
    go :: IO (Maybe b)
    go = mask $ \restore -> do
      arr <- takeMVar arrived 
      -- We first check the arrived messages. If we get interrupted during this
      -- search, we just put the MVar back (we haven't read from the Chan yet)
      (arr', mb) <- onException (restore (checkArrived Nil arr))
                                (putMVar arrived arr) 
      case (mb, blockSpec) of
        (Just b, _) -> do 
          putMVar arrived arr'
          return (Just b)
        (Nothing, NonBlocking) ->
          checkNonBlocking arr'
        (Nothing, Blocking) ->
          Just <$> checkBlocking arr' 
        (Nothing, Timeout n) ->
          timeout n $ checkBlocking arr'

    -- We reverse the accumulator on return only if we find a match
    checkArrived :: StrictList a -> StrictList a -> IO (StrictList a, Maybe b)
    checkArrived acc Nil = return (acc, Nothing)
    checkArrived acc (Cons x xs) = 
      case check x of
        Just y  -> return (reverse' acc xs, Just y)
        Nothing -> checkArrived (Cons x acc) xs

    -- If we call checkBlocking there may or may not be a timeout
    checkBlocking :: StrictList a -> IO b
    checkBlocking acc = do
      -- readTChan is a blocking call, and hence is interruptable. If it is 
      -- interrupted, we put the value of the accumulator in 'arrived'  
      -- (as opposed to the original value), so that no messages get lost
      -- (hence the low-level structure using mask rather than modifyMVar)
      x <- onException (atomically $ readTChan incoming)
                       (putMVar arrived $ reverse acc)
      case check x of
        Nothing -> checkBlocking (Cons x acc)
        Just y  -> putMVar arrived (reverse acc) >> return y 

    -- checkNonBlocking is only called if there is no timeout
    checkNonBlocking :: StrictList a -> IO (Maybe b)
    checkNonBlocking acc = do
      -- tryReadTChan is *not* interruptible
      mx <- atomically $ tryReadTChan incoming
      case mx of
        Nothing -> putMVar arrived (reverse acc) >> return Nothing
        Just x  -> case check x of
          Nothing -> checkNonBlocking (Cons x acc)
          Just y  -> putMVar arrived (reverse acc) >> return (Just y)
        
    check :: a -> Maybe b
    check = checkMatches matches 

    checkMatches :: [a -> Maybe b] -> a -> Maybe b
    checkMatches []     _ = Nothing
    checkMatches (m:ms) a = case m a of Nothing -> checkMatches ms a
                                        Just b  -> Just b

-- | Weak reference to a CQueue
mkWeakCQueue :: CQueue a -> IO () -> IO (Weak (CQueue a))
mkWeakCQueue m@(CQueue (StrictMVar (MVar m#)) _) f = IO $ \s ->
  case mkWeak# m# m f s of (# s1, w #) -> (# s1, Weak w #)