module GHC.Core.Opt.Exitify ( exitifyProgram ) where
import GHC.Prelude
import GHC.Types.Var
import GHC.Types.Id
import GHC.Types.Id.Info
import GHC.Core
import GHC.Core.Utils
import GHC.Utils.Monad.State
import GHC.Builtin.Uniques
import GHC.Types.Var.Set
import GHC.Types.Var.Env
import GHC.Core.FVs
import GHC.Data.FastString
import GHC.Core.Type
import GHC.Utils.Misc( mapSnd )
import Data.Bifunctor
import Control.Monad
exitifyProgram :: CoreProgram -> CoreProgram
exitifyProgram :: CoreProgram -> CoreProgram
exitifyProgram CoreProgram
binds = (Bind JoinId -> Bind JoinId) -> CoreProgram -> CoreProgram
forall a b. (a -> b) -> [a] -> [b]
map Bind JoinId -> Bind JoinId
goTopLvl CoreProgram
binds
  where
    goTopLvl :: Bind JoinId -> Bind JoinId
goTopLvl (NonRec JoinId
v CoreExpr
e) = JoinId -> CoreExpr -> Bind JoinId
forall b. b -> Expr b -> Bind b
NonRec JoinId
v (InScopeSet -> CoreExpr -> CoreExpr
go InScopeSet
in_scope_toplvl CoreExpr
e)
    goTopLvl (Rec [(JoinId, CoreExpr)]
pairs) = [(JoinId, CoreExpr)] -> Bind JoinId
forall b. [(b, Expr b)] -> Bind b
Rec (((JoinId, CoreExpr) -> (JoinId, CoreExpr))
-> [(JoinId, CoreExpr)] -> [(JoinId, CoreExpr)]
forall a b. (a -> b) -> [a] -> [b]
map ((CoreExpr -> CoreExpr) -> (JoinId, CoreExpr) -> (JoinId, CoreExpr)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second (InScopeSet -> CoreExpr -> CoreExpr
go InScopeSet
in_scope_toplvl)) [(JoinId, CoreExpr)]
pairs)
      
    in_scope_toplvl :: InScopeSet
in_scope_toplvl = InScopeSet
emptyInScopeSet InScopeSet -> [JoinId] -> InScopeSet
`extendInScopeSetList` CoreProgram -> [JoinId]
forall b. [Bind b] -> [b]
bindersOfBinds CoreProgram
binds
    go :: InScopeSet -> CoreExpr -> CoreExpr
    go :: InScopeSet -> CoreExpr -> CoreExpr
go InScopeSet
_    e :: CoreExpr
e@(Var{})       = CoreExpr
e
    go InScopeSet
_    e :: CoreExpr
e@(Lit {})      = CoreExpr
e
    go InScopeSet
_    e :: CoreExpr
e@(Type {})     = CoreExpr
e
    go InScopeSet
_    e :: CoreExpr
e@(Coercion {}) = CoreExpr
e
    go InScopeSet
in_scope (Cast CoreExpr
e' CoercionR
c) = CoreExpr -> CoercionR -> CoreExpr
forall b. Expr b -> CoercionR -> Expr b
Cast (InScopeSet -> CoreExpr -> CoreExpr
go InScopeSet
in_scope CoreExpr
e') CoercionR
c
    go InScopeSet
in_scope (Tick CoreTickish
t CoreExpr
e') = CoreTickish -> CoreExpr -> CoreExpr
forall b. CoreTickish -> Expr b -> Expr b
Tick CoreTickish
t (InScopeSet -> CoreExpr -> CoreExpr
go InScopeSet
in_scope CoreExpr
e')
    go InScopeSet
in_scope (App CoreExpr
e1 CoreExpr
e2) = CoreExpr -> CoreExpr -> CoreExpr
forall b. Expr b -> Expr b -> Expr b
App (InScopeSet -> CoreExpr -> CoreExpr
go InScopeSet
in_scope CoreExpr
e1) (InScopeSet -> CoreExpr -> CoreExpr
go InScopeSet
in_scope CoreExpr
e2)
    go InScopeSet
in_scope (Lam JoinId
v CoreExpr
e')
      = JoinId -> CoreExpr -> CoreExpr
forall b. b -> Expr b -> Expr b
Lam JoinId
v (InScopeSet -> CoreExpr -> CoreExpr
go InScopeSet
in_scope' CoreExpr
e')
      where in_scope' :: InScopeSet
in_scope' = InScopeSet
in_scope InScopeSet -> JoinId -> InScopeSet
`extendInScopeSet` JoinId
v
    go InScopeSet
in_scope (Case CoreExpr
scrut JoinId
bndr Type
ty [Alt JoinId]
alts)
      = CoreExpr -> JoinId -> Type -> [Alt JoinId] -> CoreExpr
forall b. Expr b -> b -> Type -> [Alt b] -> Expr b
Case (InScopeSet -> CoreExpr -> CoreExpr
go InScopeSet
in_scope CoreExpr
scrut) JoinId
bndr Type
ty ((Alt JoinId -> Alt JoinId) -> [Alt JoinId] -> [Alt JoinId]
forall a b. (a -> b) -> [a] -> [b]
map Alt JoinId -> Alt JoinId
go_alt [Alt JoinId]
alts)
      where
        in_scope1 :: InScopeSet
in_scope1 = InScopeSet
in_scope InScopeSet -> JoinId -> InScopeSet
`extendInScopeSet` JoinId
bndr
        go_alt :: Alt JoinId -> Alt JoinId
go_alt (Alt AltCon
dc [JoinId]
pats CoreExpr
rhs) = AltCon -> [JoinId] -> CoreExpr -> Alt JoinId
forall b. AltCon -> [b] -> Expr b -> Alt b
Alt AltCon
dc [JoinId]
pats (InScopeSet -> CoreExpr -> CoreExpr
go InScopeSet
in_scope' CoreExpr
rhs)
           where in_scope' :: InScopeSet
in_scope' = InScopeSet
in_scope1 InScopeSet -> [JoinId] -> InScopeSet
`extendInScopeSetList` [JoinId]
pats
    go InScopeSet
in_scope (Let (NonRec JoinId
bndr CoreExpr
rhs) CoreExpr
body)
      = Bind JoinId -> CoreExpr -> CoreExpr
forall b. Bind b -> Expr b -> Expr b
Let (JoinId -> CoreExpr -> Bind JoinId
forall b. b -> Expr b -> Bind b
NonRec JoinId
bndr (InScopeSet -> CoreExpr -> CoreExpr
go InScopeSet
in_scope CoreExpr
rhs)) (InScopeSet -> CoreExpr -> CoreExpr
go InScopeSet
in_scope' CoreExpr
body)
      where
        in_scope' :: InScopeSet
in_scope' = InScopeSet
in_scope InScopeSet -> JoinId -> InScopeSet
`extendInScopeSet` JoinId
bndr
    go InScopeSet
in_scope (Let (Rec [(JoinId, CoreExpr)]
pairs) CoreExpr
body)
      | Bool
is_join_rec = CoreProgram -> CoreExpr -> CoreExpr
forall b. [Bind b] -> Expr b -> Expr b
mkLets (InScopeSet -> [(JoinId, CoreExpr)] -> CoreProgram
exitifyRec InScopeSet
in_scope' [(JoinId, CoreExpr)]
pairs') CoreExpr
body'
      | Bool
otherwise   = Bind JoinId -> CoreExpr -> CoreExpr
forall b. Bind b -> Expr b -> Expr b
Let ([(JoinId, CoreExpr)] -> Bind JoinId
forall b. [(b, Expr b)] -> Bind b
Rec [(JoinId, CoreExpr)]
pairs') CoreExpr
body'
      where
        is_join_rec :: Bool
is_join_rec = ((JoinId, CoreExpr) -> Bool) -> [(JoinId, CoreExpr)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (JoinId -> Bool
isJoinId (JoinId -> Bool)
-> ((JoinId, CoreExpr) -> JoinId) -> (JoinId, CoreExpr) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (JoinId, CoreExpr) -> JoinId
forall a b. (a, b) -> a
fst) [(JoinId, CoreExpr)]
pairs
        in_scope' :: InScopeSet
in_scope'   = InScopeSet
in_scope InScopeSet -> [JoinId] -> InScopeSet
`extendInScopeSetList` Bind JoinId -> [JoinId]
forall b. Bind b -> [b]
bindersOf ([(JoinId, CoreExpr)] -> Bind JoinId
forall b. [(b, Expr b)] -> Bind b
Rec [(JoinId, CoreExpr)]
pairs)
        pairs' :: [(JoinId, CoreExpr)]
pairs'      = (CoreExpr -> CoreExpr)
-> [(JoinId, CoreExpr)] -> [(JoinId, CoreExpr)]
forall b c a. (b -> c) -> [(a, b)] -> [(a, c)]
mapSnd (InScopeSet -> CoreExpr -> CoreExpr
go InScopeSet
in_scope') [(JoinId, CoreExpr)]
pairs
        body' :: CoreExpr
body'       = InScopeSet -> CoreExpr -> CoreExpr
go InScopeSet
in_scope' CoreExpr
body
type ExitifyM =  State [(JoinId, CoreExpr)]
exitifyRec :: InScopeSet -> [(Var,CoreExpr)] -> [CoreBind]
exitifyRec :: InScopeSet -> [(JoinId, CoreExpr)] -> CoreProgram
exitifyRec InScopeSet
in_scope [(JoinId, CoreExpr)]
pairs
  = [ JoinId -> CoreExpr -> Bind JoinId
forall b. b -> Expr b -> Bind b
NonRec JoinId
xid CoreExpr
rhs | (JoinId
xid,CoreExpr
rhs) <- [(JoinId, CoreExpr)]
exits ] CoreProgram -> CoreProgram -> CoreProgram
forall a. [a] -> [a] -> [a]
++ [[(JoinId, CoreExpr)] -> Bind JoinId
forall b. [(b, Expr b)] -> Bind b
Rec [(JoinId, CoreExpr)]
pairs']
  where
    
    
    
    ann_pairs :: [(JoinId, CoreExprWithFVs)]
ann_pairs = ((JoinId, CoreExpr) -> (JoinId, CoreExprWithFVs))
-> [(JoinId, CoreExpr)] -> [(JoinId, CoreExprWithFVs)]
forall a b. (a -> b) -> [a] -> [b]
map ((CoreExpr -> CoreExprWithFVs)
-> (JoinId, CoreExpr) -> (JoinId, CoreExprWithFVs)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second CoreExpr -> CoreExprWithFVs
freeVars) [(JoinId, CoreExpr)]
pairs
    
    recursive_calls :: VarSet
recursive_calls = [JoinId] -> VarSet
mkVarSet ([JoinId] -> VarSet) -> [JoinId] -> VarSet
forall a b. (a -> b) -> a -> b
$ ((JoinId, CoreExpr) -> JoinId) -> [(JoinId, CoreExpr)] -> [JoinId]
forall a b. (a -> b) -> [a] -> [b]
map (JoinId, CoreExpr) -> JoinId
forall a b. (a, b) -> a
fst [(JoinId, CoreExpr)]
pairs
    ([(JoinId, CoreExpr)]
pairs',[(JoinId, CoreExpr)]
exits) = (State [(JoinId, CoreExpr)] [(JoinId, CoreExpr)]
-> [(JoinId, CoreExpr)]
-> ([(JoinId, CoreExpr)], [(JoinId, CoreExpr)])
forall s a. State s a -> s -> (a, s)
`runState` []) (State [(JoinId, CoreExpr)] [(JoinId, CoreExpr)]
 -> ([(JoinId, CoreExpr)], [(JoinId, CoreExpr)]))
-> State [(JoinId, CoreExpr)] [(JoinId, CoreExpr)]
-> ([(JoinId, CoreExpr)], [(JoinId, CoreExpr)])
forall a b. (a -> b) -> a -> b
$
        [(JoinId, CoreExprWithFVs)]
-> ((JoinId, CoreExprWithFVs)
    -> State [(JoinId, CoreExpr)] (JoinId, CoreExpr))
-> State [(JoinId, CoreExpr)] [(JoinId, CoreExpr)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(JoinId, CoreExprWithFVs)]
ann_pairs (((JoinId, CoreExprWithFVs)
  -> State [(JoinId, CoreExpr)] (JoinId, CoreExpr))
 -> State [(JoinId, CoreExpr)] [(JoinId, CoreExpr)])
-> ((JoinId, CoreExprWithFVs)
    -> State [(JoinId, CoreExpr)] (JoinId, CoreExpr))
-> State [(JoinId, CoreExpr)] [(JoinId, CoreExpr)]
forall a b. (a -> b) -> a -> b
$ \(JoinId
x,CoreExprWithFVs
rhs) -> do
            
            let ([JoinId]
args, CoreExprWithFVs
body) = Int -> CoreExprWithFVs -> ([JoinId], CoreExprWithFVs)
forall bndr annot.
Int -> AnnExpr bndr annot -> ([bndr], AnnExpr bndr annot)
collectNAnnBndrs (JoinId -> Int
idJoinArity JoinId
x) CoreExprWithFVs
rhs
            CoreExpr
body' <- [JoinId] -> CoreExprWithFVs -> ExitifyM CoreExpr
go [JoinId]
args CoreExprWithFVs
body
            let rhs' :: CoreExpr
rhs' = [JoinId] -> CoreExpr -> CoreExpr
forall b. [b] -> Expr b -> Expr b
mkLams [JoinId]
args CoreExpr
body'
            (JoinId, CoreExpr) -> State [(JoinId, CoreExpr)] (JoinId, CoreExpr)
forall (m :: * -> *) a. Monad m => a -> m a
return (JoinId
x, CoreExpr
rhs')
    
    
    
    
    
    
    
    go :: [Var]           
                          
                          
                          
       -> CoreExprWithFVs 
       -> ExitifyM CoreExpr
    
    
    go :: [JoinId] -> CoreExprWithFVs -> ExitifyM CoreExpr
go [JoinId]
captured CoreExprWithFVs
ann_e
        | 
          let fvs :: VarSet
fvs = DVarSet -> VarSet
dVarSetToVarSet (CoreExprWithFVs -> DVarSet
freeVarsOf CoreExprWithFVs
ann_e)
        , VarSet -> VarSet -> Bool
disjointVarSet VarSet
fvs VarSet
recursive_calls
        = [JoinId] -> CoreExpr -> VarSet -> ExitifyM CoreExpr
go_exit [JoinId]
captured (CoreExprWithFVs -> CoreExpr
forall bndr annot. AnnExpr bndr annot -> Expr bndr
deAnnotate CoreExprWithFVs
ann_e) VarSet
fvs
    
    
    
    
    go [JoinId]
captured (DVarSet
_, AnnCase CoreExprWithFVs
scrut JoinId
bndr Type
ty [AnnAlt JoinId DVarSet]
alts) = do
        [Alt JoinId]
alts' <- [AnnAlt JoinId DVarSet]
-> (AnnAlt JoinId DVarSet
    -> State [(JoinId, CoreExpr)] (Alt JoinId))
-> State [(JoinId, CoreExpr)] [Alt JoinId]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [AnnAlt JoinId DVarSet]
alts ((AnnAlt JoinId DVarSet -> State [(JoinId, CoreExpr)] (Alt JoinId))
 -> State [(JoinId, CoreExpr)] [Alt JoinId])
-> (AnnAlt JoinId DVarSet
    -> State [(JoinId, CoreExpr)] (Alt JoinId))
-> State [(JoinId, CoreExpr)] [Alt JoinId]
forall a b. (a -> b) -> a -> b
$ \(AnnAlt AltCon
dc [JoinId]
pats CoreExprWithFVs
rhs) -> do
            CoreExpr
rhs' <- [JoinId] -> CoreExprWithFVs -> ExitifyM CoreExpr
go ([JoinId]
captured [JoinId] -> [JoinId] -> [JoinId]
forall a. [a] -> [a] -> [a]
++ [JoinId
bndr] [JoinId] -> [JoinId] -> [JoinId]
forall a. [a] -> [a] -> [a]
++ [JoinId]
pats) CoreExprWithFVs
rhs
            Alt JoinId -> State [(JoinId, CoreExpr)] (Alt JoinId)
forall (m :: * -> *) a. Monad m => a -> m a
return (AltCon -> [JoinId] -> CoreExpr -> Alt JoinId
forall b. AltCon -> [b] -> Expr b -> Alt b
Alt AltCon
dc [JoinId]
pats CoreExpr
rhs')
        CoreExpr -> ExitifyM CoreExpr
forall (m :: * -> *) a. Monad m => a -> m a
return (CoreExpr -> ExitifyM CoreExpr) -> CoreExpr -> ExitifyM CoreExpr
forall a b. (a -> b) -> a -> b
$ CoreExpr -> JoinId -> Type -> [Alt JoinId] -> CoreExpr
forall b. Expr b -> b -> Type -> [Alt b] -> Expr b
Case (CoreExprWithFVs -> CoreExpr
forall bndr annot. AnnExpr bndr annot -> Expr bndr
deAnnotate CoreExprWithFVs
scrut) JoinId
bndr Type
ty [Alt JoinId]
alts'
    go [JoinId]
captured (DVarSet
_, AnnLet AnnBind JoinId DVarSet
ann_bind CoreExprWithFVs
body)
        
        | AnnNonRec JoinId
j CoreExprWithFVs
rhs <- AnnBind JoinId DVarSet
ann_bind
        , Just Int
join_arity <- JoinId -> Maybe Int
isJoinId_maybe JoinId
j
        = do let ([JoinId]
params, CoreExprWithFVs
join_body) = Int -> CoreExprWithFVs -> ([JoinId], CoreExprWithFVs)
forall bndr annot.
Int -> AnnExpr bndr annot -> ([bndr], AnnExpr bndr annot)
collectNAnnBndrs Int
join_arity CoreExprWithFVs
rhs
             CoreExpr
join_body' <- [JoinId] -> CoreExprWithFVs -> ExitifyM CoreExpr
go ([JoinId]
captured [JoinId] -> [JoinId] -> [JoinId]
forall a. [a] -> [a] -> [a]
++ [JoinId]
params) CoreExprWithFVs
join_body
             let rhs' :: CoreExpr
rhs' = [JoinId] -> CoreExpr -> CoreExpr
forall b. [b] -> Expr b -> Expr b
mkLams [JoinId]
params CoreExpr
join_body'
             CoreExpr
body' <- [JoinId] -> CoreExprWithFVs -> ExitifyM CoreExpr
go ([JoinId]
captured [JoinId] -> [JoinId] -> [JoinId]
forall a. [a] -> [a] -> [a]
++ [JoinId
j]) CoreExprWithFVs
body
             CoreExpr -> ExitifyM CoreExpr
forall (m :: * -> *) a. Monad m => a -> m a
return (CoreExpr -> ExitifyM CoreExpr) -> CoreExpr -> ExitifyM CoreExpr
forall a b. (a -> b) -> a -> b
$ Bind JoinId -> CoreExpr -> CoreExpr
forall b. Bind b -> Expr b -> Expr b
Let (JoinId -> CoreExpr -> Bind JoinId
forall b. b -> Expr b -> Bind b
NonRec JoinId
j CoreExpr
rhs') CoreExpr
body'
        
        | AnnRec [(JoinId, CoreExprWithFVs)]
pairs <- AnnBind JoinId DVarSet
ann_bind
        , JoinId -> Bool
isJoinId ((JoinId, CoreExprWithFVs) -> JoinId
forall a b. (a, b) -> a
fst ([(JoinId, CoreExprWithFVs)] -> (JoinId, CoreExprWithFVs)
forall a. [a] -> a
head [(JoinId, CoreExprWithFVs)]
pairs))
        = do let js :: [JoinId]
js = ((JoinId, CoreExprWithFVs) -> JoinId)
-> [(JoinId, CoreExprWithFVs)] -> [JoinId]
forall a b. (a -> b) -> [a] -> [b]
map (JoinId, CoreExprWithFVs) -> JoinId
forall a b. (a, b) -> a
fst [(JoinId, CoreExprWithFVs)]
pairs
             [(JoinId, CoreExpr)]
pairs' <- [(JoinId, CoreExprWithFVs)]
-> ((JoinId, CoreExprWithFVs)
    -> State [(JoinId, CoreExpr)] (JoinId, CoreExpr))
-> State [(JoinId, CoreExpr)] [(JoinId, CoreExpr)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(JoinId, CoreExprWithFVs)]
pairs (((JoinId, CoreExprWithFVs)
  -> State [(JoinId, CoreExpr)] (JoinId, CoreExpr))
 -> State [(JoinId, CoreExpr)] [(JoinId, CoreExpr)])
-> ((JoinId, CoreExprWithFVs)
    -> State [(JoinId, CoreExpr)] (JoinId, CoreExpr))
-> State [(JoinId, CoreExpr)] [(JoinId, CoreExpr)]
forall a b. (a -> b) -> a -> b
$ \(JoinId
j,CoreExprWithFVs
rhs) -> do
                 let join_arity :: Int
join_arity = JoinId -> Int
idJoinArity JoinId
j
                     ([JoinId]
params, CoreExprWithFVs
join_body) = Int -> CoreExprWithFVs -> ([JoinId], CoreExprWithFVs)
forall bndr annot.
Int -> AnnExpr bndr annot -> ([bndr], AnnExpr bndr annot)
collectNAnnBndrs Int
join_arity CoreExprWithFVs
rhs
                 CoreExpr
join_body' <- [JoinId] -> CoreExprWithFVs -> ExitifyM CoreExpr
go ([JoinId]
captured [JoinId] -> [JoinId] -> [JoinId]
forall a. [a] -> [a] -> [a]
++ [JoinId]
js [JoinId] -> [JoinId] -> [JoinId]
forall a. [a] -> [a] -> [a]
++ [JoinId]
params) CoreExprWithFVs
join_body
                 let rhs' :: CoreExpr
rhs' = [JoinId] -> CoreExpr -> CoreExpr
forall b. [b] -> Expr b -> Expr b
mkLams [JoinId]
params CoreExpr
join_body'
                 (JoinId, CoreExpr) -> State [(JoinId, CoreExpr)] (JoinId, CoreExpr)
forall (m :: * -> *) a. Monad m => a -> m a
return (JoinId
j, CoreExpr
rhs')
             CoreExpr
body' <- [JoinId] -> CoreExprWithFVs -> ExitifyM CoreExpr
go ([JoinId]
captured [JoinId] -> [JoinId] -> [JoinId]
forall a. [a] -> [a] -> [a]
++ [JoinId]
js) CoreExprWithFVs
body
             CoreExpr -> ExitifyM CoreExpr
forall (m :: * -> *) a. Monad m => a -> m a
return (CoreExpr -> ExitifyM CoreExpr) -> CoreExpr -> ExitifyM CoreExpr
forall a b. (a -> b) -> a -> b
$ Bind JoinId -> CoreExpr -> CoreExpr
forall b. Bind b -> Expr b -> Expr b
Let ([(JoinId, CoreExpr)] -> Bind JoinId
forall b. [(b, Expr b)] -> Bind b
Rec [(JoinId, CoreExpr)]
pairs') CoreExpr
body'
        
        | Bool
otherwise
        = do CoreExpr
body' <- [JoinId] -> CoreExprWithFVs -> ExitifyM CoreExpr
go ([JoinId]
captured [JoinId] -> [JoinId] -> [JoinId]
forall a. [a] -> [a] -> [a]
++ Bind JoinId -> [JoinId]
forall b. Bind b -> [b]
bindersOf Bind JoinId
bind ) CoreExprWithFVs
body
             CoreExpr -> ExitifyM CoreExpr
forall (m :: * -> *) a. Monad m => a -> m a
return (CoreExpr -> ExitifyM CoreExpr) -> CoreExpr -> ExitifyM CoreExpr
forall a b. (a -> b) -> a -> b
$ Bind JoinId -> CoreExpr -> CoreExpr
forall b. Bind b -> Expr b -> Expr b
Let Bind JoinId
bind CoreExpr
body'
      where bind :: Bind JoinId
bind = AnnBind JoinId DVarSet -> Bind JoinId
forall b annot. AnnBind b annot -> Bind b
deAnnBind AnnBind JoinId DVarSet
ann_bind
    
    
    go [JoinId]
_ CoreExprWithFVs
ann_e = CoreExpr -> ExitifyM CoreExpr
forall (m :: * -> *) a. Monad m => a -> m a
return (CoreExprWithFVs -> CoreExpr
forall bndr annot. AnnExpr bndr annot -> Expr bndr
deAnnotate CoreExprWithFVs
ann_e)
    
    go_exit :: [Var]      
            -> CoreExpr   
            -> VarSet     
            -> ExitifyM CoreExpr
    
    
    go_exit :: [JoinId] -> CoreExpr -> VarSet -> ExitifyM CoreExpr
go_exit [JoinId]
captured CoreExpr
e VarSet
fvs
      
      
      
      
      | (Var JoinId
f, [CoreExpr]
args) <- CoreExpr -> (CoreExpr, [CoreExpr])
forall b. Expr b -> (Expr b, [Expr b])
collectArgs CoreExpr
e
      , JoinId -> Bool
isJoinId JoinId
f
      , (CoreExpr -> Bool) -> [CoreExpr] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all CoreExpr -> Bool
isCapturedVarArg [CoreExpr]
args
      = CoreExpr -> ExitifyM CoreExpr
forall (m :: * -> *) a. Monad m => a -> m a
return CoreExpr
e
      
      | Bool -> Bool
not Bool
is_interesting
      = CoreExpr -> ExitifyM CoreExpr
forall (m :: * -> *) a. Monad m => a -> m a
return CoreExpr
e
      
      
      | Bool
captures_join_points
      = CoreExpr -> ExitifyM CoreExpr
forall (m :: * -> *) a. Monad m => a -> m a
return CoreExpr
e
      
      | Bool
otherwise
      = do { 
             let rhs :: CoreExpr
rhs   = [JoinId] -> CoreExpr -> CoreExpr
forall b. [b] -> Expr b -> Expr b
mkLams [JoinId]
abs_vars CoreExpr
e
                 avoid :: InScopeSet
avoid = InScopeSet
in_scope InScopeSet -> [JoinId] -> InScopeSet
`extendInScopeSetList` [JoinId]
captured
             
           ; JoinId
v <- InScopeSet -> Int -> CoreExpr -> ExitifyM JoinId
addExit InScopeSet
avoid ([JoinId] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [JoinId]
abs_vars) CoreExpr
rhs
             
           ; CoreExpr -> ExitifyM CoreExpr
forall (m :: * -> *) a. Monad m => a -> m a
return (CoreExpr -> ExitifyM CoreExpr) -> CoreExpr -> ExitifyM CoreExpr
forall a b. (a -> b) -> a -> b
$ CoreExpr -> [JoinId] -> CoreExpr
forall b. Expr b -> [JoinId] -> Expr b
mkVarApps (JoinId -> CoreExpr
forall b. JoinId -> Expr b
Var JoinId
v) [JoinId]
abs_vars }
      where
        
        isCapturedVarArg :: CoreExpr -> Bool
isCapturedVarArg (Var JoinId
v) = JoinId
v JoinId -> [JoinId] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [JoinId]
captured
        isCapturedVarArg CoreExpr
_ = Bool
False
        
        
        
        is_interesting :: Bool
is_interesting = (JoinId -> Bool) -> VarSet -> Bool
anyVarSet JoinId -> Bool
isLocalId (VarSet -> Bool) -> VarSet -> Bool
forall a b. (a -> b) -> a -> b
$
                         VarSet
fvs VarSet -> VarSet -> VarSet
`minusVarSet` [JoinId] -> VarSet
mkVarSet [JoinId]
captured
        
        
        abs_vars :: [JoinId]
abs_vars = (VarSet, [JoinId]) -> [JoinId]
forall a b. (a, b) -> b
snd ((VarSet, [JoinId]) -> [JoinId]) -> (VarSet, [JoinId]) -> [JoinId]
forall a b. (a -> b) -> a -> b
$ (JoinId -> (VarSet, [JoinId]) -> (VarSet, [JoinId]))
-> (VarSet, [JoinId]) -> [JoinId] -> (VarSet, [JoinId])
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr JoinId -> (VarSet, [JoinId]) -> (VarSet, [JoinId])
pick (VarSet
fvs, []) [JoinId]
captured
          where
            pick :: JoinId -> (VarSet, [JoinId]) -> (VarSet, [JoinId])
pick JoinId
v (VarSet
fvs', [JoinId]
acc) | JoinId
v JoinId -> VarSet -> Bool
`elemVarSet` VarSet
fvs' = (VarSet
fvs' VarSet -> JoinId -> VarSet
`delVarSet` JoinId
v, JoinId -> JoinId
zap JoinId
v JoinId -> [JoinId] -> [JoinId]
forall a. a -> [a] -> [a]
: [JoinId]
acc)
                               | Bool
otherwise           = (VarSet
fvs',               [JoinId]
acc)
        
        
        
        zap :: JoinId -> JoinId
zap JoinId
v | JoinId -> Bool
isId JoinId
v = JoinId -> IdInfo -> JoinId
setIdInfo JoinId
v IdInfo
vanillaIdInfo
              | Bool
otherwise = JoinId
v
        
        captures_join_points :: Bool
captures_join_points = (JoinId -> Bool) -> [JoinId] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any JoinId -> Bool
isJoinId [JoinId]
abs_vars
mkExitJoinId :: InScopeSet -> Type -> JoinArity -> ExitifyM JoinId
mkExitJoinId :: InScopeSet -> Type -> Int -> ExitifyM JoinId
mkExitJoinId InScopeSet
in_scope Type
ty Int
join_arity = do
    [(JoinId, CoreExpr)]
fs <- State [(JoinId, CoreExpr)] [(JoinId, CoreExpr)]
forall s. State s s
get
    let avoid :: InScopeSet
avoid = InScopeSet
in_scope InScopeSet -> [JoinId] -> InScopeSet
`extendInScopeSetList` (((JoinId, CoreExpr) -> JoinId) -> [(JoinId, CoreExpr)] -> [JoinId]
forall a b. (a -> b) -> [a] -> [b]
map (JoinId, CoreExpr) -> JoinId
forall a b. (a, b) -> a
fst [(JoinId, CoreExpr)]
fs)
                         InScopeSet -> JoinId -> InScopeSet
`extendInScopeSet` JoinId
exit_id_tmpl 
    JoinId -> ExitifyM JoinId
forall (m :: * -> *) a. Monad m => a -> m a
return (InScopeSet -> JoinId -> JoinId
uniqAway InScopeSet
avoid JoinId
exit_id_tmpl)
  where
    exit_id_tmpl :: JoinId
exit_id_tmpl = FastString -> Unique -> Type -> Type -> JoinId
mkSysLocal (String -> FastString
fsLit String
"exit") Unique
initExitJoinUnique Type
Many Type
ty
                    JoinId -> Int -> JoinId
`asJoinId` Int
join_arity
addExit :: InScopeSet -> JoinArity -> CoreExpr -> ExitifyM JoinId
addExit :: InScopeSet -> Int -> CoreExpr -> ExitifyM JoinId
addExit InScopeSet
in_scope Int
join_arity CoreExpr
rhs = do
    
    let ty :: Type
ty = CoreExpr -> Type
exprType CoreExpr
rhs
    JoinId
v <- InScopeSet -> Type -> Int -> ExitifyM JoinId
mkExitJoinId InScopeSet
in_scope Type
ty Int
join_arity
    [(JoinId, CoreExpr)]
fs <- State [(JoinId, CoreExpr)] [(JoinId, CoreExpr)]
forall s. State s s
get
    [(JoinId, CoreExpr)] -> State [(JoinId, CoreExpr)] ()
forall s. s -> State s ()
put ((JoinId
v,CoreExpr
rhs)(JoinId, CoreExpr) -> [(JoinId, CoreExpr)] -> [(JoinId, CoreExpr)]
forall a. a -> [a] -> [a]
:[(JoinId, CoreExpr)]
fs)
    JoinId -> ExitifyM JoinId
forall (m :: * -> *) a. Monad m => a -> m a
return JoinId
v