{-# LANGUAGE CPP #-} {-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE FunctionalDependencies #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE Safe #-} {-# LANGUAGE UndecidableInstances #-} module Control.Monad.ContextError ( ContextErrorT , runContextErrorT , ContextError , runContextError , MonadContextError (..) ) where #if !MIN_VERSION_mtl(2,2,0) import Control.Monad.Error #endif import Control.Applicative import Control.Monad (MonadPlus, mplus, mzero) import Control.Monad.Trans.Class (MonadTrans, lift) import Control.Monad.Trans.Cont as Cont (ContT, liftLocal) import Control.Monad.Trans.Except (ExceptT, mapExceptT) import Control.Monad.Trans.Identity (IdentityT, mapIdentityT) import Control.Monad.Trans.Maybe (MaybeT, mapMaybeT) import Control.Monad.Trans.Reader (ReaderT, mapReaderT) import qualified Control.Monad.Trans.RWS.Lazy as Lazy (RWST, mapRWST) import qualified Control.Monad.Trans.RWS.Strict as Strict (RWST, mapRWST) import qualified Control.Monad.Trans.State.Lazy as Lazy (StateT, mapStateT) import qualified Control.Monad.Trans.State.Strict as Strict (StateT, mapStateT) import Control.Monad.Trans.Writer.Lazy as Lazy (WriterT, mapWriterT) import Control.Monad.Trans.Writer.Strict as Strict (WriterT, mapWriterT) import Control.Monad.State (MonadState (..)) import Control.Monad.Reader (MonadReader (..)) import Control.Monad.Writer (MonadWriter (..)) import Data.Functor.Identity #if !MIN_VERSION_base(4,11,0) import Data.Semigroup #endif ---------------------------------------------------------------------- -- Monad newtype ContextErrorT c e m a = ContextErrorT { unContextErrorT :: forall b. c -> (e -> m b) -> (c -> a -> m b) -> m b } runContextErrorT :: (Monad m) => ContextErrorT c e m a -> c -> m (Either e a) runContextErrorT k c = unContextErrorT k c (return . Left) (const $ return . Right) type ContextError c e a = ContextErrorT c e Identity a runContextError :: ContextError c e a -> c -> Either e a runContextError k c = runIdentity $ unContextErrorT k c (return . Left) (const $ return . Right) instance Functor (ContextErrorT c e m) where fmap f e = ContextErrorT $ \c err ret -> unContextErrorT e c err (\c' -> ret c' . f) instance Applicative (ContextErrorT c e m) where pure a = ContextErrorT $ \c _ ret -> ret c a {-# INLINE pure #-} fe <*> ae = ContextErrorT $ \c err ret -> unContextErrorT fe c err (\c' f -> unContextErrorT ae c' err (\c'' -> ret c'' . f)) {-# INLINE (<*>) #-} instance (Semigroup e) => Alternative (ContextErrorT c e m) where -- FIXME: sane 'empty' needed! empty = ContextErrorT $ \_ err _ -> err (error "empty ContextErrorT") {-# INLINE empty #-} ae <|> be = ContextErrorT $ \c err ret -> unContextErrorT ae c (\e -> unContextErrorT be c (\e' -> err (e <> e')) ret) ret {-# INLINE (<|>) #-} instance Monad (ContextErrorT c e m) where return a = ContextErrorT $ \c _ ret -> ret c a {-# INLINE return #-} ma >>= fb = ContextErrorT $ \c err ret -> unContextErrorT ma c err $ \c' a -> unContextErrorT (fb a) c' err ret {-# INLINE (>>=) #-} instance (Semigroup e) => MonadPlus (ContextErrorT c e m) where mzero = empty {-# INLINE mzero #-} mplus = (<|>) {-# INLINE mplus #-} instance MonadTrans (ContextErrorT c e) where lift act = ContextErrorT $ \c _ ret -> act >>= ret c {-# INLINE lift #-} instance MonadState s m => MonadState s (ContextErrorT c e m) where get = lift get put = lift . put state = lift . state instance MonadWriter w m => MonadWriter w (ContextErrorT c e m) where writer = lift . writer tell = lift . tell listen m = ContextErrorT $ \c err ret -> do (res, w) <- listen (unContextErrorT m c (return . Left) (curry (return . Right))) case res of Left e -> err e Right (c', a) -> ret c' (a, w) pass m = ContextErrorT $ \c err ret -> pass $ do res <- unContextErrorT m c (return . Left) (curry (return . Right)) case res of Right (c', (a, f)) -> liftA (\b -> (b, f)) $ ret c' a Left e -> liftA (\b -> (b, id)) $ err e instance MonadReader r m => MonadReader r (ContextErrorT c e m) where ask = lift ask local f m = ContextErrorT $ \c err ret -> local f (unContextErrorT m c err ret) reader = lift . reader ---------------------------------------------------------------------- -- Monad class stuff class (Monad m) => MonadContextError c e m | m -> c e where throwInContext :: (c -> e) -> m a askContext :: m c localContext :: (c -> c) -> m a -> m a modifyContext :: (c -> c) -> m () instance Monad m => MonadContextError c e (ContextErrorT c e m) where throwInContext f = ContextErrorT $ \c err _ -> err (f c) askContext = ContextErrorT $ \c _ ret -> ret c c localContext f m = ContextErrorT $ \c err ret -> unContextErrorT m (f c) err (\_ -> ret c) modifyContext f = ContextErrorT $ \c _ ret -> ret (f c) () instance MonadContextError c e m => MonadContextError c e (ContT r m) where throwInContext = lift . throwInContext askContext = lift askContext localContext = Cont.liftLocal askContext localContext modifyContext = lift . modifyContext #if MIN_VERSION_mtl(2, 2, 0) instance MonadContextError c e m => MonadContextError c e (ExceptT e m) where throwInContext = lift . throwInContext askContext = lift askContext localContext = mapExceptT . localContext modifyContext = lift . modifyContext #else instance (Error e', MonadContextError c e m) => MonadContextError c e (ErrorT e' m) where throwInContext = lift . throwInContext askContext = lift askContext localContext = mapErrorT . localContext modifyContext = lift . modifyContext #endif instance MonadContextError c e m => MonadContextError c e (IdentityT m) where throwInContext = lift . throwInContext askContext = lift askContext localContext = mapIdentityT . localContext modifyContext = lift . modifyContext instance MonadContextError c e m => MonadContextError c e (MaybeT m) where throwInContext = lift . throwInContext askContext = lift askContext localContext = mapMaybeT . localContext modifyContext = lift . modifyContext instance MonadContextError c e m => MonadContextError c e (ReaderT r m) where throwInContext = lift . throwInContext askContext = lift askContext localContext = mapReaderT . localContext modifyContext = lift . modifyContext instance (Monoid w, MonadContextError c e m) => MonadContextError c e (Lazy.WriterT w m) where throwInContext = lift . throwInContext askContext = lift askContext localContext = Lazy.mapWriterT . localContext modifyContext = lift . modifyContext instance (Monoid w, MonadContextError c e m) => MonadContextError c e (Strict.WriterT w m) where throwInContext = lift . throwInContext askContext = lift askContext localContext = Strict.mapWriterT . localContext modifyContext = lift . modifyContext instance MonadContextError c e m => MonadContextError c e (Lazy.StateT s m) where throwInContext = lift . throwInContext askContext = lift askContext localContext = Lazy.mapStateT . localContext modifyContext = lift . modifyContext instance MonadContextError c e m => MonadContextError c e (Strict.StateT s m) where throwInContext = lift . throwInContext askContext = lift askContext localContext = Strict.mapStateT . localContext modifyContext = lift . modifyContext instance (Monoid w, MonadContextError c e m) => MonadContextError c e (Lazy.RWST r w s m) where throwInContext = lift . throwInContext askContext = lift askContext localContext = Lazy.mapRWST . localContext modifyContext = lift . modifyContext instance (Monoid w, MonadContextError c e m) => MonadContextError c e (Strict.RWST r w s m) where throwInContext = lift . throwInContext askContext = lift askContext localContext = Strict.mapRWST . localContext modifyContext = lift . modifyContext