--------------------------------------------------------------------------------
-- |
-- Module      :  Data.Geometry.Polygon.Core
-- Copyright   :  (C) David Himmelstrup
-- License     :  see the LICENSE file
-- Maintainer  :  David Himmelstrup
--
-- Implementation of Floyd-Warshall shortest path algorithm.
--
-- See Wikipedia article for details: https://en.wikipedia.org/wiki/Floyd%E2%80%93Warshall_algorithm
--
--------------------------------------------------------------------------------
module Algorithms.FloydWarshall
  ( mkIndex
  , mkGraph
  , floydWarshall
  ) where

import           Control.Monad               (forM_, when)
import           Control.Monad.ST            (ST)
import           Data.Vector.Unboxed.Mutable as V (MVector, length, replicate, unsafeRead,
                                                   unsafeWrite, Unbox)

-- | \( O(n^3) \)
floydWarshall :: (Unbox a, Fractional a, Ord a) => Int -> MVector s (a, Int) -> ST s ()
floydWarshall :: Int -> MVector s (a, Int) -> ST s ()
floydWarshall Int
n MVector s (a, Int)
graph = do
    let nSq :: Int
nSq = MVector s (a, Int) -> Int
forall a s. Unbox a => MVector s a -> Int
V.length MVector s (a, Int)
graph
    Bool -> ST s () -> ST s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
nSq) (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ [Char] -> ST s ()
forall a. HasCallStack => [Char] -> a
error [Char]
"Bad bounds"
    [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0 .. Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
k ->
      [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0 .. Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
i ->
        [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0 .. Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
j -> do
          (a
distIJ, Int
_) <- (Int, Int) -> ST s (a, Int)
access (Int
i,Int
j)
          (a
distIK, Int
pathIK) <- (Int, Int) -> ST s (a, Int)
access (Int
i,Int
k)
          (a
distKJ, Int
_) <- (Int, Int) -> ST s (a, Int)
access (Int
k,Int
j)
          let indirectDist :: a
indirectDist = a
distIK a -> a -> a
forall a. Num a => a -> a -> a
+ a
distKJ
          Bool -> ST s () -> ST s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (a
distIJ a -> a -> Bool
forall a. Ord a => a -> a -> Bool
> a
indirectDista -> a -> a
forall a. Num a => a -> a -> a
+a
indirectDista -> a -> a
forall a. Num a => a -> a -> a
*a
eps Bool -> Bool -> Bool
&& a
distIJ a -> a -> Bool
forall a. Ord a => a -> a -> Bool
> a
distIK Bool -> Bool -> Bool
&& a
distIJ a -> a -> Bool
forall a. Ord a => a -> a -> Bool
> a
distKJ) (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$
            (Int, Int) -> (a, Int) -> ST s ()
put (Int
i,Int
j) (a
indirectDist, Int
pathIK)
  where
    access :: (Int, Int) -> ST s (a, Int)
access (Int, Int)
idx = MVector (PrimState (ST s)) (a, Int) -> Int -> ST s (a, Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
V.unsafeRead MVector s (a, Int)
MVector (PrimState (ST s)) (a, Int)
graph (Int -> (Int, Int) -> Int
forall a. Num a => a -> (a, a) -> a
mkIndex Int
n (Int, Int)
idx)
    put :: (Int, Int) -> (a, Int) -> ST s ()
put (Int, Int)
idx (a, Int)
e = MVector (PrimState (ST s)) (a, Int) -> Int -> (a, Int) -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
V.unsafeWrite MVector s (a, Int)
MVector (PrimState (ST s)) (a, Int)
graph (Int -> (Int, Int) -> Int
forall a. Num a => a -> (a, a) -> a
mkIndex Int
n (Int, Int)
idx) (a, Int)
e
    eps :: a
eps = a
1e-10 -- When two paths are nearly the same length, pick the one with the fewest segments.

-- | Compute the index of an element in a given range.
mkIndex :: Num a => a -> (a, a) -> a
mkIndex :: a -> (a, a) -> a
mkIndex a
n (a
i,a
j) = a
ia -> a -> a
forall a. Num a => a -> a -> a
*a
na -> a -> a
forall a. Num a => a -> a -> a
+a
j

-- | Construct a weighted graph from \(n\) vertices, a max bound, and a list of weighted edges.
mkGraph :: (Unbox a, Num a) => Int -> a -> [(Int,Int,a)] -> ST s (MVector s (a, Int))
mkGraph :: Int -> a -> [(Int, Int, a)] -> ST s (MVector s (a, Int))
mkGraph Int
n a
maxValue [(Int, Int, a)]
edges = do
  MVector s (a, Int)
graph <- Int -> (a, Int) -> ST s (MVector (PrimState (ST s)) (a, Int))
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
V.replicate (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
n) (a
maxValue, Int
forall a. Bounded a => a
maxBound)
  [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0..Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
v -> do
    MVector (PrimState (ST s)) (a, Int) -> Int -> (a, Int) -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
unsafeWrite MVector s (a, Int)
MVector (PrimState (ST s)) (a, Int)
graph (Int -> (Int, Int) -> Int
forall a. Num a => a -> (a, a) -> a
mkIndex Int
n (Int
v,Int
v)) (a
0, Int
v)
  [(Int, Int, a)] -> ((Int, Int, a) -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(Int, Int, a)]
edges (((Int, Int, a) -> ST s ()) -> ST s ())
-> ((Int, Int, a) -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \(Int
i,Int
j,a
cost) -> do
    MVector (PrimState (ST s)) (a, Int) -> Int -> (a, Int) -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
unsafeWrite MVector s (a, Int)
MVector (PrimState (ST s)) (a, Int)
graph (Int -> (Int, Int) -> Int
forall a. Num a => a -> (a, a) -> a
mkIndex Int
n (Int
i,Int
j)) (a
cost, Int
j)
    MVector (PrimState (ST s)) (a, Int) -> Int -> (a, Int) -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
unsafeWrite MVector s (a, Int)
MVector (PrimState (ST s)) (a, Int)
graph (Int -> (Int, Int) -> Int
forall a. Num a => a -> (a, a) -> a
mkIndex Int
n (Int
j,Int
i)) (a
cost, Int
i)
  MVector s (a, Int) -> ST s (MVector s (a, Int))
forall (m :: * -> *) a. Monad m => a -> m a
return MVector s (a, Int)
graph