{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeFamilies #-}

-- | The bulk of the short-circuiting implementation.
module Futhark.Optimise.ArrayShortCircuiting.ArrayCoalescing (mkCoalsTab, CoalsTab, mkCoalsTabGPU) where

import Control.Exception.Base qualified as Exc
import Control.Monad.Reader
import Control.Monad.State.Strict
import Data.Function ((&))
import Data.List qualified as L
import Data.List.NonEmpty (NonEmpty (..))
import Data.List.NonEmpty qualified as NE
import Data.Map.Strict qualified as M
import Data.Maybe
import Data.Sequence (Seq (..))
import Data.Set qualified as S
import Futhark.Analysis.PrimExp.Convert
import Futhark.IR.Aliases
import Futhark.IR.GPUMem
import Futhark.IR.Mem.IxFun qualified as IxFun
import Futhark.IR.SeqMem
import Futhark.MonadFreshNames
import Futhark.Optimise.ArrayShortCircuiting.DataStructs
import Futhark.Optimise.ArrayShortCircuiting.LastUse
import Futhark.Optimise.ArrayShortCircuiting.MemRefAggreg
import Futhark.Optimise.ArrayShortCircuiting.TopdownAnalysis
import Futhark.Util

-- | A helper type describing representations that can be short-circuited.
type Coalesceable rep inner =
  ( CreatesNewArrOp (OpWithAliases inner),
    ASTRep rep,
    CanBeAliased inner,
    Op rep ~ MemOp inner,
    HasMemBlock (Aliases rep),
    LetDec rep ~ LetDecMem,
    TopDownHelper (OpWithAliases inner)
  )

-- Helper type for computing scalar tables on ops.
newtype ComputeScalarTableOnOp rep = ComputeScalarTableOnOp
  { forall {k} (rep :: k).
ComputeScalarTableOnOp rep
-> ScopeTab rep
-> Op (Aliases rep)
-> ScalarTableM rep (Map VName (PrimExp VName))
scalarTableOnOp :: ScopeTab rep -> Op (Aliases rep) -> ScalarTableM rep (M.Map VName (PrimExp VName))
  }

type ScalarTableM rep a = Reader (ComputeScalarTableOnOp rep) a

newtype ShortCircuitReader rep = ShortCircuitReader
  { forall {k} (rep :: k).
ShortCircuitReader rep
-> InhibitTab
-> Pat (ConsumedInExp, LetDecMem)
-> Op (Aliases rep)
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep BotUpEnv
onOp :: LUTabFun -> Pat (VarAliases, LetDecMem) -> Op (Aliases rep) -> TopdownEnv rep -> BotUpEnv -> ShortCircuitM rep BotUpEnv
  }

newtype ShortCircuitM rep a = ShortCircuitM (ReaderT (ShortCircuitReader rep) (State VNameSource) a)
  deriving (forall k (rep :: k) a b.
a -> ShortCircuitM rep b -> ShortCircuitM rep a
forall k (rep :: k) a b.
(a -> b) -> ShortCircuitM rep a -> ShortCircuitM rep b
forall a b. a -> ShortCircuitM rep b -> ShortCircuitM rep a
forall a b. (a -> b) -> ShortCircuitM rep a -> ShortCircuitM rep b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> ShortCircuitM rep b -> ShortCircuitM rep a
$c<$ :: forall k (rep :: k) a b.
a -> ShortCircuitM rep b -> ShortCircuitM rep a
fmap :: forall a b. (a -> b) -> ShortCircuitM rep a -> ShortCircuitM rep b
$cfmap :: forall k (rep :: k) a b.
(a -> b) -> ShortCircuitM rep a -> ShortCircuitM rep b
Functor, forall a. a -> ShortCircuitM rep a
forall k (rep :: k). Functor (ShortCircuitM rep)
forall k (rep :: k) a. a -> ShortCircuitM rep a
forall k (rep :: k) a b.
ShortCircuitM rep a -> ShortCircuitM rep b -> ShortCircuitM rep a
forall k (rep :: k) a b.
ShortCircuitM rep a -> ShortCircuitM rep b -> ShortCircuitM rep b
forall k (rep :: k) a b.
ShortCircuitM rep (a -> b)
-> ShortCircuitM rep a -> ShortCircuitM rep b
forall k (rep :: k) a b c.
(a -> b -> c)
-> ShortCircuitM rep a
-> ShortCircuitM rep b
-> ShortCircuitM rep c
forall a b.
ShortCircuitM rep a -> ShortCircuitM rep b -> ShortCircuitM rep a
forall a b.
ShortCircuitM rep a -> ShortCircuitM rep b -> ShortCircuitM rep b
forall a b.
ShortCircuitM rep (a -> b)
-> ShortCircuitM rep a -> ShortCircuitM rep b
forall a b c.
(a -> b -> c)
-> ShortCircuitM rep a
-> ShortCircuitM rep b
-> ShortCircuitM rep c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: forall a b.
ShortCircuitM rep a -> ShortCircuitM rep b -> ShortCircuitM rep a
$c<* :: forall k (rep :: k) a b.
ShortCircuitM rep a -> ShortCircuitM rep b -> ShortCircuitM rep a
*> :: forall a b.
ShortCircuitM rep a -> ShortCircuitM rep b -> ShortCircuitM rep b
$c*> :: forall k (rep :: k) a b.
ShortCircuitM rep a -> ShortCircuitM rep b -> ShortCircuitM rep b
liftA2 :: forall a b c.
(a -> b -> c)
-> ShortCircuitM rep a
-> ShortCircuitM rep b
-> ShortCircuitM rep c
$cliftA2 :: forall k (rep :: k) a b c.
(a -> b -> c)
-> ShortCircuitM rep a
-> ShortCircuitM rep b
-> ShortCircuitM rep c
<*> :: forall a b.
ShortCircuitM rep (a -> b)
-> ShortCircuitM rep a -> ShortCircuitM rep b
$c<*> :: forall k (rep :: k) a b.
ShortCircuitM rep (a -> b)
-> ShortCircuitM rep a -> ShortCircuitM rep b
pure :: forall a. a -> ShortCircuitM rep a
$cpure :: forall k (rep :: k) a. a -> ShortCircuitM rep a
Applicative, forall a. a -> ShortCircuitM rep a
forall k (rep :: k). Applicative (ShortCircuitM rep)
forall k (rep :: k) a. a -> ShortCircuitM rep a
forall k (rep :: k) a b.
ShortCircuitM rep a -> ShortCircuitM rep b -> ShortCircuitM rep b
forall k (rep :: k) a b.
ShortCircuitM rep a
-> (a -> ShortCircuitM rep b) -> ShortCircuitM rep b
forall a b.
ShortCircuitM rep a -> ShortCircuitM rep b -> ShortCircuitM rep b
forall a b.
ShortCircuitM rep a
-> (a -> ShortCircuitM rep b) -> ShortCircuitM rep b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: forall a. a -> ShortCircuitM rep a
$creturn :: forall k (rep :: k) a. a -> ShortCircuitM rep a
>> :: forall a b.
ShortCircuitM rep a -> ShortCircuitM rep b -> ShortCircuitM rep b
$c>> :: forall k (rep :: k) a b.
ShortCircuitM rep a -> ShortCircuitM rep b -> ShortCircuitM rep b
>>= :: forall a b.
ShortCircuitM rep a
-> (a -> ShortCircuitM rep b) -> ShortCircuitM rep b
$c>>= :: forall k (rep :: k) a b.
ShortCircuitM rep a
-> (a -> ShortCircuitM rep b) -> ShortCircuitM rep b
Monad, MonadReader (ShortCircuitReader rep), MonadState VNameSource)

instance MonadFreshNames (ShortCircuitM rep) where
  putNameSource :: VNameSource -> ShortCircuitM rep ()
putNameSource = forall s (m :: * -> *). MonadState s m => s -> m ()
put
  getNameSource :: ShortCircuitM rep VNameSource
getNameSource = forall s (m :: * -> *). MonadState s m => m s
get

emptyTopdownEnv :: TopdownEnv rep
emptyTopdownEnv :: forall {k} (rep :: k). TopdownEnv rep
emptyTopdownEnv =
  TopdownEnv
    { alloc :: AllocTab
alloc = forall a. Monoid a => a
mempty,
      scope :: ScopeTab rep
scope = forall a. Monoid a => a
mempty,
      inhibited :: InhibitTab
inhibited = forall a. Monoid a => a
mempty,
      v_alias :: VarAliasTab
v_alias = forall a. Monoid a => a
mempty,
      m_alias :: InhibitTab
m_alias = forall a. Monoid a => a
mempty,
      nonNegatives :: Names
nonNegatives = forall a. Monoid a => a
mempty,
      scalarTable :: Map VName (PrimExp VName)
scalarTable = forall a. Monoid a => a
mempty,
      knownLessThan :: [(VName, PrimExp VName)]
knownLessThan = forall a. Monoid a => a
mempty,
      td_asserts :: [SubExp]
td_asserts = forall a. Monoid a => a
mempty
    }

emptyBotUpEnv :: BotUpEnv
emptyBotUpEnv :: BotUpEnv
emptyBotUpEnv =
  BotUpEnv
    { scals :: Map VName (PrimExp VName)
scals = forall a. Monoid a => a
mempty,
      activeCoals :: CoalsTab
activeCoals = forall a. Monoid a => a
mempty,
      successCoals :: CoalsTab
successCoals = forall a. Monoid a => a
mempty,
      inhibit :: InhibitTab
inhibit = forall a. Monoid a => a
mempty
    }

--------------------------------------------------------------------------------
--- Main Coalescing Transformation computes a successful coalescing table    ---
--------------------------------------------------------------------------------

-- | Given a 'FunDef' in 'SegMem' representation, compute the coalescing table
-- by folding over each function.
mkCoalsTab :: (MonadFreshNames m) => FunDef (Aliases SeqMem) -> m CoalsTab
mkCoalsTab :: forall (m :: * -> *).
MonadFreshNames m =>
FunDef (Aliases SeqMem) -> m CoalsTab
mkCoalsTab =
  forall {k} (m :: * -> *) (rep :: k) inner.
(MonadFreshNames m, Coalesceable rep inner,
 FParamInfo rep ~ FParamMem) =>
(FunDef (Aliases rep) -> InhibitTab)
-> ShortCircuitReader rep
-> ComputeScalarTableOnOp rep
-> FunDef (Aliases rep)
-> m CoalsTab
mkCoalsTabFun
    (forall a b. (a, b) -> b
snd forall b c a. (b -> c) -> (a -> b) -> a -> c
. FunDef (Aliases SeqMem) -> (Name, InhibitTab)
lastUseSeqMem)
    (forall {k} (rep :: k).
(InhibitTab
 -> Pat (ConsumedInExp, LetDecMem)
 -> Op (Aliases rep)
 -> TopdownEnv rep
 -> BotUpEnv
 -> ShortCircuitM rep BotUpEnv)
-> ShortCircuitReader rep
ShortCircuitReader InhibitTab
-> Pat (ConsumedInExp, LetDecMem)
-> Op (Aliases SeqMem)
-> TopdownEnv SeqMem
-> BotUpEnv
-> ShortCircuitM SeqMem BotUpEnv
shortCircuitSeqMem)
    (forall {k} (rep :: k).
(ScopeTab rep
 -> Op (Aliases rep)
 -> ScalarTableM rep (Map VName (PrimExp VName)))
-> ComputeScalarTableOnOp rep
ComputeScalarTableOnOp forall a b. (a -> b) -> a -> b
$ forall a b. a -> b -> a
const forall a b. (a -> b) -> a -> b
$ forall a b. a -> b -> a
const forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Monoid a => a
mempty)

-- | Given a 'FunDef' in 'GPUMem' representation, compute the coalescing table
-- by folding over each function.
mkCoalsTabGPU :: (MonadFreshNames m) => FunDef (Aliases GPUMem) -> m CoalsTab
mkCoalsTabGPU :: forall (m :: * -> *).
MonadFreshNames m =>
FunDef (Aliases GPUMem) -> m CoalsTab
mkCoalsTabGPU =
  forall {k} (m :: * -> *) (rep :: k) inner.
(MonadFreshNames m, Coalesceable rep inner,
 FParamInfo rep ~ FParamMem) =>
(FunDef (Aliases rep) -> InhibitTab)
-> ShortCircuitReader rep
-> ComputeScalarTableOnOp rep
-> FunDef (Aliases rep)
-> m CoalsTab
mkCoalsTabFun
    (forall a b. (a, b) -> b
snd forall b c a. (b -> c) -> (a -> b) -> a -> c
. FunDef (Aliases GPUMem) -> (Name, InhibitTab)
lastUseGPUMem)
    (forall {k} (rep :: k).
(InhibitTab
 -> Pat (ConsumedInExp, LetDecMem)
 -> Op (Aliases rep)
 -> TopdownEnv rep
 -> BotUpEnv
 -> ShortCircuitM rep BotUpEnv)
-> ShortCircuitReader rep
ShortCircuitReader InhibitTab
-> Pat (ConsumedInExp, LetDecMem)
-> Op (Aliases GPUMem)
-> TopdownEnv GPUMem
-> BotUpEnv
-> ShortCircuitM GPUMem BotUpEnv
shortCircuitGPUMem)
    (forall {k} (rep :: k).
(ScopeTab rep
 -> Op (Aliases rep)
 -> ScalarTableM rep (Map VName (PrimExp VName)))
-> ComputeScalarTableOnOp rep
ComputeScalarTableOnOp ScopeTab GPUMem
-> Op (Aliases GPUMem)
-> ScalarTableM GPUMem (Map VName (PrimExp VName))
computeScalarTableGPUMem)

-- | Given a function, compute the coalescing table
mkCoalsTabFun ::
  (MonadFreshNames m, Coalesceable rep inner, FParamInfo rep ~ FParamMem) =>
  (FunDef (Aliases rep) -> LUTabFun) ->
  ShortCircuitReader rep ->
  ComputeScalarTableOnOp rep ->
  FunDef (Aliases rep) ->
  m CoalsTab
mkCoalsTabFun :: forall {k} (m :: * -> *) (rep :: k) inner.
(MonadFreshNames m, Coalesceable rep inner,
 FParamInfo rep ~ FParamMem) =>
(FunDef (Aliases rep) -> InhibitTab)
-> ShortCircuitReader rep
-> ComputeScalarTableOnOp rep
-> FunDef (Aliases rep)
-> m CoalsTab
mkCoalsTabFun FunDef (Aliases rep) -> InhibitTab
lufun ShortCircuitReader rep
r ComputeScalarTableOnOp rep
computeScalarOnOp fun :: FunDef (Aliases rep)
fun@(FunDef Maybe EntryPoint
_ Attrs
_ Name
_ [RetType (Aliases rep)]
_ [FParam (Aliases rep)]
fpars Body (Aliases rep)
body) = do
  -- First compute last-use information
  let lutab :: InhibitTab
lutab = FunDef (Aliases rep) -> InhibitTab
lufun FunDef (Aliases rep)
fun
      unique_mems :: AllocTab
unique_mems = [Param FParamMem] -> AllocTab
getUniqueMemFParam [FParam (Aliases rep)]
fpars
      scalar_table :: Map VName (PrimExp VName)
scalar_table =
        forall r a. Reader r a -> r -> a
runReader
          ( forall (m :: * -> *) b a.
(Monad m, Monoid b) =>
(a -> m b) -> [a] -> m b
concatMapM
              (forall {k} (rep :: k) inner.
Coalesceable rep inner =>
ScopeTab rep
-> Stm (Aliases rep)
-> ScalarTableM rep (Map VName (PrimExp VName))
computeScalarTable forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) a. Scoped rep a => a -> Scope rep
scopeOf FunDef (Aliases rep)
fun forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k) a. Scoped rep a => a -> Scope rep
scopeOf (forall {k} (rep :: k). Body rep -> Stms rep
bodyStms Body (Aliases rep)
body))
              (forall {k} (rep :: k). Stms rep -> [Stm rep]
stmsToList forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Body rep -> Stms rep
bodyStms Body (Aliases rep)
body)
          )
          ComputeScalarTableOnOp rep
computeScalarOnOp
      topenv :: TopdownEnv rep
topenv =
        forall {k} (rep :: k). TopdownEnv rep
emptyTopdownEnv
          { scope :: ScopeTab rep
scope = forall {k} (rep :: k) dec.
(FParamInfo rep ~ dec) =>
[Param dec] -> Scope rep
scopeOfFParams [FParam (Aliases rep)]
fpars,
            alloc :: AllocTab
alloc = AllocTab
unique_mems,
            scalarTable :: Map VName (PrimExp VName)
scalarTable = Map VName (PrimExp VName)
scalar_table,
            nonNegatives :: Names
nonNegatives = forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Param FParamMem -> Names
paramSizes [FParam (Aliases rep)]
fpars
          }
      ShortCircuitM ReaderT (ShortCircuitReader rep) (State VNameSource) CoalsTab
m = forall {k} (rep :: k) inner.
Coalesceable rep inner =>
InhibitTab
-> [Param FParamMem]
-> Body (Aliases rep)
-> TopdownEnv rep
-> ShortCircuitM rep CoalsTab
fixPointCoalesce InhibitTab
lutab [FParam (Aliases rep)]
fpars Body (Aliases rep)
body TopdownEnv rep
topenv
  forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource forall a b. (a -> b) -> a -> b
$ forall s a. State s a -> s -> (a, s)
runState (forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (ShortCircuitReader rep) (State VNameSource) CoalsTab
m ShortCircuitReader rep
r)

paramSizes :: Param FParamMem -> Names
paramSizes :: Param FParamMem -> Names
paramSizes (Param Attrs
_ VName
_ (MemArray PrimType
_ ShapeBase SubExp
shp Uniqueness
_ MemBind
_)) = forall a. FreeIn a => a -> Names
freeIn ShapeBase SubExp
shp
paramSizes Param FParamMem
_ = forall a. Monoid a => a
mempty

-- | Short-circuit handler for a 'SeqMem' 'Op'.
--
-- Because 'SeqMem' don't have any special operation, simply return the input
-- 'BotUpEnv'.
shortCircuitSeqMem :: LUTabFun -> Pat (VarAliases, LetDecMem) -> Op (Aliases SeqMem) -> TopdownEnv SeqMem -> BotUpEnv -> ShortCircuitM SeqMem BotUpEnv
shortCircuitSeqMem :: InhibitTab
-> Pat (ConsumedInExp, LetDecMem)
-> Op (Aliases SeqMem)
-> TopdownEnv SeqMem
-> BotUpEnv
-> ShortCircuitM SeqMem BotUpEnv
shortCircuitSeqMem InhibitTab
_ Pat (ConsumedInExp, LetDecMem)
_ Op (Aliases SeqMem)
_ TopdownEnv SeqMem
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure

-- | Short-circuit handler for 'GPUMem' 'Op'.
--
-- When the 'Op' is a 'SegOp', we handle it accordingly, otherwise we do
-- nothing.
shortCircuitGPUMem ::
  LUTabFun ->
  Pat (VarAliases, LetDecMem) ->
  Op (Aliases GPUMem) ->
  TopdownEnv GPUMem ->
  BotUpEnv ->
  ShortCircuitM GPUMem BotUpEnv
shortCircuitGPUMem :: InhibitTab
-> Pat (ConsumedInExp, LetDecMem)
-> Op (Aliases GPUMem)
-> TopdownEnv GPUMem
-> BotUpEnv
-> ShortCircuitM GPUMem BotUpEnv
shortCircuitGPUMem InhibitTab
_ Pat (ConsumedInExp, LetDecMem)
_ (Alloc SubExp
_ Space
_) TopdownEnv GPUMem
_ BotUpEnv
bu_env = forall (f :: * -> *) a. Applicative f => a -> f a
pure BotUpEnv
bu_env
shortCircuitGPUMem InhibitTab
lutab Pat (ConsumedInExp, LetDecMem)
pat (Inner (SegOp (SegMap lvl :: SegLevel
lvl@SegThread {} SegSpace
space [Type]
_ KernelBody (Aliases GPUMem)
kernel_body))) TopdownEnv GPUMem
td_env BotUpEnv
bu_env =
  -- No special handling necessary for 'SegMap'. Just call the helper-function.
  Int
-> SegLevel
-> InhibitTab
-> Pat (ConsumedInExp, LetDecMem)
-> SegSpace
-> KernelBody (Aliases GPUMem)
-> TopdownEnv GPUMem
-> BotUpEnv
-> ShortCircuitM GPUMem BotUpEnv
shortCircuitGPUMemHelper Int
0 SegLevel
lvl InhibitTab
lutab Pat (ConsumedInExp, LetDecMem)
pat SegSpace
space KernelBody (Aliases GPUMem)
kernel_body TopdownEnv GPUMem
td_env BotUpEnv
bu_env
shortCircuitGPUMem InhibitTab
lutab Pat (ConsumedInExp, LetDecMem)
pat (Inner (SegOp (SegMap lvl :: SegLevel
lvl@SegGroup {} SegSpace
space [Type]
_ KernelBody (Aliases GPUMem)
kernel_body))) TopdownEnv GPUMem
td_env BotUpEnv
bu_env =
  -- No special handling necessary for 'SegMap'. Just call the helper-function.
  Int
-> SegLevel
-> InhibitTab
-> Pat (ConsumedInExp, LetDecMem)
-> SegSpace
-> KernelBody (Aliases GPUMem)
-> TopdownEnv GPUMem
-> BotUpEnv
-> ShortCircuitM GPUMem BotUpEnv
shortCircuitGPUMemHelper Int
0 SegLevel
lvl InhibitTab
lutab Pat (ConsumedInExp, LetDecMem)
pat SegSpace
space KernelBody (Aliases GPUMem)
kernel_body TopdownEnv GPUMem
td_env BotUpEnv
bu_env
shortCircuitGPUMem InhibitTab
lutab Pat (ConsumedInExp, LetDecMem)
pat (Inner (SegOp (SegMap lvl :: SegLevel
lvl@SegThreadInGroup {} SegSpace
space [Type]
_ KernelBody (Aliases GPUMem)
kernel_body))) TopdownEnv GPUMem
td_env BotUpEnv
bu_env =
  -- No special handling necessary for 'SegMap'. Just call the helper-function.
  Int
-> SegLevel
-> InhibitTab
-> Pat (ConsumedInExp, LetDecMem)
-> SegSpace
-> KernelBody (Aliases GPUMem)
-> TopdownEnv GPUMem
-> BotUpEnv
-> ShortCircuitM GPUMem BotUpEnv
shortCircuitGPUMemHelper Int
0 SegLevel
lvl InhibitTab
lutab Pat (ConsumedInExp, LetDecMem)
pat SegSpace
space KernelBody (Aliases GPUMem)
kernel_body TopdownEnv GPUMem
td_env BotUpEnv
bu_env
shortCircuitGPUMem InhibitTab
lutab Pat (ConsumedInExp, LetDecMem)
pat (Inner (SegOp (SegRed SegLevel
lvl SegSpace
space [SegBinOp (Aliases GPUMem)]
binops [Type]
_ KernelBody (Aliases GPUMem)
kernel_body))) TopdownEnv GPUMem
td_env BotUpEnv
bu_env =
  -- When handling 'SegRed', we we first invalidate all active coalesce-entries
  -- where any of the variables in 'vartab' are also free in the list of
  -- 'SegBinOp'. In other words, anything that is used as part of the reduction
  -- step should probably not be coalesced.
  let to_fail :: CoalsTab
to_fail = forall a k. (a -> Bool) -> Map k a -> Map k a
M.filter (\CoalsEntry
entry -> [VName] -> Names
namesFromList (forall k a. Map k a -> [k]
M.keys forall a b. (a -> b) -> a -> b
$ CoalsEntry -> Map VName Coalesced
vartab CoalsEntry
entry) Names -> Names -> Bool
`namesIntersect` forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (forall a. FreeIn a => a -> Names
freeIn forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). SegBinOp rep -> Lambda rep
segBinOpLambda) [SegBinOp (Aliases GPUMem)]
binops) forall a b. (a -> b) -> a -> b
$ BotUpEnv -> CoalsTab
activeCoals BotUpEnv
bu_env
      (CoalsTab
active, InhibitTab
inh) =
        forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab)
markFailedCoal (BotUpEnv -> CoalsTab
activeCoals BotUpEnv
bu_env, BotUpEnv -> InhibitTab
inhibit BotUpEnv
bu_env) forall a b. (a -> b) -> a -> b
$ forall k a. Map k a -> [k]
M.keys CoalsTab
to_fail
      bu_env' :: BotUpEnv
bu_env' = BotUpEnv
bu_env {activeCoals :: CoalsTab
activeCoals = CoalsTab
active, inhibit :: InhibitTab
inhibit = InhibitTab
inh}
      num_reds :: Int
num_reds = forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
red_ts
   in Int
-> SegLevel
-> InhibitTab
-> Pat (ConsumedInExp, LetDecMem)
-> SegSpace
-> KernelBody (Aliases GPUMem)
-> TopdownEnv GPUMem
-> BotUpEnv
-> ShortCircuitM GPUMem BotUpEnv
shortCircuitGPUMemHelper Int
num_reds SegLevel
lvl InhibitTab
lutab Pat (ConsumedInExp, LetDecMem)
pat SegSpace
space KernelBody (Aliases GPUMem)
kernel_body TopdownEnv GPUMem
td_env BotUpEnv
bu_env'
  where
    segment_dims :: [SubExp]
segment_dims = forall a. [a] -> [a]
init forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space
    red_ts :: [Type]
red_ts = do
      SegBinOp (Aliases GPUMem)
op <- [SegBinOp (Aliases GPUMem)]
binops
      let shp :: ShapeBase SubExp
shp = forall d. [d] -> ShapeBase d
Shape [SubExp]
segment_dims forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k). SegBinOp rep -> ShapeBase SubExp
segBinOpShape SegBinOp (Aliases GPUMem)
op
      forall a b. (a -> b) -> [a] -> [b]
map (Type -> ShapeBase SubExp -> Type
`arrayOfShape` ShapeBase SubExp
shp) (forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp (Aliases GPUMem)
op)
shortCircuitGPUMem InhibitTab
lutab Pat (ConsumedInExp, LetDecMem)
pat (Inner (SegOp (SegScan SegLevel
lvl SegSpace
space [SegBinOp (Aliases GPUMem)]
binops [Type]
_ KernelBody (Aliases GPUMem)
kernel_body))) TopdownEnv GPUMem
td_env BotUpEnv
bu_env =
  -- Like in the handling of 'SegRed', we do not want to coalesce anything that
  -- is used in the 'SegBinOp'
  let to_fail :: CoalsTab
to_fail = forall a k. (a -> Bool) -> Map k a -> Map k a
M.filter (\CoalsEntry
entry -> [VName] -> Names
namesFromList (forall k a. Map k a -> [k]
M.keys forall a b. (a -> b) -> a -> b
$ CoalsEntry -> Map VName Coalesced
vartab CoalsEntry
entry) Names -> Names -> Bool
`namesIntersect` forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (forall a. FreeIn a => a -> Names
freeIn forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). SegBinOp rep -> Lambda rep
segBinOpLambda) [SegBinOp (Aliases GPUMem)]
binops) forall a b. (a -> b) -> a -> b
$ BotUpEnv -> CoalsTab
activeCoals BotUpEnv
bu_env
      (CoalsTab
active, InhibitTab
inh) = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab)
markFailedCoal (BotUpEnv -> CoalsTab
activeCoals BotUpEnv
bu_env, BotUpEnv -> InhibitTab
inhibit BotUpEnv
bu_env) forall a b. (a -> b) -> a -> b
$ forall k a. Map k a -> [k]
M.keys CoalsTab
to_fail
      bu_env' :: BotUpEnv
bu_env' = BotUpEnv
bu_env {activeCoals :: CoalsTab
activeCoals = CoalsTab
active, inhibit :: InhibitTab
inhibit = InhibitTab
inh}
   in Int
-> SegLevel
-> InhibitTab
-> Pat (ConsumedInExp, LetDecMem)
-> SegSpace
-> KernelBody (Aliases GPUMem)
-> TopdownEnv GPUMem
-> BotUpEnv
-> ShortCircuitM GPUMem BotUpEnv
shortCircuitGPUMemHelper Int
0 SegLevel
lvl InhibitTab
lutab Pat (ConsumedInExp, LetDecMem)
pat SegSpace
space KernelBody (Aliases GPUMem)
kernel_body TopdownEnv GPUMem
td_env BotUpEnv
bu_env'
shortCircuitGPUMem InhibitTab
lutab Pat (ConsumedInExp, LetDecMem)
pat (Inner (SegOp (SegHist SegLevel
lvl SegSpace
space [HistOp (Aliases GPUMem)]
histops [Type]
_ KernelBody (Aliases GPUMem)
kernel_body))) TopdownEnv GPUMem
td_env BotUpEnv
bu_env = do
  -- Need to take zipped patterns and histDest (flattened) and insert transitive coalesces
  let to_fail :: CoalsTab
to_fail = forall a k. (a -> Bool) -> Map k a -> Map k a
M.filter (\CoalsEntry
entry -> [VName] -> Names
namesFromList (forall k a. Map k a -> [k]
M.keys forall a b. (a -> b) -> a -> b
$ CoalsEntry -> Map VName Coalesced
vartab CoalsEntry
entry) Names -> Names -> Bool
`namesIntersect` forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (forall a. FreeIn a => a -> Names
freeIn forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). HistOp rep -> Lambda rep
histOp) [HistOp (Aliases GPUMem)]
histops) forall a b. (a -> b) -> a -> b
$ BotUpEnv -> CoalsTab
activeCoals BotUpEnv
bu_env
      (CoalsTab
active, InhibitTab
inh) = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab)
markFailedCoal (BotUpEnv -> CoalsTab
activeCoals BotUpEnv
bu_env, BotUpEnv -> InhibitTab
inhibit BotUpEnv
bu_env) forall a b. (a -> b) -> a -> b
$ forall k a. Map k a -> [k]
M.keys CoalsTab
to_fail
      bu_env' :: BotUpEnv
bu_env' = BotUpEnv
bu_env {activeCoals :: CoalsTab
activeCoals = CoalsTab
active, inhibit :: InhibitTab
inhibit = InhibitTab
inh}
  BotUpEnv
bu_env'' <- Int
-> SegLevel
-> InhibitTab
-> Pat (ConsumedInExp, LetDecMem)
-> SegSpace
-> KernelBody (Aliases GPUMem)
-> TopdownEnv GPUMem
-> BotUpEnv
-> ShortCircuitM GPUMem BotUpEnv
shortCircuitGPUMemHelper Int
0 SegLevel
lvl InhibitTab
lutab Pat (ConsumedInExp, LetDecMem)
pat SegSpace
space KernelBody (Aliases GPUMem)
kernel_body TopdownEnv GPUMem
td_env BotUpEnv
bu_env'
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$
    forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl BotUpEnv -> (PatElem (ConsumedInExp, LetDecMem), VName) -> BotUpEnv
insertHistCoals BotUpEnv
bu_env'' forall a b. (a -> b) -> a -> b
$
      forall a b. [a] -> [b] -> [(a, b)]
zip (forall dec. Pat dec -> [PatElem dec]
patElems Pat (ConsumedInExp, LetDecMem)
pat) forall a b. (a -> b) -> a -> b
$
        forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap forall {k} (rep :: k). HistOp rep -> [VName]
histDest [HistOp (Aliases GPUMem)]
histops
  where
    insertHistCoals :: BotUpEnv -> (PatElem (ConsumedInExp, LetDecMem), VName) -> BotUpEnv
insertHistCoals BotUpEnv
acc (PatElem VName
p (ConsumedInExp, LetDecMem)
_, VName
hist_dest) =
      case ( forall {k} (rep :: k).
HasMemBlock rep =>
VName -> Scope rep -> Maybe ArrayMemBound
getScopeMemInfo VName
p forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). TopdownEnv rep -> ScopeTab rep
scope TopdownEnv GPUMem
td_env,
             forall {k} (rep :: k).
HasMemBlock rep =>
VName -> Scope rep -> Maybe ArrayMemBound
getScopeMemInfo VName
hist_dest forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). TopdownEnv rep -> ScopeTab rep
scope TopdownEnv GPUMem
td_env
           ) of
        (Just (MemBlock PrimType
_ ShapeBase SubExp
_ VName
p_mem IxFun
_), Just (MemBlock PrimType
_ ShapeBase SubExp
_ VName
dest_mem IxFun
_)) ->
          case forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
p_mem forall a b. (a -> b) -> a -> b
$ BotUpEnv -> CoalsTab
successCoals BotUpEnv
acc of
            Just CoalsEntry
entry ->
              -- Update this entry with an optdep for the memory block of hist_dest
              let entry' :: CoalsEntry
entry' = CoalsEntry
entry {optdeps :: Map VName VName
optdeps = forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
p VName
p_mem forall a b. (a -> b) -> a -> b
$ CoalsEntry -> Map VName VName
optdeps CoalsEntry
entry}
               in BotUpEnv
acc
                    { successCoals :: CoalsTab
successCoals = forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
p_mem CoalsEntry
entry' forall a b. (a -> b) -> a -> b
$ BotUpEnv -> CoalsTab
successCoals BotUpEnv
acc,
                      activeCoals :: CoalsTab
activeCoals = forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
dest_mem CoalsEntry
entry forall a b. (a -> b) -> a -> b
$ BotUpEnv -> CoalsTab
activeCoals BotUpEnv
acc
                    }
            Maybe CoalsEntry
Nothing -> BotUpEnv
acc
        (Maybe ArrayMemBound, Maybe ArrayMemBound)
_ -> BotUpEnv
acc
shortCircuitGPUMem InhibitTab
lutab Pat (ConsumedInExp, LetDecMem)
pat (Inner (GPUBody [Type]
_ Body (Aliases GPUMem)
body)) TopdownEnv GPUMem
td_env BotUpEnv
bu_env = do
  VName
fresh1 <- forall (m :: * -> *). MonadFreshNames m => String -> m VName
newNameFromString String
"gpubody"
  VName
fresh2 <- forall (m :: * -> *). MonadFreshNames m => String -> m VName
newNameFromString String
"gpubody"
  Int
-> SegLevel
-> InhibitTab
-> Pat (ConsumedInExp, LetDecMem)
-> SegSpace
-> KernelBody (Aliases GPUMem)
-> TopdownEnv GPUMem
-> BotUpEnv
-> ShortCircuitM GPUMem BotUpEnv
shortCircuitGPUMemHelper
    Int
0
    -- Construct a 'SegLevel' corresponding to a single thread
    ( SegVirt -> Maybe KernelGrid -> SegLevel
SegThread SegVirt
SegNoVirt forall a b. (a -> b) -> a -> b
$
        forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$
          Count NumGroups SubExp -> Count GroupSize SubExp -> KernelGrid
KernelGrid
            (forall {k} (u :: k) e. e -> Count u e
Count forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue forall a b. (a -> b) -> a -> b
$ Int64 -> IntValue
Int64Value Int64
1)
            (forall {k} (u :: k) e. e -> Count u e
Count forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue forall a b. (a -> b) -> a -> b
$ Int64 -> IntValue
Int64Value Int64
1)
    )
    InhibitTab
lutab
    Pat (ConsumedInExp, LetDecMem)
pat
    (VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
fresh1 [(VName
fresh2, PrimValue -> SubExp
Constant forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue forall a b. (a -> b) -> a -> b
$ Int64 -> IntValue
Int64Value Int64
1)])
    (Body (Aliases GPUMem) -> KernelBody (Aliases GPUMem)
bodyToKernelBody Body (Aliases GPUMem)
body)
    TopdownEnv GPUMem
td_env
    BotUpEnv
bu_env
shortCircuitGPUMem InhibitTab
_ Pat (ConsumedInExp, LetDecMem)
_ (Inner (SizeOp SizeOp
_)) TopdownEnv GPUMem
_ BotUpEnv
bu_env = forall (f :: * -> *) a. Applicative f => a -> f a
pure BotUpEnv
bu_env
shortCircuitGPUMem InhibitTab
_ Pat (ConsumedInExp, LetDecMem)
_ (Inner (OtherOp ())) TopdownEnv GPUMem
_ BotUpEnv
bu_env = forall (f :: * -> *) a. Applicative f => a -> f a
pure BotUpEnv
bu_env

dropLastSegSpace :: SegSpace -> SegSpace
dropLastSegSpace :: SegSpace -> SegSpace
dropLastSegSpace SegSpace
space = SegSpace
space {unSegSpace :: [(VName, SubExp)]
unSegSpace = forall a. [a] -> [a]
init forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space}

isSegThread :: SegLevel -> Bool
isSegThread :: SegLevel -> Bool
isSegThread SegThread {} = Bool
True
isSegThread SegLevel
_ = Bool
False

-- | Computes the slice written at the end of a thread in a 'SegOp'.
threadSlice :: SegSpace -> KernelResult -> Maybe (Slice (TPrimExp Int64 VName))
threadSlice :: SegSpace -> KernelResult -> Maybe (Slice (TPrimExp Int64 VName))
threadSlice SegSpace
space Returns {} =
  forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$
    forall d. [DimIndex d] -> Slice d
Slice forall a b. (a -> b) -> a -> b
$
      forall a b. (a -> b) -> [a] -> [b]
map (forall d. d -> DimIndex d
DimFix forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b c. (a -> b -> c) -> b -> a -> c
flip forall v. v -> PrimType -> PrimExp v
LeafExp (IntType -> PrimType
IntType IntType
Int64) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) forall a b. (a -> b) -> a -> b
$
        SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
threadSlice SegSpace
space (RegTileReturns Certs
_ [(SubExp, SubExp, SubExp)]
dims VName
_) =
  forall a. a -> Maybe a
Just
    forall a b. (a -> b) -> a -> b
$ forall d. [DimIndex d] -> Slice d
Slice
    forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
      ( \(SubExp
_, SubExp
block_tile_size0, SubExp
reg_tile_size0) (VName
x0, SubExp
_) ->
          let x :: TPrimExp Int64 VName
x = SubExp -> TPrimExp Int64 VName
pe64 forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
x0
              block_tile_size :: TPrimExp Int64 VName
block_tile_size = SubExp -> TPrimExp Int64 VName
pe64 SubExp
block_tile_size0
              reg_tile_size :: TPrimExp Int64 VName
reg_tile_size = SubExp -> TPrimExp Int64 VName
pe64 SubExp
reg_tile_size0
           in forall d. d -> d -> d -> DimIndex d
DimSlice (TPrimExp Int64 VName
x forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
block_tile_size forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
reg_tile_size) (TPrimExp Int64 VName
block_tile_size forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
reg_tile_size) TPrimExp Int64 VName
1
      )
      [(SubExp, SubExp, SubExp)]
dims
    forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
threadSlice SegSpace
_ KernelResult
_ = forall a. Maybe a
Nothing

bodyToKernelBody :: Body (Aliases GPUMem) -> KernelBody (Aliases GPUMem)
bodyToKernelBody :: Body (Aliases GPUMem) -> KernelBody (Aliases GPUMem)
bodyToKernelBody (Body BodyDec (Aliases GPUMem)
dec Stms (Aliases GPUMem)
stms Result
res) =
  forall {k} (rep :: k).
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody BodyDec (Aliases GPUMem)
dec Stms (Aliases GPUMem)
stms forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (\(SubExpRes Certs
cert SubExp
subexps) -> ResultManifest -> Certs -> SubExp -> KernelResult
Returns ResultManifest
ResultNoSimplify Certs
cert SubExp
subexps) Result
res

-- | A helper for all the different kinds of 'SegOp'.
--
-- Consists of four parts:
--
-- 1. Create coalescing relations between the pattern elements and the kernel
-- body results using 'makeSegMapCoals'.
--
-- 2. Process the statements of the 'KernelBody'.
--
-- 3. Check the overlap between the different threads.
--
-- 4. Mark active coalescings as finished, since a 'SegOp' is an array creation
-- point.
shortCircuitGPUMemHelper ::
  -- | The number of returns for which we should drop the last seg space
  Int ->
  SegLevel ->
  LUTabFun ->
  Pat (VarAliases, LetDecMem) ->
  SegSpace ->
  KernelBody (Aliases GPUMem) ->
  TopdownEnv GPUMem ->
  BotUpEnv ->
  ShortCircuitM GPUMem BotUpEnv
shortCircuitGPUMemHelper :: Int
-> SegLevel
-> InhibitTab
-> Pat (ConsumedInExp, LetDecMem)
-> SegSpace
-> KernelBody (Aliases GPUMem)
-> TopdownEnv GPUMem
-> BotUpEnv
-> ShortCircuitM GPUMem BotUpEnv
shortCircuitGPUMemHelper Int
num_reds SegLevel
lvl InhibitTab
lutab pat :: Pat (ConsumedInExp, LetDecMem)
pat@(Pat [PatElem (ConsumedInExp, LetDecMem)]
ps0) SegSpace
space0 KernelBody (Aliases GPUMem)
kernel_body TopdownEnv GPUMem
td_env BotUpEnv
bu_env = do
  -- We need to drop the last element of the 'SegSpace' for pattern elements
  -- that correspond to reductions.
  let ps_space_and_res :: [(PatElem (ConsumedInExp, LetDecMem), SegSpace, KernelResult)]
ps_space_and_res =
        forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [PatElem (ConsumedInExp, LetDecMem)]
ps0 (forall a. Int -> a -> [a]
replicate Int
num_reds (SegSpace -> SegSpace
dropLastSegSpace SegSpace
space0) forall a. Semigroup a => a -> a -> a
<> forall a. a -> [a]
repeat SegSpace
space0) forall a b. (a -> b) -> a -> b
$
          forall {k} (rep :: k). KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody (Aliases GPUMem)
kernel_body
  -- Create coalescing relations between pattern elements and kernel body
  -- results
  let (CoalsTab
actv0, InhibitTab
inhibit0) =
        forall {k} (rep :: k).
HasMemBlock (Aliases rep) =>
CoalsTab
-> InhibitTab
-> Map VName (PrimExp VName)
-> TopdownEnv rep
-> [PatElem (ConsumedInExp, LetDecMem)]
-> (CoalsTab, InhibitTab)
filterSafetyCond2and5
          (BotUpEnv -> CoalsTab
activeCoals BotUpEnv
bu_env)
          (BotUpEnv -> InhibitTab
inhibit BotUpEnv
bu_env)
          (BotUpEnv -> Map VName (PrimExp VName)
scals BotUpEnv
bu_env)
          TopdownEnv GPUMem
td_env
          (forall dec. Pat dec -> [PatElem dec]
patElems Pat (ConsumedInExp, LetDecMem)
pat)
      (CoalsTab
actv_return, InhibitTab
inhibit_return) =
        if Int
num_reds forall a. Ord a => a -> a -> Bool
> Int
0
          then (CoalsTab
actv0, InhibitTab
inhibit0)
          else forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (SegLevel
-> TopdownEnv GPUMem
-> KernelBody (Aliases GPUMem)
-> (CoalsTab, InhibitTab)
-> (PatElem (ConsumedInExp, LetDecMem), SegSpace, KernelResult)
-> (CoalsTab, InhibitTab)
makeSegMapCoals SegLevel
lvl TopdownEnv GPUMem
td_env KernelBody (Aliases GPUMem)
kernel_body) (CoalsTab
actv0, InhibitTab
inhibit0) [(PatElem (ConsumedInExp, LetDecMem), SegSpace, KernelResult)]
ps_space_and_res

  -- Start from empty references, we'll update with aggregates later.
  let actv0' :: CoalsTab
actv0' = forall a b k. (a -> b) -> Map k a -> Map k b
M.map (\CoalsEntry
etry -> CoalsEntry
etry {memrefs :: MemRefs
memrefs = forall a. Monoid a => a
mempty}) forall a b. (a -> b) -> a -> b
$ CoalsTab
actv0 forall a. Semigroup a => a -> a -> a
<> CoalsTab
actv_return
  -- Process kernel body statements
  BotUpEnv
bu_env' <-
    forall {k} (rep :: k) inner.
Coalesceable rep inner =>
InhibitTab
-> Stms (Aliases rep)
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep BotUpEnv
mkCoalsTabStms InhibitTab
lutab (forall {k} (rep :: k). KernelBody rep -> Stms rep
kernelBodyStms KernelBody (Aliases GPUMem)
kernel_body) TopdownEnv GPUMem
td_env forall a b. (a -> b) -> a -> b
$
      BotUpEnv
bu_env {activeCoals :: CoalsTab
activeCoals = CoalsTab
actv0', inhibit :: InhibitTab
inhibit = InhibitTab
inhibit_return}

  let actv_coals_after :: CoalsTab
actv_coals_after =
        forall k a b. (k -> a -> b) -> Map k a -> Map k b
M.mapWithKey
          ( \VName
k CoalsEntry
etry ->
              CoalsEntry
etry
                { memrefs :: MemRefs
memrefs = CoalsEntry -> MemRefs
memrefs CoalsEntry
etry forall a. Semigroup a => a -> a -> a
<> forall b a. b -> (a -> b) -> Maybe a -> b
maybe forall a. Monoid a => a
mempty CoalsEntry -> MemRefs
memrefs (forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
k forall a b. (a -> b) -> a -> b
$ CoalsTab
actv0 forall a. Semigroup a => a -> a -> a
<> CoalsTab
actv_return)
                }
          )
          forall a b. (a -> b) -> a -> b
$ BotUpEnv -> CoalsTab
activeCoals BotUpEnv
bu_env'

  -- Check partial overlap.
  let checkPartialOverlap :: BotUpEnv -> (VName, CoalsEntry) -> ShortCircuitM GPUMem BotUpEnv
checkPartialOverlap BotUpEnv
bu_env_f (VName
k, CoalsEntry
entry) = do
        let sliceThreadAccess :: (PatElem (ConsumedInExp, LetDecMem), SegSpace, KernelResult)
-> AccessSummary
sliceThreadAccess (PatElem (ConsumedInExp, LetDecMem)
p, SegSpace
space, KernelResult
res) =
              case forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (forall dec. PatElem dec -> VName
patElemName PatElem (ConsumedInExp, LetDecMem)
p) forall a b. (a -> b) -> a -> b
$ CoalsEntry -> Map VName Coalesced
vartab CoalsEntry
entry of
                Just (Coalesced CoalescedKind
_ (MemBlock PrimType
_ ShapeBase SubExp
_ VName
_ IxFun
ixf) FreeVarSubsts
_) ->
                  forall b a. b -> (a -> b) -> Maybe a -> b
maybe
                    AccessSummary
Undeterminable
                    ( IxFun -> AccessSummary
ixfunToAccessSummary
                        forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun
ixf
                        forall b c a. (b -> c) -> (a -> b) -> a -> c
. [TPrimExp Int64 VName]
-> Slice (TPrimExp Int64 VName) -> Slice (TPrimExp Int64 VName)
fullSlice (forall num. (Eq num, IntegralExp num) => IxFun num -> Shape num
IxFun.shape IxFun
ixf)
                    )
                    forall a b. (a -> b) -> a -> b
$ SegSpace -> KernelResult -> Maybe (Slice (TPrimExp Int64 VName))
threadSlice SegSpace
space KernelResult
res
                Maybe Coalesced
Nothing -> forall a. Monoid a => a
mempty
            thread_writes :: AccessSummary
thread_writes = forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (PatElem (ConsumedInExp, LetDecMem), SegSpace, KernelResult)
-> AccessSummary
sliceThreadAccess [(PatElem (ConsumedInExp, LetDecMem), SegSpace, KernelResult)]
ps_space_and_res
            source_writes :: AccessSummary
source_writes = MemRefs -> AccessSummary
srcwrts (CoalsEntry -> MemRefs
memrefs CoalsEntry
entry) forall a. Semigroup a => a -> a -> a
<> AccessSummary
thread_writes
        AccessSummary
destination_uses <-
          case MemRefs -> AccessSummary
dstrefs (CoalsEntry -> MemRefs
memrefs CoalsEntry
entry)
            AccessSummary -> AccessSummary -> AccessSummary
`accessSubtract` MemRefs -> AccessSummary
dstrefs (forall b a. b -> (a -> b) -> Maybe a -> b
maybe forall a. Monoid a => a
mempty CoalsEntry -> MemRefs
memrefs forall a b. (a -> b) -> a -> b
$ forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
k forall a b. (a -> b) -> a -> b
$ BotUpEnv -> CoalsTab
activeCoals BotUpEnv
bu_env) of
            Set Set LmadRef
s ->
              forall (m :: * -> *) b a.
(Monad m, Monoid b) =>
(a -> m b) -> [a] -> m b
concatMapM
                (forall (m :: * -> *).
MonadFreshNames m =>
Map VName (PrimExp VName)
-> [(VName, SubExp)] -> LmadRef -> m AccessSummary
aggSummaryMapPartial (forall {k} (rep :: k). TopdownEnv rep -> Map VName (PrimExp VName)
scalarTable TopdownEnv GPUMem
td_env) forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space0)
                (forall a. Set a -> [a]
S.toList Set LmadRef
s)
            AccessSummary
Undeterminable -> forall (f :: * -> *) a. Applicative f => a -> f a
pure AccessSummary
Undeterminable
        let res :: Bool
res = forall {k} (rep :: k).
(CanBeAliased (Op rep), RepTypes rep) =>
TopdownEnv rep -> AccessSummary -> AccessSummary -> Bool
noMemOverlap TopdownEnv GPUMem
td_env AccessSummary
destination_uses AccessSummary
source_writes
        if Bool
res
          then forall (f :: * -> *) a. Applicative f => a -> f a
pure BotUpEnv
bu_env_f
          else do
            let (CoalsTab
ac, InhibitTab
inh) = (CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab)
markFailedCoal (BotUpEnv -> CoalsTab
activeCoals BotUpEnv
bu_env_f, BotUpEnv -> InhibitTab
inhibit BotUpEnv
bu_env_f) VName
k
            forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ BotUpEnv
bu_env_f {activeCoals :: CoalsTab
activeCoals = CoalsTab
ac, inhibit :: InhibitTab
inhibit = InhibitTab
inh}

  BotUpEnv
bu_env'' <-
    forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM
      BotUpEnv -> (VName, CoalsEntry) -> ShortCircuitM GPUMem BotUpEnv
checkPartialOverlap
      (BotUpEnv
bu_env' {activeCoals :: CoalsTab
activeCoals = CoalsTab
actv_coals_after})
      forall a b. (a -> b) -> a -> b
$ forall k a. Map k a -> [(k, a)]
M.toList CoalsTab
actv_coals_after

  let updateMemRefs :: CoalsEntry -> ShortCircuitM GPUMem CoalsEntry
updateMemRefs CoalsEntry
entry = do
        AccessSummary
wrts <- forall (m :: * -> *).
MonadFreshNames m =>
Map VName (PrimExp VName)
-> [(VName, SubExp)] -> AccessSummary -> m AccessSummary
aggSummaryMapTotal (forall {k} (rep :: k). TopdownEnv rep -> Map VName (PrimExp VName)
scalarTable TopdownEnv GPUMem
td_env) (SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space0) forall a b. (a -> b) -> a -> b
$ MemRefs -> AccessSummary
srcwrts forall a b. (a -> b) -> a -> b
$ CoalsEntry -> MemRefs
memrefs CoalsEntry
entry
        AccessSummary
uses <- forall (m :: * -> *).
MonadFreshNames m =>
Map VName (PrimExp VName)
-> [(VName, SubExp)] -> AccessSummary -> m AccessSummary
aggSummaryMapTotal (forall {k} (rep :: k). TopdownEnv rep -> Map VName (PrimExp VName)
scalarTable TopdownEnv GPUMem
td_env) (SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space0) forall a b. (a -> b) -> a -> b
$ MemRefs -> AccessSummary
dstrefs forall a b. (a -> b) -> a -> b
$ CoalsEntry -> MemRefs
memrefs CoalsEntry
entry

        -- Add destination uses from the pattern
        let uses' :: AccessSummary
uses' =
              forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap
                ( \case
                    PatElem VName
_ (ConsumedInExp
_, MemArray PrimType
_ ShapeBase SubExp
_ NoUniqueness
_ (ArrayIn VName
p_mem IxFun
p_ixf))
                      | VName
p_mem VName -> Names -> Bool
`nameIn` CoalsEntry -> Names
alsmem CoalsEntry
entry ->
                          IxFun -> AccessSummary
ixfunToAccessSummary IxFun
p_ixf
                    PatElem (ConsumedInExp, LetDecMem)
_ -> forall a. Monoid a => a
mempty
                )
                [PatElem (ConsumedInExp, LetDecMem)]
ps0

        forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ CoalsEntry
entry {memrefs :: MemRefs
memrefs = AccessSummary -> AccessSummary -> MemRefs
MemRefs (AccessSummary
uses forall a. Semigroup a => a -> a -> a
<> AccessSummary
uses') AccessSummary
wrts}

  CoalsTab
actv <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM CoalsEntry -> ShortCircuitM GPUMem CoalsEntry
updateMemRefs forall a b. (a -> b) -> a -> b
$ BotUpEnv -> CoalsTab
activeCoals BotUpEnv
bu_env''
  let bu_env''' :: BotUpEnv
bu_env''' = BotUpEnv
bu_env'' {activeCoals :: CoalsTab
activeCoals = CoalsTab
actv}

  -- Process pattern and return values
  let mergee_writes :: [(PatElem (ConsumedInExp, LetDecMem), (VName, VName, IxFun))]
mergee_writes =
        forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe
          ( \(PatElem (ConsumedInExp, LetDecMem)
p, SegSpace
_, KernelResult
_) ->
              forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (PatElem (ConsumedInExp, LetDecMem)
p,) forall a b. (a -> b) -> a -> b
$
                forall {k} (rep :: k).
HasMemBlock (Aliases rep) =>
TopdownEnv rep -> CoalsTab -> VName -> Maybe (VName, VName, IxFun)
getDirAliasedIxfn' TopdownEnv GPUMem
td_env (BotUpEnv -> CoalsTab
activeCoals BotUpEnv
bu_env''') forall a b. (a -> b) -> a -> b
$
                  forall dec. PatElem dec -> VName
patElemName PatElem (ConsumedInExp, LetDecMem)
p
          )
          [(PatElem (ConsumedInExp, LetDecMem), SegSpace, KernelResult)]
ps_space_and_res

  -- Now, for each mergee write, we need to check that it doesn't overlap with any previous uses of the destination.
  let checkMergeeOverlap :: BotUpEnv
-> (PatElem (ConsumedInExp, LetDecMem), (VName, VName, IxFun))
-> ShortCircuitM GPUMem BotUpEnv
checkMergeeOverlap BotUpEnv
bu_env_f (PatElem (ConsumedInExp, LetDecMem)
p, (VName
m_b, VName
_, IxFun
ixf)) =
        let as :: AccessSummary
as = IxFun -> AccessSummary
ixfunToAccessSummary IxFun
ixf
         in -- Should be @bu_env@ here, because we need to check overlap
            -- against previous uses.
            case forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
m_b forall a b. (a -> b) -> a -> b
$ BotUpEnv -> CoalsTab
activeCoals BotUpEnv
bu_env of
              Just CoalsEntry
coal_entry -> do
                let mrefs :: MemRefs
mrefs =
                      CoalsEntry -> MemRefs
memrefs CoalsEntry
coal_entry
                    res :: Bool
res = forall {k} (rep :: k).
(CanBeAliased (Op rep), RepTypes rep) =>
TopdownEnv rep -> AccessSummary -> AccessSummary -> Bool
noMemOverlap TopdownEnv GPUMem
td_env AccessSummary
as forall a b. (a -> b) -> a -> b
$ MemRefs -> AccessSummary
dstrefs MemRefs
mrefs
                    fail_res :: BotUpEnv
fail_res =
                      let (CoalsTab
ac, InhibitTab
inh) = (CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab)
markFailedCoal (BotUpEnv -> CoalsTab
activeCoals BotUpEnv
bu_env_f, BotUpEnv -> InhibitTab
inhibit BotUpEnv
bu_env_f) VName
m_b
                       in BotUpEnv
bu_env_f {activeCoals :: CoalsTab
activeCoals = CoalsTab
ac, inhibit :: InhibitTab
inhibit = InhibitTab
inh}

                if Bool
res
                  then case forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (forall dec. PatElem dec -> VName
patElemName PatElem (ConsumedInExp, LetDecMem)
p) forall a b. (a -> b) -> a -> b
$ CoalsEntry -> Map VName Coalesced
vartab CoalsEntry
coal_entry of
                    Maybe Coalesced
Nothing -> forall (f :: * -> *) a. Applicative f => a -> f a
pure BotUpEnv
bu_env_f
                    Just (Coalesced CoalescedKind
knd mbd :: ArrayMemBound
mbd@(MemBlock PrimType
_ ShapeBase SubExp
_ VName
_ IxFun
ixfn) FreeVarSubsts
_) -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$
                      case forall {k} a (rep :: k).
FreeIn a =>
ScopeTab rep
-> Map VName (PrimExp VName) -> a -> Maybe FreeVarSubsts
freeVarSubstitutions (forall {k} (rep :: k). TopdownEnv rep -> ScopeTab rep
scope TopdownEnv GPUMem
td_env) (forall {k} (rep :: k). TopdownEnv rep -> Map VName (PrimExp VName)
scalarTable TopdownEnv GPUMem
td_env) IxFun
ixfn of
                        Just FreeVarSubsts
fv_subst ->
                          if IxFun -> [Int]
ixfunPermutation IxFun
ixfn
                            forall a. Eq a => a -> a -> Bool
== IxFun -> [Int]
ixfunPermutation (ArrayMemBound -> IxFun
ixfun forall a b. (a -> b) -> a -> b
$ forall a. HasCallStack => Maybe a -> a
fromJust forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
HasMemBlock rep =>
VName -> Scope rep -> Maybe ArrayMemBound
getScopeMemInfo (forall dec. PatElem dec -> VName
patElemName PatElem (ConsumedInExp, LetDecMem)
p) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). TopdownEnv rep -> ScopeTab rep
scope TopdownEnv GPUMem
td_env)
                            then
                              let entry :: CoalsEntry
entry =
                                    CoalsEntry
coal_entry
                                      { vartab :: Map VName Coalesced
vartab =
                                          forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert
                                            (forall dec. PatElem dec -> VName
patElemName PatElem (ConsumedInExp, LetDecMem)
p)
                                            (CoalescedKind -> ArrayMemBound -> FreeVarSubsts -> Coalesced
Coalesced CoalescedKind
knd ArrayMemBound
mbd FreeVarSubsts
fv_subst)
                                            (CoalsEntry -> Map VName Coalesced
vartab CoalsEntry
coal_entry)
                                      }
                                  (CoalsTab
ac, CoalsTab
suc) =
                                    (CoalsTab, CoalsTab) -> VName -> CoalsEntry -> (CoalsTab, CoalsTab)
markSuccessCoal (BotUpEnv -> CoalsTab
activeCoals BotUpEnv
bu_env_f, BotUpEnv -> CoalsTab
successCoals BotUpEnv
bu_env_f) VName
m_b CoalsEntry
entry
                               in BotUpEnv
bu_env_f {activeCoals :: CoalsTab
activeCoals = CoalsTab
ac, successCoals :: CoalsTab
successCoals = CoalsTab
suc}
                            else BotUpEnv
fail_res
                        Maybe FreeVarSubsts
Nothing ->
                          BotUpEnv
fail_res
                  else forall (f :: * -> *) a. Applicative f => a -> f a
pure BotUpEnv
fail_res
              Maybe CoalsEntry
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure BotUpEnv
bu_env_f

  forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM BotUpEnv
-> (PatElem (ConsumedInExp, LetDecMem), (VName, VName, IxFun))
-> ShortCircuitM GPUMem BotUpEnv
checkMergeeOverlap BotUpEnv
bu_env''' [(PatElem (ConsumedInExp, LetDecMem), (VName, VName, IxFun))]
mergee_writes

ixfunPermutation :: IxFun -> [Int]
ixfunPermutation :: IxFun -> [Int]
ixfunPermutation = forall a b. (a -> b) -> [a] -> [b]
map forall num. LMADDim num -> Int
IxFun.ldPerm forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall num. LMAD num -> [LMADDim num]
IxFun.lmadDims forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. NonEmpty a -> a
NE.head forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall num. IxFun num -> NonEmpty (LMAD num)
IxFun.ixfunLMADs

-- | Given a pattern element and the corresponding kernel result, try to put the
-- kernel result directly in the memory block of pattern element
makeSegMapCoals :: SegLevel -> TopdownEnv GPUMem -> KernelBody (Aliases GPUMem) -> (CoalsTab, InhibitTab) -> (PatElem (VarAliases, LetDecMem), SegSpace, KernelResult) -> (CoalsTab, InhibitTab)
makeSegMapCoals :: SegLevel
-> TopdownEnv GPUMem
-> KernelBody (Aliases GPUMem)
-> (CoalsTab, InhibitTab)
-> (PatElem (ConsumedInExp, LetDecMem), SegSpace, KernelResult)
-> (CoalsTab, InhibitTab)
makeSegMapCoals SegLevel
lvl TopdownEnv GPUMem
td_env KernelBody (Aliases GPUMem)
kernel_body (CoalsTab
active, InhibitTab
inhb) (PatElem VName
pat_name (ConsumedInExp
_, MemArray PrimType
_ ShapeBase SubExp
_ NoUniqueness
_ (ArrayIn VName
pat_mem IxFun
pat_ixf)), SegSpace
space, Returns ResultManifest
_ Certs
_ (Var VName
return_name))
  | Just mb :: ArrayMemBound
mb@(MemBlock PrimType
tp ShapeBase SubExp
return_shp VName
return_mem IxFun
_) <-
      forall {k} (rep :: k).
HasMemBlock rep =>
VName -> Scope rep -> Maybe ArrayMemBound
getScopeMemInfo VName
return_name forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). TopdownEnv rep -> ScopeTab rep
scope TopdownEnv GPUMem
td_env forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k) a. Scoped rep a => a -> Scope rep
scopeOf (forall {k} (rep :: k). KernelBody rep -> Stms rep
kernelBodyStms KernelBody (Aliases GPUMem)
kernel_body),
    SegLevel -> Bool
isSegThread SegLevel
lvl,
    MemMem Space
pat_space <- forall r a. Reader r a -> r -> a
runReader (forall {k} (rep :: k) (m :: * -> *) inner.
(HasScope rep m, Mem rep inner) =>
VName -> m LetDecMem
lookupMemInfo VName
pat_mem) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Scope (Aliases rep) -> Scope rep
removeScopeAliases forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). TopdownEnv rep -> ScopeTab rep
scope TopdownEnv GPUMem
td_env,
    MemMem Space
return_space <- forall r a. Reader r a -> r -> a
runReader (forall {k} (rep :: k) (m :: * -> *) inner.
(HasScope rep m, Mem rep inner) =>
VName -> m LetDecMem
lookupMemInfo VName
return_mem) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Scope (Aliases rep) -> Scope rep
removeScopeAliases forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). TopdownEnv rep -> ScopeTab rep
scope TopdownEnv GPUMem
td_env forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k) a. Scoped rep a => a -> Scope rep
scopeOf (forall {k} (rep :: k). KernelBody rep -> Stms rep
kernelBodyStms KernelBody (Aliases GPUMem)
kernel_body) forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k). SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space,
    Space
pat_space forall a. Eq a => a -> a -> Bool
== Space
return_space =
      case forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
pat_mem CoalsTab
active of
        Maybe CoalsEntry
Nothing ->
          -- We are not in a transitive case
          if forall num. IxFun num -> Bool
IxFun.hasOneLmad IxFun
pat_ixf
            then case ( forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
False (VName
pat_mem `nameIn`) forall a b. (a -> b) -> a -> b
$ forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
return_mem InhibitTab
inhb,
                        CoalescedKind -> ArrayMemBound -> FreeVarSubsts -> Coalesced
Coalesced CoalescedKind
InPlaceCoal ArrayMemBound
mb forall a. Monoid a => a
mempty
                          forall a b. a -> (a -> b) -> b
& forall k a. k -> a -> Map k a
M.singleton VName
return_name
                          forall a b. a -> (a -> b) -> b
& forall a b c. (a -> b -> c) -> b -> a -> c
flip (forall {k} (rep :: k).
HasMemBlock (Aliases rep) =>
TopdownEnv rep
-> Map VName Coalesced -> VName -> Maybe (Map VName Coalesced)
addInvAliassesVarTab TopdownEnv GPUMem
td_env) VName
return_name
                          forall a b. a -> (a -> b) -> b
& forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap
                            ( forall k a. Ord k => (a -> a) -> k -> Map k a -> Map k a
M.adjust
                                ( \(Coalesced CoalescedKind
knd (MemBlock PrimType
pt ShapeBase SubExp
shp VName
_ IxFun
_) FreeVarSubsts
subst) ->
                                    CoalescedKind -> ArrayMemBound -> FreeVarSubsts -> Coalesced
Coalesced
                                      CoalescedKind
knd
                                      ( PrimType -> ShapeBase SubExp -> VName -> IxFun -> ArrayMemBound
MemBlock PrimType
pt ShapeBase SubExp
shp VName
pat_mem forall a b. (a -> b) -> a -> b
$
                                          forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun
pat_ixf forall a b. (a -> b) -> a -> b
$
                                            [TPrimExp Int64 VName]
-> Slice (TPrimExp Int64 VName) -> Slice (TPrimExp Int64 VName)
fullSlice (forall num. (Eq num, IntegralExp num) => IxFun num -> Shape num
IxFun.shape IxFun
pat_ixf) forall a b. (a -> b) -> a -> b
$
                                              forall d. [DimIndex d] -> Slice d
Slice forall a b. (a -> b) -> a -> b
$
                                                forall a b. (a -> b) -> [a] -> [b]
map (forall d. d -> DimIndex d
DimFix forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b c. (a -> b -> c) -> b -> a -> c
flip forall v. v -> PrimType -> PrimExp v
LeafExp (IntType -> PrimType
IntType IntType
Int64) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) forall a b. (a -> b) -> a -> b
$
                                                  SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
                                      )
                                      FreeVarSubsts
subst
                                )
                                VName
return_name
                            )
                      ) of
              (Bool
False, Just Map VName Coalesced
vtab) ->
                (CoalsTab
active forall a. Semigroup a => a -> a -> a
<> forall k a. k -> a -> Map k a
M.singleton VName
return_mem (VName
-> IxFun
-> Names
-> Map VName Coalesced
-> Map VName VName
-> MemRefs
-> CoalsEntry
CoalsEntry VName
pat_mem IxFun
pat_ixf (VName -> Names
oneName VName
pat_mem) Map VName Coalesced
vtab forall a. Monoid a => a
mempty forall a. Monoid a => a
mempty), InhibitTab
inhb)
              (Bool, Maybe (Map VName Coalesced))
_ -> (CoalsTab
active, InhibitTab
inhb)
            else (CoalsTab
active, InhibitTab
inhb)
        Just CoalsEntry
trans ->
          case ( forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
False (CoalsEntry -> VName
dstmem CoalsEntry
trans `nameIn`) forall a b. (a -> b) -> a -> b
$ forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
return_mem InhibitTab
inhb,
                 CoalescedKind -> ArrayMemBound -> FreeVarSubsts -> Coalesced
Coalesced CoalescedKind
InPlaceCoal (PrimType -> ShapeBase SubExp -> VName -> IxFun -> ArrayMemBound
MemBlock PrimType
tp ShapeBase SubExp
return_shp (CoalsEntry -> VName
dstmem CoalsEntry
trans) (CoalsEntry -> IxFun
dstind CoalsEntry
trans)) forall a. Monoid a => a
mempty
                   forall a b. a -> (a -> b) -> b
& forall k a. k -> a -> Map k a
M.singleton VName
return_name
                   forall a b. a -> (a -> b) -> b
& forall a b c. (a -> b -> c) -> b -> a -> c
flip (forall {k} (rep :: k).
HasMemBlock (Aliases rep) =>
TopdownEnv rep
-> Map VName Coalesced -> VName -> Maybe (Map VName Coalesced)
addInvAliassesVarTab TopdownEnv GPUMem
td_env) VName
return_name
                   forall a b. a -> (a -> b) -> b
& forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap
                     ( forall k a. Ord k => (a -> a) -> k -> Map k a -> Map k a
M.adjust
                         ( \(Coalesced CoalescedKind
knd (MemBlock PrimType
pt ShapeBase SubExp
shp VName
mem ixf :: IxFun
ixf@(IxFun.IxFun NonEmpty LmadRef
_ [TPrimExp Int64 VName]
base_shape Bool
_)) FreeVarSubsts
subst) ->
                             CoalescedKind -> ArrayMemBound -> FreeVarSubsts -> Coalesced
Coalesced
                               CoalescedKind
knd
                               ( PrimType -> ShapeBase SubExp -> VName -> IxFun -> ArrayMemBound
MemBlock PrimType
pt ShapeBase SubExp
shp VName
mem forall a b. (a -> b) -> a -> b
$
                                   forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun
ixf forall a b. (a -> b) -> a -> b
$
                                     [TPrimExp Int64 VName]
-> Slice (TPrimExp Int64 VName) -> Slice (TPrimExp Int64 VName)
fullSlice [TPrimExp Int64 VName]
base_shape forall a b. (a -> b) -> a -> b
$
                                       forall d. [DimIndex d] -> Slice d
Slice forall a b. (a -> b) -> a -> b
$
                                         forall a b. (a -> b) -> [a] -> [b]
map (forall d. d -> DimIndex d
DimFix forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b c. (a -> b -> c) -> b -> a -> c
flip forall v. v -> PrimType -> PrimExp v
LeafExp (IntType -> PrimType
IntType IntType
Int64) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) forall a b. (a -> b) -> a -> b
$
                                           SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
                               )
                               FreeVarSubsts
subst
                         )
                         VName
return_name
                     )
               ) of
            (Bool
False, Just Map VName Coalesced
vtab) ->
              let opts :: Map VName VName
opts = if CoalsEntry -> VName
dstmem CoalsEntry
trans forall a. Eq a => a -> a -> Bool
== VName
pat_mem then forall a. Monoid a => a
mempty else forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
pat_name VName
pat_mem forall a b. (a -> b) -> a -> b
$ CoalsEntry -> Map VName VName
optdeps CoalsEntry
trans
               in ( forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert
                      VName
return_mem
                      ( VName
-> IxFun
-> Names
-> Map VName Coalesced
-> Map VName VName
-> MemRefs
-> CoalsEntry
CoalsEntry
                          (CoalsEntry -> VName
dstmem CoalsEntry
trans)
                          (CoalsEntry -> IxFun
dstind CoalsEntry
trans)
                          (VName -> Names
oneName VName
pat_mem forall a. Semigroup a => a -> a -> a
<> CoalsEntry -> Names
alsmem CoalsEntry
trans)
                          Map VName Coalesced
vtab
                          Map VName VName
opts
                          forall a. Monoid a => a
mempty
                      )
                      CoalsTab
active,
                    InhibitTab
inhb
                  )
            (Bool, Maybe (Map VName Coalesced))
_ -> (CoalsTab
active, InhibitTab
inhb)
makeSegMapCoals SegLevel
_ TopdownEnv GPUMem
td_env KernelBody (Aliases GPUMem)
_ (CoalsTab, InhibitTab)
x (PatElem (ConsumedInExp, LetDecMem)
_, SegSpace
_, WriteReturns Certs
_ ShapeBase SubExp
_ VName
return_name [(Slice SubExp, SubExp)]
_) =
  case forall {k} (rep :: k).
HasMemBlock rep =>
VName -> Scope rep -> Maybe ArrayMemBound
getScopeMemInfo VName
return_name forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). TopdownEnv rep -> ScopeTab rep
scope TopdownEnv GPUMem
td_env of
    Just (MemBlock PrimType
_ ShapeBase SubExp
_ VName
return_mem IxFun
_) -> (CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab)
markFailedCoal (CoalsTab, InhibitTab)
x VName
return_mem
    Maybe ArrayMemBound
Nothing -> forall a. HasCallStack => String -> a
error String
"Should not happen?"
makeSegMapCoals SegLevel
_ TopdownEnv GPUMem
td_env KernelBody (Aliases GPUMem)
_ (CoalsTab, InhibitTab)
x (PatElem (ConsumedInExp, LetDecMem)
_, SegSpace
_, KernelResult
result) =
  forall a. FreeIn a => a -> Names
freeIn KernelResult
result
    forall a b. a -> (a -> b) -> b
& Names -> [VName]
namesToList
    forall a b. a -> (a -> b) -> b
& forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (forall a b c. (a -> b -> c) -> b -> a -> c
flip forall {k} (rep :: k).
HasMemBlock rep =>
VName -> Scope rep -> Maybe ArrayMemBound
getScopeMemInfo forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). TopdownEnv rep -> ScopeTab rep
scope TopdownEnv GPUMem
td_env)
    forall a b. a -> (a -> b) -> b
& forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (\(MemBlock PrimType
_ ShapeBase SubExp
_ VName
mem IxFun
_) -> forall a b c. (a -> b -> c) -> b -> a -> c
flip (CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab)
markFailedCoal VName
mem) (CoalsTab, InhibitTab)
x

fullSlice :: [TPrimExp Int64 VName] -> Slice (TPrimExp Int64 VName) -> Slice (TPrimExp Int64 VName)
fullSlice :: [TPrimExp Int64 VName]
-> Slice (TPrimExp Int64 VName) -> Slice (TPrimExp Int64 VName)
fullSlice [TPrimExp Int64 VName]
shp (Slice [DimIndex (TPrimExp Int64 VName)]
slc) =
  forall d. [DimIndex d] -> Slice d
Slice forall a b. (a -> b) -> a -> b
$ [DimIndex (TPrimExp Int64 VName)]
slc forall a. [a] -> [a] -> [a]
++ forall a b. (a -> b) -> [a] -> [b]
map (\TPrimExp Int64 VName
d -> forall d. d -> d -> d -> DimIndex d
DimSlice TPrimExp Int64 VName
0 TPrimExp Int64 VName
d TPrimExp Int64 VName
1) (forall a. Int -> [a] -> [a]
drop (forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimIndex (TPrimExp Int64 VName)]
slc) [TPrimExp Int64 VName]
shp)

fixPointCoalesce ::
  (Coalesceable rep inner) =>
  LUTabFun ->
  [Param FParamMem] ->
  Body (Aliases rep) ->
  TopdownEnv rep ->
  ShortCircuitM rep CoalsTab
fixPointCoalesce :: forall {k} (rep :: k) inner.
Coalesceable rep inner =>
InhibitTab
-> [Param FParamMem]
-> Body (Aliases rep)
-> TopdownEnv rep
-> ShortCircuitM rep CoalsTab
fixPointCoalesce InhibitTab
lutab [Param FParamMem]
fpar Body (Aliases rep)
bdy TopdownEnv rep
topenv = do
  BotUpEnv
buenv <- forall {k} (rep :: k) inner.
Coalesceable rep inner =>
InhibitTab
-> Stms (Aliases rep)
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep BotUpEnv
mkCoalsTabStms InhibitTab
lutab (forall {k} (rep :: k). Body rep -> Stms rep
bodyStms Body (Aliases rep)
bdy) TopdownEnv rep
topenv (BotUpEnv
emptyBotUpEnv {inhibit :: InhibitTab
inhibit = forall {k} (rep :: k). TopdownEnv rep -> InhibitTab
inhibited TopdownEnv rep
topenv})
  let (CoalsTab
succ_tab, CoalsTab
actv_tab, InhibitTab
inhb_tab) = (BotUpEnv -> CoalsTab
successCoals BotUpEnv
buenv, BotUpEnv -> CoalsTab
activeCoals BotUpEnv
buenv, BotUpEnv -> InhibitTab
inhibit BotUpEnv
buenv)
      -- Allow short-circuiting function parameters that are unique and have
      -- matching index functions, otherwise mark as failed
      handleFunctionParams :: (CoalsTab, InhibitTab, CoalsTab)
-> (a, Uniqueness, ArrayMemBound)
-> (CoalsTab, InhibitTab, CoalsTab)
handleFunctionParams (CoalsTab
a, InhibitTab
i, CoalsTab
s) (a
_, Uniqueness
u, MemBlock PrimType
_ ShapeBase SubExp
_ VName
m IxFun
ixf) =
        case (Uniqueness
u, forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
m CoalsTab
a) of
          (Uniqueness
Unique, Just CoalsEntry
entry)
            | CoalsEntry -> IxFun
dstind CoalsEntry
entry forall a. Eq a => a -> a -> Bool
== IxFun
ixf ->
                let (CoalsTab
a', CoalsTab
s') = (CoalsTab, CoalsTab) -> VName -> CoalsEntry -> (CoalsTab, CoalsTab)
markSuccessCoal (CoalsTab
a, CoalsTab
s) VName
m CoalsEntry
entry
                 in (CoalsTab
a', InhibitTab
i, CoalsTab
s')
          (Uniqueness, Maybe CoalsEntry)
_ ->
            let (CoalsTab
a', InhibitTab
i') = (CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab)
markFailedCoal (CoalsTab
a, InhibitTab
i) VName
m
             in (CoalsTab
a', InhibitTab
i', CoalsTab
s)
      (CoalsTab
actv_tab', InhibitTab
inhb_tab', CoalsTab
succ_tab') =
        forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl
          forall {a}.
(CoalsTab, InhibitTab, CoalsTab)
-> (a, Uniqueness, ArrayMemBound)
-> (CoalsTab, InhibitTab, CoalsTab)
handleFunctionParams
          (CoalsTab
actv_tab, InhibitTab
inhb_tab, CoalsTab
succ_tab)
          forall a b. (a -> b) -> a -> b
$ [Param FParamMem] -> [(VName, Uniqueness, ArrayMemBound)]
getArrMemAssocFParam [Param FParamMem]
fpar

      (CoalsTab
succ_tab'', InhibitTab
failed_optdeps) = CoalsTab -> InhibitTab -> (CoalsTab, InhibitTab)
fixPointFilterDeps CoalsTab
succ_tab' forall k a. Map k a
M.empty
      inhb_tab'' :: InhibitTab
inhb_tab'' = forall k a. Ord k => (a -> a -> a) -> Map k a -> Map k a -> Map k a
M.unionWith forall a. Semigroup a => a -> a -> a
(<>) InhibitTab
failed_optdeps InhibitTab
inhb_tab'
   in if Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ forall k a. Map k a -> Bool
M.null CoalsTab
actv_tab'
        then forall a. HasCallStack => String -> a
error (String
"COALESCING ROOT: BROKEN INV, active not empty: " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show (forall k a. Map k a -> [k]
M.keys CoalsTab
actv_tab'))
        else
          if forall k a. Map k a -> Bool
M.null forall a b. (a -> b) -> a -> b
$ InhibitTab
inhb_tab'' forall k a b. Ord k => Map k a -> Map k b -> Map k a
`M.difference` forall {k} (rep :: k). TopdownEnv rep -> InhibitTab
inhibited TopdownEnv rep
topenv
            then forall (f :: * -> *) a. Applicative f => a -> f a
pure CoalsTab
succ_tab''
            else forall {k} (rep :: k) inner.
Coalesceable rep inner =>
InhibitTab
-> [Param FParamMem]
-> Body (Aliases rep)
-> TopdownEnv rep
-> ShortCircuitM rep CoalsTab
fixPointCoalesce InhibitTab
lutab [Param FParamMem]
fpar Body (Aliases rep)
bdy (TopdownEnv rep
topenv {inhibited :: InhibitTab
inhibited = InhibitTab
inhb_tab''})
  where
    fixPointFilterDeps :: CoalsTab -> InhibitTab -> (CoalsTab, InhibitTab)
    fixPointFilterDeps :: CoalsTab -> InhibitTab -> (CoalsTab, InhibitTab)
fixPointFilterDeps CoalsTab
coaltab InhibitTab
inhbtab =
      let (CoalsTab
coaltab', InhibitTab
inhbtab') = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab)
filterDeps (CoalsTab
coaltab, InhibitTab
inhbtab) (forall k a. Map k a -> [k]
M.keys CoalsTab
coaltab)
       in if forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall k a. Map k a -> [k]
M.keys CoalsTab
coaltab) forall a. Eq a => a -> a -> Bool
== forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall k a. Map k a -> [k]
M.keys CoalsTab
coaltab')
            then (CoalsTab
coaltab', InhibitTab
inhbtab')
            else CoalsTab -> InhibitTab -> (CoalsTab, InhibitTab)
fixPointFilterDeps CoalsTab
coaltab' InhibitTab
inhbtab'

    filterDeps :: (CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab)
filterDeps (CoalsTab
coal, InhibitTab
inhb) VName
mb
      | Bool -> Bool
not (forall k a. Ord k => k -> Map k a -> Bool
M.member VName
mb CoalsTab
coal) = (CoalsTab
coal, InhibitTab
inhb)
    filterDeps (CoalsTab
coal, InhibitTab
inhb) VName
mb
      | Just CoalsEntry
coal_etry <- forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
mb CoalsTab
coal =
          let failed :: Map VName VName
failed = forall k a. (k -> a -> Bool) -> Map k a -> Map k a
M.filterWithKey (forall {k}. Ord k => Map k CoalsEntry -> VName -> k -> Bool
failedOptDep CoalsTab
coal) (CoalsEntry -> Map VName VName
optdeps CoalsEntry
coal_etry)
           in if forall k a. Map k a -> Bool
M.null Map VName VName
failed
                then (CoalsTab
coal, InhibitTab
inhb) -- all ok
                else -- optimistic dependencies failed for the current
                -- memblock; extend inhibited mem-block mergings.
                  (CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab)
markFailedCoal (CoalsTab
coal, InhibitTab
inhb) VName
mb
    filterDeps (CoalsTab, InhibitTab)
_ VName
_ = forall a. HasCallStack => String -> a
error String
"In ArrayCoalescing.hs, fun filterDeps, impossible case reached!"
    failedOptDep :: Map k CoalsEntry -> VName -> k -> Bool
failedOptDep Map k CoalsEntry
coal VName
_ k
mr
      | Bool -> Bool
not (k
mr forall k a. Ord k => k -> Map k a -> Bool
`M.member` Map k CoalsEntry
coal) = Bool
True
    failedOptDep Map k CoalsEntry
coal VName
r k
mr
      | Just CoalsEntry
coal_etry <- forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup k
mr Map k CoalsEntry
coal = Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ VName
r forall k a. Ord k => k -> Map k a -> Bool
`M.member` CoalsEntry -> Map VName Coalesced
vartab CoalsEntry
coal_etry
    failedOptDep Map k CoalsEntry
_ VName
_ k
_ = forall a. HasCallStack => String -> a
error String
"In ArrayCoalescing.hs, fun failedOptDep, impossible case reached!"

-- | Perform short-circuiting on 'Stms'.
mkCoalsTabStms ::
  (Coalesceable rep inner) =>
  LUTabFun ->
  Stms (Aliases rep) ->
  TopdownEnv rep ->
  BotUpEnv ->
  ShortCircuitM rep BotUpEnv
mkCoalsTabStms :: forall {k} (rep :: k) inner.
Coalesceable rep inner =>
InhibitTab
-> Stms (Aliases rep)
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep BotUpEnv
mkCoalsTabStms InhibitTab
lutab Stms (Aliases rep)
stms0 = Stms (Aliases rep)
-> TopdownEnv rep -> BotUpEnv -> ShortCircuitM rep BotUpEnv
traverseStms Stms (Aliases rep)
stms0
  where
    non_negs_in_pats :: Names
non_negs_in_pats = forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (forall rep. Typed rep => Pat rep -> Names
nonNegativesInPat forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Stm rep -> Pat (LetDec rep)
stmPat) Stms (Aliases rep)
stms0
    traverseStms :: Stms (Aliases rep)
-> TopdownEnv rep -> BotUpEnv -> ShortCircuitM rep BotUpEnv
traverseStms Stms (Aliases rep)
Empty TopdownEnv rep
_ BotUpEnv
bu_env = forall (f :: * -> *) a. Applicative f => a -> f a
pure BotUpEnv
bu_env
    traverseStms (Stm (Aliases rep)
stm :<| Stms (Aliases rep)
stms) TopdownEnv rep
td_env BotUpEnv
bu_env = do
      -- Compute @td_env@ top down
      let td_env' :: TopdownEnv rep
td_env' = forall {k} (rep :: k) inner.
(ASTRep rep, Op rep ~ MemOp inner,
 TopDownHelper (OpWithAliases inner)) =>
TopdownEnv rep -> Stm (Aliases rep) -> TopdownEnv rep
updateTopdownEnv TopdownEnv rep
td_env Stm (Aliases rep)
stm
      -- Compute @bu_env@ bottom up
      BotUpEnv
bu_env' <- Stms (Aliases rep)
-> TopdownEnv rep -> BotUpEnv -> ShortCircuitM rep BotUpEnv
traverseStms Stms (Aliases rep)
stms TopdownEnv rep
td_env' BotUpEnv
bu_env
      forall {k} (rep :: k) inner.
Coalesceable rep inner =>
InhibitTab
-> Stm (Aliases rep)
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep BotUpEnv
mkCoalsTabStm InhibitTab
lutab Stm (Aliases rep)
stm (TopdownEnv rep
td_env' {nonNegatives :: Names
nonNegatives = forall {k} (rep :: k). TopdownEnv rep -> Names
nonNegatives TopdownEnv rep
td_env' forall a. Semigroup a => a -> a -> a
<> Names
non_negs_in_pats}) BotUpEnv
bu_env'

-- | Array (register) coalescing can have one of three shapes:
--      a) @let y    = copy(b^{lu})@
--      b) @let y    = concat(a, b^{lu})@
--      c) @let y[i] = b^{lu}@
--   The intent is to use the memory block of the left-hand side
--     for the right-hand side variable, meaning to store @b@ in
--     @m_y@ (rather than @m_b@).
--   The following five safety conditions are necessary:
--      1. the right-hand side is lastly-used in the current statement
--      2. the allocation of @m_y@ dominates the creation of @b@
--         ^ relax it by hoisting the allocation of @m_y@
--      3. there is no use of the left-hand side memory block @m_y@
--           during the liveness of @b@, i.e., in between its last use
--           and its creation.
--         ^ relax it by pointwise/interval-based checking
--      4. @b@ is a newly created array, i.e., does not aliases anything
--         ^ relax it to support exitential memory blocks for if-then-else
--      5. the new index function of @b@ corresponding to memory block @m_y@
--           can be translated at the definition of @b@, and the
--           same for all variables aliasing @b@.
--   Observation: during the live range of @b@, @m_b@ can only be used by
--                variables aliased with @b@, because @b@ is newly created.
--                relax it: in case @m_b@ is existential due to an if-then-else
--                          then the checks should be extended to the actual
--                          array-creation points.
mkCoalsTabStm ::
  (Coalesceable rep inner) =>
  LUTabFun ->
  Stm (Aliases rep) ->
  TopdownEnv rep ->
  BotUpEnv ->
  ShortCircuitM rep BotUpEnv
mkCoalsTabStm :: forall {k} (rep :: k) inner.
Coalesceable rep inner =>
InhibitTab
-> Stm (Aliases rep)
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep BotUpEnv
mkCoalsTabStm InhibitTab
_ (Let (Pat [PatElem (LetDec (Aliases rep))
pe]) StmAux (ExpDec (Aliases rep))
_ Exp (Aliases rep)
e) TopdownEnv rep
td_env BotUpEnv
bu_env
  | Just PrimExp VName
primexp <- forall {k} (m :: * -> *) (rep :: k) v.
(MonadFail m, RepTypes rep) =>
(VName -> m (PrimExp v)) -> Exp rep -> m (PrimExp v)
primExpFromExp (forall {k} (rep :: k).
(CanBeAliased (Op rep), RepTypes rep) =>
ScopeTab rep
-> Map VName (PrimExp VName) -> VName -> Maybe (PrimExp VName)
vnameToPrimExp (forall {k} (rep :: k). TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env) (BotUpEnv -> Map VName (PrimExp VName)
scals BotUpEnv
bu_env)) Exp (Aliases rep)
e =
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ BotUpEnv
bu_env {scals :: Map VName (PrimExp VName)
scals = forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert (forall dec. PatElem dec -> VName
patElemName PatElem (LetDec (Aliases rep))
pe) PrimExp VName
primexp (BotUpEnv -> Map VName (PrimExp VName)
scals BotUpEnv
bu_env)}
mkCoalsTabStm InhibitTab
lutab (Let Pat (LetDec (Aliases rep))
patt StmAux (ExpDec (Aliases rep))
_ (Match [SubExp]
_ [Case (Body (Aliases rep))]
cases Body (Aliases rep)
defbody MatchDec (BranchType (Aliases rep))
_)) TopdownEnv rep
td_env BotUpEnv
bu_env = do
  let pat_val_elms :: [PatElem (ConsumedInExp, LetDecMem)]
pat_val_elms = forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec (Aliases rep))
patt
      -- ToDo: 1. we need to record existential memory blocks in alias table on the top-down pass.
      --       2. need to extend the scope table

      --  i) Filter @activeCoals@ by the 2ND AND 5th safety conditions:
      (CoalsTab
activeCoals0, InhibitTab
inhibit0) =
        forall {k} (rep :: k).
HasMemBlock (Aliases rep) =>
CoalsTab
-> InhibitTab
-> Map VName (PrimExp VName)
-> TopdownEnv rep
-> [PatElem (ConsumedInExp, LetDecMem)]
-> (CoalsTab, InhibitTab)
filterSafetyCond2and5
          (BotUpEnv -> CoalsTab
activeCoals BotUpEnv
bu_env)
          (BotUpEnv -> InhibitTab
inhibit BotUpEnv
bu_env)
          (BotUpEnv -> Map VName (PrimExp VName)
scals BotUpEnv
bu_env)
          TopdownEnv rep
td_env
          [PatElem (ConsumedInExp, LetDecMem)]
pat_val_elms

      -- ii) extend @activeCoals@ by transfering the pattern-elements bindings existent
      --     in @activeCoals@ to the body results of the then and else branches, but only
      --     if the current pattern element can be potentially coalesced and also
      --     if the current pattern element satisfies safety conditions 2 & 5.
      res_mem_def :: [MemBodyResult]
res_mem_def = forall {k} (rep :: k).
HasMemBlock (Aliases rep) =>
CoalsTab
-> ScopeTab rep
-> [PatElem (ConsumedInExp, LetDecMem)]
-> Body (Aliases rep)
-> [MemBodyResult]
findMemBodyResult CoalsTab
activeCoals0 (forall {k} (rep :: k). TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env) [PatElem (ConsumedInExp, LetDecMem)]
pat_val_elms Body (Aliases rep)
defbody
      res_mem_cases :: [[MemBodyResult]]
res_mem_cases = forall a b. (a -> b) -> [a] -> [b]
map (forall {k} (rep :: k).
HasMemBlock (Aliases rep) =>
CoalsTab
-> ScopeTab rep
-> [PatElem (ConsumedInExp, LetDecMem)]
-> Body (Aliases rep)
-> [MemBodyResult]
findMemBodyResult CoalsTab
activeCoals0 (forall {k} (rep :: k). TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env) [PatElem (ConsumedInExp, LetDecMem)]
pat_val_elms forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall body. Case body -> body
caseBody) [Case (Body (Aliases rep))]
cases

      subs_def :: FreeVarSubsts
subs_def = forall aliases.
Pat (aliases, LetDecMem) -> [SubExp] -> FreeVarSubsts
mkSubsTab Pat (LetDec (Aliases rep))
patt forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Body rep -> Result
bodyResult Body (Aliases rep)
defbody
      subs_cases :: [FreeVarSubsts]
subs_cases = forall a b. (a -> b) -> [a] -> [b]
map (forall aliases.
Pat (aliases, LetDecMem) -> [SubExp] -> FreeVarSubsts
mkSubsTab Pat (LetDec (Aliases rep))
patt forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Body rep -> Result
bodyResult forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall body. Case body -> body
caseBody) [Case (Body (Aliases rep))]
cases

      actv_def_i :: CoalsTab
actv_def_i = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (FreeVarSubsts -> CoalsTab -> MemBodyResult -> CoalsTab
transferCoalsToBody FreeVarSubsts
subs_def) CoalsTab
activeCoals0 [MemBodyResult]
res_mem_def
      actv_cases_i :: [CoalsTab]
actv_cases_i = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\FreeVarSubsts
subs [MemBodyResult]
res -> forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (FreeVarSubsts -> CoalsTab -> MemBodyResult -> CoalsTab
transferCoalsToBody FreeVarSubsts
subs) CoalsTab
activeCoals0 [MemBodyResult]
res) [FreeVarSubsts]
subs_cases [[MemBodyResult]]
res_mem_cases

      -- eliminate the original pattern binding of the if statement,
      -- @let x = if y[0,0] > 0 then map (+y[0,0]) a else map (+1) b@
      -- @let y[0] = x@
      -- should succeed because @m_y@ is used before @x@ is created.
      aux :: Map VName a -> MemBodyResult -> Map VName a
aux Map VName a
ac (MemBodyResult VName
m_b VName
_ VName
_ VName
m_r) = if VName
m_b forall a. Eq a => a -> a -> Bool
== VName
m_r then Map VName a
ac else forall k a. Ord k => k -> Map k a -> Map k a
M.delete VName
m_b Map VName a
ac
      actv_def :: CoalsTab
actv_def = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl forall {a}. Map VName a -> MemBodyResult -> Map VName a
aux CoalsTab
actv_def_i [MemBodyResult]
res_mem_def
      actv_cases :: [CoalsTab]
actv_cases = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl forall {a}. Map VName a -> MemBodyResult -> Map VName a
aux) [CoalsTab]
actv_cases_i [[MemBodyResult]]
res_mem_cases

  -- iii) process the then and else bodies
  BotUpEnv
res_def <- forall {k} (rep :: k) inner.
Coalesceable rep inner =>
InhibitTab
-> Stms (Aliases rep)
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep BotUpEnv
mkCoalsTabStms InhibitTab
lutab (forall {k} (rep :: k). Body rep -> Stms rep
bodyStms Body (Aliases rep)
defbody) TopdownEnv rep
td_env (BotUpEnv
bu_env {activeCoals :: CoalsTab
activeCoals = CoalsTab
actv_def})
  [BotUpEnv]
res_cases <- forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (\Case (Body (Aliases rep))
c CoalsTab
a -> forall {k} (rep :: k) inner.
Coalesceable rep inner =>
InhibitTab
-> Stms (Aliases rep)
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep BotUpEnv
mkCoalsTabStms InhibitTab
lutab (forall {k} (rep :: k). Body rep -> Stms rep
bodyStms forall a b. (a -> b) -> a -> b
$ forall body. Case body -> body
caseBody Case (Body (Aliases rep))
c) TopdownEnv rep
td_env (BotUpEnv
bu_env {activeCoals :: CoalsTab
activeCoals = CoalsTab
a})) [Case (Body (Aliases rep))]
cases [CoalsTab]
actv_cases
  let (CoalsTab
actv_def0, CoalsTab
succ_def0, InhibitTab
inhb_def0) = (BotUpEnv -> CoalsTab
activeCoals BotUpEnv
res_def, BotUpEnv -> CoalsTab
successCoals BotUpEnv
res_def, BotUpEnv -> InhibitTab
inhibit BotUpEnv
res_def)

      -- iv) optimistically mark the pattern succesful:
      ((CoalsTab
activeCoals1, InhibitTab
inhibit1), CoalsTab
successCoals1) =
        forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl
          ( forall {c}.
[(CoalsTab, Map VName c)]
-> ((CoalsTab, InhibitTab), CoalsTab)
-> [MemBodyResult]
-> ((CoalsTab, InhibitTab), CoalsTab)
foldfun
              ( (CoalsTab
actv_def0, CoalsTab
succ_def0)
                  forall a. a -> [a] -> [a]
: forall a b. [a] -> [b] -> [(a, b)]
zip (forall a b. (a -> b) -> [a] -> [b]
map BotUpEnv -> CoalsTab
activeCoals [BotUpEnv]
res_cases) (forall a b. (a -> b) -> [a] -> [b]
map BotUpEnv -> CoalsTab
successCoals [BotUpEnv]
res_cases)
              )
          )
          ((CoalsTab
activeCoals0, InhibitTab
inhibit0), BotUpEnv -> CoalsTab
successCoals BotUpEnv
bu_env)
          (forall a. [[a]] -> [[a]]
L.transpose forall a b. (a -> b) -> a -> b
$ [MemBodyResult]
res_mem_def forall a. a -> [a] -> [a]
: [[MemBodyResult]]
res_mem_cases)

      --  v) unify coalescing results of all branches by taking the union
      --     of all entries in the current/then/else success tables.

      actv_res :: CoalsTab
actv_res = forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (forall k a b c.
Ord k =>
(a -> b -> c) -> Map k a -> Map k b -> Map k c
M.intersectionWith CoalsEntry -> CoalsEntry -> CoalsEntry
unionCoalsEntry) CoalsTab
activeCoals1 forall a b. (a -> b) -> a -> b
$ CoalsTab
actv_def0 forall a. a -> [a] -> [a]
: forall a b. (a -> b) -> [a] -> [b]
map BotUpEnv -> CoalsTab
activeCoals [BotUpEnv]
res_cases

      succ_res :: CoalsTab
succ_res = forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (forall k a. Ord k => (a -> a -> a) -> Map k a -> Map k a -> Map k a
M.unionWith CoalsEntry -> CoalsEntry -> CoalsEntry
unionCoalsEntry) CoalsTab
successCoals1 forall a b. (a -> b) -> a -> b
$ CoalsTab
succ_def0 forall a. a -> [a] -> [a]
: forall a b. (a -> b) -> [a] -> [b]
map BotUpEnv -> CoalsTab
successCoals [BotUpEnv]
res_cases

      -- vi) The step of filtering by 3rd safety condition is not
      --       necessary, because we perform index analysis of the
      --       source/destination uses, and they should have been
      --       filtered during the analysis of the then/else bodies.
      inhibit_res :: InhibitTab
inhibit_res =
        forall (f :: * -> *) k a.
(Foldable f, Ord k) =>
(a -> a -> a) -> f (Map k a) -> Map k a
M.unionsWith
          forall a. Semigroup a => a -> a -> a
(<>)
          ( InhibitTab
inhibit1
              forall a. a -> [a] -> [a]
: forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
                ( \CoalsTab
actv InhibitTab
inhb ->
                    let failed :: CoalsTab
failed = forall k a b. Ord k => Map k a -> Map k b -> Map k a
M.difference CoalsTab
actv forall a b. (a -> b) -> a -> b
$ forall k a b c.
Ord k =>
(a -> b -> c) -> Map k a -> Map k b -> Map k c
M.intersectionWith CoalsEntry -> CoalsEntry -> CoalsEntry
unionCoalsEntry CoalsTab
actv CoalsTab
activeCoals0
                     in forall a b. (a, b) -> b
snd forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab)
markFailedCoal (CoalsTab
failed, InhibitTab
inhb) (forall k a. Map k a -> [k]
M.keys CoalsTab
failed)
                )
                (CoalsTab
actv_def0 forall a. a -> [a] -> [a]
: forall a b. (a -> b) -> [a] -> [b]
map BotUpEnv -> CoalsTab
activeCoals [BotUpEnv]
res_cases)
                (InhibitTab
inhb_def0 forall a. a -> [a] -> [a]
: forall a b. (a -> b) -> [a] -> [b]
map BotUpEnv -> InhibitTab
inhibit [BotUpEnv]
res_cases)
          )
  forall (f :: * -> *) a. Applicative f => a -> f a
pure
    BotUpEnv
bu_env
      { activeCoals :: CoalsTab
activeCoals =
          CoalsTab
actv_res,
        successCoals :: CoalsTab
successCoals = CoalsTab
succ_res,
        inhibit :: InhibitTab
inhibit = InhibitTab
inhibit_res
      }
  where
    foldfun :: [(CoalsTab, Map VName c)]
-> ((CoalsTab, InhibitTab), CoalsTab)
-> [MemBodyResult]
-> ((CoalsTab, InhibitTab), CoalsTab)
foldfun [(CoalsTab, Map VName c)]
_ ((CoalsTab, InhibitTab), CoalsTab)
_ [] =
      forall a. HasCallStack => String -> a
error String
"Imposible Case 1!!!"
    foldfun [(CoalsTab, Map VName c)]
_ ((CoalsTab
act, InhibitTab
_), CoalsTab
_) [MemBodyResult]
mem_body_results
      | Maybe CoalsEntry
Nothing <- forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (MemBodyResult -> VName
patMem forall a b. (a -> b) -> a -> b
$ forall a. [a] -> a
head [MemBodyResult]
mem_body_results) CoalsTab
act =
          forall a. HasCallStack => String -> a
error String
"Imposible Case 2!!!"
    foldfun
      [(CoalsTab, Map VName c)]
acc
      ((CoalsTab
act, InhibitTab
inhb), CoalsTab
succc)
      mem_body_results :: [MemBodyResult]
mem_body_results@(MemBodyResult VName
m_b VName
_ VName
_ VName
_ : [MemBodyResult]
_)
        | Just CoalsEntry
info <- forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
m_b CoalsTab
act,
          Just [c]
_ <- forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup forall b c a. (b -> c) -> (a -> b) -> a -> c
. MemBodyResult -> VName
bodyMem) [MemBodyResult]
mem_body_results forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd [(CoalsTab, Map VName c)]
acc =
            -- Optimistically promote to successful coalescing and append!
            let info' :: CoalsEntry
info' =
                  CoalsEntry
info
                    { optdeps :: Map VName VName
optdeps =
                        forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr
                          (\MemBodyResult
mbr -> forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert (MemBodyResult -> VName
bodyName MemBodyResult
mbr) (MemBodyResult -> VName
bodyMem MemBodyResult
mbr))
                          (CoalsEntry -> Map VName VName
optdeps CoalsEntry
info)
                          [MemBodyResult]
mem_body_results
                    }
                (CoalsTab
act', CoalsTab
succc') = (CoalsTab, CoalsTab) -> VName -> CoalsEntry -> (CoalsTab, CoalsTab)
markSuccessCoal (CoalsTab
act, CoalsTab
succc) VName
m_b CoalsEntry
info'
             in ((CoalsTab
act', InhibitTab
inhb), CoalsTab
succc')
    foldfun
      [(CoalsTab, Map VName c)]
acc
      ((CoalsTab
act, InhibitTab
inhb), CoalsTab
succc)
      mem_body_results :: [MemBodyResult]
mem_body_results@(MemBodyResult VName
m_b VName
_ VName
_ VName
_ : [MemBodyResult]
_)
        | Just CoalsEntry
info <- forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
m_b CoalsTab
act,
          forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (forall a. Eq a => a -> a -> Bool
(==) VName
m_b forall b c a. (b -> c) -> (a -> b) -> a -> c
. MemBodyResult -> VName
bodyMem) [MemBodyResult]
mem_body_results,
          Just [CoalsEntry]
info' <- forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup forall b c a. (b -> c) -> (a -> b) -> a -> c
. MemBodyResult -> VName
bodyMem) [MemBodyResult]
mem_body_results forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(CoalsTab, Map VName c)]
acc =
            -- Treating special case resembling:
            -- @let x0 = map (+1) a                                  @
            -- @let x3 = if cond then let x1 = x0 with [0] <- 2 in x1@
            -- @                 else let x2 = x0 with [1] <- 3 in x2@
            -- @let z[1] = x3                                        @
            -- In this case the result active table should be the union
            -- of the @m_x@ entries of the then and else active tables.
            let info'' :: CoalsEntry
info'' =
                  forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl CoalsEntry -> CoalsEntry -> CoalsEntry
unionCoalsEntry CoalsEntry
info [CoalsEntry]
info'
                act' :: CoalsTab
act' = forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
m_b CoalsEntry
info'' CoalsTab
act
             in ((CoalsTab
act', InhibitTab
inhb), CoalsTab
succc)
    foldfun [(CoalsTab, Map VName c)]
_ ((CoalsTab
act, InhibitTab
inhb), CoalsTab
succc) (MemBodyResult
mbr : [MemBodyResult]
_) =
      -- one of the branches has failed coalescing,
      -- hence remove the coalescing of the result.

      ((CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab)
markFailedCoal (CoalsTab
act, InhibitTab
inhb) (MemBodyResult -> VName
patMem MemBodyResult
mbr), CoalsTab
succc)
mkCoalsTabStm InhibitTab
lutab (Let Pat (LetDec (Aliases rep))
pat StmAux (ExpDec (Aliases rep))
_ (DoLoop [(FParam (Aliases rep), SubExp)]
arginis LoopForm (Aliases rep)
lform Body (Aliases rep)
body)) TopdownEnv rep
td_env BotUpEnv
bu_env = do
  let pat_val_elms :: [PatElem (ConsumedInExp, LetDecMem)]
pat_val_elms = forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec (Aliases rep))
pat

      --  i) Filter @activeCoals@ by the 2nd, 3rd AND 5th safety conditions. In
      --  other words, for each active coalescing target, the creation of the
      --  array we're trying to merge should happen before the allocation of the
      --  merge target and the index function should be translateable.
      (CoalsTab
actv0, InhibitTab
inhibit0) =
        forall {k} (rep :: k).
HasMemBlock (Aliases rep) =>
CoalsTab
-> InhibitTab
-> Map VName (PrimExp VName)
-> TopdownEnv rep
-> [PatElem (ConsumedInExp, LetDecMem)]
-> (CoalsTab, InhibitTab)
filterSafetyCond2and5
          (BotUpEnv -> CoalsTab
activeCoals BotUpEnv
bu_env)
          (BotUpEnv -> InhibitTab
inhibit BotUpEnv
bu_env)
          (BotUpEnv -> Map VName (PrimExp VName)
scals BotUpEnv
bu_env)
          TopdownEnv rep
td_env
          [PatElem (ConsumedInExp, LetDecMem)]
pat_val_elms
      -- ii) Extend @activeCoals@ by transfering the pattern-elements bindings
      --     existent in @activeCoals@ to the loop-body results, but only if:
      --       (a) the pattern element is a candidate for coalescing,        &&
      --       (b) the pattern element satisfies safety conditions 2 & 5,
      --           (conditions (a) and (b) have already been checked above), &&
      --       (c) the memory block of the corresponding body result is
      --           allocated outside the loop, i.e., non-existential,        &&
      --       (d) the init name is lastly-used in the initialization
      --           of the loop variant.
      --     Otherwise fail and remove from active-coalescing table!
      bdy_ress :: Result
bdy_ress = forall {k} (rep :: k). Body rep -> Result
bodyResult Body (Aliases rep)
body
      ([(VName, VName)]
patmems, [(VName, VName)]
argmems, [(VName, VName)]
inimems, [(VName, VName)]
resmems) =
        forall a b c d. [(a, b, c, d)] -> ([a], [b], [c], [d])
L.unzip4 forall a b. (a -> b) -> a -> b
$
          forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (CoalsTab
-> (PatElem (ConsumedInExp, LetDecMem),
    (Param (FParamInfo rep), SubExp), SubExp)
-> Maybe
     ((VName, VName), (VName, VName), (VName, VName), (VName, VName))
mapmbFun CoalsTab
actv0) (forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [PatElem (ConsumedInExp, LetDecMem)]
pat_val_elms [(FParam (Aliases rep), SubExp)]
arginis forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp Result
bdy_ress) -- td_env'

      -- remove the other pattern elements from the active coalescing table:
      coal_pat_names :: Names
coal_pat_names = [VName] -> Names
namesFromList forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(VName, VName)]
patmems
      (CoalsTab
actv1, InhibitTab
inhibit1) =
        forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl
          ( \(CoalsTab
act, InhibitTab
inhb) (VName
b, MemBlock PrimType
_ ShapeBase SubExp
_ VName
m_b IxFun
_) ->
              if VName
b VName -> Names -> Bool
`nameIn` Names
coal_pat_names
                then (CoalsTab
act, InhibitTab
inhb) -- ok
                else (CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab)
markFailedCoal (CoalsTab
act, InhibitTab
inhb) VName
m_b -- remove from active
          )
          (CoalsTab
actv0, InhibitTab
inhibit0)
          (forall aliases.
Pat (aliases, LetDecMem) -> [(VName, ArrayMemBound)]
getArrMemAssoc Pat (LetDec (Aliases rep))
pat)

      -- iii) Process the loop's body.
      --      If the memory blocks of the loop result and loop variant param differ
      --      then make the original memory block of the loop result conflict with
      --      the original memory block of the loop parameter. This is done in
      --      order to prevent the coalescing of @a1@, @a0@, @x@ and @db@ in the
      --      same memory block of @y@ in the example below:
      --      @loop(a1 = a0) = for i < n do @
      --      @    let x = map (stencil a1) (iota n)@
      --      @    let db = copy x          @
      --      @    in db                    @
      --      @let y[0] = a1                @
      --      Meaning the coalescing of @x@ in @let db = copy x@ should fail because
      --      @a1@ appears in the definition of @let x = map (stencil a1) (iota n)@.
      res_mem_bdy :: [MemBodyResult]
res_mem_bdy = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\(VName
b, VName
m_b) (VName
r, VName
m_r) -> VName -> VName -> VName -> VName -> MemBodyResult
MemBodyResult VName
m_b VName
b VName
r VName
m_r) [(VName, VName)]
patmems [(VName, VName)]
resmems
      res_mem_arg :: [MemBodyResult]
res_mem_arg = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\(VName
b, VName
m_b) (VName
r, VName
m_r) -> VName -> VName -> VName -> VName -> MemBodyResult
MemBodyResult VName
m_b VName
b VName
r VName
m_r) [(VName, VName)]
patmems [(VName, VName)]
argmems
      res_mem_ini :: [MemBodyResult]
res_mem_ini = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\(VName
b, VName
m_b) (VName
r, VName
m_r) -> VName -> VName -> VName -> VName -> MemBodyResult
MemBodyResult VName
m_b VName
b VName
r VName
m_r) [(VName, VName)]
patmems [(VName, VName)]
inimems

      actv2 :: CoalsTab
actv2 =
        let subs_res :: FreeVarSubsts
subs_res = forall aliases.
Pat (aliases, LetDecMem) -> [SubExp] -> FreeVarSubsts
mkSubsTab Pat (LetDec (Aliases rep))
pat forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Body rep -> Result
bodyResult Body (Aliases rep)
body
            actv11 :: CoalsTab
actv11 = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (FreeVarSubsts -> CoalsTab -> MemBodyResult -> CoalsTab
transferCoalsToBody FreeVarSubsts
subs_res) CoalsTab
actv1 [MemBodyResult]
res_mem_bdy
            subs_arg :: FreeVarSubsts
subs_arg = forall aliases.
Pat (aliases, LetDecMem) -> [SubExp] -> FreeVarSubsts
mkSubsTab Pat (LetDec (Aliases rep))
pat forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. Param dec -> VName
paramName forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) [(FParam (Aliases rep), SubExp)]
arginis
            actv12 :: CoalsTab
actv12 = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (FreeVarSubsts -> CoalsTab -> MemBodyResult -> CoalsTab
transferCoalsToBody FreeVarSubsts
subs_arg) CoalsTab
actv11 [MemBodyResult]
res_mem_arg
            subs_ini :: FreeVarSubsts
subs_ini = forall aliases.
Pat (aliases, LetDecMem) -> [SubExp] -> FreeVarSubsts
mkSubsTab Pat (LetDec (Aliases rep))
pat forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd [(FParam (Aliases rep), SubExp)]
arginis
         in forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (FreeVarSubsts -> CoalsTab -> MemBodyResult -> CoalsTab
transferCoalsToBody FreeVarSubsts
subs_ini) CoalsTab
actv12 [MemBodyResult]
res_mem_ini

      -- The code below adds an aliasing relation to the loop-arg memory
      --   so that to prevent, e.g., the coalescing of an iterative stencil
      --   (you need a buffer for the result and a separate one for the stencil).
      -- @ let b =               @
      -- @    loop (a) for i<N do@
      -- @        stencil a      @
      -- @  ...                  @
      -- @  y[slc_y] = b         @
      -- This should fail coalescing because we are aliasing @m_a@ with
      --   the memory block of the result.
      insertMemAliases :: CoalsTab -> (MemBodyResult, MemBodyResult) -> CoalsTab
insertMemAliases CoalsTab
tab (MemBodyResult VName
_ VName
_ VName
_ VName
m_r, MemBodyResult VName
_ VName
_ VName
_ VName
m_a) =
        if VName
m_r forall a. Eq a => a -> a -> Bool
== VName
m_a
          then CoalsTab
tab
          else case forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
m_r CoalsTab
tab of
            Maybe CoalsEntry
Nothing -> CoalsTab
tab
            Just CoalsEntry
etry ->
              forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
m_r (CoalsEntry
etry {alsmem :: Names
alsmem = CoalsEntry -> Names
alsmem CoalsEntry
etry forall a. Semigroup a => a -> a -> a
<> VName -> Names
oneName VName
m_a}) CoalsTab
tab
      actv3 :: CoalsTab
actv3 = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl CoalsTab -> (MemBodyResult, MemBodyResult) -> CoalsTab
insertMemAliases CoalsTab
actv2 (forall a b. [a] -> [b] -> [(a, b)]
zip [MemBodyResult]
res_mem_bdy [MemBodyResult]
res_mem_arg)
      -- analysing the loop body starts from a null memory-reference set;
      --  the results of the loop body iteration are aggregated later
      actv4 :: CoalsTab
actv4 = forall a b k. (a -> b) -> Map k a -> Map k b
M.map (\CoalsEntry
etry -> CoalsEntry
etry {memrefs :: MemRefs
memrefs = forall a. Monoid a => a
mempty}) CoalsTab
actv3
  BotUpEnv
res_env_body <-
    forall {k} (rep :: k) inner.
Coalesceable rep inner =>
InhibitTab
-> Stms (Aliases rep)
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep BotUpEnv
mkCoalsTabStms
      InhibitTab
lutab
      (forall {k} (rep :: k). Body rep -> Stms rep
bodyStms Body (Aliases rep)
body)
      TopdownEnv rep
td_env'
      ( BotUpEnv
bu_env
          { activeCoals :: CoalsTab
activeCoals = CoalsTab
actv4,
            inhibit :: InhibitTab
inhibit = InhibitTab
inhibit1
          }
      )
  let scals_loop :: Map VName (PrimExp VName)
scals_loop = BotUpEnv -> Map VName (PrimExp VName)
scals BotUpEnv
res_env_body
      (CoalsTab
res_actv0, CoalsTab
res_succ0, InhibitTab
res_inhb0) = (BotUpEnv -> CoalsTab
activeCoals BotUpEnv
res_env_body, BotUpEnv -> CoalsTab
successCoals BotUpEnv
res_env_body, BotUpEnv -> InhibitTab
inhibit BotUpEnv
res_env_body)
      -- iv) Aggregate memory references across loop and filter unsound coalescing
      -- a) Filter the active-table by the FIRST SOUNDNESS condition, namely:
      --     W_i does not overlap with Union_{j=i+1..n} U_j,
      --     where W_i corresponds to the Write set of src mem-block m_b,
      --     and U_j correspond to the uses of the destination
      --     mem-block m_y, in which m_b is coalesced into.
      --     W_i and U_j correspond to the accesses within the loop body.
      mb_loop_idx :: Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName))
mb_loop_idx = forall {k} (rep :: k).
LoopForm (Aliases rep)
-> Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName))
mbLoopIndexRange LoopForm (Aliases rep)
lform
  CoalsTab
res_actv1 <- forall k (m :: * -> *) v.
(Eq k, Monad m) =>
(v -> m Bool) -> Map k v -> m (Map k v)
filterMapM1 (Map VName (PrimExp VName)
-> Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName))
-> CoalsEntry
-> ShortCircuitM rep Bool
loopSoundness1Entry Map VName (PrimExp VName)
scals_loop Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName))
mb_loop_idx) CoalsTab
res_actv0

  -- b) Update the memory-reference summaries across loop:
  --   W = Union_{i=0..n-1} W_i Union W_{before-loop}
  --   U = Union_{i=0..n-1} U_i Union U_{before-loop}
  CoalsTab
res_actv2 <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (ScopeTab rep
-> Map VName (PrimExp VName)
-> Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName))
-> CoalsEntry
-> ShortCircuitM rep CoalsEntry
aggAcrossLoopEntry (forall {k} (rep :: k). TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env' forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k) a. Scoped rep a => a -> Scope rep
scopeOf (forall {k} (rep :: k). Body rep -> Stms rep
bodyStms Body (Aliases rep)
body)) Map VName (PrimExp VName)
scals_loop Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName))
mb_loop_idx) CoalsTab
res_actv1

  -- c) check soundness of the successful promotions for:
  --      - the entries that have been promoted to success during the loop-body pass
  --      - for all the entries of active table
  --    Filter the entries by the SECOND SOUNDNESS CONDITION, namely:
  --      Union_{i=1..n-1} W_i does not overlap the before-the-loop uses
  --        of the destination memory block.
  let res_actv3 :: CoalsTab
res_actv3 = forall k a. (k -> a -> Bool) -> Map k a -> Map k a
M.filterWithKey (CoalsTab -> VName -> CoalsEntry -> Bool
loopSoundness2Entry CoalsTab
actv3) CoalsTab
res_actv2

  let tmp_succ :: CoalsTab
tmp_succ =
        forall k a. (k -> a -> Bool) -> Map k a -> Map k a
M.filterWithKey (forall {k} {a} {p}. Ord k => Map k a -> k -> p -> Bool
okLookup CoalsTab
actv3) forall a b. (a -> b) -> a -> b
$
          forall k a b. Ord k => Map k a -> Map k b -> Map k a
M.difference CoalsTab
res_succ0 (BotUpEnv -> CoalsTab
successCoals BotUpEnv
bu_env)
      ver_succ :: CoalsTab
ver_succ = forall k a. (k -> a -> Bool) -> Map k a -> Map k a
M.filterWithKey (CoalsTab -> VName -> CoalsEntry -> Bool
loopSoundness2Entry CoalsTab
actv3) CoalsTab
tmp_succ
  let suc_fail :: CoalsTab
suc_fail = forall k a b. Ord k => Map k a -> Map k b -> Map k a
M.difference CoalsTab
tmp_succ CoalsTab
ver_succ
      (CoalsTab
res_succ, InhibitTab
res_inhb1) = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab)
markFailedCoal (CoalsTab
res_succ0, InhibitTab
res_inhb0) forall a b. (a -> b) -> a -> b
$ forall k a. Map k a -> [k]
M.keys CoalsTab
suc_fail
      --
      act_fail :: CoalsTab
act_fail = forall k a b. Ord k => Map k a -> Map k b -> Map k a
M.difference CoalsTab
res_actv0 CoalsTab
res_actv3
      (CoalsTab
_, InhibitTab
res_inhb) = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab)
markFailedCoal (CoalsTab
res_actv0, InhibitTab
res_inhb1) forall a b. (a -> b) -> a -> b
$ forall k a. Map k a -> [k]
M.keys CoalsTab
act_fail
      res_actv :: CoalsTab
res_actv =
        forall k a b. (k -> a -> b) -> Map k a -> Map k b
M.mapWithKey (forall {k}.
Ord k =>
Map k CoalsEntry -> k -> CoalsEntry -> CoalsEntry
addBeforeLoop CoalsTab
actv3) CoalsTab
res_actv3

      -- v) optimistically mark the pattern succesful if there is any chance to succeed
      ((CoalsTab
fin_actv1, InhibitTab
fin_inhb1), CoalsTab
fin_succ1) =
        forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl ((CoalsTab, InhibitTab), CoalsTab)
-> ((VName, VName), (VName, VName), (VName, VName), (VName, VName))
-> ((CoalsTab, InhibitTab), CoalsTab)
foldFunOptimPromotion ((CoalsTab
res_actv, InhibitTab
res_inhb), CoalsTab
res_succ) forall a b. (a -> b) -> a -> b
$
          forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
L.zip4 [(VName, VName)]
patmems [(VName, VName)]
argmems [(VName, VName)]
resmems [(VName, VName)]
inimems
      (CoalsTab
fin_actv2, InhibitTab
fin_inhb2) =
        forall a k b. (a -> k -> b -> a) -> a -> Map k b -> a
M.foldlWithKey
          ( \(CoalsTab, InhibitTab)
acc VName
k CoalsEntry
_ ->
              if VName
k VName -> Names -> Bool
`nameIn` [VName] -> Names
namesFromList (forall a b. (a -> b) -> [a] -> [b]
map (forall dec. Param dec -> VName
paramName forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) [(FParam (Aliases rep), SubExp)]
arginis)
                then (CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab)
markFailedCoal (CoalsTab, InhibitTab)
acc VName
k
                else (CoalsTab, InhibitTab)
acc
          )
          (CoalsTab
fin_actv1, InhibitTab
fin_inhb1)
          CoalsTab
fin_actv1
  forall (f :: * -> *) a. Applicative f => a -> f a
pure BotUpEnv
bu_env {activeCoals :: CoalsTab
activeCoals = CoalsTab
fin_actv2, successCoals :: CoalsTab
successCoals = CoalsTab
fin_succ1, inhibit :: InhibitTab
inhibit = InhibitTab
fin_inhb2}
  where
    allocs_bdy :: AllocTab
allocs_bdy = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl forall {k} {rep :: k} {inner}.
(Op rep ~ MemOp inner) =>
AllocTab -> Stm rep -> AllocTab
getAllocs (forall {k} (rep :: k). TopdownEnv rep -> AllocTab
alloc TopdownEnv rep
td_env') forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Body rep -> Stms rep
bodyStms Body (Aliases rep)
body
    td_env_allocs :: TopdownEnv rep
td_env_allocs = TopdownEnv rep
td_env' {alloc :: AllocTab
alloc = AllocTab
allocs_bdy, scope :: ScopeTab rep
scope = forall {k} (rep :: k). TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env' forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k) a. Scoped rep a => a -> Scope rep
scopeOf (forall {k} (rep :: k). Body rep -> Stms rep
bodyStms Body (Aliases rep)
body)}
    td_env' :: TopdownEnv rep
td_env' = forall {k} (rep :: k).
TopdownEnv rep
-> [(FParam rep, SubExp)]
-> LoopForm (Aliases rep)
-> TopdownEnv rep
updateTopdownEnvLoop TopdownEnv rep
td_env [(FParam (Aliases rep), SubExp)]
arginis LoopForm (Aliases rep)
lform
    getAllocs :: AllocTab -> Stm rep -> AllocTab
getAllocs AllocTab
tab (Let (Pat [PatElem (LetDec rep)
pe]) StmAux (ExpDec rep)
_ (Op (Alloc SubExp
_ Space
sp))) =
      forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert (forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe) Space
sp AllocTab
tab
    getAllocs AllocTab
tab Stm rep
_ = AllocTab
tab
    okLookup :: Map k a -> k -> p -> Bool
okLookup Map k a
tab k
m p
_
      | Just a
_ <- forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup k
m Map k a
tab = Bool
True
    okLookup Map k a
_ k
_ p
_ = Bool
False
    --
    mapmbFun :: CoalsTab
-> (PatElem (ConsumedInExp, LetDecMem),
    (Param (FParamInfo rep), SubExp), SubExp)
-> Maybe
     ((VName, VName), (VName, VName), (VName, VName), (VName, VName))
mapmbFun CoalsTab
actv0 (PatElem (ConsumedInExp, LetDecMem)
patel, (Param (FParamInfo rep)
arg, SubExp
ini), SubExp
bdyres)
      | VName
b <- forall dec. PatElem dec -> VName
patElemName PatElem (ConsumedInExp, LetDecMem)
patel,
        (ConsumedInExp
_, MemArray PrimType
_ ShapeBase SubExp
_ NoUniqueness
_ (ArrayIn VName
m_b IxFun
_)) <- forall dec. PatElem dec -> dec
patElemDec PatElem (ConsumedInExp, LetDecMem)
patel,
        VName
a <- forall dec. Param dec -> VName
paramName Param (FParamInfo rep)
arg,
        Var VName
a0 <- SubExp
ini,
        Var VName
r <- SubExp
bdyres,
        Just CoalsEntry
coal_etry <- forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
m_b CoalsTab
actv0,
        Just Coalesced
_ <- forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
b (CoalsEntry -> Map VName Coalesced
vartab CoalsEntry
coal_etry),
        Just (MemBlock PrimType
_ ShapeBase SubExp
_ VName
m_a IxFun
_) <- forall {k} (rep :: k).
HasMemBlock rep =>
VName -> Scope rep -> Maybe ArrayMemBound
getScopeMemInfo VName
a (forall {k} (rep :: k). TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env_allocs),
        Just (MemBlock PrimType
_ ShapeBase SubExp
_ VName
m_a0 IxFun
_) <- forall {k} (rep :: k).
HasMemBlock rep =>
VName -> Scope rep -> Maybe ArrayMemBound
getScopeMemInfo VName
a0 (forall {k} (rep :: k). TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env_allocs),
        Just (MemBlock PrimType
_ ShapeBase SubExp
_ VName
m_r IxFun
_) <- forall {k} (rep :: k).
HasMemBlock rep =>
VName -> Scope rep -> Maybe ArrayMemBound
getScopeMemInfo VName
r (forall {k} (rep :: k). TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env_allocs),
        Just Names
nms <- forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
a InhibitTab
lutab,
        VName
a0 VName -> Names -> Bool
`nameIn` Names
nms,
        VName
m_r forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` forall k a. Map k a -> [k]
M.keys (forall {k} (rep :: k). TopdownEnv rep -> AllocTab
alloc TopdownEnv rep
td_env_allocs) =
          forall a. a -> Maybe a
Just ((VName
b, VName
m_b), (VName
a, VName
m_a), (VName
a0, VName
m_a0), (VName
r, VName
m_r))
    mapmbFun CoalsTab
_ (PatElem (ConsumedInExp, LetDecMem)
_patel, (Param (FParamInfo rep)
_arg, SubExp
_ini), SubExp
_bdyres) = forall a. Maybe a
Nothing
    foldFunOptimPromotion ::
      ((CoalsTab, InhibitTab), CoalsTab) ->
      ((VName, VName), (VName, VName), (VName, VName), (VName, VName)) ->
      ((CoalsTab, InhibitTab), CoalsTab)
    foldFunOptimPromotion :: ((CoalsTab, InhibitTab), CoalsTab)
-> ((VName, VName), (VName, VName), (VName, VName), (VName, VName))
-> ((CoalsTab, InhibitTab), CoalsTab)
foldFunOptimPromotion ((CoalsTab
act, InhibitTab
inhb), CoalsTab
succc) ((VName
b, VName
m_b), (VName
a, VName
m_a), (VName
_r, VName
m_r), (VName
b_i, VName
m_i))
      | VName
m_r forall a. Eq a => a -> a -> Bool
== VName
m_i,
        Just CoalsEntry
info <- forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
m_i CoalsTab
act,
        Just Map VName Coalesced
vtab_i <- forall {k} (rep :: k).
HasMemBlock (Aliases rep) =>
TopdownEnv rep
-> Map VName Coalesced -> VName -> Maybe (Map VName Coalesced)
addInvAliassesVarTab TopdownEnv rep
td_env (CoalsEntry -> Map VName Coalesced
vartab CoalsEntry
info) VName
b_i =
          forall a. HasCallStack => Bool -> a -> a
Exc.assert
            (VName
m_r forall a. Eq a => a -> a -> Bool
== VName
m_b Bool -> Bool -> Bool
&& VName
m_a forall a. Eq a => a -> a -> Bool
== VName
m_b)
            ((forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
m_b (CoalsEntry
info {vartab :: Map VName Coalesced
vartab = Map VName Coalesced
vtab_i}) CoalsTab
act, InhibitTab
inhb), CoalsTab
succc)
      | VName
m_r forall a. Eq a => a -> a -> Bool
== VName
m_i =
          forall a. HasCallStack => Bool -> a -> a
Exc.assert
            (VName
m_r forall a. Eq a => a -> a -> Bool
== VName
m_b Bool -> Bool -> Bool
&& VName
m_a forall a. Eq a => a -> a -> Bool
== VName
m_b)
            ((CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab)
markFailedCoal (CoalsTab
act, InhibitTab
inhb) VName
m_b, CoalsTab
succc)
      | Just CoalsEntry
info_b0 <- forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
m_b CoalsTab
act,
        Just CoalsEntry
info_a0 <- forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
m_a CoalsTab
act,
        Just CoalsEntry
info_i <- forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
m_i CoalsTab
act,
        forall k a. Ord k => k -> Map k a -> Bool
M.member VName
m_r CoalsTab
succc,
        Just Map VName Coalesced
vtab_i <- forall {k} (rep :: k).
HasMemBlock (Aliases rep) =>
TopdownEnv rep
-> Map VName Coalesced -> VName -> Maybe (Map VName Coalesced)
addInvAliassesVarTab TopdownEnv rep
td_env (CoalsEntry -> Map VName Coalesced
vartab CoalsEntry
info_i) VName
b_i,
        [Just CoalsEntry
info_b, Just CoalsEntry
info_a] <- forall a b. (a -> b) -> [a] -> [b]
map (VName, CoalsEntry) -> Maybe CoalsEntry
translateIxFnInScope [(VName
b, CoalsEntry
info_b0), (VName
a, CoalsEntry
info_a0)] =
          let info_b' :: CoalsEntry
info_b' = CoalsEntry
info_b {optdeps :: Map VName VName
optdeps = forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
b_i VName
m_i forall a b. (a -> b) -> a -> b
$ CoalsEntry -> Map VName VName
optdeps CoalsEntry
info_b}
              info_a' :: CoalsEntry
info_a' = CoalsEntry
info_a {optdeps :: Map VName VName
optdeps = forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
b_i VName
m_i forall a b. (a -> b) -> a -> b
$ CoalsEntry -> Map VName VName
optdeps CoalsEntry
info_a}
              info_i' :: CoalsEntry
info_i' =
                CoalsEntry
info_i
                  { optdeps :: Map VName VName
optdeps = forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
b VName
m_b forall a b. (a -> b) -> a -> b
$ CoalsEntry -> Map VName VName
optdeps CoalsEntry
info_i,
                    memrefs :: MemRefs
memrefs = forall a. Monoid a => a
mempty,
                    vartab :: Map VName Coalesced
vartab = Map VName Coalesced
vtab_i
                  }
              act' :: CoalsTab
act' = forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
m_i CoalsEntry
info_i' CoalsTab
act
              (CoalsTab
act1, CoalsTab
succc1) =
                forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl
                  (\(CoalsTab, CoalsTab)
acc (VName
m, CoalsEntry
info) -> (CoalsTab, CoalsTab) -> VName -> CoalsEntry -> (CoalsTab, CoalsTab)
markSuccessCoal (CoalsTab, CoalsTab)
acc VName
m CoalsEntry
info)
                  (CoalsTab
act', CoalsTab
succc)
                  [(VName
m_b, CoalsEntry
info_b'), (VName
m_a, CoalsEntry
info_a')]
           in -- ToDo: make sure that ixfun translates and update substitutions (?)
              ((CoalsTab
act1, InhibitTab
inhb), CoalsTab
succc1)
    foldFunOptimPromotion ((CoalsTab
act, InhibitTab
inhb), CoalsTab
succc) ((VName
_, VName
m_b), (VName
_a, VName
m_a), (VName
_r, VName
m_r), (VName
_b_i, VName
m_i)) =
      forall a. HasCallStack => Bool -> a -> a
Exc.assert
        (VName
m_r forall a. Eq a => a -> a -> Bool
/= VName
m_i)
        (forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab)
markFailedCoal (CoalsTab
act, InhibitTab
inhb) [VName
m_b, VName
m_a, VName
m_r, VName
m_i], CoalsTab
succc)

    translateIxFnInScope :: (VName, CoalsEntry) -> Maybe CoalsEntry
translateIxFnInScope (VName
x, CoalsEntry
info)
      | Just (Coalesced CoalescedKind
knd mbd :: ArrayMemBound
mbd@(MemBlock PrimType
_ ShapeBase SubExp
_ VName
_ IxFun
ixfn) FreeVarSubsts
_subs0) <- forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
x (CoalsEntry -> Map VName Coalesced
vartab CoalsEntry
info),
        forall {k} (rep :: k). TopdownEnv rep -> VName -> Bool
isInScope TopdownEnv rep
td_env (CoalsEntry -> VName
dstmem CoalsEntry
info) =
          let scope_tab :: ScopeTab rep
scope_tab =
                forall {k} (rep :: k). TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env
                  forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k) dec.
(FParamInfo rep ~ dec) =>
[Param dec] -> Scope rep
scopeOfFParams (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(FParam (Aliases rep), SubExp)]
arginis)
           in case forall {k} a (rep :: k).
FreeIn a =>
ScopeTab rep
-> Map VName (PrimExp VName) -> a -> Maybe FreeVarSubsts
freeVarSubstitutions ScopeTab rep
scope_tab (BotUpEnv -> Map VName (PrimExp VName)
scals BotUpEnv
bu_env) IxFun
ixfn of
                Just FreeVarSubsts
fv_subst ->
                  forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ CoalsEntry
info {vartab :: Map VName Coalesced
vartab = forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
x (CoalescedKind -> ArrayMemBound -> FreeVarSubsts -> Coalesced
Coalesced CoalescedKind
knd ArrayMemBound
mbd FreeVarSubsts
fv_subst) (CoalsEntry -> Map VName Coalesced
vartab CoalsEntry
info)}
                Maybe FreeVarSubsts
Nothing -> forall a. Maybe a
Nothing
    translateIxFnInScope (VName, CoalsEntry)
_ = forall a. Maybe a
Nothing
    se0 :: SubExp
se0 = IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0
    mbLoopIndexRange ::
      LoopForm (Aliases rep) ->
      Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName))
    mbLoopIndexRange :: forall {k} (rep :: k).
LoopForm (Aliases rep)
-> Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName))
mbLoopIndexRange (WhileLoop VName
_) = forall a. Maybe a
Nothing
    mbLoopIndexRange (ForLoop VName
inm IntType
_inttp SubExp
seN [(LParam (Aliases rep), VName)]
_) = forall a. a -> Maybe a
Just (VName
inm, (SubExp -> TPrimExp Int64 VName
pe64 SubExp
se0, SubExp -> TPrimExp Int64 VName
pe64 SubExp
seN))
    addBeforeLoop :: Map k CoalsEntry -> k -> CoalsEntry -> CoalsEntry
addBeforeLoop Map k CoalsEntry
actv_bef k
m_b CoalsEntry
etry =
      case forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup k
m_b Map k CoalsEntry
actv_bef of
        Maybe CoalsEntry
Nothing -> CoalsEntry
etry
        Just CoalsEntry
etry0 ->
          CoalsEntry
etry {memrefs :: MemRefs
memrefs = CoalsEntry -> MemRefs
memrefs CoalsEntry
etry0 forall a. Semigroup a => a -> a -> a
<> CoalsEntry -> MemRefs
memrefs CoalsEntry
etry}
    aggAcrossLoopEntry :: ScopeTab rep
-> Map VName (PrimExp VName)
-> Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName))
-> CoalsEntry
-> ShortCircuitM rep CoalsEntry
aggAcrossLoopEntry ScopeTab rep
scope_loop Map VName (PrimExp VName)
scal_tab Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName))
idx CoalsEntry
etry = do
      AccessSummary
wrts <-
        forall {k} (m :: * -> *) (rep :: k).
MonadFreshNames m =>
ScopeTab rep
-> ScopeTab rep
-> Map VName (PrimExp VName)
-> Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName))
-> AccessSummary
-> m AccessSummary
aggSummaryLoopTotal (forall {k} (rep :: k). TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env) ScopeTab rep
scope_loop Map VName (PrimExp VName)
scal_tab Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName))
idx forall a b. (a -> b) -> a -> b
$
          (MemRefs -> AccessSummary
srcwrts forall b c a. (b -> c) -> (a -> b) -> a -> c
. CoalsEntry -> MemRefs
memrefs) CoalsEntry
etry
      AccessSummary
uses <-
        forall {k} (m :: * -> *) (rep :: k).
MonadFreshNames m =>
ScopeTab rep
-> ScopeTab rep
-> Map VName (PrimExp VName)
-> Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName))
-> AccessSummary
-> m AccessSummary
aggSummaryLoopTotal (forall {k} (rep :: k). TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env) ScopeTab rep
scope_loop Map VName (PrimExp VName)
scal_tab Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName))
idx forall a b. (a -> b) -> a -> b
$
          (MemRefs -> AccessSummary
dstrefs forall b c a. (b -> c) -> (a -> b) -> a -> c
. CoalsEntry -> MemRefs
memrefs) CoalsEntry
etry
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ CoalsEntry
etry {memrefs :: MemRefs
memrefs = AccessSummary -> AccessSummary -> MemRefs
MemRefs AccessSummary
uses AccessSummary
wrts}
    loopSoundness1Entry :: Map VName (PrimExp VName)
-> Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName))
-> CoalsEntry
-> ShortCircuitM rep Bool
loopSoundness1Entry Map VName (PrimExp VName)
scal_tab Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName))
idx CoalsEntry
etry = do
      let wrt_i :: AccessSummary
wrt_i = (MemRefs -> AccessSummary
srcwrts forall b c a. (b -> c) -> (a -> b) -> a -> c
. CoalsEntry -> MemRefs
memrefs) CoalsEntry
etry
      AccessSummary
use_p <-
        forall (m :: * -> *).
MonadFreshNames m =>
Map VName (PrimExp VName)
-> Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName))
-> AccessSummary
-> m AccessSummary
aggSummaryLoopPartial (Map VName (PrimExp VName)
scal_tab forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k). TopdownEnv rep -> Map VName (PrimExp VName)
scalarTable TopdownEnv rep
td_env) Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName))
idx forall a b. (a -> b) -> a -> b
$
          MemRefs -> AccessSummary
dstrefs forall a b. (a -> b) -> a -> b
$
            CoalsEntry -> MemRefs
memrefs CoalsEntry
etry
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
(CanBeAliased (Op rep), RepTypes rep) =>
TopdownEnv rep -> AccessSummary -> AccessSummary -> Bool
noMemOverlap TopdownEnv rep
td_env' AccessSummary
wrt_i AccessSummary
use_p
    loopSoundness2Entry :: CoalsTab -> VName -> CoalsEntry -> Bool
    loopSoundness2Entry :: CoalsTab -> VName -> CoalsEntry -> Bool
loopSoundness2Entry CoalsTab
old_actv VName
m_b CoalsEntry
etry =
      case forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
m_b CoalsTab
old_actv of
        Maybe CoalsEntry
Nothing -> Bool
True
        Just CoalsEntry
etry0 ->
          let uses_before :: AccessSummary
uses_before = (MemRefs -> AccessSummary
dstrefs forall b c a. (b -> c) -> (a -> b) -> a -> c
. CoalsEntry -> MemRefs
memrefs) CoalsEntry
etry0
              write_loop :: AccessSummary
write_loop = (MemRefs -> AccessSummary
srcwrts forall b c a. (b -> c) -> (a -> b) -> a -> c
. CoalsEntry -> MemRefs
memrefs) CoalsEntry
etry
           in forall {k} (rep :: k).
(CanBeAliased (Op rep), RepTypes rep) =>
TopdownEnv rep -> AccessSummary -> AccessSummary -> Bool
noMemOverlap TopdownEnv rep
td_env AccessSummary
write_loop AccessSummary
uses_before

-- The case of in-place update:
--   @let x' = x with slice <- elm@
mkCoalsTabStm InhibitTab
lutab stm :: Stm (Aliases rep)
stm@(Let pat :: Pat (LetDec (Aliases rep))
pat@(Pat [PatElem (LetDec (Aliases rep))
x']) StmAux (ExpDec (Aliases rep))
_ e :: Exp (Aliases rep)
e@(BasicOp (Update Safety
safety VName
x Slice SubExp
_ SubExp
_elm))) TopdownEnv rep
td_env BotUpEnv
bu_env
  | [(VName
_, MemBlock PrimType
_ ShapeBase SubExp
_ VName
m_x IxFun
_)] <- forall aliases.
Pat (aliases, LetDecMem) -> [(VName, ArrayMemBound)]
getArrMemAssoc Pat (LetDec (Aliases rep))
pat =
      do
        -- (a) filter by the 3rd safety for @elm@ and @x'@
        let (CoalsTab
actv, InhibitTab
inhbt) = forall {k} (rep :: k) inner.
(CanBeAliased (Op rep), RepTypes rep, Op rep ~ MemOp inner,
 HasMemBlock (Aliases rep)) =>
TopdownEnv rep
-> BotUpEnv -> Stm (Aliases rep) -> (CoalsTab, InhibitTab)
recordMemRefUses TopdownEnv rep
td_env BotUpEnv
bu_env Stm (Aliases rep)
stm
            -- (b) if @x'@ is in active coalesced table, then add an entry for @x@ as well
            (CoalsTab
actv', InhibitTab
inhbt') =
              case forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
m_x CoalsTab
actv of
                Maybe CoalsEntry
Nothing -> (CoalsTab
actv, InhibitTab
inhbt)
                Just CoalsEntry
info ->
                  case forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (forall dec. PatElem dec -> VName
patElemName PatElem (LetDec (Aliases rep))
x') (CoalsEntry -> Map VName Coalesced
vartab CoalsEntry
info) of
                    Maybe Coalesced
Nothing ->
                      (CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab)
markFailedCoal (CoalsTab
actv, InhibitTab
inhbt) VName
m_x
                    Just (Coalesced CoalescedKind
k mblk :: ArrayMemBound
mblk@(MemBlock PrimType
_ ShapeBase SubExp
_ VName
_ IxFun
x_indfun) FreeVarSubsts
_) ->
                      case forall {k} a (rep :: k).
FreeIn a =>
ScopeTab rep
-> Map VName (PrimExp VName) -> a -> Maybe FreeVarSubsts
freeVarSubstitutions (forall {k} (rep :: k). TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env) (BotUpEnv -> Map VName (PrimExp VName)
scals BotUpEnv
bu_env) IxFun
x_indfun of
                        Just FreeVarSubsts
fv_subs
                          | forall {k} (rep :: k). TopdownEnv rep -> VName -> Bool
isInScope TopdownEnv rep
td_env (CoalsEntry -> VName
dstmem CoalsEntry
info) ->
                              let coal_etry_x :: Coalesced
coal_etry_x = CoalescedKind -> ArrayMemBound -> FreeVarSubsts -> Coalesced
Coalesced CoalescedKind
k ArrayMemBound
mblk FreeVarSubsts
fv_subs
                                  info' :: CoalsEntry
info' =
                                    CoalsEntry
info
                                      { vartab :: Map VName Coalesced
vartab =
                                          forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
x Coalesced
coal_etry_x forall a b. (a -> b) -> a -> b
$
                                            forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert (forall dec. PatElem dec -> VName
patElemName PatElem (LetDec (Aliases rep))
x') Coalesced
coal_etry_x (CoalsEntry -> Map VName Coalesced
vartab CoalsEntry
info)
                                      }
                               in (forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
m_x CoalsEntry
info' CoalsTab
actv, InhibitTab
inhbt)
                        Maybe FreeVarSubsts
_ ->
                          (CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab)
markFailedCoal (CoalsTab
actv, InhibitTab
inhbt) VName
m_x

            -- (c) this stm is also a potential source for coalescing, so process it
            actv'' :: CoalsTab
actv'' = if Safety
safety forall a. Eq a => a -> a -> Bool
== Safety
Unsafe then forall {k} (rep :: k).
HasMemBlock (Aliases rep) =>
Pat (ConsumedInExp, LetDecMem)
-> Exp (Aliases rep)
-> InhibitTab
-> TopdownEnv rep
-> CoalsTab
-> CoalsTab
-> InhibitTab
-> CoalsTab
mkCoalsHelper3PatternMatch Pat (LetDec (Aliases rep))
pat Exp (Aliases rep)
e InhibitTab
lutab TopdownEnv rep
td_env (BotUpEnv -> CoalsTab
successCoals BotUpEnv
bu_env) CoalsTab
actv' InhibitTab
inhbt' else CoalsTab
actv'
        forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$
          BotUpEnv
bu_env {activeCoals :: CoalsTab
activeCoals = CoalsTab
actv'', inhibit :: InhibitTab
inhibit = InhibitTab
inhbt'}

-- The case of flat in-place update:
--   @let x' = x with flat-slice <- elm@
mkCoalsTabStm InhibitTab
lutab stm :: Stm (Aliases rep)
stm@(Let pat :: Pat (LetDec (Aliases rep))
pat@(Pat [PatElem (LetDec (Aliases rep))
x']) StmAux (ExpDec (Aliases rep))
_ e :: Exp (Aliases rep)
e@(BasicOp (FlatUpdate VName
x FlatSlice SubExp
_ VName
_elm))) TopdownEnv rep
td_env BotUpEnv
bu_env
  | [(VName
_, MemBlock PrimType
_ ShapeBase SubExp
_ VName
m_x IxFun
_)] <- forall aliases.
Pat (aliases, LetDecMem) -> [(VName, ArrayMemBound)]
getArrMemAssoc Pat (LetDec (Aliases rep))
pat =
      do
        -- (a) filter by the 3rd safety for @elm@ and @x'@
        let (CoalsTab
actv, InhibitTab
inhbt) = forall {k} (rep :: k) inner.
(CanBeAliased (Op rep), RepTypes rep, Op rep ~ MemOp inner,
 HasMemBlock (Aliases rep)) =>
TopdownEnv rep
-> BotUpEnv -> Stm (Aliases rep) -> (CoalsTab, InhibitTab)
recordMemRefUses TopdownEnv rep
td_env BotUpEnv
bu_env Stm (Aliases rep)
stm
            -- (b) if @x'@ is in active coalesced table, then add an entry for @x@ as well
            (CoalsTab
actv', InhibitTab
inhbt') =
              case forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
m_x CoalsTab
actv of
                Maybe CoalsEntry
Nothing -> (CoalsTab
actv, InhibitTab
inhbt)
                Just CoalsEntry
info ->
                  case forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (forall dec. PatElem dec -> VName
patElemName PatElem (LetDec (Aliases rep))
x') (CoalsEntry -> Map VName Coalesced
vartab CoalsEntry
info) of
                    Maybe Coalesced
Nothing ->
                      -- error "In ArrayCoalescing.hs, fun mkCoalsTabStm, case in-place update!"
                      -- this case should not happen, but if it can that just fail conservatively
                      (CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab)
markFailedCoal (CoalsTab
actv, InhibitTab
inhbt) VName
m_x
                    Just (Coalesced CoalescedKind
k mblk :: ArrayMemBound
mblk@(MemBlock PrimType
_ ShapeBase SubExp
_ VName
_ IxFun
x_indfun) FreeVarSubsts
_) ->
                      case forall {k} a (rep :: k).
FreeIn a =>
ScopeTab rep
-> Map VName (PrimExp VName) -> a -> Maybe FreeVarSubsts
freeVarSubstitutions (forall {k} (rep :: k). TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env) (BotUpEnv -> Map VName (PrimExp VName)
scals BotUpEnv
bu_env) IxFun
x_indfun of
                        Just FreeVarSubsts
fv_subs
                          | forall {k} (rep :: k). TopdownEnv rep -> VName -> Bool
isInScope TopdownEnv rep
td_env (CoalsEntry -> VName
dstmem CoalsEntry
info) ->
                              let coal_etry_x :: Coalesced
coal_etry_x = CoalescedKind -> ArrayMemBound -> FreeVarSubsts -> Coalesced
Coalesced CoalescedKind
k ArrayMemBound
mblk FreeVarSubsts
fv_subs
                                  info' :: CoalsEntry
info' =
                                    CoalsEntry
info
                                      { vartab :: Map VName Coalesced
vartab =
                                          forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
x Coalesced
coal_etry_x forall a b. (a -> b) -> a -> b
$
                                            forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert (forall dec. PatElem dec -> VName
patElemName PatElem (LetDec (Aliases rep))
x') Coalesced
coal_etry_x (CoalsEntry -> Map VName Coalesced
vartab CoalsEntry
info)
                                      }
                               in (forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
m_x CoalsEntry
info' CoalsTab
actv, InhibitTab
inhbt)
                        Maybe FreeVarSubsts
_ ->
                          (CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab)
markFailedCoal (CoalsTab
actv, InhibitTab
inhbt) VName
m_x

            -- (c) this stm is also a potential source for coalescing, so process it
            actv'' :: CoalsTab
actv'' = forall {k} (rep :: k).
HasMemBlock (Aliases rep) =>
Pat (ConsumedInExp, LetDecMem)
-> Exp (Aliases rep)
-> InhibitTab
-> TopdownEnv rep
-> CoalsTab
-> CoalsTab
-> InhibitTab
-> CoalsTab
mkCoalsHelper3PatternMatch Pat (LetDec (Aliases rep))
pat Exp (Aliases rep)
e InhibitTab
lutab TopdownEnv rep
td_env (BotUpEnv -> CoalsTab
successCoals BotUpEnv
bu_env) CoalsTab
actv' InhibitTab
inhbt'
        forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$
          BotUpEnv
bu_env {activeCoals :: CoalsTab
activeCoals = CoalsTab
actv'', inhibit :: InhibitTab
inhibit = InhibitTab
inhbt'}
--
mkCoalsTabStm InhibitTab
_ (Let Pat (LetDec (Aliases rep))
pat StmAux (ExpDec (Aliases rep))
_ (BasicOp Update {})) TopdownEnv rep
_ BotUpEnv
_ =
  forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"In ArrayCoalescing.hs, fun mkCoalsTabStm, illegal pattern for in-place update: " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show Pat (LetDec (Aliases rep))
pat
-- default handling
mkCoalsTabStm InhibitTab
lutab (Let Pat (LetDec (Aliases rep))
pat StmAux (ExpDec (Aliases rep))
_ (Op Op (Aliases rep)
op)) TopdownEnv rep
td_env BotUpEnv
bu_env = do
  -- Process body
  InhibitTab
-> Pat (ConsumedInExp, LetDecMem)
-> MemOp (OpWithAliases inner)
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep BotUpEnv
on_op <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall {k} (rep :: k).
ShortCircuitReader rep
-> InhibitTab
-> Pat (ConsumedInExp, LetDecMem)
-> Op (Aliases rep)
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep BotUpEnv
onOp
  InhibitTab
-> Pat (ConsumedInExp, LetDecMem)
-> MemOp (OpWithAliases inner)
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep BotUpEnv
on_op InhibitTab
lutab Pat (LetDec (Aliases rep))
pat Op (Aliases rep)
op TopdownEnv rep
td_env BotUpEnv
bu_env
mkCoalsTabStm InhibitTab
lutab stm :: Stm (Aliases rep)
stm@(Let Pat (LetDec (Aliases rep))
pat StmAux (ExpDec (Aliases rep))
_ Exp (Aliases rep)
e) TopdownEnv rep
td_env BotUpEnv
bu_env = do
  --   i) Filter @activeCoals@ by the 3rd safety condition:
  --      this is now relaxed by use of LMAD eqs:
  --      the memory referenced in stm are added to memrefs::dstrefs
  --      in corresponding coal-tab entries.
  let (CoalsTab
activeCoals', InhibitTab
inhibit') = forall {k} (rep :: k) inner.
(CanBeAliased (Op rep), RepTypes rep, Op rep ~ MemOp inner,
 HasMemBlock (Aliases rep)) =>
TopdownEnv rep
-> BotUpEnv -> Stm (Aliases rep) -> (CoalsTab, InhibitTab)
recordMemRefUses TopdownEnv rep
td_env BotUpEnv
bu_env Stm (Aliases rep)
stm
      -- mkCoalsHelper1FilterActive pat (freeIn e) (scope td_env) (scals bu_env)
      --                           (activeCoals bu_env) (inhibit bu_env)

      --  ii) promote any of the entries in @activeCoals@ to @successCoals@ as long as
      --        - this statement defined a variable consumed in a coalesced statement
      --        - and safety conditions 2, 4, and 5 are satisfied.
      --      AND extend @activeCoals@ table for any definition of a variable that
      --      aliases a coalesced variable.
      safe_4 :: Bool
safe_4 = forall {k} (rep :: k). CreatesNewArrOp (Op rep) => Exp rep -> Bool
createsNewArrOK Exp (Aliases rep)
e
      ((CoalsTab
activeCoals'', InhibitTab
inhibit''), CoalsTab
successCoals') =
        forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (Bool
-> ((CoalsTab, InhibitTab), CoalsTab)
-> (VName, ArrayMemBound)
-> ((CoalsTab, InhibitTab), CoalsTab)
foldfun Bool
safe_4) ((CoalsTab
activeCoals', InhibitTab
inhibit'), BotUpEnv -> CoalsTab
successCoals BotUpEnv
bu_env) (forall aliases.
Pat (aliases, LetDecMem) -> [(VName, ArrayMemBound)]
getArrMemAssoc Pat (LetDec (Aliases rep))
pat)

      -- iii) record a potentially coalesced statement in @activeCoals@
      activeCoals''' :: CoalsTab
activeCoals''' = forall {k} (rep :: k).
HasMemBlock (Aliases rep) =>
Pat (ConsumedInExp, LetDecMem)
-> Exp (Aliases rep)
-> InhibitTab
-> TopdownEnv rep
-> CoalsTab
-> CoalsTab
-> InhibitTab
-> CoalsTab
mkCoalsHelper3PatternMatch Pat (LetDec (Aliases rep))
pat Exp (Aliases rep)
e InhibitTab
lutab TopdownEnv rep
td_env CoalsTab
successCoals' CoalsTab
activeCoals'' (forall {k} (rep :: k). TopdownEnv rep -> InhibitTab
inhibited TopdownEnv rep
td_env)
  forall (f :: * -> *) a. Applicative f => a -> f a
pure BotUpEnv
bu_env {activeCoals :: CoalsTab
activeCoals = CoalsTab
activeCoals''', inhibit :: InhibitTab
inhibit = InhibitTab
inhibit'', successCoals :: CoalsTab
successCoals = CoalsTab
successCoals'}
  where
    foldfun :: Bool
-> ((CoalsTab, InhibitTab), CoalsTab)
-> (VName, ArrayMemBound)
-> ((CoalsTab, InhibitTab), CoalsTab)
foldfun Bool
safe_4 ((CoalsTab
a_acc, InhibitTab
inhb), CoalsTab
s_acc) (VName
b, MemBlock PrimType
tp ShapeBase SubExp
shp VName
mb IxFun
_b_indfun) =
      case forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
mb CoalsTab
a_acc of
        Maybe CoalsEntry
Nothing -> ((CoalsTab
a_acc, InhibitTab
inhb), CoalsTab
s_acc)
        Just info :: CoalsEntry
info@(CoalsEntry VName
x_mem IxFun
_ Names
_ Map VName Coalesced
vtab Map VName VName
_ MemRefs
_) ->
          let failed :: (CoalsTab, InhibitTab)
failed = (CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab)
markFailedCoal (CoalsTab
a_acc, InhibitTab
inhb) VName
mb
           in case forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
b Map VName Coalesced
vtab of
                Maybe Coalesced
Nothing ->
                  -- we hit the definition of some variable @b@ aliased with
                  --    the coalesced variable @x@, hence extend @activeCoals@, e.g.,
                  --       @let x = map f arr  @
                  --       @let b = alias x  @ <- current statement
                  --       @ ... use of b ...  @
                  --       @let c = alias b    @ <- currently fails
                  --       @let y[i] = x       @
                  -- where @alias@ can be @transpose@, @slice@, @rotate@, @reshape@.
                  -- We use getTransitiveAlias helper function to track the aliasing
                  --    through the td_env, and to find the updated ixfun of @b@:
                  case forall {k} (rep :: k).
HasMemBlock (Aliases rep) =>
TopdownEnv rep -> CoalsTab -> VName -> Maybe (VName, VName, IxFun)
getDirAliasedIxfn TopdownEnv rep
td_env CoalsTab
a_acc VName
b of
                    Maybe (VName, VName, IxFun)
Nothing -> ((CoalsTab, InhibitTab)
failed, CoalsTab
s_acc)
                    Just (VName
_, VName
_, IxFun
b_indfun') ->
                      case forall {k} a (rep :: k).
FreeIn a =>
ScopeTab rep
-> Map VName (PrimExp VName) -> a -> Maybe FreeVarSubsts
freeVarSubstitutions (forall {k} (rep :: k). TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env) (BotUpEnv -> Map VName (PrimExp VName)
scals BotUpEnv
bu_env) IxFun
b_indfun' of
                        Maybe FreeVarSubsts
Nothing -> ((CoalsTab, InhibitTab)
failed, CoalsTab
s_acc)
                        Just FreeVarSubsts
fv_subst ->
                          let mem_info :: Coalesced
mem_info = CoalescedKind -> ArrayMemBound -> FreeVarSubsts -> Coalesced
Coalesced CoalescedKind
TransitiveCoal (PrimType -> ShapeBase SubExp -> VName -> IxFun -> ArrayMemBound
MemBlock PrimType
tp ShapeBase SubExp
shp VName
x_mem IxFun
b_indfun') FreeVarSubsts
fv_subst
                              info' :: CoalsEntry
info' = CoalsEntry
info {vartab :: Map VName Coalesced
vartab = forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
b Coalesced
mem_info Map VName Coalesced
vtab}
                           in ((forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
mb CoalsEntry
info' CoalsTab
a_acc, InhibitTab
inhb), CoalsTab
s_acc)
                Just (Coalesced CoalescedKind
k mblk :: ArrayMemBound
mblk@(MemBlock PrimType
_ ShapeBase SubExp
_ VName
_ IxFun
new_indfun) FreeVarSubsts
_) ->
                  -- we are at the definition of the coalesced variable @b@
                  -- if 2,4,5 hold promote it to successful coalesced table,
                  -- or if e = transpose, etc. then postpone decision for later on
                  let safe_2 :: Bool
safe_2 = forall {k} (rep :: k). TopdownEnv rep -> VName -> Bool
isInScope TopdownEnv rep
td_env VName
x_mem
                   in case forall {k} a (rep :: k).
FreeIn a =>
ScopeTab rep
-> Map VName (PrimExp VName) -> a -> Maybe FreeVarSubsts
freeVarSubstitutions (forall {k} (rep :: k). TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env) (BotUpEnv -> Map VName (PrimExp VName)
scals BotUpEnv
bu_env) IxFun
new_indfun of
                        Just FreeVarSubsts
fv_subst
                          | Bool
safe_2 ->
                              let mem_info :: Coalesced
mem_info = CoalescedKind -> ArrayMemBound -> FreeVarSubsts -> Coalesced
Coalesced CoalescedKind
k ArrayMemBound
mblk FreeVarSubsts
fv_subst
                                  info' :: CoalsEntry
info' = CoalsEntry
info {vartab :: Map VName Coalesced
vartab = forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
b Coalesced
mem_info Map VName Coalesced
vtab}
                               in if Bool
safe_4
                                    then -- array creation point, successful coalescing verified!

                                      let (CoalsTab
a_acc', CoalsTab
s_acc') = (CoalsTab, CoalsTab) -> VName -> CoalsEntry -> (CoalsTab, CoalsTab)
markSuccessCoal (CoalsTab
a_acc, CoalsTab
s_acc) VName
mb CoalsEntry
info'
                                       in ((CoalsTab
a_acc', InhibitTab
inhb), CoalsTab
s_acc')
                                    else -- this is an invertible alias case of the kind
                                    -- @ let b    = alias a @
                                    -- @ let x[i] = b @
                                    -- do not promote, but update the index function

                                      ((forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
mb CoalsEntry
info' CoalsTab
a_acc, InhibitTab
inhb), CoalsTab
s_acc)
                        Maybe FreeVarSubsts
_ -> ((CoalsTab, InhibitTab)
failed, CoalsTab
s_acc) -- fail!

ixfunToAccessSummary :: IxFun.IxFun (TPrimExp Int64 VName) -> AccessSummary
ixfunToAccessSummary :: IxFun -> AccessSummary
ixfunToAccessSummary (IxFun.IxFun (LmadRef
lmad NE.:| []) [TPrimExp Int64 VName]
_ Bool
_) = Set LmadRef -> AccessSummary
Set forall a b. (a -> b) -> a -> b
$ forall a. a -> Set a
S.singleton LmadRef
lmad
ixfunToAccessSummary IxFun
_ = AccessSummary
Undeterminable

-- | Check safety conditions 2 and 5 and update new substitutions:
-- called on the pat-elements of loop and if-then-else expressions.
--
-- The safety conditions are: The allocation of merge target should dominate the
-- creation of the array we're trying to merge and the new index function of the
-- array can be translated at the definition site of b. The latter requires that
-- any variables used in the index function of the target array are available at
-- the definition site of b.
filterSafetyCond2and5 ::
  HasMemBlock (Aliases rep) =>
  CoalsTab ->
  InhibitTab ->
  ScalarTab ->
  TopdownEnv rep ->
  [PatElem (VarAliases, LetDecMem)] ->
  (CoalsTab, InhibitTab)
filterSafetyCond2and5 :: forall {k} (rep :: k).
HasMemBlock (Aliases rep) =>
CoalsTab
-> InhibitTab
-> Map VName (PrimExp VName)
-> TopdownEnv rep
-> [PatElem (ConsumedInExp, LetDecMem)]
-> (CoalsTab, InhibitTab)
filterSafetyCond2and5 CoalsTab
act_coal InhibitTab
inhb_coal Map VName (PrimExp VName)
scals_env TopdownEnv rep
td_env =
  forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (CoalsTab, InhibitTab)
-> PatElem (ConsumedInExp, LetDecMem) -> (CoalsTab, InhibitTab)
helper (CoalsTab
act_coal, InhibitTab
inhb_coal)
  where
    helper :: (CoalsTab, InhibitTab)
-> PatElem (ConsumedInExp, LetDecMem) -> (CoalsTab, InhibitTab)
helper (CoalsTab
acc, InhibitTab
inhb) PatElem (ConsumedInExp, LetDecMem)
patel =
      -- For each pattern element in the input list
      case (forall dec. PatElem dec -> VName
patElemName PatElem (ConsumedInExp, LetDecMem)
patel, forall dec. PatElem dec -> dec
patElemDec PatElem (ConsumedInExp, LetDecMem)
patel) of
        (VName
b, (ConsumedInExp
_, MemArray PrimType
tp0 ShapeBase SubExp
shp0 NoUniqueness
_ (ArrayIn VName
m_b IxFun
_idxfn_b))) ->
          -- If it is an array in memory block m_b
          case forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
m_b CoalsTab
acc of
            Maybe CoalsEntry
Nothing -> (CoalsTab
acc, InhibitTab
inhb)
            Just info :: CoalsEntry
info@(CoalsEntry VName
x_mem IxFun
_ Names
_ Map VName Coalesced
vtab Map VName VName
_ MemRefs
_) ->
              -- And m_b we're trying to coalesce m_b
              let failed :: (CoalsTab, InhibitTab)
failed = (CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab)
markFailedCoal (CoalsTab
acc, InhibitTab
inhb) VName
m_b
               in case forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
b Map VName Coalesced
vtab of
                    Maybe Coalesced
Nothing ->
                      case forall {k} (rep :: k).
HasMemBlock (Aliases rep) =>
TopdownEnv rep -> CoalsTab -> VName -> Maybe (VName, VName, IxFun)
getDirAliasedIxfn TopdownEnv rep
td_env CoalsTab
acc VName
b of
                        Maybe (VName, VName, IxFun)
Nothing -> (CoalsTab, InhibitTab)
failed
                        Just (VName
_, VName
_, IxFun
b_indfun') ->
                          -- And we have the index function of b
                          case forall {k} a (rep :: k).
FreeIn a =>
ScopeTab rep
-> Map VName (PrimExp VName) -> a -> Maybe FreeVarSubsts
freeVarSubstitutions (forall {k} (rep :: k). TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env) Map VName (PrimExp VName)
scals_env IxFun
b_indfun' of
                            Maybe FreeVarSubsts
Nothing -> (CoalsTab, InhibitTab)
failed
                            Just FreeVarSubsts
fv_subst ->
                              let mem_info :: Coalesced
mem_info = CoalescedKind -> ArrayMemBound -> FreeVarSubsts -> Coalesced
Coalesced CoalescedKind
TransitiveCoal (PrimType -> ShapeBase SubExp -> VName -> IxFun -> ArrayMemBound
MemBlock PrimType
tp0 ShapeBase SubExp
shp0 VName
x_mem IxFun
b_indfun') FreeVarSubsts
fv_subst
                                  info' :: CoalsEntry
info' = CoalsEntry
info {vartab :: Map VName Coalesced
vartab = forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
b Coalesced
mem_info Map VName Coalesced
vtab}
                               in (forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
m_b CoalsEntry
info' CoalsTab
acc, InhibitTab
inhb)
                    Just (Coalesced CoalescedKind
k (MemBlock PrimType
pt ShapeBase SubExp
shp VName
_ IxFun
new_indfun) FreeVarSubsts
_) ->
                      let safe_2 :: Bool
safe_2 = forall {k} (rep :: k). TopdownEnv rep -> VName -> Bool
isInScope TopdownEnv rep
td_env VName
x_mem
                       in case forall {k} a (rep :: k).
FreeIn a =>
ScopeTab rep
-> Map VName (PrimExp VName) -> a -> Maybe FreeVarSubsts
freeVarSubstitutions (forall {k} (rep :: k). TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env) Map VName (PrimExp VName)
scals_env IxFun
new_indfun of
                            Just FreeVarSubsts
fv_subst
                              | Bool
safe_2 ->
                                  let mem_info :: Coalesced
mem_info = CoalescedKind -> ArrayMemBound -> FreeVarSubsts -> Coalesced
Coalesced CoalescedKind
k (PrimType -> ShapeBase SubExp -> VName -> IxFun -> ArrayMemBound
MemBlock PrimType
pt ShapeBase SubExp
shp VName
x_mem IxFun
new_indfun) FreeVarSubsts
fv_subst
                                      info' :: CoalsEntry
info' = CoalsEntry
info {vartab :: Map VName Coalesced
vartab = forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
b Coalesced
mem_info Map VName Coalesced
vtab}
                                   in (forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
m_b CoalsEntry
info' CoalsTab
acc, InhibitTab
inhb)
                            Maybe FreeVarSubsts
_ -> (CoalsTab, InhibitTab)
failed
        (VName, (ConsumedInExp, LetDecMem))
_ -> (CoalsTab
acc, InhibitTab
inhb)

-- |   Pattern matches a potentially coalesced statement and
--     records a new association in @activeCoals@
mkCoalsHelper3PatternMatch ::
  HasMemBlock (Aliases rep) =>
  Pat (VarAliases, LetDecMem) ->
  Exp (Aliases rep) ->
  LUTabFun ->
  TopdownEnv rep ->
  CoalsTab ->
  CoalsTab ->
  InhibitTab ->
  CoalsTab
mkCoalsHelper3PatternMatch :: forall {k} (rep :: k).
HasMemBlock (Aliases rep) =>
Pat (ConsumedInExp, LetDecMem)
-> Exp (Aliases rep)
-> InhibitTab
-> TopdownEnv rep
-> CoalsTab
-> CoalsTab
-> InhibitTab
-> CoalsTab
mkCoalsHelper3PatternMatch Pat (ConsumedInExp, LetDecMem)
pat Exp (Aliases rep)
e InhibitTab
lutab TopdownEnv rep
td_env CoalsTab
_ CoalsTab
activeCoals_tab InhibitTab
_
  | Maybe
  [(CoalescedKind, IxFun -> IxFun, VName, VName, IxFun, VName, VName,
    IxFun, PrimType, ShapeBase SubExp)]
Nothing <- forall {k} (rep :: k).
HasMemBlock (Aliases rep) =>
InhibitTab
-> ScopeTab rep
-> Pat (ConsumedInExp, LetDecMem)
-> Exp (Aliases rep)
-> Maybe
     [(CoalescedKind, IxFun -> IxFun, VName, VName, IxFun, VName, VName,
       IxFun, PrimType, ShapeBase SubExp)]
genCoalStmtInfo InhibitTab
lutab (forall {k} (rep :: k). TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env) Pat (ConsumedInExp, LetDecMem)
pat Exp (Aliases rep)
e =
      CoalsTab
activeCoals_tab
mkCoalsHelper3PatternMatch Pat (ConsumedInExp, LetDecMem)
pat Exp (Aliases rep)
e InhibitTab
lutab TopdownEnv rep
td_env CoalsTab
successCoals_tab CoalsTab
activeCoals_tab InhibitTab
inhibit_tab
  | Just [(CoalescedKind, IxFun -> IxFun, VName, VName, IxFun, VName, VName,
  IxFun, PrimType, ShapeBase SubExp)]
clst <- forall {k} (rep :: k).
HasMemBlock (Aliases rep) =>
InhibitTab
-> ScopeTab rep
-> Pat (ConsumedInExp, LetDecMem)
-> Exp (Aliases rep)
-> Maybe
     [(CoalescedKind, IxFun -> IxFun, VName, VName, IxFun, VName, VName,
       IxFun, PrimType, ShapeBase SubExp)]
genCoalStmtInfo InhibitTab
lutab (forall {k} (rep :: k). TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env) Pat (ConsumedInExp, LetDecMem)
pat Exp (Aliases rep)
e =
      forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl CoalsTab
-> (CoalescedKind, IxFun -> IxFun, VName, VName, IxFun, VName,
    VName, IxFun, PrimType, ShapeBase SubExp)
-> CoalsTab
processNewCoalesce CoalsTab
activeCoals_tab [(CoalescedKind, IxFun -> IxFun, VName, VName, IxFun, VName, VName,
  IxFun, PrimType, ShapeBase SubExp)]
clst
  where
    processNewCoalesce :: CoalsTab
-> (CoalescedKind, IxFun -> IxFun, VName, VName, IxFun, VName,
    VName, IxFun, PrimType, ShapeBase SubExp)
-> CoalsTab
processNewCoalesce CoalsTab
acc (CoalescedKind
knd, IxFun -> IxFun
alias_fn, VName
x, VName
m_x, IxFun
ind_x, VName
b, VName
m_b, IxFun
_, PrimType
tp_b, ShapeBase SubExp
shp_b) =
      -- test whether we are in a transitive coalesced case, i.e.,
      --      @let b = scratch ...@
      --      @.....@
      --      @let x[j] = b@
      --      @let y[i] = x@
      -- and compose the index function of @x@ with that of @y@,
      -- and update aliasing of the @m_b@ entry to also contain @m_y@
      -- on top of @m_x@, i.e., transitively, any use of @m_y@ should
      -- be checked for the lifetime of @b@.
      let proper_coals_tab :: CoalsTab
proper_coals_tab = case CoalescedKind
knd of
            CoalescedKind
InPlaceCoal -> CoalsTab
activeCoals_tab
            CoalescedKind
_ -> CoalsTab
successCoals_tab
          (VName
m_yx, IxFun
ind_yx, Names
mem_yx_al, Map VName VName
x_deps) =
            case forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
m_x CoalsTab
proper_coals_tab of
              Maybe CoalsEntry
Nothing ->
                (VName
m_x, IxFun -> IxFun
alias_fn IxFun
ind_x, VName -> Names
oneName VName
m_x, forall k a. Map k a
M.empty)
              Just (CoalsEntry VName
m_y IxFun
ind_y Names
y_al Map VName Coalesced
vtab Map VName VName
x_deps0 MemRefs
_) ->
                let ind :: IxFun
ind = case forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
x Map VName Coalesced
vtab of
                      Just (Coalesced CoalescedKind
_ (MemBlock PrimType
_ ShapeBase SubExp
_ VName
_ IxFun
ixf) FreeVarSubsts
_) ->
                        IxFun
ixf
                      Maybe Coalesced
Nothing ->
                        IxFun
ind_y
                 in (VName
m_y, IxFun -> IxFun
alias_fn IxFun
ind, VName -> Names
oneName VName
m_x forall a. Semigroup a => a -> a -> a
<> Names
y_al, Map VName VName
x_deps0)
          success0 :: Bool
success0 = forall num. IxFun num -> Bool
IxFun.hasOneLmad IxFun
ind_yx
          m_b_aliased_m_yx :: Bool
m_b_aliased_m_yx = forall {k} (rep :: k). TopdownEnv rep -> VName -> [VName] -> Bool
areAnyAliased TopdownEnv rep
td_env VName
m_b [VName
m_yx] -- m_b \= m_yx
       in case (Bool
success0, Bool -> Bool
not Bool
m_b_aliased_m_yx, forall {k} (rep :: k). TopdownEnv rep -> VName -> Bool
isInScope TopdownEnv rep
td_env VName
m_yx) of -- nameIn m_yx (alloc td_env)
            (Bool
True, Bool
True, Bool
True) ->
              -- Finally update the @activeCoals@ table with a fresh
              --   binding for @m_b@; if such one exists then overwrite.
              -- Also, add all variables from the alias chain of @b@ to
              --   @vartab@, for example, in the case of a sequence:
              --   @ b0 = if cond then ... else ... @
              --   @ b1 = alias0 b0 @
              --   @ b  = alias1 b1 @
              --   @ x[j] = b @
              -- Then @b1@ and @b0@ should also be added to @vartab@ if
              --   @alias1@ and @alias0@ are invertible, otherwise fail early!
              let mem_info :: Coalesced
mem_info = CoalescedKind -> ArrayMemBound -> FreeVarSubsts -> Coalesced
Coalesced CoalescedKind
knd (PrimType -> ShapeBase SubExp -> VName -> IxFun -> ArrayMemBound
MemBlock PrimType
tp_b ShapeBase SubExp
shp_b VName
m_yx IxFun
ind_yx) forall k a. Map k a
M.empty
                  opts' :: Map VName VName
opts' =
                    if VName
m_yx forall a. Eq a => a -> a -> Bool
== VName
m_x
                      then forall k a. Map k a
M.empty
                      else forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
x VName
m_x Map VName VName
x_deps
                  vtab :: Map VName Coalesced
vtab = forall k a. k -> a -> Map k a
M.singleton VName
b Coalesced
mem_info
                  mvtab :: Maybe (Map VName Coalesced)
mvtab = forall {k} (rep :: k).
HasMemBlock (Aliases rep) =>
TopdownEnv rep
-> Map VName Coalesced -> VName -> Maybe (Map VName Coalesced)
addInvAliassesVarTab TopdownEnv rep
td_env Map VName Coalesced
vtab VName
b

                  is_inhibited :: Bool
is_inhibited = case forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
m_b InhibitTab
inhibit_tab of
                    Just Names
nms -> VName
m_yx VName -> Names -> Bool
`nameIn` Names
nms
                    Maybe Names
Nothing -> Bool
False
               in case (Bool
is_inhibited, Maybe (Map VName Coalesced)
mvtab) of
                    (Bool
True, Maybe (Map VName Coalesced)
_) -> CoalsTab
acc -- fail due to inhibited
                    (Bool
_, Maybe (Map VName Coalesced)
Nothing) -> CoalsTab
acc -- fail early due to non-invertible aliasing
                    (Bool
_, Just Map VName Coalesced
vtab') ->
                      -- successfully adding a new coalesced entry
                      let coal_etry :: CoalsEntry
coal_etry =
                            VName
-> IxFun
-> Names
-> Map VName Coalesced
-> Map VName VName
-> MemRefs
-> CoalsEntry
CoalsEntry
                              VName
m_yx
                              IxFun
ind_yx
                              Names
mem_yx_al
                              Map VName Coalesced
vtab'
                              Map VName VName
opts'
                              forall a. Monoid a => a
mempty
                       in forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
m_b CoalsEntry
coal_etry CoalsTab
acc
            (Bool, Bool, Bool)
_ -> CoalsTab
acc
mkCoalsHelper3PatternMatch Pat (ConsumedInExp, LetDecMem)
_ Exp (Aliases rep)
_ InhibitTab
_ TopdownEnv rep
_ CoalsTab
_ CoalsTab
_ InhibitTab
_ =
  forall a. HasCallStack => String -> a
error String
"In ArrayCoalescing.hs, fun mkCoalsHelper3PatternMatch: Unreachable!!!"

genCoalStmtInfo ::
  HasMemBlock (Aliases rep) =>
  LUTabFun ->
  ScopeTab rep ->
  Pat (VarAliases, LetDecMem) ->
  Exp (Aliases rep) ->
  Maybe [(CoalescedKind, IxFun -> IxFun, VName, VName, IxFun, VName, VName, IxFun, PrimType, Shape)]
-- CASE a) @let x <- copy(b^{lu})@
genCoalStmtInfo :: forall {k} (rep :: k).
HasMemBlock (Aliases rep) =>
InhibitTab
-> ScopeTab rep
-> Pat (ConsumedInExp, LetDecMem)
-> Exp (Aliases rep)
-> Maybe
     [(CoalescedKind, IxFun -> IxFun, VName, VName, IxFun, VName, VName,
       IxFun, PrimType, ShapeBase SubExp)]
genCoalStmtInfo InhibitTab
lutab ScopeTab rep
scopetab Pat (ConsumedInExp, LetDecMem)
pat (BasicOp (Copy VName
b))
  | Pat [PatElem VName
x (ConsumedInExp
_, MemArray PrimType
_ ShapeBase SubExp
_ NoUniqueness
_ (ArrayIn VName
m_x IxFun
ind_x))] <- Pat (ConsumedInExp, LetDecMem)
pat =
      case (forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
x InhibitTab
lutab, forall {k} (rep :: k).
HasMemBlock rep =>
VName -> Scope rep -> Maybe ArrayMemBound
getScopeMemInfo VName
b ScopeTab rep
scopetab) of
        (Just Names
last_uses, Just (MemBlock PrimType
tpb ShapeBase SubExp
shpb VName
m_b IxFun
ind_b)) ->
          if VName
b VName -> Names -> Bool
`notNameIn` Names
last_uses
            then forall a. Maybe a
Nothing
            else forall a. a -> Maybe a
Just [(CoalescedKind
CopyCoal, forall a. a -> a
id, VName
x, VName
m_x, IxFun
ind_x, VName
b, VName
m_b, IxFun
ind_b, PrimType
tpb, ShapeBase SubExp
shpb)]
        (Maybe Names, Maybe ArrayMemBound)
_ -> forall a. Maybe a
Nothing
-- CASE c) @let x[i] = b^{lu}@
genCoalStmtInfo InhibitTab
lutab ScopeTab rep
scopetab Pat (ConsumedInExp, LetDecMem)
pat (BasicOp (Update Safety
_ VName
x Slice SubExp
slice_x (Var VName
b)))
  | Pat [PatElem VName
x' (ConsumedInExp
_, MemArray PrimType
_ ShapeBase SubExp
_ NoUniqueness
_ (ArrayIn VName
m_x IxFun
ind_x))] <- Pat (ConsumedInExp, LetDecMem)
pat =
      case (forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
x' InhibitTab
lutab, forall {k} (rep :: k).
HasMemBlock rep =>
VName -> Scope rep -> Maybe ArrayMemBound
getScopeMemInfo VName
b ScopeTab rep
scopetab) of
        (Just Names
last_uses, Just (MemBlock PrimType
tpb ShapeBase SubExp
shpb VName
m_b IxFun
ind_b)) ->
          if VName
b VName -> Names -> Bool
`notNameIn` Names
last_uses
            then forall a. Maybe a
Nothing
            else forall a. a -> Maybe a
Just [(CoalescedKind
InPlaceCoal, (IxFun -> Slice SubExp -> IxFun
`updateIndFunSlice` Slice SubExp
slice_x), VName
x, VName
m_x, IxFun
ind_x, VName
b, VName
m_b, IxFun
ind_b, PrimType
tpb, ShapeBase SubExp
shpb)]
        (Maybe Names, Maybe ArrayMemBound)
_ -> forall a. Maybe a
Nothing
  where
    updateIndFunSlice :: IxFun -> Slice SubExp -> IxFun
    updateIndFunSlice :: IxFun -> Slice SubExp -> IxFun
updateIndFunSlice IxFun
ind_fun Slice SubExp
slc_x =
      let slc_x' :: [DimIndex (TPrimExp Int64 VName)]
slc_x' = forall a b. (a -> b) -> [a] -> [b]
map (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 VName
pe64) forall a b. (a -> b) -> a -> b
$ forall d. Slice d -> [DimIndex d]
unSlice Slice SubExp
slc_x
       in forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun
ind_fun forall a b. (a -> b) -> a -> b
$ forall d. [DimIndex d] -> Slice d
Slice [DimIndex (TPrimExp Int64 VName)]
slc_x'
genCoalStmtInfo InhibitTab
lutab ScopeTab rep
scopetab Pat (ConsumedInExp, LetDecMem)
pat (BasicOp (FlatUpdate VName
x FlatSlice SubExp
slice_x VName
b))
  | Pat [PatElem VName
x' (ConsumedInExp
_, MemArray PrimType
_ ShapeBase SubExp
_ NoUniqueness
_ (ArrayIn VName
m_x IxFun
ind_x))] <- Pat (ConsumedInExp, LetDecMem)
pat =
      case (forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
x' InhibitTab
lutab, forall {k} (rep :: k).
HasMemBlock rep =>
VName -> Scope rep -> Maybe ArrayMemBound
getScopeMemInfo VName
b ScopeTab rep
scopetab) of
        (Just Names
last_uses, Just (MemBlock PrimType
tpb ShapeBase SubExp
shpb VName
m_b IxFun
ind_b)) ->
          if VName
b VName -> Names -> Bool
`notNameIn` Names
last_uses
            then forall a. Maybe a
Nothing
            else forall a. a -> Maybe a
Just [(CoalescedKind
InPlaceCoal, (IxFun -> FlatSlice SubExp -> IxFun
`updateIndFunSlice` FlatSlice SubExp
slice_x), VName
x, VName
m_x, IxFun
ind_x, VName
b, VName
m_b, IxFun
ind_b, PrimType
tpb, ShapeBase SubExp
shpb)]
        (Maybe Names, Maybe ArrayMemBound)
_ -> forall a. Maybe a
Nothing
  where
    updateIndFunSlice :: IxFun -> FlatSlice SubExp -> IxFun
    updateIndFunSlice :: IxFun -> FlatSlice SubExp -> IxFun
updateIndFunSlice IxFun
ind_fun (FlatSlice SubExp
offset [FlatDimIndex SubExp]
dims) =
      forall num.
(Eq num, IntegralExp num) =>
IxFun num -> FlatSlice num -> IxFun num
IxFun.flatSlice IxFun
ind_fun forall a b. (a -> b) -> a -> b
$ forall d. d -> [FlatDimIndex d] -> FlatSlice d
FlatSlice (SubExp -> TPrimExp Int64 VName
pe64 SubExp
offset) forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 VName
pe64) [FlatDimIndex SubExp]
dims

-- CASE b) @let x = concat(a, b^{lu})@
genCoalStmtInfo InhibitTab
lutab ScopeTab rep
scopetab Pat (ConsumedInExp, LetDecMem)
pat (BasicOp (Concat Int
concat_dim (VName
b0 :| [VName]
bs) SubExp
_))
  | Pat [PatElem VName
x (ConsumedInExp
_, MemArray PrimType
_ ShapeBase SubExp
_ NoUniqueness
_ (ArrayIn VName
m_x IxFun
ind_x))] <- Pat (ConsumedInExp, LetDecMem)
pat =
      case forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
x InhibitTab
lutab of
        Maybe Names
Nothing -> forall a. Maybe a
Nothing
        Just Names
last_uses ->
          let zero :: TPrimExp Int64 VName
zero = SubExp -> TPrimExp Int64 VName
pe64 forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0
              markConcatParts :: ([(CoalescedKind, IxFun -> IxFun, VName, VName, IxFun, VName,
   VName, IxFun, PrimType, ShapeBase SubExp)],
 TPrimExp Int64 VName, Bool)
-> VName
-> ([(CoalescedKind, IxFun -> IxFun, VName, VName, IxFun, VName,
      VName, IxFun, PrimType, ShapeBase SubExp)],
    TPrimExp Int64 VName, Bool)
markConcatParts ([(CoalescedKind, IxFun -> IxFun, VName, VName, IxFun, VName, VName,
  IxFun, PrimType, ShapeBase SubExp)]
acc, TPrimExp Int64 VName
offs, Bool
succ0) VName
b =
                if Bool -> Bool
not Bool
succ0
                  then ([(CoalescedKind, IxFun -> IxFun, VName, VName, IxFun, VName, VName,
  IxFun, PrimType, ShapeBase SubExp)]
acc, TPrimExp Int64 VName
offs, Bool
succ0)
                  else case forall {k} (rep :: k).
HasMemBlock rep =>
VName -> Scope rep -> Maybe ArrayMemBound
getScopeMemInfo VName
b ScopeTab rep
scopetab of
                    Just (MemBlock PrimType
tpb shpb :: ShapeBase SubExp
shpb@(Shape dims :: [SubExp]
dims@(SubExp
_ : [SubExp]
_)) VName
m_b IxFun
ind_b)
                      | Just SubExp
d <- forall int a. Integral int => int -> [a] -> Maybe a
maybeNth Int
concat_dim [SubExp]
dims ->
                          let offs' :: TPrimExp Int64 VName
offs' = TPrimExp Int64 VName
offs forall a. Num a => a -> a -> a
+ SubExp -> TPrimExp Int64 VName
pe64 SubExp
d
                           in if VName
b VName -> Names -> Bool
`nameIn` Names
last_uses
                                then
                                  let slc :: Slice (TPrimExp Int64 VName)
slc =
                                        forall d. [DimIndex d] -> Slice d
Slice forall a b. (a -> b) -> a -> b
$
                                          forall a b. (a -> b) -> [a] -> [b]
map (forall d. Num d => d -> d -> DimIndex d
unitSlice TPrimExp Int64 VName
zero forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> TPrimExp Int64 VName
pe64) (forall a. Int -> [a] -> [a]
take Int
concat_dim [SubExp]
dims)
                                            forall a. Semigroup a => a -> a -> a
<> [forall d. Num d => d -> d -> DimIndex d
unitSlice TPrimExp Int64 VName
offs (SubExp -> TPrimExp Int64 VName
pe64 SubExp
d)]
                                            forall a. Semigroup a => a -> a -> a
<> forall a b. (a -> b) -> [a] -> [b]
map (forall d. Num d => d -> d -> DimIndex d
unitSlice TPrimExp Int64 VName
zero forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> TPrimExp Int64 VName
pe64) (forall a. Int -> [a] -> [a]
drop (Int
concat_dim forall a. Num a => a -> a -> a
+ Int
1) [SubExp]
dims)
                                   in ( [(CoalescedKind, IxFun -> IxFun, VName, VName, IxFun, VName, VName,
  IxFun, PrimType, ShapeBase SubExp)]
acc forall a. [a] -> [a] -> [a]
++ [(CoalescedKind
ConcatCoal, (forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
`IxFun.slice` Slice (TPrimExp Int64 VName)
slc), VName
x, VName
m_x, IxFun
ind_x, VName
b, VName
m_b, IxFun
ind_b, PrimType
tpb, ShapeBase SubExp
shpb)],
                                        TPrimExp Int64 VName
offs',
                                        Bool
True
                                      )
                                else ([(CoalescedKind, IxFun -> IxFun, VName, VName, IxFun, VName, VName,
  IxFun, PrimType, ShapeBase SubExp)]
acc, TPrimExp Int64 VName
offs', Bool
True)
                    Maybe ArrayMemBound
_ -> ([(CoalescedKind, IxFun -> IxFun, VName, VName, IxFun, VName, VName,
  IxFun, PrimType, ShapeBase SubExp)]
acc, TPrimExp Int64 VName
offs, Bool
False)
              ([(CoalescedKind, IxFun -> IxFun, VName, VName, IxFun, VName, VName,
  IxFun, PrimType, ShapeBase SubExp)]
res, TPrimExp Int64 VName
_, Bool
_) = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl ([(CoalescedKind, IxFun -> IxFun, VName, VName, IxFun, VName,
   VName, IxFun, PrimType, ShapeBase SubExp)],
 TPrimExp Int64 VName, Bool)
-> VName
-> ([(CoalescedKind, IxFun -> IxFun, VName, VName, IxFun, VName,
      VName, IxFun, PrimType, ShapeBase SubExp)],
    TPrimExp Int64 VName, Bool)
markConcatParts ([], TPrimExp Int64 VName
zero, Bool
True) (VName
b0 forall a. a -> [a] -> [a]
: [VName]
bs)
           in if forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(CoalescedKind, IxFun -> IxFun, VName, VName, IxFun, VName, VName,
  IxFun, PrimType, ShapeBase SubExp)]
res then forall a. Maybe a
Nothing else forall a. a -> Maybe a
Just [(CoalescedKind, IxFun -> IxFun, VName, VName, IxFun, VName, VName,
  IxFun, PrimType, ShapeBase SubExp)]
res
-- CASE other than a), b), or c) not supported
genCoalStmtInfo InhibitTab
_ ScopeTab rep
_ Pat (ConsumedInExp, LetDecMem)
_ Exp (Aliases rep)
_ = forall a. Maybe a
Nothing

data MemBodyResult = MemBodyResult
  { MemBodyResult -> VName
patMem :: VName,
    MemBodyResult -> VName
_patName :: VName,
    MemBodyResult -> VName
bodyName :: VName,
    MemBodyResult -> VName
bodyMem :: VName
  }

-- | Results in pairs of pattern-blockresult pairs of (var name, mem block)
--   for those if-patterns that are candidates for coalescing.
findMemBodyResult ::
  (HasMemBlock (Aliases rep)) =>
  CoalsTab ->
  ScopeTab rep ->
  [PatElem (VarAliases, LetDecMem)] ->
  Body (Aliases rep) ->
  [MemBodyResult]
findMemBodyResult :: forall {k} (rep :: k).
HasMemBlock (Aliases rep) =>
CoalsTab
-> ScopeTab rep
-> [PatElem (ConsumedInExp, LetDecMem)]
-> Body (Aliases rep)
-> [MemBodyResult]
findMemBodyResult CoalsTab
activeCoals_tab ScopeTab rep
scope_env [PatElem (ConsumedInExp, LetDecMem)]
patelms Body (Aliases rep)
bdy =
  forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe
    (PatElem (ConsumedInExp, LetDecMem), SubExp) -> Maybe MemBodyResult
findMemBodyResult'
    (forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem (ConsumedInExp, LetDecMem)]
patelms forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Body rep -> Result
bodyResult Body (Aliases rep)
bdy)
  where
    scope_env' :: ScopeTab rep
scope_env' = ScopeTab rep
scope_env forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k) a. Scoped rep a => a -> Scope rep
scopeOf (forall {k} (rep :: k). Body rep -> Stms rep
bodyStms Body (Aliases rep)
bdy)
    findMemBodyResult' :: (PatElem (ConsumedInExp, LetDecMem), SubExp) -> Maybe MemBodyResult
findMemBodyResult' (PatElem (ConsumedInExp, LetDecMem)
patel, SubExp
se_r) =
      case (forall dec. PatElem dec -> VName
patElemName PatElem (ConsumedInExp, LetDecMem)
patel, forall dec. PatElem dec -> dec
patElemDec PatElem (ConsumedInExp, LetDecMem)
patel, SubExp
se_r) of
        (VName
b, (ConsumedInExp
_, MemArray PrimType
_ ShapeBase SubExp
_ NoUniqueness
_ (ArrayIn VName
m_b IxFun
_)), Var VName
r) ->
          case forall {k} (rep :: k).
HasMemBlock rep =>
VName -> Scope rep -> Maybe ArrayMemBound
getScopeMemInfo VName
r ScopeTab rep
scope_env' of
            Maybe ArrayMemBound
Nothing -> forall a. Maybe a
Nothing
            Just (MemBlock PrimType
_ ShapeBase SubExp
_ VName
m_r IxFun
_) ->
              case forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
m_b CoalsTab
activeCoals_tab of
                Maybe CoalsEntry
Nothing -> forall a. Maybe a
Nothing
                Just CoalsEntry
coal_etry ->
                  case forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
b (CoalsEntry -> Map VName Coalesced
vartab CoalsEntry
coal_etry) of
                    Maybe Coalesced
Nothing -> forall a. Maybe a
Nothing
                    Just Coalesced
_ -> forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ VName -> VName -> VName -> VName -> MemBodyResult
MemBodyResult VName
m_b VName
b VName
r VName
m_r
        (VName, (ConsumedInExp, LetDecMem), SubExp)
_ -> forall a. Maybe a
Nothing

-- | transfers coalescing from if-pattern to then|else body result
--   in the active coalesced table. The transfer involves, among
--   others, inserting @(r,m_r)@ in the optimistically-dependency
--   set of @m_b@'s entry and inserting @(b,m_b)@ in the opt-deps
--   set of @m_r@'s entry. Meaning, ultimately, @m_b@ can be merged
--   if @m_r@ can be merged (and vice-versa). This is checked by a
--   fix point iteration at the function-definition level.
transferCoalsToBody ::
  M.Map VName (TPrimExp Int64 VName) -> -- (PrimExp VName)
  CoalsTab ->
  MemBodyResult ->
  CoalsTab
transferCoalsToBody :: FreeVarSubsts -> CoalsTab -> MemBodyResult -> CoalsTab
transferCoalsToBody FreeVarSubsts
exist_subs CoalsTab
activeCoals_tab (MemBodyResult VName
m_b VName
b VName
r VName
m_r)
  | -- the @Nothing@ pattern for the two lookups cannot happen
    -- because they were already cheked in @findMemBodyResult@
    Just CoalsEntry
etry <- forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
m_b CoalsTab
activeCoals_tab,
    Just (Coalesced CoalescedKind
knd (MemBlock PrimType
btp ShapeBase SubExp
shp VName
_ IxFun
ind_b) FreeVarSubsts
subst_b) <- forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
b forall a b. (a -> b) -> a -> b
$ CoalsEntry -> Map VName Coalesced
vartab CoalsEntry
etry =
      -- by definition of if-stmt, r and b have the same basic type, shape and
      -- index function, hence, for example, do not need to rebase
      -- We will check whether it is translatable at the definition point of r.
      let ind_r :: IxFun
ind_r = forall {k} a (t :: k).
Ord a =>
Map a (TPrimExp t a)
-> IxFun (TPrimExp t a) -> IxFun (TPrimExp t a)
IxFun.substituteInIxFun FreeVarSubsts
exist_subs IxFun
ind_b
          subst_r :: FreeVarSubsts
subst_r = forall k a. Ord k => Map k a -> Map k a -> Map k a
M.union FreeVarSubsts
exist_subs FreeVarSubsts
subst_b
          mem_info :: Coalesced
mem_info = CoalescedKind -> ArrayMemBound -> FreeVarSubsts -> Coalesced
Coalesced CoalescedKind
knd (PrimType -> ShapeBase SubExp -> VName -> IxFun -> ArrayMemBound
MemBlock PrimType
btp ShapeBase SubExp
shp (CoalsEntry -> VName
dstmem CoalsEntry
etry) IxFun
ind_r) FreeVarSubsts
subst_r
       in if VName
m_r forall a. Eq a => a -> a -> Bool
== VName
m_b -- already unified, just add binding for @r@
            then
              let etry' :: CoalsEntry
etry' =
                    CoalsEntry
etry
                      { optdeps :: Map VName VName
optdeps = forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
b VName
m_b (CoalsEntry -> Map VName VName
optdeps CoalsEntry
etry),
                        vartab :: Map VName Coalesced
vartab = forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
r Coalesced
mem_info (CoalsEntry -> Map VName Coalesced
vartab CoalsEntry
etry)
                      }
               in forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
m_r CoalsEntry
etry' CoalsTab
activeCoals_tab
            else -- make them both optimistically depend on each other

              let opts_x_new :: Map VName VName
opts_x_new = forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
r VName
m_r (CoalsEntry -> Map VName VName
optdeps CoalsEntry
etry)
                  -- Here we should translate the @ind_b@ field of @mem_info@
                  -- across the existential introduced by the if-then-else
                  coal_etry :: CoalsEntry
coal_etry =
                    CoalsEntry
etry
                      { vartab :: Map VName Coalesced
vartab = forall k a. k -> a -> Map k a
M.singleton VName
r Coalesced
mem_info,
                        optdeps :: Map VName VName
optdeps = forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
b VName
m_b (CoalsEntry -> Map VName VName
optdeps CoalsEntry
etry)
                      }
               in forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
m_b (CoalsEntry
etry {optdeps :: Map VName VName
optdeps = Map VName VName
opts_x_new}) forall a b. (a -> b) -> a -> b
$
                    forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
m_r CoalsEntry
coal_etry CoalsTab
activeCoals_tab
  | Bool
otherwise = forall a. HasCallStack => String -> a
error String
"Impossible"

mkSubsTab ::
  Pat (aliases, LetDecMem) ->
  [SubExp] ->
  M.Map VName (TPrimExp Int64 VName)
mkSubsTab :: forall aliases.
Pat (aliases, LetDecMem) -> [SubExp] -> FreeVarSubsts
mkSubsTab Pat (aliases, LetDecMem)
pat [SubExp]
res =
  let pat_elms :: [PatElem (aliases, LetDecMem)]
pat_elms = forall dec. Pat dec -> [PatElem dec]
patElems Pat (aliases, LetDecMem)
pat
   in forall k a. Ord k => [(k, a)] -> Map k a
M.fromList forall a b. (a -> b) -> a -> b
$ forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe forall {a} {d} {u} {ret}.
(PatElem (a, MemInfo d u ret), SubExp)
-> Maybe (VName, TPrimExp Int64 VName)
mki64subst forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem (aliases, LetDecMem)]
pat_elms [SubExp]
res
  where
    mki64subst :: (PatElem (a, MemInfo d u ret), SubExp)
-> Maybe (VName, TPrimExp Int64 VName)
mki64subst (PatElem (a, MemInfo d u ret)
a, Var VName
v)
      | (a
_, MemPrim (IntType IntType
Int64)) <- forall dec. PatElem dec -> dec
patElemDec PatElem (a, MemInfo d u ret)
a = forall a. a -> Maybe a
Just (forall dec. PatElem dec -> VName
patElemName PatElem (a, MemInfo d u ret)
a, forall a. a -> TPrimExp Int64 a
le64 VName
v)
    mki64subst (PatElem (a, MemInfo d u ret)
a, se :: SubExp
se@(Constant (IntValue (Int64Value Int64
_)))) = forall a. a -> Maybe a
Just (forall dec. PatElem dec -> VName
patElemName PatElem (a, MemInfo d u ret)
a, SubExp -> TPrimExp Int64 VName
pe64 SubExp
se)
    mki64subst (PatElem (a, MemInfo d u ret), SubExp)
_ = forall a. Maybe a
Nothing

computeScalarTable ::
  (Coalesceable rep inner) =>
  ScopeTab rep ->
  Stm (Aliases rep) ->
  ScalarTableM rep (M.Map VName (PrimExp VName))
computeScalarTable :: forall {k} (rep :: k) inner.
Coalesceable rep inner =>
ScopeTab rep
-> Stm (Aliases rep)
-> ScalarTableM rep (Map VName (PrimExp VName))
computeScalarTable ScopeTab rep
scope_table (Let (Pat [PatElem (LetDec (Aliases rep))
pe]) StmAux (ExpDec (Aliases rep))
_ Exp (Aliases rep)
e)
  | Just PrimExp VName
primexp <- forall {k} (m :: * -> *) (rep :: k) v.
(MonadFail m, RepTypes rep) =>
(VName -> m (PrimExp v)) -> Exp rep -> m (PrimExp v)
primExpFromExp (forall {k} (rep :: k).
(CanBeAliased (Op rep), RepTypes rep) =>
ScopeTab rep
-> Map VName (PrimExp VName) -> VName -> Maybe (PrimExp VName)
vnameToPrimExp ScopeTab rep
scope_table forall a. Monoid a => a
mempty) Exp (Aliases rep)
e =
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall k a. k -> a -> Map k a
M.singleton (forall dec. PatElem dec -> VName
patElemName PatElem (LetDec (Aliases rep))
pe) PrimExp VName
primexp
computeScalarTable ScopeTab rep
scope_table (Let Pat (LetDec (Aliases rep))
_ StmAux (ExpDec (Aliases rep))
_ (DoLoop [(FParam (Aliases rep), SubExp)]
loop_inits LoopForm (Aliases rep)
loop_form Body (Aliases rep)
body)) =
  forall (m :: * -> *) b a.
(Monad m, Monoid b) =>
(a -> m b) -> [a] -> m b
concatMapM
    ( forall {k} (rep :: k) inner.
Coalesceable rep inner =>
ScopeTab rep
-> Stm (Aliases rep)
-> ScalarTableM rep (Map VName (PrimExp VName))
computeScalarTable forall a b. (a -> b) -> a -> b
$
        ScopeTab rep
scope_table
          forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k) dec.
(FParamInfo rep ~ dec) =>
[Param dec] -> Scope rep
scopeOfFParams (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(FParam (Aliases rep), SubExp)]
loop_inits)
          forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k) a. Scoped rep a => a -> Scope rep
scopeOf LoopForm (Aliases rep)
loop_form
          forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k) a. Scoped rep a => a -> Scope rep
scopeOf (forall {k} (rep :: k). Body rep -> Stms rep
bodyStms Body (Aliases rep)
body)
    )
    (forall {k} (rep :: k). Stms rep -> [Stm rep]
stmsToList forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Body rep -> Stms rep
bodyStms Body (Aliases rep)
body)
computeScalarTable ScopeTab rep
scope_table (Let Pat (LetDec (Aliases rep))
_ StmAux (ExpDec (Aliases rep))
_ (Match [SubExp]
_ [Case (Body (Aliases rep))]
cases Body (Aliases rep)
body MatchDec (BranchType (Aliases rep))
_)) = do
  Map VName (PrimExp VName)
body_tab <- forall (m :: * -> *) b a.
(Monad m, Monoid b) =>
(a -> m b) -> [a] -> m b
concatMapM (forall {k} (rep :: k) inner.
Coalesceable rep inner =>
ScopeTab rep
-> Stm (Aliases rep)
-> ScalarTableM rep (Map VName (PrimExp VName))
computeScalarTable forall a b. (a -> b) -> a -> b
$ ScopeTab rep
scope_table forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k) a. Scoped rep a => a -> Scope rep
scopeOf (forall {k} (rep :: k). Body rep -> Stms rep
bodyStms Body (Aliases rep)
body)) (forall {k} (rep :: k). Stms rep -> [Stm rep]
stmsToList forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Body rep -> Stms rep
bodyStms Body (Aliases rep)
body)
  Map VName (PrimExp VName)
cases_tab <-
    forall (m :: * -> *) b a.
(Monad m, Monoid b) =>
(a -> m b) -> [a] -> m b
concatMapM
      ( \(Case [Maybe PrimValue]
_ Body (Aliases rep)
b) ->
          forall (m :: * -> *) b a.
(Monad m, Monoid b) =>
(a -> m b) -> [a] -> m b
concatMapM
            (forall {k} (rep :: k) inner.
Coalesceable rep inner =>
ScopeTab rep
-> Stm (Aliases rep)
-> ScalarTableM rep (Map VName (PrimExp VName))
computeScalarTable forall a b. (a -> b) -> a -> b
$ ScopeTab rep
scope_table forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k) a. Scoped rep a => a -> Scope rep
scopeOf (forall {k} (rep :: k). Body rep -> Stms rep
bodyStms Body (Aliases rep)
b))
            ( forall {k} (rep :: k). Stms rep -> [Stm rep]
stmsToList forall a b. (a -> b) -> a -> b
$
                forall {k} (rep :: k). Body rep -> Stms rep
bodyStms Body (Aliases rep)
body
            )
      )
      [Case (Body (Aliases rep))]
cases
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Map VName (PrimExp VName)
body_tab forall a. Semigroup a => a -> a -> a
<> Map VName (PrimExp VName)
cases_tab
computeScalarTable ScopeTab rep
scope_table (Let Pat (LetDec (Aliases rep))
_ StmAux (ExpDec (Aliases rep))
_ (Op Op (Aliases rep)
op)) = do
  ScopeTab rep
-> MemOp (OpWithAliases inner)
-> ReaderT
     (ComputeScalarTableOnOp rep) Identity (Map VName (PrimExp VName))
on_op <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall {k} (rep :: k).
ComputeScalarTableOnOp rep
-> ScopeTab rep
-> Op (Aliases rep)
-> ScalarTableM rep (Map VName (PrimExp VName))
scalarTableOnOp
  ScopeTab rep
-> MemOp (OpWithAliases inner)
-> ReaderT
     (ComputeScalarTableOnOp rep) Identity (Map VName (PrimExp VName))
on_op ScopeTab rep
scope_table Op (Aliases rep)
op
computeScalarTable ScopeTab rep
_ Stm (Aliases rep)
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Monoid a => a
mempty

computeScalarTableGPUMem :: ScopeTab GPUMem -> Op (Aliases GPUMem) -> ScalarTableM GPUMem (M.Map VName (PrimExp VName))
computeScalarTableGPUMem :: ScopeTab GPUMem
-> Op (Aliases GPUMem)
-> ScalarTableM GPUMem (Map VName (PrimExp VName))
computeScalarTableGPUMem ScopeTab GPUMem
_ (Alloc SubExp
_ Space
_) = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Monoid a => a
mempty
computeScalarTableGPUMem ScopeTab GPUMem
scope_table (Inner (SegOp SegOp SegLevel (Aliases GPUMem)
segop)) = do
  forall (m :: * -> *) b a.
(Monad m, Monoid b) =>
(a -> m b) -> [a] -> m b
concatMapM
    (forall {k} (rep :: k) inner.
Coalesceable rep inner =>
ScopeTab rep
-> Stm (Aliases rep)
-> ScalarTableM rep (Map VName (PrimExp VName))
computeScalarTable forall a b. (a -> b) -> a -> b
$ ScopeTab GPUMem
scope_table forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k) a. Scoped rep a => a -> Scope rep
scopeOf (forall {k} (rep :: k). KernelBody rep -> Stms rep
kernelBodyStms forall a b. (a -> b) -> a -> b
$ forall {k} lvl (rep :: k). SegOp lvl rep -> KernelBody rep
segBody SegOp SegLevel (Aliases GPUMem)
segop) forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k). SegSpace -> Scope rep
scopeOfSegSpace (forall {k} lvl (rep :: k). SegOp lvl rep -> SegSpace
segSpace SegOp SegLevel (Aliases GPUMem)
segop))
    (forall {k} (rep :: k). Stms rep -> [Stm rep]
stmsToList forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). KernelBody rep -> Stms rep
kernelBodyStms forall a b. (a -> b) -> a -> b
$ forall {k} lvl (rep :: k). SegOp lvl rep -> KernelBody rep
segBody SegOp SegLevel (Aliases GPUMem)
segop)
computeScalarTableGPUMem ScopeTab GPUMem
_ (Inner (SizeOp SizeOp
_)) = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Monoid a => a
mempty
computeScalarTableGPUMem ScopeTab GPUMem
_ (Inner (OtherOp ())) = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Monoid a => a
mempty
computeScalarTableGPUMem ScopeTab GPUMem
scope_table (Inner (GPUBody [Type]
_ Body (Aliases GPUMem)
body)) =
  forall (m :: * -> *) b a.
(Monad m, Monoid b) =>
(a -> m b) -> [a] -> m b
concatMapM
    (forall {k} (rep :: k) inner.
Coalesceable rep inner =>
ScopeTab rep
-> Stm (Aliases rep)
-> ScalarTableM rep (Map VName (PrimExp VName))
computeScalarTable forall a b. (a -> b) -> a -> b
$ ScopeTab GPUMem
scope_table forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k) a. Scoped rep a => a -> Scope rep
scopeOf (forall {k} (rep :: k). Body rep -> Stms rep
bodyStms Body (Aliases GPUMem)
body))
    (forall {k} (rep :: k). Stms rep -> [Stm rep]
stmsToList forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Body rep -> Stms rep
bodyStms Body (Aliases GPUMem)
body)

filterMapM1 :: (Eq k, Monad m) => (v -> m Bool) -> M.Map k v -> m (M.Map k v)
filterMapM1 :: forall k (m :: * -> *) v.
(Eq k, Monad m) =>
(v -> m Bool) -> Map k v -> m (Map k v)
filterMapM1 v -> m Bool
f Map k v
m = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall k a. Eq k => [(k, a)] -> Map k a
M.fromAscList forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a.
Applicative m =>
(a -> m Bool) -> [a] -> m [a]
filterM (v -> m Bool
f forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd) forall a b. (a -> b) -> a -> b
$ forall k a. Map k a -> [(k, a)]
M.toAscList Map k v
m