module Control.Concurrent.Bag.TaskBuffer
 ( TaskBufferSTM (..)
 , SplitFunction
 , takeFirst
 , splitVertical
 , splitHalf
 , BufferType (..)
 , newChanBuffer
 , newStackBuffer )
where

import Control.Concurrent.STM
  ( STM
  , TChan
  , newTChan
  , writeTChan
  , readTChan
  , tryReadTChan
  , isEmptyTChan
  , unGetTChan
  , retry )
import Control.Concurrent.STM.TStack
import Control.Monad ( liftM )
import Data.Maybe ( isNothing, fromJust )

data TaskBufferSTM a = TaskBufferSTM {
    writeBufferSTM   :: a -> STM ()
  , unGetBufferSTM   :: a -> STM ()
  , readBufferSTM    :: STM a
  , tryReadBufferSTM :: STM (Maybe a)
  , isEmptyBufferSTM :: STM Bool
  }

data BufferType = Queue | Stack

newChanBuffer :: STM (TaskBufferSTM r)
newChanBuffer = do
  c <- newTChan
  return $ TaskBufferSTM (writeTChan c) (unGetTChan c) (readTChan c) (tryReadTChan c) (isEmptyTChan c)

newStackBuffer :: STM (TaskBufferSTM r)
newStackBuffer = do
  s <- newTStack
  return $ TaskBufferSTM (writeTStack s) (writeTStack s) (readTStack s) (tryReadTStack s) (isEmptyTStack s)

-- Split functions --
type SplitFunction r = TaskBufferSTM (IO (Maybe r)) -> TaskBufferSTM (IO (Maybe r)) -> STM (IO (Maybe r))

takeFirst :: SplitFunction r
takeFirst _ from = readBufferSTM from

splitVertical :: SplitFunction r
splitVertical to from = do
  -- This is ok, because there has to be a value in it.
  first <- readBufferSTM from
  splitRest to from
  return first
  where
  splitRest to from = do
    first  <- tryReadBufferSTM from
    second <- tryReadBufferSTM from
    case (first, second) of
      (Nothing, _)      -> return ()
      (Just f, Nothing) -> do
        unGetBufferSTM from f
      (Just f, Just s)  -> do
        splitRest to from
        unGetBufferSTM to   s
        unGetBufferSTM from f

splitHalf :: SplitFunction r
splitHalf to from = do
  splitRest to from 0
  -- This is ok, because we always put a value into it.
  first <- readBufferSTM to
  return first
  where
  splitRest to from n = do
    first  <- tryReadBufferSTM from
    case first of
      Nothing -> return n
      Just f  -> do
        c <- splitRest to from (n+1)
        if c > 0
          then do
            unGetBufferSTM to   f
            return (c-2)
          else do
            unGetBufferSTM from f
            return (c-2)