{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE Safe #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeFamilies #-}
-- | This module defines a collection of simplification rules, as per
-- "Futhark.Optimise.Simplify.Rule".  They are used in the
-- simplifier.
--
-- For performance reasons, many sufficiently simple logically
-- separate rules are merged into single "super-rules", like ruleIf
-- and ruleBasicOp.  This is because it is relatively expensive to
-- activate a rule just to determine that it does not apply.  Thus, it
-- is more efficient to have a few very fat rules than a lot of small
-- rules.  This does not affect the compiler result in any way; it is
-- purely an optimisation to speed up compilation.
module Futhark.Optimise.Simplify.Rules
  ( standardRules
  , removeUnnecessaryCopy
  )
where

import Control.Monad
import Data.Either
import Data.List (find, isSuffixOf, partition, sort)
import Data.Maybe
import qualified Data.Map.Strict as M

import qualified Futhark.Analysis.SymbolTable as ST
import qualified Futhark.Analysis.UsageTable as UT
import Futhark.Analysis.DataDependencies
import Futhark.Optimise.Simplify.ClosedForm
import Futhark.Optimise.Simplify.Rule
import Futhark.Analysis.PrimExp.Convert
import Futhark.IR
import Futhark.IR.Prop.Aliases
import Futhark.Construct
import Futhark.Util

topDownRules :: (BinderOps lore, Aliased lore) => [TopDownRule lore]
topDownRules :: [TopDownRule lore]
topDownRules = [ RuleDoLoop lore (TopDown lore) -> TopDownRule lore
forall lore a. RuleDoLoop lore a -> SimplificationRule lore a
RuleDoLoop RuleDoLoop lore (TopDown lore)
forall lore. BinderOps lore => TopDownRuleDoLoop lore
hoistLoopInvariantMergeVariables
               , RuleDoLoop lore (TopDown lore) -> TopDownRule lore
forall lore a. RuleDoLoop lore a -> SimplificationRule lore a
RuleDoLoop RuleDoLoop lore (TopDown lore)
forall lore. BinderOps lore => TopDownRuleDoLoop lore
simplifyClosedFormLoop
               , RuleDoLoop lore (TopDown lore) -> TopDownRule lore
forall lore a. RuleDoLoop lore a -> SimplificationRule lore a
RuleDoLoop RuleDoLoop lore (TopDown lore)
forall lore. BinderOps lore => TopDownRuleDoLoop lore
simplifKnownIterationLoop
               , RuleDoLoop lore (TopDown lore) -> TopDownRule lore
forall lore a. RuleDoLoop lore a -> SimplificationRule lore a
RuleDoLoop RuleDoLoop lore (TopDown lore)
forall lore.
(BinderOps lore, Aliased lore) =>
TopDownRuleDoLoop lore
simplifyLoopVariables
               , RuleGeneric lore (TopDown lore) -> TopDownRule lore
forall lore a. RuleGeneric lore a -> SimplificationRule lore a
RuleGeneric RuleGeneric lore (TopDown lore)
forall lore. BinderOps lore => TopDownRuleGeneric lore
constantFoldPrimFun
               , RuleIf lore (TopDown lore) -> TopDownRule lore
forall lore a. RuleIf lore a -> SimplificationRule lore a
RuleIf RuleIf lore (TopDown lore)
forall lore. BinderOps lore => TopDownRuleIf lore
ruleIf
               , RuleIf lore (TopDown lore) -> TopDownRule lore
forall lore a. RuleIf lore a -> SimplificationRule lore a
RuleIf RuleIf lore (TopDown lore)
forall lore. BinderOps lore => TopDownRuleIf lore
hoistBranchInvariant
               , RuleBasicOp lore (TopDown lore) -> TopDownRule lore
forall lore a. RuleBasicOp lore a -> SimplificationRule lore a
RuleBasicOp RuleBasicOp lore (TopDown lore)
forall lore. BinderOps lore => TopDownRuleBasicOp lore
ruleBasicOp
               ]

bottomUpRules :: BinderOps lore => [BottomUpRule lore]
bottomUpRules :: [BottomUpRule lore]
bottomUpRules = [ RuleDoLoop lore (BottomUp lore) -> BottomUpRule lore
forall lore a. RuleDoLoop lore a -> SimplificationRule lore a
RuleDoLoop RuleDoLoop lore (BottomUp lore)
forall lore. BinderOps lore => BottomUpRuleDoLoop lore
removeRedundantMergeVariables
                , RuleIf lore (BottomUp lore) -> BottomUpRule lore
forall lore a. RuleIf lore a -> SimplificationRule lore a
RuleIf RuleIf lore (BottomUp lore)
forall lore. BinderOps lore => BottomUpRuleIf lore
removeDeadBranchResult
                , RuleBasicOp lore (BottomUp lore) -> BottomUpRule lore
forall lore a. RuleBasicOp lore a -> SimplificationRule lore a
RuleBasicOp RuleBasicOp lore (BottomUp lore)
forall lore. BinderOps lore => BottomUpRuleBasicOp lore
simplifyIndex
                , RuleBasicOp lore (BottomUp lore) -> BottomUpRule lore
forall lore a. RuleBasicOp lore a -> SimplificationRule lore a
RuleBasicOp RuleBasicOp lore (BottomUp lore)
forall lore. BinderOps lore => BottomUpRuleBasicOp lore
simplifyConcat
                ]

asInt32PrimExp :: PrimExp v -> PrimExp v
asInt32PrimExp :: PrimExp v -> PrimExp v
asInt32PrimExp PrimExp v
pe
  | IntType IntType
it <- PrimExp v -> PrimType
forall v. PrimExp v -> PrimType
primExpType PrimExp v
pe, IntType
it IntType -> IntType -> Bool
forall a. Eq a => a -> a -> Bool
/= IntType
Int32 =
      IntType -> PrimExp v -> PrimExp v
forall v. IntType -> PrimExp v -> PrimExp v
sExt IntType
Int32 PrimExp v
pe
  | Bool
otherwise =
      PrimExp v
pe

-- | A set of standard simplification rules.  These assume pure
-- functional semantics, and so probably should not be applied after
-- memory block merging.
standardRules :: (BinderOps lore, Aliased lore) => RuleBook lore
standardRules :: RuleBook lore
standardRules = [TopDownRule lore] -> [BottomUpRule lore] -> RuleBook lore
forall m. [TopDownRule m] -> [BottomUpRule m] -> RuleBook m
ruleBook [TopDownRule lore]
forall lore. (BinderOps lore, Aliased lore) => [TopDownRule lore]
topDownRules [BottomUpRule lore]
forall lore. BinderOps lore => [BottomUpRule lore]
bottomUpRules

-- This next one is tricky - it's easy enough to determine that some
-- loop result is not used after the loop, but here, we must also make
-- sure that it does not affect any other values.
--
-- I do not claim that the current implementation of this rule is
-- perfect, but it should suffice for many cases, and should never
-- generate wrong code.
removeRedundantMergeVariables :: BinderOps lore => BottomUpRuleDoLoop lore
removeRedundantMergeVariables :: BottomUpRuleDoLoop lore
removeRedundantMergeVariables (SymbolTable lore
_, UsageTable
used) Pattern lore
pat StmAux (ExpDec lore)
_ ([(FParam lore, SubExp)]
ctx, [(FParam lore, SubExp)]
val, LoopForm lore
form, BodyT lore
body)
  | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ ((FParam lore, SubExp) -> Bool) -> [(FParam lore, SubExp)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (FParam lore -> Bool
usedAfterLoop (FParam lore -> Bool)
-> ((FParam lore, SubExp) -> FParam lore)
-> (FParam lore, SubExp)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (FParam lore, SubExp) -> FParam lore
forall a b. (a, b) -> a
fst) [(FParam lore, SubExp)]
val,
    [(FParam lore, SubExp)] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(FParam lore, SubExp)]
ctx = -- FIXME: things get tricky if we can remove all vals
               -- but some ctxs are still used.  We take the easy way
               -- out for now.
  let ([SubExp]
ctx_es, [SubExp]
val_es) = Int -> [SubExp] -> ([SubExp], [SubExp])
forall a. Int -> [a] -> ([a], [a])
splitAt ([(FParam lore, SubExp)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(FParam lore, SubExp)]
ctx) ([SubExp] -> ([SubExp], [SubExp]))
-> [SubExp] -> ([SubExp], [SubExp])
forall a b. (a -> b) -> a -> b
$ BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT lore
body
      necessaryForReturned :: Names
necessaryForReturned =
        (FParam lore -> Bool)
-> [(FParam lore, SubExp)] -> Map VName Names -> Names
forall dec.
(Param dec -> Bool)
-> [(Param dec, SubExp)] -> Map VName Names -> Names
findNecessaryForReturned FParam lore -> Bool
usedAfterLoopOrInForm
        ([FParam lore] -> [SubExp] -> [(FParam lore, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (((FParam lore, SubExp) -> FParam lore)
-> [(FParam lore, SubExp)] -> [FParam lore]
forall a b. (a -> b) -> [a] -> [b]
map (FParam lore, SubExp) -> FParam lore
forall a b. (a, b) -> a
fst ([(FParam lore, SubExp)] -> [FParam lore])
-> [(FParam lore, SubExp)] -> [FParam lore]
forall a b. (a -> b) -> a -> b
$ [(FParam lore, SubExp)]
ctx[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)] -> [(FParam lore, SubExp)]
forall a. [a] -> [a] -> [a]
++[(FParam lore, SubExp)]
val) ([SubExp] -> [(FParam lore, SubExp)])
-> [SubExp] -> [(FParam lore, SubExp)]
forall a b. (a -> b) -> a -> b
$ [SubExp]
ctx_es[SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++[SubExp]
val_es) (BodyT lore -> Map VName Names
forall lore. ASTLore lore => Body lore -> Map VName Names
dataDependencies BodyT lore
body)

      resIsNecessary :: ((FParam lore, SubExp), SubExp) -> Bool
resIsNecessary ((FParam lore
v,SubExp
_), SubExp
_) =
        FParam lore -> Bool
usedAfterLoop FParam lore
v Bool -> Bool -> Bool
||
        FParam lore -> VName
forall dec. Param dec -> VName
paramName FParam lore
v VName -> Names -> Bool
`nameIn` Names
necessaryForReturned Bool -> Bool -> Bool
||
        FParam lore -> Bool
referencedInPat FParam lore
v Bool -> Bool -> Bool
||
        FParam lore -> Bool
referencedInForm FParam lore
v

      ([((FParam lore, SubExp), SubExp)]
keep_ctx, [((FParam lore, SubExp), SubExp)]
discard_ctx) =
        (((FParam lore, SubExp), SubExp) -> Bool)
-> [((FParam lore, SubExp), SubExp)]
-> ([((FParam lore, SubExp), SubExp)],
    [((FParam lore, SubExp), SubExp)])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition ((FParam lore, SubExp), SubExp) -> Bool
resIsNecessary ([((FParam lore, SubExp), SubExp)]
 -> ([((FParam lore, SubExp), SubExp)],
     [((FParam lore, SubExp), SubExp)]))
-> [((FParam lore, SubExp), SubExp)]
-> ([((FParam lore, SubExp), SubExp)],
    [((FParam lore, SubExp), SubExp)])
forall a b. (a -> b) -> a -> b
$ [(FParam lore, SubExp)]
-> [SubExp] -> [((FParam lore, SubExp), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [(FParam lore, SubExp)]
ctx [SubExp]
ctx_es
      ([(PatElemT (LetDec lore), ((FParam lore, SubExp), SubExp))]
keep_valpart, [(PatElemT (LetDec lore), ((FParam lore, SubExp), SubExp))]
discard_valpart) =
        ((PatElemT (LetDec lore), ((FParam lore, SubExp), SubExp)) -> Bool)
-> [(PatElemT (LetDec lore), ((FParam lore, SubExp), SubExp))]
-> ([(PatElemT (LetDec lore), ((FParam lore, SubExp), SubExp))],
    [(PatElemT (LetDec lore), ((FParam lore, SubExp), SubExp))])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition (((FParam lore, SubExp), SubExp) -> Bool
resIsNecessary (((FParam lore, SubExp), SubExp) -> Bool)
-> ((PatElemT (LetDec lore), ((FParam lore, SubExp), SubExp))
    -> ((FParam lore, SubExp), SubExp))
-> (PatElemT (LetDec lore), ((FParam lore, SubExp), SubExp))
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PatElemT (LetDec lore), ((FParam lore, SubExp), SubExp))
-> ((FParam lore, SubExp), SubExp)
forall a b. (a, b) -> b
snd) ([(PatElemT (LetDec lore), ((FParam lore, SubExp), SubExp))]
 -> ([(PatElemT (LetDec lore), ((FParam lore, SubExp), SubExp))],
     [(PatElemT (LetDec lore), ((FParam lore, SubExp), SubExp))]))
-> [(PatElemT (LetDec lore), ((FParam lore, SubExp), SubExp))]
-> ([(PatElemT (LetDec lore), ((FParam lore, SubExp), SubExp))],
    [(PatElemT (LetDec lore), ((FParam lore, SubExp), SubExp))])
forall a b. (a -> b) -> a -> b
$
        [PatElemT (LetDec lore)]
-> [((FParam lore, SubExp), SubExp)]
-> [(PatElemT (LetDec lore), ((FParam lore, SubExp), SubExp))]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pattern lore -> [PatElemT (LetDec lore)]
forall dec. PatternT dec -> [PatElemT dec]
patternValueElements Pattern lore
pat) ([((FParam lore, SubExp), SubExp)]
 -> [(PatElemT (LetDec lore), ((FParam lore, SubExp), SubExp))])
-> [((FParam lore, SubExp), SubExp)]
-> [(PatElemT (LetDec lore), ((FParam lore, SubExp), SubExp))]
forall a b. (a -> b) -> a -> b
$ [(FParam lore, SubExp)]
-> [SubExp] -> [((FParam lore, SubExp), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [(FParam lore, SubExp)]
val [SubExp]
val_es

      ([PatElemT (LetDec lore)]
keep_valpatelems, [((FParam lore, SubExp), SubExp)]
keep_val) = [(PatElemT (LetDec lore), ((FParam lore, SubExp), SubExp))]
-> ([PatElemT (LetDec lore)], [((FParam lore, SubExp), SubExp)])
forall a b. [(a, b)] -> ([a], [b])
unzip [(PatElemT (LetDec lore), ((FParam lore, SubExp), SubExp))]
keep_valpart
      ([PatElemT (LetDec lore)]
_discard_valpatelems, [((FParam lore, SubExp), SubExp)]
discard_val) = [(PatElemT (LetDec lore), ((FParam lore, SubExp), SubExp))]
-> ([PatElemT (LetDec lore)], [((FParam lore, SubExp), SubExp)])
forall a b. [(a, b)] -> ([a], [b])
unzip [(PatElemT (LetDec lore), ((FParam lore, SubExp), SubExp))]
discard_valpart
      ([(FParam lore, SubExp)]
ctx', [SubExp]
ctx_es') = [((FParam lore, SubExp), SubExp)]
-> ([(FParam lore, SubExp)], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip [((FParam lore, SubExp), SubExp)]
keep_ctx
      ([(FParam lore, SubExp)]
val', [SubExp]
val_es') = [((FParam lore, SubExp), SubExp)]
-> ([(FParam lore, SubExp)], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip [((FParam lore, SubExp), SubExp)]
keep_val

      body' :: BodyT lore
body' = BodyT lore
body { bodyResult :: [SubExp]
bodyResult = [SubExp]
ctx_es' [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp]
val_es' }
      free_in_keeps :: Names
free_in_keeps = [PatElemT (LetDec lore)] -> Names
forall a. FreeIn a => a -> Names
freeIn [PatElemT (LetDec lore)]
keep_valpatelems

      stillUsedContext :: PatElemT (LetDec lore) -> Bool
stillUsedContext PatElemT (LetDec lore)
pat_elem =
        PatElemT (LetDec lore) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
pat_elem VName -> Names -> Bool
`nameIn`
        (Names
free_in_keeps Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<>
         [PatElemT (LetDec lore)] -> Names
forall a. FreeIn a => a -> Names
freeIn ((PatElemT (LetDec lore) -> Bool)
-> [PatElemT (LetDec lore)] -> [PatElemT (LetDec lore)]
forall a. (a -> Bool) -> [a] -> [a]
filter (PatElemT (LetDec lore) -> PatElemT (LetDec lore) -> Bool
forall a. Eq a => a -> a -> Bool
/=PatElemT (LetDec lore)
pat_elem) ([PatElemT (LetDec lore)] -> [PatElemT (LetDec lore)])
-> [PatElemT (LetDec lore)] -> [PatElemT (LetDec lore)]
forall a b. (a -> b) -> a -> b
$ Pattern lore -> [PatElemT (LetDec lore)]
forall dec. PatternT dec -> [PatElemT dec]
patternContextElements Pattern lore
pat))

      pat' :: Pattern lore
pat' = Pattern lore
pat { patternValueElements :: [PatElemT (LetDec lore)]
patternValueElements = [PatElemT (LetDec lore)]
keep_valpatelems
                 , patternContextElements :: [PatElemT (LetDec lore)]
patternContextElements =
                     (PatElemT (LetDec lore) -> Bool)
-> [PatElemT (LetDec lore)] -> [PatElemT (LetDec lore)]
forall a. (a -> Bool) -> [a] -> [a]
filter PatElemT (LetDec lore) -> Bool
stillUsedContext ([PatElemT (LetDec lore)] -> [PatElemT (LetDec lore)])
-> [PatElemT (LetDec lore)] -> [PatElemT (LetDec lore)]
forall a b. (a -> b) -> a -> b
$ Pattern lore -> [PatElemT (LetDec lore)]
forall dec. PatternT dec -> [PatElemT dec]
patternContextElements Pattern lore
pat }
  in if [(FParam lore, SubExp)]
ctx' [(FParam lore, SubExp)]
-> [(FParam lore, SubExp)] -> [(FParam lore, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(FParam lore, SubExp)]
val' [(FParam lore, SubExp)] -> [(FParam lore, SubExp)] -> Bool
forall a. Eq a => a -> a -> Bool
== [(FParam lore, SubExp)]
ctx [(FParam lore, SubExp)]
-> [(FParam lore, SubExp)] -> [(FParam lore, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(FParam lore, SubExp)]
val
     then Rule lore
forall lore. Rule lore
Skip
     else RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
       -- We can't just remove the bindings in 'discard', since the loop
       -- body may still use their names in (now-dead) expressions.
       -- Hence, we add them inside the loop, fully aware that dead-code
       -- removal will eventually get rid of them.  Some care is
       -- necessary to handle unique bindings.
       BodyT lore
body'' <- RuleM lore (Body (Lore (RuleM lore)))
-> RuleM lore (Body (Lore (RuleM lore)))
forall (m :: * -> *).
MonadBinder m =>
m (Body (Lore m)) -> m (Body (Lore m))
insertStmsM (RuleM lore (Body (Lore (RuleM lore)))
 -> RuleM lore (Body (Lore (RuleM lore))))
-> RuleM lore (Body (Lore (RuleM lore)))
-> RuleM lore (Body (Lore (RuleM lore)))
forall a b. (a -> b) -> a -> b
$ do
         (([VName], ExpT lore) -> RuleM lore ())
-> [([VName], ExpT lore)] -> RuleM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (([VName] -> ExpT lore -> RuleM lore ())
-> ([VName], ExpT lore) -> RuleM lore ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry [VName] -> ExpT lore -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames) ([([VName], ExpT lore)] -> RuleM lore ())
-> [([VName], ExpT lore)] -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ [((FParam lore, SubExp), SubExp)] -> [([VName], ExpT lore)]
forall b lore.
[((FParam lore, SubExp), b)] -> [([VName], ExpT lore)]
dummyStms [((FParam lore, SubExp), SubExp)]
discard_ctx
         (([VName], ExpT lore) -> RuleM lore ())
-> [([VName], ExpT lore)] -> RuleM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (([VName] -> ExpT lore -> RuleM lore ())
-> ([VName], ExpT lore) -> RuleM lore ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry [VName] -> ExpT lore -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames) ([([VName], ExpT lore)] -> RuleM lore ())
-> [([VName], ExpT lore)] -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ [((FParam lore, SubExp), SubExp)] -> [([VName], ExpT lore)]
forall b lore.
[((FParam lore, SubExp), b)] -> [([VName], ExpT lore)]
dummyStms [((FParam lore, SubExp), SubExp)]
discard_val
         BodyT lore -> RuleM lore (BodyT lore)
forall (m :: * -> *) a. Monad m => a -> m a
return BodyT lore
body'
       Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat' (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ [(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
forall lore.
[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
DoLoop [(FParam lore, SubExp)]
ctx' [(FParam lore, SubExp)]
val' LoopForm lore
form BodyT lore
body''
  where pat_used :: [Bool]
pat_used = (VName -> Bool) -> [VName] -> [Bool]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> UsageTable -> Bool
`UT.isUsedDirectly` UsageTable
used) ([VName] -> [Bool]) -> [VName] -> [Bool]
forall a b. (a -> b) -> a -> b
$ Pattern lore -> [VName]
forall dec. PatternT dec -> [VName]
patternValueNames Pattern lore
pat
        used_vals :: [VName]
used_vals = ((VName, Bool) -> VName) -> [(VName, Bool)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, Bool) -> VName
forall a b. (a, b) -> a
fst ([(VName, Bool)] -> [VName]) -> [(VName, Bool)] -> [VName]
forall a b. (a -> b) -> a -> b
$ ((VName, Bool) -> Bool) -> [(VName, Bool)] -> [(VName, Bool)]
forall a. (a -> Bool) -> [a] -> [a]
filter (VName, Bool) -> Bool
forall a b. (a, b) -> b
snd ([(VName, Bool)] -> [(VName, Bool)])
-> [(VName, Bool)] -> [(VName, Bool)]
forall a b. (a -> b) -> a -> b
$ [VName] -> [Bool] -> [(VName, Bool)]
forall a b. [a] -> [b] -> [(a, b)]
zip (((FParam lore, SubExp) -> VName)
-> [(FParam lore, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (FParam lore -> VName
forall dec. Param dec -> VName
paramName (FParam lore -> VName)
-> ((FParam lore, SubExp) -> FParam lore)
-> (FParam lore, SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (FParam lore, SubExp) -> FParam lore
forall a b. (a, b) -> a
fst) [(FParam lore, SubExp)]
val) [Bool]
pat_used
        usedAfterLoop :: FParam lore -> Bool
usedAfterLoop = (VName -> [VName] -> Bool) -> [VName] -> VName -> Bool
forall a b c. (a -> b -> c) -> b -> a -> c
flip VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
elem [VName]
used_vals (VName -> Bool) -> (FParam lore -> VName) -> FParam lore -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FParam lore -> VName
forall dec. Param dec -> VName
paramName
        usedAfterLoopOrInForm :: FParam lore -> Bool
usedAfterLoopOrInForm FParam lore
p =
          FParam lore -> Bool
usedAfterLoop FParam lore
p Bool -> Bool -> Bool
|| FParam lore -> VName
forall dec. Param dec -> VName
paramName FParam lore
p VName -> Names -> Bool
`nameIn` LoopForm lore -> Names
forall a. FreeIn a => a -> Names
freeIn LoopForm lore
form
        patAnnotNames :: Names
patAnnotNames = [FParam lore] -> Names
forall a. FreeIn a => a -> Names
freeIn ([FParam lore] -> Names) -> [FParam lore] -> Names
forall a b. (a -> b) -> a -> b
$ ((FParam lore, SubExp) -> FParam lore)
-> [(FParam lore, SubExp)] -> [FParam lore]
forall a b. (a -> b) -> [a] -> [b]
map (FParam lore, SubExp) -> FParam lore
forall a b. (a, b) -> a
fst ([(FParam lore, SubExp)] -> [FParam lore])
-> [(FParam lore, SubExp)] -> [FParam lore]
forall a b. (a -> b) -> a -> b
$ [(FParam lore, SubExp)]
ctx[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)] -> [(FParam lore, SubExp)]
forall a. [a] -> [a] -> [a]
++[(FParam lore, SubExp)]
val
        referencedInPat :: FParam lore -> Bool
referencedInPat = (VName -> Names -> Bool
`nameIn` Names
patAnnotNames) (VName -> Bool) -> (FParam lore -> VName) -> FParam lore -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FParam lore -> VName
forall dec. Param dec -> VName
paramName
        referencedInForm :: FParam lore -> Bool
referencedInForm = (VName -> Names -> Bool
`nameIn` LoopForm lore -> Names
forall a. FreeIn a => a -> Names
freeIn LoopForm lore
form) (VName -> Bool) -> (FParam lore -> VName) -> FParam lore -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FParam lore -> VName
forall dec. Param dec -> VName
paramName

        dummyStms :: [((FParam lore, SubExp), b)] -> [([VName], ExpT lore)]
dummyStms = (((FParam lore, SubExp), b) -> ([VName], ExpT lore))
-> [((FParam lore, SubExp), b)] -> [([VName], ExpT lore)]
forall a b. (a -> b) -> [a] -> [b]
map ((FParam lore, SubExp), b) -> ([VName], ExpT lore)
forall dec b lore.
DeclTyped dec =>
((Param dec, SubExp), b) -> ([VName], ExpT lore)
dummyStm
        dummyStm :: ((Param dec, SubExp), b) -> ([VName], ExpT lore)
dummyStm ((Param dec
p,SubExp
e), b
_)
          | TypeBase Shape Uniqueness -> Bool
forall shape. TypeBase shape Uniqueness -> Bool
unique (Param dec -> TypeBase Shape Uniqueness
forall dec. DeclTyped dec => Param dec -> TypeBase Shape Uniqueness
paramDeclType Param dec
p),
            Var VName
v <- SubExp
e            = ([Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
p], BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
v)
          | Bool
otherwise             = ([Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
p], BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
e)
removeRedundantMergeVariables (SymbolTable lore, UsageTable)
_ Pattern lore
_ StmAux (ExpDec lore)
_ ([(FParam lore, SubExp)], [(FParam lore, SubExp)], LoopForm lore,
 BodyT lore)
_ =
  Rule lore
forall lore. Rule lore
Skip

-- We may change the type of the loop if we hoist out a shape
-- annotation, in which case we also need to tweak the bound pattern.
hoistLoopInvariantMergeVariables :: BinderOps lore => TopDownRuleDoLoop lore
hoistLoopInvariantMergeVariables :: TopDownRuleDoLoop lore
hoistLoopInvariantMergeVariables TopDown lore
_ Pattern lore
pat StmAux (ExpDec lore)
_ ([(FParam lore, SubExp)]
ctx, [(FParam lore, SubExp)]
val, LoopForm lore
form, BodyT lore
loopbody) =
    -- Figure out which of the elements of loopresult are
    -- loop-invariant, and hoist them out.
  case (((FParam lore, SubExp), SubExp)
 -> ([(Ident, SubExp)], [(PatElemT (LetDec lore), VName)],
     [(FParam lore, SubExp)], [SubExp])
 -> ([(Ident, SubExp)], [(PatElemT (LetDec lore), VName)],
     [(FParam lore, SubExp)], [SubExp]))
-> ([(Ident, SubExp)], [(PatElemT (LetDec lore), VName)],
    [(FParam lore, SubExp)], [SubExp])
-> [((FParam lore, SubExp), SubExp)]
-> ([(Ident, SubExp)], [(PatElemT (LetDec lore), VName)],
    [(FParam lore, SubExp)], [SubExp])
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr ((FParam lore, SubExp), SubExp)
-> ([(Ident, SubExp)], [(PatElemT (LetDec lore), VName)],
    [(FParam lore, SubExp)], [SubExp])
-> ([(Ident, SubExp)], [(PatElemT (LetDec lore), VName)],
    [(FParam lore, SubExp)], [SubExp])
checkInvariance ([], [(PatElemT (LetDec lore), VName)]
explpat, [], []) ([((FParam lore, SubExp), SubExp)]
 -> ([(Ident, SubExp)], [(PatElemT (LetDec lore), VName)],
     [(FParam lore, SubExp)], [SubExp]))
-> [((FParam lore, SubExp), SubExp)]
-> ([(Ident, SubExp)], [(PatElemT (LetDec lore), VName)],
    [(FParam lore, SubExp)], [SubExp])
forall a b. (a -> b) -> a -> b
$
       [(FParam lore, SubExp)]
-> [SubExp] -> [((FParam lore, SubExp), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [(FParam lore, SubExp)]
merge [SubExp]
res of
    ([], [(PatElemT (LetDec lore), VName)]
_, [(FParam lore, SubExp)]
_, [SubExp]
_) ->
      -- Nothing is invariant.
      Rule lore
forall lore. Rule lore
Skip
    ([(Ident, SubExp)]
invariant, [(PatElemT (LetDec lore), VName)]
explpat', [(FParam lore, SubExp)]
merge', [SubExp]
res') -> RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
      -- We have moved something invariant out of the loop.
      let loopbody' :: BodyT lore
loopbody' = BodyT lore
loopbody { bodyResult :: [SubExp]
bodyResult = [SubExp]
res' }
          invariantShape :: (a, VName) -> Bool
          invariantShape :: (a, VName) -> Bool
invariantShape (a
_, VName
shapemerge) = VName
shapemerge VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem`
                                           ((FParam lore, SubExp) -> VName)
-> [(FParam lore, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (FParam lore -> VName
forall dec. Param dec -> VName
paramName (FParam lore -> VName)
-> ((FParam lore, SubExp) -> FParam lore)
-> (FParam lore, SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (FParam lore, SubExp) -> FParam lore
forall a b. (a, b) -> a
fst) [(FParam lore, SubExp)]
merge'
          ([(PatElemT (LetDec lore), VName)]
implpat',[(PatElemT (LetDec lore), VName)]
implinvariant) = ((PatElemT (LetDec lore), VName) -> Bool)
-> [(PatElemT (LetDec lore), VName)]
-> ([(PatElemT (LetDec lore), VName)],
    [(PatElemT (LetDec lore), VName)])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition (PatElemT (LetDec lore), VName) -> Bool
forall a. (a, VName) -> Bool
invariantShape [(PatElemT (LetDec lore), VName)]
implpat
          implinvariant' :: [(Ident, SubExp)]
implinvariant' = [ (PatElemT (LetDec lore) -> Ident
forall dec. Typed dec => PatElemT dec -> Ident
patElemIdent PatElemT (LetDec lore)
p, VName -> SubExp
Var VName
v) | (PatElemT (LetDec lore)
p,VName
v) <- [(PatElemT (LetDec lore), VName)]
implinvariant ]
          implpat'' :: [PatElemT (LetDec lore)]
implpat'' = ((PatElemT (LetDec lore), VName) -> PatElemT (LetDec lore))
-> [(PatElemT (LetDec lore), VName)] -> [PatElemT (LetDec lore)]
forall a b. (a -> b) -> [a] -> [b]
map (PatElemT (LetDec lore), VName) -> PatElemT (LetDec lore)
forall a b. (a, b) -> a
fst [(PatElemT (LetDec lore), VName)]
implpat'
          explpat'' :: [PatElemT (LetDec lore)]
explpat'' = ((PatElemT (LetDec lore), VName) -> PatElemT (LetDec lore))
-> [(PatElemT (LetDec lore), VName)] -> [PatElemT (LetDec lore)]
forall a b. (a -> b) -> [a] -> [b]
map (PatElemT (LetDec lore), VName) -> PatElemT (LetDec lore)
forall a b. (a, b) -> a
fst [(PatElemT (LetDec lore), VName)]
explpat'
          ([(FParam lore, SubExp)]
ctx', [(FParam lore, SubExp)]
val') = Int
-> [(FParam lore, SubExp)]
-> ([(FParam lore, SubExp)], [(FParam lore, SubExp)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([(PatElemT (LetDec lore), VName)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(PatElemT (LetDec lore), VName)]
implpat') [(FParam lore, SubExp)]
merge'
      [(Ident, SubExp)]
-> ((Ident, SubExp) -> RuleM lore ()) -> RuleM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([(Ident, SubExp)]
invariant [(Ident, SubExp)] -> [(Ident, SubExp)] -> [(Ident, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(Ident, SubExp)]
implinvariant') (((Ident, SubExp) -> RuleM lore ()) -> RuleM lore ())
-> ((Ident, SubExp) -> RuleM lore ()) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ \(Ident
v1,SubExp
v2) ->
        [VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [Ident -> VName
identName Ident
v1] (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
v2
      Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind ([PatElemT (LetDec lore)]
-> [PatElemT (LetDec lore)] -> Pattern lore
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [PatElemT (LetDec lore)]
implpat'' [PatElemT (LetDec lore)]
explpat'') (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
        [(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
forall lore.
[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
DoLoop [(FParam lore, SubExp)]
ctx' [(FParam lore, SubExp)]
val' LoopForm lore
form BodyT lore
loopbody'
  where merge :: [(FParam lore, SubExp)]
merge = [(FParam lore, SubExp)]
ctx [(FParam lore, SubExp)]
-> [(FParam lore, SubExp)] -> [(FParam lore, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(FParam lore, SubExp)]
val
        res :: [SubExp]
res = BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT lore
loopbody

        implpat :: [(PatElemT (LetDec lore), VName)]
implpat = [PatElemT (LetDec lore)]
-> [VName] -> [(PatElemT (LetDec lore), VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pattern lore -> [PatElemT (LetDec lore)]
forall dec. PatternT dec -> [PatElemT dec]
patternContextElements Pattern lore
pat) ([VName] -> [(PatElemT (LetDec lore), VName)])
-> [VName] -> [(PatElemT (LetDec lore), VName)]
forall a b. (a -> b) -> a -> b
$
                  ((FParam lore, SubExp) -> VName)
-> [(FParam lore, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (FParam lore -> VName
forall dec. Param dec -> VName
paramName (FParam lore -> VName)
-> ((FParam lore, SubExp) -> FParam lore)
-> (FParam lore, SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (FParam lore, SubExp) -> FParam lore
forall a b. (a, b) -> a
fst) [(FParam lore, SubExp)]
ctx
        explpat :: [(PatElemT (LetDec lore), VName)]
explpat = [PatElemT (LetDec lore)]
-> [VName] -> [(PatElemT (LetDec lore), VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pattern lore -> [PatElemT (LetDec lore)]
forall dec. PatternT dec -> [PatElemT dec]
patternValueElements Pattern lore
pat) ([VName] -> [(PatElemT (LetDec lore), VName)])
-> [VName] -> [(PatElemT (LetDec lore), VName)]
forall a b. (a -> b) -> a -> b
$
                  ((FParam lore, SubExp) -> VName)
-> [(FParam lore, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (FParam lore -> VName
forall dec. Param dec -> VName
paramName (FParam lore -> VName)
-> ((FParam lore, SubExp) -> FParam lore)
-> (FParam lore, SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (FParam lore, SubExp) -> FParam lore
forall a b. (a, b) -> a
fst) [(FParam lore, SubExp)]
val

        namesOfMergeParams :: Names
namesOfMergeParams = [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ ((FParam lore, SubExp) -> VName)
-> [(FParam lore, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (FParam lore -> VName
forall dec. Param dec -> VName
paramName (FParam lore -> VName)
-> ((FParam lore, SubExp) -> FParam lore)
-> (FParam lore, SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (FParam lore, SubExp) -> FParam lore
forall a b. (a, b) -> a
fst) ([(FParam lore, SubExp)] -> [VName])
-> [(FParam lore, SubExp)] -> [VName]
forall a b. (a -> b) -> a -> b
$ [(FParam lore, SubExp)]
ctx[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)] -> [(FParam lore, SubExp)]
forall a. [a] -> [a] -> [a]
++[(FParam lore, SubExp)]
val

        removeFromResult :: (Param dec, b)
-> [(PatElemT dec, VName)]
-> (Maybe (Ident, b), [(PatElemT dec, VName)])
removeFromResult (Param dec
mergeParam,b
mergeInit) [(PatElemT dec, VName)]
explpat' =
          case ((PatElemT dec, VName) -> Bool)
-> [(PatElemT dec, VName)]
-> ([(PatElemT dec, VName)], [(PatElemT dec, VName)])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
==Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
mergeParam) (VName -> Bool)
-> ((PatElemT dec, VName) -> VName)
-> (PatElemT dec, VName)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PatElemT dec, VName) -> VName
forall a b. (a, b) -> b
snd) [(PatElemT dec, VName)]
explpat' of
            ([(PatElemT dec
patelem,VName
_)], [(PatElemT dec, VName)]
rest) ->
              ((Ident, b) -> Maybe (Ident, b)
forall a. a -> Maybe a
Just (PatElemT dec -> Ident
forall dec. Typed dec => PatElemT dec -> Ident
patElemIdent PatElemT dec
patelem, b
mergeInit), [(PatElemT dec, VName)]
rest)
            ([(PatElemT dec, VName)]
_,      [(PatElemT dec, VName)]
_) ->
              (Maybe (Ident, b)
forall a. Maybe a
Nothing, [(PatElemT dec, VName)]
explpat')

        checkInvariance :: ((FParam lore, SubExp), SubExp)
-> ([(Ident, SubExp)], [(PatElemT (LetDec lore), VName)],
    [(FParam lore, SubExp)], [SubExp])
-> ([(Ident, SubExp)], [(PatElemT (LetDec lore), VName)],
    [(FParam lore, SubExp)], [SubExp])
checkInvariance
          ((FParam lore
mergeParam,SubExp
mergeInit), SubExp
resExp)
          ([(Ident, SubExp)]
invariant, [(PatElemT (LetDec lore), VName)]
explpat', [(FParam lore, SubExp)]
merge', [SubExp]
resExps)
          | Bool -> Bool
not (TypeBase Shape Uniqueness -> Bool
forall shape. TypeBase shape Uniqueness -> Bool
unique (FParam lore -> TypeBase Shape Uniqueness
forall dec. DeclTyped dec => Param dec -> TypeBase Shape Uniqueness
paramDeclType FParam lore
mergeParam)) Bool -> Bool -> Bool
||
            TypeBase Shape Uniqueness -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank (FParam lore -> TypeBase Shape Uniqueness
forall dec. DeclTyped dec => Param dec -> TypeBase Shape Uniqueness
paramDeclType FParam lore
mergeParam) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1,
            SubExp -> Bool
isInvariant SubExp
resExp,
            -- Also do not remove the condition in a while-loop.
            Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ FParam lore -> VName
forall dec. Param dec -> VName
paramName FParam lore
mergeParam VName -> Names -> Bool
`nameIn` LoopForm lore -> Names
forall a. FreeIn a => a -> Names
freeIn LoopForm lore
form =
          let (Maybe (Ident, SubExp)
bnd, [(PatElemT (LetDec lore), VName)]
explpat'') =
                (FParam lore, SubExp)
-> [(PatElemT (LetDec lore), VName)]
-> (Maybe (Ident, SubExp), [(PatElemT (LetDec lore), VName)])
forall dec dec b.
Typed dec =>
(Param dec, b)
-> [(PatElemT dec, VName)]
-> (Maybe (Ident, b), [(PatElemT dec, VName)])
removeFromResult (FParam lore
mergeParam,SubExp
mergeInit) [(PatElemT (LetDec lore), VName)]
explpat'
          in (([(Ident, SubExp)] -> [(Ident, SubExp)])
-> ((Ident, SubExp) -> [(Ident, SubExp)] -> [(Ident, SubExp)])
-> Maybe (Ident, SubExp)
-> [(Ident, SubExp)]
-> [(Ident, SubExp)]
forall b a. b -> (a -> b) -> Maybe a -> b
maybe [(Ident, SubExp)] -> [(Ident, SubExp)]
forall a. a -> a
id (:) Maybe (Ident, SubExp)
bnd ([(Ident, SubExp)] -> [(Ident, SubExp)])
-> [(Ident, SubExp)] -> [(Ident, SubExp)]
forall a b. (a -> b) -> a -> b
$ (FParam lore -> Ident
forall dec. Typed dec => Param dec -> Ident
paramIdent FParam lore
mergeParam, SubExp
mergeInit) (Ident, SubExp) -> [(Ident, SubExp)] -> [(Ident, SubExp)]
forall a. a -> [a] -> [a]
: [(Ident, SubExp)]
invariant,
              [(PatElemT (LetDec lore), VName)]
explpat'', [(FParam lore, SubExp)]
merge', [SubExp]
resExps)
          where
            -- A non-unique merge variable is invariant if the corresponding
            -- subexp in the result is EITHER:
            --
            --  (0) a variable of the same name as the parameter, where
            --  all existential parameters are already known to be
            --  invariant
            isInvariant :: SubExp -> Bool
isInvariant (Var VName
v2)
              | FParam lore -> VName
forall dec. Param dec -> VName
paramName FParam lore
mergeParam VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
v2 =
                Names -> FParam lore -> Bool
allExistentialInvariant
                ([VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ ((Ident, SubExp) -> VName) -> [(Ident, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (Ident -> VName
identName (Ident -> VName)
-> ((Ident, SubExp) -> Ident) -> (Ident, SubExp) -> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Ident, SubExp) -> Ident
forall a b. (a, b) -> a
fst) [(Ident, SubExp)]
invariant) FParam lore
mergeParam
            --  (1) or identical to the initial value of the parameter.
            isInvariant SubExp
_ = SubExp
mergeInit SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
resExp

        checkInvariance ((FParam lore
mergeParam,SubExp
mergeInit), SubExp
resExp) ([(Ident, SubExp)]
invariant, [(PatElemT (LetDec lore), VName)]
explpat', [(FParam lore, SubExp)]
merge', [SubExp]
resExps) =
          ([(Ident, SubExp)]
invariant, [(PatElemT (LetDec lore), VName)]
explpat', (FParam lore
mergeParam,SubExp
mergeInit)(FParam lore, SubExp)
-> [(FParam lore, SubExp)] -> [(FParam lore, SubExp)]
forall a. a -> [a] -> [a]
:[(FParam lore, SubExp)]
merge', SubExp
resExpSubExp -> [SubExp] -> [SubExp]
forall a. a -> [a] -> [a]
:[SubExp]
resExps)

        allExistentialInvariant :: Names -> FParam lore -> Bool
allExistentialInvariant Names
namesOfInvariant FParam lore
mergeParam =
          (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Names -> VName -> Bool
invariantOrNotMergeParam Names
namesOfInvariant) ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$
          FParam lore -> Names
forall a. FreeIn a => a -> Names
freeIn FParam lore
mergeParam Names -> Names -> Names
`namesSubtract` VName -> Names
oneName (FParam lore -> VName
forall dec. Param dec -> VName
paramName FParam lore
mergeParam)
        invariantOrNotMergeParam :: Names -> VName -> Bool
invariantOrNotMergeParam Names
namesOfInvariant VName
name =
          Bool -> Bool
not (VName
name VName -> Names -> Bool
`nameIn` Names
namesOfMergeParams) Bool -> Bool -> Bool
||
          VName
name VName -> Names -> Bool
`nameIn` Names
namesOfInvariant

-- | A function that, given a subexpression, returns its type.
type TypeLookup = SubExp -> Maybe Type

-- | A simple rule is a top-down rule that can be expressed as a pure
-- function.
type SimpleRule lore = VarLookup lore -> TypeLookup -> BasicOp -> Maybe (BasicOp, Certificates)

simpleRules :: [SimpleRule lore]
simpleRules :: [SimpleRule lore]
simpleRules = [ SimpleRule lore
forall lore. SimpleRule lore
simplifyBinOp
              , SimpleRule lore
forall lore. SimpleRule lore
simplifyCmpOp
              , SimpleRule lore
forall lore. SimpleRule lore
simplifyUnOp
              , SimpleRule lore
forall lore. SimpleRule lore
simplifyConvOp
              , SimpleRule lore
forall lore. SimpleRule lore
simplifyAssert
              , SimpleRule lore
forall lore. SimpleRule lore
copyScratchToScratch
              , SimpleRule lore
forall lore. SimpleRule lore
simplifyIdentityReshape
              , SimpleRule lore
forall lore. SimpleRule lore
simplifyReshapeReshape
              , SimpleRule lore
forall lore. SimpleRule lore
simplifyReshapeScratch
              , SimpleRule lore
forall lore. SimpleRule lore
simplifyReshapeReplicate
              , SimpleRule lore
forall lore. SimpleRule lore
simplifyReshapeIota
              , SimpleRule lore
forall lore. SimpleRule lore
improveReshape ]

simplifyClosedFormLoop :: BinderOps lore => TopDownRuleDoLoop lore
simplifyClosedFormLoop :: TopDownRuleDoLoop lore
simplifyClosedFormLoop TopDown lore
_ Pattern lore
pat StmAux (ExpDec lore)
_ ([], [(FParam lore, SubExp)]
val, ForLoop VName
i IntType
_ SubExp
bound [], BodyT lore
body) =
  RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ Pattern lore
-> [(FParam lore, SubExp)]
-> Names
-> SubExp
-> BodyT lore
-> RuleM lore ()
forall lore.
(ASTLore lore, BinderOps lore) =>
Pattern lore
-> [(FParam lore, SubExp)]
-> Names
-> SubExp
-> Body lore
-> RuleM lore ()
loopClosedForm Pattern lore
pat [(FParam lore, SubExp)]
val (VName -> Names
oneName VName
i) SubExp
bound BodyT lore
body
simplifyClosedFormLoop TopDown lore
_ Pattern lore
_ StmAux (ExpDec lore)
_ ([(FParam lore, SubExp)], [(FParam lore, SubExp)], LoopForm lore,
 BodyT lore)
_ = Rule lore
forall lore. Rule lore
Skip

simplifyLoopVariables :: (BinderOps lore, Aliased lore) => TopDownRuleDoLoop lore
simplifyLoopVariables :: TopDownRuleDoLoop lore
simplifyLoopVariables TopDown lore
vtable Pattern lore
pat StmAux (ExpDec lore)
_ ([(FParam lore, SubExp)]
ctx, [(FParam lore, SubExp)]
val, form :: LoopForm lore
form@(ForLoop VName
i IntType
it SubExp
num_iters [(LParam lore, VName)]
loop_vars), BodyT lore
body)
  | [Maybe (RuleM lore IndexResult)]
simplifiable <- ((LParam lore, VName) -> Maybe (RuleM lore IndexResult))
-> [(LParam lore, VName)] -> [Maybe (RuleM lore IndexResult)]
forall a b. (a -> b) -> [a] -> [b]
map (LParam lore, VName) -> Maybe (RuleM lore IndexResult)
checkIfSimplifiable [(LParam lore, VName)]
loop_vars,
    Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (Maybe (RuleM lore IndexResult) -> Bool)
-> [Maybe (RuleM lore IndexResult)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Maybe (RuleM lore IndexResult) -> Bool
forall a. Maybe a -> Bool
isNothing [Maybe (RuleM lore IndexResult)]
simplifiable = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
      -- Check if the simplifications throw away more information than
      -- we are comfortable with at this stage.
      ([Maybe (LParam lore, VName)]
maybe_loop_vars, [Stms lore]
body_prefix_stms) <-
        Scope lore
-> RuleM lore ([Maybe (LParam lore, VName)], [Stms lore])
-> RuleM lore ([Maybe (LParam lore, VName)], [Stms lore])
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (LoopForm lore -> Scope lore
forall lore a. Scoped lore a => a -> Scope lore
scopeOf LoopForm lore
form) (RuleM lore ([Maybe (LParam lore, VName)], [Stms lore])
 -> RuleM lore ([Maybe (LParam lore, VName)], [Stms lore]))
-> RuleM lore ([Maybe (LParam lore, VName)], [Stms lore])
-> RuleM lore ([Maybe (LParam lore, VName)], [Stms lore])
forall a b. (a -> b) -> a -> b
$
        [(Maybe (LParam lore, VName), Stms lore)]
-> ([Maybe (LParam lore, VName)], [Stms lore])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Maybe (LParam lore, VName), Stms lore)]
 -> ([Maybe (LParam lore, VName)], [Stms lore]))
-> RuleM lore [(Maybe (LParam lore, VName), Stms lore)]
-> RuleM lore ([Maybe (LParam lore, VName)], [Stms lore])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((LParam lore, VName)
 -> Maybe (RuleM lore IndexResult)
 -> RuleM lore (Maybe (LParam lore, VName), Stms lore))
-> [(LParam lore, VName)]
-> [Maybe (RuleM lore IndexResult)]
-> RuleM lore [(Maybe (LParam lore, VName), Stms lore)]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (LParam lore, VName)
-> Maybe (RuleM lore IndexResult)
-> RuleM lore (Maybe (LParam lore, VName), Stms lore)
onLoopVar [(LParam lore, VName)]
loop_vars [Maybe (RuleM lore IndexResult)]
simplifiable
      if [Maybe (LParam lore, VName)]
maybe_loop_vars [Maybe (LParam lore, VName)]
-> [Maybe (LParam lore, VName)] -> Bool
forall a. Eq a => a -> a -> Bool
== ((LParam lore, VName) -> Maybe (LParam lore, VName))
-> [(LParam lore, VName)] -> [Maybe (LParam lore, VName)]
forall a b. (a -> b) -> [a] -> [b]
map (LParam lore, VName) -> Maybe (LParam lore, VName)
forall a. a -> Maybe a
Just [(LParam lore, VName)]
loop_vars
        then RuleM lore ()
forall lore a. RuleM lore a
cannotSimplify
        else do BodyT lore
body' <- RuleM lore (Body (Lore (RuleM lore)))
-> RuleM lore (Body (Lore (RuleM lore)))
forall (m :: * -> *).
MonadBinder m =>
m (Body (Lore m)) -> m (Body (Lore m))
insertStmsM (RuleM lore (Body (Lore (RuleM lore)))
 -> RuleM lore (Body (Lore (RuleM lore))))
-> RuleM lore (Body (Lore (RuleM lore)))
-> RuleM lore (Body (Lore (RuleM lore)))
forall a b. (a -> b) -> a -> b
$ do
                  Stms (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms (Stms (Lore (RuleM lore)) -> RuleM lore ())
-> Stms (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ [Stms lore] -> Stms lore
forall a. Monoid a => [a] -> a
mconcat [Stms lore]
body_prefix_stms
                  [SubExp] -> RuleM lore (BodyT lore)
forall (m :: * -> *).
MonadBinder m =>
[SubExp] -> m (Body (Lore m))
resultBodyM ([SubExp] -> RuleM lore (BodyT lore))
-> RuleM lore [SubExp] -> RuleM lore (BodyT lore)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Body (Lore (RuleM lore)) -> RuleM lore [SubExp]
forall (m :: * -> *). MonadBinder m => Body (Lore m) -> m [SubExp]
bodyBind BodyT lore
Body (Lore (RuleM lore))
body
                Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ [(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
forall lore.
[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
DoLoop [(FParam lore, SubExp)]
ctx [(FParam lore, SubExp)]
val
                  (VName
-> IntType -> SubExp -> [(LParam lore, VName)] -> LoopForm lore
forall lore.
VName
-> IntType -> SubExp -> [(LParam lore, VName)] -> LoopForm lore
ForLoop VName
i IntType
it SubExp
num_iters ([(LParam lore, VName)] -> LoopForm lore)
-> [(LParam lore, VName)] -> LoopForm lore
forall a b. (a -> b) -> a -> b
$ [Maybe (LParam lore, VName)] -> [(LParam lore, VName)]
forall a. [Maybe a] -> [a]
catMaybes [Maybe (LParam lore, VName)]
maybe_loop_vars) BodyT lore
body'

  where seType :: SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType (Var VName
v)
          | VName
v VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
i = TypeBase Shape NoUniqueness -> Maybe (TypeBase Shape NoUniqueness)
forall a. a -> Maybe a
Just (TypeBase Shape NoUniqueness
 -> Maybe (TypeBase Shape NoUniqueness))
-> TypeBase Shape NoUniqueness
-> Maybe (TypeBase Shape NoUniqueness)
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase Shape NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim (PrimType -> TypeBase Shape NoUniqueness)
-> PrimType -> TypeBase Shape NoUniqueness
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
it
          | Bool
otherwise = VName -> TopDown lore -> Maybe (TypeBase Shape NoUniqueness)
forall lore.
ASTLore lore =>
VName -> SymbolTable lore -> Maybe (TypeBase Shape NoUniqueness)
ST.lookupType VName
v TopDown lore
vtable
        seType (Constant PrimValue
v) = TypeBase Shape NoUniqueness -> Maybe (TypeBase Shape NoUniqueness)
forall a. a -> Maybe a
Just (TypeBase Shape NoUniqueness
 -> Maybe (TypeBase Shape NoUniqueness))
-> TypeBase Shape NoUniqueness
-> Maybe (TypeBase Shape NoUniqueness)
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase Shape NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim (PrimType -> TypeBase Shape NoUniqueness)
-> PrimType -> TypeBase Shape NoUniqueness
forall a b. (a -> b) -> a -> b
$ PrimValue -> PrimType
primValueType PrimValue
v
        consumed_in_body :: Names
consumed_in_body = BodyT lore -> Names
forall lore. Aliased lore => Body lore -> Names
consumedInBody BodyT lore
body

        vtable' :: TopDown lore
vtable' = Scope lore -> TopDown lore
forall lore. ASTLore lore => Scope lore -> SymbolTable lore
ST.fromScope (LoopForm lore -> Scope lore
forall lore a. Scoped lore a => a -> Scope lore
scopeOf LoopForm lore
form) TopDown lore -> TopDown lore -> TopDown lore
forall a. Semigroup a => a -> a -> a
<> TopDown lore
vtable

        checkIfSimplifiable :: (LParam lore, VName) -> Maybe (RuleM lore IndexResult)
checkIfSimplifiable (LParam lore
p,VName
arr) =
          SymbolTable (Lore (RuleM lore))
-> (SubExp -> Maybe (TypeBase Shape NoUniqueness))
-> VName
-> Slice SubExp
-> Bool
-> Maybe (RuleM lore IndexResult)
forall (m :: * -> *).
MonadBinder m =>
SymbolTable (Lore m)
-> (SubExp -> Maybe (TypeBase Shape NoUniqueness))
-> VName
-> Slice SubExp
-> Bool
-> Maybe (m IndexResult)
simplifyIndexing TopDown lore
SymbolTable (Lore (RuleM lore))
vtable' SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType VName
arr
          (SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (VName -> SubExp
Var VName
i) DimIndex SubExp -> Slice SubExp -> Slice SubExp
forall a. a -> [a] -> [a]
: TypeBase Shape NoUniqueness -> Slice SubExp -> Slice SubExp
fullSlice (LParam lore -> TypeBase Shape NoUniqueness
forall dec. Typed dec => Param dec -> TypeBase Shape NoUniqueness
paramType LParam lore
p) []) (Bool -> Maybe (RuleM lore IndexResult))
-> Bool -> Maybe (RuleM lore IndexResult)
forall a b. (a -> b) -> a -> b
$
          LParam lore -> VName
forall dec. Param dec -> VName
paramName LParam lore
p VName -> Names -> Bool
`nameIn` Names
consumed_in_body

        -- We only want this simplification if the result does not refer
        -- to 'i' at all, or does not contain accesses.
        onLoopVar :: (LParam lore, VName)
-> Maybe (RuleM lore IndexResult)
-> RuleM lore (Maybe (LParam lore, VName), Stms lore)
onLoopVar (LParam lore
p,VName
arr) Maybe (RuleM lore IndexResult)
Nothing =
          (Maybe (LParam lore, VName), Stms lore)
-> RuleM lore (Maybe (LParam lore, VName), Stms lore)
forall (m :: * -> *) a. Monad m => a -> m a
return ((LParam lore, VName) -> Maybe (LParam lore, VName)
forall a. a -> Maybe a
Just (LParam lore
p,VName
arr), Stms lore
forall a. Monoid a => a
mempty)
        onLoopVar (LParam lore
p,VName
arr) (Just RuleM lore IndexResult
m) = do
          (IndexResult
x,Stms lore
x_stms) <- RuleM lore IndexResult
-> RuleM lore (IndexResult, Stms (Lore (RuleM lore)))
forall (m :: * -> *) a.
MonadBinder m =>
m a -> m (a, Stms (Lore m))
collectStms RuleM lore IndexResult
m
          case IndexResult
x of
            IndexResult Certificates
cs VName
arr' Slice SubExp
slice
              | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (Stm lore -> Bool) -> Stms lore -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any ((VName
i VName -> Names -> Bool
`nameIn`) (Names -> Bool) -> (Stm lore -> Names) -> Stm lore -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm lore -> Names
forall a. FreeIn a => a -> Names
freeIn) Stms lore
x_stms,
                DimFix (Var VName
j) : Slice SubExp
slice' <- Slice SubExp
slice,
                VName
j VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
i, Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ VName
i VName -> Names -> Bool
`nameIn` Slice SubExp -> Names
forall a. FreeIn a => a -> Names
freeIn Slice SubExp
slice -> do
                  Stms (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms Stms lore
Stms (Lore (RuleM lore))
x_stms
                  SubExp
w <- Int -> TypeBase Shape NoUniqueness -> SubExp
forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 (TypeBase Shape NoUniqueness -> SubExp)
-> RuleM lore (TypeBase Shape NoUniqueness) -> RuleM lore SubExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> RuleM lore (TypeBase Shape NoUniqueness)
forall lore (m :: * -> *).
HasScope lore m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
arr'
                  VName
for_in_partial <-
                    Certificates -> RuleM lore VName -> RuleM lore VName
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs (RuleM lore VName -> RuleM lore VName)
-> RuleM lore VName -> RuleM lore VName
forall a b. (a -> b) -> a -> b
$ String -> Exp (Lore (RuleM lore)) -> RuleM lore VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"for_in_partial" (Exp (Lore (RuleM lore)) -> RuleM lore VName)
-> Exp (Lore (RuleM lore)) -> RuleM lore VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr' (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$
                    SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice (IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
0) SubExp
w (IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
1) DimIndex SubExp -> Slice SubExp -> Slice SubExp
forall a. a -> [a] -> [a]
: Slice SubExp
slice'
                  (Maybe (LParam lore, VName), Stms lore)
-> RuleM lore (Maybe (LParam lore, VName), Stms lore)
forall (m :: * -> *) a. Monad m => a -> m a
return ((LParam lore, VName) -> Maybe (LParam lore, VName)
forall a. a -> Maybe a
Just (LParam lore
p, VName
for_in_partial), Stms lore
forall a. Monoid a => a
mempty)

            SubExpResult Certificates
cs SubExp
se
              | (Stm lore -> Bool) -> Stms lore -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (ExpT lore -> Bool
forall lore. ExpT lore -> Bool
notIndex (ExpT lore -> Bool) -> (Stm lore -> ExpT lore) -> Stm lore -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm lore -> ExpT lore
forall lore. Stm lore -> Exp lore
stmExp) Stms lore
x_stms -> do
                  Stms lore
x_stms' <- RuleM lore () -> RuleM lore (Stms (Lore (RuleM lore)))
forall (m :: * -> *) a. MonadBinder m => m a -> m (Stms (Lore m))
collectStms_ (RuleM lore () -> RuleM lore (Stms (Lore (RuleM lore))))
-> RuleM lore () -> RuleM lore (Stms (Lore (RuleM lore)))
forall a b. (a -> b) -> a -> b
$ Certificates -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ do
                    Stms (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms Stms lore
Stms (Lore (RuleM lore))
x_stms
                    [VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [LParam lore -> VName
forall dec. Param dec -> VName
paramName LParam lore
p] (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
                  (Maybe (LParam lore, VName), Stms lore)
-> RuleM lore (Maybe (LParam lore, VName), Stms lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (LParam lore, VName)
forall a. Maybe a
Nothing, Stms lore
x_stms')

            IndexResult
_ -> (Maybe (LParam lore, VName), Stms lore)
-> RuleM lore (Maybe (LParam lore, VName), Stms lore)
forall (m :: * -> *) a. Monad m => a -> m a
return ((LParam lore, VName) -> Maybe (LParam lore, VName)
forall a. a -> Maybe a
Just (LParam lore
p,VName
arr), Stms lore
forall a. Monoid a => a
mempty)

        notIndex :: ExpT lore -> Bool
notIndex (BasicOp Index{}) = Bool
False
        notIndex ExpT lore
_                 = Bool
True
simplifyLoopVariables TopDown lore
_ Pattern lore
_ StmAux (ExpDec lore)
_ ([(FParam lore, SubExp)], [(FParam lore, SubExp)], LoopForm lore,
 BodyT lore)
_ = Rule lore
forall lore. Rule lore
Skip

simplifKnownIterationLoop :: BinderOps lore => TopDownRuleDoLoop lore
simplifKnownIterationLoop :: TopDownRuleDoLoop lore
simplifKnownIterationLoop TopDown lore
_ Pattern lore
pat StmAux (ExpDec lore)
_ ([(FParam lore, SubExp)]
ctx, [(FParam lore, SubExp)]
val, ForLoop VName
i IntType
it (Constant PrimValue
iters) [(LParam lore, VName)]
loop_vars, BodyT lore
body)
  | PrimValue -> Bool
zeroIsh PrimValue
iters = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
      let bindResult :: PatElemT dec -> SubExp -> m ()
bindResult PatElemT dec
p SubExp
r = [VName] -> Exp (Lore m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [PatElemT dec -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT dec
p] (Exp (Lore m) -> m ()) -> Exp (Lore m) -> m ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
r
      (PatElemT (LetDec lore) -> SubExp -> RuleM lore ())
-> [PatElemT (LetDec lore)] -> [SubExp] -> RuleM lore ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ PatElemT (LetDec lore) -> SubExp -> RuleM lore ()
forall (m :: * -> *) dec.
MonadBinder m =>
PatElemT dec -> SubExp -> m ()
bindResult (Pattern lore -> [PatElemT (LetDec lore)]
forall dec. PatternT dec -> [PatElemT dec]
patternContextElements Pattern lore
pat) (((FParam lore, SubExp) -> SubExp)
-> [(FParam lore, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (FParam lore, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(FParam lore, SubExp)]
ctx)
      (PatElemT (LetDec lore) -> SubExp -> RuleM lore ())
-> [PatElemT (LetDec lore)] -> [SubExp] -> RuleM lore ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ PatElemT (LetDec lore) -> SubExp -> RuleM lore ()
forall (m :: * -> *) dec.
MonadBinder m =>
PatElemT dec -> SubExp -> m ()
bindResult (Pattern lore -> [PatElemT (LetDec lore)]
forall dec. PatternT dec -> [PatElemT dec]
patternValueElements Pattern lore
pat) (((FParam lore, SubExp) -> SubExp)
-> [(FParam lore, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (FParam lore, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(FParam lore, SubExp)]
val)

  | PrimValue -> Bool
oneIsh PrimValue
iters = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do

  [(FParam lore, SubExp)]
-> ((FParam lore, SubExp) -> RuleM lore ()) -> RuleM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([(FParam lore, SubExp)]
ctx[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)] -> [(FParam lore, SubExp)]
forall a. [a] -> [a] -> [a]
++[(FParam lore, SubExp)]
val) (((FParam lore, SubExp) -> RuleM lore ()) -> RuleM lore ())
-> ((FParam lore, SubExp) -> RuleM lore ()) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ \(FParam lore
mergevar, SubExp
mergeinit) ->
    [VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [FParam lore -> VName
forall dec. Param dec -> VName
paramName FParam lore
mergevar] (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
mergeinit

  [VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName
i] (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
it Integer
0

  [(LParam lore, VName)]
-> ((LParam lore, VName) -> RuleM lore ()) -> RuleM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(LParam lore, VName)]
loop_vars (((LParam lore, VName) -> RuleM lore ()) -> RuleM lore ())
-> ((LParam lore, VName) -> RuleM lore ()) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ \(LParam lore
p,VName
arr) ->
    [VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [LParam lore -> VName
forall dec. Param dec -> VName
paramName LParam lore
p] (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$
    SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
0) DimIndex SubExp -> Slice SubExp -> Slice SubExp
forall a. a -> [a] -> [a]
: TypeBase Shape NoUniqueness -> Slice SubExp -> Slice SubExp
fullSlice (LParam lore -> TypeBase Shape NoUniqueness
forall dec. Typed dec => Param dec -> TypeBase Shape NoUniqueness
paramType LParam lore
p) []

  -- Some of the sizes in the types here might be temporarily wrong
  -- until copy propagation fixes it up.
  [SubExp]
res <- Body (Lore (RuleM lore)) -> RuleM lore [SubExp]
forall (m :: * -> *). MonadBinder m => Body (Lore m) -> m [SubExp]
bodyBind BodyT lore
Body (Lore (RuleM lore))
body
  [(VName, SubExp)]
-> ((VName, SubExp) -> RuleM lore ()) -> RuleM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pattern lore -> [VName]
forall dec. PatternT dec -> [VName]
patternNames Pattern lore
pat) [SubExp]
res) (((VName, SubExp) -> RuleM lore ()) -> RuleM lore ())
-> ((VName, SubExp) -> RuleM lore ()) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ \(VName
v, SubExp
se) ->
    [VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName
v] (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
simplifKnownIterationLoop TopDown lore
_ Pattern lore
_ StmAux (ExpDec lore)
_ ([(FParam lore, SubExp)], [(FParam lore, SubExp)], LoopForm lore,
 BodyT lore)
_ =
  Rule lore
forall lore. Rule lore
Skip

-- | Turn @copy(x)@ into @x@ iff @x@ is not used after this copy
-- statement and it can be consumed.
--
-- This simplistic rule is only valid before we introduce memory.
removeUnnecessaryCopy :: BinderOps lore => BottomUpRuleBasicOp lore
removeUnnecessaryCopy :: BottomUpRuleBasicOp lore
removeUnnecessaryCopy (SymbolTable lore
vtable,UsageTable
used) (Pattern [] [PatElemT (LetDec lore)
d]) StmAux (ExpDec lore)
_ (Copy VName
v)
  | Bool -> Bool
not (VName
v VName -> UsageTable -> Bool
`UT.isConsumed` UsageTable
used),
    (Bool -> Bool
not (VName
v VName -> UsageTable -> Bool
`UT.used` UsageTable
used) Bool -> Bool -> Bool
&& Bool
consumable) Bool -> Bool -> Bool
|| Bool -> Bool
not (PatElemT (LetDec lore) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
d VName -> UsageTable -> Bool
`UT.isConsumed` UsageTable
used) =
      RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ [VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [PatElemT (LetDec lore) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
d] (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v
  where -- We need to make sure we can even consume the original.
        -- This is currently a hacky check, much too conservative,
        -- because we don't have the information conveniently
        -- available.
        consumable :: Bool
consumable = case VName -> Map VName (NameInfo lore) -> Maybe (NameInfo lore)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v (Map VName (NameInfo lore) -> Maybe (NameInfo lore))
-> Map VName (NameInfo lore) -> Maybe (NameInfo lore)
forall a b. (a -> b) -> a -> b
$ SymbolTable lore -> Map VName (NameInfo lore)
forall lore. SymbolTable lore -> Scope lore
ST.toScope SymbolTable lore
vtable of
                       Just (FParamName FParamInfo lore
info) -> TypeBase Shape Uniqueness -> Bool
forall shape. TypeBase shape Uniqueness -> Bool
unique (TypeBase Shape Uniqueness -> Bool)
-> TypeBase Shape Uniqueness -> Bool
forall a b. (a -> b) -> a -> b
$ FParamInfo lore -> TypeBase Shape Uniqueness
forall t. DeclTyped t => t -> TypeBase Shape Uniqueness
declTypeOf FParamInfo lore
info
                       Maybe (NameInfo lore)
_ -> Bool
False
removeUnnecessaryCopy (SymbolTable lore, UsageTable)
_ PatternT (LetDec lore)
_ StmAux (ExpDec lore)
_ BasicOp
_ = Rule lore
forall lore. Rule lore
Skip

simplifyCmpOp :: SimpleRule lore
simplifyCmpOp :: SimpleRule lore
simplifyCmpOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (CmpOp CmpOp
cmp SubExp
e1 SubExp
e2)
  | SubExp
e1 SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 = PrimValue -> Maybe (BasicOp, Certificates)
constRes (PrimValue -> Maybe (BasicOp, Certificates))
-> PrimValue -> Maybe (BasicOp, Certificates)
forall a b. (a -> b) -> a -> b
$ Bool -> PrimValue
BoolValue (Bool -> PrimValue) -> Bool -> PrimValue
forall a b. (a -> b) -> a -> b
$
               case CmpOp
cmp of CmpEq{}  -> Bool
True
                           CmpSlt{} -> Bool
False
                           CmpUlt{} -> Bool
False
                           CmpSle{} -> Bool
True
                           CmpUle{} -> Bool
True
                           FCmpLt{} -> Bool
False
                           FCmpLe{} -> Bool
True
                           CmpOp
CmpLlt -> Bool
False
                           CmpOp
CmpLle -> Bool
True

simplifyCmpOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (CmpOp CmpOp
cmp (Constant PrimValue
v1) (Constant PrimValue
v2)) =
  PrimValue -> Maybe (BasicOp, Certificates)
constRes (PrimValue -> Maybe (BasicOp, Certificates))
-> (Bool -> PrimValue) -> Bool -> Maybe (BasicOp, Certificates)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Bool -> PrimValue
BoolValue (Bool -> Maybe (BasicOp, Certificates))
-> Maybe Bool -> Maybe (BasicOp, Certificates)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< CmpOp -> PrimValue -> PrimValue -> Maybe Bool
doCmpOp CmpOp
cmp PrimValue
v1 PrimValue
v2

simplifyCmpOp VarLookup lore
look SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (CmpOp CmpEq{} (Constant (IntValue IntValue
x)) (Var VName
v))
  | Just (BasicOp (ConvOp BToI{} SubExp
b), Certificates
cs) <- VarLookup lore
look VName
v =
      case IntValue -> Int
forall int. Integral int => IntValue -> int
valueIntegral IntValue
x :: Int of
        Int
1 -> (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (SubExp -> BasicOp
SubExp SubExp
b, Certificates
cs)
        Int
0 -> (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (UnOp -> SubExp -> BasicOp
UnOp UnOp
Not SubExp
b, Certificates
cs)
        Int
_ -> (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (SubExp -> BasicOp
SubExp (PrimValue -> SubExp
Constant (Bool -> PrimValue
BoolValue Bool
False)), Certificates
cs)

simplifyCmpOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ BasicOp
_ = Maybe (BasicOp, Certificates)
forall a. Maybe a
Nothing

simplifyBinOp :: SimpleRule lore

simplifyBinOp :: SimpleRule lore
simplifyBinOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp BinOp
op (Constant PrimValue
v1) (Constant PrimValue
v2))
  | Just PrimValue
res <- BinOp -> PrimValue -> PrimValue -> Maybe PrimValue
doBinOp BinOp
op PrimValue
v1 PrimValue
v2 =
      PrimValue -> Maybe (BasicOp, Certificates)
constRes PrimValue
res

simplifyBinOp VarLookup lore
look SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp Add{} SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e2
  | SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1

  -- x+(y-x) => y
  | Var VName
v2 <- SubExp
e2,
    Just (BasicOp (BinOp Sub{} SubExp
e2_a SubExp
e2_b), Certificates
cs) <- VarLookup lore
look VName
v2,
    SubExp
e2_b SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e1 = (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (SubExp -> BasicOp
SubExp SubExp
e2_a, Certificates
cs)

simplifyBinOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp FAdd{} SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e2
  | SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1


simplifyBinOp VarLookup lore
look SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp Sub{} SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
  -- Cases for simplifying (a+b)-b and permutations.
  | Var VName
v1 <- SubExp
e1,
    Just (BasicOp (BinOp Add{} SubExp
e1_a SubExp
e1_b), Certificates
cs) <- VarLookup lore
look VName
v1,
    SubExp
e1_a SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 = (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (SubExp -> BasicOp
SubExp SubExp
e1_b, Certificates
cs)
  | Var VName
v1 <- SubExp
e1,
    Just (BasicOp (BinOp Add{} SubExp
e1_a SubExp
e1_b), Certificates
cs) <- VarLookup lore
look VName
v1,
    SubExp
e1_b SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 = (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (SubExp -> BasicOp
SubExp SubExp
e1_a, Certificates
cs)
  | Var VName
v2 <- SubExp
e2,
    Just (BasicOp (BinOp Add{} SubExp
e2_a SubExp
e2_b), Certificates
cs) <- VarLookup lore
look VName
v2,
    SubExp
e2_a SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e1 = (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (SubExp -> BasicOp
SubExp SubExp
e2_b, Certificates
cs)
  | Var VName
v2 <- SubExp
e1,
    Just (BasicOp (BinOp Add{} SubExp
e2_a SubExp
e2_b), Certificates
cs) <- VarLookup lore
look VName
v2,
    SubExp
e2_b SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e1 = (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (SubExp -> BasicOp
SubExp SubExp
e2_a, Certificates
cs)

simplifyBinOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp FSub{} SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1

simplifyBinOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp Mul{} SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
  | SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e2
  | SubExp -> Bool
isCt1 SubExp
e1 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e2
  | SubExp -> Bool
isCt1 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1

simplifyBinOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp FMul{} SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
  | SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e2
  | SubExp -> Bool
isCt1 SubExp
e1 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e2
  | SubExp -> Bool
isCt1 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1

simplifyBinOp VarLookup lore
look SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp (SMod IntType
t Safety
_) SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt1 SubExp
e2 = PrimValue -> Maybe (BasicOp, Certificates)
constRes (PrimValue -> Maybe (BasicOp, Certificates))
-> PrimValue -> Maybe (BasicOp, Certificates)
forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue (IntValue -> PrimValue) -> IntValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ IntType -> Int -> IntValue
forall int. Integral int => IntType -> int -> IntValue
intValue IntType
t (Int
0 :: Int)
  | SubExp
e1 SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 = PrimValue -> Maybe (BasicOp, Certificates)
constRes (PrimValue -> Maybe (BasicOp, Certificates))
-> PrimValue -> Maybe (BasicOp, Certificates)
forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue (IntValue -> PrimValue) -> IntValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ IntType -> Int -> IntValue
forall int. Integral int => IntType -> int -> IntValue
intValue IntType
t (Int
0 :: Int)
  | Var VName
v1 <- SubExp
e1,
    Just (BasicOp (BinOp SMod{} SubExp
_ SubExp
e4), Certificates
v1_cs) <- VarLookup lore
look VName
v1,
    SubExp
e4 SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 = (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (SubExp -> BasicOp
SubExp SubExp
e1, Certificates
v1_cs)

simplifyBinOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp SDiv{} SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
  | SubExp -> Bool
isCt1 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
  | SubExp -> Bool
isCt0 SubExp
e2 = Maybe (BasicOp, Certificates)
forall a. Maybe a
Nothing

simplifyBinOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp SDivUp{} SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
  | SubExp -> Bool
isCt1 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
  | SubExp -> Bool
isCt0 SubExp
e2 = Maybe (BasicOp, Certificates)
forall a. Maybe a
Nothing

simplifyBinOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp FDiv{} SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
  | SubExp -> Bool
isCt1 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
  | SubExp -> Bool
isCt0 SubExp
e2 = Maybe (BasicOp, Certificates)
forall a. Maybe a
Nothing

simplifyBinOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp (SRem IntType
t Safety
_) SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt1 SubExp
e2 = PrimValue -> Maybe (BasicOp, Certificates)
constRes (PrimValue -> Maybe (BasicOp, Certificates))
-> PrimValue -> Maybe (BasicOp, Certificates)
forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue (IntValue -> PrimValue) -> IntValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ IntType -> Int -> IntValue
forall int. Integral int => IntType -> int -> IntValue
intValue IntType
t (Int
0 :: Int)
  | SubExp
e1 SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 = PrimValue -> Maybe (BasicOp, Certificates)
constRes (PrimValue -> Maybe (BasicOp, Certificates))
-> PrimValue -> Maybe (BasicOp, Certificates)
forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue (IntValue -> PrimValue) -> IntValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ IntType -> Int -> IntValue
forall int. Integral int => IntType -> int -> IntValue
intValue IntType
t (Int
1 :: Int)

simplifyBinOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp SQuot{} SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt1 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
  | SubExp -> Bool
isCt0 SubExp
e2 = Maybe (BasicOp, Certificates)
forall a. Maybe a
Nothing

simplifyBinOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp (FPow FloatType
t) SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes (SubExp -> Maybe (BasicOp, Certificates))
-> SubExp -> Maybe (BasicOp, Certificates)
forall a b. (a -> b) -> a -> b
$ FloatType -> Double -> SubExp
floatConst FloatType
t Double
1
  | SubExp -> Bool
isCt0 SubExp
e1 Bool -> Bool -> Bool
|| SubExp -> Bool
isCt1 SubExp
e1 Bool -> Bool -> Bool
|| SubExp -> Bool
isCt1 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1

simplifyBinOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp (Shl IntType
t) SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
  | SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes (SubExp -> Maybe (BasicOp, Certificates))
-> SubExp -> Maybe (BasicOp, Certificates)
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
t Integer
0

simplifyBinOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp AShr{} SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1

simplifyBinOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp (And IntType
t) SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes (SubExp -> Maybe (BasicOp, Certificates))
-> SubExp -> Maybe (BasicOp, Certificates)
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
t Integer
0
  | SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes (SubExp -> Maybe (BasicOp, Certificates))
-> SubExp -> Maybe (BasicOp, Certificates)
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
t Integer
0
  | SubExp
e1 SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1

simplifyBinOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp Or{} SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e2
  | SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
  | SubExp
e1 SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1

simplifyBinOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp (Xor IntType
t) SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e2
  | SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
  | SubExp
e1 SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes (SubExp -> Maybe (BasicOp, Certificates))
-> SubExp -> Maybe (BasicOp, Certificates)
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
t Integer
0

simplifyBinOp VarLookup lore
defOf SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp BinOp
LogAnd SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e1 = PrimValue -> Maybe (BasicOp, Certificates)
constRes (PrimValue -> Maybe (BasicOp, Certificates))
-> PrimValue -> Maybe (BasicOp, Certificates)
forall a b. (a -> b) -> a -> b
$ Bool -> PrimValue
BoolValue Bool
False
  | SubExp -> Bool
isCt0 SubExp
e2 = PrimValue -> Maybe (BasicOp, Certificates)
constRes (PrimValue -> Maybe (BasicOp, Certificates))
-> PrimValue -> Maybe (BasicOp, Certificates)
forall a b. (a -> b) -> a -> b
$ Bool -> PrimValue
BoolValue Bool
False
  | SubExp -> Bool
isCt1 SubExp
e1 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e2
  | SubExp -> Bool
isCt1 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
  | Var VName
v <- SubExp
e1,
    Just (BasicOp (UnOp UnOp
Not SubExp
e1'), Certificates
v_cs) <- VarLookup lore
defOf VName
v,
    SubExp
e1' SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 = (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ Bool -> PrimValue
BoolValue Bool
False, Certificates
v_cs)
  | Var VName
v <- SubExp
e2,
    Just (BasicOp (UnOp UnOp
Not SubExp
e2'), Certificates
v_cs) <- VarLookup lore
defOf VName
v,
    SubExp
e2' SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e1 = (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ Bool -> PrimValue
BoolValue Bool
False, Certificates
v_cs)

simplifyBinOp VarLookup lore
defOf SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp BinOp
LogOr SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e2
  | SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
  | SubExp -> Bool
isCt1 SubExp
e1 = PrimValue -> Maybe (BasicOp, Certificates)
constRes (PrimValue -> Maybe (BasicOp, Certificates))
-> PrimValue -> Maybe (BasicOp, Certificates)
forall a b. (a -> b) -> a -> b
$ Bool -> PrimValue
BoolValue Bool
True
  | SubExp -> Bool
isCt1 SubExp
e2 = PrimValue -> Maybe (BasicOp, Certificates)
constRes (PrimValue -> Maybe (BasicOp, Certificates))
-> PrimValue -> Maybe (BasicOp, Certificates)
forall a b. (a -> b) -> a -> b
$ Bool -> PrimValue
BoolValue Bool
True
  | Var VName
v <- SubExp
e1,
    Just (BasicOp (UnOp UnOp
Not SubExp
e1'), Certificates
v_cs) <- VarLookup lore
defOf VName
v,
    SubExp
e1' SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 = (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ Bool -> PrimValue
BoolValue Bool
True, Certificates
v_cs)
  | Var VName
v <- SubExp
e2,
    Just (BasicOp (UnOp UnOp
Not SubExp
e2'), Certificates
v_cs) <- VarLookup lore
defOf VName
v,
    SubExp
e2' SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e1 = (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ Bool -> PrimValue
BoolValue Bool
True, Certificates
v_cs)

simplifyBinOp VarLookup lore
defOf SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp (SMax IntType
it) SubExp
e1 SubExp
e2)
  | SubExp
e1 SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 =
      SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
  | Var VName
v1 <- SubExp
e1,
    Just (BasicOp (BinOp (SMax IntType
_) SubExp
e1_1 SubExp
e1_2), Certificates
v1_cs) <- VarLookup lore
defOf VName
v1,
    SubExp
e1_1 SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 =
      (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> BinOp
SMax IntType
it) SubExp
e1_2 SubExp
e2, Certificates
v1_cs)
  | Var VName
v1 <- SubExp
e1,
    Just (BasicOp (BinOp (SMax IntType
_) SubExp
e1_1 SubExp
e1_2), Certificates
v1_cs) <- VarLookup lore
defOf VName
v1,
    SubExp
e1_2 SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 =
      (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> BinOp
SMax IntType
it) SubExp
e1_1 SubExp
e2, Certificates
v1_cs)
  | Var VName
v2 <- SubExp
e2,
    Just (BasicOp (BinOp (SMax IntType
_) SubExp
e2_1 SubExp
e2_2), Certificates
v2_cs) <- VarLookup lore
defOf VName
v2,
    SubExp
e2_1 SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e1 =
      (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> BinOp
SMax IntType
it) SubExp
e2_2 SubExp
e1, Certificates
v2_cs)
  | Var VName
v2 <- SubExp
e2,
    Just (BasicOp (BinOp (SMax IntType
_) SubExp
e2_1 SubExp
e2_2), Certificates
v2_cs) <- VarLookup lore
defOf VName
v2,
    SubExp
e2_2 SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e1 =
      (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> BinOp
SMax IntType
it) SubExp
e2_1 SubExp
e1, Certificates
v2_cs)

simplifyBinOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ BasicOp
_ = Maybe (BasicOp, Certificates)
forall a. Maybe a
Nothing

constRes :: PrimValue -> Maybe (BasicOp, Certificates)
constRes :: PrimValue -> Maybe (BasicOp, Certificates)
constRes = (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just ((BasicOp, Certificates) -> Maybe (BasicOp, Certificates))
-> (PrimValue -> (BasicOp, Certificates))
-> PrimValue
-> Maybe (BasicOp, Certificates)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (,Certificates
forall a. Monoid a => a
mempty) (BasicOp -> (BasicOp, Certificates))
-> (PrimValue -> BasicOp) -> PrimValue -> (BasicOp, Certificates)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp (SubExp -> BasicOp)
-> (PrimValue -> SubExp) -> PrimValue -> BasicOp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimValue -> SubExp
Constant

subExpRes :: SubExp -> Maybe (BasicOp, Certificates)
subExpRes :: SubExp -> Maybe (BasicOp, Certificates)
subExpRes = (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just ((BasicOp, Certificates) -> Maybe (BasicOp, Certificates))
-> (SubExp -> (BasicOp, Certificates))
-> SubExp
-> Maybe (BasicOp, Certificates)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (,Certificates
forall a. Monoid a => a
mempty) (BasicOp -> (BasicOp, Certificates))
-> (SubExp -> BasicOp) -> SubExp -> (BasicOp, Certificates)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp

simplifyUnOp :: SimpleRule lore
simplifyUnOp :: SimpleRule lore
simplifyUnOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (UnOp UnOp
op (Constant PrimValue
v)) =
  PrimValue -> Maybe (BasicOp, Certificates)
constRes (PrimValue -> Maybe (BasicOp, Certificates))
-> Maybe PrimValue -> Maybe (BasicOp, Certificates)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< UnOp -> PrimValue -> Maybe PrimValue
doUnOp UnOp
op PrimValue
v
simplifyUnOp VarLookup lore
defOf SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (UnOp UnOp
Not (Var VName
v))
  | Just (BasicOp (UnOp UnOp
Not SubExp
v2), Certificates
v_cs) <- VarLookup lore
defOf VName
v =
      (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (SubExp -> BasicOp
SubExp SubExp
v2, Certificates
v_cs)
simplifyUnOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ BasicOp
_ =
  Maybe (BasicOp, Certificates)
forall a. Maybe a
Nothing

simplifyConvOp :: SimpleRule lore
simplifyConvOp :: SimpleRule lore
simplifyConvOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (ConvOp ConvOp
op (Constant PrimValue
v)) =
  PrimValue -> Maybe (BasicOp, Certificates)
constRes (PrimValue -> Maybe (BasicOp, Certificates))
-> Maybe PrimValue -> Maybe (BasicOp, Certificates)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ConvOp -> PrimValue -> Maybe PrimValue
doConvOp ConvOp
op PrimValue
v
simplifyConvOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (ConvOp ConvOp
op SubExp
se)
  | (PrimType
from, PrimType
to) <- ConvOp -> (PrimType, PrimType)
convOpType ConvOp
op, PrimType
from PrimType -> PrimType -> Bool
forall a. Eq a => a -> a -> Bool
== PrimType
to =
  SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
se
simplifyConvOp VarLookup lore
lookupVar SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (ConvOp (SExt IntType
t2 IntType
t1) (Var VName
v))
  | Just (BasicOp (ConvOp (SExt IntType
t3 IntType
_) SubExp
se), Certificates
v_cs) <- VarLookup lore
lookupVar VName
v,
    IntType
t2 IntType -> IntType -> Bool
forall a. Ord a => a -> a -> Bool
>= IntType
t3 =
      (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (ConvOp -> SubExp -> BasicOp
ConvOp (IntType -> IntType -> ConvOp
SExt IntType
t3 IntType
t1) SubExp
se, Certificates
v_cs)
simplifyConvOp VarLookup lore
lookupVar SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (ConvOp (ZExt IntType
t2 IntType
t1) (Var VName
v))
  | Just (BasicOp (ConvOp (ZExt IntType
t3 IntType
_) SubExp
se), Certificates
v_cs) <- VarLookup lore
lookupVar VName
v,
    IntType
t2 IntType -> IntType -> Bool
forall a. Ord a => a -> a -> Bool
>= IntType
t3 =
      (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (ConvOp -> SubExp -> BasicOp
ConvOp (IntType -> IntType -> ConvOp
ZExt IntType
t3 IntType
t1) SubExp
se, Certificates
v_cs)
simplifyConvOp VarLookup lore
lookupVar SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (ConvOp (SIToFP IntType
t2 FloatType
t1) (Var VName
v))
  | Just (BasicOp (ConvOp (SExt IntType
t3 IntType
_) SubExp
se), Certificates
v_cs) <- VarLookup lore
lookupVar VName
v,
    IntType
t2 IntType -> IntType -> Bool
forall a. Ord a => a -> a -> Bool
>= IntType
t3 =
      (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (ConvOp -> SubExp -> BasicOp
ConvOp (IntType -> FloatType -> ConvOp
SIToFP IntType
t3 FloatType
t1) SubExp
se, Certificates
v_cs)
simplifyConvOp VarLookup lore
lookupVar SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (ConvOp (UIToFP IntType
t2 FloatType
t1) (Var VName
v))
  | Just (BasicOp (ConvOp (ZExt IntType
t3 IntType
_) SubExp
se), Certificates
v_cs) <- VarLookup lore
lookupVar VName
v,
    IntType
t2 IntType -> IntType -> Bool
forall a. Ord a => a -> a -> Bool
>= IntType
t3 =
      (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (ConvOp -> SubExp -> BasicOp
ConvOp (IntType -> FloatType -> ConvOp
UIToFP IntType
t3 FloatType
t1) SubExp
se, Certificates
v_cs)
simplifyConvOp VarLookup lore
lookupVar SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (ConvOp (FPConv FloatType
t2 FloatType
t1) (Var VName
v))
  | Just (BasicOp (ConvOp (FPConv FloatType
t3 FloatType
_) SubExp
se), Certificates
v_cs) <- VarLookup lore
lookupVar VName
v,
    FloatType
t2 FloatType -> FloatType -> Bool
forall a. Ord a => a -> a -> Bool
>= FloatType
t3 =
      (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (ConvOp -> SubExp -> BasicOp
ConvOp (FloatType -> FloatType -> ConvOp
FPConv FloatType
t3 FloatType
t1) SubExp
se, Certificates
v_cs)
simplifyConvOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ BasicOp
_ =
  Maybe (BasicOp, Certificates)
forall a. Maybe a
Nothing

-- If expression is true then just replace assertion.
simplifyAssert :: SimpleRule lore
simplifyAssert :: SimpleRule lore
simplifyAssert VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (Assert (Constant (BoolValue Bool
True)) ErrorMsg SubExp
_ (SrcLoc, [SrcLoc])
_) =
  PrimValue -> Maybe (BasicOp, Certificates)
constRes PrimValue
Checked
simplifyAssert VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ BasicOp
_ =
  Maybe (BasicOp, Certificates)
forall a. Maybe a
Nothing

constantFoldPrimFun :: BinderOps lore => TopDownRuleGeneric lore
constantFoldPrimFun :: TopDownRuleGeneric lore
constantFoldPrimFun TopDown lore
_ (Let Pattern lore
pat (StmAux Certificates
cs Attrs
attrs ExpDec lore
_) (Apply Name
fname [(SubExp, Diet)]
args [RetType lore]
_ (Safety, SrcLoc, [SrcLoc])
_))
  | Just [PrimValue]
args' <- ((SubExp, Diet) -> Maybe PrimValue)
-> [(SubExp, Diet)] -> Maybe [PrimValue]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SubExp -> Maybe PrimValue
isConst (SubExp -> Maybe PrimValue)
-> ((SubExp, Diet) -> SubExp) -> (SubExp, Diet) -> Maybe PrimValue
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SubExp, Diet) -> SubExp
forall a b. (a, b) -> a
fst) [(SubExp, Diet)]
args,
    Just ([PrimType]
_, PrimType
_, [PrimValue] -> Maybe PrimValue
fun) <- String
-> Map
     String ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
-> Maybe ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (Name -> String
nameToString Name
fname) Map String ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
primFuns,
    Just PrimValue
result <- [PrimValue] -> Maybe PrimValue
fun [PrimValue]
args' =
      RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ Certificates -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ Attrs -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Attrs -> m a -> m a
attributing Attrs
attrs (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
      Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant PrimValue
result
  where isConst :: SubExp -> Maybe PrimValue
isConst (Constant PrimValue
v) = PrimValue -> Maybe PrimValue
forall a. a -> Maybe a
Just PrimValue
v
        isConst SubExp
_ = Maybe PrimValue
forall a. Maybe a
Nothing
constantFoldPrimFun TopDown lore
_ Stm lore
_ = Rule lore
forall lore. Rule lore
Skip

simplifyIndex :: BinderOps lore => BottomUpRuleBasicOp lore
simplifyIndex :: BottomUpRuleBasicOp lore
simplifyIndex (SymbolTable lore
vtable, UsageTable
used) pat :: Pattern lore
pat@(Pattern [] [PatElemT (LetDec lore)
pe]) (StmAux Certificates
cs Attrs
attrs ExpDec lore
_) (Index VName
idd Slice SubExp
inds)
  | Just RuleM lore IndexResult
m <- SymbolTable (Lore (RuleM lore))
-> (SubExp -> Maybe (TypeBase Shape NoUniqueness))
-> VName
-> Slice SubExp
-> Bool
-> Maybe (RuleM lore IndexResult)
forall (m :: * -> *).
MonadBinder m =>
SymbolTable (Lore m)
-> (SubExp -> Maybe (TypeBase Shape NoUniqueness))
-> VName
-> Slice SubExp
-> Bool
-> Maybe (m IndexResult)
simplifyIndexing SymbolTable lore
SymbolTable (Lore (RuleM lore))
vtable SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType VName
idd Slice SubExp
inds Bool
consumed = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
      IndexResult
res <- RuleM lore IndexResult
m
      Attrs -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Attrs -> m a -> m a
attributing Attrs
attrs (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ case IndexResult
res of
        SubExpResult Certificates
cs' SubExp
se ->
          Certificates -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (Certificates
csCertificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<>Certificates
cs') (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
          [VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames (Pattern lore -> [VName]
forall dec. PatternT dec -> [VName]
patternNames Pattern lore
pat) (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
        IndexResult Certificates
extra_cs VName
idd' Slice SubExp
inds' ->
          Certificates -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (Certificates
csCertificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<>Certificates
extra_cs) (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
          [VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames (Pattern lore -> [VName]
forall dec. PatternT dec -> [VName]
patternNames Pattern lore
pat) (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
idd' Slice SubExp
inds'
  where consumed :: Bool
consumed = PatElemT (LetDec lore) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
pe VName -> UsageTable -> Bool
`UT.isConsumed` UsageTable
used
        seType :: SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType (Var VName
v) = VName -> SymbolTable lore -> Maybe (TypeBase Shape NoUniqueness)
forall lore.
ASTLore lore =>
VName -> SymbolTable lore -> Maybe (TypeBase Shape NoUniqueness)
ST.lookupType VName
v SymbolTable lore
vtable
        seType (Constant PrimValue
v) = TypeBase Shape NoUniqueness -> Maybe (TypeBase Shape NoUniqueness)
forall a. a -> Maybe a
Just (TypeBase Shape NoUniqueness
 -> Maybe (TypeBase Shape NoUniqueness))
-> TypeBase Shape NoUniqueness
-> Maybe (TypeBase Shape NoUniqueness)
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase Shape NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim (PrimType -> TypeBase Shape NoUniqueness)
-> PrimType -> TypeBase Shape NoUniqueness
forall a b. (a -> b) -> a -> b
$ PrimValue -> PrimType
primValueType PrimValue
v

simplifyIndex (SymbolTable lore, UsageTable)
_ Pattern lore
_ StmAux (ExpDec lore)
_ BasicOp
_ = Rule lore
forall lore. Rule lore
Skip

data IndexResult = IndexResult Certificates VName (Slice SubExp)
                 | SubExpResult Certificates SubExp

simplifyIndexing :: MonadBinder m =>
                    ST.SymbolTable (Lore m) -> TypeLookup
                 -> VName -> Slice SubExp -> Bool
                 -> Maybe (m IndexResult)
simplifyIndexing :: SymbolTable (Lore m)
-> (SubExp -> Maybe (TypeBase Shape NoUniqueness))
-> VName
-> Slice SubExp
-> Bool
-> Maybe (m IndexResult)
simplifyIndexing SymbolTable (Lore m)
vtable SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType VName
idd Slice SubExp
inds Bool
consuming =
  case VName -> Maybe (BasicOp, Certificates)
defOf VName
idd of
    Maybe (BasicOp, Certificates)
_ | Just TypeBase Shape NoUniqueness
t <- SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType (VName -> SubExp
Var VName
idd),
        Slice SubExp
inds Slice SubExp -> Slice SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== TypeBase Shape NoUniqueness -> Slice SubExp -> Slice SubExp
fullSlice TypeBase Shape NoUniqueness
t [] ->
          m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ IndexResult -> m IndexResult
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certificates -> SubExp -> IndexResult
SubExpResult Certificates
forall a. Monoid a => a
mempty (SubExp -> IndexResult) -> SubExp -> IndexResult
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
idd

      | Just [SubExp]
inds' <- Slice SubExp -> Maybe [SubExp]
forall d. Slice d -> Maybe [d]
sliceIndices Slice SubExp
inds,
        Just (ST.Indexed Certificates
cs PrimExp VName
e) <- VName -> [SubExp] -> SymbolTable (Lore m) -> Maybe Indexed
forall lore.
ASTLore lore =>
VName -> [SubExp] -> SymbolTable lore -> Maybe Indexed
ST.index VName
idd [SubExp]
inds' SymbolTable (Lore m)
vtable,
        PrimExp VName -> Bool
forall v. PrimExp v -> Bool
worthInlining PrimExp VName
e,
        (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> SymbolTable (Lore m) -> Bool
forall lore. VName -> SymbolTable lore -> Bool
`ST.elem` SymbolTable (Lore m)
vtable) (Certificates -> [VName]
unCertificates Certificates
cs) ->
          m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ Certificates -> SubExp -> IndexResult
SubExpResult Certificates
cs (SubExp -> IndexResult) -> m SubExp -> m IndexResult
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> PrimExp VName -> m SubExp
forall (m :: * -> *) a.
(MonadBinder m, ToExp a) =>
String -> a -> m SubExp
toSubExp String
"index_primexp" PrimExp VName
e

      | Just [SubExp]
inds' <- Slice SubExp -> Maybe [SubExp]
forall d. Slice d -> Maybe [d]
sliceIndices Slice SubExp
inds,
        Just (ST.IndexedArray Certificates
cs VName
arr [PrimExp VName]
inds'') <- VName -> [SubExp] -> SymbolTable (Lore m) -> Maybe Indexed
forall lore.
ASTLore lore =>
VName -> [SubExp] -> SymbolTable lore -> Maybe Indexed
ST.index VName
idd [SubExp]
inds' SymbolTable (Lore m)
vtable,
        (PrimExp VName -> Bool) -> [PrimExp VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all PrimExp VName -> Bool
forall v. PrimExp v -> Bool
worthInlining [PrimExp VName]
inds'',
        (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> SymbolTable (Lore m) -> Bool
forall lore. VName -> SymbolTable lore -> Bool
`ST.elem` SymbolTable (Lore m)
vtable) (Certificates -> [VName]
unCertificates Certificates
cs) ->
          m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ Certificates -> VName -> Slice SubExp -> IndexResult
IndexResult Certificates
cs VName
arr (Slice SubExp -> IndexResult)
-> ([SubExp] -> Slice SubExp) -> [SubExp] -> IndexResult
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SubExp -> DimIndex SubExp) -> [SubExp] -> Slice SubExp
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix ([SubExp] -> IndexResult) -> m [SubExp] -> m IndexResult
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
          (PrimExp VName -> m SubExp) -> [PrimExp VName] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (String -> PrimExp VName -> m SubExp
forall (m :: * -> *) a.
(MonadBinder m, ToExp a) =>
String -> a -> m SubExp
toSubExp String
"index_primexp") [PrimExp VName]
inds''

    Maybe (BasicOp, Certificates)
Nothing -> Maybe (m IndexResult)
forall a. Maybe a
Nothing

    Just (SubExp (Var VName
v), Certificates
cs) -> m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ IndexResult -> m IndexResult
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certificates -> VName -> Slice SubExp -> IndexResult
IndexResult Certificates
cs VName
v Slice SubExp
inds

    Just (Iota SubExp
_ SubExp
x SubExp
s IntType
to_it, Certificates
cs)
      | [DimFix SubExp
ii] <- Slice SubExp
inds,
        Just (Prim (IntType IntType
from_it)) <- SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType SubExp
ii ->
          m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$
          (SubExp -> IndexResult) -> m SubExp -> m IndexResult
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Certificates -> SubExp -> IndexResult
SubExpResult Certificates
cs) (m SubExp -> m IndexResult) -> m SubExp -> m IndexResult
forall a b. (a -> b) -> a -> b
$ String -> PrimExp VName -> m SubExp
forall (m :: * -> *) a.
(MonadBinder m, ToExp a) =>
String -> a -> m SubExp
toSubExp String
"index_iota" (PrimExp VName -> m SubExp) -> PrimExp VName -> m SubExp
forall a b. (a -> b) -> a -> b
$
          IntType -> PrimExp VName -> PrimExp VName
forall v. IntType -> PrimExp v -> PrimExp v
sExt IntType
to_it (PrimType -> SubExp -> PrimExp VName
primExpFromSubExp (IntType -> PrimType
IntType IntType
from_it) SubExp
ii)
          PrimExp VName -> PrimExp VName -> PrimExp VName
forall a. Num a => a -> a -> a
* PrimType -> SubExp -> PrimExp VName
primExpFromSubExp (IntType -> PrimType
IntType IntType
to_it) SubExp
s
          PrimExp VName -> PrimExp VName -> PrimExp VName
forall a. Num a => a -> a -> a
+ PrimType -> SubExp -> PrimExp VName
primExpFromSubExp (IntType -> PrimType
IntType IntType
to_it) SubExp
x
      | [DimSlice SubExp
i_offset SubExp
i_n SubExp
i_stride] <- Slice SubExp
inds ->
          m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ do
            SubExp
i_offset' <- IntType -> SubExp -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
IntType -> SubExp -> m SubExp
asIntS IntType
to_it SubExp
i_offset
            SubExp
i_stride' <- IntType -> SubExp -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
IntType -> SubExp -> m SubExp
asIntS IntType
to_it SubExp
i_stride
            SubExp
i_offset'' <- String -> PrimExp VName -> m SubExp
forall (m :: * -> *) a.
(MonadBinder m, ToExp a) =>
String -> a -> m SubExp
toSubExp String
"iota_offset" (PrimExp VName -> m SubExp) -> PrimExp VName -> m SubExp
forall a b. (a -> b) -> a -> b
$
                          PrimType -> SubExp -> PrimExp VName
primExpFromSubExp (IntType -> PrimType
IntType IntType
to_it) SubExp
x PrimExp VName -> PrimExp VName -> PrimExp VName
forall a. Num a => a -> a -> a
+
                          PrimType -> SubExp -> PrimExp VName
primExpFromSubExp (IntType -> PrimType
IntType IntType
to_it) SubExp
s PrimExp VName -> PrimExp VName -> PrimExp VName
forall a. Num a => a -> a -> a
*
                          PrimType -> SubExp -> PrimExp VName
primExpFromSubExp (IntType -> PrimType
IntType IntType
to_it) SubExp
i_offset'
            SubExp
i_stride'' <- String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"iota_offset" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$
                          BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Mul IntType
Int32 Overflow
OverflowWrap) SubExp
s SubExp
i_stride'
            (SubExp -> IndexResult) -> m SubExp -> m IndexResult
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Certificates -> SubExp -> IndexResult
SubExpResult Certificates
cs) (m SubExp -> m IndexResult) -> m SubExp -> m IndexResult
forall a b. (a -> b) -> a -> b
$ String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"slice_iota" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$
              BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
i_n SubExp
i_offset'' SubExp
i_stride'' IntType
to_it

    Just (Rotate [SubExp]
offsets VName
a, Certificates
cs) -> m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ do
      [SubExp]
dims <- TypeBase Shape NoUniqueness -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (TypeBase Shape NoUniqueness -> [SubExp])
-> m (TypeBase Shape NoUniqueness) -> m [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> m (TypeBase Shape NoUniqueness)
forall lore (m :: * -> *).
HasScope lore m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
a
      let adjustI :: SubExp -> SubExp -> SubExp -> m SubExp
adjustI SubExp
i SubExp
o SubExp
d = do
            SubExp
i_p_o <- String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"i_p_o" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Add IntType
Int32 Overflow
OverflowWrap) SubExp
i SubExp
o
            String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"rot_i" (BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Safety -> BinOp
SMod IntType
Int32 Safety
Unsafe) SubExp
i_p_o SubExp
d)
          adjust :: (DimIndex SubExp, SubExp, SubExp) -> f (DimIndex SubExp)
adjust (DimFix SubExp
i, SubExp
o, SubExp
d) =
            SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> f SubExp -> f (DimIndex SubExp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> SubExp -> SubExp -> f SubExp
forall (m :: * -> *).
MonadBinder m =>
SubExp -> SubExp -> SubExp -> m SubExp
adjustI SubExp
i SubExp
o SubExp
d
          adjust (DimSlice SubExp
i SubExp
n SubExp
s, SubExp
o, SubExp
d) =
            SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice (SubExp -> SubExp -> SubExp -> DimIndex SubExp)
-> f SubExp -> f (SubExp -> SubExp -> DimIndex SubExp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> SubExp -> SubExp -> f SubExp
forall (m :: * -> *).
MonadBinder m =>
SubExp -> SubExp -> SubExp -> m SubExp
adjustI SubExp
i SubExp
o SubExp
d f (SubExp -> SubExp -> DimIndex SubExp)
-> f SubExp -> f (SubExp -> DimIndex SubExp)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SubExp -> f SubExp
forall (f :: * -> *) a. Applicative f => a -> f a
pure SubExp
n f (SubExp -> DimIndex SubExp) -> f SubExp -> f (DimIndex SubExp)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SubExp -> f SubExp
forall (f :: * -> *) a. Applicative f => a -> f a
pure SubExp
s
      Certificates -> VName -> Slice SubExp -> IndexResult
IndexResult Certificates
cs VName
a (Slice SubExp -> IndexResult) -> m (Slice SubExp) -> m IndexResult
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((DimIndex SubExp, SubExp, SubExp) -> m (DimIndex SubExp))
-> [(DimIndex SubExp, SubExp, SubExp)] -> m (Slice SubExp)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (DimIndex SubExp, SubExp, SubExp) -> m (DimIndex SubExp)
forall (f :: * -> *).
MonadBinder f =>
(DimIndex SubExp, SubExp, SubExp) -> f (DimIndex SubExp)
adjust (Slice SubExp
-> [SubExp] -> [SubExp] -> [(DimIndex SubExp, SubExp, SubExp)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 Slice SubExp
inds [SubExp]
offsets [SubExp]
dims)

    Just (Index VName
aa Slice SubExp
ais, Certificates
cs) ->
      m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ Certificates -> VName -> Slice SubExp -> IndexResult
IndexResult Certificates
cs VName
aa (Slice SubExp -> IndexResult) -> m (Slice SubExp) -> m IndexResult
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
      Slice (PrimExp VName) -> m (Slice SubExp)
forall (m :: * -> *).
MonadBinder m =>
Slice (PrimExp VName) -> m (Slice SubExp)
subExpSlice (Slice (PrimExp VName)
-> Slice (PrimExp VName) -> Slice (PrimExp VName)
forall d. Num d => Slice d -> Slice d -> Slice d
sliceSlice (Slice SubExp -> Slice (PrimExp VName)
primExpSlice Slice SubExp
ais) (Slice SubExp -> Slice (PrimExp VName)
primExpSlice Slice SubExp
inds))

    Just (Replicate (Shape [SubExp
_]) (Var VName
vv), Certificates
cs)
      | [DimFix{}]   <- Slice SubExp
inds, Bool -> Bool
not Bool
consuming -> m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ IndexResult -> m IndexResult
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certificates -> SubExp -> IndexResult
SubExpResult Certificates
cs (SubExp -> IndexResult) -> SubExp -> IndexResult
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
vv
      | DimFix{}:Slice SubExp
is' <- Slice SubExp
inds, Bool -> Bool
not Bool
consuming -> m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ IndexResult -> m IndexResult
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certificates -> VName -> Slice SubExp -> IndexResult
IndexResult Certificates
cs VName
vv Slice SubExp
is'

    Just (Replicate (Shape [SubExp
_]) val :: SubExp
val@(Constant PrimValue
_), Certificates
cs)
      | [DimFix{}] <- Slice SubExp
inds, Bool -> Bool
not Bool
consuming -> m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ IndexResult -> m IndexResult
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certificates -> SubExp -> IndexResult
SubExpResult Certificates
cs SubExp
val

    Just (Replicate (Shape [SubExp]
ds) SubExp
v, Certificates
cs)
      | (Slice SubExp
ds_inds, Slice SubExp
rest_inds) <- Int -> Slice SubExp -> (Slice SubExp, Slice SubExp)
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
ds) Slice SubExp
inds,
        ([SubExp]
ds', Slice SubExp
ds_inds') <- [(SubExp, DimIndex SubExp)] -> ([SubExp], Slice SubExp)
forall a b. [(a, b)] -> ([a], [b])
unzip ([(SubExp, DimIndex SubExp)] -> ([SubExp], Slice SubExp))
-> [(SubExp, DimIndex SubExp)] -> ([SubExp], Slice SubExp)
forall a b. (a -> b) -> a -> b
$ (DimIndex SubExp -> Maybe (SubExp, DimIndex SubExp))
-> Slice SubExp -> [(SubExp, DimIndex SubExp)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe DimIndex SubExp -> Maybe (SubExp, DimIndex SubExp)
index Slice SubExp
ds_inds,
        [SubExp]
ds' [SubExp] -> [SubExp] -> Bool
forall a. Eq a => a -> a -> Bool
/= [SubExp]
ds ->
        m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ do
          VName
arr <- String -> Exp (Lore m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"smaller_replicate" (Exp (Lore m) -> m VName) -> Exp (Lore m) -> m VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
ds') SubExp
v
          IndexResult -> m IndexResult
forall (m :: * -> *) a. Monad m => a -> m a
return (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certificates -> VName -> Slice SubExp -> IndexResult
IndexResult Certificates
cs VName
arr (Slice SubExp -> IndexResult) -> Slice SubExp -> IndexResult
forall a b. (a -> b) -> a -> b
$ Slice SubExp
ds_inds' Slice SubExp -> Slice SubExp -> Slice SubExp
forall a. [a] -> [a] -> [a]
++ Slice SubExp
rest_inds
      where index :: DimIndex SubExp -> Maybe (SubExp, DimIndex SubExp)
index DimFix{} = Maybe (SubExp, DimIndex SubExp)
forall a. Maybe a
Nothing
            index (DimSlice SubExp
_ SubExp
n SubExp
s) = (SubExp, DimIndex SubExp) -> Maybe (SubExp, DimIndex SubExp)
forall a. a -> Maybe a
Just (SubExp
n, SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice (Int32 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int32
0::Int32)) SubExp
n SubExp
s)

    Just (Rearrange [Int]
perm VName
src, Certificates
cs)
       | [Int] -> Int
rearrangeReach [Int]
perm Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Slice SubExp -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ((DimIndex SubExp -> Bool) -> Slice SubExp -> Slice SubExp
forall a. (a -> Bool) -> [a] -> [a]
takeWhile DimIndex SubExp -> Bool
forall d. DimIndex d -> Bool
isIndex Slice SubExp
inds) ->
         let inds' :: Slice SubExp
inds' = [Int] -> Slice SubExp -> Slice SubExp
forall a. [Int] -> [a] -> [a]
rearrangeShape ([Int] -> [Int]
rearrangeInverse [Int]
perm) Slice SubExp
inds
         in m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ IndexResult -> m IndexResult
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certificates -> VName -> Slice SubExp -> IndexResult
IndexResult Certificates
cs VName
src Slice SubExp
inds'
      where isIndex :: DimIndex d -> Bool
isIndex DimFix{} = Bool
True
            isIndex DimIndex d
_          = Bool
False

    Just (Copy VName
src, Certificates
cs)
      | Just [SubExp]
dims <- TypeBase Shape NoUniqueness -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (TypeBase Shape NoUniqueness -> [SubExp])
-> Maybe (TypeBase Shape NoUniqueness) -> Maybe [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType (VName -> SubExp
Var VName
src),
        Slice SubExp -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Slice SubExp
inds Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
dims,
        Bool -> Bool
not Bool
consuming, VName -> SymbolTable (Lore m) -> Bool
forall lore. VName -> SymbolTable lore -> Bool
ST.available VName
src SymbolTable (Lore m)
vtable ->
          m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ IndexResult -> m IndexResult
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certificates -> VName -> Slice SubExp -> IndexResult
IndexResult Certificates
cs VName
src Slice SubExp
inds

    Just (Reshape ShapeChange SubExp
newshape VName
src, Certificates
cs)
      | Just [SubExp]
newdims <- ShapeChange SubExp -> Maybe [SubExp]
forall d. ShapeChange d -> Maybe [d]
shapeCoercion ShapeChange SubExp
newshape,
        Just [SubExp]
olddims <- TypeBase Shape NoUniqueness -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (TypeBase Shape NoUniqueness -> [SubExp])
-> Maybe (TypeBase Shape NoUniqueness) -> Maybe [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType (VName -> SubExp
Var VName
src),
        [Bool]
changed_dims <- (SubExp -> SubExp -> Bool) -> [SubExp] -> [SubExp] -> [Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
(/=) [SubExp]
newdims [SubExp]
olddims,
        Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
or ([Bool] -> Bool) -> [Bool] -> Bool
forall a b. (a -> b) -> a -> b
$ Int -> [Bool] -> [Bool]
forall a. Int -> [a] -> [a]
drop (Slice SubExp -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Slice SubExp
inds) [Bool]
changed_dims ->
        m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ IndexResult -> m IndexResult
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certificates -> VName -> Slice SubExp -> IndexResult
IndexResult Certificates
cs VName
src Slice SubExp
inds

      | Just [SubExp]
newdims <- ShapeChange SubExp -> Maybe [SubExp]
forall d. ShapeChange d -> Maybe [d]
shapeCoercion ShapeChange SubExp
newshape,
        Just [SubExp]
olddims <- TypeBase Shape NoUniqueness -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (TypeBase Shape NoUniqueness -> [SubExp])
-> Maybe (TypeBase Shape NoUniqueness) -> Maybe [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType (VName -> SubExp
Var VName
src),
        ShapeChange SubExp -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeChange SubExp
newshape Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Slice SubExp -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Slice SubExp
inds,
        [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
olddims Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
newdims ->
        m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ IndexResult -> m IndexResult
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certificates -> VName -> Slice SubExp -> IndexResult
IndexResult Certificates
cs VName
src Slice SubExp
inds

    Just (Reshape [DimChange SubExp
_] VName
v2, Certificates
cs)
      | Just [SubExp
_] <- TypeBase Shape NoUniqueness -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (TypeBase Shape NoUniqueness -> [SubExp])
-> Maybe (TypeBase Shape NoUniqueness) -> Maybe [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType (VName -> SubExp
Var VName
v2) ->
        m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ IndexResult -> m IndexResult
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certificates -> VName -> Slice SubExp -> IndexResult
IndexResult Certificates
cs VName
v2 Slice SubExp
inds

    Just (Concat Int
d VName
x [VName]
xs SubExp
_, Certificates
cs)
      | Just (Slice SubExp
ibef, DimFix SubExp
i, Slice SubExp
iaft) <- Int
-> Slice SubExp
-> Maybe (Slice SubExp, DimIndex SubExp, Slice SubExp)
forall int a. Integral int => int -> [a] -> Maybe ([a], a, [a])
focusNth Int
d Slice SubExp
inds,
        Just (Prim PrimType
res_t) <- (TypeBase Shape NoUniqueness
-> [SubExp] -> TypeBase Shape NoUniqueness
forall oldshape u.
TypeBase oldshape u -> [SubExp] -> TypeBase Shape u
`setArrayDims` Slice SubExp -> [SubExp]
forall d. Slice d -> [d]
sliceDims Slice SubExp
inds) (TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness)
-> Maybe (TypeBase Shape NoUniqueness)
-> Maybe (TypeBase Shape NoUniqueness)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
                             VName
-> SymbolTable (Lore m) -> Maybe (TypeBase Shape NoUniqueness)
forall lore.
ASTLore lore =>
VName -> SymbolTable lore -> Maybe (TypeBase Shape NoUniqueness)
ST.lookupType VName
x SymbolTable (Lore m)
vtable -> m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ do
      SubExp
x_len <- Int -> TypeBase Shape NoUniqueness -> SubExp
forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
d (TypeBase Shape NoUniqueness -> SubExp)
-> m (TypeBase Shape NoUniqueness) -> m SubExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> m (TypeBase Shape NoUniqueness)
forall lore (m :: * -> *).
HasScope lore m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
x
      [SubExp]
xs_lens <- (VName -> m SubExp) -> [VName] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((TypeBase Shape NoUniqueness -> SubExp)
-> m (TypeBase Shape NoUniqueness) -> m SubExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Int -> TypeBase Shape NoUniqueness -> SubExp
forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
d) (m (TypeBase Shape NoUniqueness) -> m SubExp)
-> (VName -> m (TypeBase Shape NoUniqueness)) -> VName -> m SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> m (TypeBase Shape NoUniqueness)
forall lore (m :: * -> *).
HasScope lore m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType) [VName]
xs

      let add :: SubExp -> SubExp -> m (SubExp, SubExp)
add SubExp
n SubExp
m = do
            SubExp
added <- String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"index_concat_add" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Add IntType
Int32 Overflow
OverflowWrap) SubExp
n SubExp
m
            (SubExp, SubExp) -> m (SubExp, SubExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp
added, SubExp
n)
      (SubExp
_, [SubExp]
starts) <- (SubExp -> SubExp -> m (SubExp, SubExp))
-> SubExp -> [SubExp] -> m (SubExp, [SubExp])
forall (m :: * -> *) acc x y.
Monad m =>
(acc -> x -> m (acc, y)) -> acc -> [x] -> m (acc, [y])
mapAccumLM SubExp -> SubExp -> m (SubExp, SubExp)
forall (m :: * -> *).
MonadBinder m =>
SubExp -> SubExp -> m (SubExp, SubExp)
add SubExp
x_len [SubExp]
xs_lens
      let xs_and_starts :: [(VName, SubExp)]
xs_and_starts = [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a]
reverse ([(VName, SubExp)] -> [(VName, SubExp)])
-> [(VName, SubExp)] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ [VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
xs [SubExp]
starts

      let mkBranch :: [(VName, SubExp)] -> m SubExp
mkBranch [] =
            String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"index_concat" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
x (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Slice SubExp
ibef Slice SubExp -> Slice SubExp -> Slice SubExp
forall a. [a] -> [a] -> [a]
++ SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
i DimIndex SubExp -> Slice SubExp -> Slice SubExp
forall a. a -> [a] -> [a]
: Slice SubExp
iaft
          mkBranch ((VName
x', SubExp
start):[(VName, SubExp)]
xs_and_starts') = do
            SubExp
cmp <- String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"index_concat_cmp" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ CmpOp -> SubExp -> SubExp -> BasicOp
CmpOp (IntType -> CmpOp
CmpSle IntType
Int32) SubExp
start SubExp
i
            (SubExp
thisres, Stms (Lore m)
thisbnds) <- m SubExp -> m (SubExp, Stms (Lore m))
forall (m :: * -> *) a.
MonadBinder m =>
m a -> m (a, Stms (Lore m))
collectStms (m SubExp -> m (SubExp, Stms (Lore m)))
-> m SubExp -> m (SubExp, Stms (Lore m))
forall a b. (a -> b) -> a -> b
$ do
              SubExp
i' <- String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"index_concat_i" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Sub IntType
Int32 Overflow
OverflowWrap) SubExp
i SubExp
start
              String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"index_concat" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
x' (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Slice SubExp
ibef Slice SubExp -> Slice SubExp -> Slice SubExp
forall a. [a] -> [a] -> [a]
++ SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
i' DimIndex SubExp -> Slice SubExp -> Slice SubExp
forall a. a -> [a] -> [a]
: Slice SubExp
iaft
            BodyT (Lore m)
thisbody <- Stms (Lore m) -> [SubExp] -> m (BodyT (Lore m))
forall (m :: * -> *).
MonadBinder m =>
Stms (Lore m) -> [SubExp] -> m (Body (Lore m))
mkBodyM Stms (Lore m)
thisbnds [SubExp
thisres]
            (SubExp
altres, Stms (Lore m)
altbnds) <- m SubExp -> m (SubExp, Stms (Lore m))
forall (m :: * -> *) a.
MonadBinder m =>
m a -> m (a, Stms (Lore m))
collectStms (m SubExp -> m (SubExp, Stms (Lore m)))
-> m SubExp -> m (SubExp, Stms (Lore m))
forall a b. (a -> b) -> a -> b
$ [(VName, SubExp)] -> m SubExp
mkBranch [(VName, SubExp)]
xs_and_starts'
            BodyT (Lore m)
altbody <- Stms (Lore m) -> [SubExp] -> m (BodyT (Lore m))
forall (m :: * -> *).
MonadBinder m =>
Stms (Lore m) -> [SubExp] -> m (Body (Lore m))
mkBodyM Stms (Lore m)
altbnds [SubExp
altres]
            String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"index_concat_branch" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ SubExp
-> BodyT (Lore m)
-> BodyT (Lore m)
-> IfDec (BranchType (Lore m))
-> Exp (Lore m)
forall lore.
SubExp
-> BodyT lore -> BodyT lore -> IfDec (BranchType lore) -> ExpT lore
If SubExp
cmp BodyT (Lore m)
thisbody BodyT (Lore m)
altbody (IfDec (BranchType (Lore m)) -> Exp (Lore m))
-> IfDec (BranchType (Lore m)) -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$
              [BranchType (Lore m)] -> IfSort -> IfDec (BranchType (Lore m))
forall rt. [rt] -> IfSort -> IfDec rt
IfDec [PrimType -> BranchType (Lore m)
forall rt. IsBodyType rt => PrimType -> rt
primBodyType PrimType
res_t] IfSort
IfNormal
      Certificates -> SubExp -> IndexResult
SubExpResult Certificates
cs (SubExp -> IndexResult) -> m SubExp -> m IndexResult
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [(VName, SubExp)] -> m SubExp
mkBranch [(VName, SubExp)]
xs_and_starts

    Just (ArrayLit [SubExp]
ses TypeBase Shape NoUniqueness
_, Certificates
cs)
      | DimFix (Constant (IntValue (Int32Value Int32
i))) : Slice SubExp
inds' <- Slice SubExp
inds,
        Just SubExp
se <- Int32 -> [SubExp] -> Maybe SubExp
forall int a. Integral int => int -> [a] -> Maybe a
maybeNth Int32
i [SubExp]
ses ->
        case Slice SubExp
inds' of
          [] -> m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ IndexResult -> m IndexResult
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certificates -> SubExp -> IndexResult
SubExpResult Certificates
cs SubExp
se
          Slice SubExp
_ | Var VName
v2 <- SubExp
se  -> m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ IndexResult -> m IndexResult
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certificates -> VName -> Slice SubExp -> IndexResult
IndexResult Certificates
cs VName
v2 Slice SubExp
inds'
          Slice SubExp
_ -> Maybe (m IndexResult)
forall a. Maybe a
Nothing

    -- Indexing single-element arrays.  We know the index must be 0.
    Maybe (BasicOp, Certificates)
_ | Just TypeBase Shape NoUniqueness
t <- SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType (SubExp -> Maybe (TypeBase Shape NoUniqueness))
-> SubExp -> Maybe (TypeBase Shape NoUniqueness)
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
idd, SubExp -> Bool
isCt1 (SubExp -> Bool) -> SubExp -> Bool
forall a b. (a -> b) -> a -> b
$ Int -> TypeBase Shape NoUniqueness -> SubExp
forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 TypeBase Shape NoUniqueness
t,
        DimFix SubExp
i : Slice SubExp
inds' <- Slice SubExp
inds, Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ SubExp -> Bool
isCt0 SubExp
i ->
          m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ IndexResult -> m IndexResult
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certificates -> VName -> Slice SubExp -> IndexResult
IndexResult Certificates
forall a. Monoid a => a
mempty VName
idd (Slice SubExp -> IndexResult) -> Slice SubExp -> IndexResult
forall a b. (a -> b) -> a -> b
$
          SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (Int32 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int32
0::Int32)) DimIndex SubExp -> Slice SubExp -> Slice SubExp
forall a. a -> [a] -> [a]
: Slice SubExp
inds'

    Maybe (BasicOp, Certificates)
_ -> Maybe (m IndexResult)
forall a. Maybe a
Nothing

    where defOf :: VName -> Maybe (BasicOp, Certificates)
defOf VName
v = do (BasicOp BasicOp
op, Certificates
def_cs) <- VName -> SymbolTable (Lore m) -> Maybe (Exp (Lore m), Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (Exp lore, Certificates)
ST.lookupExp VName
v SymbolTable (Lore m)
vtable
                       (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall (m :: * -> *) a. Monad m => a -> m a
return (BasicOp
op, Certificates
def_cs)

          -- | A crude heuristic for determining when a PrimExp is
          -- worth inlining over keeping it in an array and reading it
          -- from memory.
          worthInlining :: PrimExp v -> Bool
worthInlining PrimExp v
e
            | Int -> PrimExp v -> Bool
forall v. Int -> PrimExp v -> Bool
primExpSizeAtLeast Int
20 PrimExp v
e = Bool
False -- totally ad-hoc.
            | Bool
otherwise = PrimExp v -> Bool
forall v. PrimExp v -> Bool
worthInlining' PrimExp v
e
          worthInlining' :: PrimExp v -> Bool
worthInlining' (BinOpExp Pow{} PrimExp v
_ PrimExp v
_) = Bool
False
          worthInlining' (BinOpExp FPow{} PrimExp v
_ PrimExp v
_) = Bool
False
          worthInlining' (BinOpExp BinOp
_ PrimExp v
x PrimExp v
y) = PrimExp v -> Bool
worthInlining' PrimExp v
x Bool -> Bool -> Bool
&& PrimExp v -> Bool
worthInlining' PrimExp v
y
          worthInlining' (CmpOpExp CmpOp
_ PrimExp v
x PrimExp v
y) = PrimExp v -> Bool
worthInlining' PrimExp v
x Bool -> Bool -> Bool
&& PrimExp v -> Bool
worthInlining' PrimExp v
y
          worthInlining' (ConvOpExp ConvOp
_ PrimExp v
x) = PrimExp v -> Bool
worthInlining' PrimExp v
x
          worthInlining' (UnOpExp UnOp
_ PrimExp v
x) = PrimExp v -> Bool
worthInlining' PrimExp v
x
          worthInlining' FunExp{} = Bool
False
          worthInlining' PrimExp v
_ = Bool
True

simplifyConcat :: BinderOps lore => BottomUpRuleBasicOp lore

-- concat@1(transpose(x),transpose(y)) == transpose(concat@0(x,y))
simplifyConcat :: BottomUpRuleBasicOp lore
simplifyConcat (SymbolTable lore
vtable, UsageTable
_) Pattern lore
pat StmAux (ExpDec lore)
_ (Concat Int
i VName
x [VName]
xs SubExp
new_d)
  | Just Int
r <- TypeBase Shape NoUniqueness -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank (TypeBase Shape NoUniqueness -> Int)
-> Maybe (TypeBase Shape NoUniqueness) -> Maybe Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> SymbolTable lore -> Maybe (TypeBase Shape NoUniqueness)
forall lore.
ASTLore lore =>
VName -> SymbolTable lore -> Maybe (TypeBase Shape NoUniqueness)
ST.lookupType VName
x SymbolTable lore
vtable,
    let perm :: [Int]
perm = [Int
i] [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [Int
0..Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1..Int
rInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1],
    Just (VName
x',Certificates
x_cs) <- [Int] -> VName -> Maybe (VName, Certificates)
transposedBy [Int]
perm VName
x,
    Just ([VName]
xs',[Certificates]
xs_cs) <- [(VName, Certificates)] -> ([VName], [Certificates])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, Certificates)] -> ([VName], [Certificates]))
-> Maybe [(VName, Certificates)] -> Maybe ([VName], [Certificates])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VName -> Maybe (VName, Certificates))
-> [VName] -> Maybe [(VName, Certificates)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ([Int] -> VName -> Maybe (VName, Certificates)
transposedBy [Int]
perm) [VName]
xs = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
      VName
concat_rearrange <-
        Certificates -> RuleM lore VName -> RuleM lore VName
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (Certificates
x_csCertificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<>[Certificates] -> Certificates
forall a. Monoid a => [a] -> a
mconcat [Certificates]
xs_cs) (RuleM lore VName -> RuleM lore VName)
-> RuleM lore VName -> RuleM lore VName
forall a b. (a -> b) -> a -> b
$
        String -> Exp (Lore (RuleM lore)) -> RuleM lore VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"concat_rearrange" (Exp (Lore (RuleM lore)) -> RuleM lore VName)
-> Exp (Lore (RuleM lore)) -> RuleM lore VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ Int -> VName -> [VName] -> SubExp -> BasicOp
Concat Int
0 VName
x' [VName]
xs' SubExp
new_d
      Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Rearrange [Int]
perm VName
concat_rearrange
  where transposedBy :: [Int] -> VName -> Maybe (VName, Certificates)
transposedBy [Int]
perm1 VName
v =
          case VName -> SymbolTable lore -> Maybe (ExpT lore, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (Exp lore, Certificates)
ST.lookupExp VName
v SymbolTable lore
vtable of
            Just (BasicOp (Rearrange [Int]
perm2 VName
v'), Certificates
vcs)
              | [Int]
perm1 [Int] -> [Int] -> Bool
forall a. Eq a => a -> a -> Bool
== [Int]
perm2 -> (VName, Certificates) -> Maybe (VName, Certificates)
forall a. a -> Maybe a
Just (VName
v', Certificates
vcs)
            Maybe (ExpT lore, Certificates)
_ -> Maybe (VName, Certificates)
forall a. Maybe a
Nothing

-- concat xs (concat ys zs) == concat xs ys zs
simplifyConcat (SymbolTable lore
vtable, UsageTable
_) Pattern lore
pat (StmAux Certificates
cs Attrs
attrs ExpDec lore
_) (Concat Int
i VName
x [VName]
xs SubExp
new_d)
  | VName
x' VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
/= VName
x Bool -> Bool -> Bool
|| [[VName]] -> [VName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[VName]]
xs' [VName] -> [VName] -> Bool
forall a. Eq a => a -> a -> Bool
/= [VName]
xs = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$
      Certificates -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (Certificates
csCertificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<>Certificates
x_csCertificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<>[Certificates] -> Certificates
forall a. Monoid a => [a] -> a
mconcat [Certificates]
xs_cs) (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
      Attrs -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Attrs -> m a -> m a
attributing Attrs
attrs (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
      BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ Int -> VName -> [VName] -> SubExp -> BasicOp
Concat Int
i VName
x' ([VName]
zs[VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++[[VName]] -> [VName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[VName]]
xs') SubExp
new_d
  where (VName
x':[VName]
zs, Certificates
x_cs) = VName -> ([VName], Certificates)
isConcat VName
x
        ([[VName]]
xs', [Certificates]
xs_cs) = [([VName], Certificates)] -> ([[VName]], [Certificates])
forall a b. [(a, b)] -> ([a], [b])
unzip ([([VName], Certificates)] -> ([[VName]], [Certificates]))
-> [([VName], Certificates)] -> ([[VName]], [Certificates])
forall a b. (a -> b) -> a -> b
$ (VName -> ([VName], Certificates))
-> [VName] -> [([VName], Certificates)]
forall a b. (a -> b) -> [a] -> [b]
map VName -> ([VName], Certificates)
isConcat [VName]
xs
        isConcat :: VName -> ([VName], Certificates)
isConcat VName
v = case VName -> SymbolTable lore -> Maybe (BasicOp, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (BasicOp, Certificates)
ST.lookupBasicOp VName
v SymbolTable lore
vtable of
                       Just (Concat Int
j VName
y [VName]
ys SubExp
_, Certificates
v_cs) | Int
j Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
i -> (VName
y VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
ys, Certificates
v_cs)
                       Maybe (BasicOp, Certificates)
_ -> ([VName
v], Certificates
forall a. Monoid a => a
mempty)

-- If concatenating a bunch of array literals (or equivalent
-- replicate), just construct the array literal instead.
simplifyConcat (SymbolTable lore
vtable, UsageTable
_) Pattern lore
pat StmAux (ExpDec lore)
aux (Concat Int
0 VName
x [VName]
xs SubExp
_)
  | Just ([[SubExp]]
vs, [Certificates]
vcs) <- [([SubExp], Certificates)] -> ([[SubExp]], [Certificates])
forall a b. [(a, b)] -> ([a], [b])
unzip ([([SubExp], Certificates)] -> ([[SubExp]], [Certificates]))
-> Maybe [([SubExp], Certificates)]
-> Maybe ([[SubExp]], [Certificates])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VName -> Maybe ([SubExp], Certificates))
-> [VName] -> Maybe [([SubExp], Certificates)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> Maybe ([SubExp], Certificates)
isArrayLit (VName
xVName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
:[VName]
xs) = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
      TypeBase Shape NoUniqueness
rt <- TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
forall shape u.
ArrayShape shape =>
TypeBase shape u -> TypeBase shape u
rowType (TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness)
-> RuleM lore (TypeBase Shape NoUniqueness)
-> RuleM lore (TypeBase Shape NoUniqueness)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> RuleM lore (TypeBase Shape NoUniqueness)
forall lore (m :: * -> *).
HasScope lore m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
x
      Certificates -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying ([Certificates] -> Certificates
forall a. Monoid a => [a] -> a
mconcat [Certificates]
vcs) (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ StmAux (ExpDec lore) -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux (ExpDec lore)
aux (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
        Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ [SubExp] -> TypeBase Shape NoUniqueness -> BasicOp
ArrayLit ([[SubExp]] -> [SubExp]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[SubExp]]
vs) TypeBase Shape NoUniqueness
rt
      where isArrayLit :: VName -> Maybe ([SubExp], Certificates)
isArrayLit VName
v
              | Just (Replicate Shape
shape SubExp
se, Certificates
vcs) <- VName -> SymbolTable lore -> Maybe (BasicOp, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (BasicOp, Certificates)
ST.lookupBasicOp VName
v SymbolTable lore
vtable,
                Shape -> Bool
unitShape Shape
shape = ([SubExp], Certificates) -> Maybe ([SubExp], Certificates)
forall a. a -> Maybe a
Just ([SubExp
se], Certificates
vcs)
              | Just (ArrayLit [SubExp]
ses TypeBase Shape NoUniqueness
_, Certificates
vcs) <- VName -> SymbolTable lore -> Maybe (BasicOp, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (BasicOp, Certificates)
ST.lookupBasicOp VName
v SymbolTable lore
vtable =
                  ([SubExp], Certificates) -> Maybe ([SubExp], Certificates)
forall a. a -> Maybe a
Just ([SubExp]
ses, Certificates
vcs)
              | Bool
otherwise =
                  Maybe ([SubExp], Certificates)
forall a. Maybe a
Nothing

            unitShape :: Shape -> Bool
unitShape = (Shape -> Shape -> Bool
forall a. Eq a => a -> a -> Bool
==[SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue (IntValue -> PrimValue) -> IntValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ Int32 -> IntValue
Int32Value Int32
1])

simplifyConcat (SymbolTable lore, UsageTable)
_ Pattern lore
_ StmAux (ExpDec lore)
_  BasicOp
_ = Rule lore
forall lore. Rule lore
Skip

ruleIf :: BinderOps lore => TopDownRuleIf lore

ruleIf :: TopDownRuleIf lore
ruleIf TopDown lore
_ Pattern lore
pat StmAux (ExpDec lore)
_ (SubExp
e1, BodyT lore
tb, BodyT lore
fb, IfDec [BranchType lore]
_ IfSort
ifsort)
  | Just BodyT lore
branch <- Maybe (BodyT lore)
checkBranch,
    IfSort
ifsort IfSort -> IfSort -> Bool
forall a. Eq a => a -> a -> Bool
/= IfSort
IfFallback Bool -> Bool -> Bool
|| SubExp -> Bool
isCt1 SubExp
e1 = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
  let ses :: [SubExp]
ses = BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT lore
branch
  Stms (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms (Stms (Lore (RuleM lore)) -> RuleM lore ())
-> Stms (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BodyT lore -> Stms lore
forall lore. BodyT lore -> Stms lore
bodyStms BodyT lore
branch
  [RuleM lore ()] -> RuleM lore ()
forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, Monad m) =>
t (m a) -> m ()
sequence_ [ [VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [PatElemT (LetDec lore) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
p] (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
            | (PatElemT (LetDec lore)
p,SubExp
se) <- [PatElemT (LetDec lore)]
-> [SubExp] -> [(PatElemT (LetDec lore), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pattern lore -> [PatElemT (LetDec lore)]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern lore
pat) [SubExp]
ses]

  where checkBranch :: Maybe (BodyT lore)
checkBranch
          | SubExp -> Bool
isCt1 SubExp
e1  = BodyT lore -> Maybe (BodyT lore)
forall a. a -> Maybe a
Just BodyT lore
tb
          | SubExp -> Bool
isCt0 SubExp
e1  = BodyT lore -> Maybe (BodyT lore)
forall a. a -> Maybe a
Just BodyT lore
fb
          | Bool
otherwise = Maybe (BodyT lore)
forall a. Maybe a
Nothing

-- IMPROVE: the following two rules can be generalised to work in more
-- cases, especially when the branches have bindings, or return more
-- than one value.
--
-- if c then True else v == c || v
ruleIf TopDown lore
_ Pattern lore
pat StmAux (ExpDec lore)
_
  (SubExp
cond, Body BodyDec lore
_ Stms lore
tstms [Constant (BoolValue Bool
True)],
         Body BodyDec lore
_ Stms lore
fstms [SubExp
se], IfDec [BranchType lore]
ts IfSort
_)
  | Stms lore -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Stms lore
tstms, Stms lore -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Stms lore
fstms, [Prim PrimType
Bool] <- (BranchType lore -> ExtType) -> [BranchType lore] -> [ExtType]
forall a b. (a -> b) -> [a] -> [b]
map BranchType lore -> ExtType
forall t. ExtTyped t => t -> ExtType
extTypeOf [BranchType lore]
ts =
      RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp BinOp
LogOr SubExp
cond SubExp
se

-- When type(x)==bool, if c then x else y == (c && x) || (!c && y)
ruleIf TopDown lore
_ Pattern lore
pat StmAux (ExpDec lore)
_ (SubExp
cond, BodyT lore
tb, BodyT lore
fb, IfDec [BranchType lore]
ts IfSort
_)
  | Body BodyDec lore
_ Stms lore
tstms [SubExp
tres] <- BodyT lore
tb,
    Body BodyDec lore
_ Stms lore
fstms [SubExp
fres] <- BodyT lore
fb,
    (Stm lore -> Bool) -> Stms lore -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (ExpT lore -> Bool
forall lore. IsOp (Op lore) => Exp lore -> Bool
safeExp (ExpT lore -> Bool) -> (Stm lore -> ExpT lore) -> Stm lore -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm lore -> ExpT lore
forall lore. Stm lore -> Exp lore
stmExp) (Stms lore -> Bool) -> Stms lore -> Bool
forall a b. (a -> b) -> a -> b
$ Stms lore
tstms Stms lore -> Stms lore -> Stms lore
forall a. Semigroup a => a -> a -> a
<> Stms lore
fstms,
    (BranchType lore -> Bool) -> [BranchType lore] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all ((ExtType -> ExtType -> Bool
forall a. Eq a => a -> a -> Bool
==PrimType -> ExtType
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
Bool) (ExtType -> Bool)
-> (BranchType lore -> ExtType) -> BranchType lore -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BranchType lore -> ExtType
forall t. ExtTyped t => t -> ExtType
extTypeOf) [BranchType lore]
ts = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
  Stms (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms Stms lore
Stms (Lore (RuleM lore))
tstms
  Stms (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms Stms lore
Stms (Lore (RuleM lore))
fstms
  ExpT lore
e <- BinOp
-> RuleM lore (Exp (Lore (RuleM lore)))
-> RuleM lore (Exp (Lore (RuleM lore)))
-> RuleM lore (Exp (Lore (RuleM lore)))
forall (m :: * -> *).
MonadBinder m =>
BinOp -> m (Exp (Lore m)) -> m (Exp (Lore m)) -> m (Exp (Lore m))
eBinOp BinOp
LogOr (ExpT lore -> RuleM lore (ExpT lore)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpT lore -> RuleM lore (ExpT lore))
-> ExpT lore -> RuleM lore (ExpT lore)
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp BinOp
LogAnd SubExp
cond SubExp
tres)
                    (BinOp
-> RuleM lore (Exp (Lore (RuleM lore)))
-> RuleM lore (Exp (Lore (RuleM lore)))
-> RuleM lore (Exp (Lore (RuleM lore)))
forall (m :: * -> *).
MonadBinder m =>
BinOp -> m (Exp (Lore m)) -> m (Exp (Lore m)) -> m (Exp (Lore m))
eBinOp BinOp
LogAnd (ExpT lore -> RuleM lore (ExpT lore)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpT lore -> RuleM lore (ExpT lore))
-> ExpT lore -> RuleM lore (ExpT lore)
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ UnOp -> SubExp -> BasicOp
UnOp UnOp
Not SubExp
cond)
                     (ExpT lore -> RuleM lore (ExpT lore)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpT lore -> RuleM lore (ExpT lore))
-> ExpT lore -> RuleM lore (ExpT lore)
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
fres))
  Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat ExpT lore
Exp (Lore (RuleM lore))
e

ruleIf TopDown lore
_ Pattern lore
pat StmAux (ExpDec lore)
_ (SubExp
_, BodyT lore
tbranch, BodyT lore
_, IfDec [BranchType lore]
_ IfSort
IfFallback)
  | [VName] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$ Pattern lore -> [VName]
forall dec. PatternT dec -> [VName]
patternContextNames Pattern lore
pat,
    (Stm lore -> Bool) -> Stms lore -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (ExpT lore -> Bool
forall lore. IsOp (Op lore) => Exp lore -> Bool
safeExp (ExpT lore -> Bool) -> (Stm lore -> ExpT lore) -> Stm lore -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm lore -> ExpT lore
forall lore. Stm lore -> Exp lore
stmExp) (Stms lore -> Bool) -> Stms lore -> Bool
forall a b. (a -> b) -> a -> b
$ BodyT lore -> Stms lore
forall lore. BodyT lore -> Stms lore
bodyStms BodyT lore
tbranch = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
      let ses :: [SubExp]
ses = BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT lore
tbranch
      Stms (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms (Stms (Lore (RuleM lore)) -> RuleM lore ())
-> Stms (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BodyT lore -> Stms lore
forall lore. BodyT lore -> Stms lore
bodyStms BodyT lore
tbranch
      [RuleM lore ()] -> RuleM lore ()
forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, Monad m) =>
t (m a) -> m ()
sequence_ [ [VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [PatElemT (LetDec lore) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
p] (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
                | (PatElemT (LetDec lore)
p,SubExp
se) <- [PatElemT (LetDec lore)]
-> [SubExp] -> [(PatElemT (LetDec lore), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pattern lore -> [PatElemT (LetDec lore)]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern lore
pat) [SubExp]
ses]

ruleIf TopDown lore
_ Pattern lore
pat StmAux (ExpDec lore)
_ (SubExp
cond, BodyT lore
tb, BodyT lore
fb, IfDec (BranchType lore)
_)
  | Body BodyDec lore
_ Stms lore
_ [Constant (IntValue IntValue
t)] <- BodyT lore
tb,
    Body BodyDec lore
_ Stms lore
_ [Constant (IntValue IntValue
f)] <- BodyT lore
fb =
      if IntValue -> Bool
oneIshInt IntValue
t Bool -> Bool -> Bool
&& IntValue -> Bool
zeroIshInt IntValue
f
      then RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$
           Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ ConvOp -> SubExp -> BasicOp
ConvOp (IntType -> ConvOp
BToI (IntValue -> IntType
intValueType IntValue
t)) SubExp
cond
      else if IntValue -> Bool
zeroIshInt IntValue
t Bool -> Bool -> Bool
&& IntValue -> Bool
oneIshInt IntValue
f
      then RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
        SubExp
cond_neg <- String -> Exp (Lore (RuleM lore)) -> RuleM lore SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"cond_neg" (Exp (Lore (RuleM lore)) -> RuleM lore SubExp)
-> Exp (Lore (RuleM lore)) -> RuleM lore SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ UnOp -> SubExp -> BasicOp
UnOp UnOp
Not SubExp
cond
        Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ ConvOp -> SubExp -> BasicOp
ConvOp (IntType -> ConvOp
BToI (IntValue -> IntType
intValueType IntValue
t)) SubExp
cond_neg
      else Rule lore
forall lore. Rule lore
Skip

ruleIf TopDown lore
_ Pattern lore
_ StmAux (ExpDec lore)
_ (SubExp, BodyT lore, BodyT lore, IfDec (BranchType lore))
_ = Rule lore
forall lore. Rule lore
Skip

-- | Move out results of a conditional expression whose computation is
-- either invariant to the branches (only done for results in the
-- context), or the same in both branches.
hoistBranchInvariant :: BinderOps lore => TopDownRuleIf lore
hoistBranchInvariant :: TopDownRuleIf lore
hoistBranchInvariant TopDown lore
_ Pattern lore
pat StmAux (ExpDec lore)
_ (SubExp
cond, BodyT lore
tb, BodyT lore
fb, IfDec [BranchType lore]
ret IfSort
ifsort) = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
  let tses :: [SubExp]
tses = BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT lore
tb
      fses :: [SubExp]
fses = BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT lore
fb
  ([Maybe (Int, SubExp)]
hoistings, ([PatElemT (LetDec lore)]
pes, [Either Int (BranchType lore)]
ts, [(SubExp, SubExp)]
res)) <-
    ([Either
    (Maybe (Int, SubExp))
    (PatElemT (LetDec lore), Either Int (BranchType lore),
     (SubExp, SubExp))]
 -> ([Maybe (Int, SubExp)],
     ([PatElemT (LetDec lore)], [Either Int (BranchType lore)],
      [(SubExp, SubExp)])))
-> RuleM
     lore
     [Either
        (Maybe (Int, SubExp))
        (PatElemT (LetDec lore), Either Int (BranchType lore),
         (SubExp, SubExp))]
-> RuleM
     lore
     ([Maybe (Int, SubExp)],
      ([PatElemT (LetDec lore)], [Either Int (BranchType lore)],
       [(SubExp, SubExp)]))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (([(PatElemT (LetDec lore), Either Int (BranchType lore),
   (SubExp, SubExp))]
 -> ([PatElemT (LetDec lore)], [Either Int (BranchType lore)],
     [(SubExp, SubExp)]))
-> ([Maybe (Int, SubExp)],
    [(PatElemT (LetDec lore), Either Int (BranchType lore),
      (SubExp, SubExp))])
-> ([Maybe (Int, SubExp)],
    ([PatElemT (LetDec lore)], [Either Int (BranchType lore)],
     [(SubExp, SubExp)]))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [(PatElemT (LetDec lore), Either Int (BranchType lore),
  (SubExp, SubExp))]
-> ([PatElemT (LetDec lore)], [Either Int (BranchType lore)],
    [(SubExp, SubExp)])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 (([Maybe (Int, SubExp)],
  [(PatElemT (LetDec lore), Either Int (BranchType lore),
    (SubExp, SubExp))])
 -> ([Maybe (Int, SubExp)],
     ([PatElemT (LetDec lore)], [Either Int (BranchType lore)],
      [(SubExp, SubExp)])))
-> ([Either
       (Maybe (Int, SubExp))
       (PatElemT (LetDec lore), Either Int (BranchType lore),
        (SubExp, SubExp))]
    -> ([Maybe (Int, SubExp)],
        [(PatElemT (LetDec lore), Either Int (BranchType lore),
          (SubExp, SubExp))]))
-> [Either
      (Maybe (Int, SubExp))
      (PatElemT (LetDec lore), Either Int (BranchType lore),
       (SubExp, SubExp))]
-> ([Maybe (Int, SubExp)],
    ([PatElemT (LetDec lore)], [Either Int (BranchType lore)],
     [(SubExp, SubExp)]))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Either
   (Maybe (Int, SubExp))
   (PatElemT (LetDec lore), Either Int (BranchType lore),
    (SubExp, SubExp))]
-> ([Maybe (Int, SubExp)],
    [(PatElemT (LetDec lore), Either Int (BranchType lore),
      (SubExp, SubExp))])
forall a b. [Either a b] -> ([a], [b])
partitionEithers) (RuleM
   lore
   [Either
      (Maybe (Int, SubExp))
      (PatElemT (LetDec lore), Either Int (BranchType lore),
       (SubExp, SubExp))]
 -> RuleM
      lore
      ([Maybe (Int, SubExp)],
       ([PatElemT (LetDec lore)], [Either Int (BranchType lore)],
        [(SubExp, SubExp)])))
-> RuleM
     lore
     [Either
        (Maybe (Int, SubExp))
        (PatElemT (LetDec lore), Either Int (BranchType lore),
         (SubExp, SubExp))]
-> RuleM
     lore
     ([Maybe (Int, SubExp)],
      ([PatElemT (LetDec lore)], [Either Int (BranchType lore)],
       [(SubExp, SubExp)]))
forall a b. (a -> b) -> a -> b
$ ((PatElemT (LetDec lore), Either Int (BranchType lore),
  (SubExp, SubExp))
 -> RuleM
      lore
      (Either
         (Maybe (Int, SubExp))
         (PatElemT (LetDec lore), Either Int (BranchType lore),
          (SubExp, SubExp))))
-> [(PatElemT (LetDec lore), Either Int (BranchType lore),
     (SubExp, SubExp))]
-> RuleM
     lore
     [Either
        (Maybe (Int, SubExp))
        (PatElemT (LetDec lore), Either Int (BranchType lore),
         (SubExp, SubExp))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (PatElemT (LetDec lore), Either Int (BranchType lore),
 (SubExp, SubExp))
-> RuleM
     lore
     (Either
        (Maybe (Int, SubExp))
        (PatElemT (LetDec lore), Either Int (BranchType lore),
         (SubExp, SubExp)))
branchInvariant ([(PatElemT (LetDec lore), Either Int (BranchType lore),
   (SubExp, SubExp))]
 -> RuleM
      lore
      [Either
         (Maybe (Int, SubExp))
         (PatElemT (LetDec lore), Either Int (BranchType lore),
          (SubExp, SubExp))])
-> [(PatElemT (LetDec lore), Either Int (BranchType lore),
     (SubExp, SubExp))]
-> RuleM
     lore
     [Either
        (Maybe (Int, SubExp))
        (PatElemT (LetDec lore), Either Int (BranchType lore),
         (SubExp, SubExp))]
forall a b. (a -> b) -> a -> b
$
      [PatElemT (LetDec lore)]
-> [Either Int (BranchType lore)]
-> [(SubExp, SubExp)]
-> [(PatElemT (LetDec lore), Either Int (BranchType lore),
     (SubExp, SubExp))]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 (Pattern lore -> [PatElemT (LetDec lore)]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern lore
pat)
           ((Int -> Either Int (BranchType lore))
-> [Int] -> [Either Int (BranchType lore)]
forall a b. (a -> b) -> [a] -> [b]
map Int -> Either Int (BranchType lore)
forall a b. a -> Either a b
Left [Int
0..Int
num_ctxInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] [Either Int (BranchType lore)]
-> [Either Int (BranchType lore)] -> [Either Int (BranchType lore)]
forall a. [a] -> [a] -> [a]
++ (BranchType lore -> Either Int (BranchType lore))
-> [BranchType lore] -> [Either Int (BranchType lore)]
forall a b. (a -> b) -> [a] -> [b]
map BranchType lore -> Either Int (BranchType lore)
forall a b. b -> Either a b
Right [BranchType lore]
ret)
           ([SubExp] -> [SubExp] -> [(SubExp, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [SubExp]
tses [SubExp]
fses)
  let ctx_fixes :: [(Int, SubExp)]
ctx_fixes = [Maybe (Int, SubExp)] -> [(Int, SubExp)]
forall a. [Maybe a] -> [a]
catMaybes [Maybe (Int, SubExp)]
hoistings
      ([SubExp]
tses', [SubExp]
fses') = [(SubExp, SubExp)] -> ([SubExp], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip [(SubExp, SubExp)]
res
      tb' :: BodyT lore
tb' = BodyT lore
tb { bodyResult :: [SubExp]
bodyResult = [SubExp]
tses' }
      fb' :: BodyT lore
fb' = BodyT lore
fb { bodyResult :: [SubExp]
bodyResult = [SubExp]
fses' }
      ret' :: [BranchType lore]
ret' = ((Int, SubExp) -> [BranchType lore] -> [BranchType lore])
-> [BranchType lore] -> [(Int, SubExp)] -> [BranchType lore]
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr ((Int -> SubExp -> [BranchType lore] -> [BranchType lore])
-> (Int, SubExp) -> [BranchType lore] -> [BranchType lore]
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Int -> SubExp -> [BranchType lore] -> [BranchType lore]
forall t. FixExt t => Int -> SubExp -> t -> t
fixExt) ([Either Int (BranchType lore)] -> [BranchType lore]
forall a b. [Either a b] -> [b]
rights [Either Int (BranchType lore)]
ts) [(Int, SubExp)]
ctx_fixes
      ([PatElemT (LetDec lore)]
ctx_pes, [PatElemT (LetDec lore)]
val_pes) = Int
-> [PatElemT (LetDec lore)]
-> ([PatElemT (LetDec lore)], [PatElemT (LetDec lore)])
forall a. Int -> [a] -> ([a], [a])
splitFromEnd ([BranchType lore] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [BranchType lore]
ret') [PatElemT (LetDec lore)]
pes
  if Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [Maybe (Int, SubExp)] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Maybe (Int, SubExp)]
hoistings -- Was something hoisted?
     then do -- We may have to add some reshapes if we made the type
             -- less existential.
             BodyT lore
tb'' <- BodyT (Lore (RuleM lore))
-> [ExtType] -> RuleM lore (BodyT (Lore (RuleM lore)))
forall (m :: * -> *).
MonadBinder m =>
BodyT (Lore m) -> [ExtType] -> m (BodyT (Lore m))
reshapeBodyResults BodyT lore
BodyT (Lore (RuleM lore))
tb' ([ExtType] -> RuleM lore (BodyT (Lore (RuleM lore))))
-> [ExtType] -> RuleM lore (BodyT (Lore (RuleM lore)))
forall a b. (a -> b) -> a -> b
$ (BranchType lore -> ExtType) -> [BranchType lore] -> [ExtType]
forall a b. (a -> b) -> [a] -> [b]
map BranchType lore -> ExtType
forall t. ExtTyped t => t -> ExtType
extTypeOf [BranchType lore]
ret'
             BodyT lore
fb'' <- BodyT (Lore (RuleM lore))
-> [ExtType] -> RuleM lore (BodyT (Lore (RuleM lore)))
forall (m :: * -> *).
MonadBinder m =>
BodyT (Lore m) -> [ExtType] -> m (BodyT (Lore m))
reshapeBodyResults BodyT lore
BodyT (Lore (RuleM lore))
fb' ([ExtType] -> RuleM lore (BodyT (Lore (RuleM lore))))
-> [ExtType] -> RuleM lore (BodyT (Lore (RuleM lore)))
forall a b. (a -> b) -> a -> b
$ (BranchType lore -> ExtType) -> [BranchType lore] -> [ExtType]
forall a b. (a -> b) -> [a] -> [b]
map BranchType lore -> ExtType
forall t. ExtTyped t => t -> ExtType
extTypeOf [BranchType lore]
ret'
             Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind ([PatElemT (LetDec lore)]
-> [PatElemT (LetDec lore)] -> Pattern lore
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [PatElemT (LetDec lore)]
ctx_pes [PatElemT (LetDec lore)]
val_pes) (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
               SubExp
-> BodyT lore -> BodyT lore -> IfDec (BranchType lore) -> ExpT lore
forall lore.
SubExp
-> BodyT lore -> BodyT lore -> IfDec (BranchType lore) -> ExpT lore
If SubExp
cond BodyT lore
tb'' BodyT lore
fb'' ([BranchType lore] -> IfSort -> IfDec (BranchType lore)
forall rt. [rt] -> IfSort -> IfDec rt
IfDec [BranchType lore]
ret' IfSort
ifsort)
     else RuleM lore ()
forall lore a. RuleM lore a
cannotSimplify
  where num_ctx :: Int
num_ctx = [PatElemT (LetDec lore)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([PatElemT (LetDec lore)] -> Int)
-> [PatElemT (LetDec lore)] -> Int
forall a b. (a -> b) -> a -> b
$ Pattern lore -> [PatElemT (LetDec lore)]
forall dec. PatternT dec -> [PatElemT dec]
patternContextElements Pattern lore
pat
        bound_in_branches :: Names
bound_in_branches = [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ (Stm lore -> [VName]) -> Seq (Stm lore) -> [VName]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Pattern lore -> [VName]
forall dec. PatternT dec -> [VName]
patternNames (Pattern lore -> [VName])
-> (Stm lore -> Pattern lore) -> Stm lore -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm lore -> Pattern lore
forall lore. Stm lore -> Pattern lore
stmPattern) (Seq (Stm lore) -> [VName]) -> Seq (Stm lore) -> [VName]
forall a b. (a -> b) -> a -> b
$
                            BodyT lore -> Seq (Stm lore)
forall lore. BodyT lore -> Stms lore
bodyStms BodyT lore
tb Seq (Stm lore) -> Seq (Stm lore) -> Seq (Stm lore)
forall a. Semigroup a => a -> a -> a
<> BodyT lore -> Seq (Stm lore)
forall lore. BodyT lore -> Stms lore
bodyStms BodyT lore
fb
        mem_sizes :: Names
mem_sizes = [PatElemT (LetDec lore)] -> Names
forall a. FreeIn a => a -> Names
freeIn ([PatElemT (LetDec lore)] -> Names)
-> [PatElemT (LetDec lore)] -> Names
forall a b. (a -> b) -> a -> b
$ (PatElemT (LetDec lore) -> Bool)
-> [PatElemT (LetDec lore)] -> [PatElemT (LetDec lore)]
forall a. (a -> Bool) -> [a] -> [a]
filter (TypeBase Shape NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
isMem (TypeBase Shape NoUniqueness -> Bool)
-> (PatElemT (LetDec lore) -> TypeBase Shape NoUniqueness)
-> PatElemT (LetDec lore)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElemT (LetDec lore) -> TypeBase Shape NoUniqueness
forall dec.
Typed dec =>
PatElemT dec -> TypeBase Shape NoUniqueness
patElemType) ([PatElemT (LetDec lore)] -> [PatElemT (LetDec lore)])
-> [PatElemT (LetDec lore)] -> [PatElemT (LetDec lore)]
forall a b. (a -> b) -> a -> b
$ Pattern lore -> [PatElemT (LetDec lore)]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern lore
pat
        invariant :: SubExp -> Bool
invariant Constant{} = Bool
True
        invariant (Var VName
v) = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ VName
v VName -> Names -> Bool
`nameIn` Names
bound_in_branches

        isMem :: TypeBase shape u -> Bool
isMem Mem{} = Bool
True
        isMem TypeBase shape u
_ = Bool
False
        sizeOfMem :: VName -> Bool
sizeOfMem VName
v = VName
v VName -> Names -> Bool
`nameIn` Names
mem_sizes

        branchInvariant :: (PatElemT (LetDec lore), Either Int (BranchType lore),
 (SubExp, SubExp))
-> RuleM
     lore
     (Either
        (Maybe (Int, SubExp))
        (PatElemT (LetDec lore), Either Int (BranchType lore),
         (SubExp, SubExp)))
branchInvariant (PatElemT (LetDec lore)
pe, Either Int (BranchType lore)
t, (SubExp
tse, SubExp
fse))
          -- Do both branches return the same value?
          | SubExp
tse SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
fse = do
              [VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [PatElemT (LetDec lore) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
pe] (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
tse
              PatElemT (LetDec lore)
-> Either Int (BranchType lore)
-> RuleM
     lore
     (Either
        (Maybe (Int, SubExp))
        (PatElemT (LetDec lore), Either Int (BranchType lore),
         (SubExp, SubExp)))
forall (m :: * -> *) dec a b b.
Monad m =>
PatElemT dec -> Either a b -> m (Either (Maybe (a, SubExp)) b)
hoisted PatElemT (LetDec lore)
pe Either Int (BranchType lore)
t

          -- Do both branches return values that are free in the
          -- branch, and are we not the only pattern element?  The
          -- latter is to avoid infinite application of this rule.
          | SubExp -> Bool
invariant SubExp
tse, SubExp -> Bool
invariant SubExp
fse, Pattern lore -> Int
forall dec. PatternT dec -> Int
patternSize Pattern lore
pat Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
1,
            Prim PrimType
_ <- PatElemT (LetDec lore) -> TypeBase Shape NoUniqueness
forall dec.
Typed dec =>
PatElemT dec -> TypeBase Shape NoUniqueness
patElemType PatElemT (LetDec lore)
pe, Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ VName -> Bool
sizeOfMem (VName -> Bool) -> VName -> Bool
forall a b. (a -> b) -> a -> b
$ PatElemT (LetDec lore) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
pe = do
              [BranchType lore]
bt <- Pattern lore -> RuleM lore [BranchType lore]
forall lore (m :: * -> *).
(ASTLore lore, HasScope lore m, Monad m) =>
Pattern lore -> m [BranchType lore]
expTypesFromPattern (Pattern lore -> RuleM lore [BranchType lore])
-> Pattern lore -> RuleM lore [BranchType lore]
forall a b. (a -> b) -> a -> b
$ [PatElemT (LetDec lore)]
-> [PatElemT (LetDec lore)] -> Pattern lore
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [PatElemT (LetDec lore)
pe]
              [VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [PatElemT (LetDec lore) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
pe] (ExpT lore -> RuleM lore ())
-> RuleM lore (ExpT lore) -> RuleM lore ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
                (SubExp
-> BodyT lore -> BodyT lore -> IfDec (BranchType lore) -> ExpT lore
forall lore.
SubExp
-> BodyT lore -> BodyT lore -> IfDec (BranchType lore) -> ExpT lore
If SubExp
cond (BodyT lore -> BodyT lore -> IfDec (BranchType lore) -> ExpT lore)
-> RuleM lore (BodyT lore)
-> RuleM lore (BodyT lore -> IfDec (BranchType lore) -> ExpT lore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [SubExp] -> RuleM lore (BodyT (Lore (RuleM lore)))
forall (m :: * -> *).
MonadBinder m =>
[SubExp] -> m (Body (Lore m))
resultBodyM [SubExp
tse]
                         RuleM lore (BodyT lore -> IfDec (BranchType lore) -> ExpT lore)
-> RuleM lore (BodyT lore)
-> RuleM lore (IfDec (BranchType lore) -> ExpT lore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [SubExp] -> RuleM lore (BodyT (Lore (RuleM lore)))
forall (m :: * -> *).
MonadBinder m =>
[SubExp] -> m (Body (Lore m))
resultBodyM [SubExp
fse]
                         RuleM lore (IfDec (BranchType lore) -> ExpT lore)
-> RuleM lore (IfDec (BranchType lore)) -> RuleM lore (ExpT lore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> IfDec (BranchType lore) -> RuleM lore (IfDec (BranchType lore))
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([BranchType lore] -> IfSort -> IfDec (BranchType lore)
forall rt. [rt] -> IfSort -> IfDec rt
IfDec [BranchType lore]
bt IfSort
ifsort))
              PatElemT (LetDec lore)
-> Either Int (BranchType lore)
-> RuleM
     lore
     (Either
        (Maybe (Int, SubExp))
        (PatElemT (LetDec lore), Either Int (BranchType lore),
         (SubExp, SubExp)))
forall (m :: * -> *) dec a b b.
Monad m =>
PatElemT dec -> Either a b -> m (Either (Maybe (a, SubExp)) b)
hoisted PatElemT (LetDec lore)
pe Either Int (BranchType lore)
t

          | Bool
otherwise =
              Either
  (Maybe (Int, SubExp))
  (PatElemT (LetDec lore), Either Int (BranchType lore),
   (SubExp, SubExp))
-> RuleM
     lore
     (Either
        (Maybe (Int, SubExp))
        (PatElemT (LetDec lore), Either Int (BranchType lore),
         (SubExp, SubExp)))
forall (m :: * -> *) a. Monad m => a -> m a
return (Either
   (Maybe (Int, SubExp))
   (PatElemT (LetDec lore), Either Int (BranchType lore),
    (SubExp, SubExp))
 -> RuleM
      lore
      (Either
         (Maybe (Int, SubExp))
         (PatElemT (LetDec lore), Either Int (BranchType lore),
          (SubExp, SubExp))))
-> Either
     (Maybe (Int, SubExp))
     (PatElemT (LetDec lore), Either Int (BranchType lore),
      (SubExp, SubExp))
-> RuleM
     lore
     (Either
        (Maybe (Int, SubExp))
        (PatElemT (LetDec lore), Either Int (BranchType lore),
         (SubExp, SubExp)))
forall a b. (a -> b) -> a -> b
$ (PatElemT (LetDec lore), Either Int (BranchType lore),
 (SubExp, SubExp))
-> Either
     (Maybe (Int, SubExp))
     (PatElemT (LetDec lore), Either Int (BranchType lore),
      (SubExp, SubExp))
forall a b. b -> Either a b
Right (PatElemT (LetDec lore)
pe, Either Int (BranchType lore)
t, (SubExp
tse,SubExp
fse))

        hoisted :: PatElemT dec -> Either a b -> m (Either (Maybe (a, SubExp)) b)
hoisted PatElemT dec
pe (Left a
i) = Either (Maybe (a, SubExp)) b -> m (Either (Maybe (a, SubExp)) b)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either (Maybe (a, SubExp)) b -> m (Either (Maybe (a, SubExp)) b))
-> Either (Maybe (a, SubExp)) b -> m (Either (Maybe (a, SubExp)) b)
forall a b. (a -> b) -> a -> b
$ Maybe (a, SubExp) -> Either (Maybe (a, SubExp)) b
forall a b. a -> Either a b
Left (Maybe (a, SubExp) -> Either (Maybe (a, SubExp)) b)
-> Maybe (a, SubExp) -> Either (Maybe (a, SubExp)) b
forall a b. (a -> b) -> a -> b
$ (a, SubExp) -> Maybe (a, SubExp)
forall a. a -> Maybe a
Just (a
i, VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ PatElemT dec -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT dec
pe)
        hoisted PatElemT dec
_ Right{}   = Either (Maybe (a, SubExp)) b -> m (Either (Maybe (a, SubExp)) b)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either (Maybe (a, SubExp)) b -> m (Either (Maybe (a, SubExp)) b))
-> Either (Maybe (a, SubExp)) b -> m (Either (Maybe (a, SubExp)) b)
forall a b. (a -> b) -> a -> b
$ Maybe (a, SubExp) -> Either (Maybe (a, SubExp)) b
forall a b. a -> Either a b
Left Maybe (a, SubExp)
forall a. Maybe a
Nothing

        reshapeBodyResults :: BodyT (Lore m) -> [ExtType] -> m (BodyT (Lore m))
reshapeBodyResults BodyT (Lore m)
body [ExtType]
rets = m (BodyT (Lore m)) -> m (BodyT (Lore m))
forall (m :: * -> *).
MonadBinder m =>
m (Body (Lore m)) -> m (Body (Lore m))
insertStmsM (m (BodyT (Lore m)) -> m (BodyT (Lore m)))
-> m (BodyT (Lore m)) -> m (BodyT (Lore m))
forall a b. (a -> b) -> a -> b
$ do
          [SubExp]
ses <- BodyT (Lore m) -> m [SubExp]
forall (m :: * -> *). MonadBinder m => Body (Lore m) -> m [SubExp]
bodyBind BodyT (Lore m)
body
          let ([SubExp]
ctx_ses, [SubExp]
val_ses) = Int -> [SubExp] -> ([SubExp], [SubExp])
forall a. Int -> [a] -> ([a], [a])
splitFromEnd ([ExtType] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [ExtType]
rets) [SubExp]
ses
          [SubExp] -> m (BodyT (Lore m))
forall (m :: * -> *).
MonadBinder m =>
[SubExp] -> m (Body (Lore m))
resultBodyM ([SubExp] -> m (BodyT (Lore m)))
-> ([SubExp] -> [SubExp]) -> [SubExp] -> m (BodyT (Lore m))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([SubExp]
ctx_ses[SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++) ([SubExp] -> m (BodyT (Lore m)))
-> m [SubExp] -> m (BodyT (Lore m))
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (SubExp -> ExtType -> m SubExp)
-> [SubExp] -> [ExtType] -> m [SubExp]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM SubExp -> ExtType -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
SubExp -> ExtType -> m SubExp
reshapeResult [SubExp]
val_ses [ExtType]
rets
        reshapeResult :: SubExp -> ExtType -> m SubExp
reshapeResult (Var VName
v) t :: ExtType
t@Array{} = do
          TypeBase Shape NoUniqueness
v_t <- VName -> m (TypeBase Shape NoUniqueness)
forall lore (m :: * -> *).
HasScope lore m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
v
          let newshape :: [SubExp]
newshape = TypeBase Shape NoUniqueness -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (TypeBase Shape NoUniqueness -> [SubExp])
-> TypeBase Shape NoUniqueness -> [SubExp]
forall a b. (a -> b) -> a -> b
$ ExtType
-> TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
removeExistentials ExtType
t TypeBase Shape NoUniqueness
v_t
          if [SubExp]
newshape [SubExp] -> [SubExp] -> Bool
forall a. Eq a => a -> a -> Bool
/= TypeBase Shape NoUniqueness -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims TypeBase Shape NoUniqueness
v_t
            then String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"branch_ctx_reshaped" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ [SubExp] -> VName -> Exp (Lore m)
forall lore. [SubExp] -> VName -> Exp lore
shapeCoerce [SubExp]
newshape VName
v
            else SubExp -> m SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp -> m SubExp) -> SubExp -> m SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v
        reshapeResult SubExp
se ExtType
_ =
          SubExp -> m SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return SubExp
se

simplifyIdentityReshape :: SimpleRule lore
simplifyIdentityReshape :: SimpleRule lore
simplifyIdentityReshape VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType (Reshape ShapeChange SubExp
newshape VName
v)
  | Just TypeBase Shape NoUniqueness
t <- SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType (SubExp -> Maybe (TypeBase Shape NoUniqueness))
-> SubExp -> Maybe (TypeBase Shape NoUniqueness)
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v,
    ShapeChange SubExp -> [SubExp]
forall d. ShapeChange d -> [d]
newDims ShapeChange SubExp
newshape [SubExp] -> [SubExp] -> Bool
forall a. Eq a => a -> a -> Bool
== TypeBase Shape NoUniqueness -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims TypeBase Shape NoUniqueness
t = -- No-op reshape.
      SubExp -> Maybe (BasicOp, Certificates)
subExpRes (SubExp -> Maybe (BasicOp, Certificates))
-> SubExp -> Maybe (BasicOp, Certificates)
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v
simplifyIdentityReshape VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ BasicOp
_ = Maybe (BasicOp, Certificates)
forall a. Maybe a
Nothing

simplifyReshapeReshape :: SimpleRule lore
simplifyReshapeReshape :: SimpleRule lore
simplifyReshapeReshape VarLookup lore
defOf SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (Reshape ShapeChange SubExp
newshape VName
v)
  | Just (BasicOp (Reshape ShapeChange SubExp
oldshape VName
v2), Certificates
v_cs) <- VarLookup lore
defOf VName
v =
    (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (ShapeChange SubExp -> VName -> BasicOp
Reshape (ShapeChange SubExp -> ShapeChange SubExp -> ShapeChange SubExp
forall d. Eq d => ShapeChange d -> ShapeChange d -> ShapeChange d
fuseReshape ShapeChange SubExp
oldshape ShapeChange SubExp
newshape) VName
v2, Certificates
v_cs)
simplifyReshapeReshape VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ BasicOp
_ = Maybe (BasicOp, Certificates)
forall a. Maybe a
Nothing

simplifyReshapeScratch :: SimpleRule lore
simplifyReshapeScratch :: SimpleRule lore
simplifyReshapeScratch VarLookup lore
defOf SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (Reshape ShapeChange SubExp
newshape VName
v)
  | Just (BasicOp (Scratch PrimType
bt [SubExp]
_), Certificates
v_cs) <- VarLookup lore
defOf VName
v =
    (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (PrimType -> [SubExp] -> BasicOp
Scratch PrimType
bt ([SubExp] -> BasicOp) -> [SubExp] -> BasicOp
forall a b. (a -> b) -> a -> b
$ ShapeChange SubExp -> [SubExp]
forall d. ShapeChange d -> [d]
newDims ShapeChange SubExp
newshape, Certificates
v_cs)
simplifyReshapeScratch VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ BasicOp
_ = Maybe (BasicOp, Certificates)
forall a. Maybe a
Nothing

simplifyReshapeReplicate :: SimpleRule lore
simplifyReshapeReplicate :: SimpleRule lore
simplifyReshapeReplicate VarLookup lore
defOf SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType (Reshape ShapeChange SubExp
newshape VName
v)
  | Just (BasicOp (Replicate Shape
_ SubExp
se), Certificates
v_cs) <- VarLookup lore
defOf VName
v,
    Just Shape
oldshape <- TypeBase Shape NoUniqueness -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape (TypeBase Shape NoUniqueness -> Shape)
-> Maybe (TypeBase Shape NoUniqueness) -> Maybe Shape
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType SubExp
se,
    Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
oldshape [SubExp] -> [SubExp] -> Bool
forall a. Eq a => [a] -> [a] -> Bool
`isSuffixOf` ShapeChange SubExp -> [SubExp]
forall d. ShapeChange d -> [d]
newDims ShapeChange SubExp
newshape =
      let new :: [SubExp]
new = Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
take (ShapeChange SubExp -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeChange SubExp
newshape Int -> Int -> Int
forall a. Num a => a -> a -> a
- Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank Shape
oldshape) ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$
                ShapeChange SubExp -> [SubExp]
forall d. ShapeChange d -> [d]
newDims ShapeChange SubExp
newshape
      in (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (Shape -> SubExp -> BasicOp
Replicate ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
new) SubExp
se, Certificates
v_cs)
simplifyReshapeReplicate VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ BasicOp
_ = Maybe (BasicOp, Certificates)
forall a. Maybe a
Nothing

simplifyReshapeIota :: SimpleRule lore
simplifyReshapeIota :: SimpleRule lore
simplifyReshapeIota VarLookup lore
defOf SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (Reshape ShapeChange SubExp
newshape VName
v)
  | Just (BasicOp (Iota SubExp
_ SubExp
offset SubExp
stride IntType
it), Certificates
v_cs) <- VarLookup lore
defOf VName
v,
    [SubExp
n] <- ShapeChange SubExp -> [SubExp]
forall d. ShapeChange d -> [d]
newDims ShapeChange SubExp
newshape =
      (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
n SubExp
offset SubExp
stride IntType
it, Certificates
v_cs)
simplifyReshapeIota VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ BasicOp
_ = Maybe (BasicOp, Certificates)
forall a. Maybe a
Nothing

improveReshape :: SimpleRule lore
improveReshape :: SimpleRule lore
improveReshape VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType (Reshape ShapeChange SubExp
newshape VName
v)
  | Just TypeBase Shape NoUniqueness
t <- SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType (SubExp -> Maybe (TypeBase Shape NoUniqueness))
-> SubExp -> Maybe (TypeBase Shape NoUniqueness)
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v,
    ShapeChange SubExp
newshape' <- [SubExp] -> ShapeChange SubExp -> ShapeChange SubExp
forall d. Eq d => [d] -> ShapeChange d -> ShapeChange d
informReshape (TypeBase Shape NoUniqueness -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims TypeBase Shape NoUniqueness
t) ShapeChange SubExp
newshape,
    ShapeChange SubExp
newshape' ShapeChange SubExp -> ShapeChange SubExp -> Bool
forall a. Eq a => a -> a -> Bool
/= ShapeChange SubExp
newshape =
      (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (ShapeChange SubExp -> VName -> BasicOp
Reshape ShapeChange SubExp
newshape' VName
v, Certificates
forall a. Monoid a => a
mempty)
improveReshape VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ BasicOp
_ = Maybe (BasicOp, Certificates)
forall a. Maybe a
Nothing

-- | If we are copying a scratch array (possibly indirectly), just turn it into a scratch by
-- itself.
copyScratchToScratch :: SimpleRule lore
copyScratchToScratch :: SimpleRule lore
copyScratchToScratch VarLookup lore
defOf SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType (Copy VName
src) = do
  TypeBase Shape NoUniqueness
t <- SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType (SubExp -> Maybe (TypeBase Shape NoUniqueness))
-> SubExp -> Maybe (TypeBase Shape NoUniqueness)
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
src
  if VName -> Bool
isActuallyScratch VName
src then
    (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (PrimType -> [SubExp] -> BasicOp
Scratch (TypeBase Shape NoUniqueness -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType TypeBase Shape NoUniqueness
t) (TypeBase Shape NoUniqueness -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims TypeBase Shape NoUniqueness
t), Certificates
forall a. Monoid a => a
mempty)
    else Maybe (BasicOp, Certificates)
forall a. Maybe a
Nothing
  where isActuallyScratch :: VName -> Bool
isActuallyScratch VName
v =
          case Exp lore -> Maybe BasicOp
forall lore. Exp lore -> Maybe BasicOp
asBasicOp (Exp lore -> Maybe BasicOp)
-> ((Exp lore, Certificates) -> Exp lore)
-> (Exp lore, Certificates)
-> Maybe BasicOp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Exp lore, Certificates) -> Exp lore
forall a b. (a, b) -> a
fst ((Exp lore, Certificates) -> Maybe BasicOp)
-> Maybe (Exp lore, Certificates) -> Maybe BasicOp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VarLookup lore
defOf VName
v of
            Just Scratch{} -> Bool
True
            Just (Rearrange [Int]
_ VName
v') -> VName -> Bool
isActuallyScratch VName
v'
            Just (Reshape ShapeChange SubExp
_ VName
v') -> VName -> Bool
isActuallyScratch VName
v'
            Maybe BasicOp
_ -> Bool
False
copyScratchToScratch VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ BasicOp
_ =
  Maybe (BasicOp, Certificates)
forall a. Maybe a
Nothing

ruleBasicOp :: BinderOps lore => TopDownRuleBasicOp lore

-- Check all the simpleRules.
ruleBasicOp :: TopDownRuleBasicOp lore
ruleBasicOp TopDown lore
vtable Pattern lore
pat StmAux (ExpDec lore)
aux BasicOp
op
  | Just (BasicOp
op', Certificates
cs) <- [Maybe (BasicOp, Certificates)] -> Maybe (BasicOp, Certificates)
forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, MonadPlus m) =>
t (m a) -> m a
msum [ SimpleRule lore
rule VName -> Maybe (Exp lore, Certificates)
defOf SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType BasicOp
op | SimpleRule lore
rule <- [SimpleRule lore]
forall lore. [SimpleRule lore]
simpleRules ] =
      RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ Certificates -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (Certificates
cs Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> StmAux (ExpDec lore) -> Certificates
forall dec. StmAux dec -> Certificates
stmAuxCerts StmAux (ExpDec lore)
aux) (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp BasicOp
op'
  where defOf :: VName -> Maybe (Exp lore, Certificates)
defOf = (VName -> TopDown lore -> Maybe (Exp lore, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (Exp lore, Certificates)
`ST.lookupExp` TopDown lore
vtable)
        seType :: SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType (Var VName
v) = VName -> TopDown lore -> Maybe (TypeBase Shape NoUniqueness)
forall lore.
ASTLore lore =>
VName -> SymbolTable lore -> Maybe (TypeBase Shape NoUniqueness)
ST.lookupType VName
v TopDown lore
vtable
        seType (Constant PrimValue
v) = TypeBase Shape NoUniqueness -> Maybe (TypeBase Shape NoUniqueness)
forall a. a -> Maybe a
Just (TypeBase Shape NoUniqueness
 -> Maybe (TypeBase Shape NoUniqueness))
-> TypeBase Shape NoUniqueness
-> Maybe (TypeBase Shape NoUniqueness)
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase Shape NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim (PrimType -> TypeBase Shape NoUniqueness)
-> PrimType -> TypeBase Shape NoUniqueness
forall a b. (a -> b) -> a -> b
$ PrimValue -> PrimType
primValueType PrimValue
v

ruleBasicOp TopDown lore
vtable Pattern lore
pat StmAux (ExpDec lore)
_ (Update VName
src Slice SubExp
_ (Var VName
v))
  | Just (BasicOp Scratch{}, Certificates
_) <- VName -> TopDown lore -> Maybe (Exp lore, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (Exp lore, Certificates)
ST.lookupExp VName
v TopDown lore
vtable =
      RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
src

-- If we are writing a single-element slice from some array, and the
-- element of that array can be computed as a PrimExp based on the
-- index, let's just write that instead.
ruleBasicOp TopDown lore
vtable Pattern lore
pat StmAux (ExpDec lore)
aux (Update VName
src [DimSlice SubExp
i SubExp
n SubExp
s] (Var VName
v))
  | SubExp -> Bool
isCt1 SubExp
n, SubExp -> Bool
isCt1 SubExp
s,
    Just (ST.Indexed Certificates
cs PrimExp VName
e) <- VName -> [SubExp] -> TopDown lore -> Maybe Indexed
forall lore.
ASTLore lore =>
VName -> [SubExp] -> SymbolTable lore -> Maybe Indexed
ST.index VName
v [IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
0] TopDown lore
vtable =
      RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
        SubExp
e' <- String -> PrimExp VName -> RuleM lore SubExp
forall (m :: * -> *) a.
(MonadBinder m, ToExp a) =>
String -> a -> m SubExp
toSubExp String
"update_elem" PrimExp VName
e
        StmAux (ExpDec lore) -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux (ExpDec lore)
aux (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ Certificates -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
          Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> SubExp -> BasicOp
Update VName
src [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
i] SubExp
e'

ruleBasicOp TopDown lore
vtable Pattern lore
pat StmAux (ExpDec lore)
_ (Update VName
dest Slice SubExp
destis (Var VName
v))
  | Just (Exp lore
e, Certificates
_) <- VName -> TopDown lore -> Maybe (Exp lore, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (Exp lore, Certificates)
ST.lookupExp VName
v TopDown lore
vtable,
    Exp lore -> Bool
arrayFrom Exp lore
e =
      RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
dest
  where arrayFrom :: Exp lore -> Bool
arrayFrom (BasicOp (Copy VName
copy_v))
          | Just (Exp lore
e',Certificates
_) <- VName -> TopDown lore -> Maybe (Exp lore, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (Exp lore, Certificates)
ST.lookupExp VName
copy_v TopDown lore
vtable =
              Exp lore -> Bool
arrayFrom Exp lore
e'
        arrayFrom (BasicOp (Index VName
src Slice SubExp
srcis)) =
          VName
src VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
dest Bool -> Bool -> Bool
&& Slice SubExp
destis Slice SubExp -> Slice SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== Slice SubExp
srcis
        arrayFrom (BasicOp (Replicate Shape
v_shape SubExp
v_se))
          | Just (Replicate Shape
dest_shape SubExp
dest_se, Certificates
_) <- VName -> TopDown lore -> Maybe (BasicOp, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (BasicOp, Certificates)
ST.lookupBasicOp VName
dest TopDown lore
vtable,
            SubExp
v_se SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
dest_se,
            Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
v_shape [SubExp] -> [SubExp] -> Bool
forall a. Eq a => [a] -> [a] -> Bool
`isSuffixOf` Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
dest_shape =
              Bool
True
        arrayFrom Exp lore
_ =
          Bool
False

-- | Turn in-place updates that replace an entire array into just
-- array literals.
ruleBasicOp TopDown lore
vtable Pattern lore
pat StmAux (ExpDec lore)
_ (Update VName
dest Slice SubExp
is SubExp
se)
  | Just TypeBase Shape NoUniqueness
dest_t <- VName -> TopDown lore -> Maybe (TypeBase Shape NoUniqueness)
forall lore.
ASTLore lore =>
VName -> SymbolTable lore -> Maybe (TypeBase Shape NoUniqueness)
ST.lookupType VName
dest TopDown lore