{-# LANGUAGE TypeFamilies #-}

module Futhark.IR.Mem.Simplify
  ( simplifyProgGeneric,
    simplifyStmsGeneric,
    simpleGeneric,
    SimplifyMemory,
    memRuleBook,
  )
where

import Control.Monad
import Data.List (find)
import Futhark.Analysis.SymbolTable qualified as ST
import Futhark.Analysis.UsageTable qualified as UT
import Futhark.Construct
import Futhark.IR.Mem
import Futhark.IR.Mem.IxFun qualified as IxFun
import Futhark.IR.Prop.Aliases (AliasedOp)
import Futhark.Optimise.Simplify qualified as Simplify
import Futhark.Optimise.Simplify.Engine qualified as Engine
import Futhark.Optimise.Simplify.Rep
import Futhark.Optimise.Simplify.Rule
import Futhark.Optimise.Simplify.Rules
import Futhark.Pass
import Futhark.Pass.ExplicitAllocations (simplifiable)
import Futhark.Util

-- | Some constraints that must hold for the simplification rules to work.
type SimplifyMemory rep inner =
  ( Simplify.SimplifiableRep rep,
    LetDec rep ~ LetDecMem,
    ExpDec rep ~ (),
    BodyDec rep ~ (),
    CanBeWise (OpC rep),
    BuilderOps (Wise rep),
    OpReturns (inner (Wise rep)),
    ST.IndexOp (inner (Wise rep)),
    AliasedOp (inner (Wise rep)),
    Mem rep inner,
    CanBeWise inner,
    RephraseOp inner
  )

simpleGeneric ::
  (SimplifyMemory rep inner) =>
  (inner (Wise rep) -> UT.UsageTable) ->
  Simplify.SimplifyOp rep (inner (Wise rep)) ->
  Simplify.SimpleOps rep
simpleGeneric :: forall rep (inner :: * -> *).
SimplifyMemory rep inner =>
(inner (Wise rep) -> UsageTable)
-> SimplifyOp rep (inner (Wise rep)) -> SimpleOps rep
simpleGeneric = (inner (Wise rep) -> UsageTable)
-> (inner (Wise rep)
    -> SimpleM rep (inner (Wise rep), Stms (Wise rep)))
-> SimpleOps rep
forall rep (inner :: * -> *).
(SimplifiableRep rep, LetDec rep ~ LetDecMem, ExpDec rep ~ (),
 BodyDec rep ~ (), Mem (Wise rep) inner, CanBeWise inner,
 RephraseOp inner, IsOp (inner rep), OpReturns (inner (Wise rep)),
 AliasedOp (inner (Wise rep)), IndexOp (inner (Wise rep))) =>
(inner (Wise rep) -> UsageTable)
-> (inner (Wise rep)
    -> SimpleM rep (inner (Wise rep), Stms (Wise rep)))
-> SimpleOps rep
simplifiable

simplifyProgGeneric ::
  (SimplifyMemory rep inner) =>
  RuleBook (Wise rep) ->
  Simplify.SimpleOps rep ->
  Prog rep ->
  PassM (Prog rep)
simplifyProgGeneric :: forall rep (inner :: * -> *).
SimplifyMemory rep inner =>
RuleBook (Wise rep)
-> SimpleOps rep -> Prog rep -> PassM (Prog rep)
simplifyProgGeneric RuleBook (Wise rep)
rules SimpleOps rep
ops =
  SimpleOps rep
-> RuleBook (Wise rep)
-> HoistBlockers rep
-> Prog rep
-> PassM (Prog rep)
forall rep.
SimplifiableRep rep =>
SimpleOps rep
-> RuleBook (Wise rep)
-> HoistBlockers rep
-> Prog rep
-> PassM (Prog rep)
Simplify.simplifyProg
    SimpleOps rep
ops
    RuleBook (Wise rep)
rules
    HoistBlockers rep
forall rep (inner :: * -> *).
(OpC rep ~ MemOp inner) =>
HoistBlockers rep
blockers {blockHoistBranch :: BlockPred (Wise rep)
Engine.blockHoistBranch = BlockPred (Wise rep)
forall {rep} {inner :: * -> *} {rep} {p}.
(OpC rep ~ MemOp inner, Typed (LetDec rep)) =>
SymbolTable rep -> p -> Stm rep -> Bool
blockAllocs}
  where
    blockAllocs :: SymbolTable rep -> p -> Stm rep -> Bool
blockAllocs SymbolTable rep
vtable p
_ (Let Pat (LetDec rep)
_ StmAux (ExpDec rep)
_ (Op Alloc {})) =
      Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ SymbolTable rep -> Bool
forall rep. SymbolTable rep -> Bool
ST.simplifyMemory SymbolTable rep
vtable
    -- Do not hoist statements that produce arrays.  This is
    -- because in the KernelsMem representation, multiple
    -- arrays can be located in the same memory block, and moving
    -- their creation out of a branch can thus cause memory
    -- corruption.  At this point in the compiler we have probably
    -- already moved all the array creations that matter.
    blockAllocs SymbolTable rep
_ p
_ (Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ Exp rep
_) =
      Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (Type -> Bool) -> [Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType ([Type] -> Bool) -> [Type] -> Bool
forall a b. (a -> b) -> a -> b
$ Pat (LetDec rep) -> [Type]
forall dec. Typed dec => Pat dec -> [Type]
patTypes Pat (LetDec rep)
pat

simplifyStmsGeneric ::
  ( HasScope rep m,
    MonadFreshNames m,
    SimplifyMemory rep inner
  ) =>
  RuleBook (Wise rep) ->
  Simplify.SimpleOps rep ->
  Stms rep ->
  m (Stms rep)
simplifyStmsGeneric :: forall rep (m :: * -> *) (inner :: * -> *).
(HasScope rep m, MonadFreshNames m, SimplifyMemory rep inner) =>
RuleBook (Wise rep) -> SimpleOps rep -> Stms rep -> m (Stms rep)
simplifyStmsGeneric RuleBook (Wise rep)
rules SimpleOps rep
ops Stms rep
stms = do
  Scope rep
scope <- m (Scope rep)
forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope
  SimpleOps rep
-> RuleBook (Wise rep)
-> HoistBlockers rep
-> Scope rep
-> Stms rep
-> m (Stms rep)
forall (m :: * -> *) rep.
(MonadFreshNames m, SimplifiableRep rep) =>
SimpleOps rep
-> RuleBook (Wise rep)
-> HoistBlockers rep
-> Scope rep
-> Stms rep
-> m (Stms rep)
Simplify.simplifyStms
    SimpleOps rep
ops
    RuleBook (Wise rep)
rules
    HoistBlockers rep
forall rep (inner :: * -> *).
(OpC rep ~ MemOp inner) =>
HoistBlockers rep
blockers
    Scope rep
scope
    Stms rep
stms

isResultAlloc :: (OpC rep ~ MemOp op) => Engine.BlockPred rep
isResultAlloc :: forall rep (op :: * -> *). (OpC rep ~ MemOp op) => BlockPred rep
isResultAlloc SymbolTable rep
_ UsageTable
usage (Let (Pat [PatElem (LetDec rep)
pe]) StmAux (ExpDec rep)
_ (Op Alloc {})) =
  VName -> UsageTable -> Bool
UT.isInResult (PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe) UsageTable
usage
isResultAlloc SymbolTable rep
_ UsageTable
_ Stm rep
_ = Bool
False

isAlloc :: (OpC rep ~ MemOp op) => Engine.BlockPred rep
isAlloc :: forall rep (op :: * -> *). (OpC rep ~ MemOp op) => BlockPred rep
isAlloc SymbolTable rep
_ UsageTable
_ (Let Pat (LetDec rep)
_ StmAux (ExpDec rep)
_ (Op Alloc {})) = Bool
True
isAlloc SymbolTable rep
_ UsageTable
_ Stm rep
_ = Bool
False

blockers ::
  (OpC rep ~ MemOp inner) =>
  Simplify.HoistBlockers rep
blockers :: forall rep (inner :: * -> *).
(OpC rep ~ MemOp inner) =>
HoistBlockers rep
blockers =
  HoistBlockers rep
forall {k} (rep :: k). HoistBlockers rep
Engine.noExtraHoistBlockers
    { blockHoistPar :: BlockPred (Wise rep)
Engine.blockHoistPar = BlockPred (Wise rep)
forall rep (op :: * -> *). (OpC rep ~ MemOp op) => BlockPred rep
isAlloc,
      blockHoistSeq :: BlockPred (Wise rep)
Engine.blockHoistSeq = BlockPred (Wise rep)
forall rep (op :: * -> *). (OpC rep ~ MemOp op) => BlockPred rep
isResultAlloc,
      isAllocation :: Stm (Wise rep) -> Bool
Engine.isAllocation = BlockPred (Wise rep)
forall rep (op :: * -> *). (OpC rep ~ MemOp op) => BlockPred rep
isAlloc SymbolTable (Wise rep)
forall a. Monoid a => a
mempty UsageTable
forall a. Monoid a => a
mempty
    }

-- | Standard collection of simplification rules for representations
-- with memory.
memRuleBook :: (SimplifyMemory rep inner) => RuleBook (Wise rep)
memRuleBook :: forall rep (inner :: * -> *).
SimplifyMemory rep inner =>
RuleBook (Wise rep)
memRuleBook =
  RuleBook (Wise rep)
forall rep. (BuilderOps rep, TraverseOpStms rep) => RuleBook rep
standardRules
    RuleBook (Wise rep) -> RuleBook (Wise rep) -> RuleBook (Wise rep)
forall a. Semigroup a => a -> a -> a
<> [TopDownRule (Wise rep)]
-> [BottomUpRule (Wise rep)] -> RuleBook (Wise rep)
forall m. [TopDownRule m] -> [BottomUpRule m] -> RuleBook m
ruleBook
      [ RuleMatch (Wise rep) (TopDown (Wise rep)) -> TopDownRule (Wise rep)
forall rep a. RuleMatch rep a -> SimplificationRule rep a
RuleMatch RuleMatch (Wise rep) (TopDown (Wise rep))
forall rep (inner :: * -> *).
SimplifyMemory rep inner =>
TopDownRuleMatch (Wise rep)
unExistentialiseMemory,
        RuleOp (Wise rep) (TopDown (Wise rep)) -> TopDownRule (Wise rep)
forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp RuleOp (Wise rep) (TopDown (Wise rep))
forall rep (inner :: * -> *).
SimplifyMemory rep inner =>
TopDownRuleOp (Wise rep)
decertifySafeAlloc
      ]
      []

-- | If a branch is returning some existential memory, but the size of
-- the array is not existential, and the index function of the array
-- does not refer to any names in the pattern, then we can create a
-- block of the proper size and always return there.
unExistentialiseMemory :: (SimplifyMemory rep inner) => TopDownRuleMatch (Wise rep)
unExistentialiseMemory :: forall rep (inner :: * -> *).
SimplifyMemory rep inner =>
TopDownRuleMatch (Wise rep)
unExistentialiseMemory TopDown (Wise rep)
vtable Pat (LetDec (Wise rep))
pat StmAux (ExpDec (Wise rep))
_ ([SubExp]
cond, [Case (Body (Wise rep))]
cases, Body (Wise rep)
defbody, MatchDec (BranchType (Wise rep))
ifdec)
  | TopDown (Wise rep) -> Bool
forall rep. SymbolTable rep -> Bool
ST.simplifyMemory TopDown (Wise rep)
vtable,
    [(PatElem (VarWisdom, LetDecMem), PrimExp VName, VName, Space)]
fixable <- ([(PatElem (VarWisdom, LetDecMem), PrimExp VName, VName, Space)]
 -> PatElem (VarWisdom, LetDecMem)
 -> [(PatElem (VarWisdom, LetDecMem), PrimExp VName, VName, Space)])
-> [(PatElem (VarWisdom, LetDecMem), PrimExp VName, VName, Space)]
-> [PatElem (VarWisdom, LetDecMem)]
-> [(PatElem (VarWisdom, LetDecMem), PrimExp VName, VName, Space)]
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl [(PatElem (VarWisdom, LetDecMem), PrimExp VName, VName, Space)]
-> PatElem (VarWisdom, LetDecMem)
-> [(PatElem (VarWisdom, LetDecMem), PrimExp VName, VName, Space)]
hasConcretisableMemory [(PatElem (VarWisdom, LetDecMem), PrimExp VName, VName, Space)]
forall a. Monoid a => a
mempty ([PatElem (VarWisdom, LetDecMem)]
 -> [(PatElem (VarWisdom, LetDecMem), PrimExp VName, VName, Space)])
-> [PatElem (VarWisdom, LetDecMem)]
-> [(PatElem (VarWisdom, LetDecMem), PrimExp VName, VName, Space)]
forall a b. (a -> b) -> a -> b
$ Pat (VarWisdom, LetDecMem) -> [PatElem (VarWisdom, LetDecMem)]
forall dec. Pat dec -> [PatElem dec]
patElems Pat (VarWisdom, LetDecMem)
Pat (LetDec (Wise rep))
pat,
    Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [(PatElem (VarWisdom, LetDecMem), PrimExp VName, VName, Space)]
-> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(PatElem (VarWisdom, LetDecMem), PrimExp VName, VName, Space)]
fixable = RuleM (Wise rep) () -> Rule (Wise rep)
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM (Wise rep) () -> Rule (Wise rep))
-> RuleM (Wise rep) () -> Rule (Wise rep)
forall a b. (a -> b) -> a -> b
$ do
      -- Create non-existential memory blocks big enough to hold the
      -- arrays.
      ([(VName, VName)]
arr_to_mem, [(VName, VName)]
oldmem_to_mem) <-
        ([((VName, VName), (VName, VName))]
 -> ([(VName, VName)], [(VName, VName)]))
-> RuleM (Wise rep) [((VName, VName), (VName, VName))]
-> RuleM (Wise rep) ([(VName, VName)], [(VName, VName)])
forall a b. (a -> b) -> RuleM (Wise rep) a -> RuleM (Wise rep) b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [((VName, VName), (VName, VName))]
-> ([(VName, VName)], [(VName, VName)])
forall a b. [(a, b)] -> ([a], [b])
unzip (RuleM (Wise rep) [((VName, VName), (VName, VName))]
 -> RuleM (Wise rep) ([(VName, VName)], [(VName, VName)]))
-> RuleM (Wise rep) [((VName, VName), (VName, VName))]
-> RuleM (Wise rep) ([(VName, VName)], [(VName, VName)])
forall a b. (a -> b) -> a -> b
$
          [(PatElem (VarWisdom, LetDecMem), PrimExp VName, VName, Space)]
-> ((PatElem (VarWisdom, LetDecMem), PrimExp VName, VName, Space)
    -> RuleM (Wise rep) ((VName, VName), (VName, VName)))
-> RuleM (Wise rep) [((VName, VName), (VName, VName))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(PatElem (VarWisdom, LetDecMem), PrimExp VName, VName, Space)]
fixable (((PatElem (VarWisdom, LetDecMem), PrimExp VName, VName, Space)
  -> RuleM (Wise rep) ((VName, VName), (VName, VName)))
 -> RuleM (Wise rep) [((VName, VName), (VName, VName))])
-> ((PatElem (VarWisdom, LetDecMem), PrimExp VName, VName, Space)
    -> RuleM (Wise rep) ((VName, VName), (VName, VName)))
-> RuleM (Wise rep) [((VName, VName), (VName, VName))]
forall a b. (a -> b) -> a -> b
$ \(PatElem (VarWisdom, LetDecMem)
arr_pe, PrimExp VName
mem_size, VName
oldmem, Space
space) -> do
            SubExp
size <- String -> PrimExp VName -> RuleM (Wise rep) SubExp
forall (m :: * -> *) a.
(MonadBuilder m, ToExp a) =>
String -> a -> m SubExp
toSubExp String
"size" PrimExp VName
mem_size
            VName
mem <- String -> Exp (Rep (RuleM (Wise rep))) -> RuleM (Wise rep) VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"mem" (Exp (Rep (RuleM (Wise rep))) -> RuleM (Wise rep) VName)
-> Exp (Rep (RuleM (Wise rep))) -> RuleM (Wise rep) VName
forall a b. (a -> b) -> a -> b
$ Op (Rep (RuleM (Wise rep))) -> Exp (Rep (RuleM (Wise rep)))
forall rep. Op rep -> Exp rep
Op (Op (Rep (RuleM (Wise rep))) -> Exp (Rep (RuleM (Wise rep))))
-> Op (Rep (RuleM (Wise rep))) -> Exp (Rep (RuleM (Wise rep)))
forall a b. (a -> b) -> a -> b
$ SubExp -> Space -> MemOp inner (Wise rep)
forall (inner :: * -> *) rep. SubExp -> Space -> MemOp inner rep
Alloc SubExp
size Space
space
            ((VName, VName), (VName, VName))
-> RuleM (Wise rep) ((VName, VName), (VName, VName))
forall a. a -> RuleM (Wise rep) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((PatElem (VarWisdom, LetDecMem) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (VarWisdom, LetDecMem)
arr_pe, VName
mem), (VName
oldmem, VName
mem))

      -- Update the branches to contain Copy expressions putting the
      -- arrays where they are expected.
      let updateBody :: Body (Wise rep) -> RuleM (Wise rep) (Body (Rep (RuleM (Wise rep))))
updateBody Body (Wise rep)
body = RuleM (Wise rep) Result
-> RuleM (Wise rep) (Body (Rep (RuleM (Wise rep))))
forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ (RuleM (Wise rep) Result
 -> RuleM (Wise rep) (Body (Rep (RuleM (Wise rep)))))
-> RuleM (Wise rep) Result
-> RuleM (Wise rep) (Body (Rep (RuleM (Wise rep))))
forall a b. (a -> b) -> a -> b
$ do
            Result
res <- Body (Rep (RuleM (Wise rep))) -> RuleM (Wise rep) Result
forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind Body (Rep (RuleM (Wise rep)))
Body (Wise rep)
body
            (PatElem (VarWisdom, LetDecMem)
 -> SubExpRes -> RuleM (Wise rep) SubExpRes)
-> [PatElem (VarWisdom, LetDecMem)]
-> Result
-> RuleM (Wise rep) Result
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM PatElem (VarWisdom, LetDecMem)
-> SubExpRes -> RuleM (Wise rep) SubExpRes
updateResult (Pat (VarWisdom, LetDecMem) -> [PatElem (VarWisdom, LetDecMem)]
forall dec. Pat dec -> [PatElem dec]
patElems Pat (VarWisdom, LetDecMem)
Pat (LetDec (Wise rep))
pat) Result
res
          updateResult :: PatElem (VarWisdom, LetDecMem)
-> SubExpRes -> RuleM (Wise rep) SubExpRes
updateResult PatElem (VarWisdom, LetDecMem)
pat_elem (SubExpRes Certs
cs (Var VName
v))
            | Just VName
mem <- VName -> [(VName, VName)] -> Maybe VName
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup (PatElem (VarWisdom, LetDecMem) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (VarWisdom, LetDecMem)
pat_elem) [(VName, VName)]
arr_to_mem,
              (VarWisdom
_, MemArray PrimType
pt ShapeBase SubExp
shape NoUniqueness
u (ArrayIn VName
_ IxFun
ixfun)) <- PatElem (VarWisdom, LetDecMem) -> (VarWisdom, LetDecMem)
forall dec. PatElem dec -> dec
patElemDec PatElem (VarWisdom, LetDecMem)
pat_elem = do
                VName
v_copy <- String -> RuleM (Wise rep) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> RuleM (Wise rep) VName)
-> String -> RuleM (Wise rep) VName
forall a b. (a -> b) -> a -> b
$ VName -> String
baseString VName
v String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_nonext_copy"
                let v_pat :: Pat LetDecMem
v_pat =
                      [PatElem LetDecMem] -> Pat LetDecMem
forall dec. [PatElem dec] -> Pat dec
Pat [VName -> LetDecMem -> PatElem LetDecMem
forall dec. VName -> dec -> PatElem dec
PatElem VName
v_copy (LetDecMem -> PatElem LetDecMem) -> LetDecMem -> PatElem LetDecMem
forall a b. (a -> b) -> a -> b
$ PrimType
-> ShapeBase SubExp -> NoUniqueness -> MemBind -> LetDecMem
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase SubExp
shape NoUniqueness
u (MemBind -> LetDecMem) -> MemBind -> LetDecMem
forall a b. (a -> b) -> a -> b
$ VName -> IxFun -> MemBind
ArrayIn VName
mem IxFun
ixfun]
                Stm (Rep (RuleM (Wise rep))) -> RuleM (Wise rep) ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm (Stm (Rep (RuleM (Wise rep))) -> RuleM (Wise rep) ())
-> Stm (Rep (RuleM (Wise rep))) -> RuleM (Wise rep) ()
forall a b. (a -> b) -> a -> b
$ Pat (LetDec rep)
-> StmAux (ExpDec rep) -> Exp (Wise rep) -> Stm (Wise rep)
forall rep.
Informing rep =>
Pat (LetDec rep)
-> StmAux (ExpDec rep) -> Exp (Wise rep) -> Stm (Wise rep)
mkWiseStm Pat (LetDec rep)
Pat LetDecMem
v_pat (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp (Wise rep) -> Stm (Wise rep))
-> Exp (Wise rep) -> Stm (Wise rep)
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Wise rep)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Wise rep)) -> BasicOp -> Exp (Wise rep)
forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> SubExp -> BasicOp
Replicate ShapeBase SubExp
forall a. Monoid a => a
mempty (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v
                SubExpRes -> RuleM (Wise rep) SubExpRes
forall a. a -> RuleM (Wise rep) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExpRes -> RuleM (Wise rep) SubExpRes)
-> SubExpRes -> RuleM (Wise rep) SubExpRes
forall a b. (a -> b) -> a -> b
$ Certs -> SubExp -> SubExpRes
SubExpRes Certs
cs (SubExp -> SubExpRes) -> SubExp -> SubExpRes
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v_copy
            | Just VName
mem <- VName -> [(VName, VName)] -> Maybe VName
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup (PatElem (VarWisdom, LetDecMem) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (VarWisdom, LetDecMem)
pat_elem) [(VName, VName)]
oldmem_to_mem =
                SubExpRes -> RuleM (Wise rep) SubExpRes
forall a. a -> RuleM (Wise rep) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExpRes -> RuleM (Wise rep) SubExpRes)
-> SubExpRes -> RuleM (Wise rep) SubExpRes
forall a b. (a -> b) -> a -> b
$ Certs -> SubExp -> SubExpRes
SubExpRes Certs
cs (SubExp -> SubExpRes) -> SubExp -> SubExpRes
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
mem
          updateResult PatElem (VarWisdom, LetDecMem)
_ SubExpRes
se =
            SubExpRes -> RuleM (Wise rep) SubExpRes
forall a. a -> RuleM (Wise rep) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure SubExpRes
se
      [Case (Body (Wise rep))]
cases' <- (Case (Body (Wise rep))
 -> RuleM (Wise rep) (Case (Body (Wise rep))))
-> [Case (Body (Wise rep))]
-> RuleM (Wise rep) [Case (Body (Wise rep))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ((Body (Wise rep) -> RuleM (Wise rep) (Body (Wise rep)))
-> Case (Body (Wise rep))
-> RuleM (Wise rep) (Case (Body (Wise rep)))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> Case a -> f (Case b)
traverse Body (Wise rep) -> RuleM (Wise rep) (Body (Rep (RuleM (Wise rep))))
Body (Wise rep) -> RuleM (Wise rep) (Body (Wise rep))
updateBody) [Case (Body (Wise rep))]
cases
      Body (Wise rep)
defbody' <- Body (Wise rep) -> RuleM (Wise rep) (Body (Rep (RuleM (Wise rep))))
updateBody Body (Wise rep)
defbody
      Pat (LetDec (Rep (RuleM (Wise rep))))
-> Exp (Rep (RuleM (Wise rep))) -> RuleM (Wise rep) ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec (Rep (RuleM (Wise rep))))
Pat (LetDec (Wise rep))
pat (Exp (Rep (RuleM (Wise rep))) -> RuleM (Wise rep) ())
-> Exp (Rep (RuleM (Wise rep))) -> RuleM (Wise rep) ()
forall a b. (a -> b) -> a -> b
$ [SubExp]
-> [Case (Body (Wise rep))]
-> Body (Wise rep)
-> MatchDec (BranchType (Wise rep))
-> Exp (Wise rep)
forall rep.
[SubExp]
-> [Case (Body rep)]
-> Body rep
-> MatchDec (BranchType rep)
-> Exp rep
Match [SubExp]
cond [Case (Body (Wise rep))]
cases' Body (Wise rep)
defbody' MatchDec (BranchType (Wise rep))
ifdec
  where
    onlyUsedIn :: VName -> VName -> Bool
onlyUsedIn VName
name VName
here =
      Bool -> Bool
not (Bool -> Bool)
-> ([PatElem (VarWisdom, LetDecMem)] -> Bool)
-> [PatElem (VarWisdom, LetDecMem)]
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PatElem (VarWisdom, LetDecMem) -> Bool)
-> [PatElem (VarWisdom, LetDecMem)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any ((VName
name `nameIn`) (Names -> Bool)
-> (PatElem (VarWisdom, LetDecMem) -> Names)
-> PatElem (VarWisdom, LetDecMem)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElem (VarWisdom, LetDecMem) -> Names
forall a. FreeIn a => a -> Names
freeIn) ([PatElem (VarWisdom, LetDecMem)] -> Bool)
-> ([PatElem (VarWisdom, LetDecMem)]
    -> [PatElem (VarWisdom, LetDecMem)])
-> [PatElem (VarWisdom, LetDecMem)]
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PatElem (VarWisdom, LetDecMem) -> Bool)
-> [PatElem (VarWisdom, LetDecMem)]
-> [PatElem (VarWisdom, LetDecMem)]
forall a. (a -> Bool) -> [a] -> [a]
filter ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
/= VName
here) (VName -> Bool)
-> (PatElem (VarWisdom, LetDecMem) -> VName)
-> PatElem (VarWisdom, LetDecMem)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElem (VarWisdom, LetDecMem) -> VName
forall dec. PatElem dec -> VName
patElemName) ([PatElem (VarWisdom, LetDecMem)] -> Bool)
-> [PatElem (VarWisdom, LetDecMem)] -> Bool
forall a b. (a -> b) -> a -> b
$
        Pat (VarWisdom, LetDecMem) -> [PatElem (VarWisdom, LetDecMem)]
forall dec. Pat dec -> [PatElem dec]
patElems Pat (VarWisdom, LetDecMem)
Pat (LetDec (Wise rep))
pat
    knownSize :: SubExp -> Bool
knownSize Constant {} = Bool
True
    knownSize (Var VName
v) = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ VName -> Bool
inContext VName
v
    inContext :: VName -> Bool
inContext = (VName -> [VName] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` Pat (VarWisdom, LetDecMem) -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat (VarWisdom, LetDecMem)
Pat (LetDec (Wise rep))
pat)

    hasConcretisableMemory :: [(PatElem (VarWisdom, LetDecMem), PrimExp VName, VName, Space)]
-> PatElem (VarWisdom, LetDecMem)
-> [(PatElem (VarWisdom, LetDecMem), PrimExp VName, VName, Space)]
hasConcretisableMemory [(PatElem (VarWisdom, LetDecMem), PrimExp VName, VName, Space)]
fixable PatElem (VarWisdom, LetDecMem)
pat_elem
      | (VarWisdom
_, MemArray PrimType
pt ShapeBase SubExp
shape NoUniqueness
_ (ArrayIn VName
mem IxFun
ixfun)) <- PatElem (VarWisdom, LetDecMem) -> (VarWisdom, LetDecMem)
forall dec. PatElem dec -> dec
patElemDec PatElem (VarWisdom, LetDecMem)
pat_elem,
        Just (Int
j, Mem Space
space) <-
          (PatElem (VarWisdom, LetDecMem) -> Type)
-> (Int, PatElem (VarWisdom, LetDecMem)) -> (Int, Type)
forall a b. (a -> b) -> (Int, a) -> (Int, b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap PatElem (VarWisdom, LetDecMem) -> Type
forall dec. Typed dec => PatElem dec -> Type
patElemType
            ((Int, PatElem (VarWisdom, LetDecMem)) -> (Int, Type))
-> Maybe (Int, PatElem (VarWisdom, LetDecMem)) -> Maybe (Int, Type)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((Int, PatElem (VarWisdom, LetDecMem)) -> Bool)
-> [(Int, PatElem (VarWisdom, LetDecMem))]
-> Maybe (Int, PatElem (VarWisdom, LetDecMem))
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find
              ((VName
mem ==) (VName -> Bool)
-> ((Int, PatElem (VarWisdom, LetDecMem)) -> VName)
-> (Int, PatElem (VarWisdom, LetDecMem))
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElem (VarWisdom, LetDecMem) -> VName
forall dec. PatElem dec -> VName
patElemName (PatElem (VarWisdom, LetDecMem) -> VName)
-> ((Int, PatElem (VarWisdom, LetDecMem))
    -> PatElem (VarWisdom, LetDecMem))
-> (Int, PatElem (VarWisdom, LetDecMem))
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int, PatElem (VarWisdom, LetDecMem))
-> PatElem (VarWisdom, LetDecMem)
forall a b. (a, b) -> b
snd)
              ([Int]
-> [PatElem (VarWisdom, LetDecMem)]
-> [(Int, PatElem (VarWisdom, LetDecMem))]
forall a b. [a] -> [b] -> [(a, b)]
zip [(Int
0 :: Int) ..] ([PatElem (VarWisdom, LetDecMem)]
 -> [(Int, PatElem (VarWisdom, LetDecMem))])
-> [PatElem (VarWisdom, LetDecMem)]
-> [(Int, PatElem (VarWisdom, LetDecMem))]
forall a b. (a -> b) -> a -> b
$ Pat (VarWisdom, LetDecMem) -> [PatElem (VarWisdom, LetDecMem)]
forall dec. Pat dec -> [PatElem dec]
patElems Pat (VarWisdom, LetDecMem)
Pat (LetDec (Wise rep))
pat),
        Just Result
cases_ses <- (Case (Body (Wise rep)) -> Maybe SubExpRes)
-> [Case (Body (Wise rep))] -> Maybe Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Int -> Result -> Maybe SubExpRes
forall int a. Integral int => int -> [a] -> Maybe a
maybeNth Int
j (Result -> Maybe SubExpRes)
-> (Case (Body (Wise rep)) -> Result)
-> Case (Body (Wise rep))
-> Maybe SubExpRes
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body (Wise rep) -> Result
forall rep. Body rep -> Result
bodyResult (Body (Wise rep) -> Result)
-> (Case (Body (Wise rep)) -> Body (Wise rep))
-> Case (Body (Wise rep))
-> Result
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Case (Body (Wise rep)) -> Body (Wise rep)
forall body. Case body -> body
caseBody) [Case (Body (Wise rep))]
cases,
        Just SubExpRes
defbody_se <- Int -> Result -> Maybe SubExpRes
forall int a. Integral int => int -> [a] -> Maybe a
maybeNth Int
j (Result -> Maybe SubExpRes) -> Result -> Maybe SubExpRes
forall a b. (a -> b) -> a -> b
$ Body (Wise rep) -> Result
forall rep. Body rep -> Result
bodyResult Body (Wise rep)
defbody,
        VName
mem VName -> VName -> Bool
`onlyUsedIn` PatElem (VarWisdom, LetDecMem) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (VarWisdom, LetDecMem)
pat_elem,
        [TPrimExp Int64 VName] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (IxFun -> [TPrimExp Int64 VName]
forall num. IxFun num -> Shape num
IxFun.base IxFun
ixfun) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== ShapeBase SubExp -> Int
forall a. ArrayShape a => a -> Int
shapeRank ShapeBase SubExp
shape, -- See #1325
        (SubExp -> Bool) -> [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all SubExp -> Bool
knownSize (ShapeBase SubExp -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
shape),
        Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ IxFun -> Names
forall a. FreeIn a => a -> Names
freeIn IxFun
ixfun Names -> Names -> Bool
`namesIntersect` [VName] -> Names
namesFromList (Pat (VarWisdom, LetDecMem) -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat (VarWisdom, LetDecMem)
Pat (LetDec (Wise rep))
pat),
        (SubExpRes -> Bool) -> Result -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (SubExpRes
defbody_se /=) Result
cases_ses =
          let mem_size :: PrimExp VName
mem_size =
                TPrimExp Int64 VName -> PrimExp VName
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Int64 VName -> PrimExp VName)
-> TPrimExp Int64 VName -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TPrimExp Int64 VName] -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ PrimType -> TPrimExp Int64 VName
forall a. Num a => PrimType -> a
primByteSize PrimType
pt TPrimExp Int64 VName
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. a -> [a] -> [a]
: (TPrimExp Int64 VName -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map TPrimExp Int64 VName -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (IxFun -> [TPrimExp Int64 VName]
forall num. IxFun num -> Shape num
IxFun.base IxFun
ixfun)
           in (PatElem (VarWisdom, LetDecMem)
pat_elem, PrimExp VName
mem_size, VName
mem, Space
space) (PatElem (VarWisdom, LetDecMem), PrimExp VName, VName, Space)
-> [(PatElem (VarWisdom, LetDecMem), PrimExp VName, VName, Space)]
-> [(PatElem (VarWisdom, LetDecMem), PrimExp VName, VName, Space)]
forall a. a -> [a] -> [a]
: [(PatElem (VarWisdom, LetDecMem), PrimExp VName, VName, Space)]
fixable
      | Bool
otherwise =
          [(PatElem (VarWisdom, LetDecMem), PrimExp VName, VName, Space)]
fixable
unExistentialiseMemory TopDown (Wise rep)
_ Pat (LetDec (Wise rep))
_ StmAux (ExpDec (Wise rep))
_ ([SubExp], [Case (Body (Wise rep))], Body (Wise rep),
 MatchDec (BranchType (Wise rep)))
_ = Rule (Wise rep)
forall rep. Rule rep
Skip

-- If an allocation is statically known to be safe, then we can remove
-- the certificates on it.  This can help hoist things that would
-- otherwise be stuck inside loops or branches.
decertifySafeAlloc :: (SimplifyMemory rep inner) => TopDownRuleOp (Wise rep)
decertifySafeAlloc :: forall rep (inner :: * -> *).
SimplifyMemory rep inner =>
TopDownRuleOp (Wise rep)
decertifySafeAlloc TopDown (Wise rep)
_ Pat (LetDec (Wise rep))
pat (StmAux Certs
cs Attrs
attrs ExpDec (Wise rep)
_) Op (Wise rep)
op
  | Certs
cs Certs -> Certs -> Bool
forall a. Eq a => a -> a -> Bool
/= Certs
forall a. Monoid a => a
mempty,
    [Mem Space
_] <- Pat (VarWisdom, LetDecMem) -> [Type]
forall dec. Typed dec => Pat dec -> [Type]
patTypes Pat (VarWisdom, LetDecMem)
Pat (LetDec (Wise rep))
pat,
    MemOp inner (Wise rep) -> Bool
forall op. IsOp op => op -> Bool
safeOp Op (Wise rep)
MemOp inner (Wise rep)
op =
      RuleM (Wise rep) () -> Rule (Wise rep)
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM (Wise rep) () -> Rule (Wise rep))
-> RuleM (Wise rep) () -> Rule (Wise rep)
forall a b. (a -> b) -> a -> b
$ Attrs -> RuleM (Wise rep) () -> RuleM (Wise rep) ()
forall (m :: * -> *) a. MonadBuilder m => Attrs -> m a -> m a
attributing Attrs
attrs (RuleM (Wise rep) () -> RuleM (Wise rep) ())
-> RuleM (Wise rep) () -> RuleM (Wise rep) ()
forall a b. (a -> b) -> a -> b
$ Pat (LetDec (Rep (RuleM (Wise rep))))
-> Exp (Rep (RuleM (Wise rep))) -> RuleM (Wise rep) ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec (Rep (RuleM (Wise rep))))
Pat (LetDec (Wise rep))
pat (Exp (Rep (RuleM (Wise rep))) -> RuleM (Wise rep) ())
-> Exp (Rep (RuleM (Wise rep))) -> RuleM (Wise rep) ()
forall a b. (a -> b) -> a -> b
$ Op (Wise rep) -> Exp (Wise rep)
forall rep. Op rep -> Exp rep
Op Op (Wise rep)
op
decertifySafeAlloc TopDown (Wise rep)
_ Pat (LetDec (Wise rep))
_ StmAux (ExpDec (Wise rep))
_ Op (Wise rep)
_ = Rule (Wise rep)
forall rep. Rule rep
Skip