{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}

module Futhark.Analysis.MemAlias
  ( analyzeSeqMem,
    analyzeGPUMem,
    canBeSameMemory,
    aliasesOf,
    MemAliases,
  )
where

import Control.Monad.Reader
import Data.Bifunctor
import Data.Function ((&))
import Data.Functor ((<&>))
import qualified Data.Map as M
import Data.Maybe (fromMaybe, mapMaybe)
import qualified Data.Set as S
import Futhark.IR.GPUMem
import Futhark.IR.SeqMem
import Futhark.Util
import Futhark.Util.Pretty

-- For our purposes, memory aliases are a bijective function: If @a@ aliases
-- @b@, @b@ also aliases @a@. However, this relationship is not transitive. Consider for instance the following:
--
-- @
--   let xs@mem_1 =
--     if ... then
--       replicate i 0 @ mem_2
--     else
--       replicate j 1 @ mem_3
-- @
--
-- Here, @mem_1@ aliases both @mem_2@ and @mem_3@, each of which alias @mem_1@
-- but not each other.
newtype MemAliases = MemAliases (M.Map VName Names)
  deriving (Int -> MemAliases -> ShowS
[MemAliases] -> ShowS
MemAliases -> String
(Int -> MemAliases -> ShowS)
-> (MemAliases -> String)
-> ([MemAliases] -> ShowS)
-> Show MemAliases
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [MemAliases] -> ShowS
$cshowList :: [MemAliases] -> ShowS
show :: MemAliases -> String
$cshow :: MemAliases -> String
showsPrec :: Int -> MemAliases -> ShowS
$cshowsPrec :: Int -> MemAliases -> ShowS
Show, MemAliases -> MemAliases -> Bool
(MemAliases -> MemAliases -> Bool)
-> (MemAliases -> MemAliases -> Bool) -> Eq MemAliases
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: MemAliases -> MemAliases -> Bool
$c/= :: MemAliases -> MemAliases -> Bool
== :: MemAliases -> MemAliases -> Bool
$c== :: MemAliases -> MemAliases -> Bool
Eq)

instance Semigroup MemAliases where
  (MemAliases Map VName Names
m1) <> :: MemAliases -> MemAliases -> MemAliases
<> (MemAliases Map VName Names
m2) = Map VName Names -> MemAliases
MemAliases (Map VName Names -> MemAliases) -> Map VName Names -> MemAliases
forall a b. (a -> b) -> a -> b
$ (Names -> Names -> Names)
-> Map VName Names -> Map VName Names -> Map VName Names
forall k a. Ord k => (a -> a -> a) -> Map k a -> Map k a -> Map k a
M.unionWith Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
(<>) Map VName Names
m1 Map VName Names
m2

instance Monoid MemAliases where
  mempty :: MemAliases
mempty = Map VName Names -> MemAliases
MemAliases Map VName Names
forall a. Monoid a => a
mempty

instance Pretty MemAliases where
  ppr :: MemAliases -> Doc
ppr (MemAliases Map VName Names
m) = Map VName Names -> Doc
forall a. Pretty a => a -> Doc
ppr Map VName Names
m

addAlias :: VName -> VName -> MemAliases -> MemAliases
addAlias :: VName -> VName -> MemAliases -> MemAliases
addAlias VName
v1 VName
v2 MemAliases
m =
  MemAliases
m MemAliases -> MemAliases -> MemAliases
forall a. Semigroup a => a -> a -> a
<> VName -> Names -> MemAliases
singleton VName
v1 (VName -> Names
oneName VName
v2) MemAliases -> MemAliases -> MemAliases
forall a. Semigroup a => a -> a -> a
<> VName -> Names -> MemAliases
singleton VName
v2 Names
forall a. Monoid a => a
mempty

singleton :: VName -> Names -> MemAliases
singleton :: VName -> Names -> MemAliases
singleton VName
v Names
ns = Map VName Names -> MemAliases
MemAliases (Map VName Names -> MemAliases) -> Map VName Names -> MemAliases
forall a b. (a -> b) -> a -> b
$ VName -> Names -> Map VName Names
forall k a. k -> a -> Map k a
M.singleton VName
v Names
ns

canBeSameMemory :: MemAliases -> VName -> VName -> Bool
canBeSameMemory :: MemAliases -> VName -> VName -> Bool
canBeSameMemory (MemAliases Map VName Names
m) VName
v1 VName
v2 =
  case (Names -> Bool) -> Maybe Names -> Maybe Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (VName
v2 VName -> Names -> Bool
`nameIn`) (VName -> Map VName Names -> Maybe Names
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v1 Map VName Names
m) of
    Just Bool
True -> Bool
True
    Just Bool
False -> case (Names -> Bool) -> Maybe Names -> Maybe Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (VName
v1 VName -> Names -> Bool
`nameIn`) (VName -> Map VName Names -> Maybe Names
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v2 Map VName Names
m) of
      Just Bool
b -> Bool
b
      Maybe Bool
Nothing -> String -> Bool
forall a. HasCallStack => String -> a
error (String -> Bool) -> String -> Bool
forall a b. (a -> b) -> a -> b
$ String
"VName not found in MemAliases: " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> VName -> String
forall a. Pretty a => a -> String
pretty VName
v2
    Maybe Bool
Nothing -> String -> Bool
forall a. HasCallStack => String -> a
error (String -> Bool) -> String -> Bool
forall a b. (a -> b) -> a -> b
$ String
"VName not found in MemAliases: " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> VName -> String
forall a. Pretty a => a -> String
pretty VName
v1

aliasesOf :: MemAliases -> VName -> Names
aliasesOf :: MemAliases -> VName -> Names
aliasesOf (MemAliases Map VName Names
m) VName
v = Names -> Maybe Names -> Names
forall a. a -> Maybe a -> a
fromMaybe Names
forall a. Monoid a => a
mempty (Maybe Names -> Names) -> Maybe Names -> Names
forall a b. (a -> b) -> a -> b
$ VName -> Map VName Names -> Maybe Names
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName Names
m

isIn :: VName -> MemAliases -> Bool
isIn :: VName -> MemAliases -> Bool
isIn VName
v (MemAliases Map VName Names
m) = VName
v VName -> Set VName -> Bool
forall a. Ord a => a -> Set a -> Bool
`S.member` Map VName Names -> Set VName
forall k a. Map k a -> Set k
M.keysSet Map VName Names
m

newtype Env inner = Env {Env inner -> MemAliases -> inner -> MemAliasesM inner MemAliases
onInner :: MemAliases -> inner -> MemAliasesM inner MemAliases}

type MemAliasesM inner a = Reader (Env inner) a

analyzeHostOp :: MemAliases -> HostOp GPUMem () -> MemAliasesM (HostOp GPUMem ()) MemAliases
analyzeHostOp :: MemAliases
-> HostOp GPUMem () -> MemAliasesM (HostOp GPUMem ()) MemAliases
analyzeHostOp MemAliases
m (SegOp (SegMap SegLevel
_ SegSpace
_ [Type]
_ KernelBody GPUMem
kbody)) =
  Stms GPUMem
-> MemAliases -> MemAliasesM (HostOp GPUMem ()) MemAliases
forall rep inner.
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Stms rep -> MemAliases -> MemAliasesM inner MemAliases
analyzeStms (KernelBody GPUMem -> Stms GPUMem
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
kbody) MemAliases
m
analyzeHostOp MemAliases
m (SegOp (SegRed SegLevel
_ SegSpace
_ [SegBinOp GPUMem]
_ [Type]
_ KernelBody GPUMem
kbody)) =
  Stms GPUMem
-> MemAliases -> MemAliasesM (HostOp GPUMem ()) MemAliases
forall rep inner.
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Stms rep -> MemAliases -> MemAliasesM inner MemAliases
analyzeStms (KernelBody GPUMem -> Stms GPUMem
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
kbody) MemAliases
m
analyzeHostOp MemAliases
m (SegOp (SegScan SegLevel
_ SegSpace
_ [SegBinOp GPUMem]
_ [Type]
_ KernelBody GPUMem
kbody)) =
  Stms GPUMem
-> MemAliases -> MemAliasesM (HostOp GPUMem ()) MemAliases
forall rep inner.
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Stms rep -> MemAliases -> MemAliasesM inner MemAliases
analyzeStms (KernelBody GPUMem -> Stms GPUMem
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
kbody) MemAliases
m
analyzeHostOp MemAliases
m (SegOp (SegHist SegLevel
_ SegSpace
_ [HistOp GPUMem]
_ [Type]
_ KernelBody GPUMem
kbody)) =
  Stms GPUMem
-> MemAliases -> MemAliasesM (HostOp GPUMem ()) MemAliases
forall rep inner.
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Stms rep -> MemAliases -> MemAliasesM inner MemAliases
analyzeStms (KernelBody GPUMem -> Stms GPUMem
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
kbody) MemAliases
m
analyzeHostOp MemAliases
_ HostOp GPUMem ()
_ = MemAliases -> MemAliasesM (HostOp GPUMem ()) MemAliases
forall (m :: * -> *) a. Monad m => a -> m a
return MemAliases
forall a. Monoid a => a
mempty

analyzeStm :: (Mem rep inner, LetDec rep ~ LetDecMem) => MemAliases -> Stm rep -> MemAliasesM inner MemAliases
analyzeStm :: MemAliases -> Stm rep -> MemAliasesM inner MemAliases
analyzeStm MemAliases
m (Let (Pat [PatElem VName
vname LetDec rep
_]) StmAux (ExpDec rep)
_ (Op (Alloc _ _))) =
  MemAliases -> MemAliasesM inner MemAliases
forall (m :: * -> *) a. Monad m => a -> m a
return (MemAliases -> MemAliasesM inner MemAliases)
-> MemAliases -> MemAliasesM inner MemAliases
forall a b. (a -> b) -> a -> b
$ MemAliases
m MemAliases -> MemAliases -> MemAliases
forall a. Semigroup a => a -> a -> a
<> VName -> Names -> MemAliases
singleton VName
vname Names
forall a. Monoid a => a
mempty
analyzeStm MemAliases
m (Let PatT (LetDec rep)
_ StmAux (ExpDec rep)
_ (Op (Inner inner))) = do
  MemAliases -> inner -> MemAliasesM inner MemAliases
on_inner <- (Env inner -> MemAliases -> inner -> MemAliasesM inner MemAliases)
-> ReaderT
     (Env inner)
     Identity
     (MemAliases -> inner -> MemAliasesM inner MemAliases)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env inner -> MemAliases -> inner -> MemAliasesM inner MemAliases
forall inner.
Env inner -> MemAliases -> inner -> MemAliasesM inner MemAliases
onInner
  MemAliases -> inner -> MemAliasesM inner MemAliases
on_inner MemAliases
m inner
inner
analyzeStm MemAliases
m (Let PatT (LetDec rep)
pat StmAux (ExpDec rep)
_ (If SubExp
_ BodyT rep
then_body BodyT rep
else_body IfDec (BranchType rep)
_)) = do
  MemAliases
m' <-
    Stms rep -> MemAliases -> MemAliasesM inner MemAliases
forall rep inner.
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Stms rep -> MemAliases -> MemAliasesM inner MemAliases
analyzeStms (BodyT rep -> Stms rep
forall rep. BodyT rep -> Stms rep
bodyStms BodyT rep
then_body) MemAliases
m
      MemAliasesM inner MemAliases
-> (MemAliases -> MemAliasesM inner MemAliases)
-> MemAliasesM inner MemAliases
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Stms rep -> MemAliases -> MemAliasesM inner MemAliases
forall rep inner.
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Stms rep -> MemAliases -> MemAliasesM inner MemAliases
analyzeStms (BodyT rep -> Stms rep
forall rep. BodyT rep -> Stms rep
bodyStms BodyT rep
else_body)
  [VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (PatT LetDecMem -> [VName]
forall dec. PatT dec -> [VName]
patNames PatT (LetDec rep)
PatT LetDecMem
pat) ((SubExpRes -> SubExp) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp ([SubExpRes] -> [SubExp]) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ BodyT rep -> [SubExpRes]
forall rep. BodyT rep -> [SubExpRes]
bodyResult BodyT rep
then_body)
    [(VName, SubExp)] -> [(VName, SubExp)] -> [(VName, SubExp)]
forall a. Semigroup a => a -> a -> a
<> [VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (PatT LetDecMem -> [VName]
forall dec. PatT dec -> [VName]
patNames PatT (LetDec rep)
PatT LetDecMem
pat) ((SubExpRes -> SubExp) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp ([SubExpRes] -> [SubExp]) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ BodyT rep -> [SubExpRes]
forall rep. BodyT rep -> [SubExpRes]
bodyResult BodyT rep
else_body)
    [(VName, SubExp)]
-> ([(VName, SubExp)] -> [(VName, VName)]) -> [(VName, VName)]
forall a b. a -> (a -> b) -> b
& ((VName, SubExp) -> Maybe (VName, VName))
-> [(VName, SubExp)] -> [(VName, VName)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (MemAliases -> (VName, SubExp) -> Maybe (VName, VName)
filterFun MemAliases
m')
    [(VName, VName)] -> ([(VName, VName)] -> MemAliases) -> MemAliases
forall a b. a -> (a -> b) -> b
& ((VName, VName) -> MemAliases -> MemAliases)
-> MemAliases -> [(VName, VName)] -> MemAliases
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr ((VName -> VName -> MemAliases -> MemAliases)
-> (VName, VName) -> MemAliases -> MemAliases
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry VName -> VName -> MemAliases -> MemAliases
addAlias) MemAliases
m'
    MemAliases
-> (MemAliases -> MemAliasesM inner MemAliases)
-> MemAliasesM inner MemAliases
forall a b. a -> (a -> b) -> b
& MemAliases -> MemAliasesM inner MemAliases
forall (m :: * -> *) a. Monad m => a -> m a
return
analyzeStm MemAliases
m (Let PatT (LetDec rep)
pat StmAux (ExpDec rep)
_ (DoLoop [(FParam rep, SubExp)]
params LoopForm rep
_ BodyT rep
body)) = do
  let m_init :: MemAliases
m_init =
        ((Param FParamMem, SubExp) -> SubExp)
-> [(Param FParamMem, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (Param FParamMem, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(FParam rep, SubExp)]
[(Param FParamMem, SubExp)]
params
          [SubExp] -> ([SubExp] -> [(VName, SubExp)]) -> [(VName, SubExp)]
forall a b. a -> (a -> b) -> b
& [VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (PatT LetDecMem -> [VName]
forall dec. PatT dec -> [VName]
patNames PatT (LetDec rep)
PatT LetDecMem
pat)
          [(VName, SubExp)]
-> ([(VName, SubExp)] -> [(VName, VName)]) -> [(VName, VName)]
forall a b. a -> (a -> b) -> b
& ((VName, SubExp) -> Maybe (VName, VName))
-> [(VName, SubExp)] -> [(VName, VName)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (MemAliases -> (VName, SubExp) -> Maybe (VName, VName)
filterFun MemAliases
m)
          [(VName, VName)] -> ([(VName, VName)] -> MemAliases) -> MemAliases
forall a b. a -> (a -> b) -> b
& ((VName, VName) -> MemAliases -> MemAliases)
-> MemAliases -> [(VName, VName)] -> MemAliases
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr ((VName -> VName -> MemAliases -> MemAliases)
-> (VName, VName) -> MemAliases -> MemAliases
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry VName -> VName -> MemAliases -> MemAliases
addAlias) MemAliases
m
      m_params :: MemAliases
m_params =
        ((Param FParamMem, SubExp) -> Maybe (VName, VName))
-> [(Param FParamMem, SubExp)] -> [(VName, VName)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (MemAliases -> (VName, SubExp) -> Maybe (VName, VName)
filterFun MemAliases
m_init ((VName, SubExp) -> Maybe (VName, VName))
-> ((Param FParamMem, SubExp) -> (VName, SubExp))
-> (Param FParamMem, SubExp)
-> Maybe (VName, VName)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param FParamMem -> VName)
-> (Param FParamMem, SubExp) -> (VName, SubExp)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first Param FParamMem -> VName
forall dec. Param dec -> VName
paramName) [(FParam rep, SubExp)]
[(Param FParamMem, SubExp)]
params
          [(VName, VName)] -> ([(VName, VName)] -> MemAliases) -> MemAliases
forall a b. a -> (a -> b) -> b
& ((VName, VName) -> MemAliases -> MemAliases)
-> MemAliases -> [(VName, VName)] -> MemAliases
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr ((VName -> VName -> MemAliases -> MemAliases)
-> (VName, VName) -> MemAliases -> MemAliases
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry VName -> VName -> MemAliases -> MemAliases
addAlias) MemAliases
m_init
  MemAliases
m_body <- Stms rep -> MemAliases -> MemAliasesM inner MemAliases
forall rep inner.
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Stms rep -> MemAliases -> MemAliasesM inner MemAliases
analyzeStms (BodyT rep -> Stms rep
forall rep. BodyT rep -> Stms rep
bodyStms BodyT rep
body) MemAliases
m_params
  [VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (PatT LetDecMem -> [VName]
forall dec. PatT dec -> [VName]
patNames PatT (LetDec rep)
PatT LetDecMem
pat) ((SubExpRes -> SubExp) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp ([SubExpRes] -> [SubExp]) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ BodyT rep -> [SubExpRes]
forall rep. BodyT rep -> [SubExpRes]
bodyResult BodyT rep
body)
    [(VName, SubExp)]
-> ([(VName, SubExp)] -> [(VName, VName)]) -> [(VName, VName)]
forall a b. a -> (a -> b) -> b
& ((VName, SubExp) -> Maybe (VName, VName))
-> [(VName, SubExp)] -> [(VName, VName)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (MemAliases -> (VName, SubExp) -> Maybe (VName, VName)
filterFun MemAliases
m_body)
    [(VName, VName)] -> ([(VName, VName)] -> MemAliases) -> MemAliases
forall a b. a -> (a -> b) -> b
& ((VName, VName) -> MemAliases -> MemAliases)
-> MemAliases -> [(VName, VName)] -> MemAliases
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr ((VName -> VName -> MemAliases -> MemAliases)
-> (VName, VName) -> MemAliases -> MemAliases
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry VName -> VName -> MemAliases -> MemAliases
addAlias) MemAliases
m_body
    MemAliases
-> (MemAliases -> MemAliasesM inner MemAliases)
-> MemAliasesM inner MemAliases
forall a b. a -> (a -> b) -> b
& MemAliases -> MemAliasesM inner MemAliases
forall (m :: * -> *) a. Monad m => a -> m a
return
analyzeStm MemAliases
m Stm rep
_ = MemAliases -> MemAliasesM inner MemAliases
forall (m :: * -> *) a. Monad m => a -> m a
return MemAliases
m

filterFun :: MemAliases -> (VName, SubExp) -> Maybe (VName, VName)
filterFun :: MemAliases -> (VName, SubExp) -> Maybe (VName, VName)
filterFun MemAliases
m' (VName
v, Var VName
v') | VName
v' VName -> MemAliases -> Bool
`isIn` MemAliases
m' = (VName, VName) -> Maybe (VName, VName)
forall a. a -> Maybe a
Just (VName
v, VName
v')
filterFun MemAliases
_ (VName, SubExp)
_ = Maybe (VName, VName)
forall a. Maybe a
Nothing

analyzeStms :: (Mem rep inner, LetDec rep ~ LetDecMem) => Stms rep -> MemAliases -> MemAliasesM inner MemAliases
analyzeStms :: Stms rep -> MemAliases -> MemAliasesM inner MemAliases
analyzeStms =
  (MemAliases -> Stms rep -> MemAliasesM inner MemAliases)
-> Stms rep -> MemAliases -> MemAliasesM inner MemAliases
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((MemAliases -> Stms rep -> MemAliasesM inner MemAliases)
 -> Stms rep -> MemAliases -> MemAliasesM inner MemAliases)
-> (MemAliases -> Stms rep -> MemAliasesM inner MemAliases)
-> Stms rep
-> MemAliases
-> MemAliasesM inner MemAliases
forall a b. (a -> b) -> a -> b
$ (MemAliases -> Stm rep -> MemAliasesM inner MemAliases)
-> MemAliases -> Stms rep -> MemAliasesM inner MemAliases
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM MemAliases -> Stm rep -> MemAliasesM inner MemAliases
forall rep inner.
(Mem rep inner, LetDec rep ~ LetDecMem) =>
MemAliases -> Stm rep -> MemAliasesM inner MemAliases
analyzeStm

analyzeFun :: (Mem rep inner, LetDec rep ~ LetDecMem) => FunDef rep -> MemAliasesM inner MemAliases
analyzeFun :: FunDef rep -> MemAliasesM inner MemAliases
analyzeFun FunDef rep
f =
  FunDef rep -> [FParam rep]
forall rep. FunDef rep -> [FParam rep]
funDefParams FunDef rep
f
    [Param FParamMem]
-> ([Param FParamMem] -> [MemAliases]) -> [MemAliases]
forall a b. a -> (a -> b) -> b
& (Param FParamMem -> Maybe MemAliases)
-> [Param FParamMem] -> [MemAliases]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe Param FParamMem -> Maybe MemAliases
forall d u ret. Param (MemInfo d u ret) -> Maybe MemAliases
justMem
    [MemAliases] -> ([MemAliases] -> MemAliases) -> MemAliases
forall a b. a -> (a -> b) -> b
& [MemAliases] -> MemAliases
forall a. Monoid a => [a] -> a
mconcat
    MemAliases
-> (MemAliases -> MemAliasesM inner MemAliases)
-> MemAliasesM inner MemAliases
forall a b. a -> (a -> b) -> b
& Stms rep -> MemAliases -> MemAliasesM inner MemAliases
forall rep inner.
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Stms rep -> MemAliases -> MemAliasesM inner MemAliases
analyzeStms (BodyT rep -> Stms rep
forall rep. BodyT rep -> Stms rep
bodyStms (BodyT rep -> Stms rep) -> BodyT rep -> Stms rep
forall a b. (a -> b) -> a -> b
$ FunDef rep -> BodyT rep
forall rep. FunDef rep -> BodyT rep
funDefBody FunDef rep
f)
  where
    justMem :: Param (MemInfo d u ret) -> Maybe MemAliases
justMem (Param Attrs
_ VName
v (MemMem Space
_)) = MemAliases -> Maybe MemAliases
forall a. a -> Maybe a
Just (MemAliases -> Maybe MemAliases) -> MemAliases -> Maybe MemAliases
forall a b. (a -> b) -> a -> b
$ VName -> Names -> MemAliases
singleton VName
v Names
forall a. Monoid a => a
mempty
    justMem Param (MemInfo d u ret)
_ = Maybe MemAliases
forall a. Maybe a
Nothing

transitiveClosure :: MemAliases -> MemAliases
transitiveClosure :: MemAliases -> MemAliases
transitiveClosure ma :: MemAliases
ma@(MemAliases Map VName Names
m) =
  (VName -> Names -> MemAliases) -> Map VName Names -> MemAliases
forall m k a. Monoid m => (k -> a -> m) -> Map k a -> m
M.foldMapWithKey
    ( \VName
k Names
ns ->
        Names -> [VName]
namesToList Names
ns
          [VName] -> ([VName] -> Names) -> Names
forall a b. a -> (a -> b) -> b
& (VName -> Names) -> [VName] -> Names
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (MemAliases -> VName -> Names
aliasesOf MemAliases
ma)
          Names -> (Names -> MemAliases) -> MemAliases
forall a b. a -> (a -> b) -> b
& VName -> Names -> MemAliases
singleton VName
k
    )
    Map VName Names
m
    MemAliases -> MemAliases -> MemAliases
forall a. Semigroup a => a -> a -> a
<> MemAliases
ma

analyzeSeqMem :: Prog SeqMem -> MemAliases
analyzeSeqMem :: Prog SeqMem -> MemAliases
analyzeSeqMem Prog SeqMem
prog = MemAliases -> MemAliases
completeBijection (MemAliases -> MemAliases) -> MemAliases -> MemAliases
forall a b. (a -> b) -> a -> b
$ Reader (Env ()) MemAliases -> Env () -> MemAliases
forall r a. Reader r a -> r -> a
runReader (Prog SeqMem -> Reader (Env ()) MemAliases
forall rep inner.
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Prog rep -> MemAliasesM inner MemAliases
analyze Prog SeqMem
prog) (Env () -> MemAliases) -> Env () -> MemAliases
forall a b. (a -> b) -> a -> b
$ (MemAliases -> () -> Reader (Env ()) MemAliases) -> Env ()
forall inner.
(MemAliases -> inner -> MemAliasesM inner MemAliases) -> Env inner
Env ((MemAliases -> () -> Reader (Env ()) MemAliases) -> Env ())
-> (MemAliases -> () -> Reader (Env ()) MemAliases) -> Env ()
forall a b. (a -> b) -> a -> b
$ \MemAliases
x ()
_ -> MemAliases -> Reader (Env ()) MemAliases
forall (m :: * -> *) a. Monad m => a -> m a
return MemAliases
x

analyzeGPUMem :: Prog GPUMem -> MemAliases
analyzeGPUMem :: Prog GPUMem -> MemAliases
analyzeGPUMem Prog GPUMem
prog = MemAliases -> MemAliases
completeBijection (MemAliases -> MemAliases) -> MemAliases -> MemAliases
forall a b. (a -> b) -> a -> b
$ MemAliasesM (HostOp GPUMem ()) MemAliases
-> Env (HostOp GPUMem ()) -> MemAliases
forall r a. Reader r a -> r -> a
runReader (Prog GPUMem -> MemAliasesM (HostOp GPUMem ()) MemAliases
forall rep inner.
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Prog rep -> MemAliasesM inner MemAliases
analyze Prog GPUMem
prog) (Env (HostOp GPUMem ()) -> MemAliases)
-> Env (HostOp GPUMem ()) -> MemAliases
forall a b. (a -> b) -> a -> b
$ (MemAliases
 -> HostOp GPUMem () -> MemAliasesM (HostOp GPUMem ()) MemAliases)
-> Env (HostOp GPUMem ())
forall inner.
(MemAliases -> inner -> MemAliasesM inner MemAliases) -> Env inner
Env MemAliases
-> HostOp GPUMem () -> MemAliasesM (HostOp GPUMem ()) MemAliases
analyzeHostOp

analyze :: (Mem rep inner, LetDec rep ~ LetDecMem) => Prog rep -> MemAliasesM inner MemAliases
analyze :: Prog rep -> MemAliasesM inner MemAliases
analyze Prog rep
prog =
  Prog rep -> [FunDef rep]
forall rep. Prog rep -> [FunDef rep]
progFuns Prog rep
prog
    [FunDef rep]
-> ([FunDef rep] -> MemAliasesM inner MemAliases)
-> MemAliasesM inner MemAliases
forall a b. a -> (a -> b) -> b
& (MemAliases -> FunDef rep -> MemAliasesM inner MemAliases)
-> MemAliases -> [FunDef rep] -> MemAliasesM inner MemAliases
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (\MemAliases
m FunDef rep
f -> MemAliases -> MemAliases -> MemAliases
forall a. Semigroup a => a -> a -> a
(<>) MemAliases
m (MemAliases -> MemAliases)
-> MemAliasesM inner MemAliases -> MemAliasesM inner MemAliases
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> FunDef rep -> MemAliasesM inner MemAliases
forall rep inner.
(Mem rep inner, LetDec rep ~ LetDecMem) =>
FunDef rep -> MemAliasesM inner MemAliases
analyzeFun FunDef rep
f) (Map VName Names -> MemAliases
MemAliases Map VName Names
forall a. Monoid a => a
mempty)
    MemAliasesM inner MemAliases
-> (MemAliases -> MemAliases) -> MemAliasesM inner MemAliases
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> (MemAliases -> MemAliases) -> MemAliases -> MemAliases
forall a. Eq a => (a -> a) -> a -> a
fixPoint MemAliases -> MemAliases
transitiveClosure

completeBijection :: MemAliases -> MemAliases
completeBijection :: MemAliases -> MemAliases
completeBijection ma :: MemAliases
ma@(MemAliases Map VName Names
m) =
  (VName -> Names -> MemAliases) -> Map VName Names -> MemAliases
forall m k a. Monoid m => (k -> a -> m) -> Map k a -> m
M.foldMapWithKey (\VName
k Names
ns -> (VName -> MemAliases) -> [VName] -> MemAliases
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (VName -> Names -> MemAliases
`singleton` VName -> Names
oneName VName
k) (Names -> [VName]
namesToList Names
ns)) Map VName Names
m MemAliases -> MemAliases -> MemAliases
forall a. Semigroup a => a -> a -> a
<> MemAliases
ma