{-# LANGUAGE RankNTypes #-}

{- |
  Module      :  Data.Graph.Dom
  Copyright   :  (c) Matt Morrow 2009
  License     :  BSD3
  Maintainer  :  <morrow@moonpatio.com>
  Stability   :  experimental
  Portability :  portable

  The Lengauer-Tarjan graph dominators algorithm.

    \[1\] Lengauer, Tarjan,
      /A Fast Algorithm for Finding Dominators in a Flowgraph/, 1979.

    \[2\] Muchnick,
      /Advanced Compiler Design and Implementation/, 1997.

    \[3\] Brisk, Sarrafzadeh,
      /Interference Graphs for Procedures in Static Single/
      /Information Form are Interval Graphs/, 2007.

  TODO: An ST version.
-}

module Data.Graph.Dom (
   Node,Path,Edge
  ,Graph,Rooted
  ,idom,ipdom
  ,domTree,pdomTree
  ,dom,pdom
  ,pddfs,rpddfs
  ,fromAdj,fromEdges
  ,toAdj,toEdges
  ,asTree,asGraph
  ,parents,ancestors
) where

import Data.Tree
import Data.Map(Map)
import Data.IntMap(IntMap)
import Data.IntSet(IntSet)
import qualified Data.Map as M
import qualified Data.IntMap as IM
import qualified Data.IntSet as IS
import Data.Monoid(Monoid(..))
import Control.Applicative
import Control.Monad
import Data.List

-----------------------------------------------------------------------------

type Node       = Int
type Path       = [Node]
type Edge       = (Node,Node)
type Graph      = IntMap IntSet
type Rooted     = (Node, Graph)

-----------------------------------------------------------------------------

-- | /Dominators/.
-- Complexity as for @idom@
dom :: Rooted -> [(Node, Path)]
dom = ancestors . domTree

-- | /Post-dominators/.
-- Complexity as for @idom@.
pdom :: Rooted -> [(Node, Path)]
pdom = ancestors . pdomTree

-- | /Dominator tree/.
-- Complexity as for @idom@.
domTree :: Rooted -> Tree Node
domTree a@(r,_) =
  let is = filter ((/=r).fst) (idom a)
      tg = fromEdges (fmap swap is)
  in asTree (r,tg)

-- | /Post-dominator tree/.
-- Complexity as for @idom@.
pdomTree :: Rooted -> Tree Node
pdomTree a@(r,_) =
  let is = filter ((/=r).fst) (ipdom a)
      tg = fromEdges (fmap swap is)
  in asTree (r,tg)

-- | /Immediate dominators/.
-- /O(|E|*alpha(|E|,|V|))/, where /alpha(m,n)/ is
-- \"a functional inverse of Ackermann's function\".
--
-- This Complexity bound assumes /O(1)/ indexing. Since we're
-- using @IntMap@, it has an additional /lg |V|/ factor
-- somewhere in there. I'm not sure where.
idom :: Rooted -> [(Node,Node)]
idom = IM.toList
     . domE
     . execS idomM
     . initEnv
     . pruneReach

-- | /Immediate post-dominators/.
-- Complexity as for @idom@.
ipdom :: Rooted -> [(Node,Node)]
ipdom = IM.toList
      . domE
      . execS idomM
      . initEnv
      . pruneReach
      . mapsnd predG

-----------------------------------------------------------------------------

-- | /Post-dominated depth-first search/.
pddfs :: Rooted -> [Node]
pddfs = reverse . rpddfs

-- | /Reverse post-dominated depth-first search/.
rpddfs :: Rooted -> [Node]
rpddfs = concat . levels . pdomTree

-----------------------------------------------------------------------------

type Dom a = S Env a
type NodeSet    = IntSet
type NodeMap a  = IntMap a
data Env = Env
  {dfsE       :: !Int
  ,zeroE      :: !Node
  ,rootE      :: !Node
  ,succE      :: !Graph
  ,predE      :: !Graph
  ,bucketE    :: !Graph
  ,labelE     :: !(NodeMap Node)
  ,parentE    :: !(NodeMap Node)
  ,ancestorE  :: !(NodeMap Node)
  ,childE     :: !(NodeMap Node)
  ,ndfsE      :: !(IntMap  Node)
  ,dfnE       :: !(NodeMap Int)
  ,sdnoE      :: !(NodeMap Int)
  ,sizeE      :: !(NodeMap Int)
  ,domE       :: !(NodeMap Node)}
  deriving(Eq,Ord,Read,Show)

-----------------------------------------------------------------------------

idomM :: Dom ()
idomM = do
  dfsDom =<< rootM
  n <- gets dfsE
  forM_ [n,n-1..1] (\i-> do
    w <- ndfsM i
    sw <- sdnoM w
    ps <- predsM w
    forM_ ps (\v-> do
      u <- eval v
      su <- sdnoM u
      when (su < sw)
        (modify(\e->e{sdnoE
          =IM.insert w su (sdnoE e)})))
    z <- ndfsM =<< sdnoM w
    modify(\e->e{bucketE
      =IM.adjust (w`IS.insert`) z (bucketE e)})
    pw <- parentM w
    link pw w
    bps <- bucketM pw
    forM_ bps (\v-> do
      u <- eval v
      su <- sdnoM u
      sv <- sdnoM v
      let dv = case su < sv of
                True-> u
                False-> pw
      modify(\e->e{domE
        =IM.insert v dv (domE e)})))
  forM_ [1..n] (\i-> do
    w <- ndfsM i
    j <- sdnoM w
    z <- ndfsM j
    dw <- domM w
    when (dw /= z)
      (do ddw <- domM dw
          modify(\e->e{domE
            =IM.insert w ddw (domE e)})))

-----------------------------------------------------------------------------

eval :: Node -> Dom Node
eval v = do
  n0 <- zeroM
  a  <- ancestorM v
  case a==n0 of
    True-> labelM v
    False-> do
      compress v
      a   <- ancestorM v
      l   <- labelM v
      la  <- labelM a
      sl  <- sdnoM l
      sla <- sdnoM la
      case sl <= sla of
        True-> return l
        False-> return la

compress :: Node -> Dom ()
compress v = do
  n0  <- zeroM
  a   <- ancestorM v
  aa  <- ancestorM a
  when (aa /= n0) (do
    compress a
    a   <- ancestorM v
    aa  <- ancestorM a
    l   <- labelM v
    la  <- labelM a
    sl  <- sdnoM l
    sla <- sdnoM la
    when (sla < sl)
      (modify(\e->e{labelE
        =IM.insert v la (labelE e)}))
    modify(\e->e{ancestorE
      =IM.insert v aa (ancestorE e)}))

-----------------------------------------------------------------------------

link :: Node -> Node -> Dom ()
link v w = do
  n0  <- zeroM
  lw  <- labelM w
  slw <- sdnoM lw
  let balance s = do
        c   <- childM s
        lc  <- labelM c
        slc <- sdnoM lc
        case slw < slc of
          False-> return s
          True-> do
            zs  <- sizeM s
            zc  <- sizeM c
            cc  <- childM c
            zcc <- sizeM cc
            case 2*zc <= zs+zcc of
              True-> do
                modify(\e->e
                  {ancestorE=IM.insert c s (ancestorE e)
                  ,childE=IM.insert s cc (childE e)})
                balance s
              False-> do
                modify(\e->e
                  {sizeE=IM.insert c zs (sizeE e)
                  ,ancestorE=IM.insert s c (ancestorE e)})
                balance c
  s   <- balance w
  lw  <- labelM w
  zw  <- sizeM w
  modify(\e->e
    {labelE=IM.insert s lw (labelE e)
    ,sizeE=IM.adjust (+zw) v (sizeE e)})
  let follow s = do
        when (s /= n0) (do
          modify(\e->e{ancestorE
            =IM.insert s v (ancestorE e)})
          follow =<< childM s)
  zv  <- sizeM v
  follow =<< case zv < 2*zw of
              False-> return s
              True-> do
                cv <- childM v
                modify(\e->e{childE
                  =IM.insert v s (childE e)})
                return cv

-----------------------------------------------------------------------------

dfsDom :: Node -> Dom ()
dfsDom i = do
  _   <- go i
  n0  <- zeroM
  r   <- rootM
  modify(\e->e{parentE
    =IM.insert r n0 (parentE e)})
  where go i = do
          n <- nextM
          modify(\e->e
            {dfnE   = IM.insert i n (dfnE e)
            ,sdnoE  = IM.insert i n (sdnoE e)
            ,ndfsE  = IM.insert n i (ndfsE e)
            ,labelE = IM.insert i i (labelE e)})
          ss <- succsM i
          forM_ ss (\j-> do
            s <- sdnoM j
            case s==0 of
              False-> return()
              True-> do
                modify(\e->e{parentE=
                  IM.insert j i (parentE e)})
                go j)

-----------------------------------------------------------------------------

initEnv :: Rooted -> Env
initEnv (r,g) =
  let n = IM.size g
      ks = IM.keys g
      n0 = 1 + maximum ks
      ns = n0:ks
      doms      = IM.singleton r r
      sdno      = IM.fromList (zip ns (repeat 0))
      bucket    = IM.fromList (zip ns (repeat mempty))
      size      = IM.fromList (zip ns (0 : repeat 1))
      ancestor  = IM.fromList (zip ns (repeat n0))
      child     = ancestor
      label     = IM.singleton n0 n0
      pred      = predG g
 in Env {dfsE       = 0
        ,zeroE      = n0
        ,rootE      = r
        ,labelE     = label
        ,parentE    = mempty
        ,ancestorE  = ancestor
        ,childE     = child
        ,ndfsE      = mempty
        ,dfnE       = mempty
        ,sdnoE      = sdno
        ,sizeE      = size
        ,succE      = g
        ,predE      = pred
        ,bucketE    = bucket
        ,domE       = doms}

-----------------------------------------------------------------------------

zeroM :: Dom Node
zeroM = gets zeroE
domM :: Node -> Dom Node
domM i = gets ((IM.!i) . domE)
rootM :: Dom Node
rootM = gets rootE
succsM :: Node -> Dom [Node]
succsM i = gets (IS.toList . (!i) . succE)
predsM :: Node -> Dom [Node]
predsM i = gets (IS.toList . (!i) . predE)
bucketM :: Node -> Dom [Node]
bucketM i = gets (IS.toList . (!i) . bucketE)
sizeM :: Node -> Dom Int
sizeM i = gets ((IM.!i) . sizeE)
sdnoM :: Node -> Dom Int
sdnoM i = gets ((IM.!i) . sdnoE)
dfnM :: Node -> Dom Int
dfnM i = gets ((IM.!i) . dfnE)
ndfsM :: Int -> Dom Node
ndfsM i = gets ((IM.!i) . ndfsE)
childM :: Node -> Dom Node
childM i = gets ((IM.!i) . childE)
ancestorM :: Node -> Dom Node
ancestorM i = gets ((IM.!i) . ancestorE)
parentM :: Node -> Dom Node
parentM i = gets ((IM.!i) . parentE)
labelM :: Node -> Dom Node
labelM i = gets ((IM.!i) . labelE)
nextM :: Dom Int
nextM = do
  n <- gets dfsE
  let n' = n+1
  modify(\e->e{dfsE=n'})
  return n'

-----------------------------------------------------------------------------

(!) :: Monoid a => IntMap a -> Int -> a
(!) g n = maybe mempty id (IM.lookup n g)

fromAdj :: [(Node, [Node])] -> Graph
fromAdj = IM.fromList . fmap (mapsnd IS.fromList)

fromEdges :: [Edge] -> Graph
fromEdges = collectI IS.union fst (IS.singleton . snd)

toAdj :: Graph -> [(Node, [Node])]
toAdj = fmap (mapsnd IS.toList) . IM.toList

toEdges :: Graph -> [Edge]
toEdges = concatMap (uncurry (fmap . (,))) . toAdj

predG :: Graph -> Graph
predG g = IM.unionWith IS.union (go g) g0
  where g0 = fmap (const mempty) g
        go = flip IM.foldWithKey mempty (\i a m ->
                foldl' (\m p -> IM.insertWith mappend p
                                      (IS.singleton i) m)
                        m
                       (IS.toList a))

pruneReach :: Rooted -> Rooted
pruneReach (r,g) = (r,g2)
  where is = reachable
              (maybe mempty id
                . flip IM.lookup g) $ r
        g2 = IM.fromList
            . fmap (mapsnd (IS.filter (`IS.member`is)))
            . filter ((`IS.member`is) . fst)
            . IM.toList $ g

tip :: Tree a -> (a, [Tree a])
tip (Node a ts) = (a, ts)

parents :: Tree a -> [(a, a)]
parents (Node i xs) = p i xs
        ++ concatMap parents xs
  where p i = fmap (flip (,) i . rootLabel)

ancestors :: Tree a -> [(a, [a])]
ancestors = go []
  where go acc (Node i xs)
          = let acc' = i:acc
            in p acc' xs ++ concatMap (go acc') xs
        p is = fmap (flip (,) is . rootLabel)

asGraph :: Tree Node -> Rooted
asGraph t@(Node a _) = let g = go t in (a, fromAdj g)
  where go (Node a ts) = let as = (fst . unzip . fmap tip) ts
                          in (a, as) : concatMap go ts

asTree :: Rooted -> Tree Node
asTree (r,g) = let go a = Node a (fmap go ((IS.toList . f) a))
                   f = (g !)
            in go r

reachable :: (Node -> NodeSet) -> (Node -> NodeSet)
reachable f a = go (IS.singleton a) a
  where go seen a = let s = f a
                        as = IS.toList (s `IS.difference` seen)
                    in foldl' go (s `IS.union` seen) as

collectI :: (c -> c -> c)
        -> (a -> Int) -> (a -> c) -> [a] -> IntMap c
collectI (<>) f g
  = foldl' (\m a -> IM.insertWith (<>)
                                  (f a)
                                  (g a) m) mempty

collect :: (Ord b) => (c -> c -> c)
        -> (a -> b) -> (a -> c) -> [a] -> Map b c
collect (<>) f g
  = foldl' (\m a -> M.insertWith' (<>)
                                  (f a)
                                  (g a) m) mempty

swap :: (a,b) -> (b,a)
swap = uncurry (flip (,))

mapfst :: (a -> c) -> (a,b) -> (c,b)
mapfst f = \(a,b) -> (f a, b)

mapsnd :: (b -> c) -> (a,b) -> (a,c)
mapsnd f = \(a,b) -> (a, f b)

-----------------------------------------------------------------------------

newtype S s a = S {unS :: forall o. (a -> s -> o) -> s -> o}
instance Functor (S s) where
  fmap f (S g) = S (\k -> g (k . f))
instance Monad (S s) where
  return a = S (\k -> k a)
  S g >>= f = S (\k -> g (\a -> unS (f a) k))
instance Applicative (S s) where
  pure = return
  (<*>) = ap
get :: S s s
get = S (\k s -> k s s)
gets :: (s -> a) -> S s a
gets f = S (\k s -> k (f s) s)
set :: s -> S s ()
set s = S (\k _ -> k () s)
modify :: (s -> s) -> S s ()
modify f = S (\k -> k () . f)
runS :: S s a -> s -> (a, s)
runS (S g) = g (,)
evalS :: S s a -> s -> a
evalS (S g) = g const
execS :: S s a -> s -> s
execS (S g) = g (flip const)

-----------------------------------------------------------------------------

g0 = fromAdj
  [(1,[2,3])
  ,(2,[3])
  ,(3,[4])
  ,(4,[3,5,6])
  ,(5,[7])
  ,(6,[7])
  ,(7,[4,8])
  ,(8,[3,9,10])
  ,(9,[1])
  ,(10,[7])]

g1 = fromAdj
  [(0,[1])
  ,(1,[2,3])
  ,(2,[7])
  ,(3,[4])
  ,(4,[5,6])
  ,(5,[7])
  ,(6,[4])
  ,(7,[])]

-----------------------------------------------------------------------------