{-# LANGUAGE TypeFamilies #-}

-- | This module defines a collection of simplification rules, as per
-- "Futhark.Optimise.Simplify.Rule".  They are used in the
-- simplifier.
--
-- For performance reasons, many sufficiently simple logically
-- separate rules are merged into single "super-rules", like ruleIf
-- and ruleBasicOp.  This is because it is relatively expensive to
-- activate a rule just to determine that it does not apply.  Thus, it
-- is more efficient to have a few very fat rules than a lot of small
-- rules.  This does not affect the compiler result in any way; it is
-- purely an optimisation to speed up compilation.
module Futhark.Optimise.Simplify.Rules
  ( standardRules,
    removeUnnecessaryCopy,
  )
where

import Control.Monad
import Control.Monad.State
import Data.List (insert, unzip4, zip4)
import Data.Map.Strict qualified as M
import Data.Maybe
import Futhark.Analysis.PrimExp.Convert
import Futhark.Analysis.SymbolTable qualified as ST
import Futhark.Analysis.UsageTable qualified as UT
import Futhark.Construct
import Futhark.IR
import Futhark.IR.Prop.Aliases
import Futhark.Optimise.Simplify.Rule
import Futhark.Optimise.Simplify.Rules.BasicOp
import Futhark.Optimise.Simplify.Rules.Index
import Futhark.Optimise.Simplify.Rules.Loop
import Futhark.Optimise.Simplify.Rules.Match
import Futhark.Util

topDownRules :: BuilderOps rep => [TopDownRule rep]
topDownRules :: forall rep. BuilderOps rep => [TopDownRule rep]
topDownRules =
  [ forall {k} (rep :: k) a.
RuleGeneric rep a -> SimplificationRule rep a
RuleGeneric forall rep. BuilderOps rep => TopDownRuleGeneric rep
constantFoldPrimFun,
    forall {k} (rep :: k) a.
RuleGeneric rep a -> SimplificationRule rep a
RuleGeneric forall rep. BuilderOps rep => TopDownRuleGeneric rep
withAccTopDown
  ]

bottomUpRules :: (BuilderOps rep, TraverseOpStms rep) => [BottomUpRule rep]
bottomUpRules :: forall rep.
(BuilderOps rep, TraverseOpStms rep) =>
[BottomUpRule rep]
bottomUpRules =
  [ forall {k} (rep :: k) a.
RuleGeneric rep a -> SimplificationRule rep a
RuleGeneric forall rep.
(TraverseOpStms rep, BuilderOps rep) =>
BottomUpRuleGeneric rep
withAccBottomUp,
    forall {k} (rep :: k) a.
RuleBasicOp rep a -> SimplificationRule rep a
RuleBasicOp forall rep. BuilderOps rep => BottomUpRuleBasicOp rep
simplifyIndex
  ]

-- | A set of standard simplification rules.  These assume pure
-- functional semantics, and so probably should not be applied after
-- memory block merging.
standardRules :: (BuilderOps rep, TraverseOpStms rep, Aliased rep) => RuleBook rep
standardRules :: forall rep.
(BuilderOps rep, TraverseOpStms rep, Aliased rep) =>
RuleBook rep
standardRules =
  forall {k} (m :: k).
[TopDownRule m] -> [BottomUpRule m] -> RuleBook m
ruleBook forall rep. BuilderOps rep => [TopDownRule rep]
topDownRules forall rep.
(BuilderOps rep, TraverseOpStms rep) =>
[BottomUpRule rep]
bottomUpRules
    forall a. Semigroup a => a -> a -> a
<> forall rep. (BuilderOps rep, Aliased rep) => RuleBook rep
loopRules
    forall a. Semigroup a => a -> a -> a
<> forall rep. (BuilderOps rep, Aliased rep) => RuleBook rep
basicOpRules
    forall a. Semigroup a => a -> a -> a
<> forall rep. BuilderOps rep => RuleBook rep
matchRules

-- | Turn @copy(x)@ into @x@ iff @x@ is not used after this copy
-- statement and it can be consumed.
--
-- This simplistic rule is only valid before we introduce memory.
removeUnnecessaryCopy :: BuilderOps rep => BottomUpRuleBasicOp rep
removeUnnecessaryCopy :: forall rep. BuilderOps rep => BottomUpRuleBasicOp rep
removeUnnecessaryCopy (SymbolTable rep
vtable, UsageTable
used) (Pat [PatElem (LetDec rep)
d]) StmAux (ExpDec rep)
_ (Copy VName
v)
  | Bool -> Bool
not (VName
v VName -> UsageTable -> Bool
`UT.isConsumed` UsageTable
used),
    -- This two first clauses below are too conservative, but the
    -- problem is that 'v' might not look like it has been consumed if
    -- it is consumed in an outer scope.  This is because the
    -- simplifier applies bottom-up rules in a kind of deepest-first
    -- order.
    Bool -> Bool
not (forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
d VName -> UsageTable -> Bool
`UT.isInResult` UsageTable
used)
      Bool -> Bool -> Bool
|| forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
d
      VName -> UsageTable -> Bool
`UT.isConsumed` UsageTable
used
      -- Always OK to remove the copy if 'v' has no aliases and is never
      -- used again.
      Bool -> Bool -> Bool
|| (Bool
v_is_fresh Bool -> Bool -> Bool
&& Bool
v_not_used_again),
    (Bool
v_not_used_again Bool -> Bool -> Bool
&& Bool
consumable) Bool -> Bool -> Bool
|| Bool -> Bool
not (forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
d VName -> UsageTable -> Bool
`UT.isConsumed` UsageTable
used) =
      forall {k} (rep :: k). RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
d] forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v
  where
    v_not_used_again :: Bool
v_not_used_again = Bool -> Bool
not (VName
v VName -> UsageTable -> Bool
`UT.used` UsageTable
used)
    v_is_fresh :: Bool
v_is_fresh = VName
v forall {k} (rep :: k). VName -> SymbolTable rep -> Names
`ST.lookupAliases` SymbolTable rep
vtable forall a. Eq a => a -> a -> Bool
== forall a. Monoid a => a
mempty
    -- We need to make sure we can even consume the original.  The big
    -- missing piece here is that we cannot do copy removal inside of
    -- 'map' and other SOACs, but that is handled by SOAC-specific rules.
    consumable :: Bool
consumable = forall a. a -> Maybe a -> a
fromMaybe Bool
False forall a b. (a -> b) -> a -> b
$ do
      Entry rep
e <- forall {k} (rep :: k).
VName -> SymbolTable rep -> Maybe (Entry rep)
ST.lookup VName
v SymbolTable rep
vtable
      forall (f :: * -> *). Alternative f => Bool -> f ()
guard forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Entry rep -> Int
ST.entryDepth Entry rep
e forall a. Eq a => a -> a -> Bool
== forall {k} (rep :: k). SymbolTable rep -> Int
ST.loopDepth SymbolTable rep
vtable
      Entry rep -> Maybe Bool
consumableStm Entry rep
e forall (m :: * -> *) a. MonadPlus m => m a -> m a -> m a
`mplus` Entry rep -> Maybe Bool
consumableFParam Entry rep
e
    consumableFParam :: Entry rep -> Maybe Bool
consumableFParam =
      forall a. a -> Maybe a
Just forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
False (forall shape. TypeBase shape Uniqueness -> Bool
unique forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall t. DeclTyped t => t -> DeclType
declTypeOf) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Entry rep -> Maybe (FParamInfo rep)
ST.entryFParam
    consumableStm :: Entry rep -> Maybe Bool
consumableStm Entry rep
e = do
      forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Entry rep -> Maybe (Stm rep)
ST.entryStm Entry rep
e -- Must be a stm.
      forall (f :: * -> *). Alternative f => Bool -> f ()
guard Bool
v_is_fresh
      forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
True
removeUnnecessaryCopy (SymbolTable rep, UsageTable)
_ Pat (LetDec rep)
_ StmAux (ExpDec rep)
_ BasicOp
_ = forall {k} (rep :: k). Rule rep
Skip

constantFoldPrimFun :: BuilderOps rep => TopDownRuleGeneric rep
constantFoldPrimFun :: forall rep. BuilderOps rep => TopDownRuleGeneric rep
constantFoldPrimFun TopDown rep
_ (Let Pat (LetDec rep)
pat (StmAux Certs
cs Attrs
attrs ExpDec rep
_) (Apply Name
fname [(SubExp, Diet)]
args [RetType rep]
_ (Safety, SrcLoc, [SrcLoc])
_))
  | Just [PrimValue]
args' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SubExp -> Maybe PrimValue
isConst forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) [(SubExp, Diet)]
args,
    Just ([PrimType]
_, PrimType
_, [PrimValue] -> Maybe PrimValue
fun) <- forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (Name -> String
nameToString Name
fname) Map String ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
primFuns,
    Just PrimValue
result <- [PrimValue] -> Maybe PrimValue
fun [PrimValue]
args' =
      forall {k} (rep :: k). RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$
        forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs forall a b. (a -> b) -> a -> b
$
          forall (m :: * -> *) a. MonadBuilder m => Attrs -> m a -> m a
attributing Attrs
attrs forall a b. (a -> b) -> a -> b
$
            forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
pat forall a b. (a -> b) -> a -> b
$
              forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
                SubExp -> BasicOp
SubExp forall a b. (a -> b) -> a -> b
$
                  PrimValue -> SubExp
Constant PrimValue
result
  where
    isConst :: SubExp -> Maybe PrimValue
isConst (Constant PrimValue
v) = forall a. a -> Maybe a
Just PrimValue
v
    isConst SubExp
_ = forall a. Maybe a
Nothing
constantFoldPrimFun TopDown rep
_ Stm rep
_ = forall {k} (rep :: k). Rule rep
Skip

simplifyIndex :: BuilderOps rep => BottomUpRuleBasicOp rep
simplifyIndex :: forall rep. BuilderOps rep => BottomUpRuleBasicOp rep
simplifyIndex (SymbolTable rep
vtable, UsageTable
used) pat :: Pat (LetDec rep)
pat@(Pat [PatElem (LetDec rep)
pe]) (StmAux Certs
cs Attrs
attrs ExpDec rep
_) (Index VName
idd Slice SubExp
inds)
  | Just RuleM rep IndexResult
m <- forall (m :: * -> *).
MonadBuilder m =>
SymbolTable (Rep m)
-> TypeLookup
-> VName
-> Slice SubExp
-> Bool
-> Maybe (m IndexResult)
simplifyIndexing SymbolTable rep
vtable TypeLookup
seType VName
idd Slice SubExp
inds Bool
consumed = forall {k} (rep :: k). RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$ do
      IndexResult
res <- RuleM rep IndexResult
m
      forall (m :: * -> *) a. MonadBuilder m => Attrs -> m a -> m a
attributing Attrs
attrs forall a b. (a -> b) -> a -> b
$ case IndexResult
res of
        SubExpResult Certs
cs' SubExp
se ->
          forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (Certs
cs forall a. Semigroup a => a -> a -> a
<> Certs
cs') forall a b. (a -> b) -> a -> b
$
            forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames (forall dec. Pat dec -> [VName]
patNames Pat (LetDec rep)
pat) forall a b. (a -> b) -> a -> b
$
              forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
                SubExp -> BasicOp
SubExp SubExp
se
        IndexResult Certs
extra_cs VName
idd' Slice SubExp
inds' ->
          forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (Certs
cs forall a. Semigroup a => a -> a -> a
<> Certs
extra_cs) forall a b. (a -> b) -> a -> b
$
            forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames (forall dec. Pat dec -> [VName]
patNames Pat (LetDec rep)
pat) forall a b. (a -> b) -> a -> b
$
              forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
                VName -> Slice SubExp -> BasicOp
Index VName
idd' Slice SubExp
inds'
  where
    consumed :: Bool
consumed = forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe VName -> UsageTable -> Bool
`UT.isConsumed` UsageTable
used
    seType :: TypeLookup
seType (Var VName
v) = forall {k} (rep :: k).
ASTRep rep =>
VName -> SymbolTable rep -> Maybe Type
ST.lookupType VName
v SymbolTable rep
vtable
    seType (Constant PrimValue
v) = forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim forall a b. (a -> b) -> a -> b
$ PrimValue -> PrimType
primValueType PrimValue
v
simplifyIndex (SymbolTable rep, UsageTable)
_ Pat (LetDec rep)
_ StmAux (ExpDec rep)
_ BasicOp
_ = forall {k} (rep :: k). Rule rep
Skip

withAccTopDown :: BuilderOps rep => TopDownRuleGeneric rep
-- A WithAcc with no accumulators is sent to Valhalla.
withAccTopDown :: forall rep. BuilderOps rep => TopDownRuleGeneric rep
withAccTopDown TopDown rep
_ (Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux (WithAcc [] Lambda rep
lam)) = forall {k} (rep :: k). RuleM rep () -> Rule rep
Simplify forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux forall a b. (a -> b) -> a -> b
$ do
  Result
lam_res <- forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda rep
lam
  forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip (forall dec. Pat dec -> [VName]
patNames Pat (LetDec rep)
pat) Result
lam_res) forall a b. (a -> b) -> a -> b
$ \(VName
v, SubExpRes Certs
cs SubExp
se) ->
    forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
v] forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
-- Identify those results in 'lam' that are free and move them out.
withAccTopDown TopDown rep
vtable (Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux (WithAcc [WithAccInput rep]
inputs Lambda rep
lam)) = forall {k} (rep :: k). RuleM rep () -> Rule rep
Simplify forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux forall a b. (a -> b) -> a -> b
$ do
  let ([Param (LParamInfo rep)]
cert_params, [Param (LParamInfo rep)]
acc_params) =
        forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. Foldable t => t a -> Int
length [WithAccInput rep]
inputs) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam
      (Result
acc_res, Result
nonacc_res) =
        forall a. Int -> [a] -> ([a], [a])
splitFromEnd Int
num_nonaccs forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Body rep -> Result
bodyResult forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda rep
lam
      ([PatElem (LetDec rep)]
acc_pes, [PatElem (LetDec rep)]
nonacc_pes) =
        forall a. Int -> [a] -> ([a], [a])
splitFromEnd Int
num_nonaccs forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec rep)
pat

  -- Look at accumulator results.
  ([[PatElem (LetDec rep)]]
acc_pes', [WithAccInput rep]
inputs', [(Param (LParamInfo rep), Param (LParamInfo rep))]
params', Result
acc_res') <-
    forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall a b c d. [(a, b, c, d)] -> ([a], [b], [c], [d])
unzip4 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. [Maybe a] -> [a]
catMaybes) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {m :: * -> *} {dec} {a} {c} {a} {dec}.
MonadBuilder m =>
([PatElem dec], (a, [VName], c), (a, Param dec), SubExpRes)
-> m (Maybe
        ([PatElem dec], (a, [VName], c), (a, Param dec), SubExpRes))
tryMoveAcc forall a b. (a -> b) -> a -> b
$
      forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4
        (forall a. [Int] -> [a] -> [[a]]
chunks (forall a b. (a -> b) -> [a] -> [b]
map forall {t :: * -> *} {a} {a} {c}. Foldable t => (a, t a, c) -> Int
inputArrs [WithAccInput rep]
inputs) [PatElem (LetDec rep)]
acc_pes)
        [WithAccInput rep]
inputs
        (forall a b. [a] -> [b] -> [(a, b)]
zip [Param (LParamInfo rep)]
cert_params [Param (LParamInfo rep)]
acc_params)
        Result
acc_res
  let ([Param (LParamInfo rep)]
cert_params', [Param (LParamInfo rep)]
acc_params') = forall a b. [(a, b)] -> ([a], [b])
unzip [(Param (LParamInfo rep), Param (LParamInfo rep))]
params'

  -- Look at non-accumulator results.
  ([PatElem (LetDec rep)]
nonacc_pes', Result
nonacc_res') <-
    forall a b. [(a, b)] -> ([a], [b])
unzip forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. [Maybe a] -> [a]
catMaybes forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (PatElem (LetDec rep), SubExpRes)
-> RuleM rep (Maybe (PatElem (LetDec rep), SubExpRes))
tryMoveNonAcc (forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem (LetDec rep)]
nonacc_pes Result
nonacc_res)

  forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[PatElem (LetDec rep)]]
acc_pes' forall a. Eq a => a -> a -> Bool
== [PatElem (LetDec rep)]
acc_pes Bool -> Bool -> Bool
&& [PatElem (LetDec rep)]
nonacc_pes' forall a. Eq a => a -> a -> Bool
== [PatElem (LetDec rep)]
nonacc_pes) forall {k} (rep :: k) a. RuleM rep a
cannotSimplify

  Lambda rep
lam' <-
    forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda ([Param (LParamInfo rep)]
cert_params' forall a. [a] -> [a] -> [a]
++ [Param (LParamInfo rep)]
acc_params') forall a b. (a -> b) -> a -> b
$
      forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind forall a b. (a -> b) -> a -> b
$
        (forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda rep
lam) {bodyResult :: Result
bodyResult = Result
acc_res' forall a. Semigroup a => a -> a -> a
<> Result
nonacc_res'}

  forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind (forall dec. [PatElem dec] -> Pat dec
Pat (forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[PatElem (LetDec rep)]]
acc_pes' forall a. Semigroup a => a -> a -> a
<> [PatElem (LetDec rep)]
nonacc_pes')) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). [WithAccInput rep] -> Lambda rep -> Exp rep
WithAcc [WithAccInput rep]
inputs' Lambda rep
lam'
  where
    num_nonaccs :: Int
num_nonaccs = forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda rep
lam) forall a. Num a => a -> a -> a
- forall (t :: * -> *) a. Foldable t => t a -> Int
length [WithAccInput rep]
inputs
    inputArrs :: (a, t a, c) -> Int
inputArrs (a
_, t a
arrs, c
_) = forall (t :: * -> *) a. Foldable t => t a -> Int
length t a
arrs

    tryMoveAcc :: ([PatElem dec], (a, [VName], c), (a, Param dec), SubExpRes)
-> m (Maybe
        ([PatElem dec], (a, [VName], c), (a, Param dec), SubExpRes))
tryMoveAcc ([PatElem dec]
pes, (a
_, [VName]
arrs, c
_), (a
_, Param dec
acc_p), SubExpRes Certs
cs (Var VName
v))
      | forall dec. Param dec -> VName
paramName Param dec
acc_p forall a. Eq a => a -> a -> Bool
== VName
v,
        Certs
cs forall a. Eq a => a -> a -> Bool
== forall a. Monoid a => a
mempty = do
          forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem dec]
pes [VName]
arrs) forall a b. (a -> b) -> a -> b
$ \(PatElem dec
pe, VName
arr) ->
            forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [forall dec. PatElem dec -> VName
patElemName PatElem dec
pe] forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
arr
          forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing
    tryMoveAcc ([PatElem dec], (a, [VName], c), (a, Param dec), SubExpRes)
x =
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just ([PatElem dec], (a, [VName], c), (a, Param dec), SubExpRes)
x

    tryMoveNonAcc :: (PatElem (LetDec rep), SubExpRes)
-> RuleM rep (Maybe (PatElem (LetDec rep), SubExpRes))
tryMoveNonAcc (PatElem (LetDec rep)
pe, SubExpRes Certs
cs (Var VName
v))
      | VName
v forall {k} (rep :: k). VName -> SymbolTable rep -> Bool
`ST.elem` TopDown rep
vtable,
        Certs
cs forall a. Eq a => a -> a -> Bool
== forall a. Monoid a => a
mempty = do
          forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe] forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v
          forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing
    tryMoveNonAcc (PatElem (LetDec rep)
pe, SubExpRes Certs
cs (Constant PrimValue
v))
      | Certs
cs forall a. Eq a => a -> a -> Bool
== forall a. Monoid a => a
mempty = do
          forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe] forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant PrimValue
v
          forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing
    tryMoveNonAcc (PatElem (LetDec rep), SubExpRes)
x =
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just (PatElem (LetDec rep), SubExpRes)
x
withAccTopDown TopDown rep
_ Stm rep
_ = forall {k} (rep :: k). Rule rep
Skip

elimUpdates :: (ASTRep rep, TraverseOpStms rep) => [VName] -> Body rep -> (Body rep, [VName])
elimUpdates :: forall {k} (rep :: k).
(ASTRep rep, TraverseOpStms rep) =>
[VName] -> Body rep -> (Body rep, [VName])
elimUpdates [VName]
get_rid_of = forall a b c. (a -> b -> c) -> b -> a -> c
flip forall s a. State s a -> s -> (a, s)
runState forall a. Monoid a => a
mempty forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body rep -> StateT [VName] Identity (Body rep)
onBody
  where
    onBody :: Body rep -> StateT [VName] Identity (Body rep)
onBody Body rep
body = do
      Stms rep
stms' <- Stms rep -> StateT [VName] Identity (Stms rep)
onStms forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Body rep -> Stms rep
bodyStms Body rep
body
      forall (f :: * -> *) a. Applicative f => a -> f a
pure Body rep
body {bodyStms :: Stms rep
bodyStms = Stms rep
stms'}
    onStms :: Stms rep -> StateT [VName] Identity (Stms rep)
onStms = forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse Stm rep -> StateT [VName] Identity (Stm rep)
onStm
    onStm :: Stm rep -> StateT [VName] Identity (Stm rep)
onStm (Let pat :: Pat (LetDec rep)
pat@(Pat [PatElem VName
_ LetDec rep
dec]) StmAux (ExpDec rep)
aux (BasicOp (UpdateAcc VName
acc [SubExp]
_ [SubExp]
_)))
      | Acc VName
c Shape
_ [Type]
_ NoUniqueness
_ <- forall t. Typed t => t -> Type
typeOf LetDec rep
dec,
        VName
c forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
get_rid_of = do
          forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (forall a. Ord a => a -> [a] -> [a]
insert VName
c)
          forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
acc
    onStm (Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux Exp rep
e) = forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp rep -> StateT [VName] Identity (Exp rep)
onExp Exp rep
e
    onExp :: Exp rep -> StateT [VName] Identity (Exp rep)
onExp = forall {k1} {k2} (m :: * -> *) (frep :: k1) (trep :: k2).
Monad m =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM Mapper rep rep (StateT [VName] Identity)
mapper
      where
        mapper :: Mapper rep rep (StateT [VName] Identity)
mapper =
          forall {k} (m :: * -> *) (rep :: k). Monad m => Mapper rep rep m
identityMapper
            { mapOnOp :: Op rep -> StateT [VName] Identity (Op rep)
mapOnOp = forall {k} (rep :: k) (m :: * -> *).
(TraverseOpStms rep, Monad m) =>
OpStmsTraverser m (Op rep) rep
traverseOpStms (\Scope rep
_ Stms rep
stms -> Stms rep -> StateT [VName] Identity (Stms rep)
onStms Stms rep
stms),
              mapOnBody :: Scope rep -> Body rep -> StateT [VName] Identity (Body rep)
mapOnBody = \Scope rep
_ Body rep
body -> Body rep -> StateT [VName] Identity (Body rep)
onBody Body rep
body
            }

withAccBottomUp :: (TraverseOpStms rep, BuilderOps rep) => BottomUpRuleGeneric rep
-- Eliminate dead results.  See Note [Dead Code Elimination for WithAcc]
withAccBottomUp :: forall rep.
(TraverseOpStms rep, BuilderOps rep) =>
BottomUpRuleGeneric rep
withAccBottomUp (SymbolTable rep
_, UsageTable
utable) (Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux (WithAcc [WithAccInput rep]
inputs Lambda rep
lam))
  | Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> UsageTable -> Bool
`UT.used` UsageTable
utable) forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [VName]
patNames Pat (LetDec rep)
pat = forall {k} (rep :: k). RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$ do
      let (Result
acc_res, Result
nonacc_res) =
            forall a. Int -> [a] -> ([a], [a])
splitFromEnd Int
num_nonaccs forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Body rep -> Result
bodyResult forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda rep
lam
          ([PatElem (LetDec rep)]
acc_pes, [PatElem (LetDec rep)]
nonacc_pes) =
            forall a. Int -> [a] -> ([a], [a])
splitFromEnd Int
num_nonaccs forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec rep)
pat
          ([Param (LParamInfo rep)]
cert_params, [Param (LParamInfo rep)]
acc_params) =
            forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. Foldable t => t a -> Int
length [WithAccInput rep]
inputs) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam

      -- Eliminate unused accumulator results
      let get_rid_of :: [VName]
get_rid_of =
            forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. (a -> Bool) -> [a] -> [a]
filter ([PatElem (LetDec rep)], VName) -> Bool
getRidOf
              forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip
                (forall a. [Int] -> [a] -> [[a]]
chunks (forall a b. (a -> b) -> [a] -> [b]
map forall {t :: * -> *} {a} {a} {c}. Foldable t => (a, t a, c) -> Int
inputArrs [WithAccInput rep]
inputs) [PatElem (LetDec rep)]
acc_pes)
              forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName [Param (LParamInfo rep)]
cert_params

      -- Eliminate unused non-accumulator results
      let ([PatElem (LetDec rep)]
nonacc_pes', Result
nonacc_res') =
            forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ forall a. (a -> Bool) -> [a] -> [a]
filter (PatElem (LetDec rep), SubExpRes) -> Bool
keepNonAccRes forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem (LetDec rep)]
nonacc_pes Result
nonacc_res

      forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (forall (t :: * -> *) a. Foldable t => t a -> Bool
null [VName]
get_rid_of Bool -> Bool -> Bool
&& [PatElem (LetDec rep)]
nonacc_pes' forall a. Eq a => a -> a -> Bool
== [PatElem (LetDec rep)]
nonacc_pes) forall {k} (rep :: k) a. RuleM rep a
cannotSimplify

      let (Body rep
body', [VName]
eliminated) = forall {k} (rep :: k).
(ASTRep rep, TraverseOpStms rep) =>
[VName] -> Body rep -> (Body rep, [VName])
elimUpdates [VName]
get_rid_of forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda rep
lam

      forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (forall (t :: * -> *) a. Foldable t => t a -> Bool
null [VName]
eliminated Bool -> Bool -> Bool
&& [PatElem (LetDec rep)]
nonacc_pes' forall a. Eq a => a -> a -> Bool
== [PatElem (LetDec rep)]
nonacc_pes) forall {k} (rep :: k) a. RuleM rep a
cannotSimplify

      let pes' :: [PatElem (LetDec rep)]
pes' = [PatElem (LetDec rep)]
acc_pes forall a. [a] -> [a] -> [a]
++ [PatElem (LetDec rep)]
nonacc_pes'

      Lambda rep
lam' <- forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda ([Param (LParamInfo rep)]
cert_params forall a. [a] -> [a] -> [a]
++ [Param (LParamInfo rep)]
acc_params) forall a b. (a -> b) -> a -> b
$ do
        forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind Body rep
body'
        forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Result
acc_res forall a. [a] -> [a] -> [a]
++ Result
nonacc_res'

      forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind (forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (LetDec rep)]
pes') forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). [WithAccInput rep] -> Lambda rep -> Exp rep
WithAcc [WithAccInput rep]
inputs Lambda rep
lam'
  where
    num_nonaccs :: Int
num_nonaccs = forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda rep
lam) forall a. Num a => a -> a -> a
- forall (t :: * -> *) a. Foldable t => t a -> Int
length [WithAccInput rep]
inputs
    inputArrs :: (a, t a, c) -> Int
inputArrs (a
_, t a
arrs, c
_) = forall (t :: * -> *) a. Foldable t => t a -> Int
length t a
arrs
    getRidOf :: ([PatElem (LetDec rep)], VName) -> Bool
getRidOf ([PatElem (LetDec rep)]
pes, VName
_) = Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any ((VName -> UsageTable -> Bool
`UT.used` UsageTable
utable) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. PatElem dec -> VName
patElemName) [PatElem (LetDec rep)]
pes
    keepNonAccRes :: (PatElem (LetDec rep), SubExpRes) -> Bool
keepNonAccRes (PatElem (LetDec rep)
pe, SubExpRes
_) = forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe VName -> UsageTable -> Bool
`UT.used` UsageTable
utable
withAccBottomUp (SymbolTable rep, UsageTable)
_ Stm rep
_ = forall {k} (rep :: k). Rule rep
Skip

-- Note [Dead Code Elimination for WithAcc]
--
-- Our static semantics for accumulators are basically those of linear
-- types.  This makes dead code elimination somewhat tricky.  First,
-- what we consider dead code is when we have a WithAcc where at least
-- one of the array results (that internally correspond to an
-- accumulator) are unused.  E.g
--
-- let {X',Y'} =
--   with_acc {X, Y} (\X_p Y_p X_acc Y_acc -> ... {X_acc', Y_acc'})
--
-- where X' is not used later.  Note that Y' is still used later.  If
-- none of the results of the WithAcc are used, then the Stm as a
-- whole is dead and can be removed.  That's the trivial case, done
-- implicitly by the simplifier.  The interesting case is exactly when
-- some of the results are unused.  How do we get rid of them?
--
-- Naively, we might just remove them:
--
-- let Y' =
--   with_acc Y (\Y_p Y_acc -> ... Y_acc')
--
-- This is safe *only* if X_acc is used *only* in the result (i.e. an
-- "identity" WithAcc).  Otherwise we end up with references to X_acc,
-- which no longer exists.  This simple case is actually handled in
-- the withAccTopDown rule, and is easy enough.
--
-- What we actually do when we decide to eliminate X_acc is that we
-- inspect the body of the WithAcc and eliminate all UpdateAcc
-- operations that refer to the same accumulator as X_acc (identified
-- by the X_p token).  I.e. we turn every
--
-- let B = update_acc(A, ...)
--
-- where 'A' is ultimately decided from X_acc into
--
-- let B = A
--
-- That's it!  We then let ordinary dead code elimination eventually
-- simplify the body enough that we have an "identity" WithAcc.  There
-- is no _guarantee_ that this will happen, but our general dead code
-- elimination tends to be prettyString good.