{-# LANGUAGE OverloadedStrings #-}

-- | Loop simplification rules.
module Futhark.Optimise.Simplify.Rules.Loop (loopRules) where

import Control.Monad
import Data.List (partition)
import Data.Maybe
import Futhark.Analysis.DataDependencies
import Futhark.Analysis.PrimExp.Convert
import qualified Futhark.Analysis.SymbolTable as ST
import qualified Futhark.Analysis.UsageTable as UT
import Futhark.Construct
import Futhark.IR
import Futhark.IR.Prop.Aliases
import Futhark.Optimise.Simplify.Rule
import Futhark.Optimise.Simplify.Rules.ClosedForm
import Futhark.Optimise.Simplify.Rules.Index
import Futhark.Transform.Rename

-- This next one is tricky - it's easy enough to determine that some
-- loop result is not used after the loop, but here, we must also make
-- sure that it does not affect any other values.
--
-- I do not claim that the current implementation of this rule is
-- perfect, but it should suffice for many cases, and should never
-- generate wrong code.
removeRedundantMergeVariables :: BinderOps rep => BottomUpRuleDoLoop rep
removeRedundantMergeVariables :: forall rep. BinderOps rep => BottomUpRuleDoLoop rep
removeRedundantMergeVariables (SymbolTable rep
_, UsageTable
used) Pattern rep
pat StmAux (ExpDec rep)
aux ([(Param (FParamInfo rep), SubExp)]
ctx, [(Param (FParamInfo rep), SubExp)]
val, LoopForm rep
form, BodyT rep
body)
  | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ ((Param (FParamInfo rep), SubExp) -> Bool)
-> [(Param (FParamInfo rep), SubExp)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Param (FParamInfo rep) -> Bool
forall {dec}. Param dec -> Bool
usedAfterLoop (Param (FParamInfo rep) -> Bool)
-> ((Param (FParamInfo rep), SubExp) -> Param (FParamInfo rep))
-> (Param (FParamInfo rep), SubExp)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param (FParamInfo rep), SubExp) -> Param (FParamInfo rep)
forall a b. (a, b) -> a
fst) [(Param (FParamInfo rep), SubExp)]
val,
    [(Param (FParamInfo rep), SubExp)] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(Param (FParamInfo rep), SubExp)]
ctx -- FIXME: things get tricky if we can remove all vals
    -- but some ctxs are still used.  We take the easy way
    -- out for now.
    =
    let ([SubExp]
ctx_es, [SubExp]
val_es) = Int -> [SubExp] -> ([SubExp], [SubExp])
forall a. Int -> [a] -> ([a], [a])
splitAt ([(Param (FParamInfo rep), SubExp)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(Param (FParamInfo rep), SubExp)]
ctx) ([SubExp] -> ([SubExp], [SubExp]))
-> [SubExp] -> ([SubExp], [SubExp])
forall a b. (a -> b) -> a -> b
$ BodyT rep -> [SubExp]
forall rep. BodyT rep -> [SubExp]
bodyResult BodyT rep
body
        necessaryForReturned :: Names
necessaryForReturned =
          (Param (FParamInfo rep) -> Bool)
-> [(Param (FParamInfo rep), SubExp)] -> Map VName Names -> Names
forall dec.
(Param dec -> Bool)
-> [(Param dec, SubExp)] -> Map VName Names -> Names
findNecessaryForReturned
            Param (FParamInfo rep) -> Bool
forall {dec}. Param dec -> Bool
usedAfterLoopOrInForm
            ([Param (FParamInfo rep)]
-> [SubExp] -> [(Param (FParamInfo rep), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (((Param (FParamInfo rep), SubExp) -> Param (FParamInfo rep))
-> [(Param (FParamInfo rep), SubExp)] -> [Param (FParamInfo rep)]
forall a b. (a -> b) -> [a] -> [b]
map (Param (FParamInfo rep), SubExp) -> Param (FParamInfo rep)
forall a b. (a, b) -> a
fst ([(Param (FParamInfo rep), SubExp)] -> [Param (FParamInfo rep)])
-> [(Param (FParamInfo rep), SubExp)] -> [Param (FParamInfo rep)]
forall a b. (a -> b) -> a -> b
$ [(Param (FParamInfo rep), SubExp)]
ctx [(Param (FParamInfo rep), SubExp)]
-> [(Param (FParamInfo rep), SubExp)]
-> [(Param (FParamInfo rep), SubExp)]
forall a. [a] -> [a] -> [a]
++ [(Param (FParamInfo rep), SubExp)]
val) ([SubExp] -> [(Param (FParamInfo rep), SubExp)])
-> [SubExp] -> [(Param (FParamInfo rep), SubExp)]
forall a b. (a -> b) -> a -> b
$ [SubExp]
ctx_es [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp]
val_es)
            (BodyT rep -> Map VName Names
forall rep. ASTRep rep => Body rep -> Map VName Names
dataDependencies BodyT rep
body)

        resIsNecessary :: ((Param dec, b), b) -> Bool
resIsNecessary ((Param dec
v, b
_), b
_) =
          Param dec -> Bool
forall {dec}. Param dec -> Bool
usedAfterLoop Param dec
v
            Bool -> Bool -> Bool
|| Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
v VName -> Names -> Bool
`nameIn` Names
necessaryForReturned
            Bool -> Bool -> Bool
|| Param dec -> Bool
forall {dec}. Param dec -> Bool
referencedInPat Param dec
v
            Bool -> Bool -> Bool
|| Param dec -> Bool
forall {dec}. Param dec -> Bool
referencedInForm Param dec
v

        ([((Param (FParamInfo rep), SubExp), SubExp)]
keep_ctx, [((Param (FParamInfo rep), SubExp), SubExp)]
discard_ctx) =
          (((Param (FParamInfo rep), SubExp), SubExp) -> Bool)
-> [((Param (FParamInfo rep), SubExp), SubExp)]
-> ([((Param (FParamInfo rep), SubExp), SubExp)],
    [((Param (FParamInfo rep), SubExp), SubExp)])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition ((Param (FParamInfo rep), SubExp), SubExp) -> Bool
forall {dec} {b} {b}. ((Param dec, b), b) -> Bool
resIsNecessary ([((Param (FParamInfo rep), SubExp), SubExp)]
 -> ([((Param (FParamInfo rep), SubExp), SubExp)],
     [((Param (FParamInfo rep), SubExp), SubExp)]))
-> [((Param (FParamInfo rep), SubExp), SubExp)]
-> ([((Param (FParamInfo rep), SubExp), SubExp)],
    [((Param (FParamInfo rep), SubExp), SubExp)])
forall a b. (a -> b) -> a -> b
$ [(Param (FParamInfo rep), SubExp)]
-> [SubExp] -> [((Param (FParamInfo rep), SubExp), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [(Param (FParamInfo rep), SubExp)]
ctx [SubExp]
ctx_es
        ([(PatElemT (LetDec rep),
  ((Param (FParamInfo rep), SubExp), SubExp))]
keep_valpart, [(PatElemT (LetDec rep),
  ((Param (FParamInfo rep), SubExp), SubExp))]
discard_valpart) =
          ((PatElemT (LetDec rep),
  ((Param (FParamInfo rep), SubExp), SubExp))
 -> Bool)
-> [(PatElemT (LetDec rep),
     ((Param (FParamInfo rep), SubExp), SubExp))]
-> ([(PatElemT (LetDec rep),
      ((Param (FParamInfo rep), SubExp), SubExp))],
    [(PatElemT (LetDec rep),
      ((Param (FParamInfo rep), SubExp), SubExp))])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition (((Param (FParamInfo rep), SubExp), SubExp) -> Bool
forall {dec} {b} {b}. ((Param dec, b), b) -> Bool
resIsNecessary (((Param (FParamInfo rep), SubExp), SubExp) -> Bool)
-> ((PatElemT (LetDec rep),
     ((Param (FParamInfo rep), SubExp), SubExp))
    -> ((Param (FParamInfo rep), SubExp), SubExp))
-> (PatElemT (LetDec rep),
    ((Param (FParamInfo rep), SubExp), SubExp))
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PatElemT (LetDec rep), ((Param (FParamInfo rep), SubExp), SubExp))
-> ((Param (FParamInfo rep), SubExp), SubExp)
forall a b. (a, b) -> b
snd) ([(PatElemT (LetDec rep),
   ((Param (FParamInfo rep), SubExp), SubExp))]
 -> ([(PatElemT (LetDec rep),
       ((Param (FParamInfo rep), SubExp), SubExp))],
     [(PatElemT (LetDec rep),
       ((Param (FParamInfo rep), SubExp), SubExp))]))
-> [(PatElemT (LetDec rep),
     ((Param (FParamInfo rep), SubExp), SubExp))]
-> ([(PatElemT (LetDec rep),
      ((Param (FParamInfo rep), SubExp), SubExp))],
    [(PatElemT (LetDec rep),
      ((Param (FParamInfo rep), SubExp), SubExp))])
forall a b. (a -> b) -> a -> b
$
            [PatElemT (LetDec rep)]
-> [((Param (FParamInfo rep), SubExp), SubExp)]
-> [(PatElemT (LetDec rep),
     ((Param (FParamInfo rep), SubExp), SubExp))]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pattern rep -> [PatElemT (LetDec rep)]
forall dec. PatternT dec -> [PatElemT dec]
patternValueElements Pattern rep
pat) ([((Param (FParamInfo rep), SubExp), SubExp)]
 -> [(PatElemT (LetDec rep),
      ((Param (FParamInfo rep), SubExp), SubExp))])
-> [((Param (FParamInfo rep), SubExp), SubExp)]
-> [(PatElemT (LetDec rep),
     ((Param (FParamInfo rep), SubExp), SubExp))]
forall a b. (a -> b) -> a -> b
$ [(Param (FParamInfo rep), SubExp)]
-> [SubExp] -> [((Param (FParamInfo rep), SubExp), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [(Param (FParamInfo rep), SubExp)]
val [SubExp]
val_es

        ([PatElemT (LetDec rep)]
keep_valpatelems, [((Param (FParamInfo rep), SubExp), SubExp)]
keep_val) = [(PatElemT (LetDec rep),
  ((Param (FParamInfo rep), SubExp), SubExp))]
-> ([PatElemT (LetDec rep)],
    [((Param (FParamInfo rep), SubExp), SubExp)])
forall a b. [(a, b)] -> ([a], [b])
unzip [(PatElemT (LetDec rep),
  ((Param (FParamInfo rep), SubExp), SubExp))]
keep_valpart
        ([PatElemT (LetDec rep)]
_discard_valpatelems, [((Param (FParamInfo rep), SubExp), SubExp)]
discard_val) = [(PatElemT (LetDec rep),
  ((Param (FParamInfo rep), SubExp), SubExp))]
-> ([PatElemT (LetDec rep)],
    [((Param (FParamInfo rep), SubExp), SubExp)])
forall a b. [(a, b)] -> ([a], [b])
unzip [(PatElemT (LetDec rep),
  ((Param (FParamInfo rep), SubExp), SubExp))]
discard_valpart
        ([(Param (FParamInfo rep), SubExp)]
ctx', [SubExp]
ctx_es') = [((Param (FParamInfo rep), SubExp), SubExp)]
-> ([(Param (FParamInfo rep), SubExp)], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip [((Param (FParamInfo rep), SubExp), SubExp)]
keep_ctx
        ([(Param (FParamInfo rep), SubExp)]
val', [SubExp]
val_es') = [((Param (FParamInfo rep), SubExp), SubExp)]
-> ([(Param (FParamInfo rep), SubExp)], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip [((Param (FParamInfo rep), SubExp), SubExp)]
keep_val

        body' :: BodyT rep
body' = BodyT rep
body {bodyResult :: [SubExp]
bodyResult = [SubExp]
ctx_es' [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp]
val_es'}
        free_in_keeps :: Names
free_in_keeps = [PatElemT (LetDec rep)] -> Names
forall a. FreeIn a => a -> Names
freeIn [PatElemT (LetDec rep)]
keep_valpatelems

        stillUsedContext :: PatElemT (LetDec rep) -> Bool
stillUsedContext PatElemT (LetDec rep)
pat_elem =
          PatElemT (LetDec rep) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec rep)
pat_elem
            VName -> Names -> Bool
`nameIn` ( Names
free_in_keeps
                         Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> [PatElemT (LetDec rep)] -> Names
forall a. FreeIn a => a -> Names
freeIn ((PatElemT (LetDec rep) -> Bool)
-> [PatElemT (LetDec rep)] -> [PatElemT (LetDec rep)]
forall a. (a -> Bool) -> [a] -> [a]
filter (PatElemT (LetDec rep) -> PatElemT (LetDec rep) -> Bool
forall a. Eq a => a -> a -> Bool
/= PatElemT (LetDec rep)
pat_elem) ([PatElemT (LetDec rep)] -> [PatElemT (LetDec rep)])
-> [PatElemT (LetDec rep)] -> [PatElemT (LetDec rep)]
forall a b. (a -> b) -> a -> b
$ Pattern rep -> [PatElemT (LetDec rep)]
forall dec. PatternT dec -> [PatElemT dec]
patternContextElements Pattern rep
pat)
                     )

        pat' :: Pattern rep
pat' =
          Pattern rep
pat
            { patternValueElements :: [PatElemT (LetDec rep)]
patternValueElements = [PatElemT (LetDec rep)]
keep_valpatelems,
              patternContextElements :: [PatElemT (LetDec rep)]
patternContextElements =
                (PatElemT (LetDec rep) -> Bool)
-> [PatElemT (LetDec rep)] -> [PatElemT (LetDec rep)]
forall a. (a -> Bool) -> [a] -> [a]
filter PatElemT (LetDec rep) -> Bool
stillUsedContext ([PatElemT (LetDec rep)] -> [PatElemT (LetDec rep)])
-> [PatElemT (LetDec rep)] -> [PatElemT (LetDec rep)]
forall a b. (a -> b) -> a -> b
$ Pattern rep -> [PatElemT (LetDec rep)]
forall dec. PatternT dec -> [PatElemT dec]
patternContextElements Pattern rep
pat
            }
     in if [(Param (FParamInfo rep), SubExp)]
ctx' [(Param (FParamInfo rep), SubExp)]
-> [(Param (FParamInfo rep), SubExp)]
-> [(Param (FParamInfo rep), SubExp)]
forall a. [a] -> [a] -> [a]
++ [(Param (FParamInfo rep), SubExp)]
val' [(Param (FParamInfo rep), SubExp)]
-> [(Param (FParamInfo rep), SubExp)] -> Bool
forall a. Eq a => a -> a -> Bool
== [(Param (FParamInfo rep), SubExp)]
ctx [(Param (FParamInfo rep), SubExp)]
-> [(Param (FParamInfo rep), SubExp)]
-> [(Param (FParamInfo rep), SubExp)]
forall a. [a] -> [a] -> [a]
++ [(Param (FParamInfo rep), SubExp)]
val
          then Rule rep
forall rep. Rule rep
Skip
          else RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
            -- We can't just remove the bindings in 'discard', since the loop
            -- body may still use their names in (now-dead) expressions.
            -- Hence, we add them inside the loop, fully aware that dead-code
            -- removal will eventually get rid of them.  Some care is
            -- necessary to handle unique bindings.
            BodyT rep
body'' <- RuleM rep (Body (Rep (RuleM rep)))
-> RuleM rep (Body (Rep (RuleM rep)))
forall (m :: * -> *).
MonadBinder m =>
m (Body (Rep m)) -> m (Body (Rep m))
insertStmsM (RuleM rep (Body (Rep (RuleM rep)))
 -> RuleM rep (Body (Rep (RuleM rep))))
-> RuleM rep (Body (Rep (RuleM rep)))
-> RuleM rep (Body (Rep (RuleM rep)))
forall a b. (a -> b) -> a -> b
$ do
              (([VName], ExpT rep) -> RuleM rep ())
-> [([VName], ExpT rep)] -> RuleM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (([VName] -> ExpT rep -> RuleM rep ())
-> ([VName], ExpT rep) -> RuleM rep ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry [VName] -> ExpT rep -> RuleM rep ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames) ([([VName], ExpT rep)] -> RuleM rep ())
-> [([VName], ExpT rep)] -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ [((Param (FParamInfo rep), SubExp), SubExp)]
-> [([VName], ExpT rep)]
forall {b} {rep}.
[((Param (FParamInfo rep), SubExp), b)] -> [([VName], ExpT rep)]
dummyStms [((Param (FParamInfo rep), SubExp), SubExp)]
discard_ctx
              (([VName], ExpT rep) -> RuleM rep ())
-> [([VName], ExpT rep)] -> RuleM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (([VName] -> ExpT rep -> RuleM rep ())
-> ([VName], ExpT rep) -> RuleM rep ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry [VName] -> ExpT rep -> RuleM rep ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames) ([([VName], ExpT rep)] -> RuleM rep ())
-> [([VName], ExpT rep)] -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ [((Param (FParamInfo rep), SubExp), SubExp)]
-> [([VName], ExpT rep)]
forall {b} {rep}.
[((Param (FParamInfo rep), SubExp), b)] -> [([VName], ExpT rep)]
dummyStms [((Param (FParamInfo rep), SubExp), SubExp)]
discard_val
              BodyT rep -> RuleM rep (BodyT rep)
forall (m :: * -> *) a. Monad m => a -> m a
return BodyT rep
body'
            StmAux (ExpDec rep) -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) anyrep a.
MonadBinder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ Pattern (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Rep m) -> Exp (Rep m) -> m ()
letBind Pattern rep
Pattern (Rep (RuleM rep))
pat' (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ [(Param (FParamInfo rep), SubExp)]
-> [(Param (FParamInfo rep), SubExp)]
-> LoopForm rep
-> BodyT rep
-> ExpT rep
forall rep.
[(FParam rep, SubExp)]
-> [(FParam rep, SubExp)] -> LoopForm rep -> BodyT rep -> ExpT rep
DoLoop [(Param (FParamInfo rep), SubExp)]
ctx' [(Param (FParamInfo rep), SubExp)]
val' LoopForm rep
form BodyT rep
body''
  where
    pat_used :: [Bool]
pat_used = (VName -> Bool) -> [VName] -> [Bool]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> UsageTable -> Bool
`UT.isUsedDirectly` UsageTable
used) ([VName] -> [Bool]) -> [VName] -> [Bool]
forall a b. (a -> b) -> a -> b
$ Pattern rep -> [VName]
forall dec. PatternT dec -> [VName]
patternValueNames Pattern rep
pat
    used_vals :: [VName]
used_vals = ((VName, Bool) -> VName) -> [(VName, Bool)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, Bool) -> VName
forall a b. (a, b) -> a
fst ([(VName, Bool)] -> [VName]) -> [(VName, Bool)] -> [VName]
forall a b. (a -> b) -> a -> b
$ ((VName, Bool) -> Bool) -> [(VName, Bool)] -> [(VName, Bool)]
forall a. (a -> Bool) -> [a] -> [a]
filter (VName, Bool) -> Bool
forall a b. (a, b) -> b
snd ([(VName, Bool)] -> [(VName, Bool)])
-> [(VName, Bool)] -> [(VName, Bool)]
forall a b. (a -> b) -> a -> b
$ [VName] -> [Bool] -> [(VName, Bool)]
forall a b. [a] -> [b] -> [(a, b)]
zip (((Param (FParamInfo rep), SubExp) -> VName)
-> [(Param (FParamInfo rep), SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (Param (FParamInfo rep) -> VName
forall dec. Param dec -> VName
paramName (Param (FParamInfo rep) -> VName)
-> ((Param (FParamInfo rep), SubExp) -> Param (FParamInfo rep))
-> (Param (FParamInfo rep), SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param (FParamInfo rep), SubExp) -> Param (FParamInfo rep)
forall a b. (a, b) -> a
fst) [(Param (FParamInfo rep), SubExp)]
val) [Bool]
pat_used
    usedAfterLoop :: Param dec -> Bool
usedAfterLoop = (VName -> [VName] -> Bool) -> [VName] -> VName -> Bool
forall a b c. (a -> b -> c) -> b -> a -> c
flip VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
elem [VName]
used_vals (VName -> Bool) -> (Param dec -> VName) -> Param dec -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param dec -> VName
forall dec. Param dec -> VName
paramName
    usedAfterLoopOrInForm :: Param dec -> Bool
usedAfterLoopOrInForm Param dec
p =
      Param dec -> Bool
forall {dec}. Param dec -> Bool
usedAfterLoop Param dec
p Bool -> Bool -> Bool
|| Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
p VName -> Names -> Bool
`nameIn` LoopForm rep -> Names
forall a. FreeIn a => a -> Names
freeIn LoopForm rep
form
    patAnnotNames :: Names
patAnnotNames = [Param (FParamInfo rep)] -> Names
forall a. FreeIn a => a -> Names
freeIn ([Param (FParamInfo rep)] -> Names)
-> [Param (FParamInfo rep)] -> Names
forall a b. (a -> b) -> a -> b
$ ((Param (FParamInfo rep), SubExp) -> Param (FParamInfo rep))
-> [(Param (FParamInfo rep), SubExp)] -> [Param (FParamInfo rep)]
forall a b. (a -> b) -> [a] -> [b]
map (Param (FParamInfo rep), SubExp) -> Param (FParamInfo rep)
forall a b. (a, b) -> a
fst ([(Param (FParamInfo rep), SubExp)] -> [Param (FParamInfo rep)])
-> [(Param (FParamInfo rep), SubExp)] -> [Param (FParamInfo rep)]
forall a b. (a -> b) -> a -> b
$ [(Param (FParamInfo rep), SubExp)]
ctx [(Param (FParamInfo rep), SubExp)]
-> [(Param (FParamInfo rep), SubExp)]
-> [(Param (FParamInfo rep), SubExp)]
forall a. [a] -> [a] -> [a]
++ [(Param (FParamInfo rep), SubExp)]
val
    referencedInPat :: Param dec -> Bool
referencedInPat = (VName -> Names -> Bool
`nameIn` Names
patAnnotNames) (VName -> Bool) -> (Param dec -> VName) -> Param dec -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param dec -> VName
forall dec. Param dec -> VName
paramName
    referencedInForm :: Param dec -> Bool
referencedInForm = (VName -> Names -> Bool
`nameIn` LoopForm rep -> Names
forall a. FreeIn a => a -> Names
freeIn LoopForm rep
form) (VName -> Bool) -> (Param dec -> VName) -> Param dec -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param dec -> VName
forall dec. Param dec -> VName
paramName

    dummyStms :: [((Param (FParamInfo rep), SubExp), b)] -> [([VName], ExpT rep)]
dummyStms = (((Param (FParamInfo rep), SubExp), b) -> ([VName], ExpT rep))
-> [((Param (FParamInfo rep), SubExp), b)] -> [([VName], ExpT rep)]
forall a b. (a -> b) -> [a] -> [b]
map ((Param (FParamInfo rep), SubExp), b) -> ([VName], ExpT rep)
forall {dec} {b} {rep}.
DeclTyped dec =>
((Param dec, SubExp), b) -> ([VName], ExpT rep)
dummyStm
    dummyStm :: ((Param dec, SubExp), b) -> ([VName], ExpT rep)
dummyStm ((Param dec
p, SubExp
e), b
_)
      | TypeBase Shape Uniqueness -> Bool
forall shape. TypeBase shape Uniqueness -> Bool
unique (Param dec -> TypeBase Shape Uniqueness
forall dec. DeclTyped dec => Param dec -> TypeBase Shape Uniqueness
paramDeclType Param dec
p),
        Var VName
v <- SubExp
e =
        ([Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
p], BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
v)
      | Bool
otherwise = ([Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
p], BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
e)
removeRedundantMergeVariables (SymbolTable rep, UsageTable)
_ Pattern rep
_ StmAux (ExpDec rep)
_ ([(Param (FParamInfo rep), SubExp)],
 [(Param (FParamInfo rep), SubExp)], LoopForm rep, BodyT rep)
_ =
  Rule rep
forall rep. Rule rep
Skip

-- We may change the type of the loop if we hoist out a shape
-- annotation, in which case we also need to tweak the bound pattern.
hoistLoopInvariantMergeVariables :: BinderOps rep => TopDownRuleDoLoop rep
hoistLoopInvariantMergeVariables :: forall rep. BinderOps rep => TopDownRuleDoLoop rep
hoistLoopInvariantMergeVariables TopDown rep
vtable Pattern rep
pat StmAux (ExpDec rep)
aux ([(FParam rep, SubExp)]
ctx, [(FParam rep, SubExp)]
val, LoopForm rep
form, BodyT rep
loopbody) =
  -- Figure out which of the elements of loopresult are
  -- loop-invariant, and hoist them out.
  case ((VName, (FParam rep, SubExp), SubExp)
 -> ([(Ident, SubExp)], [(PatElemT (LetDec rep), VName)],
     [(FParam rep, SubExp)], [SubExp])
 -> ([(Ident, SubExp)], [(PatElemT (LetDec rep), VName)],
     [(FParam rep, SubExp)], [SubExp]))
-> ([(Ident, SubExp)], [(PatElemT (LetDec rep), VName)],
    [(FParam rep, SubExp)], [SubExp])
-> [(VName, (FParam rep, SubExp), SubExp)]
-> ([(Ident, SubExp)], [(PatElemT (LetDec rep), VName)],
    [(FParam rep, SubExp)], [SubExp])
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (VName, (FParam rep, SubExp), SubExp)
-> ([(Ident, SubExp)], [(PatElemT (LetDec rep), VName)],
    [(FParam rep, SubExp)], [SubExp])
-> ([(Ident, SubExp)], [(PatElemT (LetDec rep), VName)],
    [(FParam rep, SubExp)], [SubExp])
forall {dec} {dec}.
(DeclTyped dec, Typed dec, FreeIn dec, Typed dec) =>
(VName, (Param dec, SubExp), SubExp)
-> ([(Ident, SubExp)], [(PatElemT dec, VName)],
    [(Param dec, SubExp)], [SubExp])
-> ([(Ident, SubExp)], [(PatElemT dec, VName)],
    [(Param dec, SubExp)], [SubExp])
checkInvariance ([], [(PatElemT (LetDec rep), VName)]
explpat, [], []) ([(VName, (FParam rep, SubExp), SubExp)]
 -> ([(Ident, SubExp)], [(PatElemT (LetDec rep), VName)],
     [(FParam rep, SubExp)], [SubExp]))
-> [(VName, (FParam rep, SubExp), SubExp)]
-> ([(Ident, SubExp)], [(PatElemT (LetDec rep), VName)],
    [(FParam rep, SubExp)], [SubExp])
forall a b. (a -> b) -> a -> b
$
    [VName]
-> [(FParam rep, SubExp)]
-> [SubExp]
-> [(VName, (FParam rep, SubExp), SubExp)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 (Pattern rep -> [VName]
forall dec. PatternT dec -> [VName]
patternNames Pattern rep
pat) [(FParam rep, SubExp)]
merge [SubExp]
res of
    ([], [(PatElemT (LetDec rep), VName)]
_, [(FParam rep, SubExp)]
_, [SubExp]
_) ->
      -- Nothing is invariant.
      Rule rep
forall rep. Rule rep
Skip
    ([(Ident, SubExp)]
invariant, [(PatElemT (LetDec rep), VName)]
explpat', [(FParam rep, SubExp)]
merge', [SubExp]
res') -> RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
      -- We have moved something invariant out of the loop.
      let loopbody' :: BodyT rep
loopbody' = BodyT rep
loopbody {bodyResult :: [SubExp]
bodyResult = [SubExp]
res'}
          invariantShape :: (a, VName) -> Bool
          invariantShape :: forall a. (a, VName) -> Bool
invariantShape (a
_, VName
shapemerge) =
            VName
shapemerge
              VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` ((FParam rep, SubExp) -> VName)
-> [(FParam rep, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (FParam rep -> VName
forall dec. Param dec -> VName
paramName (FParam rep -> VName)
-> ((FParam rep, SubExp) -> FParam rep)
-> (FParam rep, SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (FParam rep, SubExp) -> FParam rep
forall a b. (a, b) -> a
fst) [(FParam rep, SubExp)]
merge'
          ([(PatElemT (LetDec rep), VName)]
implpat', [(PatElemT (LetDec rep), VName)]
implinvariant) = ((PatElemT (LetDec rep), VName) -> Bool)
-> [(PatElemT (LetDec rep), VName)]
-> ([(PatElemT (LetDec rep), VName)],
    [(PatElemT (LetDec rep), VName)])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition (PatElemT (LetDec rep), VName) -> Bool
forall a. (a, VName) -> Bool
invariantShape [(PatElemT (LetDec rep), VName)]
implpat
          implinvariant' :: [(Ident, SubExp)]
implinvariant' = [(PatElemT (LetDec rep) -> Ident
forall dec. Typed dec => PatElemT dec -> Ident
patElemIdent PatElemT (LetDec rep)
p, VName -> SubExp
Var VName
v) | (PatElemT (LetDec rep)
p, VName
v) <- [(PatElemT (LetDec rep), VName)]
implinvariant]
          implpat'' :: [PatElemT (LetDec rep)]
implpat'' = ((PatElemT (LetDec rep), VName) -> PatElemT (LetDec rep))
-> [(PatElemT (LetDec rep), VName)] -> [PatElemT (LetDec rep)]
forall a b. (a -> b) -> [a] -> [b]
map (PatElemT (LetDec rep), VName) -> PatElemT (LetDec rep)
forall a b. (a, b) -> a
fst [(PatElemT (LetDec rep), VName)]
implpat'
          explpat'' :: [PatElemT (LetDec rep)]
explpat'' = ((PatElemT (LetDec rep), VName) -> PatElemT (LetDec rep))
-> [(PatElemT (LetDec rep), VName)] -> [PatElemT (LetDec rep)]
forall a b. (a -> b) -> [a] -> [b]
map (PatElemT (LetDec rep), VName) -> PatElemT (LetDec rep)
forall a b. (a, b) -> a
fst [(PatElemT (LetDec rep), VName)]
explpat'
          ([(FParam rep, SubExp)]
ctx', [(FParam rep, SubExp)]
val') = Int
-> [(FParam rep, SubExp)]
-> ([(FParam rep, SubExp)], [(FParam rep, SubExp)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([(PatElemT (LetDec rep), VName)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(PatElemT (LetDec rep), VName)]
implpat') [(FParam rep, SubExp)]
merge'
      [(Ident, SubExp)]
-> ((Ident, SubExp) -> RuleM rep ()) -> RuleM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([(Ident, SubExp)]
invariant [(Ident, SubExp)] -> [(Ident, SubExp)] -> [(Ident, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(Ident, SubExp)]
implinvariant') (((Ident, SubExp) -> RuleM rep ()) -> RuleM rep ())
-> ((Ident, SubExp) -> RuleM rep ()) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ \(Ident
v1, SubExp
v2) ->
        [VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [Ident -> VName
identName Ident
v1] (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
v2
      StmAux (ExpDec rep) -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) anyrep a.
MonadBinder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
        Pattern (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Rep m) -> Exp (Rep m) -> m ()
letBind ([PatElemT (LetDec rep)] -> [PatElemT (LetDec rep)] -> Pattern rep
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [PatElemT (LetDec rep)]
implpat'' [PatElemT (LetDec rep)]
explpat'') (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
          [(FParam rep, SubExp)]
-> [(FParam rep, SubExp)] -> LoopForm rep -> BodyT rep -> ExpT rep
forall rep.
[(FParam rep, SubExp)]
-> [(FParam rep, SubExp)] -> LoopForm rep -> BodyT rep -> ExpT rep
DoLoop [(FParam rep, SubExp)]
ctx' [(FParam rep, SubExp)]
val' LoopForm rep
form BodyT rep
loopbody'
  where
    merge :: [(FParam rep, SubExp)]
merge = [(FParam rep, SubExp)]
ctx [(FParam rep, SubExp)]
-> [(FParam rep, SubExp)] -> [(FParam rep, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(FParam rep, SubExp)]
val
    res :: [SubExp]
res = BodyT rep -> [SubExp]
forall rep. BodyT rep -> [SubExp]
bodyResult BodyT rep
loopbody

    implpat :: [(PatElemT (LetDec rep), VName)]
implpat =
      [PatElemT (LetDec rep)]
-> [VName] -> [(PatElemT (LetDec rep), VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pattern rep -> [PatElemT (LetDec rep)]
forall dec. PatternT dec -> [PatElemT dec]
patternContextElements Pattern rep
pat) ([VName] -> [(PatElemT (LetDec rep), VName)])
-> [VName] -> [(PatElemT (LetDec rep), VName)]
forall a b. (a -> b) -> a -> b
$
        ((FParam rep, SubExp) -> VName)
-> [(FParam rep, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (FParam rep -> VName
forall dec. Param dec -> VName
paramName (FParam rep -> VName)
-> ((FParam rep, SubExp) -> FParam rep)
-> (FParam rep, SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (FParam rep, SubExp) -> FParam rep
forall a b. (a, b) -> a
fst) [(FParam rep, SubExp)]
ctx
    explpat :: [(PatElemT (LetDec rep), VName)]
explpat =
      [PatElemT (LetDec rep)]
-> [VName] -> [(PatElemT (LetDec rep), VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pattern rep -> [PatElemT (LetDec rep)]
forall dec. PatternT dec -> [PatElemT dec]
patternValueElements Pattern rep
pat) ([VName] -> [(PatElemT (LetDec rep), VName)])
-> [VName] -> [(PatElemT (LetDec rep), VName)]
forall a b. (a -> b) -> a -> b
$
        ((FParam rep, SubExp) -> VName)
-> [(FParam rep, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (FParam rep -> VName
forall dec. Param dec -> VName
paramName (FParam rep -> VName)
-> ((FParam rep, SubExp) -> FParam rep)
-> (FParam rep, SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (FParam rep, SubExp) -> FParam rep
forall a b. (a, b) -> a
fst) [(FParam rep, SubExp)]
val

    namesOfMergeParams :: Names
namesOfMergeParams = [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ ((FParam rep, SubExp) -> VName)
-> [(FParam rep, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (FParam rep -> VName
forall dec. Param dec -> VName
paramName (FParam rep -> VName)
-> ((FParam rep, SubExp) -> FParam rep)
-> (FParam rep, SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (FParam rep, SubExp) -> FParam rep
forall a b. (a, b) -> a
fst) ([(FParam rep, SubExp)] -> [VName])
-> [(FParam rep, SubExp)] -> [VName]
forall a b. (a -> b) -> a -> b
$ [(FParam rep, SubExp)]
ctx [(FParam rep, SubExp)]
-> [(FParam rep, SubExp)] -> [(FParam rep, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(FParam rep, SubExp)]
val

    removeFromResult :: (Param dec, b)
-> [(PatElemT dec, VName)]
-> (Maybe (Ident, b), [(PatElemT dec, VName)])
removeFromResult (Param dec
mergeParam, b
mergeInit) [(PatElemT dec, VName)]
explpat' =
      case ((PatElemT dec, VName) -> Bool)
-> [(PatElemT dec, VName)]
-> ([(PatElemT dec, VName)], [(PatElemT dec, VName)])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
mergeParam) (VName -> Bool)
-> ((PatElemT dec, VName) -> VName)
-> (PatElemT dec, VName)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PatElemT dec, VName) -> VName
forall a b. (a, b) -> b
snd) [(PatElemT dec, VName)]
explpat' of
        ([(PatElemT dec
patelem, VName
_)], [(PatElemT dec, VName)]
rest) ->
          ((Ident, b) -> Maybe (Ident, b)
forall a. a -> Maybe a
Just (PatElemT dec -> Ident
forall dec. Typed dec => PatElemT dec -> Ident
patElemIdent PatElemT dec
patelem, b
mergeInit), [(PatElemT dec, VName)]
rest)
        ([(PatElemT dec, VName)]
_, [(PatElemT dec, VName)]
_) ->
          (Maybe (Ident, b)
forall a. Maybe a
Nothing, [(PatElemT dec, VName)]
explpat')

    checkInvariance :: (VName, (Param dec, SubExp), SubExp)
-> ([(Ident, SubExp)], [(PatElemT dec, VName)],
    [(Param dec, SubExp)], [SubExp])
-> ([(Ident, SubExp)], [(PatElemT dec, VName)],
    [(Param dec, SubExp)], [SubExp])
checkInvariance
      (VName
pat_name, (Param dec
mergeParam, SubExp
mergeInit), SubExp
resExp)
      ([(Ident, SubExp)]
invariant, [(PatElemT dec, VName)]
explpat', [(Param dec, SubExp)]
merge', [SubExp]
resExps)
        | Bool -> Bool
not (TypeBase Shape Uniqueness -> Bool
forall shape. TypeBase shape Uniqueness -> Bool
unique (Param dec -> TypeBase Shape Uniqueness
forall dec. DeclTyped dec => Param dec -> TypeBase Shape Uniqueness
paramDeclType Param dec
mergeParam))
            Bool -> Bool -> Bool
|| TypeBase Shape Uniqueness -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank (Param dec -> TypeBase Shape Uniqueness
forall dec. DeclTyped dec => Param dec -> TypeBase Shape Uniqueness
paramDeclType Param dec
mergeParam) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1,
          Bool
isInvariant,
          -- Also do not remove the condition in a while-loop.
          Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
mergeParam VName -> Names -> Bool
`nameIn` LoopForm rep -> Names
forall a. FreeIn a => a -> Names
freeIn LoopForm rep
form =
          let (Maybe (Ident, SubExp)
bnd, [(PatElemT dec, VName)]
explpat'') =
                (Param dec, SubExp)
-> [(PatElemT dec, VName)]
-> (Maybe (Ident, SubExp), [(PatElemT dec, VName)])
forall {dec} {dec} {b}.
Typed dec =>
(Param dec, b)
-> [(PatElemT dec, VName)]
-> (Maybe (Ident, b), [(PatElemT dec, VName)])
removeFromResult (Param dec
mergeParam, SubExp
mergeInit) [(PatElemT dec, VName)]
explpat'
           in ( ([(Ident, SubExp)] -> [(Ident, SubExp)])
-> ((Ident, SubExp) -> [(Ident, SubExp)] -> [(Ident, SubExp)])
-> Maybe (Ident, SubExp)
-> [(Ident, SubExp)]
-> [(Ident, SubExp)]
forall b a. b -> (a -> b) -> Maybe a -> b
maybe [(Ident, SubExp)] -> [(Ident, SubExp)]
forall a. a -> a
id (:) Maybe (Ident, SubExp)
bnd ([(Ident, SubExp)] -> [(Ident, SubExp)])
-> [(Ident, SubExp)] -> [(Ident, SubExp)]
forall a b. (a -> b) -> a -> b
$ (Param dec -> Ident
forall dec. Typed dec => Param dec -> Ident
paramIdent Param dec
mergeParam, SubExp
mergeInit) (Ident, SubExp) -> [(Ident, SubExp)] -> [(Ident, SubExp)]
forall a. a -> [a] -> [a]
: [(Ident, SubExp)]
invariant,
                [(PatElemT dec, VName)]
explpat'',
                [(Param dec, SubExp)]
merge',
                [SubExp]
resExps
              )
        where
          -- A non-unique merge variable is invariant if one of the
          -- following is true:
          --
          -- (0) The result is a variable of the same name as the
          -- parameter, where all existential parameters are already
          -- known to be invariant
          isInvariant :: Bool
isInvariant
            | Var VName
v2 <- SubExp
resExp,
              Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
mergeParam VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
v2 =
              Names -> Param dec -> Bool
forall {dec}. FreeIn dec => Names -> Param dec -> Bool
allExistentialInvariant
                ([VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ ((Ident, SubExp) -> VName) -> [(Ident, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (Ident -> VName
identName (Ident -> VName)
-> ((Ident, SubExp) -> Ident) -> (Ident, SubExp) -> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Ident, SubExp) -> Ident
forall a b. (a, b) -> a
fst) [(Ident, SubExp)]
invariant)
                Param dec
mergeParam
            -- (1) The result is identical to the initial parameter value.
            | SubExp
mergeInit SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
resExp = Bool
True
            -- (2) The initial parameter value is equal to an outer
            -- loop parameter 'P', where the initial value of 'P' is
            -- equal to 'resExp', AND 'resExp' ultimately becomes the
            -- new value of 'P'.  XXX: it's a bit clumsy that this
            -- only works for one level of nesting, and I think it
            -- would not be too hard to generalise.
            | Var VName
init_v <- SubExp
mergeInit,
              Just (SubExp
p_init, SubExp
p_res) <- VName -> TopDown rep -> Maybe (SubExp, SubExp)
forall rep. VName -> SymbolTable rep -> Maybe (SubExp, SubExp)
ST.lookupLoopParam VName
init_v TopDown rep
vtable,
              SubExp
p_init SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
resExp,
              SubExp
p_res SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== VName -> SubExp
Var VName
pat_name =
              Bool
True
            | Bool
otherwise = Bool
False
    checkInvariance
      (VName
_pat_name, (Param dec
mergeParam, SubExp
mergeInit), SubExp
resExp)
      ([(Ident, SubExp)]
invariant, [(PatElemT dec, VName)]
explpat', [(Param dec, SubExp)]
merge', [SubExp]
resExps) =
        ([(Ident, SubExp)]
invariant, [(PatElemT dec, VName)]
explpat', (Param dec
mergeParam, SubExp
mergeInit) (Param dec, SubExp)
-> [(Param dec, SubExp)] -> [(Param dec, SubExp)]
forall a. a -> [a] -> [a]
: [(Param dec, SubExp)]
merge', SubExp
resExp SubExp -> [SubExp] -> [SubExp]
forall a. a -> [a] -> [a]
: [SubExp]
resExps)

    allExistentialInvariant :: Names -> Param dec -> Bool
allExistentialInvariant Names
namesOfInvariant Param dec
mergeParam =
      (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Names -> VName -> Bool
invariantOrNotMergeParam Names
namesOfInvariant) ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$
        Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$
          Param dec -> Names
forall a. FreeIn a => a -> Names
freeIn Param dec
mergeParam Names -> Names -> Names
`namesSubtract` VName -> Names
oneName (Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
mergeParam)
    invariantOrNotMergeParam :: Names -> VName -> Bool
invariantOrNotMergeParam Names
namesOfInvariant VName
name =
      Bool -> Bool
not (VName
name VName -> Names -> Bool
`nameIn` Names
namesOfMergeParams)
        Bool -> Bool -> Bool
|| VName
name VName -> Names -> Bool
`nameIn` Names
namesOfInvariant

simplifyClosedFormLoop :: BinderOps rep => TopDownRuleDoLoop rep
simplifyClosedFormLoop :: forall rep. BinderOps rep => TopDownRuleDoLoop rep
simplifyClosedFormLoop TopDown rep
_ Pattern rep
pat StmAux (ExpDec rep)
_ ([], [(FParam rep, SubExp)]
val, ForLoop VName
i IntType
it SubExp
bound [], BodyT rep
body) =
  RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ Pattern rep
-> [(FParam rep, SubExp)]
-> Names
-> IntType
-> SubExp
-> BodyT rep
-> RuleM rep ()
forall rep.
(ASTRep rep, BinderOps rep) =>
Pattern rep
-> [(FParam rep, SubExp)]
-> Names
-> IntType
-> SubExp
-> Body rep
-> RuleM rep ()
loopClosedForm Pattern rep
pat [(FParam rep, SubExp)]
val (VName -> Names
oneName VName
i) IntType
it SubExp
bound BodyT rep
body
simplifyClosedFormLoop TopDown rep
_ Pattern rep
_ StmAux (ExpDec rep)
_ ([(FParam rep, SubExp)], [(FParam rep, SubExp)], LoopForm rep,
 BodyT rep)
_ = Rule rep
forall rep. Rule rep
Skip

simplifyLoopVariables :: (BinderOps rep, Aliased rep) => TopDownRuleDoLoop rep
simplifyLoopVariables :: forall rep. (BinderOps rep, Aliased rep) => TopDownRuleDoLoop rep
simplifyLoopVariables TopDown rep
vtable Pattern rep
pat StmAux (ExpDec rep)
aux ([(FParam rep, SubExp)]
ctx, [(FParam rep, SubExp)]
val, form :: LoopForm rep
form@(ForLoop VName
i IntType
it SubExp
num_iters [(Param (LParamInfo rep), VName)]
loop_vars), BodyT rep
body)
  | [Maybe (RuleM rep IndexResult)]
simplifiable <- ((Param (LParamInfo rep), VName) -> Maybe (RuleM rep IndexResult))
-> [(Param (LParamInfo rep), VName)]
-> [Maybe (RuleM rep IndexResult)]
forall a b. (a -> b) -> [a] -> [b]
map (Param (LParamInfo rep), VName) -> Maybe (RuleM rep IndexResult)
forall {dec}.
Typed dec =>
(Param dec, VName) -> Maybe (RuleM rep IndexResult)
checkIfSimplifiable [(Param (LParamInfo rep), VName)]
loop_vars,
    Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (Maybe (RuleM rep IndexResult) -> Bool)
-> [Maybe (RuleM rep IndexResult)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Maybe (RuleM rep IndexResult) -> Bool
forall a. Maybe a -> Bool
isNothing [Maybe (RuleM rep IndexResult)]
simplifiable = RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
    -- Check if the simplifications throw away more information than
    -- we are comfortable with at this stage.
    ([Maybe (Param (LParamInfo rep), VName)]
maybe_loop_vars, [Stms rep]
body_prefix_stms) <-
      Scope rep
-> RuleM rep ([Maybe (Param (LParamInfo rep), VName)], [Stms rep])
-> RuleM rep ([Maybe (Param (LParamInfo rep), VName)], [Stms rep])
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (LoopForm rep -> Scope rep
forall rep a. Scoped rep a => a -> Scope rep
scopeOf LoopForm rep
form) (RuleM rep ([Maybe (Param (LParamInfo rep), VName)], [Stms rep])
 -> RuleM rep ([Maybe (Param (LParamInfo rep), VName)], [Stms rep]))
-> RuleM rep ([Maybe (Param (LParamInfo rep), VName)], [Stms rep])
-> RuleM rep ([Maybe (Param (LParamInfo rep), VName)], [Stms rep])
forall a b. (a -> b) -> a -> b
$
        [(Maybe (Param (LParamInfo rep), VName), Stms rep)]
-> ([Maybe (Param (LParamInfo rep), VName)], [Stms rep])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Maybe (Param (LParamInfo rep), VName), Stms rep)]
 -> ([Maybe (Param (LParamInfo rep), VName)], [Stms rep]))
-> RuleM rep [(Maybe (Param (LParamInfo rep), VName), Stms rep)]
-> RuleM rep ([Maybe (Param (LParamInfo rep), VName)], [Stms rep])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((Param (LParamInfo rep), VName)
 -> Maybe (RuleM rep IndexResult)
 -> RuleM rep (Maybe (Param (LParamInfo rep), VName), Stms rep))
-> [(Param (LParamInfo rep), VName)]
-> [Maybe (RuleM rep IndexResult)]
-> RuleM rep [(Maybe (Param (LParamInfo rep), VName), Stms rep)]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (Param (LParamInfo rep), VName)
-> Maybe (RuleM rep IndexResult)
-> RuleM rep (Maybe (Param (LParamInfo rep), VName), Stms rep)
forall {m :: * -> *} {dec}.
MonadBinder m =>
(Param dec, VName)
-> Maybe (m IndexResult)
-> m (Maybe (Param dec, VName), Stms (Rep m))
onLoopVar [(Param (LParamInfo rep), VName)]
loop_vars [Maybe (RuleM rep IndexResult)]
simplifiable
    if [Maybe (Param (LParamInfo rep), VName)]
maybe_loop_vars [Maybe (Param (LParamInfo rep), VName)]
-> [Maybe (Param (LParamInfo rep), VName)] -> Bool
forall a. Eq a => a -> a -> Bool
== ((Param (LParamInfo rep), VName)
 -> Maybe (Param (LParamInfo rep), VName))
-> [(Param (LParamInfo rep), VName)]
-> [Maybe (Param (LParamInfo rep), VName)]
forall a b. (a -> b) -> [a] -> [b]
map (Param (LParamInfo rep), VName)
-> Maybe (Param (LParamInfo rep), VName)
forall a. a -> Maybe a
Just [(Param (LParamInfo rep), VName)]
loop_vars
      then RuleM rep ()
forall rep a. RuleM rep a
cannotSimplify
      else do
        BodyT rep
body' <- RuleM rep [SubExp] -> RuleM rep (Body (Rep (RuleM rep)))
forall (m :: * -> *).
MonadBinder m =>
m [SubExp] -> m (Body (Rep m))
buildBody_ (RuleM rep [SubExp] -> RuleM rep (Body (Rep (RuleM rep))))
-> RuleM rep [SubExp] -> RuleM rep (Body (Rep (RuleM rep)))
forall a b. (a -> b) -> a -> b
$ do
          Stms (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *). MonadBinder m => Stms (Rep m) -> m ()
addStms (Stms (Rep (RuleM rep)) -> RuleM rep ())
-> Stms (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ [Stms rep] -> Stms rep
forall a. Monoid a => [a] -> a
mconcat [Stms rep]
body_prefix_stms
          Body (Rep (RuleM rep)) -> RuleM rep [SubExp]
forall (m :: * -> *). MonadBinder m => Body (Rep m) -> m [SubExp]
bodyBind BodyT rep
Body (Rep (RuleM rep))
body
        StmAux (ExpDec rep) -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) anyrep a.
MonadBinder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
          Pattern (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Rep m) -> Exp (Rep m) -> m ()
letBind Pattern rep
Pattern (Rep (RuleM rep))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
            [(FParam rep, SubExp)]
-> [(FParam rep, SubExp)] -> LoopForm rep -> BodyT rep -> ExpT rep
forall rep.
[(FParam rep, SubExp)]
-> [(FParam rep, SubExp)] -> LoopForm rep -> BodyT rep -> ExpT rep
DoLoop
              [(FParam rep, SubExp)]
ctx
              [(FParam rep, SubExp)]
val
              (VName
-> IntType
-> SubExp
-> [(Param (LParamInfo rep), VName)]
-> LoopForm rep
forall rep.
VName -> IntType -> SubExp -> [(LParam rep, VName)] -> LoopForm rep
ForLoop VName
i IntType
it SubExp
num_iters ([(Param (LParamInfo rep), VName)] -> LoopForm rep)
-> [(Param (LParamInfo rep), VName)] -> LoopForm rep
forall a b. (a -> b) -> a -> b
$ [Maybe (Param (LParamInfo rep), VName)]
-> [(Param (LParamInfo rep), VName)]
forall a. [Maybe a] -> [a]
catMaybes [Maybe (Param (LParamInfo rep), VName)]
maybe_loop_vars)
              BodyT rep
body'
  where
    seType :: SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType (Var VName
v)
      | VName
v VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
i = TypeBase Shape NoUniqueness -> Maybe (TypeBase Shape NoUniqueness)
forall a. a -> Maybe a
Just (TypeBase Shape NoUniqueness
 -> Maybe (TypeBase Shape NoUniqueness))
-> TypeBase Shape NoUniqueness
-> Maybe (TypeBase Shape NoUniqueness)
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase Shape NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim (PrimType -> TypeBase Shape NoUniqueness)
-> PrimType -> TypeBase Shape NoUniqueness
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
it
      | Bool
otherwise = VName -> TopDown rep -> Maybe (TypeBase Shape NoUniqueness)
forall rep.
ASTRep rep =>
VName -> SymbolTable rep -> Maybe (TypeBase Shape NoUniqueness)
ST.lookupType VName
v TopDown rep
vtable
    seType (Constant PrimValue
v) = TypeBase Shape NoUniqueness -> Maybe (TypeBase Shape NoUniqueness)
forall a. a -> Maybe a
Just (TypeBase Shape NoUniqueness
 -> Maybe (TypeBase Shape NoUniqueness))
-> TypeBase Shape NoUniqueness
-> Maybe (TypeBase Shape NoUniqueness)
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase Shape NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim (PrimType -> TypeBase Shape NoUniqueness)
-> PrimType -> TypeBase Shape NoUniqueness
forall a b. (a -> b) -> a -> b
$ PrimValue -> PrimType
primValueType PrimValue
v
    consumed_in_body :: Names
consumed_in_body = BodyT rep -> Names
forall rep. Aliased rep => Body rep -> Names
consumedInBody BodyT rep
body

    vtable' :: TopDown rep
vtable' = Scope rep -> TopDown rep
forall rep. ASTRep rep => Scope rep -> SymbolTable rep
ST.fromScope (LoopForm rep -> Scope rep
forall rep a. Scoped rep a => a -> Scope rep
scopeOf LoopForm rep
form) TopDown rep -> TopDown rep -> TopDown rep
forall a. Semigroup a => a -> a -> a
<> TopDown rep
vtable

    checkIfSimplifiable :: (Param dec, VName) -> Maybe (RuleM rep IndexResult)
checkIfSimplifiable (Param dec
p, VName
arr) =
      SymbolTable (Rep (RuleM rep))
-> (SubExp -> Maybe (TypeBase Shape NoUniqueness))
-> VName
-> Slice SubExp
-> Bool
-> Maybe (RuleM rep IndexResult)
forall (m :: * -> *).
MonadBinder m =>
SymbolTable (Rep m)
-> (SubExp -> Maybe (TypeBase Shape NoUniqueness))
-> VName
-> Slice SubExp
-> Bool
-> Maybe (m IndexResult)
simplifyIndexing
        TopDown rep
SymbolTable (Rep (RuleM rep))
vtable'
        SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType
        VName
arr
        (SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (VName -> SubExp
Var VName
i) DimIndex SubExp -> Slice SubExp -> Slice SubExp
forall a. a -> [a] -> [a]
: TypeBase Shape NoUniqueness -> Slice SubExp -> Slice SubExp
fullSlice (Param dec -> TypeBase Shape NoUniqueness
forall dec. Typed dec => Param dec -> TypeBase Shape NoUniqueness
paramType Param dec
p) [])
        (Bool -> Maybe (RuleM rep IndexResult))
-> Bool -> Maybe (RuleM rep IndexResult)
forall a b. (a -> b) -> a -> b
$ Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
p VName -> Names -> Bool
`nameIn` Names
consumed_in_body

    -- We only want this simplification if the result does not refer
    -- to 'i' at all, or does not contain accesses.
    onLoopVar :: (Param dec, VName)
-> Maybe (m IndexResult)
-> m (Maybe (Param dec, VName), Stms (Rep m))
onLoopVar (Param dec
p, VName
arr) Maybe (m IndexResult)
Nothing =
      (Maybe (Param dec, VName), Stms (Rep m))
-> m (Maybe (Param dec, VName), Stms (Rep m))
forall (m :: * -> *) a. Monad m => a -> m a
return ((Param dec, VName) -> Maybe (Param dec, VName)
forall a. a -> Maybe a
Just (Param dec
p, VName
arr), Stms (Rep m)
forall a. Monoid a => a
mempty)
    onLoopVar (Param dec
p, VName
arr) (Just m IndexResult
m) = do
      (IndexResult
x, Stms (Rep m)
x_stms) <- m IndexResult -> m (IndexResult, Stms (Rep m))
forall (m :: * -> *) a. MonadBinder m => m a -> m (a, Stms (Rep m))
collectStms m IndexResult
m
      case IndexResult
x of
        IndexResult Certificates
cs VName
arr' Slice SubExp
slice
          | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (Stm (Rep m) -> Bool) -> Stms (Rep m) -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any ((VName
i VName -> Names -> Bool
`nameIn`) (Names -> Bool) -> (Stm (Rep m) -> Names) -> Stm (Rep m) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm (Rep m) -> Names
forall a. FreeIn a => a -> Names
freeIn) Stms (Rep m)
x_stms,
            DimFix (Var VName
j) : Slice SubExp
slice' <- Slice SubExp
slice,
            VName
j VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
i,
            Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ VName
i VName -> Names -> Bool
`nameIn` Slice SubExp -> Names
forall a. FreeIn a => a -> Names
freeIn Slice SubExp
slice -> do
            Stms (Rep m) -> m ()
forall (m :: * -> *). MonadBinder m => Stms (Rep m) -> m ()
addStms Stms (Rep m)
x_stms
            SubExp
w <- Int -> TypeBase Shape NoUniqueness -> SubExp
forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 (TypeBase Shape NoUniqueness -> SubExp)
-> m (TypeBase Shape NoUniqueness) -> m SubExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> m (TypeBase Shape NoUniqueness)
forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
arr'
            VName
for_in_partial <-
              Certificates -> m VName -> m VName
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs (m VName -> m VName) -> m VName -> m VName
forall a b. (a -> b) -> a -> b
$
                String -> Exp (Rep m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Rep m) -> m VName
letExp String
"for_in_partial" (Exp (Rep m) -> m VName) -> Exp (Rep m) -> m VName
forall a b. (a -> b) -> a -> b
$
                  BasicOp -> Exp (Rep m)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$
                    VName -> Slice SubExp -> BasicOp
Index VName
arr' (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$
                      SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) SubExp
w (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) DimIndex SubExp -> Slice SubExp -> Slice SubExp
forall a. a -> [a] -> [a]
: Slice SubExp
slice'
            (Maybe (Param dec, VName), Stms (Rep m))
-> m (Maybe (Param dec, VName), Stms (Rep m))
forall (m :: * -> *) a. Monad m => a -> m a
return ((Param dec, VName) -> Maybe (Param dec, VName)
forall a. a -> Maybe a
Just (Param dec
p, VName
for_in_partial), Stms (Rep m)
forall a. Monoid a => a
mempty)
        SubExpResult Certificates
cs SubExp
se
          | (Stm (Rep m) -> Bool) -> Stms (Rep m) -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Exp (Rep m) -> Bool
forall {rep}. ExpT rep -> Bool
notIndex (Exp (Rep m) -> Bool)
-> (Stm (Rep m) -> Exp (Rep m)) -> Stm (Rep m) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm (Rep m) -> Exp (Rep m)
forall rep. Stm rep -> Exp rep
stmExp) Stms (Rep m)
x_stms -> do
            Stms (Rep m)
x_stms' <- m () -> m (Stms (Rep m))
forall (m :: * -> *) a. MonadBinder m => m a -> m (Stms (Rep m))
collectStms_ (m () -> m (Stms (Rep m))) -> m () -> m (Stms (Rep m))
forall a b. (a -> b) -> a -> b
$
              Certificates -> m () -> m ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
                Stms (Rep m) -> m ()
forall (m :: * -> *). MonadBinder m => Stms (Rep m) -> m ()
addStms Stms (Rep m)
x_stms
                [VName] -> Exp (Rep m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
p] (Exp (Rep m) -> m ()) -> Exp (Rep m) -> m ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep m)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
            (Maybe (Param dec, VName), Stms (Rep m))
-> m (Maybe (Param dec, VName), Stms (Rep m))
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (Param dec, VName)
forall a. Maybe a
Nothing, Stms (Rep m)
x_stms')
        IndexResult
_ -> (Maybe (Param dec, VName), Stms (Rep m))
-> m (Maybe (Param dec, VName), Stms (Rep m))
forall (m :: * -> *) a. Monad m => a -> m a
return ((Param dec, VName) -> Maybe (Param dec, VName)
forall a. a -> Maybe a
Just (Param dec
p, VName
arr), Stms (Rep m)
forall a. Monoid a => a
mempty)

    notIndex :: ExpT rep -> Bool
notIndex (BasicOp Index {}) = Bool
False
    notIndex ExpT rep
_ = Bool
True
simplifyLoopVariables TopDown rep
_ Pattern rep
_ StmAux (ExpDec rep)
_ ([(FParam rep, SubExp)], [(FParam rep, SubExp)], LoopForm rep,
 BodyT rep)
_ = Rule rep
forall rep. Rule rep
Skip

-- If a for-loop with no loop variables has a counter of type Int64,
-- and the bound is just a constant or sign-extended integer of
-- smaller type, then change the loop to iterate over the smaller type
-- instead.  We then move the sign extension inside the loop instead.
-- This addresses loops of the form @for i in x..<y@ in the source
-- language.
narrowLoopType :: (BinderOps rep) => TopDownRuleDoLoop rep
narrowLoopType :: forall rep. BinderOps rep => TopDownRuleDoLoop rep
narrowLoopType TopDown rep
vtable Pattern rep
pat StmAux (ExpDec rep)
aux ([(FParam rep, SubExp)]
ctx, [(FParam rep, SubExp)]
val, ForLoop VName
i IntType
Int64 SubExp
n [], BodyT rep
body)
  | Just (SubExp
n', IntType
it', Certificates
cs) <- Maybe (SubExp, IntType, Certificates)
smallerType =
    RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
      VName
i' <- String -> RuleM rep VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> RuleM rep VName) -> String -> RuleM rep VName
forall a b. (a -> b) -> a -> b
$ VName -> String
baseString VName
i
      let form' :: LoopForm rep
form' = VName -> IntType -> SubExp -> [(LParam rep, VName)] -> LoopForm rep
forall rep.
VName -> IntType -> SubExp -> [(LParam rep, VName)] -> LoopForm rep
ForLoop VName
i' IntType
it' SubExp
n' []
      BodyT rep
body' <- RuleM rep (BodyT rep) -> RuleM rep (BodyT rep)
forall (m :: * -> *).
MonadBinder m =>
m (Body (Rep m)) -> m (Body (Rep m))
insertStmsM (RuleM rep (BodyT rep) -> RuleM rep (BodyT rep))
-> (RuleM rep (BodyT rep) -> RuleM rep (BodyT rep))
-> RuleM rep (BodyT rep)
-> RuleM rep (BodyT rep)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LoopForm rep -> RuleM rep (BodyT rep) -> RuleM rep (BodyT rep)
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf LoopForm rep
forall {rep}. LoopForm rep
form' (RuleM rep (BodyT rep) -> RuleM rep (BodyT rep))
-> RuleM rep (BodyT rep) -> RuleM rep (BodyT rep)
forall a b. (a -> b) -> a -> b
$ do
        [VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
i] (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ ConvOp -> SubExp -> BasicOp
ConvOp (IntType -> IntType -> ConvOp
SExt IntType
it' IntType
Int64) (VName -> SubExp
Var VName
i')
        BodyT rep -> RuleM rep (BodyT rep)
forall (f :: * -> *) a. Applicative f => a -> f a
pure BodyT rep
body
      StmAux (ExpDec rep) -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) anyrep a.
MonadBinder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
        Certificates -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
          Pattern (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Rep m) -> Exp (Rep m) -> m ()
letBind Pattern rep
Pattern (Rep (RuleM rep))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ [(FParam rep, SubExp)]
-> [(FParam rep, SubExp)] -> LoopForm rep -> BodyT rep -> ExpT rep
forall rep.
[(FParam rep, SubExp)]
-> [(FParam rep, SubExp)] -> LoopForm rep -> BodyT rep -> ExpT rep
DoLoop [(FParam rep, SubExp)]
ctx [(FParam rep, SubExp)]
val LoopForm rep
forall {rep}. LoopForm rep
form' BodyT rep
body'
  where
    smallerType :: Maybe (SubExp, IntType, Certificates)
smallerType
      | Var VName
n' <- SubExp
n,
        Just (ConvOp (SExt IntType
it' IntType
_) SubExp
n'', Certificates
cs) <- VName -> TopDown rep -> Maybe (BasicOp, Certificates)
forall rep.
VName -> SymbolTable rep -> Maybe (BasicOp, Certificates)
ST.lookupBasicOp VName
n' TopDown rep
vtable =
        (SubExp, IntType, Certificates)
-> Maybe (SubExp, IntType, Certificates)
forall a. a -> Maybe a
Just (SubExp
n'', IntType
it', Certificates
cs)
      | Constant (IntValue (Int64Value Int64
n')) <- SubExp
n,
        Int64 -> Integer
forall a. Integral a => a -> Integer
toInteger Int64
n' Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
<= Int32 -> Integer
forall a. Integral a => a -> Integer
toInteger (Int32
forall a. Bounded a => a
maxBound :: Int32) =
        (SubExp, IntType, Certificates)
-> Maybe (SubExp, IntType, Certificates)
forall a. a -> Maybe a
Just (IntType -> Integer -> SubExp
intConst IntType
Int32 (Int64 -> Integer
forall a. Integral a => a -> Integer
toInteger Int64
n'), IntType
Int32, Certificates
forall a. Monoid a => a
mempty)
      | Bool
otherwise =
        Maybe (SubExp, IntType, Certificates)
forall a. Maybe a
Nothing
narrowLoopType TopDown rep
_ Pattern rep
_ StmAux (ExpDec rep)
_ ([(FParam rep, SubExp)], [(FParam rep, SubExp)], LoopForm rep,
 BodyT rep)
_ = Rule rep
forall rep. Rule rep
Skip

unroll ::
  BinderOps rep =>
  Integer ->
  [(FParam rep, SubExp)] ->
  (VName, IntType, Integer) ->
  [(LParam rep, VName)] ->
  Body rep ->
  RuleM rep [SubExp]
unroll :: forall rep.
BinderOps rep =>
Integer
-> [(FParam rep, SubExp)]
-> (VName, IntType, Integer)
-> [(LParam rep, VName)]
-> Body rep
-> RuleM rep [SubExp]
unroll Integer
n [(FParam rep, SubExp)]
merge (VName
iv, IntType
it, Integer
i) [(LParam rep, VName)]
loop_vars Body rep
body
  | Integer
i Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
>= Integer
n =
    [SubExp] -> RuleM rep [SubExp]
forall (m :: * -> *) a. Monad m => a -> m a
return ([SubExp] -> RuleM rep [SubExp]) -> [SubExp] -> RuleM rep [SubExp]
forall a b. (a -> b) -> a -> b
$ ((FParam rep, SubExp) -> SubExp)
-> [(FParam rep, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (FParam rep, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(FParam rep, SubExp)]
merge
  | Bool
otherwise = do
    Body rep
iter_body <- RuleM rep (Body (Rep (RuleM rep)))
-> RuleM rep (Body (Rep (RuleM rep)))
forall (m :: * -> *).
MonadBinder m =>
m (Body (Rep m)) -> m (Body (Rep m))
insertStmsM (RuleM rep (Body (Rep (RuleM rep)))
 -> RuleM rep (Body (Rep (RuleM rep))))
-> RuleM rep (Body (Rep (RuleM rep)))
-> RuleM rep (Body (Rep (RuleM rep)))
forall a b. (a -> b) -> a -> b
$ do
      [(FParam rep, SubExp)]
-> ((FParam rep, SubExp) -> RuleM rep ()) -> RuleM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(FParam rep, SubExp)]
merge (((FParam rep, SubExp) -> RuleM rep ()) -> RuleM rep ())
-> ((FParam rep, SubExp) -> RuleM rep ()) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ \(FParam rep
mergevar, SubExp
mergeinit) ->
        [VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [FParam rep -> VName
forall dec. Param dec -> VName
paramName FParam rep
mergevar] (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
mergeinit

      [VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
iv] (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
it Integer
i

      [(LParam rep, VName)]
-> ((LParam rep, VName) -> RuleM rep ()) -> RuleM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(LParam rep, VName)]
loop_vars (((LParam rep, VName) -> RuleM rep ()) -> RuleM rep ())
-> ((LParam rep, VName) -> RuleM rep ()) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ \(LParam rep
p, VName
arr) ->
        [VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [LParam rep -> VName
forall dec. Param dec -> VName
paramName LParam rep
p] (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
          BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$
            VName -> Slice SubExp -> BasicOp
Index VName
arr (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$
              SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
i) DimIndex SubExp -> Slice SubExp -> Slice SubExp
forall a. a -> [a] -> [a]
: TypeBase Shape NoUniqueness -> Slice SubExp -> Slice SubExp
fullSlice (LParam rep -> TypeBase Shape NoUniqueness
forall dec. Typed dec => Param dec -> TypeBase Shape NoUniqueness
paramType LParam rep
p) []

      -- Some of the sizes in the types here might be temporarily wrong
      -- until copy propagation fixes it up.
      Body rep -> RuleM rep (Body rep)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Body rep
body

    Body rep
iter_body' <- Body rep -> RuleM rep (Body rep)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Body rep -> m (Body rep)
renameBody Body rep
iter_body
    Stms (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *). MonadBinder m => Stms (Rep m) -> m ()
addStms (Stms (Rep (RuleM rep)) -> RuleM rep ())
-> Stms (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ Body rep -> Stms rep
forall rep. BodyT rep -> Stms rep
bodyStms Body rep
iter_body'

    let merge' :: [(FParam rep, SubExp)]
merge' = [FParam rep] -> [SubExp] -> [(FParam rep, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (((FParam rep, SubExp) -> FParam rep)
-> [(FParam rep, SubExp)] -> [FParam rep]
forall a b. (a -> b) -> [a] -> [b]
map (FParam rep, SubExp) -> FParam rep
forall a b. (a, b) -> a
fst [(FParam rep, SubExp)]
merge) ([SubExp] -> [(FParam rep, SubExp)])
-> [SubExp] -> [(FParam rep, SubExp)]
forall a b. (a -> b) -> a -> b
$ Body rep -> [SubExp]
forall rep. BodyT rep -> [SubExp]
bodyResult Body rep
iter_body'
    Integer
-> [(FParam rep, SubExp)]
-> (VName, IntType, Integer)
-> [(LParam rep, VName)]
-> Body rep
-> RuleM rep [SubExp]
forall rep.
BinderOps rep =>
Integer
-> [(FParam rep, SubExp)]
-> (VName, IntType, Integer)
-> [(LParam rep, VName)]
-> Body rep
-> RuleM rep [SubExp]
unroll Integer
n [(FParam rep, SubExp)]
merge' (VName
iv, IntType
it, Integer
i Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer
1) [(LParam rep, VName)]
loop_vars Body rep
body

simplifyKnownIterationLoop :: BinderOps rep => TopDownRuleDoLoop rep
simplifyKnownIterationLoop :: forall rep. BinderOps rep => TopDownRuleDoLoop rep
simplifyKnownIterationLoop TopDown rep
_ Pattern rep
pat StmAux (ExpDec rep)
aux ([(FParam rep, SubExp)]
ctx, [(FParam rep, SubExp)]
val, ForLoop VName
i IntType
it (Constant PrimValue
iters) [(LParam rep, VName)]
loop_vars, BodyT rep
body)
  | IntValue IntValue
n <- PrimValue
iters,
    IntValue -> Bool
zeroIshInt IntValue
n Bool -> Bool -> Bool
|| IntValue -> Bool
oneIshInt IntValue
n Bool -> Bool -> Bool
|| Attr
"unroll" Attr -> Attrs -> Bool
`inAttrs` StmAux (ExpDec rep) -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs StmAux (ExpDec rep)
aux = RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
    [SubExp]
res <- Integer
-> [(FParam rep, SubExp)]
-> (VName, IntType, Integer)
-> [(LParam rep, VName)]
-> BodyT rep
-> RuleM rep [SubExp]
forall rep.
BinderOps rep =>
Integer
-> [(FParam rep, SubExp)]
-> (VName, IntType, Integer)
-> [(LParam rep, VName)]
-> Body rep
-> RuleM rep [SubExp]
unroll (IntValue -> Integer
forall int. Integral int => IntValue -> int
valueIntegral IntValue
n) ([(FParam rep, SubExp)]
ctx [(FParam rep, SubExp)]
-> [(FParam rep, SubExp)] -> [(FParam rep, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(FParam rep, SubExp)]
val) (VName
i, IntType
it, Integer
0) [(LParam rep, VName)]
loop_vars BodyT rep
body
    [(VName, SubExp)]
-> ((VName, SubExp) -> RuleM rep ()) -> RuleM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pattern rep -> [VName]
forall dec. PatternT dec -> [VName]
patternNames Pattern rep
pat) [SubExp]
res) (((VName, SubExp) -> RuleM rep ()) -> RuleM rep ())
-> ((VName, SubExp) -> RuleM rep ()) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ \(VName
v, SubExp
se) ->
      [VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
v] (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
simplifyKnownIterationLoop TopDown rep
_ Pattern rep
_ StmAux (ExpDec rep)
_ ([(FParam rep, SubExp)], [(FParam rep, SubExp)], LoopForm rep,
 BodyT rep)
_ =
  Rule rep
forall rep. Rule rep
Skip

topDownRules :: (BinderOps rep, Aliased rep) => [TopDownRule rep]
topDownRules :: forall rep. (BinderOps rep, Aliased rep) => [TopDownRule rep]
topDownRules =
  [ RuleDoLoop rep (TopDown rep) -> TopDownRule rep
forall rep a. RuleDoLoop rep a -> SimplificationRule rep a
RuleDoLoop RuleDoLoop rep (TopDown rep)
forall rep. BinderOps rep => TopDownRuleDoLoop rep
hoistLoopInvariantMergeVariables,
    RuleDoLoop rep (TopDown rep) -> TopDownRule rep
forall rep a. RuleDoLoop rep a -> SimplificationRule rep a
RuleDoLoop RuleDoLoop rep (TopDown rep)
forall rep. BinderOps rep => TopDownRuleDoLoop rep
simplifyClosedFormLoop,
    RuleDoLoop rep (TopDown rep) -> TopDownRule rep
forall rep a. RuleDoLoop rep a -> SimplificationRule rep a
RuleDoLoop RuleDoLoop rep (TopDown rep)
forall rep. BinderOps rep => TopDownRuleDoLoop rep
simplifyKnownIterationLoop,
    RuleDoLoop rep (TopDown rep) -> TopDownRule rep
forall rep a. RuleDoLoop rep a -> SimplificationRule rep a
RuleDoLoop RuleDoLoop rep (TopDown rep)
forall rep. (BinderOps rep, Aliased rep) => TopDownRuleDoLoop rep
simplifyLoopVariables,
    RuleDoLoop rep (TopDown rep) -> TopDownRule rep
forall rep a. RuleDoLoop rep a -> SimplificationRule rep a
RuleDoLoop RuleDoLoop rep (TopDown rep)
forall rep. BinderOps rep => TopDownRuleDoLoop rep
narrowLoopType
  ]

bottomUpRules :: BinderOps rep => [BottomUpRule rep]
bottomUpRules :: forall rep. BinderOps rep => [BottomUpRule rep]
bottomUpRules =
  [ RuleDoLoop rep (BottomUp rep) -> BottomUpRule rep
forall rep a. RuleDoLoop rep a -> SimplificationRule rep a
RuleDoLoop RuleDoLoop rep (BottomUp rep)
forall rep. BinderOps rep => BottomUpRuleDoLoop rep
removeRedundantMergeVariables
  ]

-- | Standard loop simplification rules.
loopRules :: (BinderOps rep, Aliased rep) => RuleBook rep
loopRules :: forall rep. (BinderOps rep, Aliased rep) => RuleBook rep
loopRules = [TopDownRule rep] -> [BottomUpRule rep] -> RuleBook rep
forall m. [TopDownRule m] -> [BottomUpRule m] -> RuleBook m
ruleBook [TopDownRule rep]
forall rep. (BinderOps rep, Aliased rep) => [TopDownRule rep]
topDownRules [BottomUpRule rep]
forall rep. BinderOps rep => [BottomUpRule rep]
bottomUpRules