module Control.Concurrent.MVar.Expiring
( ExpiringMVar
, newExpiringMVar
, readExpiringMVar
, resetExpiringMVarTimer
, isExpiredMVar
, cancelExpiration
, changeExpiration
, removeExpiredMVars
) where

import Control.Concurrent
import Control.Concurrent.MVar
import qualified Data.Foldable as Foldable
import Control.Monad
import Data.Traversable as Traversable
import Control.Applicative
import Data.Monoid

{- | An ExpiringMVar contains a value that will be thrown out after
   a given period of time. The timer can be reset before the value
   expires. -}
data ExpiringMVar a = ExpiringMVar { var :: MVar a
                                   , expireTime :: Int
                                   , expirer :: ThreadId
                                   }
{- | Create a new value that is set to be thrown away after a minimum period of time. Each call to newExpiringMVar spawns a thread that runs until the value expires. -}

newExpiringMVar :: a -- ^ The value which will expire.
                -> Int -- ^ The number of microseconds after which the value will expire.
                   -- Note that the caveats which apply to Control.Concurrent.threadDelay apply to the expiration, also.
                -> IO (ExpiringMVar a)
newExpiringMVar v delay = do
  var <- newMVar v
  clearThreadId <- forkExpireIn delay var
  return $ ExpiringMVar var delay clearThreadId

{- | If the value has not yet expired, you are able to retrieve it.
   Reading the value does not expire it. In other words, readExpiringMVar behaves like readMVar.-}
readExpiringMVar :: ExpiringMVar a -> IO (Maybe a)
readExpiringMVar (ExpiringMVar var _ _) = do
  v <- tryTakeMVar var
  case v of
    Nothing -> return Nothing
    Just v -> do {putMVar var v; return (Just v)}

{- | If the value has not yet expired, reset the timer. If the value expired,
   no timer is created. -}
resetExpiringMVarTimer :: ExpiringMVar a -> IO (ExpiringMVar a)
resetExpiringMVarTimer expires@(ExpiringMVar var expireTime _) = do
  wasCleared <- isExpiredMVar expires
  if wasCleared
    then return expires
    else do cancelExpiration expires
            newExpirer <- forkExpireIn expireTime var
            return $ expires { expirer = newExpirer }

forkExpireIn :: Int -> MVar a -> IO ThreadId
forkExpireIn delay var = forkIO $ expireIn delay var

expireIn :: Int -> MVar a -> IO ()
expireIn delay var = do
  threadDelay delay
  empty <- isEmptyMVar var
  tryTakeMVar var
  return ()

{- | Determine whether an MVar has expired. -}
isExpiredMVar :: ExpiringMVar a -> IO Bool
isExpiredMVar expires =
    isEmptyMVar (var expires)

{- | If you decide that a value should never expire, you can cancel the timer. -}
cancelExpiration :: ExpiringMVar a -> IO ()
cancelExpiration expires =
    killThread (expirer expires)

{- | If the value hasn't yet expired, change the timer and reset it. -}
changeExpiration :: Int -- ^ The new number of milliseconds to set the expiration timer to.
                 -> ExpiringMVar a -> IO (ExpiringMVar a)
changeExpiration newDelay expires =
  resetExpiringMVarTimer (expires { expireTime = newDelay })

{- I made mfilterM with the help of byorgey on #haskell. -}

mfilterM :: (Monoid (f a), Applicative f, Traversable t, Monad m) =>
            (a -> m Bool) -> t a -> m (f a)
mfilterM p c = do
  t <- Traversable.mapM (\x -> do 
                           b <- p x
                           return (if b 
                                   then pure x
                                   else mempty)
                        ) c
  return $ Foldable.fold t

{- | For a collection of ExpiringMVars, filter out the ones that have expired, and put the remaining ExpiringMVars in a new collection. -} 
removeExpiredMVars :: (Monoid (f (ExpiringMVar a)), Alternative f, Traversable.Traversable t) =>
                  t (ExpiringMVar a) -> IO (f (ExpiringMVar a))
removeExpiredMVars = mfilterM (\x -> do {b <- isExpiredMVar x; return $ not b})

secondsToMicroseconds :: Int -> Int
secondsToMicroseconds = (* 1000000)

testAux :: [ExpiringMVar Char] -> IO ()
testAux xs = do
  xs' <- removeExpiredMVars xs :: IO [ExpiringMVar Char]
  print $ length xs'
  when (length xs' > 0) $
       do {threadDelay $ secondsToMicroseconds 10; testAux xs'}

test :: IO ()
test = do
  xs <- Control.Monad.mapM (\d -> newExpiringMVar 'c' (d*1000)) [1..100000]
  testAux xs