module Data.Graph.AStar (aStar) where
import qualified Data.Set as Set
import Data.Set (Set, (\\))
import qualified Data.Map as Map
import Data.Map (Map, (!))
import qualified Data.PSQueue as PSQ
import Data.PSQueue (PSQ, Binding(..), minView)
import Data.List (foldl')
data AStar a c = AStar { visited :: !(Set a),
waiting :: !(PSQ a c),
score :: !(Map a c),
memoHeur :: !(Map a c),
cameFrom :: !(Map a a),
end :: !(Maybe a) }
deriving Show
aStarInit start = AStar { visited = Set.empty,
waiting = PSQ.singleton start 0,
score = Map.singleton start 0,
memoHeur = Map.empty,
cameFrom = Map.empty,
end = Nothing }
runAStar :: (Ord a, Ord c, Num c) =>
(a -> Set a)
-> (a -> a -> c)
-> (a -> c)
-> (a -> Bool)
-> a
-> AStar a c
runAStar graph dist heur goal start = aStar' (aStarInit start)
where aStar' s
= case minView (waiting s) of
Nothing -> s
Just (x :-> _, w') ->
if goal x
then s { end = Just x }
else aStar' $ foldl' (expand x)
(s { waiting = w',
visited = Set.insert x (visited s)})
(Set.toList (graph x \\ visited s))
expand x s y
= let v = score s ! x + dist x y
in case PSQ.lookup y (waiting s) of
Nothing -> link x y v
(s { memoHeur
= Map.insert y (heur y) (memoHeur s) })
Just _ -> if v < score s ! y
then link x y v s
else s
link x y v s
= s { cameFrom = Map.insert y x (cameFrom s),
score = Map.insert y v (score s),
waiting = PSQ.insert y (v + memoHeur s ! y) (waiting s) }
aStar :: (Ord a, Ord c, Num c) =>
(a -> Set a)
-> (a -> a -> c)
-> (a -> c)
-> (a -> Bool)
-> a
-> Maybe [a]
aStar graph dist heur goal start
= let s = runAStar graph dist heur goal start
in case end s of
Nothing -> Nothing
Just e -> Just (reverse . takeWhile (not . (== start)) . iterate (cameFrom s !) $ e)
plane :: (Integer, Integer) -> Set (Integer, Integer)
plane (x,y) = Set.fromList [(x1,y),(x+1,y),(x,y1),(x,y+1)]
planeHole :: (Integer, Integer) -> Set (Integer, Integer)
planeHole (x,y) = Set.filter (\(u,v) -> planeDist (u,v) (0,0) > 10) (plane (x,y))
planeDist :: (Integer, Integer) -> (Integer, Integer) -> Double
planeDist (x1,y1) (x2,y2) = sqrt ((x1'x2')^2 + (y1'y2')^2)
where [x1',y1',x2',y2'] = map fromInteger [x1,y1,x2,y2]