-- Copyright (c) 2011, David Amos. All rights reserved.

{-# LANGUAGE NoMonomorphismRestriction #-}

-- |A module for working with directed graphs (digraphs).
-- Some of the functions are specifically for working with directed acyclic graphs (DAGs),
-- that is, directed graphs containing no cycles.
module Math.Combinatorics.Digraph where

import Data.List as L
import qualified Data.Map.Strict as M
import qualified Data.Set as S

import Math.Core.Utils (picks, toSet)

-- |A digraph is represented as DG vs es, where vs is the list of vertices, and es is the list of edges.
-- Edges are directed: an edge (u,v) means an edge from u to v.
-- A digraph is considered to be in normal form if both es and vs are in ascending order.
-- This is the preferred form, and some functions will only work for digraphs in normal form.
data Digraph v = DG [v] [(v,v)] deriving (Eq,Ord,Show)

instance Functor Digraph where
    -- |If f is not order-preserving, then you should call nf afterwards
    fmap f (DG vs es) = DG (map f vs) (map (\(u,v)->(f u, f v)) es)

nf (DG vs es) = DG (L.sort vs) (L.sort es)

vertices (DG vs _) = vs

edges (DG _ es) = es


-- Is it valid to call them predecessors / successors in the case when the digraph contains cycles?

predecessors (DG _ es) v = [u | (u,v') <- es, v' == v]

successors (DG _ es) u = [v | (u',v) <- es, u' == u]

-- Calculate maps of predecessor and successor lists for each vertex in a digraph.
-- If a vertex has no predecessors (respectively successors), then it is left out of the relevant map
adjLists (DG vs es) = adjLists' (M.empty, M.empty) es
    where adjLists' (preds,succs) ((u,v):es) =
              adjLists' (M.insertWith (flip (++)) v [u] preds, M.insertWith (flip (++)) u [v] succs) es
          adjLists' (preds,succs) [] = (preds, succs)


digraphIsos1 (DG vsa esa) (DG vsb esb)
    | length vsa /= length vsb = []
    | length esa /= length esb = []
    | otherwise = digraphIsos' [] vsa vsb
    where digraphIsos' xys [] [] = [xys]
          digraphIsos' xys (x:xs) ys =
              concat [ digraphIsos' ((x,y):xys) xs ys'
                     | (y,ys') <- picks ys, isCompatible (x,y) xys]
          isCompatible (x,y) xys = and [ ((x,x') `elem` esa) == ((y,y') `elem` esb)
                                      && ((x',x) `elem` esa) == ((y',y) `elem` esb)
                                       | (x',y') <- xys ]

digraphIsos2 a b
    | length (vertices a) /= length (vertices b) = []
    | L.sort (M.elems indega) /= L.sort (M.elems indegb) = []
    | L.sort (M.elems outdega) /= L.sort (M.elems outdegb) = []
    | otherwise = dfs [] (vertices a) (vertices b)
    where (preda,succa) = adjLists a
          (predb,succb) = adjLists b
          indega = M.map length preda
          indegb = M.map length predb
          outdega = M.map length succa
          outdegb = M.map length succb
          isCompatible (x,y) xys = (M.findWithDefault 0 x indega) == (M.findWithDefault 0 y indegb)
                                && (M.findWithDefault 0 x outdega) == (M.findWithDefault 0 y outdegb)
                                && and [ (x' `elem` predx) == (y' `elem` predy)
                                      && (x' `elem` succx) == (y' `elem` succy)
                                       | let predx = M.findWithDefault [] x preda, let predy = M.findWithDefault [] y predb,
                                         let succx = M.findWithDefault [] x succa, let succy = M.findWithDefault [] y succb,
                                         (x',y') <- xys]
          dfs xys [] [] = [xys]
          dfs xys (x:xs) ys =
              concat [ dfs ((x,y):xys) xs ys'
                     | (y,ys') <- picks ys, isCompatible (x,y) xys]

-- For DAGs, can almost certainly do better than the above by using the height partition
-- However see remarks in Poset on orderIsos:
-- What is most efficient will depend on whether you want to list all of them, or just find out whether there are any or not
-- Could also try refining the height partition by (indegree,outdegree)


-- doesn't check whether input is a dag
-- if not, then the output will not contain all the vs
heightPartitionDAG dag@(DG vs es) = heightPartition' S.empty [v | v <- vs, v `M.notMember` preds] -- ie vertices with no predecessors
    where (preds,succs) = adjLists dag
          heightPartition' interior boundary
              | null boundary = []
              | otherwise = let interior' = S.union interior $ S.fromList boundary
                                boundary' = toSet [v | u <- boundary, v <- M.findWithDefault [] u succs,
                                                       all (`S.member` interior') (preds M.! v) ]
                            in boundary : heightPartition' interior' boundary'

isDAG dag@(DG vs _) = length vs == length (concat (heightPartitionDAG dag))

-- Only valid for DAGs, not for digraphs in general
dagIsos dagA@(DG vsA esA) dagB@(DG vsB esB)
    | length vsA /= length (concat heightPartA) = error "dagIsos: dagA is not a DAG"
    | length vsB /= length (concat heightPartB) = error "dagIsos: dagB is not a DAG"
    | map length heightPartA /= map length heightPartB = []
    | otherwise = dfs [] heightPartA heightPartB
    where heightPartA = heightPartitionDAG dagA
          heightPartB = heightPartitionDAG dagB
          (predsA,_) = adjLists dagA
          (predsB,_) = adjLists dagB
          dfs xys [] [] = [xys]
          dfs xys ([]:las) ([]:lbs) = dfs xys las lbs
          dfs xys ((x:xs):las) (ys:lbs) =
              concat [ dfs ((x,y):xys) (xs:las) (ys' : lbs)
                     | (y,ys') <- picks ys, isCompatible (x,y) xys]
          isCompatible (x,y) xys =
              let preds_x = M.findWithDefault [] x predsA
                  preds_y = M.findWithDefault [] y predsB
              in and [ (x' `elem` preds_x) == (y' `elem` preds_y) | (x',y') <- xys]
              -- and [ ((x',x) `elem` esA) == ((y',y) `elem` esB)
              --     | (x',y') <- xys ]
          -- we only need to check predecessors, not successors, because we proceeding by height ordering

-- can probably do better by intersecting the height partition with the (indegree,outdegree) partition
-- (although on very symmetrical posets such as B n, this won't help at all)

-- |Are the two DAGs isomorphic?
isDagIso :: (Ord a, Ord b) => Digraph a -> Digraph b -> Bool
isDagIso dagA dagB = (not . null) (dagIsos dagA dagB)


perms [] = [[]]
perms (x:xs) = [ls ++ [x] ++ rs | ps <- perms xs, (ls,rs) <- zip (inits ps) (tails ps)]
-- or use L.permutations

{-
-- orderings compatible with the height partition
heightOrderingsDAG dag@(DG vs es) = heightOrderings' [[]] (heightPartitionDAG dag)
    where heightOrderings' initsegs (level:levels) =
              let addsegs = perms level
                  initsegs' = [init ++ add | init <- initsegs, add <- addsegs]
              in heightOrderings' initsegs' levels
          heightOrderings' segs [] = segs
-}

isoRepDAG1 dag@(DG vs es) = isoRepDAG' [M.empty] 1 (heightPartitionDAG dag)
    where isoRepDAG' initmaps j (level:levels) =
              let j' = j + length level
                  addmaps = [M.fromList (zip ps [j..]) | ps <- perms level]
                  initmaps' = [init +++ add | init <- initmaps, add <- addmaps]
              in isoRepDAG' initmaps' j' levels
          isoRepDAG' maps _ [] = DG [1..length vs] (minimum [L.sort (map (\(u,v) -> (m M.! u, m M.! v)) es) | m <- maps])
          initmap +++ addmap = M.union initmap addmap

-- For example
-- > isoRepDAG1 (DG ['a'..'e'] [('a','c'),('a','d'),('b','d'),('b','e'),('d','e')])
-- ([1,2,3,4,5],[(1,3),(1,4),(2,3),(2,5),(3,5)])
-- > isoRepDAG1 (DG ['a'..'e'] [('a','d'),('a','e'),('b','c'),('b','d'),('d','e')])
-- ([1,2,3,4,5],[(1,3),(1,4),(2,3),(2,5),(3,5)])


-- Find the minimum height-preserving numberings of the vertices, using dfs
isoRepDAG2 dag@(DG vs es) = minimum $ dfs [] srclevels trglevels
    where -- (preds,succs) = adjLists dag
          srclevels = heightPartitionDAG dag
          trglevels = reverse $ fst $ foldl
                      (\(tls,is) sl -> let (js,ks) = splitAt (length sl) is in (js:tls,ks))
                      ([],[1..]) srclevels
          dfs xys [] [] = [xys]
          dfs xys ([]:sls) ([]:tls) = dfs xys sls tls
          dfs xys ((x:xs):sls) (ys:tls) =
              concat [ dfs ((x,y):xys) (xs:sls) (ys' : tls) | (y,ys') <- picks ys]
              -- not applying any compatibility condition yet


-- Find the height-respecting numbering of the vertices which leads to the minimal numbering of the edges
-- So this is calculating the same function as isoRepDAG1, but more efficiently
-- Uses dfs with pruning, rather than exhaustive search
isoRepDAG3 dag@(DG vs es) = dfs root [root]
    where n = length vs
          root = ([],(1,0),M.empty,(srclevels,trglevels)) -- root of the search tree
          (preds,succs) = adjLists dag
          srclevels = heightPartitionDAG dag
          trglevels = reverse $ fst $ foldl
                      (\(tls,is) sl -> let (js,ks) = splitAt (length sl) is in (js:tls,ks))
                      ([],[1..]) srclevels
          dfs best (node:stack) =
              -- node : -- for debugging
              case cmpPartial best node of
              LT -> dfs best stack                      -- ie prune the search tree at this node
              GT -> dfs node (successors node ++ stack) -- ie replace best with this node
              EQ -> dfs best (successors node ++ stack)
          -- dfs best [] = [best] -- !! for debugging
          dfs best@(es',_,_,_) [] = DG [1..n] es'
          successors (es,_,_,([],[])) = []
          successors (es,(i,j),m,([]:sls,[]:tls)) = successors (es,(i,j),m,(sls,tls))
          successors (es,(i,j),m,(xs:sls,(y:ys):tls)) =
              [ (es', (i',y), m', (L.delete x xs : sls, ys : tls))
              | x <- xs,
                let m' = M.insert x y m,
                let es' = L.sort $ es ++ [(m M.! u, y) | u <- M.findWithDefault [] x preds],
                let i' = nextunfinished m' i ]
          -- a vertex is considered finished when all its successors have assignments in the map
          nextunfinished m i =
              case [v | (v,i') <- M.assocs m, i' == i] of
              [] -> i
              [u] -> if all (`M.member` m) (M.findWithDefault [] u succs)
                     then nextunfinished m (i+1) -- i is finished: all successors already have assignments in the map
                     else i
          cmpPartial (es,_,_,_) (es',(i',j'),_,_) =
              cmpPartial' (i',j') es es'
              -- where j' = maximum $ 0 : map snd es'
          cmpPartial' (i',j') ((u,v):es) ((u',v'):es') =
          -- Any new e' that can be added to es' must be greater than (i',j')
          -- (we don't care about possible extensions of es, because we're not extending them)
              case compare (u,v) (u',v') of
              EQ -> cmpPartial' (i',j') es es'
              LT -> if (u,v) <= (i',j') then LT else EQ
              GT -> GT -- always replace best if you beat it
                       -- (even if it could improve, it's not going to as we're not progressing it)
          cmpPartial' (i',j') ((u,v):es) [] = if (u,v) <= (i',j') then LT else EQ
          cmpPartial' _ [] ((u',v'):es') = GT -- always extend an existing partial best
          cmpPartial' _ [] [] = EQ


-- Now we seek a numbering of the vertices which respects height-ordering,
-- and within each height level respects (indegree,outdegree) ordering.
-- We seek the numbering which minimises the resulting edge list.


-- |Given a directed acyclic graph (DAG), return a canonical representative for its isomorphism class.
-- @isoRepDAG dag@ is isomorphic to @dag@. It follows that if @isoRepDAG dagA == isoRepDAG dagB@ then @dagA@ is isomorphic to @dagB@.
-- Conversely, @isoRepDAG dag@ is the minimal element in the isomorphism class, subject to some constraints.
-- It follows that if @dagA@ is isomorphic to @dagB@, then @isoRepDAG dagA == isoRepDAG dagB@.
--
-- The algorithm of course is faster on some DAGs than others: roughly speaking,
-- it prefers \"tall\" DAGs (long chains) to \"wide\" DAGs (long antichains),
-- and it prefers asymmetric DAGs (ie those with smaller automorphism groups).
isoRepDAG :: (Ord a) => Digraph a -> Digraph Int
isoRepDAG dag@(DG vs es) = dfs root [root]
    where n = length vs
          root = ([],(1,0),M.empty,(srclevels,trglevels)) -- root of the search tree
          (preds,succs) = adjLists dag
          indegs = M.map length preds
          outdegs = M.map length succs
          byDegree vs = (map . map) snd $ L.groupBy (\(du,u) (dv,v) -> du == dv) $ L.sort
                        [( (M.findWithDefault 0 v indegs, M.findWithDefault 0 v outdegs), v) | v <- vs]
          srclevels = concatMap byDegree $ heightPartitionDAG dag
          trglevels = reverse $ fst $ foldl
                      (\(tls,is) sl -> let (js,ks) = splitAt (length sl) is in (js:tls,ks))
                      ([],[1..]) srclevels
          dfs best (node:stack) =
              -- node : -- for debugging
              case cmpPartial best node of
              LT -> dfs best stack                      -- ie prune the search tree at this node
              GT -> dfs node (successors node ++ stack) -- ie replace best with this node
              EQ -> dfs best (successors node ++ stack)
          -- dfs best [] = [best] -- !! for debugging
          dfs best@(es',_,_,_) [] = DG [1..n] es'
          successors (es,_,_,([],[])) = []
          successors (es,(i,j),m,([]:sls,[]:tls)) = successors (es,(i,j),m,(sls,tls))
          successors (es,(i,j),m,(xs:sls,(y:ys):tls)) =
              [ (es', (i',y), m', (L.delete x xs : sls, ys : tls))
              | x <- xs,
                let m' = M.insert x y m,
                let es' = L.sort $ es ++ [(m M.! u, y) | u <- M.findWithDefault [] x preds],
                let i' = nextunfinished m' i ]
          -- a vertex is considered finished when all its successors have assignments in the map
          nextunfinished m i =
              case [v | (v,i') <- M.assocs m, i' == i] of
              [] -> i
              [u] -> if all (`M.member` m) (M.findWithDefault [] u succs)
                     then nextunfinished m (i+1) -- i is finished: all successors already have assignments in the map
                     else i
          cmpPartial (es,_,_,_) (es',(i',j'),_,_) =
              cmpPartial' (i',j') es es'
              -- where j' = maximum $ 0 : map snd es'
          cmpPartial' (i',j') ((u,v):es) ((u',v'):es') =
          -- Any new e' that can be added to es' must be greater than (i',j')
          -- (we don't care about possible extensions of es, because we're not extending them)
              case compare (u,v) (u',v') of
              EQ -> cmpPartial' (i',j') es es'
              LT -> if (u,v) <= (i',j') then LT else EQ
              GT -> GT -- always replace best if you beat it
                       -- (even if it could improve, it's not going to as we're not progressing it)
          cmpPartial' (i',j') ((u,v):es) [] = if (u,v) <= (i',j') then LT else EQ
          cmpPartial' _ [] ((u',v'):es') = GT -- always extend an existing partial best
          cmpPartial' _ [] [] = EQ