{-# LANGUAGE TypeFamilies #-}

module Futhark.IR.Mem.Simplify
  ( simplifyProgGeneric,
    simplifyStmsGeneric,
    simpleGeneric,
    SimplifyMemory,
  )
where

import Control.Monad
import Data.List (find)
import Futhark.Analysis.SymbolTable qualified as ST
import Futhark.Analysis.UsageTable qualified as UT
import Futhark.Construct
import Futhark.IR.Mem
import Futhark.IR.Mem.IxFun qualified as IxFun
import Futhark.IR.Prop.Aliases (AliasedOp)
import Futhark.Optimise.Simplify qualified as Simplify
import Futhark.Optimise.Simplify.Engine qualified as Engine
import Futhark.Optimise.Simplify.Rep
import Futhark.Optimise.Simplify.Rule
import Futhark.Optimise.Simplify.Rules
import Futhark.Pass
import Futhark.Pass.ExplicitAllocations (simplifiable)
import Futhark.Util

-- | Some constraints that must hold for the simplification rules to work.
type SimplifyMemory rep inner =
  ( Simplify.SimplifiableRep rep,
    LetDec rep ~ LetDecMem,
    ExpDec rep ~ (),
    BodyDec rep ~ (),
    CanBeWise (Op rep),
    BuilderOps (Wise rep),
    OpReturns (OpWithWisdom inner),
    ST.IndexOp (OpWithWisdom inner),
    AliasedOp (OpWithWisdom inner),
    Mem rep inner
  )

simpleGeneric ::
  (SimplifyMemory rep inner) =>
  (OpWithWisdom inner -> UT.UsageTable) ->
  Simplify.SimplifyOp rep (OpWithWisdom inner) ->
  Simplify.SimpleOps rep
simpleGeneric :: forall {k} (rep :: k) inner.
SimplifyMemory rep inner =>
(OpWithWisdom inner -> UsageTable)
-> SimplifyOp rep (OpWithWisdom inner) -> SimpleOps rep
simpleGeneric = forall {k} (rep :: k) inner.
(SimplifiableRep rep, LetDec rep ~ LetDecMem, ExpDec rep ~ (),
 BodyDec rep ~ (), OpReturns (OpWithWisdom inner),
 AliasedOp (OpWithWisdom inner), IndexOp (OpWithWisdom inner),
 Mem rep inner) =>
(OpWithWisdom inner -> UsageTable)
-> (OpWithWisdom inner
    -> SimpleM rep (OpWithWisdom inner, Stms (Wise rep)))
-> SimpleOps rep
simplifiable

simplifyProgGeneric ::
  (SimplifyMemory rep inner) =>
  Simplify.SimpleOps rep ->
  Prog rep ->
  PassM (Prog rep)
simplifyProgGeneric :: forall {k} (rep :: k) inner.
SimplifyMemory rep inner =>
SimpleOps rep -> Prog rep -> PassM (Prog rep)
simplifyProgGeneric SimpleOps rep
ops =
  forall {k} (rep :: k).
SimplifiableRep rep =>
SimpleOps rep
-> RuleBook (Wise rep)
-> HoistBlockers rep
-> Prog rep
-> PassM (Prog rep)
Simplify.simplifyProg
    SimpleOps rep
ops
    forall {k} (rep :: k) inner.
SimplifyMemory rep inner =>
RuleBook (Wise rep)
callKernelRules
    forall {k} (rep :: k) inner.
(Op rep ~ MemOp inner) =>
HoistBlockers rep
blockers {blockHoistBranch :: BlockPred (Wise rep)
Engine.blockHoistBranch = forall {k} {k} {rep :: k} {inner} {rep :: k} {p}.
(Op rep ~ MemOp inner, Typed (LetDec rep)) =>
SymbolTable rep -> p -> Stm rep -> Bool
blockAllocs}
  where
    blockAllocs :: SymbolTable rep -> p -> Stm rep -> Bool
blockAllocs SymbolTable rep
vtable p
_ (Let Pat (LetDec rep)
_ StmAux (ExpDec rep)
_ (Op Alloc {})) =
      Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). SymbolTable rep -> Bool
ST.simplifyMemory SymbolTable rep
vtable
    -- Do not hoist statements that produce arrays.  This is
    -- because in the KernelsMem representation, multiple
    -- arrays can be located in the same memory block, and moving
    -- their creation out of a branch can thus cause memory
    -- corruption.  At this point in the compiler we have probably
    -- already moved all the array creations that matter.
    blockAllocs SymbolTable rep
_ p
_ (Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ Exp rep
_) =
      Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all forall shape u. TypeBase shape u -> Bool
primType forall a b. (a -> b) -> a -> b
$ forall dec. Typed dec => Pat dec -> [Type]
patTypes Pat (LetDec rep)
pat

simplifyStmsGeneric ::
  ( HasScope rep m,
    MonadFreshNames m,
    SimplifyMemory rep inner
  ) =>
  Simplify.SimpleOps rep ->
  Stms rep ->
  m (Stms rep)
simplifyStmsGeneric :: forall {k} (rep :: k) (m :: * -> *) inner.
(HasScope rep m, MonadFreshNames m, SimplifyMemory rep inner) =>
SimpleOps rep -> Stms rep -> m (Stms rep)
simplifyStmsGeneric SimpleOps rep
ops Stms rep
stms = do
  Scope rep
scope <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
m (Scope rep)
askScope
  forall {k} (m :: * -> *) (rep :: k).
(MonadFreshNames m, SimplifiableRep rep) =>
SimpleOps rep
-> RuleBook (Wise rep)
-> HoistBlockers rep
-> Scope rep
-> Stms rep
-> m (Stms rep)
Simplify.simplifyStms
    SimpleOps rep
ops
    forall {k} (rep :: k) inner.
SimplifyMemory rep inner =>
RuleBook (Wise rep)
callKernelRules
    forall {k} (rep :: k) inner.
(Op rep ~ MemOp inner) =>
HoistBlockers rep
blockers
    Scope rep
scope
    Stms rep
stms

isResultAlloc :: Op rep ~ MemOp op => Engine.BlockPred rep
isResultAlloc :: forall {k} (rep :: k) op. (Op rep ~ MemOp op) => BlockPred rep
isResultAlloc SymbolTable rep
_ UsageTable
usage (Let (Pat [PatElem (LetDec rep)
pe]) StmAux (ExpDec rep)
_ (Op Alloc {})) =
  VName -> UsageTable -> Bool
UT.isInResult (forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe) UsageTable
usage
isResultAlloc SymbolTable rep
_ UsageTable
_ Stm rep
_ = Bool
False

isAlloc :: Op rep ~ MemOp op => Engine.BlockPred rep
isAlloc :: forall {k} (rep :: k) op. (Op rep ~ MemOp op) => BlockPred rep
isAlloc SymbolTable rep
_ UsageTable
_ (Let Pat (LetDec rep)
_ StmAux (ExpDec rep)
_ (Op Alloc {})) = Bool
True
isAlloc SymbolTable rep
_ UsageTable
_ Stm rep
_ = Bool
False

blockers ::
  (Op rep ~ MemOp inner) =>
  Simplify.HoistBlockers rep
blockers :: forall {k} (rep :: k) inner.
(Op rep ~ MemOp inner) =>
HoistBlockers rep
blockers =
  forall {k} (rep :: k). HoistBlockers rep
Engine.noExtraHoistBlockers
    { blockHoistPar :: BlockPred (Wise rep)
Engine.blockHoistPar = forall {k} (rep :: k) op. (Op rep ~ MemOp op) => BlockPred rep
isAlloc,
      blockHoistSeq :: BlockPred (Wise rep)
Engine.blockHoistSeq = forall {k} (rep :: k) op. (Op rep ~ MemOp op) => BlockPred rep
isResultAlloc,
      isAllocation :: Stm (Wise rep) -> Bool
Engine.isAllocation = forall {k} (rep :: k) op. (Op rep ~ MemOp op) => BlockPred rep
isAlloc forall a. Monoid a => a
mempty forall a. Monoid a => a
mempty
    }

callKernelRules :: SimplifyMemory rep inner => RuleBook (Wise rep)
callKernelRules :: forall {k} (rep :: k) inner.
SimplifyMemory rep inner =>
RuleBook (Wise rep)
callKernelRules =
  forall rep.
(BuilderOps rep, TraverseOpStms rep, Aliased rep) =>
RuleBook rep
standardRules
    forall a. Semigroup a => a -> a -> a
<> forall {k} (m :: k).
[TopDownRule m] -> [BottomUpRule m] -> RuleBook m
ruleBook
      [ forall {k} (rep :: k) a.
RuleBasicOp rep a -> SimplificationRule rep a
RuleBasicOp forall rep u.
(BuilderOps rep, LetDec rep ~ (VarWisdom, MemBound u)) =>
TopDownRuleBasicOp rep
copyCopyToCopy,
        forall {k} (rep :: k) a.
RuleMatch rep a -> SimplificationRule rep a
RuleMatch forall {k} (rep :: k) inner.
SimplifyMemory rep inner =>
TopDownRuleMatch (Wise rep)
unExistentialiseMemory,
        forall {k} (rep :: k) a. RuleOp rep a -> SimplificationRule rep a
RuleOp forall {k} (rep :: k) inner.
SimplifyMemory rep inner =>
TopDownRuleOp (Wise rep)
decertifySafeAlloc
      ]
      []

-- | If a branch is returning some existential memory, but the size of
-- the array is not existential, and the index function of the array
-- does not refer to any names in the pattern, then we can create a
-- block of the proper size and always return there.
unExistentialiseMemory :: SimplifyMemory rep inner => TopDownRuleMatch (Wise rep)
unExistentialiseMemory :: forall {k} (rep :: k) inner.
SimplifyMemory rep inner =>
TopDownRuleMatch (Wise rep)
unExistentialiseMemory TopDown (Wise rep)
vtable Pat (LetDec (Wise rep))
pat StmAux (ExpDec (Wise rep))
_ ([SubExp]
cond, [Case (Body (Wise rep))]
cases, Body (Wise rep)
defbody, MatchDec (BranchType (Wise rep))
ifdec)
  | forall {k} (rep :: k). SymbolTable rep -> Bool
ST.simplifyMemory TopDown (Wise rep)
vtable,
    [(PatElem (VarWisdom, LetDecMem), PrimExp VName, VName, Space)]
fixable <- forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl [(PatElem (VarWisdom, LetDecMem), PrimExp VName, VName, Space)]
-> PatElem (VarWisdom, LetDecMem)
-> [(PatElem (VarWisdom, LetDecMem), PrimExp VName, VName, Space)]
hasConcretisableMemory forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec (Wise rep))
pat,
    Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(PatElem (VarWisdom, LetDecMem), PrimExp VName, VName, Space)]
fixable = forall {k} (rep :: k). RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$ do
      -- Create non-existential memory blocks big enough to hold the
      -- arrays.
      ([(VName, VName)]
arr_to_mem, [(VName, VName)]
oldmem_to_mem) <-
        forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$
          forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(PatElem (VarWisdom, LetDecMem), PrimExp VName, VName, Space)]
fixable forall a b. (a -> b) -> a -> b
$ \(PatElem (VarWisdom, LetDecMem)
arr_pe, PrimExp VName
mem_size, VName
oldmem, Space
space) -> do
            SubExp
size <- forall (m :: * -> *) a.
(MonadBuilder m, ToExp a) =>
String -> a -> m SubExp
toSubExp String
"size" PrimExp VName
mem_size
            VName
mem <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"mem" forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall inner. SubExp -> Space -> MemOp inner
Alloc SubExp
size Space
space
            forall (f :: * -> *) a. Applicative f => a -> f a
pure ((forall dec. PatElem dec -> VName
patElemName PatElem (VarWisdom, LetDecMem)
arr_pe, VName
mem), (VName
oldmem, VName
mem))

      -- Update the branches to contain Copy expressions putting the
      -- arrays where they are expected.
      let updateBody :: Body (Wise rep) -> RuleM (Wise rep) (Body (Rep (RuleM (Wise rep))))
updateBody Body (Wise rep)
body = forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ forall a b. (a -> b) -> a -> b
$ do
            Result
res <- forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind Body (Wise rep)
body
            forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM PatElem (VarWisdom, LetDecMem)
-> SubExpRes -> RuleM (Wise rep) SubExpRes
updateResult (forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec (Wise rep))
pat) Result
res
          updateResult :: PatElem (VarWisdom, LetDecMem)
-> SubExpRes -> RuleM (Wise rep) SubExpRes
updateResult PatElem (VarWisdom, LetDecMem)
pat_elem (SubExpRes Certs
cs (Var VName
v))
            | Just VName
mem <- forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup (forall dec. PatElem dec -> VName
patElemName PatElem (VarWisdom, LetDecMem)
pat_elem) [(VName, VName)]
arr_to_mem,
              (VarWisdom
_, MemArray PrimType
pt ShapeBase SubExp
shape NoUniqueness
u (ArrayIn VName
_ IxFun
ixfun)) <- forall dec. PatElem dec -> dec
patElemDec PatElem (VarWisdom, LetDecMem)
pat_elem = do
                VName
v_copy <- forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName forall a b. (a -> b) -> a -> b
$ VName -> String
baseString VName
v forall a. Semigroup a => a -> a -> a
<> String
"_nonext_copy"
                let v_pat :: Pat LetDecMem
v_pat =
                      forall dec. [PatElem dec] -> Pat dec
Pat [forall dec. VName -> dec -> PatElem dec
PatElem VName
v_copy forall a b. (a -> b) -> a -> b
$ forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase SubExp
shape NoUniqueness
u forall a b. (a -> b) -> a -> b
$ VName -> IxFun -> MemBind
ArrayIn VName
mem IxFun
ixfun]
                forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
(ASTRep rep, CanBeWise (Op rep)) =>
Pat (LetDec rep)
-> StmAux (ExpDec rep) -> Exp (Wise rep) -> Stm (Wise rep)
mkWiseStm Pat LetDecMem
v_pat (forall dec. dec -> StmAux dec
defAux ()) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp (VName -> BasicOp
Copy VName
v)
                forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Certs -> SubExp -> SubExpRes
SubExpRes Certs
cs forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v_copy
            | Just VName
mem <- forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup (forall dec. PatElem dec -> VName
patElemName PatElem (VarWisdom, LetDecMem)
pat_elem) [(VName, VName)]
oldmem_to_mem =
                forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Certs -> SubExp -> SubExpRes
SubExpRes Certs
cs forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
mem
          updateResult PatElem (VarWisdom, LetDecMem)
_ SubExpRes
se =
            forall (f :: * -> *) a. Applicative f => a -> f a
pure SubExpRes
se
      [Case (Body (Wise rep))]
cases' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse Body (Wise rep) -> RuleM (Wise rep) (Body (Rep (RuleM (Wise rep))))
updateBody) [Case (Body (Wise rep))]
cases
      Body (Wise rep)
defbody' <- Body (Wise rep) -> RuleM (Wise rep) (Body (Rep (RuleM (Wise rep))))
updateBody Body (Wise rep)
defbody
      forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec (Wise rep))
pat forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
[SubExp]
-> [Case (Body rep)]
-> Body rep
-> MatchDec (BranchType rep)
-> Exp rep
Match [SubExp]
cond [Case (Body (Wise rep))]
cases' Body (Wise rep)
defbody' MatchDec (BranchType (Wise rep))
ifdec
  where
    onlyUsedIn :: VName -> VName -> Bool
onlyUsedIn VName
name VName
here =
      Bool -> Bool
not forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any ((VName
name `nameIn`) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. FreeIn a => a -> Names
freeIn) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. (a -> Bool) -> [a] -> [a]
filter ((forall a. Eq a => a -> a -> Bool
/= VName
here) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. PatElem dec -> VName
patElemName) forall a b. (a -> b) -> a -> b
$
        forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec (Wise rep))
pat
    knownSize :: SubExp -> Bool
knownSize Constant {} = Bool
True
    knownSize (Var VName
v) = Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ VName -> Bool
inContext VName
v
    inContext :: VName -> Bool
inContext = (forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` forall dec. Pat dec -> [VName]
patNames Pat (LetDec (Wise rep))
pat)

    hasConcretisableMemory :: [(PatElem (VarWisdom, LetDecMem), PrimExp VName, VName, Space)]
-> PatElem (VarWisdom, LetDecMem)
-> [(PatElem (VarWisdom, LetDecMem), PrimExp VName, VName, Space)]
hasConcretisableMemory [(PatElem (VarWisdom, LetDecMem), PrimExp VName, VName, Space)]
fixable PatElem (VarWisdom, LetDecMem)
pat_elem
      | (VarWisdom
_, MemArray PrimType
pt ShapeBase SubExp
shape NoUniqueness
_ (ArrayIn VName
mem IxFun
ixfun)) <- forall dec. PatElem dec -> dec
patElemDec PatElem (VarWisdom, LetDecMem)
pat_elem,
        Just (Int
j, Mem Space
space) <-
          forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall dec. Typed dec => PatElem dec -> Type
patElemType
            forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find
              ((VName
mem ==) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. PatElem dec -> VName
patElemName forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd)
              (forall a b. [a] -> [b] -> [(a, b)]
zip [(Int
0 :: Int) ..] forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec (Wise rep))
pat),
        Just Result
cases_ses <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall int a. Integral int => int -> [a] -> Maybe a
maybeNth Int
j forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Body rep -> Result
bodyResult forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall body. Case body -> body
caseBody) [Case (Body (Wise rep))]
cases,
        Just SubExpRes
defbody_se <- forall int a. Integral int => int -> [a] -> Maybe a
maybeNth Int
j forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Body rep -> Result
bodyResult Body (Wise rep)
defbody,
        VName
mem VName -> VName -> Bool
`onlyUsedIn` forall dec. PatElem dec -> VName
patElemName PatElem (VarWisdom, LetDecMem)
pat_elem,
        forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall num. IxFun num -> Shape num
IxFun.base IxFun
ixfun) forall a. Eq a => a -> a -> Bool
== forall a. ArrayShape a => a -> Int
shapeRank ShapeBase SubExp
shape, -- See #1325
        forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all SubExp -> Bool
knownSize (forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
shape),
        Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ forall a. FreeIn a => a -> Names
freeIn IxFun
ixfun Names -> Names -> Bool
`namesIntersect` [VName] -> Names
namesFromList (forall dec. Pat dec -> [VName]
patNames Pat (LetDec (Wise rep))
pat),
        forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (SubExpRes
defbody_se /=) Result
cases_ses =
          let mem_size :: PrimExp VName
mem_size =
                forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product forall a b. (a -> b) -> a -> b
$ forall a. Num a => PrimType -> a
primByteSize PrimType
pt forall a. a -> [a] -> [a]
: forall a b. (a -> b) -> [a] -> [b]
map forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (forall num. IxFun num -> Shape num
IxFun.base IxFun
ixfun)
           in (PatElem (VarWisdom, LetDecMem)
pat_elem, PrimExp VName
mem_size, VName
mem, Space
space) forall a. a -> [a] -> [a]
: [(PatElem (VarWisdom, LetDecMem), PrimExp VName, VName, Space)]
fixable
      | Bool
otherwise =
          [(PatElem (VarWisdom, LetDecMem), PrimExp VName, VName, Space)]
fixable
unExistentialiseMemory TopDown (Wise rep)
_ Pat (LetDec (Wise rep))
_ StmAux (ExpDec (Wise rep))
_ ([SubExp], [Case (Body (Wise rep))], Body (Wise rep),
 MatchDec (BranchType (Wise rep)))
_ = forall {k} (rep :: k). Rule rep
Skip

-- | If we are copying something that is itself a copy, just copy the
-- original one instead.
copyCopyToCopy ::
  ( BuilderOps rep,
    LetDec rep ~ (VarWisdom, MemBound u)
  ) =>
  TopDownRuleBasicOp rep
copyCopyToCopy :: forall rep u.
(BuilderOps rep, LetDec rep ~ (VarWisdom, MemBound u)) =>
TopDownRuleBasicOp rep
copyCopyToCopy TopDown rep
vtable pat :: Pat (LetDec rep)
pat@(Pat [PatElem (LetDec rep)
pat_elem]) StmAux (ExpDec rep)
_ (Copy VName
v1)
  | Just (BasicOp (Copy VName
v2), Certs
v1_cs) <- forall {k} (rep :: k).
VName -> SymbolTable rep -> Maybe (Exp rep, Certs)
ST.lookupExp VName
v1 TopDown rep
vtable,
    Just (VarWisdom
_, MemArray PrimType
_ ShapeBase SubExp
_ u
_ (ArrayIn VName
srcmem IxFun
src_ixfun)) <-
      forall {k} (rep :: k). Entry rep -> Maybe (LetDec rep)
ST.entryLetBoundDec forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall {k} (rep :: k).
VName -> SymbolTable rep -> Maybe (Entry rep)
ST.lookup VName
v1 TopDown rep
vtable,
    Just (Mem Space
src_space) <- forall {k} (rep :: k).
ASTRep rep =>
VName -> SymbolTable rep -> Maybe Type
ST.lookupType VName
srcmem TopDown rep
vtable,
    (VarWisdom
_, MemArray PrimType
_ ShapeBase SubExp
_ u
_ (ArrayIn VName
destmem IxFun
dest_ixfun)) <- forall dec. PatElem dec -> dec
patElemDec PatElem (LetDec rep)
pat_elem,
    Just (Mem Space
dest_space) <- forall {k} (rep :: k).
ASTRep rep =>
VName -> SymbolTable rep -> Maybe Type
ST.lookupType VName
destmem TopDown rep
vtable,
    Space
src_space forall a. Eq a => a -> a -> Bool
== Space
dest_space,
    IxFun
dest_ixfun forall a. Eq a => a -> a -> Bool
== IxFun
src_ixfun =
      forall {k} (rep :: k). RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
v1_cs forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
pat forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
v2
copyCopyToCopy TopDown rep
vtable Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ (Copy VName
v0)
  | Just (BasicOp (Rearrange [Int]
perm VName
v1), Certs
v0_cs) <- forall {k} (rep :: k).
VName -> SymbolTable rep -> Maybe (Exp rep, Certs)
ST.lookupExp VName
v0 TopDown rep
vtable,
    Just (BasicOp (Copy VName
v2), Certs
v1_cs) <- forall {k} (rep :: k).
VName -> SymbolTable rep -> Maybe (Exp rep, Certs)
ST.lookupExp VName
v1 TopDown rep
vtable = forall {k} (rep :: k). RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$ do
      VName
v0' <-
        forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (Certs
v0_cs forall a. Semigroup a => a -> a -> a
<> Certs
v1_cs) forall a b. (a -> b) -> a -> b
$
          forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"rearrange_v0" forall a b. (a -> b) -> a -> b
$
            forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
              [Int] -> VName -> BasicOp
Rearrange [Int]
perm VName
v2
      forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
pat forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
v0'
copyCopyToCopy TopDown rep
_ Pat (LetDec rep)
_ StmAux (ExpDec rep)
_ BasicOp
_ = forall {k} (rep :: k). Rule rep
Skip

-- If an allocation is statically known to be safe, then we can remove
-- the certificates on it.  This can help hoist things that would
-- otherwise be stuck inside loops or branches.
decertifySafeAlloc :: SimplifyMemory rep inner => TopDownRuleOp (Wise rep)
decertifySafeAlloc :: forall {k} (rep :: k) inner.
SimplifyMemory rep inner =>
TopDownRuleOp (Wise rep)
decertifySafeAlloc TopDown (Wise rep)
_ Pat (LetDec (Wise rep))
pat (StmAux Certs
cs Attrs
attrs ExpDec (Wise rep)
_) Op (Wise rep)
op
  | Certs
cs forall a. Eq a => a -> a -> Bool
/= forall a. Monoid a => a
mempty,
    [Mem Space
_] <- forall dec. Typed dec => Pat dec -> [Type]
patTypes Pat (LetDec (Wise rep))
pat,
    forall op. IsOp op => op -> Bool
safeOp Op (Wise rep)
op =
      forall {k} (rep :: k). RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a. MonadBuilder m => Attrs -> m a -> m a
attributing Attrs
attrs forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec (Wise rep))
pat forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Op rep -> Exp rep
Op Op (Wise rep)
op
decertifySafeAlloc TopDown (Wise rep)
_ Pat (LetDec (Wise rep))
_ StmAux (ExpDec (Wise rep))
_ Op (Wise rep)
_ = forall {k} (rep :: k). Rule rep
Skip