module Pipes.Safe
    ( 
      SafeT
    , runSafeT
    , runSafeP
     
    , ReleaseKey
    , Base
    , MonadSafe(..)
    , onException
    , finally
    , bracket
    , bracket_
    , bracketOnError
    
    
    , module Control.Monad.Catch
    , module Control.Exception
    ) where
import Control.Applicative (Applicative(pure, (<*>)))
import Control.Exception(Exception(..), SomeException(..))
import qualified Control.Monad.Catch as C
import Control.Monad.Catch
    ( MonadCatch(..)
    , mask_
    , uninterruptibleMask_
    , catchAll
    , catchIOError
    , catchJust
    , catchIf
    , Handler(..)
    , catches
    , handle
    , handleAll
    , handleIOError
    , handleJust
    , handleIf
    , try
    , tryJust
    , Exception(..)
    , SomeException
    )
import Control.Monad.IO.Class (MonadIO(liftIO))
import Control.Monad.Trans.Class (MonadTrans(lift))
import qualified Control.Monad.Catch.Pure          as E
import qualified Control.Monad.Trans.Identity      as I
import qualified Control.Monad.Trans.Reader        as R
import qualified Control.Monad.Trans.RWS.Lazy      as RWS
import qualified Control.Monad.Trans.RWS.Strict    as RWS'
import qualified Control.Monad.Trans.State.Lazy    as S
import qualified Control.Monad.Trans.State.Strict  as S'
import qualified Control.Monad.Trans.Writer.Lazy   as W
import qualified Control.Monad.Trans.Writer.Strict as W'
import Data.IORef (IORef, newIORef, readIORef, writeIORef)
import qualified Data.Map as M
import Data.Monoid (Monoid)
import Pipes (Proxy, Effect, Effect', runEffect)
import Pipes.Internal (unsafeHoist, Proxy(..))
import Pipes.Lift (liftCatchError)
data Restore m = Unmasked | Masked (forall x . m x -> m x)
liftMask
    :: forall m a' a b' b r . (MonadIO m, MonadCatch m)
    => (forall s . ((forall x . m x -> m x) -> m s) -> m s)
    -> ((forall x . Proxy a' a b' b m x -> Proxy a' a b' b m x)
        -> Proxy a' a b' b m r)
    -> Proxy a' a b' b m r
liftMask maskFunction k = do
        ioref <- liftIO (newIORef Unmasked)
        let unmask
                :: forall y . (Monad m)
                => Proxy a' a b' b m y -> Proxy a' a b' b m y
            unmask p = do
                mRestore <- liftIO (readIORef ioref)
                case mRestore of
                    Unmasked       -> p
                    Masked restore -> do
                        r <- unsafeHoist restore p
                        lift $ restore $ return ()
                        return r
            loop p = case p of
                Request a' fa  -> Request a' (loop . fa )
                Respond b  fb' -> Respond b  (loop . fb')
                M m0           -> M $ maskFunction $ \restore -> do
                    liftIO $ writeIORef ioref (Masked restore)
                    let loop' m = do
                            p' <- m
                            case p' of
                                M m' -> loop' m'
                                _    -> return p'
                    p' <- loop' m0
                    liftIO $ writeIORef ioref  Unmasked
                    return (loop p')
                Pure r         -> Pure r
        loop (k unmask)
instance (MonadCatch m, MonadIO m) => MonadCatch (Proxy a' a b' b m) where
    throwM = lift . throwM
    catch  = liftCatchError C.catch
    mask                = liftMask mask
    uninterruptibleMask = liftMask uninterruptibleMask
data Finalizers m = Finalizers
    { _nextKey    :: !Integer
    , _finalizers :: !(M.Map Integer (m ()))
    }
newtype SafeT m r = SafeT { unSafeT :: R.ReaderT (IORef (Finalizers m)) m r }
instance (Monad m) => Functor (SafeT m) where
    fmap f m = SafeT (do
        r <- unSafeT m
        return (f r) )
instance (Monad m) => Applicative (SafeT m) where
    pure r = SafeT (return r)
    mf <*> mx = SafeT (do
        f <- unSafeT mf
        x <- unSafeT mx
        return (f x) )
instance (Monad m) => Monad (SafeT m) where
    return r = SafeT (return r)
    m >>= f = SafeT (do
        r <- unSafeT m
        unSafeT (f r) )
instance (MonadIO m) => MonadIO (SafeT m) where
    liftIO m = SafeT (liftIO m)
instance (MonadCatch m) => MonadCatch (SafeT m) where
    throwM e = SafeT (throwM e)
    m `catch` f = SafeT (unSafeT m `C.catch` \r -> unSafeT (f r))
    mask k = SafeT (mask (\restore ->
        unSafeT (k (\ma -> SafeT (restore (unSafeT ma)))) ))
    uninterruptibleMask k = SafeT (uninterruptibleMask (\restore ->
        unSafeT (k (\ma -> SafeT (restore (unSafeT ma)))) ))
instance MonadTrans SafeT where
    lift m = SafeT (lift m)
runSafeT :: (MonadCatch m, MonadIO m) => SafeT m r -> m r
runSafeT m = C.bracket
    (liftIO $ newIORef $! Finalizers 0 M.empty)
    (\ioref -> do
        Finalizers _ fs <- liftIO (readIORef ioref)
        mapM snd (M.toDescList fs) )
    (R.runReaderT (unSafeT m))
runSafeP :: (MonadCatch m, MonadIO m) => Effect (SafeT m) r -> Effect' m r
runSafeP = lift . runSafeT . runEffect
newtype ReleaseKey = ReleaseKey { unlock :: Integer }
type family Base (m :: * -> *) :: * -> *
type instance Base IO = IO
type instance Base (SafeT m) = m
type instance Base (Proxy a' a b' b m) = Base m
type instance Base (I.IdentityT m) = Base m
type instance Base (E.CatchT m) = Base m
type instance Base (R.ReaderT i m) = Base m
type instance Base (S.StateT s m) = Base m
type instance Base (S'.StateT s m) = Base m
type instance Base (W.WriterT w m) = Base m
type instance Base (W'.WriterT w m) = Base m
type instance Base (RWS.RWST i w s m) = Base m
type instance Base (RWS'.RWST i w s m) = Base m
class (MonadCatch m, MonadIO m, Monad (Base m)) => MonadSafe m where
    
    liftBase :: Base m r -> m r
    
    register :: Base m () -> m ReleaseKey
    
    release  :: ReleaseKey -> m ()
instance (MonadIO m, MonadCatch m) => MonadSafe (SafeT m) where
    liftBase = lift
    register io = do
        ioref <- SafeT R.ask
        liftIO $ do
            Finalizers n fs <- readIORef ioref
            writeIORef ioref $! Finalizers (n + 1) (M.insert n io fs)
            return (ReleaseKey n)
    release key = do
        ioref <- SafeT R.ask
        liftIO $ do
            Finalizers n fs <- readIORef ioref
            writeIORef ioref $! Finalizers n (M.delete (unlock key) fs)
instance (MonadSafe m) => MonadSafe (Proxy a' a b' b m) where
    liftBase = lift . liftBase
    register = lift . register
    release  = lift . release
instance (MonadSafe m) => MonadSafe (I.IdentityT m) where
    liftBase = lift . liftBase
    register = lift . register
    release  = lift . release
instance (MonadSafe m) => MonadSafe (E.CatchT m) where
    liftBase = lift . liftBase
    register = lift . register
    release  = lift . release
instance (MonadSafe m) => MonadSafe (R.ReaderT i m) where
    liftBase = lift . liftBase
    register = lift . register
    release  = lift . release
instance (MonadSafe m) => MonadSafe (S.StateT s m) where
    liftBase = lift . liftBase
    register = lift . register
    release  = lift . release
instance (MonadSafe m) => MonadSafe (S'.StateT s m) where
    liftBase = lift . liftBase
    register = lift . register
    release  = lift . release
instance (MonadSafe m, Monoid w) => MonadSafe (W.WriterT w m) where
    liftBase = lift . liftBase
    register = lift . register
    release  = lift . release
instance (MonadSafe m, Monoid w) => MonadSafe (W'.WriterT w m) where
    liftBase = lift . liftBase
    register = lift . register
    release  = lift . release
instance (MonadSafe m, Monoid w) => MonadSafe (RWS.RWST i w s m) where
    liftBase = lift . liftBase
    register = lift . register
    release  = lift . release
instance (MonadSafe m, Monoid w) => MonadSafe (RWS'.RWST i w s m) where
    liftBase = lift . liftBase
    register = lift . register
    release  = lift . release
onException :: (MonadSafe m) => m a -> Base m b -> m a
m1 `onException` io = do
    key <- register (io >> return ())
    r   <- m1
    release key
    return r
finally :: (MonadSafe m) => m a -> Base m b -> m a
m1 `finally` after = bracket_ (return ()) after m1
bracket :: (MonadSafe m) => Base m a -> (a -> Base m b) -> (a -> m c) -> m c
bracket before after action = mask $ \restore -> do
    h <- liftBase before
    r <- restore (action h) `onException` after h
    _ <- liftBase (after h)
    return r
bracket_ :: (MonadSafe m) => Base m a -> Base m b -> m c -> m c
bracket_ before after action = bracket before (\_ -> after) (\_ -> action)
bracketOnError
    :: (MonadSafe m) => Base m a -> (a -> Base m b) -> (a -> m c) -> m c
bracketOnError before after action = mask $ \restore -> do
    h <- liftBase before
    restore (action h) `onException` after h