{-# LANGUAGE CPP #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecursiveDo #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -fno-warn-warnings-deprecations #-}
module Control.Cond
( CondT, Cond
, runCondT, runCond, execCondT, evalCondT, test
, MonadQuery(..), guardM, guard_, guardM_, apply, consider
, accept, ignore, norecurse, prune
, matches, ifM, whenM, unlessM
, if_, when_, unless_, or_, and_, not_
, recurse
)
where
import Control.Applicative
import Control.Arrow (second)
import Control.Monad hiding (mapM_, sequence_)
import Control.Monad.Base
import Control.Monad.Catch
import Control.Monad.Cont.Class as C
import Control.Monad.Error.Class as E
import Control.Monad.Fix
import Control.Monad.Morph as M
import Control.Monad.Reader.Class as R
import Control.Monad.State.Class as S
import Control.Monad.Trans
import Control.Monad.Trans.Cont (ContT(..))
import Control.Monad.Trans.Control
import Control.Monad.Trans.Error (ErrorT(..))
import Control.Monad.Trans.Except (ExceptT(..))
import Control.Monad.Trans.Identity (IdentityT(..))
import Control.Monad.Trans.List (ListT(..))
import Control.Monad.Trans.Maybe (MaybeT(..))
import qualified Control.Monad.Trans.RWS.Lazy as LazyRWS
import qualified Control.Monad.Trans.RWS.Strict as StrictRWS
import Control.Monad.Trans.Reader (ReaderT(..))
import Control.Monad.Trans.State (StateT(..))
import qualified Control.Monad.Trans.State.Lazy as Lazy
import qualified Control.Monad.Trans.State.Strict as Strict
import qualified Control.Monad.Trans.Writer.Lazy as Lazy
import qualified Control.Monad.Trans.Writer.Strict as Strict
import Control.Monad.Writer.Class
import Control.Monad.Zip
import Data.Foldable
import Data.Functor.Identity
import Data.Monoid hiding ((<>))
import Data.Semigroup
import Prelude hiding (mapM_, foldr1, sequence_)
data Recursor a m r = Stop | Recurse (CondT a m r) | Continue
deriving Functor
instance Semigroup (Recursor a m r) where
Stop <> _ = Stop
_ <> Stop = Stop
Recurse n <> _ = Recurse n
_ <> Recurse n = Recurse n
_ <> _ = Continue
{-# INLINE (<>) #-}
instance Monoid (Recursor a m r) where
mempty = Continue
{-# INLINE mempty #-}
mappend = (<>)
{-# INLINE mappend #-}
instance MFunctor (Recursor a) where
hoist _ Stop = Stop
hoist nat (Recurse n) = Recurse (hoist nat n)
hoist _ Continue = Continue
{-# INLINE hoist #-}
type CondR a m r = (Maybe r, Recursor a m r)
accept' :: r -> CondR a m r
accept' x = (Just x, Continue)
{-# INLINE accept' #-}
recurse' :: CondR a m r
recurse' = (Nothing, Continue)
{-# INLINE recurse' #-}
newtype CondT a m r = CondT { getCondT :: StateT a m (CondR a m r) }
deriving Functor
type Cond a = CondT a Identity
instance (Monad m, Semigroup r) => Semigroup (CondT a m r) where
(<>) = liftM2 (<>)
{-# INLINE (<>) #-}
instance (Monad m, Monoid r) => Monoid (CondT a m r) where
mempty = CondT $ return mempty
{-# INLINE mempty #-}
mappend = liftM2 mappend
{-# INLINE mappend #-}
instance Monad m => Applicative (CondT a m) where
pure = return
{-# INLINE pure #-}
(<*>) = ap
{-# INLINE (<*>) #-}
instance Monad m => Monad (CondT a m) where
return = CondT . return . accept'
{-# INLINE return #-}
fail _ = mzero
{-# INLINE fail #-}
CondT m >>= k = CondT $ m >>= \case
(Nothing, Stop) -> return (Nothing, Stop)
(Nothing, Continue) -> return (Nothing, Continue)
(Nothing, Recurse n) -> return (Nothing, Recurse (n >>= k))
(Just r, Stop) -> fmap (const Stop) `liftM` getCondT (k r)
(Just r, Continue) -> getCondT (k r)
(Just r, Recurse n) -> getCondT (k r) >>= \case
(v, Continue) -> return (v, Recurse (n >>= k))
x -> return x
{-# INLINEABLE (>>=) #-}
instance MonadReader r m => MonadReader r (CondT a m) where
ask = lift R.ask
{-# INLINE ask #-}
local f (CondT m) = CondT $ R.local f m
{-# INLINE local #-}
reader = lift . R.reader
{-# INLINE reader #-}
instance MonadWriter w m => MonadWriter w (CondT a m) where
writer = lift . writer
{-# INLINE writer #-}
tell = lift . tell
{-# INLINE tell #-}
listen m = m >>= lift . listen . return
{-# INLINE listen #-}
pass m = m >>= lift . pass . return
{-# INLINE pass #-}
instance MonadState s m => MonadState s (CondT a m) where
get = lift S.get
{-# INLINE get #-}
put = lift . S.put
{-# INLINE put #-}
state = lift . S.state
{-# INLINE state #-}
instance Monad m => Alternative (CondT a m) where
empty = CondT $ return recurse'
{-# INLINE empty #-}
CondT f <|> CondT g = CondT $ do
r <- f
case r of
x@(Just _, _) -> return x
_ -> g
{-# INLINE (<|>) #-}
instance Monad m => MonadPlus (CondT a m) where
mzero = CondT $ return recurse'
{-# INLINE mzero #-}
mplus (CondT f) (CondT g) = CondT $ do
r <- f
case r of
x@(Just _, _) -> return x
_ -> g
{-# INLINE mplus #-}
instance MonadError e m => MonadError e (CondT a m) where
throwError = CondT . throwError
{-# INLINE throwError #-}
catchError (CondT m) h = CondT $ m `catchError` \e -> getCondT (h e)
{-# INLINE catchError #-}
instance MonadThrow m => MonadThrow (CondT a m) where
throwM = CondT . throwM
{-# INLINE throwM #-}
instance MonadCatch m => MonadCatch (CondT a m) where
catch (CondT m) c = CondT $ m `catch` \e -> getCondT (c e)
{-# INLINE catch #-}
#if MIN_VERSION_exceptions(0,6,0)
instance MonadMask m => MonadMask (CondT a m) where
#endif
mask a = CondT $ mask $ \u -> getCondT (a $ q u)
where q u = CondT . u . getCondT
{-# INLINE mask #-}
uninterruptibleMask a =
CondT $ uninterruptibleMask $ \u -> getCondT (a $ q u)
where q u = CondT . u . getCondT
{-# INLINEABLE uninterruptibleMask #-}
instance MonadBase b m => MonadBase b (CondT a m) where
liftBase m = CondT $ liftM accept' $ liftBase m
{-# INLINE liftBase #-}
instance MonadIO m => MonadIO (CondT a m) where
liftIO m = CondT $ liftM accept' $ liftIO m
{-# INLINE liftIO #-}
instance MonadTrans (CondT a) where
lift m = CondT $ liftM accept' $ lift m
{-# INLINE lift #-}
#if MIN_VERSION_monad_control(1,0,0)
instance MonadBaseControl b m => MonadBaseControl b (CondT r m) where
type StM (CondT r m) a = StM m (CondR r m a, r)
liftBaseWith f = CondT $ StateT $ \s ->
liftM (\x -> (accept' x, s)) $ liftBaseWith $ \runInBase ->
f $ \k -> runInBase $ runStateT (getCondT k) s
{-# INLINABLE liftBaseWith #-}
restoreM = CondT . StateT . const . restoreM
{-# INLINE restoreM #-}
#else
instance MonadBaseControl b m => MonadBaseControl b (CondT r m) where
newtype StM (CondT r m) a =
CondTStM { unCondTStM :: StM m (Result r m a, r) }
liftBaseWith f = CondT $ StateT $ \s ->
liftM (\x -> (accept' x, s)) $ liftBaseWith $ \runInBase -> f $ \k ->
liftM CondTStM $ runInBase $ runStateT (getCondT k) s
{-# INLINEABLE liftBaseWith #-}
restoreM = CondT . StateT . const . restoreM . unCondTStM
{-# INLINE restoreM #-}
#endif
instance MFunctor (CondT a) where
hoist nat (CondT m) = CondT $ hoist nat (fmap (hoist nat) `liftM` m)
{-# INLINE hoist #-}
instance MonadCont m => MonadCont (CondT a m) where
callCC f = CondT $ StateT $ \a ->
callCC $ \k -> flip runStateT a $ getCondT $ f $ \r ->
CondT $ StateT $ \a' -> k ((Just r, Continue), a')
instance Monad m => MonadZip (CondT a m) where
mzipWith = liftM2
{-# INLINE mzipWith #-}
instance MonadFix m => MonadFix (CondT a m) where
mfix f = CondT $ StateT $ \a -> mdo
((mb, n), a') <- case mb of
Nothing -> return ((mb, n), a')
Just b -> runStateT (getCondT (f b)) a
return ((mb, n), a')
runCondT :: Monad m => a -> CondT a m r -> m ((Maybe r, Maybe (CondT a m r)), a)
runCondT a c@(CondT (StateT s)) = go `liftM` s a
where
{-# INLINE go #-}
go (p, a') = (second (recursorToMaybe c) p, a')
{-# INLINE recursorToMaybe #-}
recursorToMaybe _ Stop = Nothing
recursorToMaybe p Continue = Just p
recursorToMaybe _ (Recurse n) = Just n
{-# INLINE runCondT #-}
runCond :: a -> Cond a r -> Maybe r
runCond = ((fst . fst . runIdentity) .) . runCondT
{-# INLINE runCond #-}
execCondT :: Monad m => a -> CondT a m r -> m (Maybe a, Maybe (CondT a m r))
execCondT a c = go `liftM` runCondT a c
where
go ((mr, mnext), a') = (const a' <$> mr, mnext)
{-# INLINE execCondT #-}
evalCondT :: Monad m => a -> CondT a m r -> m (Maybe r)
evalCondT a c = go `liftM` runCondT a c
where
go ((mr, _), _) = mr
{-# INLINE evalCondT #-}
test :: Monad m => a -> CondT a m r -> m Bool
test a c = go `liftM` runCondT a c
where
go ((Nothing, _), _) = False
go ((Just _, _), _) = True
{-# INLINE test #-}
class Monad m => MonadQuery a m | m -> a where
query :: m a
queries :: (a -> b) -> m b
update :: a -> m ()
updates :: (a -> a) -> m ()
instance Monad m => MonadQuery a (CondT a m) where
query = CondT $ gets accept'
{-# INLINE query #-}
queries f = CondT $ state (\a -> (accept' (f a), a))
{-# INLINE queries #-}
update a = CondT $ liftM accept' $ put a
{-# INLINE update #-}
updates f = CondT $ liftM accept' $ modify f
{-# INLINE updates #-}
instance MonadQuery r m => MonadQuery r (ReaderT r m) where
query = lift query
{-# INLINE query #-}
queries = lift . queries
{-# INLINE queries #-}
update = lift . update
{-# INLINE update #-}
updates = lift . updates
{-# INLINE updates #-}
instance (MonadQuery r m, Monoid w) => MonadQuery r (LazyRWS.RWST r w s m) where
query = lift query
{-# INLINE query #-}
queries = lift . queries
{-# INLINE queries #-}
update = lift . update
{-# INLINE update #-}
updates = lift . updates
{-# INLINE updates #-}
instance (MonadQuery r m, Monoid w)
=> MonadQuery r (StrictRWS.RWST r w s m) where
query = lift query
{-# INLINE query #-}
queries = lift . queries
{-# INLINE queries #-}
update = lift . update
{-# INLINE update #-}
updates = lift . updates
{-# INLINE updates #-}
instance MonadQuery r' m => MonadQuery r' (ContT r m) where
query = lift query
{-# INLINE query #-}
queries = lift . queries
{-# INLINE queries #-}
update = lift . update
{-# INLINE update #-}
updates = lift . updates
{-# INLINE updates #-}
instance (Error e, MonadQuery r m) => MonadQuery r (ErrorT e m) where
query = lift query
{-# INLINE query #-}
queries = lift . queries
{-# INLINE queries #-}
update = lift . update
{-# INLINE update #-}
updates = lift . updates
{-# INLINE updates #-}
instance MonadQuery r m => MonadQuery r (ExceptT e m) where
query = lift query
{-# INLINE query #-}
queries = lift . queries
{-# INLINE queries #-}
update = lift . update
{-# INLINE update #-}
updates = lift . updates
{-# INLINE updates #-}
instance MonadQuery r m => MonadQuery r (IdentityT m) where
query = lift query
{-# INLINE query #-}
queries = lift . queries
{-# INLINE queries #-}
update = lift . update
{-# INLINE update #-}
updates = lift . updates
{-# INLINE updates #-}
instance MonadQuery r m => MonadQuery r (ListT m) where
query = lift query
{-# INLINE query #-}
queries = lift . queries
{-# INLINE queries #-}
update = lift . update
{-# INLINE update #-}
updates = lift . updates
{-# INLINE updates #-}
instance MonadQuery r m => MonadQuery r (MaybeT m) where
query = lift query
{-# INLINE query #-}
queries = lift . queries
{-# INLINE queries #-}
update = lift . update
{-# INLINE update #-}
updates = lift . updates
{-# INLINE updates #-}
instance MonadQuery r m => MonadQuery r (Lazy.StateT s m) where
query = lift query
{-# INLINE query #-}
queries = lift . queries
{-# INLINE queries #-}
update = lift . update
{-# INLINE update #-}
updates = lift . updates
{-# INLINE updates #-}
instance MonadQuery r m => MonadQuery r (Strict.StateT s m) where
query = lift query
{-# INLINE query #-}
queries = lift . queries
{-# INLINE queries #-}
update = lift . update
{-# INLINE update #-}
updates = lift . updates
{-# INLINE updates #-}
instance (Monoid w, MonadQuery r m) => MonadQuery r (Lazy.WriterT w m) where
query = lift query
{-# INLINE query #-}
queries = lift . queries
{-# INLINE queries #-}
update = lift . update
{-# INLINE update #-}
updates = lift . updates
{-# INLINE updates #-}
instance (Monoid w, MonadQuery r m) => MonadQuery r (Strict.WriterT w m) where
query = lift query
{-# INLINE query #-}
queries = lift . queries
{-# INLINE queries #-}
update = lift . update
{-# INLINE update #-}
updates = lift . updates
{-# INLINE updates #-}
guardM :: MonadPlus m => m Bool -> m ()
guardM = (>>= guard)
{-# INLINE guardM #-}
guard_ :: (MonadPlus m, MonadQuery a m) => (a -> Bool) -> m ()
guard_ f = query >>= guard . f
{-# INLINE guard_ #-}
guardM_ :: (MonadPlus m, MonadQuery a m) => (a -> m Bool) -> m ()
guardM_ f = query >>= guardM . f
{-# INLINE guardM_ #-}
apply :: (MonadPlus m, MonadQuery a m) => (a -> m (Maybe r)) -> m r
apply = queries >=> (>>= maybe mzero return)
{-# INLINE apply #-}
consider :: (MonadPlus m, MonadQuery a m) => (a -> m (Maybe (r, a))) -> m r
consider = queries >=> (>>= maybe mzero (\(r, a') -> const r `liftM` update a'))
{-# INLINE consider #-}
accept :: MonadPlus m => m ()
accept = return ()
{-# INLINE accept #-}
ignore :: MonadPlus m => m r
ignore = mzero
{-# INLINE ignore #-}
norecurse :: Monad m => CondT a m ()
norecurse = CondT $ return (Just (), Stop)
{-# INLINE norecurse #-}
prune :: Monad m => CondT a m r
prune = CondT $ return (Nothing, Stop)
{-# INLINE prune #-}
matches :: MonadPlus m => m r -> m Bool
matches m = (const True `liftM` m) `mplus` return False
{-# INLINE matches #-}
ifM :: Monad m => m Bool -> m s -> m s -> m s
ifM c x y = c >>= \b -> if b then x else y
{-# INLINE ifM #-}
if_ :: MonadPlus m => m r -> m s -> m s -> m s
if_ c x y = matches c >>= \b -> if b then x else y
{-# INLINE if_ #-}
whenM :: Monad m => m Bool -> m s -> m ()
whenM c x = ifM c (x >> return ()) (return ())
{-# INLINE whenM #-}
when_ :: MonadPlus m => m r -> m s -> m ()
when_ c x = if_ c (x >> return ()) (return ())
{-# INLINE when_ #-}
unlessM :: Monad m => m Bool -> m s -> m ()
unlessM c x = ifM c (return ()) (x >> return ())
{-# INLINE unlessM #-}
unless_ :: MonadPlus m => m r -> m s -> m ()
unless_ c x = if_ c (return ()) (x >> return ())
{-# INLINE unless_ #-}
or_ :: MonadPlus m => [m r] -> m r
or_ = Data.Foldable.msum
{-# INLINE or_ #-}
and_ :: MonadPlus m => [m r] -> m ()
and_ = sequence_
{-# INLINE and_ #-}
not_ :: MonadPlus m => m r -> m ()
not_ c = if_ c ignore accept
{-# INLINE not_ #-}
recurse :: Monad m => CondT a m r -> CondT a m r
recurse c = CondT $ fmap (const (Recurse c)) `liftM` getCondT c
{-# INLINE recurse #-}