------------------------------------------------------------------------- -- | -- Module : Control.Monad.Logic -- Copyright : (c) 2007-2014 Dan Doel, -- (c) 2011-2013 Edward Kmett, -- (c) 2014 Roman Cheplyaka, -- (c) 2020-2021 Andrew Lelechenko, -- (c) 2020-2021 Kevin Quick -- License : BSD3 -- Maintainer : Andrew Lelechenko -- -- Adapted from the paper -- -- by Oleg Kiselyov, Chung-chieh Shan, Daniel P. Friedman, Amr Sabry. -- Note that the paper uses 'MonadPlus' vocabulary -- ('mzero' and 'mplus'), -- while examples below prefer 'empty' and '<|>' -- from 'Alternative'. ------------------------------------------------------------------------- {-# LANGUAGE CPP #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE UndecidableInstances #-} #if __GLASGOW_HASKELL__ >= 704 {-# LANGUAGE Safe #-} #endif module Control.Monad.Logic ( module Control.Monad.Logic.Class, -- * The Logic monad Logic, logic, runLogic, observe, observeMany, observeAll, -- * The LogicT monad transformer LogicT(..), runLogicT, observeT, observeManyT, observeAllT, module Control.Monad, module Trans ) where import Control.Applicative import Control.Monad import qualified Control.Monad.Fail as Fail import Control.Monad.Identity (Identity(..)) import Control.Monad.IO.Class (MonadIO(..)) import Control.Monad.Trans (MonadTrans(..)) import qualified Control.Monad.Trans as Trans import Control.Monad.Reader.Class (MonadReader(..)) import Control.Monad.State.Class (MonadState(..)) import Control.Monad.Error.Class (MonadError(..)) #if !MIN_VERSION_base(4,8,0) import Data.Monoid (Monoid (..)) #endif #if MIN_VERSION_base(4,9,0) import Data.Semigroup (Semigroup (..)) #endif import qualified Data.Foldable as F import qualified Data.Traversable as T import Control.Monad.Logic.Class ------------------------------------------------------------------------- -- | A monad transformer for performing backtracking computations -- layered over another monad @m@. newtype LogicT m a = LogicT { unLogicT :: forall r. (a -> m r -> m r) -> m r -> m r } ------------------------------------------------------------------------- -- | Extracts the first result from a 'LogicT' computation, -- failing if there are no results at all. #if !MIN_VERSION_base(4,13,0) observeT :: Monad m => LogicT m a -> m a #else observeT :: MonadFail m => LogicT m a -> m a #endif observeT lt = unLogicT lt (const . return) (fail "No answer.") ------------------------------------------------------------------------- -- | Extracts all results from a 'LogicT' computation, unless blocked by the -- underlying monad. -- -- For example, given -- -- >>> let nats = pure 0 <|> fmap (+ 1) nats -- -- some monads (like 'Identity', 'Control.Monad.Reader.Reader', -- 'Control.Monad.Writer.Writer', and 'Control.Monad.State.State') -- will be productive: -- -- >>> take 5 $ runIdentity (observeAllT nats) -- [0,1,2,3,4] -- -- but others (like 'Control.Monad.Except.ExceptT', -- and 'Control.Monad.Cont.ContT') will not: -- -- >>> take 20 <$> runExcept (observeAllT nats) -- -- In general, if the underlying monad manages control flow then -- 'observeAllT' may be unproductive under infinite branching, -- and 'observeManyT' should be used instead. observeAllT :: Applicative m => LogicT m a -> m [a] observeAllT m = unLogicT m (fmap . (:)) (pure []) ------------------------------------------------------------------------- -- | Extracts up to a given number of results from a 'LogicT' computation. observeManyT :: Monad m => Int -> LogicT m a -> m [a] observeManyT n m | n <= 0 = return [] | n == 1 = unLogicT m (\a _ -> return [a]) (return []) | otherwise = unLogicT (msplit m) sk (return []) where sk Nothing _ = return [] sk (Just (a, m')) _ = (a:) `liftM` observeManyT (n-1) m' ------------------------------------------------------------------------- -- | Runs a 'LogicT' computation with the specified initial success and -- failure continuations. -- -- The second argument ("success continuation") takes one result of -- the 'LogicT' computation and the monad to run for any subsequent -- matches. -- -- The third argument ("failure continuation") is called when the -- 'LogicT' cannot produce any more results. -- -- For example: -- -- >>> yieldWords = foldr ((<|>) . pure) empty -- >>> showEach wrd nxt = putStrLn wrd >> nxt -- >>> runLogicT (yieldWords ["foo", "bar"]) showEach (putStrLn "none!") -- foo -- bar -- none! -- >>> runLogicT (yieldWords []) showEach (putStrLn "none!") -- none! -- >>> showFirst wrd _ = putStrLn wrd -- >>> runLogicT (yieldWords ["foo", "bar"]) showFirst (putStrLn "none!") -- foo -- runLogicT :: LogicT m a -> (a -> m r -> m r) -> m r -> m r runLogicT (LogicT r) = r ------------------------------------------------------------------------- -- | The basic 'Logic' monad, for performing backtracking computations -- returning values (e.g. 'Logic' @a@ will return values of type @a@). type Logic = LogicT Identity ------------------------------------------------------------------------- -- | A smart constructor for 'Logic' computations. logic :: (forall r. (a -> r -> r) -> r -> r) -> Logic a logic f = LogicT $ \k -> Identity . f (\a -> runIdentity . k a . Identity) . runIdentity ------------------------------------------------------------------------- -- | Extracts the first result from a 'Logic' computation, failing if -- there are no results. -- -- >>> observe (pure 5 <|> pure 3 <|> empty) -- 5 -- -- >>> observe empty -- *** Exception: No answer. -- observe :: Logic a -> a observe lt = runIdentity $ unLogicT lt (const . pure) (error "No answer.") ------------------------------------------------------------------------- -- | Extracts all results from a 'Logic' computation. -- -- >>> observe (pure 5 <|> empty <|> empty <|> pure 3 <|> empty) -- [5,3] -- observeAll :: Logic a -> [a] observeAll = runIdentity . observeAllT ------------------------------------------------------------------------- -- | Extracts up to a given number of results from a 'Logic' computation. -- -- >>> let nats = pure 0 <|> fmap (+ 1) nats -- >>> observeMany 5 nats -- [0,1,2,3,4] -- observeMany :: Int -> Logic a -> [a] observeMany i = take i . observeAll -- Implementing 'observeMany' using 'observeManyT' is quite costly, -- because it calls 'msplit' multiple times. ------------------------------------------------------------------------- -- | Runs a 'Logic' computation with the specified initial success and -- failure continuations. -- -- >>> runLogic empty (+) 0 -- 0 -- -- >>> runLogic (pure 5 <|> pure 3 <|> empty) (+) 0 -- 8 -- runLogic :: Logic a -> (a -> r -> r) -> r -> r runLogic l s f = runIdentity $ unLogicT l si fi where si = fmap . s fi = Identity f instance Functor (LogicT f) where fmap f lt = LogicT $ \sk fk -> unLogicT lt (sk . f) fk instance Applicative (LogicT f) where pure a = LogicT $ \sk fk -> sk a fk f <*> a = LogicT $ \sk fk -> unLogicT f (\g fk' -> unLogicT a (sk . g) fk') fk instance Alternative (LogicT f) where empty = LogicT $ \_ fk -> fk f1 <|> f2 = LogicT $ \sk fk -> unLogicT f1 sk (unLogicT f2 sk fk) instance Monad (LogicT m) where return = pure m >>= f = LogicT $ \sk fk -> unLogicT m (\a fk' -> unLogicT (f a) sk fk') fk #if !MIN_VERSION_base(4,13,0) fail = Fail.fail #endif instance Fail.MonadFail (LogicT m) where fail _ = LogicT $ \_ fk -> fk instance MonadPlus (LogicT m) where mzero = empty mplus = (<|>) #if MIN_VERSION_base(4,9,0) instance Semigroup (LogicT m a) where (<>) = mplus sconcat = foldr1 mplus #endif instance Monoid (LogicT m a) where mempty = empty mappend = (<|>) mconcat = F.asum instance MonadTrans LogicT where lift m = LogicT $ \sk fk -> m >>= \a -> sk a fk instance (MonadIO m) => MonadIO (LogicT m) where liftIO = lift . liftIO instance (Monad m) => MonadLogic (LogicT m) where -- 'msplit' is quite costly even if the base 'Monad' is 'Identity'. -- Try to avoid it. msplit m = lift $ unLogicT m ssk (return Nothing) where ssk a fk = return $ Just (a, lift fk >>= reflect) once m = LogicT $ \sk fk -> unLogicT m (\a _ -> sk a fk) fk lnot m = LogicT $ \sk fk -> unLogicT m (\_ _ -> fk) (sk () fk) #if MIN_VERSION_base(4,8,0) instance {-# OVERLAPPABLE #-} (Applicative m, F.Foldable m) => F.Foldable (LogicT m) where foldMap f m = F.fold $ unLogicT m (fmap . mappend . f) (pure mempty) instance {-# OVERLAPPING #-} F.Foldable (LogicT Identity) where foldr f z m = runLogic m f z #else instance (Applicative m, F.Foldable m) => F.Foldable (LogicT m) where foldMap f m = F.fold $ unLogicT m (fmap . mappend . f) (pure mempty) #endif instance T.Traversable (LogicT Identity) where traverse g l = runLogic l (\a ft -> cons <$> g a <*> ft) (pure empty) where cons a l' = pure a <|> l' -- Needs undecidable instances instance MonadReader r m => MonadReader r (LogicT m) where ask = lift ask local f (LogicT m) = LogicT $ \sk fk -> do env <- ask local f $ m ((local (const env) .) . sk) (local (const env) fk) -- Needs undecidable instances instance MonadState s m => MonadState s (LogicT m) where get = lift get put = lift . put -- Needs undecidable instances instance MonadError e m => MonadError e (LogicT m) where throwError = lift . throwError catchError m h = LogicT $ \sk fk -> let handle r = r `catchError` \e -> unLogicT (h e) sk fk in handle $ unLogicT m (\a -> sk a . handle) fk