{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE PartialTypeSignatures #-}
{-# LANGUAGE TypeFamilies #-}

-- | Perform array short circuiting
module Futhark.Optimise.ArrayShortCircuiting
  ( optimiseSeqMem,
    optimiseGPUMem,
    optimiseMCMem,
  )
where

import Control.Monad
import Control.Monad.Reader
import Data.Function ((&))
import Data.List qualified as L
import Data.Map qualified as M
import Data.Maybe (fromMaybe)
import Futhark.Analysis.Alias qualified as AnlAls
import Futhark.IR.Aliases
import Futhark.IR.GPUMem
import Futhark.IR.MCMem
import Futhark.IR.Mem.IxFun (substituteInIxFun)
import Futhark.IR.SeqMem
import Futhark.Optimise.ArrayShortCircuiting.ArrayCoalescing
import Futhark.Optimise.ArrayShortCircuiting.DataStructs
import Futhark.Pass (Pass (..))
import Futhark.Pass qualified as Pass
import Futhark.Util

data Env inner = Env
  { forall inner. Env inner -> CoalsTab
envCoalesceTab :: CoalsTab,
    forall inner. Env inner -> inner -> UpdateM inner inner
onInner :: inner -> UpdateM inner inner,
    forall inner. Env inner -> Names
memAllocsToRemove :: Names
  }

type UpdateM inner a = Reader (Env inner) a

optimiseSeqMem :: Pass SeqMem SeqMem
optimiseSeqMem :: Pass SeqMem SeqMem
optimiseSeqMem = String
-> String
-> (Prog (Aliases SeqMem) -> PassM (Map Name CoalsTab))
-> (NoOp SeqMem -> UpdateM (NoOp SeqMem) (NoOp SeqMem))
-> (CoalsTab
    -> [FParam (Aliases SeqMem)] -> (Names, [FParam (Aliases SeqMem)]))
-> Pass SeqMem SeqMem
forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem, AliasableRep rep) =>
String
-> String
-> (Prog (Aliases rep) -> PassM (Map Name CoalsTab))
-> (inner rep -> UpdateM (inner rep) (inner rep))
-> (CoalsTab
    -> [FParam (Aliases rep)] -> (Names, [FParam (Aliases rep)]))
-> Pass rep rep
pass String
"short-circuit" String
"Array Short-Circuiting" Prog (Aliases SeqMem) -> PassM (Map Name CoalsTab)
forall (m :: * -> *).
MonadFreshNames m =>
Prog (Aliases SeqMem) -> m (Map Name CoalsTab)
mkCoalsTab NoOp SeqMem -> UpdateM (NoOp SeqMem) (NoOp SeqMem)
forall a. a -> ReaderT (Env (NoOp SeqMem)) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure CoalsTab
-> [FParam (Aliases SeqMem)] -> (Names, [FParam (Aliases SeqMem)])
CoalsTab
-> [Param (MemInfo SubExp Uniqueness MemBind)]
-> (Names, [Param (MemInfo SubExp Uniqueness MemBind)])
replaceInParams

optimiseGPUMem :: Pass GPUMem GPUMem
optimiseGPUMem :: Pass GPUMem GPUMem
optimiseGPUMem = String
-> String
-> (Prog (Aliases GPUMem) -> PassM (Map Name CoalsTab))
-> (HostOp NoOp GPUMem
    -> UpdateM (HostOp NoOp GPUMem) (HostOp NoOp GPUMem))
-> (CoalsTab
    -> [FParam (Aliases GPUMem)] -> (Names, [FParam (Aliases GPUMem)]))
-> Pass GPUMem GPUMem
forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem, AliasableRep rep) =>
String
-> String
-> (Prog (Aliases rep) -> PassM (Map Name CoalsTab))
-> (inner rep -> UpdateM (inner rep) (inner rep))
-> (CoalsTab
    -> [FParam (Aliases rep)] -> (Names, [FParam (Aliases rep)]))
-> Pass rep rep
pass String
"short-circuit-gpu" String
"Array Short-Circuiting (GPU)" Prog (Aliases GPUMem) -> PassM (Map Name CoalsTab)
forall (m :: * -> *).
MonadFreshNames m =>
Prog (Aliases GPUMem) -> m (Map Name CoalsTab)
mkCoalsTabGPU HostOp NoOp GPUMem
-> UpdateM (HostOp NoOp GPUMem) (HostOp NoOp GPUMem)
replaceInHostOp CoalsTab
-> [FParam (Aliases GPUMem)] -> (Names, [FParam (Aliases GPUMem)])
CoalsTab
-> [Param (MemInfo SubExp Uniqueness MemBind)]
-> (Names, [Param (MemInfo SubExp Uniqueness MemBind)])
replaceInParams

optimiseMCMem :: Pass MCMem MCMem
optimiseMCMem :: Pass MCMem MCMem
optimiseMCMem = String
-> String
-> (Prog (Aliases MCMem) -> PassM (Map Name CoalsTab))
-> (MCOp NoOp MCMem -> UpdateM (MCOp NoOp MCMem) (MCOp NoOp MCMem))
-> (CoalsTab
    -> [FParam (Aliases MCMem)] -> (Names, [FParam (Aliases MCMem)]))
-> Pass MCMem MCMem
forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem, AliasableRep rep) =>
String
-> String
-> (Prog (Aliases rep) -> PassM (Map Name CoalsTab))
-> (inner rep -> UpdateM (inner rep) (inner rep))
-> (CoalsTab
    -> [FParam (Aliases rep)] -> (Names, [FParam (Aliases rep)]))
-> Pass rep rep
pass String
"short-circuit-mc" String
"Array Short-Circuiting (MC)" Prog (Aliases MCMem) -> PassM (Map Name CoalsTab)
forall (m :: * -> *).
MonadFreshNames m =>
Prog (Aliases MCMem) -> m (Map Name CoalsTab)
mkCoalsTabMC MCOp NoOp MCMem -> UpdateM (MCOp NoOp MCMem) (MCOp NoOp MCMem)
replaceInMCOp CoalsTab
-> [FParam (Aliases MCMem)] -> (Names, [FParam (Aliases MCMem)])
CoalsTab
-> [Param (MemInfo SubExp Uniqueness MemBind)]
-> (Names, [Param (MemInfo SubExp Uniqueness MemBind)])
replaceInParams

replaceInParams :: CoalsTab -> [Param FParamMem] -> (Names, [Param FParamMem])
replaceInParams :: CoalsTab
-> [Param (MemInfo SubExp Uniqueness MemBind)]
-> (Names, [Param (MemInfo SubExp Uniqueness MemBind)])
replaceInParams CoalsTab
coalstab [Param (MemInfo SubExp Uniqueness MemBind)]
fparams =
  let (Names
mem_allocs_to_remove, [Param (MemInfo SubExp Uniqueness MemBind)]
fparams') =
        ((Names, [Param (MemInfo SubExp Uniqueness MemBind)])
 -> Param (MemInfo SubExp Uniqueness MemBind)
 -> (Names, [Param (MemInfo SubExp Uniqueness MemBind)]))
-> (Names, [Param (MemInfo SubExp Uniqueness MemBind)])
-> [Param (MemInfo SubExp Uniqueness MemBind)]
-> (Names, [Param (MemInfo SubExp Uniqueness MemBind)])
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (Names, [Param (MemInfo SubExp Uniqueness MemBind)])
-> Param (MemInfo SubExp Uniqueness MemBind)
-> (Names, [Param (MemInfo SubExp Uniqueness MemBind)])
replaceInParam (Names
forall a. Monoid a => a
mempty, [Param (MemInfo SubExp Uniqueness MemBind)]
forall a. Monoid a => a
mempty) [Param (MemInfo SubExp Uniqueness MemBind)]
fparams
   in (Names
mem_allocs_to_remove, [Param (MemInfo SubExp Uniqueness MemBind)]
-> [Param (MemInfo SubExp Uniqueness MemBind)]
forall a. [a] -> [a]
reverse [Param (MemInfo SubExp Uniqueness MemBind)]
fparams')
  where
    replaceInParam :: (Names, [Param (MemInfo SubExp Uniqueness MemBind)])
-> Param (MemInfo SubExp Uniqueness MemBind)
-> (Names, [Param (MemInfo SubExp Uniqueness MemBind)])
replaceInParam (Names
to_remove, [Param (MemInfo SubExp Uniqueness MemBind)]
acc) (Param Attrs
attrs VName
name MemInfo SubExp Uniqueness MemBind
dec) =
      case MemInfo SubExp Uniqueness MemBind
dec of
        MemMem Space
_
          | Just CoalsEntry
entry <- VName -> CoalsTab -> Maybe CoalsEntry
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
name CoalsTab
coalstab ->
              (VName -> Names
oneName (CoalsEntry -> VName
dstmem CoalsEntry
entry) Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
to_remove, Attrs
-> VName
-> MemInfo SubExp Uniqueness MemBind
-> Param (MemInfo SubExp Uniqueness MemBind)
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
attrs (CoalsEntry -> VName
dstmem CoalsEntry
entry) MemInfo SubExp Uniqueness MemBind
dec Param (MemInfo SubExp Uniqueness MemBind)
-> [Param (MemInfo SubExp Uniqueness MemBind)]
-> [Param (MemInfo SubExp Uniqueness MemBind)]
forall a. a -> [a] -> [a]
: [Param (MemInfo SubExp Uniqueness MemBind)]
acc)
        MemArray PrimType
pt ShapeBase SubExp
shp Uniqueness
u (ArrayIn VName
m IxFun
ixf)
          | Just CoalsEntry
entry <- VName -> CoalsTab -> Maybe CoalsEntry
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
m CoalsTab
coalstab ->
              (Names
to_remove, Attrs
-> VName
-> MemInfo SubExp Uniqueness MemBind
-> Param (MemInfo SubExp Uniqueness MemBind)
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
attrs VName
name (PrimType
-> ShapeBase SubExp
-> Uniqueness
-> MemBind
-> MemInfo SubExp Uniqueness MemBind
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase SubExp
shp Uniqueness
u (MemBind -> MemInfo SubExp Uniqueness MemBind)
-> MemBind -> MemInfo SubExp Uniqueness MemBind
forall a b. (a -> b) -> a -> b
$ VName -> IxFun -> MemBind
ArrayIn (CoalsEntry -> VName
dstmem CoalsEntry
entry) IxFun
ixf) Param (MemInfo SubExp Uniqueness MemBind)
-> [Param (MemInfo SubExp Uniqueness MemBind)]
-> [Param (MemInfo SubExp Uniqueness MemBind)]
forall a. a -> [a] -> [a]
: [Param (MemInfo SubExp Uniqueness MemBind)]
acc)
        MemInfo SubExp Uniqueness MemBind
_ -> (Names
to_remove, Attrs
-> VName
-> MemInfo SubExp Uniqueness MemBind
-> Param (MemInfo SubExp Uniqueness MemBind)
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
attrs VName
name MemInfo SubExp Uniqueness MemBind
dec Param (MemInfo SubExp Uniqueness MemBind)
-> [Param (MemInfo SubExp Uniqueness MemBind)]
-> [Param (MemInfo SubExp Uniqueness MemBind)]
forall a. a -> [a] -> [a]
: [Param (MemInfo SubExp Uniqueness MemBind)]
acc)

removeAllocsInStms :: Stms rep -> UpdateM inner (Stms rep)
removeAllocsInStms :: forall rep inner. Stms rep -> UpdateM inner (Stms rep)
removeAllocsInStms Stms rep
stms = do
  Names
to_remove <- (Env inner -> Names) -> ReaderT (Env inner) Identity Names
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env inner -> Names
forall inner. Env inner -> Names
memAllocsToRemove
  Stms rep -> [Stm rep]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms rep
stms
    [Stm rep] -> ([Stm rep] -> [Stm rep]) -> [Stm rep]
forall a b. a -> (a -> b) -> b
& (Stm rep -> Bool) -> [Stm rep] -> [Stm rep]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not (Bool -> Bool) -> (Stm rep -> Bool) -> Stm rep -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> Names -> Bool) -> Names -> VName -> Bool
forall a b c. (a -> b -> c) -> b -> a -> c
flip VName -> Names -> Bool
nameIn Names
to_remove (VName -> Bool) -> (Stm rep -> VName) -> Stm rep -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [VName] -> VName
forall a. HasCallStack => [a] -> a
head ([VName] -> VName) -> (Stm rep -> [VName]) -> Stm rep -> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pat (LetDec rep) -> [VName]
forall dec. Pat dec -> [VName]
patNames (Pat (LetDec rep) -> [VName])
-> (Stm rep -> Pat (LetDec rep)) -> Stm rep -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm rep -> Pat (LetDec rep)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat)
    [Stm rep] -> ([Stm rep] -> Stms rep) -> Stms rep
forall a b. a -> (a -> b) -> b
& [Stm rep] -> Stms rep
forall rep. [Stm rep] -> Stms rep
stmsFromList
    Stms rep
-> (Stms rep -> UpdateM inner (Stms rep))
-> UpdateM inner (Stms rep)
forall a b. a -> (a -> b) -> b
& Stms rep -> UpdateM inner (Stms rep)
forall a. a -> ReaderT (Env inner) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure

pass ::
  (Mem rep inner, LetDec rep ~ LetDecMem, AliasableRep rep) =>
  String ->
  String ->
  (Prog (Aliases rep) -> Pass.PassM (M.Map Name CoalsTab)) ->
  (inner rep -> UpdateM (inner rep) (inner rep)) ->
  (CoalsTab -> [FParam (Aliases rep)] -> (Names, [FParam (Aliases rep)])) ->
  Pass rep rep
pass :: forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem, AliasableRep rep) =>
String
-> String
-> (Prog (Aliases rep) -> PassM (Map Name CoalsTab))
-> (inner rep -> UpdateM (inner rep) (inner rep))
-> (CoalsTab
    -> [FParam (Aliases rep)] -> (Names, [FParam (Aliases rep)]))
-> Pass rep rep
pass String
flag String
desc Prog (Aliases rep) -> PassM (Map Name CoalsTab)
mk inner rep -> UpdateM (inner rep) (inner rep)
on_inner CoalsTab
-> [FParam (Aliases rep)] -> (Names, [FParam (Aliases rep)])
on_fparams =
  String -> String -> (Prog rep -> PassM (Prog rep)) -> Pass rep rep
forall fromrep torep.
String
-> String
-> (Prog fromrep -> PassM (Prog torep))
-> Pass fromrep torep
Pass String
flag String
desc ((Prog rep -> PassM (Prog rep)) -> Pass rep rep)
-> (Prog rep -> PassM (Prog rep)) -> Pass rep rep
forall a b. (a -> b) -> a -> b
$ \Prog rep
prog -> do
    Map Name CoalsTab
coaltabs <- Prog (Aliases rep) -> PassM (Map Name CoalsTab)
mk (Prog (Aliases rep) -> PassM (Map Name CoalsTab))
-> Prog (Aliases rep) -> PassM (Map Name CoalsTab)
forall a b. (a -> b) -> a -> b
$ Prog rep -> Prog (Aliases rep)
forall rep. AliasableRep rep => Prog rep -> Prog (Aliases rep)
AnlAls.aliasAnalysis Prog rep
prog
    (Stms rep -> PassM (Stms rep))
-> (Stms rep -> FunDef rep -> PassM (FunDef rep))
-> Prog rep
-> PassM (Prog rep)
forall fromrep torep.
(Stms fromrep -> PassM (Stms torep))
-> (Stms torep -> FunDef fromrep -> PassM (FunDef torep))
-> Prog fromrep
-> PassM (Prog torep)
Pass.intraproceduralTransformationWithConsts Stms rep -> PassM (Stms rep)
forall a. a -> PassM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Map Name CoalsTab -> Stms rep -> FunDef rep -> PassM (FunDef rep)
onFun Map Name CoalsTab
coaltabs) Prog rep
prog
  where
    onFun :: Map Name CoalsTab -> Stms rep -> FunDef rep -> PassM (FunDef rep)
onFun Map Name CoalsTab
coaltabs Stms rep
_ FunDef rep
f = do
      let coaltab :: CoalsTab
coaltab = Map Name CoalsTab
coaltabs Map Name CoalsTab -> Name -> CoalsTab
forall k a. Ord k => Map k a -> k -> a
M.! FunDef rep -> Name
forall rep. FunDef rep -> Name
funDefName FunDef rep
f
      let (Names
mem_allocs_to_remove, [FParam (Aliases rep)]
new_fparams) = CoalsTab
-> [FParam (Aliases rep)] -> (Names, [FParam (Aliases rep)])
on_fparams CoalsTab
coaltab ([FParam (Aliases rep)] -> (Names, [FParam (Aliases rep)]))
-> [FParam (Aliases rep)] -> (Names, [FParam (Aliases rep)])
forall a b. (a -> b) -> a -> b
$ FunDef rep -> [FParam rep]
forall rep. FunDef rep -> [FParam rep]
funDefParams FunDef rep
f
      FunDef rep -> PassM (FunDef rep)
forall a. a -> PassM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (FunDef rep -> PassM (FunDef rep))
-> FunDef rep -> PassM (FunDef rep)
forall a b. (a -> b) -> a -> b
$
        FunDef rep
f
          { funDefBody :: Body rep
funDefBody = CoalsTab -> Names -> Body rep -> Body rep
onBody CoalsTab
coaltab Names
mem_allocs_to_remove (Body rep -> Body rep) -> Body rep -> Body rep
forall a b. (a -> b) -> a -> b
$ FunDef rep -> Body rep
forall rep. FunDef rep -> Body rep
funDefBody FunDef rep
f,
            funDefParams :: [FParam rep]
funDefParams = [FParam rep]
[FParam (Aliases rep)]
new_fparams
          }

    onBody :: CoalsTab -> Names -> Body rep -> Body rep
onBody CoalsTab
coaltab Names
mem_allocs_to_remove Body rep
body =
      Body rep
body
        { bodyStms :: Stms rep
bodyStms =
            Reader (Env (inner rep)) (Stms rep) -> Env (inner rep) -> Stms rep
forall r a. Reader r a -> r -> a
runReader
              (Stms rep -> Reader (Env (inner rep)) (Stms rep)
forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Stms rep -> UpdateM (inner rep) (Stms rep)
updateStms (Stms rep -> Reader (Env (inner rep)) (Stms rep))
-> Stms rep -> Reader (Env (inner rep)) (Stms rep)
forall a b. (a -> b) -> a -> b
$ Body rep -> Stms rep
forall rep. Body rep -> Stms rep
bodyStms Body rep
body)
              (CoalsTab
-> (inner rep -> UpdateM (inner rep) (inner rep))
-> Names
-> Env (inner rep)
forall inner.
CoalsTab -> (inner -> UpdateM inner inner) -> Names -> Env inner
Env CoalsTab
coaltab inner rep -> UpdateM (inner rep) (inner rep)
on_inner Names
mem_allocs_to_remove),
          bodyResult :: Result
bodyResult = (SubExpRes -> SubExpRes) -> Result -> Result
forall a b. (a -> b) -> [a] -> [b]
map (CoalsTab -> SubExpRes -> SubExpRes
replaceResMem CoalsTab
coaltab) (Result -> Result) -> Result -> Result
forall a b. (a -> b) -> a -> b
$ Body rep -> Result
forall rep. Body rep -> Result
bodyResult Body rep
body
        }

replaceResMem :: CoalsTab -> SubExpRes -> SubExpRes
replaceResMem :: CoalsTab -> SubExpRes -> SubExpRes
replaceResMem CoalsTab
coaltab SubExpRes
res =
  case (VName -> CoalsTab -> Maybe CoalsEntry)
-> CoalsTab -> VName -> Maybe CoalsEntry
forall a b c. (a -> b -> c) -> b -> a -> c
flip VName -> CoalsTab -> Maybe CoalsEntry
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup CoalsTab
coaltab (VName -> Maybe CoalsEntry) -> Maybe VName -> Maybe CoalsEntry
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SubExpRes -> Maybe VName
subExpResVName SubExpRes
res of
    Just CoalsEntry
entry -> SubExpRes
res {resSubExp :: SubExp
resSubExp = VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ CoalsEntry -> VName
dstmem CoalsEntry
entry}
    Maybe CoalsEntry
Nothing -> SubExpRes
res

updateStms ::
  (Mem rep inner, LetDec rep ~ LetDecMem) =>
  Stms rep ->
  UpdateM (inner rep) (Stms rep)
updateStms :: forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Stms rep -> UpdateM (inner rep) (Stms rep)
updateStms Stms rep
stms = do
  Stms rep
stms' <- (Stm rep -> ReaderT (Env (inner rep)) Identity (Stm rep))
-> Stms rep -> UpdateM (inner rep) (Stms rep)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> Seq a -> m (Seq b)
mapM Stm rep -> ReaderT (Env (inner rep)) Identity (Stm rep)
forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Stm rep -> UpdateM (inner rep) (Stm rep)
replaceInStm Stms rep
stms
  Stms rep -> UpdateM (inner rep) (Stms rep)
forall rep inner. Stms rep -> UpdateM inner (Stms rep)
removeAllocsInStms Stms rep
stms'

replaceInStm ::
  (Mem rep inner, LetDec rep ~ LetDecMem) =>
  Stm rep ->
  UpdateM (inner rep) (Stm rep)
replaceInStm :: forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Stm rep -> UpdateM (inner rep) (Stm rep)
replaceInStm (Let (Pat [PatElem (LetDec rep)]
elems) (StmAux Certs
c Attrs
a ExpDec rep
d) Exp rep
e) = do
  [PatElem LetDecMem]
elems' <- (PatElem LetDecMem
 -> ReaderT (Env (inner rep)) Identity (PatElem LetDecMem))
-> [PatElem LetDecMem]
-> ReaderT (Env (inner rep)) Identity [PatElem LetDecMem]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM PatElem LetDecMem
-> ReaderT (Env (inner rep)) Identity (PatElem LetDecMem)
forall inner.
PatElem LetDecMem -> UpdateM inner (PatElem LetDecMem)
replaceInPatElem [PatElem (LetDec rep)]
[PatElem LetDecMem]
elems
  Exp rep
e' <- [PatElem LetDecMem] -> Exp rep -> UpdateM (inner rep) (Exp rep)
forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem) =>
[PatElem LetDecMem] -> Exp rep -> UpdateM (inner rep) (Exp rep)
replaceInExp [PatElem LetDecMem]
elems' Exp rep
e
  [CoalsEntry]
entries <- (Env (inner rep) -> [CoalsEntry])
-> ReaderT (Env (inner rep)) Identity [CoalsEntry]
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks (CoalsTab -> [CoalsEntry]
forall k a. Map k a -> [a]
M.elems (CoalsTab -> [CoalsEntry])
-> (Env (inner rep) -> CoalsTab) -> Env (inner rep) -> [CoalsEntry]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env (inner rep) -> CoalsTab
forall inner. Env inner -> CoalsTab
envCoalesceTab)
  let c' :: Certs
c' = case (CoalsEntry -> Bool) -> [CoalsEntry] -> [CoalsEntry]
forall a. (a -> Bool) -> [a] -> [a]
filter (\CoalsEntry
entry -> ((PatElem LetDecMem -> VName) -> [PatElem LetDecMem] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map PatElem LetDecMem -> VName
forall dec. PatElem dec -> VName
patElemName [PatElem (LetDec rep)]
[PatElem LetDecMem]
elems [VName] -> [VName] -> [VName]
forall a. Eq a => [a] -> [a] -> [a]
`L.intersect` Map VName Coalesced -> [VName]
forall k a. Map k a -> [k]
M.keys (CoalsEntry -> Map VName Coalesced
vartab CoalsEntry
entry)) [VName] -> [VName] -> Bool
forall a. Eq a => a -> a -> Bool
/= []) [CoalsEntry]
entries of
        [] -> Certs
c
        [CoalsEntry]
entries' -> Certs
c Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> (CoalsEntry -> Certs) -> [CoalsEntry] -> Certs
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap CoalsEntry -> Certs
certs [CoalsEntry]
entries'
  Stm rep -> UpdateM (inner rep) (Stm rep)
forall a. a -> ReaderT (Env (inner rep)) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stm rep -> UpdateM (inner rep) (Stm rep))
-> Stm rep -> UpdateM (inner rep) (Stm rep)
forall a b. (a -> b) -> a -> b
$ Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let ([PatElem LetDecMem] -> Pat LetDecMem
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem LetDecMem]
elems') (Certs -> Attrs -> ExpDec rep -> StmAux (ExpDec rep)
forall dec. Certs -> Attrs -> dec -> StmAux dec
StmAux Certs
c' Attrs
a ExpDec rep
d) Exp rep
e'
  where
    replaceInPatElem :: PatElem LetDecMem -> UpdateM inner (PatElem LetDecMem)
    replaceInPatElem :: forall inner.
PatElem LetDecMem -> UpdateM inner (PatElem LetDecMem)
replaceInPatElem p :: PatElem LetDecMem
p@(PatElem VName
vname (MemArray PrimType
_ ShapeBase SubExp
_ NoUniqueness
u MemBind
_)) =
      PatElem LetDecMem -> Maybe (PatElem LetDecMem) -> PatElem LetDecMem
forall a. a -> Maybe a -> a
fromMaybe PatElem LetDecMem
p (Maybe (PatElem LetDecMem) -> PatElem LetDecMem)
-> ReaderT (Env inner) Identity (Maybe (PatElem LetDecMem))
-> ReaderT (Env inner) Identity (PatElem LetDecMem)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName
-> (VName -> LetDecMem -> PatElem LetDecMem)
-> NoUniqueness
-> ReaderT (Env inner) Identity (Maybe (PatElem LetDecMem))
forall u a inner.
VName -> (VName -> MemBound u -> a) -> u -> UpdateM inner (Maybe a)
lookupAndReplace VName
vname VName -> LetDecMem -> PatElem LetDecMem
forall dec. VName -> dec -> PatElem dec
PatElem NoUniqueness
u
    replaceInPatElem PatElem LetDecMem
p = PatElem LetDecMem
-> ReaderT (Env inner) Identity (PatElem LetDecMem)
forall a. a -> ReaderT (Env inner) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure PatElem LetDecMem
p

replaceInExp ::
  (Mem rep inner, LetDec rep ~ LetDecMem) =>
  [PatElem LetDecMem] ->
  Exp rep ->
  UpdateM (inner rep) (Exp rep)
replaceInExp :: forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem) =>
[PatElem LetDecMem] -> Exp rep -> UpdateM (inner rep) (Exp rep)
replaceInExp [PatElem LetDecMem]
_ e :: Exp rep
e@(BasicOp BasicOp
_) = Exp rep -> ReaderT (Env (inner rep)) Identity (Exp rep)
forall a. a -> ReaderT (Env (inner rep)) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Exp rep
e
replaceInExp [PatElem LetDecMem]
pat_elems (Match [SubExp]
cond_ses [Case (Body rep)]
cases Body rep
defbody MatchDec (BranchType rep)
dec) = do
  Body rep
defbody' <- Body rep -> UpdateM (inner rep) (Body rep)
forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Body rep -> UpdateM (inner rep) (Body rep)
replaceInIfBody Body rep
defbody
  [Case (Body rep)]
cases' <- (Case (Body rep)
 -> ReaderT (Env (inner rep)) Identity (Case (Body rep)))
-> [Case (Body rep)]
-> ReaderT (Env (inner rep)) Identity [Case (Body rep)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\(Case [Maybe PrimValue]
p Body rep
b) -> [Maybe PrimValue] -> Body rep -> Case (Body rep)
forall body. [Maybe PrimValue] -> body -> Case body
Case [Maybe PrimValue]
p (Body rep -> Case (Body rep))
-> UpdateM (inner rep) (Body rep)
-> ReaderT (Env (inner rep)) Identity (Case (Body rep))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Body rep -> UpdateM (inner rep) (Body rep)
forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Body rep -> UpdateM (inner rep) (Body rep)
replaceInIfBody Body rep
b) [Case (Body rep)]
cases
  [BranchTypeMem]
case_rets <- (PatElem LetDecMem
 -> BranchTypeMem
 -> ReaderT (Env (inner rep)) Identity BranchTypeMem)
-> [PatElem LetDecMem]
-> [BranchTypeMem]
-> ReaderT (Env (inner rep)) Identity [BranchTypeMem]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM ([PatElem LetDecMem]
-> PatElem LetDecMem
-> BranchTypeMem
-> ReaderT (Env (inner rep)) Identity BranchTypeMem
forall dec inner.
[PatElem dec]
-> PatElem LetDecMem
-> BranchTypeMem
-> UpdateM inner BranchTypeMem
generalizeIxfun [PatElem LetDecMem]
pat_elems) [PatElem LetDecMem]
pat_elems ([BranchTypeMem]
 -> ReaderT (Env (inner rep)) Identity [BranchTypeMem])
-> [BranchTypeMem]
-> ReaderT (Env (inner rep)) Identity [BranchTypeMem]
forall a b. (a -> b) -> a -> b
$ MatchDec BranchTypeMem -> [BranchTypeMem]
forall rt. MatchDec rt -> [rt]
matchReturns MatchDec (BranchType rep)
MatchDec BranchTypeMem
dec
  let dec' :: MatchDec BranchTypeMem
dec' = MatchDec (BranchType rep)
MatchDec BranchTypeMem
dec {matchReturns :: [BranchTypeMem]
matchReturns = [BranchTypeMem]
case_rets}
  Exp rep -> ReaderT (Env (inner rep)) Identity (Exp rep)
forall a. a -> ReaderT (Env (inner rep)) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp rep -> ReaderT (Env (inner rep)) Identity (Exp rep))
-> Exp rep -> ReaderT (Env (inner rep)) Identity (Exp rep)
forall a b. (a -> b) -> a -> b
$ [SubExp]
-> [Case (Body rep)]
-> Body rep
-> MatchDec (BranchType rep)
-> Exp rep
forall rep.
[SubExp]
-> [Case (Body rep)]
-> Body rep
-> MatchDec (BranchType rep)
-> Exp rep
Match [SubExp]
cond_ses [Case (Body rep)]
cases' Body rep
defbody' MatchDec (BranchType rep)
MatchDec BranchTypeMem
dec'
replaceInExp [PatElem LetDecMem]
_ (Loop [(FParam rep, SubExp)]
loop_inits LoopForm
loop_form (Body BodyDec rep
dec Stms rep
stms Result
res)) = do
  [Param (MemInfo SubExp Uniqueness MemBind)]
loop_inits' <- ((Param (MemInfo SubExp Uniqueness MemBind), SubExp)
 -> ReaderT
      (Env (inner rep))
      Identity
      (Param (MemInfo SubExp Uniqueness MemBind)))
-> [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
-> ReaderT
     (Env (inner rep))
     Identity
     [Param (MemInfo SubExp Uniqueness MemBind)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Param (MemInfo SubExp Uniqueness MemBind)
-> ReaderT
     (Env (inner rep))
     Identity
     (Param (MemInfo SubExp Uniqueness MemBind))
forall inner.
Param (MemInfo SubExp Uniqueness MemBind)
-> UpdateM inner (Param (MemInfo SubExp Uniqueness MemBind))
replaceInFParam (Param (MemInfo SubExp Uniqueness MemBind)
 -> ReaderT
      (Env (inner rep))
      Identity
      (Param (MemInfo SubExp Uniqueness MemBind)))
-> ((Param (MemInfo SubExp Uniqueness MemBind), SubExp)
    -> Param (MemInfo SubExp Uniqueness MemBind))
-> (Param (MemInfo SubExp Uniqueness MemBind), SubExp)
-> ReaderT
     (Env (inner rep))
     Identity
     (Param (MemInfo SubExp Uniqueness MemBind))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param (MemInfo SubExp Uniqueness MemBind), SubExp)
-> Param (MemInfo SubExp Uniqueness MemBind)
forall a b. (a, b) -> a
fst) [(FParam rep, SubExp)]
[(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
loop_inits
  Stms rep
stms' <- Stms rep -> UpdateM (inner rep) (Stms rep)
forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Stms rep -> UpdateM (inner rep) (Stms rep)
updateStms Stms rep
stms
  CoalsTab
coalstab <- (Env (inner rep) -> CoalsTab)
-> ReaderT (Env (inner rep)) Identity CoalsTab
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env (inner rep) -> CoalsTab
forall inner. Env inner -> CoalsTab
envCoalesceTab
  let res' :: Result
res' = (SubExpRes -> SubExpRes) -> Result -> Result
forall a b. (a -> b) -> [a] -> [b]
map (CoalsTab -> SubExpRes -> SubExpRes
replaceResMem CoalsTab
coalstab) Result
res
  Exp rep -> ReaderT (Env (inner rep)) Identity (Exp rep)
forall a. a -> ReaderT (Env (inner rep)) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp rep -> ReaderT (Env (inner rep)) Identity (Exp rep))
-> Exp rep -> ReaderT (Env (inner rep)) Identity (Exp rep)
forall a b. (a -> b) -> a -> b
$ [(FParam rep, SubExp)] -> LoopForm -> Body rep -> Exp rep
forall rep.
[(FParam rep, SubExp)] -> LoopForm -> Body rep -> Exp rep
Loop ([FParam rep] -> [SubExp] -> [(FParam rep, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [FParam rep]
[Param (MemInfo SubExp Uniqueness MemBind)]
loop_inits' ([SubExp] -> [(FParam rep, SubExp)])
-> [SubExp] -> [(FParam rep, SubExp)]
forall a b. (a -> b) -> a -> b
$ ((Param (MemInfo SubExp Uniqueness MemBind), SubExp) -> SubExp)
-> [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
-> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (Param (MemInfo SubExp Uniqueness MemBind), SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(FParam rep, SubExp)]
[(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
loop_inits) LoopForm
loop_form (Body rep -> Exp rep) -> Body rep -> Exp rep
forall a b. (a -> b) -> a -> b
$ BodyDec rep -> Stms rep -> Result -> Body rep
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body BodyDec rep
dec Stms rep
stms' Result
res'
replaceInExp [PatElem LetDecMem]
_ (Op Op rep
op) =
  case Op rep
op of
    Inner inner rep
i -> do
      inner rep -> UpdateM (inner rep) (inner rep)
on_op <- (Env (inner rep) -> inner rep -> UpdateM (inner rep) (inner rep))
-> ReaderT
     (Env (inner rep))
     Identity
     (inner rep -> UpdateM (inner rep) (inner rep))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env (inner rep) -> inner rep -> UpdateM (inner rep) (inner rep)
forall inner. Env inner -> inner -> UpdateM inner inner
onInner
      Op rep -> Exp rep
MemOp inner rep -> Exp rep
forall rep. Op rep -> Exp rep
Op (MemOp inner rep -> Exp rep)
-> (inner rep -> MemOp inner rep) -> inner rep -> Exp rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
. inner rep -> MemOp inner rep
forall (inner :: * -> *) rep. inner rep -> MemOp inner rep
Inner (inner rep -> Exp rep)
-> UpdateM (inner rep) (inner rep)
-> ReaderT (Env (inner rep)) Identity (Exp rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> inner rep -> UpdateM (inner rep) (inner rep)
on_op inner rep
i
    Op rep
_ -> Exp rep -> ReaderT (Env (inner rep)) Identity (Exp rep)
forall a. a -> ReaderT (Env (inner rep)) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp rep -> ReaderT (Env (inner rep)) Identity (Exp rep))
-> Exp rep -> ReaderT (Env (inner rep)) Identity (Exp rep)
forall a b. (a -> b) -> a -> b
$ Op rep -> Exp rep
forall rep. Op rep -> Exp rep
Op Op rep
op
replaceInExp [PatElem LetDecMem]
_ e :: Exp rep
e@WithAcc {} = Exp rep -> ReaderT (Env (inner rep)) Identity (Exp rep)
forall a. a -> ReaderT (Env (inner rep)) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Exp rep
e
replaceInExp [PatElem LetDecMem]
_ e :: Exp rep
e@Apply {} = Exp rep -> ReaderT (Env (inner rep)) Identity (Exp rep)
forall a. a -> ReaderT (Env (inner rep)) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Exp rep
e

replaceInSegOp ::
  (Mem rep inner, LetDec rep ~ LetDecMem) =>
  SegOp lvl rep ->
  UpdateM (inner rep) (SegOp lvl rep)
replaceInSegOp :: forall rep (inner :: * -> *) lvl.
(Mem rep inner, LetDec rep ~ LetDecMem) =>
SegOp lvl rep -> UpdateM (inner rep) (SegOp lvl rep)
replaceInSegOp (SegMap lvl
lvl SegSpace
sp [Type]
tps KernelBody rep
body) = do
  Stms rep
stms <- Stms rep -> UpdateM (inner rep) (Stms rep)
forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Stms rep -> UpdateM (inner rep) (Stms rep)
updateStms (Stms rep -> UpdateM (inner rep) (Stms rep))
-> Stms rep -> UpdateM (inner rep) (Stms rep)
forall a b. (a -> b) -> a -> b
$ KernelBody rep -> Stms rep
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody rep
body
  SegOp lvl rep -> UpdateM (inner rep) (SegOp lvl rep)
forall a. a -> ReaderT (Env (inner rep)) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SegOp lvl rep -> UpdateM (inner rep) (SegOp lvl rep))
-> SegOp lvl rep -> UpdateM (inner rep) (SegOp lvl rep)
forall a b. (a -> b) -> a -> b
$ lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
forall lvl rep.
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap lvl
lvl SegSpace
sp [Type]
tps (KernelBody rep -> SegOp lvl rep)
-> KernelBody rep -> SegOp lvl rep
forall a b. (a -> b) -> a -> b
$ KernelBody rep
body {kernelBodyStms :: Stms rep
kernelBodyStms = Stms rep
stms}
replaceInSegOp (SegRed lvl
lvl SegSpace
sp [SegBinOp rep]
binops [Type]
tps KernelBody rep
body) = do
  Stms rep
stms <- Stms rep -> UpdateM (inner rep) (Stms rep)
forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Stms rep -> UpdateM (inner rep) (Stms rep)
updateStms (Stms rep -> UpdateM (inner rep) (Stms rep))
-> Stms rep -> UpdateM (inner rep) (Stms rep)
forall a b. (a -> b) -> a -> b
$ KernelBody rep -> Stms rep
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody rep
body
  SegOp lvl rep -> UpdateM (inner rep) (SegOp lvl rep)
forall a. a -> ReaderT (Env (inner rep)) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SegOp lvl rep -> UpdateM (inner rep) (SegOp lvl rep))
-> SegOp lvl rep -> UpdateM (inner rep) (SegOp lvl rep)
forall a b. (a -> b) -> a -> b
$ lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegRed lvl
lvl SegSpace
sp [SegBinOp rep]
binops [Type]
tps (KernelBody rep -> SegOp lvl rep)
-> KernelBody rep -> SegOp lvl rep
forall a b. (a -> b) -> a -> b
$ KernelBody rep
body {kernelBodyStms :: Stms rep
kernelBodyStms = Stms rep
stms}
replaceInSegOp (SegScan lvl
lvl SegSpace
sp [SegBinOp rep]
binops [Type]
tps KernelBody rep
body) = do
  Stms rep
stms <- Stms rep -> UpdateM (inner rep) (Stms rep)
forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Stms rep -> UpdateM (inner rep) (Stms rep)
updateStms (Stms rep -> UpdateM (inner rep) (Stms rep))
-> Stms rep -> UpdateM (inner rep) (Stms rep)
forall a b. (a -> b) -> a -> b
$ KernelBody rep -> Stms rep
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody rep
body
  SegOp lvl rep -> UpdateM (inner rep) (SegOp lvl rep)
forall a. a -> ReaderT (Env (inner rep)) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SegOp lvl rep -> UpdateM (inner rep) (SegOp lvl rep))
-> SegOp lvl rep -> UpdateM (inner rep) (SegOp lvl rep)
forall a b. (a -> b) -> a -> b
$ lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegScan lvl
lvl SegSpace
sp [SegBinOp rep]
binops [Type]
tps (KernelBody rep -> SegOp lvl rep)
-> KernelBody rep -> SegOp lvl rep
forall a b. (a -> b) -> a -> b
$ KernelBody rep
body {kernelBodyStms :: Stms rep
kernelBodyStms = Stms rep
stms}
replaceInSegOp (SegHist lvl
lvl SegSpace
sp [HistOp rep]
hist_ops [Type]
tps KernelBody rep
body) = do
  Stms rep
stms <- Stms rep -> UpdateM (inner rep) (Stms rep)
forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Stms rep -> UpdateM (inner rep) (Stms rep)
updateStms (Stms rep -> UpdateM (inner rep) (Stms rep))
-> Stms rep -> UpdateM (inner rep) (Stms rep)
forall a b. (a -> b) -> a -> b
$ KernelBody rep -> Stms rep
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody rep
body
  SegOp lvl rep -> UpdateM (inner rep) (SegOp lvl rep)
forall a. a -> ReaderT (Env (inner rep)) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SegOp lvl rep -> UpdateM (inner rep) (SegOp lvl rep))
-> SegOp lvl rep -> UpdateM (inner rep) (SegOp lvl rep)
forall a b. (a -> b) -> a -> b
$ lvl
-> SegSpace
-> [HistOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
forall lvl rep.
lvl
-> SegSpace
-> [HistOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegHist lvl
lvl SegSpace
sp [HistOp rep]
hist_ops [Type]
tps (KernelBody rep -> SegOp lvl rep)
-> KernelBody rep -> SegOp lvl rep
forall a b. (a -> b) -> a -> b
$ KernelBody rep
body {kernelBodyStms :: Stms rep
kernelBodyStms = Stms rep
stms}

replaceInHostOp :: HostOp NoOp GPUMem -> UpdateM (HostOp NoOp GPUMem) (HostOp NoOp GPUMem)
replaceInHostOp :: HostOp NoOp GPUMem
-> UpdateM (HostOp NoOp GPUMem) (HostOp NoOp GPUMem)
replaceInHostOp (SegOp SegOp SegLevel GPUMem
op) = SegOp SegLevel GPUMem -> HostOp NoOp GPUMem
forall (op :: * -> *) rep. SegOp SegLevel rep -> HostOp op rep
SegOp (SegOp SegLevel GPUMem -> HostOp NoOp GPUMem)
-> ReaderT
     (Env (HostOp NoOp GPUMem)) Identity (SegOp SegLevel GPUMem)
-> UpdateM (HostOp NoOp GPUMem) (HostOp NoOp GPUMem)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOp SegLevel GPUMem
-> ReaderT
     (Env (HostOp NoOp GPUMem)) Identity (SegOp SegLevel GPUMem)
forall rep (inner :: * -> *) lvl.
(Mem rep inner, LetDec rep ~ LetDecMem) =>
SegOp lvl rep -> UpdateM (inner rep) (SegOp lvl rep)
replaceInSegOp SegOp SegLevel GPUMem
op
replaceInHostOp HostOp NoOp GPUMem
op = HostOp NoOp GPUMem
-> UpdateM (HostOp NoOp GPUMem) (HostOp NoOp GPUMem)
forall a. a -> ReaderT (Env (HostOp NoOp GPUMem)) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure HostOp NoOp GPUMem
op

replaceInMCOp :: MCOp NoOp MCMem -> UpdateM (MCOp NoOp MCMem) (MCOp NoOp MCMem)
replaceInMCOp :: MCOp NoOp MCMem -> UpdateM (MCOp NoOp MCMem) (MCOp NoOp MCMem)
replaceInMCOp (ParOp Maybe (SegOp () MCMem)
par_op SegOp () MCMem
op) =
  Maybe (SegOp () MCMem) -> SegOp () MCMem -> MCOp NoOp MCMem
forall (op :: * -> *) rep.
Maybe (SegOp () rep) -> SegOp () rep -> MCOp op rep
ParOp (Maybe (SegOp () MCMem) -> SegOp () MCMem -> MCOp NoOp MCMem)
-> ReaderT
     (Env (MCOp NoOp MCMem)) Identity (Maybe (SegOp () MCMem))
-> ReaderT
     (Env (MCOp NoOp MCMem))
     Identity
     (SegOp () MCMem -> MCOp NoOp MCMem)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SegOp () MCMem
 -> ReaderT (Env (MCOp NoOp MCMem)) Identity (SegOp () MCMem))
-> Maybe (SegOp () MCMem)
-> ReaderT
     (Env (MCOp NoOp MCMem)) Identity (Maybe (SegOp () MCMem))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> Maybe a -> f (Maybe b)
traverse SegOp () MCMem
-> ReaderT (Env (MCOp NoOp MCMem)) Identity (SegOp () MCMem)
forall rep (inner :: * -> *) lvl.
(Mem rep inner, LetDec rep ~ LetDecMem) =>
SegOp lvl rep -> UpdateM (inner rep) (SegOp lvl rep)
replaceInSegOp Maybe (SegOp () MCMem)
par_op ReaderT
  (Env (MCOp NoOp MCMem))
  Identity
  (SegOp () MCMem -> MCOp NoOp MCMem)
-> ReaderT (Env (MCOp NoOp MCMem)) Identity (SegOp () MCMem)
-> UpdateM (MCOp NoOp MCMem) (MCOp NoOp MCMem)
forall a b.
ReaderT (Env (MCOp NoOp MCMem)) Identity (a -> b)
-> ReaderT (Env (MCOp NoOp MCMem)) Identity a
-> ReaderT (Env (MCOp NoOp MCMem)) Identity b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOp () MCMem
-> ReaderT (Env (MCOp NoOp MCMem)) Identity (SegOp () MCMem)
forall rep (inner :: * -> *) lvl.
(Mem rep inner, LetDec rep ~ LetDecMem) =>
SegOp lvl rep -> UpdateM (inner rep) (SegOp lvl rep)
replaceInSegOp SegOp () MCMem
op
replaceInMCOp MCOp NoOp MCMem
op = MCOp NoOp MCMem -> UpdateM (MCOp NoOp MCMem) (MCOp NoOp MCMem)
forall a. a -> ReaderT (Env (MCOp NoOp MCMem)) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure MCOp NoOp MCMem
op

generalizeIxfun :: [PatElem dec] -> PatElem LetDecMem -> BodyReturns -> UpdateM inner BodyReturns
generalizeIxfun :: forall dec inner.
[PatElem dec]
-> PatElem LetDecMem
-> BranchTypeMem
-> UpdateM inner BranchTypeMem
generalizeIxfun
  [PatElem dec]
pat_elems
  (PatElem VName
vname (MemArray PrimType
_ ShapeBase SubExp
_ NoUniqueness
_ (ArrayIn VName
mem IxFun
ixf)))
  m :: BranchTypeMem
m@(MemArray PrimType
pt ShapeBase ExtSize
shp NoUniqueness
u MemReturn
_) = do
    CoalsTab
coaltab <- (Env inner -> CoalsTab) -> ReaderT (Env inner) Identity CoalsTab
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env inner -> CoalsTab
forall inner. Env inner -> CoalsTab
envCoalesceTab
    if (CoalsEntry -> Bool) -> CoalsTab -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> Map VName Coalesced -> Bool
forall k a. Ord k => k -> Map k a -> Bool
M.member VName
vname (Map VName Coalesced -> Bool)
-> (CoalsEntry -> Map VName Coalesced) -> CoalsEntry -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CoalsEntry -> Map VName Coalesced
vartab) CoalsTab
coaltab
      then
        [VName] -> IxFun -> ExtIxFun
existentialiseIxFun ((PatElem dec -> VName) -> [PatElem dec] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map PatElem dec -> VName
forall dec. PatElem dec -> VName
patElemName [PatElem dec]
pat_elems) IxFun
ixf
          ExtIxFun -> (ExtIxFun -> MemReturn) -> MemReturn
forall a b. a -> (a -> b) -> b
& VName -> ExtIxFun -> MemReturn
ReturnsInBlock VName
mem
          MemReturn -> (MemReturn -> BranchTypeMem) -> BranchTypeMem
forall a b. a -> (a -> b) -> b
& PrimType
-> ShapeBase ExtSize -> NoUniqueness -> MemReturn -> BranchTypeMem
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase ExtSize
shp NoUniqueness
u
          BranchTypeMem
-> (BranchTypeMem -> UpdateM inner BranchTypeMem)
-> UpdateM inner BranchTypeMem
forall a b. a -> (a -> b) -> b
& BranchTypeMem -> UpdateM inner BranchTypeMem
forall a. a -> ReaderT (Env inner) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
      else BranchTypeMem -> UpdateM inner BranchTypeMem
forall a. a -> ReaderT (Env inner) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure BranchTypeMem
m
generalizeIxfun [PatElem dec]
_ PatElem LetDecMem
_ BranchTypeMem
m = BranchTypeMem -> UpdateM inner BranchTypeMem
forall a. a -> ReaderT (Env inner) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure BranchTypeMem
m

replaceInIfBody :: (Mem rep inner, LetDec rep ~ LetDecMem) => Body rep -> UpdateM (inner rep) (Body rep)
replaceInIfBody :: forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Body rep -> UpdateM (inner rep) (Body rep)
replaceInIfBody b :: Body rep
b@(Body BodyDec rep
_ Stms rep
stms Result
res) = do
  CoalsTab
coaltab <- (Env (inner rep) -> CoalsTab)
-> ReaderT (Env (inner rep)) Identity CoalsTab
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env (inner rep) -> CoalsTab
forall inner. Env inner -> CoalsTab
envCoalesceTab
  Stms rep
stms' <- Stms rep -> UpdateM (inner rep) (Stms rep)
forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Stms rep -> UpdateM (inner rep) (Stms rep)
updateStms Stms rep
stms
  Body rep -> UpdateM (inner rep) (Body rep)
forall a. a -> ReaderT (Env (inner rep)) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Body rep -> UpdateM (inner rep) (Body rep))
-> Body rep -> UpdateM (inner rep) (Body rep)
forall a b. (a -> b) -> a -> b
$ Body rep
b {bodyStms :: Stms rep
bodyStms = Stms rep
stms', bodyResult :: Result
bodyResult = (SubExpRes -> SubExpRes) -> Result -> Result
forall a b. (a -> b) -> [a] -> [b]
map (CoalsTab -> SubExpRes -> SubExpRes
replaceResMem CoalsTab
coaltab) Result
res}

replaceInFParam :: Param FParamMem -> UpdateM inner (Param FParamMem)
replaceInFParam :: forall inner.
Param (MemInfo SubExp Uniqueness MemBind)
-> UpdateM inner (Param (MemInfo SubExp Uniqueness MemBind))
replaceInFParam p :: Param (MemInfo SubExp Uniqueness MemBind)
p@(Param Attrs
_ VName
vname (MemArray PrimType
_ ShapeBase SubExp
_ Uniqueness
u MemBind
_)) = do
  Param (MemInfo SubExp Uniqueness MemBind)
-> Maybe (Param (MemInfo SubExp Uniqueness MemBind))
-> Param (MemInfo SubExp Uniqueness MemBind)
forall a. a -> Maybe a -> a
fromMaybe Param (MemInfo SubExp Uniqueness MemBind)
p (Maybe (Param (MemInfo SubExp Uniqueness MemBind))
 -> Param (MemInfo SubExp Uniqueness MemBind))
-> ReaderT
     (Env inner)
     Identity
     (Maybe (Param (MemInfo SubExp Uniqueness MemBind)))
-> UpdateM inner (Param (MemInfo SubExp Uniqueness MemBind))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName
-> (VName
    -> MemInfo SubExp Uniqueness MemBind
    -> Param (MemInfo SubExp Uniqueness MemBind))
-> Uniqueness
-> ReaderT
     (Env inner)
     Identity
     (Maybe (Param (MemInfo SubExp Uniqueness MemBind)))
forall u a inner.
VName -> (VName -> MemBound u -> a) -> u -> UpdateM inner (Maybe a)
lookupAndReplace VName
vname (Attrs
-> VName
-> MemInfo SubExp Uniqueness MemBind
-> Param (MemInfo SubExp Uniqueness MemBind)
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
forall a. Monoid a => a
mempty) Uniqueness
u
replaceInFParam Param (MemInfo SubExp Uniqueness MemBind)
p = Param (MemInfo SubExp Uniqueness MemBind)
-> UpdateM inner (Param (MemInfo SubExp Uniqueness MemBind))
forall a. a -> ReaderT (Env inner) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Param (MemInfo SubExp Uniqueness MemBind)
p

lookupAndReplace ::
  VName ->
  (VName -> MemBound u -> a) ->
  u ->
  UpdateM inner (Maybe a)
lookupAndReplace :: forall u a inner.
VName -> (VName -> MemBound u -> a) -> u -> UpdateM inner (Maybe a)
lookupAndReplace VName
vname VName -> MemBound u -> a
f u
u = do
  CoalsTab
coaltab <- (Env inner -> CoalsTab) -> ReaderT (Env inner) Identity CoalsTab
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env inner -> CoalsTab
forall inner. Env inner -> CoalsTab
envCoalesceTab
  case VName -> Map VName Coalesced -> Maybe Coalesced
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
vname (Map VName Coalesced -> Maybe Coalesced)
-> Map VName Coalesced -> Maybe Coalesced
forall a b. (a -> b) -> a -> b
$ (CoalsEntry -> Map VName Coalesced)
-> CoalsTab -> Map VName Coalesced
forall m a. Monoid m => (a -> m) -> Map VName a -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap CoalsEntry -> Map VName Coalesced
vartab CoalsTab
coaltab of
    Just (Coalesced CoalescedKind
_ (MemBlock PrimType
pt ShapeBase SubExp
shp VName
mem IxFun
ixf) FreeVarSubsts
subs) ->
      IxFun
ixf
        IxFun -> (IxFun -> IxFun) -> IxFun
forall a b. a -> (a -> b) -> b
& (IxFun -> IxFun) -> IxFun -> IxFun
forall a. Eq a => (a -> a) -> a -> a
fixPoint (FreeVarSubsts -> IxFun -> IxFun
forall {k} a (t :: k).
Ord a =>
Map a (TPrimExp t a)
-> IxFun (TPrimExp t a) -> IxFun (TPrimExp t a)
substituteInIxFun FreeVarSubsts
subs)
        IxFun -> (IxFun -> MemBind) -> MemBind
forall a b. a -> (a -> b) -> b
& VName -> IxFun -> MemBind
ArrayIn VName
mem
        MemBind -> (MemBind -> MemBound u) -> MemBound u
forall a b. a -> (a -> b) -> b
& PrimType -> ShapeBase SubExp -> u -> MemBind -> MemBound u
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase SubExp
shp u
u
        MemBound u -> (MemBound u -> a) -> a
forall a b. a -> (a -> b) -> b
& VName -> MemBound u -> a
f VName
vname
        a -> (a -> Maybe a) -> Maybe a
forall a b. a -> (a -> b) -> b
& a -> Maybe a
forall a. a -> Maybe a
Just
        Maybe a
-> (Maybe a -> UpdateM inner (Maybe a)) -> UpdateM inner (Maybe a)
forall a b. a -> (a -> b) -> b
& Maybe a -> UpdateM inner (Maybe a)
forall a. a -> ReaderT (Env inner) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
    Maybe Coalesced
Nothing -> Maybe a -> UpdateM inner (Maybe a)
forall a. a -> ReaderT (Env inner) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe a
forall a. Maybe a
Nothing