{-# LANGUAGE Safe, MultiParamTypeClasses, FlexibleContexts, FunctionalDependencies, TypeOperators #-} module Control.Monad.Distributive where import qualified Control.Monad.State.Strict as Strict import Control.Monad.State import Control.Monad.Writer import Control.Monad.Reader import Control.Monad.Maybe import Control.Monad.Error import Control.Monad.List import Control.Monad.Identity import Control.Monad.Morph swap ~(x, y) = (y, x) class (MonadTrans m) => Takeout m y | m -> y where -- | Pop out the underlying monad of a transformer, with a data structure to hold the state. takeout :: (Monad n) => m n t -> m Identity (n (y t)) -- | Put the data structure back in. combine :: (Monad x) => y t -> m x t instance Takeout (Strict.StateT s) ((,) s) where takeout m = Strict.get >>= return . liftM swap . Strict.runStateT m combine (s, x) = Strict.put s >> return x instance Takeout (StateT s) ((,) s) where takeout m = get >>= return . liftM swap . runStateT m combine (s, x) = put s >> return x instance Takeout (ReaderT r) Identity where takeout m = ask >>= return . liftM Identity . runReaderT m combine = return . runIdentity instance (Monoid w) => Takeout (WriterT w) ((,) w) where takeout = return . liftM swap . runWriterT combine (w, x) = tell w >> return x -- | The opposite of takeout. putin m = hoist lift (liftM (hoist (return . runIdentity)) m) >>= lift putin1 m = hoist (return . runIdentity) m >>= lift -- | Transformers that distribute over one another. -- -- For reorganizing a monad stack. class Leftdistr m where ldist :: (Monad (n x), Monad x) => m (n x) t -> n x (m Identity t) class Rightdistr m where rdist :: (Monad (n Identity), Monad (n x), MonadTrans n, MFunctor n, Monad x) => n Identity (m x t) -> m (n x) t instance Leftdistr MaybeT where ldist m = runMaybeT m >>= return . maybe mzero return instance (Error t) => Leftdistr (ErrorT t) where ldist m = runErrorT m >>= return . either throwError return instance (Monoid x) => Leftdistr (WriterT x) where ldist m = runWriterT m >>= \(x, w) -> return $ tell w >> return x instance Leftdistr ListT where ldist m = runListT m >>= return . msum . map return instance Rightdistr (Strict.StateT v) where rdist m = get >>= \s -> lift $ putin1 $ liftM (`Strict.evalStateT` s) m instance Rightdistr (StateT v) where rdist m = get >>= \s -> lift $ putin1 $ liftM (`evalStateT` s) m instance Rightdistr (ReaderT v) where rdist m = ask >>= \v -> lift $ putin1 $ liftM (`runReaderT` v) m -- | Left distributivity of a monad transformer. ldist' m = putin $ ldist m -- | Right distributivity. rdist' m = rdist (takeout m) >>= lift . combine