{-# LANGUAGE TypeFamilies #-}

-- | Interference analysis for Futhark programs.
module Futhark.Analysis.Interference (Graph, analyseProgGPU) where

import Control.Monad
import Control.Monad.Reader
import Data.Foldable (toList)
import Data.Function ((&))
import Data.Functor ((<&>))
import Data.Map (Map)
import Data.Map qualified as M
import Data.Maybe (catMaybes, fromMaybe, mapMaybe)
import Data.Set (Set)
import Data.Set qualified as S
import Futhark.Analysis.Alias qualified as AnlAls
import Futhark.Analysis.LastUse (LUTabFun)
import Futhark.Analysis.LastUse qualified as LastUse
import Futhark.Analysis.MemAlias qualified as MemAlias
import Futhark.IR.GPUMem
import Futhark.Util (cartesian, invertMap)

-- | The set of 'VName' currently in use.
type InUse = Names

-- | The set of 'VName' that are no longer in use.
type LastUsed = Names

-- | An interference graph. An element @(x, y)@ in the set means that there is
-- an undirected edge between @x@ and @y@, and therefore the lifetimes of @x@
-- and @y@ overlap and they "interfere" with each other. We assume that pairs
-- are always normalized, such that @x@ < @y@, before inserting. This should
-- prevent any duplicates. We also don't allow any pairs where @x == y@.
type Graph a = Set (a, a)

-- | Insert an edge between two values into the graph.
makeEdge :: (Ord a) => a -> a -> Graph a
makeEdge :: forall a. Ord a => a -> a -> Graph a
makeEdge a
v1 a
v2
  | a
v1 a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
v2 = Graph a
forall a. Monoid a => a
mempty
  | Bool
otherwise = (a, a) -> Graph a
forall a. a -> Set a
S.singleton (a -> a -> a
forall a. Ord a => a -> a -> a
min a
v1 a
v2, a -> a -> a
forall a. Ord a => a -> a -> a
max a
v1 a
v2)

analyseStm ::
  (LocalScope GPUMem m) =>
  LUTabFun ->
  InUse ->
  Stm GPUMem ->
  m (InUse, LastUsed, Graph VName)
analyseStm :: forall (m :: * -> *).
LocalScope GPUMem m =>
LUTabFun -> Names -> Stm GPUMem -> m (Names, Names, Graph VName)
analyseStm LUTabFun
lumap Names
inuse0 Stm GPUMem
stm =
  Stm GPUMem
-> m (Names, Names, Graph VName) -> m (Names, Names, Graph VName)
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stm GPUMem
stm (m (Names, Names, Graph VName) -> m (Names, Names, Graph VName))
-> m (Names, Names, Graph VName) -> m (Names, Names, Graph VName)
forall a b. (a -> b) -> a -> b
$ do
    let pat_name :: VName
pat_name = PatElem LetDecMem -> VName
forall dec. PatElem dec -> VName
patElemName (PatElem LetDecMem -> VName) -> PatElem LetDecMem -> VName
forall a b. (a -> b) -> a -> b
$ [PatElem LetDecMem] -> PatElem LetDecMem
forall a. HasCallStack => [a] -> a
head ([PatElem LetDecMem] -> PatElem LetDecMem)
-> [PatElem LetDecMem] -> PatElem LetDecMem
forall a b. (a -> b) -> a -> b
$ Pat LetDecMem -> [PatElem LetDecMem]
forall dec. Pat dec -> [PatElem dec]
patElems (Pat LetDecMem -> [PatElem LetDecMem])
-> Pat LetDecMem -> [PatElem LetDecMem]
forall a b. (a -> b) -> a -> b
$ Stm GPUMem -> Pat (LetDec GPUMem)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm GPUMem
stm

    Names
new_mems <-
      Stm GPUMem -> Pat (LetDec GPUMem)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm GPUMem
stm
        Pat LetDecMem
-> (Pat LetDecMem -> [PatElem LetDecMem]) -> [PatElem LetDecMem]
forall a b. a -> (a -> b) -> b
& Pat LetDecMem -> [PatElem LetDecMem]
forall dec. Pat dec -> [PatElem dec]
patElems
        [PatElem LetDecMem]
-> ([PatElem LetDecMem] -> m [Maybe VName]) -> m [Maybe VName]
forall a b. a -> (a -> b) -> b
& (PatElem LetDecMem -> m (Maybe VName))
-> [PatElem LetDecMem] -> m [Maybe VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (VName -> m (Maybe VName)
forall (m :: * -> *).
LocalScope GPUMem m =>
VName -> m (Maybe VName)
memInfo (VName -> m (Maybe VName))
-> (PatElem LetDecMem -> VName)
-> PatElem LetDecMem
-> m (Maybe VName)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElem LetDecMem -> VName
forall dec. PatElem dec -> VName
patElemName)
        m [Maybe VName] -> ([Maybe VName] -> [VName]) -> m [VName]
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> [Maybe VName] -> [VName]
forall a. [Maybe a] -> [a]
catMaybes
        m [VName] -> ([VName] -> Names) -> m Names
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> [VName] -> Names
namesFromList

    -- `new_mems` should interfere with any mems inside the statement expression
    let inuse_outside :: Names
inuse_outside = Names
inuse0 Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
new_mems

    -- `inuse` is the set of memory blocks that are inuse at the end of any code
    -- bodies inside the expression. `lus` is the set of all memory blocks that
    -- have reached their last use in any code bodies inside the
    -- expression. `graph` is the interference graph computed for any code
    -- bodies inside the expression.
    (Names
inuse, Names
lus, Graph VName
graph) <- LUTabFun -> Names -> Exp GPUMem -> m (Names, Names, Graph VName)
forall (m :: * -> *).
LocalScope GPUMem m =>
LUTabFun -> Names -> Exp GPUMem -> m (Names, Names, Graph VName)
analyseExp LUTabFun
lumap Names
inuse_outside (Stm GPUMem -> Exp GPUMem
forall rep. Stm rep -> Exp rep
stmExp Stm GPUMem
stm)

    Names
last_use_mems <-
      VName -> LUTabFun -> Maybe Names
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
pat_name LUTabFun
lumap
        Maybe Names -> (Maybe Names -> Names) -> Names
forall a b. a -> (a -> b) -> b
& Names -> Maybe Names -> Names
forall a. a -> Maybe a -> a
fromMaybe Names
forall a. Monoid a => a
mempty
        Names -> (Names -> [VName]) -> [VName]
forall a b. a -> (a -> b) -> b
& Names -> [VName]
namesToList
        [VName] -> ([VName] -> m [Maybe VName]) -> m [Maybe VName]
forall a b. a -> (a -> b) -> b
& (VName -> m (Maybe VName)) -> [VName] -> m [Maybe VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> m (Maybe VName)
forall (m :: * -> *).
LocalScope GPUMem m =>
VName -> m (Maybe VName)
memInfo
        m [Maybe VName] -> ([Maybe VName] -> [VName]) -> m [VName]
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> [Maybe VName] -> [VName]
forall a. [Maybe a] -> [a]
catMaybes
        m [VName] -> ([VName] -> Names) -> m Names
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> [VName] -> Names
namesFromList
        m Names -> (Names -> Names) -> m Names
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> Names -> Names -> Names
namesIntersection Names
inuse_outside

    (Names, Names, Graph VName) -> m (Names, Names, Graph VName)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
      ( (Names
inuse_outside Names -> Names -> Names
`namesSubtract` Names
last_use_mems Names -> Names -> Names
`namesSubtract` Names
lus)
          Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
new_mems,
        (Names
lus Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
last_use_mems) Names -> Names -> Names
`namesSubtract` Names
new_mems,
        Graph VName
graph
          Graph VName -> Graph VName -> Graph VName
forall a. Semigroup a => a -> a -> a
<> (VName -> VName -> Graph VName)
-> [VName] -> [VName] -> Graph VName
forall m (t :: * -> *) a.
(Monoid m, Foldable t) =>
(a -> a -> m) -> t a -> t a -> m
cartesian
            VName -> VName -> Graph VName
forall a. Ord a => a -> a -> Graph a
makeEdge
            (Names -> [VName]
namesToList Names
inuse_outside)
            (Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ Names
inuse_outside Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
inuse Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
lus Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
last_use_mems)
      )

-- We conservatively treat all memory arguments to a Loop to
-- interfere with each other, as well as anything used inside the
-- loop.  This could potentially be improved by looking at the
-- interference computed by the loop body wrt. the loop arguments, but
-- probably very few programs would benefit from this.
analyseLoopParams ::
  [(FParam GPUMem, SubExp)] ->
  (InUse, LastUsed, Graph VName) ->
  (InUse, LastUsed, Graph VName)
analyseLoopParams :: [(FParam GPUMem, SubExp)]
-> (Names, Names, Graph VName) -> (Names, Names, Graph VName)
analyseLoopParams [(FParam GPUMem, SubExp)]
merge (Names
inuse, Names
lastused, Graph VName
graph) =
  (Names
inuse, Names
lastused, (VName -> VName -> Graph VName)
-> [VName] -> [VName] -> Graph VName
forall m (t :: * -> *) a.
(Monoid m, Foldable t) =>
(a -> a -> m) -> t a -> t a -> m
cartesian VName -> VName -> Graph VName
forall a. Ord a => a -> a -> Graph a
makeEdge [VName]
mems ([VName]
mems [VName] -> [VName] -> [VName]
forall a. Semigroup a => a -> a -> a
<> [VName]
inner_mems) Graph VName -> Graph VName -> Graph VName
forall a. Semigroup a => a -> a -> a
<> Graph VName
graph)
  where
    mems :: [VName]
mems = ((Param (MemInfo SubExp Uniqueness MemBind), SubExp)
 -> Maybe VName)
-> [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)] -> [VName]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (Param (MemInfo SubExp Uniqueness MemBind), SubExp) -> Maybe VName
forall {d} {u} {ret}.
(Param (MemInfo d u ret), SubExp) -> Maybe VName
isMemArg [(FParam GPUMem, SubExp)]
[(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
merge
    inner_mems :: [VName]
inner_mems = Names -> [VName]
namesToList Names
lastused [VName] -> [VName] -> [VName]
forall a. Semigroup a => a -> a -> a
<> Names -> [VName]
namesToList Names
inuse
    isMemArg :: (Param (MemInfo d u ret), SubExp) -> Maybe VName
isMemArg (Param Attrs
_ VName
_ MemMem {}, Var VName
v) = VName -> Maybe VName
forall a. a -> Maybe a
Just VName
v
    isMemArg (Param (MemInfo d u ret), SubExp)
_ = Maybe VName
forall a. Maybe a
Nothing

analyseExp ::
  (LocalScope GPUMem m) =>
  LUTabFun ->
  InUse ->
  Exp GPUMem ->
  m (InUse, LastUsed, Graph VName)
analyseExp :: forall (m :: * -> *).
LocalScope GPUMem m =>
LUTabFun -> Names -> Exp GPUMem -> m (Names, Names, Graph VName)
analyseExp LUTabFun
lumap Names
inuse_outside Exp GPUMem
expr =
  case Exp GPUMem
expr of
    Match [SubExp]
_ [Case (Body GPUMem)]
cases Body GPUMem
defbody MatchDec (BranchType GPUMem)
_ ->
      ([(Names, Names, Graph VName)] -> (Names, Names, Graph VName))
-> m [(Names, Names, Graph VName)] -> m (Names, Names, Graph VName)
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [(Names, Names, Graph VName)] -> (Names, Names, Graph VName)
forall a. Monoid a => [a] -> a
mconcat (m [(Names, Names, Graph VName)] -> m (Names, Names, Graph VName))
-> m [(Names, Names, Graph VName)] -> m (Names, Names, Graph VName)
forall a b. (a -> b) -> a -> b
$
        (Body GPUMem -> m (Names, Names, Graph VName))
-> [Body GPUMem] -> m [(Names, Names, Graph VName)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (LUTabFun -> Names -> Body GPUMem -> m (Names, Names, Graph VName)
forall (m :: * -> *).
LocalScope GPUMem m =>
LUTabFun -> Names -> Body GPUMem -> m (Names, Names, Graph VName)
analyseBody LUTabFun
lumap Names
inuse_outside) ([Body GPUMem] -> m [(Names, Names, Graph VName)])
-> [Body GPUMem] -> m [(Names, Names, Graph VName)]
forall a b. (a -> b) -> a -> b
$
          Body GPUMem
defbody Body GPUMem -> [Body GPUMem] -> [Body GPUMem]
forall a. a -> [a] -> [a]
: (Case (Body GPUMem) -> Body GPUMem)
-> [Case (Body GPUMem)] -> [Body GPUMem]
forall a b. (a -> b) -> [a] -> [b]
map Case (Body GPUMem) -> Body GPUMem
forall body. Case body -> body
caseBody [Case (Body GPUMem)]
cases
    Loop [(FParam GPUMem, SubExp)]
merge LoopForm
_ Body GPUMem
body ->
      [(FParam GPUMem, SubExp)]
-> (Names, Names, Graph VName) -> (Names, Names, Graph VName)
analyseLoopParams [(FParam GPUMem, SubExp)]
merge ((Names, Names, Graph VName) -> (Names, Names, Graph VName))
-> m (Names, Names, Graph VName) -> m (Names, Names, Graph VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> LUTabFun -> Names -> Body GPUMem -> m (Names, Names, Graph VName)
forall (m :: * -> *).
LocalScope GPUMem m =>
LUTabFun -> Names -> Body GPUMem -> m (Names, Names, Graph VName)
analyseBody LUTabFun
lumap Names
inuse_outside Body GPUMem
body
    Op (Inner (SegOp SegOp SegLevel GPUMem
segop)) -> do
      LUTabFun
-> Names -> SegOp SegLevel GPUMem -> m (Names, Names, Graph VName)
forall (m :: * -> *) lvl.
LocalScope GPUMem m =>
LUTabFun
-> Names -> SegOp lvl GPUMem -> m (Names, Names, Graph VName)
analyseSegOp LUTabFun
lumap Names
inuse_outside SegOp SegLevel GPUMem
segop
    Exp GPUMem
_ ->
      (Names, Names, Graph VName) -> m (Names, Names, Graph VName)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Names, Names, Graph VName)
forall a. Monoid a => a
mempty

analyseKernelBody ::
  (LocalScope GPUMem m) =>
  LUTabFun ->
  InUse ->
  KernelBody GPUMem ->
  m (InUse, LastUsed, Graph VName)
analyseKernelBody :: forall (m :: * -> *).
LocalScope GPUMem m =>
LUTabFun
-> Names -> KernelBody GPUMem -> m (Names, Names, Graph VName)
analyseKernelBody LUTabFun
lumap Names
inuse KernelBody GPUMem
body = LUTabFun -> Names -> Stms GPUMem -> m (Names, Names, Graph VName)
forall (m :: * -> *).
LocalScope GPUMem m =>
LUTabFun -> Names -> Stms GPUMem -> m (Names, Names, Graph VName)
analyseStms LUTabFun
lumap Names
inuse (Stms GPUMem -> m (Names, Names, Graph VName))
-> Stms GPUMem -> m (Names, Names, Graph VName)
forall a b. (a -> b) -> a -> b
$ KernelBody GPUMem -> Stms GPUMem
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
body

analyseBody ::
  (LocalScope GPUMem m) =>
  LUTabFun ->
  InUse ->
  Body GPUMem ->
  m (InUse, LastUsed, Graph VName)
analyseBody :: forall (m :: * -> *).
LocalScope GPUMem m =>
LUTabFun -> Names -> Body GPUMem -> m (Names, Names, Graph VName)
analyseBody LUTabFun
lumap Names
inuse Body GPUMem
body = LUTabFun -> Names -> Stms GPUMem -> m (Names, Names, Graph VName)
forall (m :: * -> *).
LocalScope GPUMem m =>
LUTabFun -> Names -> Stms GPUMem -> m (Names, Names, Graph VName)
analyseStms LUTabFun
lumap Names
inuse (Stms GPUMem -> m (Names, Names, Graph VName))
-> Stms GPUMem -> m (Names, Names, Graph VName)
forall a b. (a -> b) -> a -> b
$ Body GPUMem -> Stms GPUMem
forall rep. Body rep -> Stms rep
bodyStms Body GPUMem
body

analyseStms ::
  (LocalScope GPUMem m) =>
  LUTabFun ->
  InUse ->
  Stms GPUMem ->
  m (InUse, LastUsed, Graph VName)
analyseStms :: forall (m :: * -> *).
LocalScope GPUMem m =>
LUTabFun -> Names -> Stms GPUMem -> m (Names, Names, Graph VName)
analyseStms LUTabFun
lumap Names
inuse0 Stms GPUMem
stms = do
  Stms GPUMem
-> m (Names, Names, Graph VName) -> m (Names, Names, Graph VName)
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stms GPUMem
stms (m (Names, Names, Graph VName) -> m (Names, Names, Graph VName))
-> m (Names, Names, Graph VName) -> m (Names, Names, Graph VName)
forall a b. (a -> b) -> a -> b
$ ((Names, Names, Graph VName)
 -> Stm GPUMem -> m (Names, Names, Graph VName))
-> (Names, Names, Graph VName)
-> [Stm GPUMem]
-> m (Names, Names, Graph VName)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (Names, Names, Graph VName)
-> Stm GPUMem -> m (Names, Names, Graph VName)
helper (Names
inuse0, Names
forall a. Monoid a => a
mempty, Graph VName
forall a. Monoid a => a
mempty) ([Stm GPUMem] -> m (Names, Names, Graph VName))
-> [Stm GPUMem] -> m (Names, Names, Graph VName)
forall a b. (a -> b) -> a -> b
$ Stms GPUMem -> [Stm GPUMem]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms GPUMem
stms
  where
    helper :: (Names, Names, Graph VName)
-> Stm GPUMem -> m (Names, Names, Graph VName)
helper (Names
inuse, Names
lus, Graph VName
graph) Stm GPUMem
stm = do
      (Names
inuse', Names
lus', Graph VName
graph') <- LUTabFun -> Names -> Stm GPUMem -> m (Names, Names, Graph VName)
forall (m :: * -> *).
LocalScope GPUMem m =>
LUTabFun -> Names -> Stm GPUMem -> m (Names, Names, Graph VName)
analyseStm LUTabFun
lumap Names
inuse Stm GPUMem
stm
      (Names, Names, Graph VName) -> m (Names, Names, Graph VName)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Names
inuse', Names
lus' Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
lus, Graph VName
graph' Graph VName -> Graph VName -> Graph VName
forall a. Semigroup a => a -> a -> a
<> Graph VName
graph)

analyseSegOp ::
  (LocalScope GPUMem m) =>
  LUTabFun ->
  InUse ->
  SegOp lvl GPUMem ->
  m (InUse, LastUsed, Graph VName)
analyseSegOp :: forall (m :: * -> *) lvl.
LocalScope GPUMem m =>
LUTabFun
-> Names -> SegOp lvl GPUMem -> m (Names, Names, Graph VName)
analyseSegOp LUTabFun
lumap Names
inuse (SegMap lvl
_ SegSpace
_ [Type]
_ KernelBody GPUMem
body) =
  LUTabFun
-> Names -> KernelBody GPUMem -> m (Names, Names, Graph VName)
forall (m :: * -> *).
LocalScope GPUMem m =>
LUTabFun
-> Names -> KernelBody GPUMem -> m (Names, Names, Graph VName)
analyseKernelBody LUTabFun
lumap Names
inuse KernelBody GPUMem
body
analyseSegOp LUTabFun
lumap Names
inuse (SegRed lvl
_ SegSpace
_ [SegBinOp GPUMem]
binops [Type]
_ KernelBody GPUMem
body) =
  LUTabFun
-> Names
-> [SegBinOp GPUMem]
-> KernelBody GPUMem
-> m (Names, Names, Graph VName)
forall (m :: * -> *).
LocalScope GPUMem m =>
LUTabFun
-> Names
-> [SegBinOp GPUMem]
-> KernelBody GPUMem
-> m (Names, Names, Graph VName)
segWithBinOps LUTabFun
lumap Names
inuse [SegBinOp GPUMem]
binops KernelBody GPUMem
body
analyseSegOp LUTabFun
lumap Names
inuse (SegScan lvl
_ SegSpace
_ [SegBinOp GPUMem]
binops [Type]
_ KernelBody GPUMem
body) = do
  LUTabFun
-> Names
-> [SegBinOp GPUMem]
-> KernelBody GPUMem
-> m (Names, Names, Graph VName)
forall (m :: * -> *).
LocalScope GPUMem m =>
LUTabFun
-> Names
-> [SegBinOp GPUMem]
-> KernelBody GPUMem
-> m (Names, Names, Graph VName)
segWithBinOps LUTabFun
lumap Names
inuse [SegBinOp GPUMem]
binops KernelBody GPUMem
body
analyseSegOp LUTabFun
lumap Names
inuse (SegHist lvl
_ SegSpace
_ [HistOp GPUMem]
histops [Type]
_ KernelBody GPUMem
body) = do
  (Names
inuse', Names
lus', Graph VName
graph) <- LUTabFun
-> Names -> KernelBody GPUMem -> m (Names, Names, Graph VName)
forall (m :: * -> *).
LocalScope GPUMem m =>
LUTabFun
-> Names -> KernelBody GPUMem -> m (Names, Names, Graph VName)
analyseKernelBody LUTabFun
lumap Names
inuse KernelBody GPUMem
body
  (Names
inuse'', Names
lus'', Graph VName
graph') <- [(Names, Names, Graph VName)] -> (Names, Names, Graph VName)
forall a. Monoid a => [a] -> a
mconcat ([(Names, Names, Graph VName)] -> (Names, Names, Graph VName))
-> m [(Names, Names, Graph VName)] -> m (Names, Names, Graph VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (HistOp GPUMem -> m (Names, Names, Graph VName))
-> [HistOp GPUMem] -> m [(Names, Names, Graph VName)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (LUTabFun -> Names -> HistOp GPUMem -> m (Names, Names, Graph VName)
forall (m :: * -> *).
LocalScope GPUMem m =>
LUTabFun -> Names -> HistOp GPUMem -> m (Names, Names, Graph VName)
analyseHistOp LUTabFun
lumap Names
inuse') [HistOp GPUMem]
histops
  (Names, Names, Graph VName) -> m (Names, Names, Graph VName)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Names
inuse'', Names
lus' Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
lus'', Graph VName
graph Graph VName -> Graph VName -> Graph VName
forall a. Semigroup a => a -> a -> a
<> Graph VName
graph')

segWithBinOps ::
  (LocalScope GPUMem m) =>
  LUTabFun ->
  InUse ->
  [SegBinOp GPUMem] ->
  KernelBody GPUMem ->
  m (InUse, LastUsed, Graph VName)
segWithBinOps :: forall (m :: * -> *).
LocalScope GPUMem m =>
LUTabFun
-> Names
-> [SegBinOp GPUMem]
-> KernelBody GPUMem
-> m (Names, Names, Graph VName)
segWithBinOps LUTabFun
lumap Names
inuse [SegBinOp GPUMem]
binops KernelBody GPUMem
body = do
  (Names
inuse', Names
lus', Graph VName
graph) <- LUTabFun
-> Names -> KernelBody GPUMem -> m (Names, Names, Graph VName)
forall (m :: * -> *).
LocalScope GPUMem m =>
LUTabFun
-> Names -> KernelBody GPUMem -> m (Names, Names, Graph VName)
analyseKernelBody LUTabFun
lumap Names
inuse KernelBody GPUMem
body
  (Names
inuse'', Names
lus'', Graph VName
graph') <-
    [(Names, Names, Graph VName)] -> (Names, Names, Graph VName)
forall a. Monoid a => [a] -> a
mconcat
      ([(Names, Names, Graph VName)] -> (Names, Names, Graph VName))
-> m [(Names, Names, Graph VName)] -> m (Names, Names, Graph VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SegBinOp GPUMem -> m (Names, Names, Graph VName))
-> [SegBinOp GPUMem] -> m [(Names, Names, Graph VName)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM
        (LUTabFun
-> Names -> SegBinOp GPUMem -> m (Names, Names, Graph VName)
forall (m :: * -> *).
LocalScope GPUMem m =>
LUTabFun
-> Names -> SegBinOp GPUMem -> m (Names, Names, Graph VName)
analyseSegBinOp LUTabFun
lumap Names
inuse')
        [SegBinOp GPUMem]
binops
  (Names, Names, Graph VName) -> m (Names, Names, Graph VName)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Names
inuse'', Names
lus' Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
lus'', Graph VName
graph Graph VName -> Graph VName -> Graph VName
forall a. Semigroup a => a -> a -> a
<> Graph VName
graph')

analyseSegBinOp ::
  (LocalScope GPUMem m) =>
  LUTabFun ->
  InUse ->
  SegBinOp GPUMem ->
  m (InUse, LastUsed, Graph VName)
analyseSegBinOp :: forall (m :: * -> *).
LocalScope GPUMem m =>
LUTabFun
-> Names -> SegBinOp GPUMem -> m (Names, Names, Graph VName)
analyseSegBinOp LUTabFun
lumap Names
inuse (SegBinOp Commutativity
_ Lambda GPUMem
lambda [SubExp]
_ Shape
_) =
  LUTabFun -> Names -> Lambda GPUMem -> m (Names, Names, Graph VName)
forall (m :: * -> *).
LocalScope GPUMem m =>
LUTabFun -> Names -> Lambda GPUMem -> m (Names, Names, Graph VName)
analyseLambda LUTabFun
lumap Names
inuse Lambda GPUMem
lambda

analyseHistOp ::
  (LocalScope GPUMem m) =>
  LUTabFun ->
  InUse ->
  HistOp GPUMem ->
  m (InUse, LastUsed, Graph VName)
analyseHistOp :: forall (m :: * -> *).
LocalScope GPUMem m =>
LUTabFun -> Names -> HistOp GPUMem -> m (Names, Names, Graph VName)
analyseHistOp LUTabFun
lumap Names
inuse = LUTabFun -> Names -> Lambda GPUMem -> m (Names, Names, Graph VName)
forall (m :: * -> *).
LocalScope GPUMem m =>
LUTabFun -> Names -> Lambda GPUMem -> m (Names, Names, Graph VName)
analyseLambda LUTabFun
lumap Names
inuse (Lambda GPUMem -> m (Names, Names, Graph VName))
-> (HistOp GPUMem -> Lambda GPUMem)
-> HistOp GPUMem
-> m (Names, Names, Graph VName)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp GPUMem -> Lambda GPUMem
forall rep. HistOp rep -> Lambda rep
histOp

analyseLambda ::
  (LocalScope GPUMem m) =>
  LUTabFun ->
  InUse ->
  Lambda GPUMem ->
  m (InUse, LastUsed, Graph VName)
analyseLambda :: forall (m :: * -> *).
LocalScope GPUMem m =>
LUTabFun -> Names -> Lambda GPUMem -> m (Names, Names, Graph VName)
analyseLambda LUTabFun
lumap Names
inuse = LUTabFun -> Names -> Body GPUMem -> m (Names, Names, Graph VName)
forall (m :: * -> *).
LocalScope GPUMem m =>
LUTabFun -> Names -> Body GPUMem -> m (Names, Names, Graph VName)
analyseBody LUTabFun
lumap Names
inuse (Body GPUMem -> m (Names, Names, Graph VName))
-> (Lambda GPUMem -> Body GPUMem)
-> Lambda GPUMem
-> m (Names, Names, Graph VName)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda GPUMem -> Body GPUMem
forall rep. Lambda rep -> Body rep
lambdaBody

analyseProgGPU :: Prog GPUMem -> Graph VName
analyseProgGPU :: Prog GPUMem -> Graph VName
analyseProgGPU Prog GPUMem
prog = Stms GPUMem -> Graph VName
onConsts (Prog GPUMem -> Stms GPUMem
forall rep. Prog rep -> Stms rep
progConsts Prog GPUMem
prog) Graph VName -> Graph VName -> Graph VName
forall a. Semigroup a => a -> a -> a
<> (FunDef GPUMem -> Graph VName) -> [FunDef GPUMem] -> Graph VName
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap FunDef GPUMem -> Graph VName
onFun (Prog GPUMem -> [FunDef GPUMem]
forall rep. Prog rep -> [FunDef rep]
progFuns Prog GPUMem
prog)
  where
    (MemAliases
consts_aliases, Map Name MemAliases
funs_aliases) = Prog GPUMem -> (MemAliases, Map Name MemAliases)
MemAlias.analyzeGPUMem Prog GPUMem
prog
    (LUTabFun
lumap_consts, Map Name LUTabFun
lumap) = Prog (Aliases GPUMem) -> (LUTabFun, Map Name LUTabFun)
LastUse.lastUseGPUMem (Prog (Aliases GPUMem) -> (LUTabFun, Map Name LUTabFun))
-> Prog (Aliases GPUMem) -> (LUTabFun, Map Name LUTabFun)
forall a b. (a -> b) -> a -> b
$ Prog GPUMem -> Prog (Aliases GPUMem)
forall rep. AliasableRep rep => Prog rep -> Prog (Aliases rep)
AnlAls.aliasAnalysis Prog GPUMem
prog
    onFun :: FunDef GPUMem -> Graph VName
onFun FunDef GPUMem
f =
      MemAliases -> Graph VName -> Graph VName
applyAliases (MemAliases -> Maybe MemAliases -> MemAliases
forall a. a -> Maybe a -> a
fromMaybe MemAliases
forall a. Monoid a => a
mempty (Maybe MemAliases -> MemAliases) -> Maybe MemAliases -> MemAliases
forall a b. (a -> b) -> a -> b
$ Name -> Map Name MemAliases -> Maybe MemAliases
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (FunDef GPUMem -> Name
forall rep. FunDef rep -> Name
funDefName FunDef GPUMem
f) Map Name MemAliases
funs_aliases) (Graph VName -> Graph VName) -> Graph VName -> Graph VName
forall a b. (a -> b) -> a -> b
$
        Reader (Scope GPUMem) (Graph VName) -> Scope GPUMem -> Graph VName
forall r a. Reader r a -> r -> a
runReader (LUTabFun -> Stms GPUMem -> Reader (Scope GPUMem) (Graph VName)
forall (m :: * -> *).
LocalScope GPUMem m =>
LUTabFun -> Stms GPUMem -> m (Graph VName)
analyseGPU (Map Name LUTabFun
lumap Map Name LUTabFun -> Name -> LUTabFun
forall k a. Ord k => Map k a -> k -> a
M.! FunDef GPUMem -> Name
forall rep. FunDef rep -> Name
funDefName FunDef GPUMem
f) (Stms GPUMem -> Reader (Scope GPUMem) (Graph VName))
-> Stms GPUMem -> Reader (Scope GPUMem) (Graph VName)
forall a b. (a -> b) -> a -> b
$ Body GPUMem -> Stms GPUMem
forall rep. Body rep -> Stms rep
bodyStms (Body GPUMem -> Stms GPUMem) -> Body GPUMem -> Stms GPUMem
forall a b. (a -> b) -> a -> b
$ FunDef GPUMem -> Body GPUMem
forall rep. FunDef rep -> Body rep
funDefBody FunDef GPUMem
f) (Scope GPUMem -> Graph VName) -> Scope GPUMem -> Graph VName
forall a b. (a -> b) -> a -> b
$
          FunDef GPUMem -> Scope GPUMem
forall rep a. Scoped rep a => a -> Scope rep
scopeOf FunDef GPUMem
f
    onConsts :: Stms GPUMem -> Graph VName
onConsts Stms GPUMem
stms =
      MemAliases -> Graph VName -> Graph VName
applyAliases MemAliases
consts_aliases (Graph VName -> Graph VName) -> Graph VName -> Graph VName
forall a b. (a -> b) -> a -> b
$
        Reader (Scope GPUMem) (Graph VName) -> Scope GPUMem -> Graph VName
forall r a. Reader r a -> r -> a
runReader (LUTabFun -> Stms GPUMem -> Reader (Scope GPUMem) (Graph VName)
forall (m :: * -> *).
LocalScope GPUMem m =>
LUTabFun -> Stms GPUMem -> m (Graph VName)
analyseGPU LUTabFun
lumap_consts Stms GPUMem
stms) (Scope GPUMem
forall a. Monoid a => a
mempty :: Scope GPUMem)

applyAliases :: MemAlias.MemAliases -> Graph VName -> Graph VName
applyAliases :: MemAliases -> Graph VName -> Graph VName
applyAliases MemAliases
aliases =
  -- For each pair @(x, y)@ in graph, all memory aliases of x should interfere with all memory aliases of y
  ((VName, VName) -> Graph VName) -> Graph VName -> Graph VName
forall m a. Monoid m => (a -> m) -> Set a -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap
    ( \(VName
x, VName
y) ->
        let xs :: Names
xs = MemAliases -> VName -> Names
MemAlias.aliasesOf MemAliases
aliases VName
x Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> VName -> Names
oneName VName
x
            ys :: Names
ys = MemAliases -> VName -> Names
MemAlias.aliasesOf MemAliases
aliases VName
y Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> VName -> Names
oneName VName
y
         in (VName -> VName -> Graph VName)
-> [VName] -> [VName] -> Graph VName
forall m (t :: * -> *) a.
(Monoid m, Foldable t) =>
(a -> a -> m) -> t a -> t a -> m
cartesian VName -> VName -> Graph VName
forall a. Ord a => a -> a -> Graph a
makeEdge (Names -> [VName]
namesToList Names
xs) (Names -> [VName]
namesToList Names
ys)
    )

-- | Perform interference analysis on the given statements. The result is a
-- triple of the names currently in use, names that hit their last use somewhere
-- within, and the resulting graph.
analyseGPU ::
  (LocalScope GPUMem m) =>
  LUTabFun ->
  Stms GPUMem ->
  m (Graph VName)
analyseGPU :: forall (m :: * -> *).
LocalScope GPUMem m =>
LUTabFun -> Stms GPUMem -> m (Graph VName)
analyseGPU LUTabFun
lumap Stms GPUMem
stms = do
  (Names
_, Names
_, Graph VName
graph) <- LUTabFun -> Stms GPUMem -> m (Names, Names, Graph VName)
forall (m :: * -> *).
LocalScope GPUMem m =>
LUTabFun -> Stms GPUMem -> m (Names, Names, Graph VName)
analyseGPU' LUTabFun
lumap Stms GPUMem
stms
  -- We need to insert edges between memory blocks which differ in size, if they
  -- are in DefaultSpace. The problem is that during memory expansion,
  -- DefaultSpace arrays in kernels are interleaved. If the element sizes of two
  -- merged memory blocks are different, threads might try to read and write to
  -- overlapping memory positions. More information here:
  -- https://munksgaard.me/technical-diary/2020-12-30.html#org210775b
  Map VName Space
spaces <- (Space -> Bool) -> Map VName Space -> Map VName Space
forall a k. (a -> Bool) -> Map k a -> Map k a
M.filter (Space -> Space -> Bool
forall a. Eq a => a -> a -> Bool
== Space
DefaultSpace) (Map VName Space -> Map VName Space)
-> m (Map VName Space) -> m (Map VName Space)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Stms GPUMem -> m (Map VName Space)
forall (m :: * -> *).
LocalScope GPUMem m =>
Stms GPUMem -> m (Map VName Space)
memSpaces Stms GPUMem
stms
  Map Int (Set VName)
inv_size_map <-
    Stms GPUMem -> m (Map VName Int)
forall (m :: * -> *).
LocalScope GPUMem m =>
Stms GPUMem -> m (Map VName Int)
memSizes Stms GPUMem
stms
      m (Map VName Int)
-> (Map VName Int -> Map VName Int) -> m (Map VName Int)
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> (Map VName Int -> Set VName -> Map VName Int)
-> Set VName -> Map VName Int -> Map VName Int
forall a b c. (a -> b -> c) -> b -> a -> c
flip Map VName Int -> Set VName -> Map VName Int
forall k a. Ord k => Map k a -> Set k -> Map k a
M.restrictKeys ([VName] -> Set VName
forall a. Ord a => [a] -> Set a
S.fromList ([VName] -> Set VName) -> [VName] -> Set VName
forall a b. (a -> b) -> a -> b
$ Map VName Space -> [VName]
forall k a. Map k a -> [k]
M.keys Map VName Space
spaces)
      m (Map VName Int)
-> (Map VName Int -> Map Int (Set VName))
-> m (Map Int (Set VName))
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> Map VName Int -> Map Int (Set VName)
forall v k. (Ord v, Ord k) => Map k v -> Map v (Set k)
invertMap
  let new_edges :: Graph VName
new_edges =
        (Set VName -> Set VName -> Graph VName)
-> Map Int (Set VName) -> Map Int (Set VName) -> Graph VName
forall m (t :: * -> *) a.
(Monoid m, Foldable t) =>
(a -> a -> m) -> t a -> t a -> m
cartesian
          (\Set VName
x Set VName
y -> if Set VName
x Set VName -> Set VName -> Bool
forall a. Eq a => a -> a -> Bool
/= Set VName
y then (VName -> VName -> Graph VName)
-> Set VName -> Set VName -> Graph VName
forall m (t :: * -> *) a.
(Monoid m, Foldable t) =>
(a -> a -> m) -> t a -> t a -> m
cartesian VName -> VName -> Graph VName
forall a. Ord a => a -> a -> Graph a
makeEdge Set VName
x Set VName
y else Graph VName
forall a. Monoid a => a
mempty)
          Map Int (Set VName)
inv_size_map
          Map Int (Set VName)
inv_size_map
  Graph VName -> m (Graph VName)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Graph VName -> m (Graph VName)) -> Graph VName -> m (Graph VName)
forall a b. (a -> b) -> a -> b
$ Graph VName
graph Graph VName -> Graph VName -> Graph VName
forall a. Semigroup a => a -> a -> a
<> Graph VName
new_edges

-- | Return a mapping from memory blocks to their element sizes in the given
-- statements.
memSizes :: (LocalScope GPUMem m) => Stms GPUMem -> m (Map VName Int)
memSizes :: forall (m :: * -> *).
LocalScope GPUMem m =>
Stms GPUMem -> m (Map VName Int)
memSizes Stms GPUMem
stms =
  Stms GPUMem -> m (Map VName Int) -> m (Map VName Int)
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stms GPUMem
stms (m (Map VName Int) -> m (Map VName Int))
-> m (Map VName Int) -> m (Map VName Int)
forall a b. (a -> b) -> a -> b
$ ([Map VName Int] -> Map VName Int)
-> m [Map VName Int] -> m (Map VName Int)
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [Map VName Int] -> Map VName Int
forall a. Monoid a => [a] -> a
mconcat (m [Map VName Int] -> m (Map VName Int))
-> ([Stm GPUMem] -> m [Map VName Int])
-> [Stm GPUMem]
-> m (Map VName Int)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Stm GPUMem -> m (Map VName Int))
-> [Stm GPUMem] -> m [Map VName Int]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Stm GPUMem -> m (Map VName Int)
forall (m :: * -> *).
LocalScope GPUMem m =>
Stm GPUMem -> m (Map VName Int)
memSizesStm ([Stm GPUMem] -> m (Map VName Int))
-> [Stm GPUMem] -> m (Map VName Int)
forall a b. (a -> b) -> a -> b
$ Stms GPUMem -> [Stm GPUMem]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms GPUMem
stms
  where
    memSizesStm :: (LocalScope GPUMem m) => Stm GPUMem -> m (Map VName Int)
    memSizesStm :: forall (m :: * -> *).
LocalScope GPUMem m =>
Stm GPUMem -> m (Map VName Int)
memSizesStm (Let Pat (LetDec GPUMem)
pat StmAux (ExpDec GPUMem)
_ Exp GPUMem
e) = do
      Map VName Int
arraySizes <- ([Map VName Int] -> Map VName Int)
-> m [Map VName Int] -> m (Map VName Int)
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [Map VName Int] -> Map VName Int
forall a. Monoid a => [a] -> a
mconcat (m [Map VName Int] -> m (Map VName Int))
-> ([VName] -> m [Map VName Int]) -> [VName] -> m (Map VName Int)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VName -> m (Map VName Int)) -> [VName] -> m [Map VName Int]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> m (Map VName Int)
forall (m :: * -> *).
LocalScope GPUMem m =>
VName -> m (Map VName Int)
memElemSize ([VName] -> m (Map VName Int)) -> [VName] -> m (Map VName Int)
forall a b. (a -> b) -> a -> b
$ Pat LetDecMem -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat (LetDec GPUMem)
Pat LetDecMem
pat
      Map VName Int
arraySizes' <- Exp GPUMem -> m (Map VName Int)
forall (m :: * -> *).
LocalScope GPUMem m =>
Exp GPUMem -> m (Map VName Int)
memSizesExp Exp GPUMem
e
      Map VName Int -> m (Map VName Int)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Map VName Int -> m (Map VName Int))
-> Map VName Int -> m (Map VName Int)
forall a b. (a -> b) -> a -> b
$ Map VName Int
arraySizes Map VName Int -> Map VName Int -> Map VName Int
forall a. Semigroup a => a -> a -> a
<> Map VName Int
arraySizes'
    memSizesExp :: (LocalScope GPUMem m) => Exp GPUMem -> m (Map VName Int)
    memSizesExp :: forall (m :: * -> *).
LocalScope GPUMem m =>
Exp GPUMem -> m (Map VName Int)
memSizesExp (Op (Inner (SegOp SegOp SegLevel GPUMem
segop))) =
      let body :: KernelBody GPUMem
body = SegOp SegLevel GPUMem -> KernelBody GPUMem
forall lvl rep. SegOp lvl rep -> KernelBody rep
segBody SegOp SegLevel GPUMem
segop
       in Stms GPUMem -> m (Map VName Int) -> m (Map VName Int)
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf (KernelBody GPUMem -> Stms GPUMem
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
body)
            (m (Map VName Int) -> m (Map VName Int))
-> m (Map VName Int) -> m (Map VName Int)
forall a b. (a -> b) -> a -> b
$ ([Map VName Int] -> Map VName Int)
-> m [Map VName Int] -> m (Map VName Int)
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [Map VName Int] -> Map VName Int
forall a. Monoid a => [a] -> a
mconcat
              (m [Map VName Int] -> m (Map VName Int))
-> ([Stm GPUMem] -> m [Map VName Int])
-> [Stm GPUMem]
-> m (Map VName Int)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Stm GPUMem -> m (Map VName Int))
-> [Stm GPUMem] -> m [Map VName Int]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Stm GPUMem -> m (Map VName Int)
forall (m :: * -> *).
LocalScope GPUMem m =>
Stm GPUMem -> m (Map VName Int)
memSizesStm
            ([Stm GPUMem] -> m (Map VName Int))
-> [Stm GPUMem] -> m (Map VName Int)
forall a b. (a -> b) -> a -> b
$ Stms GPUMem -> [Stm GPUMem]
forall rep. Stms rep -> [Stm rep]
stmsToList
            (Stms GPUMem -> [Stm GPUMem]) -> Stms GPUMem -> [Stm GPUMem]
forall a b. (a -> b) -> a -> b
$ KernelBody GPUMem -> Stms GPUMem
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
body
    memSizesExp (Match [SubExp]
_ [Case (Body GPUMem)]
cases Body GPUMem
defbody MatchDec (BranchType GPUMem)
_) = do
      [Map VName Int] -> Map VName Int
forall a. Monoid a => [a] -> a
mconcat ([Map VName Int] -> Map VName Int)
-> m [Map VName Int] -> m (Map VName Int)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Body GPUMem -> m (Map VName Int))
-> [Body GPUMem] -> m [Map VName Int]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Stms GPUMem -> m (Map VName Int)
forall (m :: * -> *).
LocalScope GPUMem m =>
Stms GPUMem -> m (Map VName Int)
memSizes (Stms GPUMem -> m (Map VName Int))
-> (Body GPUMem -> Stms GPUMem) -> Body GPUMem -> m (Map VName Int)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body GPUMem -> Stms GPUMem
forall rep. Body rep -> Stms rep
bodyStms) (Body GPUMem
defbody Body GPUMem -> [Body GPUMem] -> [Body GPUMem]
forall a. a -> [a] -> [a]
: (Case (Body GPUMem) -> Body GPUMem)
-> [Case (Body GPUMem)] -> [Body GPUMem]
forall a b. (a -> b) -> [a] -> [b]
map Case (Body GPUMem) -> Body GPUMem
forall body. Case body -> body
caseBody [Case (Body GPUMem)]
cases)
    memSizesExp (Loop [(FParam GPUMem, SubExp)]
_ LoopForm
_ Body GPUMem
body) =
      Stms GPUMem -> m (Map VName Int)
forall (m :: * -> *).
LocalScope GPUMem m =>
Stms GPUMem -> m (Map VName Int)
memSizes (Stms GPUMem -> m (Map VName Int))
-> Stms GPUMem -> m (Map VName Int)
forall a b. (a -> b) -> a -> b
$ Body GPUMem -> Stms GPUMem
forall rep. Body rep -> Stms rep
bodyStms Body GPUMem
body
    memSizesExp Exp GPUMem
_ = Map VName Int -> m (Map VName Int)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Map VName Int
forall a. Monoid a => a
mempty

-- | Return a mapping from memory blocks to the space they are allocated in.
memSpaces :: (LocalScope GPUMem m) => Stms GPUMem -> m (Map VName Space)
memSpaces :: forall (m :: * -> *).
LocalScope GPUMem m =>
Stms GPUMem -> m (Map VName Space)
memSpaces Stms GPUMem
stms =
  Map VName Space -> m (Map VName Space)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Map VName Space -> m (Map VName Space))
-> Map VName Space -> m (Map VName Space)
forall a b. (a -> b) -> a -> b
$ (Stm GPUMem -> Map VName Space) -> Stms GPUMem -> Map VName Space
forall m a. Monoid m => (a -> m) -> Seq a -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm GPUMem -> Map VName Space
getSpacesStm Stms GPUMem
stms
  where
    getSpacesStm :: Stm GPUMem -> Map VName Space
    getSpacesStm :: Stm GPUMem -> Map VName Space
getSpacesStm (Let (Pat [PatElem VName
name LetDec GPUMem
_]) StmAux (ExpDec GPUMem)
_ (Op (Alloc SubExp
_ Space
sp))) =
      VName -> Space -> Map VName Space
forall k a. k -> a -> Map k a
M.singleton VName
name Space
sp
    getSpacesStm (Let Pat (LetDec GPUMem)
_ StmAux (ExpDec GPUMem)
_ (Op (Alloc SubExp
_ Space
_))) = [Char] -> Map VName Space
forall a. HasCallStack => [Char] -> a
error [Char]
"impossible"
    getSpacesStm (Let Pat (LetDec GPUMem)
_ StmAux (ExpDec GPUMem)
_ (Op (Inner (SegOp SegOp SegLevel GPUMem
segop)))) =
      (Stm GPUMem -> Map VName Space) -> Stms GPUMem -> Map VName Space
forall m a. Monoid m => (a -> m) -> Seq a -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm GPUMem -> Map VName Space
getSpacesStm (Stms GPUMem -> Map VName Space) -> Stms GPUMem -> Map VName Space
forall a b. (a -> b) -> a -> b
$ KernelBody GPUMem -> Stms GPUMem
forall rep. KernelBody rep -> Stms rep
kernelBodyStms (KernelBody GPUMem -> Stms GPUMem)
-> KernelBody GPUMem -> Stms GPUMem
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel GPUMem -> KernelBody GPUMem
forall lvl rep. SegOp lvl rep -> KernelBody rep
segBody SegOp SegLevel GPUMem
segop
    getSpacesStm (Let Pat (LetDec GPUMem)
_ StmAux (ExpDec GPUMem)
_ (Match [SubExp]
_ [Case (Body GPUMem)]
cases Body GPUMem
defbody MatchDec (BranchType GPUMem)
_)) =
      (Body GPUMem -> Map VName Space)
-> [Body GPUMem] -> Map VName Space
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap ((Stm GPUMem -> Map VName Space) -> Stms GPUMem -> Map VName Space
forall m a. Monoid m => (a -> m) -> Seq a -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm GPUMem -> Map VName Space
getSpacesStm (Stms GPUMem -> Map VName Space)
-> (Body GPUMem -> Stms GPUMem) -> Body GPUMem -> Map VName Space
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body GPUMem -> Stms GPUMem
forall rep. Body rep -> Stms rep
bodyStms) ([Body GPUMem] -> Map VName Space)
-> [Body GPUMem] -> Map VName Space
forall a b. (a -> b) -> a -> b
$ Body GPUMem
defbody Body GPUMem -> [Body GPUMem] -> [Body GPUMem]
forall a. a -> [a] -> [a]
: (Case (Body GPUMem) -> Body GPUMem)
-> [Case (Body GPUMem)] -> [Body GPUMem]
forall a b. (a -> b) -> [a] -> [b]
map Case (Body GPUMem) -> Body GPUMem
forall body. Case body -> body
caseBody [Case (Body GPUMem)]
cases
    getSpacesStm (Let Pat (LetDec GPUMem)
_ StmAux (ExpDec GPUMem)
_ (Loop [(FParam GPUMem, SubExp)]
_ LoopForm
_ Body GPUMem
body)) =
      (Stm GPUMem -> Map VName Space) -> Stms GPUMem -> Map VName Space
forall m a. Monoid m => (a -> m) -> Seq a -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm GPUMem -> Map VName Space
getSpacesStm (Body GPUMem -> Stms GPUMem
forall rep. Body rep -> Stms rep
bodyStms Body GPUMem
body)
    getSpacesStm Stm GPUMem
_ = Map VName Space
forall a. Monoid a => a
mempty

analyseGPU' ::
  (LocalScope GPUMem m) =>
  LUTabFun ->
  Stms GPUMem ->
  m (InUse, LastUsed, Graph VName)
analyseGPU' :: forall (m :: * -> *).
LocalScope GPUMem m =>
LUTabFun -> Stms GPUMem -> m (Names, Names, Graph VName)
analyseGPU' LUTabFun
lumap Stms GPUMem
stms =
  [(Names, Names, Graph VName)] -> (Names, Names, Graph VName)
forall a. Monoid a => [a] -> a
mconcat ([(Names, Names, Graph VName)] -> (Names, Names, Graph VName))
-> (Seq (Names, Names, Graph VName)
    -> [(Names, Names, Graph VName)])
-> Seq (Names, Names, Graph VName)
-> (Names, Names, Graph VName)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Seq (Names, Names, Graph VName) -> [(Names, Names, Graph VName)]
forall a. Seq a -> [a]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList (Seq (Names, Names, Graph VName) -> (Names, Names, Graph VName))
-> m (Seq (Names, Names, Graph VName))
-> m (Names, Names, Graph VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Stm GPUMem -> m (Names, Names, Graph VName))
-> Stms GPUMem -> m (Seq (Names, Names, Graph VName))
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> Seq a -> m (Seq b)
mapM Stm GPUMem -> m (Names, Names, Graph VName)
forall (m :: * -> *).
LocalScope GPUMem m =>
Stm GPUMem -> m (Names, Names, Graph VName)
helper Stms GPUMem
stms
  where
    helper ::
      (LocalScope GPUMem m) =>
      Stm GPUMem ->
      m (InUse, LastUsed, Graph VName)
    helper :: forall (m :: * -> *).
LocalScope GPUMem m =>
Stm GPUMem -> m (Names, Names, Graph VName)
helper stm :: Stm GPUMem
stm@Let {stmExp :: forall rep. Stm rep -> Exp rep
stmExp = Op (Inner (SegOp SegOp SegLevel GPUMem
segop))} =
      Stm GPUMem
-> m (Names, Names, Graph VName) -> m (Names, Names, Graph VName)
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stm GPUMem
stm (m (Names, Names, Graph VName) -> m (Names, Names, Graph VName))
-> m (Names, Names, Graph VName) -> m (Names, Names, Graph VName)
forall a b. (a -> b) -> a -> b
$ LUTabFun
-> Names -> SegOp SegLevel GPUMem -> m (Names, Names, Graph VName)
forall (m :: * -> *) lvl.
LocalScope GPUMem m =>
LUTabFun
-> Names -> SegOp lvl GPUMem -> m (Names, Names, Graph VName)
analyseSegOp LUTabFun
lumap Names
forall a. Monoid a => a
mempty SegOp SegLevel GPUMem
segop
    helper stm :: Stm GPUMem
stm@Let {stmExp :: forall rep. Stm rep -> Exp rep
stmExp = Match [SubExp]
_ [Case (Body GPUMem)]
cases Body GPUMem
defbody MatchDec (BranchType GPUMem)
_} =
      Stm GPUMem
-> m (Names, Names, Graph VName) -> m (Names, Names, Graph VName)
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stm GPUMem
stm (m (Names, Names, Graph VName) -> m (Names, Names, Graph VName))
-> m (Names, Names, Graph VName) -> m (Names, Names, Graph VName)
forall a b. (a -> b) -> a -> b
$
        [(Names, Names, Graph VName)] -> (Names, Names, Graph VName)
forall a. Monoid a => [a] -> a
mconcat
          ([(Names, Names, Graph VName)] -> (Names, Names, Graph VName))
-> m [(Names, Names, Graph VName)] -> m (Names, Names, Graph VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Body GPUMem -> m (Names, Names, Graph VName))
-> [Body GPUMem] -> m [(Names, Names, Graph VName)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (LUTabFun -> Stms GPUMem -> m (Names, Names, Graph VName)
forall (m :: * -> *).
LocalScope GPUMem m =>
LUTabFun -> Stms GPUMem -> m (Names, Names, Graph VName)
analyseGPU' LUTabFun
lumap (Stms GPUMem -> m (Names, Names, Graph VName))
-> (Body GPUMem -> Stms GPUMem)
-> Body GPUMem
-> m (Names, Names, Graph VName)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body GPUMem -> Stms GPUMem
forall rep. Body rep -> Stms rep
bodyStms) (Body GPUMem
defbody Body GPUMem -> [Body GPUMem] -> [Body GPUMem]
forall a. a -> [a] -> [a]
: (Case (Body GPUMem) -> Body GPUMem)
-> [Case (Body GPUMem)] -> [Body GPUMem]
forall a b. (a -> b) -> [a] -> [b]
map Case (Body GPUMem) -> Body GPUMem
forall body. Case body -> body
caseBody [Case (Body GPUMem)]
cases)
    helper stm :: Stm GPUMem
stm@Let {stmExp :: forall rep. Stm rep -> Exp rep
stmExp = Loop [(FParam GPUMem, SubExp)]
merge LoopForm
_ Body GPUMem
body} =
      ((Names, Names, Graph VName) -> (Names, Names, Graph VName))
-> m (Names, Names, Graph VName) -> m (Names, Names, Graph VName)
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ([(FParam GPUMem, SubExp)]
-> (Names, Names, Graph VName) -> (Names, Names, Graph VName)
analyseLoopParams [(FParam GPUMem, SubExp)]
merge) (m (Names, Names, Graph VName) -> m (Names, Names, Graph VName))
-> (m (Names, Names, Graph VName) -> m (Names, Names, Graph VName))
-> m (Names, Names, Graph VName)
-> m (Names, Names, Graph VName)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm GPUMem
-> m (Names, Names, Graph VName) -> m (Names, Names, Graph VName)
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stm GPUMem
stm (m (Names, Names, Graph VName) -> m (Names, Names, Graph VName))
-> m (Names, Names, Graph VName) -> m (Names, Names, Graph VName)
forall a b. (a -> b) -> a -> b
$
        LUTabFun -> Stms GPUMem -> m (Names, Names, Graph VName)
forall (m :: * -> *).
LocalScope GPUMem m =>
LUTabFun -> Stms GPUMem -> m (Names, Names, Graph VName)
analyseGPU' LUTabFun
lumap (Stms GPUMem -> m (Names, Names, Graph VName))
-> Stms GPUMem -> m (Names, Names, Graph VName)
forall a b. (a -> b) -> a -> b
$
          Body GPUMem -> Stms GPUMem
forall rep. Body rep -> Stms rep
bodyStms Body GPUMem
body
    helper Stm GPUMem
stm =
      Stm GPUMem
-> m (Names, Names, Graph VName) -> m (Names, Names, Graph VName)
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stm GPUMem
stm (m (Names, Names, Graph VName) -> m (Names, Names, Graph VName))
-> m (Names, Names, Graph VName) -> m (Names, Names, Graph VName)
forall a b. (a -> b) -> a -> b
$ (Names, Names, Graph VName) -> m (Names, Names, Graph VName)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Names, Names, Graph VName)
forall a. Monoid a => a
mempty

nameInfoToMemInfo :: (Mem rep inner) => NameInfo rep -> MemBound NoUniqueness
nameInfoToMemInfo :: forall rep (inner :: * -> *).
Mem rep inner =>
NameInfo rep -> LetDecMem
nameInfoToMemInfo NameInfo rep
info =
  case NameInfo rep
info of
    FParamName FParamInfo rep
summary -> MemInfo SubExp Uniqueness MemBind -> LetDecMem
forall d u r. MemInfo d u r -> MemInfo d NoUniqueness r
noUniquenessReturns FParamInfo rep
MemInfo SubExp Uniqueness MemBind
summary
    LParamName LParamInfo rep
summary -> LParamInfo rep
LetDecMem
summary
    LetName LetDec rep
summary -> LetDec rep -> LetDecMem
forall t. HasLetDecMem t => t -> LetDecMem
letDecMem LetDec rep
summary
    IndexName IntType
it -> PrimType -> LetDecMem
forall d u ret. PrimType -> MemInfo d u ret
MemPrim (PrimType -> LetDecMem) -> PrimType -> LetDecMem
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
it

memInfo :: (LocalScope GPUMem m) => VName -> m (Maybe VName)
memInfo :: forall (m :: * -> *).
LocalScope GPUMem m =>
VName -> m (Maybe VName)
memInfo VName
vname = do
  Maybe LetDecMem
summary <- (Scope GPUMem -> Maybe LetDecMem) -> m (Maybe LetDecMem)
forall a. (Scope GPUMem -> a) -> m a
forall rep (m :: * -> *) a.
HasScope rep m =>
(Scope rep -> a) -> m a
asksScope ((NameInfo GPUMem -> LetDecMem)
-> Maybe (NameInfo GPUMem) -> Maybe LetDecMem
forall a b. (a -> b) -> Maybe a -> Maybe b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap NameInfo GPUMem -> LetDecMem
forall rep (inner :: * -> *).
Mem rep inner =>
NameInfo rep -> LetDecMem
nameInfoToMemInfo (Maybe (NameInfo GPUMem) -> Maybe LetDecMem)
-> (Scope GPUMem -> Maybe (NameInfo GPUMem))
-> Scope GPUMem
-> Maybe LetDecMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Scope GPUMem -> Maybe (NameInfo GPUMem)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
vname)
  case Maybe LetDecMem
summary of
    Just (MemArray PrimType
_ Shape
_ NoUniqueness
_ (ArrayIn VName
mem IxFun
_)) ->
      Maybe VName -> m (Maybe VName)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe VName -> m (Maybe VName)) -> Maybe VName -> m (Maybe VName)
forall a b. (a -> b) -> a -> b
$ VName -> Maybe VName
forall a. a -> Maybe a
Just VName
mem
    Maybe LetDecMem
_ ->
      Maybe VName -> m (Maybe VName)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe VName
forall a. Maybe a
Nothing

-- | Returns a mapping from memory block to element size. The input is the
-- `VName` of a variable (supposedly an array), and the result is a mapping from
-- the memory block of that array to element size of the array.
memElemSize :: (LocalScope GPUMem m) => VName -> m (Map VName Int)
memElemSize :: forall (m :: * -> *).
LocalScope GPUMem m =>
VName -> m (Map VName Int)
memElemSize VName
vname = do
  Maybe LetDecMem
summary <- (Scope GPUMem -> Maybe LetDecMem) -> m (Maybe LetDecMem)
forall a. (Scope GPUMem -> a) -> m a
forall rep (m :: * -> *) a.
HasScope rep m =>
(Scope rep -> a) -> m a
asksScope ((NameInfo GPUMem -> LetDecMem)
-> Maybe (NameInfo GPUMem) -> Maybe LetDecMem
forall a b. (a -> b) -> Maybe a -> Maybe b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap NameInfo GPUMem -> LetDecMem
forall rep (inner :: * -> *).
Mem rep inner =>
NameInfo rep -> LetDecMem
nameInfoToMemInfo (Maybe (NameInfo GPUMem) -> Maybe LetDecMem)
-> (Scope GPUMem -> Maybe (NameInfo GPUMem))
-> Scope GPUMem
-> Maybe LetDecMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Scope GPUMem -> Maybe (NameInfo GPUMem)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
vname)
  case Maybe LetDecMem
summary of
    Just (MemArray PrimType
pt Shape
_ NoUniqueness
_ (ArrayIn VName
mem IxFun
_)) ->
      Map VName Int -> m (Map VName Int)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Map VName Int -> m (Map VName Int))
-> Map VName Int -> m (Map VName Int)
forall a b. (a -> b) -> a -> b
$ VName -> Int -> Map VName Int
forall k a. k -> a -> Map k a
M.singleton VName
mem (PrimType -> Int
forall a. Num a => PrimType -> a
primByteSize PrimType
pt)
    Maybe LetDecMem
_ ->
      Map VName Int -> m (Map VName Int)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Map VName Int
forall a. Monoid a => a
mempty