module Control.Monad.Distributive where
import qualified Control.Monad.State.Strict as Strict
import Control.Monad.State
import Control.Monad.Writer
import Control.Monad.Reader
import Control.Monad.Error
import Control.Monad.List
import Control.Monad.Identity
import Control.Monad.Morph
swap ~(x, y) = (y, x)
class (MonadTrans m) => Takeout m y | m -> y where
takeout :: (Monad n) => m n t -> m Identity (n (y t))
combine :: (Monad x) => y t -> m x t
instance Takeout (Strict.StateT s) ((,) s) where
takeout m = Strict.get >>= return . liftM swap . Strict.runStateT m
combine (s, x) = Strict.put s >> return x
instance Takeout (StateT s) ((,) s) where
takeout m = get >>= return . liftM swap . runStateT m
combine (s, x) = put s >> return x
instance Takeout (ReaderT r) Identity where
takeout m = ask >>= return . liftM Identity . runReaderT m
combine = return . runIdentity
instance (Monoid w) => Takeout (WriterT w) ((,) w) where
takeout = return . liftM swap . runWriterT
combine (w, x) = tell w >> return x
putin m = hoist lift (liftM (hoist (return . runIdentity)) m) >>= lift
putin1 m = hoist (return . runIdentity) m >>= lift
class Leftdistr m where
ldist :: (Monad (n x), Monad x) => m (n x) t -> n x (m Identity t)
class Rightdistr m where
rdist :: (Monad (n Identity), Monad (n x), MonadTrans n, MFunctor n, Monad x) => n Identity (m x t) -> m (n x) t
instance (Error t) => Leftdistr (ErrorT t) where
ldist m = runErrorT m >>= return . either throwError return
instance (Monoid x) => Leftdistr (WriterT x) where
ldist m = runWriterT m >>= \(x, w) -> return $ tell w >> return x
instance Leftdistr ListT where
ldist m = runListT m >>= return . msum . map return
instance Rightdistr (Strict.StateT v) where
rdist m = get >>= \s -> lift $ putin1 $ liftM (`Strict.evalStateT` s) m
instance Rightdistr (StateT v) where
rdist m = get >>= \s -> lift $ putin1 $ liftM (`evalStateT` s) m
instance Rightdistr (ReaderT v) where
rdist m = ask >>= \v -> lift $ putin1 $ liftM (`runReaderT` v) m
ldist' m = putin $ ldist m
rdist' m = rdist (takeout m) >>= lift . combine