{-# LANGUAGE ScopedTypeVariables #-}
--------------------------------------------------------------------------------
-- |
-- Module      :  Algorithms.Graph.MST
-- Copyright   :  (C) Frank Staals
-- License     :  see the LICENSE file
-- Maintainer  :  Frank Staals
--------------------------------------------------------------------------------
module Algorithms.Graph.MST( mst
                           , mstEdges
                           , makeTree
                           ) where

import           Algorithms.Graph.DFS (AdjacencyLists, dfs')
import           Control.Monad (forM_, when, filterM)
import           Control.Monad.ST (ST,runST)
import qualified Data.List as L
import           Data.PlanarGraph
import           Data.Tree
import qualified Data.Vector as V
import qualified Data.Vector.Mutable as MV
import qualified Data.Vector.Unboxed.Mutable as UMV

--------------------------------------------------------------------------------


-- | Minimum spanning tree of the edges. The result is a rooted tree, in which
-- the nodes are the vertices in the planar graph together with the edge weight
-- of the edge to their parent. The root's weight is zero.
--
-- The algorithm used is Kruskal's.
--
-- running time: \(O(n \log n)\)
mst   :: Ord e => PlanarGraph s w v e f -> Tree (VertexId s w)
mst :: PlanarGraph s w v e f -> Tree (VertexId s w)
mst PlanarGraph s w v e f
g = PlanarGraph s w v e f -> [Dart s] -> Tree (VertexId s w)
forall k (s :: k) (w :: World) v e f.
PlanarGraph s w v e f -> [Dart s] -> Tree (VertexId s w)
makeTree PlanarGraph s w v e f
g ([Dart s] -> Tree (VertexId s w))
-> [Dart s] -> Tree (VertexId s w)
forall a b. (a -> b) -> a -> b
$ PlanarGraph s w v e f -> [Dart s]
forall k e (s :: k) (w :: World) v f.
Ord e =>
PlanarGraph s w v e f -> [Dart s]
mstEdges PlanarGraph s w v e f
g
  -- TODO: Add edges/darts to the output somehow.

-- | Computes the set of edges in the Minimum spanning tree
--
-- running time: \(O(n \log n)\)
mstEdges   :: Ord e => PlanarGraph s w v e f -> [Dart s]
mstEdges :: PlanarGraph s w v e f -> [Dart s]
mstEdges PlanarGraph s w v e f
g = (forall s. ST s [Dart s]) -> [Dart s]
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s [Dart s]) -> [Dart s])
-> (forall s. ST s [Dart s]) -> [Dart s]
forall a b. (a -> b) -> a -> b
$ do
          UF s (VertexId s w)
uf <- Int -> ST s (UF s (VertexId s w))
forall k s (a :: k). Int -> ST s (UF s a)
new (PlanarGraph s w v e f -> Int
forall k (s :: k) (w :: World) v e f. PlanarGraph s w v e f -> Int
numVertices PlanarGraph s w v e f
g)
          (Dart s -> ST s Bool) -> [Dart s] -> ST s [Dart s]
forall (m :: * -> *) a.
Applicative m =>
(a -> m Bool) -> [a] -> m [a]
filterM (\Dart s
e -> UF s (VertexId s w) -> VertexId s w -> VertexId s w -> ST s Bool
forall a s. (Enum a, Eq a) => UF s a -> a -> a -> ST s Bool
union UF s (VertexId s w)
uf (Dart s -> PlanarGraph s w v e f -> VertexId s w
forall k (s :: k) (w :: World) v e f.
Dart s -> PlanarGraph s w v e f -> VertexId s w
headOf Dart s
e PlanarGraph s w v e f
g) (Dart s -> PlanarGraph s w v e f -> VertexId s w
forall k (s :: k) (w :: World) v e f.
Dart s -> PlanarGraph s w v e f -> VertexId s w
tailOf Dart s
e PlanarGraph s w v e f
g)) [Dart s]
edges''
  where
    edges'' :: [Dart s]
edges'' = ((Dart s, e) -> Dart s) -> [(Dart s, e)] -> [Dart s]
forall a b. (a -> b) -> [a] -> [b]
map (Dart s, e) -> Dart s
forall a b. (a, b) -> a
fst ([(Dart s, e)] -> [Dart s])
-> (Vector (Dart s, e) -> [(Dart s, e)])
-> Vector (Dart s, e)
-> [Dart s]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Dart s, e) -> e) -> [(Dart s, e)] -> [(Dart s, e)]
forall b a. Ord b => (a -> b) -> [a] -> [a]
L.sortOn (Dart s, e) -> e
forall a b. (a, b) -> b
snd ([(Dart s, e)] -> [(Dart s, e)])
-> (Vector (Dart s, e) -> [(Dart s, e)])
-> Vector (Dart s, e)
-> [(Dart s, e)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vector (Dart s, e) -> [(Dart s, e)]
forall a. Vector a -> [a]
V.toList (Vector (Dart s, e) -> [Dart s]) -> Vector (Dart s, e) -> [Dart s]
forall a b. (a -> b) -> a -> b
$ PlanarGraph s w v e f -> Vector (Dart s, e)
forall k (s :: k) (w :: World) v e f.
PlanarGraph s w v e f -> Vector (Dart s, e)
edges PlanarGraph s w v e f
g


-- | Given an underlying planar graph, and a set of edges that form a tree,
-- create the actual tree.
--
-- pre: the planar graph has at least one vertex.
makeTree   :: forall s w v e f.
              PlanarGraph s w v e f -> [Dart s] -> Tree (VertexId s w)
makeTree :: PlanarGraph s w v e f -> [Dart s] -> Tree (VertexId s w)
makeTree PlanarGraph s w v e f
g = (AdjacencyLists s w -> VertexId s w -> Tree (VertexId s w))
-> VertexId s w -> AdjacencyLists s w -> Tree (VertexId s w)
forall a b c. (a -> b -> c) -> b -> a -> c
flip AdjacencyLists s w -> VertexId s w -> Tree (VertexId s w)
forall k (s :: k) (w :: World).
AdjacencyLists s w -> VertexId s w -> Tree (VertexId s w)
dfs' VertexId s w
start (AdjacencyLists s w -> Tree (VertexId s w))
-> ([Dart s] -> AdjacencyLists s w)
-> [Dart s]
-> Tree (VertexId s w)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Dart s] -> AdjacencyLists s w
mkAdjacencyLists
  where
    n :: Int
n = PlanarGraph s w v e f -> Int
forall k (s :: k) (w :: World) v e f. PlanarGraph s w v e f -> Int
numVertices PlanarGraph s w v e f
g
    start :: VertexId s w
start = Vector (VertexId s w) -> VertexId s w
forall a. Vector a -> a
V.head (Vector (VertexId s w) -> VertexId s w)
-> Vector (VertexId s w) -> VertexId s w
forall a b. (a -> b) -> a -> b
$ PlanarGraph s w v e f -> Vector (VertexId s w)
forall k (s :: k) (w :: World) v e f.
PlanarGraph s w v e f -> Vector (VertexId s w)
vertices' PlanarGraph s w v e f
g

    append                  :: MV.MVector s' [a] -> VertexId s w -> a -> ST s' ()
    append :: MVector s' [a] -> VertexId s w -> a -> ST s' ()
append MVector s' [a]
v (VertexId Int
i) a
x = MVector (PrimState (ST s')) [a] -> Int -> ST s' [a]
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> m a
MV.read MVector s' [a]
MVector (PrimState (ST s')) [a]
v Int
i ST s' [a] -> ([a] -> ST s' ()) -> ST s' ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= MVector (PrimState (ST s')) [a] -> Int -> [a] -> ST s' ()
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> a -> m ()
MV.write MVector s' [a]
MVector (PrimState (ST s')) [a]
v Int
i ([a] -> ST s' ()) -> ([a] -> [a]) -> [a] -> ST s' ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a
xa -> [a] -> [a]
forall a. a -> [a] -> [a]
:)

    mkAdjacencyLists         :: [Dart s] -> AdjacencyLists s w
    mkAdjacencyLists :: [Dart s] -> AdjacencyLists s w
mkAdjacencyLists [Dart s]
edges'' = (forall s. ST s (MVector s [VertexId s w])) -> AdjacencyLists s w
forall a. (forall s. ST s (MVector s a)) -> Vector a
V.create ((forall s. ST s (MVector s [VertexId s w])) -> AdjacencyLists s w)
-> (forall s. ST s (MVector s [VertexId s w]))
-> AdjacencyLists s w
forall a b. (a -> b) -> a -> b
$ do
                                 MVector s [VertexId s w]
vs <- Int
-> [VertexId s w]
-> ST s (MVector (PrimState (ST s)) [VertexId s w])
forall (m :: * -> *) a.
PrimMonad m =>
Int -> a -> m (MVector (PrimState m) a)
MV.replicate Int
n []
                                 [Dart s] -> (Dart s -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Dart s]
edges'' ((Dart s -> ST s ()) -> ST s ()) -> (Dart s -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Dart s
e -> do
                                   let u :: VertexId s w
u = Dart s -> PlanarGraph s w v e f -> VertexId s w
forall k (s :: k) (w :: World) v e f.
Dart s -> PlanarGraph s w v e f -> VertexId s w
headOf Dart s
e PlanarGraph s w v e f
g
                                       v :: VertexId s w
v = Dart s -> PlanarGraph s w v e f -> VertexId s w
forall k (s :: k) (w :: World) v e f.
Dart s -> PlanarGraph s w v e f -> VertexId s w
tailOf Dart s
e PlanarGraph s w v e f
g
                                   MVector s [VertexId s w] -> VertexId s w -> VertexId s w -> ST s ()
forall s' a. MVector s' [a] -> VertexId s w -> a -> ST s' ()
append MVector s [VertexId s w]
vs VertexId s w
u VertexId s w
v
                                   MVector s [VertexId s w] -> VertexId s w -> VertexId s w -> ST s ()
forall s' a. MVector s' [a] -> VertexId s w -> a -> ST s' ()
append MVector s [VertexId s w]
vs VertexId s w
v VertexId s w
u
                                 MVector s [VertexId s w] -> ST s (MVector s [VertexId s w])
forall (f :: * -> *) a. Applicative f => a -> f a
pure MVector s [VertexId s w]
vs
--------------------------------------------------------------------------------

-- | Union find DS
newtype UF s a = UF { UF s a -> MVector s (Int, Int)
_unUF :: UMV.MVector s (Int,Int) }

new   :: Int -> ST s (UF s a)
new :: Int -> ST s (UF s a)
new Int
n = do
          MVector s (Int, Int)
v <- Int -> ST s (MVector (PrimState (ST s)) (Int, Int))
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> m (MVector (PrimState m) a)
UMV.new Int
n
          [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0..Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
i ->
            MVector (PrimState (ST s)) (Int, Int)
-> Int -> (Int, Int) -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UMV.write MVector s (Int, Int)
MVector (PrimState (ST s)) (Int, Int)
v Int
i (Int
i,Int
0)
          UF s a -> ST s (UF s a)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (UF s a -> ST s (UF s a)) -> UF s a -> ST s (UF s a)
forall a b. (a -> b) -> a -> b
$ MVector s (Int, Int) -> UF s a
forall k s (a :: k). MVector s (Int, Int) -> UF s a
UF MVector s (Int, Int)
v

-- | Union the components containing x and y. Returns weather or not the two
-- components were already in the same component or not.
union               :: (Enum a, Eq a) => UF s a -> a -> a -> ST s Bool
union :: UF s a -> a -> a -> ST s Bool
union uf :: UF s a
uf@(UF MVector s (Int, Int)
v) a
x a
y = do
                        (a
rx,Int
rrx) <- UF s a -> a -> ST s (a, Int)
forall a s. (Enum a, Eq a) => UF s a -> a -> ST s (a, Int)
find' UF s a
uf a
x
                        (a
ry,Int
rry) <- UF s a -> a -> ST s (a, Int)
forall a s. (Enum a, Eq a) => UF s a -> a -> ST s (a, Int)
find' UF s a
uf a
y
                        let b :: Bool
b = a
rx a -> a -> Bool
forall a. Eq a => a -> a -> Bool
/= a
ry
                            rx' :: Int
rx' = a -> Int
forall a. Enum a => a -> Int
fromEnum a
rx
                            ry' :: Int
ry' = a -> Int
forall a. Enum a => a -> Int
fromEnum a
ry
                        Bool -> ST s () -> ST s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
b (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ case Int
rrx Int -> Int -> Ordering
forall a. Ord a => a -> a -> Ordering
`compare` Int
rry of
                            Ordering
LT -> MVector (PrimState (ST s)) (Int, Int)
-> Int -> (Int, Int) -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UMV.write MVector s (Int, Int)
MVector (PrimState (ST s)) (Int, Int)
v Int
rx'  (Int
ry',Int
rrx)
                            Ordering
GT -> MVector (PrimState (ST s)) (Int, Int)
-> Int -> (Int, Int) -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UMV.write MVector s (Int, Int)
MVector (PrimState (ST s)) (Int, Int)
v Int
ry' (Int
rx',Int
rry)
                            Ordering
EQ -> do MVector (PrimState (ST s)) (Int, Int)
-> Int -> (Int, Int) -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UMV.write MVector s (Int, Int)
MVector (PrimState (ST s)) (Int, Int)
v Int
ry' (Int
rx',Int
rry)
                                     MVector (PrimState (ST s)) (Int, Int)
-> Int -> (Int, Int) -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UMV.write MVector s (Int, Int)
MVector (PrimState (ST s)) (Int, Int)
v Int
rx' (Int
rx',Int
rrxInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)
                        Bool -> ST s Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
b


-- | Get the representative of the component containing x
-- find    :: (Enum a, Eq a) => UF s a -> a -> ST s a
-- find uf = fmap fst . find' uf

-- | get the representative (and its rank) of the component containing x
find'             :: (Enum a, Eq a) => UF s a -> a -> ST s (a,Int)
find' :: UF s a -> a -> ST s (a, Int)
find' uf :: UF s a
uf@(UF MVector s (Int, Int)
v) a
x = do
                      (Int
p,Int
r) <- MVector (PrimState (ST s)) (Int, Int) -> Int -> ST s (Int, Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UMV.read MVector s (Int, Int)
MVector (PrimState (ST s)) (Int, Int)
v (a -> Int
forall a. Enum a => a -> Int
fromEnum a
x) -- get my parent
                      if Int -> a
forall a. Enum a => Int -> a
toEnum Int
p a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
x then
                        (a, Int) -> ST s (a, Int)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a
x,Int
r) -- I am a root
                      else do
                        rt :: (a, Int)
rt@(a
j,Int
_) <- UF s a -> a -> ST s (a, Int)
forall a s. (Enum a, Eq a) => UF s a -> a -> ST s (a, Int)
find' UF s a
uf (Int -> a
forall a. Enum a => Int -> a
toEnum Int
p)  -- get the root of my parent
                        MVector (PrimState (ST s)) (Int, Int)
-> Int -> (Int, Int) -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UMV.write MVector s (Int, Int)
MVector (PrimState (ST s)) (Int, Int)
v (a -> Int
forall a. Enum a => a -> Int
fromEnum a
x) (a -> Int
forall a. Enum a => a -> Int
fromEnum a
j,Int
r)   -- path compression
                        (a, Int) -> ST s (a, Int)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a, Int)
rt


--------------------------------------------------------------------------------

-- partial implementation of Prims
-- mst g = undefined

-- -- | runs MST with a given root
-- mstFrom     :: (Ord e, Monoid e)
--             => VertexId s w -> PlanarGraph s w v e f -> Tree (VertexId s w, e)
-- mstFrom r g = prims initialQ (Node (r,mempty) [])
--   where
--     update' k p q = Q.adjust (const p) k q

--     -- initial Q has the value of the root set to the zero element, and has no
--     -- parent. The others are all set to Top (and have no parent yet)
--     initialQ = update' r (ValT (mempty,Nothing))
--              . GV.foldr (\v q -> Q.insert v (Top,Nothing) q) Q.empty $ vertices g

--     prims qq t = case Q.minView qq of
--       Nothing -> t
--       Just (v Q.:-> (w,p), q) -> prims $

--------------------------------------------------------------------------------
-- Testing Stuff

-- testG = planarGraph' [ [ (Dart aA Negative, "a-")
--                        , (Dart aC Positive, "c+")
--                        , (Dart aB Positive, "b+")
--                        , (Dart aA Positive, "a+")
--                        ]
--                      , [ (Dart aE Negative, "e-")
--                        , (Dart aB Negative, "b-")
--                        , (Dart aD Negative, "d-")
--                        , (Dart aG Positive, "g+")
--                        ]
--                      , [ (Dart aE Positive, "e+")
--                        , (Dart aD Positive, "d+")
--                        , (Dart aC Negative, "c-")
--                        ]
--                      , [ (Dart aG Negative, "g-")
--                        ]
--                      ]
--   where
--     (aA:aB:aC:aD:aE:aG:_) = take 6 [Arc 0..]