{-| Module : Control.Monad.AStar Description : Provides a monad transformer for A* search. See README for details and examples. Copyright : (c) Chris Penner, 2019 License : BSD3 See the for usage info and examples. -} {-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE MultiParamTypeClasses #-} module Control.Monad.AStar ( -- * Types AStar , AStarT -- * Methods , MonadAStar(..) -- * Executing Search , runAStarT , execAStarT , evalAStarT , runAStar , execAStar , evalAStar , tryWhile , tryWhileT ) where import Control.Monad.AStar.Class import Control.Monad.Except import Control.Monad.Fail import Control.Monad.Logic import Control.Applicative import Control.Monad.State import Data.Functor.Identity import Data.Maybe data Step c r a = Pure a | Weighted c | Solved r deriving (Show, Functor, Eq) -- | Non-transformer version of 'AStarT' type AStar s c r a = AStarT s c r Identity a -- | The 'AStar' search monad transformer -- -- Lots of type variables here: -- -- @s@: State; keep anything you want in here, it will stay coherent across -- branch switches. -- -- @c@: Cost measure: The type you'll use for determining the estimated cost of following a given -- branch. Usually requires 'Ord'. -- -- @r@: Result type, this is often redundant to State but is provided for convenience. -- This is the type you pass to 'done' when you've found a solution. -- -- @m@: An arbitrary monad which will be threaded through. -- -- Be wary that effects will be run in seemingly non-deterministic ordering as we switch -- chaotically between branches. newtype AStarT s c r m a = AStarT { unAStarT :: StateT s (LogicT m) (Step c r a) } deriving stock Functor -- mapResult :: (r -> r') -> AStarT s c r m a -> AStarT s c r' m a -- mapResult f (AStarT m) = AStarT $ fmap go m -- where -- go (Pure a) = Pure a -- go (Weighted c) = Weighted c -- go (Solved r) = Solved $ f r instance MonadTrans (AStarT s c r) where lift m = AStarT . lift . lift $ (Pure <$> m) instance (MonadIO m, Ord c) => MonadIO (AStarT s c r m) where liftIO io = lift $ liftIO io instance (Monad m, Ord c) => Applicative (AStarT s c r m) where pure = return (<*>) = ap instance (Ord c, Monad m) => MonadPlus (AStarT s c r m) where mzero = empty mplus = (<|>) instance (Ord c, Monad m) => MonadFail (AStarT s c r m) where fail _ = empty instance (Ord c, Monad m) => MonadState s (AStarT s c r m) where get = AStarT $ Pure <$> get put s = AStarT $ Pure <$> put s instance (Monad m, Ord c) => Monad (AStarT s c r m) where return = AStarT . return . Pure AStarT m >>= f = AStarT $ do msplit m >>= \case Nothing -> empty Just (Pure a, continue) -> unAStarT $ (f a) `weightedInterleave` (AStarT continue >>= f) Just (Solved r, _) -> pure $ Solved r Just (Weighted c, continue) -> do reflect $ Just (Weighted c, unAStarT $ AStarT continue >>= f) instance (Ord c, Monad m) => Alternative (AStarT s c r m) where empty = AStarT empty (<|>) = weightedInterleave weightedInterleave :: (Ord c, Monad m) => AStarT s c r m a -> AStarT s c r m a -> AStarT s c r m a weightedInterleave (AStarT a) (AStarT b) = AStarT $ weightedInterleave' a b weightedInterleave' :: (Ord c, MonadLogic m, MonadState s m) => m (Step c r a) -> m (Step c r a) -> m (Step c r a) weightedInterleave' ma mb = do beforeBoth <- get (rA, lState) <- liftA2 (,) (msplit ma) get put beforeBoth (rB, rState) <- liftA2 (,) (msplit mb) get case (rA, rB) of (m, Nothing) -> put lState >> reflect m (Nothing, m) -> put rState >> reflect m (Just (Solved a, _), _) -> put lState >> pure (Solved a) (_ , Just (Solved a, _)) -> put rState >> pure (Solved a) (l@(Just (Weighted lw, lm)), r@(Just (Weighted rw, rm))) | lw < rw -> (put lState >> pure (Weighted lw)) <|> ((put lState >> lm) `weightedInterleave'` (put rState >> reflect r)) | otherwise -> (put rState >> pure (Weighted rw)) <|> ((put rState >> rm) `weightedInterleave'` (put lState >> reflect l)) (l, r) -> (put lState >> reflect l) `weightedInterleave'` (put rState >> reflect r) -- | Run an A* computation effect returning the solution and branch state if one was found. runAStarT :: (Monad m) => AStarT s c r m a -> s -> m (Maybe (r, s)) runAStarT (AStarT m) s = fmap listToMaybe . observeManyT 1 $ do runStateT m s >>= \case (Solved a, s) -> return (a, s) _ -> empty -- | Run a pure A* computation returning the solution and branch state if one was found. runAStar :: AStar s c r a -> s -> Maybe (r, s) runAStar m s = runIdentity $ runAStarT m s -- | Run an effectful A* computation returning only the branch state execAStarT :: (Monad m) => AStarT s c r m a -> s -> m (Maybe s) execAStarT m s = fmap snd <$> runAStarT m s -- | Run an effectful A* computation returning only the solution evalAStarT :: (Monad m) => AStarT s c r m a -> s -> m (Maybe r) evalAStarT m s = fmap fst <$> runAStarT m s -- | Run a pure A* computation returning only the branch state execAStar :: AStar s c r a -> s -> (Maybe s) execAStar m s = fmap snd $ runAStar m s -- | Run a pure A* computation returning only the solution evalAStar :: AStar s c r a -> s -> (Maybe r) evalAStar m s = fmap fst $ runAStar m s -- | Run a pure A* search but short-circuit when the lowest cost fails a predicate. -- -- This is useful for detecting if your search is diverging, or is likely to fail. tryWhile :: (c -> Bool) -> AStar s c r a -> s -> (Maybe (r, s)) tryWhile p m s = runIdentity $ tryWhileT p m s -- | Effectful version of 'tryWhile' tryWhileT :: Monad m => (c -> Bool) -> AStarT s c r m a -> s -> m (Maybe (r, s)) tryWhileT p m s = do stepAStar m s >>= \case Nothing -> return Nothing Just ((Pure _, s), continue) -> tryWhileT p continue s Just ((Weighted c, s), continue) -> if p c then tryWhileT p continue s else return Nothing Just ((Solved r, s), _) -> return (Just (r, s)) stepAStar :: (Monad m) => AStarT s c r m a -> s -> m (Maybe ((Step c r a, s), AStarT s c r m a)) stepAStar (AStarT m) s = fmap (fmap go) . observeT . (fmap . fmap . fmap . fmap) fst $ msplit (runStateT m s) where go (v, x) = (v, AStarT (lift x)) instance (Ord w, Monad m) => MonadAStar w r (AStarT s w r m) where branch = (<|>) updateCost c = AStarT $ pure (Weighted c) <|> return (Pure ()) done = AStarT . pure . Solved