{-# LANGUAGE TypeFamilies #-}

module Futhark.Analysis.MemAlias
  ( analyzeSeqMem,
    analyzeGPUMem,
    aliasesOf,
    MemAliases,
  )
where

import Control.Monad.Reader
import Data.Bifunctor
import Data.Function ((&))
import Data.Functor ((<&>))
import Data.Map qualified as M
import Data.Maybe (fromMaybe, mapMaybe)
import Data.Set qualified as S
import Futhark.IR.GPUMem
import Futhark.IR.SeqMem
import Futhark.Util
import Futhark.Util.Pretty

-- For our purposes, memory aliases are a bijective function: If @a@ aliases
-- @b@, @b@ also aliases @a@. However, this relationship is not transitive. Consider for instance the following:
--
-- @
--   let xs@mem_1 =
--     if ... then
--       replicate i 0 @ mem_2
--     else
--       replicate j 1 @ mem_3
-- @
--
-- Here, @mem_1@ aliases both @mem_2@ and @mem_3@, each of which alias @mem_1@
-- but not each other.
newtype MemAliases = MemAliases (M.Map VName Names)
  deriving (Int -> MemAliases -> ShowS
[MemAliases] -> ShowS
MemAliases -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [MemAliases] -> ShowS
$cshowList :: [MemAliases] -> ShowS
show :: MemAliases -> String
$cshow :: MemAliases -> String
showsPrec :: Int -> MemAliases -> ShowS
$cshowsPrec :: Int -> MemAliases -> ShowS
Show, MemAliases -> MemAliases -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: MemAliases -> MemAliases -> Bool
$c/= :: MemAliases -> MemAliases -> Bool
== :: MemAliases -> MemAliases -> Bool
$c== :: MemAliases -> MemAliases -> Bool
Eq)

instance Semigroup MemAliases where
  (MemAliases Map VName Names
m1) <> :: MemAliases -> MemAliases -> MemAliases
<> (MemAliases Map VName Names
m2) = Map VName Names -> MemAliases
MemAliases forall a b. (a -> b) -> a -> b
$ forall k a. Ord k => (a -> a -> a) -> Map k a -> Map k a -> Map k a
M.unionWith forall a. Semigroup a => a -> a -> a
(<>) Map VName Names
m1 Map VName Names
m2

instance Monoid MemAliases where
  mempty :: MemAliases
mempty = Map VName Names -> MemAliases
MemAliases forall a. Monoid a => a
mempty

instance Pretty MemAliases where
  pretty :: forall ann. MemAliases -> Doc ann
pretty (MemAliases Map VName Names
m) = forall a. [Doc a] -> Doc a
stack forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall {a} {a} {ann}. (Pretty a, Pretty a) => (a, a) -> Doc ann
f forall a b. (a -> b) -> a -> b
$ forall k a. Map k a -> [(k, a)]
M.toList Map VName Names
m
    where
      f :: (a, a) -> Doc ann
f (a
v, a
vs) = forall a ann. Pretty a => a -> Doc ann
pretty a
v forall ann. Doc ann -> Doc ann -> Doc ann
<+> Doc ann
"aliases:" forall ann. Doc ann -> Doc ann -> Doc ann
</> forall ann. Int -> Doc ann -> Doc ann
indent Int
2 (forall a. Doc a -> Doc a
oneLine forall a b. (a -> b) -> a -> b
$ forall a ann. Pretty a => a -> Doc ann
pretty a
vs)

addAlias :: VName -> VName -> MemAliases -> MemAliases
addAlias :: VName -> VName -> MemAliases -> MemAliases
addAlias VName
v1 VName
v2 MemAliases
m =
  MemAliases
m forall a. Semigroup a => a -> a -> a
<> VName -> Names -> MemAliases
singleton VName
v1 (VName -> Names
oneName VName
v2) forall a. Semigroup a => a -> a -> a
<> VName -> Names -> MemAliases
singleton VName
v2 forall a. Monoid a => a
mempty

singleton :: VName -> Names -> MemAliases
singleton :: VName -> Names -> MemAliases
singleton VName
v Names
ns = Map VName Names -> MemAliases
MemAliases forall a b. (a -> b) -> a -> b
$ forall k a. k -> a -> Map k a
M.singleton VName
v Names
ns

aliasesOf :: MemAliases -> VName -> Names
aliasesOf :: MemAliases -> VName -> Names
aliasesOf (MemAliases Map VName Names
m) VName
v = forall a. a -> Maybe a -> a
fromMaybe forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$ forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName Names
m

isIn :: VName -> MemAliases -> Bool
isIn :: VName -> MemAliases -> Bool
isIn VName
v (MemAliases Map VName Names
m) = VName
v forall a. Ord a => a -> Set a -> Bool
`S.member` forall k a. Map k a -> Set k
M.keysSet Map VName Names
m

newtype Env inner = Env {forall inner.
Env inner -> MemAliases -> inner -> MemAliasesM inner MemAliases
onInner :: MemAliases -> inner -> MemAliasesM inner MemAliases}

type MemAliasesM inner a = Reader (Env inner) a

analyzeHostOp :: MemAliases -> HostOp GPUMem () -> MemAliasesM (HostOp GPUMem ()) MemAliases
analyzeHostOp :: MemAliases
-> HostOp GPUMem () -> MemAliasesM (HostOp GPUMem ()) MemAliases
analyzeHostOp MemAliases
m (SegOp (SegMap SegLevel
_ SegSpace
_ [Type]
_ KernelBody GPUMem
kbody)) =
  forall {k} (rep :: k) inner.
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Stms rep -> MemAliases -> MemAliasesM inner MemAliases
analyzeStms (forall {k} (rep :: k). KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
kbody) MemAliases
m
analyzeHostOp MemAliases
m (SegOp (SegRed SegLevel
_ SegSpace
_ [SegBinOp GPUMem]
_ [Type]
_ KernelBody GPUMem
kbody)) =
  forall {k} (rep :: k) inner.
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Stms rep -> MemAliases -> MemAliasesM inner MemAliases
analyzeStms (forall {k} (rep :: k). KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
kbody) MemAliases
m
analyzeHostOp MemAliases
m (SegOp (SegScan SegLevel
_ SegSpace
_ [SegBinOp GPUMem]
_ [Type]
_ KernelBody GPUMem
kbody)) =
  forall {k} (rep :: k) inner.
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Stms rep -> MemAliases -> MemAliasesM inner MemAliases
analyzeStms (forall {k} (rep :: k). KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
kbody) MemAliases
m
analyzeHostOp MemAliases
m (SegOp (SegHist SegLevel
_ SegSpace
_ [HistOp GPUMem]
_ [Type]
_ KernelBody GPUMem
kbody)) =
  forall {k} (rep :: k) inner.
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Stms rep -> MemAliases -> MemAliasesM inner MemAliases
analyzeStms (forall {k} (rep :: k). KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
kbody) MemAliases
m
analyzeHostOp MemAliases
m SizeOp {} = forall (f :: * -> *) a. Applicative f => a -> f a
pure MemAliases
m
analyzeHostOp MemAliases
m GPUBody {} = forall (f :: * -> *) a. Applicative f => a -> f a
pure MemAliases
m
analyzeHostOp MemAliases
m (OtherOp ()) = forall (f :: * -> *) a. Applicative f => a -> f a
pure MemAliases
m

analyzeStm :: (Mem rep inner, LetDec rep ~ LetDecMem) => MemAliases -> Stm rep -> MemAliasesM inner MemAliases
analyzeStm :: forall {k} (rep :: k) inner.
(Mem rep inner, LetDec rep ~ LetDecMem) =>
MemAliases -> Stm rep -> MemAliasesM inner MemAliases
analyzeStm MemAliases
m (Let (Pat [PatElem VName
vname LetDec rep
_]) StmAux (ExpDec rep)
_ (Op (Alloc SubExp
_ Space
_))) =
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ MemAliases
m forall a. Semigroup a => a -> a -> a
<> VName -> Names -> MemAliases
singleton VName
vname forall a. Monoid a => a
mempty
analyzeStm MemAliases
m (Let Pat (LetDec rep)
_ StmAux (ExpDec rep)
_ (Op (Inner inner
inner))) = do
  MemAliases -> inner -> ReaderT (Env inner) Identity MemAliases
on_inner <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall inner.
Env inner -> MemAliases -> inner -> MemAliasesM inner MemAliases
onInner
  MemAliases -> inner -> ReaderT (Env inner) Identity MemAliases
on_inner MemAliases
m inner
inner
analyzeStm MemAliases
m (Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ (Match [SubExp]
_ [Case (Body rep)]
cases Body rep
defbody MatchDec (BranchType rep)
_)) = do
  let bodies :: [Body rep]
bodies = Body rep
defbody forall a. a -> [a] -> [a]
: forall a b. (a -> b) -> [a] -> [b]
map forall body. Case body -> body
caseBody [Case (Body rep)]
cases
  MemAliases
m' <- forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (forall a b c. (a -> b -> c) -> b -> a -> c
flip forall {k} (rep :: k) inner.
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Stms rep -> MemAliases -> MemAliasesM inner MemAliases
analyzeStms) MemAliases
m forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall {k} (rep :: k). Body rep -> Stms rep
bodyStms [Body rep]
bodies
  forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (forall a b. [a] -> [b] -> [(a, b)]
zip (forall dec. Pat dec -> [VName]
patNames Pat (LetDec rep)
pat) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Body rep -> Result
bodyResult) [Body rep]
bodies
    forall a b. a -> (a -> b) -> b
& forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (MemAliases -> (VName, SubExp) -> Maybe (VName, VName)
filterFun MemAliases
m')
    forall a b. a -> (a -> b) -> b
& forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry VName -> VName -> MemAliases -> MemAliases
addAlias) MemAliases
m'
    forall a b. a -> (a -> b) -> b
& forall (f :: * -> *) a. Applicative f => a -> f a
pure
analyzeStm MemAliases
m (Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ (DoLoop [(FParam rep, SubExp)]
params LoopForm rep
_ Body rep
body)) = do
  let m_init :: MemAliases
m_init =
        forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd [(FParam rep, SubExp)]
params
          forall a b. a -> (a -> b) -> b
& forall a b. [a] -> [b] -> [(a, b)]
zip (forall dec. Pat dec -> [VName]
patNames Pat (LetDec rep)
pat)
          forall a b. a -> (a -> b) -> b
& forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (MemAliases -> (VName, SubExp) -> Maybe (VName, VName)
filterFun MemAliases
m)
          forall a b. a -> (a -> b) -> b
& forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry VName -> VName -> MemAliases -> MemAliases
addAlias) MemAliases
m
      m_params :: MemAliases
m_params =
        forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (MemAliases -> (VName, SubExp) -> Maybe (VName, VName)
filterFun MemAliases
m_init forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first forall dec. Param dec -> VName
paramName) [(FParam rep, SubExp)]
params
          forall a b. a -> (a -> b) -> b
& forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry VName -> VName -> MemAliases -> MemAliases
addAlias) MemAliases
m_init
  MemAliases
m_body <- forall {k} (rep :: k) inner.
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Stms rep -> MemAliases -> MemAliasesM inner MemAliases
analyzeStms (forall {k} (rep :: k). Body rep -> Stms rep
bodyStms Body rep
body) MemAliases
m_params
  forall a b. [a] -> [b] -> [(a, b)]
zip (forall dec. Pat dec -> [VName]
patNames Pat (LetDec rep)
pat) (forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Body rep -> Result
bodyResult Body rep
body)
    forall a b. a -> (a -> b) -> b
& forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (MemAliases -> (VName, SubExp) -> Maybe (VName, VName)
filterFun MemAliases
m_body)
    forall a b. a -> (a -> b) -> b
& forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry VName -> VName -> MemAliases -> MemAliases
addAlias) MemAliases
m_body
    forall a b. a -> (a -> b) -> b
& forall (f :: * -> *) a. Applicative f => a -> f a
pure
analyzeStm MemAliases
m Stm rep
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure MemAliases
m

filterFun :: MemAliases -> (VName, SubExp) -> Maybe (VName, VName)
filterFun :: MemAliases -> (VName, SubExp) -> Maybe (VName, VName)
filterFun MemAliases
m' (VName
v, Var VName
v') | VName
v' VName -> MemAliases -> Bool
`isIn` MemAliases
m' = forall a. a -> Maybe a
Just (VName
v, VName
v')
filterFun MemAliases
_ (VName, SubExp)
_ = forall a. Maybe a
Nothing

analyzeStms :: (Mem rep inner, LetDec rep ~ LetDecMem) => Stms rep -> MemAliases -> MemAliasesM inner MemAliases
analyzeStms :: forall {k} (rep :: k) inner.
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Stms rep -> MemAliases -> MemAliasesM inner MemAliases
analyzeStms =
  forall a b c. (a -> b -> c) -> b -> a -> c
flip forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM forall {k} (rep :: k) inner.
(Mem rep inner, LetDec rep ~ LetDecMem) =>
MemAliases -> Stm rep -> MemAliasesM inner MemAliases
analyzeStm

analyzeFun :: (Mem rep inner, LetDec rep ~ LetDecMem) => FunDef rep -> MemAliasesM inner MemAliases
analyzeFun :: forall {k} (rep :: k) inner.
(Mem rep inner, LetDec rep ~ LetDecMem) =>
FunDef rep -> MemAliasesM inner MemAliases
analyzeFun FunDef rep
f =
  forall {k} (rep :: k). FunDef rep -> [FParam rep]
funDefParams FunDef rep
f
    forall a b. a -> (a -> b) -> b
& forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe forall {d} {u} {ret}. Param (MemInfo d u ret) -> Maybe MemAliases
justMem
    forall a b. a -> (a -> b) -> b
& forall a. Monoid a => [a] -> a
mconcat
    forall a b. a -> (a -> b) -> b
& forall {k} (rep :: k) inner.
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Stms rep -> MemAliases -> MemAliasesM inner MemAliases
analyzeStms (forall {k} (rep :: k). Body rep -> Stms rep
bodyStms forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). FunDef rep -> Body rep
funDefBody FunDef rep
f)
  where
    justMem :: Param (MemInfo d u ret) -> Maybe MemAliases
justMem (Param Attrs
_ VName
v (MemMem Space
_)) = forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ VName -> Names -> MemAliases
singleton VName
v forall a. Monoid a => a
mempty
    justMem Param (MemInfo d u ret)
_ = forall a. Maybe a
Nothing

transitiveClosure :: MemAliases -> MemAliases
transitiveClosure :: MemAliases -> MemAliases
transitiveClosure ma :: MemAliases
ma@(MemAliases Map VName Names
m) =
  forall m k a. Monoid m => (k -> a -> m) -> Map k a -> m
M.foldMapWithKey
    ( \VName
k Names
ns ->
        Names -> [VName]
namesToList Names
ns
          forall a b. a -> (a -> b) -> b
& forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (MemAliases -> VName -> Names
aliasesOf MemAliases
ma)
          forall a b. a -> (a -> b) -> b
& VName -> Names -> MemAliases
singleton VName
k
    )
    Map VName Names
m
    forall a. Semigroup a => a -> a -> a
<> MemAliases
ma

analyzeSeqMem :: Prog SeqMem -> MemAliases
analyzeSeqMem :: Prog SeqMem -> MemAliases
analyzeSeqMem Prog SeqMem
prog = MemAliases -> MemAliases
completeBijection forall a b. (a -> b) -> a -> b
$ forall r a. Reader r a -> r -> a
runReader (forall {k} (rep :: k) inner.
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Prog rep -> MemAliasesM inner MemAliases
analyze Prog SeqMem
prog) forall a b. (a -> b) -> a -> b
$ forall inner.
(MemAliases -> inner -> MemAliasesM inner MemAliases) -> Env inner
Env forall a b. (a -> b) -> a -> b
$ \MemAliases
x ()
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure MemAliases
x

analyzeGPUMem :: Prog GPUMem -> MemAliases
analyzeGPUMem :: Prog GPUMem -> MemAliases
analyzeGPUMem Prog GPUMem
prog = MemAliases -> MemAliases
completeBijection forall a b. (a -> b) -> a -> b
$ forall r a. Reader r a -> r -> a
runReader (forall {k} (rep :: k) inner.
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Prog rep -> MemAliasesM inner MemAliases
analyze Prog GPUMem
prog) forall a b. (a -> b) -> a -> b
$ forall inner.
(MemAliases -> inner -> MemAliasesM inner MemAliases) -> Env inner
Env MemAliases
-> HostOp GPUMem () -> MemAliasesM (HostOp GPUMem ()) MemAliases
analyzeHostOp

analyze :: (Mem rep inner, LetDec rep ~ LetDecMem) => Prog rep -> MemAliasesM inner MemAliases
analyze :: forall {k} (rep :: k) inner.
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Prog rep -> MemAliasesM inner MemAliases
analyze Prog rep
prog =
  forall {k} (rep :: k). Prog rep -> [FunDef rep]
progFuns Prog rep
prog
    forall a b. a -> (a -> b) -> b
& forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (\MemAliases
m FunDef rep
f -> forall a. Semigroup a => a -> a -> a
(<>) MemAliases
m forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) inner.
(Mem rep inner, LetDec rep ~ LetDecMem) =>
FunDef rep -> MemAliasesM inner MemAliases
analyzeFun FunDef rep
f) (Map VName Names -> MemAliases
MemAliases forall a. Monoid a => a
mempty)
    forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> forall a. Eq a => (a -> a) -> a -> a
fixPoint MemAliases -> MemAliases
transitiveClosure

completeBijection :: MemAliases -> MemAliases
completeBijection :: MemAliases -> MemAliases
completeBijection ma :: MemAliases
ma@(MemAliases Map VName Names
m) =
  forall m k a. Monoid m => (k -> a -> m) -> Map k a -> m
M.foldMapWithKey (\VName
k Names
ns -> forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (VName -> Names -> MemAliases
`singleton` VName -> Names
oneName VName
k) (Names -> [VName]
namesToList Names
ns)) Map VName Names
m forall a. Semigroup a => a -> a -> a
<> MemAliases
ma