{-# LANGUAGE TypeFamilies, ScopedTypeVariables, CPP #-}
{-# OPTIONS_GHC -fprof-auto #-}
module BlockLayout
    ( sequenceTop )
where
#include "HsVersions.h"
import GhcPrelude
import Instruction
import NCGMonad
import CFG
import BlockId
import Cmm
import Hoopl.Collections
import Hoopl.Label
import Hoopl.Block
import DynFlags (gopt, GeneralFlag(..), DynFlags, backendMaintainsCfg)
import UniqFM
import Util
import Unique
import Digraph
import Outputable
import Maybes
import ListSetOps (removeDups)
import PprCmm ()
import OrdList
import Data.List
import Data.Foldable (toList)
import Hoopl.Graph
import qualified Data.Set as Set
neighbourOverlapp :: Int
neighbourOverlapp = 2
fuseEdgeThreshold :: EdgeWeight
fuseEdgeThreshold = 0
type FrontierMap = LabelMap ([BlockId],BlockChain)
newtype BlockChain
    = BlockChain { chainBlocks :: (OrdList BlockId) }
instance Eq (BlockChain) where
    (BlockChain blks1) == (BlockChain blks2)
        = fromOL blks1 == fromOL blks2
instance Ord (BlockChain) where
   (BlockChain lbls1) `compare` (BlockChain lbls2)
       = (fromOL lbls1) `compare` (fromOL lbls2)
instance Outputable (BlockChain) where
    ppr (BlockChain blks) =
        parens (text "Chain:" <+> ppr (fromOL $ blks) )
data WeightedEdge = WeightedEdge !BlockId !BlockId EdgeWeight deriving (Eq)
instance Ord WeightedEdge where
  compare (WeightedEdge from1 to1 weight1)
          (WeightedEdge from2 to2 weight2)
    | weight1 < weight2 || weight1 == weight2 && from1 < from2 ||
      weight1 == weight2 && from1 == from2 && to1 < to2
    = LT
    | from1 == from2 && to1 == to2 && weight1 == weight2
    = EQ
    | otherwise
    = GT
instance Outputable WeightedEdge where
    ppr (WeightedEdge from to info) =
        ppr from <> text "->" <> ppr to <> brackets (ppr info)
type WeightedEdgeList = [WeightedEdge]
noDups :: [BlockChain] -> Bool
noDups chains =
    let chainBlocks = concatMap chainToBlocks chains :: [BlockId]
        (_blocks, dups) = removeDups compare chainBlocks
    in if null dups then True
        else pprTrace "Duplicates:" (ppr (map toList dups) $$ text "chains" <+> ppr chains ) False
inFront :: BlockId -> BlockChain -> Bool
inFront bid (BlockChain seq)
  = headOL seq == bid
chainMember :: BlockId -> BlockChain -> Bool
chainMember bid chain
  = elem bid $ fromOL . chainBlocks $ chain
chainSingleton :: BlockId -> BlockChain
chainSingleton lbl
    = BlockChain (unitOL lbl)
chainSnoc :: BlockChain -> BlockId -> BlockChain
chainSnoc (BlockChain blks) lbl
  = BlockChain (blks `snocOL` lbl)
chainConcat :: BlockChain -> BlockChain -> BlockChain
chainConcat (BlockChain blks1) (BlockChain blks2)
  = BlockChain (blks1 `appOL` blks2)
chainToBlocks :: BlockChain -> [BlockId]
chainToBlocks (BlockChain blks) = fromOL blks
breakChainAt :: BlockId -> BlockChain
             -> (BlockChain,BlockChain)
breakChainAt bid (BlockChain blks)
    | not (bid == head rblks)
    = panic "Block not in chain"
    | otherwise
    = (BlockChain (toOL lblks),
       BlockChain (toOL rblks))
  where
    (lblks, rblks) = break (\lbl -> lbl == bid) (fromOL blks)
takeR :: Int -> BlockChain -> [BlockId]
takeR n (BlockChain blks) =
    take n . fromOLReverse $ blks
takeL :: Int -> BlockChain -> [BlockId]
takeL n (BlockChain blks) =
    take n . fromOL $ blks
fuseChains :: WeightedEdgeList -> LabelMap BlockChain
           -> (LabelMap BlockChain, Set.Set WeightedEdge)
fuseChains weights chains
    = let fronts = mapFromList $
                    map (\chain -> (headOL . chainBlocks $ chain,chain)) $
                    mapElems chains :: LabelMap BlockChain
          (chains', used, _) = applyEdges weights chains fronts Set.empty
      in (chains', used)
    where
        applyEdges :: WeightedEdgeList -> LabelMap BlockChain
                   -> LabelMap BlockChain -> Set.Set WeightedEdge
                   -> (LabelMap BlockChain, Set.Set WeightedEdge, LabelMap BlockChain)
        applyEdges [] chainsEnd chainsFront used
            = (chainsEnd, used, chainsFront)
        applyEdges (edge@(WeightedEdge from to w):edges) chainsEnd chainsFront used
            
            | w <= fuseEdgeThreshold
            = ( chainsEnd, used, chainsFront)
            
            | Just c1 <- mapLookup from chainsEnd
            , Just c2 <- mapLookup to chainsFront
            , c1 /= c2
            = let newChain = chainConcat c1 c2
                  front = headOL . chainBlocks $ newChain
                  end = lastOL . chainBlocks $ newChain
                  chainsFront' = mapInsert front newChain $
                                 mapDelete to chainsFront
                  chainsEnd'   = mapInsert end newChain $
                                 mapDelete from chainsEnd
              in applyEdges edges chainsEnd' chainsFront'
                            (Set.insert edge used)
            | otherwise
            
            = applyEdges edges chainsEnd chainsFront used
combineNeighbourhood :: WeightedEdgeList -> [BlockChain]
                     -> [BlockChain]
combineNeighbourhood edges chains
    = 
      applyEdges edges endFrontier startFrontier
    where
        
        endFrontier, startFrontier :: FrontierMap
        endFrontier =
            mapFromList $ concatMap (\chain ->
                                let ends = getEnds chain :: [BlockId]
                                    entry = (ends,chain)
                                in map (\x -> (x,entry)) ends ) chains
        startFrontier =
            mapFromList $ concatMap (\chain ->
                                let front = getFronts chain
                                    entry = (front,chain)
                                in map (\x -> (x,entry)) front) chains
        applyEdges :: WeightedEdgeList -> FrontierMap -> FrontierMap
                   -> [BlockChain]
        applyEdges [] chainEnds _chainFronts =
            ordNub $ map snd $ mapElems chainEnds
        applyEdges ((WeightedEdge from to _w):edges) chainEnds chainFronts
            | Just (c1_e,c1) <- mapLookup from chainEnds
            , Just (c2_f,c2) <- mapLookup to chainFronts
            , c1 /= c2 
            = let newChain = chainConcat c1 c2
                  newChainFrontier = getFronts newChain
                  newChainEnds = getEnds newChain
                  newFronts :: FrontierMap
                  newFronts =
                    let withoutOld =
                            foldl' (\m b -> mapDelete b m :: FrontierMap) chainFronts (c2_f ++ getFronts c1)
                        entry =
                            (newChainFrontier,newChain) 
                    in foldl' (\m x -> mapInsert x entry m)
                              withoutOld newChainFrontier
                  newEnds =
                    let withoutOld = foldl' (\m b -> mapDelete b m) chainEnds (c1_e ++ getEnds c2)
                        entry = (newChainEnds,newChain) 
                    in foldl' (\m x -> mapInsert x entry m)
                              withoutOld newChainEnds
              in
                
                
                
                
                
                
                
                
                
                
                
                
                
                 applyEdges edges newEnds newFronts
            | otherwise
            = 
              applyEdges edges chainEnds chainFronts
         where
        getFronts chain = takeL neighbourOverlapp chain
        getEnds chain = takeR neighbourOverlapp chain
buildChains :: CFG -> [BlockId]
            -> ( LabelMap BlockChain  
               , Set.Set (BlockId, BlockId)) 
buildChains succWeights blocks
  = let (_, fusedEdges, chains) = buildNext setEmpty mapEmpty blocks Set.empty
    in (chains, fusedEdges)
  where
    
    
    
    buildNext :: LabelSet
              -> LabelMap BlockChain 
              -> [BlockId] 
              -> Set.Set (BlockId, BlockId)
              -> ( [BlockChain]  
                 , Set.Set (BlockId, BlockId) 
                 , LabelMap BlockChain
                 )
    buildNext _placed chains [] linked =
        ([], linked, chains)
    buildNext placed chains (block:todo) linked
        | setMember block placed
        = buildNext placed chains todo linked
        | otherwise
        = buildNext placed' chains' todo linked'
      where
        placed' = (foldl' (flip setInsert) placed placedBlocks)
        linked' = Set.union linked linkedEdges
        (placedBlocks, chains', linkedEdges) = findChain block
        
        
        
        findChain :: BlockId
                -> ([BlockId],LabelMap BlockChain, Set.Set (BlockId, BlockId))
        findChain block
        
        
          | (pred:_) <- preds
          , alreadyPlaced pred
          , Just predChain <- mapLookup pred chains
          , (best:_) <- filter (not . alreadyPlaced) $ getSuccs pred
          , best == lbl
          = 
            let newChain = chainSnoc predChain block
                chainMap = mapInsert lbl newChain $ mapDelete pred chains
            in  ( [lbl]
                , chainMap
                , Set.singleton (pred,lbl) )
          | otherwise
          = 
            ( [lbl]
            , mapInsert lbl (chainSingleton lbl) chains
            , Set.empty)
            where
              alreadyPlaced blkId = (setMember blkId placed)
              lbl = block
              getSuccs = map fst . getSuccEdgesSorted succWeights
              preds = map fst $ getSuccEdgesSorted predWeights lbl
    
    predWeights = reverseEdges succWeights
newtype BlockNode e x = BN (BlockId,[BlockId])
instance NonLocal (BlockNode) where
  entryLabel (BN (lbl,_))   = lbl
  successors (BN (_,succs)) = succs
fromNode :: BlockNode C C -> BlockId
fromNode (BN x) = fst x
sequenceChain :: forall a i. (Instruction i, Outputable i) => LabelMap a -> CFG
            -> [GenBasicBlock i] -> [GenBasicBlock i]
sequenceChain _info _weights    [] = []
sequenceChain _info _weights    [x] = [x]
sequenceChain  info weights'     blocks@((BasicBlock entry _):_) =
    
    
    
    let weights :: CFG
        weights
            = filterEdges (\_f _t edgeInfo -> edgeWeight edgeInfo > 0) weights'
        blockMap :: LabelMap (GenBasicBlock i)
        blockMap
            = foldl' (\m blk@(BasicBlock lbl _ins) ->
                        mapInsert lbl blk m)
                     mapEmpty blocks
        toNode :: BlockId -> BlockNode C C
        toNode bid =
            
            BN (bid,map fst . getSuccEdgesSorted weights' $ bid)
        orderedBlocks :: [BlockId]
        orderedBlocks
            = map fromNode $
              revPostorderFrom (fmap (toNode . blockId) blockMap) entry
        (builtChains, builtEdges)
            = {-# SCC "buildChains" #-}
              
              
              buildChains weights orderedBlocks
        rankedEdges :: WeightedEdgeList
        
        rankedEdges =
            map (\(from, to, weight) -> WeightedEdge from to weight) .
            filter (\(from, to, _)
                        -> not (Set.member (from,to) builtEdges)) .
            sortWith (\(_,_,w) -> - w) $ weightedEdgeList weights
        (fusedChains, fusedEdges)
            = ASSERT(noDups $ mapElems builtChains)
              {-# SCC "fuseChains" #-}
              
              
              fuseChains rankedEdges builtChains
        rankedEdges' =
            filter (\edge -> not $ Set.member edge fusedEdges) $ rankedEdges
        neighbourChains
            = ASSERT(noDups $ mapElems fusedChains)
              {-# SCC "groupNeighbourChains" #-}
              
              combineNeighbourhood rankedEdges' (mapElems fusedChains)
        
        ([entryChain],chains')
            = ASSERT(noDups $ neighbourChains)
              partition (chainMember entry) neighbourChains
        (entryChain':entryRest)
            | inFront entry entryChain = [entryChain]
            | (rest,entry) <- breakChainAt entry entryChain
            = [entry,rest]
            | otherwise = pprPanic "Entry point eliminated" $
                            ppr ([entryChain],chains')
        prepedChains
            = entryChain':(entryRest++chains') :: [BlockChain]
        blockList
            
            = (concatMap fromOL $ map chainBlocks prepedChains)
        
        chainPlaced = setFromList $ blockList :: LabelSet
        unplaced =
            let blocks = mapKeys blockMap
                isPlaced b = setMember (b) chainPlaced
            in filter (\block -> not (isPlaced block)) blocks
        placedBlocks =
            
            blockList ++ unplaced
        getBlock bid = expectJust "Block placment" $ mapLookup bid blockMap
    in
        
        ASSERT(all (\bid -> mapMember bid blockMap) placedBlocks)
        dropJumps info $ map getBlock placedBlocks
dropJumps :: forall a i. Instruction i => LabelMap a -> [GenBasicBlock i]
          -> [GenBasicBlock i]
dropJumps _    [] = []
dropJumps info ((BasicBlock lbl ins):todo)
    | not . null $ ins 
    , [dest] <- jumpDestsOfInstr (last ins)
    , ((BasicBlock nextLbl _) : _) <- todo
    , not (mapMember dest info)
    , nextLbl == dest
    = BasicBlock lbl (init ins) : dropJumps info todo
    | otherwise
    = BasicBlock lbl ins : dropJumps info todo
sequenceTop
    :: (Instruction instr, Outputable instr)
    => DynFlags 
    -> NcgImpl statics instr jumpDest -> CFG
    -> NatCmmDecl statics instr -> NatCmmDecl statics instr
sequenceTop _     _       _           top@(CmmData _ _) = top
sequenceTop dflags ncgImpl edgeWeights
            (CmmProc info lbl live (ListGraph blocks))
  | (gopt Opt_CfgBlocklayout dflags) && backendMaintainsCfg dflags
  
  = CmmProc info lbl live ( ListGraph $ ncgMakeFarBranches ncgImpl info $
                            sequenceChain info edgeWeights blocks )
  | otherwise
  
  = CmmProc info lbl live ( ListGraph $ ncgMakeFarBranches ncgImpl info $
                            sequenceBlocks cfg info blocks)
  where
    cfg
      | (gopt Opt_WeightlessBlocklayout dflags) ||
        (not $ backendMaintainsCfg dflags)
      
      = Nothing
      
      | otherwise = Just edgeWeights
sequenceBlocks :: Instruction inst => Maybe CFG -> LabelMap a
               -> [GenBasicBlock inst] -> [GenBasicBlock inst]
sequenceBlocks _edgeWeight _ [] = []
sequenceBlocks edgeWeights infos (entry:blocks) =
    let entryNode = mkNode edgeWeights entry
        bodyNodes = reverse
                    (flattenSCCs (sccBlocks edgeWeights blocks))
    in dropJumps infos . seqBlocks infos $ ( entryNode : bodyNodes)
  
sccBlocks
        :: Instruction instr
        => Maybe CFG -> [NatBasicBlock instr]
        -> [SCC (Node BlockId (NatBasicBlock instr))]
sccBlocks edgeWeights blocks =
    stronglyConnCompFromEdgedVerticesUniqR
        (map (mkNode edgeWeights) blocks)
mkNode :: (Instruction t)
       => Maybe CFG -> GenBasicBlock t
       -> Node BlockId (GenBasicBlock t)
mkNode edgeWeights block@(BasicBlock id instrs) =
    DigraphNode block id outEdges
  where
    outEdges :: [BlockId]
    outEdges
      
      = successor
      where
        successor
          | Just successors <- fmap (`getSuccEdgesSorted` id)
                                    edgeWeights 
          = case successors of
            [] -> []
            ((target,info):_)
              | length successors > 2 || edgeWeight info <= 0 -> []
              | otherwise -> [target]
          | otherwise
          = case jumpDestsOfInstr (last instrs) of
                [one] -> [one]
                _many -> []
seqBlocks :: LabelMap i -> [Node BlockId (GenBasicBlock t1)]
                        -> [GenBasicBlock t1]
seqBlocks infos blocks = placeNext pullable0 todo0
  where
    
    
    
    
    
    pullable0 = listToUFM [ (i,(b,n)) | DigraphNode b i n <- blocks ]
    todo0     = map node_key blocks
    placeNext _ [] = []
    placeNext pullable (i:rest)
        | Just (block, pullable') <- lookupDeleteUFM pullable i
        = place pullable' rest block
        | otherwise
        
        = placeNext pullable rest
    place pullable todo (block,[])
                          = block : placeNext pullable todo
    place pullable todo (block@(BasicBlock id instrs),[next])
        | mapMember next infos
        = block : placeNext pullable todo
        | Just (nextBlock, pullable') <- lookupDeleteUFM pullable next
        = BasicBlock id instrs : place pullable' todo nextBlock
        | otherwise
        = block : placeNext pullable todo
    place _ _ (_,tooManyNextNodes)
        = pprPanic "seqBlocks" (ppr tooManyNextNodes)
lookupDeleteUFM :: Uniquable key => UniqFM elt -> key
                -> Maybe (elt, UniqFM elt)
lookupDeleteUFM m k = do 
    v <- lookupUFM m k
    return (v, delFromUFM m k)