{-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE Rank2Types #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE DeriveGeneric, DeriveDataTypeable #-} {-# LANGUAGE FlexibleInstances, FlexibleContexts, UndecidableInstances #-} {-# OPTIONS -Wno-name-shadowing #-} module Control.Monad.Free ( module Control.Monad, module Control.Monad.Fail, -- * Free Monads MonadFree(..), Free(..), isPure, isImpure, foldFree, evalFree, mapFree, mapFreeM, mapFreeM', -- * Monad Morphisms foldFreeM, induce, -- * Free Monad Transformers FreeT(..), foldFreeT, foldFreeT', mapFreeT, foldFreeA, mapFreeA, -- * Translate between Free monad and Free monad transformer computations trans, trans', untrans,liftFree ) where import Control.Applicative import Control.Monad hiding (fail) import Control.Monad.Fail import Control.Monad.Trans.Class import Control.Monad.IO.Class import Data.Bifunctor import Data.Foldable import Data.Functor.Classes import Data.Traversable as T import Data.Typeable (Typeable) import GHC.Generics (Generic) -- | This type class generalizes over encodings of Free Monads. class (Functor f, Monad m) => MonadFree f m where free :: m a -> m (Either a (f (m a))) -- ^ 'Opens' a computation and allows to observe the side effects wrap :: f (m a) -> m a -- ^ Wraps a side effect into a monadic computation instance Functor f => MonadFree f (Free f) where free = evalFree (Pure . Left) (Pure . Right) wrap = Impure data Free f a = Impure (f (Free f a)) | Pure a deriving (Generic, Typeable) instance (Eq1 f) => Eq1 (Free f) where liftEq (==) (Pure a) (Pure b) = a == b liftEq (==) (Impure a) (Impure b) = liftEq (liftEq (==)) a b liftEq _ _ _ = False instance (Eq a, Eq1 f) => Eq (Free f a) where (==) = eq1 instance Ord1 f => Ord1 (Free f) where liftCompare _ Impure{} Pure{} = LT liftCompare _ Pure{} Impure{} = GT liftCompare compare (Pure a) (Pure b) = compare a b liftCompare compare (Impure a) (Impure b) = liftCompare (liftCompare compare) a b instance (Ord a, Ord1 f) => Ord (Free f a) where compare = compare1 instance (Show a, Show1 f) => Show (Free f a) where showsPrec p (Pure a) = showParen (p > 0) $ ("Pure " ++) . showsPrec 11 a showsPrec p (Impure a) = showParen (p > 0) $ ("Impure " ++) . liftShowsPrec showsPrec showList 11 a instance Functor f => Functor (Free f) where fmap f = go where go (Pure a) = Pure (f a) go (Impure fa) = Impure (fmap go fa) {-# INLINE fmap #-} instance (Functor f, Foldable f) => Foldable (Free f) where foldMap f (Pure a) = f a foldMap f (Impure fa) = fold $ fmap (foldMap f) fa instance Traversable f => Traversable (Free f) where traverse f (Pure a) = Pure <$> f a traverse f (Impure a) = Impure <$> traverse (traverse f) a instance Functor f => Monad (Free f) where return = Pure Pure a >>= f = f a Impure fa >>= f = Impure (fmap (>>= f) fa) instance Functor f => Applicative (Free f) where pure = Pure Pure f <*> x = fmap f x Impure f <*> x = Impure (fmap (<*> x) f) isPure, isImpure :: Free f a -> Bool isPure Pure{} = True; isPure _ = False isImpure = not . isPure foldFree :: Functor f => (a -> b) -> (f b -> b) -> Free f a -> b foldFree pure _ (Pure x) = pure x foldFree pure imp (Impure x) = imp (fmap (foldFree pure imp) x) foldFreeM :: (Traversable f, Monad m) => (a -> m b) -> (f b -> m b) -> Free f a -> m b foldFreeM pure _ (Pure x) = pure x foldFreeM pure imp (Impure x) = imp =<< T.mapM (foldFreeM pure imp) x foldFreeA :: (Traversable f, Applicative m) => (a -> m b) -> m (f b -> b) -> Free f a -> m b foldFreeA pure _ (Pure x) = pure x foldFreeA pure imp (Impure x) = imp <*> traverse (foldFreeA pure imp) x induce :: (Functor f, Monad m) => (forall a. f a -> m a) -> Free f a -> m a induce f = foldFree return (join . f) evalFree :: (a -> b) -> (f(Free f a) -> b) -> Free f a -> b evalFree p _ (Pure x) = p x evalFree _ i (Impure x) = i x mapFree :: (Functor f, Functor g) => (f (Free g a) -> g (Free g a)) -> Free f a -> Free g a mapFree eta = foldFree Pure (Impure . eta) mapFreeM :: (Traversable f, Functor g, Monad m) => (f (Free g a) -> m(g (Free g a))) -> Free f a -> m(Free g a) mapFreeM eta = foldFreeM (return . Pure) (liftM Impure . eta) mapFreeA :: (Traversable f, Functor g, Applicative m) => m (f (Free g a) -> g (Free g a)) -> Free f a -> m(Free g a) mapFreeA eta = foldFreeA (pure . Pure) (liftA (Impure .) eta) mapFreeM' :: (Functor f, Traversable g, Monad m) => (forall a. f a -> m(g a)) -> Free f a -> m(Free g a) mapFreeM' eta = foldFree (return . Pure) (liftM Impure . join . liftM T.sequence . eta) -- * Monad Transformer -- (built upon Luke Palmer control-monad-free hackage package) newtype FreeT f m a = FreeT { unFreeT :: m (Either a (f (FreeT f m a))) } instance (Traversable m, Traversable f) => Foldable (FreeT f m) where foldMap = foldMapDefault instance (Traversable m, Traversable f) => Traversable (FreeT f m) where traverse f (FreeT a) = FreeT <$> ( traverse f' a) where f' (Left x) = Left <$> f x f' (Right x) = Right <$> (traverse.traverse) f x instance (Functor f, Functor m) => Functor (FreeT f m) where fmap f = FreeT . fmap (bimap f ((fmap.fmap) f)) . unFreeT instance (Functor f, Functor a, Monad a) => Applicative (FreeT f a) where pure = FreeT . return . Left (<*>) = ap instance (Functor f, Monad m) => Monad (FreeT f m) where return = FreeT . return . Left m >>= f = FreeT $ unFreeT m >>= \r -> case r of Left x -> unFreeT $ f x Right xc -> return . Right $ fmap (>>= f) xc instance (Functor f, Monad m) => MonadFree f (FreeT f m) where wrap = FreeT . return . Right free = lift . unFreeT instance (Functor f) => MonadTrans (FreeT f) where lift = FreeT . liftM Left instance (Functor f, Monad m, MonadIO m) => MonadIO (FreeT f m) where liftIO = lift . liftIO instance (Functor f, Monad m, MonadPlus m) => MonadPlus (FreeT f m) where mzero = lift mzero mplus a b = FreeT (mplus (unFreeT a) (unFreeT b)) instance (Functor f, Functor m, Monad m, MonadPlus m) => Alternative (FreeT f m) where empty = mzero (<|>) = mplus foldFreeT :: (Traversable f, Monad m) => (a -> m b) -> (f b -> m b) -> FreeT f m a -> m b foldFreeT p i m = unFreeT m >>= \r -> case r of Left x -> p x Right fx -> T.mapM (foldFreeT p i) fx >>= i foldFreeT' :: (Traversable f, Monad m) => (a -> b) -> (f b -> b) -> FreeT f m a -> m b foldFreeT' p i (FreeT m) = m >>= f where f (Left x) = return (p x) f (Right fx) = i `liftM` T.mapM (foldFreeT' p i) fx mapFreeT :: (Functor f, Functor m) => (forall a. m a -> m' a) -> FreeT f m a -> FreeT f m' a mapFreeT f (FreeT m) = FreeT (f ((fmap.fmap.fmap) (mapFreeT f) m)) untrans :: (Traversable f, Monad m) => FreeT f m a -> m(Free f a) untrans = foldFreeT (return . Pure) (return . Impure) trans :: MonadFree f m => Free f a -> m a trans = foldFree return wrap trans' :: (Functor f, Monad m) => m(Free f a) -> FreeT f m a trans' = FreeT . join . liftM unFreeT . liftM trans liftFree :: (Functor f, Monad m) => (a -> Free f b) -> (a -> FreeT f m b) liftFree f = trans . f