{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}

module Futhark.Optimise.ArrayShortCircuiting.MemRefAggreg
  ( recordMemRefUses,
    freeVarSubstitutions,
    translateAccessSummary,
    aggSummaryLoopTotal,
    aggSummaryLoopPartial,
    aggSummaryMapPartial,
    aggSummaryMapTotal,
    noMemOverlap,
  )
where

import Control.Monad
import Data.Function ((&))
import Data.List (intersect, partition, uncons)
import Data.List.NonEmpty (NonEmpty (..))
import Data.Map.Strict qualified as M
import Data.Maybe
import Data.Set qualified as S
import Futhark.Analysis.AlgSimplify
import Futhark.Analysis.PrimExp.Convert
import Futhark.IR.Aliases
import Futhark.IR.Mem
import Futhark.IR.Mem.IxFun qualified as IxFun
import Futhark.MonadFreshNames
import Futhark.Optimise.ArrayShortCircuiting.DataStructs
import Futhark.Optimise.ArrayShortCircuiting.TopdownAnalysis
import Futhark.Util

-----------------------------------------------------
-- Some translations of Accesses and Ixfuns        --
-----------------------------------------------------

-- | Checks whether the index function can be translated at the current program
-- point and also returns the substitutions.  It comes down to answering the
-- question: "can one perform enough substitutions (from the bottom-up scalar
-- table) until all vars appearing in the index function are defined in the
-- current scope?"
freeVarSubstitutions ::
  FreeIn a =>
  ScopeTab rep ->
  ScalarTab ->
  a ->
  Maybe FreeVarSubsts
freeVarSubstitutions :: forall {k} a (rep :: k).
FreeIn a =>
ScopeTab rep -> ScalarTab -> a -> Maybe FreeVarSubsts
freeVarSubstitutions ScopeTab rep
scope0 ScalarTab
scals0 a
indfun =
  FreeVarSubsts -> [VName] -> Maybe FreeVarSubsts
freeVarSubstitutions' forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList forall a b. (a -> b) -> a -> b
$ forall a. FreeIn a => a -> Names
freeIn a
indfun
  where
    freeVarSubstitutions' :: FreeVarSubsts -> [VName] -> Maybe FreeVarSubsts
    freeVarSubstitutions' :: FreeVarSubsts -> [VName] -> Maybe FreeVarSubsts
freeVarSubstitutions' FreeVarSubsts
subs [] = forall a. a -> Maybe a
Just FreeVarSubsts
subs
    freeVarSubstitutions' FreeVarSubsts
subs0 [VName]
fvs =
      let fvs_not_in_scope :: [VName]
fvs_not_in_scope = forall a. (a -> Bool) -> [a] -> [a]
filter (forall k a. Ord k => k -> Map k a -> Bool
`M.notMember` ScopeTab rep
scope0) [VName]
fvs
       in case forall a b. [(a, b)] -> ([a], [b])
unzip forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> Maybe (FreeVarSubsts, [VName])
getSubstitution [VName]
fvs_not_in_scope of
            -- We require that all free variables can be substituted
            Just ([FreeVarSubsts]
subs, [[VName]]
new_fvs) ->
              FreeVarSubsts -> [VName] -> Maybe FreeVarSubsts
freeVarSubstitutions' (FreeVarSubsts
subs0 forall a. Semigroup a => a -> a -> a
<> forall a. Monoid a => [a] -> a
mconcat [FreeVarSubsts]
subs) forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[VName]]
new_fvs
            Maybe ([FreeVarSubsts], [[VName]])
Nothing -> forall a. Maybe a
Nothing
    getSubstitution :: VName -> Maybe (FreeVarSubsts, [VName])
getSubstitution VName
v
      | Just PrimExp VName
pe <- forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v ScalarTab
scals0,
        IntType IntType
_ <- forall v. PrimExp v -> PrimType
primExpType PrimExp VName
pe =
          forall a. a -> Maybe a
Just (forall k a. k -> a -> Map k a
M.singleton VName
v forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp PrimExp VName
pe, Names -> [VName]
namesToList forall a b. (a -> b) -> a -> b
$ forall a. FreeIn a => a -> Names
freeIn PrimExp VName
pe)
    getSubstitution VName
_v = forall a. Maybe a
Nothing

-- | Translates free variables in an access summary
translateAccessSummary :: ScopeTab rep -> ScalarTab -> AccessSummary -> AccessSummary
translateAccessSummary :: forall {k} (rep :: k).
ScopeTab rep -> ScalarTab -> AccessSummary -> AccessSummary
translateAccessSummary ScopeTab rep
_ ScalarTab
_ AccessSummary
Undeterminable = AccessSummary
Undeterminable
translateAccessSummary ScopeTab rep
scope0 ScalarTab
scals0 (Set Set LmadRef
slmads)
  | Just FreeVarSubsts
subs <- forall {k} a (rep :: k).
FreeIn a =>
ScopeTab rep -> ScalarTab -> a -> Maybe FreeVarSubsts
freeVarSubstitutions ScopeTab rep
scope0 ScalarTab
scals0 Set LmadRef
slmads =
      Set LmadRef
slmads
        forall a b. a -> (a -> b) -> b
& forall b a. Ord b => (a -> b) -> Set a -> Set b
S.map (forall {k} a (t :: k).
Ord a =>
Map a (TPrimExp t a) -> LMAD (TPrimExp t a) -> LMAD (TPrimExp t a)
IxFun.substituteInLMAD FreeVarSubsts
subs)
        forall a b. a -> (a -> b) -> b
& Set LmadRef -> AccessSummary
Set
translateAccessSummary ScopeTab rep
_ ScalarTab
_ AccessSummary
_ = AccessSummary
Undeterminable

-- | This function computes the written and read memory references for the current statement
getUseSumFromStm ::
  (Op rep ~ MemOp inner, HasMemBlock (Aliases rep)) =>
  TopdownEnv rep ->
  CoalsTab ->
  Stm (Aliases rep) ->
  -- | A pair of written and written+read memory locations, along with their
  -- associated array and the index function used
  Maybe ([(VName, VName, IxFun)], [(VName, VName, IxFun)])
getUseSumFromStm :: forall {k} (rep :: k) inner.
(Op rep ~ MemOp inner, HasMemBlock (Aliases rep)) =>
TopdownEnv rep
-> CoalsTab
-> Stm (Aliases rep)
-> Maybe ([(VName, VName, IxFun)], [(VName, VName, IxFun)])
getUseSumFromStm TopdownEnv rep
td_env CoalsTab
coal_tab (Let Pat (LetDec (Aliases rep))
_ StmAux (ExpDec (Aliases rep))
_ (BasicOp (Index VName
arr (Slice [DimIndex SubExp]
slc))))
  | Just (MemBlock PrimType
_ Shape
shp VName
_ IxFun
_) <- forall {k} (rep :: k).
HasMemBlock rep =>
VName -> Scope rep -> Maybe ArrayMemBound
getScopeMemInfo VName
arr (forall {k} (rep :: k). TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env),
    forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimIndex SubExp]
slc forall a. Eq a => a -> a -> Bool
== forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall d. ShapeBase d -> [d]
shapeDims Shape
shp) Bool -> Bool -> Bool
&& forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all forall {d}. DimIndex d -> Bool
isFix [DimIndex SubExp]
slc = do
      (VName
mem_b, VName
mem_arr, IxFun
ixfn_arr) <- forall {k} (rep :: k).
HasMemBlock (Aliases rep) =>
TopdownEnv rep -> CoalsTab -> VName -> Maybe (VName, VName, IxFun)
getDirAliasedIxfn TopdownEnv rep
td_env CoalsTab
coal_tab VName
arr
      let new_ixfn :: IxFun
new_ixfn = forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun
ixfn_arr forall a b. (a -> b) -> a -> b
$ forall d. [DimIndex d] -> Slice d
Slice forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 VName
pe64) [DimIndex SubExp]
slc
      forall (f :: * -> *) a. Applicative f => a -> f a
pure ([], [(VName
mem_b, VName
mem_arr, IxFun
new_ixfn)])
  where
    isFix :: DimIndex d -> Bool
isFix DimFix {} = Bool
True
    isFix DimIndex d
_ = Bool
False
getUseSumFromStm TopdownEnv rep
_ CoalsTab
_ (Let Pat {} StmAux (ExpDec (Aliases rep))
_ (BasicOp Index {})) = forall a. a -> Maybe a
Just ([], []) -- incomplete slices
getUseSumFromStm TopdownEnv rep
_ CoalsTab
_ (Let Pat {} StmAux (ExpDec (Aliases rep))
_ (BasicOp FlatIndex {})) = forall a. a -> Maybe a
Just ([], []) -- incomplete slices
getUseSumFromStm TopdownEnv rep
td_env CoalsTab
coal_tab (Let (Pat [PatElem (LetDec (Aliases rep))]
pes) StmAux (ExpDec (Aliases rep))
_ (BasicOp (ArrayLit [SubExp]
ses Type
_))) =
  let rds :: [(VName, VName, IxFun)]
rds = forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (forall {k} (rep :: k).
HasMemBlock (Aliases rep) =>
TopdownEnv rep -> CoalsTab -> VName -> Maybe (VName, VName, IxFun)
getDirAliasedIxfn TopdownEnv rep
td_env CoalsTab
coal_tab) forall a b. (a -> b) -> a -> b
$ forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe SubExp -> Maybe VName
seName [SubExp]
ses
      wrts :: [(VName, VName, IxFun)]
wrts = forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (forall {k} (rep :: k).
HasMemBlock (Aliases rep) =>
TopdownEnv rep -> CoalsTab -> VName -> Maybe (VName, VName, IxFun)
getDirAliasedIxfn TopdownEnv rep
td_env CoalsTab
coal_tab forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. PatElem dec -> VName
patElemName) [PatElem (LetDec (Aliases rep))]
pes
   in forall a. a -> Maybe a
Just ([(VName, VName, IxFun)]
wrts, [(VName, VName, IxFun)]
wrts forall a. [a] -> [a] -> [a]
++ [(VName, VName, IxFun)]
rds)
  where
    seName :: SubExp -> Maybe VName
seName (Var VName
a) = forall a. a -> Maybe a
Just VName
a
    seName (Constant PrimValue
_) = forall a. Maybe a
Nothing
-- In place update @x[slc] <- a@. In the "in-place update" case,
--   summaries should be added after the old variable @x@ has
--   been added in the active coalesced table.
getUseSumFromStm TopdownEnv rep
td_env CoalsTab
coal_tab (Let (Pat [PatElem (LetDec (Aliases rep))
x']) StmAux (ExpDec (Aliases rep))
_ (BasicOp (Update Safety
_ VName
_x (Slice [DimIndex SubExp]
slc) SubExp
a_se))) = do
  (VName
m_b, VName
m_x, IxFun
x_ixfn) <- forall {k} (rep :: k).
HasMemBlock (Aliases rep) =>
TopdownEnv rep -> CoalsTab -> VName -> Maybe (VName, VName, IxFun)
getDirAliasedIxfn TopdownEnv rep
td_env CoalsTab
coal_tab (forall dec. PatElem dec -> VName
patElemName PatElem (LetDec (Aliases rep))
x')
  let x_ixfn_slc :: IxFun
x_ixfn_slc = forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun
x_ixfn forall a b. (a -> b) -> a -> b
$ forall d. [DimIndex d] -> Slice d
Slice forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 VName
pe64) [DimIndex SubExp]
slc
      r1 :: (VName, VName, IxFun)
r1 = (VName
m_b, VName
m_x, IxFun
x_ixfn_slc)
  case SubExp
a_se of
    Constant PrimValue
_ -> forall a. a -> Maybe a
Just ([(VName, VName, IxFun)
r1], [(VName, VName, IxFun)
r1])
    Var VName
a -> case forall {k} (rep :: k).
HasMemBlock (Aliases rep) =>
TopdownEnv rep -> CoalsTab -> VName -> Maybe (VName, VName, IxFun)
getDirAliasedIxfn TopdownEnv rep
td_env CoalsTab
coal_tab VName
a of
      Maybe (VName, VName, IxFun)
Nothing -> forall a. a -> Maybe a
Just ([(VName, VName, IxFun)
r1], [(VName, VName, IxFun)
r1])
      Just (VName, VName, IxFun)
r2 -> forall a. a -> Maybe a
Just ([(VName, VName, IxFun)
r1], [(VName, VName, IxFun)
r1, (VName, VName, IxFun)
r2])
getUseSumFromStm TopdownEnv rep
td_env CoalsTab
coal_tab (Let (Pat [PatElem (LetDec (Aliases rep))
y]) StmAux (ExpDec (Aliases rep))
_ (BasicOp (Copy VName
x))) = do
  -- y = copy x
  (VName, VName, IxFun)
wrt <- forall {k} (rep :: k).
HasMemBlock (Aliases rep) =>
TopdownEnv rep -> CoalsTab -> VName -> Maybe (VName, VName, IxFun)
getDirAliasedIxfn TopdownEnv rep
td_env CoalsTab
coal_tab forall a b. (a -> b) -> a -> b
$ forall dec. PatElem dec -> VName
patElemName PatElem (LetDec (Aliases rep))
y
  (VName, VName, IxFun)
rd <- forall {k} (rep :: k).
HasMemBlock (Aliases rep) =>
TopdownEnv rep -> CoalsTab -> VName -> Maybe (VName, VName, IxFun)
getDirAliasedIxfn TopdownEnv rep
td_env CoalsTab
coal_tab VName
x
  forall (f :: * -> *) a. Applicative f => a -> f a
pure ([(VName, VName, IxFun)
wrt], [(VName, VName, IxFun)
wrt, (VName, VName, IxFun)
rd])
getUseSumFromStm TopdownEnv rep
_ CoalsTab
_ (Let Pat {} StmAux (ExpDec (Aliases rep))
_ (BasicOp Copy {})) = forall a. HasCallStack => [Char] -> a
error [Char]
"Impossible"
getUseSumFromStm TopdownEnv rep
td_env CoalsTab
coal_tab (Let (Pat [PatElem (LetDec (Aliases rep))]
ys) StmAux (ExpDec (Aliases rep))
_ (BasicOp (Concat Int
_i (VName
a :| [VName]
bs) SubExp
_ses))) =
  -- concat
  let ws :: [(VName, VName, IxFun)]
ws = forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (forall {k} (rep :: k).
HasMemBlock (Aliases rep) =>
TopdownEnv rep -> CoalsTab -> VName -> Maybe (VName, VName, IxFun)
getDirAliasedIxfn TopdownEnv rep
td_env CoalsTab
coal_tab forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. PatElem dec -> VName
patElemName) [PatElem (LetDec (Aliases rep))]
ys
      rs :: [(VName, VName, IxFun)]
rs = forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (forall {k} (rep :: k).
HasMemBlock (Aliases rep) =>
TopdownEnv rep -> CoalsTab -> VName -> Maybe (VName, VName, IxFun)
getDirAliasedIxfn TopdownEnv rep
td_env CoalsTab
coal_tab) (VName
a forall a. a -> [a] -> [a]
: [VName]
bs)
   in forall a. a -> Maybe a
Just ([(VName, VName, IxFun)]
ws, [(VName, VName, IxFun)]
ws forall a. [a] -> [a] -> [a]
++ [(VName, VName, IxFun)]
rs)
getUseSumFromStm TopdownEnv rep
td_env CoalsTab
coal_tab (Let (Pat [PatElem (LetDec (Aliases rep))]
ys) StmAux (ExpDec (Aliases rep))
_ (BasicOp (Manifest [Int]
_perm VName
x))) =
  let ws :: [(VName, VName, IxFun)]
ws = forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (forall {k} (rep :: k).
HasMemBlock (Aliases rep) =>
TopdownEnv rep -> CoalsTab -> VName -> Maybe (VName, VName, IxFun)
getDirAliasedIxfn TopdownEnv rep
td_env CoalsTab
coal_tab forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. PatElem dec -> VName
patElemName) [PatElem (LetDec (Aliases rep))]
ys
      rs :: [(VName, VName, IxFun)]
rs = forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (forall {k} (rep :: k).
HasMemBlock (Aliases rep) =>
TopdownEnv rep -> CoalsTab -> VName -> Maybe (VName, VName, IxFun)
getDirAliasedIxfn TopdownEnv rep
td_env CoalsTab
coal_tab) [VName
x]
   in forall a. a -> Maybe a
Just ([(VName, VName, IxFun)]
ws, [(VName, VName, IxFun)]
ws forall a. [a] -> [a] -> [a]
++ [(VName, VName, IxFun)]
rs)
getUseSumFromStm TopdownEnv rep
td_env CoalsTab
coal_tab (Let (Pat [PatElem (LetDec (Aliases rep))]
ys) StmAux (ExpDec (Aliases rep))
_ (BasicOp (Replicate Shape
_shp SubExp
se))) =
  let ws :: [(VName, VName, IxFun)]
ws = forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (forall {k} (rep :: k).
HasMemBlock (Aliases rep) =>
TopdownEnv rep -> CoalsTab -> VName -> Maybe (VName, VName, IxFun)
getDirAliasedIxfn TopdownEnv rep
td_env CoalsTab
coal_tab forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. PatElem dec -> VName
patElemName) [PatElem (LetDec (Aliases rep))]
ys
   in case SubExp
se of
        Constant PrimValue
_ -> forall a. a -> Maybe a
Just ([(VName, VName, IxFun)]
ws, [(VName, VName, IxFun)]
ws)
        Var VName
x -> forall a. a -> Maybe a
Just ([(VName, VName, IxFun)]
ws, [(VName, VName, IxFun)]
ws forall a. [a] -> [a] -> [a]
++ forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (forall {k} (rep :: k).
HasMemBlock (Aliases rep) =>
TopdownEnv rep -> CoalsTab -> VName -> Maybe (VName, VName, IxFun)
getDirAliasedIxfn TopdownEnv rep
td_env CoalsTab
coal_tab) [VName
x])
getUseSumFromStm TopdownEnv rep
td_env CoalsTab
coal_tab (Let (Pat [PatElem (LetDec (Aliases rep))
x]) StmAux (ExpDec (Aliases rep))
_ (BasicOp (FlatUpdate VName
_ (FlatSlice SubExp
offset [FlatDimIndex SubExp]
slc) VName
v)))
  | Just (VName
m_b, VName
m_x, IxFun
x_ixfn) <- forall {k} (rep :: k).
HasMemBlock (Aliases rep) =>
TopdownEnv rep -> CoalsTab -> VName -> Maybe (VName, VName, IxFun)
getDirAliasedIxfn TopdownEnv rep
td_env CoalsTab
coal_tab (forall dec. PatElem dec -> VName
patElemName PatElem (LetDec (Aliases rep))
x) =
      let x_ixfn_slc :: IxFun
x_ixfn_slc = forall num.
(Eq num, IntegralExp num) =>
IxFun num -> FlatSlice num -> IxFun num
IxFun.flatSlice IxFun
x_ixfn forall a b. (a -> b) -> a -> b
$ forall d. d -> [FlatDimIndex d] -> FlatSlice d
FlatSlice (SubExp -> TPrimExp Int64 VName
pe64 SubExp
offset) forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 VName
pe64) [FlatDimIndex SubExp]
slc
          r1 :: (VName, VName, IxFun)
r1 = (VName
m_b, VName
m_x, IxFun
x_ixfn_slc)
       in case forall {k} (rep :: k).
HasMemBlock (Aliases rep) =>
TopdownEnv rep -> CoalsTab -> VName -> Maybe (VName, VName, IxFun)
getDirAliasedIxfn TopdownEnv rep
td_env CoalsTab
coal_tab VName
v of
            Maybe (VName, VName, IxFun)
Nothing -> forall a. a -> Maybe a
Just ([(VName, VName, IxFun)
r1], [(VName, VName, IxFun)
r1])
            Just (VName, VName, IxFun)
r2 -> forall a. a -> Maybe a
Just ([(VName, VName, IxFun)
r1], [(VName, VName, IxFun)
r1, (VName, VName, IxFun)
r2])
getUseSumFromStm TopdownEnv rep
_ CoalsTab
_ (Let Pat {} StmAux (ExpDec (Aliases rep))
_ BasicOp {}) = forall a. a -> Maybe a
Just ([], [])
getUseSumFromStm TopdownEnv rep
_ CoalsTab
_ (Let Pat {} StmAux (ExpDec (Aliases rep))
_ (Op (Alloc SubExp
_ Space
_))) = forall a. a -> Maybe a
Just ([], [])
getUseSumFromStm TopdownEnv rep
_ CoalsTab
_ Stm (Aliases rep)
_ =
  -- if-then-else, loops are supposed to be treated separately,
  -- calls are not supported, and Ops are not yet supported
  forall a. Maybe a
Nothing

-- | This function:
--     1. computes the written and read memory references for the current statement
--          (by calling @getUseSumFromStm@)
--     2. fails the entries in active coalesced table for which the write set
--          overlaps the uses of the destination (to that point)
recordMemRefUses ::
  (CanBeAliased (Op rep), RepTypes rep, Op rep ~ MemOp inner, HasMemBlock (Aliases rep)) =>
  TopdownEnv rep ->
  BotUpEnv ->
  Stm (Aliases rep) ->
  (CoalsTab, InhibitTab)
recordMemRefUses :: forall {k} (rep :: k) inner.
(CanBeAliased (Op rep), RepTypes rep, Op rep ~ MemOp inner,
 HasMemBlock (Aliases rep)) =>
TopdownEnv rep
-> BotUpEnv -> Stm (Aliases rep) -> (CoalsTab, InhibitTab)
recordMemRefUses TopdownEnv rep
td_env BotUpEnv
bu_env Stm (Aliases rep)
stm =
  let active_tab :: CoalsTab
active_tab = BotUpEnv -> CoalsTab
activeCoals BotUpEnv
bu_env
      inhibit_tab :: InhibitTab
inhibit_tab = BotUpEnv -> InhibitTab
inhibit BotUpEnv
bu_env
      active_etries :: [(VName, CoalsEntry)]
active_etries = forall k a. Map k a -> [(k, a)]
M.toList CoalsTab
active_tab
   in case forall {k} (rep :: k) inner.
(Op rep ~ MemOp inner, HasMemBlock (Aliases rep)) =>
TopdownEnv rep
-> CoalsTab
-> Stm (Aliases rep)
-> Maybe ([(VName, VName, IxFun)], [(VName, VName, IxFun)])
getUseSumFromStm TopdownEnv rep
td_env CoalsTab
active_tab Stm (Aliases rep)
stm of
        Maybe ([(VName, VName, IxFun)], [(VName, VName, IxFun)])
Nothing ->
          forall k a. Map k a -> [(k, a)]
M.toList CoalsTab
active_tab
            forall a b. a -> (a -> b) -> b
& forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl
              ( \(CoalsTab, InhibitTab)
state (VName
m_b, CoalsEntry
entry) ->
                  if Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> Bool
null forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [VName]
patNames (forall {k} (rep :: k). Stm rep -> Pat (LetDec rep)
stmPat Stm (Aliases rep)
stm) forall a. Eq a => [a] -> [a] -> [a]
`intersect` forall k a. Map k a -> [k]
M.keys (CoalsEntry -> Map VName Coalesced
vartab CoalsEntry
entry)
                    then (CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab)
markFailedCoal (CoalsTab, InhibitTab)
state VName
m_b
                    else (CoalsTab, InhibitTab)
state
              )
              (CoalsTab
active_tab, InhibitTab
inhibit_tab)
        Just ([(VName, VName, IxFun)], [(VName, VName, IxFun)])
use_sums ->
          let ([Maybe AccessSummary]
mb_wrts, [AccessSummary]
prev_uses, [AccessSummary]
mb_lmads) =
                forall a b. (a -> b) -> [a] -> [b]
map (([(VName, VName, IxFun)], [(VName, VName, IxFun)])
-> CoalsTab
-> (VName, CoalsEntry)
-> (Maybe AccessSummary, AccessSummary, AccessSummary)
checkOverlapAndExpand ([(VName, VName, IxFun)], [(VName, VName, IxFun)])
use_sums CoalsTab
active_tab) [(VName, CoalsEntry)]
active_etries
                  forall a b. a -> (a -> b) -> b
& forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3

              -- keep only the entries that do not overlap with the memory
              -- blocks defined in @pat@ or @inner_free_vars@.
              -- the others must be recorded in @inhibit_tab@ because
              -- they violate the 3rd safety condition.
              active_tab1 :: CoalsTab
active_tab1 =
                forall k a. Ord k => [(k, a)] -> Map k a
M.fromList
                  forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map
                    ( \(AccessSummary
wrts, (AccessSummary
uses, AccessSummary
prev_use, (VName
k, CoalsEntry
etry))) ->
                        let mrefs' :: MemRefs
mrefs' = (CoalsEntry -> MemRefs
memrefs CoalsEntry
etry) {dstrefs :: AccessSummary
dstrefs = AccessSummary
prev_use}
                            etry' :: CoalsEntry
etry' = CoalsEntry
etry {memrefs :: MemRefs
memrefs = MemRefs
mrefs'}
                         in (VName
k, AccessSummary -> AccessSummary -> CoalsEntry -> CoalsEntry
addLmads AccessSummary
wrts AccessSummary
uses CoalsEntry
etry')
                    )
                  forall a b. (a -> b) -> a -> b
$ forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (\(Maybe AccessSummary
x, (AccessSummary, AccessSummary, (VName, CoalsEntry))
y) -> (,(AccessSummary, AccessSummary, (VName, CoalsEntry))
y) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe AccessSummary
x) -- only keep successful coals
                  forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [Maybe AccessSummary]
mb_wrts
                  forall a b. (a -> b) -> a -> b
$ forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [AccessSummary]
mb_lmads [AccessSummary]
prev_uses [(VName, CoalsEntry)]
active_etries
              failed_tab :: CoalsTab
failed_tab =
                forall k a. Ord k => [(k, a)] -> Map k a
M.fromList forall a b. (a -> b) -> a -> b
$
                  forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd forall a b. (a -> b) -> a -> b
$
                    forall a. (a -> Bool) -> [a] -> [a]
filter (forall a. Maybe a -> Bool
isNothing forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) forall a b. (a -> b) -> a -> b
$
                      forall a b. [a] -> [b] -> [(a, b)]
zip [Maybe AccessSummary]
mb_wrts [(VName, CoalsEntry)]
active_etries
              (CoalsTab
_, InhibitTab
inhibit_tab1) = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab)
markFailedCoal (CoalsTab
failed_tab, InhibitTab
inhibit_tab) forall a b. (a -> b) -> a -> b
$ forall k a. Map k a -> [k]
M.keys CoalsTab
failed_tab
           in (CoalsTab
active_tab1, InhibitTab
inhibit_tab1)
  where
    checkOverlapAndExpand :: ([(VName, VName, IxFun)], [(VName, VName, IxFun)])
-> CoalsTab
-> (VName, CoalsEntry)
-> (Maybe AccessSummary, AccessSummary, AccessSummary)
checkOverlapAndExpand ([(VName, VName, IxFun)]
stm_wrts, [(VName, VName, IxFun)]
stm_uses) CoalsTab
active_tab (VName
m_b, CoalsEntry
etry) =
      let alias_m_b :: Names
alias_m_b = Names -> VName -> Names
getAliases forall a. Monoid a => a
mempty VName
m_b
          stm_uses' :: [(VName, VName, IxFun)]
stm_uses' = forall a. (a -> Bool) -> [a] -> [a]
filter ((VName -> Names -> Bool
`notNameIn` Names
alias_m_b) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {a} {b} {c}. (a, b, c) -> a
tupFst) [(VName, VName, IxFun)]
stm_uses
          all_aliases :: Names
all_aliases = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Names -> VName -> Names
getAliases forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList forall a b. (a -> b) -> a -> b
$ CoalsEntry -> Names
alsmem CoalsEntry
etry
          ixfns :: [IxFun]
ixfns = forall a b. (a -> b) -> [a] -> [b]
map forall {a} {b} {c}. (a, b, c) -> c
tupThd forall a b. (a -> b) -> a -> b
$ forall a. (a -> Bool) -> [a] -> [a]
filter ((VName -> Names -> Bool
`nameIn` Names
all_aliases) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {a} {b} {c}. (a, b, c) -> b
tupSnd) [(VName, VName, IxFun)]
stm_uses'
          lmads' :: [LmadRef]
lmads' = forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe IxFun -> Maybe LmadRef
mbLmad [IxFun]
ixfns
          lmads'' :: AccessSummary
lmads'' =
            if forall (t :: * -> *) a. Foldable t => t a -> Int
length [LmadRef]
lmads' forall a. Eq a => a -> a -> Bool
== forall (t :: * -> *) a. Foldable t => t a -> Int
length [IxFun]
ixfns
              then Set LmadRef -> AccessSummary
Set forall a b. (a -> b) -> a -> b
$ forall a. Ord a => [a] -> Set a
S.fromList [LmadRef]
lmads'
              else AccessSummary
Undeterminable
          wrt_ixfns :: [IxFun]
wrt_ixfns = forall a b. (a -> b) -> [a] -> [b]
map forall {a} {b} {c}. (a, b, c) -> c
tupThd forall a b. (a -> b) -> a -> b
$ forall a. (a -> Bool) -> [a] -> [a]
filter ((VName -> Names -> Bool
`nameIn` Names
alias_m_b) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {a} {b} {c}. (a, b, c) -> a
tupFst) [(VName, VName, IxFun)]
stm_wrts
          wrt_tmps :: [LmadRef]
wrt_tmps = forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe IxFun -> Maybe LmadRef
mbLmad [IxFun]
wrt_ixfns
          prev_use :: AccessSummary
prev_use =
            forall {k} (rep :: k).
ScopeTab rep -> ScalarTab -> AccessSummary -> AccessSummary
translateAccessSummary (forall {k} (rep :: k). TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env) (forall {k} (rep :: k). TopdownEnv rep -> ScalarTab
scalarTable TopdownEnv rep
td_env) forall a b. (a -> b) -> a -> b
$
              (MemRefs -> AccessSummary
dstrefs forall b c a. (b -> c) -> (a -> b) -> a -> c
. CoalsEntry -> MemRefs
memrefs) CoalsEntry
etry
          wrt_lmads' :: AccessSummary
wrt_lmads' =
            if forall (t :: * -> *) a. Foldable t => t a -> Int
length [LmadRef]
wrt_tmps forall a. Eq a => a -> a -> Bool
== forall (t :: * -> *) a. Foldable t => t a -> Int
length [IxFun]
wrt_ixfns
              then Set LmadRef -> AccessSummary
Set forall a b. (a -> b) -> a -> b
$ forall a. Ord a => [a] -> Set a
S.fromList [LmadRef]
wrt_tmps
              else AccessSummary
Undeterminable
          original_mem_aliases :: Names
original_mem_aliases =
            forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall {a} {b} {c}. (a, b, c) -> a
tupFst [(VName, VName, IxFun)]
stm_uses
              forall a b. a -> (a -> b) -> b
& forall a. [a] -> Maybe (a, [a])
uncons
              forall a b. a -> (a -> b) -> b
& forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. (a, b) -> a
fst
              forall a b. a -> (a -> b) -> b
& forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
(=<<) (forall k a. Ord k => k -> Map k a -> Maybe a
`M.lookup` CoalsTab
active_tab)
              forall a b. a -> (a -> b) -> b
& forall b a. b -> (a -> b) -> Maybe a -> b
maybe forall a. Monoid a => a
mempty CoalsEntry -> Names
alsmem
          (AccessSummary
wrt_lmads'', AccessSummary
lmads) =
            if VName
m_b VName -> Names -> Bool
`nameIn` Names
original_mem_aliases
              then (AccessSummary
wrt_lmads' forall a. Semigroup a => a -> a -> a
<> AccessSummary
lmads'', Set LmadRef -> AccessSummary
Set forall a. Monoid a => a
mempty)
              else (AccessSummary
wrt_lmads', AccessSummary
lmads'')
          no_overlap :: Bool
no_overlap = forall {k} (rep :: k).
(CanBeAliased (Op rep), RepTypes rep) =>
TopdownEnv rep -> AccessSummary -> AccessSummary -> Bool
noMemOverlap TopdownEnv rep
td_env AccessSummary
prev_use AccessSummary
wrt_lmads''
          wrt_lmads :: Maybe AccessSummary
wrt_lmads =
            if Bool
no_overlap
              then forall a. a -> Maybe a
Just AccessSummary
wrt_lmads''
              else forall a. Maybe a
Nothing
       in (Maybe AccessSummary
wrt_lmads, AccessSummary
prev_use, AccessSummary
lmads)

    tupFst :: (a, b, c) -> a
tupFst (a
a, b
_, c
_) = a
a
    tupSnd :: (a, b, c) -> b
tupSnd (a
_, b
b, c
_) = b
b
    tupThd :: (a, b, c) -> c
tupThd (a
_, b
_, c
c) = c
c
    getAliases :: Names -> VName -> Names
getAliases Names
acc VName
m =
      VName -> Names
oneName VName
m
        forall a. Semigroup a => a -> a -> a
<> Names
acc
        forall a. Semigroup a => a -> a -> a
<> forall a. a -> Maybe a -> a
fromMaybe forall a. Monoid a => a
mempty (forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
m (forall {k} (rep :: k). TopdownEnv rep -> InhibitTab
m_alias TopdownEnv rep
td_env))
    mbLmad :: IxFun -> Maybe LmadRef
mbLmad IxFun
indfun
      | Just FreeVarSubsts
subs <- forall {k} a (rep :: k).
FreeIn a =>
ScopeTab rep -> ScalarTab -> a -> Maybe FreeVarSubsts
freeVarSubstitutions (forall {k} (rep :: k). TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env) (BotUpEnv -> ScalarTab
scals BotUpEnv
bu_env) IxFun
indfun,
        (IxFun.IxFun (LmadRef
lmad :| []) Shape (TPrimExp Int64 VName)
_ Bool
_) <- forall {k} a (t :: k).
Ord a =>
Map a (TPrimExp t a)
-> IxFun (TPrimExp t a) -> IxFun (TPrimExp t a)
IxFun.substituteInIxFun FreeVarSubsts
subs IxFun
indfun =
          forall a. a -> Maybe a
Just LmadRef
lmad
    mbLmad IxFun
_ = forall a. Maybe a
Nothing
    addLmads :: AccessSummary -> AccessSummary -> CoalsEntry -> CoalsEntry
addLmads AccessSummary
wrts AccessSummary
uses CoalsEntry
etry =
      CoalsEntry
etry {memrefs :: MemRefs
memrefs = AccessSummary -> AccessSummary -> MemRefs
MemRefs AccessSummary
uses AccessSummary
wrts forall a. Semigroup a => a -> a -> a
<> CoalsEntry -> MemRefs
memrefs CoalsEntry
etry}

-- | Check for memory overlap of two access summaries.
--
-- This check is conservative, so unless we can guarantee that there is no
-- overlap, we return 'False'.
noMemOverlap :: (CanBeAliased (Op rep), RepTypes rep) => TopdownEnv rep -> AccessSummary -> AccessSummary -> Bool
noMemOverlap :: forall {k} (rep :: k).
(CanBeAliased (Op rep), RepTypes rep) =>
TopdownEnv rep -> AccessSummary -> AccessSummary -> Bool
noMemOverlap TopdownEnv rep
_ AccessSummary
_ (Set Set LmadRef
mr)
  | Set LmadRef
mr forall a. Eq a => a -> a -> Bool
== forall a. Monoid a => a
mempty = Bool
True
noMemOverlap TopdownEnv rep
_ (Set Set LmadRef
mr) AccessSummary
_
  | Set LmadRef
mr forall a. Eq a => a -> a -> Bool
== forall a. Monoid a => a
mempty = Bool
True
noMemOverlap TopdownEnv rep
td_env (Set Set LmadRef
is0) (Set Set LmadRef
js0)
  | Just [PrimExp VName]
non_negs <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *) v.
Applicative m =>
(VName -> m (PrimExp v)) -> SubExp -> m (PrimExp v)
primExpFromSubExpM (forall {k} (rep :: k).
(CanBeAliased (Op rep), RepTypes rep) =>
ScopeTab rep -> ScalarTab -> VName -> Maybe (PrimExp VName)
vnameToPrimExp (forall {k} (rep :: k). TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env) (forall {k} (rep :: k). TopdownEnv rep -> ScalarTab
scalarTable TopdownEnv rep
td_env)) forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). TopdownEnv rep -> Names
nonNegatives TopdownEnv rep
td_env =
      let ([LmadRef]
_, [LmadRef]
not_disjoints) =
            forall a. (a -> Bool) -> [a] -> ([a], [a])
partition
              ( \LmadRef
i ->
                  forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all
                    ( \LmadRef
j ->
                        [(VName, PrimExp VName)] -> Names -> LmadRef -> LmadRef -> Bool
IxFun.disjoint [(VName, PrimExp VName)]
less_thans (forall {k} (rep :: k). TopdownEnv rep -> Names
nonNegatives TopdownEnv rep
td_env) LmadRef
i LmadRef
j
                          Bool -> Bool -> Bool
|| forall scope asserts.
scope
-> asserts
-> [(VName, PrimExp VName)]
-> Names
-> LmadRef
-> LmadRef
-> Bool
IxFun.disjoint2 () () [(VName, PrimExp VName)]
less_thans (forall {k} (rep :: k). TopdownEnv rep -> Names
nonNegatives TopdownEnv rep
td_env) LmadRef
i LmadRef
j
                          Bool -> Bool -> Bool
|| Map VName Type
-> [PrimExp VName]
-> [(VName, PrimExp VName)]
-> [PrimExp VName]
-> LmadRef
-> LmadRef
-> Bool
IxFun.disjoint3 (forall t. Typed t => t -> Type
typeOf forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k). TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env) [PrimExp VName]
asserts [(VName, PrimExp VName)]
less_thans [PrimExp VName]
non_negs LmadRef
i LmadRef
j
                    )
                    [LmadRef]
js
              )
              [LmadRef]
is
       in forall (t :: * -> *) a. Foldable t => t a -> Bool
null [LmadRef]
not_disjoints
  where
    less_thans :: [(VName, PrimExp VName)]
less_thans = forall a b. (a -> b) -> [a] -> [b]
map (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. (a -> b) -> a -> b
$ forall a. Eq a => (a -> a) -> a -> a
fixPoint forall a b. (a -> b) -> a -> b
$ forall v. Ord v => Map v (PrimExp v) -> PrimExp v -> PrimExp v
substituteInPrimExp forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). TopdownEnv rep -> ScalarTab
scalarTable TopdownEnv rep
td_env) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). TopdownEnv rep -> [(VName, PrimExp VName)]
knownLessThan TopdownEnv rep
td_env
    asserts :: [PrimExp VName]
asserts = forall a b. (a -> b) -> [a] -> [b]
map (forall a. Eq a => (a -> a) -> a -> a
fixPoint (forall v. Ord v => Map v (PrimExp v) -> PrimExp v -> PrimExp v
substituteInPrimExp forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). TopdownEnv rep -> ScalarTab
scalarTable TopdownEnv rep
td_env) forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
Bool) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). TopdownEnv rep -> [SubExp]
td_asserts TopdownEnv rep
td_env
    is :: [LmadRef]
is = forall a b. (a -> b) -> [a] -> [b]
map (forall a. Eq a => (a -> a) -> a -> a
fixPoint (forall {k} a (t :: k).
Ord a =>
Map a (TPrimExp t a) -> LMAD (TPrimExp t a) -> LMAD (TPrimExp t a)
IxFun.substituteInLMAD forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k). TopdownEnv rep -> ScalarTab
scalarTable TopdownEnv rep
td_env)) forall a b. (a -> b) -> a -> b
$ forall a. Set a -> [a]
S.toList Set LmadRef
is0
    js :: [LmadRef]
js = forall a b. (a -> b) -> [a] -> [b]
map (forall a. Eq a => (a -> a) -> a -> a
fixPoint (forall {k} a (t :: k).
Ord a =>
Map a (TPrimExp t a) -> LMAD (TPrimExp t a) -> LMAD (TPrimExp t a)
IxFun.substituteInLMAD forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k). TopdownEnv rep -> ScalarTab
scalarTable TopdownEnv rep
td_env)) forall a b. (a -> b) -> a -> b
$ forall a. Set a -> [a]
S.toList Set LmadRef
js0
noMemOverlap TopdownEnv rep
_ AccessSummary
_ AccessSummary
_ = Bool
False

-- | Computes the total aggregated access summary for a loop by expanding the
-- access summary given according to the iterator variable and bounds of the
-- loop.
--
-- Corresponds to:
--
-- \[
--   \bigcup_{j=0}^{j<n} Access_j
-- \]
aggSummaryLoopTotal ::
  MonadFreshNames m =>
  ScopeTab rep ->
  ScopeTab rep ->
  ScalarTab ->
  Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName)) ->
  AccessSummary ->
  m AccessSummary
aggSummaryLoopTotal :: forall {k} (m :: * -> *) (rep :: k).
MonadFreshNames m =>
ScopeTab rep
-> ScopeTab rep
-> ScalarTab
-> Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName))
-> AccessSummary
-> m AccessSummary
aggSummaryLoopTotal ScopeTab rep
_ ScopeTab rep
_ ScalarTab
_ Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName))
_ AccessSummary
Undeterminable = forall (f :: * -> *) a. Applicative f => a -> f a
pure AccessSummary
Undeterminable
aggSummaryLoopTotal ScopeTab rep
_ ScopeTab rep
_ ScalarTab
_ Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName))
_ (Set Set LmadRef
l)
  | Set LmadRef
l forall a. Eq a => a -> a -> Bool
== forall a. Monoid a => a
mempty = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Set LmadRef -> AccessSummary
Set forall a. Monoid a => a
mempty
aggSummaryLoopTotal ScopeTab rep
scope_bef ScopeTab rep
scope_loop ScalarTab
scals_loop Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName))
_ AccessSummary
access
  | Set Set LmadRef
ls <- forall {k} (rep :: k).
ScopeTab rep -> ScalarTab -> AccessSummary -> AccessSummary
translateAccessSummary ScopeTab rep
scope_loop ScalarTab
scals_loop AccessSummary
access,
    Names
nms <- forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl forall a. Semigroup a => a -> a -> a
(<>) forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a. FreeIn a => a -> Names
freeIn forall a b. (a -> b) -> a -> b
$ forall a. Set a -> [a]
S.toList Set LmadRef
ls,
    forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all VName -> Bool
inBeforeScope forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList Names
nms = do
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Set LmadRef -> AccessSummary
Set Set LmadRef
ls
  where
    inBeforeScope :: VName -> Bool
inBeforeScope VName
v =
      case forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v ScopeTab rep
scope_bef of
        Maybe (NameInfo (Aliases rep))
Nothing -> Bool
False
        Just NameInfo (Aliases rep)
_ -> Bool
True
aggSummaryLoopTotal ScopeTab rep
_ ScopeTab rep
_ ScalarTab
scalars_loop (Just (VName
iterator_var, (TPrimExp Int64 VName
lower_bound, TPrimExp Int64 VName
upper_bound))) (Set Set LmadRef
lmads) =
  forall (m :: * -> *) b a.
(Monad m, Monoid b) =>
(a -> m b) -> [a] -> m b
concatMapM
    ( forall (m :: * -> *).
MonadFreshNames m =>
VName
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> LmadRef
-> m AccessSummary
aggSummaryOne VName
iterator_var TPrimExp Int64 VName
lower_bound TPrimExp Int64 VName
upper_bound
        forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Eq a => (a -> a) -> a -> a
fixPoint (forall {k} a (t :: k).
Ord a =>
Map a (TPrimExp t a) -> LMAD (TPrimExp t a) -> LMAD (TPrimExp t a)
IxFun.substituteInLMAD forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp ScalarTab
scalars_loop)
    )
    (forall a. Set a -> [a]
S.toList Set LmadRef
lmads)
aggSummaryLoopTotal ScopeTab rep
_ ScopeTab rep
_ ScalarTab
_ Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName))
_ AccessSummary
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure AccessSummary
Undeterminable

-- | For a given iteration of the loop $i$, computes the aggregated loop access
-- summary of all later iterations.
--
-- Corresponds to:
--
-- \[
--   \bigcup_{j=i+1}^{j<n} Access_j
-- \]
aggSummaryLoopPartial ::
  MonadFreshNames m =>
  ScalarTab ->
  Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName)) ->
  AccessSummary ->
  m AccessSummary
aggSummaryLoopPartial :: forall (m :: * -> *).
MonadFreshNames m =>
ScalarTab
-> Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName))
-> AccessSummary
-> m AccessSummary
aggSummaryLoopPartial ScalarTab
_ Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName))
_ AccessSummary
Undeterminable = forall (f :: * -> *) a. Applicative f => a -> f a
pure AccessSummary
Undeterminable
aggSummaryLoopPartial ScalarTab
_ Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName))
Nothing AccessSummary
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure AccessSummary
Undeterminable
aggSummaryLoopPartial ScalarTab
scalars_loop (Just (VName
iterator_var, (TPrimExp Int64 VName
_, TPrimExp Int64 VName
upper_bound))) (Set Set LmadRef
lmads) = do
  -- map over each index function in the access summary
  --   Substitube a fresh variable k for the loop iterator
  --   if k is in stride or span of ixfun: fall back to total
  --   new_stride = old_offset - old_offset (where k+1 is substituted for k)
  --   new_offset = old_offset where k = lower bound of iteration
  --   new_span = upper bound of iteration
  forall (m :: * -> *) b a.
(Monad m, Monoid b) =>
(a -> m b) -> [a] -> m b
concatMapM
    ( forall (m :: * -> *).
MonadFreshNames m =>
VName
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> LmadRef
-> m AccessSummary
aggSummaryOne
        VName
iterator_var
        (forall v. PrimExp v -> TPrimExp Int64 v
isInt64 (forall v. v -> PrimType -> PrimExp v
LeafExp VName
iterator_var forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
Int64) forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
1)
        (TPrimExp Int64 VName
upper_bound forall a. Num a => a -> a -> a
- VName -> TPrimExp Int64 VName
typedLeafExp VName
iterator_var forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1)
        forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Eq a => (a -> a) -> a -> a
fixPoint (forall {k} a (t :: k).
Ord a =>
Map a (TPrimExp t a) -> LMAD (TPrimExp t a) -> LMAD (TPrimExp t a)
IxFun.substituteInLMAD forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp ScalarTab
scalars_loop)
    )
    (forall a. Set a -> [a]
S.toList Set LmadRef
lmads)

-- | For a given map with $k$ dimensions and an index $i$ for each dimension,
-- compute the aggregated access summary of all other threads.
--
-- For the innermost dimension, this corresponds to
--
-- \[
--   \bigcup_{j=0}^{j<i} Access_j \cup \bigcup_{j=i+1}^{j<n} Access_j
-- \]
--
-- where $Access_j$ describes the point accesses in the map. As we move up in
-- dimensionality, the previous access summaries are kept, in addition to the
-- total aggregation of the inner dimensions. For outer dimensions, the equation
-- is the same, the point accesses in $Access_j$ are replaced with the total
-- aggregation of the inner dimensions.
aggSummaryMapPartial :: MonadFreshNames m => ScalarTab -> [(VName, SubExp)] -> LmadRef -> m AccessSummary
aggSummaryMapPartial :: forall (m :: * -> *).
MonadFreshNames m =>
ScalarTab -> [(VName, SubExp)] -> LmadRef -> m AccessSummary
aggSummaryMapPartial ScalarTab
_ [] = forall a b. a -> b -> a
const forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Monoid a => a
mempty
aggSummaryMapPartial ScalarTab
scalars [(VName, SubExp)]
dims =
  AccessSummary
-> [(VName, SubExp)] -> AccessSummary -> m AccessSummary
helper forall a. Monoid a => a
mempty (forall a. [a] -> [a]
reverse [(VName, SubExp)]
dims) forall b c a. (b -> c) -> (a -> b) -> a -> c
. Set LmadRef -> AccessSummary
Set forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. a -> Set a
S.singleton -- Reverse dims so we work from the inside out
  where
    helper :: AccessSummary
-> [(VName, SubExp)] -> AccessSummary -> m AccessSummary
helper AccessSummary
acc [] AccessSummary
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure AccessSummary
acc
    helper AccessSummary
Undeterminable [(VName, SubExp)]
_ AccessSummary
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure AccessSummary
Undeterminable
    helper AccessSummary
_ [(VName, SubExp)]
_ AccessSummary
Undeterminable = forall (f :: * -> *) a. Applicative f => a -> f a
pure AccessSummary
Undeterminable
    helper (Set Set LmadRef
acc) ((VName
gtid, SubExp
size) : [(VName, SubExp)]
rest) (Set Set LmadRef
as) = do
      AccessSummary
partial_as <- forall (m :: * -> *).
MonadFreshNames m =>
ScalarTab -> (VName, SubExp) -> AccessSummary -> m AccessSummary
aggSummaryMapPartialOne ScalarTab
scalars (VName
gtid, SubExp
size) (Set LmadRef -> AccessSummary
Set Set LmadRef
as)
      AccessSummary
total_as <-
        forall (m :: * -> *) b a.
(Monad m, Monoid b) =>
(a -> m b) -> [a] -> m b
concatMapM
          (forall (m :: * -> *).
MonadFreshNames m =>
VName
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> LmadRef
-> m AccessSummary
aggSummaryOne VName
gtid TPrimExp Int64 VName
0 (forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp forall a b. (a -> b) -> a -> b
$ PrimType -> SubExp -> PrimExp VName
primExpFromSubExp (IntType -> PrimType
IntType IntType
Int64) SubExp
size))
          (forall a. Set a -> [a]
S.toList Set LmadRef
as)
      AccessSummary
-> [(VName, SubExp)] -> AccessSummary -> m AccessSummary
helper (Set LmadRef -> AccessSummary
Set Set LmadRef
acc forall a. Semigroup a => a -> a -> a
<> AccessSummary
partial_as) [(VName, SubExp)]
rest AccessSummary
total_as

-- | Given an access summary $a$, a thread id $i$ and the size $n$ of the
-- dimension, compute the partial map summary.
--
-- Corresponds to
--
-- \[
--   \bigcup_{j=0}^{j<i} a_j \cup \bigcup_{j=i+1}^{j<n} a_j
-- \]
aggSummaryMapPartialOne :: MonadFreshNames m => ScalarTab -> (VName, SubExp) -> AccessSummary -> m AccessSummary
aggSummaryMapPartialOne :: forall (m :: * -> *).
MonadFreshNames m =>
ScalarTab -> (VName, SubExp) -> AccessSummary -> m AccessSummary
aggSummaryMapPartialOne ScalarTab
_ (VName, SubExp)
_ AccessSummary
Undeterminable = forall (f :: * -> *) a. Applicative f => a -> f a
pure AccessSummary
Undeterminable
aggSummaryMapPartialOne ScalarTab
_ (VName
_, Constant PrimValue
n) (Set Set LmadRef
_) | PrimValue -> Bool
oneIsh PrimValue
n = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Monoid a => a
mempty
aggSummaryMapPartialOne ScalarTab
scalars (VName
gtid, SubExp
size) (Set Set LmadRef
lmads0) =
  forall (m :: * -> *) b a.
(Monad m, Monoid b) =>
(a -> m b) -> [a] -> m b
concatMapM
    (TPrimExp Int64 VName, TPrimExp Int64 VName) -> m AccessSummary
helper
    [ (TPrimExp Int64 VName
0, forall v. PrimExp v -> TPrimExp Int64 v
isInt64 (forall v. v -> PrimType -> PrimExp v
LeafExp VName
gtid forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
Int64)),
      ( forall v. PrimExp v -> TPrimExp Int64 v
isInt64 (forall v. v -> PrimType -> PrimExp v
LeafExp VName
gtid forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
Int64) forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
1,
        forall v. PrimExp v -> TPrimExp Int64 v
isInt64 (PrimType -> SubExp -> PrimExp VName
primExpFromSubExp (IntType -> PrimType
IntType IntType
Int64) SubExp
size) forall a. Num a => a -> a -> a
- forall v. PrimExp v -> TPrimExp Int64 v
isInt64 (forall v. v -> PrimType -> PrimExp v
LeafExp VName
gtid forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
Int64) forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1
      )
    ]
  where
    lmads :: [LmadRef]
lmads = forall a b. (a -> b) -> [a] -> [b]
map (forall a. Eq a => (a -> a) -> a -> a
fixPoint (forall {k} a (t :: k).
Ord a =>
Map a (TPrimExp t a) -> LMAD (TPrimExp t a) -> LMAD (TPrimExp t a)
IxFun.substituteInLMAD forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp ScalarTab
scalars)) forall a b. (a -> b) -> a -> b
$ forall a. Set a -> [a]
S.toList Set LmadRef
lmads0
    helper :: (TPrimExp Int64 VName, TPrimExp Int64 VName) -> m AccessSummary
helper (TPrimExp Int64 VName
x, TPrimExp Int64 VName
y) = forall (m :: * -> *) b a.
(Monad m, Monoid b) =>
(a -> m b) -> [a] -> m b
concatMapM (forall (m :: * -> *).
MonadFreshNames m =>
VName
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> LmadRef
-> m AccessSummary
aggSummaryOne VName
gtid TPrimExp Int64 VName
x TPrimExp Int64 VName
y) [LmadRef]
lmads

-- | Computes to total access summary over a multi-dimensional map.
aggSummaryMapTotal :: MonadFreshNames m => ScalarTab -> [(VName, SubExp)] -> AccessSummary -> m AccessSummary
aggSummaryMapTotal :: forall (m :: * -> *).
MonadFreshNames m =>
ScalarTab -> [(VName, SubExp)] -> AccessSummary -> m AccessSummary
aggSummaryMapTotal ScalarTab
_ [] AccessSummary
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Monoid a => a
mempty
aggSummaryMapTotal ScalarTab
_ [(VName, SubExp)]
_ (Set Set LmadRef
lmads)
  | Set LmadRef
lmads forall a. Eq a => a -> a -> Bool
== forall a. Monoid a => a
mempty = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Monoid a => a
mempty
aggSummaryMapTotal ScalarTab
_ [(VName, SubExp)]
_ AccessSummary
Undeterminable = forall (f :: * -> *) a. Applicative f => a -> f a
pure AccessSummary
Undeterminable
aggSummaryMapTotal ScalarTab
scalars [(VName, SubExp)]
segspace (Set Set LmadRef
lmads0) =
  forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM
    ( \AccessSummary
as' (VName
gtid', SubExp
size') -> case AccessSummary
as' of
        Set Set LmadRef
lmads' ->
          forall (m :: * -> *) b a.
(Monad m, Monoid b) =>
(a -> m b) -> [a] -> m b
concatMapM
            ( forall (m :: * -> *).
MonadFreshNames m =>
VName
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> LmadRef
-> m AccessSummary
aggSummaryOne VName
gtid' TPrimExp Int64 VName
0 forall a b. (a -> b) -> a -> b
$
                forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp forall a b. (a -> b) -> a -> b
$
                  PrimType -> SubExp -> PrimExp VName
primExpFromSubExp (IntType -> PrimType
IntType IntType
Int64) SubExp
size'
            )
            (forall a. Set a -> [a]
S.toList Set LmadRef
lmads')
        AccessSummary
Undeterminable -> forall (f :: * -> *) a. Applicative f => a -> f a
pure AccessSummary
Undeterminable
    )
    (Set LmadRef -> AccessSummary
Set Set LmadRef
lmads)
    (forall a. [a] -> [a]
reverse [(VName, SubExp)]
segspace)
  where
    lmads :: Set LmadRef
lmads =
      forall a. Ord a => [a] -> Set a
S.fromList forall a b. (a -> b) -> a -> b
$
        forall a b. (a -> b) -> [a] -> [b]
map (forall a. Eq a => (a -> a) -> a -> a
fixPoint (forall {k} a (t :: k).
Ord a =>
Map a (TPrimExp t a) -> LMAD (TPrimExp t a) -> LMAD (TPrimExp t a)
IxFun.substituteInLMAD forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp ScalarTab
scalars)) forall a b. (a -> b) -> a -> b
$
          forall a. Set a -> [a]
S.toList Set LmadRef
lmads0

-- | Helper function that aggregates the accesses of single LMAD according to a
-- given iterator value, a lower bound and a span.
--
-- If successful, the result is an index function with an extra outer
-- dimension. The stride of the outer dimension is computed by taking the
-- difference between two points in the index function.
--
-- The function returns 'Underterminable' if the iterator is free in the output
-- LMAD or the dimensions of the input LMAD .
aggSummaryOne :: MonadFreshNames m => VName -> TPrimExp Int64 VName -> TPrimExp Int64 VName -> LmadRef -> m AccessSummary
aggSummaryOne :: forall (m :: * -> *).
MonadFreshNames m =>
VName
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> LmadRef
-> m AccessSummary
aggSummaryOne VName
iterator_var TPrimExp Int64 VName
lower_bound TPrimExp Int64 VName
spn lmad :: LmadRef
lmad@(IxFun.LMAD TPrimExp Int64 VName
offset0 [LMADDim (TPrimExp Int64 VName)]
dims0)
  | VName
iterator_var VName -> Names -> Bool
`nameIn` forall a. FreeIn a => a -> Names
freeIn [LMADDim (TPrimExp Int64 VName)]
dims0 = forall (f :: * -> *) a. Applicative f => a -> f a
pure AccessSummary
Undeterminable
  | VName
iterator_var VName -> Names -> Bool
`notNameIn` forall a. FreeIn a => a -> Names
freeIn TPrimExp Int64 VName
offset0 = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Set LmadRef -> AccessSummary
Set forall a b. (a -> b) -> a -> b
$ forall a. a -> Set a
S.singleton LmadRef
lmad
  | Bool
otherwise = do
      VName
new_var <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"k"
      let offset :: TPrimExp Int64 VName
offset = TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
replaceIteratorWith (VName -> TPrimExp Int64 VName
typedLeafExp VName
new_var) TPrimExp Int64 VName
offset0
          offsetp1 :: TPrimExp Int64 VName
offsetp1 = TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
replaceIteratorWith (VName -> TPrimExp Int64 VName
typedLeafExp VName
new_var forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
1) TPrimExp Int64 VName
offset0
          new_stride :: TPrimExp Int64 VName
new_stride = forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp forall a b. (a -> b) -> a -> b
$ forall v. PrimExp v -> PrimExp v
constFoldPrimExp forall a b. (a -> b) -> a -> b
$ PrimExp VName -> PrimExp VName
simplify forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
offsetp1 forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
offset
          new_offset :: TPrimExp Int64 VName
new_offset = TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
replaceIteratorWith TPrimExp Int64 VName
lower_bound TPrimExp Int64 VName
offset0
          new_lmad :: LmadRef
new_lmad =
            forall num. num -> [LMADDim num] -> LMAD num
IxFun.LMAD TPrimExp Int64 VName
new_offset forall a b. (a -> b) -> a -> b
$
              forall num. num -> num -> Int -> Monotonicity -> LMADDim num
IxFun.LMADDim TPrimExp Int64 VName
new_stride TPrimExp Int64 VName
spn Int
0 Monotonicity
IxFun.Inc forall a. a -> [a] -> [a]
: forall a b. (a -> b) -> [a] -> [b]
map forall {num}. LMADDim num -> LMADDim num
incPerm [LMADDim (TPrimExp Int64 VName)]
dims0
      if VName
new_var VName -> Names -> Bool
`nameIn` forall a. FreeIn a => a -> Names
freeIn LmadRef
new_lmad
        then forall (f :: * -> *) a. Applicative f => a -> f a
pure AccessSummary
Undeterminable
        else forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Set LmadRef -> AccessSummary
Set forall a b. (a -> b) -> a -> b
$ forall a. a -> Set a
S.singleton LmadRef
new_lmad
  where
    incPerm :: LMADDim num -> LMADDim num
incPerm LMADDim num
dim = LMADDim num
dim {ldPerm :: Int
IxFun.ldPerm = forall num. LMADDim num -> Int
IxFun.ldPerm LMADDim num
dim forall a. Num a => a -> a -> a
+ Int
1}
    replaceIteratorWith :: TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
replaceIteratorWith TPrimExp Int64 VName
se = forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall v. Ord v => Map v (PrimExp v) -> PrimExp v -> PrimExp v
substituteInPrimExp (forall k a. k -> a -> Map k a
M.singleton VName
iterator_var forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
se) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped

-- | Takes a 'VName' and converts it into a 'TPrimExp' with type 'Int64'.
typedLeafExp :: VName -> TPrimExp Int64 VName
typedLeafExp :: VName -> TPrimExp Int64 VName
typedLeafExp VName
vname = forall v. PrimExp v -> TPrimExp Int64 v
isInt64 forall a b. (a -> b) -> a -> b
$ forall v. v -> PrimType -> PrimExp v
LeafExp VName
vname (IntType -> PrimType
IntType IntType
Int64)