{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE DeriveGeneric, DeriveDataTypeable #-}
{-# LANGUAGE FlexibleInstances, FlexibleContexts, UndecidableInstances #-}
{-# OPTIONS -Wno-name-shadowing #-}
module Control.Monad.Free (
module Control.Monad,
module Control.Monad.Fail,
MonadFree(..),
Free(..), isPure, isImpure,
foldFree,
evalFree, mapFree, mapFreeM, mapFreeM',
foldFreeM,
induce,
FreeT(..),
foldFreeT, foldFreeT', mapFreeT,
foldFreeA, mapFreeA,
trans, trans', untrans,liftFree
) where
import Control.Applicative
import Control.Monad hiding (fail)
import Control.Monad.Fail
import Control.Monad.Trans.Class
import Control.Monad.IO.Class
import Data.Bifunctor
import Data.Foldable
import Data.Functor.Classes
import Data.Traversable as T
import Data.Typeable (Typeable)
import GHC.Generics (Generic)
class (Functor f, Monad m) => MonadFree f m where
free :: m a -> m (Either a (f (m a)))
wrap :: f (m a) -> m a
instance Functor f => MonadFree f (Free f) where
free = evalFree (Pure . Left) (Pure . Right)
wrap = Impure
data Free f a = Impure (f (Free f a)) | Pure a deriving (Generic, Typeable)
instance (Eq1 f) => Eq1 (Free f) where
liftEq (==) (Pure a) (Pure b) = a == b
liftEq (==) (Impure a) (Impure b) = liftEq (liftEq (==)) a b
liftEq _ _ _ = False
instance (Eq a, Eq1 f) => Eq (Free f a) where (==) = eq1
instance Ord1 f => Ord1 (Free f) where
liftCompare _ Impure{} Pure{} = LT
liftCompare _ Pure{} Impure{} = GT
liftCompare compare (Pure a) (Pure b) = compare a b
liftCompare compare (Impure a) (Impure b) = liftCompare (liftCompare compare) a b
instance (Ord a, Ord1 f) => Ord (Free f a) where
compare = compare1
instance (Show a, Show1 f) => Show (Free f a) where
showsPrec p (Pure a) = showParen (p > 0) $ ("Pure " ++) . showsPrec 11 a
showsPrec p (Impure a) = showParen (p > 0) $ ("Impure " ++) . liftShowsPrec showsPrec showList 11 a
instance Functor f => Functor (Free f) where
fmap f = go where
go (Pure a) = Pure (f a)
go (Impure fa) = Impure (fmap go fa)
{-# INLINE fmap #-}
instance (Functor f, Foldable f) => Foldable (Free f) where
foldMap f (Pure a) = f a
foldMap f (Impure fa) = fold $ fmap (foldMap f) fa
instance Traversable f => Traversable (Free f) where
traverse f (Pure a) = Pure <$> f a
traverse f (Impure a) = Impure <$> traverse (traverse f) a
instance Functor f => Monad (Free f) where
return = Pure
Pure a >>= f = f a
Impure fa >>= f = Impure (fmap (>>= f) fa)
instance Functor f => Applicative (Free f) where
pure = Pure
Pure f <*> x = fmap f x
Impure f <*> x = Impure (fmap (<*> x) f)
isPure, isImpure :: Free f a -> Bool
isPure Pure{} = True; isPure _ = False
isImpure = not . isPure
foldFree :: Functor f => (a -> b) -> (f b -> b) -> Free f a -> b
foldFree pure _ (Pure x) = pure x
foldFree pure imp (Impure x) = imp (fmap (foldFree pure imp) x)
foldFreeM :: (Traversable f, Monad m) => (a -> m b) -> (f b -> m b) -> Free f a -> m b
foldFreeM pure _ (Pure x) = pure x
foldFreeM pure imp (Impure x) = imp =<< T.mapM (foldFreeM pure imp) x
foldFreeA :: (Traversable f, Applicative m) => (a -> m b) -> m (f b -> b) -> Free f a -> m b
foldFreeA pure _ (Pure x) = pure x
foldFreeA pure imp (Impure x) = imp <*> traverse (foldFreeA pure imp) x
induce :: (Functor f, Monad m) => (forall a. f a -> m a) -> Free f a -> m a
induce f = foldFree return (join . f)
evalFree :: (a -> b) -> (f(Free f a) -> b) -> Free f a -> b
evalFree p _ (Pure x) = p x
evalFree _ i (Impure x) = i x
mapFree :: (Functor f, Functor g) => (f (Free g a) -> g (Free g a)) -> Free f a -> Free g a
mapFree eta = foldFree Pure (Impure . eta)
mapFreeM :: (Traversable f, Functor g, Monad m) => (f (Free g a) -> m(g (Free g a))) -> Free f a -> m(Free g a)
mapFreeM eta = foldFreeM (return . Pure) (liftM Impure . eta)
mapFreeA :: (Traversable f, Functor g, Applicative m) =>
m (f (Free g a) -> g (Free g a)) -> Free f a -> m(Free g a)
mapFreeA eta = foldFreeA (pure . Pure) (liftA (Impure .) eta)
mapFreeM' :: (Functor f, Traversable g, Monad m) => (forall a. f a -> m(g a)) -> Free f a -> m(Free g a)
mapFreeM' eta = foldFree (return . Pure)
(liftM Impure . join . liftM T.sequence . eta)
newtype FreeT f m a = FreeT { unFreeT :: m (Either a (f (FreeT f m a))) }
instance (Traversable m, Traversable f) => Foldable (FreeT f m) where foldMap = foldMapDefault
instance (Traversable m, Traversable f) => Traversable (FreeT f m) where
traverse f (FreeT a) = FreeT <$> ( traverse f' a) where
f' (Left x) = Left <$> f x
f' (Right x) = Right <$> (traverse.traverse) f x
instance (Functor f, Functor m) => Functor (FreeT f m) where
fmap f = FreeT . fmap (bimap f ((fmap.fmap) f)) . unFreeT
instance (Functor f, Functor a, Monad a) => Applicative (FreeT f a) where
pure = FreeT . return . Left
(<*>) = ap
instance (Functor f, Monad m) => Monad (FreeT f m) where
return = FreeT . return . Left
m >>= f = FreeT $ unFreeT m >>= \r ->
case r of
Left x -> unFreeT $ f x
Right xc -> return . Right $ fmap (>>= f) xc
instance (Functor f, Monad m) => MonadFree f (FreeT f m) where
wrap = FreeT . return . Right
free = lift . unFreeT
instance (Functor f) => MonadTrans (FreeT f) where
lift = FreeT . liftM Left
instance (Functor f, Monad m, MonadIO m) => MonadIO (FreeT f m) where
liftIO = lift . liftIO
instance (Functor f, Monad m, MonadPlus m) => MonadPlus (FreeT f m) where
mzero = lift mzero
mplus a b = FreeT (mplus (unFreeT a) (unFreeT b))
instance (Functor f, Functor m, Monad m, MonadPlus m) => Alternative (FreeT f m) where
empty = mzero
(<|>) = mplus
foldFreeT :: (Traversable f, Monad m) => (a -> m b) -> (f b -> m b) -> FreeT f m a -> m b
foldFreeT p i m = unFreeT m >>= \r ->
case r of
Left x -> p x
Right fx -> T.mapM (foldFreeT p i) fx >>= i
foldFreeT' :: (Traversable f, Monad m) => (a -> b) -> (f b -> b) -> FreeT f m a -> m b
foldFreeT' p i (FreeT m) = m >>= f where
f (Left x) = return (p x)
f (Right fx) = i `liftM` T.mapM (foldFreeT' p i) fx
mapFreeT :: (Functor f, Functor m) => (forall a. m a -> m' a) -> FreeT f m a -> FreeT f m' a
mapFreeT f (FreeT m) = FreeT (f ((fmap.fmap.fmap) (mapFreeT f) m))
untrans :: (Traversable f, Monad m) => FreeT f m a -> m(Free f a)
untrans = foldFreeT (return . Pure) (return . Impure)
trans :: MonadFree f m => Free f a -> m a
trans = foldFree return wrap
trans' :: (Functor f, Monad m) => m(Free f a) -> FreeT f m a
trans' = FreeT . join . liftM unFreeT . liftM trans
liftFree :: (Functor f, Monad m) => (a -> Free f b) -> (a -> FreeT f m b)
liftFree f = trans . f