{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE TypeFamilies #-}

-- | This module implements an optimization that tries to statically reuse
-- kernel-level allocations. The goal is to lower the static memory usage, which
-- might allow more programs to run using intra-group parallelism.
module Futhark.Optimise.ReuseAllocations (optimise) where

import Control.Exception
import Control.Monad.Reader
import Control.Monad.State.Strict
import Data.Function ((&))
import Data.Map (Map, (!))
import qualified Data.Map as M
import Data.Set (Set)
import qualified Data.Set as S
import qualified Futhark.Analysis.Interference as Interference
import qualified Futhark.Analysis.LastUse as LastUse
import Futhark.Binder.Class
import Futhark.Construct
import Futhark.IR.GPUMem
import qualified Futhark.Optimise.ReuseAllocations.GreedyColoring as GreedyColoring
import Futhark.Pass (Pass (..), PassM)
import qualified Futhark.Pass as Pass
import Futhark.Util (invertMap)

-- | A mapping from allocation names to their size and space.
type Allocs = Map VName (SubExp, Space)

getAllocsStm :: Stm GPUMem -> Allocs
getAllocsStm :: Stm GPUMem -> Allocs
getAllocsStm (Let (Pattern [] [PatElem VName
name LetDec GPUMem
_]) StmAux (ExpDec GPUMem)
_ (Op (Alloc SubExp
se Space
sp))) =
  VName -> (SubExp, Space) -> Allocs
forall k a. k -> a -> Map k a
M.singleton VName
name (SubExp
se, Space
sp)
getAllocsStm (Let PatternT (LetDec GPUMem)
_ StmAux (ExpDec GPUMem)
_ (Op (Alloc SubExp
_ Space
_))) = [Char] -> Allocs
forall a. HasCallStack => [Char] -> a
error [Char]
"impossible"
getAllocsStm (Let PatternT (LetDec GPUMem)
_ StmAux (ExpDec GPUMem)
_ (If SubExp
_ BodyT GPUMem
then_body BodyT GPUMem
else_body IfDec (BranchType GPUMem)
_)) =
  (Stm GPUMem -> Allocs) -> Seq (Stm GPUMem) -> Allocs
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm GPUMem -> Allocs
getAllocsStm (BodyT GPUMem -> Seq (Stm GPUMem)
forall rep. BodyT rep -> Stms rep
bodyStms BodyT GPUMem
then_body)
    Allocs -> Allocs -> Allocs
forall a. Semigroup a => a -> a -> a
<> (Stm GPUMem -> Allocs) -> Seq (Stm GPUMem) -> Allocs
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm GPUMem -> Allocs
getAllocsStm (BodyT GPUMem -> Seq (Stm GPUMem)
forall rep. BodyT rep -> Stms rep
bodyStms BodyT GPUMem
else_body)
getAllocsStm (Let PatternT (LetDec GPUMem)
_ StmAux (ExpDec GPUMem)
_ (DoLoop [(FParam GPUMem, SubExp)]
_ [(FParam GPUMem, SubExp)]
_ LoopForm GPUMem
_ BodyT GPUMem
body)) =
  (Stm GPUMem -> Allocs) -> Seq (Stm GPUMem) -> Allocs
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm GPUMem -> Allocs
getAllocsStm (BodyT GPUMem -> Seq (Stm GPUMem)
forall rep. BodyT rep -> Stms rep
bodyStms BodyT GPUMem
body)
getAllocsStm Stm GPUMem
_ = Allocs
forall a. Monoid a => a
mempty

getAllocsSegOp :: SegOp lvl GPUMem -> Allocs
getAllocsSegOp :: forall lvl. SegOp lvl GPUMem -> Allocs
getAllocsSegOp (SegMap lvl
_ SegSpace
_ [Type]
_ KernelBody GPUMem
body) =
  (Stm GPUMem -> Allocs) -> Seq (Stm GPUMem) -> Allocs
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm GPUMem -> Allocs
getAllocsStm (KernelBody GPUMem -> Seq (Stm GPUMem)
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
body)
getAllocsSegOp (SegRed lvl
_ SegSpace
_ [SegBinOp GPUMem]
_ [Type]
_ KernelBody GPUMem
body) =
  (Stm GPUMem -> Allocs) -> Seq (Stm GPUMem) -> Allocs
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm GPUMem -> Allocs
getAllocsStm (KernelBody GPUMem -> Seq (Stm GPUMem)
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
body)
getAllocsSegOp (SegScan lvl
_ SegSpace
_ [SegBinOp GPUMem]
_ [Type]
_ KernelBody GPUMem
body) =
  (Stm GPUMem -> Allocs) -> Seq (Stm GPUMem) -> Allocs
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm GPUMem -> Allocs
getAllocsStm (KernelBody GPUMem -> Seq (Stm GPUMem)
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
body)
getAllocsSegOp (SegHist lvl
_ SegSpace
_ [HistOp GPUMem]
_ [Type]
_ KernelBody GPUMem
body) =
  (Stm GPUMem -> Allocs) -> Seq (Stm GPUMem) -> Allocs
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm GPUMem -> Allocs
getAllocsStm (KernelBody GPUMem -> Seq (Stm GPUMem)
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
body)

setAllocsStm :: Map VName SubExp -> Stm GPUMem -> Stm GPUMem
setAllocsStm :: Map VName SubExp -> Stm GPUMem -> Stm GPUMem
setAllocsStm Map VName SubExp
m stm :: Stm GPUMem
stm@(Let (Pattern [] [PatElem VName
name LetDec GPUMem
_]) StmAux (ExpDec GPUMem)
_ (Op (Alloc SubExp
_ Space
_)))
  | Just SubExp
s <- VName -> Map VName SubExp -> Maybe SubExp
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
name Map VName SubExp
m =
    Stm GPUMem
stm {stmExp :: ExpT GPUMem
stmExp = BasicOp -> ExpT GPUMem
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT GPUMem) -> BasicOp -> ExpT GPUMem
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
s}
setAllocsStm Map VName SubExp
_ stm :: Stm GPUMem
stm@(Let PatternT (LetDec GPUMem)
_ StmAux (ExpDec GPUMem)
_ (Op (Alloc SubExp
_ Space
_))) = Stm GPUMem
stm
setAllocsStm Map VName SubExp
m stm :: Stm GPUMem
stm@(Let PatternT (LetDec GPUMem)
_ StmAux (ExpDec GPUMem)
_ (Op (Inner (SegOp SegOp SegLevel GPUMem
segop)))) =
  Stm GPUMem
stm {stmExp :: ExpT GPUMem
stmExp = Op GPUMem -> ExpT GPUMem
forall rep. Op rep -> ExpT rep
Op (Op GPUMem -> ExpT GPUMem) -> Op GPUMem -> ExpT GPUMem
forall a b. (a -> b) -> a -> b
$ HostOp GPUMem () -> MemOp (HostOp GPUMem ())
forall inner. inner -> MemOp inner
Inner (HostOp GPUMem () -> MemOp (HostOp GPUMem ()))
-> HostOp GPUMem () -> MemOp (HostOp GPUMem ())
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel GPUMem -> HostOp GPUMem ()
forall rep op. SegOp SegLevel rep -> HostOp rep op
SegOp (SegOp SegLevel GPUMem -> HostOp GPUMem ())
-> SegOp SegLevel GPUMem -> HostOp GPUMem ()
forall a b. (a -> b) -> a -> b
$ Map VName SubExp -> SegOp SegLevel GPUMem -> SegOp SegLevel GPUMem
forall lvl.
Map VName SubExp -> SegOp lvl GPUMem -> SegOp lvl GPUMem
setAllocsSegOp Map VName SubExp
m SegOp SegLevel GPUMem
segop}
setAllocsStm Map VName SubExp
m stm :: Stm GPUMem
stm@(Let PatternT (LetDec GPUMem)
_ StmAux (ExpDec GPUMem)
_ (If SubExp
cse BodyT GPUMem
then_body BodyT GPUMem
else_body IfDec (BranchType GPUMem)
dec)) =
  Stm GPUMem
stm
    { stmExp :: ExpT GPUMem
stmExp =
        SubExp
-> BodyT GPUMem
-> BodyT GPUMem
-> IfDec (BranchType GPUMem)
-> ExpT GPUMem
forall rep.
SubExp
-> BodyT rep -> BodyT rep -> IfDec (BranchType rep) -> ExpT rep
If
          SubExp
cse
          (BodyT GPUMem
then_body {bodyStms :: Seq (Stm GPUMem)
bodyStms = Map VName SubExp -> Stm GPUMem -> Stm GPUMem
setAllocsStm Map VName SubExp
m (Stm GPUMem -> Stm GPUMem) -> Seq (Stm GPUMem) -> Seq (Stm GPUMem)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> BodyT GPUMem -> Seq (Stm GPUMem)
forall rep. BodyT rep -> Stms rep
bodyStms BodyT GPUMem
then_body})
          (BodyT GPUMem
else_body {bodyStms :: Seq (Stm GPUMem)
bodyStms = Map VName SubExp -> Stm GPUMem -> Stm GPUMem
setAllocsStm Map VName SubExp
m (Stm GPUMem -> Stm GPUMem) -> Seq (Stm GPUMem) -> Seq (Stm GPUMem)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> BodyT GPUMem -> Seq (Stm GPUMem)
forall rep. BodyT rep -> Stms rep
bodyStms BodyT GPUMem
else_body})
          IfDec (BranchType GPUMem)
dec
    }
setAllocsStm Map VName SubExp
m stm :: Stm GPUMem
stm@(Let PatternT (LetDec GPUMem)
_ StmAux (ExpDec GPUMem)
_ (DoLoop [(FParam GPUMem, SubExp)]
ctx [(FParam GPUMem, SubExp)]
vals LoopForm GPUMem
form BodyT GPUMem
body)) =
  Stm GPUMem
stm
    { stmExp :: ExpT GPUMem
stmExp =
        [(FParam GPUMem, SubExp)]
-> [(FParam GPUMem, SubExp)]
-> LoopForm GPUMem
-> BodyT GPUMem
-> ExpT GPUMem
forall rep.
[(FParam rep, SubExp)]
-> [(FParam rep, SubExp)] -> LoopForm rep -> BodyT rep -> ExpT rep
DoLoop
          [(FParam GPUMem, SubExp)]
ctx
          [(FParam GPUMem, SubExp)]
vals
          LoopForm GPUMem
form
          (BodyT GPUMem
body {bodyStms :: Seq (Stm GPUMem)
bodyStms = Map VName SubExp -> Stm GPUMem -> Stm GPUMem
setAllocsStm Map VName SubExp
m (Stm GPUMem -> Stm GPUMem) -> Seq (Stm GPUMem) -> Seq (Stm GPUMem)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> BodyT GPUMem -> Seq (Stm GPUMem)
forall rep. BodyT rep -> Stms rep
bodyStms BodyT GPUMem
body})
    }
setAllocsStm Map VName SubExp
_ Stm GPUMem
stm = Stm GPUMem
stm

setAllocsSegOp ::
  Map VName SubExp ->
  SegOp lvl GPUMem ->
  SegOp lvl GPUMem
setAllocsSegOp :: forall lvl.
Map VName SubExp -> SegOp lvl GPUMem -> SegOp lvl GPUMem
setAllocsSegOp Map VName SubExp
m (SegMap lvl
lvl SegSpace
sp [Type]
tps KernelBody GPUMem
body) =
  lvl -> SegSpace -> [Type] -> KernelBody GPUMem -> SegOp lvl GPUMem
forall lvl rep.
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap lvl
lvl SegSpace
sp [Type]
tps (KernelBody GPUMem -> SegOp lvl GPUMem)
-> KernelBody GPUMem -> SegOp lvl GPUMem
forall a b. (a -> b) -> a -> b
$
    KernelBody GPUMem
body {kernelBodyStms :: Seq (Stm GPUMem)
kernelBodyStms = Map VName SubExp -> Stm GPUMem -> Stm GPUMem
setAllocsStm Map VName SubExp
m (Stm GPUMem -> Stm GPUMem) -> Seq (Stm GPUMem) -> Seq (Stm GPUMem)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KernelBody GPUMem -> Seq (Stm GPUMem)
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
body}
setAllocsSegOp Map VName SubExp
m (SegRed lvl
lvl SegSpace
sp [SegBinOp GPUMem]
segbinops [Type]
tps KernelBody GPUMem
body) =
  lvl
-> SegSpace
-> [SegBinOp GPUMem]
-> [Type]
-> KernelBody GPUMem
-> SegOp lvl GPUMem
forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegRed lvl
lvl SegSpace
sp [SegBinOp GPUMem]
segbinops [Type]
tps (KernelBody GPUMem -> SegOp lvl GPUMem)
-> KernelBody GPUMem -> SegOp lvl GPUMem
forall a b. (a -> b) -> a -> b
$
    KernelBody GPUMem
body {kernelBodyStms :: Seq (Stm GPUMem)
kernelBodyStms = Map VName SubExp -> Stm GPUMem -> Stm GPUMem
setAllocsStm Map VName SubExp
m (Stm GPUMem -> Stm GPUMem) -> Seq (Stm GPUMem) -> Seq (Stm GPUMem)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KernelBody GPUMem -> Seq (Stm GPUMem)
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
body}
setAllocsSegOp Map VName SubExp
m (SegScan lvl
lvl SegSpace
sp [SegBinOp GPUMem]
segbinops [Type]
tps KernelBody GPUMem
body) =
  lvl
-> SegSpace
-> [SegBinOp GPUMem]
-> [Type]
-> KernelBody GPUMem
-> SegOp lvl GPUMem
forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegScan lvl
lvl SegSpace
sp [SegBinOp GPUMem]
segbinops [Type]
tps (KernelBody GPUMem -> SegOp lvl GPUMem)
-> KernelBody GPUMem -> SegOp lvl GPUMem
forall a b. (a -> b) -> a -> b
$
    KernelBody GPUMem
body {kernelBodyStms :: Seq (Stm GPUMem)
kernelBodyStms = Map VName SubExp -> Stm GPUMem -> Stm GPUMem
setAllocsStm Map VName SubExp
m (Stm GPUMem -> Stm GPUMem) -> Seq (Stm GPUMem) -> Seq (Stm GPUMem)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KernelBody GPUMem -> Seq (Stm GPUMem)
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
body}
setAllocsSegOp Map VName SubExp
m (SegHist lvl
lvl SegSpace
sp [HistOp GPUMem]
segbinops [Type]
tps KernelBody GPUMem
body) =
  lvl
-> SegSpace
-> [HistOp GPUMem]
-> [Type]
-> KernelBody GPUMem
-> SegOp lvl GPUMem
forall lvl rep.
lvl
-> SegSpace
-> [HistOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegHist lvl
lvl SegSpace
sp [HistOp GPUMem]
segbinops [Type]
tps (KernelBody GPUMem -> SegOp lvl GPUMem)
-> KernelBody GPUMem -> SegOp lvl GPUMem
forall a b. (a -> b) -> a -> b
$
    KernelBody GPUMem
body {kernelBodyStms :: Seq (Stm GPUMem)
kernelBodyStms = Map VName SubExp -> Stm GPUMem -> Stm GPUMem
setAllocsStm Map VName SubExp
m (Stm GPUMem -> Stm GPUMem) -> Seq (Stm GPUMem) -> Seq (Stm GPUMem)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KernelBody GPUMem -> Seq (Stm GPUMem)
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
body}

maxSubExp :: MonadBinder m => Set SubExp -> m SubExp
maxSubExp :: forall (m :: * -> *). MonadBinder m => Set SubExp -> m SubExp
maxSubExp = [SubExp] -> m SubExp
forall {m :: * -> *}. MonadBinder m => [SubExp] -> m SubExp
helper ([SubExp] -> m SubExp)
-> (Set SubExp -> [SubExp]) -> Set SubExp -> m SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Set SubExp -> [SubExp]
forall a. Set a -> [a]
S.toList
  where
    helper :: [SubExp] -> m SubExp
helper (SubExp
s1 : SubExp
s2 : [SubExp]
sexps) = do
      SubExp
z <- [Char] -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"maxSubHelper" (Exp (Rep m) -> m SubExp) -> Exp (Rep m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep m)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> BinOp
UMax IntType
Int64) SubExp
s1 SubExp
s2
      [SubExp] -> m SubExp
helper (SubExp
z SubExp -> [SubExp] -> [SubExp]
forall a. a -> [a] -> [a]
: [SubExp]
sexps)
    helper [SubExp
s] =
      SubExp -> m SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return SubExp
s
    helper [] = [Char] -> m SubExp
forall a. HasCallStack => [Char] -> a
error [Char]
"impossible"

definedInExp :: Exp GPUMem -> Set VName
definedInExp :: ExpT GPUMem -> Set VName
definedInExp (Op (Inner (SegOp SegOp SegLevel GPUMem
segop))) =
  SegOp SegLevel GPUMem -> Set VName
forall lvl. SegOp lvl GPUMem -> Set VName
definedInSegOp SegOp SegLevel GPUMem
segop
definedInExp (If SubExp
_ BodyT GPUMem
then_body BodyT GPUMem
else_body IfDec (BranchType GPUMem)
_) =
  (Stm GPUMem -> Set VName) -> Seq (Stm GPUMem) -> Set VName
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm GPUMem -> Set VName
definedInStm (BodyT GPUMem -> Seq (Stm GPUMem)
forall rep. BodyT rep -> Stms rep
bodyStms BodyT GPUMem
then_body)
    Set VName -> Set VName -> Set VName
forall a. Semigroup a => a -> a -> a
<> (Stm GPUMem -> Set VName) -> Seq (Stm GPUMem) -> Set VName
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm GPUMem -> Set VName
definedInStm (BodyT GPUMem -> Seq (Stm GPUMem)
forall rep. BodyT rep -> Stms rep
bodyStms BodyT GPUMem
else_body)
definedInExp (DoLoop [(FParam GPUMem, SubExp)]
_ [(FParam GPUMem, SubExp)]
_ LoopForm GPUMem
_ BodyT GPUMem
body) =
  (Stm GPUMem -> Set VName) -> Seq (Stm GPUMem) -> Set VName
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm GPUMem -> Set VName
definedInStm (Seq (Stm GPUMem) -> Set VName) -> Seq (Stm GPUMem) -> Set VName
forall a b. (a -> b) -> a -> b
$ BodyT GPUMem -> Seq (Stm GPUMem)
forall rep. BodyT rep -> Stms rep
bodyStms BodyT GPUMem
body
definedInExp ExpT GPUMem
_ = Set VName
forall a. Monoid a => a
mempty

definedInStm :: Stm GPUMem -> Set VName
definedInStm :: Stm GPUMem -> Set VName
definedInStm Let {stmPattern :: forall rep. Stm rep -> Pattern rep
stmPattern = Pattern [PatElemT (LetDec GPUMem)]
ctx [PatElemT (LetDec GPUMem)]
vals, ExpT GPUMem
stmExp :: ExpT GPUMem
stmExp :: forall rep. Stm rep -> Exp rep
stmExp} =
  let definedInside :: Set VName
definedInside =
        [PatElemT (LetDec GPUMem)]
[PatElemT LetDecMem]
ctx [PatElemT LetDecMem]
-> [PatElemT LetDecMem] -> [PatElemT LetDecMem]
forall a. Semigroup a => a -> a -> a
<> [PatElemT (LetDec GPUMem)]
[PatElemT LetDecMem]
vals
          [PatElemT LetDecMem]
-> ([PatElemT LetDecMem] -> [VName]) -> [VName]
forall a b. a -> (a -> b) -> b
& (PatElemT LetDecMem -> VName) -> [PatElemT LetDecMem] -> [VName]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap PatElemT LetDecMem -> VName
forall dec. PatElemT dec -> VName
patElemName
          [VName] -> ([VName] -> Set VName) -> Set VName
forall a b. a -> (a -> b) -> b
& [VName] -> Set VName
forall a. Ord a => [a] -> Set a
S.fromList
   in ExpT GPUMem -> Set VName
definedInExp ExpT GPUMem
stmExp Set VName -> Set VName -> Set VName
forall a. Semigroup a => a -> a -> a
<> Set VName
definedInside

definedInSegOp :: SegOp lvl GPUMem -> Set VName
definedInSegOp :: forall lvl. SegOp lvl GPUMem -> Set VName
definedInSegOp (SegMap lvl
_ SegSpace
_ [Type]
_ KernelBody GPUMem
body) =
  (Stm GPUMem -> Set VName) -> Seq (Stm GPUMem) -> Set VName
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm GPUMem -> Set VName
definedInStm (Seq (Stm GPUMem) -> Set VName) -> Seq (Stm GPUMem) -> Set VName
forall a b. (a -> b) -> a -> b
$ KernelBody GPUMem -> Seq (Stm GPUMem)
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
body
definedInSegOp (SegRed lvl
_ SegSpace
_ [SegBinOp GPUMem]
_ [Type]
_ KernelBody GPUMem
body) =
  (Stm GPUMem -> Set VName) -> Seq (Stm GPUMem) -> Set VName
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm GPUMem -> Set VName
definedInStm (Seq (Stm GPUMem) -> Set VName) -> Seq (Stm GPUMem) -> Set VName
forall a b. (a -> b) -> a -> b
$ KernelBody GPUMem -> Seq (Stm GPUMem)
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
body
definedInSegOp (SegScan lvl
_ SegSpace
_ [SegBinOp GPUMem]
_ [Type]
_ KernelBody GPUMem
body) =
  (Stm GPUMem -> Set VName) -> Seq (Stm GPUMem) -> Set VName
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm GPUMem -> Set VName
definedInStm (Seq (Stm GPUMem) -> Set VName) -> Seq (Stm GPUMem) -> Set VName
forall a b. (a -> b) -> a -> b
$ KernelBody GPUMem -> Seq (Stm GPUMem)
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
body
definedInSegOp (SegHist lvl
_ SegSpace
_ [HistOp GPUMem]
_ [Type]
_ KernelBody GPUMem
body) =
  (Stm GPUMem -> Set VName) -> Seq (Stm GPUMem) -> Set VName
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm GPUMem -> Set VName
definedInStm (Seq (Stm GPUMem) -> Set VName) -> Seq (Stm GPUMem) -> Set VName
forall a b. (a -> b) -> a -> b
$ KernelBody GPUMem -> Seq (Stm GPUMem)
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
body

isKernelInvariant :: SegOp lvl GPUMem -> (SubExp, space) -> Bool
isKernelInvariant :: forall lvl space. SegOp lvl GPUMem -> (SubExp, space) -> Bool
isKernelInvariant SegOp lvl GPUMem
segop (Var VName
vname, space
_) =
  Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ VName
vname VName -> Set VName -> Bool
forall a. Ord a => a -> Set a -> Bool
`S.member` SegOp lvl GPUMem -> Set VName
forall lvl. SegOp lvl GPUMem -> Set VName
definedInSegOp SegOp lvl GPUMem
segop
isKernelInvariant SegOp lvl GPUMem
_ (SubExp, space)
_ = Bool
True

onKernelBodyStms ::
  MonadBinder m =>
  SegOp lvl GPUMem ->
  (Stms GPUMem -> m (Stms GPUMem)) ->
  m (SegOp lvl GPUMem)
onKernelBodyStms :: forall (m :: * -> *) lvl.
MonadBinder m =>
SegOp lvl GPUMem
-> (Seq (Stm GPUMem) -> m (Seq (Stm GPUMem)))
-> m (SegOp lvl GPUMem)
onKernelBodyStms (SegMap lvl
lvl SegSpace
space [Type]
ts KernelBody GPUMem
body) Seq (Stm GPUMem) -> m (Seq (Stm GPUMem))
f = do
  Seq (Stm GPUMem)
stms <- Seq (Stm GPUMem) -> m (Seq (Stm GPUMem))
f (Seq (Stm GPUMem) -> m (Seq (Stm GPUMem)))
-> Seq (Stm GPUMem) -> m (Seq (Stm GPUMem))
forall a b. (a -> b) -> a -> b
$ KernelBody GPUMem -> Seq (Stm GPUMem)
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
body
  SegOp lvl GPUMem -> m (SegOp lvl GPUMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (SegOp lvl GPUMem -> m (SegOp lvl GPUMem))
-> SegOp lvl GPUMem -> m (SegOp lvl GPUMem)
forall a b. (a -> b) -> a -> b
$ lvl -> SegSpace -> [Type] -> KernelBody GPUMem -> SegOp lvl GPUMem
forall lvl rep.
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap lvl
lvl SegSpace
space [Type]
ts (KernelBody GPUMem -> SegOp lvl GPUMem)
-> KernelBody GPUMem -> SegOp lvl GPUMem
forall a b. (a -> b) -> a -> b
$ KernelBody GPUMem
body {kernelBodyStms :: Seq (Stm GPUMem)
kernelBodyStms = Seq (Stm GPUMem)
stms}
onKernelBodyStms (SegRed lvl
lvl SegSpace
space [SegBinOp GPUMem]
binops [Type]
ts KernelBody GPUMem
body) Seq (Stm GPUMem) -> m (Seq (Stm GPUMem))
f = do
  Seq (Stm GPUMem)
stms <- Seq (Stm GPUMem) -> m (Seq (Stm GPUMem))
f (Seq (Stm GPUMem) -> m (Seq (Stm GPUMem)))
-> Seq (Stm GPUMem) -> m (Seq (Stm GPUMem))
forall a b. (a -> b) -> a -> b
$ KernelBody GPUMem -> Seq (Stm GPUMem)
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
body
  SegOp lvl GPUMem -> m (SegOp lvl GPUMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (SegOp lvl GPUMem -> m (SegOp lvl GPUMem))
-> SegOp lvl GPUMem -> m (SegOp lvl GPUMem)
forall a b. (a -> b) -> a -> b
$ lvl
-> SegSpace
-> [SegBinOp GPUMem]
-> [Type]
-> KernelBody GPUMem
-> SegOp lvl GPUMem
forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegRed lvl
lvl SegSpace
space [SegBinOp GPUMem]
binops [Type]
ts (KernelBody GPUMem -> SegOp lvl GPUMem)
-> KernelBody GPUMem -> SegOp lvl GPUMem
forall a b. (a -> b) -> a -> b
$ KernelBody GPUMem
body {kernelBodyStms :: Seq (Stm GPUMem)
kernelBodyStms = Seq (Stm GPUMem)
stms}
onKernelBodyStms (SegScan lvl
lvl SegSpace
space [SegBinOp GPUMem]
binops [Type]
ts KernelBody GPUMem
body) Seq (Stm GPUMem) -> m (Seq (Stm GPUMem))
f = do
  Seq (Stm GPUMem)
stms <- Seq (Stm GPUMem) -> m (Seq (Stm GPUMem))
f (Seq (Stm GPUMem) -> m (Seq (Stm GPUMem)))
-> Seq (Stm GPUMem) -> m (Seq (Stm GPUMem))
forall a b. (a -> b) -> a -> b
$ KernelBody GPUMem -> Seq (Stm GPUMem)
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
body
  SegOp lvl GPUMem -> m (SegOp lvl GPUMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (SegOp lvl GPUMem -> m (SegOp lvl GPUMem))
-> SegOp lvl GPUMem -> m (SegOp lvl GPUMem)
forall a b. (a -> b) -> a -> b
$ lvl
-> SegSpace
-> [SegBinOp GPUMem]
-> [Type]
-> KernelBody GPUMem
-> SegOp lvl GPUMem
forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegScan lvl
lvl SegSpace
space [SegBinOp GPUMem]
binops [Type]
ts (KernelBody GPUMem -> SegOp lvl GPUMem)
-> KernelBody GPUMem -> SegOp lvl GPUMem
forall a b. (a -> b) -> a -> b
$ KernelBody GPUMem
body {kernelBodyStms :: Seq (Stm GPUMem)
kernelBodyStms = Seq (Stm GPUMem)
stms}
onKernelBodyStms (SegHist lvl
lvl SegSpace
space [HistOp GPUMem]
binops [Type]
ts KernelBody GPUMem
body) Seq (Stm GPUMem) -> m (Seq (Stm GPUMem))
f = do
  Seq (Stm GPUMem)
stms <- Seq (Stm GPUMem) -> m (Seq (Stm GPUMem))
f (Seq (Stm GPUMem) -> m (Seq (Stm GPUMem)))
-> Seq (Stm GPUMem) -> m (Seq (Stm GPUMem))
forall a b. (a -> b) -> a -> b
$ KernelBody GPUMem -> Seq (Stm GPUMem)
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
body
  SegOp lvl GPUMem -> m (SegOp lvl GPUMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (SegOp lvl GPUMem -> m (SegOp lvl GPUMem))
-> SegOp lvl GPUMem -> m (SegOp lvl GPUMem)
forall a b. (a -> b) -> a -> b
$ lvl
-> SegSpace
-> [HistOp GPUMem]
-> [Type]
-> KernelBody GPUMem
-> SegOp lvl GPUMem
forall lvl rep.
lvl
-> SegSpace
-> [HistOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegHist lvl
lvl SegSpace
space [HistOp GPUMem]
binops [Type]
ts (KernelBody GPUMem -> SegOp lvl GPUMem)
-> KernelBody GPUMem -> SegOp lvl GPUMem
forall a b. (a -> b) -> a -> b
$ KernelBody GPUMem
body {kernelBodyStms :: Seq (Stm GPUMem)
kernelBodyStms = Seq (Stm GPUMem)
stms}

-- | This is the actual optimiser. Given an interference graph and a `SegOp`,
-- replace allocations and references to memory blocks inside with a (hopefully)
-- reduced number of allocations.
optimiseKernel ::
  (MonadBinder m, Rep m ~ GPUMem) =>
  Interference.Graph VName ->
  SegOp lvl GPUMem ->
  m (SegOp lvl GPUMem)
optimiseKernel :: forall (m :: * -> *) lvl.
(MonadBinder m, Rep m ~ GPUMem) =>
Graph VName -> SegOp lvl GPUMem -> m (SegOp lvl GPUMem)
optimiseKernel Graph VName
graph SegOp lvl GPUMem
segop0 = do
  SegOp lvl GPUMem
segop <- SegOp lvl GPUMem
-> (Seq (Stm GPUMem) -> m (Seq (Stm GPUMem)))
-> m (SegOp lvl GPUMem)
forall (m :: * -> *) lvl.
MonadBinder m =>
SegOp lvl GPUMem
-> (Seq (Stm GPUMem) -> m (Seq (Stm GPUMem)))
-> m (SegOp lvl GPUMem)
onKernelBodyStms SegOp lvl GPUMem
segop0 ((Seq (Stm GPUMem) -> m (Seq (Stm GPUMem)))
 -> m (SegOp lvl GPUMem))
-> (Seq (Stm GPUMem) -> m (Seq (Stm GPUMem)))
-> m (SegOp lvl GPUMem)
forall a b. (a -> b) -> a -> b
$ (SegOp SegLevel GPUMem -> m (SegOp SegLevel GPUMem))
-> Seq (Stm GPUMem) -> m (Seq (Stm GPUMem))
forall (m :: * -> *).
LocalScope GPUMem m =>
(SegOp SegLevel GPUMem -> m (SegOp SegLevel GPUMem))
-> Seq (Stm GPUMem) -> m (Seq (Stm GPUMem))
onKernels ((SegOp SegLevel GPUMem -> m (SegOp SegLevel GPUMem))
 -> Seq (Stm GPUMem) -> m (Seq (Stm GPUMem)))
-> (SegOp SegLevel GPUMem -> m (SegOp SegLevel GPUMem))
-> Seq (Stm GPUMem)
-> m (Seq (Stm GPUMem))
forall a b. (a -> b) -> a -> b
$ Graph VName -> SegOp SegLevel GPUMem -> m (SegOp SegLevel GPUMem)
forall (m :: * -> *) lvl.
(MonadBinder m, Rep m ~ GPUMem) =>
Graph VName -> SegOp lvl GPUMem -> m (SegOp lvl GPUMem)
optimiseKernel Graph VName
graph
  let allocs :: Allocs
allocs = ((SubExp, Space) -> Bool) -> Allocs -> Allocs
forall a k. (a -> Bool) -> Map k a -> Map k a
M.filter (SegOp lvl GPUMem -> (SubExp, Space) -> Bool
forall lvl space. SegOp lvl GPUMem -> (SubExp, space) -> Bool
isKernelInvariant SegOp lvl GPUMem
segop) (Allocs -> Allocs) -> Allocs -> Allocs
forall a b. (a -> b) -> a -> b
$ SegOp lvl GPUMem -> Allocs
forall lvl. SegOp lvl GPUMem -> Allocs
getAllocsSegOp SegOp lvl GPUMem
segop
      (Map Int Space
colorspaces, Coloring VName
coloring) =
        Map VName Space -> Graph VName -> (Map Int Space, Coloring VName)
forall a space.
(Ord a, Ord space) =>
Map a space -> Graph a -> (Map Int space, Coloring a)
GreedyColoring.colorGraph
          (((SubExp, Space) -> Space) -> Allocs -> Map VName Space
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (SubExp, Space) -> Space
forall a b. (a, b) -> b
snd Allocs
allocs)
          Graph VName
graph
  ([SubExp]
maxes, Seq (Stm GPUMem)
maxstms) <-
    Coloring VName -> Map Int (Set VName)
forall v k. (Ord v, Ord k) => Map k v -> Map v (Set k)
invertMap Coloring VName
coloring
      Map Int (Set VName)
-> (Map Int (Set VName) -> [Set VName]) -> [Set VName]
forall a b. a -> (a -> b) -> b
& Map Int (Set VName) -> [Set VName]
forall k a. Map k a -> [a]
M.elems
      [Set VName] -> ([Set VName] -> m [SubExp]) -> m [SubExp]
forall a b. a -> (a -> b) -> b
& (Set VName -> m SubExp) -> [Set VName] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Set SubExp -> m SubExp
forall (m :: * -> *). MonadBinder m => Set SubExp -> m SubExp
maxSubExp (Set SubExp -> m SubExp)
-> (Set VName -> Set SubExp) -> Set VName -> m SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> SubExp) -> Set VName -> Set SubExp
forall b a. Ord b => (a -> b) -> Set a -> Set b
S.map ((SubExp, Space) -> SubExp
forall a b. (a, b) -> a
fst ((SubExp, Space) -> SubExp)
-> (VName -> (SubExp, Space)) -> VName -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Allocs
allocs Allocs -> VName -> (SubExp, Space)
forall k a. Ord k => Map k a -> k -> a
!)))
      m [SubExp]
-> (m [SubExp] -> m ([SubExp], Seq (Stm GPUMem)))
-> m ([SubExp], Seq (Stm GPUMem))
forall a b. a -> (a -> b) -> b
& m [SubExp] -> m ([SubExp], Seq (Stm GPUMem))
forall (m :: * -> *) a. MonadBinder m => m a -> m (a, Stms (Rep m))
collectStms
  ([SubExp]
colors, Seq (Stm GPUMem)
stms) <-
    Bool -> [SubExp] -> [SubExp]
forall a. HasCallStack => Bool -> a -> a
assert ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
maxes Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Map Int Space -> Int
forall k a. Map k a -> Int
M.size Map Int Space
colorspaces) [SubExp]
maxes
      [SubExp] -> ([SubExp] -> [(Int, SubExp)]) -> [(Int, SubExp)]
forall a b. a -> (a -> b) -> b
& [Int] -> [SubExp] -> [(Int, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0 ..]
      [(Int, SubExp)] -> ([(Int, SubExp)] -> m [SubExp]) -> m [SubExp]
forall a b. a -> (a -> b) -> b
& ((Int, SubExp) -> m SubExp) -> [(Int, SubExp)] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\(Int
i, SubExp
x) -> [Char] -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"color" (Exp (Rep m) -> m SubExp) -> Exp (Rep m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ Op GPUMem -> ExpT GPUMem
forall rep. Op rep -> ExpT rep
Op (Op GPUMem -> ExpT GPUMem) -> Op GPUMem -> ExpT GPUMem
forall a b. (a -> b) -> a -> b
$ SubExp -> Space -> MemOp (HostOp GPUMem ())
forall inner. SubExp -> Space -> MemOp inner
Alloc SubExp
x (Space -> MemOp (HostOp GPUMem ()))
-> Space -> MemOp (HostOp GPUMem ())
forall a b. (a -> b) -> a -> b
$ Map Int Space
colorspaces Map Int Space -> Int -> Space
forall k a. Ord k => Map k a -> k -> a
! Int
i)
      m [SubExp]
-> (m [SubExp] -> m ([SubExp], Seq (Stm GPUMem)))
-> m ([SubExp], Seq (Stm GPUMem))
forall a b. a -> (a -> b) -> b
& m [SubExp] -> m ([SubExp], Seq (Stm GPUMem))
forall (m :: * -> *) a. MonadBinder m => m a -> m (a, Stms (Rep m))
collectStms
  let segop' :: SegOp lvl GPUMem
segop' = Map VName SubExp -> SegOp lvl GPUMem -> SegOp lvl GPUMem
forall lvl.
Map VName SubExp -> SegOp lvl GPUMem -> SegOp lvl GPUMem
setAllocsSegOp ((Int -> SubExp) -> Coloring VName -> Map VName SubExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ([SubExp]
colors [SubExp] -> Int -> SubExp
forall a. [a] -> Int -> a
!!) Coloring VName
coloring) SegOp lvl GPUMem
segop
  SegOp lvl GPUMem -> m (SegOp lvl GPUMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (SegOp lvl GPUMem -> m (SegOp lvl GPUMem))
-> SegOp lvl GPUMem -> m (SegOp lvl GPUMem)
forall a b. (a -> b) -> a -> b
$ case SegOp lvl GPUMem
segop' of
    SegMap lvl
lvl SegSpace
sp [Type]
tps KernelBody GPUMem
body ->
      lvl -> SegSpace -> [Type] -> KernelBody GPUMem -> SegOp lvl GPUMem
forall lvl rep.
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap lvl
lvl SegSpace
sp [Type]
tps (KernelBody GPUMem -> SegOp lvl GPUMem)
-> KernelBody GPUMem -> SegOp lvl GPUMem
forall a b. (a -> b) -> a -> b
$
        KernelBody GPUMem
body {kernelBodyStms :: Seq (Stm GPUMem)
kernelBodyStms = Seq (Stm GPUMem)
maxstms Seq (Stm GPUMem) -> Seq (Stm GPUMem) -> Seq (Stm GPUMem)
forall a. Semigroup a => a -> a -> a
<> Seq (Stm GPUMem)
stms Seq (Stm GPUMem) -> Seq (Stm GPUMem) -> Seq (Stm GPUMem)
forall a. Semigroup a => a -> a -> a
<> KernelBody GPUMem -> Seq (Stm GPUMem)
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
body}
    SegRed lvl
lvl SegSpace
sp [SegBinOp GPUMem]
binops [Type]
tps KernelBody GPUMem
body ->
      lvl
-> SegSpace
-> [SegBinOp GPUMem]
-> [Type]
-> KernelBody GPUMem
-> SegOp lvl GPUMem
forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegRed lvl
lvl SegSpace
sp [SegBinOp GPUMem]
binops [Type]
tps (KernelBody GPUMem -> SegOp lvl GPUMem)
-> KernelBody GPUMem -> SegOp lvl GPUMem
forall a b. (a -> b) -> a -> b
$
        KernelBody GPUMem
body {kernelBodyStms :: Seq (Stm GPUMem)
kernelBodyStms = Seq (Stm GPUMem)
maxstms Seq (Stm GPUMem) -> Seq (Stm GPUMem) -> Seq (Stm GPUMem)
forall a. Semigroup a => a -> a -> a
<> Seq (Stm GPUMem)
stms Seq (Stm GPUMem) -> Seq (Stm GPUMem) -> Seq (Stm GPUMem)
forall a. Semigroup a => a -> a -> a
<> KernelBody GPUMem -> Seq (Stm GPUMem)
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
body}
    SegScan lvl
lvl SegSpace
sp [SegBinOp GPUMem]
binops [Type]
tps KernelBody GPUMem
body ->
      lvl
-> SegSpace
-> [SegBinOp GPUMem]
-> [Type]
-> KernelBody GPUMem
-> SegOp lvl GPUMem
forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegScan lvl
lvl SegSpace
sp [SegBinOp GPUMem]
binops [Type]
tps (KernelBody GPUMem -> SegOp lvl GPUMem)
-> KernelBody GPUMem -> SegOp lvl GPUMem
forall a b. (a -> b) -> a -> b
$
        KernelBody GPUMem
body {kernelBodyStms :: Seq (Stm GPUMem)
kernelBodyStms = Seq (Stm GPUMem)
maxstms Seq (Stm GPUMem) -> Seq (Stm GPUMem) -> Seq (Stm GPUMem)
forall a. Semigroup a => a -> a -> a
<> Seq (Stm GPUMem)
stms Seq (Stm GPUMem) -> Seq (Stm GPUMem) -> Seq (Stm GPUMem)
forall a. Semigroup a => a -> a -> a
<> KernelBody GPUMem -> Seq (Stm GPUMem)
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
body}
    SegHist lvl
lvl SegSpace
sp [HistOp GPUMem]
binops [Type]
tps KernelBody GPUMem
body ->
      lvl
-> SegSpace
-> [HistOp GPUMem]
-> [Type]
-> KernelBody GPUMem
-> SegOp lvl GPUMem
forall lvl rep.
lvl
-> SegSpace
-> [HistOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegHist lvl
lvl SegSpace
sp [HistOp GPUMem]
binops [Type]
tps (KernelBody GPUMem -> SegOp lvl GPUMem)
-> KernelBody GPUMem -> SegOp lvl GPUMem
forall a b. (a -> b) -> a -> b
$
        KernelBody GPUMem
body {kernelBodyStms :: Seq (Stm GPUMem)
kernelBodyStms = Seq (Stm GPUMem)
maxstms Seq (Stm GPUMem) -> Seq (Stm GPUMem) -> Seq (Stm GPUMem)
forall a. Semigroup a => a -> a -> a
<> Seq (Stm GPUMem)
stms Seq (Stm GPUMem) -> Seq (Stm GPUMem) -> Seq (Stm GPUMem)
forall a. Semigroup a => a -> a -> a
<> KernelBody GPUMem -> Seq (Stm GPUMem)
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
body}

-- | Helper function that modifies kernels found inside some statements.
onKernels ::
  LocalScope GPUMem m =>
  (SegOp SegLevel GPUMem -> m (SegOp SegLevel GPUMem)) ->
  Stms GPUMem ->
  m (Stms GPUMem)
onKernels :: forall (m :: * -> *).
LocalScope GPUMem m =>
(SegOp SegLevel GPUMem -> m (SegOp SegLevel GPUMem))
-> Seq (Stm GPUMem) -> m (Seq (Stm GPUMem))
onKernels SegOp SegLevel GPUMem -> m (SegOp SegLevel GPUMem)
f =
  (Stm GPUMem -> m (Stm GPUMem))
-> Seq (Stm GPUMem) -> m (Seq (Stm GPUMem))
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Stm GPUMem -> m (Stm GPUMem)
helper
  where
    helper :: Stm GPUMem -> m (Stm GPUMem)
helper stm :: Stm GPUMem
stm@Let {stmExp :: forall rep. Stm rep -> Exp rep
stmExp = Op (Inner (SegOp SegOp SegLevel GPUMem
segop))} =
      Stm GPUMem -> m (Stm GPUMem) -> m (Stm GPUMem)
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stm GPUMem
stm (m (Stm GPUMem) -> m (Stm GPUMem))
-> m (Stm GPUMem) -> m (Stm GPUMem)
forall a b. (a -> b) -> a -> b
$ do
        SegOp SegLevel GPUMem
exp' <- SegOp SegLevel GPUMem -> m (SegOp SegLevel GPUMem)
f SegOp SegLevel GPUMem
segop
        Stm GPUMem -> m (Stm GPUMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stm GPUMem -> m (Stm GPUMem)) -> Stm GPUMem -> m (Stm GPUMem)
forall a b. (a -> b) -> a -> b
$ Stm GPUMem
stm {stmExp :: ExpT GPUMem
stmExp = Op GPUMem -> ExpT GPUMem
forall rep. Op rep -> ExpT rep
Op (Op GPUMem -> ExpT GPUMem) -> Op GPUMem -> ExpT GPUMem
forall a b. (a -> b) -> a -> b
$ HostOp GPUMem () -> MemOp (HostOp GPUMem ())
forall inner. inner -> MemOp inner
Inner (HostOp GPUMem () -> MemOp (HostOp GPUMem ()))
-> HostOp GPUMem () -> MemOp (HostOp GPUMem ())
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel GPUMem -> HostOp GPUMem ()
forall rep op. SegOp SegLevel rep -> HostOp rep op
SegOp SegOp SegLevel GPUMem
exp'}
    helper stm :: Stm GPUMem
stm@Let {stmExp :: forall rep. Stm rep -> Exp rep
stmExp = If SubExp
c BodyT GPUMem
then_body BodyT GPUMem
else_body IfDec (BranchType GPUMem)
dec} =
      Stm GPUMem -> m (Stm GPUMem) -> m (Stm GPUMem)
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stm GPUMem
stm (m (Stm GPUMem) -> m (Stm GPUMem))
-> m (Stm GPUMem) -> m (Stm GPUMem)
forall a b. (a -> b) -> a -> b
$ do
        Seq (Stm GPUMem)
then_body_stms <- SegOp SegLevel GPUMem -> m (SegOp SegLevel GPUMem)
f (SegOp SegLevel GPUMem -> m (SegOp SegLevel GPUMem))
-> Seq (Stm GPUMem) -> m (Seq (Stm GPUMem))
forall (m :: * -> *).
LocalScope GPUMem m =>
(SegOp SegLevel GPUMem -> m (SegOp SegLevel GPUMem))
-> Seq (Stm GPUMem) -> m (Seq (Stm GPUMem))
`onKernels` BodyT GPUMem -> Seq (Stm GPUMem)
forall rep. BodyT rep -> Stms rep
bodyStms BodyT GPUMem
then_body
        Seq (Stm GPUMem)
else_body_stms <- SegOp SegLevel GPUMem -> m (SegOp SegLevel GPUMem)
f (SegOp SegLevel GPUMem -> m (SegOp SegLevel GPUMem))
-> Seq (Stm GPUMem) -> m (Seq (Stm GPUMem))
forall (m :: * -> *).
LocalScope GPUMem m =>
(SegOp SegLevel GPUMem -> m (SegOp SegLevel GPUMem))
-> Seq (Stm GPUMem) -> m (Seq (Stm GPUMem))
`onKernels` BodyT GPUMem -> Seq (Stm GPUMem)
forall rep. BodyT rep -> Stms rep
bodyStms BodyT GPUMem
else_body
        Stm GPUMem -> m (Stm GPUMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stm GPUMem -> m (Stm GPUMem)) -> Stm GPUMem -> m (Stm GPUMem)
forall a b. (a -> b) -> a -> b
$
          Stm GPUMem
stm
            { stmExp :: ExpT GPUMem
stmExp =
                SubExp
-> BodyT GPUMem
-> BodyT GPUMem
-> IfDec (BranchType GPUMem)
-> ExpT GPUMem
forall rep.
SubExp
-> BodyT rep -> BodyT rep -> IfDec (BranchType rep) -> ExpT rep
If
                  SubExp
c
                  (BodyT GPUMem
then_body {bodyStms :: Seq (Stm GPUMem)
bodyStms = Seq (Stm GPUMem)
then_body_stms})
                  (BodyT GPUMem
else_body {bodyStms :: Seq (Stm GPUMem)
bodyStms = Seq (Stm GPUMem)
else_body_stms})
                  IfDec (BranchType GPUMem)
dec
            }
    helper stm :: Stm GPUMem
stm@Let {stmExp :: forall rep. Stm rep -> Exp rep
stmExp = DoLoop [(FParam GPUMem, SubExp)]
ctx [(FParam GPUMem, SubExp)]
vals LoopForm GPUMem
form BodyT GPUMem
body} =
      Stm GPUMem -> m (Stm GPUMem) -> m (Stm GPUMem)
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stm GPUMem
stm (m (Stm GPUMem) -> m (Stm GPUMem))
-> m (Stm GPUMem) -> m (Stm GPUMem)
forall a b. (a -> b) -> a -> b
$ do
        Seq (Stm GPUMem)
stms <- SegOp SegLevel GPUMem -> m (SegOp SegLevel GPUMem)
f (SegOp SegLevel GPUMem -> m (SegOp SegLevel GPUMem))
-> Seq (Stm GPUMem) -> m (Seq (Stm GPUMem))
forall (m :: * -> *).
LocalScope GPUMem m =>
(SegOp SegLevel GPUMem -> m (SegOp SegLevel GPUMem))
-> Seq (Stm GPUMem) -> m (Seq (Stm GPUMem))
`onKernels` BodyT GPUMem -> Seq (Stm GPUMem)
forall rep. BodyT rep -> Stms rep
bodyStms BodyT GPUMem
body
        Stm GPUMem -> m (Stm GPUMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stm GPUMem -> m (Stm GPUMem)) -> Stm GPUMem -> m (Stm GPUMem)
forall a b. (a -> b) -> a -> b
$ Stm GPUMem
stm {stmExp :: ExpT GPUMem
stmExp = [(FParam GPUMem, SubExp)]
-> [(FParam GPUMem, SubExp)]
-> LoopForm GPUMem
-> BodyT GPUMem
-> ExpT GPUMem
forall rep.
[(FParam rep, SubExp)]
-> [(FParam rep, SubExp)] -> LoopForm rep -> BodyT rep -> ExpT rep
DoLoop [(FParam GPUMem, SubExp)]
ctx [(FParam GPUMem, SubExp)]
vals LoopForm GPUMem
form (BodyT GPUMem
body {bodyStms :: Seq (Stm GPUMem)
bodyStms = Seq (Stm GPUMem)
stms})}
    helper Stm GPUMem
stm =
      Stm GPUMem -> m (Stm GPUMem) -> m (Stm GPUMem)
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stm GPUMem
stm (m (Stm GPUMem) -> m (Stm GPUMem))
-> m (Stm GPUMem) -> m (Stm GPUMem)
forall a b. (a -> b) -> a -> b
$ Stm GPUMem -> m (Stm GPUMem)
forall (m :: * -> *) a. Monad m => a -> m a
return Stm GPUMem
stm

-- | Perform the reuse-allocations optimization.
optimise :: Pass GPUMem GPUMem
optimise :: Pass GPUMem GPUMem
optimise =
  [Char]
-> [Char]
-> (Prog GPUMem -> PassM (Prog GPUMem))
-> Pass GPUMem GPUMem
forall fromrep torep.
[Char]
-> [Char]
-> (Prog fromrep -> PassM (Prog torep))
-> Pass fromrep torep
Pass [Char]
"reuse allocations" [Char]
"reuse allocations" ((Prog GPUMem -> PassM (Prog GPUMem)) -> Pass GPUMem GPUMem)
-> (Prog GPUMem -> PassM (Prog GPUMem)) -> Pass GPUMem GPUMem
forall a b. (a -> b) -> a -> b
$ \Prog GPUMem
prog ->
    let (LastUseMap
lumap, Used
_) = Prog GPUMem -> (LastUseMap, Used)
LastUse.analyseProg Prog GPUMem
prog
        graph :: Graph VName
graph =
          (FunDef GPUMem -> Graph VName) -> [FunDef GPUMem] -> Graph VName
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap
            ( \FunDef GPUMem
f ->
                Reader (Scope GPUMem) (Graph VName) -> Scope GPUMem -> Graph VName
forall r a. Reader r a -> r -> a
runReader
                  ( LastUseMap
-> Seq (Stm GPUMem) -> Reader (Scope GPUMem) (Graph VName)
forall (m :: * -> *).
LocalScope GPUMem m =>
LastUseMap -> Seq (Stm GPUMem) -> m (Graph VName)
Interference.analyseGPU LastUseMap
lumap (Seq (Stm GPUMem) -> Reader (Scope GPUMem) (Graph VName))
-> Seq (Stm GPUMem) -> Reader (Scope GPUMem) (Graph VName)
forall a b. (a -> b) -> a -> b
$
                      BodyT GPUMem -> Seq (Stm GPUMem)
forall rep. BodyT rep -> Stms rep
bodyStms (BodyT GPUMem -> Seq (Stm GPUMem))
-> BodyT GPUMem -> Seq (Stm GPUMem)
forall a b. (a -> b) -> a -> b
$ FunDef GPUMem -> BodyT GPUMem
forall rep. FunDef rep -> BodyT rep
funDefBody FunDef GPUMem
f
                  )
                  (Scope GPUMem -> Graph VName) -> Scope GPUMem -> Graph VName
forall a b. (a -> b) -> a -> b
$ FunDef GPUMem -> Scope GPUMem
forall rep a. Scoped rep a => a -> Scope rep
scopeOf FunDef GPUMem
f
            )
            ([FunDef GPUMem] -> Graph VName) -> [FunDef GPUMem] -> Graph VName
forall a b. (a -> b) -> a -> b
$ Prog GPUMem -> [FunDef GPUMem]
forall rep. Prog rep -> [FunDef rep]
progFuns Prog GPUMem
prog
     in (Scope GPUMem -> Seq (Stm GPUMem) -> PassM (Seq (Stm GPUMem)))
-> Prog GPUMem -> PassM (Prog GPUMem)
forall rep.
(Scope rep -> Stms rep -> PassM (Stms rep))
-> Prog rep -> PassM (Prog rep)
Pass.intraproceduralTransformation (Graph VName
-> Scope GPUMem -> Seq (Stm GPUMem) -> PassM (Seq (Stm GPUMem))
onStms Graph VName
graph) Prog GPUMem
prog
  where
    onStms ::
      Interference.Graph VName ->
      Scope GPUMem ->
      Stms GPUMem ->
      PassM (Stms GPUMem)
    onStms :: Graph VName
-> Scope GPUMem -> Seq (Stm GPUMem) -> PassM (Seq (Stm GPUMem))
onStms Graph VName
graph Scope GPUMem
scope Seq (Stm GPUMem)
stms = do
      let m :: BinderT GPUMem (StateT VNameSource Identity) (Seq (Stm GPUMem))
m = Scope GPUMem
-> BinderT GPUMem (StateT VNameSource Identity) (Seq (Stm GPUMem))
-> BinderT GPUMem (StateT VNameSource Identity) (Seq (Stm GPUMem))
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope GPUMem
scope (BinderT GPUMem (StateT VNameSource Identity) (Seq (Stm GPUMem))
 -> BinderT GPUMem (StateT VNameSource Identity) (Seq (Stm GPUMem)))
-> BinderT GPUMem (StateT VNameSource Identity) (Seq (Stm GPUMem))
-> BinderT GPUMem (StateT VNameSource Identity) (Seq (Stm GPUMem))
forall a b. (a -> b) -> a -> b
$ Graph VName
-> SegOp SegLevel GPUMem
-> BinderT
     GPUMem (StateT VNameSource Identity) (SegOp SegLevel GPUMem)
forall (m :: * -> *) lvl.
(MonadBinder m, Rep m ~ GPUMem) =>
Graph VName -> SegOp lvl GPUMem -> m (SegOp lvl GPUMem)
optimiseKernel Graph VName
graph (SegOp SegLevel GPUMem
 -> BinderT
      GPUMem (StateT VNameSource Identity) (SegOp SegLevel GPUMem))
-> Seq (Stm GPUMem)
-> BinderT GPUMem (StateT VNameSource Identity) (Seq (Stm GPUMem))
forall (m :: * -> *).
LocalScope GPUMem m =>
(SegOp SegLevel GPUMem -> m (SegOp SegLevel GPUMem))
-> Seq (Stm GPUMem) -> m (Seq (Stm GPUMem))
`onKernels` Seq (Stm GPUMem)
stms
      ((Seq (Stm GPUMem), Seq (Stm GPUMem)) -> Seq (Stm GPUMem))
-> PassM (Seq (Stm GPUMem), Seq (Stm GPUMem))
-> PassM (Seq (Stm GPUMem))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Seq (Stm GPUMem), Seq (Stm GPUMem)) -> Seq (Stm GPUMem)
forall a b. (a, b) -> a
fst (PassM (Seq (Stm GPUMem), Seq (Stm GPUMem))
 -> PassM (Seq (Stm GPUMem)))
-> PassM (Seq (Stm GPUMem), Seq (Stm GPUMem))
-> PassM (Seq (Stm GPUMem))
forall a b. (a -> b) -> a -> b
$ (VNameSource
 -> ((Seq (Stm GPUMem), Seq (Stm GPUMem)), VNameSource))
-> PassM (Seq (Stm GPUMem), Seq (Stm GPUMem))
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource
  -> ((Seq (Stm GPUMem), Seq (Stm GPUMem)), VNameSource))
 -> PassM (Seq (Stm GPUMem), Seq (Stm GPUMem)))
-> (VNameSource
    -> ((Seq (Stm GPUMem), Seq (Stm GPUMem)), VNameSource))
-> PassM (Seq (Stm GPUMem), Seq (Stm GPUMem))
forall a b. (a -> b) -> a -> b
$ State VNameSource (Seq (Stm GPUMem), Seq (Stm GPUMem))
-> VNameSource
-> ((Seq (Stm GPUMem), Seq (Stm GPUMem)), VNameSource)
forall s a. State s a -> s -> (a, s)
runState (BinderT GPUMem (StateT VNameSource Identity) (Seq (Stm GPUMem))
-> Scope GPUMem
-> State VNameSource (Seq (Stm GPUMem), Seq (Stm GPUMem))
forall (m :: * -> *) rep a.
MonadFreshNames m =>
BinderT rep m a -> Scope rep -> m (a, Stms rep)
runBinderT BinderT GPUMem (StateT VNameSource Identity) (Seq (Stm GPUMem))
m Scope GPUMem
forall a. Monoid a => a
mempty)