{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FlexibleContexts #-}
module Control.Monad.AStar
(
AStar
, AStarT
, BranchState (..)
, MonadAStar(..)
, branch
, failure
, runAStar
, runAStarT
, evalAStar
, evalAStarT
, execAStar
, execAStarT
)
where
import Control.Monad.AStar.Class
import Control.Monad.Except
import Control.Monad.Fail
import Control.Monad.Logic
import Control.Applicative
import Control.Monad.State
import Data.Functor.Identity
import Data.Maybe
data Step r a = Pure a | Checkpoint | Solved r
deriving (Show, Functor, Eq)
data BranchState s c =
BranchState { branchState :: s
, cumulativeCost :: c
, estimateTillDone :: c
}
deriving (Show, Eq)
totalCost :: Semigroup c => BranchState s c -> c
totalCost bs = cumulativeCost bs <> estimateTillDone bs
type AStar s c r a = AStarT s c r Identity a
newtype AStarT s c r m a =
AStarT { unAStarT :: StateT (BranchState s c) (LogicT m) (Step r a)
} deriving stock Functor
instance MonadTrans (AStarT s c r) where
lift m = AStarT . lift . lift $ (Pure <$> m)
instance (MonadIO m, Semigroup c, Ord c) => MonadIO (AStarT s c r m) where
liftIO io = lift $ liftIO io
instance (Monad m, Semigroup c, Ord c) => Applicative (AStarT s c r m) where
pure = return
(<*>) = ap
instance (Ord c, Semigroup c, Monad m) => MonadPlus (AStarT s c r m) where
mzero = empty
mplus = (<|>)
instance (Ord c, Semigroup c, Monad m) => MonadFail (AStarT s c r m) where
fail _ = empty
instance (Ord c, Semigroup c, Monad m) => MonadState s (AStarT s c r m) where
get = AStarT $ Pure . branchState <$> get
put s = AStarT $ Pure <$> modify (\bs -> bs{branchState=s})
instance (Monad m, Semigroup c, Ord c) => Monad (AStarT s c r m) where
return = AStarT . return . Pure
AStarT m >>= f = AStarT $ do
next <- msplit m
case next of
Nothing -> empty
(Just (Solved r, _)) -> pure (Solved r)
(Just (Pure a, continue)) -> (unAStarT $ f a) <|> unAStarT (AStarT continue >>= f)
(Just (Checkpoint, continue)) ->
pure Checkpoint <|> unAStarT (AStarT continue >>= f)
instance (Ord c, Monad m, Semigroup c) => Alternative (AStarT s c r m) where
empty = AStarT empty
(<|>) = weightedInterleave
weightedInterleave :: (Ord c, Semigroup c, Monad m) => AStarT s c r m a -> AStarT s c r m a -> AStarT s c r m a
weightedInterleave (AStarT a) (AStarT b) = AStarT $ weightedInterleave' a b
weightedInterleave' :: (Ord c, Semigroup c, MonadLogic m, MonadState (BranchState s c) m) => m (Step r a) -> m (Step r a) -> m (Step r a)
weightedInterleave' ma mb = do
beforeBoth <- get
(rA, lState) <- liftA2 (,) (msplit ma) get
put beforeBoth
(rB, rState) <- liftA2 (,) (msplit mb) get
case (rA, rB) of
(m, Nothing) -> put lState >> reflect m
(Nothing, m) -> put rState >> reflect m
(Just (Solved r, _), _) -> put lState >> pure (Solved r)
(_, Just (Solved r, _)) -> put rState >> pure (Solved r)
(l@(Just (Checkpoint, lm)), r@(Just (Checkpoint, rm)))
| totalCost lState <= totalCost rState ->do
(put lState >> pure Checkpoint)
<|> ((put lState >> lm) `weightedInterleave'` (put rState >> reflect r))
| otherwise ->
(put rState >> pure Checkpoint)
<|> ((put rState >> rm) `weightedInterleave'` (put lState >> reflect l))
((Just (Pure la, lm)), r) ->
(put lState >> pure (Pure la)) `interleave` ((put lState >> lm) `weightedInterleave'` (put rState >> reflect r))
(l, (Just (Pure ra, rm))) ->
(put rState >> pure (Pure ra)) `interleave` ((put rState >> rm) `weightedInterleave'` (put lState >> reflect l))
runAStar :: (Monoid c) => AStar s c r a -> s -> Maybe (r, s)
runAStar m s = runIdentity $ runAStarT m s
runAStarT :: (Monad m, Monoid c) => AStarT s c r m a -> s -> m (Maybe (r, s))
runAStarT (AStarT m) s = fmap listToMaybe . observeManyT 1 $ do
runStateT m (BranchState s mempty mempty) >>= \case
(Solved r, s) -> return (r, branchState s)
_ -> empty
evalAStar :: (Monoid c) => AStar s c r a -> s -> Maybe r
evalAStar m s = fst <$> runAStar m s
evalAStarT :: (Monad m, Monoid c) => AStarT s c r m a -> s -> m (Maybe r)
evalAStarT m s = fmap fst <$> runAStarT m s
execAStar :: (Monoid c) => AStar s c r a -> s -> Maybe s
execAStar m s = fmap snd $ runAStar m s
execAStarT :: (Monad m, Monoid c) => AStarT s c r m a -> s -> m (Maybe s)
execAStarT m s = fmap snd <$> runAStarT m s
instance (Ord c, Monoid c, Monad m) => MonadAStar c r (AStarT s c r m) where
estimate c = AStarT $ modify (\bs -> bs{estimateTillDone=c}) >> (pure Checkpoint <|> pure (Pure ()))
spend c = AStarT $ modify (\bs -> bs{cumulativeCost=cumulativeCost bs <> c}) >> (pure Checkpoint <|> pure (Pure ()))
done = AStarT . pure . Solved