{-# LANGUAGE BangPatterns #-}
-- | Graph algorithms in the ST monad.
module Data.Graph.ST (
    Graph, newGraph, newGraphNoDupeNodes,
    successorNodes,
    -- * SCC Computation
    SCC(..), sccs,
    -- * Relation Tools
    nonReflexiveRepresentativesForNodes,
) where

import Control.Monad
import Control.Monad.ST
import Data.Array.ST
import Data.Array.Unboxed
import Data.Hashable
import qualified Data.HashTable.Class as H
import qualified Data.HashTable.ST.Basic as B
import Data.List (sortBy)
import Data.STRef
import qualified Data.Set.MutableBit as BS

type HashTable s k v = B.HashTable s k v

data SCC a = AcyclicSCC a | CyclicSCC [a] deriving (Eq, Show)

nodesOfScc :: SCC a -> [a]
nodesOfScc (AcyclicSCC a) = [a]
nodesOfScc (CyclicSCC as) = as

-- | A graph of 'a's in the state thread s.
--
-- We store the successors in an unboxed array and store indexes into the
-- array for the index at which at a node's successors start. This is very
-- memory efficient and cache friendly.
data Graph s a = Graph {
        nodeMap :: HashTable s a Int,
        invNodeMap :: STArray s Int a,
        successorsArray :: UArray Int Int,
        successorStarts :: UArray Int Int,
        nodeCount :: Int
    }

-- a version of mapM that accumulates on the heap, not the stack.
mapM' f = go [] 
  where 
    go acc [] = return (reverse acc) 
    go acc (a:as) = do {x <- f a; go (x:acc) as} 

successorsForNode :: Graph s a -> Int -> [Int]
successorsForNode gr nid =
    map (\ix -> successorsArray gr!ix) [successorStarts gr!nid..(successorStarts gr!(nid+1) -1)]

successorNodes :: (Eq a, Hashable a) => Graph s a -> a -> ST s [a]
successorNodes graph node = do
    Just nid <- H.lookup (nodeMap graph) node
    mapM (readArray (invNodeMap graph)) (successorsForNode graph nid)

newGraph :: (Eq a, Hashable a) => [a] -> [(a, a)] -> ST s (Graph s a)
newGraph nodes edges = do
    nodeNumberTable <- H.new :: ST s (HashTable s a Int)

    nextNode <- newSTRef 0
    -- We firstly map each of the nodes to an int in [0, length nodes)
    mapM_ (\n -> do
        mnid <- H.lookup nodeNumberTable n
        case mnid of
            Just nid -> return ()
            Nothing -> do
                nid <- readSTRef nextNode
                H.insert nodeNumberTable n nid
                writeSTRef nextNode $! nid+1
        ) nodes
    nodeCount <- readSTRef nextNode

    -- The following requires mapM otherwise we might pop the stack
    intEdges <- mapM' (\(x,y) -> do
            Just ix <- H.lookup nodeNumberTable x
            Just iy <- H.lookup nodeNumberTable y
            return (ix,iy)
        ) edges

    invNodeNumberTable <- newArray_ (0, nodeCount-1) :: ST s (STArray s Int a)
    H.mapM_ (\ (a,aid) -> writeArray invNodeNumberTable aid a) nodeNumberTable

    let edgeCount = length edges
        !sortedEdges = sortBy (\ x y -> compare (fst x) (fst y)) intEdges
        successors :: UArray Int Int
        !successors = listArray (0, edgeCount-1) (map snd sortedEdges)
        -- nodeStarts[nodeCount] = edgeCount, to make everything easier.
        nodeStarts :: UArray Int Int
        !nodeStarts = listArray (0, nodeCount) (computeStarts (-1) 0 sortedEdges)
        computeStarts :: Int -> Int -> [(Int, Int)] -> [Int]
        computeStarts currentNode _ [] | currentNode == nodeCount = []
        computeStarts currentNode eix [] = eix : computeStarts (currentNode+1) eix []
        computeStarts currentNode eix ((n,_):es) | currentNode == n =
            computeStarts currentNode (eix+1) es
        computeStarts currentNode eix ((n,s):es) =
            eix : computeStarts (currentNode+1) eix ((n,s):es)

    return $! Graph {
        nodeMap = nodeNumberTable,
        invNodeMap = invNodeNumberTable,
        successorsArray = successors,
        successorStarts = nodeStarts,
        nodeCount = nodeCount
    }

newGraphNoDupeNodes :: (Eq a, Hashable a) => [a] -> [(a, a)] -> ST s (Graph s a)
newGraphNoDupeNodes nodes edges = do
    let nodeCount = length nodes
        edgeCount = length edges
    
    nodeNumberTable <- H.newSized nodeCount :: ST s (HashTable s a Int)
    invNodeNumberTable <- newArray_ (0, nodeCount-1) :: ST s (STArray s Int a)
    -- We firstly map each of the nodes to an int in [0, length nodes)
    zipWithM (H.insert nodeNumberTable) nodes [0..]
    zipWithM (writeArray invNodeNumberTable) [0..] nodes
    -- The following requires mapM otherwise we might pop the stack
    intEdges <- mapM' (\(x,y) -> do
            Just ix <- H.lookup nodeNumberTable x
            Just iy <- H.lookup nodeNumberTable y
            return (ix,iy)
        ) edges

    let !sortedEdges = sortBy (\ x y -> compare (fst x) (fst y)) intEdges
        successors :: UArray Int Int
        !successors = listArray (0, edgeCount-1) (map snd sortedEdges)
        -- nodeStarts[nodeCount] = edgeCount, to make everything easier.
        nodeStarts :: UArray Int Int
        !nodeStarts = listArray (0, nodeCount) (computeStarts (-1) 0 sortedEdges)
        computeStarts :: Int -> Int -> [(Int, Int)] -> [Int]
        computeStarts currentNode _ [] | currentNode == nodeCount = []
        computeStarts currentNode eix [] = eix : computeStarts (currentNode+1) eix []
        computeStarts currentNode eix ((n,_):es) | currentNode == n =
            computeStarts currentNode (eix+1) es
        computeStarts currentNode eix ((n,s):es) =
            eix : computeStarts (currentNode+1) eix ((n,s):es)

    return $! Graph {
        nodeMap = nodeNumberTable,
        invNodeMap = invNodeNumberTable,
        successorsArray = successors,
        successorStarts = nodeStarts,
        nodeCount = nodeCount
    }

-- | An optimised implementation of Tarjan's SCC algorithm.
sccs :: (Eq a, Hashable a) => Graph s a -> ST s [SCC a]
sccs graph = do
    sccs <- intSccs graph
    mapM (\ scc -> case scc of
                        AcyclicSCC xid -> do
                            x <- readArray (invNodeMap graph) xid
                            return $! AcyclicSCC x
                        CyclicSCC xids -> do
                            xs <- mapM (readArray (invNodeMap graph)) xids
                            return $! CyclicSCC xs) sccs

-- | An optimised implementation of Tarjan's SCC algorithm.
--
-- Returns the SCCs according to a reverse topological sort of the DAG of the
-- SCCs (i.e. if an scc x has an edge to an scc y, then x preceeds y).
intSccs :: (Eq a, Hashable a) => Graph s a -> ST s [SCC Int]
intSccs graph = do
    let successors = successorsArray graph
        nodeStarts = successorStarts graph
    -- The DFS number of each node
    dfsNumber <- newArray (0, nodeCount graph-1) (-1) :: ST s (STArray s Int Int)
    -- The lowlink of each node.
    lowlink <- newArray (0, nodeCount graph-1) (-1) :: ST s (STArray s Int Int)
    -- The current strongly connected component
    pointStack <- newSTRef []
    -- A bitset represetning the set of nodes in the current SCC.
    stackSet <- BS.newSized (nodeCount graph-1)
    -- The next DFS number
    nextDFSNumber <- newSTRef 0
    -- The set of computed SCCs.
    computedSccs <- newSTRef []

    -- We use an explicit stack to avoid popping the Haskell stack. Invariant
    -- is that the given node/transition has not yet been visited.
    programStack <- newSTRef []

    let modifyStackTop nid tid = modifySTRef programStack (\ stk -> (nid, tid):tail stk)
        popStack = do
            _:stk <- readSTRef programStack
            modifySTRef programStack (\ _ -> stk )
            case stk of
                [] -> return ()
                (nid, tid):_ -> visitTransition' nid tid

        strongConnect nid = do
            ix <- readSTRef nextDFSNumber
            writeArray lowlink nid ix
            writeArray dfsNumber nid ix
            writeSTRef nextDFSNumber (ix+1)
            modifySTRef pointStack (\xs -> nid : xs)
            BS.insert stackSet nid

            modifySTRef programStack (\stk -> (nid, nodeStarts!nid) : stk)
            visitTransition nid (nodeStarts!nid)

        visitTransition nid tid | tid >= nodeStarts!(nid+1) = do
            -- Finish visitng this node
            ix <- readArray dfsNumber nid
            ourLowLink <- readArray lowlink nid
            when (ourLowLink == ix) $ do
                -- Found a new SCC - pop items off the stack
                let takeItems is [] = return (is, [])
                    takeItems is (nid:nids) = do
                        nodeNum <- readArray dfsNumber nid
                        if nodeNum >= ix then takeItems (nid:is) nids
                        else return (is, nid:nids)
                items <- readSTRef pointStack
                (scc, remaining) <- takeItems [] items
                mapM_ (BS.remove stackSet) scc
                writeSTRef pointStack remaining
                let !scc' = case scc of
                                [x] | x `elem` successorsForNode graph x -> CyclicSCC [x]
                                [x] -> AcyclicSCC x
                                _ -> CyclicSCC scc
                modifySTRef computedSccs (\sccs -> scc' : sccs)
            popStack
        visitTransition nid tid = do
            let sid = successors!tid
        
            dfsNum <- readArray dfsNumber sid
            if dfsNum == -1 then do
                -- Found a tree arc
                -- Record that we need to come back to here.
                modifyStackTop nid tid
                strongConnect sid
            else do
                b <- BS.member stackSet sid
                when b $ do
                    theirNumber <- readArray dfsNumber sid
                    ourLowLink <- readArray lowlink nid
                    writeArray lowlink nid (min ourLowLink theirNumber) 
                visitTransition nid (tid+1)

        visitTransition' nid tid = do
            let sid = successors!tid
            ourLowLink <- readArray lowlink nid
            theirLowLink <- readArray lowlink sid
            writeArray lowlink nid (min ourLowLink theirLowLink)
            visitTransition nid (tid+1)

    mapM_ (\nid -> do
        num <- readArray dfsNumber nid
        when (num == -1) $ strongConnect nid) [0..nodeCount graph-1]
    readSTRef computedSccs

-- | Given a graph, computes the transitive (but not reflexive) closure of the
-- graph and then returns the relation (a,b) such that b is the representative
-- member for a. Note, no pairs of the form a == b are returned, even if there
-- is an edge from a to b. This is to minimise the size of the transitive
-- closure.
nonReflexiveRepresentativesForNodes :: (Eq a, Hashable a) => Graph s a -> ST s [(a, a)]
nonReflexiveRepresentativesForNodes graph = do
    sccs <- intSccs graph
    let sccCount = length sccs
    -- A map from node id to the id of the scc id
    sccForNode <- newArray (0, nodeCount graph-1) (-1) :: ST s (STUArray s Int Int)
    zipWithM (\ sccId scc ->
            mapM_ (\ nid -> writeArray sccForNode nid sccId) (nodesOfScc scc)
        ) [0..] sccs

    -- Map from scc id to node representative (of the scc)
    sccRepresentatives <- newArray (0, sccCount-1) (-1) :: ST s (STUArray s Int Int)
    zipWithM (\ sccId scc ->
            writeArray sccRepresentatives sccId (head (nodesOfScc scc))
        ) [0..] sccs

    -- Now, create the representative pairs
    xss <- mapM' (\nid -> do
            sccId <- readArray sccForNode nid
            sccRepr <- readArray sccRepresentatives sccId
            if sccRepr == nid then 
                -- don't include
                return []
            else do
                n <- readArray (invNodeMap graph) nid
                repr <- readArray (invNodeMap graph) sccRepr
                return [(n, repr)]
        ) [0..nodeCount graph-1]
    return $! concat xss