module Control.Concurrent.Bag.TaskBuffer
 ( TaskBufferSTM (..)
 , SplitFunction
 , takeFirst )
where

import Control.Concurrent.STM
  ( STM
  , TChan
  , newTChan
  , writeTChan
  , readTChan
  , tryReadTChan
  , isEmptyTChan
  , unGetTChan
  , retry )
import Control.Concurrent.STM.TStack
import Control.Monad ( liftM )
import Data.Maybe ( isNothing )

class TaskBufferSTM b where
  newBufferSTM     :: STM (b a)
  writeBufferSTM   :: b a -> a -> STM ()
  -- | Put the data back into the buffer.
  --   The item will be the next item read.
  unGetBufferSTM   :: b a -> a -> STM ()

  readBufferSTM    :: b a -> STM a
  readBufferSTM buf = do
    thing <- tryReadBufferSTM buf
    case thing of
      Nothing -> retry
      Just v  -> return v

  tryReadBufferSTM :: b a -> STM (Maybe a)
  tryReadBufferSTM buf = do
    empty <- isEmptyBufferSTM buf
    if empty
      then return Nothing
      else (liftM Just) $ readBufferSTM buf

  isEmptyBufferSTM :: b a -> STM Bool
  isEmptyBufferSTM = (liftM isNothing) . tryReadBufferSTM

  splitVertical :: SplitFunction b r
  splitVertical to from = do
    -- This is ok, because there has to be a value in it.
    first <- readBufferSTM from
    splitRest to from
    return first
   where
    splitRest to from = do
      first  <- tryReadBufferSTM from
      second <- tryReadBufferSTM from
      case (first, second) of
        (Nothing, _)      -> return ()
        (Just f, Nothing) -> do
          unGetBufferSTM from f
        (Just f, Just s)  -> do
          splitRest to from
          unGetBufferSTM to   s
          unGetBufferSTM from f

  splitHalf :: SplitFunction b r
  splitHalf to from = do
    splitRest to from 0
    -- This is ok, because we always put a value into it.
    first <- readBufferSTM to
    return first
   where
    splitRest to from n = do
      first  <- tryReadBufferSTM from
      case first of
        Nothing -> return n
        Just f  -> do
          c <- splitRest to from (n+1)
          if c > 0
            then do
              unGetBufferSTM to   f
              return (c-2)
            else do
              unGetBufferSTM from f
              return (c-2)



instance TaskBufferSTM TChan where
  newBufferSTM     = newTChan
  writeBufferSTM   = writeTChan
  unGetBufferSTM   = unGetTChan
  readBufferSTM    = readTChan
  tryReadBufferSTM = tryReadTChan
  isEmptyBufferSTM = isEmptyTChan

instance TaskBufferSTM TStack where
  newBufferSTM     = newTStack
  writeBufferSTM   = writeTStack
  unGetBufferSTM   = writeTStack
  readBufferSTM    = readTStack
  isEmptyBufferSTM = isEmptyTStack

-- Split functions --
type SplitFunction b r = b (IO (Maybe r)) -> b (IO (Maybe r)) -> STM (IO (Maybe r))

takeFirst :: TaskBufferSTM b =>
             b (IO (Maybe r))
          -> b (IO (Maybe r))
          -> STM (IO (Maybe r))
takeFirst = const readBufferSTM