{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeFamilies #-}
module Control.Concurrent.EQueue.STMEQueue where
import           Control.Concurrent.EQueue.Class
import           Control.Concurrent.STM
import           Control.Monad.Trans
import           Data.Foldable
import           Data.Map (Map)
import qualified Data.Map as Map
import           Data.Maybe (catMaybes)
import           Data.Semigroup
import           Data.Unique
data STMEQueue a =
    STMEQueue
    { _eqActiveSources :: TVar (Map Unique (STM (Maybe a)))
    }
register :: (MonadIO m) => STMEQueue a -> STM (Maybe a) -> m (IO ())
register (STMEQueue tqm) g = liftIO $ do
  u <- newUnique
  atomically $ do
    modifyTVar tqm (Map.insert u g)
    return . atomically $ modifyTVar tqm (Map.delete u)
newSTMEQueue :: MonadIO m => m (STMEQueue a)
newSTMEQueue = liftIO $ STMEQueue <$> newTVarIO mempty
instance EQueue STMEQueue where
  registerSemi eq f = liftIO $ do
    t <- newEmptyTMVarIO
    (mappendTMVar t,) <$> register eq ((fmap f) <$> tryTakeTMVar t)
    where
      mappendTMVar :: Semigroup a => TMVar a -> a -> IO ()
      mappendTMVar t a = atomically $ do
        mv <- tryTakeTMVar t
        case mv of
          Nothing -> putTMVar t a
          Just v  -> putTMVar t (v <> a)
  registerQueued eq = liftIO $ do
    t <- newTQueueIO
    (atomically . writeTQueue t,) <$> register eq (tryReadTQueue t)
data STMEQueueWait =
    ReturnImmediate
    
  | RequireEvent
    
  deriving (Eq)
instance EQueueW STMEQueue where
  type WaitPolicy STMEQueue = STMEQueueWait
  waitEQ (STMEQueue tqm) wp = liftIO . atomically $ do
    qm <- readTVar tqm
    es <- catMaybes <$> (sequenceA . toList $ qm)
    if (null es && wp == RequireEvent)
      then retry
      else return es