{-|
Module      : Control.Concurrent.Bag.TaskBufferSTM
Description : Task buffers for the STM monad
Copyright   : (c) Bastian Holst, 2014
License     : BSD3
Maintainer  : bastianholst@gmx.de
Stability   : experimental
Portability : POSIX

This module contains the definition of a task buffer in the 'STM' monad,
'TaskBufferSTM', and possible split functions besides the functions to create a
'Stack' and a 'Queue' buffer.
-}
module Control.Concurrent.Bag.TaskBufferSTM
 ( TaskBufferSTM (..)
 , BufferType (..)
 , SplitFunction
 , takeFirst
 , splitVertical
 , splitHalf
 , newChanBufferSTM
 , newStackBufferSTM )
where

import Control.Concurrent.STM
  ( STM
  , TChan
  , newTChan
  , writeTChan
  , readTChan
  , tryReadTChan
  , isEmptyTChan
  , unGetTChan
  , retry )
import Control.Concurrent.STM.TStack
import Control.Monad ( liftM )
import Data.Maybe ( isNothing, fromJust )
import Control.Concurrent.Bag.BufferType

-- | A buffer holding tasks.
--
--   For this type, all access functions are using the 'STM' monad.
--
--   Note, that this is not a type class because we want to allow the user
--   to select between multiple buffers other than on type level.
data TaskBufferSTM a = TaskBufferSTM {
    -- | Function to write an item into the buffer in the normal way.
    writeBufferSTM   :: a -> STM ()
    -- | Function to write an item into the buffer at the read end.
  , unGetBufferSTM   :: a -> STM ()
    -- | Function to read item from the buffer. Blocks if empty.
  , readBufferSTM    :: STM a
    -- | Function to try to read an item from the buffer. Returns 'Nothing' if
    --   empty.
  , tryReadBufferSTM :: STM (Maybe a)
    -- | Check whether the buffer is empty.
  , isEmptyBufferSTM :: STM Bool
  }

-- | Create a new Queue buffer from a 'TChan'.
newChanBufferSTM :: STM (TaskBufferSTM r)
newChanBufferSTM = do
  c <- newTChan
  return $ TaskBufferSTM (writeTChan c) (unGetTChan c) (readTChan c) (tryReadTChan c) (isEmptyTChan c)

-- | Create a new Stack buffer from a 'TStack'.
newStackBufferSTM :: STM (TaskBufferSTM r)
newStackBufferSTM = do
  s <- newTStack
  return $ TaskBufferSTM (writeTStack s) (writeTStack s) (readTStack s) (tryReadTStack s) (isEmptyTStack s)

-- | Split functions are used to split the contents of the source buffer into
--   two parts. One part is left in this buffer or put back later; the other part
--   is written into the sink buffer. One element of this part is returned in
--   the STM monad. This is why the source buffer should always have at least
--   one item available. If it has not, the action will suspend.
type SplitFunction r = TaskBufferSTM (IO (Maybe r))
                    -> TaskBufferSTM (IO (Maybe r))
                    -> STM (IO (Maybe r))

-- | Just take the first item from the source buffer.
takeFirst :: SplitFunction r
takeFirst _ from = readBufferSTM from

-- | Split the buffer vertically. Every other element of the source remains
--   there. All other elements are put into the sink buffer.
splitVertical :: SplitFunction r
splitVertical to from = do
  -- This is ok, because there has to be a value in it.
  first <- readBufferSTM from
  splitRest to from
  return first
  where
  splitRest to from = do
    first  <- tryReadBufferSTM from
    second <- tryReadBufferSTM from
    case (first, second) of
      (Nothing, _)      -> return ()
      (Just f, Nothing) -> do
        unGetBufferSTM from f
      (Just f, Just s)  -> do
        splitRest to from
        unGetBufferSTM to   s
        unGetBufferSTM from f

-- | Split the buffer in two halves. Takes one half out of the source buffer
--   and puts it into the sink buffer.
splitHalf :: SplitFunction r
splitHalf to from = do
  splitRest to from 0
  -- This is ok, because we always put a value into it.
  first <- readBufferSTM to
  return first
  where
  splitRest to from n = do
    first  <- tryReadBufferSTM from
    case first of
      Nothing -> return n
      Just f  -> do
        c <- splitRest to from (n+1)
        if c > 0
          then do
            unGetBufferSTM to   f
            return (c-2)
          else do
            unGetBufferSTM from f
            return (c-2)