module DDC.Core.Flow.Transform.Rates.Graph
        ( Graph
        , Edge
        , graphOfBinds
        , graphTopoOrder 
        , mergeWeights
        , traversal
        , invertMap
        , mlookup )
where
import DDC.Core.Collect
import DDC.Core.Flow.Compounds
import DDC.Core.Flow.Prim
import DDC.Core.Flow.Exp
import qualified DDC.Type.Env           as Env

import           Data.List              (intersect, nub)
import qualified Data.Map               as Map
import           Data.Maybe             (catMaybes)
import qualified Data.Set               as Set

-- | Graph for function
--   Each node is a binding, edges are dependencies, and the bool is whether the node's output
--   can be fused or contracted.
--   For example, filter and map dependencies can be contracted,
--   but a fold cannot as it must consume the entire stream before producing output.
--

type Edge  = (Name, Bool)
type Graph = Map.Map Name [Edge]

graphOfBinds :: [(Name,ExpF)] -> [Name] -> Graph
graphOfBinds binds extra_names
 = Map.map mkEdges graph1
 where
  mkEdges (refs, _fusible)
   = map getFusible refs
  
  getFusible r
   | Just (_,f) <- Map.lookup r graph1
   = (r, f)
   | otherwise
   = (r, True)

  graph1
   = Map.fromList
   $ map gen
   $ binds

  gen (k, xx)
   = let free = catMaybes
              $ map takeNameOfBound
              $ Set.toList
              $ freeX Env.empty xx
         refs = free `intersect` names
     in  (k, (refs, fusible xx))

  names = map fst binds ++ extra_names

  fusible xx
   | Just (f, _)                      <- takeXApps xx
   , XVar (UPrim (NameOpVector ov) _) <- f
   = case ov of
     OpVectorReduce
      -> False
     
     -- Length of `concrete rate' is known before iteration, so should be contractible.
     OpVectorLength
      -> False
     _
      -> True

   | otherwise
   = True


-- | Find topological ordering of DAG
-- Does not check for cycles - really must be a DAG!
graphTopoOrder :: Graph -> [Name]
graphTopoOrder graph
 = reverse $ go ([], Map.keysSet graph)
 where
  go (l, s)
   = case Set.minView s of
     Nothing
      -> l
     Just (m, _)
      -> go (visit (l,s) m)

  visit (l,s) m
   | Set.member m s
   = let edges    = mlookup "visit" graph m
         pres     = map fst edges
         s'       = Set.delete m s
         (l',s'') = foldl visit (l,s') pres
     in (m : l', s'')

   | otherwise
   = (l,s)



traversal :: Graph -> (Edge -> Name -> Int) -> Map.Map Name Int
traversal graph weight
 = foldl go Map.empty
 $ graphTopoOrder graph
 where
  go m node
   = let pres  = mlookup "traversal" graph node

         get e@(u,_)
          | Just v <- Map.lookup u m
          = v + weight e node
          | otherwise
          = 0

         w     = foldl max 0
               $ map get
               $ pres

     in  Map.insert node w m


mergeWeights :: Graph -> Map.Map Name Int -> Graph
mergeWeights graph weights
 = foldl go Map.empty
 $ graphTopoOrder graph
 where
  go m node
   -- Merge if it's a weighted one
   | Just k     <- name_maybe node
   = merge node k    m
   | otherwise
   = merge node node m

  merge node k m
   | Just edges <- Map.lookup node graph
   = let edges' = nub $ map (\(n,f) -> (name n, f)) edges
     in  Map.insertWith (\x y -> nub $ x ++ y) k edges' m
   | otherwise
   = m

  weights' = invertMap weights

  name n
   = maybe n id (name_maybe n)

  name_maybe n
   | Just i      <- Map.lookup n weights
   , Just (v:_)  <- Map.lookup i weights'
   = Just v
   | otherwise
   = Nothing


invertMap :: (Ord k, Ord v) => Map.Map k v -> Map.Map v [k]
invertMap m
 = Map.foldWithKey go Map.empty m
 where
  go k v m' = Map.insertWith (++) v [k] m'


mlookup :: Ord k => String -> Map.Map k v -> k -> v
mlookup str m k
 | Just v <- Map.lookup k m
 = v
 | otherwise
 = error ("ddc-core-flow.mlookup: no key " ++ str)