{-# LANGUAGE TypeFamilies #-}

module Futhark.Analysis.MemAlias
  ( analyzeSeqMem,
    analyzeGPUMem,
    aliasesOf,
    MemAliases,
  )
where

import Control.Monad
import Control.Monad.Reader
import Data.Bifunctor
import Data.Function ((&))
import Data.Functor ((<&>))
import Data.Map qualified as M
import Data.Maybe (fromMaybe, mapMaybe)
import Data.Set qualified as S
import Futhark.IR.GPUMem
import Futhark.IR.SeqMem
import Futhark.Util
import Futhark.Util.Pretty

-- For our purposes, memory aliases are a bijective function: If @a@ aliases
-- @b@, @b@ also aliases @a@. However, this relationship is not transitive. Consider for instance the following:
--
-- @
--   let xs@mem_1 =
--     if ... then
--       replicate i 0 @ mem_2
--     else
--       replicate j 1 @ mem_3
-- @
--
-- Here, @mem_1@ aliases both @mem_2@ and @mem_3@, each of which alias @mem_1@
-- but not each other.
newtype MemAliases = MemAliases (M.Map VName Names)
  deriving (Int -> MemAliases -> ShowS
[MemAliases] -> ShowS
MemAliases -> String
(Int -> MemAliases -> ShowS)
-> (MemAliases -> String)
-> ([MemAliases] -> ShowS)
-> Show MemAliases
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> MemAliases -> ShowS
showsPrec :: Int -> MemAliases -> ShowS
$cshow :: MemAliases -> String
show :: MemAliases -> String
$cshowList :: [MemAliases] -> ShowS
showList :: [MemAliases] -> ShowS
Show, MemAliases -> MemAliases -> Bool
(MemAliases -> MemAliases -> Bool)
-> (MemAliases -> MemAliases -> Bool) -> Eq MemAliases
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: MemAliases -> MemAliases -> Bool
== :: MemAliases -> MemAliases -> Bool
$c/= :: MemAliases -> MemAliases -> Bool
/= :: MemAliases -> MemAliases -> Bool
Eq)

instance Semigroup MemAliases where
  (MemAliases Map VName Names
m1) <> :: MemAliases -> MemAliases -> MemAliases
<> (MemAliases Map VName Names
m2) = Map VName Names -> MemAliases
MemAliases (Map VName Names -> MemAliases) -> Map VName Names -> MemAliases
forall a b. (a -> b) -> a -> b
$ (Names -> Names -> Names)
-> Map VName Names -> Map VName Names -> Map VName Names
forall k a. Ord k => (a -> a -> a) -> Map k a -> Map k a -> Map k a
M.unionWith Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
(<>) Map VName Names
m1 Map VName Names
m2

instance Monoid MemAliases where
  mempty :: MemAliases
mempty = Map VName Names -> MemAliases
MemAliases Map VName Names
forall a. Monoid a => a
mempty

instance Pretty MemAliases where
  pretty :: forall ann. MemAliases -> Doc ann
pretty (MemAliases Map VName Names
m) = [Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
stack ([Doc ann] -> Doc ann) -> [Doc ann] -> Doc ann
forall a b. (a -> b) -> a -> b
$ ((VName, Names) -> Doc ann) -> [(VName, Names)] -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map (VName, Names) -> Doc ann
forall {a} {a} {ann}. (Pretty a, Pretty a) => (a, a) -> Doc ann
f ([(VName, Names)] -> [Doc ann]) -> [(VName, Names)] -> [Doc ann]
forall a b. (a -> b) -> a -> b
$ Map VName Names -> [(VName, Names)]
forall k a. Map k a -> [(k, a)]
M.toList Map VName Names
m
    where
      f :: (a, a) -> Doc ann
f (a
v, a
vs) = a -> Doc ann
forall ann. a -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty a
v Doc ann -> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann -> Doc ann
<+> Doc ann
"aliases:" Doc ann -> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann -> Doc ann
</> Int -> Doc ann -> Doc ann
forall ann. Int -> Doc ann -> Doc ann
indent Int
2 (Doc ann -> Doc ann
forall a. Doc a -> Doc a
oneLine (Doc ann -> Doc ann) -> Doc ann -> Doc ann
forall a b. (a -> b) -> a -> b
$ a -> Doc ann
forall ann. a -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty a
vs)

addAlias :: VName -> VName -> MemAliases -> MemAliases
addAlias :: VName -> VName -> MemAliases -> MemAliases
addAlias VName
v1 VName
v2 MemAliases
m =
  MemAliases
m MemAliases -> MemAliases -> MemAliases
forall a. Semigroup a => a -> a -> a
<> VName -> Names -> MemAliases
singleton VName
v1 (VName -> Names
oneName VName
v2) MemAliases -> MemAliases -> MemAliases
forall a. Semigroup a => a -> a -> a
<> VName -> Names -> MemAliases
singleton VName
v2 Names
forall a. Monoid a => a
mempty

singleton :: VName -> Names -> MemAliases
singleton :: VName -> Names -> MemAliases
singleton VName
v Names
ns = Map VName Names -> MemAliases
MemAliases (Map VName Names -> MemAliases) -> Map VName Names -> MemAliases
forall a b. (a -> b) -> a -> b
$ VName -> Names -> Map VName Names
forall k a. k -> a -> Map k a
M.singleton VName
v Names
ns

aliasesOf :: MemAliases -> VName -> Names
aliasesOf :: MemAliases -> VName -> Names
aliasesOf (MemAliases Map VName Names
m) VName
v = Names -> Maybe Names -> Names
forall a. a -> Maybe a -> a
fromMaybe Names
forall a. Monoid a => a
mempty (Maybe Names -> Names) -> Maybe Names -> Names
forall a b. (a -> b) -> a -> b
$ VName -> Map VName Names -> Maybe Names
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName Names
m

isIn :: VName -> MemAliases -> Bool
isIn :: VName -> MemAliases -> Bool
isIn VName
v (MemAliases Map VName Names
m) = VName
v VName -> Set VName -> Bool
forall a. Ord a => a -> Set a -> Bool
`S.member` Map VName Names -> Set VName
forall k a. Map k a -> Set k
M.keysSet Map VName Names
m

newtype Env inner = Env {forall inner.
Env inner -> MemAliases -> inner -> MemAliasesM inner MemAliases
onInner :: MemAliases -> inner -> MemAliasesM inner MemAliases}

type MemAliasesM inner a = Reader (Env inner) a

analyzeHostOp :: MemAliases -> HostOp NoOp GPUMem -> MemAliasesM (HostOp NoOp GPUMem) MemAliases
analyzeHostOp :: MemAliases
-> HostOp NoOp GPUMem
-> MemAliasesM (HostOp NoOp GPUMem) MemAliases
analyzeHostOp MemAliases
m (SegOp (SegMap SegLevel
_ SegSpace
_ [Type]
_ KernelBody GPUMem
kbody)) =
  Stms GPUMem
-> MemAliases -> MemAliasesM (HostOp NoOp GPUMem) MemAliases
forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Stms rep -> MemAliases -> MemAliasesM (inner rep) MemAliases
analyzeStms (KernelBody GPUMem -> Stms GPUMem
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
kbody) MemAliases
m
analyzeHostOp MemAliases
m (SegOp (SegRed SegLevel
_ SegSpace
_ [SegBinOp GPUMem]
_ [Type]
_ KernelBody GPUMem
kbody)) =
  Stms GPUMem
-> MemAliases -> MemAliasesM (HostOp NoOp GPUMem) MemAliases
forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Stms rep -> MemAliases -> MemAliasesM (inner rep) MemAliases
analyzeStms (KernelBody GPUMem -> Stms GPUMem
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
kbody) MemAliases
m
analyzeHostOp MemAliases
m (SegOp (SegScan SegLevel
_ SegSpace
_ [SegBinOp GPUMem]
_ [Type]
_ KernelBody GPUMem
kbody)) =
  Stms GPUMem
-> MemAliases -> MemAliasesM (HostOp NoOp GPUMem) MemAliases
forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Stms rep -> MemAliases -> MemAliasesM (inner rep) MemAliases
analyzeStms (KernelBody GPUMem -> Stms GPUMem
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
kbody) MemAliases
m
analyzeHostOp MemAliases
m (SegOp (SegHist SegLevel
_ SegSpace
_ [HistOp GPUMem]
_ [Type]
_ KernelBody GPUMem
kbody)) =
  Stms GPUMem
-> MemAliases -> MemAliasesM (HostOp NoOp GPUMem) MemAliases
forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Stms rep -> MemAliases -> MemAliasesM (inner rep) MemAliases
analyzeStms (KernelBody GPUMem -> Stms GPUMem
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
kbody) MemAliases
m
analyzeHostOp MemAliases
m SizeOp {} = MemAliases -> MemAliasesM (HostOp NoOp GPUMem) MemAliases
forall a. a -> ReaderT (Env (HostOp NoOp GPUMem)) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure MemAliases
m
analyzeHostOp MemAliases
m GPUBody {} = MemAliases -> MemAliasesM (HostOp NoOp GPUMem) MemAliases
forall a. a -> ReaderT (Env (HostOp NoOp GPUMem)) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure MemAliases
m
analyzeHostOp MemAliases
m (OtherOp NoOp GPUMem
NoOp) = MemAliases -> MemAliasesM (HostOp NoOp GPUMem) MemAliases
forall a. a -> ReaderT (Env (HostOp NoOp GPUMem)) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure MemAliases
m

analyzeStm ::
  (Mem rep inner, LetDec rep ~ LetDecMem) =>
  MemAliases ->
  Stm rep ->
  MemAliasesM (inner rep) MemAliases
analyzeStm :: forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem) =>
MemAliases -> Stm rep -> MemAliasesM (inner rep) MemAliases
analyzeStm MemAliases
m (Let (Pat [PatElem VName
vname LetDec rep
_]) StmAux (ExpDec rep)
_ (Op (Alloc SubExp
_ Space
_))) =
  MemAliases -> ReaderT (Env (inner rep)) Identity MemAliases
forall a. a -> ReaderT (Env (inner rep)) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (MemAliases -> ReaderT (Env (inner rep)) Identity MemAliases)
-> MemAliases -> ReaderT (Env (inner rep)) Identity MemAliases
forall a b. (a -> b) -> a -> b
$ MemAliases
m MemAliases -> MemAliases -> MemAliases
forall a. Semigroup a => a -> a -> a
<> VName -> Names -> MemAliases
singleton VName
vname Names
forall a. Monoid a => a
mempty
analyzeStm MemAliases
m (Let Pat (LetDec rep)
_ StmAux (ExpDec rep)
_ (Op (Inner inner rep
inner))) = do
  MemAliases
-> inner rep -> ReaderT (Env (inner rep)) Identity MemAliases
on_inner <- (Env (inner rep)
 -> MemAliases
 -> inner rep
 -> ReaderT (Env (inner rep)) Identity MemAliases)
-> ReaderT
     (Env (inner rep))
     Identity
     (MemAliases
      -> inner rep -> ReaderT (Env (inner rep)) Identity MemAliases)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env (inner rep)
-> MemAliases
-> inner rep
-> ReaderT (Env (inner rep)) Identity MemAliases
forall inner.
Env inner -> MemAliases -> inner -> MemAliasesM inner MemAliases
onInner
  MemAliases
-> inner rep -> ReaderT (Env (inner rep)) Identity MemAliases
on_inner MemAliases
m inner rep
inner
analyzeStm MemAliases
m (Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ (Match [SubExp]
_ [Case (Body rep)]
cases Body rep
defbody MatchDec (BranchType rep)
_)) = do
  let bodies :: [Body rep]
bodies = Body rep
defbody Body rep -> [Body rep] -> [Body rep]
forall a. a -> [a] -> [a]
: (Case (Body rep) -> Body rep) -> [Case (Body rep)] -> [Body rep]
forall a b. (a -> b) -> [a] -> [b]
map Case (Body rep) -> Body rep
forall body. Case body -> body
caseBody [Case (Body rep)]
cases
  MemAliases
m' <- (MemAliases
 -> Stms rep -> ReaderT (Env (inner rep)) Identity MemAliases)
-> MemAliases
-> [Stms rep]
-> ReaderT (Env (inner rep)) Identity MemAliases
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM ((Stms rep
 -> MemAliases -> ReaderT (Env (inner rep)) Identity MemAliases)
-> MemAliases
-> Stms rep
-> ReaderT (Env (inner rep)) Identity MemAliases
forall a b c. (a -> b -> c) -> b -> a -> c
flip Stms rep
-> MemAliases -> ReaderT (Env (inner rep)) Identity MemAliases
forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Stms rep -> MemAliases -> MemAliasesM (inner rep) MemAliases
analyzeStms) MemAliases
m ([Stms rep] -> ReaderT (Env (inner rep)) Identity MemAliases)
-> [Stms rep] -> ReaderT (Env (inner rep)) Identity MemAliases
forall a b. (a -> b) -> a -> b
$ (Body rep -> Stms rep) -> [Body rep] -> [Stms rep]
forall a b. (a -> b) -> [a] -> [b]
map Body rep -> Stms rep
forall rep. Body rep -> Stms rep
bodyStms [Body rep]
bodies
  (Body rep -> [(VName, SubExp)]) -> [Body rep] -> [(VName, SubExp)]
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pat LetDecMem -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat (LetDec rep)
Pat LetDecMem
pat) ([SubExp] -> [(VName, SubExp)])
-> (Body rep -> [SubExp]) -> Body rep -> [(VName, SubExp)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SubExpRes -> SubExp) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp ([SubExpRes] -> [SubExp])
-> (Body rep -> [SubExpRes]) -> Body rep -> [SubExp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body rep -> [SubExpRes]
forall rep. Body rep -> [SubExpRes]
bodyResult) [Body rep]
bodies
    [(VName, SubExp)]
-> ([(VName, SubExp)] -> [(VName, VName)]) -> [(VName, VName)]
forall a b. a -> (a -> b) -> b
& ((VName, SubExp) -> Maybe (VName, VName))
-> [(VName, SubExp)] -> [(VName, VName)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (MemAliases -> (VName, SubExp) -> Maybe (VName, VName)
filterFun MemAliases
m')
    [(VName, VName)] -> ([(VName, VName)] -> MemAliases) -> MemAliases
forall a b. a -> (a -> b) -> b
& ((VName, VName) -> MemAliases -> MemAliases)
-> MemAliases -> [(VName, VName)] -> MemAliases
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr ((VName -> VName -> MemAliases -> MemAliases)
-> (VName, VName) -> MemAliases -> MemAliases
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry VName -> VName -> MemAliases -> MemAliases
addAlias) MemAliases
m'
    MemAliases
-> (MemAliases -> ReaderT (Env (inner rep)) Identity MemAliases)
-> ReaderT (Env (inner rep)) Identity MemAliases
forall a b. a -> (a -> b) -> b
& MemAliases -> ReaderT (Env (inner rep)) Identity MemAliases
forall a. a -> ReaderT (Env (inner rep)) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
analyzeStm MemAliases
m (Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ (Loop [(FParam rep, SubExp)]
params LoopForm
_ Body rep
body)) = do
  let m_init :: MemAliases
m_init =
        ((Param FParamMem, SubExp) -> SubExp)
-> [(Param FParamMem, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (Param FParamMem, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(FParam rep, SubExp)]
[(Param FParamMem, SubExp)]
params
          [SubExp] -> ([SubExp] -> [(VName, SubExp)]) -> [(VName, SubExp)]
forall a b. a -> (a -> b) -> b
& [VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pat LetDecMem -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat (LetDec rep)
Pat LetDecMem
pat)
          [(VName, SubExp)]
-> ([(VName, SubExp)] -> [(VName, VName)]) -> [(VName, VName)]
forall a b. a -> (a -> b) -> b
& ((VName, SubExp) -> Maybe (VName, VName))
-> [(VName, SubExp)] -> [(VName, VName)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (MemAliases -> (VName, SubExp) -> Maybe (VName, VName)
filterFun MemAliases
m)
          [(VName, VName)] -> ([(VName, VName)] -> MemAliases) -> MemAliases
forall a b. a -> (a -> b) -> b
& ((VName, VName) -> MemAliases -> MemAliases)
-> MemAliases -> [(VName, VName)] -> MemAliases
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr ((VName -> VName -> MemAliases -> MemAliases)
-> (VName, VName) -> MemAliases -> MemAliases
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry VName -> VName -> MemAliases -> MemAliases
addAlias) MemAliases
m
      m_params :: MemAliases
m_params =
        ((Param FParamMem, SubExp) -> Maybe (VName, VName))
-> [(Param FParamMem, SubExp)] -> [(VName, VName)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (MemAliases -> (VName, SubExp) -> Maybe (VName, VName)
filterFun MemAliases
m_init ((VName, SubExp) -> Maybe (VName, VName))
-> ((Param FParamMem, SubExp) -> (VName, SubExp))
-> (Param FParamMem, SubExp)
-> Maybe (VName, VName)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param FParamMem -> VName)
-> (Param FParamMem, SubExp) -> (VName, SubExp)
forall a b c. (a -> b) -> (a, c) -> (b, c)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first Param FParamMem -> VName
forall dec. Param dec -> VName
paramName) [(FParam rep, SubExp)]
[(Param FParamMem, SubExp)]
params
          [(VName, VName)] -> ([(VName, VName)] -> MemAliases) -> MemAliases
forall a b. a -> (a -> b) -> b
& ((VName, VName) -> MemAliases -> MemAliases)
-> MemAliases -> [(VName, VName)] -> MemAliases
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr ((VName -> VName -> MemAliases -> MemAliases)
-> (VName, VName) -> MemAliases -> MemAliases
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry VName -> VName -> MemAliases -> MemAliases
addAlias) MemAliases
m_init
  MemAliases
m_body <- Stms rep
-> MemAliases -> ReaderT (Env (inner rep)) Identity MemAliases
forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Stms rep -> MemAliases -> MemAliasesM (inner rep) MemAliases
analyzeStms (Body rep -> Stms rep
forall rep. Body rep -> Stms rep
bodyStms Body rep
body) MemAliases
m_params
  [VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pat LetDecMem -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat (LetDec rep)
Pat LetDecMem
pat) ((SubExpRes -> SubExp) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp ([SubExpRes] -> [SubExp]) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Body rep -> [SubExpRes]
forall rep. Body rep -> [SubExpRes]
bodyResult Body rep
body)
    [(VName, SubExp)]
-> ([(VName, SubExp)] -> [(VName, VName)]) -> [(VName, VName)]
forall a b. a -> (a -> b) -> b
& ((VName, SubExp) -> Maybe (VName, VName))
-> [(VName, SubExp)] -> [(VName, VName)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (MemAliases -> (VName, SubExp) -> Maybe (VName, VName)
filterFun MemAliases
m_body)
    [(VName, VName)] -> ([(VName, VName)] -> MemAliases) -> MemAliases
forall a b. a -> (a -> b) -> b
& ((VName, VName) -> MemAliases -> MemAliases)
-> MemAliases -> [(VName, VName)] -> MemAliases
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr ((VName -> VName -> MemAliases -> MemAliases)
-> (VName, VName) -> MemAliases -> MemAliases
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry VName -> VName -> MemAliases -> MemAliases
addAlias) MemAliases
m_body
    MemAliases
-> (MemAliases -> ReaderT (Env (inner rep)) Identity MemAliases)
-> ReaderT (Env (inner rep)) Identity MemAliases
forall a b. a -> (a -> b) -> b
& MemAliases -> ReaderT (Env (inner rep)) Identity MemAliases
forall a. a -> ReaderT (Env (inner rep)) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
analyzeStm MemAliases
m Stm rep
_ = MemAliases -> ReaderT (Env (inner rep)) Identity MemAliases
forall a. a -> ReaderT (Env (inner rep)) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure MemAliases
m

filterFun :: MemAliases -> (VName, SubExp) -> Maybe (VName, VName)
filterFun :: MemAliases -> (VName, SubExp) -> Maybe (VName, VName)
filterFun MemAliases
m' (VName
v, Var VName
v') | VName
v' VName -> MemAliases -> Bool
`isIn` MemAliases
m' = (VName, VName) -> Maybe (VName, VName)
forall a. a -> Maybe a
Just (VName
v, VName
v')
filterFun MemAliases
_ (VName, SubExp)
_ = Maybe (VName, VName)
forall a. Maybe a
Nothing

analyzeStms ::
  (Mem rep inner, LetDec rep ~ LetDecMem) =>
  Stms rep ->
  MemAliases ->
  MemAliasesM (inner rep) MemAliases
analyzeStms :: forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Stms rep -> MemAliases -> MemAliasesM (inner rep) MemAliases
analyzeStms =
  (MemAliases -> Stms rep -> MemAliasesM (inner rep) MemAliases)
-> Stms rep -> MemAliases -> MemAliasesM (inner rep) MemAliases
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((MemAliases -> Stms rep -> MemAliasesM (inner rep) MemAliases)
 -> Stms rep -> MemAliases -> MemAliasesM (inner rep) MemAliases)
-> (MemAliases -> Stms rep -> MemAliasesM (inner rep) MemAliases)
-> Stms rep
-> MemAliases
-> MemAliasesM (inner rep) MemAliases
forall a b. (a -> b) -> a -> b
$ (MemAliases -> Stm rep -> MemAliasesM (inner rep) MemAliases)
-> MemAliases -> Stms rep -> MemAliasesM (inner rep) MemAliases
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM MemAliases -> Stm rep -> MemAliasesM (inner rep) MemAliases
forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem) =>
MemAliases -> Stm rep -> MemAliasesM (inner rep) MemAliases
analyzeStm

analyzeFun ::
  (Mem rep inner, LetDec rep ~ LetDecMem) =>
  FunDef rep ->
  MemAliasesM (inner rep) (Name, MemAliases)
analyzeFun :: forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem) =>
FunDef rep -> MemAliasesM (inner rep) (Name, MemAliases)
analyzeFun FunDef rep
f =
  FunDef rep -> [FParam rep]
forall rep. FunDef rep -> [FParam rep]
funDefParams FunDef rep
f
    [Param FParamMem]
-> ([Param FParamMem] -> [MemAliases]) -> [MemAliases]
forall a b. a -> (a -> b) -> b
& (Param FParamMem -> Maybe MemAliases)
-> [Param FParamMem] -> [MemAliases]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe Param FParamMem -> Maybe MemAliases
forall {d} {u} {ret}. Param (MemInfo d u ret) -> Maybe MemAliases
justMem
    [MemAliases] -> ([MemAliases] -> MemAliases) -> MemAliases
forall a b. a -> (a -> b) -> b
& [MemAliases] -> MemAliases
forall a. Monoid a => [a] -> a
mconcat
    MemAliases
-> (MemAliases -> ReaderT (Env (inner rep)) Identity MemAliases)
-> ReaderT (Env (inner rep)) Identity MemAliases
forall a b. a -> (a -> b) -> b
& Stms rep
-> MemAliases -> ReaderT (Env (inner rep)) Identity MemAliases
forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Stms rep -> MemAliases -> MemAliasesM (inner rep) MemAliases
analyzeStms (Body rep -> Stms rep
forall rep. Body rep -> Stms rep
bodyStms (Body rep -> Stms rep) -> Body rep -> Stms rep
forall a b. (a -> b) -> a -> b
$ FunDef rep -> Body rep
forall rep. FunDef rep -> Body rep
funDefBody FunDef rep
f)
    ReaderT (Env (inner rep)) Identity MemAliases
-> (MemAliases -> (Name, MemAliases))
-> ReaderT (Env (inner rep)) Identity (Name, MemAliases)
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> (FunDef rep -> Name
forall rep. FunDef rep -> Name
funDefName FunDef rep
f,)
  where
    justMem :: Param (MemInfo d u ret) -> Maybe MemAliases
justMem (Param Attrs
_ VName
v (MemMem Space
_)) = MemAliases -> Maybe MemAliases
forall a. a -> Maybe a
Just (MemAliases -> Maybe MemAliases) -> MemAliases -> Maybe MemAliases
forall a b. (a -> b) -> a -> b
$ VName -> Names -> MemAliases
singleton VName
v Names
forall a. Monoid a => a
mempty
    justMem Param (MemInfo d u ret)
_ = Maybe MemAliases
forall a. Maybe a
Nothing

transitiveClosure :: MemAliases -> MemAliases
transitiveClosure :: MemAliases -> MemAliases
transitiveClosure ma :: MemAliases
ma@(MemAliases Map VName Names
m) =
  (VName -> Names -> MemAliases) -> Map VName Names -> MemAliases
forall m k a. Monoid m => (k -> a -> m) -> Map k a -> m
M.foldMapWithKey
    ( \VName
k Names
ns ->
        Names -> [VName]
namesToList Names
ns
          [VName] -> ([VName] -> Names) -> Names
forall a b. a -> (a -> b) -> b
& (VName -> Names) -> [VName] -> Names
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (MemAliases -> VName -> Names
aliasesOf MemAliases
ma)
          Names -> (Names -> MemAliases) -> MemAliases
forall a b. a -> (a -> b) -> b
& VName -> Names -> MemAliases
singleton VName
k
    )
    Map VName Names
m
    MemAliases -> MemAliases -> MemAliases
forall a. Semigroup a => a -> a -> a
<> MemAliases
ma

-- | Produce aliases for constants and for each function.
analyzeSeqMem :: Prog SeqMem -> (MemAliases, M.Map Name MemAliases)
analyzeSeqMem :: Prog SeqMem -> (MemAliases, Map Name MemAliases)
analyzeSeqMem Prog SeqMem
prog = (MemAliases, Map Name MemAliases)
-> (MemAliases, Map Name MemAliases)
completeBijection ((MemAliases, Map Name MemAliases)
 -> (MemAliases, Map Name MemAliases))
-> (MemAliases, Map Name MemAliases)
-> (MemAliases, Map Name MemAliases)
forall a b. (a -> b) -> a -> b
$ Reader (Env (NoOp SeqMem)) (MemAliases, Map Name MemAliases)
-> Env (NoOp SeqMem) -> (MemAliases, Map Name MemAliases)
forall r a. Reader r a -> r -> a
runReader (Prog SeqMem
-> Reader (Env (NoOp SeqMem)) (MemAliases, Map Name MemAliases)
forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Prog rep
-> MemAliasesM (inner rep) (MemAliases, Map Name MemAliases)
analyze Prog SeqMem
prog) (Env (NoOp SeqMem) -> (MemAliases, Map Name MemAliases))
-> Env (NoOp SeqMem) -> (MemAliases, Map Name MemAliases)
forall a b. (a -> b) -> a -> b
$ (MemAliases -> NoOp SeqMem -> MemAliasesM (NoOp SeqMem) MemAliases)
-> Env (NoOp SeqMem)
forall inner.
(MemAliases -> inner -> MemAliasesM inner MemAliases) -> Env inner
Env ((MemAliases
  -> NoOp SeqMem -> MemAliasesM (NoOp SeqMem) MemAliases)
 -> Env (NoOp SeqMem))
-> (MemAliases
    -> NoOp SeqMem -> MemAliasesM (NoOp SeqMem) MemAliases)
-> Env (NoOp SeqMem)
forall a b. (a -> b) -> a -> b
$ \MemAliases
x NoOp SeqMem
_ -> MemAliases -> MemAliasesM (NoOp SeqMem) MemAliases
forall a. a -> ReaderT (Env (NoOp SeqMem)) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure MemAliases
x

-- | Produce aliases for constants and for each function.
analyzeGPUMem :: Prog GPUMem -> (MemAliases, M.Map Name MemAliases)
analyzeGPUMem :: Prog GPUMem -> (MemAliases, Map Name MemAliases)
analyzeGPUMem Prog GPUMem
prog = (MemAliases, Map Name MemAliases)
-> (MemAliases, Map Name MemAliases)
completeBijection ((MemAliases, Map Name MemAliases)
 -> (MemAliases, Map Name MemAliases))
-> (MemAliases, Map Name MemAliases)
-> (MemAliases, Map Name MemAliases)
forall a b. (a -> b) -> a -> b
$ Reader (Env (HostOp NoOp GPUMem)) (MemAliases, Map Name MemAliases)
-> Env (HostOp NoOp GPUMem) -> (MemAliases, Map Name MemAliases)
forall r a. Reader r a -> r -> a
runReader (Prog GPUMem
-> Reader
     (Env (HostOp NoOp GPUMem)) (MemAliases, Map Name MemAliases)
forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Prog rep
-> MemAliasesM (inner rep) (MemAliases, Map Name MemAliases)
analyze Prog GPUMem
prog) (Env (HostOp NoOp GPUMem) -> (MemAliases, Map Name MemAliases))
-> Env (HostOp NoOp GPUMem) -> (MemAliases, Map Name MemAliases)
forall a b. (a -> b) -> a -> b
$ (MemAliases
 -> HostOp NoOp GPUMem
 -> MemAliasesM (HostOp NoOp GPUMem) MemAliases)
-> Env (HostOp NoOp GPUMem)
forall inner.
(MemAliases -> inner -> MemAliasesM inner MemAliases) -> Env inner
Env MemAliases
-> HostOp NoOp GPUMem
-> MemAliasesM (HostOp NoOp GPUMem) MemAliases
analyzeHostOp

analyze ::
  (Mem rep inner, LetDec rep ~ LetDecMem) =>
  Prog rep ->
  MemAliasesM (inner rep) (MemAliases, M.Map Name MemAliases)
analyze :: forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Prog rep
-> MemAliasesM (inner rep) (MemAliases, Map Name MemAliases)
analyze Prog rep
prog =
  (,)
    (MemAliases
 -> Map Name MemAliases -> (MemAliases, Map Name MemAliases))
-> ReaderT (Env (inner rep)) Identity MemAliases
-> ReaderT
     (Env (inner rep))
     Identity
     (Map Name MemAliases -> (MemAliases, Map Name MemAliases))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Prog rep -> Stms rep
forall rep. Prog rep -> Stms rep
progConsts Prog rep
prog Stms rep
-> (Stms rep -> ReaderT (Env (inner rep)) Identity MemAliases)
-> ReaderT (Env (inner rep)) Identity MemAliases
forall a b. a -> (a -> b) -> b
& (Stms rep
 -> MemAliases -> ReaderT (Env (inner rep)) Identity MemAliases)
-> MemAliases
-> Stms rep
-> ReaderT (Env (inner rep)) Identity MemAliases
forall a b c. (a -> b -> c) -> b -> a -> c
flip Stms rep
-> MemAliases -> ReaderT (Env (inner rep)) Identity MemAliases
forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Stms rep -> MemAliases -> MemAliasesM (inner rep) MemAliases
analyzeStms MemAliases
forall a. Monoid a => a
mempty ReaderT (Env (inner rep)) Identity MemAliases
-> (MemAliases -> MemAliases)
-> ReaderT (Env (inner rep)) Identity MemAliases
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> (MemAliases -> MemAliases) -> MemAliases -> MemAliases
forall a. Eq a => (a -> a) -> a -> a
fixPoint MemAliases -> MemAliases
transitiveClosure)
    ReaderT
  (Env (inner rep))
  Identity
  (Map Name MemAliases -> (MemAliases, Map Name MemAliases))
-> ReaderT (Env (inner rep)) Identity (Map Name MemAliases)
-> ReaderT
     (Env (inner rep)) Identity (MemAliases, Map Name MemAliases)
forall a b.
ReaderT (Env (inner rep)) Identity (a -> b)
-> ReaderT (Env (inner rep)) Identity a
-> ReaderT (Env (inner rep)) Identity b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (Prog rep -> [FunDef rep]
forall rep. Prog rep -> [FunDef rep]
progFuns Prog rep
prog [FunDef rep]
-> ([FunDef rep]
    -> ReaderT (Env (inner rep)) Identity [(Name, MemAliases)])
-> ReaderT (Env (inner rep)) Identity [(Name, MemAliases)]
forall a b. a -> (a -> b) -> b
& (FunDef rep
 -> ReaderT (Env (inner rep)) Identity (Name, MemAliases))
-> [FunDef rep]
-> ReaderT (Env (inner rep)) Identity [(Name, MemAliases)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM FunDef rep -> ReaderT (Env (inner rep)) Identity (Name, MemAliases)
forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem) =>
FunDef rep -> MemAliasesM (inner rep) (Name, MemAliases)
analyzeFun ReaderT (Env (inner rep)) Identity [(Name, MemAliases)]
-> ([(Name, MemAliases)] -> Map Name MemAliases)
-> ReaderT (Env (inner rep)) Identity (Map Name MemAliases)
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> [(Name, MemAliases)] -> Map Name MemAliases
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ReaderT (Env (inner rep)) Identity (Map Name MemAliases)
-> (Map Name MemAliases -> Map Name MemAliases)
-> ReaderT (Env (inner rep)) Identity (Map Name MemAliases)
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> (MemAliases -> MemAliases)
-> Map Name MemAliases -> Map Name MemAliases
forall a b k. (a -> b) -> Map k a -> Map k b
M.map ((MemAliases -> MemAliases) -> MemAliases -> MemAliases
forall a. Eq a => (a -> a) -> a -> a
fixPoint MemAliases -> MemAliases
transitiveClosure))

completeBijection :: (MemAliases, M.Map Name MemAliases) -> (MemAliases, M.Map Name MemAliases)
completeBijection :: (MemAliases, Map Name MemAliases)
-> (MemAliases, Map Name MemAliases)
completeBijection (MemAliases
a, Map Name MemAliases
bs) = (MemAliases -> MemAliases
f MemAliases
a, (MemAliases -> MemAliases)
-> Map Name MemAliases -> Map Name MemAliases
forall a b. (a -> b) -> Map Name a -> Map Name b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap MemAliases -> MemAliases
f Map Name MemAliases
bs)
  where
    f :: MemAliases -> MemAliases
f ma :: MemAliases
ma@(MemAliases Map VName Names
m) =
      (VName -> Names -> MemAliases) -> Map VName Names -> MemAliases
forall m k a. Monoid m => (k -> a -> m) -> Map k a -> m
M.foldMapWithKey (\VName
k Names
ns -> (VName -> MemAliases) -> [VName] -> MemAliases
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (VName -> Names -> MemAliases
`singleton` VName -> Names
oneName VName
k) (Names -> [VName]
namesToList Names
ns)) Map VName Names
m MemAliases -> MemAliases -> MemAliases
forall a. Semigroup a => a -> a -> a
<> MemAliases
ma