{-# LANGUAGE TypeFamilies #-}

-- | Interference analysis for Futhark programs.
module Futhark.Analysis.Interference (Graph, analyseProgGPU) where

import Control.Monad.Reader
import Data.Foldable (toList)
import Data.Function ((&))
import Data.Functor ((<&>))
import Data.Map (Map)
import Data.Map qualified as M
import Data.Maybe (catMaybes, fromMaybe, mapMaybe)
import Data.Set (Set)
import Data.Set qualified as S
import Futhark.Analysis.LastUse (LastUseMap)
import Futhark.Analysis.LastUse qualified as LastUse
import Futhark.Analysis.MemAlias qualified as MemAlias
import Futhark.IR.GPUMem
import Futhark.Util (cartesian, invertMap)

-- | The set of 'VName' currently in use.
type InUse = Names

-- | The set of 'VName' that are no longer in use.
type LastUsed = Names

-- | An interference graph. An element @(x, y)@ in the set means that there is
-- an undirected edge between @x@ and @y@, and therefore the lifetimes of @x@
-- and @y@ overlap and they "interfere" with each other. We assume that pairs
-- are always normalized, such that @x@ < @y@, before inserting. This should
-- prevent any duplicates. We also don't allow any pairs where @x == y@.
type Graph a = Set (a, a)

-- | Insert an edge between two values into the graph.
makeEdge :: Ord a => a -> a -> Graph a
makeEdge :: forall a. Ord a => a -> a -> Graph a
makeEdge a
v1 a
v2
  | a
v1 forall a. Eq a => a -> a -> Bool
== a
v2 = forall a. Monoid a => a
mempty
  | Bool
otherwise = forall a. a -> Set a
S.singleton (forall a. Ord a => a -> a -> a
min a
v1 a
v2, forall a. Ord a => a -> a -> a
max a
v1 a
v2)

analyseStm ::
  LocalScope GPUMem m =>
  LastUseMap ->
  InUse ->
  Stm GPUMem ->
  m (InUse, LastUsed, Graph VName)
analyseStm :: forall (m :: * -> *).
LocalScope GPUMem m =>
LastUseMap -> Names -> Stm GPUMem -> m (Names, Names, Graph VName)
analyseStm LastUseMap
lumap Names
inuse0 Stm GPUMem
stm =
  forall {k} (rep :: k) a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stm GPUMem
stm forall a b. (a -> b) -> a -> b
$ do
    let pat_name :: VName
pat_name = forall dec. PatElem dec -> VName
patElemName forall a b. (a -> b) -> a -> b
$ forall a. [a] -> a
head forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [PatElem dec]
patElems forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Stm rep -> Pat (LetDec rep)
stmPat Stm GPUMem
stm

    Names
new_mems <-
      forall {k} (rep :: k). Stm rep -> Pat (LetDec rep)
stmPat Stm GPUMem
stm
        forall a b. a -> (a -> b) -> b
& forall dec. Pat dec -> [PatElem dec]
patElems
        forall a b. a -> (a -> b) -> b
& forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *).
LocalScope GPUMem m =>
VName -> m (Maybe VName)
memInfo forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. PatElem dec -> VName
patElemName)
        forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> forall a. [Maybe a] -> [a]
catMaybes
        forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> [VName] -> Names
namesFromList

    -- `new_mems` should interfere with any mems inside the statement expression
    let inuse_outside :: Names
inuse_outside = Names
inuse0 forall a. Semigroup a => a -> a -> a
<> Names
new_mems

    -- `inuse` is the set of memory blocks that are inuse at the end of any code
    -- bodies inside the expression. `lus` is the set of all memory blocks that
    -- have reached their last use in any code bodies inside the
    -- expression. `graph` is the interference graph computed for any code
    -- bodies inside the expression.
    (Names
inuse, Names
lus, Graph VName
graph) <- forall (m :: * -> *).
LocalScope GPUMem m =>
LastUseMap -> Names -> Exp GPUMem -> m (Names, Names, Graph VName)
analyseExp LastUseMap
lumap Names
inuse_outside (forall {k} (rep :: k). Stm rep -> Exp rep
stmExp Stm GPUMem
stm)

    Names
last_use_mems <-
      forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
pat_name LastUseMap
lumap
        forall a b. a -> (a -> b) -> b
& forall a. a -> Maybe a -> a
fromMaybe forall a. Monoid a => a
mempty
        forall a b. a -> (a -> b) -> b
& Names -> [VName]
namesToList
        forall a b. a -> (a -> b) -> b
& forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall (m :: * -> *).
LocalScope GPUMem m =>
VName -> m (Maybe VName)
memInfo
        forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> forall a. [Maybe a] -> [a]
catMaybes
        forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> [VName] -> Names
namesFromList
        forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> Names -> Names -> Names
namesIntersection Names
inuse_outside

    forall (f :: * -> *) a. Applicative f => a -> f a
pure
      ( (Names
inuse_outside Names -> Names -> Names
`namesSubtract` Names
last_use_mems Names -> Names -> Names
`namesSubtract` Names
lus)
          forall a. Semigroup a => a -> a -> a
<> Names
new_mems,
        (Names
lus forall a. Semigroup a => a -> a -> a
<> Names
last_use_mems) Names -> Names -> Names
`namesSubtract` Names
new_mems,
        Graph VName
graph
          forall a. Semigroup a => a -> a -> a
<> forall m (t :: * -> *) a.
(Monoid m, Foldable t) =>
(a -> a -> m) -> t a -> t a -> m
cartesian
            forall a. Ord a => a -> a -> Graph a
makeEdge
            (Names -> [VName]
namesToList Names
inuse_outside)
            (Names -> [VName]
namesToList forall a b. (a -> b) -> a -> b
$ Names
inuse_outside forall a. Semigroup a => a -> a -> a
<> Names
inuse forall a. Semigroup a => a -> a -> a
<> Names
lus forall a. Semigroup a => a -> a -> a
<> Names
last_use_mems)
      )

-- We conservatively treat all memory arguments to a DoLoop to
-- interfere with each other, as well as anything used inside the
-- loop.  This could potentially be improved by looking at the
-- interference computed by the loop body wrt. the loop arguments, but
-- probably very few programs would benefit from this.
analyseLoopParams ::
  [(FParam GPUMem, SubExp)] ->
  (InUse, LastUsed, Graph VName) ->
  (InUse, LastUsed, Graph VName)
analyseLoopParams :: [(FParam GPUMem, SubExp)]
-> (Names, Names, Graph VName) -> (Names, Names, Graph VName)
analyseLoopParams [(FParam GPUMem, SubExp)]
merge (Names
inuse, Names
lastused, Graph VName
graph) =
  (Names
inuse, Names
lastused, forall m (t :: * -> *) a.
(Monoid m, Foldable t) =>
(a -> a -> m) -> t a -> t a -> m
cartesian forall a. Ord a => a -> a -> Graph a
makeEdge [VName]
mems ([VName]
mems forall a. Semigroup a => a -> a -> a
<> [VName]
inner_mems) forall a. Semigroup a => a -> a -> a
<> Graph VName
graph)
  where
    mems :: [VName]
mems = forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe forall {d} {u} {ret}.
(Param (MemInfo d u ret), SubExp) -> Maybe VName
isMemArg [(FParam GPUMem, SubExp)]
merge
    inner_mems :: [VName]
inner_mems = Names -> [VName]
namesToList Names
lastused forall a. Semigroup a => a -> a -> a
<> Names -> [VName]
namesToList Names
inuse
    isMemArg :: (Param (MemInfo d u ret), SubExp) -> Maybe VName
isMemArg (Param Attrs
_ VName
_ MemMem {}, Var VName
v) = forall a. a -> Maybe a
Just VName
v
    isMemArg (Param (MemInfo d u ret), SubExp)
_ = forall a. Maybe a
Nothing

analyseExp ::
  LocalScope GPUMem m =>
  LastUseMap ->
  InUse ->
  Exp GPUMem ->
  m (InUse, LastUsed, Graph VName)
analyseExp :: forall (m :: * -> *).
LocalScope GPUMem m =>
LastUseMap -> Names -> Exp GPUMem -> m (Names, Names, Graph VName)
analyseExp LastUseMap
lumap Names
inuse_outside Exp GPUMem
expr =
  case Exp GPUMem
expr of
    Match [SubExp]
_ [Case (Body GPUMem)]
cases Body GPUMem
defbody MatchDec (BranchType GPUMem)
_ ->
      forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. Monoid a => [a] -> a
mconcat forall a b. (a -> b) -> a -> b
$
        forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *).
LocalScope GPUMem m =>
LastUseMap -> Names -> Body GPUMem -> m (Names, Names, Graph VName)
analyseBody LastUseMap
lumap Names
inuse_outside) forall a b. (a -> b) -> a -> b
$
          Body GPUMem
defbody forall a. a -> [a] -> [a]
: forall a b. (a -> b) -> [a] -> [b]
map forall body. Case body -> body
caseBody [Case (Body GPUMem)]
cases
    DoLoop [(FParam GPUMem, SubExp)]
merge LoopForm GPUMem
_ Body GPUMem
body ->
      [(FParam GPUMem, SubExp)]
-> (Names, Names, Graph VName) -> (Names, Names, Graph VName)
analyseLoopParams [(FParam GPUMem, SubExp)]
merge forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *).
LocalScope GPUMem m =>
LastUseMap -> Names -> Body GPUMem -> m (Names, Names, Graph VName)
analyseBody LastUseMap
lumap Names
inuse_outside Body GPUMem
body
    Op (Inner (SegOp SegOp SegLevel GPUMem
segop)) -> do
      forall (m :: * -> *) lvl.
LocalScope GPUMem m =>
LastUseMap
-> Names -> SegOp lvl GPUMem -> m (Names, Names, Graph VName)
analyseSegOp LastUseMap
lumap Names
inuse_outside SegOp SegLevel GPUMem
segop
    Exp GPUMem
_ ->
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Monoid a => a
mempty

analyseKernelBody ::
  LocalScope GPUMem m =>
  LastUseMap ->
  InUse ->
  KernelBody GPUMem ->
  m (InUse, LastUsed, Graph VName)
analyseKernelBody :: forall (m :: * -> *).
LocalScope GPUMem m =>
LastUseMap
-> Names -> KernelBody GPUMem -> m (Names, Names, Graph VName)
analyseKernelBody LastUseMap
lumap Names
inuse KernelBody GPUMem
body = forall (m :: * -> *).
LocalScope GPUMem m =>
LastUseMap -> Names -> Stms GPUMem -> m (Names, Names, Graph VName)
analyseStms LastUseMap
lumap Names
inuse forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
body

analyseBody ::
  LocalScope GPUMem m =>
  LastUseMap ->
  InUse ->
  Body GPUMem ->
  m (InUse, LastUsed, Graph VName)
analyseBody :: forall (m :: * -> *).
LocalScope GPUMem m =>
LastUseMap -> Names -> Body GPUMem -> m (Names, Names, Graph VName)
analyseBody LastUseMap
lumap Names
inuse Body GPUMem
body = forall (m :: * -> *).
LocalScope GPUMem m =>
LastUseMap -> Names -> Stms GPUMem -> m (Names, Names, Graph VName)
analyseStms LastUseMap
lumap Names
inuse forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Body rep -> Stms rep
bodyStms Body GPUMem
body

analyseStms ::
  LocalScope GPUMem m =>
  LastUseMap ->
  InUse ->
  Stms GPUMem ->
  m (InUse, LastUsed, Graph VName)
analyseStms :: forall (m :: * -> *).
LocalScope GPUMem m =>
LastUseMap -> Names -> Stms GPUMem -> m (Names, Names, Graph VName)
analyseStms LastUseMap
lumap Names
inuse0 Stms GPUMem
stms = do
  forall {k} (rep :: k) a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stms GPUMem
stms forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (Names, Names, Graph VName)
-> Stm GPUMem -> m (Names, Names, Graph VName)
helper (Names
inuse0, forall a. Monoid a => a
mempty, forall a. Monoid a => a
mempty) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Stms rep -> [Stm rep]
stmsToList Stms GPUMem
stms
  where
    helper :: (Names, Names, Graph VName)
-> Stm GPUMem -> m (Names, Names, Graph VName)
helper (Names
inuse, Names
lus, Graph VName
graph) Stm GPUMem
stm = do
      (Names
inuse', Names
lus', Graph VName
graph') <- forall (m :: * -> *).
LocalScope GPUMem m =>
LastUseMap -> Names -> Stm GPUMem -> m (Names, Names, Graph VName)
analyseStm LastUseMap
lumap Names
inuse Stm GPUMem
stm
      forall (f :: * -> *) a. Applicative f => a -> f a
pure (Names
inuse', Names
lus' forall a. Semigroup a => a -> a -> a
<> Names
lus, Graph VName
graph' forall a. Semigroup a => a -> a -> a
<> Graph VName
graph)

analyseSegOp ::
  LocalScope GPUMem m =>
  LastUseMap ->
  InUse ->
  SegOp lvl GPUMem ->
  m (InUse, LastUsed, Graph VName)
analyseSegOp :: forall (m :: * -> *) lvl.
LocalScope GPUMem m =>
LastUseMap
-> Names -> SegOp lvl GPUMem -> m (Names, Names, Graph VName)
analyseSegOp LastUseMap
lumap Names
inuse (SegMap lvl
_ SegSpace
_ [Type]
_ KernelBody GPUMem
body) =
  forall (m :: * -> *).
LocalScope GPUMem m =>
LastUseMap
-> Names -> KernelBody GPUMem -> m (Names, Names, Graph VName)
analyseKernelBody LastUseMap
lumap Names
inuse KernelBody GPUMem
body
analyseSegOp LastUseMap
lumap Names
inuse (SegRed lvl
_ SegSpace
_ [SegBinOp GPUMem]
binops [Type]
_ KernelBody GPUMem
body) =
  forall (m :: * -> *).
LocalScope GPUMem m =>
LastUseMap
-> Names
-> [SegBinOp GPUMem]
-> KernelBody GPUMem
-> m (Names, Names, Graph VName)
segWithBinOps LastUseMap
lumap Names
inuse [SegBinOp GPUMem]
binops KernelBody GPUMem
body
analyseSegOp LastUseMap
lumap Names
inuse (SegScan lvl
_ SegSpace
_ [SegBinOp GPUMem]
binops [Type]
_ KernelBody GPUMem
body) = do
  forall (m :: * -> *).
LocalScope GPUMem m =>
LastUseMap
-> Names
-> [SegBinOp GPUMem]
-> KernelBody GPUMem
-> m (Names, Names, Graph VName)
segWithBinOps LastUseMap
lumap Names
inuse [SegBinOp GPUMem]
binops KernelBody GPUMem
body
analyseSegOp LastUseMap
lumap Names
inuse (SegHist lvl
_ SegSpace
_ [HistOp GPUMem]
histops [Type]
_ KernelBody GPUMem
body) = do
  (Names
inuse', Names
lus', Graph VName
graph) <- forall (m :: * -> *).
LocalScope GPUMem m =>
LastUseMap
-> Names -> KernelBody GPUMem -> m (Names, Names, Graph VName)
analyseKernelBody LastUseMap
lumap Names
inuse KernelBody GPUMem
body
  (Names
inuse'', Names
lus'', Graph VName
graph') <- forall a. Monoid a => [a] -> a
mconcat forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *).
LocalScope GPUMem m =>
LastUseMap
-> Names -> HistOp GPUMem -> m (Names, Names, Graph VName)
analyseHistOp LastUseMap
lumap Names
inuse') [HistOp GPUMem]
histops
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (Names
inuse'', Names
lus' forall a. Semigroup a => a -> a -> a
<> Names
lus'', Graph VName
graph forall a. Semigroup a => a -> a -> a
<> Graph VName
graph')

segWithBinOps ::
  LocalScope GPUMem m =>
  LastUseMap ->
  InUse ->
  [SegBinOp GPUMem] ->
  KernelBody GPUMem ->
  m (InUse, LastUsed, Graph VName)
segWithBinOps :: forall (m :: * -> *).
LocalScope GPUMem m =>
LastUseMap
-> Names
-> [SegBinOp GPUMem]
-> KernelBody GPUMem
-> m (Names, Names, Graph VName)
segWithBinOps LastUseMap
lumap Names
inuse [SegBinOp GPUMem]
binops KernelBody GPUMem
body = do
  (Names
inuse', Names
lus', Graph VName
graph) <- forall (m :: * -> *).
LocalScope GPUMem m =>
LastUseMap
-> Names -> KernelBody GPUMem -> m (Names, Names, Graph VName)
analyseKernelBody LastUseMap
lumap Names
inuse KernelBody GPUMem
body
  (Names
inuse'', Names
lus'', Graph VName
graph') <-
    forall a. Monoid a => [a] -> a
mconcat
      forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM
        (forall (m :: * -> *).
LocalScope GPUMem m =>
LastUseMap
-> Names -> SegBinOp GPUMem -> m (Names, Names, Graph VName)
analyseSegBinOp LastUseMap
lumap Names
inuse')
        [SegBinOp GPUMem]
binops
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (Names
inuse'', Names
lus' forall a. Semigroup a => a -> a -> a
<> Names
lus'', Graph VName
graph forall a. Semigroup a => a -> a -> a
<> Graph VName
graph')

analyseSegBinOp ::
  LocalScope GPUMem m =>
  LastUseMap ->
  InUse ->
  SegBinOp GPUMem ->
  m (InUse, LastUsed, Graph VName)
analyseSegBinOp :: forall (m :: * -> *).
LocalScope GPUMem m =>
LastUseMap
-> Names -> SegBinOp GPUMem -> m (Names, Names, Graph VName)
analyseSegBinOp LastUseMap
lumap Names
inuse (SegBinOp Commutativity
_ Lambda GPUMem
lambda [SubExp]
_ Shape
_) =
  forall (m :: * -> *).
LocalScope GPUMem m =>
LastUseMap
-> Names -> Lambda GPUMem -> m (Names, Names, Graph VName)
analyseLambda LastUseMap
lumap Names
inuse Lambda GPUMem
lambda

analyseHistOp ::
  LocalScope GPUMem m =>
  LastUseMap ->
  InUse ->
  HistOp GPUMem ->
  m (InUse, LastUsed, Graph VName)
analyseHistOp :: forall (m :: * -> *).
LocalScope GPUMem m =>
LastUseMap
-> Names -> HistOp GPUMem -> m (Names, Names, Graph VName)
analyseHistOp LastUseMap
lumap Names
inuse HistOp GPUMem
histop =
  forall (m :: * -> *).
LocalScope GPUMem m =>
LastUseMap
-> Names -> Lambda GPUMem -> m (Names, Names, Graph VName)
analyseLambda LastUseMap
lumap Names
inuse (forall {k} (rep :: k). HistOp rep -> Lambda rep
histOp HistOp GPUMem
histop)

analyseLambda ::
  LocalScope GPUMem m =>
  LastUseMap ->
  InUse ->
  Lambda GPUMem ->
  m (InUse, LastUsed, Graph VName)
analyseLambda :: forall (m :: * -> *).
LocalScope GPUMem m =>
LastUseMap
-> Names -> Lambda GPUMem -> m (Names, Names, Graph VName)
analyseLambda LastUseMap
lumap Names
inuse (Lambda [LParam GPUMem]
_ Body GPUMem
body [Type]
_) =
  forall (m :: * -> *).
LocalScope GPUMem m =>
LastUseMap -> Names -> Body GPUMem -> m (Names, Names, Graph VName)
analyseBody LastUseMap
lumap Names
inuse Body GPUMem
body

analyseProgGPU :: Prog GPUMem -> Graph VName
analyseProgGPU :: Prog GPUMem -> Graph VName
analyseProgGPU Prog GPUMem
prog =
  MemAliases -> Graph VName -> Graph VName
applyAliases (Prog GPUMem -> MemAliases
MemAlias.analyzeGPUMem Prog GPUMem
prog) forall a b. (a -> b) -> a -> b
$
    Stms GPUMem -> Graph VName
onConsts (forall {k} (rep :: k). Prog rep -> Stms rep
progConsts Prog GPUMem
prog) forall a. Semigroup a => a -> a -> a
<> forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap FunDef GPUMem -> Graph VName
onFun (forall {k} (rep :: k). Prog rep -> [FunDef rep]
progFuns Prog GPUMem
prog)
  where
    (LastUseMap
lumap, Names
_) = Prog GPUMem -> (LastUseMap, Names)
LastUse.analyseGPUMem Prog GPUMem
prog
    onFun :: FunDef GPUMem -> Graph VName
onFun FunDef GPUMem
f =
      forall r a. Reader r a -> r -> a
runReader (forall (m :: * -> *).
LocalScope GPUMem m =>
LastUseMap -> Stms GPUMem -> m (Graph VName)
analyseGPU LastUseMap
lumap forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Body rep -> Stms rep
bodyStms forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). FunDef rep -> Body rep
funDefBody FunDef GPUMem
f) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) a. Scoped rep a => a -> Scope rep
scopeOf FunDef GPUMem
f
    onConsts :: Stms GPUMem -> Graph VName
onConsts Stms GPUMem
stms =
      forall r a. Reader r a -> r -> a
runReader (forall (m :: * -> *).
LocalScope GPUMem m =>
LastUseMap -> Stms GPUMem -> m (Graph VName)
analyseGPU LastUseMap
lumap Stms GPUMem
stms) (forall a. Monoid a => a
mempty :: Scope GPUMem)

applyAliases :: MemAlias.MemAliases -> Graph VName -> Graph VName
applyAliases :: MemAliases -> Graph VName -> Graph VName
applyAliases MemAliases
aliases =
  -- For each pair @(x, y)@ in graph, all memory aliases of x should interfere with all memory aliases of y
  forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap
    ( \(VName
x, VName
y) ->
        let xs :: Names
xs = MemAliases -> VName -> Names
MemAlias.aliasesOf MemAliases
aliases VName
x forall a. Semigroup a => a -> a -> a
<> VName -> Names
oneName VName
x
            ys :: Names
ys = MemAliases -> VName -> Names
MemAlias.aliasesOf MemAliases
aliases VName
y forall a. Semigroup a => a -> a -> a
<> VName -> Names
oneName VName
y
         in forall m (t :: * -> *) a.
(Monoid m, Foldable t) =>
(a -> a -> m) -> t a -> t a -> m
cartesian forall a. Ord a => a -> a -> Graph a
makeEdge (Names -> [VName]
namesToList Names
xs) (Names -> [VName]
namesToList Names
ys)
    )

-- | Perform interference analysis on the given statements. The result is a
-- triple of the names currently in use, names that hit their last use somewhere
-- within, and the resulting graph.
analyseGPU ::
  LocalScope GPUMem m =>
  LastUseMap ->
  Stms GPUMem ->
  m (Graph VName)
analyseGPU :: forall (m :: * -> *).
LocalScope GPUMem m =>
LastUseMap -> Stms GPUMem -> m (Graph VName)
analyseGPU LastUseMap
lumap Stms GPUMem
stms = do
  (Names
_, Names
_, Graph VName
graph) <- forall (m :: * -> *).
LocalScope GPUMem m =>
LastUseMap -> Stms GPUMem -> m (Names, Names, Graph VName)
analyseGPU' LastUseMap
lumap Stms GPUMem
stms
  -- We need to insert edges between memory blocks which differ in size, if they
  -- are in DefaultSpace. The problem is that during memory expansion,
  -- DefaultSpace arrays in kernels are interleaved. If the element sizes of two
  -- merged memory blocks are different, threads might try to read and write to
  -- overlapping memory positions. More information here:
  -- https://munksgaard.me/technical-diary/2020-12-30.html#org210775b
  Map VName Space
spaces <- forall a k. (a -> Bool) -> Map k a -> Map k a
M.filter (forall a. Eq a => a -> a -> Bool
== Space
DefaultSpace) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *).
LocalScope GPUMem m =>
Stms GPUMem -> m (Map VName Space)
memSpaces Stms GPUMem
stms
  Map Int (Set VName)
inv_size_map <-
    forall (m :: * -> *).
LocalScope GPUMem m =>
Stms GPUMem -> m (Map VName Int)
memSizes Stms GPUMem
stms
      forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> forall a b c. (a -> b -> c) -> b -> a -> c
flip forall k a. Ord k => Map k a -> Set k -> Map k a
M.restrictKeys (forall a. Ord a => [a] -> Set a
S.fromList forall a b. (a -> b) -> a -> b
$ forall k a. Map k a -> [k]
M.keys Map VName Space
spaces)
      forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> forall v k. (Ord v, Ord k) => Map k v -> Map v (Set k)
invertMap
  let new_edges :: Graph VName
new_edges =
        forall m (t :: * -> *) a.
(Monoid m, Foldable t) =>
(a -> a -> m) -> t a -> t a -> m
cartesian
          (\Set VName
x Set VName
y -> if Set VName
x forall a. Eq a => a -> a -> Bool
/= Set VName
y then forall m (t :: * -> *) a.
(Monoid m, Foldable t) =>
(a -> a -> m) -> t a -> t a -> m
cartesian forall a. Ord a => a -> a -> Graph a
makeEdge Set VName
x Set VName
y else forall a. Monoid a => a
mempty)
          Map Int (Set VName)
inv_size_map
          Map Int (Set VName)
inv_size_map
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Graph VName
graph forall a. Semigroup a => a -> a -> a
<> Graph VName
new_edges

-- | Return a mapping from memory blocks to their element sizes in the given
-- statements.
memSizes :: LocalScope GPUMem m => Stms GPUMem -> m (Map VName Int)
memSizes :: forall (m :: * -> *).
LocalScope GPUMem m =>
Stms GPUMem -> m (Map VName Int)
memSizes Stms GPUMem
stms =
  forall {k} (rep :: k) a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stms GPUMem
stms forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. Monoid a => [a] -> a
mconcat forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall (m :: * -> *).
LocalScope GPUMem m =>
Stm GPUMem -> m (Map VName Int)
memSizesStm forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Stms rep -> [Stm rep]
stmsToList Stms GPUMem
stms
  where
    memSizesStm :: LocalScope GPUMem m => Stm GPUMem -> m (Map VName Int)
    memSizesStm :: forall (m :: * -> *).
LocalScope GPUMem m =>
Stm GPUMem -> m (Map VName Int)
memSizesStm (Let Pat (LetDec GPUMem)
pat StmAux (ExpDec GPUMem)
_ Exp GPUMem
e) = do
      Map VName Int
arraySizes <- forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. Monoid a => [a] -> a
mconcat forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall (m :: * -> *).
LocalScope GPUMem m =>
VName -> m (Map VName Int)
memElemSize forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [VName]
patNames Pat (LetDec GPUMem)
pat
      Map VName Int
arraySizes' <- forall (m :: * -> *).
LocalScope GPUMem m =>
Exp GPUMem -> m (Map VName Int)
memSizesExp Exp GPUMem
e
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Map VName Int
arraySizes forall a. Semigroup a => a -> a -> a
<> Map VName Int
arraySizes'
    memSizesExp :: LocalScope GPUMem m => Exp GPUMem -> m (Map VName Int)
    memSizesExp :: forall (m :: * -> *).
LocalScope GPUMem m =>
Exp GPUMem -> m (Map VName Int)
memSizesExp (Op (Inner (SegOp SegOp SegLevel GPUMem
segop))) =
      let body :: KernelBody GPUMem
body = forall {k} lvl (rep :: k). SegOp lvl rep -> KernelBody rep
segBody SegOp SegLevel GPUMem
segop
       in forall {k} (rep :: k) a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf (forall {k} (rep :: k). KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
body)
            forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. Monoid a => [a] -> a
mconcat
              forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall (m :: * -> *).
LocalScope GPUMem m =>
Stm GPUMem -> m (Map VName Int)
memSizesStm
            forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Stms rep -> [Stm rep]
stmsToList
            forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
body
    memSizesExp (Match [SubExp]
_ [Case (Body GPUMem)]
cases Body GPUMem
defbody MatchDec (BranchType GPUMem)
_) = do
      forall a. Monoid a => [a] -> a
mconcat forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *).
LocalScope GPUMem m =>
Stms GPUMem -> m (Map VName Int)
memSizes forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Body rep -> Stms rep
bodyStms) (Body GPUMem
defbody forall a. a -> [a] -> [a]
: forall a b. (a -> b) -> [a] -> [b]
map forall body. Case body -> body
caseBody [Case (Body GPUMem)]
cases)
    memSizesExp (DoLoop [(FParam GPUMem, SubExp)]
_ LoopForm GPUMem
_ Body GPUMem
body) =
      forall (m :: * -> *).
LocalScope GPUMem m =>
Stms GPUMem -> m (Map VName Int)
memSizes forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Body rep -> Stms rep
bodyStms Body GPUMem
body
    memSizesExp Exp GPUMem
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Monoid a => a
mempty

-- | Return a mapping from memory blocks to the space they are allocated in.
memSpaces :: LocalScope GPUMem m => Stms GPUMem -> m (Map VName Space)
memSpaces :: forall (m :: * -> *).
LocalScope GPUMem m =>
Stms GPUMem -> m (Map VName Space)
memSpaces Stms GPUMem
stms =
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm GPUMem -> Map VName Space
getSpacesStm Stms GPUMem
stms
  where
    getSpacesStm :: Stm GPUMem -> Map VName Space
    getSpacesStm :: Stm GPUMem -> Map VName Space
getSpacesStm (Let (Pat [PatElem VName
name LetDec GPUMem
_]) StmAux (ExpDec GPUMem)
_ (Op (Alloc SubExp
_ Space
sp))) =
      forall k a. k -> a -> Map k a
M.singleton VName
name Space
sp
    getSpacesStm (Let Pat (LetDec GPUMem)
_ StmAux (ExpDec GPUMem)
_ (Op (Alloc SubExp
_ Space
_))) = forall a. HasCallStack => [Char] -> a
error [Char]
"impossible"
    getSpacesStm (Let Pat (LetDec GPUMem)
_ StmAux (ExpDec GPUMem)
_ (Op (Inner (SegOp SegOp SegLevel GPUMem
segop)))) =
      forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm GPUMem -> Map VName Space
getSpacesStm forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). KernelBody rep -> Stms rep
kernelBodyStms forall a b. (a -> b) -> a -> b
$ forall {k} lvl (rep :: k). SegOp lvl rep -> KernelBody rep
segBody SegOp SegLevel GPUMem
segop
    getSpacesStm (Let Pat (LetDec GPUMem)
_ StmAux (ExpDec GPUMem)
_ (Match [SubExp]
_ [Case (Body GPUMem)]
cases Body GPUMem
defbody MatchDec (BranchType GPUMem)
_)) =
      forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm GPUMem -> Map VName Space
getSpacesStm forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Body rep -> Stms rep
bodyStms) forall a b. (a -> b) -> a -> b
$ Body GPUMem
defbody forall a. a -> [a] -> [a]
: forall a b. (a -> b) -> [a] -> [b]
map forall body. Case body -> body
caseBody [Case (Body GPUMem)]
cases
    getSpacesStm (Let Pat (LetDec GPUMem)
_ StmAux (ExpDec GPUMem)
_ (DoLoop [(FParam GPUMem, SubExp)]
_ LoopForm GPUMem
_ Body GPUMem
body)) =
      forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm GPUMem -> Map VName Space
getSpacesStm (forall {k} (rep :: k). Body rep -> Stms rep
bodyStms Body GPUMem
body)
    getSpacesStm Stm GPUMem
_ = forall a. Monoid a => a
mempty

analyseGPU' ::
  LocalScope GPUMem m =>
  LastUseMap ->
  Stms GPUMem ->
  m (InUse, LastUsed, Graph VName)
analyseGPU' :: forall (m :: * -> *).
LocalScope GPUMem m =>
LastUseMap -> Stms GPUMem -> m (Names, Names, Graph VName)
analyseGPU' LastUseMap
lumap Stms GPUMem
stms =
  forall a. Monoid a => [a] -> a
mconcat forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) a. Foldable t => t a -> [a]
toList forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall (m :: * -> *).
LocalScope GPUMem m =>
Stm GPUMem -> m (Names, Names, Graph VName)
helper Stms GPUMem
stms
  where
    helper ::
      LocalScope GPUMem m =>
      Stm GPUMem ->
      m (InUse, LastUsed, Graph VName)
    helper :: forall (m :: * -> *).
LocalScope GPUMem m =>
Stm GPUMem -> m (Names, Names, Graph VName)
helper stm :: Stm GPUMem
stm@Let {stmExp :: forall {k} (rep :: k). Stm rep -> Exp rep
stmExp = Op (Inner (SegOp SegOp SegLevel GPUMem
segop))} =
      forall {k} (rep :: k) a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stm GPUMem
stm forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) lvl.
LocalScope GPUMem m =>
LastUseMap
-> Names -> SegOp lvl GPUMem -> m (Names, Names, Graph VName)
analyseSegOp LastUseMap
lumap forall a. Monoid a => a
mempty SegOp SegLevel GPUMem
segop
    helper stm :: Stm GPUMem
stm@Let {stmExp :: forall {k} (rep :: k). Stm rep -> Exp rep
stmExp = Match [SubExp]
_ [Case (Body GPUMem)]
cases Body GPUMem
defbody MatchDec (BranchType GPUMem)
_} =
      forall {k} (rep :: k) a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stm GPUMem
stm forall a b. (a -> b) -> a -> b
$
        forall a. Monoid a => [a] -> a
mconcat
          forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *).
LocalScope GPUMem m =>
LastUseMap -> Stms GPUMem -> m (Names, Names, Graph VName)
analyseGPU' LastUseMap
lumap forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Body rep -> Stms rep
bodyStms) (Body GPUMem
defbody forall a. a -> [a] -> [a]
: forall a b. (a -> b) -> [a] -> [b]
map forall body. Case body -> body
caseBody [Case (Body GPUMem)]
cases)
    helper stm :: Stm GPUMem
stm@Let {stmExp :: forall {k} (rep :: k). Stm rep -> Exp rep
stmExp = DoLoop [(FParam GPUMem, SubExp)]
merge LoopForm GPUMem
_ Body GPUMem
body} =
      forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ([(FParam GPUMem, SubExp)]
-> (Names, Names, Graph VName) -> (Names, Names, Graph VName)
analyseLoopParams [(FParam GPUMem, SubExp)]
merge) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k) a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stm GPUMem
stm forall a b. (a -> b) -> a -> b
$
        forall (m :: * -> *).
LocalScope GPUMem m =>
LastUseMap -> Stms GPUMem -> m (Names, Names, Graph VName)
analyseGPU' LastUseMap
lumap forall a b. (a -> b) -> a -> b
$
          forall {k} (rep :: k). Body rep -> Stms rep
bodyStms Body GPUMem
body
    helper Stm GPUMem
stm =
      forall {k} (rep :: k) a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stm GPUMem
stm forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Monoid a => a
mempty

nameInfoToMemInfo :: Mem rep inner => NameInfo rep -> MemBound NoUniqueness
nameInfoToMemInfo :: forall {k} (rep :: k) inner.
Mem rep inner =>
NameInfo rep -> LetDecMem
nameInfoToMemInfo NameInfo rep
info =
  case NameInfo rep
info of
    FParamName FParamInfo rep
summary -> forall d u r. MemInfo d u r -> MemInfo d NoUniqueness r
noUniquenessReturns FParamInfo rep
summary
    LParamName LParamInfo rep
summary -> LParamInfo rep
summary
    LetName LetDec rep
summary -> forall t. HasLetDecMem t => t -> LetDecMem
letDecMem LetDec rep
summary
    IndexName IntType
it -> forall d u ret. PrimType -> MemInfo d u ret
MemPrim forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
it

memInfo :: LocalScope GPUMem m => VName -> m (Maybe VName)
memInfo :: forall (m :: * -> *).
LocalScope GPUMem m =>
VName -> m (Maybe VName)
memInfo VName
vname = do
  Maybe LetDecMem
summary <- forall {k} (rep :: k) (m :: * -> *) a.
HasScope rep m =>
(Scope rep -> a) -> m a
asksScope (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall {k} (rep :: k) inner.
Mem rep inner =>
NameInfo rep -> LetDecMem
nameInfoToMemInfo forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
vname)
  case Maybe LetDecMem
summary of
    Just (MemArray PrimType
_ Shape
_ NoUniqueness
_ (ArrayIn VName
mem IxFun
_)) ->
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just VName
mem
    Maybe LetDecMem
_ ->
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing

-- | Returns a mapping from memory block to element size. The input is the
-- `VName` of a variable (supposedly an array), and the result is a mapping from
-- the memory block of that array to element size of the array.
memElemSize :: LocalScope GPUMem m => VName -> m (Map VName Int)
memElemSize :: forall (m :: * -> *).
LocalScope GPUMem m =>
VName -> m (Map VName Int)
memElemSize VName
vname = do
  Maybe LetDecMem
summary <- forall {k} (rep :: k) (m :: * -> *) a.
HasScope rep m =>
(Scope rep -> a) -> m a
asksScope (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall {k} (rep :: k) inner.
Mem rep inner =>
NameInfo rep -> LetDecMem
nameInfoToMemInfo forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
vname)
  case Maybe LetDecMem
summary of
    Just (MemArray PrimType
pt Shape
_ NoUniqueness
_ (ArrayIn VName
mem IxFun
_)) ->
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall k a. k -> a -> Map k a
M.singleton VName
mem (forall a. Num a => PrimType -> a
primByteSize PrimType
pt)
    Maybe LetDecMem
_ ->
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Monoid a => a
mempty