{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ConstraintKinds #-}
module Futhark.Representation.ExplicitMemory.Simplify
       ( simplifyExplicitMemory
       , simplifyStms
       )
where

import Control.Monad
import Data.List (find)

import qualified Futhark.Representation.AST.Syntax as AST
import Futhark.Representation.AST.Syntax
  hiding (Prog, BasicOp, Exp, Body, Stm,
          Pattern, PatElem, Lambda, FunDef, FParam, LParam, RetType)
import Futhark.Representation.ExplicitMemory
import Futhark.Representation.Kernels.Simplify (simplifyKernelOp)
import Futhark.Pass.ExplicitAllocations
  (simplifiable, arraySizeInBytesExp)
import qualified Futhark.Analysis.SymbolTable as ST
import qualified Futhark.Analysis.UsageTable as UT
import qualified Futhark.Optimise.Simplify.Engine as Engine
import qualified Futhark.Optimise.Simplify as Simplify
import Futhark.Construct
import Futhark.Pass
import Futhark.Optimise.Simplify.Rules
import Futhark.Optimise.Simplify.Rule
import Futhark.Optimise.Simplify.Lore
import Futhark.Util

simpleExplicitMemory :: Simplify.SimpleOps ExplicitMemory
simpleExplicitMemory :: SimpleOps ExplicitMemory
simpleExplicitMemory = (HostOp ExplicitMemory ()
 -> SimpleM
      ExplicitMemory
      (OpWithWisdom (HostOp ExplicitMemory ()),
       Stms (Wise ExplicitMemory)))
-> SimpleOps ExplicitMemory
forall lore inner.
(SimplifiableLore lore, ExpAttr lore ~ (), BodyAttr lore ~ (),
 Op lore ~ MemOp inner, Allocator lore (PatAllocM lore)) =>
(inner -> SimpleM lore (OpWithWisdom inner, Stms (Wise lore)))
-> SimpleOps lore
simplifiable ((HostOp ExplicitMemory ()
  -> SimpleM
       ExplicitMemory
       (OpWithWisdom (HostOp ExplicitMemory ()),
        Stms (Wise ExplicitMemory)))
 -> SimpleOps ExplicitMemory)
-> (HostOp ExplicitMemory ()
    -> SimpleM
         ExplicitMemory
         (OpWithWisdom (HostOp ExplicitMemory ()),
          Stms (Wise ExplicitMemory)))
-> SimpleOps ExplicitMemory
forall a b. (a -> b) -> a -> b
$ SimplifyOp ExplicitMemory ()
-> HostOp ExplicitMemory ()
-> SimpleM
     ExplicitMemory
     (HostOp (Wise ExplicitMemory) (OpWithWisdom ()),
      Stms (Wise ExplicitMemory))
forall lore op.
(SimplifiableLore lore, BodyAttr lore ~ ()) =>
SimplifyOp lore op
-> HostOp lore op
-> SimpleM
     lore (HostOp (Wise lore) (OpWithWisdom op), Stms (Wise lore))
simplifyKernelOp (SimplifyOp ExplicitMemory ()
 -> HostOp ExplicitMemory ()
 -> SimpleM
      ExplicitMemory
      (HostOp (Wise ExplicitMemory) (OpWithWisdom ()),
       Stms (Wise ExplicitMemory)))
-> SimplifyOp ExplicitMemory ()
-> HostOp ExplicitMemory ()
-> SimpleM
     ExplicitMemory
     (HostOp (Wise ExplicitMemory) (OpWithWisdom ()),
      Stms (Wise ExplicitMemory))
forall a b. (a -> b) -> a -> b
$ SimpleM ExplicitMemory ((), Stms (Wise ExplicitMemory))
-> () -> SimpleM ExplicitMemory ((), Stms (Wise ExplicitMemory))
forall a b. a -> b -> a
const (SimpleM ExplicitMemory ((), Stms (Wise ExplicitMemory))
 -> () -> SimpleM ExplicitMemory ((), Stms (Wise ExplicitMemory)))
-> SimpleM ExplicitMemory ((), Stms (Wise ExplicitMemory))
-> ()
-> SimpleM ExplicitMemory ((), Stms (Wise ExplicitMemory))
forall a b. (a -> b) -> a -> b
$ ((), Stms (Wise ExplicitMemory))
-> SimpleM ExplicitMemory ((), Stms (Wise ExplicitMemory))
forall (m :: * -> *) a. Monad m => a -> m a
return ((), Stms (Wise ExplicitMemory)
forall a. Monoid a => a
mempty)

simplifyExplicitMemory :: Prog ExplicitMemory -> PassM (Prog ExplicitMemory)
simplifyExplicitMemory :: Prog ExplicitMemory -> PassM (Prog ExplicitMemory)
simplifyExplicitMemory =
  SimpleOps ExplicitMemory
-> RuleBook (Wise ExplicitMemory)
-> HoistBlockers ExplicitMemory
-> Prog ExplicitMemory
-> PassM (Prog ExplicitMemory)
forall lore.
SimplifiableLore lore =>
SimpleOps lore
-> RuleBook (Wise lore)
-> HoistBlockers lore
-> Prog lore
-> PassM (Prog lore)
Simplify.simplifyProg SimpleOps ExplicitMemory
simpleExplicitMemory RuleBook (Wise ExplicitMemory)
callKernelRules
  HoistBlockers ExplicitMemory
blockers { blockHoistBranch :: BlockPred (Wise ExplicitMemory)
Engine.blockHoistBranch = BlockPred (Wise ExplicitMemory)
forall lore inner lore p.
(Typed (LetAttr lore), Op lore ~ MemOp inner) =>
SymbolTable lore -> p -> Stm lore -> Bool
blockAllocs }
  where blockAllocs :: SymbolTable lore -> p -> Stm lore -> Bool
blockAllocs SymbolTable lore
vtable p
_ (Let Pattern lore
_ StmAux (ExpAttr lore)
_ (Op Alloc{})) =
          Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ SymbolTable lore -> Bool
forall lore. SymbolTable lore -> Bool
ST.simplifyMemory SymbolTable lore
vtable
        -- Do not hoist statements that produce arrays.  This is
        -- because in the ExplicitMemory representation, multiple
        -- arrays can be located in the same memory block, and moving
        -- their creation out of a branch can thus cause memory
        -- corruption.  At this point in the compiler we have probably
        -- already moved all the array creations that matter.
        blockAllocs SymbolTable lore
_ p
_ (Let Pattern lore
pat StmAux (ExpAttr lore)
_ ExpT lore
_) =
          Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (TypeBase Shape NoUniqueness -> Bool)
-> [TypeBase Shape NoUniqueness] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all TypeBase Shape NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType ([TypeBase Shape NoUniqueness] -> Bool)
-> [TypeBase Shape NoUniqueness] -> Bool
forall a b. (a -> b) -> a -> b
$ Pattern lore -> [TypeBase Shape NoUniqueness]
forall attr.
Typed attr =>
PatternT attr -> [TypeBase Shape NoUniqueness]
patternTypes Pattern lore
pat

simplifyStms :: (HasScope ExplicitMemory m, MonadFreshNames m) =>
                Stms ExplicitMemory -> m (ST.SymbolTable (Wise ExplicitMemory),
                                          Stms ExplicitMemory)
simplifyStms :: Stms ExplicitMemory
-> m (SymbolTable (Wise ExplicitMemory), Stms ExplicitMemory)
simplifyStms Stms ExplicitMemory
stms = do
  Scope ExplicitMemory
scope <- m (Scope ExplicitMemory)
forall lore (m :: * -> *). HasScope lore m => m (Scope lore)
askScope
  SimpleOps ExplicitMemory
-> RuleBook (Wise ExplicitMemory)
-> HoistBlockers ExplicitMemory
-> Scope ExplicitMemory
-> Stms ExplicitMemory
-> m (SymbolTable (Wise ExplicitMemory), Stms ExplicitMemory)
forall (m :: * -> *) lore.
(MonadFreshNames m, SimplifiableLore lore) =>
SimpleOps lore
-> RuleBook (Wise lore)
-> HoistBlockers lore
-> Scope lore
-> Stms lore
-> m (SymbolTable (Wise lore), Stms lore)
Simplify.simplifyStms SimpleOps ExplicitMemory
simpleExplicitMemory RuleBook (Wise ExplicitMemory)
callKernelRules HoistBlockers ExplicitMemory
blockers
    Scope ExplicitMemory
scope Stms ExplicitMemory
stms

isResultAlloc :: Op lore ~ MemOp op => Engine.BlockPred lore
isResultAlloc :: BlockPred lore
isResultAlloc SymbolTable lore
_ UsageTable
usage (Let (AST.Pattern [] [PatElemT (LetAttr lore)
bindee]) StmAux (ExpAttr lore)
_ (Op Alloc{})) =
  VName -> UsageTable -> Bool
UT.isInResult (PatElemT (LetAttr lore) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (LetAttr lore)
bindee) UsageTable
usage
isResultAlloc SymbolTable lore
_ UsageTable
_ Stm lore
_ = Bool
False

-- | Getting the roots of what to hoist, for now only variable
-- names that represent array and memory-block sizes.
getShapeNames :: (ExplicitMemorish lore, Op lore ~ MemOp op) =>
                 Stm (Wise lore) -> Names
getShapeNames :: Stm (Wise lore) -> Names
getShapeNames Stm (Wise lore)
stm =
  let ts :: [TypeBase Shape NoUniqueness]
ts = (PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)
 -> TypeBase Shape NoUniqueness)
-> [PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)]
-> [TypeBase Shape NoUniqueness]
forall a b. (a -> b) -> [a] -> [b]
map PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)
-> TypeBase Shape NoUniqueness
forall attr.
Typed attr =>
PatElemT attr -> TypeBase Shape NoUniqueness
patElemType ([PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)]
 -> [TypeBase Shape NoUniqueness])
-> [PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)]
-> [TypeBase Shape NoUniqueness]
forall a b. (a -> b) -> a -> b
$ PatternT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)
-> [PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)]
forall attr. PatternT attr -> [PatElemT attr]
patternElements (PatternT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)
 -> [PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)])
-> PatternT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)
-> [PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)]
forall a b. (a -> b) -> a -> b
$ Stm (Wise lore) -> Pattern (Wise lore)
forall lore. Stm lore -> Pattern lore
stmPattern Stm (Wise lore)
stm
  in [SubExp] -> Names
forall a. FreeIn a => a -> Names
freeIn ((TypeBase Shape NoUniqueness -> [SubExp])
-> [TypeBase Shape NoUniqueness] -> [SubExp]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap TypeBase Shape NoUniqueness -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims [TypeBase Shape NoUniqueness]
ts) Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<>
     case Stm (Wise lore) -> Exp (Wise lore)
forall lore. Stm lore -> Exp lore
stmExp Stm (Wise lore)
stm of Op (Alloc size _) -> SubExp -> Names
forall a. FreeIn a => a -> Names
freeIn SubExp
size
                        Exp (Wise lore)
_                 -> Names
forall a. Monoid a => a
mempty

isAlloc :: Op lore ~ MemOp op => Engine.BlockPred lore
isAlloc :: BlockPred lore
isAlloc SymbolTable lore
_ UsageTable
_ (Let Pattern lore
_ StmAux (ExpAttr lore)
_ (Op Alloc{})) = Bool
True
isAlloc SymbolTable lore
_ UsageTable
_ Stm lore
_                      = Bool
False

blockers :: Simplify.HoistBlockers ExplicitMemory
blockers :: HoistBlockers ExplicitMemory
blockers = HoistBlockers ExplicitMemory
forall lore. HoistBlockers lore
Engine.noExtraHoistBlockers {
    blockHoistPar :: BlockPred (Wise ExplicitMemory)
Engine.blockHoistPar    = BlockPred (Wise ExplicitMemory)
forall lore op. (Op lore ~ MemOp op) => BlockPred lore
isAlloc
  , blockHoistSeq :: BlockPred (Wise ExplicitMemory)
Engine.blockHoistSeq    = BlockPred (Wise ExplicitMemory)
forall lore op. (Op lore ~ MemOp op) => BlockPred lore
isResultAlloc
  , getArraySizes :: Stm (Wise ExplicitMemory) -> Names
Engine.getArraySizes    = Stm (Wise ExplicitMemory) -> Names
forall lore op.
(ExplicitMemorish lore, Op lore ~ MemOp op) =>
Stm (Wise lore) -> Names
getShapeNames
  , isAllocation :: Stm (Wise ExplicitMemory) -> Bool
Engine.isAllocation     = BlockPred (Wise ExplicitMemory)
forall lore op. (Op lore ~ MemOp op) => BlockPred lore
isAlloc SymbolTable (Wise ExplicitMemory)
forall a. Monoid a => a
mempty UsageTable
forall a. Monoid a => a
mempty
  }

callKernelRules :: RuleBook (Wise ExplicitMemory)
callKernelRules :: RuleBook (Wise ExplicitMemory)
callKernelRules = RuleBook (Wise ExplicitMemory)
forall lore. (BinderOps lore, Aliased lore) => RuleBook lore
standardRules RuleBook (Wise ExplicitMemory)
-> RuleBook (Wise ExplicitMemory) -> RuleBook (Wise ExplicitMemory)
forall a. Semigroup a => a -> a -> a
<>
                  [TopDownRule (Wise ExplicitMemory)]
-> [BottomUpRule (Wise ExplicitMemory)]
-> RuleBook (Wise ExplicitMemory)
forall m. [TopDownRule m] -> [BottomUpRule m] -> RuleBook m
ruleBook [RuleBasicOp
  (Wise ExplicitMemory) (SymbolTable (Wise ExplicitMemory))
-> TopDownRule (Wise ExplicitMemory)
forall lore a. RuleBasicOp lore a -> SimplificationRule lore a
RuleBasicOp RuleBasicOp
  (Wise ExplicitMemory) (SymbolTable (Wise ExplicitMemory))
forall lore u.
(BinderOps lore, LetAttr lore ~ (VarWisdom, MemBound u)) =>
TopDownRuleBasicOp lore
copyCopyToCopy,
                            RuleBasicOp
  (Wise ExplicitMemory) (SymbolTable (Wise ExplicitMemory))
-> TopDownRule (Wise ExplicitMemory)
forall lore a. RuleBasicOp lore a -> SimplificationRule lore a
RuleBasicOp RuleBasicOp
  (Wise ExplicitMemory) (SymbolTable (Wise ExplicitMemory))
forall lore u.
(BinderOps lore, LetAttr lore ~ (VarWisdom, MemBound u)) =>
TopDownRuleBasicOp lore
removeIdentityCopy,
                            RuleIf (Wise ExplicitMemory) (SymbolTable (Wise ExplicitMemory))
-> TopDownRule (Wise ExplicitMemory)
forall lore a. RuleIf lore a -> SimplificationRule lore a
RuleIf RuleIf (Wise ExplicitMemory) (SymbolTable (Wise ExplicitMemory))
unExistentialiseMemory] []

-- | If a branch is returning some existential memory, but the size of
-- the array is not existential, then we can create a block of the
-- proper size and always return there.
unExistentialiseMemory :: TopDownRuleIf (Wise ExplicitMemory)
unExistentialiseMemory :: RuleIf (Wise ExplicitMemory) (SymbolTable (Wise ExplicitMemory))
unExistentialiseMemory SymbolTable (Wise ExplicitMemory)
vtable Pattern (Wise ExplicitMemory)
pat StmAux (ExpAttr (Wise ExplicitMemory))
_ (SubExp
cond, BodyT (Wise ExplicitMemory)
tbranch, BodyT (Wise ExplicitMemory)
fbranch, IfAttr (BranchType (Wise ExplicitMemory))
ifattr)
  | SymbolTable (Wise ExplicitMemory) -> Bool
forall lore. SymbolTable lore -> Bool
ST.simplifyMemory SymbolTable (Wise ExplicitMemory)
vtable,
    [(PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind), VName,
  Space)]
fixable <- ([(PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind),
   VName, Space)]
 -> PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)
 -> [(PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind),
      VName, Space)])
-> [(PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind),
     VName, Space)]
-> [PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)]
-> [(PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind),
     VName, Space)]
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl [(PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind), VName,
  Space)]
-> PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)
-> [(PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind),
     VName, Space)]
hasConcretisableMemory [(PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind), VName,
  Space)]
forall a. Monoid a => a
mempty ([PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)]
 -> [(PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind),
      VName, Space)])
-> [PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)]
-> [(PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind),
     VName, Space)]
forall a b. (a -> b) -> a -> b
$ PatternT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)
-> [PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)]
forall attr. PatternT attr -> [PatElemT attr]
patternElements PatternT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)
Pattern (Wise ExplicitMemory)
pat,
    Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [(PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind), VName,
  Space)]
-> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind), VName,
  Space)]
fixable = RuleM (Wise ExplicitMemory) () -> Rule (Wise ExplicitMemory)
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM (Wise ExplicitMemory) () -> Rule (Wise ExplicitMemory))
-> RuleM (Wise ExplicitMemory) () -> Rule (Wise ExplicitMemory)
forall a b. (a -> b) -> a -> b
$ do

      -- Create non-existential memory blocks big enough to hold the
      -- arrays.
      ([(VName, VName)]
arr_to_mem, [(VName, VName)]
oldmem_to_mem) <-
        ([((VName, VName), (VName, VName))]
 -> ([(VName, VName)], [(VName, VName)]))
-> RuleM (Wise ExplicitMemory) [((VName, VName), (VName, VName))]
-> RuleM (Wise ExplicitMemory) ([(VName, VName)], [(VName, VName)])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [((VName, VName), (VName, VName))]
-> ([(VName, VName)], [(VName, VName)])
forall a b. [(a, b)] -> ([a], [b])
unzip (RuleM (Wise ExplicitMemory) [((VName, VName), (VName, VName))]
 -> RuleM
      (Wise ExplicitMemory) ([(VName, VName)], [(VName, VName)]))
-> RuleM (Wise ExplicitMemory) [((VName, VName), (VName, VName))]
-> RuleM (Wise ExplicitMemory) ([(VName, VName)], [(VName, VName)])
forall a b. (a -> b) -> a -> b
$ [(PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind), VName,
  Space)]
-> ((PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind),
     VName, Space)
    -> RuleM (Wise ExplicitMemory) ((VName, VName), (VName, VName)))
-> RuleM (Wise ExplicitMemory) [((VName, VName), (VName, VName))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind), VName,
  Space)]
fixable (((PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind),
   VName, Space)
  -> RuleM (Wise ExplicitMemory) ((VName, VName), (VName, VName)))
 -> RuleM (Wise ExplicitMemory) [((VName, VName), (VName, VName))])
-> ((PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind),
     VName, Space)
    -> RuleM (Wise ExplicitMemory) ((VName, VName), (VName, VName)))
-> RuleM (Wise ExplicitMemory) [((VName, VName), (VName, VName))]
forall a b. (a -> b) -> a -> b
$ \(PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)
arr_pe, VName
oldmem, Space
space) -> do
          SubExp
size <- String
-> Exp (Lore (RuleM (Wise ExplicitMemory)))
-> RuleM (Wise ExplicitMemory) SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"size" (ExpT (Wise ExplicitMemory) -> RuleM (Wise ExplicitMemory) SubExp)
-> RuleM (Wise ExplicitMemory) (ExpT (Wise ExplicitMemory))
-> RuleM (Wise ExplicitMemory) SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
                  PrimExp VName
-> RuleM
     (Wise ExplicitMemory) (Exp (Lore (RuleM (Wise ExplicitMemory))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (TypeBase Shape NoUniqueness -> PrimExp VName
arraySizeInBytesExp (TypeBase Shape NoUniqueness -> PrimExp VName)
-> TypeBase Shape NoUniqueness -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)
-> TypeBase Shape NoUniqueness
forall attr.
Typed attr =>
PatElemT attr -> TypeBase Shape NoUniqueness
patElemType PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)
arr_pe)
          VName
mem <- String
-> Exp (Lore (RuleM (Wise ExplicitMemory)))
-> RuleM (Wise ExplicitMemory) VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"mem" (Exp (Lore (RuleM (Wise ExplicitMemory)))
 -> RuleM (Wise ExplicitMemory) VName)
-> Exp (Lore (RuleM (Wise ExplicitMemory)))
-> RuleM (Wise ExplicitMemory) VName
forall a b. (a -> b) -> a -> b
$ Op (Wise ExplicitMemory) -> ExpT (Wise ExplicitMemory)
forall lore. Op lore -> ExpT lore
Op (Op (Wise ExplicitMemory) -> ExpT (Wise ExplicitMemory))
-> Op (Wise ExplicitMemory) -> ExpT (Wise ExplicitMemory)
forall a b. (a -> b) -> a -> b
$ SubExp -> Space -> MemOp (HostOp (Wise ExplicitMemory) ())
forall inner. SubExp -> Space -> MemOp inner
Alloc SubExp
size Space
space
          ((VName, VName), (VName, VName))
-> RuleM (Wise ExplicitMemory) ((VName, VName), (VName, VName))
forall (m :: * -> *) a. Monad m => a -> m a
return ((PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)
arr_pe, VName
mem), (VName
oldmem, VName
mem))

      -- Update the branches to contain Copy expressions putting the
      -- arrays where they are expected.
      let updateBody :: BodyT (Wise ExplicitMemory)
-> RuleM
     (Wise ExplicitMemory) (Body (Lore (RuleM (Wise ExplicitMemory))))
updateBody BodyT (Wise ExplicitMemory)
body = RuleM
  (Wise ExplicitMemory) (Body (Lore (RuleM (Wise ExplicitMemory))))
-> RuleM
     (Wise ExplicitMemory) (Body (Lore (RuleM (Wise ExplicitMemory))))
forall (m :: * -> *).
MonadBinder m =>
m (Body (Lore m)) -> m (Body (Lore m))
insertStmsM (RuleM
   (Wise ExplicitMemory) (Body (Lore (RuleM (Wise ExplicitMemory))))
 -> RuleM
      (Wise ExplicitMemory) (Body (Lore (RuleM (Wise ExplicitMemory)))))
-> RuleM
     (Wise ExplicitMemory) (Body (Lore (RuleM (Wise ExplicitMemory))))
-> RuleM
     (Wise ExplicitMemory) (Body (Lore (RuleM (Wise ExplicitMemory))))
forall a b. (a -> b) -> a -> b
$ do
            [SubExp]
res <- Body (Lore (RuleM (Wise ExplicitMemory)))
-> RuleM (Wise ExplicitMemory) [SubExp]
forall (m :: * -> *). MonadBinder m => Body (Lore m) -> m [SubExp]
bodyBind Body (Lore (RuleM (Wise ExplicitMemory)))
BodyT (Wise ExplicitMemory)
body
            [SubExp]
-> RuleM (Wise ExplicitMemory) (BodyT (Wise ExplicitMemory))
forall (m :: * -> *).
MonadBinder m =>
[SubExp] -> m (Body (Lore m))
resultBodyM ([SubExp]
 -> RuleM (Wise ExplicitMemory) (BodyT (Wise ExplicitMemory)))
-> RuleM (Wise ExplicitMemory) [SubExp]
-> RuleM (Wise ExplicitMemory) (BodyT (Wise ExplicitMemory))
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
              (PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)
 -> SubExp -> RuleM (Wise ExplicitMemory) SubExp)
-> [PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)]
-> [SubExp]
-> RuleM (Wise ExplicitMemory) [SubExp]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)
-> SubExp -> RuleM (Wise ExplicitMemory) SubExp
updateResult (PatternT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)
-> [PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)]
forall attr. PatternT attr -> [PatElemT attr]
patternElements PatternT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)
Pattern (Wise ExplicitMemory)
pat) [SubExp]
res
          updateResult :: PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)
-> SubExp -> RuleM (Wise ExplicitMemory) SubExp
updateResult PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)
pat_elem (Var VName
v)
            | Just VName
mem <- VName -> [(VName, VName)] -> Maybe VName
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup (PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)
pat_elem) [(VName, VName)]
arr_to_mem,
              (VarWisdom
_, MemArray PrimType
pt Shape
shape NoUniqueness
u (ArrayIn VName
_ IxFun
ixfun)) <- PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)
-> (VarWisdom, MemInfo SubExp NoUniqueness MemBind)
forall attr. PatElemT attr -> attr
patElemAttr PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)
pat_elem = do
                VName
v_copy <- String -> RuleM (Wise ExplicitMemory) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> RuleM (Wise ExplicitMemory) VName)
-> String -> RuleM (Wise ExplicitMemory) VName
forall a b. (a -> b) -> a -> b
$ VName -> String
baseString VName
v String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_nonext_copy"
                let v_pat :: PatternT (MemInfo SubExp NoUniqueness MemBind)
v_pat = [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> PatternT (MemInfo SubExp NoUniqueness MemBind)
forall attr. [PatElemT attr] -> [PatElemT attr] -> PatternT attr
Pattern [] [VName
-> MemInfo SubExp NoUniqueness MemBind
-> PatElemT (MemInfo SubExp NoUniqueness MemBind)
forall attr. VName -> attr -> PatElemT attr
PatElem VName
v_copy (MemInfo SubExp NoUniqueness MemBind
 -> PatElemT (MemInfo SubExp NoUniqueness MemBind))
-> MemInfo SubExp NoUniqueness MemBind
-> PatElemT (MemInfo SubExp NoUniqueness MemBind)
forall a b. (a -> b) -> a -> b
$
                                        PrimType
-> Shape
-> NoUniqueness
-> MemBind
-> MemInfo SubExp NoUniqueness MemBind
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt Shape
shape NoUniqueness
u (MemBind -> MemInfo SubExp NoUniqueness MemBind)
-> MemBind -> MemInfo SubExp NoUniqueness MemBind
forall a b. (a -> b) -> a -> b
$ VName -> IxFun -> MemBind
ArrayIn VName
mem IxFun
ixfun]
                Stm (Lore (RuleM (Wise ExplicitMemory)))
-> RuleM (Wise ExplicitMemory) ()
forall (m :: * -> *). MonadBinder m => Stm (Lore m) -> m ()
addStm (Stm (Lore (RuleM (Wise ExplicitMemory)))
 -> RuleM (Wise ExplicitMemory) ())
-> Stm (Lore (RuleM (Wise ExplicitMemory)))
-> RuleM (Wise ExplicitMemory) ()
forall a b. (a -> b) -> a -> b
$ Pattern ExplicitMemory
-> StmAux (ExpAttr ExplicitMemory)
-> ExpT (Wise ExplicitMemory)
-> Stm (Wise ExplicitMemory)
forall lore.
(Attributes lore, CanBeWise (Op lore)) =>
Pattern lore
-> StmAux (ExpAttr lore) -> Exp (Wise lore) -> Stm (Wise lore)
mkWiseLetStm Pattern ExplicitMemory
PatternT (MemInfo SubExp NoUniqueness MemBind)
v_pat (() -> StmAux ()
forall attr. attr -> StmAux attr
defAux ()) (ExpT (Wise ExplicitMemory) -> Stm (Wise ExplicitMemory))
-> ExpT (Wise ExplicitMemory) -> Stm (Wise ExplicitMemory)
forall a b. (a -> b) -> a -> b
$ BasicOp (Wise ExplicitMemory) -> ExpT (Wise ExplicitMemory)
forall lore. BasicOp lore -> ExpT lore
BasicOp (VName -> BasicOp (Wise ExplicitMemory)
forall lore. VName -> BasicOp lore
Copy VName
v)
                SubExp -> RuleM (Wise ExplicitMemory) SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp -> RuleM (Wise ExplicitMemory) SubExp)
-> SubExp -> RuleM (Wise ExplicitMemory) SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v_copy
            | Just VName
mem <- VName -> [(VName, VName)] -> Maybe VName
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup (PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)
pat_elem) [(VName, VName)]
oldmem_to_mem =
                SubExp -> RuleM (Wise ExplicitMemory) SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp -> RuleM (Wise ExplicitMemory) SubExp)
-> SubExp -> RuleM (Wise ExplicitMemory) SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
mem
          updateResult PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)
_ SubExp
se =
            SubExp -> RuleM (Wise ExplicitMemory) SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return SubExp
se
      BodyT (Wise ExplicitMemory)
tbranch' <- BodyT (Wise ExplicitMemory)
-> RuleM
     (Wise ExplicitMemory) (Body (Lore (RuleM (Wise ExplicitMemory))))
updateBody BodyT (Wise ExplicitMemory)
tbranch
      BodyT (Wise ExplicitMemory)
fbranch' <- BodyT (Wise ExplicitMemory)
-> RuleM
     (Wise ExplicitMemory) (Body (Lore (RuleM (Wise ExplicitMemory))))
updateBody BodyT (Wise ExplicitMemory)
fbranch
      Pattern (Lore (RuleM (Wise ExplicitMemory)))
-> Exp (Lore (RuleM (Wise ExplicitMemory)))
-> RuleM (Wise ExplicitMemory) ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind_ Pattern (Lore (RuleM (Wise ExplicitMemory)))
Pattern (Wise ExplicitMemory)
pat (Exp (Lore (RuleM (Wise ExplicitMemory)))
 -> RuleM (Wise ExplicitMemory) ())
-> Exp (Lore (RuleM (Wise ExplicitMemory)))
-> RuleM (Wise ExplicitMemory) ()
forall a b. (a -> b) -> a -> b
$ SubExp
-> BodyT (Wise ExplicitMemory)
-> BodyT (Wise ExplicitMemory)
-> IfAttr (BranchType (Wise ExplicitMemory))
-> ExpT (Wise ExplicitMemory)
forall lore.
SubExp
-> BodyT lore
-> BodyT lore
-> IfAttr (BranchType lore)
-> ExpT lore
If SubExp
cond BodyT (Wise ExplicitMemory)
tbranch' BodyT (Wise ExplicitMemory)
fbranch' IfAttr (BranchType (Wise ExplicitMemory))
ifattr
  where onlyUsedIn :: VName -> VName -> Bool
onlyUsedIn VName
name VName
here = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind) -> Bool)
-> [PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)]
-> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any ((VName
name VName -> Names -> Bool
`nameIn`) (Names -> Bool)
-> (PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)
    -> Names)
-> PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind) -> Names
forall a. FreeIn a => a -> Names
freeIn) ([PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)]
 -> Bool)
-> [PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)]
-> Bool
forall a b. (a -> b) -> a -> b
$
                                          (PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind) -> Bool)
-> [PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)]
-> [PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)]
forall a. (a -> Bool) -> [a] -> [a]
filter ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
/=VName
here) (VName -> Bool)
-> (PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)
    -> VName)
-> PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. PatElemT attr -> VName
patElemName) ([PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)]
 -> [PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)])
-> [PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)]
-> [PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)]
forall a b. (a -> b) -> a -> b
$
                                          PatternT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)
-> [PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)]
forall attr. PatternT attr -> [PatElemT attr]
patternValueElements PatternT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)
Pattern (Wise ExplicitMemory)
pat
        knownSize :: SubExp -> Bool
knownSize Constant{} = Bool
True
        knownSize (Var VName
v) = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ VName -> Bool
inContext VName
v
        inContext :: VName -> Bool
inContext = (VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` PatternT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)
-> [VName]
forall attr. PatternT attr -> [VName]
patternContextNames PatternT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)
Pattern (Wise ExplicitMemory)
pat)

        hasConcretisableMemory :: [(PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind), VName,
  Space)]
-> PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)
-> [(PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind),
     VName, Space)]
hasConcretisableMemory [(PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind), VName,
  Space)]
fixable PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)
pat_elem
          | (VarWisdom
_, MemArray PrimType
_ Shape
shape NoUniqueness
_ (ArrayIn VName
mem IxFun
_)) <- PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)
-> (VarWisdom, MemInfo SubExp NoUniqueness MemBind)
forall attr. PatElemT attr -> attr
patElemAttr PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)
pat_elem,
            Just (Int
j, Mem Space
space) <-
              (PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)
 -> TypeBase Shape NoUniqueness)
-> (Int, PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind))
-> (Int, TypeBase Shape NoUniqueness)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)
-> TypeBase Shape NoUniqueness
forall attr.
Typed attr =>
PatElemT attr -> TypeBase Shape NoUniqueness
patElemType ((Int, PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind))
 -> (Int, TypeBase Shape NoUniqueness))
-> Maybe
     (Int, PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind))
-> Maybe (Int, TypeBase Shape NoUniqueness)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((Int, PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind))
 -> Bool)
-> [(Int,
     PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind))]
-> Maybe
     (Int, PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind))
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName
memVName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
==) (VName -> Bool)
-> ((Int,
     PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind))
    -> VName)
-> (Int, PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind))
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. PatElemT attr -> VName
patElemName (PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)
 -> VName)
-> ((Int,
     PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind))
    -> PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind))
-> (Int, PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind))
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int, PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind))
-> PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)
forall a b. (a, b) -> b
snd)
                                        ([Int]
-> [PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)]
-> [(Int,
     PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind))]
forall a b. [a] -> [b] -> [(a, b)]
zip [(Int
0::Int)..] ([PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)]
 -> [(Int,
      PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind))])
-> [PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)]
-> [(Int,
     PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind))]
forall a b. (a -> b) -> a -> b
$ PatternT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)
-> [PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)]
forall attr. PatternT attr -> [PatElemT attr]
patternElements PatternT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)
Pattern (Wise ExplicitMemory)
pat),
            Just SubExp
tse <- Int -> [SubExp] -> Maybe SubExp
forall int a. Integral int => int -> [a] -> Maybe a
maybeNth Int
j ([SubExp] -> Maybe SubExp) -> [SubExp] -> Maybe SubExp
forall a b. (a -> b) -> a -> b
$ BodyT (Wise ExplicitMemory) -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT (Wise ExplicitMemory)
tbranch,
            Just SubExp
fse <- Int -> [SubExp] -> Maybe SubExp
forall int a. Integral int => int -> [a] -> Maybe a
maybeNth Int
j ([SubExp] -> Maybe SubExp) -> [SubExp] -> Maybe SubExp
forall a b. (a -> b) -> a -> b
$ BodyT (Wise ExplicitMemory) -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT (Wise ExplicitMemory)
fbranch,
            VName
mem VName -> VName -> Bool
`onlyUsedIn` PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)
pat_elem,
            (SubExp -> Bool) -> [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all SubExp -> Bool
knownSize (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape),
            SubExp
fse SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
/= SubExp
tse =
              (PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind)
pat_elem, VName
mem, Space
space) (PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind), VName,
 Space)
-> [(PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind),
     VName, Space)]
-> [(PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind),
     VName, Space)]
forall a. a -> [a] -> [a]
: [(PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind), VName,
  Space)]
fixable
          | Bool
otherwise =
              [(PatElemT (VarWisdom, MemInfo SubExp NoUniqueness MemBind), VName,
  Space)]
fixable
unExistentialiseMemory SymbolTable (Wise ExplicitMemory)
_ Pattern (Wise ExplicitMemory)
_ StmAux (ExpAttr (Wise ExplicitMemory))
_ (SubExp, BodyT (Wise ExplicitMemory), BodyT (Wise ExplicitMemory),
 IfAttr (BranchType (Wise ExplicitMemory)))
_ = Rule (Wise ExplicitMemory)
forall lore. Rule lore
Skip

-- | If we are copying something that is itself a copy, just copy the
-- original one instead.
copyCopyToCopy :: (BinderOps lore,
                   LetAttr lore ~ (VarWisdom, MemBound u)) =>
                  TopDownRuleBasicOp lore
copyCopyToCopy :: TopDownRuleBasicOp lore
copyCopyToCopy TopDown lore
vtable pat :: Pattern lore
pat@(Pattern [] [PatElemT (LetAttr lore)
pat_elem]) StmAux (ExpAttr lore)
_ (Copy VName
v1)
  | Just (BasicOp (Copy VName
v2), Certificates
v1_cs) <- VName -> TopDown lore -> Maybe (Exp lore, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (Exp lore, Certificates)
ST.lookupExp VName
v1 TopDown lore
vtable,

    Just (VarWisdom
_, MemArray PrimType
_ Shape
_ u
_ (ArrayIn VName
srcmem IxFun
src_ixfun)) <-
      Entry lore -> Maybe (VarWisdom, MemInfo SubExp u MemBind)
forall lore. Entry lore -> Maybe (LetAttr lore)
ST.entryLetBoundAttr (Entry lore -> Maybe (VarWisdom, MemInfo SubExp u MemBind))
-> Maybe (Entry lore)
-> Maybe (VarWisdom, MemInfo SubExp u MemBind)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> TopDown lore -> Maybe (Entry lore)
forall lore. VName -> SymbolTable lore -> Maybe (Entry lore)
ST.lookup VName
v1 TopDown lore
vtable,

    Just (Mem Space
src_space) <- VName -> TopDown lore -> Maybe (TypeBase Shape NoUniqueness)
forall lore.
Attributes lore =>
VName -> SymbolTable lore -> Maybe (TypeBase Shape NoUniqueness)
ST.lookupType VName
srcmem TopDown lore
vtable,

    (VarWisdom
_, MemArray PrimType
_ Shape
_ u
_ (ArrayIn VName
destmem IxFun
dest_ixfun)) <- PatElemT (VarWisdom, MemInfo SubExp u MemBind)
-> (VarWisdom, MemInfo SubExp u MemBind)
forall attr. PatElemT attr -> attr
patElemAttr PatElemT (VarWisdom, MemInfo SubExp u MemBind)
PatElemT (LetAttr lore)
pat_elem,

    Just (Mem Space
dest_space) <- VName -> TopDown lore -> Maybe (TypeBase Shape NoUniqueness)
forall lore.
Attributes lore =>
VName -> SymbolTable lore -> Maybe (TypeBase Shape NoUniqueness)
ST.lookupType VName
destmem TopDown lore
vtable,

    Space
src_space Space -> Space -> Bool
forall a. Eq a => a -> a -> Bool
== Space
dest_space, IxFun
dest_ixfun IxFun -> IxFun -> Bool
forall a. Eq a => a -> a -> Bool
== IxFun
src_ixfun =

      RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ Certificates -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
v1_cs (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind_ Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp lore -> Exp lore
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp lore -> Exp lore) -> BasicOp lore -> Exp lore
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp lore
forall lore. VName -> BasicOp lore
Copy VName
v2

copyCopyToCopy TopDown lore
vtable Pattern lore
pat StmAux (ExpAttr lore)
_ (Copy VName
v0)
  | Just (BasicOp (Rearrange [Int]
perm VName
v1), Certificates
v0_cs) <- VName -> TopDown lore -> Maybe (Exp lore, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (Exp lore, Certificates)
ST.lookupExp VName
v0 TopDown lore
vtable,
    Just (BasicOp (Copy VName
v2), Certificates
v1_cs) <- VName -> TopDown lore -> Maybe (Exp lore, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (Exp lore, Certificates)
ST.lookupExp VName
v1 TopDown lore
vtable = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
      VName
v0' <- Certificates -> RuleM lore VName -> RuleM lore VName
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (Certificates
v0_csCertificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<>Certificates
v1_cs) (RuleM lore VName -> RuleM lore VName)
-> RuleM lore VName -> RuleM lore VName
forall a b. (a -> b) -> a -> b
$
             String -> Exp (Lore (RuleM lore)) -> RuleM lore VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"rearrange_v0" (Exp (Lore (RuleM lore)) -> RuleM lore VName)
-> Exp (Lore (RuleM lore)) -> RuleM lore VName
forall a b. (a -> b) -> a -> b
$ BasicOp lore -> Exp lore
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp lore -> Exp lore) -> BasicOp lore -> Exp lore
forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp lore
forall lore. [Int] -> VName -> BasicOp lore
Rearrange [Int]
perm VName
v2
      Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind_ Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp lore -> Exp lore
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp lore -> Exp lore) -> BasicOp lore -> Exp lore
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp lore
forall lore. VName -> BasicOp lore
Copy VName
v0'

copyCopyToCopy TopDown lore
_ Pattern lore
_ StmAux (ExpAttr lore)
_ BasicOp lore
_ = Rule lore
forall lore. Rule lore
Skip

-- | If the destination of a copy is the same as the source, just
-- remove it.
removeIdentityCopy :: (BinderOps lore,
                       LetAttr lore ~ (VarWisdom, MemBound u)) =>
                      TopDownRuleBasicOp lore
removeIdentityCopy :: TopDownRuleBasicOp lore
removeIdentityCopy TopDown lore
vtable pat :: Pattern lore
pat@(Pattern [] [PatElemT (LetAttr lore)
pe]) StmAux (ExpAttr lore)
_ (Copy VName
v)
  | (VarWisdom
_, MemArray PrimType
_ Shape
_ u
_ (ArrayIn VName
dest_mem IxFun
dest_ixfun)) <- PatElemT (VarWisdom, MemInfo SubExp u MemBind)
-> (VarWisdom, MemInfo SubExp u MemBind)
forall attr. PatElemT attr -> attr
patElemAttr PatElemT (VarWisdom, MemInfo SubExp u MemBind)
PatElemT (LetAttr lore)
pe,
    Just (VarWisdom
_, MemArray PrimType
_ Shape
_ u
_ (ArrayIn VName
src_mem IxFun
src_ixfun)) <-
      Entry lore -> Maybe (VarWisdom, MemInfo SubExp u MemBind)
forall lore. Entry lore -> Maybe (LetAttr lore)
ST.entryLetBoundAttr (Entry lore -> Maybe (VarWisdom, MemInfo SubExp u MemBind))
-> Maybe (Entry lore)
-> Maybe (VarWisdom, MemInfo SubExp u MemBind)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> TopDown lore -> Maybe (Entry lore)
forall lore. VName -> SymbolTable lore -> Maybe (Entry lore)
ST.lookup VName
v TopDown lore
vtable,
    VName
dest_mem VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
src_mem, IxFun
dest_ixfun IxFun -> IxFun -> Bool
forall a. Eq a => a -> a -> Bool
== IxFun
src_ixfun =
      RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind_ Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp lore -> ExpT lore
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp lore -> ExpT lore) -> BasicOp lore -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp lore
forall lore. SubExp -> BasicOp lore
SubExp (SubExp -> BasicOp lore) -> SubExp -> BasicOp lore
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v

removeIdentityCopy TopDown lore
_ Pattern lore
_ StmAux (ExpAttr lore)
_ BasicOp lore
_ = Rule lore
forall lore. Rule lore
Skip