{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}

-- | This module implements an optimization that tries to statically reuse
-- kernel-level allocations. The goal is to lower the static memory usage, which
-- might allow more programs to run using intra-group parallelism.
module Futhark.Optimise.MemoryBlockMerging (optimise) where

import Control.Exception
import Control.Monad.State.Strict
import Data.Function ((&))
import Data.Map (Map, (!))
import qualified Data.Map as M
import Data.Set (Set)
import qualified Data.Set as S
import qualified Futhark.Analysis.Interference as Interference
import Futhark.Builder.Class
import Futhark.Construct
import Futhark.IR.GPUMem
import qualified Futhark.Optimise.MemoryBlockMerging.GreedyColoring as GreedyColoring
import Futhark.Pass (Pass (..), PassM)
import qualified Futhark.Pass as Pass
import Futhark.Util (invertMap)

-- | A mapping from allocation names to their size and space.
type Allocs = Map VName (SubExp, Space)

getAllocsStm :: Stm GPUMem -> Allocs
getAllocsStm :: Stm GPUMem -> Allocs
getAllocsStm (Let (Pat [PatElem VName
name LetDec GPUMem
_]) StmAux (ExpDec GPUMem)
_ (Op (Alloc se sp))) =
  VName -> (SubExp, Space) -> Allocs
forall k a. k -> a -> Map k a
M.singleton VName
name (SubExp
se, Space
sp)
getAllocsStm (Let PatT (LetDec GPUMem)
_ StmAux (ExpDec GPUMem)
_ (Op (Alloc _ _))) = [Char] -> Allocs
forall a. HasCallStack => [Char] -> a
error [Char]
"impossible"
getAllocsStm (Let PatT (LetDec GPUMem)
_ StmAux (ExpDec GPUMem)
_ (If SubExp
_ BodyT GPUMem
then_body BodyT GPUMem
else_body IfDec (BranchType GPUMem)
_)) =
  (Stm GPUMem -> Allocs) -> Seq (Stm GPUMem) -> Allocs
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm GPUMem -> Allocs
getAllocsStm (BodyT GPUMem -> Seq (Stm GPUMem)
forall rep. BodyT rep -> Stms rep
bodyStms BodyT GPUMem
then_body)
    Allocs -> Allocs -> Allocs
forall a. Semigroup a => a -> a -> a
<> (Stm GPUMem -> Allocs) -> Seq (Stm GPUMem) -> Allocs
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm GPUMem -> Allocs
getAllocsStm (BodyT GPUMem -> Seq (Stm GPUMem)
forall rep. BodyT rep -> Stms rep
bodyStms BodyT GPUMem
else_body)
getAllocsStm (Let PatT (LetDec GPUMem)
_ StmAux (ExpDec GPUMem)
_ (DoLoop [(FParam GPUMem, SubExp)]
_ LoopForm GPUMem
_ BodyT GPUMem
body)) =
  (Stm GPUMem -> Allocs) -> Seq (Stm GPUMem) -> Allocs
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm GPUMem -> Allocs
getAllocsStm (BodyT GPUMem -> Seq (Stm GPUMem)
forall rep. BodyT rep -> Stms rep
bodyStms BodyT GPUMem
body)
getAllocsStm Stm GPUMem
_ = Allocs
forall a. Monoid a => a
mempty

getAllocsSegOp :: SegOp lvl GPUMem -> Allocs
getAllocsSegOp :: SegOp lvl GPUMem -> Allocs
getAllocsSegOp (SegMap lvl
_ SegSpace
_ [Type]
_ KernelBody GPUMem
body) =
  (Stm GPUMem -> Allocs) -> Seq (Stm GPUMem) -> Allocs
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm GPUMem -> Allocs
getAllocsStm (KernelBody GPUMem -> Seq (Stm GPUMem)
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
body)
getAllocsSegOp (SegRed lvl
_ SegSpace
_ [SegBinOp GPUMem]
_ [Type]
_ KernelBody GPUMem
body) =
  (Stm GPUMem -> Allocs) -> Seq (Stm GPUMem) -> Allocs
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm GPUMem -> Allocs
getAllocsStm (KernelBody GPUMem -> Seq (Stm GPUMem)
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
body)
getAllocsSegOp (SegScan lvl
_ SegSpace
_ [SegBinOp GPUMem]
_ [Type]
_ KernelBody GPUMem
body) =
  (Stm GPUMem -> Allocs) -> Seq (Stm GPUMem) -> Allocs
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm GPUMem -> Allocs
getAllocsStm (KernelBody GPUMem -> Seq (Stm GPUMem)
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
body)
getAllocsSegOp (SegHist lvl
_ SegSpace
_ [HistOp GPUMem]
_ [Type]
_ KernelBody GPUMem
body) =
  (Stm GPUMem -> Allocs) -> Seq (Stm GPUMem) -> Allocs
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm GPUMem -> Allocs
getAllocsStm (KernelBody GPUMem -> Seq (Stm GPUMem)
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
body)

setAllocsStm :: Map VName SubExp -> Stm GPUMem -> Stm GPUMem
setAllocsStm :: Map VName SubExp -> Stm GPUMem -> Stm GPUMem
setAllocsStm Map VName SubExp
m stm :: Stm GPUMem
stm@(Let (Pat [PatElem VName
name LetDec GPUMem
_]) StmAux (ExpDec GPUMem)
_ (Op (Alloc _ _)))
  | Just SubExp
s <- VName -> Map VName SubExp -> Maybe SubExp
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
name Map VName SubExp
m =
    Stm GPUMem
stm {stmExp :: ExpT GPUMem
stmExp = BasicOp -> ExpT GPUMem
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT GPUMem) -> BasicOp -> ExpT GPUMem
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
s}
setAllocsStm Map VName SubExp
_ stm :: Stm GPUMem
stm@(Let PatT (LetDec GPUMem)
_ StmAux (ExpDec GPUMem)
_ (Op (Alloc _ _))) = Stm GPUMem
stm
setAllocsStm Map VName SubExp
m stm :: Stm GPUMem
stm@(Let PatT (LetDec GPUMem)
_ StmAux (ExpDec GPUMem)
_ (Op (Inner (SegOp segop)))) =
  Stm GPUMem
stm {stmExp :: ExpT GPUMem
stmExp = Op GPUMem -> ExpT GPUMem
forall rep. Op rep -> ExpT rep
Op (Op GPUMem -> ExpT GPUMem) -> Op GPUMem -> ExpT GPUMem
forall a b. (a -> b) -> a -> b
$ HostOp GPUMem () -> MemOp (HostOp GPUMem ())
forall inner. inner -> MemOp inner
Inner (HostOp GPUMem () -> MemOp (HostOp GPUMem ()))
-> HostOp GPUMem () -> MemOp (HostOp GPUMem ())
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel GPUMem -> HostOp GPUMem ()
forall rep op. SegOp SegLevel rep -> HostOp rep op
SegOp (SegOp SegLevel GPUMem -> HostOp GPUMem ())
-> SegOp SegLevel GPUMem -> HostOp GPUMem ()
forall a b. (a -> b) -> a -> b
$ Map VName SubExp -> SegOp SegLevel GPUMem -> SegOp SegLevel GPUMem
forall lvl.
Map VName SubExp -> SegOp lvl GPUMem -> SegOp lvl GPUMem
setAllocsSegOp Map VName SubExp
m SegOp SegLevel GPUMem
segop}
setAllocsStm Map VName SubExp
m stm :: Stm GPUMem
stm@(Let PatT (LetDec GPUMem)
_ StmAux (ExpDec GPUMem)
_ (If SubExp
cse BodyT GPUMem
then_body BodyT GPUMem
else_body IfDec (BranchType GPUMem)
dec)) =
  Stm GPUMem
stm
    { stmExp :: ExpT GPUMem
stmExp =
        SubExp
-> BodyT GPUMem
-> BodyT GPUMem
-> IfDec (BranchType GPUMem)
-> ExpT GPUMem
forall rep.
SubExp
-> BodyT rep -> BodyT rep -> IfDec (BranchType rep) -> ExpT rep
If
          SubExp
cse
          (BodyT GPUMem
then_body {bodyStms :: Seq (Stm GPUMem)
bodyStms = Map VName SubExp -> Stm GPUMem -> Stm GPUMem
setAllocsStm Map VName SubExp
m (Stm GPUMem -> Stm GPUMem) -> Seq (Stm GPUMem) -> Seq (Stm GPUMem)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> BodyT GPUMem -> Seq (Stm GPUMem)
forall rep. BodyT rep -> Stms rep
bodyStms BodyT GPUMem
then_body})
          (BodyT GPUMem
else_body {bodyStms :: Seq (Stm GPUMem)
bodyStms = Map VName SubExp -> Stm GPUMem -> Stm GPUMem
setAllocsStm Map VName SubExp
m (Stm GPUMem -> Stm GPUMem) -> Seq (Stm GPUMem) -> Seq (Stm GPUMem)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> BodyT GPUMem -> Seq (Stm GPUMem)
forall rep. BodyT rep -> Stms rep
bodyStms BodyT GPUMem
else_body})
          IfDec (BranchType GPUMem)
dec
    }
setAllocsStm Map VName SubExp
m stm :: Stm GPUMem
stm@(Let PatT (LetDec GPUMem)
_ StmAux (ExpDec GPUMem)
_ (DoLoop [(FParam GPUMem, SubExp)]
merge LoopForm GPUMem
form BodyT GPUMem
body)) =
  Stm GPUMem
stm
    { stmExp :: ExpT GPUMem
stmExp =
        [(FParam GPUMem, SubExp)]
-> LoopForm GPUMem -> BodyT GPUMem -> ExpT GPUMem
forall rep.
[(FParam rep, SubExp)] -> LoopForm rep -> BodyT rep -> ExpT rep
DoLoop [(FParam GPUMem, SubExp)]
merge LoopForm GPUMem
form (BodyT GPUMem
body {bodyStms :: Seq (Stm GPUMem)
bodyStms = Map VName SubExp -> Stm GPUMem -> Stm GPUMem
setAllocsStm Map VName SubExp
m (Stm GPUMem -> Stm GPUMem) -> Seq (Stm GPUMem) -> Seq (Stm GPUMem)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> BodyT GPUMem -> Seq (Stm GPUMem)
forall rep. BodyT rep -> Stms rep
bodyStms BodyT GPUMem
body})
    }
setAllocsStm Map VName SubExp
_ Stm GPUMem
stm = Stm GPUMem
stm

setAllocsSegOp ::
  Map VName SubExp ->
  SegOp lvl GPUMem ->
  SegOp lvl GPUMem
setAllocsSegOp :: Map VName SubExp -> SegOp lvl GPUMem -> SegOp lvl GPUMem
setAllocsSegOp Map VName SubExp
m (SegMap lvl
lvl SegSpace
sp [Type]
tps KernelBody GPUMem
body) =
  lvl -> SegSpace -> [Type] -> KernelBody GPUMem -> SegOp lvl GPUMem
forall lvl rep.
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap lvl
lvl SegSpace
sp [Type]
tps (KernelBody GPUMem -> SegOp lvl GPUMem)
-> KernelBody GPUMem -> SegOp lvl GPUMem
forall a b. (a -> b) -> a -> b
$
    KernelBody GPUMem
body {kernelBodyStms :: Seq (Stm GPUMem)
kernelBodyStms = Map VName SubExp -> Stm GPUMem -> Stm GPUMem
setAllocsStm Map VName SubExp
m (Stm GPUMem -> Stm GPUMem) -> Seq (Stm GPUMem) -> Seq (Stm GPUMem)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KernelBody GPUMem -> Seq (Stm GPUMem)
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
body}
setAllocsSegOp Map VName SubExp
m (SegRed lvl
lvl SegSpace
sp [SegBinOp GPUMem]
segbinops [Type]
tps KernelBody GPUMem
body) =
  lvl
-> SegSpace
-> [SegBinOp GPUMem]
-> [Type]
-> KernelBody GPUMem
-> SegOp lvl GPUMem
forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegRed lvl
lvl SegSpace
sp [SegBinOp GPUMem]
segbinops [Type]
tps (KernelBody GPUMem -> SegOp lvl GPUMem)
-> KernelBody GPUMem -> SegOp lvl GPUMem
forall a b. (a -> b) -> a -> b
$
    KernelBody GPUMem
body {kernelBodyStms :: Seq (Stm GPUMem)
kernelBodyStms = Map VName SubExp -> Stm GPUMem -> Stm GPUMem
setAllocsStm Map VName SubExp
m (Stm GPUMem -> Stm GPUMem) -> Seq (Stm GPUMem) -> Seq (Stm GPUMem)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KernelBody GPUMem -> Seq (Stm GPUMem)
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
body}
setAllocsSegOp Map VName SubExp
m (SegScan lvl
lvl SegSpace
sp [SegBinOp GPUMem]
segbinops [Type]
tps KernelBody GPUMem
body) =
  lvl
-> SegSpace
-> [SegBinOp GPUMem]
-> [Type]
-> KernelBody GPUMem
-> SegOp lvl GPUMem
forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegScan lvl
lvl SegSpace
sp [SegBinOp GPUMem]
segbinops [Type]
tps (KernelBody GPUMem -> SegOp lvl GPUMem)
-> KernelBody GPUMem -> SegOp lvl GPUMem
forall a b. (a -> b) -> a -> b
$
    KernelBody GPUMem
body {kernelBodyStms :: Seq (Stm GPUMem)
kernelBodyStms = Map VName SubExp -> Stm GPUMem -> Stm GPUMem
setAllocsStm Map VName SubExp
m (Stm GPUMem -> Stm GPUMem) -> Seq (Stm GPUMem) -> Seq (Stm GPUMem)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KernelBody GPUMem -> Seq (Stm GPUMem)
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
body}
setAllocsSegOp Map VName SubExp
m (SegHist lvl
lvl SegSpace
sp [HistOp GPUMem]
segbinops [Type]
tps KernelBody GPUMem
body) =
  lvl
-> SegSpace
-> [HistOp GPUMem]
-> [Type]
-> KernelBody GPUMem
-> SegOp lvl GPUMem
forall lvl rep.
lvl
-> SegSpace
-> [HistOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegHist lvl
lvl SegSpace
sp [HistOp GPUMem]
segbinops [Type]
tps (KernelBody GPUMem -> SegOp lvl GPUMem)
-> KernelBody GPUMem -> SegOp lvl GPUMem
forall a b. (a -> b) -> a -> b
$
    KernelBody GPUMem
body {kernelBodyStms :: Seq (Stm GPUMem)
kernelBodyStms = Map VName SubExp -> Stm GPUMem -> Stm GPUMem
setAllocsStm Map VName SubExp
m (Stm GPUMem -> Stm GPUMem) -> Seq (Stm GPUMem) -> Seq (Stm GPUMem)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KernelBody GPUMem -> Seq (Stm GPUMem)
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
body}

maxSubExp :: MonadBuilder m => Set SubExp -> m SubExp
maxSubExp :: Set SubExp -> m SubExp
maxSubExp = [SubExp] -> m SubExp
forall (m :: * -> *). MonadBuilder m => [SubExp] -> m SubExp
helper ([SubExp] -> m SubExp)
-> (Set SubExp -> [SubExp]) -> Set SubExp -> m SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Set SubExp -> [SubExp]
forall a. Set a -> [a]
S.toList
  where
    helper :: [SubExp] -> m SubExp
helper (SubExp
s1 : SubExp
s2 : [SubExp]
sexps) = do
      SubExp
z <- [Char] -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"maxSubHelper" (Exp (Rep m) -> m SubExp) -> Exp (Rep m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep m)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> BinOp
UMax IntType
Int64) SubExp
s1 SubExp
s2
      [SubExp] -> m SubExp
helper (SubExp
z SubExp -> [SubExp] -> [SubExp]
forall a. a -> [a] -> [a]
: [SubExp]
sexps)
    helper [SubExp
s] =
      SubExp -> m SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return SubExp
s
    helper [] = [Char] -> m SubExp
forall a. HasCallStack => [Char] -> a
error [Char]
"impossible"

isKernelInvariant :: Scope GPUMem -> (SubExp, space) -> Bool
isKernelInvariant :: Scope GPUMem -> (SubExp, space) -> Bool
isKernelInvariant Scope GPUMem
scope (Var VName
vname, space
_) = VName
vname VName -> Scope GPUMem -> Bool
forall k a. Ord k => k -> Map k a -> Bool
`M.member` Scope GPUMem
scope
isKernelInvariant Scope GPUMem
_ (SubExp, space)
_ = Bool
True

onKernelBodyStms ::
  MonadBuilder m =>
  SegOp lvl GPUMem ->
  (Stms GPUMem -> m (Stms GPUMem)) ->
  m (SegOp lvl GPUMem)
onKernelBodyStms :: SegOp lvl GPUMem
-> (Seq (Stm GPUMem) -> m (Seq (Stm GPUMem)))
-> m (SegOp lvl GPUMem)
onKernelBodyStms (SegMap lvl
lvl SegSpace
space [Type]
ts KernelBody GPUMem
body) Seq (Stm GPUMem) -> m (Seq (Stm GPUMem))
f = do
  Seq (Stm GPUMem)
stms <- Seq (Stm GPUMem) -> m (Seq (Stm GPUMem))
f (Seq (Stm GPUMem) -> m (Seq (Stm GPUMem)))
-> Seq (Stm GPUMem) -> m (Seq (Stm GPUMem))
forall a b. (a -> b) -> a -> b
$ KernelBody GPUMem -> Seq (Stm GPUMem)
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
body
  SegOp lvl GPUMem -> m (SegOp lvl GPUMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (SegOp lvl GPUMem -> m (SegOp lvl GPUMem))
-> SegOp lvl GPUMem -> m (SegOp lvl GPUMem)
forall a b. (a -> b) -> a -> b
$ lvl -> SegSpace -> [Type] -> KernelBody GPUMem -> SegOp lvl GPUMem
forall lvl rep.
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap lvl
lvl SegSpace
space [Type]
ts (KernelBody GPUMem -> SegOp lvl GPUMem)
-> KernelBody GPUMem -> SegOp lvl GPUMem
forall a b. (a -> b) -> a -> b
$ KernelBody GPUMem
body {kernelBodyStms :: Seq (Stm GPUMem)
kernelBodyStms = Seq (Stm GPUMem)
stms}
onKernelBodyStms (SegRed lvl
lvl SegSpace
space [SegBinOp GPUMem]
binops [Type]
ts KernelBody GPUMem
body) Seq (Stm GPUMem) -> m (Seq (Stm GPUMem))
f = do
  Seq (Stm GPUMem)
stms <- Seq (Stm GPUMem) -> m (Seq (Stm GPUMem))
f (Seq (Stm GPUMem) -> m (Seq (Stm GPUMem)))
-> Seq (Stm GPUMem) -> m (Seq (Stm GPUMem))
forall a b. (a -> b) -> a -> b
$ KernelBody GPUMem -> Seq (Stm GPUMem)
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
body
  SegOp lvl GPUMem -> m (SegOp lvl GPUMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (SegOp lvl GPUMem -> m (SegOp lvl GPUMem))
-> SegOp lvl GPUMem -> m (SegOp lvl GPUMem)
forall a b. (a -> b) -> a -> b
$ lvl
-> SegSpace
-> [SegBinOp GPUMem]
-> [Type]
-> KernelBody GPUMem
-> SegOp lvl GPUMem
forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegRed lvl
lvl SegSpace
space [SegBinOp GPUMem]
binops [Type]
ts (KernelBody GPUMem -> SegOp lvl GPUMem)
-> KernelBody GPUMem -> SegOp lvl GPUMem
forall a b. (a -> b) -> a -> b
$ KernelBody GPUMem
body {kernelBodyStms :: Seq (Stm GPUMem)
kernelBodyStms = Seq (Stm GPUMem)
stms}
onKernelBodyStms (SegScan lvl
lvl SegSpace
space [SegBinOp GPUMem]
binops [Type]
ts KernelBody GPUMem
body) Seq (Stm GPUMem) -> m (Seq (Stm GPUMem))
f = do
  Seq (Stm GPUMem)
stms <- Seq (Stm GPUMem) -> m (Seq (Stm GPUMem))
f (Seq (Stm GPUMem) -> m (Seq (Stm GPUMem)))
-> Seq (Stm GPUMem) -> m (Seq (Stm GPUMem))
forall a b. (a -> b) -> a -> b
$ KernelBody GPUMem -> Seq (Stm GPUMem)
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
body
  SegOp lvl GPUMem -> m (SegOp lvl GPUMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (SegOp lvl GPUMem -> m (SegOp lvl GPUMem))
-> SegOp lvl GPUMem -> m (SegOp lvl GPUMem)
forall a b. (a -> b) -> a -> b
$ lvl
-> SegSpace
-> [SegBinOp GPUMem]
-> [Type]
-> KernelBody GPUMem
-> SegOp lvl GPUMem
forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegScan lvl
lvl SegSpace
space [SegBinOp GPUMem]
binops [Type]
ts (KernelBody GPUMem -> SegOp lvl GPUMem)
-> KernelBody GPUMem -> SegOp lvl GPUMem
forall a b. (a -> b) -> a -> b
$ KernelBody GPUMem
body {kernelBodyStms :: Seq (Stm GPUMem)
kernelBodyStms = Seq (Stm GPUMem)
stms}
onKernelBodyStms (SegHist lvl
lvl SegSpace
space [HistOp GPUMem]
binops [Type]
ts KernelBody GPUMem
body) Seq (Stm GPUMem) -> m (Seq (Stm GPUMem))
f = do
  Seq (Stm GPUMem)
stms <- Seq (Stm GPUMem) -> m (Seq (Stm GPUMem))
f (Seq (Stm GPUMem) -> m (Seq (Stm GPUMem)))
-> Seq (Stm GPUMem) -> m (Seq (Stm GPUMem))
forall a b. (a -> b) -> a -> b
$ KernelBody GPUMem -> Seq (Stm GPUMem)
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
body
  SegOp lvl GPUMem -> m (SegOp lvl GPUMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (SegOp lvl GPUMem -> m (SegOp lvl GPUMem))
-> SegOp lvl GPUMem -> m (SegOp lvl GPUMem)
forall a b. (a -> b) -> a -> b
$ lvl
-> SegSpace
-> [HistOp GPUMem]
-> [Type]
-> KernelBody GPUMem
-> SegOp lvl GPUMem
forall lvl rep.
lvl
-> SegSpace
-> [HistOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegHist lvl
lvl SegSpace
space [HistOp GPUMem]
binops [Type]
ts (KernelBody GPUMem -> SegOp lvl GPUMem)
-> KernelBody GPUMem -> SegOp lvl GPUMem
forall a b. (a -> b) -> a -> b
$ KernelBody GPUMem
body {kernelBodyStms :: Seq (Stm GPUMem)
kernelBodyStms = Seq (Stm GPUMem)
stms}

-- | This is the actual optimiser. Given an interference graph and a @SegOp@,
-- replace allocations and references to memory blocks inside with a (hopefully)
-- reduced number of allocations.
optimiseKernel ::
  (MonadBuilder m, Rep m ~ GPUMem) =>
  Interference.Graph VName ->
  SegOp lvl GPUMem ->
  m (SegOp lvl GPUMem)
optimiseKernel :: Graph VName -> SegOp lvl GPUMem -> m (SegOp lvl GPUMem)
optimiseKernel Graph VName
graph SegOp lvl GPUMem
segop0 = do
  SegOp lvl GPUMem
segop <- SegOp lvl GPUMem
-> (Seq (Stm GPUMem) -> m (Seq (Stm GPUMem)))
-> m (SegOp lvl GPUMem)
forall (m :: * -> *) lvl.
MonadBuilder m =>
SegOp lvl GPUMem
-> (Seq (Stm GPUMem) -> m (Seq (Stm GPUMem)))
-> m (SegOp lvl GPUMem)
onKernelBodyStms SegOp lvl GPUMem
segop0 ((Seq (Stm GPUMem) -> m (Seq (Stm GPUMem)))
 -> m (SegOp lvl GPUMem))
-> (Seq (Stm GPUMem) -> m (Seq (Stm GPUMem)))
-> m (SegOp lvl GPUMem)
forall a b. (a -> b) -> a -> b
$ (SegOp SegLevel GPUMem -> m (SegOp SegLevel GPUMem))
-> Seq (Stm GPUMem) -> m (Seq (Stm GPUMem))
forall (m :: * -> *).
LocalScope GPUMem m =>
(SegOp SegLevel GPUMem -> m (SegOp SegLevel GPUMem))
-> Seq (Stm GPUMem) -> m (Seq (Stm GPUMem))
onKernels ((SegOp SegLevel GPUMem -> m (SegOp SegLevel GPUMem))
 -> Seq (Stm GPUMem) -> m (Seq (Stm GPUMem)))
-> (SegOp SegLevel GPUMem -> m (SegOp SegLevel GPUMem))
-> Seq (Stm GPUMem)
-> m (Seq (Stm GPUMem))
forall a b. (a -> b) -> a -> b
$ Graph VName -> SegOp SegLevel GPUMem -> m (SegOp SegLevel GPUMem)
forall (m :: * -> *) lvl.
(MonadBuilder m, Rep m ~ GPUMem) =>
Graph VName -> SegOp lvl GPUMem -> m (SegOp lvl GPUMem)
optimiseKernel Graph VName
graph
  Scope GPUMem
scope_here <- m (Scope GPUMem)
forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope
  let allocs :: Allocs
allocs = ((SubExp, Space) -> Bool) -> Allocs -> Allocs
forall a k. (a -> Bool) -> Map k a -> Map k a
M.filter (Scope GPUMem -> (SubExp, Space) -> Bool
forall space. Scope GPUMem -> (SubExp, space) -> Bool
isKernelInvariant Scope GPUMem
scope_here) (Allocs -> Allocs) -> Allocs -> Allocs
forall a b. (a -> b) -> a -> b
$ SegOp lvl GPUMem -> Allocs
forall lvl. SegOp lvl GPUMem -> Allocs
getAllocsSegOp SegOp lvl GPUMem
segop
      (Map Int Space
colorspaces, Coloring VName
coloring) =
        Map VName Space -> Graph VName -> (Map Int Space, Coloring VName)
forall a space.
(Ord a, Ord space) =>
Map a space -> Graph a -> (Map Int space, Coloring a)
GreedyColoring.colorGraph
          (((SubExp, Space) -> Space) -> Allocs -> Map VName Space
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (SubExp, Space) -> Space
forall a b. (a, b) -> b
snd Allocs
allocs)
          Graph VName
graph
  ([SubExp]
maxes, Seq (Stm GPUMem)
maxstms) <-
    Coloring VName -> Map Int (Set VName)
forall v k. (Ord v, Ord k) => Map k v -> Map v (Set k)
invertMap Coloring VName
coloring
      Map Int (Set VName)
-> (Map Int (Set VName) -> [Set VName]) -> [Set VName]
forall a b. a -> (a -> b) -> b
& Map Int (Set VName) -> [Set VName]
forall k a. Map k a -> [a]
M.elems
      [Set VName] -> ([Set VName] -> m [SubExp]) -> m [SubExp]
forall a b. a -> (a -> b) -> b
& (Set VName -> m SubExp) -> [Set VName] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Set SubExp -> m SubExp
forall (m :: * -> *). MonadBuilder m => Set SubExp -> m SubExp
maxSubExp (Set SubExp -> m SubExp)
-> (Set VName -> Set SubExp) -> Set VName -> m SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> SubExp) -> Set VName -> Set SubExp
forall b a. Ord b => (a -> b) -> Set a -> Set b
S.map ((SubExp, Space) -> SubExp
forall a b. (a, b) -> a
fst ((SubExp, Space) -> SubExp)
-> (VName -> (SubExp, Space)) -> VName -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Allocs
allocs Allocs -> VName -> (SubExp, Space)
forall k a. Ord k => Map k a -> k -> a
!)))
      m [SubExp]
-> (m [SubExp] -> m ([SubExp], Seq (Stm GPUMem)))
-> m ([SubExp], Seq (Stm GPUMem))
forall a b. a -> (a -> b) -> b
& m [SubExp] -> m ([SubExp], Seq (Stm GPUMem))
forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms
  ([SubExp]
colors, Seq (Stm GPUMem)
stms) <-
    Bool -> [SubExp] -> [SubExp]
forall a. HasCallStack => Bool -> a -> a
assert ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
maxes Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Map Int Space -> Int
forall k a. Map k a -> Int
M.size Map Int Space
colorspaces) [SubExp]
maxes
      [SubExp] -> ([SubExp] -> [(Int, SubExp)]) -> [(Int, SubExp)]
forall a b. a -> (a -> b) -> b
& [Int] -> [SubExp] -> [(Int, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0 ..]
      [(Int, SubExp)] -> ([(Int, SubExp)] -> m [SubExp]) -> m [SubExp]
forall a b. a -> (a -> b) -> b
& ((Int, SubExp) -> m SubExp) -> [(Int, SubExp)] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\(Int
i, SubExp
x) -> [Char] -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"color" (Exp (Rep m) -> m SubExp) -> Exp (Rep m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ Op GPUMem -> ExpT GPUMem
forall rep. Op rep -> ExpT rep
Op (Op GPUMem -> ExpT GPUMem) -> Op GPUMem -> ExpT GPUMem
forall a b. (a -> b) -> a -> b
$ SubExp -> Space -> MemOp (HostOp GPUMem ())
forall inner. SubExp -> Space -> MemOp inner
Alloc SubExp
x (Space -> MemOp (HostOp GPUMem ()))
-> Space -> MemOp (HostOp GPUMem ())
forall a b. (a -> b) -> a -> b
$ Map Int Space
colorspaces Map Int Space -> Int -> Space
forall k a. Ord k => Map k a -> k -> a
! Int
i)
      m [SubExp]
-> (m [SubExp] -> m ([SubExp], Seq (Stm GPUMem)))
-> m ([SubExp], Seq (Stm GPUMem))
forall a b. a -> (a -> b) -> b
& m [SubExp] -> m ([SubExp], Seq (Stm GPUMem))
forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms
  let segop' :: SegOp lvl GPUMem
segop' = Map VName SubExp -> SegOp lvl GPUMem -> SegOp lvl GPUMem
forall lvl.
Map VName SubExp -> SegOp lvl GPUMem -> SegOp lvl GPUMem
setAllocsSegOp ((Int -> SubExp) -> Coloring VName -> Map VName SubExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ([SubExp]
colors [SubExp] -> Int -> SubExp
forall a. [a] -> Int -> a
!!) Coloring VName
coloring) SegOp lvl GPUMem
segop
  SegOp lvl GPUMem -> m (SegOp lvl GPUMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (SegOp lvl GPUMem -> m (SegOp lvl GPUMem))
-> SegOp lvl GPUMem -> m (SegOp lvl GPUMem)
forall a b. (a -> b) -> a -> b
$ case SegOp lvl GPUMem
segop' of
    SegMap lvl
lvl SegSpace
sp [Type]
tps KernelBody GPUMem
body ->
      lvl -> SegSpace -> [Type] -> KernelBody GPUMem -> SegOp lvl GPUMem
forall lvl rep.
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap lvl
lvl SegSpace
sp [Type]
tps (KernelBody GPUMem -> SegOp lvl GPUMem)
-> KernelBody GPUMem -> SegOp lvl GPUMem
forall a b. (a -> b) -> a -> b
$
        KernelBody GPUMem
body {kernelBodyStms :: Seq (Stm GPUMem)
kernelBodyStms = Seq (Stm GPUMem)
maxstms Seq (Stm GPUMem) -> Seq (Stm GPUMem) -> Seq (Stm GPUMem)
forall a. Semigroup a => a -> a -> a
<> Seq (Stm GPUMem)
stms Seq (Stm GPUMem) -> Seq (Stm GPUMem) -> Seq (Stm GPUMem)
forall a. Semigroup a => a -> a -> a
<> KernelBody GPUMem -> Seq (Stm GPUMem)
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
body}
    SegRed lvl
lvl SegSpace
sp [SegBinOp GPUMem]
binops [Type]
tps KernelBody GPUMem
body ->
      lvl
-> SegSpace
-> [SegBinOp GPUMem]
-> [Type]
-> KernelBody GPUMem
-> SegOp lvl GPUMem
forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegRed lvl
lvl SegSpace
sp [SegBinOp GPUMem]
binops [Type]
tps (KernelBody GPUMem -> SegOp lvl GPUMem)
-> KernelBody GPUMem -> SegOp lvl GPUMem
forall a b. (a -> b) -> a -> b
$
        KernelBody GPUMem
body {kernelBodyStms :: Seq (Stm GPUMem)
kernelBodyStms = Seq (Stm GPUMem)
maxstms Seq (Stm GPUMem) -> Seq (Stm GPUMem) -> Seq (Stm GPUMem)
forall a. Semigroup a => a -> a -> a
<> Seq (Stm GPUMem)
stms Seq (Stm GPUMem) -> Seq (Stm GPUMem) -> Seq (Stm GPUMem)
forall a. Semigroup a => a -> a -> a
<> KernelBody GPUMem -> Seq (Stm GPUMem)
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
body}
    SegScan lvl
lvl SegSpace
sp [SegBinOp GPUMem]
binops [Type]
tps KernelBody GPUMem
body ->
      lvl
-> SegSpace
-> [SegBinOp GPUMem]
-> [Type]
-> KernelBody GPUMem
-> SegOp lvl GPUMem
forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegScan lvl
lvl SegSpace
sp [SegBinOp GPUMem]
binops [Type]
tps (KernelBody GPUMem -> SegOp lvl GPUMem)
-> KernelBody GPUMem -> SegOp lvl GPUMem
forall a b. (a -> b) -> a -> b
$
        KernelBody GPUMem
body {kernelBodyStms :: Seq (Stm GPUMem)
kernelBodyStms = Seq (Stm GPUMem)
maxstms Seq (Stm GPUMem) -> Seq (Stm GPUMem) -> Seq (Stm GPUMem)
forall a. Semigroup a => a -> a -> a
<> Seq (Stm GPUMem)
stms Seq (Stm GPUMem) -> Seq (Stm GPUMem) -> Seq (Stm GPUMem)
forall a. Semigroup a => a -> a -> a
<> KernelBody GPUMem -> Seq (Stm GPUMem)
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
body}
    SegHist lvl
lvl SegSpace
sp [HistOp GPUMem]
binops [Type]
tps KernelBody GPUMem
body ->
      lvl
-> SegSpace
-> [HistOp GPUMem]
-> [Type]
-> KernelBody GPUMem
-> SegOp lvl GPUMem
forall lvl rep.
lvl
-> SegSpace
-> [HistOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegHist lvl
lvl SegSpace
sp [HistOp GPUMem]
binops [Type]
tps (KernelBody GPUMem -> SegOp lvl GPUMem)
-> KernelBody GPUMem -> SegOp lvl GPUMem
forall a b. (a -> b) -> a -> b
$
        KernelBody GPUMem
body {kernelBodyStms :: Seq (Stm GPUMem)
kernelBodyStms = Seq (Stm GPUMem)
maxstms Seq (Stm GPUMem) -> Seq (Stm GPUMem) -> Seq (Stm GPUMem)
forall a. Semigroup a => a -> a -> a
<> Seq (Stm GPUMem)
stms Seq (Stm GPUMem) -> Seq (Stm GPUMem) -> Seq (Stm GPUMem)
forall a. Semigroup a => a -> a -> a
<> KernelBody GPUMem -> Seq (Stm GPUMem)
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
body}

-- | Helper function that modifies kernels found inside some statements.
onKernels ::
  LocalScope GPUMem m =>
  (SegOp SegLevel GPUMem -> m (SegOp SegLevel GPUMem)) ->
  Stms GPUMem ->
  m (Stms GPUMem)
onKernels :: (SegOp SegLevel GPUMem -> m (SegOp SegLevel GPUMem))
-> Seq (Stm GPUMem) -> m (Seq (Stm GPUMem))
onKernels SegOp SegLevel GPUMem -> m (SegOp SegLevel GPUMem)
f Seq (Stm GPUMem)
stms = Seq (Stm GPUMem) -> m (Seq (Stm GPUMem)) -> m (Seq (Stm GPUMem))
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Seq (Stm GPUMem)
stms (m (Seq (Stm GPUMem)) -> m (Seq (Stm GPUMem)))
-> m (Seq (Stm GPUMem)) -> m (Seq (Stm GPUMem))
forall a b. (a -> b) -> a -> b
$ (Stm GPUMem -> m (Stm GPUMem))
-> Seq (Stm GPUMem) -> m (Seq (Stm GPUMem))
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Stm GPUMem -> m (Stm GPUMem)
helper Seq (Stm GPUMem)
stms
  where
    helper :: Stm GPUMem -> m (Stm GPUMem)
helper stm :: Stm GPUMem
stm@Let {stmExp :: forall rep. Stm rep -> Exp rep
stmExp = Op (Inner (SegOp segop))} = do
      SegOp SegLevel GPUMem
exp' <- SegOp SegLevel GPUMem -> m (SegOp SegLevel GPUMem)
f SegOp SegLevel GPUMem
segop
      Stm GPUMem -> m (Stm GPUMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stm GPUMem -> m (Stm GPUMem)) -> Stm GPUMem -> m (Stm GPUMem)
forall a b. (a -> b) -> a -> b
$ Stm GPUMem
stm {stmExp :: ExpT GPUMem
stmExp = Op GPUMem -> ExpT GPUMem
forall rep. Op rep -> ExpT rep
Op (Op GPUMem -> ExpT GPUMem) -> Op GPUMem -> ExpT GPUMem
forall a b. (a -> b) -> a -> b
$ HostOp GPUMem () -> MemOp (HostOp GPUMem ())
forall inner. inner -> MemOp inner
Inner (HostOp GPUMem () -> MemOp (HostOp GPUMem ()))
-> HostOp GPUMem () -> MemOp (HostOp GPUMem ())
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel GPUMem -> HostOp GPUMem ()
forall rep op. SegOp SegLevel rep -> HostOp rep op
SegOp SegOp SegLevel GPUMem
exp'}
    helper stm :: Stm GPUMem
stm@Let {stmExp :: forall rep. Stm rep -> Exp rep
stmExp = If SubExp
c BodyT GPUMem
then_body BodyT GPUMem
else_body IfDec (BranchType GPUMem)
dec} = do
      Seq (Stm GPUMem)
then_body_stms <- SegOp SegLevel GPUMem -> m (SegOp SegLevel GPUMem)
f (SegOp SegLevel GPUMem -> m (SegOp SegLevel GPUMem))
-> Seq (Stm GPUMem) -> m (Seq (Stm GPUMem))
forall (m :: * -> *).
LocalScope GPUMem m =>
(SegOp SegLevel GPUMem -> m (SegOp SegLevel GPUMem))
-> Seq (Stm GPUMem) -> m (Seq (Stm GPUMem))
`onKernels` BodyT GPUMem -> Seq (Stm GPUMem)
forall rep. BodyT rep -> Stms rep
bodyStms BodyT GPUMem
then_body
      Seq (Stm GPUMem)
else_body_stms <- SegOp SegLevel GPUMem -> m (SegOp SegLevel GPUMem)
f (SegOp SegLevel GPUMem -> m (SegOp SegLevel GPUMem))
-> Seq (Stm GPUMem) -> m (Seq (Stm GPUMem))
forall (m :: * -> *).
LocalScope GPUMem m =>
(SegOp SegLevel GPUMem -> m (SegOp SegLevel GPUMem))
-> Seq (Stm GPUMem) -> m (Seq (Stm GPUMem))
`onKernels` BodyT GPUMem -> Seq (Stm GPUMem)
forall rep. BodyT rep -> Stms rep
bodyStms BodyT GPUMem
else_body
      Stm GPUMem -> m (Stm GPUMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stm GPUMem -> m (Stm GPUMem)) -> Stm GPUMem -> m (Stm GPUMem)
forall a b. (a -> b) -> a -> b
$
        Stm GPUMem
stm
          { stmExp :: ExpT GPUMem
stmExp =
              SubExp
-> BodyT GPUMem
-> BodyT GPUMem
-> IfDec (BranchType GPUMem)
-> ExpT GPUMem
forall rep.
SubExp
-> BodyT rep -> BodyT rep -> IfDec (BranchType rep) -> ExpT rep
If
                SubExp
c
                (BodyT GPUMem
then_body {bodyStms :: Seq (Stm GPUMem)
bodyStms = Seq (Stm GPUMem)
then_body_stms})
                (BodyT GPUMem
else_body {bodyStms :: Seq (Stm GPUMem)
bodyStms = Seq (Stm GPUMem)
else_body_stms})
                IfDec (BranchType GPUMem)
dec
          }
    helper stm :: Stm GPUMem
stm@Let {stmExp :: forall rep. Stm rep -> Exp rep
stmExp = DoLoop [(FParam GPUMem, SubExp)]
merge LoopForm GPUMem
form BodyT GPUMem
body} = do
      Seq (Stm GPUMem)
body_stms <- SegOp SegLevel GPUMem -> m (SegOp SegLevel GPUMem)
f (SegOp SegLevel GPUMem -> m (SegOp SegLevel GPUMem))
-> Seq (Stm GPUMem) -> m (Seq (Stm GPUMem))
forall (m :: * -> *).
LocalScope GPUMem m =>
(SegOp SegLevel GPUMem -> m (SegOp SegLevel GPUMem))
-> Seq (Stm GPUMem) -> m (Seq (Stm GPUMem))
`onKernels` BodyT GPUMem -> Seq (Stm GPUMem)
forall rep. BodyT rep -> Stms rep
bodyStms BodyT GPUMem
body
      Stm GPUMem -> m (Stm GPUMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stm GPUMem -> m (Stm GPUMem)) -> Stm GPUMem -> m (Stm GPUMem)
forall a b. (a -> b) -> a -> b
$ Stm GPUMem
stm {stmExp :: ExpT GPUMem
stmExp = [(FParam GPUMem, SubExp)]
-> LoopForm GPUMem -> BodyT GPUMem -> ExpT GPUMem
forall rep.
[(FParam rep, SubExp)] -> LoopForm rep -> BodyT rep -> ExpT rep
DoLoop [(FParam GPUMem, SubExp)]
merge LoopForm GPUMem
form (BodyT GPUMem
body {bodyStms :: Seq (Stm GPUMem)
bodyStms = Seq (Stm GPUMem)
body_stms})}
    helper Stm GPUMem
stm = Stm GPUMem -> m (Stm GPUMem)
forall (m :: * -> *) a. Monad m => a -> m a
return Stm GPUMem
stm

-- | Perform the reuse-allocations optimization.
optimise :: Pass GPUMem GPUMem
optimise :: Pass GPUMem GPUMem
optimise =
  [Char]
-> [Char]
-> (Prog GPUMem -> PassM (Prog GPUMem))
-> Pass GPUMem GPUMem
forall fromrep torep.
[Char]
-> [Char]
-> (Prog fromrep -> PassM (Prog torep))
-> Pass fromrep torep
Pass [Char]
"reuse allocations" [Char]
"reuse allocations" ((Prog GPUMem -> PassM (Prog GPUMem)) -> Pass GPUMem GPUMem)
-> (Prog GPUMem -> PassM (Prog GPUMem)) -> Pass GPUMem GPUMem
forall a b. (a -> b) -> a -> b
$ \Prog GPUMem
prog ->
    let graph :: Graph VName
graph = Prog GPUMem -> Graph VName
Interference.analyseProgGPU Prog GPUMem
prog
     in (Scope GPUMem -> Seq (Stm GPUMem) -> PassM (Seq (Stm GPUMem)))
-> Prog GPUMem -> PassM (Prog GPUMem)
forall rep.
(Scope rep -> Stms rep -> PassM (Stms rep))
-> Prog rep -> PassM (Prog rep)
Pass.intraproceduralTransformation (Graph VName
-> Scope GPUMem -> Seq (Stm GPUMem) -> PassM (Seq (Stm GPUMem))
onStms Graph VName
graph) Prog GPUMem
prog
  where
    onStms ::
      Interference.Graph VName ->
      Scope GPUMem ->
      Stms GPUMem ->
      PassM (Stms GPUMem)
    onStms :: Graph VName
-> Scope GPUMem -> Seq (Stm GPUMem) -> PassM (Seq (Stm GPUMem))
onStms Graph VName
graph Scope GPUMem
scope Seq (Stm GPUMem)
stms = do
      let m :: BuilderT GPUMem (StateT VNameSource Identity) (Seq (Stm GPUMem))
m = Scope GPUMem
-> BuilderT GPUMem (StateT VNameSource Identity) (Seq (Stm GPUMem))
-> BuilderT GPUMem (StateT VNameSource Identity) (Seq (Stm GPUMem))
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope GPUMem
scope (BuilderT GPUMem (StateT VNameSource Identity) (Seq (Stm GPUMem))
 -> BuilderT
      GPUMem (StateT VNameSource Identity) (Seq (Stm GPUMem)))
-> BuilderT GPUMem (StateT VNameSource Identity) (Seq (Stm GPUMem))
-> BuilderT GPUMem (StateT VNameSource Identity) (Seq (Stm GPUMem))
forall a b. (a -> b) -> a -> b
$ Graph VName
-> SegOp SegLevel GPUMem
-> BuilderT
     GPUMem (StateT VNameSource Identity) (SegOp SegLevel GPUMem)
forall (m :: * -> *) lvl.
(MonadBuilder m, Rep m ~ GPUMem) =>
Graph VName -> SegOp lvl GPUMem -> m (SegOp lvl GPUMem)
optimiseKernel Graph VName
graph (SegOp SegLevel GPUMem
 -> BuilderT
      GPUMem (StateT VNameSource Identity) (SegOp SegLevel GPUMem))
-> Seq (Stm GPUMem)
-> BuilderT GPUMem (StateT VNameSource Identity) (Seq (Stm GPUMem))
forall (m :: * -> *).
LocalScope GPUMem m =>
(SegOp SegLevel GPUMem -> m (SegOp SegLevel GPUMem))
-> Seq (Stm GPUMem) -> m (Seq (Stm GPUMem))
`onKernels` Seq (Stm GPUMem)
stms
      ((Seq (Stm GPUMem), Seq (Stm GPUMem)) -> Seq (Stm GPUMem))
-> PassM (Seq (Stm GPUMem), Seq (Stm GPUMem))
-> PassM (Seq (Stm GPUMem))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Seq (Stm GPUMem), Seq (Stm GPUMem)) -> Seq (Stm GPUMem)
forall a b. (a, b) -> a
fst (PassM (Seq (Stm GPUMem), Seq (Stm GPUMem))
 -> PassM (Seq (Stm GPUMem)))
-> PassM (Seq (Stm GPUMem), Seq (Stm GPUMem))
-> PassM (Seq (Stm GPUMem))
forall a b. (a -> b) -> a -> b
$ (VNameSource
 -> ((Seq (Stm GPUMem), Seq (Stm GPUMem)), VNameSource))
-> PassM (Seq (Stm GPUMem), Seq (Stm GPUMem))
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource
  -> ((Seq (Stm GPUMem), Seq (Stm GPUMem)), VNameSource))
 -> PassM (Seq (Stm GPUMem), Seq (Stm GPUMem)))
-> (VNameSource
    -> ((Seq (Stm GPUMem), Seq (Stm GPUMem)), VNameSource))
-> PassM (Seq (Stm GPUMem), Seq (Stm GPUMem))
forall a b. (a -> b) -> a -> b
$ State VNameSource (Seq (Stm GPUMem), Seq (Stm GPUMem))
-> VNameSource
-> ((Seq (Stm GPUMem), Seq (Stm GPUMem)), VNameSource)
forall s a. State s a -> s -> (a, s)
runState (BuilderT GPUMem (StateT VNameSource Identity) (Seq (Stm GPUMem))
-> Scope GPUMem
-> State VNameSource (Seq (Stm GPUMem), Seq (Stm GPUMem))
forall (m :: * -> *) rep a.
MonadFreshNames m =>
BuilderT rep m a -> Scope rep -> m (a, Stms rep)
runBuilderT BuilderT GPUMem (StateT VNameSource Identity) (Seq (Stm GPUMem))
m Scope GPUMem
forall a. Monoid a => a
mempty)