{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE UndecidableInstances #-}
#include "free-common.h"
module Control.Monad.Trans.Free.Church
(
FT(..)
, F, free, runF
, improveT
, toFT, fromFT
, iterT
, iterTM
, hoistFT
, transFT
, joinFT
, cutoff
, improve
, fromF, toF
, retract
, retractT
, iter
, iterM
, MonadFree(..)
, liftF
) where
import Control.Applicative
import Control.Category ((<<<), (>>>))
import Control.Monad
import Control.Monad.Catch (MonadCatch(..), MonadThrow(..))
import Control.Monad.Identity
import Control.Monad.Trans.Class
import Control.Monad.IO.Class
import Control.Monad.Reader.Class
import Control.Monad.Writer.Class
import Control.Monad.State.Class
import Control.Monad.Error.Class
import Control.Monad.Cont.Class
import Control.Monad.Free.Class
import Control.Monad.Trans.Free (FreeT(..), FreeF(..), Free)
import qualified Control.Monad.Trans.Free as FreeT
import qualified Data.Foldable as F
import qualified Data.Traversable as T
import Data.Functor.Bind hiding (join)
import Data.Functor.Classes.Compat
#if !(MIN_VERSION_base(4,8,0))
import Data.Foldable (Foldable)
import Data.Traversable (Traversable)
#endif
newtype FT f m a = FT { runFT :: forall r. (a -> m r) -> (forall x. (x -> m r) -> f x -> m r) -> m r }
#ifdef LIFTED_FUNCTOR_CLASSES
instance (Functor f, Monad m, Eq1 f, Eq1 m) => Eq1 (FT f m) where
liftEq eq x y = liftEq eq (fromFT x) (fromFT y)
instance (Functor f, Monad m, Ord1 f, Ord1 m) => Ord1 (FT f m) where
liftCompare cmp x y= liftCompare cmp (fromFT x) (fromFT y)
#else
instance ( Functor f, Monad m, Eq1 f, Eq1 m
# if !(MIN_VERSION_base(4,8,0))
, Functor m
# endif
) => Eq1 (FT f m) where
eq1 x y = eq1 (fromFT x) (fromFT y)
instance ( Functor f, Monad m, Ord1 f, Ord1 m
# if !(MIN_VERSION_base(4,8,0))
, Functor m
# endif
) => Ord1 (FT f m) where
compare1 x y = compare1 (fromFT x) (fromFT y)
#endif
instance (Eq1 (FT f m), Eq a) => Eq (FT f m a) where
(==) = eq1
instance (Ord1 (FT f m), Ord a) => Ord (FT f m a) where
compare = compare1
instance Functor (FT f m) where
fmap f (FT k) = FT $ \a fr -> k (a . f) fr
instance Apply (FT f m) where
(<.>) = (<*>)
instance Applicative (FT f m) where
pure a = FT $ \k _ -> k a
FT fk <*> FT ak = FT $ \b fr -> fk (\e -> ak (\d -> b (e d)) fr) fr
instance Bind (FT f m) where
(>>-) = (>>=)
instance Monad (FT f m) where
return = pure
FT fk >>= f = FT $ \b fr -> fk (\d -> runFT (f d) b fr) fr
instance MonadFree f (FT f m) where
wrap f = FT (\kp kf -> kf (\ft -> runFT ft kp kf) f)
instance MonadTrans (FT f) where
lift m = FT (\a _ -> m >>= a)
instance Alternative m => Alternative (FT f m) where
empty = FT (\_ _ -> empty)
FT k1 <|> FT k2 = FT $ \a fr -> k1 a fr <|> k2 a fr
instance MonadPlus m => MonadPlus (FT f m) where
mzero = FT (\_ _ -> mzero)
mplus (FT k1) (FT k2) = FT $ \a fr -> k1 a fr `mplus` k2 a fr
instance (Foldable f, Foldable m, Monad m) => Foldable (FT f m) where
foldr f r xs = F.foldr (<<<) id inner r
where
inner = runFT xs (return . f) (\xg xf -> F.foldr (liftM2 (<<<) . xg) (return id) xf)
{-# INLINE foldr #-}
#if MIN_VERSION_base(4,6,0)
foldl' f z xs = F.foldl' (!>>>) id inner z
where
(!>>>) h g = \r -> g $! h r
inner = runFT xs (return . flip f) (\xg xf -> F.foldr (liftM2 (>>>) . xg) (return id) xf)
{-# INLINE foldl' #-}
#endif
instance (Monad m, Traversable m, Traversable f) => Traversable (FT f m) where
traverse f (FT k) = fmap (join . lift) . T.sequenceA $ k traversePure traverseFree
where
traversePure = return . fmap return . f
traverseFree xg = return . fmap (wrap . fmap (join . lift)) . T.traverse (T.sequenceA . xg)
instance (MonadIO m) => MonadIO (FT f m) where
liftIO = lift . liftIO
{-# INLINE liftIO #-}
instance (Functor f, MonadError e m) => MonadError e (FT f m) where
throwError = lift . throwError
{-# INLINE throwError #-}
m `catchError` f = toFT $ fromFT m `catchError` (fromFT . f)
instance MonadCont m => MonadCont (FT f m) where
callCC f = join . lift $ callCC (\k -> return $ f (lift . k . return))
instance MonadReader r m => MonadReader r (FT f m) where
ask = lift ask
{-# INLINE ask #-}
local f = hoistFT (local f)
{-# INLINE local #-}
instance (Functor f, MonadWriter w m) => MonadWriter w (FT f m) where
tell = lift . tell
{-# INLINE tell #-}
listen = toFT . listen . fromFT
pass = toFT . pass . fromFT
#if MIN_VERSION_mtl(2,1,1)
writer w = lift (writer w)
{-# INLINE writer #-}
#endif
instance MonadState s m => MonadState s (FT f m) where
get = lift get
{-# INLINE get #-}
put = lift . put
{-# INLINE put #-}
#if MIN_VERSION_mtl(2,1,1)
state f = lift (state f)
{-# INLINE state #-}
#endif
instance MonadThrow m => MonadThrow (FT f m) where
throwM = lift . throwM
{-# INLINE throwM #-}
instance (Functor f, MonadCatch m) => MonadCatch (FT f m) where
catch m f = toFT $ fromFT m `Control.Monad.Catch.catch` (fromFT . f)
{-# INLINE catch #-}
toFT :: Monad m => FreeT f m a -> FT f m a
toFT (FreeT f) = FT $ \ka kfr -> do
freef <- f
case freef of
Pure a -> ka a
Free fb -> kfr (\x -> runFT (toFT x) ka kfr) fb
fromFT :: (Monad m, Functor f) => FT f m a -> FreeT f m a
fromFT (FT k) = FreeT $ k (return . Pure) (\xg -> runFreeT . wrap . fmap (FreeT . xg))
type F f = FT f Identity
runF :: Functor f => F f a -> (forall r. (a -> r) -> (f r -> r) -> r)
runF (FT m) = \kp kf -> runIdentity $ m (return . kp) (\xg -> return . kf . fmap (runIdentity . xg))
free :: (forall r. (a -> r) -> (f r -> r) -> r) -> F f a
free f = FT (\kp kf -> return $ f (runIdentity . kp) (runIdentity . kf return))
iterT :: (Functor f, Monad m) => (f (m a) -> m a) -> FT f m a -> m a
iterT phi (FT m) = m return (\xg -> phi . fmap xg)
{-# INLINE iterT #-}
iterTM :: (Functor f, Monad m, MonadTrans t, Monad (t m)) => (f (t m a) -> t m a) -> FT f m a -> t m a
iterTM f (FT m) = join . lift $ m (return . return) (\xg -> return . f . fmap (join . lift . xg))
hoistFT :: (Monad m, Monad n) => (forall a. m a -> n a) -> FT f m b -> FT f n b
hoistFT phi (FT m) = FT (\kp kf -> join . phi $ m (return . kp) (\xg -> return . kf (join . phi . xg)))
transFT :: (forall a. f a -> g a) -> FT f m b -> FT g m b
transFT phi (FT m) = FT (\kp kf -> m kp (\xg -> kf xg . phi))
joinFT :: (Monad m, Traversable f) => FT f m a -> m (F f a)
joinFT (FT m) = m (return . return) (\xg -> liftM wrap . T.mapM xg)
cutoff :: (Functor f, Monad m) => Integer -> FT f m a -> FT f m (Maybe a)
cutoff n = toFT . FreeT.cutoff n . fromFT
#if __GLASGOW_HASKELL__ < 710
retract :: (Functor f, Monad f) => F f a -> f a
#else
retract :: Monad f => F f a -> f a
#endif
retract m = runF m return join
{-# INLINE retract #-}
retractT :: (MonadTrans t, Monad (t m), Monad m) => FT (t m) m a -> t m a
retractT (FT m) = join . lift $ m (return . return) (\xg xf -> return $ xf >>= join . lift . xg)
iter :: Functor f => (f a -> a) -> F f a -> a
iter phi = runIdentity . iterT (Identity . phi . fmap runIdentity)
{-# INLINE iter #-}
iterM :: (Functor f, Monad m) => (f (m a) -> m a) -> F f a -> m a
iterM phi = iterT phi . hoistFT (return . runIdentity)
fromF :: (Functor f, MonadFree f m) => F f a -> m a
fromF m = runF m return wrap
{-# INLINE fromF #-}
toF :: Free f a -> F f a
toF = toFT
{-# INLINE toF #-}
improve :: Functor f => (forall m. MonadFree f m => m a) -> Free f a
improve m = fromF m
{-# INLINE improve #-}
improveT :: (Functor f, Monad m) => (forall t. MonadFree f (t m) => t m a) -> FreeT f m a
improveT m = fromFT m
{-# INLINE improveT #-}