{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeFamilies #-}
-- | This module defines a collection of simplification rules, as per
-- "Futhark.Optimise.Simplify.Rule".  They are used in the
-- simplifier.
--
-- For performance reasons, many sufficiently simple logically
-- separate rules are merged into single "super-rules", like ruleIf
-- and ruleBasicOp.  This is because it is relatively expensive to
-- activate a rule just to determine that it does not apply.  Thus, it
-- is more efficient to have a few very fat rules than a lot of small
-- rules.  This does not affect the compiler result in any way; it is
-- purely an optimisation to speed up compilation.
module Futhark.Optimise.Simplify.Rules
  ( standardRules
  , removeUnnecessaryCopy
  )
where

import Control.Monad
import Data.Either
import Data.List (find, isSuffixOf, partition, sort)
import Data.Maybe
import qualified Data.Map.Strict as M

import qualified Futhark.Analysis.SymbolTable as ST
import qualified Futhark.Analysis.UsageTable as UT
import Futhark.Analysis.DataDependencies
import Futhark.Optimise.Simplify.ClosedForm
import Futhark.Optimise.Simplify.Rule
import Futhark.Analysis.PrimExp.Convert
import Futhark.Representation.AST
import Futhark.Representation.AST.Attributes.Aliases
import Futhark.Construct
import Futhark.Util

topDownRules :: (BinderOps lore, Aliased lore) => [TopDownRule lore]
topDownRules :: [TopDownRule lore]
topDownRules = [ RuleDoLoop lore (TopDown lore) -> TopDownRule lore
forall lore a. RuleDoLoop lore a -> SimplificationRule lore a
RuleDoLoop RuleDoLoop lore (TopDown lore)
forall lore. BinderOps lore => TopDownRuleDoLoop lore
hoistLoopInvariantMergeVariables
               , RuleDoLoop lore (TopDown lore) -> TopDownRule lore
forall lore a. RuleDoLoop lore a -> SimplificationRule lore a
RuleDoLoop RuleDoLoop lore (TopDown lore)
forall lore. BinderOps lore => TopDownRuleDoLoop lore
simplifyClosedFormLoop
               , RuleDoLoop lore (TopDown lore) -> TopDownRule lore
forall lore a. RuleDoLoop lore a -> SimplificationRule lore a
RuleDoLoop RuleDoLoop lore (TopDown lore)
forall lore. BinderOps lore => TopDownRuleDoLoop lore
simplifKnownIterationLoop
               , RuleDoLoop lore (TopDown lore) -> TopDownRule lore
forall lore a. RuleDoLoop lore a -> SimplificationRule lore a
RuleDoLoop RuleDoLoop lore (TopDown lore)
forall lore.
(BinderOps lore, Aliased lore) =>
TopDownRuleDoLoop lore
simplifyLoopVariables
               , RuleGeneric lore (TopDown lore) -> TopDownRule lore
forall lore a. RuleGeneric lore a -> SimplificationRule lore a
RuleGeneric RuleGeneric lore (TopDown lore)
forall lore. BinderOps lore => TopDownRuleGeneric lore
constantFoldPrimFun
               , RuleIf lore (TopDown lore) -> TopDownRule lore
forall lore a. RuleIf lore a -> SimplificationRule lore a
RuleIf RuleIf lore (TopDown lore)
forall lore. BinderOps lore => TopDownRuleIf lore
ruleIf
               , RuleIf lore (TopDown lore) -> TopDownRule lore
forall lore a. RuleIf lore a -> SimplificationRule lore a
RuleIf RuleIf lore (TopDown lore)
forall lore. BinderOps lore => TopDownRuleIf lore
hoistBranchInvariant
               , RuleBasicOp lore (TopDown lore) -> TopDownRule lore
forall lore a. RuleBasicOp lore a -> SimplificationRule lore a
RuleBasicOp RuleBasicOp lore (TopDown lore)
forall lore. BinderOps lore => TopDownRuleBasicOp lore
ruleBasicOp
               ]

bottomUpRules :: BinderOps lore => [BottomUpRule lore]
bottomUpRules :: [BottomUpRule lore]
bottomUpRules = [ RuleDoLoop lore (BottomUp lore) -> BottomUpRule lore
forall lore a. RuleDoLoop lore a -> SimplificationRule lore a
RuleDoLoop RuleDoLoop lore (BottomUp lore)
forall lore. BinderOps lore => BottomUpRuleDoLoop lore
removeRedundantMergeVariables
                , RuleIf lore (BottomUp lore) -> BottomUpRule lore
forall lore a. RuleIf lore a -> SimplificationRule lore a
RuleIf RuleIf lore (BottomUp lore)
forall lore. BinderOps lore => BottomUpRuleIf lore
removeDeadBranchResult
                , RuleBasicOp lore (BottomUp lore) -> BottomUpRule lore
forall lore a. RuleBasicOp lore a -> SimplificationRule lore a
RuleBasicOp RuleBasicOp lore (BottomUp lore)
forall lore. BinderOps lore => BottomUpRuleBasicOp lore
simplifyIndex
                , RuleBasicOp lore (BottomUp lore) -> BottomUpRule lore
forall lore a. RuleBasicOp lore a -> SimplificationRule lore a
RuleBasicOp RuleBasicOp lore (BottomUp lore)
forall lore. BinderOps lore => BottomUpRuleBasicOp lore
simplifyConcat
                ]

asInt32PrimExp :: PrimExp v -> PrimExp v
asInt32PrimExp :: PrimExp v -> PrimExp v
asInt32PrimExp PrimExp v
pe
  | IntType IntType
it <- PrimExp v -> PrimType
forall v. PrimExp v -> PrimType
primExpType PrimExp v
pe, IntType
it IntType -> IntType -> Bool
forall a. Eq a => a -> a -> Bool
/= IntType
Int32 =
      ConvOp -> PrimExp v -> PrimExp v
forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp (IntType -> IntType -> ConvOp
SExt IntType
it IntType
Int32) PrimExp v
pe
  | Bool
otherwise =
      PrimExp v
pe

-- | A set of standard simplification rules.  These assume pure
-- functional semantics, and so probably should not be applied after
-- memory block merging.
standardRules :: (BinderOps lore, Aliased lore) => RuleBook lore
standardRules :: RuleBook lore
standardRules = [TopDownRule lore] -> [BottomUpRule lore] -> RuleBook lore
forall m. [TopDownRule m] -> [BottomUpRule m] -> RuleBook m
ruleBook [TopDownRule lore]
forall lore. (BinderOps lore, Aliased lore) => [TopDownRule lore]
topDownRules [BottomUpRule lore]
forall lore. BinderOps lore => [BottomUpRule lore]
bottomUpRules

-- This next one is tricky - it's easy enough to determine that some
-- loop result is not used after the loop, but here, we must also make
-- sure that it does not affect any other values.
--
-- I do not claim that the current implementation of this rule is
-- perfect, but it should suffice for many cases, and should never
-- generate wrong code.
removeRedundantMergeVariables :: BinderOps lore => BottomUpRuleDoLoop lore
removeRedundantMergeVariables :: BottomUpRuleDoLoop lore
removeRedundantMergeVariables (SymbolTable lore
_, UsageTable
used) Pattern lore
pat StmAux (ExpAttr lore)
_ ([(FParam lore, SubExp)]
ctx, [(FParam lore, SubExp)]
val, LoopForm lore
form, BodyT lore
body)
  | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ ((FParam lore, SubExp) -> Bool) -> [(FParam lore, SubExp)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (FParam lore -> Bool
usedAfterLoop (FParam lore -> Bool)
-> ((FParam lore, SubExp) -> FParam lore)
-> (FParam lore, SubExp)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (FParam lore, SubExp) -> FParam lore
forall a b. (a, b) -> a
fst) [(FParam lore, SubExp)]
val,
    [(FParam lore, SubExp)] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(FParam lore, SubExp)]
ctx = -- FIXME: things get tricky if we can remove all vals
               -- but some ctxs are still used.  We take the easy way
               -- out for now.
  let ([SubExp]
ctx_es, [SubExp]
val_es) = Int -> [SubExp] -> ([SubExp], [SubExp])
forall a. Int -> [a] -> ([a], [a])
splitAt ([(FParam lore, SubExp)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(FParam lore, SubExp)]
ctx) ([SubExp] -> ([SubExp], [SubExp]))
-> [SubExp] -> ([SubExp], [SubExp])
forall a b. (a -> b) -> a -> b
$ BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT lore
body
      necessaryForReturned :: Names
necessaryForReturned =
        (FParam lore -> Bool)
-> [(FParam lore, SubExp)] -> Map VName Names -> Names
forall attr.
(Param attr -> Bool)
-> [(Param attr, SubExp)] -> Map VName Names -> Names
findNecessaryForReturned FParam lore -> Bool
usedAfterLoopOrInForm
        ([FParam lore] -> [SubExp] -> [(FParam lore, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (((FParam lore, SubExp) -> FParam lore)
-> [(FParam lore, SubExp)] -> [FParam lore]
forall a b. (a -> b) -> [a] -> [b]
map (FParam lore, SubExp) -> FParam lore
forall a b. (a, b) -> a
fst ([(FParam lore, SubExp)] -> [FParam lore])
-> [(FParam lore, SubExp)] -> [FParam lore]
forall a b. (a -> b) -> a -> b
$ [(FParam lore, SubExp)]
ctx[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)] -> [(FParam lore, SubExp)]
forall a. [a] -> [a] -> [a]
++[(FParam lore, SubExp)]
val) ([SubExp] -> [(FParam lore, SubExp)])
-> [SubExp] -> [(FParam lore, SubExp)]
forall a b. (a -> b) -> a -> b
$ [SubExp]
ctx_es[SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++[SubExp]
val_es) (BodyT lore -> Map VName Names
forall lore. Attributes lore => Body lore -> Map VName Names
dataDependencies BodyT lore
body)

      resIsNecessary :: ((FParam lore, SubExp), SubExp) -> Bool
resIsNecessary ((FParam lore
v,SubExp
_), SubExp
_) =
        FParam lore -> Bool
usedAfterLoop FParam lore
v Bool -> Bool -> Bool
||
        FParam lore -> VName
forall attr. Param attr -> VName
paramName FParam lore
v VName -> Names -> Bool
`nameIn` Names
necessaryForReturned Bool -> Bool -> Bool
||
        FParam lore -> Bool
referencedInPat FParam lore
v Bool -> Bool -> Bool
||
        FParam lore -> Bool
referencedInForm FParam lore
v

      ([((FParam lore, SubExp), SubExp)]
keep_ctx, [((FParam lore, SubExp), SubExp)]
discard_ctx) =
        (((FParam lore, SubExp), SubExp) -> Bool)
-> [((FParam lore, SubExp), SubExp)]
-> ([((FParam lore, SubExp), SubExp)],
    [((FParam lore, SubExp), SubExp)])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition ((FParam lore, SubExp), SubExp) -> Bool
resIsNecessary ([((FParam lore, SubExp), SubExp)]
 -> ([((FParam lore, SubExp), SubExp)],
     [((FParam lore, SubExp), SubExp)]))
-> [((FParam lore, SubExp), SubExp)]
-> ([((FParam lore, SubExp), SubExp)],
    [((FParam lore, SubExp), SubExp)])
forall a b. (a -> b) -> a -> b
$ [(FParam lore, SubExp)]
-> [SubExp] -> [((FParam lore, SubExp), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [(FParam lore, SubExp)]
ctx [SubExp]
ctx_es
      ([(PatElemT (LetAttr lore), ((FParam lore, SubExp), SubExp))]
keep_valpart, [(PatElemT (LetAttr lore), ((FParam lore, SubExp), SubExp))]
discard_valpart) =
        ((PatElemT (LetAttr lore), ((FParam lore, SubExp), SubExp))
 -> Bool)
-> [(PatElemT (LetAttr lore), ((FParam lore, SubExp), SubExp))]
-> ([(PatElemT (LetAttr lore), ((FParam lore, SubExp), SubExp))],
    [(PatElemT (LetAttr lore), ((FParam lore, SubExp), SubExp))])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition (((FParam lore, SubExp), SubExp) -> Bool
resIsNecessary (((FParam lore, SubExp), SubExp) -> Bool)
-> ((PatElemT (LetAttr lore), ((FParam lore, SubExp), SubExp))
    -> ((FParam lore, SubExp), SubExp))
-> (PatElemT (LetAttr lore), ((FParam lore, SubExp), SubExp))
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PatElemT (LetAttr lore), ((FParam lore, SubExp), SubExp))
-> ((FParam lore, SubExp), SubExp)
forall a b. (a, b) -> b
snd) ([(PatElemT (LetAttr lore), ((FParam lore, SubExp), SubExp))]
 -> ([(PatElemT (LetAttr lore), ((FParam lore, SubExp), SubExp))],
     [(PatElemT (LetAttr lore), ((FParam lore, SubExp), SubExp))]))
-> [(PatElemT (LetAttr lore), ((FParam lore, SubExp), SubExp))]
-> ([(PatElemT (LetAttr lore), ((FParam lore, SubExp), SubExp))],
    [(PatElemT (LetAttr lore), ((FParam lore, SubExp), SubExp))])
forall a b. (a -> b) -> a -> b
$
        [PatElemT (LetAttr lore)]
-> [((FParam lore, SubExp), SubExp)]
-> [(PatElemT (LetAttr lore), ((FParam lore, SubExp), SubExp))]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pattern lore -> [PatElemT (LetAttr lore)]
forall attr. PatternT attr -> [PatElemT attr]
patternValueElements Pattern lore
pat) ([((FParam lore, SubExp), SubExp)]
 -> [(PatElemT (LetAttr lore), ((FParam lore, SubExp), SubExp))])
-> [((FParam lore, SubExp), SubExp)]
-> [(PatElemT (LetAttr lore), ((FParam lore, SubExp), SubExp))]
forall a b. (a -> b) -> a -> b
$ [(FParam lore, SubExp)]
-> [SubExp] -> [((FParam lore, SubExp), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [(FParam lore, SubExp)]
val [SubExp]
val_es

      ([PatElemT (LetAttr lore)]
keep_valpatelems, [((FParam lore, SubExp), SubExp)]
keep_val) = [(PatElemT (LetAttr lore), ((FParam lore, SubExp), SubExp))]
-> ([PatElemT (LetAttr lore)], [((FParam lore, SubExp), SubExp)])
forall a b. [(a, b)] -> ([a], [b])
unzip [(PatElemT (LetAttr lore), ((FParam lore, SubExp), SubExp))]
keep_valpart
      ([PatElemT (LetAttr lore)]
_discard_valpatelems, [((FParam lore, SubExp), SubExp)]
discard_val) = [(PatElemT (LetAttr lore), ((FParam lore, SubExp), SubExp))]
-> ([PatElemT (LetAttr lore)], [((FParam lore, SubExp), SubExp)])
forall a b. [(a, b)] -> ([a], [b])
unzip [(PatElemT (LetAttr lore), ((FParam lore, SubExp), SubExp))]
discard_valpart
      ([(FParam lore, SubExp)]
ctx', [SubExp]
ctx_es') = [((FParam lore, SubExp), SubExp)]
-> ([(FParam lore, SubExp)], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip [((FParam lore, SubExp), SubExp)]
keep_ctx
      ([(FParam lore, SubExp)]
val', [SubExp]
val_es') = [((FParam lore, SubExp), SubExp)]
-> ([(FParam lore, SubExp)], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip [((FParam lore, SubExp), SubExp)]
keep_val

      body' :: BodyT lore
body' = BodyT lore
body { bodyResult :: [SubExp]
bodyResult = [SubExp]
ctx_es' [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp]
val_es' }
      free_in_keeps :: Names
free_in_keeps = [PatElemT (LetAttr lore)] -> Names
forall a. FreeIn a => a -> Names
freeIn [PatElemT (LetAttr lore)]
keep_valpatelems

      stillUsedContext :: PatElemT (LetAttr lore) -> Bool
stillUsedContext PatElemT (LetAttr lore)
pat_elem =
        PatElemT (LetAttr lore) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (LetAttr lore)
pat_elem VName -> Names -> Bool
`nameIn`
        (Names
free_in_keeps Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<>
         [PatElemT (LetAttr lore)] -> Names
forall a. FreeIn a => a -> Names
freeIn ((PatElemT (LetAttr lore) -> Bool)
-> [PatElemT (LetAttr lore)] -> [PatElemT (LetAttr lore)]
forall a. (a -> Bool) -> [a] -> [a]
filter (PatElemT (LetAttr lore) -> PatElemT (LetAttr lore) -> Bool
forall a. Eq a => a -> a -> Bool
/=PatElemT (LetAttr lore)
pat_elem) ([PatElemT (LetAttr lore)] -> [PatElemT (LetAttr lore)])
-> [PatElemT (LetAttr lore)] -> [PatElemT (LetAttr lore)]
forall a b. (a -> b) -> a -> b
$ Pattern lore -> [PatElemT (LetAttr lore)]
forall attr. PatternT attr -> [PatElemT attr]
patternContextElements Pattern lore
pat))

      pat' :: Pattern lore
pat' = Pattern lore
pat { patternValueElements :: [PatElemT (LetAttr lore)]
patternValueElements = [PatElemT (LetAttr lore)]
keep_valpatelems
                 , patternContextElements :: [PatElemT (LetAttr lore)]
patternContextElements =
                     (PatElemT (LetAttr lore) -> Bool)
-> [PatElemT (LetAttr lore)] -> [PatElemT (LetAttr lore)]
forall a. (a -> Bool) -> [a] -> [a]
filter PatElemT (LetAttr lore) -> Bool
stillUsedContext ([PatElemT (LetAttr lore)] -> [PatElemT (LetAttr lore)])
-> [PatElemT (LetAttr lore)] -> [PatElemT (LetAttr lore)]
forall a b. (a -> b) -> a -> b
$ Pattern lore -> [PatElemT (LetAttr lore)]
forall attr. PatternT attr -> [PatElemT attr]
patternContextElements Pattern lore
pat }
  in if [(FParam lore, SubExp)]
ctx' [(FParam lore, SubExp)]
-> [(FParam lore, SubExp)] -> [(FParam lore, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(FParam lore, SubExp)]
val' [(FParam lore, SubExp)] -> [(FParam lore, SubExp)] -> Bool
forall a. Eq a => a -> a -> Bool
== [(FParam lore, SubExp)]
ctx [(FParam lore, SubExp)]
-> [(FParam lore, SubExp)] -> [(FParam lore, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(FParam lore, SubExp)]
val
     then Rule lore
forall lore. Rule lore
Skip
     else RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
       -- We can't just remove the bindings in 'discard', since the loop
       -- body may still use their names in (now-dead) expressions.
       -- Hence, we add them inside the loop, fully aware that dead-code
       -- removal will eventually get rid of them.  Some care is
       -- necessary to handle unique bindings.
       BodyT lore
body'' <- RuleM lore (Body (Lore (RuleM lore)))
-> RuleM lore (Body (Lore (RuleM lore)))
forall (m :: * -> *).
MonadBinder m =>
m (Body (Lore m)) -> m (Body (Lore m))
insertStmsM (RuleM lore (Body (Lore (RuleM lore)))
 -> RuleM lore (Body (Lore (RuleM lore))))
-> RuleM lore (Body (Lore (RuleM lore)))
-> RuleM lore (Body (Lore (RuleM lore)))
forall a b. (a -> b) -> a -> b
$ do
         (([VName], ExpT lore) -> RuleM lore [Ident])
-> [([VName], ExpT lore)] -> RuleM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (([VName] -> ExpT lore -> RuleM lore [Ident])
-> ([VName], ExpT lore) -> RuleM lore [Ident]
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry [VName] -> ExpT lore -> RuleM lore [Ident]
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m [Ident]
letBindNames) ([([VName], ExpT lore)] -> RuleM lore ())
-> [([VName], ExpT lore)] -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ [((FParam lore, SubExp), SubExp)] -> [([VName], ExpT lore)]
forall b lore.
[((FParam lore, SubExp), b)] -> [([VName], ExpT lore)]
dummyStms [((FParam lore, SubExp), SubExp)]
discard_ctx
         (([VName], ExpT lore) -> RuleM lore [Ident])
-> [([VName], ExpT lore)] -> RuleM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (([VName] -> ExpT lore -> RuleM lore [Ident])
-> ([VName], ExpT lore) -> RuleM lore [Ident]
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry [VName] -> ExpT lore -> RuleM lore [Ident]
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m [Ident]
letBindNames) ([([VName], ExpT lore)] -> RuleM lore ())
-> [([VName], ExpT lore)] -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ [((FParam lore, SubExp), SubExp)] -> [([VName], ExpT lore)]
forall b lore.
[((FParam lore, SubExp), b)] -> [([VName], ExpT lore)]
dummyStms [((FParam lore, SubExp), SubExp)]
discard_val
         BodyT lore -> RuleM lore (BodyT lore)
forall (m :: * -> *) a. Monad m => a -> m a
return BodyT lore
body'
       Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind_ Pattern lore
Pattern (Lore (RuleM lore))
pat' (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ [(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
forall lore.
[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
DoLoop [(FParam lore, SubExp)]
ctx' [(FParam lore, SubExp)]
val' LoopForm lore
form BodyT lore
body''
  where pat_used :: [Bool]
pat_used = (VName -> Bool) -> [VName] -> [Bool]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> UsageTable -> Bool
`UT.isUsedDirectly` UsageTable
used) ([VName] -> [Bool]) -> [VName] -> [Bool]
forall a b. (a -> b) -> a -> b
$ Pattern lore -> [VName]
forall attr. PatternT attr -> [VName]
patternValueNames Pattern lore
pat
        used_vals :: [VName]
used_vals = ((VName, Bool) -> VName) -> [(VName, Bool)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, Bool) -> VName
forall a b. (a, b) -> a
fst ([(VName, Bool)] -> [VName]) -> [(VName, Bool)] -> [VName]
forall a b. (a -> b) -> a -> b
$ ((VName, Bool) -> Bool) -> [(VName, Bool)] -> [(VName, Bool)]
forall a. (a -> Bool) -> [a] -> [a]
filter (VName, Bool) -> Bool
forall a b. (a, b) -> b
snd ([(VName, Bool)] -> [(VName, Bool)])
-> [(VName, Bool)] -> [(VName, Bool)]
forall a b. (a -> b) -> a -> b
$ [VName] -> [Bool] -> [(VName, Bool)]
forall a b. [a] -> [b] -> [(a, b)]
zip (((FParam lore, SubExp) -> VName)
-> [(FParam lore, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (FParam lore -> VName
forall attr. Param attr -> VName
paramName (FParam lore -> VName)
-> ((FParam lore, SubExp) -> FParam lore)
-> (FParam lore, SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (FParam lore, SubExp) -> FParam lore
forall a b. (a, b) -> a
fst) [(FParam lore, SubExp)]
val) [Bool]
pat_used
        usedAfterLoop :: FParam lore -> Bool
usedAfterLoop = (VName -> [VName] -> Bool) -> [VName] -> VName -> Bool
forall a b c. (a -> b -> c) -> b -> a -> c
flip VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
elem [VName]
used_vals (VName -> Bool) -> (FParam lore -> VName) -> FParam lore -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FParam lore -> VName
forall attr. Param attr -> VName
paramName
        usedAfterLoopOrInForm :: FParam lore -> Bool
usedAfterLoopOrInForm FParam lore
p =
          FParam lore -> Bool
usedAfterLoop FParam lore
p Bool -> Bool -> Bool
|| FParam lore -> VName
forall attr. Param attr -> VName
paramName FParam lore
p VName -> Names -> Bool
`nameIn` LoopForm lore -> Names
forall a. FreeIn a => a -> Names
freeIn LoopForm lore
form
        patAnnotNames :: Names
patAnnotNames = [FParam lore] -> Names
forall a. FreeIn a => a -> Names
freeIn ([FParam lore] -> Names) -> [FParam lore] -> Names
forall a b. (a -> b) -> a -> b
$ ((FParam lore, SubExp) -> FParam lore)
-> [(FParam lore, SubExp)] -> [FParam lore]
forall a b. (a -> b) -> [a] -> [b]
map (FParam lore, SubExp) -> FParam lore
forall a b. (a, b) -> a
fst ([(FParam lore, SubExp)] -> [FParam lore])
-> [(FParam lore, SubExp)] -> [FParam lore]
forall a b. (a -> b) -> a -> b
$ [(FParam lore, SubExp)]
ctx[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)] -> [(FParam lore, SubExp)]
forall a. [a] -> [a] -> [a]
++[(FParam lore, SubExp)]
val
        referencedInPat :: FParam lore -> Bool
referencedInPat = (VName -> Names -> Bool
`nameIn` Names
patAnnotNames) (VName -> Bool) -> (FParam lore -> VName) -> FParam lore -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FParam lore -> VName
forall attr. Param attr -> VName
paramName
        referencedInForm :: FParam lore -> Bool
referencedInForm = (VName -> Names -> Bool
`nameIn` LoopForm lore -> Names
forall a. FreeIn a => a -> Names
freeIn LoopForm lore
form) (VName -> Bool) -> (FParam lore -> VName) -> FParam lore -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FParam lore -> VName
forall attr. Param attr -> VName
paramName

        dummyStms :: [((FParam lore, SubExp), b)] -> [([VName], ExpT lore)]
dummyStms = (((FParam lore, SubExp), b) -> ([VName], ExpT lore))
-> [((FParam lore, SubExp), b)] -> [([VName], ExpT lore)]
forall a b. (a -> b) -> [a] -> [b]
map ((FParam lore, SubExp), b) -> ([VName], ExpT lore)
forall attr b lore.
DeclTyped attr =>
((Param attr, SubExp), b) -> ([VName], ExpT lore)
dummyStm
        dummyStm :: ((Param attr, SubExp), b) -> ([VName], ExpT lore)
dummyStm ((Param attr
p,SubExp
e), b
_)
          | TypeBase Shape Uniqueness -> Bool
forall shape. TypeBase shape Uniqueness -> Bool
unique (Param attr -> TypeBase Shape Uniqueness
forall attr.
DeclTyped attr =>
Param attr -> TypeBase Shape Uniqueness
paramDeclType Param attr
p),
            Var VName
v <- SubExp
e            = ([Param attr -> VName
forall attr. Param attr -> VName
paramName Param attr
p], BasicOp lore -> ExpT lore
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp lore -> ExpT lore) -> BasicOp lore -> ExpT lore
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp lore
forall lore. VName -> BasicOp lore
Copy VName
v)
          | Bool
otherwise             = ([Param attr -> VName
forall attr. Param attr -> VName
paramName Param attr
p], BasicOp lore -> ExpT lore
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp lore -> ExpT lore) -> BasicOp lore -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp lore
forall lore. SubExp -> BasicOp lore
SubExp SubExp
e)
removeRedundantMergeVariables (SymbolTable lore, UsageTable)
_ Pattern lore
_ StmAux (ExpAttr lore)
_ ([(FParam lore, SubExp)], [(FParam lore, SubExp)], LoopForm lore,
 BodyT lore)
_ =
  Rule lore
forall lore. Rule lore
Skip

-- We may change the type of the loop if we hoist out a shape
-- annotation, in which case we also need to tweak the bound pattern.
hoistLoopInvariantMergeVariables :: BinderOps lore => TopDownRuleDoLoop lore
hoistLoopInvariantMergeVariables :: TopDownRuleDoLoop lore
hoistLoopInvariantMergeVariables TopDown lore
_ Pattern lore
pat StmAux (ExpAttr lore)
_ ([(FParam lore, SubExp)]
ctx, [(FParam lore, SubExp)]
val, LoopForm lore
form, BodyT lore
loopbody) =
    -- Figure out which of the elements of loopresult are
    -- loop-invariant, and hoist them out.
  case (((FParam lore, SubExp), SubExp)
 -> ([(Ident, SubExp)], [(PatElemT (LetAttr lore), VName)],
     [(FParam lore, SubExp)], [SubExp])
 -> ([(Ident, SubExp)], [(PatElemT (LetAttr lore), VName)],
     [(FParam lore, SubExp)], [SubExp]))
-> ([(Ident, SubExp)], [(PatElemT (LetAttr lore), VName)],
    [(FParam lore, SubExp)], [SubExp])
-> [((FParam lore, SubExp), SubExp)]
-> ([(Ident, SubExp)], [(PatElemT (LetAttr lore), VName)],
    [(FParam lore, SubExp)], [SubExp])
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr ((FParam lore, SubExp), SubExp)
-> ([(Ident, SubExp)], [(PatElemT (LetAttr lore), VName)],
    [(FParam lore, SubExp)], [SubExp])
-> ([(Ident, SubExp)], [(PatElemT (LetAttr lore), VName)],
    [(FParam lore, SubExp)], [SubExp])
checkInvariance ([], [(PatElemT (LetAttr lore), VName)]
explpat, [], []) ([((FParam lore, SubExp), SubExp)]
 -> ([(Ident, SubExp)], [(PatElemT (LetAttr lore), VName)],
     [(FParam lore, SubExp)], [SubExp]))
-> [((FParam lore, SubExp), SubExp)]
-> ([(Ident, SubExp)], [(PatElemT (LetAttr lore), VName)],
    [(FParam lore, SubExp)], [SubExp])
forall a b. (a -> b) -> a -> b
$
       [(FParam lore, SubExp)]
-> [SubExp] -> [((FParam lore, SubExp), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [(FParam lore, SubExp)]
merge [SubExp]
res of
    ([], [(PatElemT (LetAttr lore), VName)]
_, [(FParam lore, SubExp)]
_, [SubExp]
_) ->
      -- Nothing is invariant.
      Rule lore
forall lore. Rule lore
Skip
    ([(Ident, SubExp)]
invariant, [(PatElemT (LetAttr lore), VName)]
explpat', [(FParam lore, SubExp)]
merge', [SubExp]
res') -> RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
      -- We have moved something invariant out of the loop.
      let loopbody' :: BodyT lore
loopbody' = BodyT lore
loopbody { bodyResult :: [SubExp]
bodyResult = [SubExp]
res' }
          invariantShape :: (a, VName) -> Bool
          invariantShape :: (a, VName) -> Bool
invariantShape (a
_, VName
shapemerge) = VName
shapemerge VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem`
                                           ((FParam lore, SubExp) -> VName)
-> [(FParam lore, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (FParam lore -> VName
forall attr. Param attr -> VName
paramName (FParam lore -> VName)
-> ((FParam lore, SubExp) -> FParam lore)
-> (FParam lore, SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (FParam lore, SubExp) -> FParam lore
forall a b. (a, b) -> a
fst) [(FParam lore, SubExp)]
merge'
          ([(PatElemT (LetAttr lore), VName)]
implpat',[(PatElemT (LetAttr lore), VName)]
implinvariant) = ((PatElemT (LetAttr lore), VName) -> Bool)
-> [(PatElemT (LetAttr lore), VName)]
-> ([(PatElemT (LetAttr lore), VName)],
    [(PatElemT (LetAttr lore), VName)])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition (PatElemT (LetAttr lore), VName) -> Bool
forall a. (a, VName) -> Bool
invariantShape [(PatElemT (LetAttr lore), VName)]
implpat
          implinvariant' :: [(Ident, SubExp)]
implinvariant' = [ (PatElemT (LetAttr lore) -> Ident
forall attr. Typed attr => PatElemT attr -> Ident
patElemIdent PatElemT (LetAttr lore)
p, VName -> SubExp
Var VName
v) | (PatElemT (LetAttr lore)
p,VName
v) <- [(PatElemT (LetAttr lore), VName)]
implinvariant ]
          implpat'' :: [PatElemT (LetAttr lore)]
implpat'' = ((PatElemT (LetAttr lore), VName) -> PatElemT (LetAttr lore))
-> [(PatElemT (LetAttr lore), VName)] -> [PatElemT (LetAttr lore)]
forall a b. (a -> b) -> [a] -> [b]
map (PatElemT (LetAttr lore), VName) -> PatElemT (LetAttr lore)
forall a b. (a, b) -> a
fst [(PatElemT (LetAttr lore), VName)]
implpat'
          explpat'' :: [PatElemT (LetAttr lore)]
explpat'' = ((PatElemT (LetAttr lore), VName) -> PatElemT (LetAttr lore))
-> [(PatElemT (LetAttr lore), VName)] -> [PatElemT (LetAttr lore)]
forall a b. (a -> b) -> [a] -> [b]
map (PatElemT (LetAttr lore), VName) -> PatElemT (LetAttr lore)
forall a b. (a, b) -> a
fst [(PatElemT (LetAttr lore), VName)]
explpat'
          ([(FParam lore, SubExp)]
ctx', [(FParam lore, SubExp)]
val') = Int
-> [(FParam lore, SubExp)]
-> ([(FParam lore, SubExp)], [(FParam lore, SubExp)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([(PatElemT (LetAttr lore), VName)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(PatElemT (LetAttr lore), VName)]
implpat') [(FParam lore, SubExp)]
merge'
      [(Ident, SubExp)]
-> ((Ident, SubExp) -> RuleM lore ()) -> RuleM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([(Ident, SubExp)]
invariant [(Ident, SubExp)] -> [(Ident, SubExp)] -> [(Ident, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(Ident, SubExp)]
implinvariant') (((Ident, SubExp) -> RuleM lore ()) -> RuleM lore ())
-> ((Ident, SubExp) -> RuleM lore ()) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ \(Ident
v1,SubExp
v2) ->
        [VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames_ [Ident -> VName
identName Ident
v1] (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp lore -> ExpT lore
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp lore -> ExpT lore) -> BasicOp lore -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp lore
forall lore. SubExp -> BasicOp lore
SubExp SubExp
v2
      Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind_ ([PatElemT (LetAttr lore)]
-> [PatElemT (LetAttr lore)] -> Pattern lore
forall attr. [PatElemT attr] -> [PatElemT attr] -> PatternT attr
Pattern [PatElemT (LetAttr lore)]
implpat'' [PatElemT (LetAttr lore)]
explpat'') (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
        [(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
forall lore.
[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
DoLoop [(FParam lore, SubExp)]
ctx' [(FParam lore, SubExp)]
val' LoopForm lore
form BodyT lore
loopbody'
  where merge :: [(FParam lore, SubExp)]
merge = [(FParam lore, SubExp)]
ctx [(FParam lore, SubExp)]
-> [(FParam lore, SubExp)] -> [(FParam lore, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(FParam lore, SubExp)]
val
        res :: [SubExp]
res = BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT lore
loopbody

        implpat :: [(PatElemT (LetAttr lore), VName)]
implpat = [PatElemT (LetAttr lore)]
-> [VName] -> [(PatElemT (LetAttr lore), VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pattern lore -> [PatElemT (LetAttr lore)]
forall attr. PatternT attr -> [PatElemT attr]
patternContextElements Pattern lore
pat) ([VName] -> [(PatElemT (LetAttr lore), VName)])
-> [VName] -> [(PatElemT (LetAttr lore), VName)]
forall a b. (a -> b) -> a -> b
$
                  ((FParam lore, SubExp) -> VName)
-> [(FParam lore, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (FParam lore -> VName
forall attr. Param attr -> VName
paramName (FParam lore -> VName)
-> ((FParam lore, SubExp) -> FParam lore)
-> (FParam lore, SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (FParam lore, SubExp) -> FParam lore
forall a b. (a, b) -> a
fst) [(FParam lore, SubExp)]
ctx
        explpat :: [(PatElemT (LetAttr lore), VName)]
explpat = [PatElemT (LetAttr lore)]
-> [VName] -> [(PatElemT (LetAttr lore), VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pattern lore -> [PatElemT (LetAttr lore)]
forall attr. PatternT attr -> [PatElemT attr]
patternValueElements Pattern lore
pat) ([VName] -> [(PatElemT (LetAttr lore), VName)])
-> [VName] -> [(PatElemT (LetAttr lore), VName)]
forall a b. (a -> b) -> a -> b
$
                  ((FParam lore, SubExp) -> VName)
-> [(FParam lore, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (FParam lore -> VName
forall attr. Param attr -> VName
paramName (FParam lore -> VName)
-> ((FParam lore, SubExp) -> FParam lore)
-> (FParam lore, SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (FParam lore, SubExp) -> FParam lore
forall a b. (a, b) -> a
fst) [(FParam lore, SubExp)]
val

        namesOfMergeParams :: Names
namesOfMergeParams = [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ ((FParam lore, SubExp) -> VName)
-> [(FParam lore, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (FParam lore -> VName
forall attr. Param attr -> VName
paramName (FParam lore -> VName)
-> ((FParam lore, SubExp) -> FParam lore)
-> (FParam lore, SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (FParam lore, SubExp) -> FParam lore
forall a b. (a, b) -> a
fst) ([(FParam lore, SubExp)] -> [VName])
-> [(FParam lore, SubExp)] -> [VName]
forall a b. (a -> b) -> a -> b
$ [(FParam lore, SubExp)]
ctx[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)] -> [(FParam lore, SubExp)]
forall a. [a] -> [a] -> [a]
++[(FParam lore, SubExp)]
val

        removeFromResult :: (Param attr, b)
-> [(PatElemT attr, VName)]
-> (Maybe (Ident, b), [(PatElemT attr, VName)])
removeFromResult (Param attr
mergeParam,b
mergeInit) [(PatElemT attr, VName)]
explpat' =
          case ((PatElemT attr, VName) -> Bool)
-> [(PatElemT attr, VName)]
-> ([(PatElemT attr, VName)], [(PatElemT attr, VName)])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
==Param attr -> VName
forall attr. Param attr -> VName
paramName Param attr
mergeParam) (VName -> Bool)
-> ((PatElemT attr, VName) -> VName)
-> (PatElemT attr, VName)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PatElemT attr, VName) -> VName
forall a b. (a, b) -> b
snd) [(PatElemT attr, VName)]
explpat' of
            ([(PatElemT attr
patelem,VName
_)], [(PatElemT attr, VName)]
rest) ->
              ((Ident, b) -> Maybe (Ident, b)
forall a. a -> Maybe a
Just (PatElemT attr -> Ident
forall attr. Typed attr => PatElemT attr -> Ident
patElemIdent PatElemT attr
patelem, b
mergeInit), [(PatElemT attr, VName)]
rest)
            ([(PatElemT attr, VName)]
_,      [(PatElemT attr, VName)]
_) ->
              (Maybe (Ident, b)
forall a. Maybe a
Nothing, [(PatElemT attr, VName)]
explpat')

        checkInvariance :: ((FParam lore, SubExp), SubExp)
-> ([(Ident, SubExp)], [(PatElemT (LetAttr lore), VName)],
    [(FParam lore, SubExp)], [SubExp])
-> ([(Ident, SubExp)], [(PatElemT (LetAttr lore), VName)],
    [(FParam lore, SubExp)], [SubExp])
checkInvariance
          ((FParam lore
mergeParam,SubExp
mergeInit), SubExp
resExp)
          ([(Ident, SubExp)]
invariant, [(PatElemT (LetAttr lore), VName)]
explpat', [(FParam lore, SubExp)]
merge', [SubExp]
resExps)
          | Bool -> Bool
not (TypeBase Shape Uniqueness -> Bool
forall shape. TypeBase shape Uniqueness -> Bool
unique (FParam lore -> TypeBase Shape Uniqueness
forall attr.
DeclTyped attr =>
Param attr -> TypeBase Shape Uniqueness
paramDeclType FParam lore
mergeParam)) Bool -> Bool -> Bool
||
            TypeBase Shape Uniqueness -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank (FParam lore -> TypeBase Shape Uniqueness
forall attr.
DeclTyped attr =>
Param attr -> TypeBase Shape Uniqueness
paramDeclType FParam lore
mergeParam) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1,
            SubExp -> Bool
isInvariant SubExp
resExp,
            -- Also do not remove the condition in a while-loop.
            Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ FParam lore -> VName
forall attr. Param attr -> VName
paramName FParam lore
mergeParam VName -> Names -> Bool
`nameIn` LoopForm lore -> Names
forall a. FreeIn a => a -> Names
freeIn LoopForm lore
form =
          let (Maybe (Ident, SubExp)
bnd, [(PatElemT (LetAttr lore), VName)]
explpat'') =
                (FParam lore, SubExp)
-> [(PatElemT (LetAttr lore), VName)]
-> (Maybe (Ident, SubExp), [(PatElemT (LetAttr lore), VName)])
forall attr attr b.
Typed attr =>
(Param attr, b)
-> [(PatElemT attr, VName)]
-> (Maybe (Ident, b), [(PatElemT attr, VName)])
removeFromResult (FParam lore
mergeParam,SubExp
mergeInit) [(PatElemT (LetAttr lore), VName)]
explpat'
          in (([(Ident, SubExp)] -> [(Ident, SubExp)])
-> ((Ident, SubExp) -> [(Ident, SubExp)] -> [(Ident, SubExp)])
-> Maybe (Ident, SubExp)
-> [(Ident, SubExp)]
-> [(Ident, SubExp)]
forall b a. b -> (a -> b) -> Maybe a -> b
maybe [(Ident, SubExp)] -> [(Ident, SubExp)]
forall a. a -> a
id (:) Maybe (Ident, SubExp)
bnd ([(Ident, SubExp)] -> [(Ident, SubExp)])
-> [(Ident, SubExp)] -> [(Ident, SubExp)]
forall a b. (a -> b) -> a -> b
$ (FParam lore -> Ident
forall attr. Typed attr => Param attr -> Ident
paramIdent FParam lore
mergeParam, SubExp
mergeInit) (Ident, SubExp) -> [(Ident, SubExp)] -> [(Ident, SubExp)]
forall a. a -> [a] -> [a]
: [(Ident, SubExp)]
invariant,
              [(PatElemT (LetAttr lore), VName)]
explpat'', [(FParam lore, SubExp)]
merge', [SubExp]
resExps)
          where
            -- A non-unique merge variable is invariant if the corresponding
            -- subexp in the result is EITHER:
            --
            --  (0) a variable of the same name as the parameter, where
            --  all existential parameters are already known to be
            --  invariant
            isInvariant :: SubExp -> Bool
isInvariant (Var VName
v2)
              | FParam lore -> VName
forall attr. Param attr -> VName
paramName FParam lore
mergeParam VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
v2 =
                Names -> FParam lore -> Bool
allExistentialInvariant
                ([VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ ((Ident, SubExp) -> VName) -> [(Ident, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (Ident -> VName
identName (Ident -> VName)
-> ((Ident, SubExp) -> Ident) -> (Ident, SubExp) -> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Ident, SubExp) -> Ident
forall a b. (a, b) -> a
fst) [(Ident, SubExp)]
invariant) FParam lore
mergeParam
            --  (1) or identical to the initial value of the parameter.
            isInvariant SubExp
_ = SubExp
mergeInit SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
resExp

        checkInvariance ((FParam lore
mergeParam,SubExp
mergeInit), SubExp
resExp) ([(Ident, SubExp)]
invariant, [(PatElemT (LetAttr lore), VName)]
explpat', [(FParam lore, SubExp)]
merge', [SubExp]
resExps) =
          ([(Ident, SubExp)]
invariant, [(PatElemT (LetAttr lore), VName)]
explpat', (FParam lore
mergeParam,SubExp
mergeInit)(FParam lore, SubExp)
-> [(FParam lore, SubExp)] -> [(FParam lore, SubExp)]
forall a. a -> [a] -> [a]
:[(FParam lore, SubExp)]
merge', SubExp
resExpSubExp -> [SubExp] -> [SubExp]
forall a. a -> [a] -> [a]
:[SubExp]
resExps)

        allExistentialInvariant :: Names -> FParam lore -> Bool
allExistentialInvariant Names
namesOfInvariant FParam lore
mergeParam =
          (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Names -> VName -> Bool
invariantOrNotMergeParam Names
namesOfInvariant) ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$
          FParam lore -> Names
forall a. FreeIn a => a -> Names
freeIn FParam lore
mergeParam Names -> Names -> Names
`namesSubtract` VName -> Names
oneName (FParam lore -> VName
forall attr. Param attr -> VName
paramName FParam lore
mergeParam)
        invariantOrNotMergeParam :: Names -> VName -> Bool
invariantOrNotMergeParam Names
namesOfInvariant VName
name =
          Bool -> Bool
not (VName
name VName -> Names -> Bool
`nameIn` Names
namesOfMergeParams) Bool -> Bool -> Bool
||
          VName
name VName -> Names -> Bool
`nameIn` Names
namesOfInvariant

-- | A function that, given a subexpression, returns its type.
type TypeLookup = SubExp -> Maybe Type

-- | A simple rule is a top-down rule that can be expressed as a pure
-- function.
type SimpleRule lore = VarLookup lore -> TypeLookup -> BasicOp lore -> Maybe (BasicOp lore, Certificates)

simpleRules :: [SimpleRule lore]
simpleRules :: [SimpleRule lore]
simpleRules = [ SimpleRule lore
forall lore. SimpleRule lore
simplifyBinOp
              , SimpleRule lore
forall lore. SimpleRule lore
simplifyCmpOp
              , SimpleRule lore
forall lore. SimpleRule lore
simplifyUnOp
              , SimpleRule lore
forall lore. SimpleRule lore
simplifyConvOp
              , SimpleRule lore
forall lore. SimpleRule lore
simplifyAssert
              , SimpleRule lore
forall lore. SimpleRule lore
copyScratchToScratch
              , SimpleRule lore
forall lore. SimpleRule lore
simplifyIdentityReshape
              , SimpleRule lore
forall lore. SimpleRule lore
simplifyReshapeReshape
              , SimpleRule lore
forall lore. SimpleRule lore
simplifyReshapeScratch
              , SimpleRule lore
forall lore. SimpleRule lore
simplifyReshapeReplicate
              , SimpleRule lore
forall lore. SimpleRule lore
simplifyReshapeIota
              , SimpleRule lore
forall lore. SimpleRule lore
improveReshape ]

simplifyClosedFormLoop :: BinderOps lore => TopDownRuleDoLoop lore
simplifyClosedFormLoop :: TopDownRuleDoLoop lore
simplifyClosedFormLoop TopDown lore
_ Pattern lore
pat StmAux (ExpAttr lore)
_ ([], [(FParam lore, SubExp)]
val, ForLoop VName
i IntType
_ SubExp
bound [], BodyT lore
body) =
  RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ Pattern lore
-> [(FParam lore, SubExp)]
-> Names
-> SubExp
-> BodyT lore
-> RuleM lore ()
forall lore.
(Attributes lore, BinderOps lore) =>
Pattern lore
-> [(FParam lore, SubExp)]
-> Names
-> SubExp
-> Body lore
-> RuleM lore ()
loopClosedForm Pattern lore
pat [(FParam lore, SubExp)]
val (VName -> Names
oneName VName
i) SubExp
bound BodyT lore
body
simplifyClosedFormLoop TopDown lore
_ Pattern lore
_ StmAux (ExpAttr lore)
_ ([(FParam lore, SubExp)], [(FParam lore, SubExp)], LoopForm lore,
 BodyT lore)
_ = Rule lore
forall lore. Rule lore
Skip

simplifyLoopVariables :: (BinderOps lore, Aliased lore) => TopDownRuleDoLoop lore
simplifyLoopVariables :: TopDownRuleDoLoop lore
simplifyLoopVariables TopDown lore
vtable Pattern lore
pat StmAux (ExpAttr lore)
_ ([(FParam lore, SubExp)]
ctx, [(FParam lore, SubExp)]
val, form :: LoopForm lore
form@(ForLoop VName
i IntType
it SubExp
num_iters [(LParam lore, VName)]
loop_vars), BodyT lore
body)
  | [Maybe (RuleM lore IndexResult)]
simplifiable <- ((LParam lore, VName) -> Maybe (RuleM lore IndexResult))
-> [(LParam lore, VName)] -> [Maybe (RuleM lore IndexResult)]
forall a b. (a -> b) -> [a] -> [b]
map (LParam lore, VName) -> Maybe (RuleM lore IndexResult)
checkIfSimplifiable [(LParam lore, VName)]
loop_vars,
    Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (Maybe (RuleM lore IndexResult) -> Bool)
-> [Maybe (RuleM lore IndexResult)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Maybe (RuleM lore IndexResult) -> Bool
forall a. Maybe a -> Bool
isNothing [Maybe (RuleM lore IndexResult)]
simplifiable = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
      -- Check if the simplifications throw away more information than
      -- we are comfortable with at this stage.
      ([Maybe (LParam lore, VName)]
maybe_loop_vars, [Stms lore]
body_prefix_stms) <-
        Scope lore
-> RuleM lore ([Maybe (LParam lore, VName)], [Stms lore])
-> RuleM lore ([Maybe (LParam lore, VName)], [Stms lore])
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (LoopForm lore -> Scope lore
forall lore a. Scoped lore a => a -> Scope lore
scopeOf LoopForm lore
form) (RuleM lore ([Maybe (LParam lore, VName)], [Stms lore])
 -> RuleM lore ([Maybe (LParam lore, VName)], [Stms lore]))
-> RuleM lore ([Maybe (LParam lore, VName)], [Stms lore])
-> RuleM lore ([Maybe (LParam lore, VName)], [Stms lore])
forall a b. (a -> b) -> a -> b
$
        [(Maybe (LParam lore, VName), Stms lore)]
-> ([Maybe (LParam lore, VName)], [Stms lore])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Maybe (LParam lore, VName), Stms lore)]
 -> ([Maybe (LParam lore, VName)], [Stms lore]))
-> RuleM lore [(Maybe (LParam lore, VName), Stms lore)]
-> RuleM lore ([Maybe (LParam lore, VName)], [Stms lore])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((LParam lore, VName)
 -> Maybe (RuleM lore IndexResult)
 -> RuleM lore (Maybe (LParam lore, VName), Stms lore))
-> [(LParam lore, VName)]
-> [Maybe (RuleM lore IndexResult)]
-> RuleM lore [(Maybe (LParam lore, VName), Stms lore)]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (LParam lore, VName)
-> Maybe (RuleM lore IndexResult)
-> RuleM lore (Maybe (LParam lore, VName), Stms lore)
onLoopVar [(LParam lore, VName)]
loop_vars [Maybe (RuleM lore IndexResult)]
simplifiable
      if [Maybe (LParam lore, VName)]
maybe_loop_vars [Maybe (LParam lore, VName)]
-> [Maybe (LParam lore, VName)] -> Bool
forall a. Eq a => a -> a -> Bool
== ((LParam lore, VName) -> Maybe (LParam lore, VName))
-> [(LParam lore, VName)] -> [Maybe (LParam lore, VName)]
forall a b. (a -> b) -> [a] -> [b]
map (LParam lore, VName) -> Maybe (LParam lore, VName)
forall a. a -> Maybe a
Just [(LParam lore, VName)]
loop_vars
        then RuleM lore ()
forall lore a. RuleM lore a
cannotSimplify
        else do BodyT lore
body' <- RuleM lore (Body (Lore (RuleM lore)))
-> RuleM lore (Body (Lore (RuleM lore)))
forall (m :: * -> *).
MonadBinder m =>
m (Body (Lore m)) -> m (Body (Lore m))
insertStmsM (RuleM lore (Body (Lore (RuleM lore)))
 -> RuleM lore (Body (Lore (RuleM lore))))
-> RuleM lore (Body (Lore (RuleM lore)))
-> RuleM lore (Body (Lore (RuleM lore)))
forall a b. (a -> b) -> a -> b
$ do
                  Stms (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms (Stms (Lore (RuleM lore)) -> RuleM lore ())
-> Stms (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ [Stms lore] -> Stms lore
forall a. Monoid a => [a] -> a
mconcat [Stms lore]
body_prefix_stms
                  [SubExp] -> RuleM lore (BodyT lore)
forall (m :: * -> *).
MonadBinder m =>
[SubExp] -> m (Body (Lore m))
resultBodyM ([SubExp] -> RuleM lore (BodyT lore))
-> RuleM lore [SubExp] -> RuleM lore (BodyT lore)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Body (Lore (RuleM lore)) -> RuleM lore [SubExp]
forall (m :: * -> *). MonadBinder m => Body (Lore m) -> m [SubExp]
bodyBind BodyT lore
Body (Lore (RuleM lore))
body
                Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind_ Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ [(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
forall lore.
[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
DoLoop [(FParam lore, SubExp)]
ctx [(FParam lore, SubExp)]
val
                  (VName
-> IntType -> SubExp -> [(LParam lore, VName)] -> LoopForm lore
forall lore.
VName
-> IntType -> SubExp -> [(LParam lore, VName)] -> LoopForm lore
ForLoop VName
i IntType
it SubExp
num_iters ([(LParam lore, VName)] -> LoopForm lore)
-> [(LParam lore, VName)] -> LoopForm lore
forall a b. (a -> b) -> a -> b
$ [Maybe (LParam lore, VName)] -> [(LParam lore, VName)]
forall a. [Maybe a] -> [a]
catMaybes [Maybe (LParam lore, VName)]
maybe_loop_vars) BodyT lore
body'

  where seType :: SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType (Var VName
v)
          | VName
v VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
i = TypeBase Shape NoUniqueness -> Maybe (TypeBase Shape NoUniqueness)
forall a. a -> Maybe a
Just (TypeBase Shape NoUniqueness
 -> Maybe (TypeBase Shape NoUniqueness))
-> TypeBase Shape NoUniqueness
-> Maybe (TypeBase Shape NoUniqueness)
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase Shape NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim (PrimType -> TypeBase Shape NoUniqueness)
-> PrimType -> TypeBase Shape NoUniqueness
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
it
          | Bool
otherwise = VName -> TopDown lore -> Maybe (TypeBase Shape NoUniqueness)
forall lore.
Attributes lore =>
VName -> SymbolTable lore -> Maybe (TypeBase Shape NoUniqueness)
ST.lookupType VName
v TopDown lore
vtable
        seType (Constant PrimValue
v) = TypeBase Shape NoUniqueness -> Maybe (TypeBase Shape NoUniqueness)
forall a. a -> Maybe a
Just (TypeBase Shape NoUniqueness
 -> Maybe (TypeBase Shape NoUniqueness))
-> TypeBase Shape NoUniqueness
-> Maybe (TypeBase Shape NoUniqueness)
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase Shape NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim (PrimType -> TypeBase Shape NoUniqueness)
-> PrimType -> TypeBase Shape NoUniqueness
forall a b. (a -> b) -> a -> b
$ PrimValue -> PrimType
primValueType PrimValue
v
        consumed_in_body :: Names
consumed_in_body = BodyT lore -> Names
forall lore. Aliased lore => Body lore -> Names
consumedInBody BodyT lore
body

        vtable' :: TopDown lore
vtable' = Scope lore -> TopDown lore
forall lore. Attributes lore => Scope lore -> SymbolTable lore
ST.fromScope (LoopForm lore -> Scope lore
forall lore a. Scoped lore a => a -> Scope lore
scopeOf LoopForm lore
form) TopDown lore -> TopDown lore -> TopDown lore
forall a. Semigroup a => a -> a -> a
<> TopDown lore
vtable

        checkIfSimplifiable :: (LParam lore, VName) -> Maybe (RuleM lore IndexResult)
checkIfSimplifiable (LParam lore
p,VName
arr) =
          SymbolTable (Lore (RuleM lore))
-> (SubExp -> Maybe (TypeBase Shape NoUniqueness))
-> VName
-> Slice SubExp
-> Bool
-> Maybe (RuleM lore IndexResult)
forall (m :: * -> *).
MonadBinder m =>
SymbolTable (Lore m)
-> (SubExp -> Maybe (TypeBase Shape NoUniqueness))
-> VName
-> Slice SubExp
-> Bool
-> Maybe (m IndexResult)
simplifyIndexing TopDown lore
SymbolTable (Lore (RuleM lore))
vtable' SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType VName
arr
          (SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (VName -> SubExp
Var VName
i) DimIndex SubExp -> Slice SubExp -> Slice SubExp
forall a. a -> [a] -> [a]
: TypeBase Shape NoUniqueness -> Slice SubExp -> Slice SubExp
fullSlice (LParam lore -> TypeBase Shape NoUniqueness
forall attr.
Typed attr =>
Param attr -> TypeBase Shape NoUniqueness
paramType LParam lore
p) []) (Bool -> Maybe (RuleM lore IndexResult))
-> Bool -> Maybe (RuleM lore IndexResult)
forall a b. (a -> b) -> a -> b
$
          LParam lore -> VName
forall attr. Param attr -> VName
paramName LParam lore
p VName -> Names -> Bool
`nameIn` Names
consumed_in_body

        -- We only want this simplification if the result does not refer
        -- to 'i' at all, or does not contain accesses.
        onLoopVar :: (LParam lore, VName)
-> Maybe (RuleM lore IndexResult)
-> RuleM lore (Maybe (LParam lore, VName), Stms lore)
onLoopVar (LParam lore
p,VName
arr) Maybe (RuleM lore IndexResult)
Nothing =
          (Maybe (LParam lore, VName), Stms lore)
-> RuleM lore (Maybe (LParam lore, VName), Stms lore)
forall (m :: * -> *) a. Monad m => a -> m a
return ((LParam lore, VName) -> Maybe (LParam lore, VName)
forall a. a -> Maybe a
Just (LParam lore
p,VName
arr), Stms lore
forall a. Monoid a => a
mempty)
        onLoopVar (LParam lore
p,VName
arr) (Just RuleM lore IndexResult
m) = do
          (IndexResult
x,Stms lore
x_stms) <- RuleM lore IndexResult
-> RuleM lore (IndexResult, Stms (Lore (RuleM lore)))
forall (m :: * -> *) a.
MonadBinder m =>
m a -> m (a, Stms (Lore m))
collectStms RuleM lore IndexResult
m
          case IndexResult
x of
            IndexResult Certificates
cs VName
arr' Slice SubExp
slice
              | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (Stm lore -> Bool) -> Stms lore -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any ((VName
i VName -> Names -> Bool
`nameIn`) (Names -> Bool) -> (Stm lore -> Names) -> Stm lore -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm lore -> Names
forall a. FreeIn a => a -> Names
freeIn) Stms lore
x_stms,
                DimFix (Var VName
j) : Slice SubExp
slice' <- Slice SubExp
slice,
                VName
j VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
i, Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ VName
i VName -> Names -> Bool
`nameIn` Slice SubExp -> Names
forall a. FreeIn a => a -> Names
freeIn Slice SubExp
slice -> do
                  Stms (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms Stms lore
Stms (Lore (RuleM lore))
x_stms
                  SubExp
w <- Int -> TypeBase Shape NoUniqueness -> SubExp
forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 (TypeBase Shape NoUniqueness -> SubExp)
-> RuleM lore (TypeBase Shape NoUniqueness) -> RuleM lore SubExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> RuleM lore (TypeBase Shape NoUniqueness)
forall lore (m :: * -> *).
HasScope lore m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
arr'
                  VName
for_in_partial <-
                    Certificates -> RuleM lore VName -> RuleM lore VName
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs (RuleM lore VName -> RuleM lore VName)
-> RuleM lore VName -> RuleM lore VName
forall a b. (a -> b) -> a -> b
$ String -> Exp (Lore (RuleM lore)) -> RuleM lore VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"for_in_partial" (Exp (Lore (RuleM lore)) -> RuleM lore VName)
-> Exp (Lore (RuleM lore)) -> RuleM lore VName
forall a b. (a -> b) -> a -> b
$ BasicOp lore -> ExpT lore
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp lore -> ExpT lore) -> BasicOp lore -> ExpT lore
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp lore
forall lore. VName -> Slice SubExp -> BasicOp lore
Index VName
arr' (Slice SubExp -> BasicOp lore) -> Slice SubExp -> BasicOp lore
forall a b. (a -> b) -> a -> b
$
                    SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice (IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
0) SubExp
w (IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
1) DimIndex SubExp -> Slice SubExp -> Slice SubExp
forall a. a -> [a] -> [a]
: Slice SubExp
slice'
                  (Maybe (LParam lore, VName), Stms lore)
-> RuleM lore (Maybe (LParam lore, VName), Stms lore)
forall (m :: * -> *) a. Monad m => a -> m a
return ((LParam lore, VName) -> Maybe (LParam lore, VName)
forall a. a -> Maybe a
Just (LParam lore
p, VName
for_in_partial), Stms lore
forall a. Monoid a => a
mempty)

            SubExpResult Certificates
cs SubExp
se
              | (Stm lore -> Bool) -> Stms lore -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (ExpT lore -> Bool
forall lore. ExpT lore -> Bool
notIndex (ExpT lore -> Bool) -> (Stm lore -> ExpT lore) -> Stm lore -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm lore -> ExpT lore
forall lore. Stm lore -> Exp lore
stmExp) Stms lore
x_stms -> do
                  Stms lore
x_stms' <- RuleM lore () -> RuleM lore (Stms (Lore (RuleM lore)))
forall (m :: * -> *) a. MonadBinder m => m a -> m (Stms (Lore m))
collectStms_ (RuleM lore () -> RuleM lore (Stms (Lore (RuleM lore))))
-> RuleM lore () -> RuleM lore (Stms (Lore (RuleM lore)))
forall a b. (a -> b) -> a -> b
$ Certificates -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ do
                    Stms (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms Stms lore
Stms (Lore (RuleM lore))
x_stms
                    [VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames_ [LParam lore -> VName
forall attr. Param attr -> VName
paramName LParam lore
p] (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp lore -> ExpT lore
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp lore -> ExpT lore) -> BasicOp lore -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp lore
forall lore. SubExp -> BasicOp lore
SubExp SubExp
se
                  (Maybe (LParam lore, VName), Stms lore)
-> RuleM lore (Maybe (LParam lore, VName), Stms lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (LParam lore, VName)
forall a. Maybe a
Nothing, Stms lore
x_stms')

            IndexResult
_ -> (Maybe (LParam lore, VName), Stms lore)
-> RuleM lore (Maybe (LParam lore, VName), Stms lore)
forall (m :: * -> *) a. Monad m => a -> m a
return ((LParam lore, VName) -> Maybe (LParam lore, VName)
forall a. a -> Maybe a
Just (LParam lore
p,VName
arr), Stms lore
forall a. Monoid a => a
mempty)

        notIndex :: ExpT lore -> Bool
notIndex (BasicOp Index{}) = Bool
False
        notIndex ExpT lore
_                 = Bool
True
simplifyLoopVariables TopDown lore
_ Pattern lore
_ StmAux (ExpAttr lore)
_ ([(FParam lore, SubExp)], [(FParam lore, SubExp)], LoopForm lore,
 BodyT lore)
_ = Rule lore
forall lore. Rule lore
Skip

simplifKnownIterationLoop :: BinderOps lore => TopDownRuleDoLoop lore
simplifKnownIterationLoop :: TopDownRuleDoLoop lore
simplifKnownIterationLoop TopDown lore
_ Pattern lore
pat StmAux (ExpAttr lore)
_ ([(FParam lore, SubExp)]
ctx, [(FParam lore, SubExp)]
val, ForLoop VName
i IntType
it (Constant PrimValue
iters) [(LParam lore, VName)]
loop_vars, BodyT lore
body)
  | PrimValue -> Bool
zeroIsh PrimValue
iters = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
      let bindResult :: PatElemT attr -> SubExp -> m [Ident]
bindResult PatElemT attr
p SubExp
r = [VName] -> Exp (Lore m) -> m [Ident]
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m [Ident]
letBindNames [PatElemT attr -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT attr
p] (Exp (Lore m) -> m [Ident]) -> Exp (Lore m) -> m [Ident]
forall a b. (a -> b) -> a -> b
$ BasicOp (Lore m) -> Exp (Lore m)
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp (Lore m) -> Exp (Lore m))
-> BasicOp (Lore m) -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp (Lore m)
forall lore. SubExp -> BasicOp lore
SubExp SubExp
r
      (PatElemT (LetAttr lore) -> SubExp -> RuleM lore [Ident])
-> [PatElemT (LetAttr lore)] -> [SubExp] -> RuleM lore ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ PatElemT (LetAttr lore) -> SubExp -> RuleM lore [Ident]
forall (m :: * -> *) attr.
MonadBinder m =>
PatElemT attr -> SubExp -> m [Ident]
bindResult (Pattern lore -> [PatElemT (LetAttr lore)]
forall attr. PatternT attr -> [PatElemT attr]
patternContextElements Pattern lore
pat) (((FParam lore, SubExp) -> SubExp)
-> [(FParam lore, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (FParam lore, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(FParam lore, SubExp)]
ctx)
      (PatElemT (LetAttr lore) -> SubExp -> RuleM lore [Ident])
-> [PatElemT (LetAttr lore)] -> [SubExp] -> RuleM lore ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ PatElemT (LetAttr lore) -> SubExp -> RuleM lore [Ident]
forall (m :: * -> *) attr.
MonadBinder m =>
PatElemT attr -> SubExp -> m [Ident]
bindResult (Pattern lore -> [PatElemT (LetAttr lore)]
forall attr. PatternT attr -> [PatElemT attr]
patternValueElements Pattern lore
pat) (((FParam lore, SubExp) -> SubExp)
-> [(FParam lore, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (FParam lore, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(FParam lore, SubExp)]
val)

  | PrimValue -> Bool
oneIsh PrimValue
iters = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do

  [(FParam lore, SubExp)]
-> ((FParam lore, SubExp) -> RuleM lore [Ident]) -> RuleM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([(FParam lore, SubExp)]
ctx[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)] -> [(FParam lore, SubExp)]
forall a. [a] -> [a] -> [a]
++[(FParam lore, SubExp)]
val) (((FParam lore, SubExp) -> RuleM lore [Ident]) -> RuleM lore ())
-> ((FParam lore, SubExp) -> RuleM lore [Ident]) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ \(FParam lore
mergevar, SubExp
mergeinit) ->
    [VName] -> Exp (Lore (RuleM lore)) -> RuleM lore [Ident]
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m [Ident]
letBindNames [FParam lore -> VName
forall attr. Param attr -> VName
paramName FParam lore
mergevar] (Exp (Lore (RuleM lore)) -> RuleM lore [Ident])
-> Exp (Lore (RuleM lore)) -> RuleM lore [Ident]
forall a b. (a -> b) -> a -> b
$ BasicOp lore -> ExpT lore
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp lore -> ExpT lore) -> BasicOp lore -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp lore
forall lore. SubExp -> BasicOp lore
SubExp SubExp
mergeinit

  [VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames_ [VName
i] (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp lore -> ExpT lore
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp lore -> ExpT lore) -> BasicOp lore -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp lore
forall lore. SubExp -> BasicOp lore
SubExp (SubExp -> BasicOp lore) -> SubExp -> BasicOp lore
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
it Integer
0

  [(LParam lore, VName)]
-> ((LParam lore, VName) -> RuleM lore ()) -> RuleM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(LParam lore, VName)]
loop_vars (((LParam lore, VName) -> RuleM lore ()) -> RuleM lore ())
-> ((LParam lore, VName) -> RuleM lore ()) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ \(LParam lore
p,VName
arr) ->
    [VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames_ [LParam lore -> VName
forall attr. Param attr -> VName
paramName LParam lore
p] (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp lore -> ExpT lore
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp lore -> ExpT lore) -> BasicOp lore -> ExpT lore
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp lore
forall lore. VName -> Slice SubExp -> BasicOp lore
Index VName
arr (Slice SubExp -> BasicOp lore) -> Slice SubExp -> BasicOp lore
forall a b. (a -> b) -> a -> b
$
    SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
0) DimIndex SubExp -> Slice SubExp -> Slice SubExp
forall a. a -> [a] -> [a]
: TypeBase Shape NoUniqueness -> Slice SubExp -> Slice SubExp
fullSlice (LParam lore -> TypeBase Shape NoUniqueness
forall attr.
Typed attr =>
Param attr -> TypeBase Shape NoUniqueness
paramType LParam lore
p) []

  -- Some of the sizes in the types here might be temporarily wrong
  -- until copy propagation fixes it up.
  [SubExp]
res <- Body (Lore (RuleM lore)) -> RuleM lore [SubExp]
forall (m :: * -> *). MonadBinder m => Body (Lore m) -> m [SubExp]
bodyBind BodyT lore
Body (Lore (RuleM lore))
body
  [(VName, SubExp)]
-> ((VName, SubExp) -> RuleM lore ()) -> RuleM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pattern lore -> [VName]
forall attr. PatternT attr -> [VName]
patternNames Pattern lore
pat) [SubExp]
res) (((VName, SubExp) -> RuleM lore ()) -> RuleM lore ())
-> ((VName, SubExp) -> RuleM lore ()) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ \(VName
v, SubExp
se) ->
    [VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames_ [VName
v] (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp lore -> ExpT lore
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp lore -> ExpT lore) -> BasicOp lore -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp lore
forall lore. SubExp -> BasicOp lore
SubExp SubExp
se
simplifKnownIterationLoop TopDown lore
_ Pattern lore
_ StmAux (ExpAttr lore)
_ ([(FParam lore, SubExp)], [(FParam lore, SubExp)], LoopForm lore,
 BodyT lore)
_ =
  Rule lore
forall lore. Rule lore
Skip

-- | Turn @copy(x)@ into @x@ iff @x@ is not used after this copy
-- statement and it can be consumed.
--
-- This simplistic rule is only valid before we introduce memory.
removeUnnecessaryCopy :: BinderOps lore => BottomUpRuleBasicOp lore
removeUnnecessaryCopy :: BottomUpRuleBasicOp lore
removeUnnecessaryCopy (SymbolTable lore
vtable,UsageTable
used) (Pattern [] [PatElemT (LetAttr lore)
d]) StmAux (ExpAttr lore)
_ (Copy VName
v)
  | Bool -> Bool
not (VName
v VName -> UsageTable -> Bool
`UT.isConsumed` UsageTable
used),
    (Bool -> Bool
not (VName
v VName -> UsageTable -> Bool
`UT.used` UsageTable
used) Bool -> Bool -> Bool
&& Bool
consumable) Bool -> Bool -> Bool
|| Bool -> Bool
not (PatElemT (LetAttr lore) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (LetAttr lore)
d VName -> UsageTable -> Bool
`UT.isConsumed` UsageTable
used) =
      RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ [VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames_ [PatElemT (LetAttr lore) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (LetAttr lore)
d] (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp lore -> ExpT lore
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp lore -> ExpT lore) -> BasicOp lore -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp lore
forall lore. SubExp -> BasicOp lore
SubExp (SubExp -> BasicOp lore) -> SubExp -> BasicOp lore
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v
  where -- We need to make sure we can even consume the original.
        -- This is currently a hacky check, much too conservative,
        -- because we don't have the information conveniently
        -- available.
        consumable :: Bool
consumable = case VName -> Map VName (NameInfo lore) -> Maybe (NameInfo lore)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v (Map VName (NameInfo lore) -> Maybe (NameInfo lore))
-> Map VName (NameInfo lore) -> Maybe (NameInfo lore)
forall a b. (a -> b) -> a -> b
$ SymbolTable lore -> Map VName (NameInfo lore)
forall lore. SymbolTable lore -> Scope lore
ST.toScope SymbolTable lore
vtable of
                       Just (FParamInfo FParamAttr lore
info) -> TypeBase Shape Uniqueness -> Bool
forall shape. TypeBase shape Uniqueness -> Bool
unique (TypeBase Shape Uniqueness -> Bool)
-> TypeBase Shape Uniqueness -> Bool
forall a b. (a -> b) -> a -> b
$ FParamAttr lore -> TypeBase Shape Uniqueness
forall t. DeclTyped t => t -> TypeBase Shape Uniqueness
declTypeOf FParamAttr lore
info
                       Maybe (NameInfo lore)
_ -> Bool
False
removeUnnecessaryCopy (SymbolTable lore, UsageTable)
_ PatternT (LetAttr lore)
_ StmAux (ExpAttr lore)
_ BasicOp lore
_ = Rule lore
forall lore. Rule lore
Skip

simplifyCmpOp :: SimpleRule lore
simplifyCmpOp :: SimpleRule lore
simplifyCmpOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (CmpOp CmpOp
cmp SubExp
e1 SubExp
e2)
  | SubExp
e1 SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 = PrimValue -> Maybe (BasicOp lore, Certificates)
forall lore. PrimValue -> Maybe (BasicOp lore, Certificates)
constRes (PrimValue -> Maybe (BasicOp lore, Certificates))
-> PrimValue -> Maybe (BasicOp lore, Certificates)
forall a b. (a -> b) -> a -> b
$ Bool -> PrimValue
BoolValue (Bool -> PrimValue) -> Bool -> PrimValue
forall a b. (a -> b) -> a -> b
$
               case CmpOp
cmp of CmpEq{}  -> Bool
True
                           CmpSlt{} -> Bool
False
                           CmpUlt{} -> Bool
False
                           CmpSle{} -> Bool
True
                           CmpUle{} -> Bool
True
                           FCmpLt{} -> Bool
False
                           FCmpLe{} -> Bool
True
                           CmpOp
CmpLlt -> Bool
False
                           CmpOp
CmpLle -> Bool
True

simplifyCmpOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (CmpOp CmpOp
cmp (Constant PrimValue
v1) (Constant PrimValue
v2)) =
  PrimValue -> Maybe (BasicOp lore, Certificates)
forall lore. PrimValue -> Maybe (BasicOp lore, Certificates)
constRes (PrimValue -> Maybe (BasicOp lore, Certificates))
-> Maybe PrimValue -> Maybe (BasicOp lore, Certificates)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Bool -> PrimValue
BoolValue (Bool -> PrimValue) -> Maybe Bool -> Maybe PrimValue
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CmpOp -> PrimValue -> PrimValue -> Maybe Bool
doCmpOp CmpOp
cmp PrimValue
v1 PrimValue
v2

simplifyCmpOp VarLookup lore
look SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (CmpOp CmpEq{} (Constant (IntValue IntValue
x)) (Var VName
v))
  | Just (BasicOp (ConvOp BToI{} SubExp
b), Certificates
cs) <- VarLookup lore
look VName
v =
      case IntValue -> Int
forall int. Integral int => IntValue -> int
valueIntegral IntValue
x :: Int of
        Int
1 -> (BasicOp lore, Certificates) -> Maybe (BasicOp lore, Certificates)
forall a. a -> Maybe a
Just (SubExp -> BasicOp lore
forall lore. SubExp -> BasicOp lore
SubExp SubExp
b, Certificates
cs)
        Int
0 -> (BasicOp lore, Certificates) -> Maybe (BasicOp lore, Certificates)
forall a. a -> Maybe a
Just (UnOp -> SubExp -> BasicOp lore
forall lore. UnOp -> SubExp -> BasicOp lore
UnOp UnOp
Not SubExp
b, Certificates
cs)
        Int
_ -> (BasicOp lore, Certificates) -> Maybe (BasicOp lore, Certificates)
forall a. a -> Maybe a
Just (SubExp -> BasicOp lore
forall lore. SubExp -> BasicOp lore
SubExp (PrimValue -> SubExp
Constant (Bool -> PrimValue
BoolValue Bool
False)), Certificates
cs)

simplifyCmpOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ BasicOp lore
_ = Maybe (BasicOp lore, Certificates)
forall a. Maybe a
Nothing

simplifyBinOp :: SimpleRule lore

simplifyBinOp :: SimpleRule lore
simplifyBinOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp BinOp
op (Constant PrimValue
v1) (Constant PrimValue
v2))
  | Just PrimValue
res <- BinOp -> PrimValue -> PrimValue -> Maybe PrimValue
doBinOp BinOp
op PrimValue
v1 PrimValue
v2 =
      PrimValue -> Maybe (BasicOp lore, Certificates)
forall lore. PrimValue -> Maybe (BasicOp lore, Certificates)
constRes PrimValue
res

simplifyBinOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp Add{} SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp lore, Certificates)
forall lore. SubExp -> Maybe (BasicOp lore, Certificates)
subExpRes SubExp
e2
  | SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp lore, Certificates)
forall lore. SubExp -> Maybe (BasicOp lore, Certificates)
subExpRes SubExp
e1

simplifyBinOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp FAdd{} SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp lore, Certificates)
forall lore. SubExp -> Maybe (BasicOp lore, Certificates)
subExpRes SubExp
e2
  | SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp lore, Certificates)
forall lore. SubExp -> Maybe (BasicOp lore, Certificates)
subExpRes SubExp
e1


simplifyBinOp VarLookup lore
look SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp Sub{} SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp lore, Certificates)
forall lore. SubExp -> Maybe (BasicOp lore, Certificates)
subExpRes SubExp
e1
  -- Cases for simplifying (a+b)-b and permutations.
  | Var VName
v1 <- SubExp
e1,
    Just (BasicOp (BinOp Add{} SubExp
e1_a SubExp
e1_b), Certificates
cs) <- VarLookup lore
look VName
v1,
    SubExp
e1_a SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 = (BasicOp lore, Certificates) -> Maybe (BasicOp lore, Certificates)
forall a. a -> Maybe a
Just (SubExp -> BasicOp lore
forall lore. SubExp -> BasicOp lore
SubExp SubExp
e1_b, Certificates
cs)
  | Var VName
v1 <- SubExp
e1,
    Just (BasicOp (BinOp Add{} SubExp
e1_a SubExp
e1_b), Certificates
cs) <- VarLookup lore
look VName
v1,
    SubExp
e1_b SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 = (BasicOp lore, Certificates) -> Maybe (BasicOp lore, Certificates)
forall a. a -> Maybe a
Just (SubExp -> BasicOp lore
forall lore. SubExp -> BasicOp lore
SubExp SubExp
e1_a, Certificates
cs)
  | Var VName
v2 <- SubExp
e2,
    Just (BasicOp (BinOp Add{} SubExp
e2_a SubExp
e2_b), Certificates
cs) <- VarLookup lore
look VName
v2,
    SubExp
e2_a SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e1 = (BasicOp lore, Certificates) -> Maybe (BasicOp lore, Certificates)
forall a. a -> Maybe a
Just (SubExp -> BasicOp lore
forall lore. SubExp -> BasicOp lore
SubExp SubExp
e2_b, Certificates
cs)
  | Var VName
v2 <- SubExp
e1,
    Just (BasicOp (BinOp Add{} SubExp
e2_a SubExp
e2_b), Certificates
cs) <- VarLookup lore
look VName
v2,
    SubExp
e2_b SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e1 = (BasicOp lore, Certificates) -> Maybe (BasicOp lore, Certificates)
forall a. a -> Maybe a
Just (SubExp -> BasicOp lore
forall lore. SubExp -> BasicOp lore
SubExp SubExp
e2_a, Certificates
cs)

simplifyBinOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp FSub{} SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp lore, Certificates)
forall lore. SubExp -> Maybe (BasicOp lore, Certificates)
subExpRes SubExp
e1

simplifyBinOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp Mul{} SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp lore, Certificates)
forall lore. SubExp -> Maybe (BasicOp lore, Certificates)
subExpRes SubExp
e1
  | SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp lore, Certificates)
forall lore. SubExp -> Maybe (BasicOp lore, Certificates)
subExpRes SubExp
e2
  | SubExp -> Bool
isCt1 SubExp
e1 = SubExp -> Maybe (BasicOp lore, Certificates)
forall lore. SubExp -> Maybe (BasicOp lore, Certificates)
subExpRes SubExp
e2
  | SubExp -> Bool
isCt1 SubExp
e2 = SubExp -> Maybe (BasicOp lore, Certificates)
forall lore. SubExp -> Maybe (BasicOp lore, Certificates)
subExpRes SubExp
e1

simplifyBinOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp FMul{} SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp lore, Certificates)
forall lore. SubExp -> Maybe (BasicOp lore, Certificates)
subExpRes SubExp
e1
  | SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp lore, Certificates)
forall lore. SubExp -> Maybe (BasicOp lore, Certificates)
subExpRes SubExp
e2
  | SubExp -> Bool
isCt1 SubExp
e1 = SubExp -> Maybe (BasicOp lore, Certificates)
forall lore. SubExp -> Maybe (BasicOp lore, Certificates)
subExpRes SubExp
e2
  | SubExp -> Bool
isCt1 SubExp
e2 = SubExp -> Maybe (BasicOp lore, Certificates)
forall lore. SubExp -> Maybe (BasicOp lore, Certificates)
subExpRes SubExp
e1

simplifyBinOp VarLookup lore
look SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp (SMod IntType
t) SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt1 SubExp
e2 = PrimValue -> Maybe (BasicOp lore, Certificates)
forall lore. PrimValue -> Maybe (BasicOp lore, Certificates)
constRes (PrimValue -> Maybe (BasicOp lore, Certificates))
-> PrimValue -> Maybe (BasicOp lore, Certificates)
forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue (IntValue -> PrimValue) -> IntValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ IntType -> Int -> IntValue
forall int. Integral int => IntType -> int -> IntValue
intValue IntType
t (Int
0 :: Int)
  | SubExp
e1 SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 = PrimValue -> Maybe (BasicOp lore, Certificates)
forall lore. PrimValue -> Maybe (BasicOp lore, Certificates)
constRes (PrimValue -> Maybe (BasicOp lore, Certificates))
-> PrimValue -> Maybe (BasicOp lore, Certificates)
forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue (IntValue -> PrimValue) -> IntValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ IntType -> Int -> IntValue
forall int. Integral int => IntType -> int -> IntValue
intValue IntType
t (Int
0 :: Int)
  | Var VName
v1 <- SubExp
e1,
    Just (BasicOp (BinOp SMod{} SubExp
_ SubExp
e4), Certificates
v1_cs) <- VarLookup lore
look VName
v1,
    SubExp
e4 SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 = (BasicOp lore, Certificates) -> Maybe (BasicOp lore, Certificates)
forall a. a -> Maybe a
Just (SubExp -> BasicOp lore
forall lore. SubExp -> BasicOp lore
SubExp SubExp
e1, Certificates
v1_cs)

simplifyBinOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp SDiv{} SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp lore, Certificates)
forall lore. SubExp -> Maybe (BasicOp lore, Certificates)
subExpRes SubExp
e1
  | SubExp -> Bool
isCt1 SubExp
e2 = SubExp -> Maybe (BasicOp lore, Certificates)
forall lore. SubExp -> Maybe (BasicOp lore, Certificates)
subExpRes SubExp
e1
  | SubExp -> Bool
isCt0 SubExp
e2 = Maybe (BasicOp lore, Certificates)
forall a. Maybe a
Nothing

simplifyBinOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp FDiv{} SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp lore, Certificates)
forall lore. SubExp -> Maybe (BasicOp lore, Certificates)
subExpRes SubExp
e1
  | SubExp -> Bool
isCt1 SubExp
e2 = SubExp -> Maybe (BasicOp lore, Certificates)
forall lore. SubExp -> Maybe (BasicOp lore, Certificates)
subExpRes SubExp
e1
  | SubExp -> Bool
isCt0 SubExp
e2 = Maybe (BasicOp lore, Certificates)
forall a. Maybe a
Nothing

simplifyBinOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp (SRem IntType
t) SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt1 SubExp
e2 = PrimValue -> Maybe (BasicOp lore, Certificates)
forall lore. PrimValue -> Maybe (BasicOp lore, Certificates)
constRes (PrimValue -> Maybe (BasicOp lore, Certificates))
-> PrimValue -> Maybe (BasicOp lore, Certificates)
forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue (IntValue -> PrimValue) -> IntValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ IntType -> Int -> IntValue
forall int. Integral int => IntType -> int -> IntValue
intValue IntType
t (Int
0 :: Int)
  | SubExp
e1 SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 = PrimValue -> Maybe (BasicOp lore, Certificates)
forall lore. PrimValue -> Maybe (BasicOp lore, Certificates)
constRes (PrimValue -> Maybe (BasicOp lore, Certificates))
-> PrimValue -> Maybe (BasicOp lore, Certificates)
forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue (IntValue -> PrimValue) -> IntValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ IntType -> Int -> IntValue
forall int. Integral int => IntType -> int -> IntValue
intValue IntType
t (Int
1 :: Int)

simplifyBinOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp SQuot{} SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt1 SubExp
e2 = SubExp -> Maybe (BasicOp lore, Certificates)
forall lore. SubExp -> Maybe (BasicOp lore, Certificates)
subExpRes SubExp
e1
  | SubExp -> Bool
isCt0 SubExp
e2 = Maybe (BasicOp lore, Certificates)
forall a. Maybe a
Nothing

simplifyBinOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp (FPow FloatType
t) SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp lore, Certificates)
forall lore. SubExp -> Maybe (BasicOp lore, Certificates)
subExpRes (SubExp -> Maybe (BasicOp lore, Certificates))
-> SubExp -> Maybe (BasicOp lore, Certificates)
forall a b. (a -> b) -> a -> b
$ FloatType -> Double -> SubExp
floatConst FloatType
t Double
1
  | SubExp -> Bool
isCt0 SubExp
e1 Bool -> Bool -> Bool
|| SubExp -> Bool
isCt1 SubExp
e1 Bool -> Bool -> Bool
|| SubExp -> Bool
isCt1 SubExp
e2 = SubExp -> Maybe (BasicOp lore, Certificates)
forall lore. SubExp -> Maybe (BasicOp lore, Certificates)
subExpRes SubExp
e1

simplifyBinOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp (Shl IntType
t) SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp lore, Certificates)
forall lore. SubExp -> Maybe (BasicOp lore, Certificates)
subExpRes SubExp
e1
  | SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp lore, Certificates)
forall lore. SubExp -> Maybe (BasicOp lore, Certificates)
subExpRes (SubExp -> Maybe (BasicOp lore, Certificates))
-> SubExp -> Maybe (BasicOp lore, Certificates)
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
t Integer
0

simplifyBinOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp AShr{} SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp lore, Certificates)
forall lore. SubExp -> Maybe (BasicOp lore, Certificates)
subExpRes SubExp
e1

simplifyBinOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp (And IntType
t) SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp lore, Certificates)
forall lore. SubExp -> Maybe (BasicOp lore, Certificates)
subExpRes (SubExp -> Maybe (BasicOp lore, Certificates))
-> SubExp -> Maybe (BasicOp lore, Certificates)
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
t Integer
0
  | SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp lore, Certificates)
forall lore. SubExp -> Maybe (BasicOp lore, Certificates)
subExpRes (SubExp -> Maybe (BasicOp lore, Certificates))
-> SubExp -> Maybe (BasicOp lore, Certificates)
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
t Integer
0
  | SubExp
e1 SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 = SubExp -> Maybe (BasicOp lore, Certificates)
forall lore. SubExp -> Maybe (BasicOp lore, Certificates)
subExpRes SubExp
e1

simplifyBinOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp Or{} SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp lore, Certificates)
forall lore. SubExp -> Maybe (BasicOp lore, Certificates)
subExpRes SubExp
e2
  | SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp lore, Certificates)
forall lore. SubExp -> Maybe (BasicOp lore, Certificates)
subExpRes SubExp
e1
  | SubExp
e1 SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 = SubExp -> Maybe (BasicOp lore, Certificates)
forall lore. SubExp -> Maybe (BasicOp lore, Certificates)
subExpRes SubExp
e1

simplifyBinOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp (Xor IntType
t) SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp lore, Certificates)
forall lore. SubExp -> Maybe (BasicOp lore, Certificates)
subExpRes SubExp
e2
  | SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp lore, Certificates)
forall lore. SubExp -> Maybe (BasicOp lore, Certificates)
subExpRes SubExp
e1
  | SubExp
e1 SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 = SubExp -> Maybe (BasicOp lore, Certificates)
forall lore. SubExp -> Maybe (BasicOp lore, Certificates)
subExpRes (SubExp -> Maybe (BasicOp lore, Certificates))
-> SubExp -> Maybe (BasicOp lore, Certificates)
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
t Integer
0

simplifyBinOp VarLookup lore
defOf SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp BinOp
LogAnd SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e1 = PrimValue -> Maybe (BasicOp lore, Certificates)
forall lore. PrimValue -> Maybe (BasicOp lore, Certificates)
constRes (PrimValue -> Maybe (BasicOp lore, Certificates))
-> PrimValue -> Maybe (BasicOp lore, Certificates)
forall a b. (a -> b) -> a -> b
$ Bool -> PrimValue
BoolValue Bool
False
  | SubExp -> Bool
isCt0 SubExp
e2 = PrimValue -> Maybe (BasicOp lore, Certificates)
forall lore. PrimValue -> Maybe (BasicOp lore, Certificates)
constRes (PrimValue -> Maybe (BasicOp lore, Certificates))
-> PrimValue -> Maybe (BasicOp lore, Certificates)
forall a b. (a -> b) -> a -> b
$ Bool -> PrimValue
BoolValue Bool
False
  | SubExp -> Bool
isCt1 SubExp
e1 = SubExp -> Maybe (BasicOp lore, Certificates)
forall lore. SubExp -> Maybe (BasicOp lore, Certificates)
subExpRes SubExp
e2
  | SubExp -> Bool
isCt1 SubExp
e2 = SubExp -> Maybe (BasicOp lore, Certificates)
forall lore. SubExp -> Maybe (BasicOp lore, Certificates)
subExpRes SubExp
e1
  | Var VName
v <- SubExp
e1,
    Just (BasicOp (UnOp UnOp
Not SubExp
e1'), Certificates
v_cs) <- VarLookup lore
defOf VName
v,
    SubExp
e1' SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 = (BasicOp lore, Certificates) -> Maybe (BasicOp lore, Certificates)
forall a. a -> Maybe a
Just (SubExp -> BasicOp lore
forall lore. SubExp -> BasicOp lore
SubExp (SubExp -> BasicOp lore) -> SubExp -> BasicOp lore
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ Bool -> PrimValue
BoolValue Bool
False, Certificates
v_cs)
  | Var VName
v <- SubExp
e2,
    Just (BasicOp (UnOp UnOp
Not SubExp
e2'), Certificates
v_cs) <- VarLookup lore
defOf VName
v,
    SubExp
e2' SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e1 = (BasicOp lore, Certificates) -> Maybe (BasicOp lore, Certificates)
forall a. a -> Maybe a
Just (SubExp -> BasicOp lore
forall lore. SubExp -> BasicOp lore
SubExp (SubExp -> BasicOp lore) -> SubExp -> BasicOp lore
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ Bool -> PrimValue
BoolValue Bool
False, Certificates
v_cs)

simplifyBinOp VarLookup lore
defOf SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp BinOp
LogOr SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp lore, Certificates)
forall lore. SubExp -> Maybe (BasicOp lore, Certificates)
subExpRes SubExp
e2
  | SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp lore, Certificates)
forall lore. SubExp -> Maybe (BasicOp lore, Certificates)
subExpRes SubExp
e1
  | SubExp -> Bool
isCt1 SubExp
e1 = PrimValue -> Maybe (BasicOp lore, Certificates)
forall lore. PrimValue -> Maybe (BasicOp lore, Certificates)
constRes (PrimValue -> Maybe (BasicOp lore, Certificates))
-> PrimValue -> Maybe (BasicOp lore, Certificates)
forall a b. (a -> b) -> a -> b
$ Bool -> PrimValue
BoolValue Bool
True
  | SubExp -> Bool
isCt1 SubExp
e2 = PrimValue -> Maybe (BasicOp lore, Certificates)
forall lore. PrimValue -> Maybe (BasicOp lore, Certificates)
constRes (PrimValue -> Maybe (BasicOp lore, Certificates))
-> PrimValue -> Maybe (BasicOp lore, Certificates)
forall a b. (a -> b) -> a -> b
$ Bool -> PrimValue
BoolValue Bool
True
  | Var VName
v <- SubExp
e1,
    Just (BasicOp (UnOp UnOp
Not SubExp
e1'), Certificates
v_cs) <- VarLookup lore
defOf VName
v,
    SubExp
e1' SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 = (BasicOp lore, Certificates) -> Maybe (BasicOp lore, Certificates)
forall a. a -> Maybe a
Just (SubExp -> BasicOp lore
forall lore. SubExp -> BasicOp lore
SubExp (SubExp -> BasicOp lore) -> SubExp -> BasicOp lore
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ Bool -> PrimValue
BoolValue Bool
True, Certificates
v_cs)
  | Var VName
v <- SubExp
e2,
    Just (BasicOp (UnOp UnOp
Not SubExp
e2'), Certificates
v_cs) <- VarLookup lore
defOf VName
v,
    SubExp
e2' SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e1 = (BasicOp lore, Certificates) -> Maybe (BasicOp lore, Certificates)
forall a. a -> Maybe a
Just (SubExp -> BasicOp lore
forall lore. SubExp -> BasicOp lore
SubExp (SubExp -> BasicOp lore) -> SubExp -> BasicOp lore
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ Bool -> PrimValue
BoolValue Bool
True, Certificates
v_cs)

simplifyBinOp VarLookup lore
defOf SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp (SMax IntType
it) SubExp
e1 SubExp
e2)
  | SubExp
e1 SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 =
      SubExp -> Maybe (BasicOp lore, Certificates)
forall lore. SubExp -> Maybe (BasicOp lore, Certificates)
subExpRes SubExp
e1
  | Var VName
v1 <- SubExp
e1,
    Just (BasicOp (BinOp (SMax IntType
_) SubExp
e1_1 SubExp
e1_2), Certificates
v1_cs) <- VarLookup lore
defOf VName
v1,
    SubExp
e1_1 SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 =
      (BasicOp lore, Certificates) -> Maybe (BasicOp lore, Certificates)
forall a. a -> Maybe a
Just (BinOp -> SubExp -> SubExp -> BasicOp lore
forall lore. BinOp -> SubExp -> SubExp -> BasicOp lore
BinOp (IntType -> BinOp
SMax IntType
it) SubExp
e1_2 SubExp
e2, Certificates
v1_cs)
  | Var VName
v1 <- SubExp
e1,
    Just (BasicOp (BinOp (SMax IntType
_) SubExp
e1_1 SubExp
e1_2), Certificates
v1_cs) <- VarLookup lore
defOf VName
v1,
    SubExp
e1_2 SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 =
      (BasicOp lore, Certificates) -> Maybe (BasicOp lore, Certificates)
forall a. a -> Maybe a
Just (BinOp -> SubExp -> SubExp -> BasicOp lore
forall lore. BinOp -> SubExp -> SubExp -> BasicOp lore
BinOp (IntType -> BinOp
SMax IntType
it) SubExp
e1_1 SubExp
e2, Certificates
v1_cs)
  | Var VName
v2 <- SubExp
e2,
    Just (BasicOp (BinOp (SMax IntType
_) SubExp
e2_1 SubExp
e2_2), Certificates
v2_cs) <- VarLookup lore
defOf VName
v2,
    SubExp
e2_1 SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e1 =
      (BasicOp lore, Certificates) -> Maybe (BasicOp lore, Certificates)
forall a. a -> Maybe a
Just (BinOp -> SubExp -> SubExp -> BasicOp lore
forall lore. BinOp -> SubExp -> SubExp -> BasicOp lore
BinOp (IntType -> BinOp
SMax IntType
it) SubExp
e2_2 SubExp
e1, Certificates
v2_cs)
  | Var VName
v2 <- SubExp
e2,
    Just (BasicOp (BinOp (SMax IntType
_) SubExp
e2_1 SubExp
e2_2), Certificates
v2_cs) <- VarLookup lore
defOf VName
v2,
    SubExp
e2_2 SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e1 =
      (BasicOp lore, Certificates) -> Maybe (BasicOp lore, Certificates)
forall a. a -> Maybe a
Just (BinOp -> SubExp -> SubExp -> BasicOp lore
forall lore. BinOp -> SubExp -> SubExp -> BasicOp lore
BinOp (IntType -> BinOp
SMax IntType
it) SubExp
e2_1 SubExp
e1, Certificates
v2_cs)

simplifyBinOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ BasicOp lore
_ = Maybe (BasicOp lore, Certificates)
forall a. Maybe a
Nothing

constRes :: PrimValue -> Maybe (BasicOp lore, Certificates)
constRes :: PrimValue -> Maybe (BasicOp lore, Certificates)
constRes = (BasicOp lore, Certificates) -> Maybe (BasicOp lore, Certificates)
forall a. a -> Maybe a
Just ((BasicOp lore, Certificates)
 -> Maybe (BasicOp lore, Certificates))
-> (PrimValue -> (BasicOp lore, Certificates))
-> PrimValue
-> Maybe (BasicOp lore, Certificates)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (,Certificates
forall a. Monoid a => a
mempty) (BasicOp lore -> (BasicOp lore, Certificates))
-> (PrimValue -> BasicOp lore)
-> PrimValue
-> (BasicOp lore, Certificates)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp lore
forall lore. SubExp -> BasicOp lore
SubExp (SubExp -> BasicOp lore)
-> (PrimValue -> SubExp) -> PrimValue -> BasicOp lore
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimValue -> SubExp
Constant

subExpRes :: SubExp -> Maybe (BasicOp lore, Certificates)
subExpRes :: SubExp -> Maybe (BasicOp lore, Certificates)
subExpRes = (BasicOp lore, Certificates) -> Maybe (BasicOp lore, Certificates)
forall a. a -> Maybe a
Just ((BasicOp lore, Certificates)
 -> Maybe (BasicOp lore, Certificates))
-> (SubExp -> (BasicOp lore, Certificates))
-> SubExp
-> Maybe (BasicOp lore, Certificates)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (,Certificates
forall a. Monoid a => a
mempty) (BasicOp lore -> (BasicOp lore, Certificates))
-> (SubExp -> BasicOp lore)
-> SubExp
-> (BasicOp lore, Certificates)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp lore
forall lore. SubExp -> BasicOp lore
SubExp

simplifyUnOp :: SimpleRule lore
simplifyUnOp :: SimpleRule lore
simplifyUnOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (UnOp UnOp
op (Constant PrimValue
v)) =
  PrimValue -> Maybe (BasicOp lore, Certificates)
forall lore. PrimValue -> Maybe (BasicOp lore, Certificates)
constRes (PrimValue -> Maybe (BasicOp lore, Certificates))
-> Maybe PrimValue -> Maybe (BasicOp lore, Certificates)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< UnOp -> PrimValue -> Maybe PrimValue
doUnOp UnOp
op PrimValue
v
simplifyUnOp VarLookup lore
defOf SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (UnOp UnOp
Not (Var VName
v))
  | Just (BasicOp (UnOp UnOp
Not SubExp
v2), Certificates
v_cs) <- VarLookup lore
defOf VName
v =
      (BasicOp lore, Certificates) -> Maybe (BasicOp lore, Certificates)
forall a. a -> Maybe a
Just (SubExp -> BasicOp lore
forall lore. SubExp -> BasicOp lore
SubExp SubExp
v2, Certificates
v_cs)
simplifyUnOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ BasicOp lore
_ =
  Maybe (BasicOp lore, Certificates)
forall a. Maybe a
Nothing

simplifyConvOp :: SimpleRule lore
simplifyConvOp :: SimpleRule lore
simplifyConvOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (ConvOp ConvOp
op (Constant PrimValue
v)) =
  PrimValue -> Maybe (BasicOp lore, Certificates)
forall lore. PrimValue -> Maybe (BasicOp lore, Certificates)
constRes (PrimValue -> Maybe (BasicOp lore, Certificates))
-> Maybe PrimValue -> Maybe (BasicOp lore, Certificates)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ConvOp -> PrimValue -> Maybe PrimValue
doConvOp ConvOp
op PrimValue
v
simplifyConvOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (ConvOp ConvOp
op SubExp
se)
  | (PrimType
from, PrimType
to) <- ConvOp -> (PrimType, PrimType)
convOpType ConvOp
op, PrimType
from PrimType -> PrimType -> Bool
forall a. Eq a => a -> a -> Bool
== PrimType
to =
  SubExp -> Maybe (BasicOp lore, Certificates)
forall lore. SubExp -> Maybe (BasicOp lore, Certificates)
subExpRes SubExp
se
simplifyConvOp VarLookup lore
lookupVar SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (ConvOp (SExt IntType
t2 IntType
t1) (Var VName
v))
  | Just (BasicOp (ConvOp (SExt IntType
t3 IntType
_) SubExp
se), Certificates
v_cs) <- VarLookup lore
lookupVar VName
v,
    IntType
t2 IntType -> IntType -> Bool
forall a. Ord a => a -> a -> Bool
>= IntType
t3 =
      (BasicOp lore, Certificates) -> Maybe (BasicOp lore, Certificates)
forall a. a -> Maybe a
Just (ConvOp -> SubExp -> BasicOp lore
forall lore. ConvOp -> SubExp -> BasicOp lore
ConvOp (IntType -> IntType -> ConvOp
SExt IntType
t3 IntType
t1) SubExp
se, Certificates
v_cs)
simplifyConvOp VarLookup lore
lookupVar SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (ConvOp (ZExt IntType
t2 IntType
t1) (Var VName
v))
  | Just (BasicOp (ConvOp (ZExt IntType
t3 IntType
_) SubExp
se), Certificates
v_cs) <- VarLookup lore
lookupVar VName
v,
    IntType
t2 IntType -> IntType -> Bool
forall a. Ord a => a -> a -> Bool
>= IntType
t3 =
      (BasicOp lore, Certificates) -> Maybe (BasicOp lore, Certificates)
forall a. a -> Maybe a
Just (ConvOp -> SubExp -> BasicOp lore
forall lore. ConvOp -> SubExp -> BasicOp lore
ConvOp (IntType -> IntType -> ConvOp
ZExt IntType
t3 IntType
t1) SubExp
se, Certificates
v_cs)
simplifyConvOp VarLookup lore
lookupVar SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (ConvOp (SIToFP IntType
t2 FloatType
t1) (Var VName
v))
  | Just (BasicOp (ConvOp (SExt IntType
t3 IntType
_) SubExp
se), Certificates
v_cs) <- VarLookup lore
lookupVar VName
v,
    IntType
t2 IntType -> IntType -> Bool
forall a. Ord a => a -> a -> Bool
>= IntType
t3 =
      (BasicOp lore, Certificates) -> Maybe (BasicOp lore, Certificates)
forall a. a -> Maybe a
Just (ConvOp -> SubExp -> BasicOp lore
forall lore. ConvOp -> SubExp -> BasicOp lore
ConvOp (IntType -> FloatType -> ConvOp
SIToFP IntType
t3 FloatType
t1) SubExp
se, Certificates
v_cs)
simplifyConvOp VarLookup lore
lookupVar SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (ConvOp (UIToFP IntType
t2 FloatType
t1) (Var VName
v))
  | Just (BasicOp (ConvOp (ZExt IntType
t3 IntType
_) SubExp
se), Certificates
v_cs) <- VarLookup lore
lookupVar VName
v,
    IntType
t2 IntType -> IntType -> Bool
forall a. Ord a => a -> a -> Bool
>= IntType
t3 =
      (BasicOp lore, Certificates) -> Maybe (BasicOp lore, Certificates)
forall a. a -> Maybe a
Just (ConvOp -> SubExp -> BasicOp lore
forall lore. ConvOp -> SubExp -> BasicOp lore
ConvOp (IntType -> FloatType -> ConvOp
UIToFP IntType
t3 FloatType
t1) SubExp
se, Certificates
v_cs)
simplifyConvOp VarLookup lore
lookupVar SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (ConvOp (FPConv FloatType
t2 FloatType
t1) (Var VName
v))
  | Just (BasicOp (ConvOp (FPConv FloatType
t3 FloatType
_) SubExp
se), Certificates
v_cs) <- VarLookup lore
lookupVar VName
v,
    FloatType
t2 FloatType -> FloatType -> Bool
forall a. Ord a => a -> a -> Bool
>= FloatType
t3 =
      (BasicOp lore, Certificates) -> Maybe (BasicOp lore, Certificates)
forall a. a -> Maybe a
Just (ConvOp -> SubExp -> BasicOp lore
forall lore. ConvOp -> SubExp -> BasicOp lore
ConvOp (FloatType -> FloatType -> ConvOp
FPConv FloatType
t3 FloatType
t1) SubExp
se, Certificates
v_cs)
simplifyConvOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ BasicOp lore
_ =
  Maybe (BasicOp lore, Certificates)
forall a. Maybe a
Nothing

-- If expression is true then just replace assertion.
simplifyAssert :: SimpleRule lore
simplifyAssert :: SimpleRule lore
simplifyAssert VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (Assert (Constant (BoolValue Bool
True)) ErrorMsg SubExp
_ (SrcLoc, [SrcLoc])
_) =
  PrimValue -> Maybe (BasicOp lore, Certificates)
forall lore. PrimValue -> Maybe (BasicOp lore, Certificates)
constRes PrimValue
Checked
simplifyAssert VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ BasicOp lore
_ =
  Maybe (BasicOp lore, Certificates)
forall a. Maybe a
Nothing

constantFoldPrimFun :: BinderOps lore => TopDownRuleGeneric lore
constantFoldPrimFun :: TopDownRuleGeneric lore
constantFoldPrimFun TopDown lore
_ (Let Pattern lore
pat (StmAux Certificates
cs ExpAttr lore
_) (Apply Name
fname [(SubExp, Diet)]
args [RetType lore]
_ (Safety, SrcLoc, [SrcLoc])
_))
  | Just [PrimValue]
args' <- ((SubExp, Diet) -> Maybe PrimValue)
-> [(SubExp, Diet)] -> Maybe [PrimValue]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SubExp -> Maybe PrimValue
isConst (SubExp -> Maybe PrimValue)
-> ((SubExp, Diet) -> SubExp) -> (SubExp, Diet) -> Maybe PrimValue
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SubExp, Diet) -> SubExp
forall a b. (a, b) -> a
fst) [(SubExp, Diet)]
args,
    Just ([PrimType]
_, PrimType
_, [PrimValue] -> Maybe PrimValue
fun) <- String
-> Map
     String ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
-> Maybe ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (Name -> String
nameToString Name
fname) Map String ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
primFuns,
    Just PrimValue
result <- [PrimValue] -> Maybe PrimValue
fun [PrimValue]
args' =
      RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ Certificates -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind_ Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp lore -> ExpT lore
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp lore -> ExpT lore) -> BasicOp lore -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp lore
forall lore. SubExp -> BasicOp lore
SubExp (SubExp -> BasicOp lore) -> SubExp -> BasicOp lore
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant PrimValue
result
  where isConst :: SubExp -> Maybe PrimValue
isConst (Constant PrimValue
v) = PrimValue -> Maybe PrimValue
forall a. a -> Maybe a
Just PrimValue
v
        isConst SubExp
_ = Maybe PrimValue
forall a. Maybe a
Nothing
constantFoldPrimFun TopDown lore
_ Stm lore
_ = Rule lore
forall lore. Rule lore
Skip

simplifyIndex :: BinderOps lore => BottomUpRuleBasicOp lore
simplifyIndex :: BottomUpRuleBasicOp lore
simplifyIndex (SymbolTable lore
vtable, UsageTable
used) pat :: Pattern lore
pat@(Pattern [] [PatElemT (LetAttr lore)
pe]) (StmAux Certificates
cs ExpAttr lore
_) (Index VName
idd Slice SubExp
inds)
  | Just RuleM lore IndexResult
m <- SymbolTable (Lore (RuleM lore))
-> (SubExp -> Maybe (TypeBase Shape NoUniqueness))
-> VName
-> Slice SubExp
-> Bool
-> Maybe (RuleM lore IndexResult)
forall (m :: * -> *).
MonadBinder m =>
SymbolTable (Lore m)
-> (SubExp -> Maybe (TypeBase Shape NoUniqueness))
-> VName
-> Slice SubExp
-> Bool
-> Maybe (m IndexResult)
simplifyIndexing SymbolTable lore
SymbolTable (Lore (RuleM lore))
vtable SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType VName
idd Slice SubExp
inds Bool
consumed = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
      IndexResult
res <- RuleM lore IndexResult
m
      case IndexResult
res of
        SubExpResult Certificates
cs' SubExp
se ->
          Certificates -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (Certificates
csCertificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<>Certificates
cs') (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
          [VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames_ (Pattern lore -> [VName]
forall attr. PatternT attr -> [VName]
patternNames Pattern lore
pat) (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp lore -> ExpT lore
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp lore -> ExpT lore) -> BasicOp lore -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp lore
forall lore. SubExp -> BasicOp lore
SubExp SubExp
se
        IndexResult Certificates
extra_cs VName
idd' Slice SubExp
inds' ->
          Certificates -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (Certificates
csCertificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<>Certificates
extra_cs) (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
          [VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames_ (Pattern lore -> [VName]
forall attr. PatternT attr -> [VName]
patternNames Pattern lore
pat) (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp lore -> ExpT lore
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp lore -> ExpT lore) -> BasicOp lore -> ExpT lore
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp lore
forall lore. VName -> Slice SubExp -> BasicOp lore
Index VName
idd' Slice SubExp
inds'
  where consumed :: Bool
consumed = PatElemT (LetAttr lore) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (LetAttr lore)
pe VName -> UsageTable -> Bool
`UT.isConsumed` UsageTable
used
        seType :: SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType (Var VName
v) = VName -> SymbolTable lore -> Maybe (TypeBase Shape NoUniqueness)
forall lore.
Attributes lore =>
VName -> SymbolTable lore -> Maybe (TypeBase Shape NoUniqueness)
ST.lookupType VName
v SymbolTable lore
vtable
        seType (Constant PrimValue
v) = TypeBase Shape NoUniqueness -> Maybe (TypeBase Shape NoUniqueness)
forall a. a -> Maybe a
Just (TypeBase Shape NoUniqueness
 -> Maybe (TypeBase Shape NoUniqueness))
-> TypeBase Shape NoUniqueness
-> Maybe (TypeBase Shape NoUniqueness)
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase Shape NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim (PrimType -> TypeBase Shape NoUniqueness)
-> PrimType -> TypeBase Shape NoUniqueness
forall a b. (a -> b) -> a -> b
$ PrimValue -> PrimType
primValueType PrimValue
v

simplifyIndex (SymbolTable lore, UsageTable)
_ Pattern lore
_ StmAux (ExpAttr lore)
_ BasicOp lore
_ = Rule lore
forall lore. Rule lore
Skip

data IndexResult = IndexResult Certificates VName (Slice SubExp)
                 | SubExpResult Certificates SubExp

simplifyIndexing :: MonadBinder m =>
                    ST.SymbolTable (Lore m) -> TypeLookup
                 -> VName -> Slice SubExp -> Bool
                 -> Maybe (m IndexResult)
simplifyIndexing :: SymbolTable (Lore m)
-> (SubExp -> Maybe (TypeBase Shape NoUniqueness))
-> VName
-> Slice SubExp
-> Bool
-> Maybe (m IndexResult)
simplifyIndexing SymbolTable (Lore m)
vtable SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType VName
idd Slice SubExp
inds Bool
consuming =
  case VName -> Maybe (BasicOp (Lore m), Certificates)
defOf VName
idd of
    Maybe (BasicOp (Lore m), Certificates)
_ | Just TypeBase Shape NoUniqueness
t <- SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType (VName -> SubExp
Var VName
idd),
        Slice SubExp
inds Slice SubExp -> Slice SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== TypeBase Shape NoUniqueness -> Slice SubExp -> Slice SubExp
fullSlice TypeBase Shape NoUniqueness
t [] ->
          m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ IndexResult -> m IndexResult
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certificates -> SubExp -> IndexResult
SubExpResult Certificates
forall a. Monoid a => a
mempty (SubExp -> IndexResult) -> SubExp -> IndexResult
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
idd

      | Just [SubExp]
inds' <- Slice SubExp -> Maybe [SubExp]
forall d. Slice d -> Maybe [d]
sliceIndices Slice SubExp
inds,
        Just (ST.Indexed Certificates
cs PrimExp VName
e) <- VName -> [SubExp] -> SymbolTable (Lore m) -> Maybe Indexed
forall lore.
Attributes lore =>
VName -> [SubExp] -> SymbolTable lore -> Maybe Indexed
ST.index VName
idd [SubExp]
inds' SymbolTable (Lore m)
vtable,
        PrimExp VName -> Bool
forall v. PrimExp v -> Bool
worthInlining PrimExp VName
e,
        (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> SymbolTable (Lore m) -> Bool
forall lore. VName -> SymbolTable lore -> Bool
`ST.elem` SymbolTable (Lore m)
vtable) (Certificates -> [VName]
unCertificates Certificates
cs) ->
          m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ Certificates -> SubExp -> IndexResult
SubExpResult Certificates
cs (SubExp -> IndexResult) -> m SubExp -> m IndexResult
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (String -> ExpT (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"index_primexp" (ExpT (Lore m) -> m SubExp) -> m (ExpT (Lore m)) -> m SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< PrimExp VName -> m (ExpT (Lore m))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp PrimExp VName
e)

      | Just [SubExp]
inds' <- Slice SubExp -> Maybe [SubExp]
forall d. Slice d -> Maybe [d]
sliceIndices Slice SubExp
inds,
        Just (ST.IndexedArray Certificates
cs VName
arr [PrimExp VName]
inds'') <- VName -> [SubExp] -> SymbolTable (Lore m) -> Maybe Indexed
forall lore.
Attributes lore =>
VName -> [SubExp] -> SymbolTable lore -> Maybe Indexed
ST.index VName
idd [SubExp]
inds' SymbolTable (Lore m)
vtable,
        (PrimExp VName -> Bool) -> [PrimExp VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all PrimExp VName -> Bool
forall v. PrimExp v -> Bool
worthInlining [PrimExp VName]
inds'',
        (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> SymbolTable (Lore m) -> Bool
forall lore. VName -> SymbolTable lore -> Bool
`ST.elem` SymbolTable (Lore m)
vtable) (Certificates -> [VName]
unCertificates Certificates
cs) ->
          m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ Certificates -> VName -> Slice SubExp -> IndexResult
IndexResult Certificates
cs VName
arr (Slice SubExp -> IndexResult)
-> ([SubExp] -> Slice SubExp) -> [SubExp] -> IndexResult
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SubExp -> DimIndex SubExp) -> [SubExp] -> Slice SubExp
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix ([SubExp] -> IndexResult) -> m [SubExp] -> m IndexResult
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
          (PrimExp VName -> m SubExp) -> [PrimExp VName] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (String -> ExpT (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"index_primexp" (ExpT (Lore m) -> m SubExp)
-> (PrimExp VName -> m (ExpT (Lore m)))
-> PrimExp VName
-> m SubExp
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< PrimExp VName -> m (ExpT (Lore m))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp) [PrimExp VName]
inds''

    Maybe (BasicOp (Lore m), Certificates)
Nothing -> Maybe (m IndexResult)
forall a. Maybe a
Nothing

    Just (SubExp (Var VName
v), Certificates
cs) -> m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ IndexResult -> m IndexResult
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certificates -> VName -> Slice SubExp -> IndexResult
IndexResult Certificates
cs VName
v Slice SubExp
inds

    Just (Iota SubExp
_ SubExp
x SubExp
s IntType
to_it, Certificates
cs)
      | [DimFix SubExp
ii] <- Slice SubExp
inds,
        Just (Prim (IntType IntType
from_it)) <- SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType SubExp
ii ->
          m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$
          (SubExp -> IndexResult) -> m SubExp -> m IndexResult
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Certificates -> SubExp -> IndexResult
SubExpResult Certificates
cs) (m SubExp -> m IndexResult) -> m SubExp -> m IndexResult
forall a b. (a -> b) -> a -> b
$ String -> ExpT (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"index_iota" (ExpT (Lore m) -> m SubExp)
-> (PrimExp VName -> m (ExpT (Lore m)))
-> PrimExp VName
-> m SubExp
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< PrimExp VName -> m (ExpT (Lore m))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (PrimExp VName -> m SubExp) -> PrimExp VName -> m SubExp
forall a b. (a -> b) -> a -> b
$
          ConvOp -> PrimExp VName -> PrimExp VName
forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp (IntType -> IntType -> ConvOp
SExt IntType
from_it IntType
to_it) (PrimType -> SubExp -> PrimExp VName
primExpFromSubExp (IntType -> PrimType
IntType IntType
from_it) SubExp
ii)
          PrimExp VName -> PrimExp VName -> PrimExp VName
forall a. Num a => a -> a -> a
* PrimType -> SubExp -> PrimExp VName
primExpFromSubExp (IntType -> PrimType
IntType IntType
to_it) SubExp
s
          PrimExp VName -> PrimExp VName -> PrimExp VName
forall a. Num a => a -> a -> a
+ PrimType -> SubExp -> PrimExp VName
primExpFromSubExp (IntType -> PrimType
IntType IntType
to_it) SubExp
x
      | [DimSlice SubExp
i_offset SubExp
i_n SubExp
i_stride] <- Slice SubExp
inds ->
          m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ do
            SubExp
i_offset' <- IntType -> SubExp -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
IntType -> SubExp -> m SubExp
asIntS IntType
to_it SubExp
i_offset
            SubExp
i_stride' <- IntType -> SubExp -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
IntType -> SubExp -> m SubExp
asIntS IntType
to_it SubExp
i_stride
            SubExp
i_offset'' <- String -> ExpT (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"iota_offset" (ExpT (Lore m) -> m SubExp)
-> (PrimExp VName -> m (ExpT (Lore m)))
-> PrimExp VName
-> m SubExp
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< PrimExp VName -> m (ExpT (Lore m))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (PrimExp VName -> m SubExp) -> PrimExp VName -> m SubExp
forall a b. (a -> b) -> a -> b
$
                          PrimType -> SubExp -> PrimExp VName
primExpFromSubExp (IntType -> PrimType
IntType IntType
to_it) SubExp
x PrimExp VName -> PrimExp VName -> PrimExp VName
forall a. Num a => a -> a -> a
+
                          PrimType -> SubExp -> PrimExp VName
primExpFromSubExp (IntType -> PrimType
IntType IntType
to_it) SubExp
s PrimExp VName -> PrimExp VName -> PrimExp VName
forall a. Num a => a -> a -> a
*
                          PrimType -> SubExp -> PrimExp VName
primExpFromSubExp (IntType -> PrimType
IntType IntType
to_it) SubExp
i_offset'
            SubExp
i_stride'' <- String -> ExpT (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"iota_offset" (ExpT (Lore m) -> m SubExp) -> ExpT (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$
                          BasicOp (Lore m) -> ExpT (Lore m)
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp (Lore m) -> ExpT (Lore m))
-> BasicOp (Lore m) -> ExpT (Lore m)
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp (Lore m)
forall lore. BinOp -> SubExp -> SubExp -> BasicOp lore
BinOp (IntType -> BinOp
Mul IntType
Int32) SubExp
s SubExp
i_stride'
            (SubExp -> IndexResult) -> m SubExp -> m IndexResult
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Certificates -> SubExp -> IndexResult
SubExpResult Certificates
cs) (m SubExp -> m IndexResult) -> m SubExp -> m IndexResult
forall a b. (a -> b) -> a -> b
$ String -> ExpT (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"slice_iota" (ExpT (Lore m) -> m SubExp) -> ExpT (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$
              BasicOp (Lore m) -> ExpT (Lore m)
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp (Lore m) -> ExpT (Lore m))
-> BasicOp (Lore m) -> ExpT (Lore m)
forall a b. (a -> b) -> a -> b
$ SubExp -> SubExp -> SubExp -> IntType -> BasicOp (Lore m)
forall lore. SubExp -> SubExp -> SubExp -> IntType -> BasicOp lore
Iota SubExp
i_n SubExp
i_offset'' SubExp
i_stride'' IntType
to_it

    Just (Rotate [SubExp]
offsets VName
a, Certificates
cs) -> m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ do
      [SubExp]
dims <- TypeBase Shape NoUniqueness -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (TypeBase Shape NoUniqueness -> [SubExp])
-> m (TypeBase Shape NoUniqueness) -> m [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> m (TypeBase Shape NoUniqueness)
forall lore (m :: * -> *).
HasScope lore m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
a
      let adjustI :: SubExp -> SubExp -> SubExp -> m SubExp
adjustI SubExp
i SubExp
o SubExp
d = do
            SubExp
i_p_o <- String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"i_p_o" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp (Lore m) -> Exp (Lore m)
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp (Lore m) -> Exp (Lore m))
-> BasicOp (Lore m) -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp (Lore m)
forall lore. BinOp -> SubExp -> SubExp -> BasicOp lore
BinOp (IntType -> BinOp
Add IntType
Int32) SubExp
i SubExp
o
            String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"rot_i" (BasicOp (Lore m) -> Exp (Lore m)
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp (Lore m) -> Exp (Lore m))
-> BasicOp (Lore m) -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp (Lore m)
forall lore. BinOp -> SubExp -> SubExp -> BasicOp lore
BinOp (IntType -> BinOp
SMod IntType
Int32) SubExp
i_p_o SubExp
d)
          adjust :: (DimIndex SubExp, SubExp, SubExp) -> f (DimIndex SubExp)
adjust (DimFix SubExp
i, SubExp
o, SubExp
d) =
            SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> f SubExp -> f (DimIndex SubExp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> SubExp -> SubExp -> f SubExp
forall (m :: * -> *).
MonadBinder m =>
SubExp -> SubExp -> SubExp -> m SubExp
adjustI SubExp
i SubExp
o SubExp
d
          adjust (DimSlice SubExp
i SubExp
n SubExp
s, SubExp
o, SubExp
d) =
            SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice (SubExp -> SubExp -> SubExp -> DimIndex SubExp)
-> f SubExp -> f (SubExp -> SubExp -> DimIndex SubExp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> SubExp -> SubExp -> f SubExp
forall (m :: * -> *).
MonadBinder m =>
SubExp -> SubExp -> SubExp -> m SubExp
adjustI SubExp
i SubExp
o SubExp
d f (SubExp -> SubExp -> DimIndex SubExp)
-> f SubExp -> f (SubExp -> DimIndex SubExp)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SubExp -> f SubExp
forall (f :: * -> *) a. Applicative f => a -> f a
pure SubExp
n f (SubExp -> DimIndex SubExp) -> f SubExp -> f (DimIndex SubExp)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SubExp -> f SubExp
forall (f :: * -> *) a. Applicative f => a -> f a
pure SubExp
s
      Certificates -> VName -> Slice SubExp -> IndexResult
IndexResult Certificates
cs VName
a (Slice SubExp -> IndexResult) -> m (Slice SubExp) -> m IndexResult
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((DimIndex SubExp, SubExp, SubExp) -> m (DimIndex SubExp))
-> [(DimIndex SubExp, SubExp, SubExp)] -> m (Slice SubExp)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (DimIndex SubExp, SubExp, SubExp) -> m (DimIndex SubExp)
forall (f :: * -> *).
MonadBinder f =>
(DimIndex SubExp, SubExp, SubExp) -> f (DimIndex SubExp)
adjust (Slice SubExp
-> [SubExp] -> [SubExp] -> [(DimIndex SubExp, SubExp, SubExp)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 Slice SubExp
inds [SubExp]
offsets [SubExp]
dims)

    Just (Index VName
aa Slice SubExp
ais, Certificates
cs) ->
      m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ Certificates -> VName -> Slice SubExp -> IndexResult
IndexResult Certificates
cs VName
aa (Slice SubExp -> IndexResult) -> m (Slice SubExp) -> m IndexResult
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Slice SubExp -> Slice SubExp -> m (Slice SubExp)
forall (m :: * -> *).
MonadBinder m =>
Slice SubExp -> Slice SubExp -> m (Slice SubExp)
sliceSlice Slice SubExp
ais Slice SubExp
inds

    Just (Replicate (Shape [SubExp
_]) (Var VName
vv), Certificates
cs)
      | [DimFix{}]   <- Slice SubExp
inds, Bool -> Bool
not Bool
consuming -> m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ IndexResult -> m IndexResult
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certificates -> SubExp -> IndexResult
SubExpResult Certificates
cs (SubExp -> IndexResult) -> SubExp -> IndexResult
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
vv
      | DimFix{}:Slice SubExp
is' <- Slice SubExp
inds, Bool -> Bool
not Bool
consuming -> m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ IndexResult -> m IndexResult
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certificates -> VName -> Slice SubExp -> IndexResult
IndexResult Certificates
cs VName
vv Slice SubExp
is'

    Just (Replicate (Shape [SubExp
_]) val :: SubExp
val@(Constant PrimValue
_), Certificates
cs)
      | [DimFix{}] <- Slice SubExp
inds, Bool -> Bool
not Bool
consuming -> m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ IndexResult -> m IndexResult
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certificates -> SubExp -> IndexResult
SubExpResult Certificates
cs SubExp
val

    Just (Replicate (Shape [SubExp]
ds) SubExp
v, Certificates
cs)
      | (Slice SubExp
ds_inds, Slice SubExp
rest_inds) <- Int -> Slice SubExp -> (Slice SubExp, Slice SubExp)
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
ds) Slice SubExp
inds,
        ([SubExp]
ds', Slice SubExp
ds_inds') <- [(SubExp, DimIndex SubExp)] -> ([SubExp], Slice SubExp)
forall a b. [(a, b)] -> ([a], [b])
unzip ([(SubExp, DimIndex SubExp)] -> ([SubExp], Slice SubExp))
-> [(SubExp, DimIndex SubExp)] -> ([SubExp], Slice SubExp)
forall a b. (a -> b) -> a -> b
$ (DimIndex SubExp -> Maybe (SubExp, DimIndex SubExp))
-> Slice SubExp -> [(SubExp, DimIndex SubExp)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe DimIndex SubExp -> Maybe (SubExp, DimIndex SubExp)
index Slice SubExp
ds_inds,
        [SubExp]
ds' [SubExp] -> [SubExp] -> Bool
forall a. Eq a => a -> a -> Bool
/= [SubExp]
ds ->
        m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ do
          VName
arr <- String -> ExpT (Lore m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"smaller_replicate" (ExpT (Lore m) -> m VName) -> ExpT (Lore m) -> m VName
forall a b. (a -> b) -> a -> b
$ BasicOp (Lore m) -> ExpT (Lore m)
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp (Lore m) -> ExpT (Lore m))
-> BasicOp (Lore m) -> ExpT (Lore m)
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp (Lore m)
forall lore. Shape -> SubExp -> BasicOp lore
Replicate ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
ds') SubExp
v
          IndexResult -> m IndexResult
forall (m :: * -> *) a. Monad m => a -> m a
return (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certificates -> VName -> Slice SubExp -> IndexResult
IndexResult Certificates
cs VName
arr (Slice SubExp -> IndexResult) -> Slice SubExp -> IndexResult
forall a b. (a -> b) -> a -> b
$ Slice SubExp
ds_inds' Slice SubExp -> Slice SubExp -> Slice SubExp
forall a. [a] -> [a] -> [a]
++ Slice SubExp
rest_inds
      where index :: DimIndex SubExp -> Maybe (SubExp, DimIndex SubExp)
index DimFix{} = Maybe (SubExp, DimIndex SubExp)
forall a. Maybe a
Nothing
            index (DimSlice SubExp
_ SubExp
n SubExp
s) = (SubExp, DimIndex SubExp) -> Maybe (SubExp, DimIndex SubExp)
forall a. a -> Maybe a
Just (SubExp
n, SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice (Int32 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int32
0::Int32)) SubExp
n SubExp
s)

    Just (Rearrange [Int]
perm VName
src, Certificates
cs)
       | [Int] -> Int
rearrangeReach [Int]
perm Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Slice SubExp -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ((DimIndex SubExp -> Bool) -> Slice SubExp -> Slice SubExp
forall a. (a -> Bool) -> [a] -> [a]
takeWhile DimIndex SubExp -> Bool
forall d. DimIndex d -> Bool
isIndex Slice SubExp
inds) ->
         let inds' :: Slice SubExp
inds' = [Int] -> Slice SubExp -> Slice SubExp
forall a. [Int] -> [a] -> [a]
rearrangeShape ([Int] -> [Int]
rearrangeInverse [Int]
perm) Slice SubExp
inds
         in m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ IndexResult -> m IndexResult
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certificates -> VName -> Slice SubExp -> IndexResult
IndexResult Certificates
cs VName
src Slice SubExp
inds'
      where isIndex :: DimIndex d -> Bool
isIndex DimFix{} = Bool
True
            isIndex DimIndex d
_          = Bool
False

    Just (Copy VName
src, Certificates
cs)
      | Just [SubExp]
dims <- TypeBase Shape NoUniqueness -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (TypeBase Shape NoUniqueness -> [SubExp])
-> Maybe (TypeBase Shape NoUniqueness) -> Maybe [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType (VName -> SubExp
Var VName
src),
        Slice SubExp -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Slice SubExp
inds Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
dims,
        Bool -> Bool
not Bool
consuming, VName -> SymbolTable (Lore m) -> Bool
forall lore. VName -> SymbolTable lore -> Bool
ST.available VName
src SymbolTable (Lore m)
vtable ->
          m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ IndexResult -> m IndexResult
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certificates -> VName -> Slice SubExp -> IndexResult
IndexResult Certificates
cs VName
src Slice SubExp
inds

    Just (Reshape ShapeChange SubExp
newshape VName
src, Certificates
cs)
      | Just [SubExp]
newdims <- ShapeChange SubExp -> Maybe [SubExp]
forall d. ShapeChange d -> Maybe [d]
shapeCoercion ShapeChange SubExp
newshape,
        Just [SubExp]
olddims <- TypeBase Shape NoUniqueness -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (TypeBase Shape NoUniqueness -> [SubExp])
-> Maybe (TypeBase Shape NoUniqueness) -> Maybe [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType (VName -> SubExp
Var VName
src),
        [Bool]
changed_dims <- (SubExp -> SubExp -> Bool) -> [SubExp] -> [SubExp] -> [Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
(/=) [SubExp]
newdims [SubExp]
olddims,
        Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
or ([Bool] -> Bool) -> [Bool] -> Bool
forall a b. (a -> b) -> a -> b
$ Int -> [Bool] -> [Bool]
forall a. Int -> [a] -> [a]
drop (Slice SubExp -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Slice SubExp
inds) [Bool]
changed_dims ->
        m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ IndexResult -> m IndexResult
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certificates -> VName -> Slice SubExp -> IndexResult
IndexResult Certificates
cs VName
src Slice SubExp
inds

      | Just [SubExp]
newdims <- ShapeChange SubExp -> Maybe [SubExp]
forall d. ShapeChange d -> Maybe [d]
shapeCoercion ShapeChange SubExp
newshape,
        Just [SubExp]
olddims <- TypeBase Shape NoUniqueness -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (TypeBase Shape NoUniqueness -> [SubExp])
-> Maybe (TypeBase Shape NoUniqueness) -> Maybe [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType (VName -> SubExp
Var VName
src),
        ShapeChange SubExp -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeChange SubExp
newshape Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Slice SubExp -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Slice SubExp
inds,
        [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
olddims Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
newdims ->
        m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ IndexResult -> m IndexResult
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certificates -> VName -> Slice SubExp -> IndexResult
IndexResult Certificates
cs VName
src Slice SubExp
inds

    Just (Reshape [DimChange SubExp
_] VName
v2, Certificates
cs)
      | Just [SubExp
_] <- TypeBase Shape NoUniqueness -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (TypeBase Shape NoUniqueness -> [SubExp])
-> Maybe (TypeBase Shape NoUniqueness) -> Maybe [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType (VName -> SubExp
Var VName
v2) ->
        m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ IndexResult -> m IndexResult
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certificates -> VName -> Slice SubExp -> IndexResult
IndexResult Certificates
cs VName
v2 Slice SubExp
inds

    Just (Concat Int
d VName
x [VName]
xs SubExp
_, Certificates
cs)
      | Just (Slice SubExp
ibef, DimFix SubExp
i, Slice SubExp
iaft) <- Int
-> Slice SubExp
-> Maybe (Slice SubExp, DimIndex SubExp, Slice SubExp)
forall int a. Integral int => int -> [a] -> Maybe ([a], a, [a])
focusNth Int
d Slice SubExp
inds,
        Just (Prim PrimType
res_t) <- (TypeBase Shape NoUniqueness
-> [SubExp] -> TypeBase Shape NoUniqueness
forall oldshape u.
TypeBase oldshape u -> [SubExp] -> TypeBase Shape u
`setArrayDims` Slice SubExp -> [SubExp]
forall d. Slice d -> [d]
sliceDims Slice SubExp
inds) (TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness)
-> Maybe (TypeBase Shape NoUniqueness)
-> Maybe (TypeBase Shape NoUniqueness)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
                             VName
-> SymbolTable (Lore m) -> Maybe (TypeBase Shape NoUniqueness)
forall lore.
Attributes lore =>
VName -> SymbolTable lore -> Maybe (TypeBase Shape NoUniqueness)
ST.lookupType VName
x SymbolTable (Lore m)
vtable -> m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ do
      SubExp
x_len <- Int -> TypeBase Shape NoUniqueness -> SubExp
forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
d (TypeBase Shape NoUniqueness -> SubExp)
-> m (TypeBase Shape NoUniqueness) -> m SubExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> m (TypeBase Shape NoUniqueness)
forall lore (m :: * -> *).
HasScope lore m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
x
      [SubExp]
xs_lens <- (VName -> m SubExp) -> [VName] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((TypeBase Shape NoUniqueness -> SubExp)
-> m (TypeBase Shape NoUniqueness) -> m SubExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Int -> TypeBase Shape NoUniqueness -> SubExp
forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
d) (m (TypeBase Shape NoUniqueness) -> m SubExp)
-> (VName -> m (TypeBase Shape NoUniqueness)) -> VName -> m SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> m (TypeBase Shape NoUniqueness)
forall lore (m :: * -> *).
HasScope lore m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType) [VName]
xs

      let add :: SubExp -> SubExp -> m (SubExp, SubExp)
add SubExp
n SubExp
m = do
            SubExp
added <- String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"index_concat_add" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp (Lore m) -> Exp (Lore m)
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp (Lore m) -> Exp (Lore m))
-> BasicOp (Lore m) -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp (Lore m)
forall lore. BinOp -> SubExp -> SubExp -> BasicOp lore
BinOp (IntType -> BinOp
Add IntType
Int32) SubExp
n SubExp
m
            (SubExp, SubExp) -> m (SubExp, SubExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp
added, SubExp
n)
      (SubExp
_, [SubExp]
starts) <- (SubExp -> SubExp -> m (SubExp, SubExp))
-> SubExp -> [SubExp] -> m (SubExp, [SubExp])
forall (m :: * -> *) acc x y.
Monad m =>
(acc -> x -> m (acc, y)) -> acc -> [x] -> m (acc, [y])
mapAccumLM SubExp -> SubExp -> m (SubExp, SubExp)
forall (m :: * -> *).
MonadBinder m =>
SubExp -> SubExp -> m (SubExp, SubExp)
add SubExp
x_len [SubExp]
xs_lens
      let xs_and_starts :: [(VName, SubExp)]
xs_and_starts = [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a]
reverse ([(VName, SubExp)] -> [(VName, SubExp)])
-> [(VName, SubExp)] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ [VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
xs [SubExp]
starts

      let mkBranch :: [(VName, SubExp)] -> m SubExp
mkBranch [] =
            String -> ExpT (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"index_concat" (ExpT (Lore m) -> m SubExp) -> ExpT (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp (Lore m) -> ExpT (Lore m)
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp (Lore m) -> ExpT (Lore m))
-> BasicOp (Lore m) -> ExpT (Lore m)
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp (Lore m)
forall lore. VName -> Slice SubExp -> BasicOp lore
Index VName
x (Slice SubExp -> BasicOp (Lore m))
-> Slice SubExp -> BasicOp (Lore m)
forall a b. (a -> b) -> a -> b
$ Slice SubExp
ibef Slice SubExp -> Slice SubExp -> Slice SubExp
forall a. [a] -> [a] -> [a]
++ SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
i DimIndex SubExp -> Slice SubExp -> Slice SubExp
forall a. a -> [a] -> [a]
: Slice SubExp
iaft
          mkBranch ((VName
x', SubExp
start):[(VName, SubExp)]
xs_and_starts') = do
            SubExp
cmp <- String -> ExpT (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"index_concat_cmp" (ExpT (Lore m) -> m SubExp) -> ExpT (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp (Lore m) -> ExpT (Lore m)
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp (Lore m) -> ExpT (Lore m))
-> BasicOp (Lore m) -> ExpT (Lore m)
forall a b. (a -> b) -> a -> b
$ CmpOp -> SubExp -> SubExp -> BasicOp (Lore m)
forall lore. CmpOp -> SubExp -> SubExp -> BasicOp lore
CmpOp (IntType -> CmpOp
CmpSle IntType
Int32) SubExp
start SubExp
i
            (SubExp
thisres, Stms (Lore m)
thisbnds) <- m SubExp -> m (SubExp, Stms (Lore m))
forall (m :: * -> *) a.
MonadBinder m =>
m a -> m (a, Stms (Lore m))
collectStms (m SubExp -> m (SubExp, Stms (Lore m)))
-> m SubExp -> m (SubExp, Stms (Lore m))
forall a b. (a -> b) -> a -> b
$ do
              SubExp
i' <- String -> ExpT (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"index_concat_i" (ExpT (Lore m) -> m SubExp) -> ExpT (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp (Lore m) -> ExpT (Lore m)
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp (Lore m) -> ExpT (Lore m))
-> BasicOp (Lore m) -> ExpT (Lore m)
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp (Lore m)
forall lore. BinOp -> SubExp -> SubExp -> BasicOp lore
BinOp (IntType -> BinOp
Sub IntType
Int32) SubExp
i SubExp
start
              String -> ExpT (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"index_concat" (ExpT (Lore m) -> m SubExp) -> ExpT (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp (Lore m) -> ExpT (Lore m)
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp (Lore m) -> ExpT (Lore m))
-> BasicOp (Lore m) -> ExpT (Lore m)
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp (Lore m)
forall lore. VName -> Slice SubExp -> BasicOp lore
Index VName
x' (Slice SubExp -> BasicOp (Lore m))
-> Slice SubExp -> BasicOp (Lore m)
forall a b. (a -> b) -> a -> b
$ Slice SubExp
ibef Slice SubExp -> Slice SubExp -> Slice SubExp
forall a. [a] -> [a] -> [a]
++ SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
i' DimIndex SubExp -> Slice SubExp -> Slice SubExp
forall a. a -> [a] -> [a]
: Slice SubExp
iaft
            BodyT (Lore m)
thisbody <- Stms (Lore m) -> [SubExp] -> m (BodyT (Lore m))
forall (m :: * -> *).
MonadBinder m =>
Stms (Lore m) -> [SubExp] -> m (Body (Lore m))
mkBodyM Stms (Lore m)
thisbnds [SubExp
thisres]
            (SubExp
altres, Stms (Lore m)
altbnds) <- m SubExp -> m (SubExp, Stms (Lore m))
forall (m :: * -> *) a.
MonadBinder m =>
m a -> m (a, Stms (Lore m))
collectStms (m SubExp -> m (SubExp, Stms (Lore m)))
-> m SubExp -> m (SubExp, Stms (Lore m))
forall a b. (a -> b) -> a -> b
$ [(VName, SubExp)] -> m SubExp
mkBranch [(VName, SubExp)]
xs_and_starts'
            BodyT (Lore m)
altbody <- Stms (Lore m) -> [SubExp] -> m (BodyT (Lore m))
forall (m :: * -> *).
MonadBinder m =>
Stms (Lore m) -> [SubExp] -> m (Body (Lore m))
mkBodyM Stms (Lore m)
altbnds [SubExp
altres]
            String -> ExpT (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"index_concat_branch" (ExpT (Lore m) -> m SubExp) -> ExpT (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ SubExp
-> BodyT (Lore m)
-> BodyT (Lore m)
-> IfAttr (BranchType (Lore m))
-> ExpT (Lore m)
forall lore.
SubExp
-> BodyT lore
-> BodyT lore
-> IfAttr (BranchType lore)
-> ExpT lore
If SubExp
cmp BodyT (Lore m)
thisbody BodyT (Lore m)
altbody (IfAttr (BranchType (Lore m)) -> ExpT (Lore m))
-> IfAttr (BranchType (Lore m)) -> ExpT (Lore m)
forall a b. (a -> b) -> a -> b
$
              [BranchType (Lore m)] -> IfSort -> IfAttr (BranchType (Lore m))
forall rt. [rt] -> IfSort -> IfAttr rt
IfAttr [PrimType -> BranchType (Lore m)
forall rt. IsBodyType rt => PrimType -> rt
primBodyType PrimType
res_t] IfSort
IfNormal
      Certificates -> SubExp -> IndexResult
SubExpResult Certificates
cs (SubExp -> IndexResult) -> m SubExp -> m IndexResult
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [(VName, SubExp)] -> m SubExp
mkBranch [(VName, SubExp)]
xs_and_starts

    Just (ArrayLit [SubExp]
ses TypeBase Shape NoUniqueness
_, Certificates
cs)
      | DimFix (Constant (IntValue (Int32Value Int32
i))) : Slice SubExp
inds' <- Slice SubExp
inds,
        Just SubExp
se <- Int32 -> [SubExp] -> Maybe SubExp
forall int a. Integral int => int -> [a] -> Maybe a
maybeNth Int32
i [SubExp]
ses ->
        case Slice SubExp
inds' of
          [] -> m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ IndexResult -> m IndexResult
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certificates -> SubExp -> IndexResult
SubExpResult Certificates
cs SubExp
se
          Slice SubExp
_ | Var VName
v2 <- SubExp
se  -> m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ IndexResult -> m IndexResult
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certificates -> VName -> Slice SubExp -> IndexResult
IndexResult Certificates
cs VName
v2 Slice SubExp
inds'
          Slice SubExp
_ -> Maybe (m IndexResult)
forall a. Maybe a
Nothing

    -- Indexing single-element arrays.  We know the index must be 0.
    Maybe (BasicOp (Lore m), Certificates)
_ | Just TypeBase Shape NoUniqueness
t <- SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType (SubExp -> Maybe (TypeBase Shape NoUniqueness))
-> SubExp -> Maybe (TypeBase Shape NoUniqueness)
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
idd, SubExp -> Bool
isCt1 (SubExp -> Bool) -> SubExp -> Bool
forall a b. (a -> b) -> a -> b
$ Int -> TypeBase Shape NoUniqueness -> SubExp
forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 TypeBase Shape NoUniqueness
t,
        DimFix SubExp
i : Slice SubExp
inds' <- Slice SubExp
inds, Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ SubExp -> Bool
isCt0 SubExp
i ->
          m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ IndexResult -> m IndexResult
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certificates -> VName -> Slice SubExp -> IndexResult
IndexResult Certificates
forall a. Monoid a => a
mempty VName
idd (Slice SubExp -> IndexResult) -> Slice SubExp -> IndexResult
forall a b. (a -> b) -> a -> b
$
          SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (Int32 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int32
0::Int32)) DimIndex SubExp -> Slice SubExp -> Slice SubExp
forall a. a -> [a] -> [a]
: Slice SubExp
inds'

    Maybe (BasicOp (Lore m), Certificates)
_ -> Maybe (m IndexResult)
forall a. Maybe a
Nothing

    where defOf :: VName -> Maybe (BasicOp (Lore m), Certificates)
defOf VName
v = do (BasicOp BasicOp (Lore m)
op, Certificates
def_cs) <- VName
-> SymbolTable (Lore m) -> Maybe (ExpT (Lore m), Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (Exp lore, Certificates)
ST.lookupExp VName
v SymbolTable (Lore m)
vtable
                       (BasicOp (Lore m), Certificates)
-> Maybe (BasicOp (Lore m), Certificates)
forall (m :: * -> *) a. Monad m => a -> m a
return (BasicOp (Lore m)
op, Certificates
def_cs)

          -- | A crude heuristic for determining when a PrimExp is
          -- worth inlining over keeping it in an array and reading it
          -- from memory.
          worthInlining :: PrimExp v -> Bool
worthInlining PrimExp v
e
            | Int -> PrimExp v -> Bool
forall v. Int -> PrimExp v -> Bool
primExpSizeAtLeast Int
20 PrimExp v
e = Bool
False -- totally ad-hoc.
            | Bool
otherwise = PrimExp v -> Bool
forall v. PrimExp v -> Bool
worthInlining' PrimExp v
e
          worthInlining' :: PrimExp v -> Bool
worthInlining' (BinOpExp Pow{} PrimExp v
_ PrimExp v
_) = Bool
False
          worthInlining' (BinOpExp FPow{} PrimExp v
_ PrimExp v
_) = Bool
False
          worthInlining' (BinOpExp BinOp
_ PrimExp v
x PrimExp v
y) = PrimExp v -> Bool
worthInlining' PrimExp v
x Bool -> Bool -> Bool
&& PrimExp v -> Bool
worthInlining' PrimExp v
y
          worthInlining' (CmpOpExp CmpOp
_ PrimExp v
x PrimExp v
y) = PrimExp v -> Bool
worthInlining' PrimExp v
x Bool -> Bool -> Bool
&& PrimExp v -> Bool
worthInlining' PrimExp v
y
          worthInlining' (ConvOpExp ConvOp
_ PrimExp v
x) = PrimExp v -> Bool
worthInlining' PrimExp v
x
          worthInlining' (UnOpExp UnOp
_ PrimExp v
x) = PrimExp v -> Bool
worthInlining' PrimExp v
x
          worthInlining' FunExp{} = Bool
False
          worthInlining' PrimExp v
_ = Bool
True

sliceSlice :: MonadBinder m =>
              [DimIndex SubExp] -> [DimIndex SubExp] -> m [DimIndex SubExp]
sliceSlice :: Slice SubExp -> Slice SubExp -> m (Slice SubExp)
sliceSlice (DimFix SubExp
j:Slice SubExp
js') Slice SubExp
is' = (SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
jDimIndex SubExp -> Slice SubExp -> Slice SubExp
forall a. a -> [a] -> [a]
:) (Slice SubExp -> Slice SubExp)
-> m (Slice SubExp) -> m (Slice SubExp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Slice SubExp -> Slice SubExp -> m (Slice SubExp)
forall (m :: * -> *).
MonadBinder m =>
Slice SubExp -> Slice SubExp -> m (Slice SubExp)
sliceSlice Slice SubExp
js' Slice SubExp
is'
sliceSlice (DimSlice SubExp
j SubExp
_ SubExp
s:Slice SubExp
js') (DimFix SubExp
i:Slice SubExp
is') = do
  SubExp
i_t_s <- String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"j_t_s" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp (Lore m) -> Exp (Lore m)
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp (Lore m) -> Exp (Lore m))
-> BasicOp (Lore m) -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp (Lore m)
forall lore. BinOp -> SubExp -> SubExp -> BasicOp lore
BinOp (IntType -> BinOp
Mul IntType
Int32) SubExp
i SubExp
s
  SubExp
j_p_i_t_s <- String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"j_p_i_t_s" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp (Lore m) -> Exp (Lore m)
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp (Lore m) -> Exp (Lore m))
-> BasicOp (Lore m) -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp (Lore m)
forall lore. BinOp -> SubExp -> SubExp -> BasicOp lore
BinOp (IntType -> BinOp
Add IntType
Int32) SubExp
j SubExp
i_t_s
  (SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
j_p_i_t_sDimIndex SubExp -> Slice SubExp -> Slice SubExp
forall a. a -> [a] -> [a]
:) (Slice SubExp -> Slice SubExp)
-> m (Slice SubExp) -> m (Slice SubExp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Slice SubExp -> Slice SubExp -> m (Slice SubExp)
forall (m :: * -> *).
MonadBinder m =>
Slice SubExp -> Slice SubExp -> m (Slice SubExp)
sliceSlice Slice SubExp
js' Slice SubExp
is'
sliceSlice (DimSlice SubExp
j SubExp
_ SubExp
s0:Slice SubExp
js') (DimSlice SubExp
i SubExp
n SubExp
s1:Slice SubExp
is') = do
  SubExp
s0_t_i <- String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"s0_t_i" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp (Lore m) -> Exp (Lore m)
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp (Lore m) -> Exp (Lore m))
-> BasicOp (Lore m) -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp (Lore m)
forall lore. BinOp -> SubExp -> SubExp -> BasicOp lore
BinOp (IntType -> BinOp
Mul IntType
Int32) SubExp
s0 SubExp
i
  SubExp
j_p_s0_t_i <- String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"j_p_s0_t_i" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp (Lore m) -> Exp (Lore m)
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp (Lore m) -> Exp (Lore m))
-> BasicOp (Lore m) -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp (Lore m)
forall lore. BinOp -> SubExp -> SubExp -> BasicOp lore
BinOp (IntType -> BinOp
Add IntType
Int32) SubExp
j SubExp
s0_t_i
  (SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice SubExp
j_p_s0_t_i SubExp
n SubExp
s1DimIndex SubExp -> Slice SubExp -> Slice SubExp
forall a. a -> [a] -> [a]
:) (Slice SubExp -> Slice SubExp)
-> m (Slice SubExp) -> m (Slice SubExp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Slice SubExp -> Slice SubExp -> m (Slice SubExp)
forall (m :: * -> *).
MonadBinder m =>
Slice SubExp -> Slice SubExp -> m (Slice SubExp)
sliceSlice Slice SubExp
js' Slice SubExp
is'
sliceSlice Slice SubExp
_ Slice SubExp
_ = Slice SubExp -> m (Slice SubExp)
forall (m :: * -> *) a. Monad m => a -> m a
return []


simplifyConcat :: BinderOps lore => BottomUpRuleBasicOp lore

-- concat@1(transpose(x),transpose(y)) == transpose(concat@0(x,y))
simplifyConcat :: BottomUpRuleBasicOp lore
simplifyConcat (SymbolTable lore
vtable, UsageTable
_) Pattern lore
pat StmAux (ExpAttr lore)
_ (Concat Int
i VName
x [VName]
xs SubExp
new_d)
  | Just Int
r <- TypeBase Shape NoUniqueness -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank (TypeBase Shape NoUniqueness -> Int)
-> Maybe (TypeBase Shape NoUniqueness) -> Maybe Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> SymbolTable lore -> Maybe (TypeBase Shape NoUniqueness)
forall lore.
Attributes lore =>
VName -> SymbolTable lore -> Maybe (TypeBase Shape NoUniqueness)
ST.lookupType VName
x SymbolTable lore
vtable,
    let perm :: [Int]
perm = [Int
i] [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [Int
0..Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1..Int
rInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1],
    Just (VName
x',Certificates
x_cs) <- [Int] -> VName -> Maybe (VName, Certificates)
transposedBy [Int]
perm VName
x,
    Just ([VName]
xs',[Certificates]
xs_cs) <- [(VName, Certificates)] -> ([VName], [Certificates])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, Certificates)] -> ([VName], [Certificates]))
-> Maybe [(VName, Certificates)] -> Maybe ([VName], [Certificates])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VName -> Maybe (VName, Certificates))
-> [VName] -> Maybe [(VName, Certificates)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ([Int] -> VName -> Maybe (VName, Certificates)
transposedBy [Int]
perm) [VName]
xs = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
      VName
concat_rearrange <-
        Certificates -> RuleM lore VName -> RuleM lore VName
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (Certificates
x_csCertificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<>[Certificates] -> Certificates
forall a. Monoid a => [a] -> a
mconcat [Certificates]
xs_cs) (RuleM lore VName -> RuleM lore VName)
-> RuleM lore VName -> RuleM lore VName
forall a b. (a -> b) -> a -> b
$
        String -> Exp (Lore (RuleM lore)) -> RuleM lore VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"concat_rearrange" (Exp (Lore (RuleM lore)) -> RuleM lore VName)
-> Exp (Lore (RuleM lore)) -> RuleM lore VName
forall a b. (a -> b) -> a -> b
$ BasicOp lore -> ExpT lore
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp lore -> ExpT lore) -> BasicOp lore -> ExpT lore
forall a b. (a -> b) -> a -> b
$ Int -> VName -> [VName] -> SubExp -> BasicOp lore
forall lore. Int -> VName -> [VName] -> SubExp -> BasicOp lore
Concat Int
0 VName
x' [VName]
xs' SubExp
new_d
      Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind_ Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp lore -> ExpT lore
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp lore -> ExpT lore) -> BasicOp lore -> ExpT lore
forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp lore
forall lore. [Int] -> VName -> BasicOp lore
Rearrange [Int]
perm VName
concat_rearrange
  where transposedBy :: [Int] -> VName -> Maybe (VName, Certificates)
transposedBy [Int]
perm1 VName
v =
          case VName -> SymbolTable lore -> Maybe (ExpT lore, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (Exp lore, Certificates)
ST.lookupExp VName
v SymbolTable lore
vtable of
            Just (BasicOp (Rearrange [Int]
perm2 VName
v'), Certificates
vcs)
              | [Int]
perm1 [Int] -> [Int] -> Bool
forall a. Eq a => a -> a -> Bool
== [Int]
perm2 -> (VName, Certificates) -> Maybe (VName, Certificates)
forall a. a -> Maybe a
Just (VName
v', Certificates
vcs)
            Maybe (ExpT lore, Certificates)
_ -> Maybe (VName, Certificates)
forall a. Maybe a
Nothing

-- concat xs (concat ys zs) == concat xs ys zs
simplifyConcat (SymbolTable lore
vtable, UsageTable
_) Pattern lore
pat (StmAux Certificates
cs ExpAttr lore
_) (Concat Int
i VName
x [VName]
xs SubExp
new_d)
  | VName
x' VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
/= VName
x Bool -> Bool -> Bool
|| [[VName]] -> [VName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[VName]]
xs' [VName] -> [VName] -> Bool
forall a. Eq a => a -> a -> Bool
/= [VName]
xs = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$
      Certificates -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (Certificates
csCertificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<>Certificates
x_csCertificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<>[Certificates] -> Certificates
forall a. Monoid a => [a] -> a
mconcat [Certificates]
xs_cs) (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
      Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind_ Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp lore -> ExpT lore
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp lore -> ExpT lore) -> BasicOp lore -> ExpT lore
forall a b. (a -> b) -> a -> b
$ Int -> VName -> [VName] -> SubExp -> BasicOp lore
forall lore. Int -> VName -> [VName] -> SubExp -> BasicOp lore
Concat Int
i VName
x' ([VName]
zs[VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++[[VName]] -> [VName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[VName]]
xs') SubExp
new_d
  where (VName
x':[VName]
zs, Certificates
x_cs) = VName -> ([VName], Certificates)
isConcat VName
x
        ([[VName]]
xs', [Certificates]
xs_cs) = [([VName], Certificates)] -> ([[VName]], [Certificates])
forall a b. [(a, b)] -> ([a], [b])
unzip ([([VName], Certificates)] -> ([[VName]], [Certificates]))
-> [([VName], Certificates)] -> ([[VName]], [Certificates])
forall a b. (a -> b) -> a -> b
$ (VName -> ([VName], Certificates))
-> [VName] -> [([VName], Certificates)]
forall a b. (a -> b) -> [a] -> [b]
map VName -> ([VName], Certificates)
isConcat [VName]
xs
        isConcat :: VName -> ([VName], Certificates)
isConcat VName
v = case VName -> SymbolTable lore -> Maybe (BasicOp lore, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (BasicOp lore, Certificates)
ST.lookupBasicOp VName
v SymbolTable lore
vtable of
                       Just (Concat Int
j VName
y [VName]
ys SubExp
_, Certificates
v_cs) | Int
j Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
i -> (VName
y VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
ys, Certificates
v_cs)
                       Maybe (BasicOp lore, Certificates)
_ -> ([VName
v], Certificates
forall a. Monoid a => a
mempty)

-- If concatenating a bunch of array literals (or equivalent
-- replicate), just construct the array literal instead.
simplifyConcat (SymbolTable lore
vtable, UsageTable
_) Pattern lore
pat (StmAux Certificates
cs ExpAttr lore
_) (Concat Int
0 VName
x [VName]
xs SubExp
_)
  | Just ([[SubExp]]
vs, [Certificates]
vcs) <- [([SubExp], Certificates)] -> ([[SubExp]], [Certificates])
forall a b. [(a, b)] -> ([a], [b])
unzip ([([SubExp], Certificates)] -> ([[SubExp]], [Certificates]))
-> Maybe [([SubExp], Certificates)]
-> Maybe ([[SubExp]], [Certificates])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VName -> Maybe ([SubExp], Certificates))
-> [VName] -> Maybe [([SubExp], Certificates)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> Maybe ([SubExp], Certificates)
isArrayLit (VName
xVName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
:[VName]
xs) = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
      TypeBase Shape NoUniqueness
rt <- TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
forall shape u.
ArrayShape shape =>
TypeBase shape u -> TypeBase shape u
rowType (TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness)
-> RuleM lore (TypeBase Shape NoUniqueness)
-> RuleM lore (TypeBase Shape NoUniqueness)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> RuleM lore (TypeBase Shape NoUniqueness)
forall lore (m :: * -> *).
HasScope lore m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
x
      Certificates -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (Certificates
cs Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> [Certificates] -> Certificates
forall a. Monoid a => [a] -> a
mconcat [Certificates]
vcs) (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
        Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind_ Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp lore -> ExpT lore
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp lore -> ExpT lore) -> BasicOp lore -> ExpT lore
forall a b. (a -> b) -> a -> b
$ [SubExp] -> TypeBase Shape NoUniqueness -> BasicOp lore
forall lore.
[SubExp] -> TypeBase Shape NoUniqueness -> BasicOp lore
ArrayLit ([[SubExp]] -> [SubExp]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[SubExp]]
vs) TypeBase Shape NoUniqueness
rt
      where isArrayLit :: VName -> Maybe ([SubExp], Certificates)
isArrayLit VName
v
              | Just (Replicate Shape
shape SubExp
se, Certificates
vcs) <- VName -> SymbolTable lore -> Maybe (BasicOp lore, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (BasicOp lore, Certificates)
ST.lookupBasicOp VName
v SymbolTable lore
vtable,
                Shape -> Bool
unitShape Shape
shape = ([SubExp], Certificates) -> Maybe ([SubExp], Certificates)
forall a. a -> Maybe a
Just ([SubExp
se], Certificates
vcs)
              | Just (ArrayLit [SubExp]
ses TypeBase Shape NoUniqueness
_, Certificates
vcs) <- VName -> SymbolTable lore -> Maybe (BasicOp lore, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (BasicOp lore, Certificates)
ST.lookupBasicOp VName
v SymbolTable lore
vtable =
                  ([SubExp], Certificates) -> Maybe ([SubExp], Certificates)
forall a. a -> Maybe a
Just ([SubExp]
ses, Certificates
vcs)
              | Bool
otherwise =
                  Maybe ([SubExp], Certificates)
forall a. Maybe a
Nothing

            unitShape :: Shape -> Bool
unitShape = (Shape -> Shape -> Bool
forall a. Eq a => a -> a -> Bool
==[SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue (IntValue -> PrimValue) -> IntValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ Int32 -> IntValue
Int32Value Int32
1])

simplifyConcat (SymbolTable lore, UsageTable)
_ Pattern lore
_ StmAux (ExpAttr lore)
_  BasicOp lore
_ = Rule lore
forall lore. Rule lore
Skip

ruleIf :: BinderOps lore => TopDownRuleIf lore

ruleIf :: TopDownRuleIf lore
ruleIf TopDown lore
_ Pattern lore
pat StmAux (ExpAttr lore)
_ (SubExp
e1, BodyT lore
tb, BodyT lore
fb, IfAttr [BranchType lore]
_ IfSort
ifsort)
  | Just BodyT lore
branch <- Maybe (BodyT lore)
checkBranch,
    IfSort
ifsort IfSort -> IfSort -> Bool
forall a. Eq a => a -> a -> Bool
/= IfSort
IfFallback Bool -> Bool -> Bool
|| SubExp -> Bool
isCt1 SubExp
e1 = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
  let ses :: [SubExp]
ses = BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT lore
branch
  Stms (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms (Stms (Lore (RuleM lore)) -> RuleM lore ())
-> Stms (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BodyT lore -> Stms lore
forall lore. BodyT lore -> Stms lore
bodyStms BodyT lore
branch
  [RuleM lore ()] -> RuleM lore ()
forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, Monad m) =>
t (m a) -> m ()
sequence_ [ [VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames_ [PatElemT (LetAttr lore) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (LetAttr lore)
p] (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp lore -> ExpT lore
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp lore -> ExpT lore) -> BasicOp lore -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp lore
forall lore. SubExp -> BasicOp lore
SubExp SubExp
se
            | (PatElemT (LetAttr lore)
p,SubExp
se) <- [PatElemT (LetAttr lore)]
-> [SubExp] -> [(PatElemT (LetAttr lore), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pattern lore -> [PatElemT (LetAttr lore)]
forall attr. PatternT attr -> [PatElemT attr]
patternElements Pattern lore
pat) [SubExp]
ses]

  where checkBranch :: Maybe (BodyT lore)
checkBranch
          | SubExp -> Bool
isCt1 SubExp
e1  = BodyT lore -> Maybe (BodyT lore)
forall a. a -> Maybe a
Just BodyT lore
tb
          | SubExp -> Bool
isCt0 SubExp
e1  = BodyT lore -> Maybe (BodyT lore)
forall a. a -> Maybe a
Just BodyT lore
fb
          | Bool
otherwise = Maybe (BodyT lore)
forall a. Maybe a
Nothing

-- IMPROVE: the following two rules can be generalised to work in more
-- cases, especially when the branches have bindings, or return more
-- than one value.
--
-- if c then True else v == c || v
ruleIf TopDown lore
_ Pattern lore
pat StmAux (ExpAttr lore)
_
  (SubExp
cond, Body BodyAttr lore
_ Stms lore
tstms [Constant (BoolValue Bool
True)],
         Body BodyAttr lore
_ Stms lore
fstms [SubExp
se], IfAttr [BranchType lore]
ts IfSort
_)
  | Stms lore -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Stms lore
tstms, Stms lore -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Stms lore
fstms, [Prim PrimType
Bool] <- [BranchType lore] -> [ExtType]
forall rt. IsBodyType rt => [rt] -> [ExtType]
bodyTypeValues [BranchType lore]
ts =
      RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind_ Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp lore -> ExpT lore
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp lore -> ExpT lore) -> BasicOp lore -> ExpT lore
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp lore
forall lore. BinOp -> SubExp -> SubExp -> BasicOp lore
BinOp BinOp
LogOr SubExp
cond SubExp
se

-- When type(x)==bool, if c then x else y == (c && x) || (!c && y)
ruleIf TopDown lore
_ Pattern lore
pat StmAux (ExpAttr lore)
_ (SubExp
cond, BodyT lore
tb, BodyT lore
fb, IfAttr [BranchType lore]
ts IfSort
_)
  | Body BodyAttr lore
_ Stms lore
tstms [SubExp
tres] <- BodyT lore
tb,
    Body BodyAttr lore
_ Stms lore
fstms [SubExp
fres] <- BodyT lore
fb,
    (Stm lore -> Bool) -> Stms lore -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (ExpT lore -> Bool
forall lore. IsOp (Op lore) => Exp lore -> Bool
safeExp (ExpT lore -> Bool) -> (Stm lore -> ExpT lore) -> Stm lore -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm lore -> ExpT lore
forall lore. Stm lore -> Exp lore
stmExp) (Stms lore -> Bool) -> Stms lore -> Bool
forall a b. (a -> b) -> a -> b
$ Stms lore
tstms Stms lore -> Stms lore -> Stms lore
forall a. Semigroup a => a -> a -> a
<> Stms lore
fstms,
    (ExtType -> Bool) -> [ExtType] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (ExtType -> ExtType -> Bool
forall a. Eq a => a -> a -> Bool
==PrimType -> ExtType
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
Bool) ([ExtType] -> Bool) -> [ExtType] -> Bool
forall a b. (a -> b) -> a -> b
$ [BranchType lore] -> [ExtType]
forall rt. IsBodyType rt => [rt] -> [ExtType]
bodyTypeValues [BranchType lore]
ts = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
  Stms (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms Stms lore
Stms (Lore (RuleM lore))
tstms
  Stms (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms Stms lore
Stms (Lore (RuleM lore))
fstms
  ExpT lore
e <- BinOp
-> RuleM lore (Exp (Lore (RuleM lore)))
-> RuleM lore (Exp (Lore (RuleM lore)))
-> RuleM lore (Exp (Lore (RuleM lore)))
forall (m :: * -> *).
MonadBinder m =>
BinOp -> m (Exp (Lore m)) -> m (Exp (Lore m)) -> m (Exp (Lore m))
eBinOp BinOp
LogOr (ExpT lore -> RuleM lore (ExpT lore)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpT lore -> RuleM lore (ExpT lore))
-> ExpT lore -> RuleM lore (ExpT lore)
forall a b. (a -> b) -> a -> b
$ BasicOp lore -> ExpT lore
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp lore -> ExpT lore) -> BasicOp lore -> ExpT lore
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp lore
forall lore. BinOp -> SubExp -> SubExp -> BasicOp lore
BinOp BinOp
LogAnd SubExp
cond SubExp
tres)
                    (BinOp
-> RuleM lore (Exp (Lore (RuleM lore)))
-> RuleM lore (Exp (Lore (RuleM lore)))
-> RuleM lore (Exp (Lore (RuleM lore)))
forall (m :: * -> *).
MonadBinder m =>
BinOp -> m (Exp (Lore m)) -> m (Exp (Lore m)) -> m (Exp (Lore m))
eBinOp BinOp
LogAnd (ExpT lore -> RuleM lore (ExpT lore)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpT lore -> RuleM lore (ExpT lore))
-> ExpT lore -> RuleM lore (ExpT lore)
forall a b. (a -> b) -> a -> b
$ BasicOp lore -> ExpT lore
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp lore -> ExpT lore) -> BasicOp lore -> ExpT lore
forall a b. (a -> b) -> a -> b
$ UnOp -> SubExp -> BasicOp lore
forall lore. UnOp -> SubExp -> BasicOp lore
UnOp UnOp
Not SubExp
cond)
                     (ExpT lore -> RuleM lore (ExpT lore)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpT lore -> RuleM lore (ExpT lore))
-> ExpT lore -> RuleM lore (ExpT lore)
forall a b. (a -> b) -> a -> b
$ BasicOp lore -> ExpT lore
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp lore -> ExpT lore) -> BasicOp lore -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp lore
forall lore. SubExp -> BasicOp lore
SubExp SubExp
fres))
  Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind_ Pattern lore
Pattern (Lore (RuleM lore))
pat ExpT lore
Exp (Lore (RuleM lore))
e

ruleIf TopDown lore
_ Pattern lore
pat StmAux (ExpAttr lore)
_ (SubExp
_, BodyT lore
tbranch, BodyT lore
_, IfAttr [BranchType lore]
_ IfSort
IfFallback)
  | [VName] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$ Pattern lore -> [VName]
forall attr. PatternT attr -> [VName]
patternContextNames Pattern lore
pat,
    (Stm lore -> Bool) -> Stms lore -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (ExpT lore -> Bool
forall lore. IsOp (Op lore) => Exp lore -> Bool
safeExp (ExpT lore -> Bool) -> (Stm lore -> ExpT lore) -> Stm lore -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm lore -> ExpT lore
forall lore. Stm lore -> Exp lore
stmExp) (Stms lore -> Bool) -> Stms lore -> Bool
forall a b. (a -> b) -> a -> b
$ BodyT lore -> Stms lore
forall lore. BodyT lore -> Stms lore
bodyStms BodyT lore
tbranch = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
      let ses :: [SubExp]
ses = BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT lore
tbranch
      Stms (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms (Stms (Lore (RuleM lore)) -> RuleM lore ())
-> Stms (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BodyT lore -> Stms lore
forall lore. BodyT lore -> Stms lore
bodyStms BodyT lore
tbranch
      [RuleM lore ()] -> RuleM lore ()
forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, Monad m) =>
t (m a) -> m ()
sequence_ [ [VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames_ [PatElemT (LetAttr lore) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (LetAttr lore)
p] (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp lore -> ExpT lore
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp lore -> ExpT lore) -> BasicOp lore -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp lore
forall lore. SubExp -> BasicOp lore
SubExp SubExp
se
                | (PatElemT (LetAttr lore)
p,SubExp
se) <- [PatElemT (LetAttr lore)]
-> [SubExp] -> [(PatElemT (LetAttr lore), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pattern lore -> [PatElemT (LetAttr lore)]
forall attr. PatternT attr -> [PatElemT attr]
patternElements Pattern lore
pat) [SubExp]
ses]

ruleIf TopDown lore
_ Pattern lore
pat StmAux (ExpAttr lore)
_ (SubExp
cond, BodyT lore
tb, BodyT lore
fb, IfAttr (BranchType lore)
_)
  | Body BodyAttr lore
_ Stms lore
_ [Constant (IntValue IntValue
t)] <- BodyT lore
tb,
    Body BodyAttr lore
_ Stms lore
_ [Constant (IntValue IntValue
f)] <- BodyT lore
fb =
      if IntValue -> Bool
oneIshInt IntValue
t Bool -> Bool -> Bool
&& IntValue -> Bool
zeroIshInt IntValue
f
      then RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$
           Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind_ Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp lore -> ExpT lore
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp lore -> ExpT lore) -> BasicOp lore -> ExpT lore
forall a b. (a -> b) -> a -> b
$ ConvOp -> SubExp -> BasicOp lore
forall lore. ConvOp -> SubExp -> BasicOp lore
ConvOp (IntType -> ConvOp
BToI (IntValue -> IntType
intValueType IntValue
t)) SubExp
cond
      else if IntValue -> Bool
zeroIshInt IntValue
t Bool -> Bool -> Bool
&& IntValue -> Bool
oneIshInt IntValue
f
      then RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
        SubExp
cond_neg <- String -> Exp (Lore (RuleM lore)) -> RuleM lore SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"cond_neg" (Exp (Lore (RuleM lore)) -> RuleM lore SubExp)
-> Exp (Lore (RuleM lore)) -> RuleM lore SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp lore -> ExpT lore
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp lore -> ExpT lore) -> BasicOp lore -> ExpT lore
forall a b. (a -> b) -> a -> b
$ UnOp -> SubExp -> BasicOp lore
forall lore. UnOp -> SubExp -> BasicOp lore
UnOp UnOp
Not SubExp
cond
        Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind_ Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp lore -> ExpT lore
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp lore -> ExpT lore) -> BasicOp lore -> ExpT lore
forall a b. (a -> b) -> a -> b
$ ConvOp -> SubExp -> BasicOp lore
forall lore. ConvOp -> SubExp -> BasicOp lore
ConvOp (IntType -> ConvOp
BToI (IntValue -> IntType
intValueType IntValue
t)) SubExp
cond_neg
      else Rule lore
forall lore. Rule lore
Skip

ruleIf TopDown lore
_ Pattern lore
_ StmAux (ExpAttr lore)
_ (SubExp, BodyT lore, BodyT lore, IfAttr (BranchType lore))
_ = Rule lore
forall lore. Rule lore
Skip

-- | Move out results of a conditional expression whose computation is
-- either invariant to the branches (only done for results in the
-- context), or the same in both branches.
hoistBranchInvariant :: BinderOps lore => TopDownRuleIf lore
hoistBranchInvariant :: TopDownRuleIf lore
hoistBranchInvariant TopDown lore
_ Pattern lore
pat StmAux (ExpAttr lore)
_ (SubExp
cond, BodyT lore
tb, BodyT lore
fb, IfAttr [BranchType lore]
ret IfSort
ifsort) = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
  let tses :: [SubExp]
tses = BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT lore
tb
      fses :: [SubExp]
fses = BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT lore
fb
  ([Maybe (Int, SubExp)]
hoistings, ([PatElemT (LetAttr lore)]
pes, [Either Int (BranchType lore)]
ts, [(SubExp, SubExp)]
res)) <-
    ([Either
    (Maybe (Int, SubExp))
    (PatElemT (LetAttr lore), Either Int (BranchType lore),
     (SubExp, SubExp))]
 -> ([Maybe (Int, SubExp)],
     ([PatElemT (LetAttr lore)], [Either Int (BranchType lore)],
      [(SubExp, SubExp)])))
-> RuleM
     lore
     [Either
        (Maybe (Int, SubExp))
        (PatElemT (LetAttr lore), Either Int (BranchType lore),
         (SubExp, SubExp))]
-> RuleM
     lore
     ([Maybe (Int, SubExp)],
      ([PatElemT (LetAttr lore)], [Either Int (BranchType lore)],
       [(SubExp, SubExp)]))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (([(PatElemT (LetAttr lore), Either Int (BranchType lore),
   (SubExp, SubExp))]
 -> ([PatElemT (LetAttr lore)], [Either Int (BranchType lore)],
     [(SubExp, SubExp)]))
-> ([Maybe (Int, SubExp)],
    [(PatElemT (LetAttr lore), Either Int (BranchType lore),
      (SubExp, SubExp))])
-> ([Maybe (Int, SubExp)],
    ([PatElemT (LetAttr lore)], [Either Int (BranchType lore)],
     [(SubExp, SubExp)]))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [(PatElemT (LetAttr lore), Either Int (BranchType lore),
  (SubExp, SubExp))]
-> ([PatElemT (LetAttr lore)], [Either Int (BranchType lore)],
    [(SubExp, SubExp)])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 (([Maybe (Int, SubExp)],
  [(PatElemT (LetAttr lore), Either Int (BranchType lore),
    (SubExp, SubExp))])
 -> ([Maybe (Int, SubExp)],
     ([PatElemT (LetAttr lore)], [Either Int (BranchType lore)],
      [(SubExp, SubExp)])))
-> ([Either
       (Maybe (Int, SubExp))
       (PatElemT (LetAttr lore), Either Int (BranchType lore),
        (SubExp, SubExp))]
    -> ([Maybe (Int, SubExp)],
        [(PatElemT (LetAttr lore), Either Int (BranchType lore),
          (SubExp, SubExp))]))
-> [Either
      (Maybe (Int, SubExp))
      (PatElemT (LetAttr lore), Either Int (BranchType lore),
       (SubExp, SubExp))]
-> ([Maybe (Int, SubExp)],
    ([PatElemT (LetAttr lore)], [Either Int (BranchType lore)],
     [(SubExp, SubExp)]))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Either
   (Maybe (Int, SubExp))
   (PatElemT (LetAttr lore), Either Int (BranchType lore),
    (SubExp, SubExp))]
-> ([Maybe (Int, SubExp)],
    [(PatElemT (LetAttr lore), Either Int (BranchType lore),
      (SubExp, SubExp))])
forall a b. [Either a b] -> ([a], [b])
partitionEithers) (RuleM
   lore
   [Either
      (Maybe (Int, SubExp))
      (PatElemT (LetAttr lore), Either Int (BranchType lore),
       (SubExp, SubExp))]
 -> RuleM
      lore
      ([Maybe (Int, SubExp)],
       ([PatElemT (LetAttr lore)], [Either Int (BranchType lore)],
        [(SubExp, SubExp)])))
-> RuleM
     lore
     [Either
        (Maybe (Int, SubExp))
        (PatElemT (LetAttr lore), Either Int (BranchType lore),
         (SubExp, SubExp))]
-> RuleM
     lore
     ([Maybe (Int, SubExp)],
      ([PatElemT (LetAttr lore)], [Either Int (BranchType lore)],
       [(SubExp, SubExp)]))
forall a b. (a -> b) -> a -> b
$ ((PatElemT (LetAttr lore), Either Int (BranchType lore),
  (SubExp, SubExp))
 -> RuleM
      lore
      (Either
         (Maybe (Int, SubExp))
         (PatElemT (LetAttr lore), Either Int (BranchType lore),
          (SubExp, SubExp))))
-> [(PatElemT (LetAttr lore), Either Int (BranchType lore),
     (SubExp, SubExp))]
-> RuleM
     lore
     [Either
        (Maybe (Int, SubExp))
        (PatElemT (LetAttr lore), Either Int (BranchType lore),
         (SubExp, SubExp))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (PatElemT (LetAttr lore), Either Int (BranchType lore),
 (SubExp, SubExp))
-> RuleM
     lore
     (Either
        (Maybe (Int, SubExp))
        (PatElemT (LetAttr lore), Either Int (BranchType lore),
         (SubExp, SubExp)))
branchInvariant ([(PatElemT (LetAttr lore), Either Int (BranchType lore),
   (SubExp, SubExp))]
 -> RuleM
      lore
      [Either
         (Maybe (Int, SubExp))
         (PatElemT (LetAttr lore), Either Int (BranchType lore),
          (SubExp, SubExp))])
-> [(PatElemT (LetAttr lore), Either Int (BranchType lore),
     (SubExp, SubExp))]
-> RuleM
     lore
     [Either
        (Maybe (Int, SubExp))
        (PatElemT (LetAttr lore), Either Int (BranchType lore),
         (SubExp, SubExp))]
forall a b. (a -> b) -> a -> b
$
      [PatElemT (LetAttr lore)]
-> [Either Int (BranchType lore)]
-> [(SubExp, SubExp)]
-> [(PatElemT (LetAttr lore), Either Int (BranchType lore),
     (SubExp, SubExp))]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 (Pattern lore -> [PatElemT (LetAttr lore)]
forall attr. PatternT attr -> [PatElemT attr]
patternElements Pattern lore
pat)
           ((Int -> Either Int (BranchType lore))
-> [Int] -> [Either Int (BranchType lore)]
forall a b. (a -> b) -> [a] -> [b]
map Int -> Either Int (BranchType lore)
forall a b. a -> Either a b
Left [Int
0..Int
num_ctxInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] [Either Int (BranchType lore)]
-> [Either Int (BranchType lore)] -> [Either Int (BranchType lore)]
forall a. [a] -> [a] -> [a]
++ (BranchType lore -> Either Int (BranchType lore))
-> [BranchType lore] -> [Either Int (BranchType lore)]
forall a b. (a -> b) -> [a] -> [b]
map BranchType lore -> Either Int (BranchType lore)
forall a b. b -> Either a b
Right [BranchType lore]
ret)
           ([SubExp] -> [SubExp] -> [(SubExp, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [SubExp]
tses [SubExp]
fses)
  let ctx_fixes :: [(Int, SubExp)]
ctx_fixes = [Maybe (Int, SubExp)] -> [(Int, SubExp)]
forall a. [Maybe a] -> [a]
catMaybes [Maybe (Int, SubExp)]
hoistings
      ([SubExp]
tses', [SubExp]
fses') = [(SubExp, SubExp)] -> ([SubExp], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip [(SubExp, SubExp)]
res
      tb' :: BodyT lore
tb' = BodyT lore
tb { bodyResult :: [SubExp]
bodyResult = [SubExp]
tses' }
      fb' :: BodyT lore
fb' = BodyT lore
fb { bodyResult :: [SubExp]
bodyResult = [SubExp]
fses' }
      ret' :: [BranchType lore]
ret' = ((Int, SubExp) -> [BranchType lore] -> [BranchType lore])
-> [BranchType lore] -> [(Int, SubExp)] -> [BranchType lore]
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr ((Int -> SubExp -> [BranchType lore] -> [BranchType lore])
-> (Int, SubExp) -> [BranchType lore] -> [BranchType lore]
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Int -> SubExp -> [BranchType lore] -> [BranchType lore]
forall t. FixExt t => Int -> SubExp -> t -> t
fixExt) ([Either Int (BranchType lore)] -> [BranchType lore]
forall a b. [Either a b] -> [b]
rights [Either Int (BranchType lore)]
ts) [(Int, SubExp)]
ctx_fixes
      ([PatElemT (LetAttr lore)]
ctx_pes, [PatElemT (LetAttr lore)]
val_pes) = Int
-> [PatElemT (LetAttr lore)]
-> ([PatElemT (LetAttr lore)], [PatElemT (LetAttr lore)])
forall a. Int -> [a] -> ([a], [a])
splitFromEnd ([BranchType lore] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [BranchType lore]
ret') [PatElemT (LetAttr lore)]
pes
  if Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [Maybe (Int, SubExp)] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Maybe (Int, SubExp)]
hoistings -- Was something hoisted?
     then do -- We may have to add some reshapes if we made the type
             -- less existential.
             BodyT lore
tb'' <- BodyT (Lore (RuleM lore))
-> [ExtType] -> RuleM lore (BodyT (Lore (RuleM lore)))
forall (m :: * -> *).
MonadBinder m =>
BodyT (Lore m) -> [ExtType] -> m (BodyT (Lore m))
reshapeBodyResults BodyT lore
BodyT (Lore (RuleM lore))
tb' ([ExtType] -> RuleM lore (BodyT (Lore (RuleM lore))))
-> [ExtType] -> RuleM lore (BodyT (Lore (RuleM lore)))
forall a b. (a -> b) -> a -> b
$ (BranchType lore -> ExtType) -> [BranchType lore] -> [ExtType]
forall a b. (a -> b) -> [a] -> [b]
map BranchType lore -> ExtType
forall t. ExtTyped t => t -> ExtType
extTypeOf [BranchType lore]
ret'
             BodyT lore
fb'' <- BodyT (Lore (RuleM lore))
-> [ExtType] -> RuleM lore (BodyT (Lore (RuleM lore)))
forall (m :: * -> *).
MonadBinder m =>
BodyT (Lore m) -> [ExtType] -> m (BodyT (Lore m))
reshapeBodyResults BodyT lore
BodyT (Lore (RuleM lore))
fb' ([ExtType] -> RuleM lore (BodyT (Lore (RuleM lore))))
-> [ExtType] -> RuleM lore (BodyT (Lore (RuleM lore)))
forall a b. (a -> b) -> a -> b
$ (BranchType lore -> ExtType) -> [BranchType lore] -> [ExtType]
forall a b. (a -> b) -> [a] -> [b]
map BranchType lore -> ExtType
forall t. ExtTyped t => t -> ExtType
extTypeOf [BranchType lore]
ret'
             Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind_ ([PatElemT (LetAttr lore)]
-> [PatElemT (LetAttr lore)] -> Pattern lore
forall attr. [PatElemT attr] -> [PatElemT attr] -> PatternT attr
Pattern [PatElemT (LetAttr lore)]
ctx_pes [PatElemT (LetAttr lore)]
val_pes) (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
               SubExp
-> BodyT lore
-> BodyT lore
-> IfAttr (BranchType lore)
-> ExpT lore
forall lore.
SubExp
-> BodyT lore
-> BodyT lore
-> IfAttr (BranchType lore)
-> ExpT lore
If SubExp
cond BodyT lore
tb'' BodyT lore
fb'' ([BranchType lore] -> IfSort -> IfAttr (BranchType lore)
forall rt. [rt] -> IfSort -> IfAttr rt
IfAttr [BranchType lore]
ret' IfSort
ifsort)
     else RuleM lore ()
forall lore a. RuleM lore a
cannotSimplify
  where num_ctx :: Int
num_ctx = [PatElemT (LetAttr lore)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([PatElemT (LetAttr lore)] -> Int)
-> [PatElemT (LetAttr lore)] -> Int
forall a b. (a -> b) -> a -> b
$ Pattern lore -> [PatElemT (LetAttr lore)]
forall attr. PatternT attr -> [PatElemT attr]
patternContextElements Pattern lore
pat
        bound_in_branches :: Names
bound_in_branches = [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ (Stm lore -> [VName]) -> Seq (Stm lore) -> [VName]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Pattern lore -> [VName]
forall attr. PatternT attr -> [VName]
patternNames (Pattern lore -> [VName])
-> (Stm lore -> Pattern lore) -> Stm lore -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm lore -> Pattern lore
forall lore. Stm lore -> Pattern lore
stmPattern) (Seq (Stm lore) -> [VName]) -> Seq (Stm lore) -> [VName]
forall a b. (a -> b) -> a -> b
$
                            BodyT lore -> Seq (Stm lore)
forall lore. BodyT lore -> Stms lore
bodyStms BodyT lore
tb Seq (Stm lore) -> Seq (Stm lore) -> Seq (Stm lore)
forall a. Semigroup a => a -> a -> a
<> BodyT lore -> Seq (Stm lore)
forall lore. BodyT lore -> Stms lore
bodyStms BodyT lore
fb
        mem_sizes :: Names
mem_sizes = [PatElemT (LetAttr lore)] -> Names
forall a. FreeIn a => a -> Names
freeIn ([PatElemT (LetAttr lore)] -> Names)
-> [PatElemT (LetAttr lore)] -> Names
forall a b. (a -> b) -> a -> b
$ (PatElemT (LetAttr lore) -> Bool)
-> [PatElemT (LetAttr lore)] -> [PatElemT (LetAttr lore)]
forall a. (a -> Bool) -> [a] -> [a]
filter (TypeBase Shape NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
isMem (TypeBase Shape NoUniqueness -> Bool)
-> (PatElemT (LetAttr lore) -> TypeBase Shape NoUniqueness)
-> PatElemT (LetAttr lore)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElemT (LetAttr lore) -> TypeBase Shape NoUniqueness
forall attr.
Typed attr =>
PatElemT attr -> TypeBase Shape NoUniqueness
patElemType) ([PatElemT (LetAttr lore)] -> [PatElemT (LetAttr lore)])
-> [PatElemT (LetAttr lore)] -> [PatElemT (LetAttr lore)]
forall a b. (a -> b) -> a -> b
$ Pattern lore -> [PatElemT (LetAttr lore)]
forall attr. PatternT attr -> [PatElemT attr]
patternElements Pattern lore
pat
        invariant :: SubExp -> Bool
invariant Constant{} = Bool
True
        invariant (Var VName
v) = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ VName
v VName -> Names -> Bool
`nameIn` Names
bound_in_branches

        isMem :: TypeBase shape u -> Bool
isMem Mem{} = Bool
True
        isMem TypeBase shape u
_ = Bool
False
        sizeOfMem :: VName -> Bool
sizeOfMem VName
v = VName
v VName -> Names -> Bool
`nameIn` Names
mem_sizes

        branchInvariant :: (PatElemT (LetAttr lore), Either Int (BranchType lore),
 (SubExp, SubExp))
-> RuleM
     lore
     (Either
        (Maybe (Int, SubExp))
        (PatElemT (LetAttr lore), Either Int (BranchType lore),
         (SubExp, SubExp)))
branchInvariant (PatElemT (LetAttr lore)
pe, Either Int (BranchType lore)
t, (SubExp
tse, SubExp
fse))
          -- Do both branches return the same value?
          | SubExp
tse SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
fse = do
              [VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames_ [PatElemT (LetAttr lore) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (LetAttr lore)
pe] (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp lore -> ExpT lore
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp lore -> ExpT lore) -> BasicOp lore -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp lore
forall lore. SubExp -> BasicOp lore
SubExp SubExp
tse
              PatElemT (LetAttr lore)
-> Either Int (BranchType lore)
-> RuleM
     lore
     (Either
        (Maybe (Int, SubExp))
        (PatElemT (LetAttr lore), Either Int (BranchType lore),
         (SubExp, SubExp)))
forall (m :: * -> *) attr a b b.
Monad m =>
PatElemT attr -> Either a b -> m (Either (Maybe (a, SubExp)) b)
hoisted PatElemT (LetAttr lore)
pe Either Int (BranchType lore)
t

          -- Do both branches return values that are free in the
          -- branch, and are we not the only pattern element?  The
          -- latter is to avoid infinite application of this rule.
          | SubExp -> Bool
invariant SubExp
tse, SubExp -> Bool
invariant SubExp
fse, Pattern lore -> Int
forall attr. PatternT attr -> Int
patternSize Pattern lore
pat Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
1,
            Prim PrimType
_ <- PatElemT (LetAttr lore) -> TypeBase Shape NoUniqueness
forall attr.
Typed attr =>
PatElemT attr -> TypeBase Shape NoUniqueness
patElemType PatElemT (LetAttr lore)
pe, Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ VName -> Bool
sizeOfMem (VName -> Bool) -> VName -> Bool
forall a b. (a -> b) -> a -> b
$ PatElemT (LetAttr lore) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (LetAttr lore)
pe = do
              [BranchType lore]
bt <- Pattern lore -> RuleM lore [BranchType lore]
forall lore (m :: * -> *).
(Attributes lore, HasScope lore m, Monad m) =>
Pattern lore -> m [BranchType lore]
expTypesFromPattern (Pattern lore -> RuleM lore [BranchType lore])
-> Pattern lore -> RuleM lore [BranchType lore]
forall a b. (a -> b) -> a -> b
$ [PatElemT (LetAttr lore)]
-> [PatElemT (LetAttr lore)] -> Pattern lore
forall attr. [PatElemT attr] -> [PatElemT attr] -> PatternT attr
Pattern [] [PatElemT (LetAttr lore)
pe]
              [VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames_ [PatElemT (LetAttr lore) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (LetAttr lore)
pe] (ExpT lore -> RuleM lore ())
-> RuleM lore (ExpT lore) -> RuleM lore ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
                (SubExp
-> BodyT lore
-> BodyT lore
-> IfAttr (BranchType lore)
-> ExpT lore
forall lore.
SubExp
-> BodyT lore
-> BodyT lore
-> IfAttr (BranchType lore)
-> ExpT lore
If SubExp
cond (BodyT lore -> BodyT lore -> IfAttr (BranchType lore) -> ExpT lore)
-> RuleM lore (BodyT lore)
-> RuleM lore (BodyT lore -> IfAttr (BranchType lore) -> ExpT lore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [SubExp] -> RuleM lore (BodyT (Lore (RuleM lore)))
forall (m :: * -> *).
MonadBinder m =>
[SubExp] -> m (Body (Lore m))
resultBodyM [SubExp
tse]
                         RuleM lore (BodyT lore -> IfAttr (BranchType lore) -> ExpT lore)
-> RuleM lore (BodyT lore)
-> RuleM lore (IfAttr (BranchType lore) -> ExpT lore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [SubExp] -> RuleM lore (BodyT (Lore (RuleM lore)))
forall (m :: * -> *).
MonadBinder m =>
[SubExp] -> m (Body (Lore m))
resultBodyM [SubExp
fse]
                         RuleM lore (IfAttr (BranchType lore) -> ExpT lore)
-> RuleM lore (IfAttr (BranchType lore)) -> RuleM lore (ExpT lore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> IfAttr (BranchType lore) -> RuleM lore (IfAttr (BranchType lore))
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([BranchType lore] -> IfSort -> IfAttr (BranchType lore)
forall rt. [rt] -> IfSort -> IfAttr rt
IfAttr [BranchType lore]
bt IfSort
ifsort))
              PatElemT (LetAttr lore)
-> Either Int (BranchType lore)
-> RuleM
     lore
     (Either
        (Maybe (Int, SubExp))
        (PatElemT (LetAttr lore), Either Int (BranchType lore),
         (SubExp, SubExp)))
forall (m :: * -> *) attr a b b.
Monad m =>
PatElemT attr -> Either a b -> m (Either (Maybe (a, SubExp)) b)
hoisted PatElemT (LetAttr lore)
pe Either Int (BranchType lore)
t

          | Bool
otherwise =
              Either
  (Maybe (Int, SubExp))
  (PatElemT (LetAttr lore), Either Int (BranchType lore),
   (SubExp, SubExp))
-> RuleM
     lore
     (Either
        (Maybe (Int, SubExp))
        (PatElemT (LetAttr lore), Either Int (BranchType lore),
         (SubExp, SubExp)))
forall (m :: * -> *) a. Monad m => a -> m a
return (Either
   (Maybe (Int, SubExp))
   (PatElemT (LetAttr lore), Either Int (BranchType lore),
    (SubExp, SubExp))
 -> RuleM
      lore
      (Either
         (Maybe (Int, SubExp))
         (PatElemT (LetAttr lore), Either Int (BranchType lore),
          (SubExp, SubExp))))
-> Either
     (Maybe (Int, SubExp))
     (PatElemT (LetAttr lore), Either Int (BranchType lore),
      (SubExp, SubExp))
-> RuleM
     lore
     (Either
        (Maybe (Int, SubExp))
        (PatElemT (LetAttr lore), Either Int (BranchType lore),
         (SubExp, SubExp)))
forall a b. (a -> b) -> a -> b
$ (PatElemT (LetAttr lore), Either Int (BranchType lore),
 (SubExp, SubExp))
-> Either
     (Maybe (Int, SubExp))
     (PatElemT (LetAttr lore), Either Int (BranchType lore),
      (SubExp, SubExp))
forall a b. b -> Either a b
Right (PatElemT (LetAttr lore)
pe, Either Int (BranchType lore)
t, (SubExp
tse,SubExp
fse))

        hoisted :: PatElemT attr -> Either a b -> m (Either (Maybe (a, SubExp)) b)
hoisted PatElemT attr
pe (Left a
i) = Either (Maybe (a, SubExp)) b -> m (Either (Maybe (a, SubExp)) b)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either (Maybe (a, SubExp)) b -> m (Either (Maybe (a, SubExp)) b))
-> Either (Maybe (a, SubExp)) b -> m (Either (Maybe (a, SubExp)) b)
forall a b. (a -> b) -> a -> b
$ Maybe (a, SubExp) -> Either (Maybe (a, SubExp)) b
forall a b. a -> Either a b
Left (Maybe (a, SubExp) -> Either (Maybe (a, SubExp)) b)
-> Maybe (a, SubExp) -> Either (Maybe (a, SubExp)) b
forall a b. (a -> b) -> a -> b
$ (a, SubExp) -> Maybe (a, SubExp)
forall a. a -> Maybe a
Just (a
i, VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ PatElemT attr -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT attr
pe)
        hoisted PatElemT attr
_ Right{}   = Either (Maybe (a, SubExp)) b -> m (Either (Maybe (a, SubExp)) b)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either (Maybe (a, SubExp)) b -> m (Either (Maybe (a, SubExp)) b))
-> Either (Maybe (a, SubExp)) b -> m (Either (Maybe (a, SubExp)) b)
forall a b. (a -> b) -> a -> b
$ Maybe (a, SubExp) -> Either (Maybe (a, SubExp)) b
forall a b. a -> Either a b
Left Maybe (a, SubExp)
forall a. Maybe a
Nothing

        reshapeBodyResults :: BodyT (Lore m) -> [ExtType] -> m (BodyT (Lore m))
reshapeBodyResults BodyT (Lore m)
body [ExtType]
rets = m (BodyT (Lore m)) -> m (BodyT (Lore m))
forall (m :: * -> *).
MonadBinder m =>
m (Body (Lore m)) -> m (Body (Lore m))
insertStmsM (m (BodyT (Lore m)) -> m (BodyT (Lore m)))
-> m (BodyT (Lore m)) -> m (BodyT (Lore m))
forall a b. (a -> b) -> a -> b
$ do
          [SubExp]
ses <- BodyT (Lore m) -> m [SubExp]
forall (m :: * -> *). MonadBinder m => Body (Lore m) -> m [SubExp]
bodyBind BodyT (Lore m)
body
          let ([SubExp]
ctx_ses, [SubExp]
val_ses) = Int -> [SubExp] -> ([SubExp], [SubExp])
forall a. Int -> [a] -> ([a], [a])
splitFromEnd ([ExtType] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [ExtType]
rets) [SubExp]
ses
          [SubExp] -> m (BodyT (Lore m))
forall (m :: * -> *).
MonadBinder m =>
[SubExp] -> m (Body (Lore m))
resultBodyM ([SubExp] -> m (BodyT (Lore m)))
-> ([SubExp] -> [SubExp]) -> [SubExp] -> m (BodyT (Lore m))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([SubExp]
ctx_ses[SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++) ([SubExp] -> m (BodyT (Lore m)))
-> m [SubExp] -> m (BodyT (Lore m))
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (SubExp -> ExtType -> m SubExp)
-> [SubExp] -> [ExtType] -> m [SubExp]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM SubExp -> ExtType -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
SubExp -> ExtType -> m SubExp
reshapeResult [SubExp]
val_ses [ExtType]
rets
        reshapeResult :: SubExp -> ExtType -> m SubExp
reshapeResult (Var VName
v) t :: ExtType
t@Array{} = do
          TypeBase Shape NoUniqueness
v_t <- VName -> m (TypeBase Shape NoUniqueness)
forall lore (m :: * -> *).
HasScope lore m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
v
          let newshape :: [SubExp]
newshape = TypeBase Shape NoUniqueness -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (TypeBase Shape NoUniqueness -> [SubExp])
-> TypeBase Shape NoUniqueness -> [SubExp]
forall a b. (a -> b) -> a -> b
$ ExtType
-> TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
removeExistentials ExtType
t TypeBase Shape NoUniqueness
v_t
          if [SubExp]
newshape [SubExp] -> [SubExp] -> Bool
forall a. Eq a => a -> a -> Bool
/= TypeBase Shape NoUniqueness -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims TypeBase Shape NoUniqueness
v_t
            then String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"branch_ctx_reshaped" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ [SubExp] -> VName -> Exp (Lore m)
forall lore. [SubExp] -> VName -> Exp lore
shapeCoerce [SubExp]
newshape VName
v
            else SubExp -> m SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp -> m SubExp) -> SubExp -> m SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v
        reshapeResult SubExp
se ExtType
_ =
          SubExp -> m SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return SubExp
se

simplifyIdentityReshape :: SimpleRule lore
simplifyIdentityReshape :: SimpleRule lore
simplifyIdentityReshape VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType (Reshape ShapeChange SubExp
newshape VName
v)
  | Just TypeBase Shape NoUniqueness
t <- SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType (SubExp -> Maybe (TypeBase Shape NoUniqueness))
-> SubExp -> Maybe (TypeBase Shape NoUniqueness)
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v,
    ShapeChange SubExp -> [SubExp]
forall d. ShapeChange d -> [d]
newDims ShapeChange SubExp
newshape [SubExp] -> [SubExp] -> Bool
forall a. Eq a => a -> a -> Bool
== TypeBase Shape NoUniqueness -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims TypeBase Shape NoUniqueness
t = -- No-op reshape.
      SubExp -> Maybe (BasicOp lore, Certificates)
forall lore. SubExp -> Maybe (BasicOp lore, Certificates)
subExpRes (SubExp -> Maybe (BasicOp lore, Certificates))
-> SubExp -> Maybe (BasicOp lore, Certificates)
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v
simplifyIdentityReshape VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ BasicOp lore
_ = Maybe (BasicOp lore, Certificates)
forall a. Maybe a
Nothing

simplifyReshapeReshape :: SimpleRule lore
simplifyReshapeReshape :: SimpleRule lore
simplifyReshapeReshape VarLookup lore
defOf SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (Reshape ShapeChange SubExp
newshape VName
v)
  | Just (BasicOp (Reshape ShapeChange SubExp
oldshape VName
v2), Certificates
v_cs) <- VarLookup lore
defOf VName
v =
    (BasicOp lore, Certificates) -> Maybe (BasicOp lore, Certificates)
forall a. a -> Maybe a
Just (ShapeChange SubExp -> VName -> BasicOp lore
forall lore. ShapeChange SubExp -> VName -> BasicOp lore
Reshape (ShapeChange SubExp -> ShapeChange SubExp -> ShapeChange SubExp
forall d. Eq d => ShapeChange d -> ShapeChange d -> ShapeChange d
fuseReshape ShapeChange SubExp
oldshape ShapeChange SubExp
newshape) VName
v2, Certificates
v_cs)
simplifyReshapeReshape VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ BasicOp lore
_ = Maybe (BasicOp lore, Certificates)
forall a. Maybe a
Nothing

simplifyReshapeScratch :: SimpleRule lore
simplifyReshapeScratch :: SimpleRule lore
simplifyReshapeScratch VarLookup lore
defOf SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (Reshape ShapeChange SubExp
newshape VName
v)
  | Just (BasicOp (Scratch PrimType
bt [SubExp]
_), Certificates
v_cs) <- VarLookup lore
defOf VName
v =
    (BasicOp lore, Certificates) -> Maybe (BasicOp lore, Certificates)
forall a. a -> Maybe a
Just (PrimType -> [SubExp] -> BasicOp lore
forall lore. PrimType -> [SubExp] -> BasicOp lore
Scratch PrimType
bt ([SubExp] -> BasicOp lore) -> [SubExp] -> BasicOp lore
forall a b. (a -> b) -> a -> b
$ ShapeChange SubExp -> [SubExp]
forall d. ShapeChange d -> [d]
newDims ShapeChange SubExp
newshape, Certificates
v_cs)
simplifyReshapeScratch VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ BasicOp lore
_ = Maybe (BasicOp lore, Certificates)
forall a. Maybe a
Nothing

simplifyReshapeReplicate :: SimpleRule lore
simplifyReshapeReplicate :: SimpleRule lore
simplifyReshapeReplicate VarLookup lore
defOf SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType (Reshape ShapeChange SubExp
newshape VName
v)
  | Just (BasicOp (Replicate Shape
_ SubExp
se), Certificates
v_cs) <- VarLookup lore
defOf VName
v,
    Just Shape
oldshape <- TypeBase Shape NoUniqueness -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape (TypeBase Shape NoUniqueness -> Shape)
-> Maybe (TypeBase Shape NoUniqueness) -> Maybe Shape
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType SubExp
se,
    Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
oldshape [SubExp] -> [SubExp] -> Bool
forall a. Eq a => [a] -> [a] -> Bool
`isSuffixOf` ShapeChange SubExp -> [SubExp]
forall d. ShapeChange d -> [d]
newDims ShapeChange SubExp
newshape =
      let new :: [SubExp]
new = Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
take (ShapeChange SubExp -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeChange SubExp
newshape Int -> Int -> Int
forall a. Num a => a -> a -> a
- Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank Shape
oldshape) ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$
                ShapeChange SubExp -> [SubExp]
forall d. ShapeChange d -> [d]
newDims ShapeChange SubExp
newshape
      in (BasicOp lore, Certificates) -> Maybe (BasicOp lore, Certificates)
forall a. a -> Maybe a
Just (Shape -> SubExp -> BasicOp lore
forall lore. Shape -> SubExp -> BasicOp lore
Replicate ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
new) SubExp
se, Certificates
v_cs)
simplifyReshapeReplicate VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ BasicOp lore
_ = Maybe (BasicOp lore, Certificates)
forall a. Maybe a
Nothing

simplifyReshapeIota :: SimpleRule lore
simplifyReshapeIota :: SimpleRule lore
simplifyReshapeIota VarLookup lore
defOf SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (Reshape ShapeChange SubExp
newshape VName
v)
  | Just (BasicOp (Iota SubExp
_ SubExp
offset SubExp
stride IntType
it), Certificates
v_cs) <- VarLookup lore
defOf VName
v,
    [SubExp
n] <- ShapeChange SubExp -> [SubExp]
forall d. ShapeChange d -> [d]
newDims ShapeChange SubExp
newshape =
      (BasicOp lore, Certificates) -> Maybe (BasicOp lore, Certificates)
forall a. a -> Maybe a
Just (SubExp -> SubExp -> SubExp -> IntType -> BasicOp lore
forall lore. SubExp -> SubExp -> SubExp -> IntType -> BasicOp lore
Iota SubExp
n SubExp
offset SubExp
stride IntType
it, Certificates
v_cs)
simplifyReshapeIota VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ BasicOp lore
_ = Maybe (BasicOp lore, Certificates)
forall a. Maybe a
Nothing

improveReshape :: SimpleRule lore
improveReshape :: SimpleRule lore
improveReshape VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType (Reshape ShapeChange SubExp
newshape VName
v)
  | Just TypeBase Shape NoUniqueness
t <- SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType (SubExp -> Maybe (TypeBase Shape NoUniqueness))
-> SubExp -> Maybe (TypeBase Shape NoUniqueness)
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v,
    ShapeChange SubExp
newshape' <- [SubExp] -> ShapeChange SubExp -> ShapeChange SubExp
forall d. Eq d => [d] -> ShapeChange d -> ShapeChange d
informReshape (TypeBase Shape NoUniqueness -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims TypeBase Shape NoUniqueness
t) ShapeChange SubExp
newshape,
    ShapeChange SubExp
newshape' ShapeChange SubExp -> ShapeChange SubExp -> Bool
forall a. Eq a => a -> a -> Bool
/= ShapeChange SubExp
newshape =
      (BasicOp lore, Certificates) -> Maybe (BasicOp lore, Certificates)
forall a. a -> Maybe a
Just (ShapeChange SubExp -> VName -> BasicOp lore
forall lore. ShapeChange SubExp -> VName -> BasicOp lore
Reshape ShapeChange SubExp
newshape' VName
v, Certificates
forall a. Monoid a => a
mempty)
improveReshape VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ BasicOp lore
_ = Maybe (BasicOp lore, Certificates)
forall a. Maybe a
Nothing

-- | If we are copying a scratch array (possibly indirectly), just turn it into a scratch by
-- itself.
copyScratchToScratch :: SimpleRule lore
copyScratchToScratch :: SimpleRule lore
copyScratchToScratch VarLookup lore
defOf SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType (Copy VName
src) = do
  TypeBase Shape NoUniqueness
t <- SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType (SubExp -> Maybe (TypeBase Shape NoUniqueness))
-> SubExp -> Maybe (TypeBase Shape NoUniqueness)
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
src
  if VName -> Bool
isActuallyScratch VName
src then
    (BasicOp lore, Certificates) -> Maybe (BasicOp lore, Certificates)
forall a. a -> Maybe a
Just (PrimType -> [SubExp] -> BasicOp lore
forall lore. PrimType -> [SubExp] -> BasicOp lore
Scratch (TypeBase Shape NoUniqueness -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType TypeBase Shape NoUniqueness
t) (TypeBase Shape NoUniqueness -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims TypeBase Shape NoUniqueness
t), Certificates
forall a. Monoid a => a
mempty)
    else Maybe (BasicOp lore, Certificates)
forall a. Maybe a
Nothing
  where isActuallyScratch :: VName -> Bool
isActuallyScratch VName
v =
          case Exp lore -> Maybe (BasicOp lore)
forall lore. Exp lore -> Maybe (BasicOp lore)
asBasicOp (Exp lore -> Maybe (BasicOp lore))
-> ((Exp lore, Certificates) -> Exp lore)
-> (Exp lore, Certificates)
-> Maybe (BasicOp lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Exp lore, Certificates) -> Exp lore
forall a b. (a, b) -> a
fst ((Exp lore, Certificates) -> Maybe (BasicOp lore))
-> Maybe (Exp lore, Certificates) -> Maybe (BasicOp lore)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VarLookup lore
defOf VName
v of
            Just Scratch{} -> Bool
True
            Just (Rearrange [Int]
_ VName
v') -> VName -> Bool
isActuallyScratch VName
v'
            Just (Reshape ShapeChange SubExp
_ VName
v') -> VName -> Bool
isActuallyScratch VName
v'
            Maybe (BasicOp lore)
_ -> Bool
False
copyScratchToScratch VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ BasicOp lore
_ =
  Maybe (BasicOp lore, Certificates)
forall a. Maybe a
Nothing

ruleBasicOp :: BinderOps lore => TopDownRuleBasicOp lore

-- Check all the simpleRules.
ruleBasicOp :: TopDownRuleBasicOp lore
ruleBasicOp TopDown lore
vtable Pattern lore
pat StmAux (ExpAttr lore)
aux BasicOp lore
op
  | Just (BasicOp lore
op', Certificates
cs) <- [Maybe (BasicOp lore, Certificates)]
-> Maybe (BasicOp lore, Certificates)
forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, MonadPlus m) =>
t (m a) -> m a
msum [ SimpleRule lore
rule VName -> Maybe (Exp lore, Certificates)
defOf SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType BasicOp lore
op | SimpleRule lore
rule <- [SimpleRule lore]
forall lore. [SimpleRule lore]
simpleRules ] =
      RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ Certificates -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (Certificates
cs Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> StmAux (ExpAttr lore) -> Certificates
forall attr. StmAux attr -> Certificates
stmAuxCerts StmAux (ExpAttr lore)
aux) (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind_ Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp lore -> Exp lore
forall lore. BasicOp lore -> ExpT lore
BasicOp BasicOp lore
op'
  where defOf :: VName -> Maybe (Exp lore, Certificates)
defOf = (VName -> TopDown lore -> Maybe (Exp lore, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (Exp lore, Certificates)
`ST.lookupExp` TopDown lore
vtable)
        seType :: SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType (Var VName
v) = VName -> TopDown lore -> Maybe (TypeBase Shape NoUniqueness)
forall lore.
Attributes lore =>
VName -> SymbolTable lore -> Maybe (TypeBase Shape NoUniqueness)
ST.lookupType VName
v TopDown lore
vtable
        seType (Constant PrimValue
v) = TypeBase Shape NoUniqueness -> Maybe (TypeBase Shape NoUniqueness)
forall a. a -> Maybe a
Just (TypeBase Shape NoUniqueness
 -> Maybe (TypeBase Shape NoUniqueness))
-> TypeBase Shape NoUniqueness
-> Maybe (TypeBase Shape NoUniqueness)
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase Shape NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim (PrimType -> TypeBase Shape NoUniqueness)
-> PrimType -> TypeBase Shape NoUniqueness
forall a b. (a -> b) -> a -> b
$ PrimValue -> PrimType
primValueType PrimValue
v

ruleBasicOp TopDown lore
vtable Pattern lore
pat StmAux (ExpAttr lore)
_ (Update VName
src Slice SubExp
_ (Var VName
v))
  | Just (BasicOp Scratch{}, Certificates
_) <- VName -> TopDown lore -> Maybe (Exp lore, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (Exp lore, Certificates)
ST.lookupExp VName
v TopDown lore
vtable =
      RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind_ Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp lore -> Exp lore
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp lore -> Exp lore) -> BasicOp lore -> Exp lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp lore
forall lore. SubExp -> BasicOp lore
SubExp (SubExp -> BasicOp lore) -> SubExp -> BasicOp lore
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
src

ruleBasicOp TopDown lore
vtable Pattern lore
pat StmAux (ExpAttr lore)
_ (Update VName
dest Slice SubExp
destis (Var VName
v))
  | Just (Exp lore
e, Certificates
_) <- VName -> TopDown lore -> Maybe (Exp lore, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (Exp lore, Certificates)
ST.lookupExp VName
v TopDown lore
vtable,
    Exp lore -> Bool
arrayFrom Exp lore
e =
      RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind_ Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp lore -> Exp lore
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp lore -> Exp lore) -> BasicOp lore -> Exp lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp lore
forall lore. SubExp -> BasicOp lore
SubExp (SubExp -> BasicOp lore) -> SubExp -> BasicOp lore
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
dest
  where arrayFrom :: Exp lore -> Bool
arrayFrom (BasicOp (Copy VName
copy_v))
          | Just (Exp lore
e',Certificates
_) <- VName -> TopDown lore -> Maybe (Exp lore, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (Exp lore, Certificates)
ST.lookupExp VName
copy_v TopDown lore
vtable =
              Exp lore -> Bool
arrayFrom Exp lore
e'
        arrayFrom (BasicOp (Index VName
src Slice SubExp
srcis)) =
          VName
src VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
dest Bool -> Bool -> Bool
&& Slice SubExp
destis Slice SubExp -> Slice SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== Slice SubExp
srcis
        arrayFrom (BasicOp (Replicate Shape
v_shape SubExp
v_se))
          | Just (Replicate Shape
dest_shape SubExp
dest_se, Certificates
_) <- VName -> TopDown lore -> Maybe (BasicOp lore, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (BasicOp lore, Certificates)
ST.lookupBasicOp VName
dest TopDown lore
vtable,
            SubExp
v_se SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
dest_se,
            Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
v_shape [SubExp] -> [SubExp] -> Bool
forall a. Eq a => [a] -> [a] -> Bool
`isSuffixOf` Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
dest_shape =
              Bool
True
        arrayFrom Exp lore
_ =
          Bool
False

-- | Turn in-place updates that replace an entire array into just
-- array literals.
ruleBasicOp TopDown lore
vtable Pattern lore
pat StmAux (ExpAttr lore)
_ (Update VName
dest Slice SubExp
is SubExp
se)
  | Just TypeBase Shape NoUniqueness
dest_t <- VName -> TopDown lore -> Maybe (TypeBase Shape NoUniqueness)
forall lore.
Attributes lore =>
VName -> SymbolTable lore -> Maybe (TypeBase Shape NoUniqueness)
ST.lookupType VName
dest TopDown lore
vtable,
    Shape -> Slice SubExp -> Bool
isFullSlice (TypeBase Shape NoUniqueness -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape TypeBase Shape NoUniqueness
dest_t) Slice SubExp
is = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$
      case SubExp
se of
        Var VName
v | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null ([SubExp] -> Bool) -> [SubExp] -> Bool
forall a b. (a -> b) -> a -> b
$ Slice SubExp -> [SubExp]
forall d. Slice d -> [d]
sliceDims Slice SubExp
is -> do
                  VName
v_reshaped <- String -> Exp (Lore (RuleM lore)) -> RuleM lore VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp (VName -> String
baseString VName
v String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_reshaped") (Exp (Lore (RuleM lore)) -> RuleM lore VName)
-> Exp (Lore (RuleM lore)) -> RuleM lore VName
forall a b. (a -> b) -> a -> b
$
                                BasicOp lore -> Exp lore
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp lore -> Exp lore) -> BasicOp lore -> Exp lore
forall a b. (a -> b) -> a -> b
$ ShapeChange SubExp -> VName -> BasicOp lore
forall lore. ShapeChange SubExp -> VName -> BasicOp lore
Reshape ((SubExp -> DimChange SubExp) -> [SubExp] -> ShapeChange SubExp
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimChange SubExp
forall d. d -> DimChange d
DimNew ([SubExp] -> ShapeChange SubExp) -> [SubExp] -> ShapeChange SubExp
forall a b. (a -> b) -> a -> b
$ TypeBase Shape NoUniqueness -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims TypeBase Shape NoUniqueness
dest_t) VName
v
                  Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind_ Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp lore -> Exp lore
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp lore -> Exp lore) -> BasicOp lore -> Exp lore
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp lore
forall lore. VName -> BasicOp lore
Copy VName
v_reshaped

        SubExp
_ -> Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind_ Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp lore -> Exp lore
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp lore -> Exp lore) -> BasicOp lore -> Exp lore
forall a b. (a -> b) -> a -> b
$ [SubExp] -> TypeBase Shape NoUniqueness -> BasicOp lore
forall lore.
[SubExp] -> TypeBase Shape NoUniqueness -> BasicOp lore
ArrayLit [SubExp
se] (TypeBase Shape NoUniqueness -> BasicOp lore)
-> TypeBase Shape NoUniqueness -> BasicOp lore
forall a b. (a -> b) -> a -> b
$ TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
forall shape u.
ArrayShape shape =>
TypeBase shape u -> TypeBase shape u
rowType TypeBase Shape NoUniqueness
dest_t

-- | Simplify a chain of in-place updates and copies.  This chain is
-- often produced by in-place lowering.
ruleBasicOp TopDown lore
vtable Pattern lore
pat (StmAux Certificates
cs1 ExpAttr lore
_) (Update VName
dest1 Slice SubExp
is1 (Var VName
v1))
  | Just (Update VName
dest2 Slice SubExp
is2 SubExp
se2, Certificates
cs2) <- VName -> TopDown lore -> Maybe (BasicOp lore, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (BasicOp lore, Certificates)
ST.lookupBasicOp VName
v1 TopDown lore
vtable,
    Just (Copy VName
v3, Certificates
cs3) <- VName -> TopDown lore -> Maybe (BasicOp lore, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (BasicOp lore, Certificates)
ST.lookupBasicOp VName
dest2 TopDown lore
vtable,
    Just (Index VName
v4 Slice SubExp
is4, Certificates
cs4) <- VName -> TopDown lore -> Maybe (BasicOp lore, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (BasicOp lore, Certificates)
ST.lookupBasicOp VName
v3 TopDown lore
vtable,
    Slice SubExp
is4 Slice SubExp -> Slice SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== Slice SubExp
is1, VName
v4 VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
dest1 =
      RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ Certificates -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (Certificates
cs1 Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> Certificates
cs2 Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> Certificates
cs3 Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> Certificates
cs4) (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ do
      Slice SubExp
is5 <- Slice SubExp -> Slice SubExp -> RuleM lore (Slice SubExp)
forall (m :: * -> *).
MonadBinder m =>
Slice SubExp -> Slice SubExp -> m (Slice SubExp)
sliceSlice Slice SubExp
is1 Slice SubExp
is2
      Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind_ Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp lore -> Exp lore
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp lore -> Exp lore) -> BasicOp lore -> Exp lore
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> SubExp -> BasicOp lore
forall lore. VName -> Slice SubExp -> SubExp -> BasicOp lore
Update VName
dest1 Slice SubExp
is5 SubExp
se2

-- | If we are comparing X against the result of a branch of the form
-- @if P then Y else Z@ then replace comparison with '(P && X == Y) ||
-- (!P && X == Z').  This may allow us to get rid of a branch, and the
-- extra comparisons may be constant-folded out.  Question: maybe we
-- should have some more checks to ensure that we only do this if that
-- is actually the case, such as if we will obtain at least one
-- constant-to-constant comparison?
ruleBasicOp TopDown lore
vtable Pattern lore
pat StmAux (ExpAttr lore)
_ (CmpOp (CmpEq PrimType
t) SubExp
se1 SubExp
se2)
  | Just RuleM lore ()
m <- SubExp -> SubExp -> Maybe (RuleM lore ())
simplifyWith SubExp
se1 SubExp
se2 = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify RuleM lore ()
m
  | Just RuleM lore ()
m <- SubExp -> SubExp -> Maybe (RuleM lore ())
simplifyWith SubExp
se2 SubExp
se1 = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify RuleM lore ()
m
  where simplifyWith :: SubExp -> SubExp -> Maybe (RuleM lore ())
simplifyWith (Var VName
v) SubExp
x
          | Just Stm lore
bnd <- Entry lore -> Maybe (Stm lore)
forall lore. Entry lore -> Maybe (Stm lore)
ST.entryStm (Entry lore -> Maybe (Stm lore))
-> Maybe (Entry lore) -> Maybe (Stm lore)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> TopDown lore -> Maybe (Entry lore)
forall lore. VName -> SymbolTable lore -> Maybe (Entry lore)
ST.lookup VName
v TopDown lore
vtable,
            If SubExp
p BodyT lore
tbranch BodyT lore
fbranch IfAttr (BranchType lore)
_ <- Stm lore -> Exp lore
forall lore. Stm lore -> Exp lore
stmExp Stm lore
bnd,
            Just (SubExp
y, SubExp
z) <-
              VName
-> Pattern lore
-> BodyT lore
-> BodyT lore
-> Maybe (SubExp, SubExp)
forall attr lore lore.
VName
-> PatternT attr
-> BodyT lore
-> BodyT lore
-> Maybe (SubExp, SubExp)
returns VName
v (Stm lore -> Pattern lore
forall lore. Stm lore -> Pattern lore
stmPattern Stm lore
bnd) BodyT lore
tbranch BodyT lore
fbranch,
            Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ BodyT lore -> Names
forall lore. Body lore -> Names
boundInBody BodyT lore
tbranch Names -> Names -> Bool
`namesIntersect` SubExp -> Names
forall a. FreeIn a => a -> Names
freeIn SubExp
y,
            Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ BodyT lore -> Names
forall lore. Body lore -> Names
boundInBody BodyT lore
fbranch Names -> Names -> Bool
`namesIntersect` SubExp -> Names
forall a. FreeIn a => a -> Names
freeIn SubExp
z = RuleM lore () -> Maybe (RuleM lore ())
forall a. a -> Maybe a
Just (RuleM lore () -> Maybe (RuleM lore ()))
-> RuleM lore () -> Maybe (RuleM lore ())
forall a b. (a -> b) -> a -> b
$ do
                SubExp
eq_x_y <-
                  String -> Exp (Lore (RuleM lore)) -> RuleM lore SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"eq_x_y" (Exp (Lore (RuleM lore)) -> RuleM lore SubExp)
-> Exp (Lore (RuleM lore)) -> RuleM lore SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp lore -> Exp lore
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp lore -> Exp lore) -> BasicOp lore -> Exp lore
forall a b. (a -> b) -> a -> b
$ CmpOp -> SubExp -> SubExp -> BasicOp lore
forall lore. CmpOp -> SubExp -> SubExp -> BasicOp lore
CmpOp (PrimType -> CmpOp
CmpEq PrimType
t) SubExp
x SubExp
y
                SubExp
eq_x_z <-
                  String -> Exp (Lore (RuleM lore)) -> RuleM lore SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"eq_x_z" (Exp (Lore (RuleM lore)) -> RuleM lore SubExp)
-> Exp (Lore (RuleM lore)) -> RuleM lore SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp lore -> Exp lore
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp lore -> Exp lore) -> BasicOp lore -> Exp lore
forall a b. (a -> b) -> a -> b
$ CmpOp -> SubExp -> SubExp -> BasicOp lore
forall lore. CmpOp -> SubExp -> SubExp -> BasicOp lore
CmpOp (PrimType -> CmpOp
CmpEq PrimType
t) SubExp
x SubExp
z
                SubExp
p_and_eq_x_y <-
                  String -> Exp (Lore (RuleM lore)) -> RuleM lore SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"p_and_eq_x_y" (Exp (Lore (RuleM lore)) -> RuleM lore SubExp)
-> Exp (Lore (RuleM lore)) -> RuleM lore SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp lore -> Exp lore
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp lore -> Exp lore) -> BasicOp lore -> Exp lore
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp lore
forall lore. BinOp -> SubExp -> SubExp -> BasicOp lore
BinOp BinOp
LogAnd SubExp
p SubExp
eq_x_y
                SubExp
not_p <-
                  String -> Exp (Lore (RuleM lore)) -> RuleM lore SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"not_p" (Exp (Lore (RuleM lore)) -> RuleM lore SubExp)
-> Exp (Lore (RuleM lore)) -> RuleM lore SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp lore -> Exp lore
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp lore -> Exp lore) -> BasicOp lore -> Exp lore
forall a b. (a -> b) -> a -> b
$ UnOp -> SubExp -> BasicOp lore
forall lore. UnOp -> SubExp -> BasicOp lore
UnOp UnOp
Not SubExp
p
                SubExp
not_p_and_eq_x_z <-
                  String -> Exp (Lore (RuleM lore)) -> RuleM lore SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"p_and_eq_x_y" (Exp (Lore (RuleM lore)) -> RuleM lore SubExp)
-> Exp (Lore (RuleM lore)) -> RuleM lore SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp lore -> Exp lore
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp lore -> Exp lore) -> BasicOp lore -> Exp lore
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp lore
forall lore. BinOp -> SubExp -> SubExp -> BasicOp lore
BinOp BinOp
LogAnd SubExp
not_p SubExp
eq_x_z
                Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind_ Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
                  BasicOp lore -> Exp lore
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp lore -> Exp lore) -> BasicOp lore -> Exp lore
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp lore
forall lore. BinOp -> SubExp -> SubExp -> BasicOp lore
BinOp BinOp
LogOr SubExp
p_and_eq_x_y SubExp
not_p_and_eq_x_z
        simplifyWith SubExp
_ SubExp
_ =
          Maybe (RuleM lore ())
forall a. Maybe a
Nothing

        returns :: VName
-> PatternT attr
-> BodyT lore
-> BodyT lore
-> Maybe (SubExp, SubExp)
returns VName
v PatternT attr
ifpat BodyT lore
tbranch BodyT lore
fbranch =
          ((PatElemT attr, (SubExp, SubExp)) -> (SubExp, SubExp))
-> Maybe (PatElemT attr, (SubExp, SubExp))
-> Maybe (SubExp, SubExp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (PatElemT attr, (SubExp, SubExp)) -> (SubExp, SubExp)
forall a b. (a, b) -> b
snd (Maybe (PatElemT attr, (SubExp, SubExp)) -> Maybe (SubExp, SubExp))
-> Maybe (PatElemT attr, (SubExp, SubExp))
-> Maybe (SubExp, SubExp)
forall a b. (a -> b) -> a -> b
$
          ((PatElemT attr, (SubExp, SubExp)) -> Bool)
-> [(PatElemT attr, (SubExp, SubExp))]
-> Maybe (PatElemT attr, (SubExp, SubExp))
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
==VName
v) (VName -> Bool)
-> ((PatElemT attr, (SubExp, SubExp)) -> VName)
-> (PatElemT attr, (SubExp, SubExp))
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElemT attr -> VName
forall attr. PatElemT attr -> VName
patElemName (PatElemT attr -> VName)
-> ((PatElemT attr, (SubExp, SubExp)) -> PatElemT attr)
-> (PatElemT attr, (SubExp, SubExp))
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PatElemT attr, (SubExp, SubExp)) -> PatElemT attr
forall a b. (a, b) -> a
fst) ([(PatElemT attr, (SubExp, SubExp))]
 -> Maybe (PatElemT attr, (SubExp, SubExp)))
-> [(PatElemT attr, (SubExp, SubExp))]
-> Maybe (PatElemT attr, (SubExp, SubExp))
forall a b. (a -> b) -> a -> b
$
          [PatElemT attr]
-> [(SubExp, SubExp)] -> [(PatElemT attr, (SubExp, SubExp))]
forall a b. [a] -> [b] -> [(a, b)]
zip (PatternT attr -> [PatElemT attr]
forall attr. PatternT attr -> [PatElemT attr]
patternValueElements PatternT attr
ifpat) ([(SubExp, SubExp)] -> [(PatElemT attr, (SubExp, SubExp))])
-> [(SubExp, SubExp)] -> [(PatElemT attr, (SubExp, SubExp))]
forall a b. (a -> b) -> a -> b
$
          [SubExp] -> [SubExp] -> [(SubExp, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT lore
tbranch) (BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT lore
fbranch)

ruleBasicOp TopDown lore
_ Pattern lore
pat StmAux (ExpAttr lore)
_ (Replicate (Shape []) se :: SubExp
se@Constant{}) =
  RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind_ Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp lore -> Exp lore
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp lore -> Exp lore) -> BasicOp lore -> Exp lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp lore
forall lore. SubExp -> BasicOp lore
SubExp SubExp
se
ruleBasicOp TopDown lore
_ Pattern lore
pat StmAux (ExpAttr lore)
_ (Replicate (Shape []) (Var VName
v)) = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
  TypeBase Shape NoUniqueness
v_t <- VName -> RuleM lore (TypeBase Shape NoUniqueness)
forall lore (m :: * -> *).
HasScope lore m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
v
  Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind_ Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp lore -> Exp lore
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp lore -> Exp lore) -> BasicOp lore -> Exp lore
forall a b. (a -> b) -> a -> b
$ if TypeBase Shape NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType TypeBase Shape NoUniqueness
v_t
                           then SubExp -> BasicOp lore
forall lore. SubExp -> BasicOp lore
SubExp (SubExp -> BasicOp lore) -> SubExp -> BasicOp lore
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v
                           else VName -> BasicOp lore
forall lore. VName -> BasicOp lore
Copy VName
v
ruleBasicOp TopDown lore
vtable Pattern lore
pat StmAux (ExpAttr lore)
_  (Replicate Shape
shape (Var VName
v))
  | Just (BasicOp (Replicate Shape
shape2 SubExp
se), Certificates
cs) <- VName -> TopDown lore -> Maybe (Exp lore, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (Exp lore, Certificates)
ST.lookupExp VName
v TopDown lore
vtable =
      RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ Certificates -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind_ Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp lore -> Exp lore
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp lore -> Exp lore) -> BasicOp lore -> Exp lore
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp lore
forall lore. Shape -> SubExp -> BasicOp lore
Replicate (Shape
shapeShape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<>Shape
shape2) SubExp
se

-- | Turn array literals with identical elements into replicates.
ruleBasicOp TopDown lore
_ Pattern lore
pat StmAux (ExpAttr lore)
_ (ArrayLit (SubExp
se:[SubExp]
ses) TypeBase Shape NoUniqueness
_)
  | (SubExp -> Bool) -> [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
==SubExp
se) [SubExp]
ses =
    RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ let n :: SubExp
n = Int32 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int -> Int32
forall a b. (Integral a, Num b) => a -> b
fromIntegral ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
ses) Int32 -> Int32 -> Int32
forall a. Num a => a -> a -> a
+ Int32
1 :: Int32)
               in Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind_ Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp lore -> Exp lore
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp lore -> Exp lore) -> BasicOp lore -> Exp lore
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp lore
forall lore. Shape -> SubExp -> BasicOp lore
Replicate ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
n]) SubExp
se

ruleBasicOp TopDown lore
vtable Pattern lore
pat (StmAux Certificates
cs ExpAttr lore
_) (Index VName
idd Slice SubExp
slice)
  | Just [SubExp]
inds <- Slice SubExp -> Maybe [SubExp]
forall d. Slice d -> Maybe [d]
sliceIndices Slice SubExp
slice,
    Just (BasicOp (Reshape ShapeChange SubExp
newshape VName
idd2), Certificates
idd_cs) <- VName -> TopDown lore -> Maybe (Exp lore, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (Exp lore, Certificates)
ST.lookupExp VName
idd TopDown lore
vtable,
    ShapeChange SubExp -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeChange SubExp
newshape Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
inds =
      RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$
      case ShapeChange SubExp -> Maybe [SubExp]
forall d. ShapeChange d -> Maybe [d]
shapeCoercion ShapeChange SubExp
newshape of
        Just [SubExp]
_ ->
          Certificates -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (Certificates
csCertificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<>Certificates
idd_cs) (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
            Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind_ Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp lore -> Exp lore
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp lore -> Exp lore) -> BasicOp lore -> Exp lore
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp lore
forall lore. VName -> Slice SubExp -> BasicOp lore
Index VName
idd2 Slice SubExp
slice
        Maybe [SubExp]
Nothing -> do
          -- Linearise indices and map to old index space.
          [SubExp]
oldshape <- TypeBase Shape NoUniqueness -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (TypeBase Shape NoUniqueness -> [SubExp])
-> RuleM lore (TypeBase Shape NoUniqueness) -> RuleM lore [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> RuleM lore (TypeBase Shape NoUniqueness)
forall lore (m :: * -> *).
HasScope lore m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
idd2
          let new_inds :: [PrimExp VName]
new_inds =
                [PrimExp VName]
-> [PrimExp VName] -> [PrimExp VName] -> [PrimExp VName]
forall num. IntegralExp num => [num] -> [num] -> [num] -> [num]
reshapeIndex ((SubExp -> PrimExp VName) -> [SubExp] -> [PrimExp VName]
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int32) [SubExp]
oldshape)
                             ((SubExp -> PrimExp VName) -> [SubExp] -> [PrimExp VName]
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int32) ([SubExp] -> [PrimExp VName]) -> [SubExp] -> [PrimExp VName]
forall a b. (a -> b) -> a -> b
$ ShapeChange SubExp -> [SubExp]
forall d. ShapeChange d -> [d]
newDims ShapeChange SubExp
newshape)
                             ((SubExp -> PrimExp VName) -> [SubExp] -> [PrimExp VName]
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int32) [SubExp]
inds)
          [SubExp]
new_inds' <-
            (PrimExp VName -> RuleM lore SubExp)
-> [PrimExp VName] -> RuleM lore [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (String -> Exp (Lore (RuleM lore)) -> RuleM lore SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"new_index" (Exp lore -> RuleM lore SubExp)
-> (PrimExp VName -> RuleM lore (Exp lore))
-> PrimExp VName
-> RuleM lore SubExp
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< PrimExp VName -> RuleM lore (Exp lore)
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (PrimExp VName -> RuleM lore (Exp lore))
-> (PrimExp VName -> PrimExp VName)
-> PrimExp VName
-> RuleM lore (Exp lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimExp VName -> PrimExp VName
forall v. PrimExp v -> PrimExp v
asInt32PrimExp) [PrimExp VName]
new_inds
          Certificates -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (Certificates
csCertificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<>Certificates
idd_cs) (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
            Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind_ Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp lore -> Exp lore
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp lore -> Exp lore) -> BasicOp lore -> Exp lore
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp lore
forall lore. VName -> Slice SubExp -> BasicOp lore
Index VName
idd2 (Slice SubExp -> BasicOp lore) -> Slice SubExp -> BasicOp lore
forall a b. (a -> b) -> a -> b
$ (SubExp -> DimIndex SubExp) -> [SubExp] -> Slice SubExp
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix [SubExp]
new_inds'

ruleBasicOp TopDown lore
_ Pattern lore
pat StmAux (ExpAttr lore)
_ (BinOp (Pow IntType
t) SubExp
e1 SubExp
e2)
  | SubExp
e1 SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== IntType -> Integer -> SubExp
intConst IntType
t Integer
2 =
      RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind_ Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp lore -> Exp lore
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp lore -> Exp lore) -> BasicOp lore -> Exp lore
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp lore
forall lore. BinOp -> SubExp -> SubExp -> BasicOp lore
BinOp (IntType -> BinOp
Shl IntType
t) (IntType -> Integer -> SubExp
intConst IntType
t Integer
1) SubExp
e2

-- Handle identity permutation.
ruleBasicOp TopDown lore
_ Pattern lore
pat StmAux (ExpAttr lore)
_ (Rearrange [Int]
perm VName
v)
  | [Int] -> [Int]
forall a. Ord a => [a] -> [a]
sort [Int]
perm [Int] -> [Int] -> Bool
forall a. Eq a => a -> a -> Bool
== [Int]
perm =
      RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind_ Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp lore -> Exp lore
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp lore -> Exp lore) -> BasicOp lore -> Exp lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp lore
forall lore. SubExp -> BasicOp lore
SubExp (SubExp -> BasicOp lore) -> SubExp -> BasicOp lore
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v

ruleBasicOp TopDown lore
vtable Pattern lore
pat (StmAux Certificates
cs ExpAttr lore
_) (Rearrange [Int]
perm VName
v)
  | Just (BasicOp (Rearrange [Int]
perm2 VName
e), Certificates
v_cs) <- VName -> TopDown lore -> Maybe (Exp lore, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (Exp lore, Certificates)
ST.lookupExp VName
v TopDown lore
vtable =
      -- Rearranging a rearranging: compose the permutations.
      RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ Certificates -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (Certificates
csCertificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<>Certificates
v_cs) (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
      Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind_ Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp lore -> Exp lore
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp lore -> Exp lore) -> BasicOp lore -> Exp lore
forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp lore
forall lore. [Int] -> VName -> BasicOp lore
Rearrange ([Int]
perm [Int] -> [Int] -> [Int]
`rearrangeCompose` [Int]
perm2) VName
e

ruleBasicOp TopDown lore
vtable Pattern lore
pat (StmAux Certificates
cs ExpAttr lore
_) (Rearrange [Int]
perm VName
v)
  | Just (BasicOp (Rotate [SubExp]
offsets VName
v2), Certificates
v_cs) <- VName -> TopDown lore -> Maybe (Exp lore, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (Exp lore, Certificates)
ST.lookupExp VName
v TopDown lore
vtable,
    Just (BasicOp (Rearrange [Int]
perm3 VName
v3), Certificates
v2_cs) <- VName -> TopDown lore -> Maybe (Exp lore, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (Exp lore, Certificates)
ST.lookupExp VName
v2 TopDown lore
vtable = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
      let offsets' :: [SubExp]
offsets' = [Int] -> [SubExp] -> [SubExp]
forall a. [Int] -> [a] -> [a]
rearrangeShape ([Int] -> [Int]
rearrangeInverse [Int]
perm3) [SubExp]
offsets
      VName
rearrange_rotate <- String -> Exp (Lore (RuleM lore)) -> RuleM lore VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"rearrange_rotate" (Exp (Lore (RuleM lore)) -> RuleM lore VName)
-> Exp (Lore (RuleM lore)) -> RuleM lore VName
forall a b. (a -> b) -> a -> b
$ BasicOp lore -> Exp lore
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp lore -> Exp lore) -> BasicOp lore -> Exp lore
forall a b. (a -> b) -> a -> b
$ [SubExp] -> VName -> BasicOp lore
forall lore. [SubExp] -> VName -> BasicOp lore
Rotate [SubExp]
offsets' VName
v3
      Certificates -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (Certificates
csCertificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<>Certificates
v_csCertificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<>Certificates
v2_cs) (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
        Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind_ Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp lore -> Exp lore
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp lore -> Exp lore) -> BasicOp lore -> Exp lore
forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp lore
forall lore. [Int] -> VName -> BasicOp lore
Rearrange ([Int]
perm [Int] -> [Int] -> [Int]
`rearrangeCompose` [Int]
perm3) VName
rearrange_rotate

-- Rearranging a replicate where the outer dimension is left untouched.
ruleBasicOp TopDown lore
vtable Pattern lore
pat (StmAux Certificates
cs ExpAttr lore
_) (Rearrange [Int]
perm VName
v1)
  | Just (BasicOp (Replicate Shape
dims (Var VName
v2)), Certificates
v1_cs) <- VName -> TopDown lore -> Maybe (Exp lore, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (Exp lore, Certificates)
ST.lookupExp VName
v1 TopDown lore
vtable,
    Int
num_dims <- Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank Shape
dims,
    ([Int]
rep_perm, [Int]
rest_perm) <- Int -> [Int] -> ([Int], [Int])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_dims [Int]
perm,
    Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [Int] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Int]
rest_perm,
    [Int]
rep_perm [Int] -> [Int] -> Bool
forall a. Eq a => a -> a -> Bool
== [Int
0..[Int] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
rep_permInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ Certificates -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (Certificates
csCertificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<>Certificates
v1_cs) (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ do
      SubExp
v <- String -> Exp (Lore (RuleM lore)) -> RuleM lore SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"rearrange_replicate" (Exp (Lore (RuleM lore)) -> RuleM lore SubExp)
-> Exp (Lore (RuleM lore)) -> RuleM lore SubExp
forall a b. (a -> b) -> a -> b
$
           BasicOp lore -> Exp lore
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp lore -> Exp lore) -> BasicOp lore -> Exp lore
forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp lore
forall lore. [Int] -> VName -> BasicOp lore
Rearrange ((Int -> Int) -> [Int] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> Int -> Int
forall a. Num a => a -> a -> a
subtract Int
num_dims) [Int]
rest_perm) VName
v2
      Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind_ Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp lore -> Exp lore
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp lore -> Exp lore) -> BasicOp lore -> Exp lore
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp lore
forall lore. Shape -> SubExp -> BasicOp lore
Replicate Shape
dims SubExp
v

-- A zero-rotation is identity.
ruleBasicOp TopDown lore
_ Pattern lore
pat StmAux (ExpAttr lore)
_ (Rotate [SubExp]
offsets VName
v)
  | (SubExp -> Bool) -> [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all SubExp -> Bool
isCt0 [SubExp]
offsets = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind_ Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp lore -> Exp lore
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp lore -> Exp lore) -> BasicOp lore -> Exp lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp lore
forall lore. SubExp -> BasicOp lore
SubExp (SubExp -> BasicOp lore) -> SubExp -> BasicOp lore
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v

ruleBasicOp TopDown lore
vtable Pattern lore
pat (StmAux Certificates
cs ExpAttr lore
_) (Rotate [SubExp]
offsets VName
v)
  | Just (BasicOp (Rearrange [Int]
perm VName
v2), Certificates
v_cs) <- VName -> TopDown lore -> Maybe (Exp lore, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (Exp lore, Certificates)
ST.lookupExp VName
v TopDown lore
vtable,
    Just (BasicOp (Rotate [SubExp]
offsets2 VName
v3), Certificates
v2_cs) <- VName -> TopDown lore -> Maybe (Exp lore, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (Exp lore, Certificates)
ST.lookupExp VName
v2 TopDown lore
vtable = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
      let offsets2' :: [SubExp]
offsets2' = [Int] -> [SubExp] -> [SubExp]
forall a. [Int] -> [a] -> [a]
rearrangeShape ([Int] -> [Int]
rearrangeInverse [Int]
perm) [SubExp]
offsets2
          addOffsets :: SubExp -> SubExp -> m SubExp
addOffsets SubExp
x SubExp
y = String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"summed_offset" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp (Lore m) -> Exp (Lore m)
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp (Lore m) -> Exp (Lore m))
-> BasicOp (Lore m) -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp (Lore m)
forall lore. BinOp -> SubExp -> SubExp -> BasicOp lore
BinOp (IntType -> BinOp
Add IntType
Int32) SubExp
x SubExp
y
      [SubExp]
offsets' <- (SubExp -> SubExp -> RuleM lore SubExp)
-> [SubExp] -> [SubExp] -> RuleM lore [SubExp]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM SubExp -> SubExp -> RuleM lore SubExp
forall (m :: * -> *). MonadBinder m => SubExp -> SubExp -> m SubExp
addOffsets [SubExp]
offsets [SubExp]
offsets2'
      VName
rotate_rearrange <-
        Certificates -> RuleM lore VName -> RuleM lore VName
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs (RuleM lore VName -> RuleM lore VName)
-> RuleM lore VName -> RuleM lore VName
forall a b. (a -> b) -> a -> b
$ String -> Exp (Lore (RuleM lore)) -> RuleM lore VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"rotate_rearrange" (Exp (Lore (RuleM lore)) -> RuleM lore VName)
-> Exp (Lore (RuleM lore)) -> RuleM lore VName
forall a b. (a -> b) -> a -> b
$ BasicOp lore -> Exp lore
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp lore -> Exp lore) -> BasicOp lore -> Exp lore
forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp lore
forall lore. [Int] -> VName -> BasicOp lore
Rearrange [Int]
perm VName
v3
      Certificates -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (Certificates
v_cs Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> Certificates
v2_cs) (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
        Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind_ Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp lore -> Exp lore
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp lore -> Exp lore) -> BasicOp lore -> Exp lore
forall a b. (a -> b) -> a -> b
$ [SubExp] -> VName -> BasicOp lore
forall lore. [SubExp] -> VName -> BasicOp lore
Rotate [SubExp]
offsets' VName
rotate_rearrange

-- Combining Rotates.
ruleBasicOp TopDown lore
vtable Pattern lore
pat (StmAux Certificates
cs ExpAttr lore
_) (Rotate [SubExp]
offsets1 VName
v)
  | Just (BasicOp (Rotate [SubExp]
offsets2 VName
v2), Certificates
v_cs) <- VName -> TopDown lore -> Maybe (Exp lore, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (Exp lore, Certificates)
ST.lookupExp VName
v TopDown lore
vtable = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
      [SubExp]
offsets <- (SubExp -> SubExp -> RuleM lore SubExp)
-> [SubExp] -> [SubExp] -> RuleM lore [SubExp]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM SubExp -> SubExp -> RuleM lore SubExp
forall (m :: * -> *). MonadBinder m => SubExp -> SubExp -> m SubExp
add [SubExp]
offsets1 [SubExp]
offsets2
      Certificates -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (Certificates
csCertificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<>Certificates
v_cs) (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
        Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind_ Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp lore -> Exp lore
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp lore -> Exp lore) -> BasicOp lore -> Exp lore
forall a b. (a -> b) -> a -> b
$ [SubExp] -> VName -> BasicOp lore
forall lore. [SubExp] -> VName -> BasicOp lore
Rotate [SubExp]
offsets VName
v2
        where add :: SubExp -> SubExp -> m SubExp
add SubExp
x SubExp
y = String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"offset" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp (Lore m) -> Exp (Lore m)
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp (Lore m) -> Exp (Lore m))
-> BasicOp (Lore m) -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp (Lore m)
forall lore. BinOp -> SubExp -> SubExp -> BasicOp lore
BinOp (IntType -> BinOp
Add IntType
Int32) SubExp
x SubExp
y

-- If we see an Update with a scalar where the value to be written is
-- the result of indexing some other array, then we convert it into an
-- Update with a slice of that array.  This matters when the arrays
-- are far away (on the GPU, say), because it avoids a copy of the
-- scalar to and from the host.
ruleBasicOp TopDown lore
vtable Pattern lore
pat (StmAux Certificates
cs_x ExpAttr lore
_) (Update VName
arr_x Slice SubExp
slice_x (Var VName
v))
  | Just [SubExp]
_ <- Slice SubExp -> Maybe [SubExp]
forall d. Slice d -> Maybe [d]
sliceIndices Slice SubExp
slice_x,
    Just (Index VName
arr_y Slice SubExp
slice_y, Certificates
cs_y) <- VName -> TopDown lore -> Maybe (BasicOp lore, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (BasicOp lore, Certificates)
ST.lookupBasicOp VName
v TopDown lore
vtable,
    VName -> TopDown lore -> Bool
forall lore. VName -> SymbolTable lore -> Bool
ST.available VName
arr_y TopDown lore
vtable,
    -- XXX: we should check for proper aliasing here instead.
    VName
arr_y VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
/= VName
arr_x,
    Just (Slice SubExp
slice_x_bef, DimFix SubExp
i, []) <- Int
-> Slice SubExp
-> Maybe (Slice SubExp, DimIndex SubExp, Slice SubExp)
forall int a. Integral int => int -> [a] -> Maybe ([a], a, [a])
focusNth (Slice SubExp -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Slice SubExp
slice_x Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Slice SubExp
slice_x,
    Just (Slice SubExp
slice_y_bef, DimFix SubExp
j, []) <- Int
-> Slice SubExp
-> Maybe (Slice SubExp, DimIndex SubExp, Slice SubExp)
forall int a. Integral int => int -> [a] -> Maybe ([a], a, [a])
focusNth (Slice SubExp -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Slice SubExp
slice_y Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Slice SubExp
slice_y = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
      let slice_x' :: Slice SubExp
slice_x' = Slice SubExp
slice_x_bef Slice SubExp -> Slice SubExp -> Slice SubExp
forall a. [a] -> [a] -> [a]
++ [SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice SubExp
i (IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
1) (IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
1)]
          slice_y' :: Slice SubExp
slice_y' = Slice SubExp
slice_y_bef Slice SubExp -> Slice SubExp -> Slice SubExp
forall a. [a] -> [a] -> [a]
++ [SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice SubExp
j (IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
1) (IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
1)]
      VName
v' <- String -> Exp (Lore (RuleM lore)) -> RuleM lore VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp (VName -> String
baseString VName
v String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_slice") (Exp (Lore (RuleM lore)) -> RuleM lore VName)
-> Exp (Lore (RuleM lore)) -> RuleM lore VName
forall a b. (a -> b) -> a -> b
$ BasicOp lore -> Exp lore
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp lore -> Exp lore) -> BasicOp lore -> Exp lore
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp lore
forall lore. VName -> Slice SubExp -> BasicOp lore
Index VName
arr_y Slice SubExp
slice_y'
      Certificates -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (Certificates
cs_x Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> Certificates
cs_y) (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
        Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind_ Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp lore -> Exp lore
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp lore -> Exp lore) -> BasicOp lore -> Exp lore
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> SubExp -> BasicOp lore
forall lore. VName -> Slice SubExp -> SubExp -> BasicOp lore
Update VName
arr_x Slice SubExp
slice_x' (SubExp -> BasicOp lore) -> SubExp -> BasicOp lore
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v'

ruleBasicOp TopDown lore
_ Pattern lore
_ StmAux (ExpAttr lore)
_ BasicOp lore
_ =
  Rule lore
forall lore. Rule lore
Skip

-- | Remove the return values of a branch, that are not actually used
-- after a branch.  Standard dead code removal can remove the branch
-- if *none* of the return values are used, but this rule is more
-- precise.
removeDeadBranchResult :: BinderOps lore => BottomUpRuleIf lore
removeDeadBranchResult :: BottomUpRuleIf lore
removeDeadBranchResult (SymbolTable lore
_, UsageTable
used) Pattern lore
pat StmAux (ExpAttr lore)
_ (SubExp
e1, BodyT lore
tb, BodyT lore
fb, IfAttr [BranchType lore]
rettype IfSort
ifsort)
  | -- Only if there is no existential context...
    Pattern lore -> Int
forall attr. PatternT attr -> Int
patternSize Pattern lore
pat Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [BranchType lore] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [BranchType lore]
rettype,
    -- Figure out which of the names in 'pat' are used...
    [Bool]
patused <- (VName -> Bool) -> [VName] -> [Bool]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> UsageTable -> Bool
`UT.isUsedDirectly` UsageTable
used) ([VName] -> [Bool]) -> [VName] -> [Bool]
forall a b. (a -> b) -> a -> b
$ Pattern lore -> [VName]
forall attr. PatternT attr -> [VName]
patternNames Pattern lore
pat,
    -- If they are not all used, then this rule applies.
    Bool -> Bool
not ([Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and [Bool]
patused) =
  -- Remove the parts of the branch-results that correspond to dead
  -- return value bindings.  Note that this leaves dead code in the
  -- branch bodies, but that will be removed later.
  let tses :: [SubExp]
tses = BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT lore
tb
      fses :: [SubExp]
fses = BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT lore
fb
      pick :: [a] -> [a]
      pick :: [a] -> [a]
pick = ((Bool, a) -> a) -> [(Bool, a)] -> [a]
forall a b. (a -> b) -> [a] -> [b]
map (Bool, a) -> a
forall a b. (a, b) -> b
snd ([(Bool, a)] -> [a]) -> ([a] -> [(Bool, a)]) -> [a] -> [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Bool, a) -> Bool) -> [(Bool, a)] -> [(Bool, a)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool, a) -> Bool
forall a b. (a, b) -> a
fst ([(Bool, a)] -> [(Bool, a)])
-> ([a] -> [(Bool, a)]) -> [a] -> [(Bool, a)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Bool] -> [a] -> [(Bool, a)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Bool]
patused
      tb' :: BodyT lore
tb' = BodyT lore
tb { bodyResult :: [SubExp]
bodyResult = [SubExp] -> [SubExp]
forall a. [a] -> [a]
pick [SubExp]
tses }
      fb' :: BodyT lore
fb' = BodyT lore
fb { bodyResult :: [SubExp]
bodyResult = [SubExp] -> [SubExp]
forall a. [a] -> [a]
pick [SubExp]
fses }
      pat' :: [PatElemT (LetAttr lore)]
pat' = [PatElemT (LetAttr lore)] -> [PatElemT (LetAttr lore)]
forall a. [a] -> [a]
pick ([PatElemT (LetAttr lore)] -> [PatElemT (LetAttr lore)])
-> [PatElemT (LetAttr lore)] -> [PatElemT (LetAttr lore)]
forall a b. (a -> b) -> a -> b
$ Pattern lore -> [PatElemT (LetAttr lore)]
forall attr. PatternT attr -> [PatElemT attr]
patternElements Pattern lore
pat
      rettype' :: [BranchType lore]
rettype' = [BranchType lore] -> [BranchType lore]
forall a. [a] -> [a]
pick [BranchType lore]
rettype
  in RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind_ ([PatElemT (LetAttr lore)]
-> [PatElemT (LetAttr lore)] -> Pattern lore
forall attr. [PatElemT attr] -> [PatElemT attr] -> PatternT attr
Pattern [] [PatElemT (LetAttr lore)]
pat') (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ SubExp
-> BodyT lore
-> BodyT lore
-> IfAttr (BranchType lore)
-> ExpT lore
forall lore.
SubExp
-> BodyT lore
-> BodyT lore
-> IfAttr (BranchType lore)
-> ExpT lore
If SubExp
e1 BodyT lore
tb' BodyT lore
fb' (IfAttr (BranchType lore) -> ExpT lore)
-> IfAttr (BranchType lore) -> ExpT lore
forall a b. (a -> b) -> a -> b
$ [BranchType lore] -> IfSort -> IfAttr (BranchType lore)
forall rt. [rt] -> IfSort -> IfAttr rt
IfAttr [BranchType lore]
rettype' IfSort
ifsort

  | Bool
otherwise = Rule lore
forall lore. Rule lore
Skip


-- Some helper functions

isCt1 :: SubExp -> Bool
isCt1 :: SubExp -> Bool
isCt1 (Constant PrimValue
v) = PrimValue -> Bool
oneIsh PrimValue
v
isCt1 SubExp
_ = Bool
False

isCt0 :: SubExp -> Bool
isCt0 :: SubExp -> Bool
isCt0 (Constant PrimValue
v) = PrimValue -> Bool
zeroIsh PrimValue
v
isCt0 SubExp
_ = Bool
False