-- |
-- Module      : MonusWeightedSearch.Examples.Dijkstra
-- Copyright   : (c) Donnacha Oisín Kidney 2021
-- Maintainer  : mail@doisinkidney.com
-- Stability   : experimental
-- Portability : non-portable
--
-- An implementation of Dijkstra's algorithm, using the 'HeapT' monad.
--
-- This is taken from section 6.1.3 of the paper
--
-- * Donnacha Oisín Kidney and Nicolas Wu. 2021. /Algebras for weighted search/.
--   Proc. ACM Program. Lang. 5, ICFP, Article 72 (August 2021), 30 pages.
--   DOI:<https://doi.org/10.1145/3473577>
--
-- This is a pretty simple implementation of the algorithm, defined monadically,
-- but it retains the time complexity of a standard purely functional
-- implementation.
--
-- We use the state monad here to avoid searching from the same node more than
-- once (which would lead to an infinite loop). Different algorithms use
-- different permutations of the monad transformers: for Dijkstra's algorithm,
-- we use @'HeapT' w ('State' ('Set' a)) a@, i.e. the 'HeapT' is outside of the
-- 'State'. This means that each branch of the search proceeds with a different
-- state; if we switch the order (to @'StateT' s ('Heap' w) a@, for example), we
-- get "global" state, which has the semantics of a /parser/. For an example
-- of that, see the module "MonusWeightedSearch.Examples.Parsing", where the
-- heap is used to implement a probabilistic parser.

module MonusWeightedSearch.Examples.Dijkstra where

-- $setup
-- >>> import Prelude hiding (head)
-- >>> import Data.List.NonEmpty (head)

import Prelude hiding (head)
import Control.Monad.State.Strict
import Control.Applicative
import Control.Monad.Writer
import Data.Foldable

import Data.Monus.Dist
import Data.Set (Set)
import qualified Data.Set as Set

import Data.List.NonEmpty (NonEmpty(..))

import Control.Monad.Heap


-- | The example graph from
-- <https://en.wikipedia.org/wiki/Dijkstra%27s_algorithm the Wikipedia article on Dijkstra's algorithm>.
--
-- <<https://upload.wikimedia.org/wikipedia/commons/5/57/Dijkstra_Animation.gif>>
graph :: Graph Int
graph :: Graph Int
graph Int
1 = [(Int
2,Dist
7),(Int
3,Dist
9),(Int
6,Dist
14)]
graph Int
2 = [(Int
3,Dist
10),(Int
4,Dist
15)]
graph Int
3 = [(Int
4,Dist
11), (Int
6,Dist
2)]
graph Int
4 = [(Int
5,Dist
6)]
graph Int
5 = []
graph Int
6 = [(Int
5,Dist
9)]
graph Int
_ = []

-- | @'unique' x@ checks that @x@ has not yet been seen in this branch of the
-- computation.
unique :: Ord a => a -> HeapT w (State (Set a)) a
unique :: forall a w. Ord a => a -> HeapT w (State (Set a)) a
unique a
x = do
  Set a
seen <- HeapT w (State (Set a)) (Set a)
forall s (m :: * -> *). MonadState s m => m s
get
  Bool -> HeapT w (State (Set a)) ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (a -> Set a -> Bool
forall a. Ord a => a -> Set a -> Bool
Set.notMember a
x Set a
seen)
  (Set a -> Set a) -> HeapT w (State (Set a)) ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (a -> Set a -> Set a
forall a. Ord a => a -> Set a -> Set a
Set.insert a
x)
  pure a
x
{-# INLINE unique #-}

-- | This is the Kleene star on the semiring of 'MonadPlus'. It is analagous to
-- the 'many' function on 'Alternative's.
star :: MonadPlus m => (a -> m a) -> a -> m a
star :: forall (m :: * -> *) a. MonadPlus m => (a -> m a) -> a -> m a
star a -> m a
f a
x = a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
x m a -> m a -> m a
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> (a -> m a
f a
x m a -> (a -> m a) -> m a
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (a -> m a) -> a -> m a
forall (m :: * -> *) a. MonadPlus m => (a -> m a) -> a -> m a
star a -> m a
f)
{-# INLINE star #-}

-- | This is a version of 'star' which keeps track of the inputs it was given.
pathed :: MonadPlus m => (a -> m a) -> a -> m (NonEmpty a)
pathed :: forall (m :: * -> *) a.
MonadPlus m =>
(a -> m a) -> a -> m (NonEmpty a)
pathed a -> m a
f = (NonEmpty a -> m (NonEmpty a)) -> NonEmpty a -> m (NonEmpty a)
forall (m :: * -> *) a. MonadPlus m => (a -> m a) -> a -> m a
star (\ ~(a
x :| [a]
xs) -> (a -> NonEmpty a) -> m a -> m (NonEmpty a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a -> [a] -> NonEmpty a
forall a. a -> [a] -> NonEmpty a
:|a
xa -> [a] -> [a]
forall a. a -> [a] -> [a]
:[a]
xs) (a -> m a
f a
x)) (NonEmpty a -> m (NonEmpty a))
-> (a -> NonEmpty a) -> a -> m (NonEmpty a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> [a] -> NonEmpty a
forall a. a -> [a] -> NonEmpty a
:| [])
{-# INLINE pathed #-}

-- | Dijkstra's algorithm. This function returns the length of the shortest path
-- from a given vertex to every vertex in the graph.
--
-- >>> dijkstra graph 1
-- [(1,0),(2,7),(3,9),(6,11),(5,20),(4,20)]
--
-- A version which actually produces the paths is 'shortestPaths'
dijkstra :: Ord a => Graph a -> a -> [(a, Dist)]
dijkstra :: forall a. Ord a => Graph a -> Graph a
dijkstra Graph a
g a
x =
  State (Set a) [(a, Dist)] -> Set a -> [(a, Dist)]
forall s a. State s a -> s -> a
evalState (HeapT Dist (State (Set a)) a -> State (Set a) [(a, Dist)]
forall (m :: * -> *) w a.
(Monad m, Monus w) =>
HeapT w m a -> m [(a, w)]
searchT ((a -> HeapT Dist (State (Set a)) a)
-> a -> HeapT Dist (State (Set a)) a
forall (m :: * -> *) a. MonadPlus m => (a -> m a) -> a -> m a
star ([HeapT Dist (State (Set a)) a] -> HeapT Dist (State (Set a)) a
forall (t :: * -> *) (f :: * -> *) a.
(Foldable t, Alternative f) =>
t (f a) -> f a
asum ([HeapT Dist (State (Set a)) a] -> HeapT Dist (State (Set a)) a)
-> (a -> [HeapT Dist (State (Set a)) a])
-> a
-> HeapT Dist (State (Set a)) a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((a, Dist) -> HeapT Dist (State (Set a)) a)
-> [(a, Dist)] -> [HeapT Dist (State (Set a)) a]
forall a b. (a -> b) -> [a] -> [b]
map (\(a
x,Dist
w) -> Dist -> HeapT Dist (State (Set a)) ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell Dist
w HeapT Dist (State (Set a)) ()
-> HeapT Dist (State (Set a)) a -> HeapT Dist (State (Set a)) a
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> a -> HeapT Dist (State (Set a)) a
forall a w. Ord a => a -> HeapT w (State (Set a)) a
unique a
x) ([(a, Dist)] -> [HeapT Dist (State (Set a)) a])
-> Graph a -> a -> [HeapT Dist (State (Set a)) a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Graph a
g) (a -> HeapT Dist (State (Set a)) a)
-> HeapT Dist (State (Set a)) a -> HeapT Dist (State (Set a)) a
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< a -> HeapT Dist (State (Set a)) a
forall a w. Ord a => a -> HeapT w (State (Set a)) a
unique a
x)) Set a
forall a. Set a
Set.empty
{-# INLINE dijkstra #-}

-- | Dijkstra's algorithm, which produces a path.
--
-- The only difference between this function and 'shortestPaths' is that this
-- uses 'pathed' rather than 'star'.
--
-- The following finds the shortest path from vertex 1 to 5:
--
-- >>> filter ((5==) . head . fst) (shortestPaths graph 1)
-- [(5 :| [6,3,1],20)]
--
-- And it is indeed @[1,3,6,5]@. (it's returned in reverse)
shortestPaths :: Ord a => Graph a -> a -> [(NonEmpty a, Dist)]
shortestPaths :: forall a. Ord a => Graph a -> a -> [(NonEmpty a, Dist)]
shortestPaths Graph a
g a
x =
  State (Set a) [(NonEmpty a, Dist)] -> Set a -> [(NonEmpty a, Dist)]
forall s a. State s a -> s -> a
evalState (HeapT Dist (State (Set a)) (NonEmpty a)
-> State (Set a) [(NonEmpty a, Dist)]
forall (m :: * -> *) w a.
(Monad m, Monus w) =>
HeapT w m a -> m [(a, w)]
searchT ((a -> HeapT Dist (State (Set a)) a)
-> a -> HeapT Dist (State (Set a)) (NonEmpty a)
forall (m :: * -> *) a.
MonadPlus m =>
(a -> m a) -> a -> m (NonEmpty a)
pathed ([HeapT Dist (State (Set a)) a] -> HeapT Dist (State (Set a)) a
forall (t :: * -> *) (f :: * -> *) a.
(Foldable t, Alternative f) =>
t (f a) -> f a
asum ([HeapT Dist (State (Set a)) a] -> HeapT Dist (State (Set a)) a)
-> (a -> [HeapT Dist (State (Set a)) a])
-> a
-> HeapT Dist (State (Set a)) a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((a, Dist) -> HeapT Dist (State (Set a)) a)
-> [(a, Dist)] -> [HeapT Dist (State (Set a)) a]
forall a b. (a -> b) -> [a] -> [b]
map (\(a
x,Dist
w) -> Dist -> HeapT Dist (State (Set a)) ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell Dist
w HeapT Dist (State (Set a)) ()
-> HeapT Dist (State (Set a)) a -> HeapT Dist (State (Set a)) a
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> a -> HeapT Dist (State (Set a)) a
forall a w. Ord a => a -> HeapT w (State (Set a)) a
unique a
x) ([(a, Dist)] -> [HeapT Dist (State (Set a)) a])
-> Graph a -> a -> [HeapT Dist (State (Set a)) a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Graph a
g) (a -> HeapT Dist (State (Set a)) (NonEmpty a))
-> HeapT Dist (State (Set a)) a
-> HeapT Dist (State (Set a)) (NonEmpty a)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< a -> HeapT Dist (State (Set a)) a
forall a w. Ord a => a -> HeapT w (State (Set a)) a
unique a
x)) Set a
forall a. Set a
Set.empty
{-# INLINE shortestPaths #-}