{-|
  Copyright   :  (C) 2018, QBayLogic
  License     :  BSD2 (see the file LICENSE)
  Maintainer  :  Christiaan Baaij <christiaan.baaij@gmail.com>

  Collection of utilities
-}

module Clash.Util.Graph
  ( topSort
  , reverseTopSort
  , callGraphBindings
  ) where

import           Data.Tuple            (swap)
import           Data.Foldable         (foldlM)
import qualified Data.IntMap.Strict    as IntMap
import qualified Data.IntSet           as IntSet

import           Clash.Core.Var (Id)
import           Clash.Core.Term (Term)
import           Clash.Driver.Types (BindingMap, Binding (bindingTerm))
import           Clash.Unique (lookupUniqMap', keysUniqMap)
import           Clash.Normalize.Util (callGraph)

data Marker
  = Temporary
  | Permanent

headSafe :: [a] -> Maybe a
headSafe :: [a] -> Maybe a
headSafe [] = Maybe a
forall a. Maybe a
Nothing
headSafe (a
a:[a]
_) = a -> Maybe a
forall a. a -> Maybe a
Just a
a

topSortVisit'
  :: IntMap.IntMap [Int]
  -- ^ Edges
  -> IntSet.IntSet
  -- ^ Unmarked nodes
  -> IntMap.IntMap Marker
  -- ^ Marked nodes
  -> [Int]
  -- ^ Sorted so far
  -> Int
  -- ^ Node to visit
  -> Either String (IntSet.IntSet, IntMap.IntMap Marker, [Int])
topSortVisit' :: IntMap [Int]
-> IntSet
-> IntMap Marker
-> [Int]
-> Int
-> Either String (IntSet, IntMap Marker, [Int])
topSortVisit' IntMap [Int]
edges IntSet
unmarked IntMap Marker
marked [Int]
sorted Int
node =
  case Int -> IntMap Marker -> Maybe Marker
forall a. Int -> IntMap a -> Maybe a
IntMap.lookup Int
node IntMap Marker
marked of
    Just Marker
Permanent -> (IntSet, IntMap Marker, [Int])
-> Either String (IntSet, IntMap Marker, [Int])
forall a b. b -> Either a b
Right (IntSet
unmarked, IntMap Marker
marked, [Int]
sorted)
    Just Marker
Temporary -> String -> Either String (IntSet, IntMap Marker, [Int])
forall a b. a -> Either a b
Left String
"cycle detected: cannot topsort cyclic graph"
    Maybe Marker
Nothing -> do
      let marked' :: IntMap Marker
marked'   = Int -> Marker -> IntMap Marker -> IntMap Marker
forall a. Int -> a -> IntMap a -> IntMap a
IntMap.insert Int
node Marker
Temporary IntMap Marker
marked
      let unmarked' :: IntSet
unmarked' = Int -> IntSet -> IntSet
IntSet.delete Int
node IntSet
unmarked
      let nodeToM :: [Int]
nodeToM   = [Int] -> Int -> IntMap [Int] -> [Int]
forall a. a -> Int -> IntMap a -> a
IntMap.findWithDefault [] Int
node IntMap [Int]
edges
      (IntSet
unmarked'', IntMap Marker
marked'', [Int]
sorted'') <-
        ((IntSet, IntMap Marker, [Int])
 -> Int -> Either String (IntSet, IntMap Marker, [Int]))
-> (IntSet, IntMap Marker, [Int])
-> [Int]
-> Either String (IntSet, IntMap Marker, [Int])
forall (t :: Type -> Type) (m :: Type -> Type) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldlM (IntSet, IntMap Marker, [Int])
-> Int -> Either String (IntSet, IntMap Marker, [Int])
visit (IntSet
unmarked', IntMap Marker
marked', [Int]
sorted) [Int]
nodeToM
      let marked''' :: IntMap Marker
marked''' = Int -> Marker -> IntMap Marker -> IntMap Marker
forall a. Int -> a -> IntMap a -> IntMap a
IntMap.insert Int
node Marker
Permanent IntMap Marker
marked''
      (IntSet, IntMap Marker, [Int])
-> Either String (IntSet, IntMap Marker, [Int])
forall (m :: Type -> Type) a. Monad m => a -> m a
return (IntSet
unmarked'', IntMap Marker
marked''', Int
node Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: [Int]
sorted'')
  where
    visit :: (IntSet, IntMap Marker, [Int])
-> Int -> Either String (IntSet, IntMap Marker, [Int])
visit (IntSet
unmarked', IntMap Marker
marked', [Int]
sorted') Int
node' =
      IntMap [Int]
-> IntSet
-> IntMap Marker
-> [Int]
-> Int
-> Either String (IntSet, IntMap Marker, [Int])
topSortVisit' IntMap [Int]
edges IntSet
unmarked' IntMap Marker
marked' [Int]
sorted' Int
node'

topSortVisit
  :: IntMap.IntMap [Int]
  -- ^ Edges
  -> IntSet.IntSet
  -- ^ Unmarked nodes
  -> IntMap.IntMap Marker
  -- ^ Marked nodes
  -> [Int]
  -- ^ Sorted so far
  -> Int
  -- ^ Node to visit
  -> Either String (IntSet.IntSet, IntMap.IntMap Marker, [Int])
topSortVisit :: IntMap [Int]
-> IntSet
-> IntMap Marker
-> [Int]
-> Int
-> Either String (IntSet, IntMap Marker, [Int])
topSortVisit IntMap [Int]
edges IntSet
unmarked IntMap Marker
marked [Int]
sorted Int
node = do
  (IntSet
unmarked', IntMap Marker
marked', [Int]
sorted') <-
    IntMap [Int]
-> IntSet
-> IntMap Marker
-> [Int]
-> Int
-> Either String (IntSet, IntMap Marker, [Int])
topSortVisit' IntMap [Int]
edges IntSet
unmarked IntMap Marker
marked [Int]
sorted Int
node

  case [Int] -> Maybe Int
forall a. [a] -> Maybe a
headSafe (IntSet -> [Int]
IntSet.toList IntSet
unmarked') of
    Maybe Int
Nothing    -> (IntSet, IntMap Marker, [Int])
-> Either String (IntSet, IntMap Marker, [Int])
forall (m :: Type -> Type) a. Monad m => a -> m a
return (IntSet
unmarked', IntMap Marker
marked', [Int]
sorted')
    Just Int
node' -> IntMap [Int]
-> IntSet
-> IntMap Marker
-> [Int]
-> Int
-> Either String (IntSet, IntMap Marker, [Int])
topSortVisit IntMap [Int]
edges IntSet
unmarked' IntMap Marker
marked' [Int]
sorted' Int
node'

-- | See: https://en.wikipedia.org/wiki/Topological_sorting. This function
-- errors if edges mention nodes not mentioned in the node list or if the
-- given graph contains cycles.
topSort
  :: [(Int, a)]
  -- ^ Nodes
  -> [(Int, Int)]
  -- ^ Edges
  -> Either String [a]
  -- ^ Error message or topologically sorted nodes
topSort :: [(Int, a)] -> [(Int, Int)] -> Either String [a]
topSort []             []     = [a] -> Either String [a]
forall a b. b -> Either a b
Right []
topSort []             [(Int, Int)]
_edges = String -> Either String [a]
forall a b. a -> Either a b
Left String
"Node list was empty, but edges non-empty"
topSort nodes :: [(Int, a)]
nodes@((Int, a)
node:[(Int, a)]
_)  [(Int, Int)]
edges = do
  [Int]
_ <- ((Int, Int) -> Either String Int)
-> [(Int, Int)] -> Either String [Int]
forall (t :: Type -> Type) (m :: Type -> Type) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\(Int
n, Int
m) -> Int -> Either String Int
checkNode Int
n Either String Int -> Either String Int -> Either String Int
forall (m :: Type -> Type) a b. Monad m => m a -> m b -> m b
>> Int -> Either String Int
checkNode Int
m) [(Int, Int)]
edges

  (IntSet
_, IntMap Marker
_, [Int]
sorted) <-
    IntMap [Int]
-> IntSet
-> IntMap Marker
-> [Int]
-> Int
-> Either String (IntSet, IntMap Marker, [Int])
topSortVisit IntMap [Int]
edges' (IntMap a -> IntSet
forall a. IntMap a -> IntSet
IntMap.keysSet IntMap a
nodes') IntMap Marker
forall a. IntMap a
IntMap.empty [] ((Int, a) -> Int
forall a b. (a, b) -> a
fst (Int, a)
node)

  (Int -> Either String a) -> [Int] -> Either String [a]
forall (t :: Type -> Type) (m :: Type -> Type) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Int -> Either String a
lookup' [Int]
sorted
    where
      nodes' :: IntMap a
nodes' = [(Int, a)] -> IntMap a
forall a. [(Int, a)] -> IntMap a
IntMap.fromList [(Int, a)]
nodes
      edges' :: IntMap [Int]
edges' = (IntMap [Int] -> (Int, Int) -> IntMap [Int])
-> IntMap [Int] -> [(Int, Int)] -> IntMap [Int]
forall (t :: Type -> Type) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl IntMap [Int] -> (Int, Int) -> IntMap [Int]
forall a. IntMap [a] -> (Int, a) -> IntMap [a]
insert IntMap [Int]
forall a. IntMap a
IntMap.empty [(Int, Int)]
edges

      -- Construction functions for quick lookup of edges from n to m, given n
      insert :: IntMap [a] -> (Int, a) -> IntMap [a]
insert IntMap [a]
im (Int
n, a
m)    = (Maybe [a] -> Maybe [a]) -> Int -> IntMap [a] -> IntMap [a]
forall a. (Maybe a -> Maybe a) -> Int -> IntMap a -> IntMap a
IntMap.alter (a -> Maybe [a] -> Maybe [a]
forall a. a -> Maybe [a] -> Maybe [a]
insert' a
m) Int
n IntMap [a]
im
      insert' :: a -> Maybe [a] -> Maybe [a]
insert' a
m Maybe [a]
Nothing   = [a] -> Maybe [a]
forall a. a -> Maybe a
Just [a
m]
      insert' a
m (Just [a]
ms) = [a] -> Maybe [a]
forall a. a -> Maybe a
Just (a
ma -> [a] -> [a]
forall a. a -> [a] -> [a]
:[a]
ms)

      -- Lookup node in nodes map. If not present, yield error
      lookup' :: Int -> Either String a
lookup' Int
n =
        case Int -> IntMap a -> Maybe a
forall a. Int -> IntMap a -> Maybe a
IntMap.lookup Int
n IntMap a
nodes' of
          Maybe a
Nothing
            -> String -> Either String a
forall a b. a -> Either a b
Left (String
"Node " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
n String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" in edge list, but not in node list.")
          Just a
n'
            -> a -> Either String a
forall a b. b -> Either a b
Right a
n'

      -- Check if edge is valid (i.e., mentioned nodes are in node list)
      checkNode :: Int -> Either String Int
checkNode Int
n
        | Int -> IntMap a -> Bool
forall a. Int -> IntMap a -> Bool
IntMap.notMember Int
n IntMap a
nodes' =
            String -> Either String Int
forall a b. a -> Either a b
Left (String
"Node " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
n String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" in edge list, but not in node list.")
        | Bool
otherwise =
            Int -> Either String Int
forall a b. b -> Either a b
Right Int
n

-- | Same as `reverse (topSort nodes edges)` if alternative representations are
-- considered the same. That is, topSort might produce multiple answers and
-- still deliver on its promise of yielding a topologically sorted node list.
-- Likewise, this function promises __one__ of those lists in reverse, but not
-- necessarily the reverse of topSort itself.
reverseTopSort
  :: [(Int, a)]
  -- ^ Nodes
  -> [(Int, Int)]
  -- ^ Edges
  -> Either String [a]
  -- ^ Reversely, topologically sorted nodes
reverseTopSort :: [(Int, a)] -> [(Int, Int)] -> Either String [a]
reverseTopSort [(Int, a)]
nodes [(Int, Int)]
edges =
  [(Int, a)] -> [(Int, Int)] -> Either String [a]
forall a. [(Int, a)] -> [(Int, Int)] -> Either String [a]
topSort [(Int, a)]
nodes (((Int, Int) -> (Int, Int)) -> [(Int, Int)] -> [(Int, Int)]
forall a b. (a -> b) -> [a] -> [b]
map (Int, Int) -> (Int, Int)
forall a b. (a, b) -> (b, a)
swap [(Int, Int)]
edges)

-- | Get all the terms corresponding to a call graph
callGraphBindings
  :: BindingMap
  -- ^ All bindings
  -> Id
  -- ^ Root of the call graph
  -> [Term]
callGraphBindings :: BindingMap -> Id -> [Term]
callGraphBindings BindingMap
bindingsMap Id
tm =
  (Int -> Term) -> [Int] -> [Term]
forall a b. (a -> b) -> [a] -> [b]
map (Binding Term -> Term
forall a. Binding a -> a
bindingTerm (Binding Term -> Term) -> (Int -> Binding Term) -> Int -> Term
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (BindingMap
bindingsMap BindingMap -> Int -> Binding Term
forall a b. (HasCallStack, Uniquable a) => UniqMap b -> a -> b
`lookupUniqMap'`)) (UniqMap (VarEnv Word) -> [Int]
forall a. UniqMap a -> [Int]
keysUniqMap UniqMap (VarEnv Word)
cg)
  where
    cg :: UniqMap (VarEnv Word)
cg = BindingMap -> Id -> UniqMap (VarEnv Word)
callGraph BindingMap
bindingsMap Id
tm