{-# LANGUAGE TypeFamilies #-}

-- | This pass attempts to lift allocations and asserts as far towards
-- the top in their body as possible. This helps memory short
-- circuiting do a better job, as it is sensitive to statement
-- ordering.  It does not try to hoist allocations outside across body
-- boundaries.
module Futhark.Pass.LiftAllocations
  ( liftAllocationsSeqMem,
    liftAllocationsGPUMem,
    liftAllocationsMCMem,
  )
where

import Control.Monad.Reader
import Data.Sequence (Seq (..))
import Futhark.Analysis.Alias (aliasAnalysis)
import Futhark.IR.Aliases
import Futhark.IR.GPUMem
import Futhark.IR.MCMem
import Futhark.IR.SeqMem
import Futhark.Pass (Pass (..))

liftInProg ::
  (AliasableRep rep, Mem rep inner, OpReturns (inner (Aliases rep))) =>
  (inner (Aliases rep) -> LiftM (inner (Aliases rep)) (inner (Aliases rep))) ->
  Prog rep ->
  Prog rep
liftInProg :: forall rep (inner :: * -> *).
(AliasableRep rep, Mem rep inner,
 OpReturns (inner (Aliases rep))) =>
(inner (Aliases rep)
 -> LiftM (inner (Aliases rep)) (inner (Aliases rep)))
-> Prog rep -> Prog rep
liftInProg inner (Aliases rep)
-> LiftM (inner (Aliases rep)) (inner (Aliases rep))
onOp Prog rep
prog =
  Prog rep
prog
    { progFuns :: [FunDef rep]
progFuns = forall rep.
RephraseOp (OpC rep) =>
FunDef (Aliases rep) -> FunDef rep
removeFunDefAliases forall b c a. (b -> c) -> (a -> b) -> a -> c
. FunDef (Aliases rep) -> FunDef (Aliases rep)
onFun forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep. Prog rep -> [FunDef rep]
progFuns (forall rep. AliasableRep rep => Prog rep -> Prog (Aliases rep)
aliasAnalysis Prog rep
prog)
    }
  where
    onFun :: FunDef (Aliases rep) -> FunDef (Aliases rep)
onFun FunDef (Aliases rep)
f = FunDef (Aliases rep)
f {funDefBody :: Body (Aliases rep)
funDefBody = Body (Aliases rep) -> Body (Aliases rep)
onBody (forall rep. FunDef rep -> Body rep
funDefBody FunDef (Aliases rep)
f)}
    onBody :: Body (Aliases rep) -> Body (Aliases rep)
onBody Body (Aliases rep)
body = forall r a. Reader r a -> r -> a
runReader (forall rep (inner :: * -> *).
(Mem rep inner, Aliased rep) =>
Body rep -> LiftM (inner rep) (Body rep)
liftAllocationsInBody Body (Aliases rep)
body) (forall inner. (inner -> LiftM inner inner) -> Env inner
Env inner (Aliases rep)
-> LiftM (inner (Aliases rep)) (inner (Aliases rep))
onOp)

liftAllocationsSeqMem :: Pass SeqMem SeqMem
liftAllocationsSeqMem :: Pass SeqMem SeqMem
liftAllocationsSeqMem =
  forall fromrep torep.
String
-> String
-> (Prog fromrep -> PassM (Prog torep))
-> Pass fromrep torep
Pass String
"lift allocations" String
"lift allocations" forall a b. (a -> b) -> a -> b
$
    forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep (inner :: * -> *).
(AliasableRep rep, Mem rep inner,
 OpReturns (inner (Aliases rep))) =>
(inner (Aliases rep)
 -> LiftM (inner (Aliases rep)) (inner (Aliases rep)))
-> Prog rep -> Prog rep
liftInProg forall (f :: * -> *) a. Applicative f => a -> f a
pure

liftAllocationsGPUMem :: Pass GPUMem GPUMem
liftAllocationsGPUMem :: Pass GPUMem GPUMem
liftAllocationsGPUMem =
  forall fromrep torep.
String
-> String
-> (Prog fromrep -> PassM (Prog torep))
-> Pass fromrep torep
Pass String
"lift allocations gpu" String
"lift allocations gpu" forall a b. (a -> b) -> a -> b
$
    forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep (inner :: * -> *).
(AliasableRep rep, Mem rep inner,
 OpReturns (inner (Aliases rep))) =>
(inner (Aliases rep)
 -> LiftM (inner (Aliases rep)) (inner (Aliases rep)))
-> Prog rep -> Prog rep
liftInProg HostOp NoOp (Aliases GPUMem)
-> LiftM
     (HostOp NoOp (Aliases GPUMem)) (HostOp NoOp (Aliases GPUMem))
liftAllocationsInHostOp

liftAllocationsMCMem :: Pass MCMem MCMem
liftAllocationsMCMem :: Pass MCMem MCMem
liftAllocationsMCMem =
  forall fromrep torep.
String
-> String
-> (Prog fromrep -> PassM (Prog torep))
-> Pass fromrep torep
Pass String
"lift allocations mc" String
"lift allocations mc" forall a b. (a -> b) -> a -> b
$
    forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep (inner :: * -> *).
(AliasableRep rep, Mem rep inner,
 OpReturns (inner (Aliases rep))) =>
(inner (Aliases rep)
 -> LiftM (inner (Aliases rep)) (inner (Aliases rep)))
-> Prog rep -> Prog rep
liftInProg MCOp NoOp (Aliases MCMem)
-> LiftM (MCOp NoOp (Aliases MCMem)) (MCOp NoOp (Aliases MCMem))
liftAllocationsInMCOp

newtype Env inner = Env
  {forall inner. Env inner -> inner -> LiftM inner inner
onInner :: inner -> LiftM inner inner}

type LiftM inner a = Reader (Env inner) a

liftAllocationsInBody ::
  (Mem rep inner, Aliased rep) =>
  Body rep ->
  LiftM (inner rep) (Body rep)
liftAllocationsInBody :: forall rep (inner :: * -> *).
(Mem rep inner, Aliased rep) =>
Body rep -> LiftM (inner rep) (Body rep)
liftAllocationsInBody Body rep
body = do
  Stms rep
stms <- forall rep (inner :: * -> *).
(Mem rep inner, Aliased rep) =>
Stms rep
-> Stms rep
-> Stms rep
-> (Names, Names)
-> LiftM (inner rep) (Stms rep)
liftAllocationsInStms (forall rep. Body rep -> Stms rep
bodyStms Body rep
body) forall a. Monoid a => a
mempty forall a. Monoid a => a
mempty forall a. Monoid a => a
mempty
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Body rep
body {bodyStms :: Stms rep
bodyStms = Stms rep
stms}

liftInsideStm ::
  (Mem rep inner, Aliased rep) =>
  Stm rep ->
  LiftM (inner rep) (Stm rep)
liftInsideStm :: forall rep (inner :: * -> *).
(Mem rep inner, Aliased rep) =>
Stm rep -> LiftM (inner rep) (Stm rep)
liftInsideStm stm :: Stm rep
stm@(Let Pat (LetDec rep)
_ StmAux (ExpDec rep)
_ (Op (Inner inner rep
inner))) = do
  inner rep -> LiftM (inner rep) (inner rep)
on_inner <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall inner. Env inner -> inner -> LiftM inner inner
onInner
  inner rep
inner' <- inner rep -> LiftM (inner rep) (inner rep)
on_inner inner rep
inner
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Stm rep
stm {stmExp :: Exp rep
stmExp = forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall (inner :: * -> *) rep. inner rep -> MemOp inner rep
Inner inner rep
inner'}
liftInsideStm stm :: Stm rep
stm@(Let Pat (LetDec rep)
_ StmAux (ExpDec rep)
_ (Match [SubExp]
cond_ses [Case (Body rep)]
cases Body rep
body MatchDec (BranchType rep)
dec)) = do
  [Case (Body rep)]
cases' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\(Case [Maybe PrimValue]
p Body rep
b) -> forall body. [Maybe PrimValue] -> body -> Case body
Case [Maybe PrimValue]
p forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep (inner :: * -> *).
(Mem rep inner, Aliased rep) =>
Body rep -> LiftM (inner rep) (Body rep)
liftAllocationsInBody Body rep
b) [Case (Body rep)]
cases
  Body rep
body' <- forall rep (inner :: * -> *).
(Mem rep inner, Aliased rep) =>
Body rep -> LiftM (inner rep) (Body rep)
liftAllocationsInBody Body rep
body
  forall (f :: * -> *) a. Applicative f => a -> f a
pure Stm rep
stm {stmExp :: Exp rep
stmExp = forall rep.
[SubExp]
-> [Case (Body rep)]
-> Body rep
-> MatchDec (BranchType rep)
-> Exp rep
Match [SubExp]
cond_ses [Case (Body rep)]
cases' Body rep
body' MatchDec (BranchType rep)
dec}
liftInsideStm stm :: Stm rep
stm@(Let Pat (LetDec rep)
_ StmAux (ExpDec rep)
_ (DoLoop [(FParam rep, SubExp)]
params LoopForm rep
form Body rep
body)) = do
  Body rep
body' <- forall rep (inner :: * -> *).
(Mem rep inner, Aliased rep) =>
Body rep -> LiftM (inner rep) (Body rep)
liftAllocationsInBody Body rep
body
  forall (f :: * -> *) a. Applicative f => a -> f a
pure Stm rep
stm {stmExp :: Exp rep
stmExp = forall rep.
[(FParam rep, SubExp)] -> LoopForm rep -> Body rep -> Exp rep
DoLoop [(FParam rep, SubExp)]
params LoopForm rep
form Body rep
body'}
liftInsideStm Stm rep
stm = forall (f :: * -> *) a. Applicative f => a -> f a
pure Stm rep
stm

liftAllocationsInStms ::
  (Mem rep inner, Aliased rep) =>
  -- | The input stms
  Stms rep ->
  -- | The lifted allocations and associated statements
  Stms rep ->
  -- | The other statements processed so far
  Stms rep ->
  -- | (Names we need to lift, consumed names)
  (Names, Names) ->
  LiftM (inner rep) (Stms rep)
liftAllocationsInStms :: forall rep (inner :: * -> *).
(Mem rep inner, Aliased rep) =>
Stms rep
-> Stms rep
-> Stms rep
-> (Names, Names)
-> LiftM (inner rep) (Stms rep)
liftAllocationsInStms Stms rep
Empty Stms rep
lifted Stms rep
acc (Names, Names)
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Stms rep
lifted forall a. Semigroup a => a -> a -> a
<> Stms rep
acc
liftAllocationsInStms (Stms rep
stms :|> Stm rep
stm) Stms rep
lifted Stms rep
acc (Names
to_lift, Names
consumed) = do
  Stm rep
stm' <- forall rep (inner :: * -> *).
(Mem rep inner, Aliased rep) =>
Stm rep -> LiftM (inner rep) (Stm rep)
liftInsideStm Stm rep
stm
  case forall rep. Stm rep -> Exp rep
stmExp Stm rep
stm' of
    BasicOp Assert {} -> Stm rep -> ReaderT (Env (inner rep)) Identity (Stms rep)
liftStm Stm rep
stm'
    Op Alloc {} -> Stm rep -> ReaderT (Env (inner rep)) Identity (Stms rep)
liftStm Stm rep
stm'
    Exp rep
_ -> do
      let pat_names :: Names
pat_names = [VName] -> Names
namesFromList forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [VName]
patNames forall a b. (a -> b) -> a -> b
$ forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm rep
stm'
      if (Names
pat_names Names -> Names -> Bool
`namesIntersect` Names
to_lift)
        Bool -> Bool -> Bool
|| Names -> Names -> Bool
namesIntersect Names
consumed (forall a. FreeIn a => a -> Names
freeIn Stm rep
stm)
        then Stm rep -> ReaderT (Env (inner rep)) Identity (Stms rep)
liftStm Stm rep
stm'
        else Stm rep -> ReaderT (Env (inner rep)) Identity (Stms rep)
dontLiftStm Stm rep
stm'
  where
    liftStm :: Stm rep -> ReaderT (Env (inner rep)) Identity (Stms rep)
liftStm Stm rep
stm' =
      forall rep (inner :: * -> *).
(Mem rep inner, Aliased rep) =>
Stms rep
-> Stms rep
-> Stms rep
-> (Names, Names)
-> LiftM (inner rep) (Stms rep)
liftAllocationsInStms Stms rep
stms (Stm rep
stm' forall a. a -> Seq a -> Seq a
:<| Stms rep
lifted) Stms rep
acc (Names
to_lift', Names
consumed')
      where
        to_lift' :: Names
to_lift' =
          forall a. FreeIn a => a -> Names
freeIn Stm rep
stm'
            forall a. Semigroup a => a -> a -> a
<> (Names
to_lift Names -> Names -> Names
`namesSubtract` [VName] -> Names
namesFromList (forall dec. Pat dec -> [VName]
patNames (forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm rep
stm')))
        consumed' :: Names
consumed' = Names
consumed forall a. Semigroup a => a -> a -> a
<> forall rep. Aliased rep => Stm rep -> Names
consumedInStm Stm rep
stm'
    dontLiftStm :: Stm rep -> ReaderT (Env (inner rep)) Identity (Stms rep)
dontLiftStm Stm rep
stm' =
      forall rep (inner :: * -> *).
(Mem rep inner, Aliased rep) =>
Stms rep
-> Stms rep
-> Stms rep
-> (Names, Names)
-> LiftM (inner rep) (Stms rep)
liftAllocationsInStms Stms rep
stms Stms rep
lifted (Stm rep
stm' forall a. a -> Seq a -> Seq a
:<| Stms rep
acc) (Names
to_lift, Names
consumed)

liftAllocationsInSegOp ::
  (Mem rep inner, Aliased rep) =>
  SegOp lvl rep ->
  LiftM (inner rep) (SegOp lvl rep)
liftAllocationsInSegOp :: forall rep (inner :: * -> *) lvl.
(Mem rep inner, Aliased rep) =>
SegOp lvl rep -> LiftM (inner rep) (SegOp lvl rep)
liftAllocationsInSegOp (SegMap lvl
lvl SegSpace
sp [Type]
tps KernelBody rep
body) = do
  Stms rep
stms <- forall rep (inner :: * -> *).
(Mem rep inner, Aliased rep) =>
Stms rep
-> Stms rep
-> Stms rep
-> (Names, Names)
-> LiftM (inner rep) (Stms rep)
liftAllocationsInStms (forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody rep
body) forall a. Monoid a => a
mempty forall a. Monoid a => a
mempty forall a. Monoid a => a
mempty
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall lvl rep.
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap lvl
lvl SegSpace
sp [Type]
tps forall a b. (a -> b) -> a -> b
$ KernelBody rep
body {kernelBodyStms :: Stms rep
kernelBodyStms = Stms rep
stms}
liftAllocationsInSegOp (SegRed lvl
lvl SegSpace
sp [SegBinOp rep]
binops [Type]
tps KernelBody rep
body) = do
  Stms rep
stms <- forall rep (inner :: * -> *).
(Mem rep inner, Aliased rep) =>
Stms rep
-> Stms rep
-> Stms rep
-> (Names, Names)
-> LiftM (inner rep) (Stms rep)
liftAllocationsInStms (forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody rep
body) forall a. Monoid a => a
mempty forall a. Monoid a => a
mempty forall a. Monoid a => a
mempty
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegRed lvl
lvl SegSpace
sp [SegBinOp rep]
binops [Type]
tps forall a b. (a -> b) -> a -> b
$ KernelBody rep
body {kernelBodyStms :: Stms rep
kernelBodyStms = Stms rep
stms}
liftAllocationsInSegOp (SegScan lvl
lvl SegSpace
sp [SegBinOp rep]
binops [Type]
tps KernelBody rep
body) = do
  Stms rep
stms <- forall rep (inner :: * -> *).
(Mem rep inner, Aliased rep) =>
Stms rep
-> Stms rep
-> Stms rep
-> (Names, Names)
-> LiftM (inner rep) (Stms rep)
liftAllocationsInStms (forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody rep
body) forall a. Monoid a => a
mempty forall a. Monoid a => a
mempty forall a. Monoid a => a
mempty
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegScan lvl
lvl SegSpace
sp [SegBinOp rep]
binops [Type]
tps forall a b. (a -> b) -> a -> b
$ KernelBody rep
body {kernelBodyStms :: Stms rep
kernelBodyStms = Stms rep
stms}
liftAllocationsInSegOp (SegHist lvl
lvl SegSpace
sp [HistOp rep]
histops [Type]
tps KernelBody rep
body) = do
  Stms rep
stms <- forall rep (inner :: * -> *).
(Mem rep inner, Aliased rep) =>
Stms rep
-> Stms rep
-> Stms rep
-> (Names, Names)
-> LiftM (inner rep) (Stms rep)
liftAllocationsInStms (forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody rep
body) forall a. Monoid a => a
mempty forall a. Monoid a => a
mempty forall a. Monoid a => a
mempty
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall lvl rep.
lvl
-> SegSpace
-> [HistOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegHist lvl
lvl SegSpace
sp [HistOp rep]
histops [Type]
tps forall a b. (a -> b) -> a -> b
$ KernelBody rep
body {kernelBodyStms :: Stms rep
kernelBodyStms = Stms rep
stms}

liftAllocationsInHostOp ::
  HostOp NoOp (Aliases GPUMem) ->
  LiftM (HostOp NoOp (Aliases GPUMem)) (HostOp NoOp (Aliases GPUMem))
liftAllocationsInHostOp :: HostOp NoOp (Aliases GPUMem)
-> LiftM
     (HostOp NoOp (Aliases GPUMem)) (HostOp NoOp (Aliases GPUMem))
liftAllocationsInHostOp (SegOp SegOp SegLevel (Aliases GPUMem)
op) = forall (op :: * -> *) rep. SegOp SegLevel rep -> HostOp op rep
SegOp forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep (inner :: * -> *) lvl.
(Mem rep inner, Aliased rep) =>
SegOp lvl rep -> LiftM (inner rep) (SegOp lvl rep)
liftAllocationsInSegOp SegOp SegLevel (Aliases GPUMem)
op
liftAllocationsInHostOp HostOp NoOp (Aliases GPUMem)
op = forall (f :: * -> *) a. Applicative f => a -> f a
pure HostOp NoOp (Aliases GPUMem)
op

liftAllocationsInMCOp ::
  MCOp NoOp (Aliases MCMem) ->
  LiftM (MCOp NoOp (Aliases MCMem)) (MCOp NoOp (Aliases MCMem))
liftAllocationsInMCOp :: MCOp NoOp (Aliases MCMem)
-> LiftM (MCOp NoOp (Aliases MCMem)) (MCOp NoOp (Aliases MCMem))
liftAllocationsInMCOp (ParOp Maybe (SegOp () (Aliases MCMem))
par SegOp () (Aliases MCMem)
op) =
  forall (op :: * -> *) rep.
Maybe (SegOp () rep) -> SegOp () rep -> MCOp op rep
ParOp forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall rep (inner :: * -> *) lvl.
(Mem rep inner, Aliased rep) =>
SegOp lvl rep -> LiftM (inner rep) (SegOp lvl rep)
liftAllocationsInSegOp Maybe (SegOp () (Aliases MCMem))
par forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall rep (inner :: * -> *) lvl.
(Mem rep inner, Aliased rep) =>
SegOp lvl rep -> LiftM (inner rep) (SegOp lvl rep)
liftAllocationsInSegOp SegOp () (Aliases MCMem)
op
liftAllocationsInMCOp MCOp NoOp (Aliases MCMem)
op = forall (f :: * -> *) a. Applicative f => a -> f a
pure MCOp NoOp (Aliases MCMem)
op