{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
module Control.Monad.AStar
(
AStar
, AStarT
, MonadAStar(..)
, branch
, failure
, runAStarT
, execAStarT
, evalAStarT
, runAStar
, execAStar
, evalAStar
, tryWhile
, tryWhileT
)
where
import Control.Monad.AStar.Class
import Control.Monad.Except
import Control.Monad.Fail
import Control.Monad.Logic
import Control.Applicative
import Control.Monad.State
import Data.Functor.Identity
import Data.Maybe
data Step c r a = Pure a | Weighted c | Solved r
deriving (Show, Functor, Eq)
type AStar s c r a = AStarT s c r Identity a
newtype AStarT s c r m a =
AStarT { unAStarT :: StateT s (LogicT m) (Step c r a)
} deriving stock Functor
instance MonadTrans (AStarT s c r) where
lift m = AStarT . lift . lift $ (Pure <$> m)
instance (MonadIO m, Ord c) => MonadIO (AStarT s c r m) where
liftIO io = lift $ liftIO io
instance (Monad m, Ord c) => Applicative (AStarT s c r m) where
pure = return
(<*>) = ap
instance (Ord c, Monad m) => MonadPlus (AStarT s c r m) where
mzero = empty
mplus = (<|>)
instance (Ord c, Monad m) => MonadFail (AStarT s c r m) where
fail _ = empty
instance (Ord c, Monad m) => MonadState s (AStarT s c r m) where
get = AStarT $ Pure <$> get
put s = AStarT $ Pure <$> put s
instance (Monad m, Ord c) => Monad (AStarT s c r m) where
return = AStarT . return . Pure
AStarT m >>= f = AStarT $ do
msplit m >>= \case
Nothing -> empty
Just (Pure a, continue) -> unAStarT $ (f a) `weightedInterleave` (AStarT continue >>= f)
Just (Solved r, _) -> pure $ Solved r
Just (Weighted c, continue) -> do
reflect $ Just (Weighted c, unAStarT $ AStarT continue >>= f)
instance (Ord c, Monad m) => Alternative (AStarT s c r m) where
empty = AStarT empty
(<|>) = weightedInterleave
weightedInterleave :: (Ord c, Monad m) => AStarT s c r m a -> AStarT s c r m a -> AStarT s c r m a
weightedInterleave (AStarT a) (AStarT b) = AStarT $ weightedInterleave' a b
weightedInterleave' :: (Ord c, MonadLogic m, MonadState s m) => m (Step c r a) -> m (Step c r a) -> m (Step c r a)
weightedInterleave' ma mb = do
beforeBoth <- get
(rA, lState) <- liftA2 (,) (msplit ma) get
put beforeBoth
(rB, rState) <- liftA2 (,) (msplit mb) get
case (rA, rB) of
(m, Nothing) -> put lState >> reflect m
(Nothing, m) -> put rState >> reflect m
(Just (Solved a, _), _) -> put lState >> pure (Solved a)
(_ , Just (Solved a, _)) -> put rState >> pure (Solved a)
(l@(Just (Weighted lw, lm)), r@(Just (Weighted rw, rm)))
| lw < rw ->
(put lState >> pure (Weighted lw))
<|> ((put lState >> lm) `weightedInterleave'` (put rState >> reflect r))
| otherwise ->
(put rState >> pure (Weighted rw))
<|> ((put rState >> rm) `weightedInterleave'` (put lState >> reflect l))
((Just (Pure la, lm)), r) ->
(put lState >> pure (Pure la)) `interleave` ((put lState >> lm) `weightedInterleave'` (put rState >> reflect r))
(l, (Just (Pure ra, rm))) ->
(put rState >> pure (Pure ra)) `interleave` ((put rState >> rm) `weightedInterleave'` (put lState >> reflect l))
runAStarT :: (Monad m) => AStarT s c r m a -> s -> m (Maybe (r, s))
runAStarT (AStarT m) s = fmap listToMaybe . observeManyT 1 $ do
runStateT m s >>= \case
(Solved a, s) -> return (a, s)
_ -> empty
runAStar :: AStar s c r a -> s -> Maybe (r, s)
runAStar m s = runIdentity $ runAStarT m s
execAStarT :: (Monad m) => AStarT s c r m a -> s -> m (Maybe s)
execAStarT m s = fmap snd <$> runAStarT m s
evalAStarT :: (Monad m) => AStarT s c r m a -> s -> m (Maybe r)
evalAStarT m s = fmap fst <$> runAStarT m s
execAStar :: AStar s c r a -> s -> (Maybe s)
execAStar m s = fmap snd $ runAStar m s
evalAStar :: AStar s c r a -> s -> (Maybe r)
evalAStar m s = fmap fst $ runAStar m s
tryWhile :: (c -> Bool) -> AStar s c r a -> s -> (Maybe (r, s))
tryWhile p m s = runIdentity $ tryWhileT p m s
tryWhileT :: Monad m => (c -> Bool) -> AStarT s c r m a -> s -> m (Maybe (r, s))
tryWhileT p m s = do
stepAStar m s >>= \case
Nothing -> return Nothing
Just ((Pure _, s), continue) -> tryWhileT p continue s
Just ((Weighted c, s), continue) ->
if p c then tryWhileT p continue s
else return Nothing
Just ((Solved r, s), _) -> return (Just (r, s))
stepAStar :: (Monad m) => AStarT s c r m a -> s -> m (Maybe ((Step c r a, s), AStarT s c r m a))
stepAStar (AStarT m) s = fmap (fmap go) . observeT . (fmap . fmap . fmap . fmap) fst $ msplit (runStateT m s)
where
go (v, x) = (v, AStarT (lift x))
instance (Ord w, Monad m) => MonadAStar w r (AStarT s w r m) where
updateCost c = AStarT $ pure (Weighted c) <|> return (Pure ())
done = AStarT . pure . Solved