module Data.Iteratee.STM (
  -- * Channel enumerator/iteratee primitives
  iterChan
 ,enumChan
  -- ** Channel control functions
 ,iterCloseChan
 ,enumCloseChan
  -- * Forking combinators
 ,forkIter
 ,forkEnum
)

where

import Control.Concurrent.STM.TBMChan
import Data.Iteratee as I

import Control.Applicative
import Control.Concurrent
import Control.Concurrent.STM
import Control.Exception (finally)
import Control.Monad
import Control.Monad.IO.Class

-- | return all available values from a TBMChan
drainChan :: TBMChan a -> STM [a]
drainChan chan = go []
  where
    go acc = tryReadTBMChan chan >>= \res -> case res of
      Just (Just a) -> go (a:acc)
      _             -> return acc

-- | Close a channel.
-- 
iterCloseChan :: (Nullable s, MonadIO m) => TBMChan s -> Iteratee s m ()
iterCloseChan chan = liftIO . atomically $ closeTBMChan chan

-- | An iteratee which writes all its data to a TBMChan.
-- 
-- The iteratee moves to a complete state when the channel is closed.
iterChan :: (Nullable s, MonadIO m) => TBMChan s -> Iteratee s m ()
iterChan chan = do
   stream_eof <- isFinished
   unless stream_eof $ do
     chnk <- getChunk
     wrote_chnk <- liftIO . atomically $ do
       isClosed <- isClosedTBMChan chan
       if isClosed then return False else writeTBMChan chan chnk >> return True
     if wrote_chnk
       then iterChan chan
       else idone () (Chunk chnk)

-- | Enumerate over data provided by a TBMChan.
enumChan :: (NullPoint s, MonadIO m) => TBMChan s -> Enumerator s m a
enumChan chan = enumFromCallback cb ()
 where
   cb () = do
       mres <- liftIO . atomically $ readTBMChan chan
       case mres of
           Nothing -> return $ Right ((False, ()), I.empty)
           Just s    -> return $ Right ((True, ()), s)

-- | An enumerator which closes the provided channel and sends EOF to the iteratee.
enumCloseChan :: (MonadIO m) => TBMChan s -> Enumerator s m a
enumCloseChan chan iter = do
  liftIO . atomically $ closeTBMChan chan
  enumEof iter

-- | Fork an enumerator to run in a separate thread, with a @sz@ upper bound on the
-- channel size.
-- 
-- The current thread will wait for the forked thread to terminate
forkEnum
  :: (MonadIO m, Nullable s, NullPoint s)
  => Int
  -> Enumerator s IO ()
  -> Enumerator s m a
forkEnum sz enum iter = do
   chan <- liftIO $ newTBMChanIO sz
   mvar <- liftIO $ newEmptyMVar
   liftIO . forkIO $ ((enum >>> enumCloseChan chan) (iterChan chan) >>= run)
                     `finally` putMVar mvar ()
   i2 <- enumChan chan iter
   liftIO $ readMVar mvar
   return i2

-- | Fork an iteratee to run in a separate thread, with a @sz@ upper bound on the
-- channel size.
-- 
-- The current thread will wait for the forked thread to finish before returning.
forkIter
  :: (Nullable s, NullPoint s, MonadIO m)
  => Int
  -> Iteratee s IO ()
  -> Iteratee s m ()
forkIter sz iter = do
  chan <- liftIO $ newTBMChanIO sz
  mvar <- liftIO $ newEmptyMVar
  liftIO . forkIO $ (enumChan chan iter >>= run) `finally` putMVar mvar ()
  iterChan chan
  iterCloseChan chan
  liftIO $ readMVar mvar