{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses  #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}

--
-- Copyright (c) 2010, João Dias, Simon Marlow, Simon Peyton Jones,
-- and Norman Ramsey
--
-- Modifications copyright (c) The University of Glasgow 2012
--
-- This module is a specialised and optimised version of
-- Compiler.Hoopl.Dataflow in the hoopl package.  In particular it is
-- specialised to the UniqSM monad.
--

module GHC.Cmm.Dataflow
  ( C, O, Block
  , lastNode, entryLabel
  , foldNodesBwdOO
  , foldRewriteNodesBwdOO
  , DataflowLattice(..), OldFact(..), NewFact(..), JoinedFact(..)
  , TransferFun, RewriteFun
  , Fact, FactBase
  , getFact, mkFactBase
  , analyzeCmmFwd, analyzeCmmBwd
  , rewriteCmmBwd
  , changedIf
  , joinOutFacts
  , joinFacts
  )
where

import GHC.Prelude

import GHC.Cmm
import GHC.Types.Unique.Supply

import Data.Array
import Data.Maybe
import Data.IntSet (IntSet)
import qualified Data.IntSet as IntSet
import Data.Kind (Type)

import GHC.Cmm.Dataflow.Block
import GHC.Cmm.Dataflow.Graph
import GHC.Cmm.Dataflow.Collections
import GHC.Cmm.Dataflow.Label

type family   Fact (x :: Extensibility) f :: Type
type instance Fact C f = FactBase f
type instance Fact O f = f

newtype OldFact a = OldFact a

newtype NewFact a = NewFact a

-- | The result of joining OldFact and NewFact.
data JoinedFact a
    = Changed !a     -- ^ Result is different than OldFact.
    | NotChanged !a  -- ^ Result is the same as OldFact.

getJoined :: JoinedFact a -> a
getJoined :: forall a. JoinedFact a -> a
getJoined (Changed a
a) = a
a
getJoined (NotChanged a
a) = a
a

changedIf :: Bool -> a -> JoinedFact a
changedIf :: forall a. Bool -> a -> JoinedFact a
changedIf Bool
True = a -> JoinedFact a
forall a. a -> JoinedFact a
Changed
changedIf Bool
False = a -> JoinedFact a
forall a. a -> JoinedFact a
NotChanged

type JoinFun a = OldFact a -> NewFact a -> JoinedFact a

data DataflowLattice a = DataflowLattice
    { forall a. DataflowLattice a -> a
fact_bot :: a
    , forall a. DataflowLattice a -> JoinFun a
fact_join :: JoinFun a
    }

data Direction = Fwd | Bwd

type TransferFun f = CmmBlock -> FactBase f -> FactBase f

-- | `TransferFun` abstracted over `n` (the node type)
type TransferFun' (n :: Extensibility -> Extensibility -> Type) f =
    Block n C C -> FactBase f -> FactBase f


-- | Function for rewriting and analysis combined. To be used with
-- @rewriteCmm@.
--
-- Currently set to work with @UniqSM@ monad, but we could probably abstract
-- that away (if we do that, we might want to specialize the fixpoint algorithms
-- to the particular monads through SPECIALIZE).
type RewriteFun f = CmmBlock -> FactBase f -> UniqSM (CmmBlock, FactBase f)

-- | `RewriteFun` abstracted over `n` (the node type)
type RewriteFun' (n :: Extensibility -> Extensibility -> Type) f =
    Block n C C -> FactBase f -> UniqSM (Block n C C, FactBase f)

analyzeCmmBwd, analyzeCmmFwd
    :: (NonLocal node)
    => DataflowLattice f
    -> TransferFun' node f
    -> GenCmmGraph node
    -> FactBase f
    -> FactBase f
analyzeCmmBwd :: forall (node :: Extensibility -> Extensibility -> *) f.
NonLocal node =>
DataflowLattice f
-> TransferFun' node f
-> GenCmmGraph node
-> FactBase f
-> FactBase f
analyzeCmmBwd = Direction
-> DataflowLattice f
-> TransferFun' node f
-> GenCmmGraph node
-> FactBase f
-> FactBase f
forall (node :: Extensibility -> Extensibility -> *) f.
NonLocal node =>
Direction
-> DataflowLattice f
-> TransferFun' node f
-> GenCmmGraph node
-> FactBase f
-> FactBase f
analyzeCmm Direction
Bwd
analyzeCmmFwd :: forall (node :: Extensibility -> Extensibility -> *) f.
NonLocal node =>
DataflowLattice f
-> TransferFun' node f
-> GenCmmGraph node
-> FactBase f
-> FactBase f
analyzeCmmFwd = Direction
-> DataflowLattice f
-> TransferFun' node f
-> GenCmmGraph node
-> FactBase f
-> FactBase f
forall (node :: Extensibility -> Extensibility -> *) f.
NonLocal node =>
Direction
-> DataflowLattice f
-> TransferFun' node f
-> GenCmmGraph node
-> FactBase f
-> FactBase f
analyzeCmm Direction
Fwd

analyzeCmm
    :: (NonLocal node)
    => Direction
    -> DataflowLattice f
    -> TransferFun' node f
    -> GenCmmGraph node
    -> FactBase f
    -> FactBase f
analyzeCmm :: forall (node :: Extensibility -> Extensibility -> *) f.
NonLocal node =>
Direction
-> DataflowLattice f
-> TransferFun' node f
-> GenCmmGraph node
-> FactBase f
-> FactBase f
analyzeCmm Direction
dir DataflowLattice f
lattice TransferFun' node f
transfer GenCmmGraph node
cmmGraph FactBase f
initFact =
    {-# SCC analyzeCmm #-}
    let entry :: BlockId
entry = GenCmmGraph node -> BlockId
forall (n :: Extensibility -> Extensibility -> *).
GenCmmGraph n -> BlockId
g_entry GenCmmGraph node
cmmGraph
        hooplGraph :: Graph node C C
hooplGraph = GenCmmGraph node -> Graph node C C
forall (n :: Extensibility -> Extensibility -> *).
GenCmmGraph n -> Graph n C C
g_graph GenCmmGraph node
cmmGraph
        blockMap :: LabelMap (Block node C C)
blockMap =
            case Graph node C C
hooplGraph of
                GMany MaybeO C (Block node O C)
NothingO LabelMap (Block node C C)
bm MaybeO C (Block node C O)
NothingO -> LabelMap (Block node C C)
bm
    in Direction
-> DataflowLattice f
-> TransferFun' node f
-> BlockId
-> LabelMap (Block node C C)
-> FactBase f
-> FactBase f
forall f (node :: Extensibility -> Extensibility -> *).
NonLocal node =>
Direction
-> DataflowLattice f
-> TransferFun' node f
-> BlockId
-> LabelMap (Block node C C)
-> FactBase f
-> FactBase f
fixpointAnalysis Direction
dir DataflowLattice f
lattice TransferFun' node f
transfer BlockId
entry LabelMap (Block node C C)
blockMap FactBase f
initFact

-- Fixpoint algorithm.
fixpointAnalysis
    :: forall f node.
       (NonLocal node)
    => Direction
    -> DataflowLattice f
    -> TransferFun' node f
    -> Label
    -> LabelMap (Block node C C)
    -> FactBase f
    -> FactBase f
fixpointAnalysis :: forall f (node :: Extensibility -> Extensibility -> *).
NonLocal node =>
Direction
-> DataflowLattice f
-> TransferFun' node f
-> BlockId
-> LabelMap (Block node C C)
-> FactBase f
-> FactBase f
fixpointAnalysis Direction
direction DataflowLattice f
lattice TransferFun' node f
do_block BlockId
entry LabelMap (Block node C C)
blockmap = IntHeap -> FactBase f -> FactBase f
loop IntHeap
start
  where
    -- Sorting the blocks helps to minimize the number of times we need to
    -- process blocks. For instance, for forward analysis we want to look at
    -- blocks in reverse postorder. Also, see comments for sortBlocks.
    blocks :: [Block node C C]
blocks     = Direction
-> BlockId -> LabelMap (Block node C C) -> [Block node C C]
forall (n :: Extensibility -> Extensibility -> *).
NonLocal n =>
Direction -> BlockId -> LabelMap (Block n C C) -> [Block n C C]
sortBlocks Direction
direction BlockId
entry LabelMap (Block node C C)
blockmap
    num_blocks :: Key
num_blocks = [Block node C C] -> Key
forall a. [a] -> Key
forall (t :: * -> *) a. Foldable t => t a -> Key
length [Block node C C]
blocks
    block_arr :: Array Key (Block node C C)
block_arr  = {-# SCC "block_arr" #-} (Key, Key) -> [Block node C C] -> Array Key (Block node C C)
forall i e. Ix i => (i, i) -> [e] -> Array i e
listArray (Key
0, Key
num_blocks Key -> Key -> Key
forall a. Num a => a -> a -> a
- Key
1) [Block node C C]
blocks
    start :: IntHeap
start      = {-# SCC "start" #-} [Key] -> IntHeap
IntSet.fromDistinctAscList
      [Key
0 .. Key
num_blocks Key -> Key -> Key
forall a. Num a => a -> a -> a
- Key
1]
    dep_blocks :: LabelMap IntHeap
dep_blocks = {-# SCC "dep_blocks" #-} Direction -> [Block node C C] -> LabelMap IntHeap
forall (node :: Extensibility -> Extensibility -> *).
NonLocal node =>
Direction -> [Block node C C] -> LabelMap IntHeap
mkDepBlocks Direction
direction [Block node C C]
blocks
    join :: JoinFun f
join       = DataflowLattice f -> JoinFun f
forall a. DataflowLattice a -> JoinFun a
fact_join DataflowLattice f
lattice

    loop
        :: IntHeap     -- Worklist, i.e., blocks to process
        -> FactBase f  -- Current result (increases monotonically)
        -> FactBase f
    loop :: IntHeap -> FactBase f -> FactBase f
loop IntHeap
todo !FactBase f
fbase1 | Just (Key
index, IntHeap
todo1) <- IntHeap -> Maybe (Key, IntHeap)
IntSet.minView IntHeap
todo =
        let block :: Block node C C
block = Array Key (Block node C C)
block_arr Array Key (Block node C C) -> Key -> Block node C C
forall i e. Ix i => Array i e -> i -> e
! Key
index
            out_facts :: FactBase f
out_facts = {-# SCC "do_block" #-} TransferFun' node f
do_block Block node C C
block FactBase f
fbase1
            -- For each of the outgoing edges, we join it with the current
            -- information in fbase1 and (if something changed) we update it
            -- and add the affected blocks to the worklist.
            (IntHeap
todo2, FactBase f
fbase2) = {-# SCC "mapFoldWithKey" #-}
                ((IntHeap, FactBase f)
 -> KeyOf LabelMap -> f -> (IntHeap, FactBase f))
-> (IntHeap, FactBase f) -> FactBase f -> (IntHeap, FactBase f)
forall b a. (b -> KeyOf LabelMap -> a -> b) -> b -> LabelMap a -> b
forall (map :: * -> *) b a.
IsMap map =>
(b -> KeyOf map -> a -> b) -> b -> map a -> b
mapFoldlWithKey
                    (JoinFun f
-> LabelMap IntHeap
-> (IntHeap, FactBase f)
-> BlockId
-> f
-> (IntHeap, FactBase f)
forall f.
JoinFun f
-> LabelMap IntHeap
-> (IntHeap, FactBase f)
-> BlockId
-> f
-> (IntHeap, FactBase f)
updateFact JoinFun f
join LabelMap IntHeap
dep_blocks) (IntHeap
todo1, FactBase f
fbase1) FactBase f
out_facts
        in IntHeap -> FactBase f -> FactBase f
loop IntHeap
todo2 FactBase f
fbase2
    loop IntHeap
_ !FactBase f
fbase1 = FactBase f
fbase1

rewriteCmmBwd
    :: (NonLocal node)
    => DataflowLattice f
    -> RewriteFun' node f
    -> GenCmmGraph node
    -> FactBase f
    -> UniqSM (GenCmmGraph node, FactBase f)
rewriteCmmBwd :: forall (node :: Extensibility -> Extensibility -> *) f.
NonLocal node =>
DataflowLattice f
-> RewriteFun' node f
-> GenCmmGraph node
-> FactBase f
-> UniqSM (GenCmmGraph node, FactBase f)
rewriteCmmBwd = Direction
-> DataflowLattice f
-> RewriteFun' node f
-> GenCmmGraph node
-> FactBase f
-> UniqSM (GenCmmGraph node, FactBase f)
forall (node :: Extensibility -> Extensibility -> *) f.
NonLocal node =>
Direction
-> DataflowLattice f
-> RewriteFun' node f
-> GenCmmGraph node
-> FactBase f
-> UniqSM (GenCmmGraph node, FactBase f)
rewriteCmm Direction
Bwd

rewriteCmm
    :: (NonLocal node)
    => Direction
    -> DataflowLattice f
    -> RewriteFun' node f
    -> GenCmmGraph node
    -> FactBase f
    -> UniqSM (GenCmmGraph node, FactBase f)
rewriteCmm :: forall (node :: Extensibility -> Extensibility -> *) f.
NonLocal node =>
Direction
-> DataflowLattice f
-> RewriteFun' node f
-> GenCmmGraph node
-> FactBase f
-> UniqSM (GenCmmGraph node, FactBase f)
rewriteCmm Direction
dir DataflowLattice f
lattice RewriteFun' node f
rwFun GenCmmGraph node
cmmGraph FactBase f
initFact = {-# SCC rewriteCmm #-} do
    let entry :: BlockId
entry = GenCmmGraph node -> BlockId
forall (n :: Extensibility -> Extensibility -> *).
GenCmmGraph n -> BlockId
g_entry GenCmmGraph node
cmmGraph
        hooplGraph :: Graph node C C
hooplGraph = GenCmmGraph node -> Graph node C C
forall (n :: Extensibility -> Extensibility -> *).
GenCmmGraph n -> Graph n C C
g_graph GenCmmGraph node
cmmGraph
        blockMap1 :: LabelMap (Block node C C)
blockMap1 =
            case Graph node C C
hooplGraph of
                GMany MaybeO C (Block node O C)
NothingO LabelMap (Block node C C)
bm MaybeO C (Block node C O)
NothingO -> LabelMap (Block node C C)
bm
    (LabelMap (Block node C C)
blockMap2, FactBase f
facts) <-
        Direction
-> DataflowLattice f
-> RewriteFun' node f
-> BlockId
-> LabelMap (Block node C C)
-> FactBase f
-> UniqSM (LabelMap (Block node C C), FactBase f)
forall f (node :: Extensibility -> Extensibility -> *).
NonLocal node =>
Direction
-> DataflowLattice f
-> RewriteFun' node f
-> BlockId
-> LabelMap (Block node C C)
-> FactBase f
-> UniqSM (LabelMap (Block node C C), FactBase f)
fixpointRewrite Direction
dir DataflowLattice f
lattice RewriteFun' node f
rwFun BlockId
entry LabelMap (Block node C C)
blockMap1 FactBase f
initFact
    (GenCmmGraph node, FactBase f)
-> UniqSM (GenCmmGraph node, FactBase f)
forall a. a -> UniqSM a
forall (m :: * -> *) a. Monad m => a -> m a
return (GenCmmGraph node
cmmGraph {g_graph = GMany NothingO blockMap2 NothingO}, FactBase f
facts)

fixpointRewrite
    :: forall f node.
       NonLocal node
    => Direction
    -> DataflowLattice f
    -> RewriteFun' node f
    -> Label
    -> LabelMap (Block node C C)
    -> FactBase f
    -> UniqSM (LabelMap (Block node C C), FactBase f)
fixpointRewrite :: forall f (node :: Extensibility -> Extensibility -> *).
NonLocal node =>
Direction
-> DataflowLattice f
-> RewriteFun' node f
-> BlockId
-> LabelMap (Block node C C)
-> FactBase f
-> UniqSM (LabelMap (Block node C C), FactBase f)
fixpointRewrite Direction
dir DataflowLattice f
lattice RewriteFun' node f
do_block BlockId
entry LabelMap (Block node C C)
blockmap = IntHeap
-> LabelMap (Block node C C)
-> FactBase f
-> UniqSM (LabelMap (Block node C C), FactBase f)
loop IntHeap
start LabelMap (Block node C C)
blockmap
  where
    -- Sorting the blocks helps to minimize the number of times we need to
    -- process blocks. For instance, for forward analysis we want to look at
    -- blocks in reverse postorder. Also, see comments for sortBlocks.
    blocks :: [Block node C C]
blocks     = Direction
-> BlockId -> LabelMap (Block node C C) -> [Block node C C]
forall (n :: Extensibility -> Extensibility -> *).
NonLocal n =>
Direction -> BlockId -> LabelMap (Block n C C) -> [Block n C C]
sortBlocks Direction
dir BlockId
entry LabelMap (Block node C C)
blockmap
    num_blocks :: Key
num_blocks = [Block node C C] -> Key
forall a. [a] -> Key
forall (t :: * -> *) a. Foldable t => t a -> Key
length [Block node C C]
blocks
    block_arr :: Array Key (Block node C C)
block_arr  = {-# SCC "block_arr_rewrite" #-}
                 (Key, Key) -> [Block node C C] -> Array Key (Block node C C)
forall i e. Ix i => (i, i) -> [e] -> Array i e
listArray (Key
0, Key
num_blocks Key -> Key -> Key
forall a. Num a => a -> a -> a
- Key
1) [Block node C C]
blocks
    start :: IntHeap
start      = {-# SCC "start_rewrite" #-}
                 [Key] -> IntHeap
IntSet.fromDistinctAscList [Key
0 .. Key
num_blocks Key -> Key -> Key
forall a. Num a => a -> a -> a
- Key
1]
    dep_blocks :: LabelMap IntHeap
dep_blocks = {-# SCC "dep_blocks_rewrite" #-} Direction -> [Block node C C] -> LabelMap IntHeap
forall (node :: Extensibility -> Extensibility -> *).
NonLocal node =>
Direction -> [Block node C C] -> LabelMap IntHeap
mkDepBlocks Direction
dir [Block node C C]
blocks
    join :: JoinFun f
join       = DataflowLattice f -> JoinFun f
forall a. DataflowLattice a -> JoinFun a
fact_join DataflowLattice f
lattice

    loop
        :: IntHeap                    -- Worklist, i.e., blocks to process
        -> LabelMap (Block node C C)  -- Rewritten blocks.
        -> FactBase f                 -- Current facts.
        -> UniqSM (LabelMap (Block node C C), FactBase f)
    loop :: IntHeap
-> LabelMap (Block node C C)
-> FactBase f
-> UniqSM (LabelMap (Block node C C), FactBase f)
loop IntHeap
todo !LabelMap (Block node C C)
blocks1 !FactBase f
fbase1
      | Just (Key
index, IntHeap
todo1) <- IntHeap -> Maybe (Key, IntHeap)
IntSet.minView IntHeap
todo = do
        -- Note that we use the *original* block here. This is important.
        -- We're optimistically rewriting blocks even before reaching the fixed
        -- point, which means that the rewrite might be incorrect. So if the
        -- facts change, we need to rewrite the original block again (taking
        -- into account the new facts).
        let block :: Block node C C
block = Array Key (Block node C C)
block_arr Array Key (Block node C C) -> Key -> Block node C C
forall i e. Ix i => Array i e -> i -> e
! Key
index
        (Block node C C
new_block, FactBase f
out_facts) <- {-# SCC "do_block_rewrite" #-}
            RewriteFun' node f
do_block Block node C C
block FactBase f
fbase1
        let blocks2 :: LabelMap (Block node C C)
blocks2 = KeyOf LabelMap
-> Block node C C
-> LabelMap (Block node C C)
-> LabelMap (Block node C C)
forall a. KeyOf LabelMap -> a -> LabelMap a -> LabelMap a
forall (map :: * -> *) a.
IsMap map =>
KeyOf map -> a -> map a -> map a
mapInsert (Block node C C -> BlockId
forall (x :: Extensibility). Block node C x -> BlockId
forall (thing :: Extensibility -> Extensibility -> *)
       (x :: Extensibility).
NonLocal thing =>
thing C x -> BlockId
entryLabel Block node C C
new_block) Block node C C
new_block LabelMap (Block node C C)
blocks1
            (IntHeap
todo2, FactBase f
fbase2) = {-# SCC "mapFoldWithKey_rewrite" #-}
                ((IntHeap, FactBase f)
 -> KeyOf LabelMap -> f -> (IntHeap, FactBase f))
-> (IntHeap, FactBase f) -> FactBase f -> (IntHeap, FactBase f)
forall b a. (b -> KeyOf LabelMap -> a -> b) -> b -> LabelMap a -> b
forall (map :: * -> *) b a.
IsMap map =>
(b -> KeyOf map -> a -> b) -> b -> map a -> b
mapFoldlWithKey
                    (JoinFun f
-> LabelMap IntHeap
-> (IntHeap, FactBase f)
-> BlockId
-> f
-> (IntHeap, FactBase f)
forall f.
JoinFun f
-> LabelMap IntHeap
-> (IntHeap, FactBase f)
-> BlockId
-> f
-> (IntHeap, FactBase f)
updateFact JoinFun f
join LabelMap IntHeap
dep_blocks) (IntHeap
todo1, FactBase f
fbase1) FactBase f
out_facts
        IntHeap
-> LabelMap (Block node C C)
-> FactBase f
-> UniqSM (LabelMap (Block node C C), FactBase f)
loop IntHeap
todo2 LabelMap (Block node C C)
blocks2 FactBase f
fbase2
    loop IntHeap
_ !LabelMap (Block node C C)
blocks1 !FactBase f
fbase1 = (LabelMap (Block node C C), FactBase f)
-> UniqSM (LabelMap (Block node C C), FactBase f)
forall a. a -> UniqSM a
forall (m :: * -> *) a. Monad m => a -> m a
return (LabelMap (Block node C C)
blocks1, FactBase f
fbase1)


{-
Note [Unreachable blocks]
~~~~~~~~~~~~~~~~~~~~~~~~~
A block that is not in the domain of tfb_fbase is "currently unreachable".
A currently-unreachable block is not even analyzed.  Reason: consider
constant prop and this graph, with entry point L1:
  L1: x:=3; goto L4
  L2: x:=4; goto L4
  L4: if x>3 goto L2 else goto L5
Here L2 is actually unreachable, but if we process it with bottom input fact,
we'll propagate (x=4) to L4, and nuke the otherwise-good rewriting of L4.

* If a currently-unreachable block is not analyzed, then its rewritten
  graph will not be accumulated in tfb_rg.  And that is good:
  unreachable blocks simply do not appear in the output.

* Note that clients must be careful to provide a fact (even if bottom)
  for each entry point. Otherwise useful blocks may be garbage collected.

* Note that updateFact must set the change-flag if a label goes from
  not-in-fbase to in-fbase, even if its fact is bottom.  In effect the
  real fact lattice is
       UNR
       bottom
       the points above bottom

* Even if the fact is going from UNR to bottom, we still call the
  client's fact_join function because it might give the client
  some useful debugging information.

* All of this only applies for *forward* ixpoints.  For the backward
  case we must treat every block as reachable; it might finish with a
  'return', and therefore have no successors, for example.
-}


-----------------------------------------------------------------------------
--  Pieces that are shared by fixpoint and fixpoint_anal
-----------------------------------------------------------------------------

-- | Sort the blocks into the right order for analysis. This means reverse
-- postorder for a forward analysis. For the backward one, we simply reverse
-- that (see Note [Backward vs forward analysis]).
sortBlocks
    :: NonLocal n
    => Direction -> Label -> LabelMap (Block n C C) -> [Block n C C]
sortBlocks :: forall (n :: Extensibility -> Extensibility -> *).
NonLocal n =>
Direction -> BlockId -> LabelMap (Block n C C) -> [Block n C C]
sortBlocks Direction
direction BlockId
entry LabelMap (Block n C C)
blockmap =
    case Direction
direction of
        Direction
Fwd -> [Block n C C]
fwd
        Direction
Bwd -> [Block n C C] -> [Block n C C]
forall a. [a] -> [a]
reverse [Block n C C]
fwd
  where
    fwd :: [Block n C C]
fwd = LabelMap (Block n C C) -> BlockId -> [Block n C C]
forall (block :: Extensibility -> Extensibility -> *).
NonLocal block =>
LabelMap (block C C) -> BlockId -> [block C C]
revPostorderFrom LabelMap (Block n C C)
blockmap BlockId
entry

-- Note [Backward vs forward analysis]
-- ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-- The forward and backward cases are not dual.  In the forward case, the entry
-- points are known, and one simply traverses the body blocks from those points.
-- In the backward case, something is known about the exit points, but a
-- backward analysis must also include reachable blocks that don't reach the
-- exit, as in a procedure that loops forever and has side effects.)
-- For instance, let E be the entry and X the exit blocks (arrows indicate
-- control flow)
--   E -> X
--   E -> B
--   B -> C
--   C -> B
-- We do need to include B and C even though they're unreachable in the
-- *reverse* graph (that we could use for backward analysis):
--   E <- X
--   E <- B
--   B <- C
--   C <- B
-- So when sorting the blocks for the backward analysis, we simply take the
-- reverse of what is used for the forward one.


-- | Construct a mapping from a @Label@ to the block indexes that should be
-- re-analyzed if the facts at that @Label@ change.
--
-- Note that we're considering here the entry point of the block, so if the
-- facts change at the entry:
-- * for a backward analysis we need to re-analyze all the predecessors, but
-- * for a forward analysis, we only need to re-analyze the current block
--   (and that will in turn propagate facts into its successors).
mkDepBlocks :: NonLocal node => Direction -> [Block node C C] -> LabelMap IntSet
mkDepBlocks :: forall (node :: Extensibility -> Extensibility -> *).
NonLocal node =>
Direction -> [Block node C C] -> LabelMap IntHeap
mkDepBlocks Direction
Fwd [Block node C C]
blocks = [Block node C C] -> Key -> LabelMap IntHeap -> LabelMap IntHeap
forall {map :: * -> *}
       {thing :: Extensibility -> Extensibility -> *}
       {x :: Extensibility}.
(KeyOf map ~ BlockId, IsMap map, NonLocal thing) =>
[thing C x] -> Key -> map IntHeap -> map IntHeap
go [Block node C C]
blocks Key
0 LabelMap IntHeap
forall a. LabelMap a
forall (map :: * -> *) a. IsMap map => map a
mapEmpty
  where
    go :: [thing C x] -> Key -> map IntHeap -> map IntHeap
go []     !Key
_ !map IntHeap
dep_map = map IntHeap
dep_map
    go (thing C x
b:[thing C x]
bs) !Key
n !map IntHeap
dep_map =
        [thing C x] -> Key -> map IntHeap -> map IntHeap
go [thing C x]
bs (Key
n Key -> Key -> Key
forall a. Num a => a -> a -> a
+ Key
1) (map IntHeap -> map IntHeap) -> map IntHeap -> map IntHeap
forall a b. (a -> b) -> a -> b
$ KeyOf map -> IntHeap -> map IntHeap -> map IntHeap
forall a. KeyOf map -> a -> map a -> map a
forall (map :: * -> *) a.
IsMap map =>
KeyOf map -> a -> map a -> map a
mapInsert (thing C x -> BlockId
forall (x :: Extensibility). thing C x -> BlockId
forall (thing :: Extensibility -> Extensibility -> *)
       (x :: Extensibility).
NonLocal thing =>
thing C x -> BlockId
entryLabel thing C x
b) (Key -> IntHeap
IntSet.singleton Key
n) map IntHeap
dep_map
mkDepBlocks Direction
Bwd [Block node C C]
blocks = [Block node C C] -> Key -> LabelMap IntHeap -> LabelMap IntHeap
forall {map :: * -> *}
       {thing :: Extensibility -> Extensibility -> *}
       {e :: Extensibility}.
(KeyOf map ~ BlockId, IsMap map, NonLocal thing) =>
[thing e C] -> Key -> map IntHeap -> map IntHeap
go [Block node C C]
blocks Key
0 LabelMap IntHeap
forall a. LabelMap a
forall (map :: * -> *) a. IsMap map => map a
mapEmpty
  where
    go :: [thing e C] -> Key -> map IntHeap -> map IntHeap
go []     !Key
_ !map IntHeap
dep_map = map IntHeap
dep_map
    go (thing e C
b:[thing e C]
bs) !Key
n !map IntHeap
dep_map =
        let insert :: map IntHeap -> BlockId -> map IntHeap
insert map IntHeap
m BlockId
l = (IntHeap -> IntHeap -> IntHeap)
-> KeyOf map -> IntHeap -> map IntHeap -> map IntHeap
forall a. (a -> a -> a) -> KeyOf map -> a -> map a -> map a
forall (map :: * -> *) a.
IsMap map =>
(a -> a -> a) -> KeyOf map -> a -> map a -> map a
mapInsertWith IntHeap -> IntHeap -> IntHeap
IntSet.union KeyOf map
BlockId
l (Key -> IntHeap
IntSet.singleton Key
n) map IntHeap
m
        in [thing e C] -> Key -> map IntHeap -> map IntHeap
go [thing e C]
bs (Key
n Key -> Key -> Key
forall a. Num a => a -> a -> a
+ Key
1) (map IntHeap -> map IntHeap) -> map IntHeap -> map IntHeap
forall a b. (a -> b) -> a -> b
$ (map IntHeap -> BlockId -> map IntHeap)
-> map IntHeap -> [BlockId] -> map IntHeap
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' map IntHeap -> BlockId -> map IntHeap
insert map IntHeap
dep_map (thing e C -> [BlockId]
forall (e :: Extensibility). thing e C -> [BlockId]
forall (thing :: Extensibility -> Extensibility -> *)
       (e :: Extensibility).
NonLocal thing =>
thing e C -> [BlockId]
successors thing e C
b)

-- | After some new facts have been generated by analysing a block, we
-- fold this function over them to generate (a) a list of block
-- indices to (re-)analyse, and (b) the new FactBase.
updateFact
    :: JoinFun f
    -> LabelMap IntSet
    -> (IntHeap, FactBase f)
    -> Label
    -> f -- out fact
    -> (IntHeap, FactBase f)
updateFact :: forall f.
JoinFun f
-> LabelMap IntHeap
-> (IntHeap, FactBase f)
-> BlockId
-> f
-> (IntHeap, FactBase f)
updateFact JoinFun f
fact_join LabelMap IntHeap
dep_blocks (IntHeap
todo, FactBase f
fbase) BlockId
lbl f
new_fact
  = case BlockId -> FactBase f -> Maybe f
forall f. BlockId -> FactBase f -> Maybe f
lookupFact BlockId
lbl FactBase f
fbase of
      Maybe f
Nothing ->
          -- See Note [No old fact]
          let !z :: FactBase f
z = KeyOf LabelMap -> f -> FactBase f -> FactBase f
forall a. KeyOf LabelMap -> a -> LabelMap a -> LabelMap a
forall (map :: * -> *) a.
IsMap map =>
KeyOf map -> a -> map a -> map a
mapInsert KeyOf LabelMap
BlockId
lbl f
new_fact FactBase f
fbase in (IntHeap
changed, FactBase f
z)
      Just f
old_fact ->
          case JoinFun f
fact_join (f -> OldFact f
forall a. a -> OldFact a
OldFact f
old_fact) (f -> NewFact f
forall a. a -> NewFact a
NewFact f
new_fact) of
              (NotChanged f
_) -> (IntHeap
todo, FactBase f
fbase)
              (Changed f
f) -> let !z :: FactBase f
z = KeyOf LabelMap -> f -> FactBase f -> FactBase f
forall a. KeyOf LabelMap -> a -> LabelMap a -> LabelMap a
forall (map :: * -> *) a.
IsMap map =>
KeyOf map -> a -> map a -> map a
mapInsert KeyOf LabelMap
BlockId
lbl f
f FactBase f
fbase in (IntHeap
changed, FactBase f
z)
  where
    changed :: IntHeap
changed = IntHeap
todo IntHeap -> IntHeap -> IntHeap
`IntSet.union`
              IntHeap -> KeyOf LabelMap -> LabelMap IntHeap -> IntHeap
forall a. a -> KeyOf LabelMap -> LabelMap a -> a
forall (map :: * -> *) a. IsMap map => a -> KeyOf map -> map a -> a
mapFindWithDefault IntHeap
IntSet.empty KeyOf LabelMap
BlockId
lbl LabelMap IntHeap
dep_blocks

{-
Note [No old fact]
~~~~~~~~~~~~~~~~~~
We know that the new_fact is >= _|_, so we don't need to join.  However,
if the new fact is also _|_, and we have already analysed its block,
we don't need to record a change.  So there's a tradeoff here.  It turns
out that always recording a change is faster.
-}

----------------------------------------------------------------
--       Utilities
----------------------------------------------------------------

-- Fact lookup: the fact `orelse` bottom
getFact  :: DataflowLattice f -> Label -> FactBase f -> f
getFact :: forall f. DataflowLattice f -> BlockId -> FactBase f -> f
getFact DataflowLattice f
lat BlockId
l FactBase f
fb = case BlockId -> FactBase f -> Maybe f
forall f. BlockId -> FactBase f -> Maybe f
lookupFact BlockId
l FactBase f
fb of Just  f
f -> f
f
                                           Maybe f
Nothing -> DataflowLattice f -> f
forall a. DataflowLattice a -> a
fact_bot DataflowLattice f
lat

-- | Returns the result of joining the facts from all the successors of the
-- provided node or block.
joinOutFacts :: (NonLocal n) => DataflowLattice f -> n e C -> FactBase f -> f
joinOutFacts :: forall (n :: Extensibility -> Extensibility -> *) f
       (e :: Extensibility).
NonLocal n =>
DataflowLattice f -> n e C -> FactBase f -> f
joinOutFacts DataflowLattice f
lattice n e C
nonLocal FactBase f
fact_base = (f -> f -> f) -> f -> [f] -> f
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' f -> f -> f
join (DataflowLattice f -> f
forall a. DataflowLattice a -> a
fact_bot DataflowLattice f
lattice) [f]
facts
  where
    join :: f -> f -> f
join f
new f
old = JoinedFact f -> f
forall a. JoinedFact a -> a
getJoined (JoinedFact f -> f) -> JoinedFact f -> f
forall a b. (a -> b) -> a -> b
$ DataflowLattice f -> JoinFun f
forall a. DataflowLattice a -> JoinFun a
fact_join DataflowLattice f
lattice (f -> OldFact f
forall a. a -> OldFact a
OldFact f
old) (f -> NewFact f
forall a. a -> NewFact a
NewFact f
new)
    facts :: [f]
facts =
        [ Maybe f -> f
forall a. HasCallStack => Maybe a -> a
fromJust Maybe f
fact
        | BlockId
s <- n e C -> [BlockId]
forall (e :: Extensibility). n e C -> [BlockId]
forall (thing :: Extensibility -> Extensibility -> *)
       (e :: Extensibility).
NonLocal thing =>
thing e C -> [BlockId]
successors n e C
nonLocal
        , let fact :: Maybe f
fact = BlockId -> FactBase f -> Maybe f
forall f. BlockId -> FactBase f -> Maybe f
lookupFact BlockId
s FactBase f
fact_base
        , Maybe f -> Bool
forall a. Maybe a -> Bool
isJust Maybe f
fact
        ]

joinFacts :: DataflowLattice f -> [f] -> f
joinFacts :: forall f. DataflowLattice f -> [f] -> f
joinFacts DataflowLattice f
lattice [f]
facts  = (f -> f -> f) -> f -> [f] -> f
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' f -> f -> f
join (DataflowLattice f -> f
forall a. DataflowLattice a -> a
fact_bot DataflowLattice f
lattice) [f]
facts
  where
    join :: f -> f -> f
join f
new f
old = JoinedFact f -> f
forall a. JoinedFact a -> a
getJoined (JoinedFact f -> f) -> JoinedFact f -> f
forall a b. (a -> b) -> a -> b
$ DataflowLattice f -> JoinFun f
forall a. DataflowLattice a -> JoinFun a
fact_join DataflowLattice f
lattice (f -> OldFact f
forall a. a -> OldFact a
OldFact f
old) (f -> NewFact f
forall a. a -> NewFact a
NewFact f
new)

-- | Returns the joined facts for each label.
mkFactBase :: DataflowLattice f -> [(Label, f)] -> FactBase f
mkFactBase :: forall f. DataflowLattice f -> [(BlockId, f)] -> FactBase f
mkFactBase DataflowLattice f
lattice = (FactBase f -> (BlockId, f) -> FactBase f)
-> FactBase f -> [(BlockId, f)] -> FactBase f
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' FactBase f -> (BlockId, f) -> FactBase f
add FactBase f
forall a. LabelMap a
forall (map :: * -> *) a. IsMap map => map a
mapEmpty
  where
    join :: JoinFun f
join = DataflowLattice f -> JoinFun f
forall a. DataflowLattice a -> JoinFun a
fact_join DataflowLattice f
lattice

    add :: FactBase f -> (BlockId, f) -> FactBase f
add FactBase f
result (BlockId
l, f
f1) =
        let !newFact :: f
newFact =
                case KeyOf LabelMap -> FactBase f -> Maybe f
forall a. KeyOf LabelMap -> LabelMap a -> Maybe a
forall (map :: * -> *) a.
IsMap map =>
KeyOf map -> map a -> Maybe a
mapLookup KeyOf LabelMap
BlockId
l FactBase f
result of
                    Maybe f
Nothing -> f
f1
                    Just f
f2 -> JoinedFact f -> f
forall a. JoinedFact a -> a
getJoined (JoinedFact f -> f) -> JoinedFact f -> f
forall a b. (a -> b) -> a -> b
$ JoinFun f
join (f -> OldFact f
forall a. a -> OldFact a
OldFact f
f1) (f -> NewFact f
forall a. a -> NewFact a
NewFact f
f2)
        in KeyOf LabelMap -> f -> FactBase f -> FactBase f
forall a. KeyOf LabelMap -> a -> LabelMap a -> LabelMap a
forall (map :: * -> *) a.
IsMap map =>
KeyOf map -> a -> map a -> map a
mapInsert KeyOf LabelMap
BlockId
l f
newFact FactBase f
result

-- | Folds backward over all nodes of an open-open block.
-- Strict in the accumulator.
foldNodesBwdOO :: (node O O -> f -> f) -> Block node O O -> f -> f
foldNodesBwdOO :: forall (node :: Extensibility -> Extensibility -> *) f.
(node O O -> f -> f) -> Block node O O -> f -> f
foldNodesBwdOO node O O -> f -> f
funOO = Block node O O -> f -> f
go
  where
    go :: Block node O O -> f -> f
go (BCat Block node O O
b1 Block node O O
b2) f
f = Block node O O -> f -> f
go Block node O O
b1 (f -> f) -> f -> f
forall a b. (a -> b) -> a -> b
$! Block node O O -> f -> f
go Block node O O
b2 f
f
    go (BSnoc Block node O O
h node O O
n) f
f = Block node O O -> f -> f
go Block node O O
h (f -> f) -> f -> f
forall a b. (a -> b) -> a -> b
$! node O O -> f -> f
funOO node O O
n f
f
    go (BCons node O O
n Block node O O
t) f
f = node O O -> f -> f
funOO node O O
n (f -> f) -> f -> f
forall a b. (a -> b) -> a -> b
$! Block node O O -> f -> f
go Block node O O
t f
f
    go (BMiddle node O O
n) f
f = node O O -> f -> f
funOO node O O
n f
f
    go Block node O O
BNil f
f = f
f
{-# INLINABLE foldNodesBwdOO #-}

-- | Folds backward over all the nodes of an open-open block and allows
-- rewriting them. The accumulator is both the block of nodes and @f@ (usually
-- dataflow facts).
-- Strict in both accumulated parts.
foldRewriteNodesBwdOO
    :: forall f node.
       (node O O -> f -> UniqSM (Block node O O, f))
    -> Block node O O
    -> f
    -> UniqSM (Block node O O, f)
foldRewriteNodesBwdOO :: forall f (node :: Extensibility -> Extensibility -> *).
(node O O -> f -> UniqSM (Block node O O, f))
-> Block node O O -> f -> UniqSM (Block node O O, f)
foldRewriteNodesBwdOO node O O -> f -> UniqSM (Block node O O, f)
rewriteOO Block node O O
initBlock f
initFacts = Block node O O -> f -> UniqSM (Block node O O, f)
go Block node O O
initBlock f
initFacts
  where
    go :: Block node O O -> f -> UniqSM (Block node O O, f)
go (BCons node O O
node1 Block node O O
block1) !f
fact1 = (node O O -> f -> UniqSM (Block node O O, f)
rewriteOO node O O
node1 (f -> UniqSM (Block node O O, f))
-> (f -> UniqSM (Block node O O, f))
-> f
-> UniqSM (Block node O O, f)
forall {m :: * -> *} {t} {n :: Extensibility -> Extensibility -> *}
       {b} {t}.
Monad m =>
(t -> m (Block n O O, b))
-> (t -> m (Block n O O, t)) -> t -> m (Block n O O, b)
`comp` Block node O O -> f -> UniqSM (Block node O O, f)
go Block node O O
block1) f
fact1
    go (BSnoc Block node O O
block1 node O O
node1) !f
fact1 = (Block node O O -> f -> UniqSM (Block node O O, f)
go Block node O O
block1 (f -> UniqSM (Block node O O, f))
-> (f -> UniqSM (Block node O O, f))
-> f
-> UniqSM (Block node O O, f)
forall {m :: * -> *} {t} {n :: Extensibility -> Extensibility -> *}
       {b} {t}.
Monad m =>
(t -> m (Block n O O, b))
-> (t -> m (Block n O O, t)) -> t -> m (Block n O O, b)
`comp` node O O -> f -> UniqSM (Block node O O, f)
rewriteOO node O O
node1) f
fact1
    go (BCat Block node O O
blockA1 Block node O O
blockB1) !f
fact1 = (Block node O O -> f -> UniqSM (Block node O O, f)
go Block node O O
blockA1 (f -> UniqSM (Block node O O, f))
-> (f -> UniqSM (Block node O O, f))
-> f
-> UniqSM (Block node O O, f)
forall {m :: * -> *} {t} {n :: Extensibility -> Extensibility -> *}
       {b} {t}.
Monad m =>
(t -> m (Block n O O, b))
-> (t -> m (Block n O O, t)) -> t -> m (Block n O O, b)
`comp` Block node O O -> f -> UniqSM (Block node O O, f)
go Block node O O
blockB1) f
fact1
    go (BMiddle node O O
node) !f
fact1 = node O O -> f -> UniqSM (Block node O O, f)
rewriteOO node O O
node f
fact1
    go Block node O O
BNil !f
fact = (Block node O O, f) -> UniqSM (Block node O O, f)
forall a. a -> UniqSM a
forall (m :: * -> *) a. Monad m => a -> m a
return (Block node O O
forall (n :: Extensibility -> Extensibility -> *). Block n O O
BNil, f
fact)

    comp :: (t -> m (Block n O O, b))
-> (t -> m (Block n O O, t)) -> t -> m (Block n O O, b)
comp t -> m (Block n O O, b)
rew1 t -> m (Block n O O, t)
rew2 = \t
f1 -> do
        (Block n O O
b, t
f2) <- t -> m (Block n O O, t)
rew2 t
f1
        (Block n O O
a, !b
f3) <- t -> m (Block n O O, b)
rew1 t
f2
        let !c :: Block n O O
c = Block n O O -> Block n O O -> Block n O O
forall (n :: Extensibility -> Extensibility -> *).
Block n O O -> Block n O O -> Block n O O
joinBlocksOO Block n O O
a Block n O O
b
        (Block n O O, b) -> m (Block n O O, b)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Block n O O
c, b
f3)
    {-# INLINE comp #-}
{-# INLINABLE foldRewriteNodesBwdOO #-}

joinBlocksOO :: Block n O O -> Block n O O -> Block n O O
joinBlocksOO :: forall (n :: Extensibility -> Extensibility -> *).
Block n O O -> Block n O O -> Block n O O
joinBlocksOO Block n O O
BNil Block n O O
b = Block n O O
b
joinBlocksOO Block n O O
b Block n O O
BNil = Block n O O
b
joinBlocksOO (BMiddle n O O
n) Block n O O
b = n O O -> Block n O O -> Block n O O
forall (n :: Extensibility -> Extensibility -> *)
       (x :: Extensibility).
n O O -> Block n O x -> Block n O x
blockCons n O O
n Block n O O
b
joinBlocksOO Block n O O
b (BMiddle n O O
n) = Block n O O -> n O O -> Block n O O
forall (n :: Extensibility -> Extensibility -> *)
       (e :: Extensibility).
Block n e O -> n O O -> Block n e O
blockSnoc Block n O O
b n O O
n
joinBlocksOO Block n O O
b1 Block n O O
b2 = Block n O O -> Block n O O -> Block n O O
forall (n :: Extensibility -> Extensibility -> *).
Block n O O -> Block n O O -> Block n O O
BCat Block n O O
b1 Block n O O
b2

type IntHeap = IntSet