{-# LANGUAGE TypeFamilies, ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE CPP #-}
module CFG
    ( CFG, CfgEdge(..), EdgeInfo(..), EdgeWeight(..)
    , TransitionSource(..)
    
    , addWeightEdge, addEdge, delEdge
    , addNodesBetween, shortcutWeightMap
    , reverseEdges, filterEdges
    , addImmediateSuccessor
    , mkWeightInfo, adjustEdgeWeight
    
    , infoEdgeList, edgeList
    , getSuccessorEdges, getSuccessors
    , getSuccEdgesSorted, weightedEdgeList
    , getEdgeInfo
    , getCfgNodes, hasNode
    , loopMembers
    
    , getCfg, getCfgProc, pprEdgeWeights, sanityCheckCfg
    
    , optimizeCFG )
where
#include "HsVersions.h"
import GhcPrelude
import BlockId
import Cmm ( RawCmmDecl, GenCmmDecl( .. ), CmmBlock, succ, g_entry
           , CmmGraph )
import CmmNode
import CmmUtils
import CmmSwitch
import Hoopl.Collections
import Hoopl.Label
import Hoopl.Block
import qualified Hoopl.Graph as G
import Util
import Digraph
import Outputable
import PprCmm ()
import qualified DynFlags as D
import Data.List
type Edge = (BlockId, BlockId)
type Edges = [Edge]
newtype EdgeWeight
  = EdgeWeight Int
  deriving (Eq,Ord,Enum,Num,Real,Integral)
instance Outputable EdgeWeight where
  ppr (EdgeWeight w) = ppr w
type EdgeInfoMap edgeInfo = LabelMap (LabelMap edgeInfo)
type CFG = EdgeInfoMap EdgeInfo
data CfgEdge
  = CfgEdge
  { edgeFrom :: !BlockId
  , edgeTo :: !BlockId
  , edgeInfo :: !EdgeInfo
  }
instance Eq CfgEdge where
  (==) (CfgEdge from1 to1 _) (CfgEdge from2 to2 _)
    = from1 == from2 && to1 == to2
instance Ord CfgEdge where
  compare (CfgEdge from1 to1 (EdgeInfo {edgeWeight = weight1}))
          (CfgEdge from2 to2 (EdgeInfo {edgeWeight = weight2}))
    | weight1 < weight2 || weight1 == weight2 && from1 < from2 ||
      weight1 == weight2 && from1 == from2 && to1 < to2
    = LT
    | from1 == from2 && to1 == to2 && weight1 == weight2
    = EQ
    | otherwise
    = GT
instance Outputable CfgEdge where
  ppr (CfgEdge from1 to1 edgeInfo)
    = parens (ppr from1 <+> text "-(" <> ppr edgeInfo <> text ")->" <+> ppr to1)
data TransitionSource
  = CmmSource (CmmNode O C)
  | AsmCodeGen
  deriving (Eq)
data EdgeInfo
  = EdgeInfo
  { transitionSource :: !TransitionSource
  , edgeWeight :: !EdgeWeight
  } deriving (Eq)
instance Outputable EdgeInfo where
  ppr edgeInfo = text "weight:" <+> ppr (edgeWeight edgeInfo)
{-# INLINEABLE mkWeightInfo #-}
mkWeightInfo :: Integral n => n -> EdgeInfo
mkWeightInfo = EdgeInfo AsmCodeGen . fromIntegral
adjustEdgeWeight :: CFG -> (EdgeWeight -> EdgeWeight)
                 -> BlockId -> BlockId -> CFG
adjustEdgeWeight cfg f from to
  | Just info <- getEdgeInfo from to cfg
  , weight <- edgeWeight info
  = addEdge from to (info { edgeWeight = f weight}) cfg
  | otherwise = cfg
getCfgNodes :: CFG -> LabelSet
getCfgNodes m = mapFoldMapWithKey (\k v -> setFromList (k:mapKeys v)) m
hasNode :: CFG -> BlockId -> Bool
hasNode m node = mapMember node m || any (mapMember node) m
sanityCheckCfg :: CFG -> LabelSet -> SDoc -> Bool
sanityCheckCfg m blockSet msg
    | blockSet == cfgNodes
    = True
    | otherwise =
        pprPanic "Block list and cfg nodes don't match" (
            text "difference:" <+> ppr diff $$
            text "blocks:" <+> ppr blockSet $$
            text "cfg:" <+> ppr m $$
            msg )
            False
    where
      cfgNodes = getCfgNodes m :: LabelSet
      diff = (setUnion cfgNodes blockSet) `setDifference` (setIntersection cfgNodes blockSet) :: LabelSet
filterEdges :: (BlockId -> BlockId -> EdgeInfo -> Bool) -> CFG -> CFG
filterEdges f cfg =
    mapMapWithKey filterSources cfg
    where
      filterSources from m =
        mapFilterWithKey (\to w -> f from to w) m
shortcutWeightMap :: CFG -> LabelMap (Maybe BlockId) -> CFG
shortcutWeightMap cfg cuts =
  foldl' applyMapping cfg $ mapToList cuts
    where
      applyMapping :: CFG -> (BlockId,Maybe BlockId) -> CFG
      
      applyMapping m (from, Nothing) =
        mapDelete from .
        fmap (mapDelete from) $ m
      
      applyMapping m (from, Just to) =
        let updatedMap :: CFG
            updatedMap
              = fmap (shortcutEdge (from,to)) $
                (mapDelete from m :: CFG )
        
        
        
        in case mapLookup to cuts of
            Nothing -> updatedMap
            Just dest -> applyMapping updatedMap (to, dest)
      
      shortcutEdge :: (BlockId, BlockId) -> LabelMap EdgeInfo -> LabelMap EdgeInfo
      shortcutEdge (from, to) m =
        case mapLookup from m of
          Just info -> mapInsert to info $ mapDelete from m
          Nothing   -> m
addImmediateSuccessor :: BlockId -> BlockId -> CFG -> CFG
addImmediateSuccessor node follower cfg
    = updateEdges . addWeightEdge node follower uncondWeight $ cfg
    where
        uncondWeight = fromIntegral . D.uncondWeight .
                       D.cfgWeightInfo $ D.unsafeGlobalDynFlags
        targets = getSuccessorEdges cfg node
        successors = map fst targets :: [BlockId]
        updateEdges = addNewSuccs . remOldSuccs
        remOldSuccs m = foldl' (flip (delEdge node)) m successors
        addNewSuccs m =
          foldl' (\m' (t,info) -> addEdge follower t info m') m targets
addEdge :: BlockId -> BlockId -> EdgeInfo -> CFG -> CFG
addEdge from to info cfg =
    mapAlter addDest from cfg
    where
        addDest Nothing = Just $ mapSingleton to info
        addDest (Just wm) = Just $ mapInsert to info wm
addWeightEdge :: BlockId -> BlockId -> EdgeWeight -> CFG -> CFG
addWeightEdge from to weight cfg =
    addEdge from to (mkWeightInfo weight) cfg
delEdge :: BlockId -> BlockId -> CFG -> CFG
delEdge from to m =
    mapAlter remDest from m
    where
        remDest Nothing = Nothing
        remDest (Just wm) = Just $ mapDelete to wm
getSuccEdgesSorted :: CFG -> BlockId -> [(BlockId,EdgeInfo)]
getSuccEdgesSorted m bid =
    let destMap = mapFindWithDefault mapEmpty bid m
        cfgEdges = mapToList destMap
        sortedEdges = sortWith (negate . edgeWeight . snd) cfgEdges
    in  
        sortedEdges
getSuccessorEdges :: CFG -> BlockId -> [(BlockId,EdgeInfo)]
getSuccessorEdges m bid = maybe [] mapToList $ mapLookup bid m
getEdgeInfo :: BlockId -> BlockId -> CFG -> Maybe EdgeInfo
getEdgeInfo from to m
    | Just wm <- mapLookup from m
    , Just info <- mapLookup to wm
    = Just $! info
    | otherwise
    = Nothing
reverseEdges :: CFG -> CFG
reverseEdges cfg = foldr add mapEmpty flatElems
  where
    elems = mapToList $ fmap mapToList cfg :: [(BlockId,[(BlockId,EdgeInfo)])]
    flatElems =
        concatMap (\(from,ws) -> map (\(to,info) -> (to,from,info)) ws ) elems
    add (to,from,info) m = addEdge to from info m
infoEdgeList :: CFG -> [CfgEdge]
infoEdgeList m =
  mapFoldMapWithKey
    (\from toMap ->
      map (\(to,info) -> CfgEdge from to info) (mapToList toMap))
    m
weightedEdgeList :: CFG -> [(BlockId,BlockId,EdgeWeight)]
weightedEdgeList m =
  mapFoldMapWithKey
    (\from toMap ->
      map (\(to,info) ->
        (from,to, edgeWeight info)) (mapToList toMap))
    m
      
edgeList :: CFG -> [Edge]
edgeList m =
        mapFoldMapWithKey (\from toMap -> fmap (from,) (mapKeys toMap)) m
getSuccessors :: CFG -> BlockId -> [BlockId]
getSuccessors m bid
    | Just wm <- mapLookup bid m
    = mapKeys wm
    | otherwise = []
pprEdgeWeights :: CFG -> SDoc
pprEdgeWeights m =
    let edges = sort $ weightedEdgeList m
        printEdge (from, to, weight)
            = text "\t" <> ppr from <+> text "->" <+> ppr to <>
              text "[label=\"" <> ppr weight <> text "\",weight=\"" <>
              ppr weight <> text "\"];\n"
        
        
        
        printNode node
            = text "\t" <> ppr node <> text ";\n"
        getEdgeNodes (from, to, _weight) = [from,to]
        edgeNodes = setFromList $ concatMap getEdgeNodes edges :: LabelSet
        nodes = filter (\n -> (not . setMember n) edgeNodes) . mapKeys $ mapFilter null m
    in
    text "digraph {\n" <>
        (foldl' (<>) empty (map printEdge edges)) <>
        (foldl' (<>) empty (map printNode nodes)) <>
    text "}\n"
{-# INLINE updateEdgeWeight #-} 
updateEdgeWeight :: (EdgeWeight -> EdgeWeight) -> Edge -> CFG -> CFG
updateEdgeWeight f (from, to) cfg
    | Just oldInfo <- getEdgeInfo from to cfg
    = let oldWeight = edgeWeight oldInfo
          newWeight = f oldWeight
      in addEdge from to (oldInfo {edgeWeight = newWeight}) cfg
    | otherwise
    = panic "Trying to update invalid edge"
mapWeights :: (BlockId -> BlockId -> EdgeWeight -> EdgeWeight) -> CFG -> CFG
mapWeights f cfg =
  foldl' (\cfg (CfgEdge from to info) ->
            let oldWeight = edgeWeight info
                newWeight = f from to oldWeight
            in addEdge from to (info {edgeWeight = newWeight}) cfg)
          cfg (infoEdgeList cfg)
addNodesBetween :: CFG -> [(BlockId,BlockId,BlockId)] -> CFG
addNodesBetween m updates =
  foldl'  updateWeight m .
          weightUpdates $ updates
    where
      weight = fromIntegral . D.uncondWeight .
                D.cfgWeightInfo $ D.unsafeGlobalDynFlags
      
      
      
      
      weightUpdates = map getWeight
      getWeight :: (BlockId,BlockId,BlockId) -> (BlockId,BlockId,BlockId,EdgeInfo)
      getWeight (from,between,old)
        | Just edgeInfo <- getEdgeInfo from old m
        = (from,between,old,edgeInfo)
        | otherwise
        = pprPanic "Can't find weight for edge that should have one" (
            text "triple" <+> ppr (from,between,old) $$
            text "updates" <+> ppr updates )
      updateWeight :: CFG -> (BlockId,BlockId,BlockId,EdgeInfo) -> CFG
      updateWeight m (from,between,old,edgeInfo)
        = addEdge from between edgeInfo .
          addWeightEdge between old weight .
          delEdge from old $ m
getCfgProc :: D.CfgWeights -> RawCmmDecl -> CFG
getCfgProc _       (CmmData {}) = mapEmpty
getCfgProc weights (CmmProc _info _lab _live graph) = getCfg weights graph
getCfg :: D.CfgWeights -> CmmGraph -> CFG
getCfg weights graph =
  foldl' insertEdge edgelessCfg $ concatMap getBlockEdges blocks
  where
    D.CFGWeights
            { D.uncondWeight = uncondWeight
            , D.condBranchWeight = condBranchWeight
            , D.switchWeight = switchWeight
            , D.callWeight = callWeight
            , D.likelyCondWeight = likelyCondWeight
            , D.unlikelyCondWeight = unlikelyCondWeight
            
            
            
            } = weights
    
    
    edgelessCfg = mapFromList $ zip (map G.entryLabel blocks) (repeat mapEmpty)
    insertEdge :: CFG -> ((BlockId,BlockId),EdgeInfo) -> CFG
    insertEdge m ((from,to),weight) =
      mapAlter f from m
        where
          f :: Maybe (LabelMap EdgeInfo) -> Maybe (LabelMap EdgeInfo)
          f Nothing = Just $ mapSingleton to weight
          f (Just destMap) = Just $ mapInsert to weight destMap
    getBlockEdges :: CmmBlock -> [((BlockId,BlockId),EdgeInfo)]
    getBlockEdges block =
      case branch of
        CmmBranch dest -> [mkEdge dest uncondWeight]
        CmmCondBranch _c t f l
          | l == Nothing ->
              [mkEdge f condBranchWeight,   mkEdge t condBranchWeight]
          | l == Just True ->
              [mkEdge f unlikelyCondWeight, mkEdge t likelyCondWeight]
          | l == Just False ->
              [mkEdge f likelyCondWeight,   mkEdge t unlikelyCondWeight]
        (CmmSwitch _e ids) ->
          let switchTargets = switchTargetsToList ids
              
              
              adjustedWeight =
                if (length switchTargets > 10) then -1 else switchWeight
          in map (\x -> mkEdge x adjustedWeight) switchTargets
        (CmmCall { cml_cont = Just cont})  -> [mkEdge cont callWeight]
        (CmmForeignCall {Cmm.succ = cont}) -> [mkEdge cont callWeight]
        (CmmCall { cml_cont = Nothing })   -> []
        other ->
            panic "Foo" $
            ASSERT2(False, ppr "Unkown successor cause:" <>
              (ppr branch <+> text "=>" <> ppr (G.successors other)))
            map (\x -> ((bid,x),mkEdgeInfo 0)) $ G.successors other
      where
        bid = G.entryLabel block
        mkEdgeInfo = EdgeInfo (CmmSource branch) . fromIntegral
        mkEdge target weight = ((bid,target), mkEdgeInfo weight)
        branch = lastNode block :: CmmNode O C
    blocks = revPostorder graph :: [CmmBlock]
findBackEdges :: BlockId -> CFG -> Edges
findBackEdges root cfg =
    
    map fst .
    filter (\x -> snd x == Backward) $ typedEdges
  where
    edges = edgeList cfg :: [(BlockId,BlockId)]
    getSuccs = getSuccessors cfg :: BlockId -> [BlockId]
    typedEdges =
      classifyEdges root getSuccs edges :: [((BlockId,BlockId),EdgeType)]
optimizeCFG :: D.CfgWeights -> RawCmmDecl -> CFG -> CFG
optimizeCFG _ (CmmData {}) cfg = cfg
optimizeCFG weights (CmmProc info _lab _live graph) cfg =
    favourFewerPreds  .
    penalizeInfoTables info .
    increaseBackEdgeWeight (g_entry graph) $ cfg
  where
    
    
    increaseBackEdgeWeight :: BlockId -> CFG -> CFG
    increaseBackEdgeWeight root cfg =
        let backedges = findBackEdges root cfg
            update weight
              
              | weight <= 0 = 0
              | otherwise
              = weight + fromIntegral (D.backEdgeBonus weights)
        in  foldl'  (\cfg edge -> updateEdgeWeight update edge cfg)
                    cfg backedges
    
    penalizeInfoTables :: LabelMap a -> CFG -> CFG
    penalizeInfoTables info cfg =
        mapWeights fupdate cfg
      where
        fupdate :: BlockId -> BlockId -> EdgeWeight -> EdgeWeight
        fupdate _ to weight
          | mapMember to info
          = weight - (fromIntegral $ D.infoTablePenalty weights)
          | otherwise = weight
    
    
    favourFewerPreds :: CFG -> CFG
    favourFewerPreds cfg =
        let
            revCfg =
              reverseEdges $ filterEdges
                              (\_from -> fallthroughTarget)  cfg
            predCount n = length $ getSuccessorEdges revCfg n
            nodes = getCfgNodes cfg
            modifiers :: Int -> Int -> (EdgeWeight, EdgeWeight)
            modifiers preds1 preds2
              | preds1 <  preds2 = ( 1,-1)
              | preds1 == preds2 = ( 0, 0)
              | otherwise        = (-1, 1)
            update cfg node
              | [(s1,e1),(s2,e2)] <- getSuccessorEdges cfg node
              , w1 <- edgeWeight e1
              , w2 <- edgeWeight e2
              
              , w1 == w2
              , (mod1,mod2) <- modifiers (predCount s1) (predCount s2)
              = (\cfg' ->
                  (adjustEdgeWeight cfg' (+mod2) node s2))
                  (adjustEdgeWeight cfg  (+mod1) node s1)
              | otherwise
              = cfg
        in setFoldl update cfg nodes
      where
        fallthroughTarget :: BlockId -> EdgeInfo -> Bool
        fallthroughTarget to (EdgeInfo source _weight)
          | mapMember to info = False
          | AsmCodeGen <- source = True
          | CmmSource (CmmBranch {}) <- source = True
          | CmmSource (CmmCondBranch {}) <- source = True
          | otherwise = False
loopMembers :: CFG -> LabelMap Bool
loopMembers cfg =
    foldl' (flip setLevel) mapEmpty sccs
  where
    mkNode :: BlockId -> Node BlockId BlockId
    mkNode bid = DigraphNode bid bid (getSuccessors cfg bid)
    nodes = map mkNode (setElems $ getCfgNodes cfg)
    sccs = stronglyConnCompFromEdgedVerticesOrd nodes
    setLevel :: SCC BlockId -> LabelMap Bool -> LabelMap Bool
    setLevel (AcyclicSCC bid) m = mapInsert bid False m
    setLevel (CyclicSCC bids) m = foldl' (\m k -> mapInsert k True m) m bids