{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}

{-|

Allocate resources which are guaranteed to be released.

One point to note: all register cleanup actions live in IO, not the main
monad. This allows both more efficient code, and for monads to be transformed.

-}

module Control.Monad.Resource
    (  -- * Data types
      ResourceT
    , ReleaseKey
      -- * Run
    , runResourceT
      -- * Resource allocation
    , with
    , register
    , release
      -- * Monad transformation
    , transResourceT
    )
where

import           Data.IntMap (IntMap)
import qualified Data.IntMap as IntMap
import           Data.IORef
                     ( IORef
                     , newIORef
                     , readIORef
                     , writeIORef
                     , atomicModifyIORef
                     )
import           Data.Word (Word)
import           Control.Applicative (Applicative (..))
import           Control.Exception (SomeException, mask, mask_, try, finally)
import           Control.Monad (liftM, when)
import           Control.Monad.Base (MonadBase (..))
import           Control.Monad.Cont.Class (MonadCont(..))
import           Control.Monad.Error.Class (MonadError (..))
import           Control.Monad.Fork.Class (MonadFork (..))
import           Control.Monad.IO.Class (MonadIO (..))
import           Control.Monad.Reader.Class (MonadReader (..))
import           Control.Monad.RWS.Class (MonadRWS (..))
import           Control.Monad.State.Class (MonadState (..))
import           Control.Monad.Writer.Class (MonadWriter (..))
import           Control.Monad.Trans.Class (MonadTrans (..))
import           Control.Monad.Trans.Control
                     ( MonadBaseControl (..)
                     , MonadTransControl (..)
                     , control
                     )


------------------------------------------------------------------------------
-- | A lookup key for a specific release action. This value is returned by
-- 'register' and 'with' and is passed to 'release'.
newtype ReleaseKey = ReleaseKey Int


------------------------------------------------------------------------------
data ReleaseMap = ReleaseMap !Int !Word !(IntMap (IO ()))


------------------------------------------------------------------------------
-- | The Resource transformer. This transformer keeps track of all registered
-- actions, and calls them upon exit (via 'runResourceT'). Actions may be
-- registered via 'register', or resources may be allocated atomically via
-- 'with'. The with function corresponds closely to @bracket@.
--
-- Releasing may be performed before exit via the 'release' function. This is
-- a highly recommended optimization, as it will ensure that scarce resources
-- are freed early. Note that calling @release@ will deregister the action, so
-- that a release action will only ever be called once.
newtype ResourceT m a = ResourceT (IORef ReleaseMap -> m a)


------------------------------------------------------------------------------
instance MonadTrans ResourceT where
    lift = ResourceT . const


------------------------------------------------------------------------------
instance MonadTransControl ResourceT where
    newtype StT ResourceT a = StReader {unStReader :: a}
    liftWith f = ResourceT $ \r -> f $ \(ResourceT t) -> liftM StReader $ t r
    restoreT = ResourceT . const . liftM unStReader


------------------------------------------------------------------------------
instance Functor m => Functor (ResourceT m) where
    fmap f (ResourceT m) = ResourceT $ \r -> fmap f (m r)


------------------------------------------------------------------------------
instance Applicative m => Applicative (ResourceT m) where
    pure = ResourceT . const . pure
    ResourceT mf <*> ResourceT ma = ResourceT $ \r -> mf r <*> ma r


------------------------------------------------------------------------------
instance Monad m => Monad (ResourceT m) where
    return = ResourceT . const . return
    ResourceT m >>= f = ResourceT $ \r -> m r >>= \a ->
        let ResourceT m' = f a in m' r


------------------------------------------------------------------------------
instance MonadIO m => MonadIO (ResourceT m) where
    liftIO = lift . liftIO


------------------------------------------------------------------------------
instance MonadBase b m => MonadBase b (ResourceT m) where
    liftBase = lift . liftBase


------------------------------------------------------------------------------
instance MonadBaseControl b m => MonadBaseControl b (ResourceT m) where
     newtype StM (ResourceT m) a = StMT (StM m a)
     liftBaseWith f = ResourceT $ \reader ->
         liftBaseWith $ \runInBase ->
             f $ liftM StMT . runInBase . (\(ResourceT r) -> r reader)
     restoreM (StMT base) = ResourceT $ const $ restoreM base


------------------------------------------------------------------------------
instance (MonadFork m, MonadBaseControl IO m) => MonadFork (ResourceT m) where
    fork (ResourceT f) = ResourceT $ \istate ->
        control $ \run -> mask $ \unmask -> do
            stateAlloc istate
            run . fork $ control $ \run' -> do
                unmask (run' $ f istate) `finally` stateCleanup istate


------------------------------------------------------------------------------
instance MonadCont m => MonadCont (ResourceT m) where
    callCC = liftCallCC callCC
      where
        liftCallCC ccc f = ResourceT $ \r -> ccc $ \ c ->
            let ResourceT m = f (ResourceT . const . c) in m r
    

------------------------------------------------------------------------------
instance MonadError e m => MonadError e (ResourceT m) where
    throwError = lift . throwError
    catchError = liftCatch catchError
      where
        liftCatch f (ResourceT m) h = ResourceT $ \r ->
            f (m r) (\e -> let ResourceT m' = h e in m' r)


------------------------------------------------------------------------------
instance MonadReader r m => MonadReader r (ResourceT m) where
    ask = lift ask
    local f (ResourceT m) = ResourceT $ local f . m


------------------------------------------------------------------------------
instance MonadRWS r w s m => MonadRWS r w s (ResourceT m)


------------------------------------------------------------------------------
instance MonadState s m => MonadState s (ResourceT m) where
    get = lift get
    put s = lift $ put s


------------------------------------------------------------------------------
instance MonadWriter w m => MonadWriter w (ResourceT m) where
    tell w = lift $ tell w
    listen = transResourceT listen
    pass = transResourceT pass


------------------------------------------------------------------------------
-- | Perform some allocation, and automatically register a cleanup action.
with :: MonadBase IO m
     => IO a -- ^ allocate
     -> (a -> IO ()) -- ^ free resource
     -> ResourceT m (ReleaseKey, a)
with acquire m = ResourceT $ \istate -> liftBase $ mask $ \unmask -> do
    a <- unmask acquire
    key <- register' istate $ m a
    return (key, a)


------------------------------------------------------------------------------
-- | Register some action that will be called precisely once, either when
-- 'runResourceT' is called, or when the 'ReleaseKey' is passed to 'release'.
register :: MonadBase IO m => IO () -> ResourceT m ReleaseKey
register m = ResourceT $ \istate -> liftBase $ register' istate m

register' :: IORef ReleaseMap -> IO () -> IO ReleaseKey
register' istate m = atomicModifyIORef istate $ \(ReleaseMap key ref im) ->
    (ReleaseMap (key + 1) ref (IntMap.insert key m im), ReleaseKey key)


------------------------------------------------------------------------------
-- | Call a release action early, and deregister it from the list of cleanup
-- actions to be performed.
release :: MonadBase IO m => ReleaseKey -> ResourceT m ()
release key = ResourceT $ \istate -> liftBase $ release' istate key

release' istate (ReleaseKey key) = mask $ \unmask -> do
    atomicModifyIORef istate lookupAction >>= maybe (return ()) unmask
  where
    lookupAction rm@(ReleaseMap key' ref im) =
        case IntMap.lookup key im of
            Nothing -> (rm, Nothing)
            Just m -> (ReleaseMap key' ref $ IntMap.delete key im, Just m)


------------------------------------------------------------------------------
-- | Transform the monad a @ResourceT@ lives in. This is most often used to
-- strip or add new transformers to a stack, e.g. to run a @ReaderT@.
transResourceT :: (m a -> n b) -> ResourceT m a -> ResourceT n b
transResourceT f (ResourceT mx) = ResourceT (\r -> f (mx r))


------------------------------------------------------------------------------
-- | Unwrap a 'ResourceT' transformer, and call all registered release
-- actions.
--
-- Note that there is some reference counting involved due to the 'MonadFork'
-- instance. If multiple threads are sharing the same collection of resources,
-- only the last call to @runResourceT@ will deallocate the resources.
runResourceT :: MonadBaseControl IO m => ResourceT m a -> m a
runResourceT (ResourceT r) = do
    istate <- liftBase $ newIORef $ ReleaseMap 0 0 IntMap.empty
    control $ \run -> mask $ \unmask -> do
        stateAlloc istate
        unmask (run $ r istate) `finally` stateCleanup istate


------------------------------------------------------------------------------
stateAlloc :: IORef ReleaseMap -> IO ()
stateAlloc istate = atomicModifyIORef istate $ \(ReleaseMap key ref im) ->
    (ReleaseMap key (ref + 1) im, ())


------------------------------------------------------------------------------
stateCleanup :: IORef ReleaseMap -> IO ()
stateCleanup istate = mask_ $ do
    (ref, im) <- atomicModifyIORef istate $ \(ReleaseMap key ref im) ->
        (ReleaseMap key (ref - 1) im, (ref - 1, im))
    when (ref == 0) $ do
        mapM_ (\x -> try' x >> return ()) $ IntMap.elems im
        writeIORef istate $ error "Control.Monad.Resource.Trans.stateCleanup:\
            \ There is a bug in the implementation. The mutable state is\
            \ being accessed after cleanup. Please contact the maintainers."
  where
    try' = try :: IO a -> IO (Either SomeException a)