{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE Safe #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeFamilies #-}

-- | This module defines a collection of simplification rules, as per
-- "Futhark.Optimise.Simplify.Rule".  They are used in the
-- simplifier.
--
-- For performance reasons, many sufficiently simple logically
-- separate rules are merged into single "super-rules", like ruleIf
-- and ruleBasicOp.  This is because it is relatively expensive to
-- activate a rule just to determine that it does not apply.  Thus, it
-- is more efficient to have a few very fat rules than a lot of small
-- rules.  This does not affect the compiler result in any way; it is
-- purely an optimisation to speed up compilation.
module Futhark.Optimise.Simplify.Rules
  ( standardRules,
    removeUnnecessaryCopy,
  )
where

import Control.Monad
import Data.Either
import Data.List (find, foldl', isSuffixOf, partition, sort)
import qualified Data.Map.Strict as M
import Data.Maybe
import Futhark.Analysis.DataDependencies
import Futhark.Analysis.PrimExp.Convert
import qualified Futhark.Analysis.SymbolTable as ST
import qualified Futhark.Analysis.UsageTable as UT
import Futhark.Construct
import Futhark.IR
import Futhark.IR.Prop.Aliases
import Futhark.Optimise.Simplify.ClosedForm
import Futhark.Optimise.Simplify.Rule
import Futhark.Transform.Rename
import Futhark.Util

topDownRules :: (BinderOps lore, Aliased lore) => [TopDownRule lore]
topDownRules :: [TopDownRule lore]
topDownRules =
  [ RuleDoLoop lore (TopDown lore) -> TopDownRule lore
forall lore a. RuleDoLoop lore a -> SimplificationRule lore a
RuleDoLoop RuleDoLoop lore (TopDown lore)
forall lore. BinderOps lore => TopDownRuleDoLoop lore
hoistLoopInvariantMergeVariables,
    RuleDoLoop lore (TopDown lore) -> TopDownRule lore
forall lore a. RuleDoLoop lore a -> SimplificationRule lore a
RuleDoLoop RuleDoLoop lore (TopDown lore)
forall lore. BinderOps lore => TopDownRuleDoLoop lore
simplifyClosedFormLoop,
    RuleDoLoop lore (TopDown lore) -> TopDownRule lore
forall lore a. RuleDoLoop lore a -> SimplificationRule lore a
RuleDoLoop RuleDoLoop lore (TopDown lore)
forall lore. BinderOps lore => TopDownRuleDoLoop lore
simplifyKnownIterationLoop,
    RuleDoLoop lore (TopDown lore) -> TopDownRule lore
forall lore a. RuleDoLoop lore a -> SimplificationRule lore a
RuleDoLoop RuleDoLoop lore (TopDown lore)
forall lore.
(BinderOps lore, Aliased lore) =>
TopDownRuleDoLoop lore
simplifyLoopVariables,
    RuleDoLoop lore (TopDown lore) -> TopDownRule lore
forall lore a. RuleDoLoop lore a -> SimplificationRule lore a
RuleDoLoop RuleDoLoop lore (TopDown lore)
forall lore. BinderOps lore => TopDownRuleDoLoop lore
narrowLoopType,
    RuleGeneric lore (TopDown lore) -> TopDownRule lore
forall lore a. RuleGeneric lore a -> SimplificationRule lore a
RuleGeneric RuleGeneric lore (TopDown lore)
forall lore. BinderOps lore => TopDownRuleGeneric lore
constantFoldPrimFun,
    RuleIf lore (TopDown lore) -> TopDownRule lore
forall lore a. RuleIf lore a -> SimplificationRule lore a
RuleIf RuleIf lore (TopDown lore)
forall lore. BinderOps lore => TopDownRuleIf lore
ruleIf,
    RuleIf lore (TopDown lore) -> TopDownRule lore
forall lore a. RuleIf lore a -> SimplificationRule lore a
RuleIf RuleIf lore (TopDown lore)
forall lore. BinderOps lore => TopDownRuleIf lore
hoistBranchInvariant,
    RuleBasicOp lore (TopDown lore) -> TopDownRule lore
forall lore a. RuleBasicOp lore a -> SimplificationRule lore a
RuleBasicOp RuleBasicOp lore (TopDown lore)
forall lore. BinderOps lore => TopDownRuleBasicOp lore
ruleBasicOp
  ]

bottomUpRules :: BinderOps lore => [BottomUpRule lore]
bottomUpRules :: [BottomUpRule lore]
bottomUpRules =
  [ RuleDoLoop lore (BottomUp lore) -> BottomUpRule lore
forall lore a. RuleDoLoop lore a -> SimplificationRule lore a
RuleDoLoop RuleDoLoop lore (BottomUp lore)
forall lore. BinderOps lore => BottomUpRuleDoLoop lore
removeRedundantMergeVariables,
    RuleIf lore (BottomUp lore) -> BottomUpRule lore
forall lore a. RuleIf lore a -> SimplificationRule lore a
RuleIf RuleIf lore (BottomUp lore)
forall lore. BinderOps lore => BottomUpRuleIf lore
removeDeadBranchResult,
    RuleBasicOp lore (BottomUp lore) -> BottomUpRule lore
forall lore a. RuleBasicOp lore a -> SimplificationRule lore a
RuleBasicOp RuleBasicOp lore (BottomUp lore)
forall lore. BinderOps lore => BottomUpRuleBasicOp lore
simplifyIndex,
    RuleBasicOp lore (BottomUp lore) -> BottomUpRule lore
forall lore a. RuleBasicOp lore a -> SimplificationRule lore a
RuleBasicOp RuleBasicOp lore (BottomUp lore)
forall lore. BinderOps lore => BottomUpRuleBasicOp lore
simplifyConcat
  ]

-- | A set of standard simplification rules.  These assume pure
-- functional semantics, and so probably should not be applied after
-- memory block merging.
standardRules :: (BinderOps lore, Aliased lore) => RuleBook lore
standardRules :: RuleBook lore
standardRules = [TopDownRule lore] -> [BottomUpRule lore] -> RuleBook lore
forall m. [TopDownRule m] -> [BottomUpRule m] -> RuleBook m
ruleBook [TopDownRule lore]
forall lore. (BinderOps lore, Aliased lore) => [TopDownRule lore]
topDownRules [BottomUpRule lore]
forall lore. BinderOps lore => [BottomUpRule lore]
bottomUpRules

-- This next one is tricky - it's easy enough to determine that some
-- loop result is not used after the loop, but here, we must also make
-- sure that it does not affect any other values.
--
-- I do not claim that the current implementation of this rule is
-- perfect, but it should suffice for many cases, and should never
-- generate wrong code.
removeRedundantMergeVariables :: BinderOps lore => BottomUpRuleDoLoop lore
removeRedundantMergeVariables :: BottomUpRuleDoLoop lore
removeRedundantMergeVariables (SymbolTable lore
_, UsageTable
used) Pattern lore
pat StmAux (ExpDec lore)
aux ([(FParam lore, SubExp)]
ctx, [(FParam lore, SubExp)]
val, LoopForm lore
form, BodyT lore
body)
  | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ ((FParam lore, SubExp) -> Bool) -> [(FParam lore, SubExp)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (FParam lore -> Bool
usedAfterLoop (FParam lore -> Bool)
-> ((FParam lore, SubExp) -> FParam lore)
-> (FParam lore, SubExp)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (FParam lore, SubExp) -> FParam lore
forall a b. (a, b) -> a
fst) [(FParam lore, SubExp)]
val,
    [(FParam lore, SubExp)] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(FParam lore, SubExp)]
ctx -- FIXME: things get tricky if we can remove all vals
    -- but some ctxs are still used.  We take the easy way
    -- out for now.
    =
    let ([SubExp]
ctx_es, [SubExp]
val_es) = Int -> [SubExp] -> ([SubExp], [SubExp])
forall a. Int -> [a] -> ([a], [a])
splitAt ([(FParam lore, SubExp)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(FParam lore, SubExp)]
ctx) ([SubExp] -> ([SubExp], [SubExp]))
-> [SubExp] -> ([SubExp], [SubExp])
forall a b. (a -> b) -> a -> b
$ BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT lore
body
        necessaryForReturned :: Names
necessaryForReturned =
          (FParam lore -> Bool)
-> [(FParam lore, SubExp)] -> Map VName Names -> Names
forall dec.
(Param dec -> Bool)
-> [(Param dec, SubExp)] -> Map VName Names -> Names
findNecessaryForReturned
            FParam lore -> Bool
usedAfterLoopOrInForm
            ([FParam lore] -> [SubExp] -> [(FParam lore, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (((FParam lore, SubExp) -> FParam lore)
-> [(FParam lore, SubExp)] -> [FParam lore]
forall a b. (a -> b) -> [a] -> [b]
map (FParam lore, SubExp) -> FParam lore
forall a b. (a, b) -> a
fst ([(FParam lore, SubExp)] -> [FParam lore])
-> [(FParam lore, SubExp)] -> [FParam lore]
forall a b. (a -> b) -> a -> b
$ [(FParam lore, SubExp)]
ctx [(FParam lore, SubExp)]
-> [(FParam lore, SubExp)] -> [(FParam lore, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(FParam lore, SubExp)]
val) ([SubExp] -> [(FParam lore, SubExp)])
-> [SubExp] -> [(FParam lore, SubExp)]
forall a b. (a -> b) -> a -> b
$ [SubExp]
ctx_es [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp]
val_es)
            (BodyT lore -> Map VName Names
forall lore. ASTLore lore => Body lore -> Map VName Names
dataDependencies BodyT lore
body)

        resIsNecessary :: ((FParam lore, SubExp), SubExp) -> Bool
resIsNecessary ((FParam lore
v, SubExp
_), SubExp
_) =
          FParam lore -> Bool
usedAfterLoop FParam lore
v
            Bool -> Bool -> Bool
|| FParam lore -> VName
forall dec. Param dec -> VName
paramName FParam lore
v VName -> Names -> Bool
`nameIn` Names
necessaryForReturned
            Bool -> Bool -> Bool
|| FParam lore -> Bool
referencedInPat FParam lore
v
            Bool -> Bool -> Bool
|| FParam lore -> Bool
referencedInForm FParam lore
v

        ([((FParam lore, SubExp), SubExp)]
keep_ctx, [((FParam lore, SubExp), SubExp)]
discard_ctx) =
          (((FParam lore, SubExp), SubExp) -> Bool)
-> [((FParam lore, SubExp), SubExp)]
-> ([((FParam lore, SubExp), SubExp)],
    [((FParam lore, SubExp), SubExp)])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition ((FParam lore, SubExp), SubExp) -> Bool
resIsNecessary ([((FParam lore, SubExp), SubExp)]
 -> ([((FParam lore, SubExp), SubExp)],
     [((FParam lore, SubExp), SubExp)]))
-> [((FParam lore, SubExp), SubExp)]
-> ([((FParam lore, SubExp), SubExp)],
    [((FParam lore, SubExp), SubExp)])
forall a b. (a -> b) -> a -> b
$ [(FParam lore, SubExp)]
-> [SubExp] -> [((FParam lore, SubExp), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [(FParam lore, SubExp)]
ctx [SubExp]
ctx_es
        ([(PatElemT (LetDec lore), ((FParam lore, SubExp), SubExp))]
keep_valpart, [(PatElemT (LetDec lore), ((FParam lore, SubExp), SubExp))]
discard_valpart) =
          ((PatElemT (LetDec lore), ((FParam lore, SubExp), SubExp)) -> Bool)
-> [(PatElemT (LetDec lore), ((FParam lore, SubExp), SubExp))]
-> ([(PatElemT (LetDec lore), ((FParam lore, SubExp), SubExp))],
    [(PatElemT (LetDec lore), ((FParam lore, SubExp), SubExp))])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition (((FParam lore, SubExp), SubExp) -> Bool
resIsNecessary (((FParam lore, SubExp), SubExp) -> Bool)
-> ((PatElemT (LetDec lore), ((FParam lore, SubExp), SubExp))
    -> ((FParam lore, SubExp), SubExp))
-> (PatElemT (LetDec lore), ((FParam lore, SubExp), SubExp))
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PatElemT (LetDec lore), ((FParam lore, SubExp), SubExp))
-> ((FParam lore, SubExp), SubExp)
forall a b. (a, b) -> b
snd) ([(PatElemT (LetDec lore), ((FParam lore, SubExp), SubExp))]
 -> ([(PatElemT (LetDec lore), ((FParam lore, SubExp), SubExp))],
     [(PatElemT (LetDec lore), ((FParam lore, SubExp), SubExp))]))
-> [(PatElemT (LetDec lore), ((FParam lore, SubExp), SubExp))]
-> ([(PatElemT (LetDec lore), ((FParam lore, SubExp), SubExp))],
    [(PatElemT (LetDec lore), ((FParam lore, SubExp), SubExp))])
forall a b. (a -> b) -> a -> b
$
            [PatElemT (LetDec lore)]
-> [((FParam lore, SubExp), SubExp)]
-> [(PatElemT (LetDec lore), ((FParam lore, SubExp), SubExp))]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pattern lore -> [PatElemT (LetDec lore)]
forall dec. PatternT dec -> [PatElemT dec]
patternValueElements Pattern lore
pat) ([((FParam lore, SubExp), SubExp)]
 -> [(PatElemT (LetDec lore), ((FParam lore, SubExp), SubExp))])
-> [((FParam lore, SubExp), SubExp)]
-> [(PatElemT (LetDec lore), ((FParam lore, SubExp), SubExp))]
forall a b. (a -> b) -> a -> b
$ [(FParam lore, SubExp)]
-> [SubExp] -> [((FParam lore, SubExp), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [(FParam lore, SubExp)]
val [SubExp]
val_es

        ([PatElemT (LetDec lore)]
keep_valpatelems, [((FParam lore, SubExp), SubExp)]
keep_val) = [(PatElemT (LetDec lore), ((FParam lore, SubExp), SubExp))]
-> ([PatElemT (LetDec lore)], [((FParam lore, SubExp), SubExp)])
forall a b. [(a, b)] -> ([a], [b])
unzip [(PatElemT (LetDec lore), ((FParam lore, SubExp), SubExp))]
keep_valpart
        ([PatElemT (LetDec lore)]
_discard_valpatelems, [((FParam lore, SubExp), SubExp)]
discard_val) = [(PatElemT (LetDec lore), ((FParam lore, SubExp), SubExp))]
-> ([PatElemT (LetDec lore)], [((FParam lore, SubExp), SubExp)])
forall a b. [(a, b)] -> ([a], [b])
unzip [(PatElemT (LetDec lore), ((FParam lore, SubExp), SubExp))]
discard_valpart
        ([(FParam lore, SubExp)]
ctx', [SubExp]
ctx_es') = [((FParam lore, SubExp), SubExp)]
-> ([(FParam lore, SubExp)], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip [((FParam lore, SubExp), SubExp)]
keep_ctx
        ([(FParam lore, SubExp)]
val', [SubExp]
val_es') = [((FParam lore, SubExp), SubExp)]
-> ([(FParam lore, SubExp)], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip [((FParam lore, SubExp), SubExp)]
keep_val

        body' :: BodyT lore
body' = BodyT lore
body {bodyResult :: [SubExp]
bodyResult = [SubExp]
ctx_es' [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp]
val_es'}
        free_in_keeps :: Names
free_in_keeps = [PatElemT (LetDec lore)] -> Names
forall a. FreeIn a => a -> Names
freeIn [PatElemT (LetDec lore)]
keep_valpatelems

        stillUsedContext :: PatElemT (LetDec lore) -> Bool
stillUsedContext PatElemT (LetDec lore)
pat_elem =
          PatElemT (LetDec lore) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
pat_elem
            VName -> Names -> Bool
`nameIn` ( Names
free_in_keeps
                         Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> [PatElemT (LetDec lore)] -> Names
forall a. FreeIn a => a -> Names
freeIn ((PatElemT (LetDec lore) -> Bool)
-> [PatElemT (LetDec lore)] -> [PatElemT (LetDec lore)]
forall a. (a -> Bool) -> [a] -> [a]
filter (PatElemT (LetDec lore) -> PatElemT (LetDec lore) -> Bool
forall a. Eq a => a -> a -> Bool
/= PatElemT (LetDec lore)
pat_elem) ([PatElemT (LetDec lore)] -> [PatElemT (LetDec lore)])
-> [PatElemT (LetDec lore)] -> [PatElemT (LetDec lore)]
forall a b. (a -> b) -> a -> b
$ Pattern lore -> [PatElemT (LetDec lore)]
forall dec. PatternT dec -> [PatElemT dec]
patternContextElements Pattern lore
pat)
                     )

        pat' :: Pattern lore
pat' =
          Pattern lore
pat
            { patternValueElements :: [PatElemT (LetDec lore)]
patternValueElements = [PatElemT (LetDec lore)]
keep_valpatelems,
              patternContextElements :: [PatElemT (LetDec lore)]
patternContextElements =
                (PatElemT (LetDec lore) -> Bool)
-> [PatElemT (LetDec lore)] -> [PatElemT (LetDec lore)]
forall a. (a -> Bool) -> [a] -> [a]
filter PatElemT (LetDec lore) -> Bool
stillUsedContext ([PatElemT (LetDec lore)] -> [PatElemT (LetDec lore)])
-> [PatElemT (LetDec lore)] -> [PatElemT (LetDec lore)]
forall a b. (a -> b) -> a -> b
$ Pattern lore -> [PatElemT (LetDec lore)]
forall dec. PatternT dec -> [PatElemT dec]
patternContextElements Pattern lore
pat
            }
     in if [(FParam lore, SubExp)]
ctx' [(FParam lore, SubExp)]
-> [(FParam lore, SubExp)] -> [(FParam lore, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(FParam lore, SubExp)]
val' [(FParam lore, SubExp)] -> [(FParam lore, SubExp)] -> Bool
forall a. Eq a => a -> a -> Bool
== [(FParam lore, SubExp)]
ctx [(FParam lore, SubExp)]
-> [(FParam lore, SubExp)] -> [(FParam lore, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(FParam lore, SubExp)]
val
          then Rule lore
forall lore. Rule lore
Skip
          else RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
            -- We can't just remove the bindings in 'discard', since the loop
            -- body may still use their names in (now-dead) expressions.
            -- Hence, we add them inside the loop, fully aware that dead-code
            -- removal will eventually get rid of them.  Some care is
            -- necessary to handle unique bindings.
            BodyT lore
body'' <- RuleM lore (Body (Lore (RuleM lore)))
-> RuleM lore (Body (Lore (RuleM lore)))
forall (m :: * -> *).
MonadBinder m =>
m (Body (Lore m)) -> m (Body (Lore m))
insertStmsM (RuleM lore (Body (Lore (RuleM lore)))
 -> RuleM lore (Body (Lore (RuleM lore))))
-> RuleM lore (Body (Lore (RuleM lore)))
-> RuleM lore (Body (Lore (RuleM lore)))
forall a b. (a -> b) -> a -> b
$ do
              (([VName], ExpT lore) -> RuleM lore ())
-> [([VName], ExpT lore)] -> RuleM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (([VName] -> ExpT lore -> RuleM lore ())
-> ([VName], ExpT lore) -> RuleM lore ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry [VName] -> ExpT lore -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames) ([([VName], ExpT lore)] -> RuleM lore ())
-> [([VName], ExpT lore)] -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ [((FParam lore, SubExp), SubExp)] -> [([VName], ExpT lore)]
forall b lore.
[((FParam lore, SubExp), b)] -> [([VName], ExpT lore)]
dummyStms [((FParam lore, SubExp), SubExp)]
discard_ctx
              (([VName], ExpT lore) -> RuleM lore ())
-> [([VName], ExpT lore)] -> RuleM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (([VName] -> ExpT lore -> RuleM lore ())
-> ([VName], ExpT lore) -> RuleM lore ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry [VName] -> ExpT lore -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames) ([([VName], ExpT lore)] -> RuleM lore ())
-> [([VName], ExpT lore)] -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ [((FParam lore, SubExp), SubExp)] -> [([VName], ExpT lore)]
forall b lore.
[((FParam lore, SubExp), b)] -> [([VName], ExpT lore)]
dummyStms [((FParam lore, SubExp), SubExp)]
discard_val
              BodyT lore -> RuleM lore (BodyT lore)
forall (m :: * -> *) a. Monad m => a -> m a
return BodyT lore
body'
            StmAux (ExpDec lore) -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux (ExpDec lore)
aux (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat' (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ [(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
forall lore.
[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
DoLoop [(FParam lore, SubExp)]
ctx' [(FParam lore, SubExp)]
val' LoopForm lore
form BodyT lore
body''
  where
    pat_used :: [Bool]
pat_used = (VName -> Bool) -> [VName] -> [Bool]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> UsageTable -> Bool
`UT.isUsedDirectly` UsageTable
used) ([VName] -> [Bool]) -> [VName] -> [Bool]
forall a b. (a -> b) -> a -> b
$ Pattern lore -> [VName]
forall dec. PatternT dec -> [VName]
patternValueNames Pattern lore
pat
    used_vals :: [VName]
used_vals = ((VName, Bool) -> VName) -> [(VName, Bool)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, Bool) -> VName
forall a b. (a, b) -> a
fst ([(VName, Bool)] -> [VName]) -> [(VName, Bool)] -> [VName]
forall a b. (a -> b) -> a -> b
$ ((VName, Bool) -> Bool) -> [(VName, Bool)] -> [(VName, Bool)]
forall a. (a -> Bool) -> [a] -> [a]
filter (VName, Bool) -> Bool
forall a b. (a, b) -> b
snd ([(VName, Bool)] -> [(VName, Bool)])
-> [(VName, Bool)] -> [(VName, Bool)]
forall a b. (a -> b) -> a -> b
$ [VName] -> [Bool] -> [(VName, Bool)]
forall a b. [a] -> [b] -> [(a, b)]
zip (((FParam lore, SubExp) -> VName)
-> [(FParam lore, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (FParam lore -> VName
forall dec. Param dec -> VName
paramName (FParam lore -> VName)
-> ((FParam lore, SubExp) -> FParam lore)
-> (FParam lore, SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (FParam lore, SubExp) -> FParam lore
forall a b. (a, b) -> a
fst) [(FParam lore, SubExp)]
val) [Bool]
pat_used
    usedAfterLoop :: FParam lore -> Bool
usedAfterLoop = (VName -> [VName] -> Bool) -> [VName] -> VName -> Bool
forall a b c. (a -> b -> c) -> b -> a -> c
flip VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
elem [VName]
used_vals (VName -> Bool) -> (FParam lore -> VName) -> FParam lore -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FParam lore -> VName
forall dec. Param dec -> VName
paramName
    usedAfterLoopOrInForm :: FParam lore -> Bool
usedAfterLoopOrInForm FParam lore
p =
      FParam lore -> Bool
usedAfterLoop FParam lore
p Bool -> Bool -> Bool
|| FParam lore -> VName
forall dec. Param dec -> VName
paramName FParam lore
p VName -> Names -> Bool
`nameIn` LoopForm lore -> Names
forall a. FreeIn a => a -> Names
freeIn LoopForm lore
form
    patAnnotNames :: Names
patAnnotNames = [FParam lore] -> Names
forall a. FreeIn a => a -> Names
freeIn ([FParam lore] -> Names) -> [FParam lore] -> Names
forall a b. (a -> b) -> a -> b
$ ((FParam lore, SubExp) -> FParam lore)
-> [(FParam lore, SubExp)] -> [FParam lore]
forall a b. (a -> b) -> [a] -> [b]
map (FParam lore, SubExp) -> FParam lore
forall a b. (a, b) -> a
fst ([(FParam lore, SubExp)] -> [FParam lore])
-> [(FParam lore, SubExp)] -> [FParam lore]
forall a b. (a -> b) -> a -> b
$ [(FParam lore, SubExp)]
ctx [(FParam lore, SubExp)]
-> [(FParam lore, SubExp)] -> [(FParam lore, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(FParam lore, SubExp)]
val
    referencedInPat :: FParam lore -> Bool
referencedInPat = (VName -> Names -> Bool
`nameIn` Names
patAnnotNames) (VName -> Bool) -> (FParam lore -> VName) -> FParam lore -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FParam lore -> VName
forall dec. Param dec -> VName
paramName
    referencedInForm :: FParam lore -> Bool
referencedInForm = (VName -> Names -> Bool
`nameIn` LoopForm lore -> Names
forall a. FreeIn a => a -> Names
freeIn LoopForm lore
form) (VName -> Bool) -> (FParam lore -> VName) -> FParam lore -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FParam lore -> VName
forall dec. Param dec -> VName
paramName

    dummyStms :: [((FParam lore, SubExp), b)] -> [([VName], ExpT lore)]
dummyStms = (((FParam lore, SubExp), b) -> ([VName], ExpT lore))
-> [((FParam lore, SubExp), b)] -> [([VName], ExpT lore)]
forall a b. (a -> b) -> [a] -> [b]
map ((FParam lore, SubExp), b) -> ([VName], ExpT lore)
forall dec b lore.
DeclTyped dec =>
((Param dec, SubExp), b) -> ([VName], ExpT lore)
dummyStm
    dummyStm :: ((Param dec, SubExp), b) -> ([VName], ExpT lore)
dummyStm ((Param dec
p, SubExp
e), b
_)
      | TypeBase Shape Uniqueness -> Bool
forall shape. TypeBase shape Uniqueness -> Bool
unique (Param dec -> TypeBase Shape Uniqueness
forall dec. DeclTyped dec => Param dec -> TypeBase Shape Uniqueness
paramDeclType Param dec
p),
        Var VName
v <- SubExp
e =
        ([Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
p], BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
v)
      | Bool
otherwise = ([Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
p], BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
e)
removeRedundantMergeVariables (SymbolTable lore, UsageTable)
_ Pattern lore
_ StmAux (ExpDec lore)
_ ([(FParam lore, SubExp)], [(FParam lore, SubExp)], LoopForm lore,
 BodyT lore)
_ =
  Rule lore
forall lore. Rule lore
Skip

-- We may change the type of the loop if we hoist out a shape
-- annotation, in which case we also need to tweak the bound pattern.
hoistLoopInvariantMergeVariables :: BinderOps lore => TopDownRuleDoLoop lore
hoistLoopInvariantMergeVariables :: TopDownRuleDoLoop lore
hoistLoopInvariantMergeVariables TopDown lore
vtable Pattern lore
pat StmAux (ExpDec lore)
aux ([(FParam lore, SubExp)]
ctx, [(FParam lore, SubExp)]
val, LoopForm lore
form, BodyT lore
loopbody) =
  -- Figure out which of the elements of loopresult are
  -- loop-invariant, and hoist them out.
  case ((VName, (FParam lore, SubExp), SubExp)
 -> ([(Ident, SubExp)], [(PatElemT (LetDec lore), VName)],
     [(FParam lore, SubExp)], [SubExp])
 -> ([(Ident, SubExp)], [(PatElemT (LetDec lore), VName)],
     [(FParam lore, SubExp)], [SubExp]))
-> ([(Ident, SubExp)], [(PatElemT (LetDec lore), VName)],
    [(FParam lore, SubExp)], [SubExp])
-> [(VName, (FParam lore, SubExp), SubExp)]
-> ([(Ident, SubExp)], [(PatElemT (LetDec lore), VName)],
    [(FParam lore, SubExp)], [SubExp])
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (VName, (FParam lore, SubExp), SubExp)
-> ([(Ident, SubExp)], [(PatElemT (LetDec lore), VName)],
    [(FParam lore, SubExp)], [SubExp])
-> ([(Ident, SubExp)], [(PatElemT (LetDec lore), VName)],
    [(FParam lore, SubExp)], [SubExp])
checkInvariance ([], [(PatElemT (LetDec lore), VName)]
explpat, [], []) ([(VName, (FParam lore, SubExp), SubExp)]
 -> ([(Ident, SubExp)], [(PatElemT (LetDec lore), VName)],
     [(FParam lore, SubExp)], [SubExp]))
-> [(VName, (FParam lore, SubExp), SubExp)]
-> ([(Ident, SubExp)], [(PatElemT (LetDec lore), VName)],
    [(FParam lore, SubExp)], [SubExp])
forall a b. (a -> b) -> a -> b
$
    [VName]
-> [(FParam lore, SubExp)]
-> [SubExp]
-> [(VName, (FParam lore, SubExp), SubExp)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 (Pattern lore -> [VName]
forall dec. PatternT dec -> [VName]
patternNames Pattern lore
pat) [(FParam lore, SubExp)]
merge [SubExp]
res of
    ([], [(PatElemT (LetDec lore), VName)]
_, [(FParam lore, SubExp)]
_, [SubExp]
_) ->
      -- Nothing is invariant.
      Rule lore
forall lore. Rule lore
Skip
    ([(Ident, SubExp)]
invariant, [(PatElemT (LetDec lore), VName)]
explpat', [(FParam lore, SubExp)]
merge', [SubExp]
res') -> RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
      -- We have moved something invariant out of the loop.
      let loopbody' :: BodyT lore
loopbody' = BodyT lore
loopbody {bodyResult :: [SubExp]
bodyResult = [SubExp]
res'}
          invariantShape :: (a, VName) -> Bool
          invariantShape :: (a, VName) -> Bool
invariantShape (a
_, VName
shapemerge) =
            VName
shapemerge
              VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` ((FParam lore, SubExp) -> VName)
-> [(FParam lore, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (FParam lore -> VName
forall dec. Param dec -> VName
paramName (FParam lore -> VName)
-> ((FParam lore, SubExp) -> FParam lore)
-> (FParam lore, SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (FParam lore, SubExp) -> FParam lore
forall a b. (a, b) -> a
fst) [(FParam lore, SubExp)]
merge'
          ([(PatElemT (LetDec lore), VName)]
implpat', [(PatElemT (LetDec lore), VName)]
implinvariant) = ((PatElemT (LetDec lore), VName) -> Bool)
-> [(PatElemT (LetDec lore), VName)]
-> ([(PatElemT (LetDec lore), VName)],
    [(PatElemT (LetDec lore), VName)])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition (PatElemT (LetDec lore), VName) -> Bool
forall a. (a, VName) -> Bool
invariantShape [(PatElemT (LetDec lore), VName)]
implpat
          implinvariant' :: [(Ident, SubExp)]
implinvariant' = [(PatElemT (LetDec lore) -> Ident
forall dec. Typed dec => PatElemT dec -> Ident
patElemIdent PatElemT (LetDec lore)
p, VName -> SubExp
Var VName
v) | (PatElemT (LetDec lore)
p, VName
v) <- [(PatElemT (LetDec lore), VName)]
implinvariant]
          implpat'' :: [PatElemT (LetDec lore)]
implpat'' = ((PatElemT (LetDec lore), VName) -> PatElemT (LetDec lore))
-> [(PatElemT (LetDec lore), VName)] -> [PatElemT (LetDec lore)]
forall a b. (a -> b) -> [a] -> [b]
map (PatElemT (LetDec lore), VName) -> PatElemT (LetDec lore)
forall a b. (a, b) -> a
fst [(PatElemT (LetDec lore), VName)]
implpat'
          explpat'' :: [PatElemT (LetDec lore)]
explpat'' = ((PatElemT (LetDec lore), VName) -> PatElemT (LetDec lore))
-> [(PatElemT (LetDec lore), VName)] -> [PatElemT (LetDec lore)]
forall a b. (a -> b) -> [a] -> [b]
map (PatElemT (LetDec lore), VName) -> PatElemT (LetDec lore)
forall a b. (a, b) -> a
fst [(PatElemT (LetDec lore), VName)]
explpat'
          ([(FParam lore, SubExp)]
ctx', [(FParam lore, SubExp)]
val') = Int
-> [(FParam lore, SubExp)]
-> ([(FParam lore, SubExp)], [(FParam lore, SubExp)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([(PatElemT (LetDec lore), VName)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(PatElemT (LetDec lore), VName)]
implpat') [(FParam lore, SubExp)]
merge'
      [(Ident, SubExp)]
-> ((Ident, SubExp) -> RuleM lore ()) -> RuleM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([(Ident, SubExp)]
invariant [(Ident, SubExp)] -> [(Ident, SubExp)] -> [(Ident, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(Ident, SubExp)]
implinvariant') (((Ident, SubExp) -> RuleM lore ()) -> RuleM lore ())
-> ((Ident, SubExp) -> RuleM lore ()) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ \(Ident
v1, SubExp
v2) ->
        [VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [Ident -> VName
identName Ident
v1] (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
v2
      StmAux (ExpDec lore) -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux (ExpDec lore)
aux (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
        Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind ([PatElemT (LetDec lore)]
-> [PatElemT (LetDec lore)] -> Pattern lore
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [PatElemT (LetDec lore)]
implpat'' [PatElemT (LetDec lore)]
explpat'') (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
          [(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
forall lore.
[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
DoLoop [(FParam lore, SubExp)]
ctx' [(FParam lore, SubExp)]
val' LoopForm lore
form BodyT lore
loopbody'
  where
    merge :: [(FParam lore, SubExp)]
merge = [(FParam lore, SubExp)]
ctx [(FParam lore, SubExp)]
-> [(FParam lore, SubExp)] -> [(FParam lore, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(FParam lore, SubExp)]
val
    res :: [SubExp]
res = BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT lore
loopbody

    implpat :: [(PatElemT (LetDec lore), VName)]
implpat =
      [PatElemT (LetDec lore)]
-> [VName] -> [(PatElemT (LetDec lore), VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pattern lore -> [PatElemT (LetDec lore)]
forall dec. PatternT dec -> [PatElemT dec]
patternContextElements Pattern lore
pat) ([VName] -> [(PatElemT (LetDec lore), VName)])
-> [VName] -> [(PatElemT (LetDec lore), VName)]
forall a b. (a -> b) -> a -> b
$
        ((FParam lore, SubExp) -> VName)
-> [(FParam lore, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (FParam lore -> VName
forall dec. Param dec -> VName
paramName (FParam lore -> VName)
-> ((FParam lore, SubExp) -> FParam lore)
-> (FParam lore, SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (FParam lore, SubExp) -> FParam lore
forall a b. (a, b) -> a
fst) [(FParam lore, SubExp)]
ctx
    explpat :: [(PatElemT (LetDec lore), VName)]
explpat =
      [PatElemT (LetDec lore)]
-> [VName] -> [(PatElemT (LetDec lore), VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pattern lore -> [PatElemT (LetDec lore)]
forall dec. PatternT dec -> [PatElemT dec]
patternValueElements Pattern lore
pat) ([VName] -> [(PatElemT (LetDec lore), VName)])
-> [VName] -> [(PatElemT (LetDec lore), VName)]
forall a b. (a -> b) -> a -> b
$
        ((FParam lore, SubExp) -> VName)
-> [(FParam lore, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (FParam lore -> VName
forall dec. Param dec -> VName
paramName (FParam lore -> VName)
-> ((FParam lore, SubExp) -> FParam lore)
-> (FParam lore, SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (FParam lore, SubExp) -> FParam lore
forall a b. (a, b) -> a
fst) [(FParam lore, SubExp)]
val

    namesOfMergeParams :: Names
namesOfMergeParams = [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ ((FParam lore, SubExp) -> VName)
-> [(FParam lore, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (FParam lore -> VName
forall dec. Param dec -> VName
paramName (FParam lore -> VName)
-> ((FParam lore, SubExp) -> FParam lore)
-> (FParam lore, SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (FParam lore, SubExp) -> FParam lore
forall a b. (a, b) -> a
fst) ([(FParam lore, SubExp)] -> [VName])
-> [(FParam lore, SubExp)] -> [VName]
forall a b. (a -> b) -> a -> b
$ [(FParam lore, SubExp)]
ctx [(FParam lore, SubExp)]
-> [(FParam lore, SubExp)] -> [(FParam lore, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(FParam lore, SubExp)]
val

    removeFromResult :: (Param dec, b)
-> [(PatElemT dec, VName)]
-> (Maybe (Ident, b), [(PatElemT dec, VName)])
removeFromResult (Param dec
mergeParam, b
mergeInit) [(PatElemT dec, VName)]
explpat' =
      case ((PatElemT dec, VName) -> Bool)
-> [(PatElemT dec, VName)]
-> ([(PatElemT dec, VName)], [(PatElemT dec, VName)])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
mergeParam) (VName -> Bool)
-> ((PatElemT dec, VName) -> VName)
-> (PatElemT dec, VName)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PatElemT dec, VName) -> VName
forall a b. (a, b) -> b
snd) [(PatElemT dec, VName)]
explpat' of
        ([(PatElemT dec
patelem, VName
_)], [(PatElemT dec, VName)]
rest) ->
          ((Ident, b) -> Maybe (Ident, b)
forall a. a -> Maybe a
Just (PatElemT dec -> Ident
forall dec. Typed dec => PatElemT dec -> Ident
patElemIdent PatElemT dec
patelem, b
mergeInit), [(PatElemT dec, VName)]
rest)
        ([(PatElemT dec, VName)]
_, [(PatElemT dec, VName)]
_) ->
          (Maybe (Ident, b)
forall a. Maybe a
Nothing, [(PatElemT dec, VName)]
explpat')

    checkInvariance :: (VName, (FParam lore, SubExp), SubExp)
-> ([(Ident, SubExp)], [(PatElemT (LetDec lore), VName)],
    [(FParam lore, SubExp)], [SubExp])
-> ([(Ident, SubExp)], [(PatElemT (LetDec lore), VName)],
    [(FParam lore, SubExp)], [SubExp])
checkInvariance
      (VName
pat_name, (FParam lore
mergeParam, SubExp
mergeInit), SubExp
resExp)
      ([(Ident, SubExp)]
invariant, [(PatElemT (LetDec lore), VName)]
explpat', [(FParam lore, SubExp)]
merge', [SubExp]
resExps)
        | Bool -> Bool
not (TypeBase Shape Uniqueness -> Bool
forall shape. TypeBase shape Uniqueness -> Bool
unique (FParam lore -> TypeBase Shape Uniqueness
forall dec. DeclTyped dec => Param dec -> TypeBase Shape Uniqueness
paramDeclType FParam lore
mergeParam))
            Bool -> Bool -> Bool
|| TypeBase Shape Uniqueness -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank (FParam lore -> TypeBase Shape Uniqueness
forall dec. DeclTyped dec => Param dec -> TypeBase Shape Uniqueness
paramDeclType FParam lore
mergeParam) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1,
          Bool
isInvariant,
          -- Also do not remove the condition in a while-loop.
          Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ FParam lore -> VName
forall dec. Param dec -> VName
paramName FParam lore
mergeParam VName -> Names -> Bool
`nameIn` LoopForm lore -> Names
forall a. FreeIn a => a -> Names
freeIn LoopForm lore
form =
          let (Maybe (Ident, SubExp)
bnd, [(PatElemT (LetDec lore), VName)]
explpat'') =
                (FParam lore, SubExp)
-> [(PatElemT (LetDec lore), VName)]
-> (Maybe (Ident, SubExp), [(PatElemT (LetDec lore), VName)])
forall dec dec b.
Typed dec =>
(Param dec, b)
-> [(PatElemT dec, VName)]
-> (Maybe (Ident, b), [(PatElemT dec, VName)])
removeFromResult (FParam lore
mergeParam, SubExp
mergeInit) [(PatElemT (LetDec lore), VName)]
explpat'
           in ( ([(Ident, SubExp)] -> [(Ident, SubExp)])
-> ((Ident, SubExp) -> [(Ident, SubExp)] -> [(Ident, SubExp)])
-> Maybe (Ident, SubExp)
-> [(Ident, SubExp)]
-> [(Ident, SubExp)]
forall b a. b -> (a -> b) -> Maybe a -> b
maybe [(Ident, SubExp)] -> [(Ident, SubExp)]
forall a. a -> a
id (:) Maybe (Ident, SubExp)
bnd ([(Ident, SubExp)] -> [(Ident, SubExp)])
-> [(Ident, SubExp)] -> [(Ident, SubExp)]
forall a b. (a -> b) -> a -> b
$ (FParam lore -> Ident
forall dec. Typed dec => Param dec -> Ident
paramIdent FParam lore
mergeParam, SubExp
mergeInit) (Ident, SubExp) -> [(Ident, SubExp)] -> [(Ident, SubExp)]
forall a. a -> [a] -> [a]
: [(Ident, SubExp)]
invariant,
                [(PatElemT (LetDec lore), VName)]
explpat'',
                [(FParam lore, SubExp)]
merge',
                [SubExp]
resExps
              )
        where
          -- A non-unique merge variable is invariant if one of the
          -- following is true:
          --
          -- (0) The result is a variable of the same name as the
          -- parameter, where all existential parameters are already
          -- known to be invariant
          isInvariant :: Bool
isInvariant
            | Var VName
v2 <- SubExp
resExp,
              FParam lore -> VName
forall dec. Param dec -> VName
paramName FParam lore
mergeParam VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
v2 =
              Names -> FParam lore -> Bool
allExistentialInvariant
                ([VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ ((Ident, SubExp) -> VName) -> [(Ident, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (Ident -> VName
identName (Ident -> VName)
-> ((Ident, SubExp) -> Ident) -> (Ident, SubExp) -> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Ident, SubExp) -> Ident
forall a b. (a, b) -> a
fst) [(Ident, SubExp)]
invariant)
                FParam lore
mergeParam
            -- (1) The result is identical to the initial parameter value.
            | SubExp
mergeInit SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
resExp = Bool
True
            -- (2) The initial parameter value is equal to an outer
            -- loop parameter 'P', where the initial value of 'P' is
            -- equal to 'resExp', AND 'resExp' ultimately becomes the
            -- new value of 'P'.  XXX: it's a bit clumsy that this
            -- only works for one level of nesting, and I think it
            -- would not be too hard to generalise.
            | Var VName
init_v <- SubExp
mergeInit,
              Just (SubExp
p_init, SubExp
p_res) <- VName -> TopDown lore -> Maybe (SubExp, SubExp)
forall lore. VName -> SymbolTable lore -> Maybe (SubExp, SubExp)
ST.lookupLoopParam VName
init_v TopDown lore
vtable,
              SubExp
p_init SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
resExp,
              SubExp
p_res SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== VName -> SubExp
Var VName
pat_name =
              Bool
True
            | Bool
otherwise = Bool
False
    checkInvariance
      (VName
_pat_name, (FParam lore
mergeParam, SubExp
mergeInit), SubExp
resExp)
      ([(Ident, SubExp)]
invariant, [(PatElemT (LetDec lore), VName)]
explpat', [(FParam lore, SubExp)]
merge', [SubExp]
resExps) =
        ([(Ident, SubExp)]
invariant, [(PatElemT (LetDec lore), VName)]
explpat', (FParam lore
mergeParam, SubExp
mergeInit) (FParam lore, SubExp)
-> [(FParam lore, SubExp)] -> [(FParam lore, SubExp)]
forall a. a -> [a] -> [a]
: [(FParam lore, SubExp)]
merge', SubExp
resExp SubExp -> [SubExp] -> [SubExp]
forall a. a -> [a] -> [a]
: [SubExp]
resExps)

    allExistentialInvariant :: Names -> FParam lore -> Bool
allExistentialInvariant Names
namesOfInvariant FParam lore
mergeParam =
      (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Names -> VName -> Bool
invariantOrNotMergeParam Names
namesOfInvariant) ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$
        Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$
          FParam lore -> Names
forall a. FreeIn a => a -> Names
freeIn FParam lore
mergeParam Names -> Names -> Names
`namesSubtract` VName -> Names
oneName (FParam lore -> VName
forall dec. Param dec -> VName
paramName FParam lore
mergeParam)
    invariantOrNotMergeParam :: Names -> VName -> Bool
invariantOrNotMergeParam Names
namesOfInvariant VName
name =
      Bool -> Bool
not (VName
name VName -> Names -> Bool
`nameIn` Names
namesOfMergeParams)
        Bool -> Bool -> Bool
|| VName
name VName -> Names -> Bool
`nameIn` Names
namesOfInvariant

-- | A function that, given a subexpression, returns its type.
type TypeLookup = SubExp -> Maybe Type

-- | A simple rule is a top-down rule that can be expressed as a pure
-- function.
type SimpleRule lore = VarLookup lore -> TypeLookup -> BasicOp -> Maybe (BasicOp, Certificates)

simpleRules :: [SimpleRule lore]
simpleRules :: [SimpleRule lore]
simpleRules =
  [ SimpleRule lore
forall lore. SimpleRule lore
simplifyBinOp,
    SimpleRule lore
forall lore. SimpleRule lore
simplifyCmpOp,
    SimpleRule lore
forall lore. SimpleRule lore
simplifyUnOp,
    SimpleRule lore
forall lore. SimpleRule lore
simplifyConvOp,
    SimpleRule lore
forall lore. SimpleRule lore
simplifyAssert,
    SimpleRule lore
forall lore. SimpleRule lore
copyScratchToScratch,
    SimpleRule lore
forall lore. SimpleRule lore
simplifyIdentityReshape,
    SimpleRule lore
forall lore. SimpleRule lore
simplifyReshapeReshape,
    SimpleRule lore
forall lore. SimpleRule lore
simplifyReshapeScratch,
    SimpleRule lore
forall lore. SimpleRule lore
simplifyReshapeReplicate,
    SimpleRule lore
forall lore. SimpleRule lore
simplifyReshapeIota,
    SimpleRule lore
forall lore. SimpleRule lore
improveReshape
  ]

simplifyClosedFormLoop :: BinderOps lore => TopDownRuleDoLoop lore
simplifyClosedFormLoop :: TopDownRuleDoLoop lore
simplifyClosedFormLoop TopDown lore
_ Pattern lore
pat StmAux (ExpDec lore)
_ ([], [(FParam lore, SubExp)]
val, ForLoop VName
i IntType
it SubExp
bound [], BodyT lore
body) =
  RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ Pattern lore
-> [(FParam lore, SubExp)]
-> Names
-> IntType
-> SubExp
-> BodyT lore
-> RuleM lore ()
forall lore.
(ASTLore lore, BinderOps lore) =>
Pattern lore
-> [(FParam lore, SubExp)]
-> Names
-> IntType
-> SubExp
-> Body lore
-> RuleM lore ()
loopClosedForm Pattern lore
pat [(FParam lore, SubExp)]
val (VName -> Names
oneName VName
i) IntType
it SubExp
bound BodyT lore
body
simplifyClosedFormLoop TopDown lore
_ Pattern lore
_ StmAux (ExpDec lore)
_ ([(FParam lore, SubExp)], [(FParam lore, SubExp)], LoopForm lore,
 BodyT lore)
_ = Rule lore
forall lore. Rule lore
Skip

simplifyLoopVariables :: (BinderOps lore, Aliased lore) => TopDownRuleDoLoop lore
simplifyLoopVariables :: TopDownRuleDoLoop lore
simplifyLoopVariables TopDown lore
vtable Pattern lore
pat StmAux (ExpDec lore)
aux ([(FParam lore, SubExp)]
ctx, [(FParam lore, SubExp)]
val, form :: LoopForm lore
form@(ForLoop VName
i IntType
it SubExp
num_iters [(LParam lore, VName)]
loop_vars), BodyT lore
body)
  | [Maybe (RuleM lore IndexResult)]
simplifiable <- ((LParam lore, VName) -> Maybe (RuleM lore IndexResult))
-> [(LParam lore, VName)] -> [Maybe (RuleM lore IndexResult)]
forall a b. (a -> b) -> [a] -> [b]
map (LParam lore, VName) -> Maybe (RuleM lore IndexResult)
checkIfSimplifiable [(LParam lore, VName)]
loop_vars,
    Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (Maybe (RuleM lore IndexResult) -> Bool)
-> [Maybe (RuleM lore IndexResult)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Maybe (RuleM lore IndexResult) -> Bool
forall a. Maybe a -> Bool
isNothing [Maybe (RuleM lore IndexResult)]
simplifiable = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
    -- Check if the simplifications throw away more information than
    -- we are comfortable with at this stage.
    ([Maybe (LParam lore, VName)]
maybe_loop_vars, [Stms lore]
body_prefix_stms) <-
      Scope lore
-> RuleM lore ([Maybe (LParam lore, VName)], [Stms lore])
-> RuleM lore ([Maybe (LParam lore, VName)], [Stms lore])
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (LoopForm lore -> Scope lore
forall lore a. Scoped lore a => a -> Scope lore
scopeOf LoopForm lore
form) (RuleM lore ([Maybe (LParam lore, VName)], [Stms lore])
 -> RuleM lore ([Maybe (LParam lore, VName)], [Stms lore]))
-> RuleM lore ([Maybe (LParam lore, VName)], [Stms lore])
-> RuleM lore ([Maybe (LParam lore, VName)], [Stms lore])
forall a b. (a -> b) -> a -> b
$
        [(Maybe (LParam lore, VName), Stms lore)]
-> ([Maybe (LParam lore, VName)], [Stms lore])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Maybe (LParam lore, VName), Stms lore)]
 -> ([Maybe (LParam lore, VName)], [Stms lore]))
-> RuleM lore [(Maybe (LParam lore, VName), Stms lore)]
-> RuleM lore ([Maybe (LParam lore, VName)], [Stms lore])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((LParam lore, VName)
 -> Maybe (RuleM lore IndexResult)
 -> RuleM lore (Maybe (LParam lore, VName), Stms lore))
-> [(LParam lore, VName)]
-> [Maybe (RuleM lore IndexResult)]
-> RuleM lore [(Maybe (LParam lore, VName), Stms lore)]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (LParam lore, VName)
-> Maybe (RuleM lore IndexResult)
-> RuleM lore (Maybe (LParam lore, VName), Stms lore)
onLoopVar [(LParam lore, VName)]
loop_vars [Maybe (RuleM lore IndexResult)]
simplifiable
    if [Maybe (LParam lore, VName)]
maybe_loop_vars [Maybe (LParam lore, VName)]
-> [Maybe (LParam lore, VName)] -> Bool
forall a. Eq a => a -> a -> Bool
== ((LParam lore, VName) -> Maybe (LParam lore, VName))
-> [(LParam lore, VName)] -> [Maybe (LParam lore, VName)]
forall a b. (a -> b) -> [a] -> [b]
map (LParam lore, VName) -> Maybe (LParam lore, VName)
forall a. a -> Maybe a
Just [(LParam lore, VName)]
loop_vars
      then RuleM lore ()
forall lore a. RuleM lore a
cannotSimplify
      else do
        BodyT lore
body' <- RuleM lore (Body (Lore (RuleM lore)))
-> RuleM lore (Body (Lore (RuleM lore)))
forall (m :: * -> *).
MonadBinder m =>
m (Body (Lore m)) -> m (Body (Lore m))
insertStmsM (RuleM lore (Body (Lore (RuleM lore)))
 -> RuleM lore (Body (Lore (RuleM lore))))
-> RuleM lore (Body (Lore (RuleM lore)))
-> RuleM lore (Body (Lore (RuleM lore)))
forall a b. (a -> b) -> a -> b
$ do
          Stms (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms (Stms (Lore (RuleM lore)) -> RuleM lore ())
-> Stms (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ [Stms lore] -> Stms lore
forall a. Monoid a => [a] -> a
mconcat [Stms lore]
body_prefix_stms
          [SubExp] -> RuleM lore (BodyT lore)
forall (m :: * -> *).
MonadBinder m =>
[SubExp] -> m (Body (Lore m))
resultBodyM ([SubExp] -> RuleM lore (BodyT lore))
-> RuleM lore [SubExp] -> RuleM lore (BodyT lore)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Body (Lore (RuleM lore)) -> RuleM lore [SubExp]
forall (m :: * -> *). MonadBinder m => Body (Lore m) -> m [SubExp]
bodyBind BodyT lore
Body (Lore (RuleM lore))
body
        StmAux (ExpDec lore) -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux (ExpDec lore)
aux (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
          Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
            [(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
forall lore.
[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
DoLoop
              [(FParam lore, SubExp)]
ctx
              [(FParam lore, SubExp)]
val
              (VName
-> IntType -> SubExp -> [(LParam lore, VName)] -> LoopForm lore
forall lore.
VName
-> IntType -> SubExp -> [(LParam lore, VName)] -> LoopForm lore
ForLoop VName
i IntType
it SubExp
num_iters ([(LParam lore, VName)] -> LoopForm lore)
-> [(LParam lore, VName)] -> LoopForm lore
forall a b. (a -> b) -> a -> b
$ [Maybe (LParam lore, VName)] -> [(LParam lore, VName)]
forall a. [Maybe a] -> [a]
catMaybes [Maybe (LParam lore, VName)]
maybe_loop_vars)
              BodyT lore
body'
  where
    seType :: SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType (Var VName
v)
      | VName
v VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
i = TypeBase Shape NoUniqueness -> Maybe (TypeBase Shape NoUniqueness)
forall a. a -> Maybe a
Just (TypeBase Shape NoUniqueness
 -> Maybe (TypeBase Shape NoUniqueness))
-> TypeBase Shape NoUniqueness
-> Maybe (TypeBase Shape NoUniqueness)
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase Shape NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim (PrimType -> TypeBase Shape NoUniqueness)
-> PrimType -> TypeBase Shape NoUniqueness
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
it
      | Bool
otherwise = VName -> TopDown lore -> Maybe (TypeBase Shape NoUniqueness)
forall lore.
ASTLore lore =>
VName -> SymbolTable lore -> Maybe (TypeBase Shape NoUniqueness)
ST.lookupType VName
v TopDown lore
vtable
    seType (Constant PrimValue
v) = TypeBase Shape NoUniqueness -> Maybe (TypeBase Shape NoUniqueness)
forall a. a -> Maybe a
Just (TypeBase Shape NoUniqueness
 -> Maybe (TypeBase Shape NoUniqueness))
-> TypeBase Shape NoUniqueness
-> Maybe (TypeBase Shape NoUniqueness)
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase Shape NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim (PrimType -> TypeBase Shape NoUniqueness)
-> PrimType -> TypeBase Shape NoUniqueness
forall a b. (a -> b) -> a -> b
$ PrimValue -> PrimType
primValueType PrimValue
v
    consumed_in_body :: Names
consumed_in_body = BodyT lore -> Names
forall lore. Aliased lore => Body lore -> Names
consumedInBody BodyT lore
body

    vtable' :: TopDown lore
vtable' = Scope lore -> TopDown lore
forall lore. ASTLore lore => Scope lore -> SymbolTable lore
ST.fromScope (LoopForm lore -> Scope lore
forall lore a. Scoped lore a => a -> Scope lore
scopeOf LoopForm lore
form) TopDown lore -> TopDown lore -> TopDown lore
forall a. Semigroup a => a -> a -> a
<> TopDown lore
vtable

    checkIfSimplifiable :: (LParam lore, VName) -> Maybe (RuleM lore IndexResult)
checkIfSimplifiable (LParam lore
p, VName
arr) =
      SymbolTable (Lore (RuleM lore))
-> (SubExp -> Maybe (TypeBase Shape NoUniqueness))
-> VName
-> Slice SubExp
-> Bool
-> Maybe (RuleM lore IndexResult)
forall (m :: * -> *).
MonadBinder m =>
SymbolTable (Lore m)
-> (SubExp -> Maybe (TypeBase Shape NoUniqueness))
-> VName
-> Slice SubExp
-> Bool
-> Maybe (m IndexResult)
simplifyIndexing
        TopDown lore
SymbolTable (Lore (RuleM lore))
vtable'
        SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType
        VName
arr
        (SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (VName -> SubExp
Var VName
i) DimIndex SubExp -> Slice SubExp -> Slice SubExp
forall a. a -> [a] -> [a]
: TypeBase Shape NoUniqueness -> Slice SubExp -> Slice SubExp
fullSlice (LParam lore -> TypeBase Shape NoUniqueness
forall dec. Typed dec => Param dec -> TypeBase Shape NoUniqueness
paramType LParam lore
p) [])
        (Bool -> Maybe (RuleM lore IndexResult))
-> Bool -> Maybe (RuleM lore IndexResult)
forall a b. (a -> b) -> a -> b
$ LParam lore -> VName
forall dec. Param dec -> VName
paramName LParam lore
p VName -> Names -> Bool
`nameIn` Names
consumed_in_body

    -- We only want this simplification if the result does not refer
    -- to 'i' at all, or does not contain accesses.
    onLoopVar :: (LParam lore, VName)
-> Maybe (RuleM lore IndexResult)
-> RuleM lore (Maybe (LParam lore, VName), Stms lore)
onLoopVar (LParam lore
p, VName
arr) Maybe (RuleM lore IndexResult)
Nothing =
      (Maybe (LParam lore, VName), Stms lore)
-> RuleM lore (Maybe (LParam lore, VName), Stms lore)
forall (m :: * -> *) a. Monad m => a -> m a
return ((LParam lore, VName) -> Maybe (LParam lore, VName)
forall a. a -> Maybe a
Just (LParam lore
p, VName
arr), Stms lore
forall a. Monoid a => a
mempty)
    onLoopVar (LParam lore
p, VName
arr) (Just RuleM lore IndexResult
m) = do
      (IndexResult
x, Stms lore
x_stms) <- RuleM lore IndexResult
-> RuleM lore (IndexResult, Stms (Lore (RuleM lore)))
forall (m :: * -> *) a.
MonadBinder m =>
m a -> m (a, Stms (Lore m))
collectStms RuleM lore IndexResult
m
      case IndexResult
x of
        IndexResult Certificates
cs VName
arr' Slice SubExp
slice
          | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (Stm lore -> Bool) -> Stms lore -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any ((VName
i VName -> Names -> Bool
`nameIn`) (Names -> Bool) -> (Stm lore -> Names) -> Stm lore -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm lore -> Names
forall a. FreeIn a => a -> Names
freeIn) Stms lore
x_stms,
            DimFix (Var VName
j) : Slice SubExp
slice' <- Slice SubExp
slice,
            VName
j VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
i,
            Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ VName
i VName -> Names -> Bool
`nameIn` Slice SubExp -> Names
forall a. FreeIn a => a -> Names
freeIn Slice SubExp
slice -> do
            Stms (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms Stms lore
Stms (Lore (RuleM lore))
x_stms
            SubExp
w <- Int -> TypeBase Shape NoUniqueness -> SubExp
forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 (TypeBase Shape NoUniqueness -> SubExp)
-> RuleM lore (TypeBase Shape NoUniqueness) -> RuleM lore SubExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> RuleM lore (TypeBase Shape NoUniqueness)
forall lore (m :: * -> *).
HasScope lore m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
arr'
            VName
for_in_partial <-
              Certificates -> RuleM lore VName -> RuleM lore VName
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs (RuleM lore VName -> RuleM lore VName)
-> RuleM lore VName -> RuleM lore VName
forall a b. (a -> b) -> a -> b
$
                String -> Exp (Lore (RuleM lore)) -> RuleM lore VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"for_in_partial" (Exp (Lore (RuleM lore)) -> RuleM lore VName)
-> Exp (Lore (RuleM lore)) -> RuleM lore VName
forall a b. (a -> b) -> a -> b
$
                  BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$
                    VName -> Slice SubExp -> BasicOp
Index VName
arr' (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$
                      SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) SubExp
w (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) DimIndex SubExp -> Slice SubExp -> Slice SubExp
forall a. a -> [a] -> [a]
: Slice SubExp
slice'
            (Maybe (LParam lore, VName), Stms lore)
-> RuleM lore (Maybe (LParam lore, VName), Stms lore)
forall (m :: * -> *) a. Monad m => a -> m a
return ((LParam lore, VName) -> Maybe (LParam lore, VName)
forall a. a -> Maybe a
Just (LParam lore
p, VName
for_in_partial), Stms lore
forall a. Monoid a => a
mempty)
        SubExpResult Certificates
cs SubExp
se
          | (Stm lore -> Bool) -> Stms lore -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (ExpT lore -> Bool
forall lore. ExpT lore -> Bool
notIndex (ExpT lore -> Bool) -> (Stm lore -> ExpT lore) -> Stm lore -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm lore -> ExpT lore
forall lore. Stm lore -> Exp lore
stmExp) Stms lore
x_stms -> do
            Stms lore
x_stms' <- RuleM lore () -> RuleM lore (Stms (Lore (RuleM lore)))
forall (m :: * -> *) a. MonadBinder m => m a -> m (Stms (Lore m))
collectStms_ (RuleM lore () -> RuleM lore (Stms (Lore (RuleM lore))))
-> RuleM lore () -> RuleM lore (Stms (Lore (RuleM lore)))
forall a b. (a -> b) -> a -> b
$
              Certificates -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ do
                Stms (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms Stms lore
Stms (Lore (RuleM lore))
x_stms
                [VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [LParam lore -> VName
forall dec. Param dec -> VName
paramName LParam lore
p] (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
            (Maybe (LParam lore, VName), Stms lore)
-> RuleM lore (Maybe (LParam lore, VName), Stms lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (LParam lore, VName)
forall a. Maybe a
Nothing, Stms lore
x_stms')
        IndexResult
_ -> (Maybe (LParam lore, VName), Stms lore)
-> RuleM lore (Maybe (LParam lore, VName), Stms lore)
forall (m :: * -> *) a. Monad m => a -> m a
return ((LParam lore, VName) -> Maybe (LParam lore, VName)
forall a. a -> Maybe a
Just (LParam lore
p, VName
arr), Stms lore
forall a. Monoid a => a
mempty)

    notIndex :: ExpT lore -> Bool
notIndex (BasicOp Index {}) = Bool
False
    notIndex ExpT lore
_ = Bool
True
simplifyLoopVariables TopDown lore
_ Pattern lore
_ StmAux (ExpDec lore)
_ ([(FParam lore, SubExp)], [(FParam lore, SubExp)], LoopForm lore,
 BodyT lore)
_ = Rule lore
forall lore. Rule lore
Skip

-- If a for-loop with no loop variables has a counter of type Int64,
-- and the bound is just a constant or sign-extended integer of
-- smaller type, then change the loop to iterate over the smaller type
-- instead.  We then move the sign extension inside the loop instead.
-- This addresses loops of the form @for i in x..<y@ in the source
-- language.
narrowLoopType :: (BinderOps lore) => TopDownRuleDoLoop lore
narrowLoopType :: TopDownRuleDoLoop lore
narrowLoopType TopDown lore
vtable Pattern lore
pat StmAux (ExpDec lore)
aux ([(FParam lore, SubExp)]
ctx, [(FParam lore, SubExp)]
val, ForLoop VName
i IntType
Int64 SubExp
n [], BodyT lore
body)
  | Just (SubExp
n', IntType
it', Certificates
cs) <- Maybe (SubExp, IntType, Certificates)
smallerType =
    RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
      VName
i' <- String -> RuleM lore VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> RuleM lore VName) -> String -> RuleM lore VName
forall a b. (a -> b) -> a -> b
$ VName -> String
baseString VName
i
      let form' :: LoopForm lore
form' = VName
-> IntType -> SubExp -> [(LParam lore, VName)] -> LoopForm lore
forall lore.
VName
-> IntType -> SubExp -> [(LParam lore, VName)] -> LoopForm lore
ForLoop VName
i' IntType
it' SubExp
n' []
      BodyT lore
body' <- RuleM lore (Body (Lore (RuleM lore)))
-> RuleM lore (Body (Lore (RuleM lore)))
forall (m :: * -> *).
MonadBinder m =>
m (Body (Lore m)) -> m (Body (Lore m))
insertStmsM (RuleM lore (Body (Lore (RuleM lore)))
 -> RuleM lore (Body (Lore (RuleM lore))))
-> RuleM lore (Body (Lore (RuleM lore)))
-> RuleM lore (Body (Lore (RuleM lore)))
forall a b. (a -> b) -> a -> b
$
        LoopForm lore -> RuleM lore (BodyT lore) -> RuleM lore (BodyT lore)
forall lore a (m :: * -> *) b.
(Scoped lore a, LocalScope lore m) =>
a -> m b -> m b
inScopeOf LoopForm lore
form' (RuleM lore (BodyT lore) -> RuleM lore (BodyT lore))
-> RuleM lore (BodyT lore) -> RuleM lore (BodyT lore)
forall a b. (a -> b) -> a -> b
$ do
          [VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName
i] (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ ConvOp -> SubExp -> BasicOp
ConvOp (IntType -> IntType -> ConvOp
SExt IntType
it' IntType
Int64) (VName -> SubExp
Var VName
i')
          BodyT lore -> RuleM lore (BodyT lore)
forall (f :: * -> *) a. Applicative f => a -> f a
pure BodyT lore
body
      StmAux (ExpDec lore) -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux (ExpDec lore)
aux (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
        Certificates -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
          Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ [(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
forall lore.
[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
DoLoop [(FParam lore, SubExp)]
ctx [(FParam lore, SubExp)]
val LoopForm lore
form' BodyT lore
body'
  where
    smallerType :: Maybe (SubExp, IntType, Certificates)
smallerType
      | Var VName
n' <- SubExp
n,
        Just (ConvOp (SExt IntType
it' IntType
_) SubExp
n'', Certificates
cs) <- VName -> TopDown lore -> Maybe (BasicOp, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (BasicOp, Certificates)
ST.lookupBasicOp VName
n' TopDown lore
vtable =
        (SubExp, IntType, Certificates)
-> Maybe (SubExp, IntType, Certificates)
forall a. a -> Maybe a
Just (SubExp
n'', IntType
it', Certificates
cs)
      | Constant (IntValue (Int64Value Int64
n')) <- SubExp
n,
        Int64 -> Integer
forall a. Integral a => a -> Integer
toInteger Int64
n' Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
<= Int32 -> Integer
forall a. Integral a => a -> Integer
toInteger (Int32
forall a. Bounded a => a
maxBound :: Int32) =
        (SubExp, IntType, Certificates)
-> Maybe (SubExp, IntType, Certificates)
forall a. a -> Maybe a
Just (IntType -> Integer -> SubExp
intConst IntType
Int32 (Int64 -> Integer
forall a. Integral a => a -> Integer
toInteger Int64
n'), IntType
Int32, Certificates
forall a. Monoid a => a
mempty)
      | Bool
otherwise =
        Maybe (SubExp, IntType, Certificates)
forall a. Maybe a
Nothing
narrowLoopType TopDown lore
_ Pattern lore
_ StmAux (ExpDec lore)
_ ([(FParam lore, SubExp)], [(FParam lore, SubExp)], LoopForm lore,
 BodyT lore)
_ = Rule lore
forall lore. Rule lore
Skip

unroll ::
  BinderOps lore =>
  Integer ->
  [(FParam lore, SubExp)] ->
  (VName, IntType, Integer) ->
  [(LParam lore, VName)] ->
  Body lore ->
  RuleM lore [SubExp]
unroll :: Integer
-> [(FParam lore, SubExp)]
-> (VName, IntType, Integer)
-> [(LParam lore, VName)]
-> Body lore
-> RuleM lore [SubExp]
unroll Integer
n [(FParam lore, SubExp)]
merge (VName
iv, IntType
it, Integer
i) [(LParam lore, VName)]
loop_vars Body lore
body
  | Integer
i Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
>= Integer
n =
    [SubExp] -> RuleM lore [SubExp]
forall (m :: * -> *) a. Monad m => a -> m a
return ([SubExp] -> RuleM lore [SubExp])
-> [SubExp] -> RuleM lore [SubExp]
forall a b. (a -> b) -> a -> b
$ ((FParam lore, SubExp) -> SubExp)
-> [(FParam lore, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (FParam lore, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(FParam lore, SubExp)]
merge
  | Bool
otherwise = do
    Body lore
iter_body <- RuleM lore (Body (Lore (RuleM lore)))
-> RuleM lore (Body (Lore (RuleM lore)))
forall (m :: * -> *).
MonadBinder m =>
m (Body (Lore m)) -> m (Body (Lore m))
insertStmsM (RuleM lore (Body (Lore (RuleM lore)))
 -> RuleM lore (Body (Lore (RuleM lore))))
-> RuleM lore (Body (Lore (RuleM lore)))
-> RuleM lore (Body (Lore (RuleM lore)))
forall a b. (a -> b) -> a -> b
$ do
      [(FParam lore, SubExp)]
-> ((FParam lore, SubExp) -> RuleM lore ()) -> RuleM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(FParam lore, SubExp)]
merge (((FParam lore, SubExp) -> RuleM lore ()) -> RuleM lore ())
-> ((FParam lore, SubExp) -> RuleM lore ()) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ \(FParam lore
mergevar, SubExp
mergeinit) ->
        [VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [FParam lore -> VName
forall dec. Param dec -> VName
paramName FParam lore
mergevar] (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
mergeinit

      [VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName
iv] (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
it Integer
i

      [(LParam lore, VName)]
-> ((LParam lore, VName) -> RuleM lore ()) -> RuleM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(LParam lore, VName)]
loop_vars (((LParam lore, VName) -> RuleM lore ()) -> RuleM lore ())
-> ((LParam lore, VName) -> RuleM lore ()) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ \(LParam lore
p, VName
arr) ->
        [VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [LParam lore -> VName
forall dec. Param dec -> VName
paramName LParam lore
p] (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
          BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$
            VName -> Slice SubExp -> BasicOp
Index VName
arr (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$
              SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
i) DimIndex SubExp -> Slice SubExp -> Slice SubExp
forall a. a -> [a] -> [a]
: TypeBase Shape NoUniqueness -> Slice SubExp -> Slice SubExp
fullSlice (LParam lore -> TypeBase Shape NoUniqueness
forall dec. Typed dec => Param dec -> TypeBase Shape NoUniqueness
paramType LParam lore
p) []

      -- Some of the sizes in the types here might be temporarily wrong
      -- until copy propagation fixes it up.
      Body lore -> RuleM lore (Body lore)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Body lore
body

    Body lore
iter_body' <- Body lore -> RuleM lore (Body lore)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Body lore -> m (Body lore)
renameBody Body lore
iter_body
    Stms (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms (Stms (Lore (RuleM lore)) -> RuleM lore ())
-> Stms (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ Body lore -> Stms lore
forall lore. BodyT lore -> Stms lore
bodyStms Body lore
iter_body'

    let merge' :: [(FParam lore, SubExp)]
merge' = [FParam lore] -> [SubExp] -> [(FParam lore, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (((FParam lore, SubExp) -> FParam lore)
-> [(FParam lore, SubExp)] -> [FParam lore]
forall a b. (a -> b) -> [a] -> [b]
map (FParam lore, SubExp) -> FParam lore
forall a b. (a, b) -> a
fst [(FParam lore, SubExp)]
merge) ([SubExp] -> [(FParam lore, SubExp)])
-> [SubExp] -> [(FParam lore, SubExp)]
forall a b. (a -> b) -> a -> b
$ Body lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult Body lore
iter_body'
    Integer
-> [(FParam lore, SubExp)]
-> (VName, IntType, Integer)
-> [(LParam lore, VName)]
-> Body lore
-> RuleM lore [SubExp]
forall lore.
BinderOps lore =>
Integer
-> [(FParam lore, SubExp)]
-> (VName, IntType, Integer)
-> [(LParam lore, VName)]
-> Body lore
-> RuleM lore [SubExp]
unroll Integer
n [(FParam lore, SubExp)]
merge' (VName
iv, IntType
it, Integer
i Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer
1) [(LParam lore, VName)]
loop_vars Body lore
body

simplifyKnownIterationLoop :: BinderOps lore => TopDownRuleDoLoop lore
simplifyKnownIterationLoop :: TopDownRuleDoLoop lore
simplifyKnownIterationLoop TopDown lore
_ Pattern lore
pat StmAux (ExpDec lore)
aux ([(FParam lore, SubExp)]
ctx, [(FParam lore, SubExp)]
val, ForLoop VName
i IntType
it (Constant PrimValue
iters) [(LParam lore, VName)]
loop_vars, BodyT lore
body)
  | IntValue IntValue
n <- PrimValue
iters,
    IntValue -> Bool
zeroIshInt IntValue
n Bool -> Bool -> Bool
|| IntValue -> Bool
oneIshInt IntValue
n Bool -> Bool -> Bool
|| Attr
"unroll" Attr -> Attrs -> Bool
`inAttrs` StmAux (ExpDec lore) -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs StmAux (ExpDec lore)
aux = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
    [SubExp]
res <- Integer
-> [(FParam lore, SubExp)]
-> (VName, IntType, Integer)
-> [(LParam lore, VName)]
-> BodyT lore
-> RuleM lore [SubExp]
forall lore.
BinderOps lore =>
Integer
-> [(FParam lore, SubExp)]
-> (VName, IntType, Integer)
-> [(LParam lore, VName)]
-> Body lore
-> RuleM lore [SubExp]
unroll (IntValue -> Integer
forall int. Integral int => IntValue -> int
valueIntegral IntValue
n) ([(FParam lore, SubExp)]
ctx [(FParam lore, SubExp)]
-> [(FParam lore, SubExp)] -> [(FParam lore, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(FParam lore, SubExp)]
val) (VName
i, IntType
it, Integer
0) [(LParam lore, VName)]
loop_vars BodyT lore
body
    [(VName, SubExp)]
-> ((VName, SubExp) -> RuleM lore ()) -> RuleM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pattern lore -> [VName]
forall dec. PatternT dec -> [VName]
patternNames Pattern lore
pat) [SubExp]
res) (((VName, SubExp) -> RuleM lore ()) -> RuleM lore ())
-> ((VName, SubExp) -> RuleM lore ()) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ \(VName
v, SubExp
se) ->
      [VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName
v] (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
simplifyKnownIterationLoop TopDown lore
_ Pattern lore
_ StmAux (ExpDec lore)
_ ([(FParam lore, SubExp)], [(FParam lore, SubExp)], LoopForm lore,
 BodyT lore)
_ =
  Rule lore
forall lore. Rule lore
Skip

-- | Turn @copy(x)@ into @x@ iff @x@ is not used after this copy
-- statement and it can be consumed.
--
-- This simplistic rule is only valid before we introduce memory.
removeUnnecessaryCopy :: BinderOps lore => BottomUpRuleBasicOp lore
removeUnnecessaryCopy :: BottomUpRuleBasicOp lore
removeUnnecessaryCopy (SymbolTable lore
vtable, UsageTable
used) (Pattern [] [PatElemT (LetDec lore)
d]) StmAux (ExpDec lore)
_ (Copy VName
v)
  | Bool -> Bool
not (VName
v VName -> UsageTable -> Bool
`UT.isConsumed` UsageTable
used),
    (Bool -> Bool
not (VName
v VName -> UsageTable -> Bool
`UT.used` UsageTable
used) Bool -> Bool -> Bool
&& Bool
consumable) Bool -> Bool -> Bool
|| Bool -> Bool
not (PatElemT (LetDec lore) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
d VName -> UsageTable -> Bool
`UT.isConsumed` UsageTable
used) =
    RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ [VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [PatElemT (LetDec lore) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
d] (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v
  where
    -- We need to make sure we can even consume the original.
    -- This is currently a hacky check, much too conservative,
    -- because we don't have the information conveniently
    -- available.
    consumable :: Bool
consumable = case VName -> Map VName (NameInfo lore) -> Maybe (NameInfo lore)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v (Map VName (NameInfo lore) -> Maybe (NameInfo lore))
-> Map VName (NameInfo lore) -> Maybe (NameInfo lore)
forall a b. (a -> b) -> a -> b
$ SymbolTable lore -> Map VName (NameInfo lore)
forall lore. SymbolTable lore -> Scope lore
ST.toScope SymbolTable lore
vtable of
      Just (FParamName FParamInfo lore
info) -> TypeBase Shape Uniqueness -> Bool
forall shape. TypeBase shape Uniqueness -> Bool
unique (TypeBase Shape Uniqueness -> Bool)
-> TypeBase Shape Uniqueness -> Bool
forall a b. (a -> b) -> a -> b
$ FParamInfo lore -> TypeBase Shape Uniqueness
forall t. DeclTyped t => t -> TypeBase Shape Uniqueness
declTypeOf FParamInfo lore
info
      Maybe (NameInfo lore)
_ -> Bool
False
removeUnnecessaryCopy (SymbolTable lore, UsageTable)
_ PatternT (LetDec lore)
_ StmAux (ExpDec lore)
_ BasicOp
_ = Rule lore
forall lore. Rule lore
Skip

simplifyCmpOp :: SimpleRule lore
simplifyCmpOp :: SimpleRule lore
simplifyCmpOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (CmpOp CmpOp
cmp SubExp
e1 SubExp
e2)
  | SubExp
e1 SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 = PrimValue -> Maybe (BasicOp, Certificates)
constRes (PrimValue -> Maybe (BasicOp, Certificates))
-> PrimValue -> Maybe (BasicOp, Certificates)
forall a b. (a -> b) -> a -> b
$
    Bool -> PrimValue
BoolValue (Bool -> PrimValue) -> Bool -> PrimValue
forall a b. (a -> b) -> a -> b
$
      case CmpOp
cmp of
        CmpEq {} -> Bool
True
        CmpSlt {} -> Bool
False
        CmpUlt {} -> Bool
False
        CmpSle {} -> Bool
True
        CmpUle {} -> Bool
True
        FCmpLt {} -> Bool
False
        FCmpLe {} -> Bool
True
        CmpOp
CmpLlt -> Bool
False
        CmpOp
CmpLle -> Bool
True
simplifyCmpOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (CmpOp CmpOp
cmp (Constant PrimValue
v1) (Constant PrimValue
v2)) =
  PrimValue -> Maybe (BasicOp, Certificates)
constRes (PrimValue -> Maybe (BasicOp, Certificates))
-> (Bool -> PrimValue) -> Bool -> Maybe (BasicOp, Certificates)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Bool -> PrimValue
BoolValue (Bool -> Maybe (BasicOp, Certificates))
-> Maybe Bool -> Maybe (BasicOp, Certificates)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< CmpOp -> PrimValue -> PrimValue -> Maybe Bool
doCmpOp CmpOp
cmp PrimValue
v1 PrimValue
v2
simplifyCmpOp VarLookup lore
look SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (CmpOp CmpEq {} (Constant (IntValue IntValue
x)) (Var VName
v))
  | Just (BasicOp (ConvOp BToI {} SubExp
b), Certificates
cs) <- VarLookup lore
look VName
v =
    case IntValue -> Int
forall int. Integral int => IntValue -> int
valueIntegral IntValue
x :: Int of
      Int
1 -> (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (SubExp -> BasicOp
SubExp SubExp
b, Certificates
cs)
      Int
0 -> (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (UnOp -> SubExp -> BasicOp
UnOp UnOp
Not SubExp
b, Certificates
cs)
      Int
_ -> (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (SubExp -> BasicOp
SubExp (PrimValue -> SubExp
Constant (Bool -> PrimValue
BoolValue Bool
False)), Certificates
cs)
simplifyCmpOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ BasicOp
_ = Maybe (BasicOp, Certificates)
forall a. Maybe a
Nothing

simplifyBinOp :: SimpleRule lore
simplifyBinOp :: SimpleRule lore
simplifyBinOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp BinOp
op (Constant PrimValue
v1) (Constant PrimValue
v2))
  | Just PrimValue
res <- BinOp -> PrimValue -> PrimValue -> Maybe PrimValue
doBinOp BinOp
op PrimValue
v1 PrimValue
v2 =
    PrimValue -> Maybe (BasicOp, Certificates)
constRes PrimValue
res
simplifyBinOp VarLookup lore
look SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp Add {} SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e2
  | SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
  -- x+(y-x) => y
  | Var VName
v2 <- SubExp
e2,
    Just (BasicOp (BinOp Sub {} SubExp
e2_a SubExp
e2_b), Certificates
cs) <- VarLookup lore
look VName
v2,
    SubExp
e2_b SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e1 =
    (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (SubExp -> BasicOp
SubExp SubExp
e2_a, Certificates
cs)
simplifyBinOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp FAdd {} SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e2
  | SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
simplifyBinOp VarLookup lore
look SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp Sub {} SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
  -- Cases for simplifying (a+b)-b and permutations.
  | Var VName
v1 <- SubExp
e1,
    Just (BasicOp (BinOp Add {} SubExp
e1_a SubExp
e1_b), Certificates
cs) <- VarLookup lore
look VName
v1,
    SubExp
e1_a SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 =
    (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (SubExp -> BasicOp
SubExp SubExp
e1_b, Certificates
cs)
  | Var VName
v1 <- SubExp
e1,
    Just (BasicOp (BinOp Add {} SubExp
e1_a SubExp
e1_b), Certificates
cs) <- VarLookup lore
look VName
v1,
    SubExp
e1_b SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 =
    (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (SubExp -> BasicOp
SubExp SubExp
e1_a, Certificates
cs)
  | Var VName
v2 <- SubExp
e2,
    Just (BasicOp (BinOp Add {} SubExp
e2_a SubExp
e2_b), Certificates
cs) <- VarLookup lore
look VName
v2,
    SubExp
e2_a SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e1 =
    (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (SubExp -> BasicOp
SubExp SubExp
e2_b, Certificates
cs)
  | Var VName
v2 <- SubExp
e1,
    Just (BasicOp (BinOp Add {} SubExp
e2_a SubExp
e2_b), Certificates
cs) <- VarLookup lore
look VName
v2,
    SubExp
e2_b SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e1 =
    (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (SubExp -> BasicOp
SubExp SubExp
e2_a, Certificates
cs)
simplifyBinOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp FSub {} SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
simplifyBinOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp Mul {} SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
  | SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e2
  | SubExp -> Bool
isCt1 SubExp
e1 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e2
  | SubExp -> Bool
isCt1 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
simplifyBinOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp FMul {} SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
  | SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e2
  | SubExp -> Bool
isCt1 SubExp
e1 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e2
  | SubExp -> Bool
isCt1 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
simplifyBinOp VarLookup lore
look SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp (SMod IntType
t Safety
_) SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt1 SubExp
e2 = PrimValue -> Maybe (BasicOp, Certificates)
constRes (PrimValue -> Maybe (BasicOp, Certificates))
-> PrimValue -> Maybe (BasicOp, Certificates)
forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue (IntValue -> PrimValue) -> IntValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ IntType -> Int -> IntValue
forall int. Integral int => IntType -> int -> IntValue
intValue IntType
t (Int
0 :: Int)
  | SubExp
e1 SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 = PrimValue -> Maybe (BasicOp, Certificates)
constRes (PrimValue -> Maybe (BasicOp, Certificates))
-> PrimValue -> Maybe (BasicOp, Certificates)
forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue (IntValue -> PrimValue) -> IntValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ IntType -> Int -> IntValue
forall int. Integral int => IntType -> int -> IntValue
intValue IntType
t (Int
0 :: Int)
  | Var VName
v1 <- SubExp
e1,
    Just (BasicOp (BinOp SMod {} SubExp
_ SubExp
e4), Certificates
v1_cs) <- VarLookup lore
look VName
v1,
    SubExp
e4 SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 =
    (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (SubExp -> BasicOp
SubExp SubExp
e1, Certificates
v1_cs)
simplifyBinOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp SDiv {} SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
  | SubExp -> Bool
isCt1 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
  | SubExp -> Bool
isCt0 SubExp
e2 = Maybe (BasicOp, Certificates)
forall a. Maybe a
Nothing
simplifyBinOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp SDivUp {} SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
  | SubExp -> Bool
isCt1 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
  | SubExp -> Bool
isCt0 SubExp
e2 = Maybe (BasicOp, Certificates)
forall a. Maybe a
Nothing
simplifyBinOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp FDiv {} SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
  | SubExp -> Bool
isCt1 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
  | SubExp -> Bool
isCt0 SubExp
e2 = Maybe (BasicOp, Certificates)
forall a. Maybe a
Nothing
simplifyBinOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp (SRem IntType
t Safety
_) SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt1 SubExp
e2 = PrimValue -> Maybe (BasicOp, Certificates)
constRes (PrimValue -> Maybe (BasicOp, Certificates))
-> PrimValue -> Maybe (BasicOp, Certificates)
forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue (IntValue -> PrimValue) -> IntValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ IntType -> Int -> IntValue
forall int. Integral int => IntType -> int -> IntValue
intValue IntType
t (Int
0 :: Int)
  | SubExp
e1 SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 = PrimValue -> Maybe (BasicOp, Certificates)
constRes (PrimValue -> Maybe (BasicOp, Certificates))
-> PrimValue -> Maybe (BasicOp, Certificates)
forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue (IntValue -> PrimValue) -> IntValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ IntType -> Int -> IntValue
forall int. Integral int => IntType -> int -> IntValue
intValue IntType
t (Int
1 :: Int)
simplifyBinOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp SQuot {} SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt1 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
  | SubExp -> Bool
isCt0 SubExp
e2 = Maybe (BasicOp, Certificates)
forall a. Maybe a
Nothing
simplifyBinOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp (FPow FloatType
t) SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes (SubExp -> Maybe (BasicOp, Certificates))
-> SubExp -> Maybe (BasicOp, Certificates)
forall a b. (a -> b) -> a -> b
$ FloatType -> Double -> SubExp
floatConst FloatType
t Double
1
  | SubExp -> Bool
isCt0 SubExp
e1 Bool -> Bool -> Bool
|| SubExp -> Bool
isCt1 SubExp
e1 Bool -> Bool -> Bool
|| SubExp -> Bool
isCt1 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
simplifyBinOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp (Shl IntType
t) SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
  | SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes (SubExp -> Maybe (BasicOp, Certificates))
-> SubExp -> Maybe (BasicOp, Certificates)
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
t Integer
0
simplifyBinOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp AShr {} SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
simplifyBinOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp (And IntType
t) SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes (SubExp -> Maybe (BasicOp, Certificates))
-> SubExp -> Maybe (BasicOp, Certificates)
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
t Integer
0
  | SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes (SubExp -> Maybe (BasicOp, Certificates))
-> SubExp -> Maybe (BasicOp, Certificates)
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
t Integer
0
  | SubExp
e1 SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
simplifyBinOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp Or {} SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e2
  | SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
  | SubExp
e1 SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
simplifyBinOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp (Xor IntType
t) SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e2
  | SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
  | SubExp
e1 SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes (SubExp -> Maybe (BasicOp, Certificates))
-> SubExp -> Maybe (BasicOp, Certificates)
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
t Integer
0
simplifyBinOp VarLookup lore
defOf SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp BinOp
LogAnd SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e1 = PrimValue -> Maybe (BasicOp, Certificates)
constRes (PrimValue -> Maybe (BasicOp, Certificates))
-> PrimValue -> Maybe (BasicOp, Certificates)
forall a b. (a -> b) -> a -> b
$ Bool -> PrimValue
BoolValue Bool
False
  | SubExp -> Bool
isCt0 SubExp
e2 = PrimValue -> Maybe (BasicOp, Certificates)
constRes (PrimValue -> Maybe (BasicOp, Certificates))
-> PrimValue -> Maybe (BasicOp, Certificates)
forall a b. (a -> b) -> a -> b
$ Bool -> PrimValue
BoolValue Bool
False
  | SubExp -> Bool
isCt1 SubExp
e1 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e2
  | SubExp -> Bool
isCt1 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
  | Var VName
v <- SubExp
e1,
    Just (BasicOp (UnOp UnOp
Not SubExp
e1'), Certificates
v_cs) <- VarLookup lore
defOf VName
v,
    SubExp
e1' SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 =
    (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ Bool -> PrimValue
BoolValue Bool
False, Certificates
v_cs)
  | Var VName
v <- SubExp
e2,
    Just (BasicOp (UnOp UnOp
Not SubExp
e2'), Certificates
v_cs) <- VarLookup lore
defOf VName
v,
    SubExp
e2' SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e1 =
    (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ Bool -> PrimValue
BoolValue Bool
False, Certificates
v_cs)
simplifyBinOp VarLookup lore
defOf SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp BinOp
LogOr SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e2
  | SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
  | SubExp -> Bool
isCt1 SubExp
e1 = PrimValue -> Maybe (BasicOp, Certificates)
constRes (PrimValue -> Maybe (BasicOp, Certificates))
-> PrimValue -> Maybe (BasicOp, Certificates)
forall a b. (a -> b) -> a -> b
$ Bool -> PrimValue
BoolValue Bool
True
  | SubExp -> Bool
isCt1 SubExp
e2 = PrimValue -> Maybe (BasicOp, Certificates)
constRes (PrimValue -> Maybe (BasicOp, Certificates))
-> PrimValue -> Maybe (BasicOp, Certificates)
forall a b. (a -> b) -> a -> b
$ Bool -> PrimValue
BoolValue Bool
True
  | Var VName
v <- SubExp
e1,
    Just (BasicOp (UnOp UnOp
Not SubExp
e1'), Certificates
v_cs) <- VarLookup lore
defOf VName
v,
    SubExp
e1' SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 =
    (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ Bool -> PrimValue
BoolValue Bool
True, Certificates
v_cs)
  | Var VName
v <- SubExp
e2,
    Just (BasicOp (UnOp UnOp
Not SubExp
e2'), Certificates
v_cs) <- VarLookup lore
defOf VName
v,
    SubExp
e2' SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e1 =
    (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ Bool -> PrimValue
BoolValue Bool
True, Certificates
v_cs)
simplifyBinOp VarLookup lore
defOf SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp (SMax IntType
it) SubExp
e1 SubExp
e2)
  | SubExp
e1 SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 =
    SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
  | Var VName
v1 <- SubExp
e1,
    Just (BasicOp (BinOp (SMax IntType
_) SubExp
e1_1 SubExp
e1_2), Certificates
v1_cs) <- VarLookup lore
defOf VName
v1,
    SubExp
e1_1 SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 =
    (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> BinOp
SMax IntType
it) SubExp
e1_2 SubExp
e2, Certificates
v1_cs)
  | Var VName
v1 <- SubExp
e1,
    Just (BasicOp (BinOp (SMax IntType
_) SubExp
e1_1 SubExp
e1_2), Certificates
v1_cs) <- VarLookup lore
defOf VName
v1,
    SubExp
e1_2 SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 =
    (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> BinOp
SMax IntType
it) SubExp
e1_1 SubExp
e2, Certificates
v1_cs)
  | Var VName
v2 <- SubExp
e2,
    Just (BasicOp (BinOp (SMax IntType
_) SubExp
e2_1 SubExp
e2_2), Certificates
v2_cs) <- VarLookup lore
defOf VName
v2,
    SubExp
e2_1 SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e1 =
    (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> BinOp
SMax IntType
it) SubExp
e2_2 SubExp
e1, Certificates
v2_cs)
  | Var VName
v2 <- SubExp
e2,
    Just (BasicOp (BinOp (SMax IntType
_) SubExp
e2_1 SubExp
e2_2), Certificates
v2_cs) <- VarLookup lore
defOf VName
v2,
    SubExp
e2_2 SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e1 =
    (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> BinOp
SMax IntType
it) SubExp
e2_1 SubExp
e1, Certificates
v2_cs)
simplifyBinOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ BasicOp
_ = Maybe (BasicOp, Certificates)
forall a. Maybe a
Nothing

constRes :: PrimValue -> Maybe (BasicOp, Certificates)
constRes :: PrimValue -> Maybe (BasicOp, Certificates)
constRes = (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just ((BasicOp, Certificates) -> Maybe (BasicOp, Certificates))
-> (PrimValue -> (BasicOp, Certificates))
-> PrimValue
-> Maybe (BasicOp, Certificates)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (,Certificates
forall a. Monoid a => a
mempty) (BasicOp -> (BasicOp, Certificates))
-> (PrimValue -> BasicOp) -> PrimValue -> (BasicOp, Certificates)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp (SubExp -> BasicOp)
-> (PrimValue -> SubExp) -> PrimValue -> BasicOp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimValue -> SubExp
Constant

subExpRes :: SubExp -> Maybe (BasicOp, Certificates)
subExpRes :: SubExp -> Maybe (BasicOp, Certificates)
subExpRes = (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just ((BasicOp, Certificates) -> Maybe (BasicOp, Certificates))
-> (SubExp -> (BasicOp, Certificates))
-> SubExp
-> Maybe (BasicOp, Certificates)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (,Certificates
forall a. Monoid a => a
mempty) (BasicOp -> (BasicOp, Certificates))
-> (SubExp -> BasicOp) -> SubExp -> (BasicOp, Certificates)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp

simplifyUnOp :: SimpleRule lore
simplifyUnOp :: SimpleRule lore
simplifyUnOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (UnOp UnOp
op (Constant PrimValue
v)) =
  PrimValue -> Maybe (BasicOp, Certificates)
constRes (PrimValue -> Maybe (BasicOp, Certificates))
-> Maybe PrimValue -> Maybe (BasicOp, Certificates)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< UnOp -> PrimValue -> Maybe PrimValue
doUnOp UnOp
op PrimValue
v
simplifyUnOp VarLookup lore
defOf SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (UnOp UnOp
Not (Var VName
v))
  | Just (BasicOp (UnOp UnOp
Not SubExp
v2), Certificates
v_cs) <- VarLookup lore
defOf VName
v =
    (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (SubExp -> BasicOp
SubExp SubExp
v2, Certificates
v_cs)
simplifyUnOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ BasicOp
_ =
  Maybe (BasicOp, Certificates)
forall a. Maybe a
Nothing

simplifyConvOp :: SimpleRule lore
simplifyConvOp :: SimpleRule lore
simplifyConvOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (ConvOp ConvOp
op (Constant PrimValue
v)) =
  PrimValue -> Maybe (BasicOp, Certificates)
constRes (PrimValue -> Maybe (BasicOp, Certificates))
-> Maybe PrimValue -> Maybe (BasicOp, Certificates)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ConvOp -> PrimValue -> Maybe PrimValue
doConvOp ConvOp
op PrimValue
v
simplifyConvOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (ConvOp ConvOp
op SubExp
se)
  | (PrimType
from, PrimType
to) <- ConvOp -> (PrimType, PrimType)
convOpType ConvOp
op,
    PrimType
from PrimType -> PrimType -> Bool
forall a. Eq a => a -> a -> Bool
== PrimType
to =
    SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
se
simplifyConvOp VarLookup lore
lookupVar SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (ConvOp (SExt IntType
t2 IntType
t1) (Var VName
v))
  | Just (BasicOp (ConvOp (SExt IntType
t3 IntType
_) SubExp
se), Certificates
v_cs) <- VarLookup lore
lookupVar VName
v,
    IntType
t2 IntType -> IntType -> Bool
forall a. Ord a => a -> a -> Bool
>= IntType
t3 =
    (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (ConvOp -> SubExp -> BasicOp
ConvOp (IntType -> IntType -> ConvOp
SExt IntType
t3 IntType
t1) SubExp
se, Certificates
v_cs)
simplifyConvOp VarLookup lore
lookupVar SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (ConvOp (ZExt IntType
t2 IntType
t1) (Var VName
v))
  | Just (BasicOp (ConvOp (ZExt IntType
t3 IntType
_) SubExp
se), Certificates
v_cs) <- VarLookup lore
lookupVar VName
v,
    IntType
t2 IntType -> IntType -> Bool
forall a. Ord a => a -> a -> Bool
>= IntType
t3 =
    (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (ConvOp -> SubExp -> BasicOp
ConvOp (IntType -> IntType -> ConvOp
ZExt IntType
t3 IntType
t1) SubExp
se, Certificates
v_cs)
simplifyConvOp VarLookup lore
lookupVar SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (ConvOp (SIToFP IntType
t2 FloatType
t1) (Var VName
v))
  | Just (BasicOp (ConvOp (SExt IntType
t3 IntType
_) SubExp
se), Certificates
v_cs) <- VarLookup lore
lookupVar VName
v,
    IntType
t2 IntType -> IntType -> Bool
forall a. Ord a => a -> a -> Bool
>= IntType
t3 =
    (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (ConvOp -> SubExp -> BasicOp
ConvOp (IntType -> FloatType -> ConvOp
SIToFP IntType
t3 FloatType
t1) SubExp
se, Certificates
v_cs)
simplifyConvOp VarLookup lore
lookupVar SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (ConvOp (UIToFP IntType
t2 FloatType
t1) (Var VName
v))
  | Just (BasicOp (ConvOp (ZExt IntType
t3 IntType
_) SubExp
se), Certificates
v_cs) <- VarLookup lore
lookupVar VName
v,
    IntType
t2 IntType -> IntType -> Bool
forall a. Ord a => a -> a -> Bool
>= IntType
t3 =
    (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (ConvOp -> SubExp -> BasicOp
ConvOp (IntType -> FloatType -> ConvOp
UIToFP IntType
t3 FloatType
t1) SubExp
se, Certificates
v_cs)
simplifyConvOp VarLookup lore
lookupVar SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (ConvOp (FPConv FloatType
t2 FloatType
t1) (Var VName
v))
  | Just (BasicOp (ConvOp (FPConv FloatType
t3 FloatType
_) SubExp
se), Certificates
v_cs) <- VarLookup lore
lookupVar VName
v,
    FloatType
t2 FloatType -> FloatType -> Bool
forall a. Ord a => a -> a -> Bool
>= FloatType
t3 =
    (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (ConvOp -> SubExp -> BasicOp
ConvOp (FloatType -> FloatType -> ConvOp
FPConv FloatType
t3 FloatType
t1) SubExp
se, Certificates
v_cs)
simplifyConvOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ BasicOp
_ =
  Maybe (BasicOp, Certificates)
forall a. Maybe a
Nothing

-- If expression is true then just replace assertion.
simplifyAssert :: SimpleRule lore
simplifyAssert :: SimpleRule lore
simplifyAssert VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (Assert (Constant (BoolValue Bool
True)) ErrorMsg SubExp
_ (SrcLoc, [SrcLoc])
_) =
  PrimValue -> Maybe (BasicOp, Certificates)
constRes PrimValue
Checked
simplifyAssert VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ BasicOp
_ =
  Maybe (BasicOp, Certificates)
forall a. Maybe a
Nothing

constantFoldPrimFun :: BinderOps lore => TopDownRuleGeneric lore
constantFoldPrimFun :: TopDownRuleGeneric lore
constantFoldPrimFun TopDown lore
_ (Let Pattern lore
pat (StmAux Certificates
cs Attrs
attrs ExpDec lore
_) (Apply Name
fname [(SubExp, Diet)]
args [RetType lore]
_ (Safety, SrcLoc, [SrcLoc])
_))
  | Just [PrimValue]
args' <- ((SubExp, Diet) -> Maybe PrimValue)
-> [(SubExp, Diet)] -> Maybe [PrimValue]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SubExp -> Maybe PrimValue
isConst (SubExp -> Maybe PrimValue)
-> ((SubExp, Diet) -> SubExp) -> (SubExp, Diet) -> Maybe PrimValue
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SubExp, Diet) -> SubExp
forall a b. (a, b) -> a
fst) [(SubExp, Diet)]
args,
    Just ([PrimType]
_, PrimType
_, [PrimValue] -> Maybe PrimValue
fun) <- String
-> Map
     String ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
-> Maybe ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (Name -> String
nameToString Name
fname) Map String ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
primFuns,
    Just PrimValue
result <- [PrimValue] -> Maybe PrimValue
fun [PrimValue]
args' =
    RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$
      Certificates -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
        Attrs -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Attrs -> m a -> m a
attributing Attrs
attrs (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
          Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant PrimValue
result
  where
    isConst :: SubExp -> Maybe PrimValue
isConst (Constant PrimValue
v) = PrimValue -> Maybe PrimValue
forall a. a -> Maybe a
Just PrimValue
v
    isConst SubExp
_ = Maybe PrimValue
forall a. Maybe a
Nothing
constantFoldPrimFun TopDown lore
_ Stm lore
_ = Rule lore
forall lore. Rule lore
Skip

simplifyIndex :: BinderOps lore => BottomUpRuleBasicOp lore
simplifyIndex :: BottomUpRuleBasicOp lore
simplifyIndex (SymbolTable lore
vtable, UsageTable
used) pat :: Pattern lore
pat@(Pattern [] [PatElemT (LetDec lore)
pe]) (StmAux Certificates
cs Attrs
attrs ExpDec lore
_) (Index VName
idd Slice SubExp
inds)
  | Just RuleM lore IndexResult
m <- SymbolTable (Lore (RuleM lore))
-> (SubExp -> Maybe (TypeBase Shape NoUniqueness))
-> VName
-> Slice SubExp
-> Bool
-> Maybe (RuleM lore IndexResult)
forall (m :: * -> *).
MonadBinder m =>
SymbolTable (Lore m)
-> (SubExp -> Maybe (TypeBase Shape NoUniqueness))
-> VName
-> Slice SubExp
-> Bool
-> Maybe (m IndexResult)
simplifyIndexing SymbolTable lore
SymbolTable (Lore (RuleM lore))
vtable SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType VName
idd Slice SubExp
inds Bool
consumed = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
    IndexResult
res <- RuleM lore IndexResult
m
    Attrs -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Attrs -> m a -> m a
attributing Attrs
attrs (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ case IndexResult
res of
      SubExpResult Certificates
cs' SubExp
se ->
        Certificates -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (Certificates
cs Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> Certificates
cs') (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
          [VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames (Pattern lore -> [VName]
forall dec. PatternT dec -> [VName]
patternNames Pattern lore
pat) (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
      IndexResult Certificates
extra_cs VName
idd' Slice SubExp
inds' ->
        Certificates -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (Certificates
cs Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> Certificates
extra_cs) (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
          [VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames (Pattern lore -> [VName]
forall dec. PatternT dec -> [VName]
patternNames Pattern lore
pat) (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
idd' Slice SubExp
inds'
  where
    consumed :: Bool
consumed = PatElemT (LetDec lore) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
pe VName -> UsageTable -> Bool
`UT.isConsumed` UsageTable
used
    seType :: SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType (Var VName
v) = VName -> SymbolTable lore -> Maybe (TypeBase Shape NoUniqueness)
forall lore.
ASTLore lore =>
VName -> SymbolTable lore -> Maybe (TypeBase Shape NoUniqueness)
ST.lookupType VName
v SymbolTable lore
vtable
    seType (Constant PrimValue
v) = TypeBase Shape NoUniqueness -> Maybe (TypeBase Shape NoUniqueness)
forall a. a -> Maybe a
Just (TypeBase Shape NoUniqueness
 -> Maybe (TypeBase Shape NoUniqueness))
-> TypeBase Shape NoUniqueness
-> Maybe (TypeBase Shape NoUniqueness)
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase Shape NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim (PrimType -> TypeBase Shape NoUniqueness)
-> PrimType -> TypeBase Shape NoUniqueness
forall a b. (a -> b) -> a -> b
$ PrimValue -> PrimType
primValueType PrimValue
v
simplifyIndex (SymbolTable lore, UsageTable)
_ Pattern lore
_ StmAux (ExpDec lore)
_ BasicOp
_ = Rule lore
forall lore. Rule lore
Skip

data IndexResult
  = IndexResult Certificates VName (Slice SubExp)
  | SubExpResult Certificates SubExp

simplifyIndexing ::
  MonadBinder m =>
  ST.SymbolTable (Lore m) ->
  TypeLookup ->
  VName ->
  Slice SubExp ->
  Bool ->
  Maybe (m IndexResult)
simplifyIndexing :: SymbolTable (Lore m)
-> (SubExp -> Maybe (TypeBase Shape NoUniqueness))
-> VName
-> Slice SubExp
-> Bool
-> Maybe (m IndexResult)
simplifyIndexing SymbolTable (Lore m)
vtable SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType VName
idd Slice SubExp
inds Bool
consuming =
  case VName -> Maybe (BasicOp, Certificates)
defOf VName
idd of
    Maybe (BasicOp, Certificates)
_
      | Just TypeBase Shape NoUniqueness
t <- SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType (VName -> SubExp
Var VName
idd),
        Slice SubExp
inds Slice SubExp -> Slice SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== TypeBase Shape NoUniqueness -> Slice SubExp -> Slice SubExp
fullSlice TypeBase Shape NoUniqueness
t [] ->
        m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ IndexResult -> m IndexResult
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certificates -> SubExp -> IndexResult
SubExpResult Certificates
forall a. Monoid a => a
mempty (SubExp -> IndexResult) -> SubExp -> IndexResult
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
idd
      | Just [SubExp]
inds' <- Slice SubExp -> Maybe [SubExp]
forall d. Slice d -> Maybe [d]
sliceIndices Slice SubExp
inds,
        Just (ST.Indexed Certificates
cs PrimExp VName
e) <- VName -> [SubExp] -> SymbolTable (Lore m) -> Maybe Indexed
forall lore.
ASTLore lore =>
VName -> [SubExp] -> SymbolTable lore -> Maybe Indexed
ST.index VName
idd [SubExp]
inds' SymbolTable (Lore m)
vtable,
        PrimExp VName -> Bool
forall v. PrimExp v -> Bool
worthInlining PrimExp VName
e,
        (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> SymbolTable (Lore m) -> Bool
forall lore. VName -> SymbolTable lore -> Bool
`ST.elem` SymbolTable (Lore m)
vtable) (Certificates -> [VName]
unCertificates Certificates
cs) ->
        m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ Certificates -> SubExp -> IndexResult
SubExpResult Certificates
cs (SubExp -> IndexResult) -> m SubExp -> m IndexResult
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> PrimExp VName -> m SubExp
forall (m :: * -> *) a.
(MonadBinder m, ToExp a) =>
String -> a -> m SubExp
toSubExp String
"index_primexp" PrimExp VName
e
      | Just [SubExp]
inds' <- Slice SubExp -> Maybe [SubExp]
forall d. Slice d -> Maybe [d]
sliceIndices Slice SubExp
inds,
        Just (ST.IndexedArray Certificates
cs VName
arr [TPrimExp Int64 VName]
inds'') <- VName -> [SubExp] -> SymbolTable (Lore m) -> Maybe Indexed
forall lore.
ASTLore lore =>
VName -> [SubExp] -> SymbolTable lore -> Maybe Indexed
ST.index VName
idd [SubExp]
inds' SymbolTable (Lore m)
vtable,
        (TPrimExp Int64 VName -> Bool) -> [TPrimExp Int64 VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (PrimExp VName -> Bool
forall v. PrimExp v -> Bool
worthInlining (PrimExp VName -> Bool)
-> (TPrimExp Int64 VName -> PrimExp VName)
-> TPrimExp Int64 VName
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TPrimExp Int64 VName -> PrimExp VName
forall t v. TPrimExp t v -> PrimExp v
untyped) [TPrimExp Int64 VName]
inds'',
        (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> SymbolTable (Lore m) -> Bool
forall lore. VName -> SymbolTable lore -> Bool
`ST.elem` SymbolTable (Lore m)
vtable) (Certificates -> [VName]
unCertificates Certificates
cs) ->
        m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$
          Certificates -> VName -> Slice SubExp -> IndexResult
IndexResult Certificates
cs VName
arr (Slice SubExp -> IndexResult)
-> ([SubExp] -> Slice SubExp) -> [SubExp] -> IndexResult
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SubExp -> DimIndex SubExp) -> [SubExp] -> Slice SubExp
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix
            ([SubExp] -> IndexResult) -> m [SubExp] -> m IndexResult
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (TPrimExp Int64 VName -> m SubExp)
-> [TPrimExp Int64 VName] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (String -> TPrimExp Int64 VName -> m SubExp
forall (m :: * -> *) a.
(MonadBinder m, ToExp a) =>
String -> a -> m SubExp
toSubExp String
"index_primexp") [TPrimExp Int64 VName]
inds''
    Maybe (BasicOp, Certificates)
Nothing -> Maybe (m IndexResult)
forall a. Maybe a
Nothing
    Just (SubExp (Var VName
v), Certificates
cs) -> m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ IndexResult -> m IndexResult
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certificates -> VName -> Slice SubExp -> IndexResult
IndexResult Certificates
cs VName
v Slice SubExp
inds
    Just (Iota SubExp
_ SubExp
x SubExp
s IntType
to_it, Certificates
cs)
      | [DimFix SubExp
ii] <- Slice SubExp
inds,
        Just (Prim (IntType IntType
from_it)) <- SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType SubExp
ii ->
        m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$
          let mul :: PrimExp VName -> PrimExp VName -> PrimExp VName
mul = BinOp -> PrimExp VName -> PrimExp VName -> PrimExp VName
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp (BinOp -> PrimExp VName -> PrimExp VName -> PrimExp VName)
-> BinOp -> PrimExp VName -> PrimExp VName -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ IntType -> Overflow -> BinOp
Mul IntType
to_it Overflow
OverflowWrap
              add :: PrimExp VName -> PrimExp VName -> PrimExp VName
add = BinOp -> PrimExp VName -> PrimExp VName -> PrimExp VName
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp (BinOp -> PrimExp VName -> PrimExp VName -> PrimExp VName)
-> BinOp -> PrimExp VName -> PrimExp VName -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ IntType -> Overflow -> BinOp
Add IntType
to_it Overflow
OverflowWrap
           in (SubExp -> IndexResult) -> m SubExp -> m IndexResult
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Certificates -> SubExp -> IndexResult
SubExpResult Certificates
cs) (m SubExp -> m IndexResult) -> m SubExp -> m IndexResult
forall a b. (a -> b) -> a -> b
$
                String -> PrimExp VName -> m SubExp
forall (m :: * -> *) a.
(MonadBinder m, ToExp a) =>
String -> a -> m SubExp
toSubExp String
"index_iota" (PrimExp VName -> m SubExp) -> PrimExp VName -> m SubExp
forall a b. (a -> b) -> a -> b
$
                  ( IntType -> PrimExp VName -> PrimExp VName
forall v. IntType -> PrimExp v -> PrimExp v
sExt IntType
to_it (PrimType -> SubExp -> PrimExp VName
primExpFromSubExp (IntType -> PrimType
IntType IntType
from_it) SubExp
ii)
                      PrimExp VName -> PrimExp VName -> PrimExp VName
`mul` PrimType -> SubExp -> PrimExp VName
primExpFromSubExp (IntType -> PrimType
IntType IntType
to_it) SubExp
s
                  )
                    PrimExp VName -> PrimExp VName -> PrimExp VName
`add` PrimType -> SubExp -> PrimExp VName
primExpFromSubExp (IntType -> PrimType
IntType IntType
to_it) SubExp
x
      | [DimSlice SubExp
i_offset SubExp
i_n SubExp
i_stride] <- Slice SubExp
inds ->
        m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ do
          SubExp
i_offset' <- IntType -> SubExp -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
IntType -> SubExp -> m SubExp
asIntS IntType
to_it SubExp
i_offset
          SubExp
i_stride' <- IntType -> SubExp -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
IntType -> SubExp -> m SubExp
asIntS IntType
to_it SubExp
i_stride
          let mul :: PrimExp VName -> PrimExp VName -> PrimExp VName
mul = BinOp -> PrimExp VName -> PrimExp VName -> PrimExp VName
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp (BinOp -> PrimExp VName -> PrimExp VName -> PrimExp VName)
-> BinOp -> PrimExp VName -> PrimExp VName -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ IntType -> Overflow -> BinOp
Mul IntType
to_it Overflow
OverflowWrap
              add :: PrimExp VName -> PrimExp VName -> PrimExp VName
add = BinOp -> PrimExp VName -> PrimExp VName -> PrimExp VName
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp (BinOp -> PrimExp VName -> PrimExp VName -> PrimExp VName)
-> BinOp -> PrimExp VName -> PrimExp VName -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ IntType -> Overflow -> BinOp
Add IntType
to_it Overflow
OverflowWrap
          SubExp
i_offset'' <-
            String -> PrimExp VName -> m SubExp
forall (m :: * -> *) a.
(MonadBinder m, ToExp a) =>
String -> a -> m SubExp
toSubExp String
"iota_offset" (PrimExp VName -> m SubExp) -> PrimExp VName -> m SubExp
forall a b. (a -> b) -> a -> b
$
              ( PrimType -> SubExp -> PrimExp VName
primExpFromSubExp (IntType -> PrimType
IntType IntType
to_it) SubExp
x
                  PrimExp VName -> PrimExp VName -> PrimExp VName
`mul` PrimType -> SubExp -> PrimExp VName
primExpFromSubExp (IntType -> PrimType
IntType IntType
to_it) SubExp
s
              )
                PrimExp VName -> PrimExp VName -> PrimExp VName
`add` PrimType -> SubExp -> PrimExp VName
primExpFromSubExp (IntType -> PrimType
IntType IntType
to_it) SubExp
i_offset'
          SubExp
i_stride'' <-
            String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"iota_offset" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$
              BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowWrap) SubExp
s SubExp
i_stride'
          (SubExp -> IndexResult) -> m SubExp -> m IndexResult
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Certificates -> SubExp -> IndexResult
SubExpResult Certificates
cs) (m SubExp -> m IndexResult) -> m SubExp -> m IndexResult
forall a b. (a -> b) -> a -> b
$
            String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"slice_iota" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$
              BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
i_n SubExp
i_offset'' SubExp
i_stride'' IntType
to_it

    -- A rotate cannot be simplified away if we are slicing a rotated dimension.
    Just (Rotate [SubExp]
offsets VName
a, Certificates
cs)
      | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
or ([Bool] -> Bool) -> [Bool] -> Bool
forall a b. (a -> b) -> a -> b
$ (SubExp -> DimIndex SubExp -> Bool)
-> [SubExp] -> Slice SubExp -> [Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith SubExp -> DimIndex SubExp -> Bool
forall d. SubExp -> DimIndex d -> Bool
rotateAndSlice [SubExp]
offsets Slice SubExp
inds -> m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ do
        [SubExp]
dims <- TypeBase Shape NoUniqueness -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (TypeBase Shape NoUniqueness -> [SubExp])
-> m (TypeBase Shape NoUniqueness) -> m [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> m (TypeBase Shape NoUniqueness)
forall lore (m :: * -> *).
HasScope lore m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
a
        let adjustI :: SubExp -> SubExp -> SubExp -> m SubExp
adjustI SubExp
i SubExp
o SubExp
d = do
              SubExp
i_p_o <- String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"i_p_o" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Add IntType
Int64 Overflow
OverflowWrap) SubExp
i SubExp
o
              String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"rot_i" (BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Safety -> BinOp
SMod IntType
Int64 Safety
Unsafe) SubExp
i_p_o SubExp
d)
            adjust :: (DimIndex SubExp, SubExp, SubExp) -> f (DimIndex SubExp)
adjust (DimFix SubExp
i, SubExp
o, SubExp
d) =
              SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> f SubExp -> f (DimIndex SubExp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> SubExp -> SubExp -> f SubExp
forall (m :: * -> *).
MonadBinder m =>
SubExp -> SubExp -> SubExp -> m SubExp
adjustI SubExp
i SubExp
o SubExp
d
            adjust (DimSlice SubExp
i SubExp
n SubExp
s, SubExp
o, SubExp
d) =
              SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice (SubExp -> SubExp -> SubExp -> DimIndex SubExp)
-> f SubExp -> f (SubExp -> SubExp -> DimIndex SubExp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> SubExp -> SubExp -> f SubExp
forall (m :: * -> *).
MonadBinder m =>
SubExp -> SubExp -> SubExp -> m SubExp
adjustI SubExp
i SubExp
o SubExp
d f (SubExp -> SubExp -> DimIndex SubExp)
-> f SubExp -> f (SubExp -> DimIndex SubExp)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SubExp -> f SubExp
forall (f :: * -> *) a. Applicative f => a -> f a
pure SubExp
n f (SubExp -> DimIndex SubExp) -> f SubExp -> f (DimIndex SubExp)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SubExp -> f SubExp
forall (f :: * -> *) a. Applicative f => a -> f a
pure SubExp
s
        Certificates -> VName -> Slice SubExp -> IndexResult
IndexResult Certificates
cs VName
a (Slice SubExp -> IndexResult) -> m (Slice SubExp) -> m IndexResult
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((DimIndex SubExp, SubExp, SubExp) -> m (DimIndex SubExp))
-> [(DimIndex SubExp, SubExp, SubExp)] -> m (Slice SubExp)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (DimIndex SubExp, SubExp, SubExp) -> m (DimIndex SubExp)
forall (f :: * -> *).
MonadBinder f =>
(DimIndex SubExp, SubExp, SubExp) -> f (DimIndex SubExp)
adjust (Slice SubExp
-> [SubExp] -> [SubExp] -> [(DimIndex SubExp, SubExp, SubExp)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 Slice SubExp
inds [SubExp]
offsets [SubExp]
dims)
      where
        rotateAndSlice :: SubExp -> DimIndex d -> Bool
rotateAndSlice SubExp
r DimSlice {} = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ SubExp -> Bool
isCt0 SubExp
r
        rotateAndSlice SubExp
_ DimIndex d
_ = Bool
False
    Just (Index VName
aa Slice SubExp
ais, Certificates
cs) ->
      m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$
        Certificates -> VName -> Slice SubExp -> IndexResult
IndexResult Certificates
cs VName
aa
          (Slice SubExp -> IndexResult) -> m (Slice SubExp) -> m IndexResult
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Slice (TPrimExp Int64 VName) -> m (Slice SubExp)
forall (m :: * -> *).
MonadBinder m =>
Slice (TPrimExp Int64 VName) -> m (Slice SubExp)
subExpSlice (Slice (TPrimExp Int64 VName)
-> Slice (TPrimExp Int64 VName) -> Slice (TPrimExp Int64 VName)
forall d. Num d => Slice d -> Slice d -> Slice d
sliceSlice (Slice SubExp -> Slice (TPrimExp Int64 VName)
primExpSlice Slice SubExp
ais) (Slice SubExp -> Slice (TPrimExp Int64 VName)
primExpSlice Slice SubExp
inds))
    Just (Replicate (Shape [SubExp
_]) (Var VName
vv), Certificates
cs)
      | [DimFix {}] <- Slice SubExp
inds, Bool -> Bool
not Bool
consuming -> m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ IndexResult -> m IndexResult
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certificates -> SubExp -> IndexResult
SubExpResult Certificates
cs (SubExp -> IndexResult) -> SubExp -> IndexResult
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
vv
      | DimFix {} : Slice SubExp
is' <- Slice SubExp
inds, Bool -> Bool
not Bool
consuming -> m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ IndexResult -> m IndexResult
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certificates -> VName -> Slice SubExp -> IndexResult
IndexResult Certificates
cs VName
vv Slice SubExp
is'
    Just (Replicate (Shape [SubExp
_]) val :: SubExp
val@(Constant PrimValue
_), Certificates
cs)
      | [DimFix {}] <- Slice SubExp
inds, Bool -> Bool
not Bool
consuming -> m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ IndexResult -> m IndexResult
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certificates -> SubExp -> IndexResult
SubExpResult Certificates
cs SubExp
val
    Just (Replicate (Shape [SubExp]
ds) SubExp
v, Certificates
cs)
      | (Slice SubExp
ds_inds, Slice SubExp
rest_inds) <- Int -> Slice SubExp -> (Slice SubExp, Slice SubExp)
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
ds) Slice SubExp
inds,
        ([SubExp]
ds', Slice SubExp
ds_inds') <- [(SubExp, DimIndex SubExp)] -> ([SubExp], Slice SubExp)
forall a b. [(a, b)] -> ([a], [b])
unzip ([(SubExp, DimIndex SubExp)] -> ([SubExp], Slice SubExp))
-> [(SubExp, DimIndex SubExp)] -> ([SubExp], Slice SubExp)
forall a b. (a -> b) -> a -> b
$ (DimIndex SubExp -> Maybe (SubExp, DimIndex SubExp))
-> Slice SubExp -> [(SubExp, DimIndex SubExp)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe DimIndex SubExp -> Maybe (SubExp, DimIndex SubExp)
index Slice SubExp
ds_inds,
        [SubExp]
ds' [SubExp] -> [SubExp] -> Bool
forall a. Eq a => a -> a -> Bool
/= [SubExp]
ds ->
        m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ do
          VName
arr <- String -> Exp (Lore m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"smaller_replicate" (Exp (Lore m) -> m VName) -> Exp (Lore m) -> m VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
ds') SubExp
v
          IndexResult -> m IndexResult
forall (m :: * -> *) a. Monad m => a -> m a
return (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certificates -> VName -> Slice SubExp -> IndexResult
IndexResult Certificates
cs VName
arr (Slice SubExp -> IndexResult) -> Slice SubExp -> IndexResult
forall a b. (a -> b) -> a -> b
$ Slice SubExp
ds_inds' Slice SubExp -> Slice SubExp -> Slice SubExp
forall a. [a] -> [a] -> [a]
++ Slice SubExp
rest_inds
      where
        index :: DimIndex SubExp -> Maybe (SubExp, DimIndex SubExp)
index DimFix {} = Maybe (SubExp, DimIndex SubExp)
forall a. Maybe a
Nothing
        index (DimSlice SubExp
_ SubExp
n SubExp
s) = (SubExp, DimIndex SubExp) -> Maybe (SubExp, DimIndex SubExp)
forall a. a -> Maybe a
Just (SubExp
n, SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice (Int64 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int64
0 :: Int64)) SubExp
n SubExp
s)
    Just (Rearrange [Int]
perm VName
src, Certificates
cs)
      | [Int] -> Int
rearrangeReach [Int]
perm Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Slice SubExp -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ((DimIndex SubExp -> Bool) -> Slice SubExp -> Slice SubExp
forall a. (a -> Bool) -> [a] -> [a]
takeWhile DimIndex SubExp -> Bool
forall d. DimIndex d -> Bool
isIndex Slice SubExp
inds) ->
        let inds' :: Slice SubExp
inds' = [Int] -> Slice SubExp -> Slice SubExp
forall a. [Int] -> [a] -> [a]
rearrangeShape ([Int] -> [Int]
rearrangeInverse [Int]
perm) Slice SubExp
inds
         in m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ IndexResult -> m IndexResult
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certificates -> VName -> Slice SubExp -> IndexResult
IndexResult Certificates
cs VName
src Slice SubExp
inds'
      where
        isIndex :: DimIndex d -> Bool
isIndex DimFix {} = Bool
True
        isIndex DimIndex d
_ = Bool
False
    Just (Copy VName
src, Certificates
cs)
      | Just [SubExp]
dims <- TypeBase Shape NoUniqueness -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (TypeBase Shape NoUniqueness -> [SubExp])
-> Maybe (TypeBase Shape NoUniqueness) -> Maybe [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType (VName -> SubExp
Var VName
src),
        Slice SubExp -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Slice SubExp
inds Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
dims,
        Bool -> Bool
not Bool
consuming,
        VName -> SymbolTable (Lore m) -> Bool
forall lore. VName -> SymbolTable lore -> Bool
ST.available VName
src SymbolTable (Lore m)
vtable ->
        m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ IndexResult -> m IndexResult
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certificates -> VName -> Slice SubExp -> IndexResult
IndexResult Certificates
cs VName
src Slice SubExp
inds
    Just (Reshape ShapeChange SubExp
newshape VName
src, Certificates
cs)
      | Just [SubExp]
newdims <- ShapeChange SubExp -> Maybe [SubExp]
forall d. ShapeChange d -> Maybe [d]
shapeCoercion ShapeChange SubExp
newshape,
        Just [SubExp]
olddims <- TypeBase Shape NoUniqueness -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (TypeBase Shape NoUniqueness -> [SubExp])
-> Maybe (TypeBase Shape NoUniqueness) -> Maybe [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType (VName -> SubExp
Var VName
src),
        [Bool]
changed_dims <- (SubExp -> SubExp -> Bool) -> [SubExp] -> [SubExp] -> [Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
(/=) [SubExp]
newdims [SubExp]
olddims,
        Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
or ([Bool] -> Bool) -> [Bool] -> Bool
forall a b. (a -> b) -> a -> b
$ Int -> [Bool] -> [Bool]
forall a. Int -> [a] -> [a]
drop (Slice SubExp -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Slice SubExp
inds) [Bool]
changed_dims ->
        m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ IndexResult -> m IndexResult
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certificates -> VName -> Slice SubExp -> IndexResult
IndexResult Certificates
cs VName
src Slice SubExp
inds
      | Just [SubExp]
newdims <- ShapeChange SubExp -> Maybe [SubExp]
forall d. ShapeChange d -> Maybe [d]
shapeCoercion ShapeChange SubExp
newshape,
        Just [SubExp]
olddims <- TypeBase Shape NoUniqueness -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (TypeBase Shape NoUniqueness -> [SubExp])
-> Maybe (TypeBase Shape NoUniqueness) -> Maybe [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType (VName -> SubExp
Var VName
src),
        ShapeChange SubExp -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeChange SubExp
newshape Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Slice SubExp -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Slice SubExp
inds,
        [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
olddims Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
newdims ->
        m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ IndexResult -> m IndexResult
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certificates -> VName -> Slice SubExp -> IndexResult
IndexResult Certificates
cs VName
src Slice SubExp
inds
    Just (Reshape [DimChange SubExp
_] VName
v2, Certificates
cs)
      | Just [SubExp
_] <- TypeBase Shape NoUniqueness -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (TypeBase Shape NoUniqueness -> [SubExp])
-> Maybe (TypeBase Shape NoUniqueness) -> Maybe [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType (VName -> SubExp
Var VName
v2) ->
        m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ IndexResult -> m IndexResult
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certificates -> VName -> Slice SubExp -> IndexResult
IndexResult Certificates
cs VName
v2 Slice SubExp
inds
    Just (Concat Int
d VName
x [VName]
xs SubExp
_, Certificates
cs)
      | -- HACK: simplifying the indexing of an N-array concatenation
        -- is going to produce an N-deep if expression, which is bad
        -- when N is large.  To try to avoid that, we use the
        -- heuristic not to simplify as long as any of the operands
        -- are themselves Concats.  The hops it that this will give
        -- simplification some time to cut down the concatenation to
        -- something smaller, before we start inlining.
        Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any VName -> Bool
isConcat ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$ VName
x VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
xs,
        Just (Slice SubExp
ibef, DimFix SubExp
i, Slice SubExp
iaft) <- Int
-> Slice SubExp
-> Maybe (Slice SubExp, DimIndex SubExp, Slice SubExp)
forall int a. Integral int => int -> [a] -> Maybe ([a], a, [a])
focusNth Int
d Slice SubExp
inds,
        Just (Prim PrimType
res_t) <-
          (TypeBase Shape NoUniqueness
-> [SubExp] -> TypeBase Shape NoUniqueness
forall oldshape u.
TypeBase oldshape u -> [SubExp] -> TypeBase Shape u
`setArrayDims` Slice SubExp -> [SubExp]
forall d. Slice d -> [d]
sliceDims Slice SubExp
inds)
            (TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness)
-> Maybe (TypeBase Shape NoUniqueness)
-> Maybe (TypeBase Shape NoUniqueness)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName
-> SymbolTable (Lore m) -> Maybe (TypeBase Shape NoUniqueness)
forall lore.
ASTLore lore =>
VName -> SymbolTable lore -> Maybe (TypeBase Shape NoUniqueness)
ST.lookupType VName
x SymbolTable (Lore m)
vtable -> m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ do
        SubExp
x_len <- Int -> TypeBase Shape NoUniqueness -> SubExp
forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
d (TypeBase Shape NoUniqueness -> SubExp)
-> m (TypeBase Shape NoUniqueness) -> m SubExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> m (TypeBase Shape NoUniqueness)
forall lore (m :: * -> *).
HasScope lore m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
x
        [SubExp]
xs_lens <- (VName -> m SubExp) -> [VName] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((TypeBase Shape NoUniqueness -> SubExp)
-> m (TypeBase Shape NoUniqueness) -> m SubExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Int -> TypeBase Shape NoUniqueness -> SubExp
forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
d) (m (TypeBase Shape NoUniqueness) -> m SubExp)
-> (VName -> m (TypeBase Shape NoUniqueness)) -> VName -> m SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> m (TypeBase Shape NoUniqueness)
forall lore (m :: * -> *).
HasScope lore m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType) [VName]
xs

        let add :: SubExp -> SubExp -> m (SubExp, SubExp)
add SubExp
n SubExp
m = do
              SubExp
added <- String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"index_concat_add" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Add IntType
Int64 Overflow
OverflowWrap) SubExp
n SubExp
m
              (SubExp, SubExp) -> m (SubExp, SubExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp
added, SubExp
n)
        (SubExp
_, [SubExp]
starts) <- (SubExp -> SubExp -> m (SubExp, SubExp))
-> SubExp -> [SubExp] -> m (SubExp, [SubExp])
forall (m :: * -> *) acc x y.
Monad m =>
(acc -> x -> m (acc, y)) -> acc -> [x] -> m (acc, [y])
mapAccumLM SubExp -> SubExp -> m (SubExp, SubExp)
forall (m :: * -> *).
MonadBinder m =>
SubExp -> SubExp -> m (SubExp, SubExp)
add SubExp
x_len [SubExp]
xs_lens
        let xs_and_starts :: [(VName, SubExp)]
xs_and_starts = [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a]
reverse ([(VName, SubExp)] -> [(VName, SubExp)])
-> [(VName, SubExp)] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ [VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
xs [SubExp]
starts

        let mkBranch :: [(VName, SubExp)] -> m SubExp
mkBranch [] =
              String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"index_concat" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
x (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Slice SubExp
ibef Slice SubExp -> Slice SubExp -> Slice SubExp
forall a. [a] -> [a] -> [a]
++ SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
i DimIndex SubExp -> Slice SubExp -> Slice SubExp
forall a. a -> [a] -> [a]
: Slice SubExp
iaft
            mkBranch ((VName
x', SubExp
start) : [(VName, SubExp)]
xs_and_starts') = do
              SubExp
cmp <- String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"index_concat_cmp" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ CmpOp -> SubExp -> SubExp -> BasicOp
CmpOp (IntType -> CmpOp
CmpSle IntType
Int64) SubExp
start SubExp
i
              (SubExp
thisres, Stms (Lore m)
thisbnds) <- m SubExp -> m (SubExp, Stms (Lore m))
forall (m :: * -> *) a.
MonadBinder m =>
m a -> m (a, Stms (Lore m))
collectStms (m SubExp -> m (SubExp, Stms (Lore m)))
-> m SubExp -> m (SubExp, Stms (Lore m))
forall a b. (a -> b) -> a -> b
$ do
                SubExp
i' <- String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"index_concat_i" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Sub IntType
Int64 Overflow
OverflowWrap) SubExp
i SubExp
start
                String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"index_concat" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
x' (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Slice SubExp
ibef Slice SubExp -> Slice SubExp -> Slice SubExp
forall a. [a] -> [a] -> [a]
++ SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
i' DimIndex SubExp -> Slice SubExp -> Slice SubExp
forall a. a -> [a] -> [a]
: Slice SubExp
iaft
              BodyT (Lore m)
thisbody <- Stms (Lore m) -> [SubExp] -> m (BodyT (Lore m))
forall (m :: * -> *).
MonadBinder m =>
Stms (Lore m) -> [SubExp] -> m (Body (Lore m))
mkBodyM Stms (Lore m)
thisbnds [SubExp
thisres]
              (SubExp
altres, Stms (Lore m)
altbnds) <- m SubExp -> m (SubExp, Stms (Lore m))
forall (m :: * -> *) a.
MonadBinder m =>
m a -> m (a, Stms (Lore m))
collectStms (m SubExp -> m (SubExp, Stms (Lore m)))
-> m SubExp -> m (SubExp, Stms (Lore m))
forall a b. (a -> b) -> a -> b
$ [(VName, SubExp)] -> m SubExp
mkBranch [(VName, SubExp)]
xs_and_starts'
              BodyT (Lore m)
altbody <- Stms (Lore m) -> [SubExp] -> m (BodyT (Lore m))
forall (m :: * -> *).
MonadBinder m =>
Stms (Lore m) -> [SubExp] -> m (Body (Lore m))
mkBodyM Stms (Lore m)
altbnds [SubExp
altres]
              String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"index_concat_branch" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$
                SubExp
-> BodyT (Lore m)
-> BodyT (Lore m)
-> IfDec (BranchType (Lore m))
-> Exp (Lore m)
forall lore.
SubExp
-> BodyT lore -> BodyT lore -> IfDec (BranchType lore) -> ExpT lore
If SubExp
cmp BodyT (Lore m)
thisbody BodyT (Lore m)
altbody (IfDec (BranchType (Lore m)) -> Exp (Lore m))
-> IfDec (BranchType (Lore m)) -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$
                  [BranchType (Lore m)] -> IfSort -> IfDec (BranchType (Lore m))
forall rt. [rt] -> IfSort -> IfDec rt
IfDec [PrimType -> BranchType (Lore m)
forall rt. IsBodyType rt => PrimType -> rt
primBodyType PrimType
res_t] IfSort
IfNormal
        Certificates -> SubExp -> IndexResult
SubExpResult Certificates
cs (SubExp -> IndexResult) -> m SubExp -> m IndexResult
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [(VName, SubExp)] -> m SubExp
mkBranch [(VName, SubExp)]
xs_and_starts
    Just (ArrayLit [SubExp]
ses TypeBase Shape NoUniqueness
_, Certificates
cs)
      | DimFix (Constant (IntValue (Int64Value Int64
i))) : Slice SubExp
inds' <- Slice SubExp
inds,
        Just SubExp
se <- Int64 -> [SubExp] -> Maybe SubExp
forall int a. Integral int => int -> [a] -> Maybe a
maybeNth Int64
i [SubExp]
ses ->
        case Slice SubExp
inds' of
          [] -> m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ IndexResult -> m IndexResult
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certificates -> SubExp -> IndexResult
SubExpResult Certificates
cs SubExp
se
          Slice SubExp
_ | Var VName
v2 <- SubExp
se -> m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ IndexResult -> m IndexResult
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certificates -> VName -> Slice SubExp -> IndexResult
IndexResult Certificates
cs VName
v2 Slice SubExp
inds'
          Slice SubExp
_ -> Maybe (m IndexResult)
forall a. Maybe a
Nothing
    -- Indexing single-element arrays.  We know the index must be 0.
    Maybe (BasicOp, Certificates)
_
      | Just TypeBase Shape NoUniqueness
t <- SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType (SubExp -> Maybe (TypeBase Shape NoUniqueness))
-> SubExp -> Maybe (TypeBase Shape NoUniqueness)
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
idd,
        SubExp -> Bool
isCt1 (SubExp -> Bool) -> SubExp -> Bool
forall a b. (a -> b) -> a -> b
$ Int -> TypeBase Shape NoUniqueness -> SubExp
forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 TypeBase Shape NoUniqueness
t,
        DimFix SubExp
i : Slice SubExp
inds' <- Slice SubExp
inds,
        Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ SubExp -> Bool
isCt0 SubExp
i ->
        m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$
          IndexResult -> m IndexResult
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$
            Certificates -> VName -> Slice SubExp -> IndexResult
IndexResult Certificates
forall a. Monoid a => a
mempty VName
idd (Slice SubExp -> IndexResult) -> Slice SubExp -> IndexResult
forall a b. (a -> b) -> a -> b
$
              SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (Int64 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int64
0 :: Int64)) DimIndex SubExp -> Slice SubExp -> Slice SubExp
forall a. a -> [a] -> [a]
: Slice SubExp
inds'
    Maybe (BasicOp, Certificates)
_ -> Maybe (m IndexResult)
forall a. Maybe a
Nothing
  where
    defOf :: VName -> Maybe (BasicOp, Certificates)
defOf VName
v = do
      (BasicOp BasicOp
op, Certificates
def_cs) <- VName -> SymbolTable (Lore m) -> Maybe (Exp (Lore m), Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (Exp lore, Certificates)
ST.lookupExp VName
v SymbolTable (Lore m)
vtable
      (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall (m :: * -> *) a. Monad m => a -> m a
return (BasicOp
op, Certificates
def_cs)
    worthInlining :: PrimExp v -> Bool
worthInlining PrimExp v
e
      | Int -> PrimExp v -> Bool
forall v. Int -> PrimExp v -> Bool
primExpSizeAtLeast Int
20 PrimExp v
e = Bool
False -- totally ad-hoc.
      | Bool
otherwise = PrimExp v -> Bool
forall v. PrimExp v -> Bool
worthInlining' PrimExp v
e
    worthInlining' :: PrimExp v -> Bool
worthInlining' (BinOpExp Pow {} PrimExp v
_ PrimExp v
_) = Bool
False
    worthInlining' (BinOpExp FPow {} PrimExp v
_ PrimExp v
_) = Bool
False
    worthInlining' (BinOpExp BinOp
_ PrimExp v
x PrimExp v
y) = PrimExp v -> Bool
worthInlining' PrimExp v
x Bool -> Bool -> Bool
&& PrimExp v -> Bool
worthInlining' PrimExp v
y
    worthInlining' (CmpOpExp CmpOp
_ PrimExp v
x PrimExp v
y) = PrimExp v -> Bool
worthInlining' PrimExp v
x Bool -> Bool -> Bool
&& PrimExp v -> Bool
worthInlining' PrimExp v
y
    worthInlining' (ConvOpExp ConvOp
_ PrimExp v
x) = PrimExp v -> Bool
worthInlining' PrimExp v
x
    worthInlining' (UnOpExp UnOp
_ PrimExp v
x) = PrimExp v -> Bool
worthInlining' PrimExp v
x
    worthInlining' FunExp {} = Bool
False
    worthInlining' PrimExp v
_ = Bool
True

    isConcat :: VName -> Bool
isConcat VName
v
      | Just (Concat {}, Certificates
_) <- VName -> Maybe (BasicOp, Certificates)
defOf VName
v =
        Bool
True
      | Bool
otherwise =
        Bool
False

data ConcatArg
  = ArgArrayLit [SubExp]
  | ArgReplicate [SubExp] SubExp
  | ArgVar VName

toConcatArg :: ST.SymbolTable lore -> VName -> (ConcatArg, Certificates)
toConcatArg :: SymbolTable lore -> VName -> (ConcatArg, Certificates)
toConcatArg SymbolTable lore
vtable VName
v =
  case VName -> SymbolTable lore -> Maybe (BasicOp, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (BasicOp, Certificates)
ST.lookupBasicOp VName
v SymbolTable lore
vtable of
    Just (ArrayLit [SubExp]
ses TypeBase Shape NoUniqueness
_, Certificates
cs) ->
      ([SubExp] -> ConcatArg
ArgArrayLit [SubExp]
ses, Certificates
cs)
    Just (Replicate Shape
shape SubExp
se, Certificates
cs) ->
      ([SubExp] -> SubExp -> ConcatArg
ArgReplicate [Int -> Shape -> SubExp
shapeSize Int
0 Shape
shape] SubExp
se, Certificates
cs)
    Maybe (BasicOp, Certificates)
_ ->
      (VName -> ConcatArg
ArgVar VName
v, Certificates
forall a. Monoid a => a
mempty)

fromConcatArg ::
  MonadBinder m =>
  Type ->
  (ConcatArg, Certificates) ->
  m VName
fromConcatArg :: TypeBase Shape NoUniqueness -> (ConcatArg, Certificates) -> m VName
fromConcatArg TypeBase Shape NoUniqueness
t (ArgArrayLit [SubExp]
ses, Certificates
cs) =
  Certificates -> m VName -> m VName
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs (m VName -> m VName) -> m VName -> m VName
forall a b. (a -> b) -> a -> b
$ String -> Exp (Lore m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"concat_lit" (Exp (Lore m) -> m VName) -> Exp (Lore m) -> m VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ [SubExp] -> TypeBase Shape NoUniqueness -> BasicOp
ArrayLit [SubExp]
ses (TypeBase Shape NoUniqueness -> BasicOp)
-> TypeBase Shape NoUniqueness -> BasicOp
forall a b. (a -> b) -> a -> b
$ TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
forall shape u.
ArrayShape shape =>
TypeBase shape u -> TypeBase shape u
rowType TypeBase Shape NoUniqueness
t
fromConcatArg TypeBase Shape NoUniqueness
elem_type (ArgReplicate [SubExp]
ws SubExp
se, Certificates
cs) = do
  let elem_shape :: Shape
elem_shape = TypeBase Shape NoUniqueness -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape TypeBase Shape NoUniqueness
elem_type
  Certificates -> m VName -> m VName
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs (m VName -> m VName) -> m VName -> m VName
forall a b. (a -> b) -> a -> b
$ do
    SubExp
w <- String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"concat_rep_w" (Exp (Lore m) -> m SubExp) -> m (Exp (Lore m)) -> m SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName -> m (Exp (Lore m))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp ([TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([TPrimExp Int64 VName] -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
ws)
    String -> Exp (Lore m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"concat_rep" (Exp (Lore m) -> m VName) -> Exp (Lore m) -> m VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate (Int -> Shape -> SubExp -> Shape
forall d. Int -> ShapeBase d -> d -> ShapeBase d
setDim Int
0 Shape
elem_shape SubExp
w) SubExp
se
fromConcatArg TypeBase Shape NoUniqueness
_ (ArgVar VName
v, Certificates
_) =
  VName -> m VName
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
v

fuseConcatArg ::
  [(ConcatArg, Certificates)] ->
  (ConcatArg, Certificates) ->
  [(ConcatArg, Certificates)]
fuseConcatArg :: [(ConcatArg, Certificates)]
-> (ConcatArg, Certificates) -> [(ConcatArg, Certificates)]
fuseConcatArg [(ConcatArg, Certificates)]
xs (ArgArrayLit [], Certificates
_) =
  [(ConcatArg, Certificates)]
xs
fuseConcatArg [(ConcatArg, Certificates)]
xs (ArgReplicate [SubExp
w] SubExp
se, Certificates
cs)
  | SubExp -> Bool
isCt0 SubExp
w =
    [(ConcatArg, Certificates)]
xs
  | SubExp -> Bool
isCt1 SubExp
w =
    [(ConcatArg, Certificates)]
-> (ConcatArg, Certificates) -> [(ConcatArg, Certificates)]
fuseConcatArg [(ConcatArg, Certificates)]
xs ([SubExp] -> ConcatArg
ArgArrayLit [SubExp
se], Certificates
cs)
fuseConcatArg ((ArgArrayLit [SubExp]
x_ses, Certificates
x_cs) : [(ConcatArg, Certificates)]
xs) (ArgArrayLit [SubExp]
y_ses, Certificates
y_cs) =
  ([SubExp] -> ConcatArg
ArgArrayLit ([SubExp]
x_ses [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp]
y_ses), Certificates
x_cs Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> Certificates
y_cs) (ConcatArg, Certificates)
-> [(ConcatArg, Certificates)] -> [(ConcatArg, Certificates)]
forall a. a -> [a] -> [a]
: [(ConcatArg, Certificates)]
xs
fuseConcatArg ((ArgReplicate [SubExp]
x_ws SubExp
x_se, Certificates
x_cs) : [(ConcatArg, Certificates)]
xs) (ArgReplicate [SubExp]
y_ws SubExp
y_se, Certificates
y_cs)
  | SubExp
x_se SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
y_se =
    ([SubExp] -> SubExp -> ConcatArg
ArgReplicate ([SubExp]
x_ws [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp]
y_ws) SubExp
x_se, Certificates
x_cs Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> Certificates
y_cs) (ConcatArg, Certificates)
-> [(ConcatArg, Certificates)] -> [(ConcatArg, Certificates)]
forall a. a -> [a] -> [a]
: [(ConcatArg, Certificates)]
xs
fuseConcatArg [(ConcatArg, Certificates)]
xs (ConcatArg, Certificates)
y =
  (ConcatArg, Certificates)
y (ConcatArg, Certificates)
-> [(ConcatArg, Certificates)] -> [(ConcatArg, Certificates)]
forall a. a -> [a] -> [a]
: [(ConcatArg, Certificates)]
xs

simplifyConcat :: BinderOps lore => BottomUpRuleBasicOp lore
-- concat@1(transpose(x),transpose(y)) == transpose(concat@0(x,y))
simplifyConcat :: BottomUpRuleBasicOp lore
simplifyConcat (SymbolTable lore
vtable, UsageTable
_) Pattern lore
pat StmAux (ExpDec lore)
_ (Concat Int
i VName
x [VName]
xs SubExp
new_d)
  | Just Int
r <- TypeBase Shape NoUniqueness -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank (TypeBase Shape NoUniqueness -> Int)
-> Maybe (TypeBase Shape NoUniqueness) -> Maybe Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> SymbolTable lore -> Maybe (TypeBase Shape NoUniqueness)
forall lore.
ASTLore lore =>
VName -> SymbolTable lore -> Maybe (TypeBase Shape NoUniqueness)
ST.lookupType VName
x SymbolTable lore
vtable,
    let perm :: [Int]
perm = [Int
i] [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [Int
0 .. Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 .. Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1],
    Just (VName
x', Certificates
x_cs) <- [Int] -> VName -> Maybe (VName, Certificates)
transposedBy [Int]
perm VName
x,
    Just ([VName]
xs', [Certificates]
xs_cs) <- [(VName, Certificates)] -> ([VName], [Certificates])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, Certificates)] -> ([VName], [Certificates]))
-> Maybe [(VName, Certificates)] -> Maybe ([VName], [Certificates])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VName -> Maybe (VName, Certificates))
-> [VName] -> Maybe [(VName, Certificates)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ([Int] -> VName -> Maybe (VName, Certificates)
transposedBy [Int]
perm) [VName]
xs = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
    VName
concat_rearrange <-
      Certificates -> RuleM lore VName -> RuleM lore VName
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (Certificates
x_cs Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> [Certificates] -> Certificates
forall a. Monoid a => [a] -> a
mconcat [Certificates]
xs_cs) (RuleM lore VName -> RuleM lore VName)
-> RuleM lore VName -> RuleM lore VName
forall a b. (a -> b) -> a -> b
$
        String -> Exp (Lore (RuleM lore)) -> RuleM lore VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"concat_rearrange" (Exp (Lore (RuleM lore)) -> RuleM lore VName)
-> Exp (Lore (RuleM lore)) -> RuleM lore VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ Int -> VName -> [VName] -> SubExp -> BasicOp
Concat Int
0 VName
x' [VName]
xs' SubExp
new_d
    Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Rearrange [Int]
perm VName
concat_rearrange
  where
    transposedBy :: [Int] -> VName -> Maybe (VName, Certificates)
transposedBy [Int]
perm1 VName
v =
      case VName -> SymbolTable lore -> Maybe (ExpT lore, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (Exp lore, Certificates)
ST.lookupExp VName
v SymbolTable lore
vtable of
        Just (BasicOp (Rearrange [Int]
perm2 VName
v'), Certificates
vcs)
          | [Int]
perm1 [Int] -> [Int] -> Bool
forall a. Eq a => a -> a -> Bool
== [Int]
perm2 -> (VName, Certificates) -> Maybe (VName, Certificates)
forall a. a -> Maybe a
Just (VName
v', Certificates
vcs)
        Maybe (ExpT lore, Certificates)
_ -> Maybe (VName, Certificates)
forall a. Maybe a
Nothing

-- Removing a concatenation that involves only a single array.  This
-- may be produced as a result of other simplification rules.
simplifyConcat (SymbolTable lore, UsageTable)
_ Pattern lore
pat StmAux (ExpDec lore)
aux (Concat Int
_ VName
x [] SubExp
_) =
  RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$
    -- Still need a copy because Concat produces a fresh array.
    StmAux (ExpDec lore) -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux (ExpDec lore)
aux (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
x
-- concat xs (concat ys zs) == concat xs ys zs
simplifyConcat (SymbolTable lore
vtable, UsageTable
_) Pattern lore
pat (StmAux Certificates
cs Attrs
attrs ExpDec lore
_) (Concat Int
i VName
x [VName]
xs SubExp
new_d)
  | VName
x' VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
/= VName
x Bool -> Bool -> Bool
|| [[VName]] -> [VName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[VName]]
xs' [VName] -> [VName] -> Bool
forall a. Eq a => a -> a -> Bool
/= [VName]
xs =
    RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$
      Certificates -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (Certificates
cs Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> Certificates
x_cs Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> [Certificates] -> Certificates
forall a. Monoid a => [a] -> a
mconcat [Certificates]
xs_cs) (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
        Attrs -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Attrs -> m a -> m a
attributing Attrs
attrs (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
          Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
            BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ Int -> VName -> [VName] -> SubExp -> BasicOp
Concat Int
i VName
x' ([VName]
zs [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [[VName]] -> [VName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[VName]]
xs') SubExp
new_d
  where
    (VName
x' : [VName]
zs, Certificates
x_cs) = VName -> ([VName], Certificates)
isConcat VName
x
    ([[VName]]
xs', [Certificates]
xs_cs) = [([VName], Certificates)] -> ([[VName]], [Certificates])
forall a b. [(a, b)] -> ([a], [b])
unzip ([([VName], Certificates)] -> ([[VName]], [Certificates]))
-> [([VName], Certificates)] -> ([[VName]], [Certificates])
forall a b. (a -> b) -> a -> b
$ (VName -> ([VName], Certificates))
-> [VName] -> [([VName], Certificates)]
forall a b. (a -> b) -> [a] -> [b]
map VName -> ([VName], Certificates)
isConcat [VName]
xs
    isConcat :: VName -> ([VName], Certificates)
isConcat VName
v = case VName -> SymbolTable lore -> Maybe (BasicOp, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (BasicOp, Certificates)
ST.lookupBasicOp VName
v SymbolTable lore
vtable of
      Just (Concat Int
j VName
y [VName]
ys SubExp
_, Certificates
v_cs) | Int
j Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
i -> (VName
y VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
ys, Certificates
v_cs)
      Maybe (BasicOp, Certificates)
_ -> ([VName
v], Certificates
forall a. Monoid a => a
mempty)

-- Fusing arguments to the concat when possible.  Only done when
-- concatenating along the outer dimension for now.
simplifyConcat (SymbolTable lore
vtable, UsageTable
_) Pattern lore
pat StmAux (ExpDec lore)
aux (Concat Int
0 VName
x [VName]
xs SubExp
outer_w)
  | -- We produce the to-be-concatenated arrays in reverse order, so
    -- reverse them back.
    (ConcatArg, Certificates)
y : [(ConcatArg, Certificates)]
ys <-
      [(ConcatArg, Certificates)] -> [(ConcatArg, Certificates)]
forall a. [a] -> [a]
reverse ([(ConcatArg, Certificates)] -> [(ConcatArg, Certificates)])
-> [(ConcatArg, Certificates)] -> [(ConcatArg, Certificates)]
forall a b. (a -> b) -> a -> b
$
        ([(ConcatArg, Certificates)]
 -> (ConcatArg, Certificates) -> [(ConcatArg, Certificates)])
-> [(ConcatArg, Certificates)]
-> [(ConcatArg, Certificates)]
-> [(ConcatArg, Certificates)]
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' [(ConcatArg, Certificates)]
-> (ConcatArg, Certificates) -> [(ConcatArg, Certificates)]
fuseConcatArg [(ConcatArg, Certificates)]
forall a. Monoid a => a
mempty ([(ConcatArg, Certificates)] -> [(ConcatArg, Certificates)])
-> [(ConcatArg, Certificates)] -> [(ConcatArg, Certificates)]
forall a b. (a -> b) -> a -> b
$
          (VName -> (ConcatArg, Certificates))
-> [VName] -> [(ConcatArg, Certificates)]
forall a b. (a -> b) -> [a] -> [b]
map (SymbolTable lore -> VName -> (ConcatArg, Certificates)
forall lore. SymbolTable lore -> VName -> (ConcatArg, Certificates)
toConcatArg SymbolTable lore
vtable) ([VName] -> [(ConcatArg, Certificates)])
-> [VName] -> [(ConcatArg, Certificates)]
forall a b. (a -> b) -> a -> b
$ VName
x VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
xs,
    [VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
xs Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= [(ConcatArg, Certificates)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(ConcatArg, Certificates)]
ys =
    RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
      TypeBase Shape NoUniqueness
elem_type <- VName -> RuleM lore (TypeBase Shape NoUniqueness)
forall lore (m :: * -> *).
HasScope lore m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
x
      VName
y' <- TypeBase Shape NoUniqueness
-> (ConcatArg, Certificates) -> RuleM lore VName
forall (m :: * -> *).
MonadBinder m =>
TypeBase Shape NoUniqueness -> (ConcatArg, Certificates) -> m VName
fromConcatArg TypeBase Shape NoUniqueness
elem_type (ConcatArg, Certificates)
y
      [VName]
ys' <- ((ConcatArg, Certificates) -> RuleM lore VName)
-> [(ConcatArg, Certificates)] -> RuleM lore [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (TypeBase Shape NoUniqueness
-> (ConcatArg, Certificates) -> RuleM lore VName
forall (m :: * -> *).
MonadBinder m =>
TypeBase Shape NoUniqueness -> (ConcatArg, Certificates) -> m VName
fromConcatArg TypeBase Shape NoUniqueness
elem_type) [(ConcatArg, Certificates)]
ys
      StmAux (ExpDec lore) -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux (ExpDec lore)
aux (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ Int -> VName -> [VName] -> SubExp -> BasicOp
Concat Int
0 VName
y' [VName]
ys' SubExp
outer_w
simplifyConcat (SymbolTable lore, UsageTable)
_ Pattern lore
_ StmAux (ExpDec lore)
_ BasicOp
_ = Rule lore
forall lore. Rule lore
Skip

ruleIf :: BinderOps lore => TopDownRuleIf lore
ruleIf :: TopDownRuleIf lore
ruleIf TopDown lore
_ Pattern lore
pat StmAux (ExpDec lore)
_ (SubExp
e1, BodyT lore
tb, BodyT lore
fb, IfDec [BranchType lore]
_ IfSort
ifsort)
  | Just BodyT lore
branch <- Maybe (BodyT lore)
checkBranch,
    IfSort
ifsort IfSort -> IfSort -> Bool
forall a. Eq a => a -> a -> Bool
/= IfSort
IfFallback Bool -> Bool -> Bool
|| SubExp -> Bool
isCt1 SubExp
e1 = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
    let ses :: [SubExp]
ses = BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT lore
branch
    Stms (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms (Stms (Lore (RuleM lore)) -> RuleM lore ())
-> Stms (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BodyT lore -> Stms lore
forall lore. BodyT lore -> Stms lore
bodyStms BodyT lore
branch
    [RuleM lore ()] -> RuleM lore ()
forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, Monad m) =>
t (m a) -> m ()
sequence_
      [ [VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [PatElemT (LetDec lore) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
p] (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
        | (PatElemT (LetDec lore)
p, SubExp
se) <- [PatElemT (LetDec lore)]
-> [SubExp] -> [(PatElemT (LetDec lore), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pattern lore -> [PatElemT (LetDec lore)]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern lore
pat) [SubExp]
ses
      ]
  where
    checkBranch :: Maybe (BodyT lore)
checkBranch
      | SubExp -> Bool
isCt1 SubExp
e1 = BodyT lore -> Maybe (BodyT lore)
forall a. a -> Maybe a
Just BodyT lore
tb
      | SubExp -> Bool
isCt0 SubExp
e1 = BodyT lore -> Maybe (BodyT lore)
forall a. a -> Maybe a
Just BodyT lore
fb
      | Bool
otherwise = Maybe (BodyT lore)
forall a. Maybe a
Nothing

-- IMPROVE: the following two rules can be generalised to work in more
-- cases, especially when the branches have bindings, or return more
-- than one value.
--
-- if c then True else v == c || v
ruleIf
  TopDown lore
_
  Pattern lore
pat
  StmAux (ExpDec lore)
_
  ( SubExp
cond,
    Body BodyDec lore
_ Stms lore
tstms [Constant (BoolValue Bool
True)],
    Body BodyDec lore
_ Stms lore
fstms [SubExp
se],
    IfDec [BranchType lore]
ts IfSort
_
    )
    | Stms lore -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Stms lore
tstms,
      Stms lore -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Stms lore
fstms,
      [Prim PrimType
Bool] <- (BranchType lore -> ExtType) -> [BranchType lore] -> [ExtType]
forall a b. (a -> b) -> [a] -> [b]
map BranchType lore -> ExtType
forall t. ExtTyped t => t -> ExtType
extTypeOf [BranchType lore]
ts =
      RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp BinOp
LogOr SubExp
cond SubExp
se
-- When type(x)==bool, if c then x else y == (c && x) || (!c && y)
ruleIf TopDown lore
_ Pattern lore
pat StmAux (ExpDec lore)
_ (SubExp
cond, BodyT lore
tb, BodyT lore
fb, IfDec [BranchType lore]
ts IfSort
_)
  | Body BodyDec lore
_ Stms lore
tstms [SubExp
tres] <- BodyT lore
tb,
    Body BodyDec lore
_ Stms lore
fstms [SubExp
fres] <- BodyT lore
fb,
    (Stm lore -> Bool) -> Stms lore -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (ExpT lore -> Bool
forall lore. IsOp (Op lore) => Exp lore -> Bool
safeExp (ExpT lore -> Bool) -> (Stm lore -> ExpT lore) -> Stm lore -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm lore -> ExpT lore
forall lore. Stm lore -> Exp lore
stmExp) (Stms lore -> Bool) -> Stms lore -> Bool
forall a b. (a -> b) -> a -> b
$ Stms lore
tstms Stms lore -> Stms lore -> Stms lore
forall a. Semigroup a => a -> a -> a
<> Stms lore
fstms,
    (BranchType lore -> Bool) -> [BranchType lore] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all ((ExtType -> ExtType -> Bool
forall a. Eq a => a -> a -> Bool
== PrimType -> ExtType
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
Bool) (ExtType -> Bool)
-> (BranchType lore -> ExtType) -> BranchType lore -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BranchType lore -> ExtType
forall t. ExtTyped t => t -> ExtType
extTypeOf) [BranchType lore]
ts = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
    Stms (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms Stms lore
Stms (Lore (RuleM lore))
tstms
    Stms (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms Stms lore
Stms (Lore (RuleM lore))
fstms
    ExpT lore
e <-
      BinOp
-> RuleM lore (Exp (Lore (RuleM lore)))
-> RuleM lore (Exp (Lore (RuleM lore)))
-> RuleM lore (Exp (Lore (RuleM lore)))
forall (m :: * -> *).
MonadBinder m =>
BinOp -> m (Exp (Lore m)) -> m (Exp (Lore m)) -> m (Exp (Lore m))
eBinOp
        BinOp
LogOr
        (ExpT lore -> RuleM lore (ExpT lore)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpT lore -> RuleM lore (ExpT lore))
-> ExpT lore -> RuleM lore (ExpT lore)
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp BinOp
LogAnd SubExp
cond SubExp
tres)
        ( BinOp
-> RuleM lore (Exp (Lore (RuleM lore)))
-> RuleM lore (Exp (Lore (RuleM lore)))
-> RuleM lore (Exp (Lore (RuleM lore)))
forall (m :: * -> *).
MonadBinder m =>
BinOp -> m (Exp (Lore m)) -> m (Exp (Lore m)) -> m (Exp (Lore m))
eBinOp
            BinOp
LogAnd
            (ExpT lore -> RuleM lore (ExpT lore)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpT lore -> RuleM lore (ExpT lore))
-> ExpT lore -> RuleM lore (ExpT lore)
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ UnOp -> SubExp -> BasicOp
UnOp UnOp
Not SubExp
cond)
            (ExpT lore -> RuleM lore (ExpT lore)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpT lore -> RuleM lore (ExpT lore))
-> ExpT lore -> RuleM lore (ExpT lore)
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
fres)
        )
    Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat ExpT lore
Exp (Lore (RuleM lore))
e
ruleIf TopDown lore
_ Pattern lore
pat StmAux (ExpDec lore)
_ (SubExp
_, BodyT lore
tbranch, BodyT lore
_, IfDec [BranchType lore]
_ IfSort
IfFallback)
  | [VName] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$ Pattern lore -> [VName]
forall dec. PatternT dec -> [VName]
patternContextNames Pattern lore
pat,
    (Stm lore -> Bool) -> Stms lore -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (ExpT lore -> Bool
forall lore. IsOp (Op lore) => Exp lore -> Bool
safeExp (ExpT lore -> Bool) -> (Stm lore -> ExpT lore) -> Stm lore -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm lore -> ExpT lore
forall lore. Stm lore -> Exp lore
stmExp) (Stms lore -> Bool) -> Stms lore -> Bool
forall a b. (a -> b) -> a -> b
$ BodyT lore -> Stms lore
forall lore. BodyT lore -> Stms lore
bodyStms BodyT lore
tbranch = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
    let ses :: [SubExp]
ses = BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT lore
tbranch
    Stms (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms (Stms (Lore (RuleM lore)) -> RuleM lore ())
-> Stms (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BodyT lore -> Stms lore
forall lore. BodyT lore -> Stms lore
bodyStms BodyT lore
tbranch
    [RuleM lore ()] -> RuleM lore ()
forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, Monad m) =>
t (m a) -> m ()
sequence_
      [ [VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [PatElemT (LetDec lore) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
p] (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
        | (PatElemT (LetDec lore)
p, SubExp
se) <- [PatElemT (LetDec lore)]
-> [SubExp] -> [(PatElemT (LetDec lore), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pattern lore -> [PatElemT (LetDec lore)]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern lore
pat) [SubExp]
ses
      ]
ruleIf TopDown lore
_ Pattern lore
pat StmAux (ExpDec lore)
_ (SubExp
cond, BodyT lore
tb, BodyT lore
fb, IfDec (BranchType lore)
_)
  | Body BodyDec lore
_ Stms lore
_ [Constant (IntValue IntValue
t)] <- BodyT lore
tb,
    Body BodyDec lore
_ Stms lore
_ [Constant (IntValue IntValue
f)] <- BodyT lore
fb =
    if IntValue -> Bool
oneIshInt IntValue
t Bool -> Bool -> Bool
&& IntValue -> Bool
zeroIshInt IntValue
f
      then
        RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$
          Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ ConvOp -> SubExp -> BasicOp
ConvOp (IntType -> ConvOp
BToI (IntValue -> IntType
intValueType IntValue
t)) SubExp
cond
      else
        if IntValue -> Bool
zeroIshInt IntValue
t Bool -> Bool -> Bool
&& IntValue -> Bool
oneIshInt IntValue
f
          then RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
            SubExp
cond_neg <- String -> Exp (Lore (RuleM lore)) -> RuleM lore SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"cond_neg" (Exp (Lore (RuleM lore)) -> RuleM lore SubExp)
-> Exp (Lore (RuleM lore)) -> RuleM lore SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ UnOp -> SubExp -> BasicOp
UnOp UnOp
Not SubExp
cond
            Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ ConvOp -> SubExp -> BasicOp
ConvOp (IntType -> ConvOp
BToI (IntValue -> IntType
intValueType IntValue
t)) SubExp
cond_neg
          else Rule lore
forall lore. Rule lore
Skip
ruleIf TopDown lore
_ Pattern lore
_ StmAux (ExpDec lore)
_ (SubExp, BodyT lore, BodyT lore, IfDec (BranchType lore))
_ = Rule lore
forall lore. Rule lore
Skip

-- | Move out results of a conditional expression whose computation is
-- either invariant to the branches (only done for results in the
-- context), or the same in both branches.
hoistBranchInvariant :: BinderOps lore => TopDownRuleIf lore
hoistBranchInvariant :: TopDownRuleIf lore
hoistBranchInvariant TopDown lore
_ Pattern lore
pat StmAux (ExpDec lore)
_ (SubExp
cond, BodyT lore
tb, BodyT lore
fb, IfDec [BranchType lore]
ret IfSort
ifsort) = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
  let tses :: [SubExp]
tses = BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT lore
tb
      fses :: [SubExp]
fses = BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT lore
fb
  ([Maybe (Int, SubExp)]
hoistings, ([PatElemT (LetDec lore)]
pes, [Either Int (BranchType lore)]
ts, [(SubExp, SubExp)]
res)) <-
    ([Either
    (Maybe (Int, SubExp))
    (PatElemT (LetDec lore), Either Int (BranchType lore),
     (SubExp, SubExp))]
 -> ([Maybe (Int, SubExp)],
     ([PatElemT (LetDec lore)], [Either Int (BranchType lore)],
      [(SubExp, SubExp)])))
-> RuleM
     lore
     [Either
        (Maybe (Int, SubExp))
        (PatElemT (LetDec lore), Either Int (BranchType lore),
         (SubExp, SubExp))]
-> RuleM
     lore
     ([Maybe (Int, SubExp)],
      ([PatElemT (LetDec lore)], [Either Int (BranchType lore)],
       [(SubExp, SubExp)]))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (([(PatElemT (LetDec lore), Either Int (BranchType lore),
   (SubExp, SubExp))]
 -> ([PatElemT (LetDec lore)], [Either Int (BranchType lore)],
     [(SubExp, SubExp)]))
-> ([Maybe (Int, SubExp)],
    [(PatElemT (LetDec lore), Either Int (BranchType lore),
      (SubExp, SubExp))])
-> ([Maybe (Int, SubExp)],
    ([PatElemT (LetDec lore)], [Either Int (BranchType lore)],
     [(SubExp, SubExp)]))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [(PatElemT (LetDec lore), Either Int (BranchType lore),
  (SubExp, SubExp))]
-> ([PatElemT (LetDec lore)], [Either Int (BranchType lore)],
    [(SubExp, SubExp)])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 (([Maybe (Int, SubExp)],
  [(PatElemT (LetDec lore), Either Int (BranchType lore),
    (SubExp, SubExp))])
 -> ([Maybe (Int, SubExp)],
     ([PatElemT (LetDec lore)], [Either Int (BranchType lore)],
      [(SubExp, SubExp)])))
-> ([Either
       (Maybe (Int, SubExp))
       (PatElemT (LetDec lore), Either Int (BranchType lore),
        (SubExp, SubExp))]
    -> ([Maybe (Int, SubExp)],
        [(PatElemT (LetDec lore), Either Int (BranchType lore),
          (SubExp, SubExp))]))
-> [Either
      (Maybe (Int, SubExp))
      (PatElemT (LetDec lore), Either Int (BranchType lore),
       (SubExp, SubExp))]
-> ([Maybe (Int, SubExp)],
    ([PatElemT (LetDec lore)], [Either Int (BranchType lore)],
     [(SubExp, SubExp)]))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Either
   (Maybe (Int, SubExp))
   (PatElemT (LetDec lore), Either Int (BranchType lore),
    (SubExp, SubExp))]
-> ([Maybe (Int, SubExp)],
    [(PatElemT (LetDec lore), Either Int (BranchType lore),
      (SubExp, SubExp))])
forall a b. [Either a b] -> ([a], [b])
partitionEithers) (RuleM
   lore
   [Either
      (Maybe (Int, SubExp))
      (PatElemT (LetDec lore), Either Int (BranchType lore),
       (SubExp, SubExp))]
 -> RuleM
      lore
      ([Maybe (Int, SubExp)],
       ([PatElemT (LetDec lore)], [Either Int (BranchType lore)],
        [(SubExp, SubExp)])))
-> RuleM
     lore
     [Either
        (Maybe (Int, SubExp))
        (PatElemT (LetDec lore), Either Int (BranchType lore),
         (SubExp, SubExp))]
-> RuleM
     lore
     ([Maybe (Int, SubExp)],
      ([PatElemT (LetDec lore)], [Either Int (BranchType lore)],
       [(SubExp, SubExp)]))
forall a b. (a -> b) -> a -> b
$
      ((PatElemT (LetDec lore), Either Int (BranchType lore),
  (SubExp, SubExp))
 -> RuleM
      lore
      (Either
         (Maybe (Int, SubExp))
         (PatElemT (LetDec lore), Either Int (BranchType lore),
          (SubExp, SubExp))))
-> [(PatElemT (LetDec lore), Either Int (BranchType lore),
     (SubExp, SubExp))]
-> RuleM
     lore
     [Either
        (Maybe (Int, SubExp))
        (PatElemT (LetDec lore), Either Int (BranchType lore),
         (SubExp, SubExp))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (PatElemT (LetDec lore), Either Int (BranchType lore),
 (SubExp, SubExp))
-> RuleM
     lore
     (Either
        (Maybe (Int, SubExp))
        (PatElemT (LetDec lore), Either Int (BranchType lore),
         (SubExp, SubExp)))
branchInvariant ([(PatElemT (LetDec lore), Either Int (BranchType lore),
   (SubExp, SubExp))]
 -> RuleM
      lore
      [Either
         (Maybe (Int, SubExp))
         (PatElemT (LetDec lore), Either Int (BranchType lore),
          (SubExp, SubExp))])
-> [(PatElemT (LetDec lore), Either Int (BranchType lore),
     (SubExp, SubExp))]
-> RuleM
     lore
     [Either
        (Maybe (Int, SubExp))
        (PatElemT (LetDec lore), Either Int (BranchType lore),
         (SubExp, SubExp))]
forall a b. (a -> b) -> a -> b
$
        [PatElemT (LetDec lore)]
-> [Either Int (BranchType lore)]
-> [(SubExp, SubExp)]
-> [(PatElemT (LetDec lore), Either Int (BranchType lore),
     (SubExp, SubExp))]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3
          (Pattern lore -> [PatElemT (LetDec lore)]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern lore
pat)
          ((Int -> Either Int (BranchType lore))
-> [Int] -> [Either Int (BranchType lore)]
forall a b. (a -> b) -> [a] -> [b]
map Int -> Either Int (BranchType lore)
forall a b. a -> Either a b
Left [Int
0 .. Int
num_ctx Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] [Either Int (BranchType lore)]
-> [Either Int (BranchType lore)] -> [Either Int (BranchType lore)]
forall a. [a] -> [a] -> [a]
++ (BranchType lore -> Either Int (BranchType lore))
-> [BranchType lore] -> [Either Int (BranchType lore)]
forall a b. (a -> b) -> [a] -> [b]
map BranchType lore -> Either Int (BranchType lore)
forall a b. b -> Either a b
Right [BranchType lore]
ret)
          ([SubExp] -> [SubExp] -> [(SubExp, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [SubExp]
tses [SubExp]
fses)
  let ctx_fixes :: [(Int, SubExp)]
ctx_fixes = [Maybe (Int, SubExp)] -> [(Int, SubExp)]
forall a. [Maybe a] -> [a]
catMaybes [Maybe (Int, SubExp)]
hoistings
      ([SubExp]
tses', [SubExp]
fses') = [(SubExp, SubExp)] -> ([SubExp], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip [(SubExp, SubExp)]
res
      tb' :: BodyT lore
tb' = BodyT lore
tb {bodyResult :: [SubExp]
bodyResult = [SubExp]
tses'}
      fb' :: BodyT lore
fb' = BodyT lore
fb {bodyResult :: [SubExp]
bodyResult = [SubExp]
fses'}
      ret' :: [BranchType lore]
ret' = ((Int, SubExp) -> [BranchType lore] -> [BranchType lore])
-> [BranchType lore] -> [(Int, SubExp)] -> [BranchType lore]
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr ((Int -> SubExp -> [BranchType lore] -> [BranchType lore])
-> (Int, SubExp) -> [BranchType lore] -> [BranchType lore]
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Int -> SubExp -> [BranchType lore] -> [BranchType lore]
forall t. FixExt t => Int -> SubExp -> t -> t
fixExt) ([Either Int (BranchType lore)] -> [BranchType lore]
forall a b. [Either a b] -> [b]
rights [Either Int (BranchType lore)]
ts) [(Int, SubExp)]
ctx_fixes
      ([PatElemT (LetDec lore)]
ctx_pes, [PatElemT (LetDec lore)]
val_pes) = Int
-> [PatElemT (LetDec lore)]
-> ([PatElemT (LetDec lore)], [PatElemT (LetDec lore)])
forall a. Int -> [a] -> ([a], [a])
splitFromEnd ([BranchType lore] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [BranchType lore]
ret') [PatElemT (LetDec lore)]
pes
  if Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [Maybe (Int, SubExp)] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Maybe (Int, SubExp)]
hoistings -- Was something hoisted?
    then do
      -- We may have to add some reshapes if we made the type
      -- less existential.
      BodyT lore
tb'' <- BodyT (Lore (RuleM lore))
-> [ExtType] -> RuleM lore (BodyT (Lore (RuleM lore)))
forall (m :: * -> *).
MonadBinder m =>
BodyT (Lore m) -> [ExtType] -> m (BodyT (Lore m))
reshapeBodyResults BodyT lore
BodyT (Lore (RuleM lore))
tb' ([ExtType] -> RuleM lore (BodyT (Lore (RuleM lore))))
-> [ExtType] -> RuleM lore (BodyT (Lore (RuleM lore)))
forall a b. (a -> b) -> a -> b
$ (BranchType lore -> ExtType) -> [BranchType lore] -> [ExtType]
forall a b. (a -> b) -> [a] -> [b]
map BranchType lore -> ExtType
forall t. ExtTyped t => t -> ExtType
extTypeOf [BranchType lore]
ret'
      BodyT lore
fb'' <- BodyT (Lore (RuleM lore))
-> [ExtType] -> RuleM lore (BodyT (Lore (RuleM lore)))
forall (m :: * -> *).
MonadBinder m =>
BodyT (Lore m) -> [ExtType] -> m (BodyT (Lore m))
reshapeBodyResults BodyT lore
BodyT (Lore (RuleM lore))
fb' ([ExtType] -> RuleM lore (BodyT (Lore (RuleM lore))))
-> [ExtType] -> RuleM lore (BodyT (Lore (RuleM lore)))
forall a b. (a -> b) -> a -> b
$ (BranchType lore -> ExtType) -> [BranchType lore] -> [ExtType]
forall a b. (a -> b) -> [a] -> [b]
map BranchType lore -> ExtType
forall t. ExtTyped t => t -> ExtType
extTypeOf [BranchType lore]
ret'
      Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind ([PatElemT (LetDec lore)]
-> [PatElemT (LetDec lore)] -> Pattern lore
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [PatElemT (LetDec lore)]
ctx_pes [PatElemT (LetDec lore)]
val_pes) (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
        SubExp
-> BodyT lore -> BodyT lore -> IfDec (BranchType lore) -> ExpT lore
forall lore.
SubExp
-> BodyT lore -> BodyT lore -> IfDec (BranchType lore) -> ExpT lore
If SubExp
cond BodyT lore
tb'' BodyT lore
fb'' ([BranchType lore] -> IfSort -> IfDec (BranchType lore)
forall rt. [rt] -> IfSort -> IfDec rt
IfDec [BranchType lore]
ret' IfSort
ifsort)
    else RuleM lore ()
forall lore a. RuleM lore a
cannotSimplify
  where
    num_ctx :: Int
num_ctx = [PatElemT (LetDec lore)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([PatElemT (LetDec lore)] -> Int)
-> [PatElemT (LetDec lore)] -> Int
forall a b. (a -> b) -> a -> b
$ Pattern lore -> [PatElemT (LetDec lore)]
forall dec. PatternT dec -> [PatElemT dec]
patternContextElements Pattern lore
pat
    bound_in_branches :: Names
bound_in_branches =
      [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$
        (Stm lore -> [VName]) -> Seq (Stm lore) -> [VName]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Pattern lore -> [VName]
forall dec. PatternT dec -> [VName]
patternNames (Pattern lore -> [VName])
-> (Stm lore -> Pattern lore) -> Stm lore -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm lore -> Pattern lore
forall lore. Stm lore -> Pattern lore
stmPattern) (Seq (Stm lore) -> [VName]) -> Seq (Stm lore) -> [VName]
forall a b. (a -> b) -> a -> b
$
          BodyT lore -> Seq (Stm lore)
forall lore. BodyT lore -> Stms lore
bodyStms BodyT lore
tb Seq (Stm lore) -> Seq (Stm lore) -> Seq (Stm lore)
forall a. Semigroup a => a -> a -> a
<> BodyT lore -> Seq (Stm lore)
forall lore. BodyT lore -> Stms lore
bodyStms BodyT lore
fb
    mem_sizes :: Names
mem_sizes = [PatElemT (LetDec lore)] -> Names
forall a. FreeIn a => a -> Names
freeIn ([PatElemT (LetDec lore)] -> Names)
-> [PatElemT (LetDec lore)] -> Names
forall a b. (a -> b) -> a -> b
$ (PatElemT (LetDec lore) -> Bool)
-> [PatElemT (LetDec lore)] -> [PatElemT (LetDec lore)]
forall a. (a -> Bool) -> [a] -> [a]
filter (TypeBase Shape NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
isMem (TypeBase Shape NoUniqueness -> Bool)
-> (PatElemT (LetDec lore) -> TypeBase Shape NoUniqueness)
-> PatElemT (LetDec lore)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElemT (LetDec lore) -> TypeBase Shape NoUniqueness
forall dec.
Typed dec =>
PatElemT dec -> TypeBase Shape NoUniqueness
patElemType) ([PatElemT (LetDec lore)] -> [PatElemT (LetDec lore)])
-> [PatElemT (LetDec lore)] -> [PatElemT (LetDec lore)]
forall a b. (a -> b) -> a -> b
$ Pattern lore -> [PatElemT (LetDec lore)]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern lore
pat
    invariant :: SubExp -> Bool
invariant Constant {} = Bool
True
    invariant (Var VName
v) = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ VName
v VName -> Names -> Bool
`nameIn` Names
bound_in_branches

    isMem :: TypeBase shape u -> Bool
isMem Mem {} = Bool
True
    isMem TypeBase shape u
_ = Bool
False
    sizeOfMem :: VName -> Bool
sizeOfMem VName
v = VName
v VName -> Names -> Bool
`nameIn` Names
mem_sizes

    branchInvariant :: (PatElemT (LetDec lore), Either Int (BranchType lore),
 (SubExp, SubExp))
-> RuleM
     lore
     (Either
        (Maybe (Int, SubExp))
        (PatElemT (LetDec lore), Either Int (BranchType lore),
         (SubExp, SubExp)))
branchInvariant (PatElemT (LetDec lore)
pe, Either Int (BranchType lore)
t, (SubExp
tse, SubExp
fse))
      -- Do both branches return the same value?
      | SubExp
tse SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
fse = do
        [VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [PatElemT (LetDec lore) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
pe] (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
tse
        PatElemT (LetDec lore)
-> Either Int (BranchType lore)
-> RuleM
     lore
     (Either
        (Maybe (Int, SubExp))
        (PatElemT (LetDec lore), Either Int (BranchType lore),
         (SubExp, SubExp)))
forall (m :: * -> *) dec a b b.
Monad m =>
PatElemT dec -> Either a b -> m (Either (Maybe (a, SubExp)) b)
hoisted PatElemT (LetDec lore)
pe Either Int (BranchType lore)
t

      -- Do both branches return values that are free in the
      -- branch, and are we not the only pattern element?  The
      -- latter is to avoid infinite application of this rule.
      | SubExp -> Bool
invariant SubExp
tse,
        SubExp -> Bool
invariant SubExp
fse,
        Pattern lore -> Int
forall dec. PatternT dec -> Int
patternSize Pattern lore
pat Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
1,
        Prim PrimType
_ <- PatElemT (LetDec lore) -> TypeBase Shape NoUniqueness
forall dec.
Typed dec =>
PatElemT dec -> TypeBase Shape NoUniqueness
patElemType PatElemT (LetDec lore)
pe,
        Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ VName -> Bool
sizeOfMem (VName -> Bool) -> VName -> Bool
forall a b. (a -> b) -> a -> b
$ PatElemT (LetDec lore) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
pe = do
        [BranchType lore]
bt <- Pattern lore -> RuleM lore [BranchType lore]
forall lore (m :: * -> *).
(ASTLore lore, HasScope lore m, Monad m) =>
Pattern lore -> m [BranchType lore]
expTypesFromPattern (Pattern lore -> RuleM lore [BranchType lore])
-> Pattern lore -> RuleM lore [BranchType lore]
forall a b. (a -> b) -> a -> b
$ [PatElemT (LetDec lore)]
-> [PatElemT (LetDec lore)] -> Pattern lore
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [PatElemT (LetDec lore)
pe]
        [VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [PatElemT (LetDec lore) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
pe]
          (ExpT lore -> RuleM lore ())
-> RuleM lore (ExpT lore) -> RuleM lore ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ( SubExp
-> BodyT lore -> BodyT lore -> IfDec (BranchType lore) -> ExpT lore
forall lore.
SubExp
-> BodyT lore -> BodyT lore -> IfDec (BranchType lore) -> ExpT lore
If SubExp
cond (BodyT lore -> BodyT lore -> IfDec (BranchType lore) -> ExpT lore)
-> RuleM lore (BodyT lore)
-> RuleM lore (BodyT lore -> IfDec (BranchType lore) -> ExpT lore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [SubExp] -> RuleM lore (BodyT (Lore (RuleM lore)))
forall (m :: * -> *).
MonadBinder m =>
[SubExp] -> m (Body (Lore m))
resultBodyM [SubExp
tse]
                  RuleM lore (BodyT lore -> IfDec (BranchType lore) -> ExpT lore)
-> RuleM lore (BodyT lore)
-> RuleM lore (IfDec (BranchType lore) -> ExpT lore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [SubExp] -> RuleM lore (BodyT (Lore (RuleM lore)))
forall (m :: * -> *).
MonadBinder m =>
[SubExp] -> m (Body (Lore m))
resultBodyM [SubExp
fse]
                  RuleM lore (IfDec (BranchType lore) -> ExpT lore)
-> RuleM lore (IfDec (BranchType lore)) -> RuleM lore (ExpT lore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> IfDec (BranchType lore) -> RuleM lore (IfDec (BranchType lore))
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([BranchType lore] -> IfSort -> IfDec (BranchType lore)
forall rt. [rt] -> IfSort -> IfDec rt
IfDec [BranchType lore]
bt IfSort
ifsort)
              )
        PatElemT (LetDec lore)
-> Either Int (BranchType lore)
-> RuleM
     lore
     (Either
        (Maybe (Int, SubExp))
        (PatElemT (LetDec lore), Either Int (BranchType lore),
         (SubExp, SubExp)))
forall (m :: * -> *) dec a b b.
Monad m =>
PatElemT dec -> Either a b -> m (Either (Maybe (a, SubExp)) b)
hoisted PatElemT (LetDec lore)
pe Either Int (BranchType lore)
t
      | Bool
otherwise =
        Either
  (Maybe (Int, SubExp))
  (PatElemT (LetDec lore), Either Int (BranchType lore),
   (SubExp, SubExp))
-> RuleM
     lore
     (Either
        (Maybe (Int, SubExp))
        (PatElemT (LetDec lore), Either Int (BranchType lore),
         (SubExp, SubExp)))
forall (m :: * -> *) a. Monad m => a -> m a
return (Either
   (Maybe (Int, SubExp))
   (PatElemT (LetDec lore), Either Int (BranchType lore),
    (SubExp, SubExp))
 -> RuleM
      lore
      (Either
         (Maybe (Int, SubExp))
         (PatElemT (LetDec lore), Either Int (BranchType lore),
          (SubExp, SubExp))))
-> Either
     (Maybe (Int, SubExp))
     (PatElemT (LetDec lore), Either Int (BranchType lore),
      (SubExp, SubExp))
-> RuleM
     lore
     (Either
        (Maybe (Int, SubExp))
        (PatElemT (LetDec lore), Either Int (BranchType lore),
         (SubExp, SubExp)))
forall a b. (a -> b) -> a -> b
$ (PatElemT (LetDec lore), Either Int (BranchType lore),
 (SubExp, SubExp))
-> Either
     (Maybe (Int, SubExp))
     (PatElemT (LetDec lore), Either Int (BranchType lore),
      (SubExp, SubExp))
forall a b. b -> Either a b
Right (PatElemT (LetDec lore)
pe, Either Int (BranchType lore)
t, (SubExp
tse, SubExp
fse))

    hoisted :: PatElemT dec -> Either a b -> m (Either (Maybe (a, SubExp)) b)
hoisted PatElemT dec
pe (Left a
i) = Either (Maybe (a, SubExp)) b -> m (Either (Maybe (a, SubExp)) b)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either (Maybe (a, SubExp)) b -> m (Either (Maybe (a, SubExp)) b))
-> Either (Maybe (a, SubExp)) b -> m (Either (Maybe (a, SubExp)) b)
forall a b. (a -> b) -> a -> b
$ Maybe (a, SubExp) -> Either (Maybe (a, SubExp)) b
forall a b. a -> Either a b
Left (Maybe (a, SubExp) -> Either (Maybe (a, SubExp)) b)
-> Maybe (a, SubExp) -> Either (Maybe (a, SubExp)) b
forall a b. (a -> b) -> a -> b
$ (a, SubExp) -> Maybe (a, SubExp)
forall a. a -> Maybe a
Just (a
i, VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ PatElemT dec -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT dec
pe)
    hoisted PatElemT dec
_ Right {} = Either (Maybe (a, SubExp)) b -> m (Either (Maybe (a, SubExp)) b)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either (Maybe (a, SubExp)) b -> m (Either (Maybe (a, SubExp)) b))
-> Either (Maybe (a, SubExp)) b -> m (Either (Maybe (a, SubExp)) b)
forall a b. (a -> b) -> a -> b
$ Maybe (a, SubExp) -> Either (Maybe (a, SubExp)) b
forall a b. a -> Either a b
Left Maybe (a, SubExp)
forall a. Maybe a
Nothing

    reshapeBodyResults :: BodyT (Lore m) -> [ExtType] -> m (BodyT (Lore m))
reshapeBodyResults BodyT (Lore m)
body [ExtType]
rets = m (BodyT (Lore m)) -> m (BodyT (Lore m))
forall (m :: * -> *).
MonadBinder m =>
m (Body (Lore m)) -> m (Body (Lore m))
insertStmsM (m (BodyT (Lore m)) -> m (BodyT (Lore m)))
-> m (BodyT (Lore m)) -> m (BodyT (Lore m))
forall a b. (a -> b) -> a -> b
$ do
      [SubExp]
ses <- BodyT (Lore m) -> m [SubExp]
forall (m :: * -> *). MonadBinder m => Body (Lore m) -> m [SubExp]
bodyBind BodyT (Lore m)
body
      let ([SubExp]
ctx_ses, [SubExp]
val_ses) = Int -> [SubExp] -> ([SubExp], [SubExp])
forall a. Int -> [a] -> ([a], [a])
splitFromEnd ([ExtType] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [ExtType]
rets) [SubExp]
ses
      [SubExp] -> m (BodyT (Lore m))
forall (m :: * -> *).
MonadBinder m =>
[SubExp] -> m (Body (Lore m))
resultBodyM ([SubExp] -> m (BodyT (Lore m)))
-> ([SubExp] -> [SubExp]) -> [SubExp] -> m (BodyT (Lore m))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([SubExp]
ctx_ses [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++) ([SubExp] -> m (BodyT (Lore m)))
-> m [SubExp] -> m (BodyT (Lore m))
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (SubExp -> ExtType -> m SubExp)
-> [SubExp] -> [ExtType] -> m [SubExp]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM SubExp -> ExtType -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
SubExp -> ExtType -> m SubExp
reshapeResult [SubExp]
val_ses [ExtType]
rets
    reshapeResult :: SubExp -> ExtType -> m SubExp
reshapeResult (Var VName
v) t :: ExtType
t@Array {} = do
      TypeBase Shape NoUniqueness
v_t <- VName -> m (TypeBase Shape NoUniqueness)
forall lore (m :: * -> *).
HasScope lore m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
v
      let newshape :: [SubExp]
newshape = TypeBase Shape NoUniqueness -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (TypeBase Shape NoUniqueness -> [SubExp])
-> TypeBase Shape NoUniqueness -> [SubExp]
forall a b. (a -> b) -> a -> b
$ ExtType
-> TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
removeExistentials ExtType
t TypeBase Shape NoUniqueness
v_t
      if [SubExp]
newshape [SubExp] -> [SubExp] -> Bool
forall a. Eq a => a -> a -> Bool
/= TypeBase Shape NoUniqueness -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims TypeBase Shape NoUniqueness
v_t
        then String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"branch_ctx_reshaped" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ [SubExp] -> VName -> Exp (Lore m)
forall lore. [SubExp] -> VName -> Exp lore
shapeCoerce [SubExp]
newshape VName
v
        else SubExp -> m SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp -> m SubExp) -> SubExp -> m SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v
    reshapeResult SubExp
se ExtType
_ =
      SubExp -> m SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return SubExp
se

simplifyIdentityReshape :: SimpleRule lore
simplifyIdentityReshape :: SimpleRule lore
simplifyIdentityReshape VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType (Reshape ShapeChange SubExp
newshape VName
v)
  | Just TypeBase Shape NoUniqueness
t <- SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType (SubExp -> Maybe (TypeBase Shape NoUniqueness))
-> SubExp -> Maybe (TypeBase Shape NoUniqueness)
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v,
    ShapeChange SubExp -> [SubExp]
forall d. ShapeChange d -> [d]
newDims ShapeChange SubExp
newshape [SubExp] -> [SubExp] -> Bool
forall a. Eq a => a -> a -> Bool
== TypeBase Shape NoUniqueness -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims TypeBase Shape NoUniqueness
t -- No-op reshape.
    =
    SubExp -> Maybe (BasicOp, Certificates)
subExpRes (SubExp -> Maybe (BasicOp, Certificates))
-> SubExp -> Maybe (BasicOp, Certificates)
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v
simplifyIdentityReshape VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ BasicOp
_ = Maybe (BasicOp, Certificates)
forall a. Maybe a
Nothing

simplifyReshapeReshape :: SimpleRule lore
simplifyReshapeReshape :: SimpleRule lore
simplifyReshapeReshape VarLookup lore
defOf SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (Reshape ShapeChange SubExp
newshape VName
v)
  | Just (BasicOp (Reshape ShapeChange SubExp
oldshape VName
v2), Certificates
v_cs) <- VarLookup lore
defOf VName
v =
    (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (ShapeChange SubExp -> VName -> BasicOp
Reshape (ShapeChange SubExp -> ShapeChange SubExp -> ShapeChange SubExp
forall d. Eq d => ShapeChange d -> ShapeChange d -> ShapeChange d
fuseReshape ShapeChange SubExp
oldshape ShapeChange SubExp
newshape) VName
v2, Certificates
v_cs)
simplifyReshapeReshape VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ BasicOp
_ = Maybe (BasicOp, Certificates)
forall a. Maybe a
Nothing

simplifyReshapeScratch :: SimpleRule lore
simplifyReshapeScratch :: SimpleRule lore
simplifyReshapeScratch VarLookup lore
defOf SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (Reshape ShapeChange SubExp
newshape VName
v)
  | Just (BasicOp (Scratch PrimType
bt [SubExp]
_), Certificates
v_cs) <- VarLookup lore
defOf VName
v =
    (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (PrimType -> [SubExp] -> BasicOp
Scratch PrimType
bt ([SubExp] -> BasicOp) -> [SubExp] -> BasicOp
forall a b. (a -> b) -> a -> b
$ ShapeChange SubExp -> [SubExp]
forall d. ShapeChange d -> [d]
newDims ShapeChange SubExp
newshape, Certificates
v_cs)
simplifyReshapeScratch VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ BasicOp
_ = Maybe (BasicOp, Certificates)
forall a. Maybe a
Nothing

simplifyReshapeReplicate :: SimpleRule lore
simplifyReshapeReplicate :: SimpleRule lore
simplifyReshapeReplicate VarLookup lore
defOf SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType (Reshape ShapeChange SubExp
newshape VName
v)
  | Just (BasicOp (Replicate Shape
_ SubExp
se), Certificates
v_cs) <- VarLookup lore
defOf VName
v,
    Just Shape
oldshape <- TypeBase Shape NoUniqueness -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape (TypeBase Shape NoUniqueness -> Shape)
-> Maybe (TypeBase Shape NoUniqueness) -> Maybe Shape
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType SubExp
se,
    Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
oldshape [SubExp] -> [SubExp] -> Bool
forall a. Eq a => [a] -> [a] -> Bool
`isSuffixOf` ShapeChange SubExp -> [SubExp]
forall d. ShapeChange d -> [d]
newDims ShapeChange SubExp
newshape =
    let new :: [SubExp]
new =
          Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
take (ShapeChange SubExp -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeChange SubExp
newshape Int -> Int -> Int
forall a. Num a => a -> a -> a
- Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank Shape
oldshape) ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$
            ShapeChange SubExp -> [SubExp]
forall d. ShapeChange d -> [d]
newDims ShapeChange SubExp
newshape
     in (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (Shape -> SubExp -> BasicOp
Replicate ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
new) SubExp
se, Certificates
v_cs)
simplifyReshapeReplicate VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ BasicOp
_ = Maybe (BasicOp, Certificates)
forall a. Maybe a
Nothing

simplifyReshapeIota :: SimpleRule lore
simplifyReshapeIota :: SimpleRule lore
simplifyReshapeIota VarLookup lore
defOf SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (Reshape ShapeChange SubExp
newshape VName
v)
  | Just (BasicOp (Iota SubExp
_ SubExp
offset SubExp
stride IntType
it), Certificates
v_cs) <- VarLookup lore
defOf VName
v,
    [SubExp
n] <- ShapeChange SubExp -> [SubExp]
forall d. ShapeChange d -> [d]
newDims ShapeChange SubExp
newshape =
    (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
n SubExp
offset SubExp
stride IntType
it, Certificates
v_cs)
simplifyReshapeIota VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ BasicOp
_ = Maybe (BasicOp, Certificates)
forall a. Maybe a
Nothing

improveReshape :: SimpleRule lore
improveReshape :: SimpleRule lore
improveReshape VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType (Reshape ShapeChange SubExp
newshape VName
v)
  | Just TypeBase Shape NoUniqueness
t <- SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType (SubExp -> Maybe (TypeBase Shape NoUniqueness))
-> SubExp -> Maybe (TypeBase Shape NoUniqueness)
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v,
    ShapeChange SubExp
newshape' <- [SubExp] -> ShapeChange SubExp -> ShapeChange SubExp
forall d. Eq d => [d] -> ShapeChange d -> ShapeChange d
informReshape (TypeBase Shape NoUniqueness -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims TypeBase Shape NoUniqueness
t) ShapeChange SubExp
newshape,
    ShapeChange SubExp
newshape' ShapeChange SubExp -> ShapeChange SubExp -> Bool
forall a. Eq a => a -> a -> Bool
/= ShapeChange SubExp
newshape =
    (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (ShapeChange SubExp -> VName -> BasicOp
Reshape ShapeChange SubExp
newshape' VName
v, Certificates
forall a. Monoid a => a
mempty)
improveReshape VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ BasicOp
_ = Maybe (BasicOp, Certificates)
forall a. Maybe a
Nothing

-- | If we are copying a scratch array (possibly indirectly), just turn it into a scratch by
-- itself.
copyScratchToScratch :: SimpleRule lore
copyScratchToScratch :: SimpleRule lore
copyScratchToScratch VarLookup lore
defOf SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType (Copy VName
src) = do
  TypeBase Shape NoUniqueness
t <- SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType (SubExp -> Maybe (TypeBase Shape NoUniqueness))
-> SubExp -> Maybe (TypeBase Shape NoUniqueness)
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
src
  if VName -> Bool
isActuallyScratch VName
src
    then (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (PrimType -> [SubExp] -> BasicOp
Scratch (TypeBase Shape NoUniqueness -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType TypeBase Shape NoUniqueness
t) (TypeBase Shape NoUniqueness -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims TypeBase Shape NoUniqueness
t), Certificates
forall a. Monoid a => a
mempty)
    else Maybe (BasicOp, Certificates)
forall a. Maybe a
Nothing
  where
    isActuallyScratch :: VName -> Bool
isActuallyScratch VName
v =
      case Exp lore -> Maybe BasicOp
forall lore. Exp lore -> Maybe BasicOp
asBasicOp (Exp lore -> Maybe BasicOp)
-> ((Exp lore, Certificates) -> Exp lore)
-> (Exp lore, Certificates)
-> Maybe BasicOp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Exp lore, Certificates) -> Exp lore
forall a b. (a, b) -> a
fst ((Exp lore, Certificates) -> Maybe BasicOp)
-> Maybe (Exp lore, Certificates) -> Maybe BasicOp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VarLookup lore
defOf VName
v of
        Just Scratch {} -> Bool
True
        Just (Rearrange [Int]
_ VName
v') -> VName -> Bool
isActuallyScratch VName
v'
        Just (Reshape ShapeChange SubExp
_ VName
v') -> VName -> Bool
isActuallyScratch VName
v'
        Maybe BasicOp
_ -> Bool
False
copyScratchToScratch VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ BasicOp
_ =
  Maybe (BasicOp, Certificates)
forall a. Maybe a
Nothing

ruleBasicOp :: BinderOps lore => TopDownRuleBasicOp lore
-- Check all the simpleRules.
ruleBasicOp :: TopDownRuleBasicOp lore
ruleBasicOp TopDown lore
vtable Pattern lore
pat StmAux (ExpDec lore)
aux BasicOp
op
  | Just (BasicOp
op', Certificates
cs) <- [Maybe (BasicOp, Certificates)] -> Maybe (BasicOp, Certificates)
forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, MonadPlus m) =>
t (m a) -> m a
msum [SimpleRule lore
rule VName -> Maybe (Exp lore, Certificates)
defOf SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType BasicOp
op | SimpleRule lore
rule <- [SimpleRule lore]
forall lore. [SimpleRule lore]
simpleRules] =
    RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ Certificates -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (Certificates
cs Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> StmAux (ExpDec lore) -> Certificates
forall dec. StmAux dec -> Certificates
stmAuxCerts StmAux (ExpDec lore)
aux) (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp BasicOp
op'
  where
    defOf :: VName -> Maybe (Exp lore, Certificates)
defOf = (VName -> TopDown lore -> Maybe (Exp lore, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (Exp lore, Certificates)
`ST.lookupExp` TopDown lore
vtable)
    seType :: SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType (Var VName
v) = VName -> TopDown lore -> Maybe (TypeBase Shape NoUniqueness)
forall lore.
ASTLore lore =>
VName -> SymbolTable lore -> Maybe (TypeBase Shape NoUniqueness)
ST.lookupType VName
v TopDown lore
vtable
    seType (Constant PrimValue
v) = TypeBase Shape NoUniqueness -> Maybe (TypeBase Shape NoUniqueness)
forall a. a -> Maybe a
Just (TypeBase Shape NoUniqueness
 -> Maybe (TypeBase Shape NoUniqueness))
-> TypeBase Shape NoUniqueness
-> Maybe (TypeBase Shape NoUniqueness)
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase Shape NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim (PrimType -> TypeBase Shape NoUniqueness)
-> PrimType -> TypeBase Shape NoUniqueness
forall a b. (a -> b) -> a -> b
$ PrimValue -> PrimType
primValueType PrimValue
v
ruleBasicOp TopDown lore
vtable Pattern lore
pat StmAux (ExpDec lore)
_ (Update VName
src Slice SubExp
_ (Var VName
v))
  | Just (BasicOp Scratch {}, Certificates
_) <- VName -> TopDown lore -> Maybe (Exp lore, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (Exp lore, Certificates)
ST.lookupExp VName
v TopDown lore
vtable =
    RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
src
-- If we are writing a single-element slice from some array, and the
-- element of that array can be computed as a PrimExp based on the
-- index, let's just write that instead.
ruleBasicOp TopDown lore
vtable Pattern lore
pat StmAux (ExpDec lore)
aux (Update VName
src [DimSlice SubExp
i SubExp
n SubExp
s] (Var VName
v))
  | SubExp -> Bool
isCt1 SubExp
n,
    SubExp -> Bool
isCt1 SubExp
s,
    Just (ST.Indexed Certificates
cs PrimExp VName
e) <- VName -> [SubExp] -> TopDown lore -> Maybe Indexed
forall lore.
ASTLore lore =>
VName -> [SubExp] -> SymbolTable lore -> Maybe Indexed
ST.index VName
v [IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0] TopDown lore
vtable =
    RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
      SubExp
e' <- String -> PrimExp VName -> RuleM lore SubExp
forall (m :: * -> *) a.
(MonadBinder m, ToExp a) =>
String -> a -> m SubExp
toSubExp String
"update_elem" PrimExp VName
e
      StmAux (ExpDec lore) -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux (ExpDec lore)
aux (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
        Certificates -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
          Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> SubExp -> BasicOp
Update VName
src [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
i] SubExp
e'
ruleBasicOp TopDown lore
vtable Pattern lore
pat StmAux (ExpDec lore)
_ (Update VName
dest Slice SubExp
destis (Var VName
v))
  | Just (Exp lore
e, Certificates
_) <- VName -> TopDown lore -> Maybe (Exp lore, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (Exp lore, Certificates)
ST.lookupExp VName
v TopDown lore
vtable,
    Exp lore -> Bool
arrayFrom Exp lore
e =
    RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
dest
  where
    arrayFrom :: Exp lore -> Bool
arrayFrom (BasicOp (Copy VName
copy_v))
      | Just (Exp lore
e', Certificates
_) <- VName -> TopDown lore -> Maybe (Exp lore, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (Exp lore, Certificates)
ST.lookupExp VName
copy_v TopDown lore
vtable =
        Exp lore -> Bool
arrayFrom Exp lore
e'
    arrayFrom (BasicOp (Index VName
src Slice SubExp
srcis)) =
      VName
src VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
dest Bool -> Bool -> Bool
&& Slice SubExp
destis Slice SubExp -> Slice SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== Slice SubExp
srcis
    arrayFrom (BasicOp (Replicate Shape
v_shape SubExp
v_se))
      | Just (Replicate Shape
dest_shape SubExp
dest_se, Certificates
_) <- VName -> TopDown lore -> Maybe (BasicOp, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (BasicOp, Certificates)
ST.lookupBasicOp VName
dest TopDown lore
vtable,
        SubExp
v_se SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
dest_se,
        Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
v_shape [SubExp] -> [SubExp] -> Bool
forall a. Eq a => [a] -> [a] -> Bool
`isSuffixOf` Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
dest_shape =
        Bool
True
    arrayFrom Exp lore
_ =
      Bool
False
ruleBasicOp TopDown lore
vtable Pattern lore
pat StmAux (ExpDec lore)
_ (Update VName
dest Slice SubExp
is SubExp
se)
  | Just TypeBase Shape NoUniqueness
dest_t <- VName -> TopDown lore -> Maybe (TypeBase Shape NoUniqueness)
forall lore.
ASTLore lore =>
VName -> SymbolTable lore -> Maybe (TypeBase Shape NoUniqueness)
ST.lookupType VName
dest TopDown lore
vtable,
    Shape -> Slice SubExp -> Bool
isFullSlice (TypeBase Shape NoUniqueness -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape TypeBase Shape NoUniqueness
dest_t) Slice SubExp
is = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$
    case SubExp
se of
      Var VName
v | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null ([SubExp] -> Bool) -> [SubExp] -> Bool
forall a b. (a -> b) -> a -> b
$ Slice SubExp -> [SubExp]
forall d. Slice d -> [d]
sliceDims Slice SubExp
is -> do
        VName
v_reshaped <-
          String -> Exp (Lore (RuleM lore)) -> RuleM lore VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp (VName -> String
baseString VName
v String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_reshaped") (Exp (Lore (RuleM lore)) -> RuleM lore VName)
-> Exp (Lore (RuleM lore)) -> RuleM lore VName
forall a b. (a -> b) -> a -> b
$
            BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ ShapeChange SubExp -> VName -> BasicOp
Reshape ((SubExp -> DimChange SubExp) -> [SubExp] -> ShapeChange SubExp
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimChange SubExp
forall d. d -> DimChange d
DimNew ([SubExp] -> ShapeChange SubExp) -> [SubExp] -> ShapeChange SubExp
forall a b. (a -> b) -> a -> b
$ TypeBase Shape NoUniqueness -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims TypeBase Shape NoUniqueness
dest_t) VName
v
        Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
v_reshaped
      SubExp
_ -> Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ [SubExp] -> TypeBase Shape NoUniqueness -> BasicOp
ArrayLit [SubExp
se] (TypeBase Shape NoUniqueness -> BasicOp)
-> TypeBase Shape NoUniqueness -> BasicOp
forall a b. (a -> b) -> a -> b
$ TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
forall shape u.
ArrayShape shape =>
TypeBase shape u -> TypeBase shape u
rowType TypeBase Shape NoUniqueness
dest_t
ruleBasicOp TopDown lore
vtable Pattern lore
pat (StmAux Certificates
cs1 Attrs
attrs ExpDec lore
_) (Update VName
dest1 Slice SubExp
is1 (Var VName
v1))
  | Just (Update VName
dest2 Slice SubExp
is2 SubExp
se2, Certificates
cs2) <- VName -> TopDown lore -> Maybe (BasicOp, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (BasicOp, Certificates)
ST.lookupBasicOp VName
v1 TopDown lore
vtable,
    Just (Copy VName
v3, Certificates
cs3) <- VName -> TopDown lore -> Maybe (BasicOp, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (BasicOp, Certificates)
ST.lookupBasicOp VName
dest2 TopDown lore
vtable,
    Just (Index VName
v4 Slice SubExp
is4, Certificates
cs4) <- VName -> TopDown lore -> Maybe (BasicOp, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (BasicOp, Certificates)
ST.lookupBasicOp VName
v3 TopDown lore
vtable,
    Slice SubExp
is4 Slice SubExp -> Slice SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== Slice SubExp
is1,
    VName
v4 VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
dest1 =
    RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$
      Certificates -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (Certificates
cs1 Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> Certificates
cs2 Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> Certificates
cs3 Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> Certificates
cs4) (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ do
        Slice SubExp
is5 <- Slice (TPrimExp Int64 VName) -> RuleM lore (Slice SubExp)
forall (m :: * -> *).
MonadBinder m =>
Slice (TPrimExp Int64 VName) -> m (Slice SubExp)
subExpSlice (Slice (TPrimExp Int64 VName) -> RuleM lore (Slice SubExp))
-> Slice (TPrimExp Int64 VName) -> RuleM lore (Slice SubExp)
forall a b. (a -> b) -> a -> b
$ Slice (TPrimExp Int64 VName)
-> Slice (TPrimExp Int64 VName) -> Slice (TPrimExp Int64 VName)
forall d. Num d => Slice d -> Slice d -> Slice d
sliceSlice (Slice SubExp -> Slice (TPrimExp Int64 VName)
primExpSlice Slice SubExp
is1) (Slice SubExp -> Slice (TPrimExp Int64 VName)
primExpSlice Slice SubExp
is2)
        Attrs -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Attrs -> m a -> m a
attributing Attrs
attrs (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> SubExp -> BasicOp
Update VName
dest1 Slice SubExp
is5 SubExp
se2
ruleBasicOp TopDown lore
vtable Pattern lore
pat StmAux (ExpDec lore)
_ (CmpOp (CmpEq PrimType
t) SubExp
se1 SubExp
se2)
  | Just RuleM lore ()
m <- SubExp -> SubExp -> Maybe (RuleM lore ())
simplifyWith SubExp
se1 SubExp
se2 = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify RuleM lore ()
m
  | Just RuleM lore ()
m <- SubExp -> SubExp -> Maybe (RuleM lore ())
simplifyWith SubExp
se2 SubExp
se1 = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify RuleM lore ()
m
  where
    simplifyWith :: SubExp -> SubExp -> Maybe (RuleM lore ())
simplifyWith (Var VName
v) SubExp
x
      | Just Stm lore
bnd <- VName -> TopDown lore -> Maybe (Stm lore)
forall lore. VName -> SymbolTable lore -> Maybe (Stm lore)
ST.lookupStm VName
v TopDown lore
vtable,
        If SubExp
p BodyT lore
tbranch BodyT lore
fbranch IfDec (BranchType lore)
_ <- Stm lore -> Exp lore
forall lore. Stm lore -> Exp lore
stmExp Stm lore
bnd,
        Just (SubExp
y, SubExp
z) <-
          VName
-> Pattern lore
-> BodyT lore
-> BodyT lore
-> Maybe (SubExp, SubExp)
forall dec lore lore.
VName
-> PatternT dec
-> BodyT lore
-> BodyT lore
-> Maybe (SubExp, SubExp)
returns VName
v (Stm lore -> Pattern lore
forall lore. Stm lore -> Pattern lore
stmPattern Stm lore
bnd) BodyT lore
tbranch BodyT lore
fbranch,
        Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ BodyT lore -> Names
forall lore. Body lore -> Names
boundInBody BodyT lore
tbranch Names -> Names -> Bool
`namesIntersect` SubExp -> Names
forall a. FreeIn a => a -> Names
freeIn SubExp
y,
        Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ BodyT lore -> Names
forall lore. Body lore -> Names
boundInBody BodyT lore
fbranch Names -> Names -> Bool
`namesIntersect` SubExp -> Names
forall a. FreeIn a => a -> Names
freeIn SubExp
z = RuleM lore () -> Maybe (RuleM lore ())
forall a. a -> Maybe a
Just (RuleM lore () -> Maybe (RuleM lore ()))
-> RuleM lore () -> Maybe (RuleM lore ())
forall a b. (a -> b) -> a -> b
$ do
        SubExp
eq_x_y <-
          String -> Exp (Lore (RuleM lore)) -> RuleM lore SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"eq_x_y" (Exp (Lore (RuleM lore)) -> RuleM lore SubExp)
-> Exp (Lore (RuleM lore)) -> RuleM lore SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ CmpOp -> SubExp -> SubExp -> BasicOp
CmpOp (PrimType -> CmpOp
CmpEq PrimType
t) SubExp
x SubExp
y
        SubExp
eq_x_z <-
          String -> Exp (Lore (RuleM lore)) -> RuleM lore SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"eq_x_z" (Exp (Lore (RuleM lore)) -> RuleM lore SubExp)
-> Exp (Lore (RuleM lore)) -> RuleM lore SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ CmpOp -> SubExp -> SubExp -> BasicOp
CmpOp (PrimType -> CmpOp
CmpEq PrimType
t) SubExp
x SubExp
z
        SubExp
p_and_eq_x_y <-
          String -> Exp (Lore (RuleM lore)) -> RuleM lore SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"p_and_eq_x_y" (Exp (Lore (RuleM lore)) -> RuleM lore SubExp)
-> Exp (Lore (RuleM lore)) -> RuleM lore SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp BinOp
LogAnd SubExp
p SubExp
eq_x_y
        SubExp
not_p <-
          String -> Exp (Lore (RuleM lore)) -> RuleM lore SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"not_p" (Exp (Lore (RuleM lore)) -> RuleM lore SubExp)
-> Exp (Lore (RuleM lore)) -> RuleM lore SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ UnOp -> SubExp -> BasicOp
UnOp UnOp
Not SubExp
p
        SubExp
not_p_and_eq_x_z <-
          String -> Exp (Lore (RuleM lore)) -> RuleM lore SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"p_and_eq_x_y" (Exp (Lore (RuleM lore)) -> RuleM lore SubExp)
-> Exp (Lore (RuleM lore)) -> RuleM lore SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp BinOp
LogAnd SubExp
not_p SubExp
eq_x_z
        Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
          BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp BinOp
LogOr SubExp
p_and_eq_x_y SubExp
not_p_and_eq_x_z
    simplifyWith SubExp
_ SubExp
_ =
      Maybe (RuleM lore ())
forall a. Maybe a
Nothing

    returns :: VName
-> PatternT dec
-> BodyT lore
-> BodyT lore
-> Maybe (SubExp, SubExp)
returns VName
v PatternT dec
ifpat BodyT lore
tbranch BodyT lore
fbranch =
      ((PatElemT dec, (SubExp, SubExp)) -> (SubExp, SubExp))
-> Maybe (PatElemT dec, (SubExp, SubExp)) -> Maybe (SubExp, SubExp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (PatElemT dec, (SubExp, SubExp)) -> (SubExp, SubExp)
forall a b. (a, b) -> b
snd (Maybe (PatElemT dec, (SubExp, SubExp)) -> Maybe (SubExp, SubExp))
-> Maybe (PatElemT dec, (SubExp, SubExp)) -> Maybe (SubExp, SubExp)
forall a b. (a -> b) -> a -> b
$
        ((PatElemT dec, (SubExp, SubExp)) -> Bool)
-> [(PatElemT dec, (SubExp, SubExp))]
-> Maybe (PatElemT dec, (SubExp, SubExp))
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
v) (VName -> Bool)
-> ((PatElemT dec, (SubExp, SubExp)) -> VName)
-> (PatElemT dec, (SubExp, SubExp))
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElemT dec -> VName
forall dec. PatElemT dec -> VName
patElemName (PatElemT dec -> VName)
-> ((PatElemT dec, (SubExp, SubExp)) -> PatElemT dec)
-> (PatElemT dec, (SubExp, SubExp))
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PatElemT dec, (SubExp, SubExp)) -> PatElemT dec
forall a b. (a, b) -> a
fst) ([(PatElemT dec, (SubExp, SubExp))]
 -> Maybe (PatElemT dec, (SubExp, SubExp)))
-> [(PatElemT dec, (SubExp, SubExp))]
-> Maybe (PatElemT dec, (SubExp, SubExp))
forall a b. (a -> b) -> a -> b
$
          [PatElemT dec]
-> [(SubExp, SubExp)] -> [(PatElemT dec, (SubExp, SubExp))]
forall a b. [a] -> [b] -> [(a, b)]
zip (PatternT dec -> [PatElemT dec]
forall dec. PatternT dec -> [PatElemT dec]
patternValueElements PatternT dec
ifpat) ([(SubExp, SubExp)] -> [(PatElemT dec, (SubExp, SubExp))])
-> [(SubExp, SubExp)] -> [(PatElemT dec, (SubExp, SubExp))]
forall a b. (a -> b) -> a -> b
$
            [SubExp] -> [SubExp] -> [(SubExp, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT lore
tbranch) (BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT lore
fbranch)
ruleBasicOp TopDown lore
_ Pattern lore
pat StmAux (ExpDec lore)
_ (Replicate (Shape []) se :: SubExp
se@Constant {}) =
  RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
ruleBasicOp TopDown lore
_ Pattern lore
pat StmAux (ExpDec lore)
_ (Replicate (Shape []) (Var VName
v)) = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
  TypeBase Shape NoUniqueness
v_t <- VName -> RuleM lore (TypeBase Shape NoUniqueness)
forall lore (m :: * -> *).
HasScope lore m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
v
  Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
    BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$
      if TypeBase Shape NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType TypeBase Shape NoUniqueness
v_t
        then SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v
        else VName -> BasicOp
Copy VName
v
ruleBasicOp TopDown lore
vtable Pattern lore
pat StmAux (ExpDec lore)
_ (Replicate Shape
shape (Var VName
v))
  | Just (BasicOp (Replicate Shape
shape2 SubExp
se), Certificates
cs) <- VName -> TopDown lore -> Maybe (Exp lore, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (Exp lore, Certificates)
ST.lookupExp VName
v TopDown lore
vtable =
    RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ Certificates -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate (Shape
shape Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> Shape
shape2) SubExp
se
ruleBasicOp TopDown lore
_ Pattern lore
pat StmAux (ExpDec lore)
_ (ArrayLit (SubExp
se : [SubExp]
ses) TypeBase Shape NoUniqueness
_)
  | (SubExp -> Bool) -> [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
se) [SubExp]
ses =
    RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$
      let n :: SubExp
n = Int64 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
ses) Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
+ Int64
1 :: Int64)
       in Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
n]) SubExp
se
ruleBasicOp TopDown lore
vtable Pattern lore
pat StmAux (ExpDec lore)
aux (Index VName
idd Slice SubExp
slice)
  | Just [SubExp]
inds <- Slice SubExp -> Maybe [SubExp]
forall d. Slice d -> Maybe [d]
sliceIndices Slice SubExp
slice,
    Just (BasicOp (Reshape ShapeChange SubExp
newshape VName
idd2), Certificates
idd_cs) <- VName -> TopDown lore -> Maybe (Exp lore, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (Exp lore, Certificates)
ST.lookupExp VName
idd TopDown lore
vtable,
    ShapeChange SubExp -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeChange SubExp
newshape Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
inds =
    RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$
      case ShapeChange SubExp -> Maybe [SubExp]
forall d. ShapeChange d -> Maybe [d]
shapeCoercion ShapeChange SubExp
newshape of
        Just [SubExp]
_ ->
          Certificates -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
idd_cs (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
            StmAux (ExpDec lore) -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux (ExpDec lore)
aux (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
              Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
idd2 Slice SubExp
slice
        Maybe [SubExp]
Nothing -> do
          -- Linearise indices and map to old index space.
          [SubExp]
oldshape <- TypeBase Shape NoUniqueness -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (TypeBase Shape NoUniqueness -> [SubExp])
-> RuleM lore (TypeBase Shape NoUniqueness) -> RuleM lore [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> RuleM lore (TypeBase Shape NoUniqueness)
forall lore (m :: * -> *).
HasScope lore m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
idd2
          let new_inds :: [TPrimExp Int64 VName]
new_inds =
                [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName]
forall num. IntegralExp num => [num] -> [num] -> [num] -> [num]
reshapeIndex
                  ((SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
oldshape)
                  ((SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> [TPrimExp Int64 VName])
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ ShapeChange SubExp -> [SubExp]
forall d. ShapeChange d -> [d]
newDims ShapeChange SubExp
newshape)
                  ((SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
inds)
          [SubExp]
new_inds' <-
            (TPrimExp Int64 VName -> RuleM lore SubExp)
-> [TPrimExp Int64 VName] -> RuleM lore [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (String -> TPrimExp Int64 VName -> RuleM lore SubExp
forall (m :: * -> *) a.
(MonadBinder m, ToExp a) =>
String -> a -> m SubExp
toSubExp String
"new_index") [TPrimExp Int64 VName]
new_inds
          Certificates -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
idd_cs (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
            StmAux (ExpDec lore) -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux (ExpDec lore)
aux (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
              Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
idd2 (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ (SubExp -> DimIndex SubExp) -> [SubExp] -> Slice SubExp
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix [SubExp]
new_inds'
ruleBasicOp TopDown lore
_ Pattern lore
pat StmAux (ExpDec lore)
_ (BinOp (Pow IntType
t) SubExp
e1 SubExp
e2)
  | SubExp
e1 SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== IntType -> Integer -> SubExp
intConst IntType
t Integer
2 =
    RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> BinOp
Shl IntType
t) (IntType -> Integer -> SubExp
intConst IntType
t Integer
1) SubExp
e2
-- Handle identity permutation.
ruleBasicOp TopDown lore
_ Pattern lore
pat StmAux (ExpDec lore)
_ (Rearrange [Int]
perm VName
v)
  | [Int] -> [Int]
forall a. Ord a => [a] -> [a]
sort [Int]
perm [Int] -> [Int] -> Bool
forall a. Eq a => a -> a -> Bool
== [Int]
perm =
    RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v
ruleBasicOp TopDown lore
vtable Pattern lore
pat StmAux (ExpDec lore)
aux (Rearrange [Int]
perm VName
v)
  | Just (BasicOp (Rearrange [Int]
perm2 VName
e), Certificates
v_cs) <- VName -> TopDown lore -> Maybe (Exp lore, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (Exp lore, Certificates)
ST.lookupExp VName
v TopDown lore
vtable =
    -- Rearranging a rearranging: compose the permutations.
    RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$
      Certificates -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
v_cs (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
        StmAux (ExpDec lore) -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux (ExpDec lore)
aux (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
          Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Rearrange ([Int]
perm [Int] -> [Int] -> [Int]
`rearrangeCompose` [Int]
perm2) VName
e
ruleBasicOp TopDown lore
vtable Pattern lore
pat StmAux (ExpDec lore)
aux (Rearrange [Int]
perm VName
v)
  | Just (BasicOp (Rotate [SubExp]
offsets VName
v2), Certificates
v_cs) <- VName -> TopDown lore -> Maybe (Exp lore, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (Exp lore, Certificates)
ST.lookupExp VName
v TopDown lore
vtable,
    Just (BasicOp (Rearrange [Int]
perm3 VName
v3), Certificates
v2_cs) <- VName -> TopDown lore -> Maybe (Exp lore, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (Exp lore, Certificates)
ST.lookupExp VName
v2 TopDown lore
vtable = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
    let offsets' :: [SubExp]
offsets' = [Int] -> [SubExp] -> [SubExp]
forall a. [Int] -> [a] -> [a]
rearrangeShape ([Int] -> [Int]
rearrangeInverse [Int]
perm3) [SubExp]
offsets
    VName
rearrange_rotate <- String -> Exp (Lore (RuleM lore)) -> RuleM lore VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"rearrange_rotate" (Exp (Lore (RuleM lore)) -> RuleM lore VName)
-> Exp (Lore (RuleM lore)) -> RuleM lore VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ [SubExp] -> VName -> BasicOp
Rotate [SubExp]
offsets' VName
v3
    Certificates -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (Certificates
v_cs Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> Certificates
v2_cs) (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
      StmAux (ExpDec lore) -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux (ExpDec lore)
aux (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
        Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Rearrange ([Int]
perm [Int] -> [Int] -> [Int]
`rearrangeCompose` [Int]
perm3) VName
rearrange_rotate

-- Rearranging a replicate where the outer dimension is left untouched.
ruleBasicOp TopDown lore
vtable Pattern lore
pat StmAux (ExpDec lore)
aux (Rearrange [Int]
perm VName
v1)
  | Just (BasicOp (Replicate Shape
dims (Var VName
v2)), Certificates
v1_cs) <- VName -> TopDown lore -> Maybe (Exp lore, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (Exp lore, Certificates)
ST.lookupExp VName
v1 TopDown lore
vtable,
    Int
num_dims <- Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank Shape
dims,
    ([Int]
rep_perm, [Int]
rest_perm) <- Int -> [Int] -> ([Int], [Int])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_dims [Int]
perm,
    Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [Int] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Int]
rest_perm,
    [Int]
rep_perm [Int] -> [Int] -> Bool
forall a. Eq a => a -> a -> Bool
== [Int
0 .. [Int] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
rep_perm Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] =
    RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$
      Certificates -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
v1_cs (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
        StmAux (ExpDec lore) -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux (ExpDec lore)
aux (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ do
          SubExp
v <-
            String -> Exp (Lore (RuleM lore)) -> RuleM lore SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"rearrange_replicate" (Exp (Lore (RuleM lore)) -> RuleM lore SubExp)
-> Exp (Lore (RuleM lore)) -> RuleM lore SubExp
forall a b. (a -> b) -> a -> b
$
              BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Rearrange ((Int -> Int) -> [Int] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> Int -> Int
forall a. Num a => a -> a -> a
subtract Int
num_dims) [Int]
rest_perm) VName
v2
          Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate Shape
dims SubExp
v

-- A zero-rotation is identity.
ruleBasicOp TopDown lore
_ Pattern lore
pat StmAux (ExpDec lore)
_ (Rotate [SubExp]
offsets VName
v)
  | (SubExp -> Bool) -> [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all SubExp -> Bool
isCt0 [SubExp]
offsets = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v
ruleBasicOp TopDown lore
vtable Pattern lore
pat StmAux (ExpDec lore)
aux (Rotate [SubExp]
offsets VName
v)
  | Just (BasicOp (Rearrange [Int]
perm VName
v2), Certificates
v_cs) <- VName -> TopDown lore -> Maybe (Exp lore, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (Exp lore, Certificates)
ST.lookupExp VName
v TopDown lore
vtable,
    Just (BasicOp (Rotate [SubExp]
offsets2 VName
v3), Certificates
v2_cs) <- VName -> TopDown lore -> Maybe (Exp lore, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (Exp lore, Certificates)
ST.lookupExp VName
v2 TopDown lore
vtable = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
    let offsets2' :: [SubExp]
offsets2' = [Int] -> [SubExp] -> [SubExp]
forall a. [Int] -> [a] -> [a]
rearrangeShape ([Int] -> [Int]
rearrangeInverse [Int]
perm) [SubExp]
offsets2
        addOffsets :: SubExp -> SubExp -> m SubExp
addOffsets SubExp
x SubExp
y = String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"summed_offset" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Add IntType
Int64 Overflow
OverflowWrap) SubExp
x SubExp
y
    [SubExp]
offsets' <- (SubExp -> SubExp -> RuleM lore SubExp)
-> [SubExp] -> [SubExp] -> RuleM lore [SubExp]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM SubExp -> SubExp -> RuleM lore SubExp
forall (m :: * -> *). MonadBinder m => SubExp -> SubExp -> m SubExp
addOffsets [SubExp]
offsets [SubExp]
offsets2'
    VName
rotate_rearrange <-
      StmAux (ExpDec lore) -> RuleM lore VName -> RuleM lore VName
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux (ExpDec lore)
aux (RuleM lore VName -> RuleM lore VName)
-> RuleM lore VName -> RuleM lore VName
forall a b. (a -> b) -> a -> b
$ String -> Exp (Lore (RuleM lore)) -> RuleM lore VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"rotate_rearrange" (Exp (Lore (RuleM lore)) -> RuleM lore VName)
-> Exp (Lore (RuleM lore)) -> RuleM lore VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Rearrange [Int]
perm VName
v3
    Certificates -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (Certificates
v_cs Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> Certificates
v2_cs) (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
      Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ [SubExp] -> VName -> BasicOp
Rotate [SubExp]
offsets' VName
rotate_rearrange

-- Combining Rotates.
ruleBasicOp TopDown lore
vtable Pattern lore
pat StmAux (ExpDec lore)
aux (Rotate [SubExp]
offsets1 VName
v)
  | Just (BasicOp (Rotate [SubExp]
offsets2 VName
v2), Certificates
v_cs) <- VName -> TopDown lore -> Maybe (Exp lore, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (Exp lore, Certificates)
ST.lookupExp VName
v TopDown lore
vtable = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
    [SubExp]
offsets <- (SubExp -> SubExp -> RuleM lore SubExp)
-> [SubExp] -> [SubExp] -> RuleM lore [SubExp]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM SubExp -> SubExp -> RuleM lore SubExp
forall (m :: * -> *). MonadBinder m => SubExp -> SubExp -> m SubExp
add [SubExp]
offsets1 [SubExp]
offsets2
    Certificates -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
v_cs (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
      StmAux (ExpDec lore) -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux (ExpDec lore)
aux (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
        Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ [SubExp] -> VName -> BasicOp
Rotate [SubExp]
offsets VName
v2
  where
    add :: SubExp -> SubExp -> m SubExp
add SubExp
x SubExp
y = String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"offset" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Add IntType
Int64 Overflow
OverflowWrap) SubExp
x SubExp
y

-- If we see an Update with a scalar where the value to be written is
-- the result of indexing some other array, then we convert it into an
-- Update with a slice of that array.  This matters when the arrays
-- are far away (on the GPU, say), because it avoids a copy of the
-- scalar to and from the host.
ruleBasicOp TopDown lore
vtable Pattern lore
pat StmAux (ExpDec lore)
aux (Update VName
arr_x Slice SubExp
slice_x (Var VName
v))
  | Just [SubExp]
_ <- Slice SubExp -> Maybe [SubExp]
forall d. Slice d -> Maybe [d]
sliceIndices Slice SubExp
slice_x,
    Just (Index VName
arr_y Slice SubExp
slice_y, Certificates
cs_y) <- VName -> TopDown lore -> Maybe (BasicOp, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (BasicOp, Certificates)
ST.lookupBasicOp VName
v TopDown lore
vtable,
    VName -> TopDown lore -> Bool
forall lore. VName -> SymbolTable lore -> Bool
ST.available VName
arr_y TopDown lore
vtable,
    -- XXX: we should check for proper aliasing here instead.
    VName
arr_y VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
/= VName
arr_x,
    Just (Slice SubExp
slice_x_bef, DimFix SubExp
i, []) <- Int
-> Slice SubExp
-> Maybe (Slice SubExp, DimIndex SubExp, Slice SubExp)
forall int a. Integral int => int -> [a] -> Maybe ([a], a, [a])
focusNth (Slice SubExp -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Slice SubExp
slice_x Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Slice SubExp
slice_x,
    Just (Slice SubExp
slice_y_bef, DimFix SubExp
j, []) <- Int
-> Slice SubExp
-> Maybe (Slice SubExp, DimIndex SubExp, Slice SubExp)
forall int a. Integral int => int -> [a] -> Maybe ([a], a, [a])
focusNth (Slice SubExp -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Slice SubExp
slice_y Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Slice SubExp
slice_y = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
    let slice_x' :: Slice SubExp
slice_x' = Slice SubExp
slice_x_bef Slice SubExp -> Slice SubExp -> Slice SubExp
forall a. [a] -> [a] -> [a]
++ [SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice SubExp
i (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1)]
        slice_y' :: Slice SubExp
slice_y' = Slice SubExp
slice_y_bef Slice SubExp -> Slice SubExp -> Slice SubExp
forall a. [a] -> [a] -> [a]
++ [SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice SubExp
j (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1)]
    VName
v' <- String -> Exp (Lore (RuleM lore)) -> RuleM lore VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp (VName -> String
baseString VName
v String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_slice") (Exp (Lore (RuleM lore)) -> RuleM lore VName)
-> Exp (Lore (RuleM lore)) -> RuleM lore VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr_y Slice SubExp
slice_y'
    Certificates -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs_y (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
      StmAux (ExpDec lore) -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux (ExpDec lore)
aux (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
        Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> SubExp -> BasicOp
Update VName
arr_x Slice SubExp
slice_x' (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v'

-- Simplify away 0<=i when 'i' is from a loop of form 'for i < n'.
ruleBasicOp TopDown lore
vtable Pattern lore
pat StmAux (ExpDec lore)
aux (CmpOp CmpSle {} SubExp
x SubExp
y)
  | Constant (IntValue (Int64Value Int64
0)) <- SubExp
x,
    Var VName
v <- SubExp
y,
    Just SubExp
_ <- VName -> TopDown lore -> Maybe SubExp
forall lore. VName -> SymbolTable lore -> Maybe SubExp
ST.lookupLoopVar VName
v TopDown lore
vtable =
    RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ StmAux (ExpDec lore) -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux (ExpDec lore)
aux (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Bool -> SubExp
forall v. IsValue v => v -> SubExp
constant Bool
True
-- Simplify away i<n when 'i' is from a loop of form 'for i < n'.
ruleBasicOp TopDown lore
vtable Pattern lore
pat StmAux (ExpDec lore)
aux (CmpOp CmpSlt {} SubExp
x SubExp
y)
  | Var VName
v <- SubExp
x,
    Just SubExp
n <- VName -> TopDown lore -> Maybe SubExp
forall lore. VName -> SymbolTable lore -> Maybe SubExp
ST.lookupLoopVar VName
v TopDown lore
vtable,
    SubExp
n SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
y =
    RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ StmAux (ExpDec lore) -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux (ExpDec lore)
aux (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Bool -> SubExp
forall v. IsValue v => v -> SubExp
constant Bool
True
-- Simplify away x<0 when 'x' has been used as array size.
ruleBasicOp TopDown lore
vtable Pattern lore
pat StmAux (ExpDec lore)
aux (CmpOp CmpSlt {} (Var VName
x) SubExp
y)
  | SubExp -> Bool
isCt0 SubExp
y,
    Bool -> (Entry lore -> Bool) -> Maybe (Entry lore) -> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
False Entry lore -> Bool
forall lore. Entry lore -> Bool
ST.entryIsSize (Maybe (Entry lore) -> Bool) -> Maybe (Entry lore) -> Bool
forall a b. (a -> b) -> a -> b
$ VName -> TopDown lore -> Maybe (Entry lore)
forall lore. VName -> SymbolTable lore -> Maybe (Entry lore)
ST.lookup VName
x TopDown lore
vtable =
    RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ StmAux (ExpDec lore) -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux (ExpDec lore)
aux (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Bool -> SubExp
forall v. IsValue v => v -> SubExp
constant Bool
False
ruleBasicOp TopDown lore
_ Pattern lore
_ StmAux (ExpDec lore)
_ BasicOp
_ =
  Rule lore
forall lore. Rule lore
Skip

-- | Remove the return values of a branch, that are not actually used
-- after a branch.  Standard dead code removal can remove the branch
-- if *none* of the return values are used, but this rule is more
-- precise.
removeDeadBranchResult :: BinderOps lore => BottomUpRuleIf lore
removeDeadBranchResult :: BottomUpRuleIf lore
removeDeadBranchResult (SymbolTable lore
_, UsageTable
used) Pattern lore
pat StmAux (ExpDec lore)
_ (SubExp
e1, BodyT lore
tb, BodyT lore
fb, IfDec [BranchType lore]
rettype IfSort
ifsort)
  | -- Only if there is no existential context...
    Pattern lore -> Int
forall dec. PatternT dec -> Int
patternSize Pattern lore
pat Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [BranchType lore] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [BranchType lore]
rettype,
    -- Figure out which of the names in 'pat' are used...
    [Bool]
patused <- (VName -> Bool) -> [VName] -> [Bool]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> UsageTable -> Bool
`UT.isUsedDirectly` UsageTable
used) ([VName] -> [Bool]) -> [VName] -> [Bool]
forall a b. (a -> b) -> a -> b
$ Pattern lore -> [VName]
forall dec. PatternT dec -> [VName]
patternNames Pattern lore
pat,
    -- If they are not all used, then this rule applies.
    Bool -> Bool
not ([Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and [Bool]
patused) =
    -- Remove the parts of the branch-results that correspond to dead
    -- return value bindings.  Note that this leaves dead code in the
    -- branch bodies, but that will be removed later.
    let tses :: [SubExp]
tses = BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT lore
tb
        fses :: [SubExp]
fses = BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT lore
fb
        pick :: [a] -> [a]
        pick :: [a] -> [a]
pick = ((Bool, a) -> a) -> [(Bool, a)] -> [a]
forall a b. (a -> b) -> [a] -> [b]
map (Bool, a) -> a
forall a b. (a, b) -> b
snd ([(Bool, a)] -> [a]) -> ([a] -> [(Bool, a)]) -> [a] -> [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Bool, a) -> Bool) -> [(Bool, a)] -> [(Bool, a)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool, a) -> Bool
forall a b. (a, b) -> a
fst ([(Bool, a)] -> [(Bool, a)])
-> ([a] -> [(Bool, a)]) -> [a] -> [(Bool, a)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Bool] -> [a] -> [(Bool, a)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Bool]
patused
        tb' :: BodyT lore
tb' = BodyT lore
tb {bodyResult :: [SubExp]
bodyResult = [SubExp] -> [SubExp]
forall a. [a] -> [a]
pick [SubExp]
tses}
        fb' :: BodyT lore
fb' = BodyT lore
fb {bodyResult :: [SubExp]
bodyResult = [SubExp] -> [SubExp]
forall a. [a] -> [a]
pick [SubExp]
fses}
        pat' :: [PatElemT (LetDec lore)]
pat' = [PatElemT (LetDec lore)] -> [PatElemT (LetDec lore)]
forall a. [a] -> [a]
pick ([PatElemT (LetDec lore)] -> [PatElemT (LetDec lore)])
-> [PatElemT (LetDec lore)] -> [PatElemT (LetDec lore)]
forall a b. (a -> b) -> a -> b
$ Pattern lore -> [PatElemT (LetDec lore)]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern lore
pat
        rettype' :: [BranchType lore]
rettype' = [BranchType lore] -> [BranchType lore]
forall a. [a] -> [a]
pick [BranchType lore]
rettype
     in RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind ([PatElemT (LetDec lore)]
-> [PatElemT (LetDec lore)] -> Pattern lore
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [PatElemT (LetDec lore)]
pat') (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ SubExp
-> BodyT lore -> BodyT lore -> IfDec (BranchType lore) -> ExpT lore
forall lore.
SubExp
-> BodyT lore -> BodyT lore -> IfDec (BranchType lore) -> ExpT lore
If SubExp
e1 BodyT lore
tb' BodyT lore
fb' (IfDec (BranchType lore) -> ExpT lore)
-> IfDec (BranchType lore) -> ExpT lore
forall a b. (a -> b) -> a -> b
$ [BranchType lore] -> IfSort -> IfDec (BranchType lore)
forall rt. [rt] -> IfSort -> IfDec rt
IfDec [BranchType lore]
rettype' IfSort
ifsort
  | Bool
otherwise = Rule lore
forall lore. Rule lore
Skip

-- Some helper functions

isCt1 :: SubExp -> Bool
isCt1 :: SubExp -> Bool
isCt1 (Constant PrimValue
v) = PrimValue -> Bool
oneIsh PrimValue
v
isCt1 SubExp
_ = Bool
False

isCt0 :: SubExp -> Bool
isCt0 :: SubExp -> Bool
isCt0 (Constant PrimValue
v) = PrimValue -> Bool
zeroIsh PrimValue
v
isCt0 SubExp
_ = Bool
False