{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}

-- | Interference analysis for Futhark programs.
module Futhark.Analysis.Interference (Graph, analyseGPU) where

import Control.Monad.Reader
import Data.Foldable (toList)
import Data.Function ((&))
import Data.Functor ((<&>))
import Data.Map (Map)
import qualified Data.Map as M
import Data.Maybe (catMaybes, fromMaybe, mapMaybe)
import Data.Set (Set)
import qualified Data.Set as S
import Futhark.Analysis.LastUse (LastUseMap)
import Futhark.IR.GPUMem
import Futhark.Util (invertMap)

-- | The set of 'VName' currently in use.
type InUse = Names

-- | The set of 'VName' that are no longer in use.
type LastUsed = Names

-- | An interference graph. An element @(x, y)@ in the set means that there is
-- an undirected edge between @x@ and @y@, and therefore the lifetimes of @x@
-- and @y@ overlap and they "interfere" with each other. We assume that pairs
-- are always normalized, such that @x@ < @y@, before inserting. This should
-- prevent any duplicates. We also don't allow any pairs where @x == y@.
type Graph a = Set (a, a)

-- | Insert an edge between two values into the graph.
makeEdge :: Ord a => a -> a -> Graph a
makeEdge :: a -> a -> Graph a
makeEdge a
v1 a
v2
  | a
v1 a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
v2 = Graph a
forall a. Monoid a => a
mempty
  | Bool
otherwise = (a, a) -> Graph a
forall a. a -> Set a
S.singleton (a -> a -> a
forall a. Ord a => a -> a -> a
min a
v1 a
v2, a -> a -> a
forall a. Ord a => a -> a -> a
max a
v1 a
v2)

-- | Compute the cartesian product of two foldable collections, using the given
-- combinator function.
cartesian :: (Monoid m, Foldable t) => (a -> a -> m) -> t a -> t a -> m
cartesian :: (a -> a -> m) -> t a -> t a -> m
cartesian a -> a -> m
f t a
xs t a
ys =
  [(a
x, a
y) | a
x <- t a -> [a]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList t a
xs, a
y <- t a -> [a]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList t a
ys]
    [(a, a)] -> ([(a, a)] -> m) -> m
forall a b. a -> (a -> b) -> b
& ((a, a) -> m) -> [(a, a)] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap ((a -> a -> m) -> (a, a) -> m
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry a -> a -> m
f)

analyseStm ::
  LocalScope GPUMem m =>
  LastUseMap ->
  InUse ->
  Stm GPUMem ->
  m (InUse, LastUsed, Graph VName)
analyseStm :: LastUseMap -> InUse -> Stm GPUMem -> m (InUse, InUse, Graph VName)
analyseStm LastUseMap
lumap InUse
inuse0 Stm GPUMem
stm =
  Stm GPUMem
-> m (InUse, InUse, Graph VName) -> m (InUse, InUse, Graph VName)
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stm GPUMem
stm (m (InUse, InUse, Graph VName) -> m (InUse, InUse, Graph VName))
-> m (InUse, InUse, Graph VName) -> m (InUse, InUse, Graph VName)
forall a b. (a -> b) -> a -> b
$ do
    let pat_name :: VName
pat_name = PatElemT LetDecMem -> VName
forall dec. PatElemT dec -> VName
patElemName (PatElemT LetDecMem -> VName) -> PatElemT LetDecMem -> VName
forall a b. (a -> b) -> a -> b
$ [PatElemT LetDecMem] -> PatElemT LetDecMem
forall a. [a] -> a
head ([PatElemT LetDecMem] -> PatElemT LetDecMem)
-> [PatElemT LetDecMem] -> PatElemT LetDecMem
forall a b. (a -> b) -> a -> b
$ PatT LetDecMem -> [PatElemT LetDecMem]
forall dec. PatT dec -> [PatElemT dec]
patElems (PatT LetDecMem -> [PatElemT LetDecMem])
-> PatT LetDecMem -> [PatElemT LetDecMem]
forall a b. (a -> b) -> a -> b
$ Stm GPUMem -> Pat GPUMem
forall rep. Stm rep -> Pat rep
stmPat Stm GPUMem
stm

    InUse
new_mems <-
      Stm GPUMem -> Pat GPUMem
forall rep. Stm rep -> Pat rep
stmPat Stm GPUMem
stm
        PatT LetDecMem
-> (PatT LetDecMem -> [PatElemT LetDecMem]) -> [PatElemT LetDecMem]
forall a b. a -> (a -> b) -> b
& PatT LetDecMem -> [PatElemT LetDecMem]
forall dec. PatT dec -> [PatElemT dec]
patElems
        [PatElemT LetDecMem]
-> ([PatElemT LetDecMem] -> m [Maybe VName]) -> m [Maybe VName]
forall a b. a -> (a -> b) -> b
& (PatElemT LetDecMem -> m (Maybe VName))
-> [PatElemT LetDecMem] -> m [Maybe VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (VName -> m (Maybe VName)
forall (m :: * -> *).
LocalScope GPUMem m =>
VName -> m (Maybe VName)
memInfo (VName -> m (Maybe VName))
-> (PatElemT LetDecMem -> VName)
-> PatElemT LetDecMem
-> m (Maybe VName)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElemT LetDecMem -> VName
forall dec. PatElemT dec -> VName
patElemName)
        m [Maybe VName] -> ([Maybe VName] -> [VName]) -> m [VName]
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> [Maybe VName] -> [VName]
forall a. [Maybe a] -> [a]
catMaybes
        m [VName] -> ([VName] -> InUse) -> m InUse
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> [VName] -> InUse
namesFromList

    -- `new_mems` should interfere with any mems inside the statement expression
    let inuse_outside :: InUse
inuse_outside = InUse
inuse0 InUse -> InUse -> InUse
forall a. Semigroup a => a -> a -> a
<> InUse
new_mems

    -- `inuse` is the set of memory blocks that are inuse at the end of any code
    -- bodies inside the expression. `lus` is the set of all memory blocks that
    -- have reached their last use in any code bodies inside the
    -- expression. `graph` is the interference graph computed for any code
    -- bodies inside the expression.
    (InUse
inuse, InUse
lus, Graph VName
graph) <- LastUseMap -> InUse -> Exp GPUMem -> m (InUse, InUse, Graph VName)
forall (m :: * -> *).
LocalScope GPUMem m =>
LastUseMap -> InUse -> Exp GPUMem -> m (InUse, InUse, Graph VName)
analyseExp LastUseMap
lumap InUse
inuse_outside (Stm GPUMem -> Exp GPUMem
forall rep. Stm rep -> Exp rep
stmExp Stm GPUMem
stm)

    InUse
last_use_mems <-
      VName -> LastUseMap -> Maybe InUse
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
pat_name LastUseMap
lumap
        Maybe InUse -> (Maybe InUse -> InUse) -> InUse
forall a b. a -> (a -> b) -> b
& InUse -> Maybe InUse -> InUse
forall a. a -> Maybe a -> a
fromMaybe InUse
forall a. Monoid a => a
mempty
        InUse -> (InUse -> [VName]) -> [VName]
forall a b. a -> (a -> b) -> b
& InUse -> [VName]
namesToList
        [VName] -> ([VName] -> m [Maybe VName]) -> m [Maybe VName]
forall a b. a -> (a -> b) -> b
& (VName -> m (Maybe VName)) -> [VName] -> m [Maybe VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> m (Maybe VName)
forall (m :: * -> *).
LocalScope GPUMem m =>
VName -> m (Maybe VName)
memInfo
        m [Maybe VName] -> ([Maybe VName] -> [VName]) -> m [VName]
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> [Maybe VName] -> [VName]
forall a. [Maybe a] -> [a]
catMaybes
        m [VName] -> ([VName] -> InUse) -> m InUse
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> [VName] -> InUse
namesFromList
        m InUse -> (InUse -> InUse) -> m InUse
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> InUse -> InUse -> InUse
namesIntersection InUse
inuse_outside

    (InUse, InUse, Graph VName) -> m (InUse, InUse, Graph VName)
forall (m :: * -> *) a. Monad m => a -> m a
return
      ( (InUse
inuse_outside InUse -> InUse -> InUse
`namesSubtract` InUse
last_use_mems InUse -> InUse -> InUse
`namesSubtract` InUse
lus)
          InUse -> InUse -> InUse
forall a. Semigroup a => a -> a -> a
<> InUse
new_mems,
        (InUse
lus InUse -> InUse -> InUse
forall a. Semigroup a => a -> a -> a
<> InUse
last_use_mems) InUse -> InUse -> InUse
`namesSubtract` InUse
new_mems,
        Graph VName
graph
          Graph VName -> Graph VName -> Graph VName
forall a. Semigroup a => a -> a -> a
<> (VName -> VName -> Graph VName)
-> [VName] -> [VName] -> Graph VName
forall m (t :: * -> *) a.
(Monoid m, Foldable t) =>
(a -> a -> m) -> t a -> t a -> m
cartesian
            VName -> VName -> Graph VName
forall a. Ord a => a -> a -> Graph a
makeEdge
            (InUse -> [VName]
namesToList InUse
inuse_outside)
            (InUse -> [VName]
namesToList (InUse -> [VName]) -> InUse -> [VName]
forall a b. (a -> b) -> a -> b
$ InUse
inuse_outside InUse -> InUse -> InUse
forall a. Semigroup a => a -> a -> a
<> InUse
inuse InUse -> InUse -> InUse
forall a. Semigroup a => a -> a -> a
<> InUse
lus InUse -> InUse -> InUse
forall a. Semigroup a => a -> a -> a
<> InUse
last_use_mems)
      )

-- We conservatively treat all memory arguments to a DoLoop to
-- interfere with each other, as well as anything used inside the
-- loop.  This could potentially be improved by looking at the
-- interference computed by the loop body wrt. the loop arguments, but
-- probably very few programs would benefit from this.
analyseLoopParams ::
  [(FParam GPUMem, SubExp)] ->
  (InUse, LastUsed, Graph VName) ->
  (InUse, LastUsed, Graph VName)
analyseLoopParams :: [(FParam GPUMem, SubExp)]
-> (InUse, InUse, Graph VName) -> (InUse, InUse, Graph VName)
analyseLoopParams [(FParam GPUMem, SubExp)]
merge (InUse
inuse, InUse
lastused, Graph VName
graph) =
  (InUse
inuse, InUse
lastused, (VName -> VName -> Graph VName)
-> [VName] -> [VName] -> Graph VName
forall m (t :: * -> *) a.
(Monoid m, Foldable t) =>
(a -> a -> m) -> t a -> t a -> m
cartesian VName -> VName -> Graph VName
forall a. Ord a => a -> a -> Graph a
makeEdge [VName]
mems ([VName]
mems [VName] -> [VName] -> [VName]
forall a. Semigroup a => a -> a -> a
<> [VName]
inner_mems) Graph VName -> Graph VName -> Graph VName
forall a. Semigroup a => a -> a -> a
<> Graph VName
graph)
  where
    mems :: [VName]
mems = ((Param (MemInfo SubExp Uniqueness MemBind), SubExp)
 -> Maybe VName)
-> [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)] -> [VName]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (Param (MemInfo SubExp Uniqueness MemBind), SubExp) -> Maybe VName
forall d u ret. (Param (MemInfo d u ret), SubExp) -> Maybe VName
isMemArg [(FParam GPUMem, SubExp)]
[(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
merge
    inner_mems :: [VName]
inner_mems = InUse -> [VName]
namesToList InUse
lastused [VName] -> [VName] -> [VName]
forall a. Semigroup a => a -> a -> a
<> InUse -> [VName]
namesToList InUse
inuse
    isMemArg :: (Param (MemInfo d u ret), SubExp) -> Maybe VName
isMemArg (Param VName
_ MemMem {}, Var VName
v) = VName -> Maybe VName
forall a. a -> Maybe a
Just VName
v
    isMemArg (Param (MemInfo d u ret), SubExp)
_ = Maybe VName
forall a. Maybe a
Nothing

analyseExp ::
  LocalScope GPUMem m =>
  LastUseMap ->
  InUse ->
  Exp GPUMem ->
  m (InUse, LastUsed, Graph VName)
analyseExp :: LastUseMap -> InUse -> Exp GPUMem -> m (InUse, InUse, Graph VName)
analyseExp LastUseMap
lumap InUse
inuse_outside Exp GPUMem
expr =
  case Exp GPUMem
expr of
    If SubExp
_ BodyT GPUMem
then_body BodyT GPUMem
else_body IfDec (BranchType GPUMem)
_ -> do
      (InUse, InUse, Graph VName)
res1 <- LastUseMap
-> InUse -> BodyT GPUMem -> m (InUse, InUse, Graph VName)
forall (m :: * -> *).
LocalScope GPUMem m =>
LastUseMap
-> InUse -> BodyT GPUMem -> m (InUse, InUse, Graph VName)
analyseBody LastUseMap
lumap InUse
inuse_outside BodyT GPUMem
then_body
      (InUse, InUse, Graph VName)
res2 <- LastUseMap
-> InUse -> BodyT GPUMem -> m (InUse, InUse, Graph VName)
forall (m :: * -> *).
LocalScope GPUMem m =>
LastUseMap
-> InUse -> BodyT GPUMem -> m (InUse, InUse, Graph VName)
analyseBody LastUseMap
lumap InUse
inuse_outside BodyT GPUMem
else_body
      (InUse, InUse, Graph VName) -> m (InUse, InUse, Graph VName)
forall (m :: * -> *) a. Monad m => a -> m a
return ((InUse, InUse, Graph VName) -> m (InUse, InUse, Graph VName))
-> (InUse, InUse, Graph VName) -> m (InUse, InUse, Graph VName)
forall a b. (a -> b) -> a -> b
$ (InUse, InUse, Graph VName)
res1 (InUse, InUse, Graph VName)
-> (InUse, InUse, Graph VName) -> (InUse, InUse, Graph VName)
forall a. Semigroup a => a -> a -> a
<> (InUse, InUse, Graph VName)
res2
    DoLoop [(FParam GPUMem, SubExp)]
merge LoopForm GPUMem
_ BodyT GPUMem
body ->
      [(FParam GPUMem, SubExp)]
-> (InUse, InUse, Graph VName) -> (InUse, InUse, Graph VName)
analyseLoopParams [(FParam GPUMem, SubExp)]
merge ((InUse, InUse, Graph VName) -> (InUse, InUse, Graph VName))
-> m (InUse, InUse, Graph VName) -> m (InUse, InUse, Graph VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> LastUseMap
-> InUse -> BodyT GPUMem -> m (InUse, InUse, Graph VName)
forall (m :: * -> *).
LocalScope GPUMem m =>
LastUseMap
-> InUse -> BodyT GPUMem -> m (InUse, InUse, Graph VName)
analyseBody LastUseMap
lumap InUse
inuse_outside BodyT GPUMem
body
    Op (Inner (SegOp segop)) -> do
      LastUseMap
-> InUse -> SegOp SegLevel GPUMem -> m (InUse, InUse, Graph VName)
forall (m :: * -> *) lvl.
LocalScope GPUMem m =>
LastUseMap
-> InUse -> SegOp lvl GPUMem -> m (InUse, InUse, Graph VName)
analyseSegOp LastUseMap
lumap InUse
inuse_outside SegOp SegLevel GPUMem
segop
    Exp GPUMem
_ ->
      (InUse, InUse, Graph VName) -> m (InUse, InUse, Graph VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (InUse, InUse, Graph VName)
forall a. Monoid a => a
mempty

analyseKernelBody ::
  LocalScope GPUMem m =>
  LastUseMap ->
  InUse ->
  KernelBody GPUMem ->
  m (InUse, LastUsed, Graph VName)
analyseKernelBody :: LastUseMap
-> InUse -> KernelBody GPUMem -> m (InUse, InUse, Graph VName)
analyseKernelBody LastUseMap
lumap InUse
inuse KernelBody GPUMem
body = LastUseMap -> InUse -> Stms GPUMem -> m (InUse, InUse, Graph VName)
forall (m :: * -> *).
LocalScope GPUMem m =>
LastUseMap -> InUse -> Stms GPUMem -> m (InUse, InUse, Graph VName)
analyseStms LastUseMap
lumap InUse
inuse (Stms GPUMem -> m (InUse, InUse, Graph VName))
-> Stms GPUMem -> m (InUse, InUse, Graph VName)
forall a b. (a -> b) -> a -> b
$ KernelBody GPUMem -> Stms GPUMem
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
body

analyseBody ::
  LocalScope GPUMem m =>
  LastUseMap ->
  InUse ->
  Body GPUMem ->
  m (InUse, LastUsed, Graph VName)
analyseBody :: LastUseMap
-> InUse -> BodyT GPUMem -> m (InUse, InUse, Graph VName)
analyseBody LastUseMap
lumap InUse
inuse BodyT GPUMem
body = LastUseMap -> InUse -> Stms GPUMem -> m (InUse, InUse, Graph VName)
forall (m :: * -> *).
LocalScope GPUMem m =>
LastUseMap -> InUse -> Stms GPUMem -> m (InUse, InUse, Graph VName)
analyseStms LastUseMap
lumap InUse
inuse (Stms GPUMem -> m (InUse, InUse, Graph VName))
-> Stms GPUMem -> m (InUse, InUse, Graph VName)
forall a b. (a -> b) -> a -> b
$ BodyT GPUMem -> Stms GPUMem
forall rep. BodyT rep -> Stms rep
bodyStms BodyT GPUMem
body

analyseStms ::
  LocalScope GPUMem m =>
  LastUseMap ->
  InUse ->
  Stms GPUMem ->
  m (InUse, LastUsed, Graph VName)
analyseStms :: LastUseMap -> InUse -> Stms GPUMem -> m (InUse, InUse, Graph VName)
analyseStms LastUseMap
lumap InUse
inuse0 Stms GPUMem
stms = do
  Stms GPUMem
-> m (InUse, InUse, Graph VName) -> m (InUse, InUse, Graph VName)
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stms GPUMem
stms (m (InUse, InUse, Graph VName) -> m (InUse, InUse, Graph VName))
-> m (InUse, InUse, Graph VName) -> m (InUse, InUse, Graph VName)
forall a b. (a -> b) -> a -> b
$ ((InUse, InUse, Graph VName)
 -> Stm GPUMem -> m (InUse, InUse, Graph VName))
-> (InUse, InUse, Graph VName)
-> [Stm GPUMem]
-> m (InUse, InUse, Graph VName)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (InUse, InUse, Graph VName)
-> Stm GPUMem -> m (InUse, InUse, Graph VName)
helper (InUse
inuse0, InUse
forall a. Monoid a => a
mempty, Graph VName
forall a. Monoid a => a
mempty) ([Stm GPUMem] -> m (InUse, InUse, Graph VName))
-> [Stm GPUMem] -> m (InUse, InUse, Graph VName)
forall a b. (a -> b) -> a -> b
$ Stms GPUMem -> [Stm GPUMem]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms GPUMem
stms
  where
    helper :: (InUse, InUse, Graph VName)
-> Stm GPUMem -> m (InUse, InUse, Graph VName)
helper (InUse
inuse, InUse
lus, Graph VName
graph) Stm GPUMem
stm = do
      (InUse
inuse', InUse
lus', Graph VName
graph') <- LastUseMap -> InUse -> Stm GPUMem -> m (InUse, InUse, Graph VName)
forall (m :: * -> *).
LocalScope GPUMem m =>
LastUseMap -> InUse -> Stm GPUMem -> m (InUse, InUse, Graph VName)
analyseStm LastUseMap
lumap InUse
inuse Stm GPUMem
stm
      (InUse, InUse, Graph VName) -> m (InUse, InUse, Graph VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (InUse
inuse', InUse
lus' InUse -> InUse -> InUse
forall a. Semigroup a => a -> a -> a
<> InUse
lus, Graph VName
graph' Graph VName -> Graph VName -> Graph VName
forall a. Semigroup a => a -> a -> a
<> Graph VName
graph)

analyseSegOp ::
  LocalScope GPUMem m =>
  LastUseMap ->
  InUse ->
  SegOp lvl GPUMem ->
  m (InUse, LastUsed, Graph VName)
analyseSegOp :: LastUseMap
-> InUse -> SegOp lvl GPUMem -> m (InUse, InUse, Graph VName)
analyseSegOp LastUseMap
lumap InUse
inuse (SegMap lvl
_ SegSpace
_ [Type]
_ KernelBody GPUMem
body) =
  LastUseMap
-> InUse -> KernelBody GPUMem -> m (InUse, InUse, Graph VName)
forall (m :: * -> *).
LocalScope GPUMem m =>
LastUseMap
-> InUse -> KernelBody GPUMem -> m (InUse, InUse, Graph VName)
analyseKernelBody LastUseMap
lumap InUse
inuse KernelBody GPUMem
body
analyseSegOp LastUseMap
lumap InUse
inuse (SegRed lvl
_ SegSpace
_ [SegBinOp GPUMem]
binops [Type]
_ KernelBody GPUMem
body) =
  LastUseMap
-> InUse
-> [SegBinOp GPUMem]
-> KernelBody GPUMem
-> m (InUse, InUse, Graph VName)
forall (m :: * -> *).
LocalScope GPUMem m =>
LastUseMap
-> InUse
-> [SegBinOp GPUMem]
-> KernelBody GPUMem
-> m (InUse, InUse, Graph VName)
segWithBinOps LastUseMap
lumap InUse
inuse [SegBinOp GPUMem]
binops KernelBody GPUMem
body
analyseSegOp LastUseMap
lumap InUse
inuse (SegScan lvl
_ SegSpace
_ [SegBinOp GPUMem]
binops [Type]
_ KernelBody GPUMem
body) = do
  LastUseMap
-> InUse
-> [SegBinOp GPUMem]
-> KernelBody GPUMem
-> m (InUse, InUse, Graph VName)
forall (m :: * -> *).
LocalScope GPUMem m =>
LastUseMap
-> InUse
-> [SegBinOp GPUMem]
-> KernelBody GPUMem
-> m (InUse, InUse, Graph VName)
segWithBinOps LastUseMap
lumap InUse
inuse [SegBinOp GPUMem]
binops KernelBody GPUMem
body
analyseSegOp LastUseMap
lumap InUse
inuse (SegHist lvl
_ SegSpace
_ [HistOp GPUMem]
histops [Type]
_ KernelBody GPUMem
body) = do
  (InUse
inuse', InUse
lus', Graph VName
graph) <- LastUseMap
-> InUse -> KernelBody GPUMem -> m (InUse, InUse, Graph VName)
forall (m :: * -> *).
LocalScope GPUMem m =>
LastUseMap
-> InUse -> KernelBody GPUMem -> m (InUse, InUse, Graph VName)
analyseKernelBody LastUseMap
lumap InUse
inuse KernelBody GPUMem
body
  (InUse
inuse'', InUse
lus'', Graph VName
graph') <- [(InUse, InUse, Graph VName)] -> (InUse, InUse, Graph VName)
forall a. Monoid a => [a] -> a
mconcat ([(InUse, InUse, Graph VName)] -> (InUse, InUse, Graph VName))
-> m [(InUse, InUse, Graph VName)] -> m (InUse, InUse, Graph VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (HistOp GPUMem -> m (InUse, InUse, Graph VName))
-> [HistOp GPUMem] -> m [(InUse, InUse, Graph VName)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (LastUseMap
-> InUse -> HistOp GPUMem -> m (InUse, InUse, Graph VName)
forall (m :: * -> *).
LocalScope GPUMem m =>
LastUseMap
-> InUse -> HistOp GPUMem -> m (InUse, InUse, Graph VName)
analyseHistOp LastUseMap
lumap InUse
inuse') [HistOp GPUMem]
histops
  (InUse, InUse, Graph VName) -> m (InUse, InUse, Graph VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (InUse
inuse'', InUse
lus' InUse -> InUse -> InUse
forall a. Semigroup a => a -> a -> a
<> InUse
lus'', Graph VName
graph Graph VName -> Graph VName -> Graph VName
forall a. Semigroup a => a -> a -> a
<> Graph VName
graph')

segWithBinOps ::
  LocalScope GPUMem m =>
  LastUseMap ->
  InUse ->
  [SegBinOp GPUMem] ->
  KernelBody GPUMem ->
  m (InUse, LastUsed, Graph VName)
segWithBinOps :: LastUseMap
-> InUse
-> [SegBinOp GPUMem]
-> KernelBody GPUMem
-> m (InUse, InUse, Graph VName)
segWithBinOps LastUseMap
lumap InUse
inuse [SegBinOp GPUMem]
binops KernelBody GPUMem
body = do
  (InUse
inuse', InUse
lus', Graph VName
graph) <- LastUseMap
-> InUse -> KernelBody GPUMem -> m (InUse, InUse, Graph VName)
forall (m :: * -> *).
LocalScope GPUMem m =>
LastUseMap
-> InUse -> KernelBody GPUMem -> m (InUse, InUse, Graph VName)
analyseKernelBody LastUseMap
lumap InUse
inuse KernelBody GPUMem
body
  (InUse
inuse'', InUse
lus'', Graph VName
graph') <-
    [(InUse, InUse, Graph VName)] -> (InUse, InUse, Graph VName)
forall a. Monoid a => [a] -> a
mconcat
      ([(InUse, InUse, Graph VName)] -> (InUse, InUse, Graph VName))
-> m [(InUse, InUse, Graph VName)] -> m (InUse, InUse, Graph VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SegBinOp GPUMem -> m (InUse, InUse, Graph VName))
-> [SegBinOp GPUMem] -> m [(InUse, InUse, Graph VName)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM
        (LastUseMap
-> InUse -> SegBinOp GPUMem -> m (InUse, InUse, Graph VName)
forall (m :: * -> *).
LocalScope GPUMem m =>
LastUseMap
-> InUse -> SegBinOp GPUMem -> m (InUse, InUse, Graph VName)
analyseSegBinOp LastUseMap
lumap InUse
inuse')
        [SegBinOp GPUMem]
binops
  (InUse, InUse, Graph VName) -> m (InUse, InUse, Graph VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (InUse
inuse'', InUse
lus' InUse -> InUse -> InUse
forall a. Semigroup a => a -> a -> a
<> InUse
lus'', Graph VName
graph Graph VName -> Graph VName -> Graph VName
forall a. Semigroup a => a -> a -> a
<> Graph VName
graph')

analyseSegBinOp ::
  LocalScope GPUMem m =>
  LastUseMap ->
  InUse ->
  SegBinOp GPUMem ->
  m (InUse, LastUsed, Graph VName)
analyseSegBinOp :: LastUseMap
-> InUse -> SegBinOp GPUMem -> m (InUse, InUse, Graph VName)
analyseSegBinOp LastUseMap
lumap InUse
inuse (SegBinOp Commutativity
_ Lambda GPUMem
lambda [SubExp]
_ Shape
_) =
  LastUseMap
-> InUse -> Lambda GPUMem -> m (InUse, InUse, Graph VName)
forall (m :: * -> *).
LocalScope GPUMem m =>
LastUseMap
-> InUse -> Lambda GPUMem -> m (InUse, InUse, Graph VName)
analyseLambda LastUseMap
lumap InUse
inuse Lambda GPUMem
lambda

analyseHistOp ::
  LocalScope GPUMem m =>
  LastUseMap ->
  InUse ->
  HistOp GPUMem ->
  m (InUse, LastUsed, Graph VName)
analyseHistOp :: LastUseMap
-> InUse -> HistOp GPUMem -> m (InUse, InUse, Graph VName)
analyseHistOp LastUseMap
lumap InUse
inuse HistOp GPUMem
histop =
  LastUseMap
-> InUse -> Lambda GPUMem -> m (InUse, InUse, Graph VName)
forall (m :: * -> *).
LocalScope GPUMem m =>
LastUseMap
-> InUse -> Lambda GPUMem -> m (InUse, InUse, Graph VName)
analyseLambda LastUseMap
lumap InUse
inuse (HistOp GPUMem -> Lambda GPUMem
forall rep. HistOp rep -> Lambda rep
histOp HistOp GPUMem
histop)

analyseLambda ::
  LocalScope GPUMem m =>
  LastUseMap ->
  InUse ->
  Lambda GPUMem ->
  m (InUse, LastUsed, Graph VName)
analyseLambda :: LastUseMap
-> InUse -> Lambda GPUMem -> m (InUse, InUse, Graph VName)
analyseLambda LastUseMap
lumap InUse
inuse (Lambda [LParam GPUMem]
_ BodyT GPUMem
body [Type]
_) =
  LastUseMap
-> InUse -> BodyT GPUMem -> m (InUse, InUse, Graph VName)
forall (m :: * -> *).
LocalScope GPUMem m =>
LastUseMap
-> InUse -> BodyT GPUMem -> m (InUse, InUse, Graph VName)
analyseBody LastUseMap
lumap InUse
inuse BodyT GPUMem
body

-- | Perform interference analysis on the given statements. The result is a
-- triple of the names currently in use, names that hit their last use somewhere
-- within, and the resulting graph.
analyseGPU ::
  LocalScope GPUMem m =>
  LastUseMap ->
  Stms GPUMem ->
  m (Graph VName)
analyseGPU :: LastUseMap -> Stms GPUMem -> m (Graph VName)
analyseGPU LastUseMap
lumap Stms GPUMem
stms = do
  (InUse
_, InUse
_, Graph VName
graph) <- LastUseMap -> Stms GPUMem -> m (InUse, InUse, Graph VName)
forall (m :: * -> *).
LocalScope GPUMem m =>
LastUseMap -> Stms GPUMem -> m (InUse, InUse, Graph VName)
analyseGPU' LastUseMap
lumap Stms GPUMem
stms
  -- We need to insert edges between memory blocks which differ in size, if they
  -- are in DefaultSpace. The problem is that during memory expansion,
  -- DefaultSpace arrays in kernels are interleaved. If the element sizes of two
  -- merged memory blocks are different, threads might try to read and write to
  -- overlapping memory positions. More information here:
  -- https://munksgaard.me/technical-diary/2020-12-30.html#org210775b
  Map VName Space
spaces <- (Space -> Bool) -> Map VName Space -> Map VName Space
forall a k. (a -> Bool) -> Map k a -> Map k a
M.filter (Space -> Space -> Bool
forall a. Eq a => a -> a -> Bool
== Space
DefaultSpace) (Map VName Space -> Map VName Space)
-> m (Map VName Space) -> m (Map VName Space)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Stms GPUMem -> m (Map VName Space)
forall (m :: * -> *).
LocalScope GPUMem m =>
Stms GPUMem -> m (Map VName Space)
memSpaces Stms GPUMem
stms
  Map Int (Set VName)
inv_size_map <-
    Stms GPUMem -> m (Map VName Int)
forall (m :: * -> *).
LocalScope GPUMem m =>
Stms GPUMem -> m (Map VName Int)
memSizes Stms GPUMem
stms
      m (Map VName Int)
-> (Map VName Int -> Map VName Int) -> m (Map VName Int)
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> (Map VName Int -> Set VName -> Map VName Int)
-> Set VName -> Map VName Int -> Map VName Int
forall a b c. (a -> b -> c) -> b -> a -> c
flip Map VName Int -> Set VName -> Map VName Int
forall k a. Ord k => Map k a -> Set k -> Map k a
M.restrictKeys ([VName] -> Set VName
forall a. Ord a => [a] -> Set a
S.fromList ([VName] -> Set VName) -> [VName] -> Set VName
forall a b. (a -> b) -> a -> b
$ Map VName Space -> [VName]
forall k a. Map k a -> [k]
M.keys Map VName Space
spaces)
      m (Map VName Int)
-> (Map VName Int -> Map Int (Set VName))
-> m (Map Int (Set VName))
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> Map VName Int -> Map Int (Set VName)
forall v k. (Ord v, Ord k) => Map k v -> Map v (Set k)
invertMap
  let new_edges :: Graph VName
new_edges =
        (Set VName -> Set VName -> Graph VName)
-> Map Int (Set VName) -> Map Int (Set VName) -> Graph VName
forall m (t :: * -> *) a.
(Monoid m, Foldable t) =>
(a -> a -> m) -> t a -> t a -> m
cartesian
          (\Set VName
x Set VName
y -> if Set VName
x Set VName -> Set VName -> Bool
forall a. Eq a => a -> a -> Bool
/= Set VName
y then (VName -> VName -> Graph VName)
-> Set VName -> Set VName -> Graph VName
forall m (t :: * -> *) a.
(Monoid m, Foldable t) =>
(a -> a -> m) -> t a -> t a -> m
cartesian VName -> VName -> Graph VName
forall a. Ord a => a -> a -> Graph a
makeEdge Set VName
x Set VName
y else Graph VName
forall a. Monoid a => a
mempty)
          Map Int (Set VName)
inv_size_map
          Map Int (Set VName)
inv_size_map
  Graph VName -> m (Graph VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (Graph VName -> m (Graph VName)) -> Graph VName -> m (Graph VName)
forall a b. (a -> b) -> a -> b
$ Graph VName
graph Graph VName -> Graph VName -> Graph VName
forall a. Semigroup a => a -> a -> a
<> Graph VName
new_edges

-- | Return a mapping from memory blocks to their element sizes in the given
-- statements.
memSizes :: LocalScope GPUMem m => Stms GPUMem -> m (Map VName Int)
memSizes :: Stms GPUMem -> m (Map VName Int)
memSizes Stms GPUMem
stms =
  Stms GPUMem -> m (Map VName Int) -> m (Map VName Int)
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stms GPUMem
stms (m (Map VName Int) -> m (Map VName Int))
-> m (Map VName Int) -> m (Map VName Int)
forall a b. (a -> b) -> a -> b
$ ([Map VName Int] -> Map VName Int)
-> m [Map VName Int] -> m (Map VName Int)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [Map VName Int] -> Map VName Int
forall a. Monoid a => [a] -> a
mconcat (m [Map VName Int] -> m (Map VName Int))
-> ([Stm GPUMem] -> m [Map VName Int])
-> [Stm GPUMem]
-> m (Map VName Int)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Stm GPUMem -> m (Map VName Int))
-> [Stm GPUMem] -> m [Map VName Int]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Stm GPUMem -> m (Map VName Int)
forall (m :: * -> *).
LocalScope GPUMem m =>
Stm GPUMem -> m (Map VName Int)
memSizesStm ([Stm GPUMem] -> m (Map VName Int))
-> [Stm GPUMem] -> m (Map VName Int)
forall a b. (a -> b) -> a -> b
$ Stms GPUMem -> [Stm GPUMem]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms GPUMem
stms
  where
    memSizesStm :: LocalScope GPUMem m => Stm GPUMem -> m (Map VName Int)
    memSizesStm :: Stm GPUMem -> m (Map VName Int)
memSizesStm (Let Pat GPUMem
pat StmAux (ExpDec GPUMem)
_ Exp GPUMem
e) = do
      Map VName Int
arraySizes <- ([Map VName Int] -> Map VName Int)
-> m [Map VName Int] -> m (Map VName Int)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [Map VName Int] -> Map VName Int
forall a. Monoid a => [a] -> a
mconcat (m [Map VName Int] -> m (Map VName Int))
-> ([VName] -> m [Map VName Int]) -> [VName] -> m (Map VName Int)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VName -> m (Map VName Int)) -> [VName] -> m [Map VName Int]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> m (Map VName Int)
forall (m :: * -> *).
LocalScope GPUMem m =>
VName -> m (Map VName Int)
memElemSize ([VName] -> m (Map VName Int)) -> [VName] -> m (Map VName Int)
forall a b. (a -> b) -> a -> b
$ PatT LetDecMem -> [VName]
forall dec. PatT dec -> [VName]
patNames Pat GPUMem
PatT LetDecMem
pat
      Map VName Int
arraySizes' <- Exp GPUMem -> m (Map VName Int)
forall (m :: * -> *).
LocalScope GPUMem m =>
Exp GPUMem -> m (Map VName Int)
memSizesExp Exp GPUMem
e
      Map VName Int -> m (Map VName Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (Map VName Int -> m (Map VName Int))
-> Map VName Int -> m (Map VName Int)
forall a b. (a -> b) -> a -> b
$ Map VName Int
arraySizes Map VName Int -> Map VName Int -> Map VName Int
forall a. Semigroup a => a -> a -> a
<> Map VName Int
arraySizes'
    memSizesExp :: LocalScope GPUMem m => Exp GPUMem -> m (Map VName Int)
    memSizesExp :: Exp GPUMem -> m (Map VName Int)
memSizesExp (Op (Inner (SegOp segop))) =
      let body :: KernelBody GPUMem
body = SegOp SegLevel GPUMem -> KernelBody GPUMem
forall lvl rep. SegOp lvl rep -> KernelBody rep
segBody SegOp SegLevel GPUMem
segop
       in Stms GPUMem -> m (Map VName Int) -> m (Map VName Int)
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf (KernelBody GPUMem -> Stms GPUMem
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
body) (m (Map VName Int) -> m (Map VName Int))
-> m (Map VName Int) -> m (Map VName Int)
forall a b. (a -> b) -> a -> b
$
            ([Map VName Int] -> Map VName Int)
-> m [Map VName Int] -> m (Map VName Int)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [Map VName Int] -> Map VName Int
forall a. Monoid a => [a] -> a
mconcat
              (m [Map VName Int] -> m (Map VName Int))
-> ([Stm GPUMem] -> m [Map VName Int])
-> [Stm GPUMem]
-> m (Map VName Int)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Stm GPUMem -> m (Map VName Int))
-> [Stm GPUMem] -> m [Map VName Int]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Stm GPUMem -> m (Map VName Int)
forall (m :: * -> *).
LocalScope GPUMem m =>
Stm GPUMem -> m (Map VName Int)
memSizesStm
              ([Stm GPUMem] -> m (Map VName Int))
-> [Stm GPUMem] -> m (Map VName Int)
forall a b. (a -> b) -> a -> b
$ Stms GPUMem -> [Stm GPUMem]
forall rep. Stms rep -> [Stm rep]
stmsToList (Stms GPUMem -> [Stm GPUMem]) -> Stms GPUMem -> [Stm GPUMem]
forall a b. (a -> b) -> a -> b
$ KernelBody GPUMem -> Stms GPUMem
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
body
    memSizesExp (If SubExp
_ BodyT GPUMem
then_body BodyT GPUMem
else_body IfDec (BranchType GPUMem)
_) = do
      Map VName Int
then_res <- Stms GPUMem -> m (Map VName Int)
forall (m :: * -> *).
LocalScope GPUMem m =>
Stms GPUMem -> m (Map VName Int)
memSizes (Stms GPUMem -> m (Map VName Int))
-> Stms GPUMem -> m (Map VName Int)
forall a b. (a -> b) -> a -> b
$ BodyT GPUMem -> Stms GPUMem
forall rep. BodyT rep -> Stms rep
bodyStms BodyT GPUMem
then_body
      Map VName Int
else_res <- Stms GPUMem -> m (Map VName Int)
forall (m :: * -> *).
LocalScope GPUMem m =>
Stms GPUMem -> m (Map VName Int)
memSizes (Stms GPUMem -> m (Map VName Int))
-> Stms GPUMem -> m (Map VName Int)
forall a b. (a -> b) -> a -> b
$ BodyT GPUMem -> Stms GPUMem
forall rep. BodyT rep -> Stms rep
bodyStms BodyT GPUMem
else_body
      Map VName Int -> m (Map VName Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (Map VName Int -> m (Map VName Int))
-> Map VName Int -> m (Map VName Int)
forall a b. (a -> b) -> a -> b
$ Map VName Int
then_res Map VName Int -> Map VName Int -> Map VName Int
forall a. Semigroup a => a -> a -> a
<> Map VName Int
else_res
    memSizesExp (DoLoop [(FParam GPUMem, SubExp)]
_ LoopForm GPUMem
_ BodyT GPUMem
body) =
      Stms GPUMem -> m (Map VName Int)
forall (m :: * -> *).
LocalScope GPUMem m =>
Stms GPUMem -> m (Map VName Int)
memSizes (Stms GPUMem -> m (Map VName Int))
-> Stms GPUMem -> m (Map VName Int)
forall a b. (a -> b) -> a -> b
$ BodyT GPUMem -> Stms GPUMem
forall rep. BodyT rep -> Stms rep
bodyStms BodyT GPUMem
body
    memSizesExp Exp GPUMem
_ = Map VName Int -> m (Map VName Int)
forall (m :: * -> *) a. Monad m => a -> m a
return Map VName Int
forall a. Monoid a => a
mempty

-- | Return a mapping from memory blocks to the space they are allocated in.
memSpaces :: LocalScope GPUMem m => Stms GPUMem -> m (Map VName Space)
memSpaces :: Stms GPUMem -> m (Map VName Space)
memSpaces Stms GPUMem
stms =
  Map VName Space -> m (Map VName Space)
forall (m :: * -> *) a. Monad m => a -> m a
return (Map VName Space -> m (Map VName Space))
-> Map VName Space -> m (Map VName Space)
forall a b. (a -> b) -> a -> b
$ (Stm GPUMem -> Map VName Space) -> Stms GPUMem -> Map VName Space
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm GPUMem -> Map VName Space
getSpacesStm Stms GPUMem
stms
  where
    getSpacesStm :: Stm GPUMem -> Map VName Space
    getSpacesStm :: Stm GPUMem -> Map VName Space
getSpacesStm (Let (Pat [PatElem VName
name LetDec GPUMem
_]) StmAux (ExpDec GPUMem)
_ (Op (Alloc _ sp))) =
      VName -> Space -> Map VName Space
forall k a. k -> a -> Map k a
M.singleton VName
name Space
sp
    getSpacesStm (Let Pat GPUMem
_ StmAux (ExpDec GPUMem)
_ (Op (Alloc _ _))) = [Char] -> Map VName Space
forall a. HasCallStack => [Char] -> a
error [Char]
"impossible"
    getSpacesStm (Let Pat GPUMem
_ StmAux (ExpDec GPUMem)
_ (Op (Inner (SegOp segop)))) =
      (Stm GPUMem -> Map VName Space) -> Stms GPUMem -> Map VName Space
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm GPUMem -> Map VName Space
getSpacesStm (Stms GPUMem -> Map VName Space) -> Stms GPUMem -> Map VName Space
forall a b. (a -> b) -> a -> b
$ KernelBody GPUMem -> Stms GPUMem
forall rep. KernelBody rep -> Stms rep
kernelBodyStms (KernelBody GPUMem -> Stms GPUMem)
-> KernelBody GPUMem -> Stms GPUMem
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel GPUMem -> KernelBody GPUMem
forall lvl rep. SegOp lvl rep -> KernelBody rep
segBody SegOp SegLevel GPUMem
segop
    getSpacesStm (Let Pat GPUMem
_ StmAux (ExpDec GPUMem)
_ (If SubExp
_ BodyT GPUMem
then_body BodyT GPUMem
else_body IfDec (BranchType GPUMem)
_)) =
      (Stm GPUMem -> Map VName Space) -> Stms GPUMem -> Map VName Space
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm GPUMem -> Map VName Space
getSpacesStm (BodyT GPUMem -> Stms GPUMem
forall rep. BodyT rep -> Stms rep
bodyStms BodyT GPUMem
then_body)
        Map VName Space -> Map VName Space -> Map VName Space
forall a. Semigroup a => a -> a -> a
<> (Stm GPUMem -> Map VName Space) -> Stms GPUMem -> Map VName Space
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm GPUMem -> Map VName Space
getSpacesStm (BodyT GPUMem -> Stms GPUMem
forall rep. BodyT rep -> Stms rep
bodyStms BodyT GPUMem
else_body)
    getSpacesStm (Let Pat GPUMem
_ StmAux (ExpDec GPUMem)
_ (DoLoop [(FParam GPUMem, SubExp)]
_ LoopForm GPUMem
_ BodyT GPUMem
body)) =
      (Stm GPUMem -> Map VName Space) -> Stms GPUMem -> Map VName Space
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm GPUMem -> Map VName Space
getSpacesStm (BodyT GPUMem -> Stms GPUMem
forall rep. BodyT rep -> Stms rep
bodyStms BodyT GPUMem
body)
    getSpacesStm Stm GPUMem
_ = Map VName Space
forall a. Monoid a => a
mempty

analyseGPU' ::
  LocalScope GPUMem m =>
  LastUseMap ->
  Stms GPUMem ->
  m (InUse, LastUsed, Graph VName)
analyseGPU' :: LastUseMap -> Stms GPUMem -> m (InUse, InUse, Graph VName)
analyseGPU' LastUseMap
lumap Stms GPUMem
stms =
  [(InUse, InUse, Graph VName)] -> (InUse, InUse, Graph VName)
forall a. Monoid a => [a] -> a
mconcat ([(InUse, InUse, Graph VName)] -> (InUse, InUse, Graph VName))
-> (Seq (InUse, InUse, Graph VName)
    -> [(InUse, InUse, Graph VName)])
-> Seq (InUse, InUse, Graph VName)
-> (InUse, InUse, Graph VName)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Seq (InUse, InUse, Graph VName) -> [(InUse, InUse, Graph VName)]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList (Seq (InUse, InUse, Graph VName) -> (InUse, InUse, Graph VName))
-> m (Seq (InUse, InUse, Graph VName))
-> m (InUse, InUse, Graph VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Stm GPUMem -> m (InUse, InUse, Graph VName))
-> Stms GPUMem -> m (Seq (InUse, InUse, Graph VName))
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Stm GPUMem -> m (InUse, InUse, Graph VName)
forall (m :: * -> *).
LocalScope GPUMem m =>
Stm GPUMem -> m (InUse, InUse, Graph VName)
helper Stms GPUMem
stms
  where
    helper ::
      LocalScope GPUMem m =>
      Stm GPUMem ->
      m (InUse, LastUsed, Graph VName)
    helper :: Stm GPUMem -> m (InUse, InUse, Graph VName)
helper stm :: Stm GPUMem
stm@Let {stmExp :: forall rep. Stm rep -> Exp rep
stmExp = Op (Inner (SegOp segop))} =
      Stm GPUMem
-> m (InUse, InUse, Graph VName) -> m (InUse, InUse, Graph VName)
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stm GPUMem
stm (m (InUse, InUse, Graph VName) -> m (InUse, InUse, Graph VName))
-> m (InUse, InUse, Graph VName) -> m (InUse, InUse, Graph VName)
forall a b. (a -> b) -> a -> b
$ LastUseMap
-> InUse -> SegOp SegLevel GPUMem -> m (InUse, InUse, Graph VName)
forall (m :: * -> *) lvl.
LocalScope GPUMem m =>
LastUseMap
-> InUse -> SegOp lvl GPUMem -> m (InUse, InUse, Graph VName)
analyseSegOp LastUseMap
lumap InUse
forall a. Monoid a => a
mempty SegOp SegLevel GPUMem
segop
    helper stm :: Stm GPUMem
stm@Let {stmExp :: forall rep. Stm rep -> Exp rep
stmExp = If SubExp
_ BodyT GPUMem
then_body BodyT GPUMem
else_body IfDec (BranchType GPUMem)
_} =
      Stm GPUMem
-> m (InUse, InUse, Graph VName) -> m (InUse, InUse, Graph VName)
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stm GPUMem
stm (m (InUse, InUse, Graph VName) -> m (InUse, InUse, Graph VName))
-> m (InUse, InUse, Graph VName) -> m (InUse, InUse, Graph VName)
forall a b. (a -> b) -> a -> b
$ do
        (InUse, InUse, Graph VName)
res1 <- LastUseMap -> Stms GPUMem -> m (InUse, InUse, Graph VName)
forall (m :: * -> *).
LocalScope GPUMem m =>
LastUseMap -> Stms GPUMem -> m (InUse, InUse, Graph VName)
analyseGPU' LastUseMap
lumap (BodyT GPUMem -> Stms GPUMem
forall rep. BodyT rep -> Stms rep
bodyStms BodyT GPUMem
then_body)
        (InUse, InUse, Graph VName)
res2 <- LastUseMap -> Stms GPUMem -> m (InUse, InUse, Graph VName)
forall (m :: * -> *).
LocalScope GPUMem m =>
LastUseMap -> Stms GPUMem -> m (InUse, InUse, Graph VName)
analyseGPU' LastUseMap
lumap (BodyT GPUMem -> Stms GPUMem
forall rep. BodyT rep -> Stms rep
bodyStms BodyT GPUMem
else_body)
        (InUse, InUse, Graph VName) -> m (InUse, InUse, Graph VName)
forall (m :: * -> *) a. Monad m => a -> m a
return ((InUse, InUse, Graph VName)
res1 (InUse, InUse, Graph VName)
-> (InUse, InUse, Graph VName) -> (InUse, InUse, Graph VName)
forall a. Semigroup a => a -> a -> a
<> (InUse, InUse, Graph VName)
res2)
    helper stm :: Stm GPUMem
stm@Let {stmExp :: forall rep. Stm rep -> Exp rep
stmExp = DoLoop [(FParam GPUMem, SubExp)]
merge LoopForm GPUMem
_ BodyT GPUMem
body} =
      ((InUse, InUse, Graph VName) -> (InUse, InUse, Graph VName))
-> m (InUse, InUse, Graph VName) -> m (InUse, InUse, Graph VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ([(FParam GPUMem, SubExp)]
-> (InUse, InUse, Graph VName) -> (InUse, InUse, Graph VName)
analyseLoopParams [(FParam GPUMem, SubExp)]
merge) (m (InUse, InUse, Graph VName) -> m (InUse, InUse, Graph VName))
-> (m (InUse, InUse, Graph VName) -> m (InUse, InUse, Graph VName))
-> m (InUse, InUse, Graph VName)
-> m (InUse, InUse, Graph VName)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm GPUMem
-> m (InUse, InUse, Graph VName) -> m (InUse, InUse, Graph VName)
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stm GPUMem
stm (m (InUse, InUse, Graph VName) -> m (InUse, InUse, Graph VName))
-> m (InUse, InUse, Graph VName) -> m (InUse, InUse, Graph VName)
forall a b. (a -> b) -> a -> b
$
        LastUseMap -> Stms GPUMem -> m (InUse, InUse, Graph VName)
forall (m :: * -> *).
LocalScope GPUMem m =>
LastUseMap -> Stms GPUMem -> m (InUse, InUse, Graph VName)
analyseGPU' LastUseMap
lumap (Stms GPUMem -> m (InUse, InUse, Graph VName))
-> Stms GPUMem -> m (InUse, InUse, Graph VName)
forall a b. (a -> b) -> a -> b
$ BodyT GPUMem -> Stms GPUMem
forall rep. BodyT rep -> Stms rep
bodyStms BodyT GPUMem
body
    helper Stm GPUMem
stm =
      Stm GPUMem
-> m (InUse, InUse, Graph VName) -> m (InUse, InUse, Graph VName)
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stm GPUMem
stm (m (InUse, InUse, Graph VName) -> m (InUse, InUse, Graph VName))
-> m (InUse, InUse, Graph VName) -> m (InUse, InUse, Graph VName)
forall a b. (a -> b) -> a -> b
$ (InUse, InUse, Graph VName) -> m (InUse, InUse, Graph VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (InUse, InUse, Graph VName)
forall a. Monoid a => a
mempty

nameInfoToMemInfo :: Mem rep inner => NameInfo rep -> MemBound NoUniqueness
nameInfoToMemInfo :: NameInfo rep -> LetDecMem
nameInfoToMemInfo NameInfo rep
info =
  case NameInfo rep
info of
    FParamName FParamInfo rep
summary -> MemInfo SubExp Uniqueness MemBind -> LetDecMem
forall d u r. MemInfo d u r -> MemInfo d NoUniqueness r
noUniquenessReturns FParamInfo rep
MemInfo SubExp Uniqueness MemBind
summary
    LParamName LParamInfo rep
summary -> LParamInfo rep
LetDecMem
summary
    LetName LetDec rep
summary -> LetDec rep -> LetDecMem
forall t. HasLetDecMem t => t -> LetDecMem
letDecMem LetDec rep
summary
    IndexName IntType
it -> PrimType -> LetDecMem
forall d u ret. PrimType -> MemInfo d u ret
MemPrim (PrimType -> LetDecMem) -> PrimType -> LetDecMem
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
it

memInfo :: LocalScope GPUMem m => VName -> m (Maybe VName)
memInfo :: VName -> m (Maybe VName)
memInfo VName
vname = do
  Maybe LetDecMem
summary <- (Scope GPUMem -> Maybe LetDecMem) -> m (Maybe LetDecMem)
forall rep (m :: * -> *) a.
HasScope rep m =>
(Scope rep -> a) -> m a
asksScope ((NameInfo GPUMem -> LetDecMem)
-> Maybe (NameInfo GPUMem) -> Maybe LetDecMem
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap NameInfo GPUMem -> LetDecMem
forall rep inner. Mem rep inner => NameInfo rep -> LetDecMem
nameInfoToMemInfo (Maybe (NameInfo GPUMem) -> Maybe LetDecMem)
-> (Scope GPUMem -> Maybe (NameInfo GPUMem))
-> Scope GPUMem
-> Maybe LetDecMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Scope GPUMem -> Maybe (NameInfo GPUMem)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
vname)
  case Maybe LetDecMem
summary of
    Just (MemArray PrimType
_ Shape
_ NoUniqueness
_ (ArrayIn VName
mem IxFun
_)) ->
      Maybe VName -> m (Maybe VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe VName -> m (Maybe VName)) -> Maybe VName -> m (Maybe VName)
forall a b. (a -> b) -> a -> b
$ VName -> Maybe VName
forall a. a -> Maybe a
Just VName
mem
    Maybe LetDecMem
_ ->
      Maybe VName -> m (Maybe VName)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe VName
forall a. Maybe a
Nothing

-- | Returns a mapping from memory block to element size. The input is the
-- `VName` of a variable (supposedly an array), and the result is a mapping from
-- the memory block of that array to element size of the array.
memElemSize :: LocalScope GPUMem m => VName -> m (Map VName Int)
memElemSize :: VName -> m (Map VName Int)
memElemSize VName
vname = do
  Maybe LetDecMem
summary <- (Scope GPUMem -> Maybe LetDecMem) -> m (Maybe LetDecMem)
forall rep (m :: * -> *) a.
HasScope rep m =>
(Scope rep -> a) -> m a
asksScope ((NameInfo GPUMem -> LetDecMem)
-> Maybe (NameInfo GPUMem) -> Maybe LetDecMem
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap NameInfo GPUMem -> LetDecMem
forall rep inner. Mem rep inner => NameInfo rep -> LetDecMem
nameInfoToMemInfo (Maybe (NameInfo GPUMem) -> Maybe LetDecMem)
-> (Scope GPUMem -> Maybe (NameInfo GPUMem))
-> Scope GPUMem
-> Maybe LetDecMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Scope GPUMem -> Maybe (NameInfo GPUMem)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
vname)
  case Maybe LetDecMem
summary of
    Just (MemArray PrimType
pt Shape
_ NoUniqueness
_ (ArrayIn VName
mem IxFun
_)) ->
      Map VName Int -> m (Map VName Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (Map VName Int -> m (Map VName Int))
-> Map VName Int -> m (Map VName Int)
forall a b. (a -> b) -> a -> b
$ VName -> Int -> Map VName Int
forall k a. k -> a -> Map k a
M.singleton VName
mem (PrimType -> Int
forall a. Num a => PrimType -> a
primByteSize PrimType
pt)
    Maybe LetDecMem
_ ->
      Map VName Int -> m (Map VName Int)
forall (m :: * -> *) a. Monad m => a -> m a
return Map VName Int
forall a. Monoid a => a
mempty