{-# LANGUAGE TypeFamilies #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}

module Futhark.IR.SOACS.Simplify
  ( simplifySOACS,
    simplifyLambda,
    simplifyFun,
    simplifyStms,
    simplifyConsts,
    simpleSOACS,
    simplifySOAC,
    soacRules,
    HasSOAC (..),
    simplifyKnownIterationSOAC,
    removeReplicateMapping,
    removeUnusedSOACInput,
    liftIdentityMapping,
    simplifyMapIota,
    SOACS,
  )
where

import Control.Monad
import Control.Monad.Identity
import Control.Monad.State
import Control.Monad.Writer
import Data.Either
import Data.Foldable
import Data.List (partition, transpose, unzip6, zip6)
import Data.List.NonEmpty (NonEmpty (..))
import Data.Map.Strict qualified as M
import Data.Maybe
import Data.Set qualified as S
import Futhark.Analysis.DataDependencies
import Futhark.Analysis.SymbolTable qualified as ST
import Futhark.Analysis.UsageTable qualified as UT
import Futhark.IR.Prop.Aliases
import Futhark.IR.SOACS
import Futhark.MonadFreshNames
import Futhark.Optimise.Simplify qualified as Simplify
import Futhark.Optimise.Simplify.Engine qualified as Engine
import Futhark.Optimise.Simplify.Rep
import Futhark.Optimise.Simplify.Rule
import Futhark.Optimise.Simplify.Rules
import Futhark.Optimise.Simplify.Rules.ClosedForm
import Futhark.Pass
import Futhark.Tools
import Futhark.Transform.Rename
import Futhark.Util

simpleSOACS :: Simplify.SimpleOps SOACS
simpleSOACS :: SimpleOps SOACS
simpleSOACS = forall rep.
(SimplifiableRep rep, Buildable rep) =>
SimplifyOp rep (Op (Wise rep)) -> SimpleOps rep
Simplify.bindableSimpleOps forall rep. SimplifiableRep rep => SimplifyOp rep (SOAC (Wise rep))
simplifySOAC

simplifySOACS :: Prog SOACS -> PassM (Prog SOACS)
simplifySOACS :: Prog SOACS -> PassM (Prog SOACS)
simplifySOACS =
  forall rep.
SimplifiableRep rep =>
SimpleOps rep
-> RuleBook (Wise rep)
-> HoistBlockers rep
-> Prog rep
-> PassM (Prog rep)
Simplify.simplifyProg SimpleOps SOACS
simpleSOACS RuleBook (Wise SOACS)
soacRules forall {k} (rep :: k). HoistBlockers rep
Engine.noExtraHoistBlockers

simplifyFun ::
  MonadFreshNames m =>
  ST.SymbolTable (Wise SOACS) ->
  FunDef SOACS ->
  m (FunDef SOACS)
simplifyFun :: forall (m :: * -> *).
MonadFreshNames m =>
SymbolTable (Wise SOACS) -> FunDef SOACS -> m (FunDef SOACS)
simplifyFun =
  forall (m :: * -> *) rep.
(MonadFreshNames m, SimplifiableRep rep) =>
SimpleOps rep
-> RuleBook (Wise rep)
-> HoistBlockers rep
-> SymbolTable (Wise rep)
-> FunDef rep
-> m (FunDef rep)
Simplify.simplifyFun SimpleOps SOACS
simpleSOACS RuleBook (Wise SOACS)
soacRules forall {k} (rep :: k). HoistBlockers rep
Engine.noExtraHoistBlockers

simplifyLambda ::
  (HasScope SOACS m, MonadFreshNames m) => Lambda SOACS -> m (Lambda SOACS)
simplifyLambda :: forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
Lambda SOACS -> m (Lambda SOACS)
simplifyLambda =
  forall (m :: * -> *) rep.
(MonadFreshNames m, HasScope rep m, SimplifiableRep rep) =>
SimpleOps rep
-> RuleBook (Wise rep)
-> HoistBlockers rep
-> Lambda rep
-> m (Lambda rep)
Simplify.simplifyLambda SimpleOps SOACS
simpleSOACS RuleBook (Wise SOACS)
soacRules forall {k} (rep :: k). HoistBlockers rep
Engine.noExtraHoistBlockers

simplifyStms ::
  (HasScope SOACS m, MonadFreshNames m) => Stms SOACS -> m (Stms SOACS)
simplifyStms :: forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
Stms SOACS -> m (Stms SOACS)
simplifyStms Stms SOACS
stms = do
  Scope SOACS
scope <- forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope
  forall (m :: * -> *) rep.
(MonadFreshNames m, SimplifiableRep rep) =>
SimpleOps rep
-> RuleBook (Wise rep)
-> HoistBlockers rep
-> Scope rep
-> Stms rep
-> m (Stms rep)
Simplify.simplifyStms SimpleOps SOACS
simpleSOACS RuleBook (Wise SOACS)
soacRules forall {k} (rep :: k). HoistBlockers rep
Engine.noExtraHoistBlockers Scope SOACS
scope Stms SOACS
stms

simplifyConsts ::
  MonadFreshNames m => Stms SOACS -> m (Stms SOACS)
simplifyConsts :: forall (m :: * -> *).
MonadFreshNames m =>
Stms SOACS -> m (Stms SOACS)
simplifyConsts =
  forall (m :: * -> *) rep.
(MonadFreshNames m, SimplifiableRep rep) =>
SimpleOps rep
-> RuleBook (Wise rep)
-> HoistBlockers rep
-> Scope rep
-> Stms rep
-> m (Stms rep)
Simplify.simplifyStms SimpleOps SOACS
simpleSOACS RuleBook (Wise SOACS)
soacRules forall {k} (rep :: k). HoistBlockers rep
Engine.noExtraHoistBlockers forall a. Monoid a => a
mempty

simplifySOAC ::
  Simplify.SimplifiableRep rep =>
  Simplify.SimplifyOp rep (SOAC (Wise rep))
simplifySOAC :: forall rep. SimplifiableRep rep => SimplifyOp rep (SOAC (Wise rep))
simplifySOAC (VJP Lambda (Wise rep)
lam [SubExp]
arr [SubExp]
vec) = do
  (Lambda (Wise rep)
lam', Stms (Wise rep)
hoisted) <- forall rep.
SimplifiableRep rep =>
Names
-> Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
Engine.simplifyLambda forall a. Monoid a => a
mempty Lambda (Wise rep)
lam
  [SubExp]
arr' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [SubExp]
arr
  [SubExp]
vec' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [SubExp]
vec
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall rep. Lambda rep -> [SubExp] -> [SubExp] -> SOAC rep
VJP Lambda (Wise rep)
lam' [SubExp]
arr' [SubExp]
vec', Stms (Wise rep)
hoisted)
simplifySOAC (JVP Lambda (Wise rep)
lam [SubExp]
arr [SubExp]
vec) = do
  (Lambda (Wise rep)
lam', Stms (Wise rep)
hoisted) <- forall rep.
SimplifiableRep rep =>
Names
-> Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
Engine.simplifyLambda forall a. Monoid a => a
mempty Lambda (Wise rep)
lam
  [SubExp]
arr' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [SubExp]
arr
  [SubExp]
vec' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [SubExp]
vec
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall rep. Lambda rep -> [SubExp] -> [SubExp] -> SOAC rep
JVP Lambda (Wise rep)
lam' [SubExp]
arr' [SubExp]
vec', Stms (Wise rep)
hoisted)
simplifySOAC (Stream SubExp
outerdim [VName]
arr [SubExp]
nes Lambda (Wise rep)
lam) = do
  SubExp
outerdim' <- forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify SubExp
outerdim
  [SubExp]
nes' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [SubExp]
nes
  [VName]
arr' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [VName]
arr
  (Lambda (Wise rep)
lam', Stms (Wise rep)
lam_hoisted) <- forall rep a. SimpleM rep a -> SimpleM rep a
Engine.enterLoop forall a b. (a -> b) -> a -> b
$ forall rep.
SimplifiableRep rep =>
Names
-> Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
Engine.simplifyLambda forall a. Monoid a => a
mempty Lambda (Wise rep)
lam
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall rep. SubExp -> [VName] -> [SubExp] -> Lambda rep -> SOAC rep
Stream SubExp
outerdim' [VName]
arr' [SubExp]
nes' Lambda (Wise rep)
lam', Stms (Wise rep)
lam_hoisted)
simplifySOAC (Scatter SubExp
w [VName]
ivs Lambda (Wise rep)
lam [(Shape, Int, VName)]
as) = do
  SubExp
w' <- forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify SubExp
w
  (Lambda (Wise rep)
lam', Stms (Wise rep)
hoisted) <- forall rep a. SimpleM rep a -> SimpleM rep a
Engine.enterLoop forall a b. (a -> b) -> a -> b
$ forall rep.
SimplifiableRep rep =>
Names
-> Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
Engine.simplifyLambda forall a. Monoid a => a
mempty Lambda (Wise rep)
lam
  [VName]
ivs' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [VName]
ivs
  [(Shape, Int, VName)]
as' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [(Shape, Int, VName)]
as
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall rep.
SubExp
-> [VName] -> Lambda rep -> [(Shape, Int, VName)] -> SOAC rep
Scatter SubExp
w' [VName]
ivs' Lambda (Wise rep)
lam' [(Shape, Int, VName)]
as', Stms (Wise rep)
hoisted)
simplifySOAC (Hist SubExp
w [VName]
imgs [HistOp (Wise rep)]
ops Lambda (Wise rep)
bfun) = do
  SubExp
w' <- forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify SubExp
w
  ([HistOp (Wise rep)]
ops', [Stms (Wise rep)]
hoisted) <- forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$
    forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [HistOp (Wise rep)]
ops forall a b. (a -> b) -> a -> b
$ \(HistOp Shape
dests_w SubExp
rf [VName]
dests [SubExp]
nes Lambda (Wise rep)
op) -> do
      Shape
dests_w' <- forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify Shape
dests_w
      SubExp
rf' <- forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify SubExp
rf
      [VName]
dests' <- forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [VName]
dests
      [SubExp]
nes' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [SubExp]
nes
      (Lambda (Wise rep)
op', Stms (Wise rep)
hoisted) <- forall rep a. SimpleM rep a -> SimpleM rep a
Engine.enterLoop forall a b. (a -> b) -> a -> b
$ forall rep.
SimplifiableRep rep =>
Names
-> Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
Engine.simplifyLambda forall a. Monoid a => a
mempty Lambda (Wise rep)
op
      forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall rep.
Shape -> SubExp -> [VName] -> [SubExp] -> Lambda rep -> HistOp rep
HistOp Shape
dests_w' SubExp
rf' [VName]
dests' [SubExp]
nes' Lambda (Wise rep)
op', Stms (Wise rep)
hoisted)
  [VName]
imgs' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [VName]
imgs
  (Lambda (Wise rep)
bfun', Stms (Wise rep)
bfun_hoisted) <- forall rep a. SimpleM rep a -> SimpleM rep a
Engine.enterLoop forall a b. (a -> b) -> a -> b
$ forall rep.
SimplifiableRep rep =>
Names
-> Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
Engine.simplifyLambda forall a. Monoid a => a
mempty Lambda (Wise rep)
bfun
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall rep.
SubExp -> [VName] -> [HistOp rep] -> Lambda rep -> SOAC rep
Hist SubExp
w' [VName]
imgs' [HistOp (Wise rep)]
ops' Lambda (Wise rep)
bfun', forall a. Monoid a => [a] -> a
mconcat [Stms (Wise rep)]
hoisted forall a. Semigroup a => a -> a -> a
<> Stms (Wise rep)
bfun_hoisted)
simplifySOAC (Screma SubExp
w [VName]
arrs (ScremaForm [Scan (Wise rep)]
scans [Reduce (Wise rep)]
reds Lambda (Wise rep)
map_lam)) = do
  ([Scan (Wise rep)]
scans', [Stms (Wise rep)]
scans_hoisted) <- forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$
    forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Scan (Wise rep)]
scans forall a b. (a -> b) -> a -> b
$ \(Scan Lambda (Wise rep)
lam [SubExp]
nes) -> do
      (Lambda (Wise rep)
lam', Stms (Wise rep)
hoisted) <- forall rep.
SimplifiableRep rep =>
Names
-> Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
Engine.simplifyLambda forall a. Monoid a => a
mempty Lambda (Wise rep)
lam
      [SubExp]
nes' <- forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [SubExp]
nes
      forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall rep. Lambda rep -> [SubExp] -> Scan rep
Scan Lambda (Wise rep)
lam' [SubExp]
nes', Stms (Wise rep)
hoisted)

  ([Reduce (Wise rep)]
reds', [Stms (Wise rep)]
reds_hoisted) <- forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$
    forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Reduce (Wise rep)]
reds forall a b. (a -> b) -> a -> b
$ \(Reduce Commutativity
comm Lambda (Wise rep)
lam [SubExp]
nes) -> do
      (Lambda (Wise rep)
lam', Stms (Wise rep)
hoisted) <- forall rep.
SimplifiableRep rep =>
Names
-> Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
Engine.simplifyLambda forall a. Monoid a => a
mempty Lambda (Wise rep)
lam
      [SubExp]
nes' <- forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [SubExp]
nes
      forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall rep. Commutativity -> Lambda rep -> [SubExp] -> Reduce rep
Reduce Commutativity
comm Lambda (Wise rep)
lam' [SubExp]
nes', Stms (Wise rep)
hoisted)

  (Lambda (Wise rep)
map_lam', Stms (Wise rep)
map_lam_hoisted) <- forall rep a. SimpleM rep a -> SimpleM rep a
Engine.enterLoop forall a b. (a -> b) -> a -> b
$ forall rep.
SimplifiableRep rep =>
Names
-> Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
Engine.simplifyLambda forall a. Monoid a => a
mempty Lambda (Wise rep)
map_lam

  (,)
    forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ( forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma
            forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify SubExp
w
            forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [VName]
arrs
            forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall rep.
[Scan rep] -> [Reduce rep] -> Lambda rep -> ScremaForm rep
ScremaForm [Scan (Wise rep)]
scans' [Reduce (Wise rep)]
reds' Lambda (Wise rep)
map_lam')
        )
    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a. Monoid a => [a] -> a
mconcat [Stms (Wise rep)]
scans_hoisted forall a. Semigroup a => a -> a -> a
<> forall a. Monoid a => [a] -> a
mconcat [Stms (Wise rep)]
reds_hoisted forall a. Semigroup a => a -> a -> a
<> Stms (Wise rep)
map_lam_hoisted)

instance BuilderOps (Wise SOACS)

instance TraverseOpStms (Wise SOACS) where
  traverseOpStms :: forall (m :: * -> *).
Monad m =>
OpStmsTraverser m (Op (Wise SOACS)) (Wise SOACS)
traverseOpStms = forall (m :: * -> *) rep.
Monad m =>
OpStmsTraverser m (SOAC rep) rep
traverseSOACStms

fixLambdaParams ::
  (MonadBuilder m, Buildable (Rep m), BuilderOps (Rep m)) =>
  Lambda (Rep m) ->
  [Maybe SubExp] ->
  m (Lambda (Rep m))
fixLambdaParams :: forall (m :: * -> *).
(MonadBuilder m, Buildable (Rep m), BuilderOps (Rep m)) =>
Lambda (Rep m) -> [Maybe SubExp] -> m (Lambda (Rep m))
fixLambdaParams Lambda (Rep m)
lam [Maybe SubExp]
fixes = do
  Body (Rep m)
body <- forall rep (m :: * -> *) somerep.
(Buildable rep, MonadFreshNames m, HasScope somerep m,
 SameScope somerep rep) =>
Builder rep (Body rep) -> m (Body rep)
runBodyBuilder forall a b. (a -> b) -> a -> b
$
    forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda (Rep m)
lam) forall a b. (a -> b) -> a -> b
$ do
      forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ forall {m :: * -> *} {dec}.
MonadBuilder m =>
Param dec -> Maybe SubExp -> m ()
maybeFix (forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda (Rep m)
lam) [Maybe SubExp]
fixes'
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> Body rep
lambdaBody Lambda (Rep m)
lam
  forall (f :: * -> *) a. Applicative f => a -> f a
pure
    Lambda (Rep m)
lam
      { lambdaBody :: Body (Rep m)
lambdaBody = Body (Rep m)
body,
        lambdaParams :: [LParam (Rep m)]
lambdaParams =
          forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst forall a b. (a -> b) -> a -> b
$
            forall a. (a -> Bool) -> [a] -> [a]
filter (forall a. Maybe a -> Bool
isNothing forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd) forall a b. (a -> b) -> a -> b
$
              forall a b. [a] -> [b] -> [(a, b)]
zip (forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda (Rep m)
lam) [Maybe SubExp]
fixes'
      }
  where
    fixes' :: [Maybe SubExp]
fixes' = [Maybe SubExp]
fixes forall a. [a] -> [a] -> [a]
++ forall a. a -> [a]
repeat forall a. Maybe a
Nothing
    maybeFix :: Param dec -> Maybe SubExp -> m ()
maybeFix Param dec
p (Just SubExp
x) = forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [forall dec. Param dec -> VName
paramName Param dec
p] forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
x
    maybeFix Param dec
_ Maybe SubExp
Nothing = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

removeLambdaResults :: [Bool] -> Lambda rep -> Lambda rep
removeLambdaResults :: forall rep. [Bool] -> Lambda rep -> Lambda rep
removeLambdaResults [Bool]
keep Lambda rep
lam =
  Lambda rep
lam
    { lambdaBody :: Body rep
lambdaBody = Body rep
lam_body',
      lambdaReturnType :: [Type]
lambdaReturnType = [Type]
ret
    }
  where
    keep' :: [a] -> [a]
    keep' :: forall a. [a] -> [a]
keep' = forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. (a -> Bool) -> [a] -> [a]
filter forall a b. (a, b) -> a
fst forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. [a] -> [b] -> [(a, b)]
zip ([Bool]
keep forall a. [a] -> [a] -> [a]
++ forall a. a -> [a]
repeat Bool
True)
    lam_body :: Body rep
lam_body = forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam
    lam_body' :: Body rep
lam_body' = Body rep
lam_body {bodyResult :: Result
bodyResult = forall a. [a] -> [a]
keep' forall a b. (a -> b) -> a -> b
$ forall rep. Body rep -> Result
bodyResult Body rep
lam_body}
    ret :: [Type]
ret = forall a. [a] -> [a]
keep' forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda rep
lam

soacRules :: RuleBook (Wise SOACS)
soacRules :: RuleBook (Wise SOACS)
soacRules = forall rep.
(BuilderOps rep, TraverseOpStms rep, Aliased rep) =>
RuleBook rep
standardRules forall a. Semigroup a => a -> a -> a
<> forall m. [TopDownRule m] -> [BottomUpRule m] -> RuleBook m
ruleBook [TopDownRule (Wise SOACS)]
topDownRules [BottomUpRule (Wise SOACS)]
bottomUpRules

-- | Does this rep contain 'SOAC's in its t'Op's?  A rep must be an
-- instance of this class for the simplification rules to work.
class HasSOAC rep where
  asSOAC :: Op rep -> Maybe (SOAC rep)
  soacOp :: SOAC rep -> Op rep

instance HasSOAC (Wise SOACS) where
  asSOAC :: Op (Wise SOACS) -> Maybe (SOAC (Wise SOACS))
asSOAC = forall a. a -> Maybe a
Just
  soacOp :: SOAC (Wise SOACS) -> Op (Wise SOACS)
soacOp = forall a. a -> a
id

topDownRules :: [TopDownRule (Wise SOACS)]
topDownRules :: [TopDownRule (Wise SOACS)]
topDownRules =
  [ forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp TopDownRuleOp (Wise SOACS)
hoistCerts,
    forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp forall rep.
(Aliased rep, BuilderOps rep, HasSOAC rep) =>
TopDownRuleOp rep
removeReplicateMapping,
    forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp TopDownRuleOp (Wise SOACS)
removeReplicateWrite,
    forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp forall rep.
(Aliased rep, Buildable rep, BuilderOps rep, HasSOAC rep) =>
TopDownRuleOp rep
removeUnusedSOACInput,
    forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp TopDownRuleOp (Wise SOACS)
simplifyClosedFormReduce,
    forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp forall rep.
(Buildable rep, BuilderOps rep, HasSOAC rep) =>
TopDownRuleOp rep
simplifyKnownIterationSOAC,
    forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp forall rep.
(Buildable rep, BuilderOps rep, HasSOAC rep) =>
TopDownRuleOp rep
liftIdentityMapping,
    forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp TopDownRuleOp (Wise SOACS)
removeDuplicateMapOutput,
    forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp TopDownRuleOp (Wise SOACS)
fuseConcatScatter,
    forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp forall rep.
(Buildable rep, BuilderOps rep, HasSOAC rep) =>
TopDownRuleOp rep
simplifyMapIota,
    forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp TopDownRuleOp (Wise SOACS)
moveTransformToInput
  ]

bottomUpRules :: [BottomUpRule (Wise SOACS)]
bottomUpRules :: [BottomUpRule (Wise SOACS)]
bottomUpRules =
  [ forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp BottomUpRuleOp (Wise SOACS)
removeDeadMapping,
    forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp BottomUpRuleOp (Wise SOACS)
removeDeadReduction,
    forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp BottomUpRuleOp (Wise SOACS)
removeDeadWrite,
    forall rep a. RuleBasicOp rep a -> SimplificationRule rep a
RuleBasicOp forall rep. BuilderOps rep => BottomUpRuleBasicOp rep
removeUnnecessaryCopy,
    forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp BottomUpRuleOp (Wise SOACS)
liftIdentityStreaming,
    forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp BottomUpRuleOp (Wise SOACS)
mapOpToOp
  ]

-- Any certificates attached to a trivial Stm in the body might as
-- well be applied to the SOAC itself.
hoistCerts :: TopDownRuleOp (Wise SOACS)
hoistCerts :: TopDownRuleOp (Wise SOACS)
hoistCerts SymbolTable (Wise SOACS)
vtable Pat (LetDec (Wise SOACS))
pat StmAux (ExpDec (Wise SOACS))
aux Op (Wise SOACS)
soac
  | (SOAC (Wise SOACS)
soac', Certs
hoisted) <- forall s a. State s a -> s -> (a, s)
runState (forall (m :: * -> *) frep trep.
Monad m =>
SOACMapper frep trep m -> SOAC frep -> m (SOAC trep)
mapSOACM SOACMapper (Wise SOACS) (Wise SOACS) (StateT Certs Identity)
mapper Op (Wise SOACS)
soac) forall a. Monoid a => a
mempty,
    Certs
hoisted forall a. Eq a => a -> a -> Bool
/= forall a. Monoid a => a
mempty =
      forall rep. RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec (Wise SOACS))
aux forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
hoisted forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec (Wise SOACS))
pat forall a b. (a -> b) -> a -> b
$ forall rep. Op rep -> Exp rep
Op SOAC (Wise SOACS)
soac'
  where
    mapper :: SOACMapper (Wise SOACS) (Wise SOACS) (StateT Certs Identity)
mapper = forall rep (m :: * -> *). Monad m => SOACMapper rep rep m
identitySOACMapper {mapOnSOACLambda :: Lambda (Wise SOACS) -> StateT Certs Identity (Lambda (Wise SOACS))
mapOnSOACLambda = Lambda (Wise SOACS) -> StateT Certs Identity (Lambda (Wise SOACS))
onLambda}
    onLambda :: Lambda (Wise SOACS) -> StateT Certs Identity (Lambda (Wise SOACS))
onLambda Lambda (Wise SOACS)
lam = do
      Stms (Wise SOACS)
stms' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Stm (Wise SOACS) -> StateT Certs Identity (Stm (Wise SOACS))
onStm forall a b. (a -> b) -> a -> b
$ forall rep. Body rep -> Stms rep
bodyStms forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> Body rep
lambdaBody Lambda (Wise SOACS)
lam
      forall (f :: * -> *) a. Applicative f => a -> f a
pure
        Lambda (Wise SOACS)
lam
          { lambdaBody :: Body (Wise SOACS)
lambdaBody =
              forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody Stms (Wise SOACS)
stms' forall a b. (a -> b) -> a -> b
$ forall rep. Body rep -> Result
bodyResult forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> Body rep
lambdaBody Lambda (Wise SOACS)
lam
          }
    onStm :: Stm (Wise SOACS) -> StateT Certs Identity (Stm (Wise SOACS))
onStm (Let Pat (LetDec (Wise SOACS))
se_pat StmAux (ExpDec (Wise SOACS))
se_aux (BasicOp (SubExp SubExp
se))) = do
      let ([VName]
invariant, [VName]
variant) =
            forall a. (a -> Bool) -> [a] -> ([a], [a])
partition (forall rep. VName -> SymbolTable rep -> Bool
`ST.elem` SymbolTable (Wise SOACS)
vtable) forall a b. (a -> b) -> a -> b
$
              Certs -> [VName]
unCerts forall a b. (a -> b) -> a -> b
$
                forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (ExpDec (Wise SOACS))
se_aux
          se_aux' :: StmAux (ExpDec (Wise SOACS))
se_aux' = StmAux (ExpDec (Wise SOACS))
se_aux {stmAuxCerts :: Certs
stmAuxCerts = [VName] -> Certs
Certs [VName]
variant}
      forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ([VName] -> Certs
Certs [VName]
invariant <>)
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec (Wise SOACS))
se_pat StmAux (ExpDec (Wise SOACS))
se_aux' forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
    onStm Stm (Wise SOACS)
stm = forall (f :: * -> *) a. Applicative f => a -> f a
pure Stm (Wise SOACS)
stm
hoistCerts SymbolTable (Wise SOACS)
_ Pat (LetDec (Wise SOACS))
_ StmAux (ExpDec (Wise SOACS))
_ Op (Wise SOACS)
_ =
  forall rep. Rule rep
Skip

liftIdentityMapping ::
  forall rep.
  (Buildable rep, BuilderOps rep, HasSOAC rep) =>
  TopDownRuleOp rep
liftIdentityMapping :: forall rep.
(Buildable rep, BuilderOps rep, HasSOAC rep) =>
TopDownRuleOp rep
liftIdentityMapping TopDown rep
_ Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux Op rep
op
  | Just (Screma SubExp
w [VName]
arrs ScremaForm rep
form :: SOAC rep) <- forall rep. HasSOAC rep => Op rep -> Maybe (SOAC rep)
asSOAC Op rep
op,
    Just Lambda rep
fun <- forall rep. ScremaForm rep -> Maybe (Lambda rep)
isMapSOAC ScremaForm rep
form = do
      let inputMap :: Map VName VName
inputMap = forall k a. Ord k => [(k, a)] -> Map k a
M.fromList forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip (forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
fun) [VName]
arrs
          free :: Names
free = forall a. FreeIn a => a -> Names
freeIn forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
fun
          rettype :: [Type]
rettype = forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda rep
fun
          ses :: Result
ses = forall rep. Body rep -> Result
bodyResult forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
fun

          freeOrConst :: SubExp -> Bool
freeOrConst (Var VName
v) = VName
v VName -> Names -> Bool
`nameIn` Names
free
          freeOrConst Constant {} = Bool
True

          checkInvariance :: (PatElem (LetDec rep), SubExpRes, Type)
-> ([(Pat (LetDec rep), Exp rep)],
    [(PatElem (LetDec rep), SubExp)], [Type])
-> ([(Pat (LetDec rep), Exp rep)],
    [(PatElem (LetDec rep), SubExp)], [Type])
checkInvariance (PatElem (LetDec rep)
outId, SubExpRes Certs
_ (Var VName
v), Type
_) ([(Pat (LetDec rep), Exp rep)]
invariant, [(PatElem (LetDec rep), SubExp)]
mapresult, [Type]
rettype')
            | Just VName
inp <- forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName VName
inputMap =
                ( (forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (LetDec rep)
outId], VName -> Exp rep
e VName
inp) forall a. a -> [a] -> [a]
: [(Pat (LetDec rep), Exp rep)]
invariant,
                  [(PatElem (LetDec rep), SubExp)]
mapresult,
                  [Type]
rettype'
                )
            where
              e :: VName -> Exp rep
e VName
inp = case forall dec. Typed dec => PatElem dec -> Type
patElemType PatElem (LetDec rep)
outId of
                Acc {} -> forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
inp
                Type
_ -> forall rep. BasicOp -> Exp rep
BasicOp (VName -> BasicOp
Copy VName
inp)
          checkInvariance (PatElem (LetDec rep)
outId, SubExpRes Certs
_ SubExp
e, Type
t) ([(Pat (LetDec rep), Exp rep)]
invariant, [(PatElem (LetDec rep), SubExp)]
mapresult, [Type]
rettype')
            | SubExp -> Bool
freeOrConst SubExp
e =
                ( (forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (LetDec rep)
outId], forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate (forall d. [d] -> ShapeBase d
Shape [SubExp
w]) SubExp
e) forall a. a -> [a] -> [a]
: [(Pat (LetDec rep), Exp rep)]
invariant,
                  [(PatElem (LetDec rep), SubExp)]
mapresult,
                  [Type]
rettype'
                )
            | Bool
otherwise =
                ( [(Pat (LetDec rep), Exp rep)]
invariant,
                  (PatElem (LetDec rep)
outId, SubExp
e) forall a. a -> [a] -> [a]
: [(PatElem (LetDec rep), SubExp)]
mapresult,
                  Type
t forall a. a -> [a] -> [a]
: [Type]
rettype'
                )

      case forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (PatElem (LetDec rep), SubExpRes, Type)
-> ([(Pat (LetDec rep), Exp rep)],
    [(PatElem (LetDec rep), SubExp)], [Type])
-> ([(Pat (LetDec rep), Exp rep)],
    [(PatElem (LetDec rep), SubExp)], [Type])
checkInvariance ([], [], []) forall a b. (a -> b) -> a -> b
$
        forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 (forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec rep)
pat) Result
ses [Type]
rettype of
        ([], [(PatElem (LetDec rep), SubExp)]
_, [Type]
_) -> forall rep. Rule rep
Skip
        ([(Pat (LetDec rep), Exp rep)]
invariant, [(PatElem (LetDec rep), SubExp)]
mapresult, [Type]
rettype') -> forall rep. RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$ do
          let ([PatElem (LetDec rep)]
pat', [SubExp]
ses') = forall a b. [(a, b)] -> ([a], [b])
unzip [(PatElem (LetDec rep), SubExp)]
mapresult
              fun' :: Lambda rep
fun' =
                Lambda rep
fun
                  { lambdaBody :: Body rep
lambdaBody = (forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
fun) {bodyResult :: Result
bodyResult = [SubExp] -> Result
subExpsRes [SubExp]
ses'},
                    lambdaReturnType :: [Type]
lambdaReturnType = [Type]
rettype'
                  }
          forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind) [(Pat (LetDec rep), Exp rep)]
invariant
          forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux forall a b. (a -> b) -> a -> b
$
            forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames (forall a b. (a -> b) -> [a] -> [b]
map forall dec. PatElem dec -> VName
patElemName [PatElem (LetDec rep)]
pat') forall a b. (a -> b) -> a -> b
$
              forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$
                forall rep. HasSOAC rep => SOAC rep -> Op rep
soacOp forall a b. (a -> b) -> a -> b
$
                  forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName]
arrs (forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda rep
fun')
liftIdentityMapping TopDown rep
_ Pat (LetDec rep)
_ StmAux (ExpDec rep)
_ Op rep
_ = forall rep. Rule rep
Skip

liftIdentityStreaming :: BottomUpRuleOp (Wise SOACS)
liftIdentityStreaming :: BottomUpRuleOp (Wise SOACS)
liftIdentityStreaming BottomUp (Wise SOACS)
_ (Pat [PatElem (LetDec (Wise SOACS))]
pes) StmAux (ExpDec (Wise SOACS))
aux (Stream SubExp
w [VName]
arrs [SubExp]
nes Lambda (Wise SOACS)
lam)
  | ([(Type, PatElem (VarWisdom, Type), SubExpRes)]
variant_map, [(PatElem (VarWisdom, Type), VName)]
invariant_map) <-
      forall a b. [Either a b] -> ([a], [b])
partitionEithers forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (Type, PatElem (VarWisdom, Type), SubExpRes)
-> Either
     (Type, PatElem (VarWisdom, Type), SubExpRes)
     (PatElem (VarWisdom, Type), VName)
isInvariantRes forall a b. (a -> b) -> a -> b
$ forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Type]
map_ts [PatElem (VarWisdom, Type)]
map_pes Result
map_res,
    Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(PatElem (VarWisdom, Type), VName)]
invariant_map = forall rep. RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$ do
      forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(PatElem (VarWisdom, Type), VName)]
invariant_map forall a b. (a -> b) -> a -> b
$ \(PatElem (VarWisdom, Type)
pe, VName
arr) ->
        forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind (forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (VarWisdom, Type)
pe]) forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
arr

      let ([Type]
variant_map_ts, [PatElem (VarWisdom, Type)]
variant_map_pes, Result
variant_map_res) = forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(Type, PatElem (VarWisdom, Type), SubExpRes)]
variant_map
          lam' :: Lambda (Wise SOACS)
lam' =
            Lambda (Wise SOACS)
lam
              { lambdaBody :: Body (Wise SOACS)
lambdaBody = (forall rep. Lambda rep -> Body rep
lambdaBody Lambda (Wise SOACS)
lam) {bodyResult :: Result
bodyResult = Result
fold_res forall a. [a] -> [a] -> [a]
++ Result
variant_map_res},
                lambdaReturnType :: [Type]
lambdaReturnType = [Type]
fold_ts forall a. [a] -> [a] -> [a]
++ [Type]
variant_map_ts
              }

      forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec (Wise SOACS))
aux forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind (forall dec. [PatElem dec] -> Pat dec
Pat forall a b. (a -> b) -> a -> b
$ [PatElem (VarWisdom, Type)]
fold_pes forall a. [a] -> [a] -> [a]
++ [PatElem (VarWisdom, Type)]
variant_map_pes) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$
        forall rep. SubExp -> [VName] -> [SubExp] -> Lambda rep -> SOAC rep
Stream SubExp
w [VName]
arrs [SubExp]
nes Lambda (Wise SOACS)
lam'
  where
    num_folds :: Int
num_folds = forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes
    ([PatElem (VarWisdom, Type)]
fold_pes, [PatElem (VarWisdom, Type)]
map_pes) = forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_folds [PatElem (LetDec (Wise SOACS))]
pes
    ([Type]
fold_ts, [Type]
map_ts) = forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_folds forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda (Wise SOACS)
lam
    lam_res :: Result
lam_res = forall rep. Body rep -> Result
bodyResult forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> Body rep
lambdaBody Lambda (Wise SOACS)
lam
    (Result
fold_res, Result
map_res) = forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_folds Result
lam_res
    params_to_arrs :: [(VName, VName)]
params_to_arrs = forall a b. [a] -> [b] -> [(a, b)]
zip (forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName forall a b. (a -> b) -> a -> b
$ forall a. Int -> [a] -> [a]
drop (Int
1 forall a. Num a => a -> a -> a
+ Int
num_folds) forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda (Wise SOACS)
lam) [VName]
arrs

    isInvariantRes :: (Type, PatElem (VarWisdom, Type), SubExpRes)
-> Either
     (Type, PatElem (VarWisdom, Type), SubExpRes)
     (PatElem (VarWisdom, Type), VName)
isInvariantRes (Type
_, PatElem (VarWisdom, Type)
pe, SubExpRes Certs
_ (Var VName
v))
      | Just VName
arr <- forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup VName
v [(VName, VName)]
params_to_arrs =
          forall a b. b -> Either a b
Right (PatElem (VarWisdom, Type)
pe, VName
arr)
    isInvariantRes (Type, PatElem (VarWisdom, Type), SubExpRes)
x =
      forall a b. a -> Either a b
Left (Type, PatElem (VarWisdom, Type), SubExpRes)
x
liftIdentityStreaming BottomUp (Wise SOACS)
_ Pat (LetDec (Wise SOACS))
_ StmAux (ExpDec (Wise SOACS))
_ Op (Wise SOACS)
_ = forall rep. Rule rep
Skip

-- | Remove all arguments to the map that are simply replicates.
-- These can be turned into free variables instead.
removeReplicateMapping ::
  (Aliased rep, BuilderOps rep, HasSOAC rep) =>
  TopDownRuleOp rep
removeReplicateMapping :: forall rep.
(Aliased rep, BuilderOps rep, HasSOAC rep) =>
TopDownRuleOp rep
removeReplicateMapping TopDown rep
vtable Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux Op rep
op
  | Just (Screma SubExp
w [VName]
arrs ScremaForm rep
form) <- forall rep. HasSOAC rep => Op rep -> Maybe (SOAC rep)
asSOAC Op rep
op,
    Just Lambda rep
fun <- forall rep. ScremaForm rep -> Maybe (Lambda rep)
isMapSOAC ScremaForm rep
form,
    Just ([([VName], Certs, Exp rep)]
stms, Lambda rep
fun', [VName]
arrs') <- forall rep.
Aliased rep =>
SymbolTable rep
-> Lambda rep
-> [VName]
-> Maybe ([([VName], Certs, Exp rep)], Lambda rep, [VName])
removeReplicateInput TopDown rep
vtable Lambda rep
fun [VName]
arrs = forall rep. RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$ do
      forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [([VName], Certs, Exp rep)]
stms forall a b. (a -> b) -> a -> b
$ \([VName]
vs, Certs
cs, Exp rep
e) -> forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName]
vs Exp rep
e
      forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
pat forall a b. (a -> b) -> a -> b
$ forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall rep. HasSOAC rep => SOAC rep -> Op rep
soacOp forall a b. (a -> b) -> a -> b
$ forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName]
arrs' forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda rep
fun'
removeReplicateMapping TopDown rep
_ Pat (LetDec rep)
_ StmAux (ExpDec rep)
_ Op rep
_ = forall rep. Rule rep
Skip

-- | Like 'removeReplicateMapping', but for 'Scatter'.
removeReplicateWrite :: TopDownRuleOp (Wise SOACS)
removeReplicateWrite :: TopDownRuleOp (Wise SOACS)
removeReplicateWrite SymbolTable (Wise SOACS)
vtable Pat (LetDec (Wise SOACS))
pat StmAux (ExpDec (Wise SOACS))
aux (Scatter SubExp
w [VName]
ivs Lambda (Wise SOACS)
lam [(Shape, Int, VName)]
as)
  | Just ([([VName], Certs, Exp (Wise SOACS))]
stms, Lambda (Wise SOACS)
lam', [VName]
ivs') <- forall rep.
Aliased rep =>
SymbolTable rep
-> Lambda rep
-> [VName]
-> Maybe ([([VName], Certs, Exp rep)], Lambda rep, [VName])
removeReplicateInput SymbolTable (Wise SOACS)
vtable Lambda (Wise SOACS)
lam [VName]
ivs = forall rep. RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$ do
      forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [([VName], Certs, Exp (Wise SOACS))]
stms forall a b. (a -> b) -> a -> b
$ \([VName]
vs, Certs
cs, Exp (Wise SOACS)
e) -> forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName]
vs Exp (Wise SOACS)
e
      forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec (Wise SOACS))
aux forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec (Wise SOACS))
pat forall a b. (a -> b) -> a -> b
$ forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall rep.
SubExp
-> [VName] -> Lambda rep -> [(Shape, Int, VName)] -> SOAC rep
Scatter SubExp
w [VName]
ivs' Lambda (Wise SOACS)
lam' [(Shape, Int, VName)]
as
removeReplicateWrite SymbolTable (Wise SOACS)
_ Pat (LetDec (Wise SOACS))
_ StmAux (ExpDec (Wise SOACS))
_ Op (Wise SOACS)
_ = forall rep. Rule rep
Skip

removeReplicateInput ::
  Aliased rep =>
  ST.SymbolTable rep ->
  Lambda rep ->
  [VName] ->
  Maybe
    ( [([VName], Certs, Exp rep)],
      Lambda rep,
      [VName]
    )
removeReplicateInput :: forall rep.
Aliased rep =>
SymbolTable rep
-> Lambda rep
-> [VName]
-> Maybe ([([VName], Certs, Exp rep)], Lambda rep, [VName])
removeReplicateInput SymbolTable rep
vtable Lambda rep
fun [VName]
arrs
  | Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> Bool
null [([VName], Certs, Exp rep)]
parameterBnds = do
      let ([Param (LParamInfo rep)]
arr_params', [VName]
arrs') = forall a b. [(a, b)] -> ([a], [b])
unzip [(Param (LParamInfo rep), VName)]
params_and_arrs
          fun' :: Lambda rep
fun' = Lambda rep
fun {lambdaParams :: [Param (LParamInfo rep)]
lambdaParams = [Param (LParamInfo rep)]
acc_params forall a. Semigroup a => a -> a -> a
<> [Param (LParamInfo rep)]
arr_params'}
      forall (f :: * -> *) a. Applicative f => a -> f a
pure ([([VName], Certs, Exp rep)]
parameterBnds, Lambda rep
fun', [VName]
arrs')
  | Bool
otherwise = forall a. Maybe a
Nothing
  where
    params :: [Param (LParamInfo rep)]
params = forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
fun
    ([Param (LParamInfo rep)]
acc_params, [Param (LParamInfo rep)]
arr_params) =
      forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. Foldable t => t a -> Int
length [Param (LParamInfo rep)]
params forall a. Num a => a -> a -> a
- forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
arrs) [Param (LParamInfo rep)]
params
    ([(Param (LParamInfo rep), VName)]
params_and_arrs, [([VName], Certs, Exp rep)]
parameterBnds) =
      forall a b. [Either a b] -> ([a], [b])
partitionEithers forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Param (LParamInfo rep)
-> VName
-> Either (Param (LParamInfo rep), VName) ([VName], Certs, Exp rep)
isReplicateAndNotConsumed [Param (LParamInfo rep)]
arr_params [VName]
arrs

    isReplicateAndNotConsumed :: Param (LParamInfo rep)
-> VName
-> Either (Param (LParamInfo rep), VName) ([VName], Certs, Exp rep)
isReplicateAndNotConsumed Param (LParamInfo rep)
p VName
v
      | Just (BasicOp (Replicate (Shape (SubExp
_ : [SubExp]
ds)) SubExp
e), Certs
v_cs) <-
          forall rep. VName -> SymbolTable rep -> Maybe (Exp rep, Certs)
ST.lookupExp VName
v SymbolTable rep
vtable,
        forall dec. Param dec -> VName
paramName Param (LParamInfo rep)
p VName -> Names -> Bool
`notNameIn` forall rep. Aliased rep => Lambda rep -> Names
consumedByLambda Lambda rep
fun =
          forall a b. b -> Either a b
Right
            ( [forall dec. Param dec -> VName
paramName Param (LParamInfo rep)
p],
              Certs
v_cs,
              case [SubExp]
ds of
                [] -> forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
e
                [SubExp]
_ -> forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate (forall d. [d] -> ShapeBase d
Shape [SubExp]
ds) SubExp
e
            )
      | Bool
otherwise =
          forall a b. a -> Either a b
Left (Param (LParamInfo rep)
p, VName
v)

-- | Remove inputs that are not used inside the SOAC.
removeUnusedSOACInput ::
  forall rep.
  (Aliased rep, Buildable rep, BuilderOps rep, HasSOAC rep) =>
  TopDownRuleOp rep
removeUnusedSOACInput :: forall rep.
(Aliased rep, Buildable rep, BuilderOps rep, HasSOAC rep) =>
TopDownRuleOp rep
removeUnusedSOACInput TopDown rep
_ Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux OpC rep rep
op
  | Just (Screma SubExp
w [VName]
arrs ScremaForm rep
form :: SOAC rep) <- forall rep. HasSOAC rep => Op rep -> Maybe (SOAC rep)
asSOAC OpC rep rep
op,
    ScremaForm [Scan rep]
scan [Reduce rep]
reduce Lambda rep
map_lam <- ScremaForm rep
form,
    Just ([VName]
used_arrs, Lambda rep
map_lam') <- forall {rep} {b}.
(FreeDec (ExpDec rep), FreeDec (BodyDec rep), FreeIn (OpC rep rep),
 FreeIn (FParamInfo rep), FreeIn (LParamInfo rep),
 FreeIn (LetDec rep), FreeIn (RetType rep),
 FreeIn (BranchType rep)) =>
Lambda rep -> [b] -> Maybe ([b], Lambda rep)
remove Lambda rep
map_lam [VName]
arrs =
      forall rep. RuleM rep () -> Rule rep
Simplify forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
pat forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$
        forall rep. HasSOAC rep => SOAC rep -> Op rep
soacOp (forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName]
used_arrs (forall rep.
[Scan rep] -> [Reduce rep] -> Lambda rep -> ScremaForm rep
ScremaForm [Scan rep]
scan [Reduce rep]
reduce Lambda rep
map_lam'))
  | Just (Scatter SubExp
w [VName]
arrs Lambda rep
map_lam [(Shape, Int, VName)]
dests :: SOAC rep) <- forall rep. HasSOAC rep => Op rep -> Maybe (SOAC rep)
asSOAC OpC rep rep
op,
    Just ([VName]
used_arrs, Lambda rep
map_lam') <- forall {rep} {b}.
(FreeDec (ExpDec rep), FreeDec (BodyDec rep), FreeIn (OpC rep rep),
 FreeIn (FParamInfo rep), FreeIn (LParamInfo rep),
 FreeIn (LetDec rep), FreeIn (RetType rep),
 FreeIn (BranchType rep)) =>
Lambda rep -> [b] -> Maybe ([b], Lambda rep)
remove Lambda rep
map_lam [VName]
arrs =
      forall rep. RuleM rep () -> Rule rep
Simplify forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
pat forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$
        forall rep. HasSOAC rep => SOAC rep -> Op rep
soacOp (forall rep.
SubExp
-> [VName] -> Lambda rep -> [(Shape, Int, VName)] -> SOAC rep
Scatter SubExp
w [VName]
used_arrs Lambda rep
map_lam' [(Shape, Int, VName)]
dests)
  where
    used_in_body :: Lambda rep -> Names
used_in_body Lambda rep
map_lam = forall a. FreeIn a => a -> Names
freeIn forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
map_lam
    usedInput :: Lambda rep -> (Param dec, b) -> Bool
usedInput Lambda rep
map_lam (Param dec
param, b
_) = forall dec. Param dec -> VName
paramName Param dec
param VName -> Names -> Bool
`nameIn` forall {rep}.
(FreeDec (ExpDec rep), FreeDec (BodyDec rep), FreeIn (OpC rep rep),
 FreeIn (FParamInfo rep), FreeIn (LParamInfo rep),
 FreeIn (LetDec rep), FreeIn (RetType rep),
 FreeIn (BranchType rep)) =>
Lambda rep -> Names
used_in_body Lambda rep
map_lam
    remove :: Lambda rep -> [b] -> Maybe ([b], Lambda rep)
remove Lambda rep
map_lam [b]
arrs =
      let ([(Param (LParamInfo rep), b)]
used, [(Param (LParamInfo rep), b)]
unused) = forall a. (a -> Bool) -> [a] -> ([a], [a])
partition (forall {rep} {dec} {b}.
(FreeDec (ExpDec rep), FreeDec (BodyDec rep), FreeIn (OpC rep rep),
 FreeIn (FParamInfo rep), FreeIn (LParamInfo rep),
 FreeIn (LetDec rep), FreeIn (RetType rep),
 FreeIn (BranchType rep)) =>
Lambda rep -> (Param dec, b) -> Bool
usedInput Lambda rep
map_lam) (forall a b. [a] -> [b] -> [(a, b)]
zip (forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
map_lam) [b]
arrs)
          ([Param (LParamInfo rep)]
used_params, [b]
used_arrs) = forall a b. [(a, b)] -> ([a], [b])
unzip [(Param (LParamInfo rep), b)]
used
          map_lam' :: Lambda rep
map_lam' = Lambda rep
map_lam {lambdaParams :: [Param (LParamInfo rep)]
lambdaParams = [Param (LParamInfo rep)]
used_params}
       in if forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(Param (LParamInfo rep), b)]
unused then forall a. Maybe a
Nothing else forall a. a -> Maybe a
Just ([b]
used_arrs, Lambda rep
map_lam')
removeUnusedSOACInput TopDown rep
_ Pat (LetDec rep)
_ StmAux (ExpDec rep)
_ OpC rep rep
_ = forall rep. Rule rep
Skip

removeDeadMapping :: BottomUpRuleOp (Wise SOACS)
removeDeadMapping :: BottomUpRuleOp (Wise SOACS)
removeDeadMapping (SymbolTable (Wise SOACS)
_, UsageTable
used) (Pat [PatElem (LetDec (Wise SOACS))]
pes) StmAux (ExpDec (Wise SOACS))
aux (Screma SubExp
w [VName]
arrs (ScremaForm [Scan (Wise SOACS)]
scans [Reduce (Wise SOACS)]
reds Lambda (Wise SOACS)
lam))
  | ([PatElem (VarWisdom, Type)]
nonmap_pes, [PatElem (VarWisdom, Type)]
map_pes) <- forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_nonmap_res [PatElem (LetDec (Wise SOACS))]
pes,
    Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> Bool
null [PatElem (VarWisdom, Type)]
map_pes =
      let (Result
nonmap_res, Result
map_res) = forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_nonmap_res forall a b. (a -> b) -> a -> b
$ forall rep. Body rep -> Result
bodyResult forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> Body rep
lambdaBody Lambda (Wise SOACS)
lam
          ([Type]
nonmap_ts, [Type]
map_ts) = forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_nonmap_res forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda (Wise SOACS)
lam
          isUsed :: (PatElem (VarWisdom, Type), SubExpRes, Type) -> Bool
isUsed (PatElem (VarWisdom, Type)
bindee, SubExpRes
_, Type
_) = (VName -> UsageTable -> Bool
`UT.used` UsageTable
used) forall a b. (a -> b) -> a -> b
$ forall dec. PatElem dec -> VName
patElemName PatElem (VarWisdom, Type)
bindee
          ([PatElem (VarWisdom, Type)]
map_pes', Result
map_res', [Type]
map_ts') =
            forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 forall a b. (a -> b) -> a -> b
$ forall a. (a -> Bool) -> [a] -> [a]
filter (PatElem (VarWisdom, Type), SubExpRes, Type) -> Bool
isUsed forall a b. (a -> b) -> a -> b
$ forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [PatElem (VarWisdom, Type)]
map_pes Result
map_res [Type]
map_ts
          lam' :: Lambda (Wise SOACS)
lam' =
            Lambda (Wise SOACS)
lam
              { lambdaBody :: Body (Wise SOACS)
lambdaBody = (forall rep. Lambda rep -> Body rep
lambdaBody Lambda (Wise SOACS)
lam) {bodyResult :: Result
bodyResult = Result
nonmap_res forall a. Semigroup a => a -> a -> a
<> Result
map_res'},
                lambdaReturnType :: [Type]
lambdaReturnType = [Type]
nonmap_ts forall a. Semigroup a => a -> a -> a
<> [Type]
map_ts'
              }
       in if [PatElem (VarWisdom, Type)]
map_pes forall a. Eq a => a -> a -> Bool
/= [PatElem (VarWisdom, Type)]
map_pes'
            then
              forall rep. RuleM rep () -> Rule rep
Simplify forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec (Wise SOACS))
aux forall a b. (a -> b) -> a -> b
$
                forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind (forall dec. [PatElem dec] -> Pat dec
Pat forall a b. (a -> b) -> a -> b
$ [PatElem (VarWisdom, Type)]
nonmap_pes forall a. Semigroup a => a -> a -> a
<> [PatElem (VarWisdom, Type)]
map_pes') forall a b. (a -> b) -> a -> b
$
                  forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$
                    forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName]
arrs forall a b. (a -> b) -> a -> b
$
                      forall rep.
[Scan rep] -> [Reduce rep] -> Lambda rep -> ScremaForm rep
ScremaForm [Scan (Wise SOACS)]
scans [Reduce (Wise SOACS)]
reds Lambda (Wise SOACS)
lam'
            else forall rep. Rule rep
Skip
  where
    num_nonmap_res :: Int
num_nonmap_res = forall rep. [Scan rep] -> Int
scanResults [Scan (Wise SOACS)]
scans forall a. Num a => a -> a -> a
+ forall rep. [Reduce rep] -> Int
redResults [Reduce (Wise SOACS)]
reds
removeDeadMapping BottomUp (Wise SOACS)
_ Pat (LetDec (Wise SOACS))
_ StmAux (ExpDec (Wise SOACS))
_ Op (Wise SOACS)
_ = forall rep. Rule rep
Skip

removeDuplicateMapOutput :: TopDownRuleOp (Wise SOACS)
removeDuplicateMapOutput :: TopDownRuleOp (Wise SOACS)
removeDuplicateMapOutput SymbolTable (Wise SOACS)
_ (Pat [PatElem (LetDec (Wise SOACS))]
pes) StmAux (ExpDec (Wise SOACS))
aux (Screma SubExp
w [VName]
arrs ScremaForm (Wise SOACS)
form)
  | Just Lambda (Wise SOACS)
fun <- forall rep. ScremaForm rep -> Maybe (Lambda rep)
isMapSOAC ScremaForm (Wise SOACS)
form =
      let ses :: Result
ses = forall rep. Body rep -> Result
bodyResult forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> Body rep
lambdaBody Lambda (Wise SOACS)
fun
          ts :: [Type]
ts = forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda (Wise SOACS)
fun
          ses_ts_pes :: [(SubExpRes, Type, PatElem (VarWisdom, Type))]
ses_ts_pes = forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 Result
ses [Type]
ts [PatElem (LetDec (Wise SOACS))]
pes
          ([(SubExpRes, Type, PatElem (VarWisdom, Type))]
ses_ts_pes', [(PatElem (VarWisdom, Type), PatElem (VarWisdom, Type))]
copies) =
            forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl forall {b} {a}.
([(SubExpRes, b, a)], [(a, a)])
-> (SubExpRes, b, a) -> ([(SubExpRes, b, a)], [(a, a)])
checkForDuplicates (forall a. Monoid a => a
mempty, forall a. Monoid a => a
mempty) [(SubExpRes, Type, PatElem (VarWisdom, Type))]
ses_ts_pes
       in if forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(PatElem (VarWisdom, Type), PatElem (VarWisdom, Type))]
copies
            then forall rep. Rule rep
Skip
            else forall rep. RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$ do
              let (Result
ses', [Type]
ts', [PatElem (VarWisdom, Type)]
pes') = forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(SubExpRes, Type, PatElem (VarWisdom, Type))]
ses_ts_pes'
                  fun' :: Lambda (Wise SOACS)
fun' =
                    Lambda (Wise SOACS)
fun
                      { lambdaBody :: Body (Wise SOACS)
lambdaBody = (forall rep. Lambda rep -> Body rep
lambdaBody Lambda (Wise SOACS)
fun) {bodyResult :: Result
bodyResult = Result
ses'},
                        lambdaReturnType :: [Type]
lambdaReturnType = [Type]
ts'
                      }
              forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec (Wise SOACS))
aux forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind (forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (VarWisdom, Type)]
pes') forall a b. (a -> b) -> a -> b
$ forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName]
arrs forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda (Wise SOACS)
fun'
              forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(PatElem (VarWisdom, Type), PatElem (VarWisdom, Type))]
copies forall a b. (a -> b) -> a -> b
$ \(PatElem (VarWisdom, Type)
from, PatElem (VarWisdom, Type)
to) ->
                forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind (forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (VarWisdom, Type)
to]) forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy forall a b. (a -> b) -> a -> b
$ forall dec. PatElem dec -> VName
patElemName PatElem (VarWisdom, Type)
from
  where
    checkForDuplicates :: ([(SubExpRes, b, a)], [(a, a)])
-> (SubExpRes, b, a) -> ([(SubExpRes, b, a)], [(a, a)])
checkForDuplicates ([(SubExpRes, b, a)]
ses_ts_pes', [(a, a)]
copies) (SubExpRes
se, b
t, a
pe)
      | Just (SubExpRes
_, b
_, a
pe') <- forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (\(SubExpRes
x, b
_, a
_) -> SubExpRes -> SubExp
resSubExp SubExpRes
x forall a. Eq a => a -> a -> Bool
== SubExpRes -> SubExp
resSubExp SubExpRes
se) [(SubExpRes, b, a)]
ses_ts_pes' =
          -- This result has been returned before, producing the
          -- array pe'.
          ([(SubExpRes, b, a)]
ses_ts_pes', (a
pe', a
pe) forall a. a -> [a] -> [a]
: [(a, a)]
copies)
      | Bool
otherwise = ([(SubExpRes, b, a)]
ses_ts_pes' forall a. [a] -> [a] -> [a]
++ [(SubExpRes
se, b
t, a
pe)], [(a, a)]
copies)
removeDuplicateMapOutput SymbolTable (Wise SOACS)
_ Pat (LetDec (Wise SOACS))
_ StmAux (ExpDec (Wise SOACS))
_ Op (Wise SOACS)
_ = forall rep. Rule rep
Skip

-- Mapping some operations becomes an extension of that operation.
mapOpToOp :: BottomUpRuleOp (Wise SOACS)
mapOpToOp :: BottomUpRuleOp (Wise SOACS)
mapOpToOp (SymbolTable (Wise SOACS)
_, UsageTable
used) Pat (LetDec (Wise SOACS))
pat StmAux (ExpDec (Wise SOACS))
aux1 Op (Wise SOACS)
e
  | Just (PatElem (VarWisdom, Type)
map_pe, Certs
cs, SubExp
w, BasicOp (Reshape ReshapeKind
k Shape
newshape VName
reshape_arr), [Param Type
p], [VName
arr]) <-
      forall dec.
Pat dec
-> SOAC (Wise SOACS)
-> Maybe
     (PatElem dec, Certs, SubExp, Exp (Wise SOACS), [Param Type],
      [VName])
isMapWithOp Pat (LetDec (Wise SOACS))
pat Op (Wise SOACS)
e,
    forall dec. Param dec -> VName
paramName Param Type
p forall a. Eq a => a -> a -> Bool
== VName
reshape_arr,
    Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ VName -> UsageTable -> Bool
UT.isConsumed (forall dec. PatElem dec -> VName
patElemName PatElem (VarWisdom, Type)
map_pe) UsageTable
used = forall rep. RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$ do
      forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (ExpDec (Wise SOACS))
aux1 forall a. Semigroup a => a -> a -> a
<> Certs
cs) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec (Wise SOACS))
pat forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
        ReshapeKind -> Shape -> VName -> BasicOp
Reshape ReshapeKind
k (forall d. [d] -> ShapeBase d
Shape [SubExp
w] forall a. Semigroup a => a -> a -> a
<> Shape
newshape) VName
arr
  | Just (PatElem (VarWisdom, Type)
_, Certs
cs, SubExp
_, BasicOp (Concat Int
d (VName
arr :| [VName]
arrs) SubExp
dw), [Param Type]
ps, VName
outer_arr : [VName]
outer_arrs) <-
      forall dec.
Pat dec
-> SOAC (Wise SOACS)
-> Maybe
     (PatElem dec, Certs, SubExp, Exp (Wise SOACS), [Param Type],
      [VName])
isMapWithOp Pat (LetDec (Wise SOACS))
pat Op (Wise SOACS)
e,
    (VName
arr forall a. a -> [a] -> [a]
: [VName]
arrs) forall a. Eq a => a -> a -> Bool
== forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName [Param Type]
ps =
      forall rep. RuleM rep () -> Rule rep
Simplify forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (ExpDec (Wise SOACS))
aux1 forall a. Semigroup a => a -> a -> a
<> Certs
cs) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec (Wise SOACS))
pat forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
        Int -> NonEmpty VName -> SubExp -> BasicOp
Concat (Int
d forall a. Num a => a -> a -> a
+ Int
1) (VName
outer_arr forall a. a -> [a] -> NonEmpty a
:| [VName]
outer_arrs) SubExp
dw
  | Just
      (PatElem (VarWisdom, Type)
map_pe, Certs
cs, SubExp
_, BasicOp (Rearrange [Int]
perm VName
rearrange_arr), [Param Type
p], [VName
arr]) <-
      forall dec.
Pat dec
-> SOAC (Wise SOACS)
-> Maybe
     (PatElem dec, Certs, SubExp, Exp (Wise SOACS), [Param Type],
      [VName])
isMapWithOp Pat (LetDec (Wise SOACS))
pat Op (Wise SOACS)
e,
    forall dec. Param dec -> VName
paramName Param Type
p forall a. Eq a => a -> a -> Bool
== VName
rearrange_arr,
    Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ VName -> UsageTable -> Bool
UT.isConsumed (forall dec. PatElem dec -> VName
patElemName PatElem (VarWisdom, Type)
map_pe) UsageTable
used =
      forall rep. RuleM rep () -> Rule rep
Simplify forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (ExpDec (Wise SOACS))
aux1 forall a. Semigroup a => a -> a -> a
<> Certs
cs) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec (Wise SOACS))
pat forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
        [Int] -> VName -> BasicOp
Rearrange (Int
0 forall a. a -> [a] -> [a]
: forall a b. (a -> b) -> [a] -> [b]
map (Int
1 +) [Int]
perm) VName
arr
  | Just (PatElem (VarWisdom, Type)
map_pe, Certs
cs, SubExp
_, BasicOp (Rotate [SubExp]
rots VName
rotate_arr), [Param Type
p], [VName
arr]) <-
      forall dec.
Pat dec
-> SOAC (Wise SOACS)
-> Maybe
     (PatElem dec, Certs, SubExp, Exp (Wise SOACS), [Param Type],
      [VName])
isMapWithOp Pat (LetDec (Wise SOACS))
pat Op (Wise SOACS)
e,
    forall dec. Param dec -> VName
paramName Param Type
p forall a. Eq a => a -> a -> Bool
== VName
rotate_arr,
    Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ VName -> UsageTable -> Bool
UT.isConsumed (forall dec. PatElem dec -> VName
patElemName PatElem (VarWisdom, Type)
map_pe) UsageTable
used =
      forall rep. RuleM rep () -> Rule rep
Simplify forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (ExpDec (Wise SOACS))
aux1 forall a. Semigroup a => a -> a -> a
<> Certs
cs) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec (Wise SOACS))
pat forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
        [SubExp] -> VName -> BasicOp
Rotate (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0 forall a. a -> [a] -> [a]
: [SubExp]
rots) VName
arr
mapOpToOp BottomUp (Wise SOACS)
_ Pat (LetDec (Wise SOACS))
_ StmAux (ExpDec (Wise SOACS))
_ Op (Wise SOACS)
_ = forall rep. Rule rep
Skip

isMapWithOp ::
  Pat dec ->
  SOAC (Wise SOACS) ->
  Maybe
    ( PatElem dec,
      Certs,
      SubExp,
      Exp (Wise SOACS),
      [Param Type],
      [VName]
    )
isMapWithOp :: forall dec.
Pat dec
-> SOAC (Wise SOACS)
-> Maybe
     (PatElem dec, Certs, SubExp, Exp (Wise SOACS), [Param Type],
      [VName])
isMapWithOp Pat dec
pat SOAC (Wise SOACS)
e
  | Pat [PatElem dec
map_pe] <- Pat dec
pat,
    Screma SubExp
w [VName]
arrs ScremaForm (Wise SOACS)
form <- SOAC (Wise SOACS)
e,
    Just Lambda (Wise SOACS)
map_lam <- forall rep. ScremaForm rep -> Maybe (Lambda rep)
isMapSOAC ScremaForm (Wise SOACS)
form,
    [Let (Pat [PatElem (LetDec (Wise SOACS))
pe]) StmAux (ExpDec (Wise SOACS))
aux2 Exp (Wise SOACS)
e'] <- forall rep. Stms rep -> [Stm rep]
stmsToList forall a b. (a -> b) -> a -> b
$ forall rep. Body rep -> Stms rep
bodyStms forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> Body rep
lambdaBody Lambda (Wise SOACS)
map_lam,
    [SubExpRes Certs
_ (Var VName
r)] <- forall rep. Body rep -> Result
bodyResult forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> Body rep
lambdaBody Lambda (Wise SOACS)
map_lam,
    VName
r forall a. Eq a => a -> a -> Bool
== forall dec. PatElem dec -> VName
patElemName PatElem (LetDec (Wise SOACS))
pe =
      forall a. a -> Maybe a
Just (PatElem dec
map_pe, forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (ExpDec (Wise SOACS))
aux2, SubExp
w, Exp (Wise SOACS)
e', forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda (Wise SOACS)
map_lam, [VName]
arrs)
  | Bool
otherwise = forall a. Maybe a
Nothing

-- | Some of the results of a reduction (or really: Redomap) may be
-- dead.  We remove them here.  The trick is that we need to look at
-- the data dependencies to see that the "dead" result is not
-- actually used for computing one of the live ones.
removeDeadReduction :: BottomUpRuleOp (Wise SOACS)
removeDeadReduction :: BottomUpRuleOp (Wise SOACS)
removeDeadReduction (SymbolTable (Wise SOACS)
_, UsageTable
used) Pat (LetDec (Wise SOACS))
pat StmAux (ExpDec (Wise SOACS))
aux (Screma SubExp
w [VName]
arrs ScremaForm (Wise SOACS)
form)
  | Just ([Reduce Commutativity
comm Lambda (Wise SOACS)
redlam [SubExp]
nes], Lambda (Wise SOACS)
maplam) <- forall rep. ScremaForm rep -> Maybe ([Reduce rep], Lambda rep)
isRedomapSOAC ScremaForm (Wise SOACS)
form,
    Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> UsageTable -> Bool
`UT.used` UsageTable
used) forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [VName]
patNames Pat (LetDec (Wise SOACS))
pat, -- Quick/cheap check
    let ([PatElem (VarWisdom, Type)]
red_pes, [PatElem (VarWisdom, Type)]
map_pes) = forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec (Wise SOACS))
pat,
    let redlam_deps :: Dependencies
redlam_deps = forall rep. ASTRep rep => Body rep -> Dependencies
dataDependencies forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> Body rep
lambdaBody Lambda (Wise SOACS)
redlam,
    let redlam_res :: Result
redlam_res = forall rep. Body rep -> Result
bodyResult forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> Body rep
lambdaBody Lambda (Wise SOACS)
redlam,
    let redlam_params :: [LParam (Wise SOACS)]
redlam_params = forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda (Wise SOACS)
redlam,
    let used_after :: [Param Type]
used_after =
          forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. (a -> Bool) -> [a] -> [a]
filter ((VName -> UsageTable -> Bool
`UT.used` UsageTable
used) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. PatElem dec -> VName
patElemName forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) forall a b. (a -> b) -> a -> b
$
            forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem (VarWisdom, Type)]
red_pes [LParam (Wise SOACS)]
redlam_params,
    let necessary :: Names
necessary =
          forall dec.
(Param dec -> Bool)
-> [(Param dec, SubExp)] -> Dependencies -> Names
findNecessaryForReturned
            (forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Param Type]
used_after)
            (forall a b. [a] -> [b] -> [(a, b)]
zip [LParam (Wise SOACS)]
redlam_params forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp forall a b. (a -> b) -> a -> b
$ Result
redlam_res forall a. Semigroup a => a -> a -> a
<> Result
redlam_res)
            Dependencies
redlam_deps,
    let alive_mask :: [Bool]
alive_mask = forall a b. (a -> b) -> [a] -> [b]
map ((VName -> Names -> Bool
`nameIn` Names
necessary) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. Param dec -> VName
paramName) [LParam (Wise SOACS)]
redlam_params,
    Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (forall a. Eq a => a -> a -> Bool
== Bool
True) (forall a. Int -> [a] -> [a]
take (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) [Bool]
alive_mask) = forall rep. RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$ do
      let fixDeadToNeutral :: Bool -> a -> Maybe a
fixDeadToNeutral Bool
lives a
ne = if Bool
lives then forall a. Maybe a
Nothing else forall a. a -> Maybe a
Just a
ne
          dead_fix :: [Maybe SubExp]
dead_fix = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall {a}. Bool -> a -> Maybe a
fixDeadToNeutral [Bool]
alive_mask [SubExp]
nes
          ([PatElem (VarWisdom, Type)]
used_red_pes, [Param Type]
_, [SubExp]
used_nes) =
            forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. (a -> Bool) -> [a] -> [a]
filter (\(PatElem (VarWisdom, Type)
_, Param Type
x, SubExp
_) -> forall dec. Param dec -> VName
paramName Param Type
x VName -> Names -> Bool
`nameIn` Names
necessary) forall a b. (a -> b) -> a -> b
$
              forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [PatElem (VarWisdom, Type)]
red_pes [LParam (Wise SOACS)]
redlam_params [SubExp]
nes

      let maplam' :: Lambda (Wise SOACS)
maplam' = forall rep. [Bool] -> Lambda rep -> Lambda rep
removeLambdaResults (forall a. Int -> [a] -> [a]
take (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) [Bool]
alive_mask) Lambda (Wise SOACS)
maplam
      Lambda (Wise SOACS)
redlam' <- forall rep. [Bool] -> Lambda rep -> Lambda rep
removeLambdaResults (forall a. Int -> [a] -> [a]
take (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) [Bool]
alive_mask) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *).
(MonadBuilder m, Buildable (Rep m), BuilderOps (Rep m)) =>
Lambda (Rep m) -> [Maybe SubExp] -> m (Lambda (Rep m))
fixLambdaParams Lambda (Wise SOACS)
redlam ([Maybe SubExp]
dead_fix forall a. [a] -> [a] -> [a]
++ [Maybe SubExp]
dead_fix)

      forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec (Wise SOACS))
aux forall a b. (a -> b) -> a -> b
$
        forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind (forall dec. [PatElem dec] -> Pat dec
Pat forall a b. (a -> b) -> a -> b
$ [PatElem (VarWisdom, Type)]
used_red_pes forall a. [a] -> [a] -> [a]
++ [PatElem (VarWisdom, Type)]
map_pes) forall a b. (a -> b) -> a -> b
$
          forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$
            forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName]
arrs forall a b. (a -> b) -> a -> b
$
              forall rep. [Reduce rep] -> Lambda rep -> ScremaForm rep
redomapSOAC [forall rep. Commutativity -> Lambda rep -> [SubExp] -> Reduce rep
Reduce Commutativity
comm Lambda (Wise SOACS)
redlam' [SubExp]
used_nes] Lambda (Wise SOACS)
maplam'
removeDeadReduction BottomUp (Wise SOACS)
_ Pat (LetDec (Wise SOACS))
_ StmAux (ExpDec (Wise SOACS))
_ Op (Wise SOACS)
_ = forall rep. Rule rep
Skip

-- | If we are writing to an array that is never used, get rid of it.
removeDeadWrite :: BottomUpRuleOp (Wise SOACS)
removeDeadWrite :: BottomUpRuleOp (Wise SOACS)
removeDeadWrite (SymbolTable (Wise SOACS)
_, UsageTable
used) Pat (LetDec (Wise SOACS))
pat StmAux (ExpDec (Wise SOACS))
aux (Scatter SubExp
w [VName]
arrs Lambda (Wise SOACS)
fun [(Shape, Int, VName)]
dests) =
  let ([Result]
i_ses, Result
v_ses) = forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ forall array a. [(Shape, Int, array)] -> [a] -> [([a], a)]
groupScatterResults' [(Shape, Int, VName)]
dests forall a b. (a -> b) -> a -> b
$ forall rep. Body rep -> Result
bodyResult forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> Body rep
lambdaBody Lambda (Wise SOACS)
fun
      ([[Type]]
i_ts, [Type]
v_ts) = forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ forall array a. [(Shape, Int, array)] -> [a] -> [([a], a)]
groupScatterResults' [(Shape, Int, VName)]
dests forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda (Wise SOACS)
fun
      isUsed :: (PatElem (VarWisdom, Type), Result, SubExpRes, [Type], Type,
 (Shape, Int, VName))
-> Bool
isUsed (PatElem (VarWisdom, Type)
bindee, Result
_, SubExpRes
_, [Type]
_, Type
_, (Shape, Int, VName)
_) = (VName -> UsageTable -> Bool
`UT.used` UsageTable
used) forall a b. (a -> b) -> a -> b
$ forall dec. PatElem dec -> VName
patElemName PatElem (VarWisdom, Type)
bindee
      ([PatElem (VarWisdom, Type)]
pat', [Result]
i_ses', Result
v_ses', [[Type]]
i_ts', [Type]
v_ts', [(Shape, Int, VName)]
dests') =
        forall a b c d e f.
[(a, b, c, d, e, f)] -> ([a], [b], [c], [d], [e], [f])
unzip6 forall a b. (a -> b) -> a -> b
$ forall a. (a -> Bool) -> [a] -> [a]
filter (PatElem (VarWisdom, Type), Result, SubExpRes, [Type], Type,
 (Shape, Int, VName))
-> Bool
isUsed forall a b. (a -> b) -> a -> b
$ forall a b c d e f.
[a] -> [b] -> [c] -> [d] -> [e] -> [f] -> [(a, b, c, d, e, f)]
zip6 (forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec (Wise SOACS))
pat) [Result]
i_ses Result
v_ses [[Type]]
i_ts [Type]
v_ts [(Shape, Int, VName)]
dests
      fun' :: Lambda (Wise SOACS)
fun' =
        Lambda (Wise SOACS)
fun
          { lambdaBody :: Body (Wise SOACS)
lambdaBody = (forall rep. Lambda rep -> Body rep
lambdaBody Lambda (Wise SOACS)
fun) {bodyResult :: Result
bodyResult = forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [Result]
i_ses' forall a. [a] -> [a] -> [a]
++ Result
v_ses'},
            lambdaReturnType :: [Type]
lambdaReturnType = forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Type]]
i_ts' forall a. [a] -> [a] -> [a]
++ [Type]
v_ts'
          }
   in if Pat (LetDec (Wise SOACS))
pat forall a. Eq a => a -> a -> Bool
/= forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (VarWisdom, Type)]
pat'
        then
          forall rep. RuleM rep () -> Rule rep
Simplify forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec (Wise SOACS))
aux forall a b. (a -> b) -> a -> b
$
            forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind (forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (VarWisdom, Type)]
pat') forall a b. (a -> b) -> a -> b
$
              forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$
                forall rep.
SubExp
-> [VName] -> Lambda rep -> [(Shape, Int, VName)] -> SOAC rep
Scatter SubExp
w [VName]
arrs Lambda (Wise SOACS)
fun' [(Shape, Int, VName)]
dests'
        else forall rep. Rule rep
Skip
removeDeadWrite BottomUp (Wise SOACS)
_ Pat (LetDec (Wise SOACS))
_ StmAux (ExpDec (Wise SOACS))
_ Op (Wise SOACS)
_ = forall rep. Rule rep
Skip

-- handles now concatenation of more than two arrays
fuseConcatScatter :: TopDownRuleOp (Wise SOACS)
fuseConcatScatter :: TopDownRuleOp (Wise SOACS)
fuseConcatScatter SymbolTable (Wise SOACS)
vtable Pat (LetDec (Wise SOACS))
pat StmAux (ExpDec (Wise SOACS))
_ (Scatter SubExp
_ [VName]
arrs Lambda (Wise SOACS)
fun [(Shape, Int, VName)]
dests)
  | Just (ws :: [SubExp]
ws@(SubExp
w' : [SubExp]
_), [[VName]]
xss, [Certs]
css) <- forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> Maybe (SubExp, [VName], Certs)
isConcat [VName]
arrs,
    [[VName]]
xivs <- forall a. [[a]] -> [[a]]
transpose [[VName]]
xss,
    forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (SubExp
w' ==) [SubExp]
ws = forall rep. RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$ do
      let r :: Int
r = forall (t :: * -> *) a. Foldable t => t a -> Int
length [[VName]]
xivs
      [Lambda (Wise SOACS)]
fun2s <- forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM (Int
r forall a. Num a => a -> a -> a
- Int
1) (forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda (Wise SOACS)
fun)
      let ([Result]
fun_is, [Result]
fun_vs) =
            forall a b. [(a, b)] -> ([a], [b])
unzip forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map (forall array a. [(Shape, Int, array)] -> [a] -> ([a], [a])
splitScatterResults [(Shape, Int, VName)]
dests forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Body rep -> Result
bodyResult forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Lambda rep -> Body rep
lambdaBody) forall a b. (a -> b) -> a -> b
$
              Lambda (Wise SOACS)
fun forall a. a -> [a] -> [a]
: [Lambda (Wise SOACS)]
fun2s
          ([[Type]]
its, [[Type]]
vts) =
            forall a b. [(a, b)] -> ([a], [b])
unzip forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Int -> a -> [a]
replicate Int
r forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall array a. [(Shape, Int, array)] -> [a] -> ([a], [a])
splitScatterResults [(Shape, Int, VName)]
dests forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda (Wise SOACS)
fun
          new_stmts :: Stms (Wise SOACS)
new_stmts = forall a. Monoid a => [a] -> a
mconcat forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (forall rep. Body rep -> Stms rep
bodyStms forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Lambda rep -> Body rep
lambdaBody) (Lambda (Wise SOACS)
fun forall a. a -> [a] -> [a]
: [Lambda (Wise SOACS)]
fun2s)
      let fun' :: Lambda (Wise SOACS)
fun' =
            Lambda
              { lambdaParams :: [LParam (Wise SOACS)]
lambdaParams = forall a. Monoid a => [a] -> a
mconcat forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall rep. Lambda rep -> [LParam rep]
lambdaParams (Lambda (Wise SOACS)
fun forall a. a -> [a] -> [a]
: [Lambda (Wise SOACS)]
fun2s),
                lambdaBody :: Body (Wise SOACS)
lambdaBody = forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody Stms (Wise SOACS)
new_stmts forall a b. (a -> b) -> a -> b
$ forall {a}. [[a]] -> [a]
mix [Result]
fun_is forall a. Semigroup a => a -> a -> a
<> forall {a}. [[a]] -> [a]
mix [Result]
fun_vs,
                lambdaReturnType :: [Type]
lambdaReturnType = forall {a}. [[a]] -> [a]
mix [[Type]]
its forall a. Semigroup a => a -> a -> a
<> forall {a}. [[a]] -> [a]
mix [[Type]]
vts
              }
      forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (forall a. Monoid a => [a] -> a
mconcat [Certs]
css) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec (Wise SOACS))
pat forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$
        forall rep.
SubExp
-> [VName] -> Lambda rep -> [(Shape, Int, VName)] -> SOAC rep
Scatter SubExp
w' (forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[VName]]
xivs) Lambda (Wise SOACS)
fun' forall a b. (a -> b) -> a -> b
$
          forall a b. (a -> b) -> [a] -> [b]
map (forall {b} {a} {c}. Num b => b -> (a, b, c) -> (a, b, c)
incWrites Int
r) [(Shape, Int, VName)]
dests
  where
    sizeOf :: VName -> Maybe SubExp
    sizeOf :: VName -> Maybe SubExp
sizeOf VName
x = forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall t. Typed t => t -> Type
typeOf forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep. VName -> SymbolTable rep -> Maybe (Entry rep)
ST.lookup VName
x SymbolTable (Wise SOACS)
vtable
    mix :: [[a]] -> [a]
mix = forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. [[a]] -> [[a]]
transpose
    incWrites :: b -> (a, b, c) -> (a, b, c)
incWrites b
r (a
w, b
n, c
a) = (a
w, b
n forall a. Num a => a -> a -> a
* b
r, c
a) -- ToDO: is it (n*r) or (n+r-1)??
    isConcat :: VName -> Maybe (SubExp, [VName], Certs)
isConcat VName
v = case forall rep. VName -> SymbolTable rep -> Maybe (Exp rep, Certs)
ST.lookupExp VName
v SymbolTable (Wise SOACS)
vtable of
      Just (BasicOp (Concat Int
0 (VName
x :| [VName]
ys) SubExp
_), Certs
cs) -> do
        SubExp
x_w <- VName -> Maybe SubExp
sizeOf VName
x
        [SubExp]
y_ws <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> Maybe SubExp
sizeOf [VName]
ys
        forall (f :: * -> *). Alternative f => Bool -> f ()
guard forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (SubExp
x_w ==) [SubExp]
y_ws
        forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp
x_w, VName
x forall a. a -> [a] -> [a]
: [VName]
ys, Certs
cs)
      Just (BasicOp (Reshape ReshapeKind
ReshapeCoerce Shape
_ VName
arr), Certs
cs) -> do
        (SubExp
a, [VName]
b, Certs
cs') <- VName -> Maybe (SubExp, [VName], Certs)
isConcat VName
arr
        forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp
a, [VName]
b, Certs
cs forall a. Semigroup a => a -> a -> a
<> Certs
cs')
      Maybe (Exp (Wise SOACS), Certs)
_ -> forall a. Maybe a
Nothing
fuseConcatScatter SymbolTable (Wise SOACS)
_ Pat (LetDec (Wise SOACS))
_ StmAux (ExpDec (Wise SOACS))
_ Op (Wise SOACS)
_ = forall rep. Rule rep
Skip

simplifyClosedFormReduce :: TopDownRuleOp (Wise SOACS)
simplifyClosedFormReduce :: TopDownRuleOp (Wise SOACS)
simplifyClosedFormReduce SymbolTable (Wise SOACS)
_ Pat (LetDec (Wise SOACS))
pat StmAux (ExpDec (Wise SOACS))
_ (Screma (Constant PrimValue
w) [VName]
_ ScremaForm (Wise SOACS)
form)
  | Just [SubExp]
nes <- forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap forall rep. Reduce rep -> [SubExp]
redNeutral forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep. ScremaForm rep -> Maybe ([Reduce rep], Lambda rep)
isRedomapSOAC ScremaForm (Wise SOACS)
form,
    PrimValue -> Bool
zeroIsh PrimValue
w =
      forall rep. RuleM rep () -> Rule rep
Simplify forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip (forall dec. Pat dec -> [VName]
patNames Pat (LetDec (Wise SOACS))
pat) [SubExp]
nes) forall a b. (a -> b) -> a -> b
$ \(VName
v, SubExp
ne) ->
        forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
v] forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
ne
simplifyClosedFormReduce SymbolTable (Wise SOACS)
vtable Pat (LetDec (Wise SOACS))
pat StmAux (ExpDec (Wise SOACS))
_ (Screma SubExp
_ [VName]
arrs ScremaForm (Wise SOACS)
form)
  | Just [Reduce Commutativity
_ Lambda (Wise SOACS)
red_fun [SubExp]
nes] <- forall rep. ScremaForm rep -> Maybe [Reduce rep]
isReduceSOAC ScremaForm (Wise SOACS)
form =
      forall rep. RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$ forall rep.
BuilderOps rep =>
VarLookup rep
-> Pat (LetDec rep)
-> Lambda rep
-> [SubExp]
-> [VName]
-> RuleM rep ()
foldClosedForm (forall rep. VName -> SymbolTable rep -> Maybe (Exp rep, Certs)
`ST.lookupExp` SymbolTable (Wise SOACS)
vtable) Pat (LetDec (Wise SOACS))
pat Lambda (Wise SOACS)
red_fun [SubExp]
nes [VName]
arrs
simplifyClosedFormReduce SymbolTable (Wise SOACS)
_ Pat (LetDec (Wise SOACS))
_ StmAux (ExpDec (Wise SOACS))
_ Op (Wise SOACS)
_ = forall rep. Rule rep
Skip

-- For now we just remove singleton SOACs.
simplifyKnownIterationSOAC ::
  (Buildable rep, BuilderOps rep, HasSOAC rep) =>
  TopDownRuleOp rep
simplifyKnownIterationSOAC :: forall rep.
(Buildable rep, BuilderOps rep, HasSOAC rep) =>
TopDownRuleOp rep
simplifyKnownIterationSOAC TopDown rep
_ Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ Op rep
op
  | Just (Screma (Constant PrimValue
k) [VName]
arrs (ScremaForm [Scan rep]
scans [Reduce rep]
reds Lambda rep
map_lam)) <- forall rep. HasSOAC rep => Op rep -> Maybe (SOAC rep)
asSOAC Op rep
op,
    PrimValue -> Bool
oneIsh PrimValue
k = forall rep. RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$ do
      let (Reduce Commutativity
_ Lambda rep
red_lam [SubExp]
red_nes) = forall rep. Buildable rep => [Reduce rep] -> Reduce rep
singleReduce [Reduce rep]
reds
          (Scan Lambda rep
scan_lam [SubExp]
scan_nes) = forall rep. Buildable rep => [Scan rep] -> Scan rep
singleScan [Scan rep]
scans
          ([PatElem (LetDec rep)]
scan_pes, [PatElem (LetDec rep)]
red_pes, [PatElem (LetDec rep)]
map_pes) =
            forall a. Int -> Int -> [a] -> ([a], [a], [a])
splitAt3 (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
scan_nes) (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
red_nes) forall a b. (a -> b) -> a -> b
$
              forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec rep)
pat
          bindMapParam :: Param dec -> VName -> m ()
bindMapParam Param dec
p VName
a = do
            Type
a_t <- forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
a
            forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [forall dec. Param dec -> VName
paramName Param dec
p] forall a b. (a -> b) -> a -> b
$
              forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
                VName -> Slice SubExp -> BasicOp
Index VName
a forall a b. (a -> b) -> a -> b
$
                  Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
a_t [forall d. d -> DimIndex d
DimFix forall a b. (a -> b) -> a -> b
$ forall v. IsValue v => v -> SubExp
constant (Int64
0 :: Int64)]
          bindArrayResult :: PatElem dec -> SubExpRes -> m ()
bindArrayResult PatElem dec
pe (SubExpRes Certs
cs SubExp
se) =
            forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [forall dec. PatElem dec -> VName
patElemName PatElem dec
pe] forall a b. (a -> b) -> a -> b
$
              forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
                [SubExp] -> Type -> BasicOp
ArrayLit [SubExp
se] forall a b. (a -> b) -> a -> b
$
                  forall u. TypeBase Shape u -> TypeBase Shape u
rowType forall a b. (a -> b) -> a -> b
$
                    forall dec. Typed dec => PatElem dec -> Type
patElemType PatElem dec
pe
          bindResult :: PatElem dec -> SubExpRes -> m ()
bindResult PatElem dec
pe (SubExpRes Certs
cs SubExp
se) =
            forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [forall dec. PatElem dec -> VName
patElemName PatElem dec
pe] forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se

      forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ forall {m :: * -> *} {dec}.
MonadBuilder m =>
Param dec -> VName -> m ()
bindMapParam (forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
map_lam) [VName]
arrs
      (Result
to_scan, Result
to_red, Result
map_res) <-
        forall a. Int -> Int -> [a] -> ([a], [a], [a])
splitAt3 (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
scan_nes) (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
red_nes)
          forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind (forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
map_lam)
      Result
scan_res <- forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m Result
eLambda Lambda rep
scan_lam forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp forall a b. (a -> b) -> a -> b
$ [SubExp]
scan_nes forall a. [a] -> [a] -> [a]
++ forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp Result
to_scan
      Result
red_res <- forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m Result
eLambda Lambda rep
red_lam forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp forall a b. (a -> b) -> a -> b
$ [SubExp]
red_nes forall a. [a] -> [a] -> [a]
++ forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp Result
to_red

      forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ forall {m :: * -> *} {dec}.
(MonadBuilder m, Typed dec) =>
PatElem dec -> SubExpRes -> m ()
bindArrayResult [PatElem (LetDec rep)]
scan_pes Result
scan_res
      forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ forall {m :: * -> *} {dec}.
MonadBuilder m =>
PatElem dec -> SubExpRes -> m ()
bindResult [PatElem (LetDec rep)]
red_pes Result
red_res
      forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ forall {m :: * -> *} {dec}.
(MonadBuilder m, Typed dec) =>
PatElem dec -> SubExpRes -> m ()
bindArrayResult [PatElem (LetDec rep)]
map_pes Result
map_res
simplifyKnownIterationSOAC TopDown rep
_ Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ Op rep
op
  | Just (Stream (Constant PrimValue
k) [VName]
arrs [SubExp]
nes Lambda rep
fold_lam) <- forall rep. HasSOAC rep => Op rep -> Maybe (SOAC rep)
asSOAC Op rep
op,
    PrimValue -> Bool
oneIsh PrimValue
k = forall rep. RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$ do
      let (Param Type
chunk_param, [Param Type]
acc_params, [Param Type]
slice_params) =
            forall dec.
Int -> [Param dec] -> (Param dec, [Param dec], [Param dec])
partitionChunkedFoldParameters (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) (forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
fold_lam)

      forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [forall dec. Param dec -> VName
paramName Param Type
chunk_param] forall a b. (a -> b) -> a -> b
$
        forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
          SubExp -> BasicOp
SubExp forall a b. (a -> b) -> a -> b
$
            IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1

      forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [Param Type]
acc_params [SubExp]
nes) forall a b. (a -> b) -> a -> b
$ \(Param Type
p, SubExp
ne) ->
        forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [forall dec. Param dec -> VName
paramName Param Type
p] forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
ne

      forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [Param Type]
slice_params [VName]
arrs) forall a b. (a -> b) -> a -> b
$ \(Param Type
p, VName
arr) ->
        forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [forall dec. Param dec -> VName
paramName Param Type
p] forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
arr

      Result
res <- forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
fold_lam

      forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip (forall dec. Pat dec -> [VName]
patNames Pat (LetDec rep)
pat) Result
res) forall a b. (a -> b) -> a -> b
$ \(VName
v, SubExpRes Certs
cs SubExp
se) ->
        forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
v] forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
simplifyKnownIterationSOAC TopDown rep
_ Pat (LetDec rep)
_ StmAux (ExpDec rep)
_ Op rep
_ = forall rep. Rule rep
Skip

data ArrayOp
  = ArrayIndexing Certs VName (Slice SubExp)
  | ArrayRearrange Certs VName [Int]
  | ArrayRotate Certs VName [SubExp]
  | ArrayReshape Certs VName ReshapeKind Shape
  | ArrayCopy Certs VName
  | -- | Never constructed.
    ArrayVar Certs VName
  deriving (ArrayOp -> ArrayOp -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ArrayOp -> ArrayOp -> Bool
$c/= :: ArrayOp -> ArrayOp -> Bool
== :: ArrayOp -> ArrayOp -> Bool
$c== :: ArrayOp -> ArrayOp -> Bool
Eq, Eq ArrayOp
ArrayOp -> ArrayOp -> Bool
ArrayOp -> ArrayOp -> Ordering
ArrayOp -> ArrayOp -> ArrayOp
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: ArrayOp -> ArrayOp -> ArrayOp
$cmin :: ArrayOp -> ArrayOp -> ArrayOp
max :: ArrayOp -> ArrayOp -> ArrayOp
$cmax :: ArrayOp -> ArrayOp -> ArrayOp
>= :: ArrayOp -> ArrayOp -> Bool
$c>= :: ArrayOp -> ArrayOp -> Bool
> :: ArrayOp -> ArrayOp -> Bool
$c> :: ArrayOp -> ArrayOp -> Bool
<= :: ArrayOp -> ArrayOp -> Bool
$c<= :: ArrayOp -> ArrayOp -> Bool
< :: ArrayOp -> ArrayOp -> Bool
$c< :: ArrayOp -> ArrayOp -> Bool
compare :: ArrayOp -> ArrayOp -> Ordering
$ccompare :: ArrayOp -> ArrayOp -> Ordering
Ord, Int -> ArrayOp -> ShowS
[ArrayOp] -> ShowS
ArrayOp -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [ArrayOp] -> ShowS
$cshowList :: [ArrayOp] -> ShowS
show :: ArrayOp -> [Char]
$cshow :: ArrayOp -> [Char]
showsPrec :: Int -> ArrayOp -> ShowS
$cshowsPrec :: Int -> ArrayOp -> ShowS
Show)

arrayOpArr :: ArrayOp -> VName
arrayOpArr :: ArrayOp -> VName
arrayOpArr (ArrayIndexing Certs
_ VName
arr Slice SubExp
_) = VName
arr
arrayOpArr (ArrayRearrange Certs
_ VName
arr [Int]
_) = VName
arr
arrayOpArr (ArrayRotate Certs
_ VName
arr [SubExp]
_) = VName
arr
arrayOpArr (ArrayReshape Certs
_ VName
arr ReshapeKind
_ Shape
_) = VName
arr
arrayOpArr (ArrayCopy Certs
_ VName
arr) = VName
arr
arrayOpArr (ArrayVar Certs
_ VName
arr) = VName
arr

arrayOpCerts :: ArrayOp -> Certs
arrayOpCerts :: ArrayOp -> Certs
arrayOpCerts (ArrayIndexing Certs
cs VName
_ Slice SubExp
_) = Certs
cs
arrayOpCerts (ArrayRearrange Certs
cs VName
_ [Int]
_) = Certs
cs
arrayOpCerts (ArrayRotate Certs
cs VName
_ [SubExp]
_) = Certs
cs
arrayOpCerts (ArrayReshape Certs
cs VName
_ ReshapeKind
_ Shape
_) = Certs
cs
arrayOpCerts (ArrayCopy Certs
cs VName
_) = Certs
cs
arrayOpCerts (ArrayVar Certs
cs VName
_) = Certs
cs

isArrayOp :: Certs -> Exp rep -> Maybe ArrayOp
isArrayOp :: forall rep. Certs -> Exp rep -> Maybe ArrayOp
isArrayOp Certs
cs (BasicOp (Index VName
arr Slice SubExp
slice)) =
  forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ Certs -> VName -> Slice SubExp -> ArrayOp
ArrayIndexing Certs
cs VName
arr Slice SubExp
slice
isArrayOp Certs
cs (BasicOp (Rearrange [Int]
perm VName
arr)) =
  forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ Certs -> VName -> [Int] -> ArrayOp
ArrayRearrange Certs
cs VName
arr [Int]
perm
isArrayOp Certs
cs (BasicOp (Rotate [SubExp]
rots VName
arr)) =
  forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ Certs -> VName -> [SubExp] -> ArrayOp
ArrayRotate Certs
cs VName
arr [SubExp]
rots
isArrayOp Certs
cs (BasicOp (Reshape ReshapeKind
k Shape
new_shape VName
arr)) =
  forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ Certs -> VName -> ReshapeKind -> Shape -> ArrayOp
ArrayReshape Certs
cs VName
arr ReshapeKind
k Shape
new_shape
isArrayOp Certs
cs (BasicOp (Copy VName
arr)) =
  forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ Certs -> VName -> ArrayOp
ArrayCopy Certs
cs VName
arr
isArrayOp Certs
_ Exp rep
_ =
  forall a. Maybe a
Nothing

fromArrayOp :: ArrayOp -> (Certs, Exp rep)
fromArrayOp :: forall rep. ArrayOp -> (Certs, Exp rep)
fromArrayOp (ArrayIndexing Certs
cs VName
arr Slice SubExp
slice) = (Certs
cs, forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr Slice SubExp
slice)
fromArrayOp (ArrayRearrange Certs
cs VName
arr [Int]
perm) = (Certs
cs, forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Rearrange [Int]
perm VName
arr)
fromArrayOp (ArrayRotate Certs
cs VName
arr [SubExp]
rots) = (Certs
cs, forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ [SubExp] -> VName -> BasicOp
Rotate [SubExp]
rots VName
arr)
fromArrayOp (ArrayReshape Certs
cs VName
arr ReshapeKind
k Shape
new_shape) = (Certs
cs, forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ ReshapeKind -> Shape -> VName -> BasicOp
Reshape ReshapeKind
k Shape
new_shape VName
arr)
fromArrayOp (ArrayCopy Certs
cs VName
arr) = (Certs
cs, forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
arr)
fromArrayOp (ArrayVar Certs
cs VName
arr) = (Certs
cs, forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
arr)

arrayOps ::
  forall rep.
  (Buildable rep, HasSOAC rep) =>
  Certs ->
  Body rep ->
  S.Set (Pat (LetDec rep), ArrayOp)
arrayOps :: forall rep.
(Buildable rep, HasSOAC rep) =>
Certs -> Body rep -> Set (Pat (LetDec rep), ArrayOp)
arrayOps Certs
cs = forall a. Monoid a => [a] -> a
mconcat forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map Stm rep -> Set (Pat (LetDec rep), ArrayOp)
onStm forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Stms rep -> [Stm rep]
stmsToList forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Body rep -> Stms rep
bodyStms
  where
    -- It is not safe to move everything out of branches (#1874);
    -- probably we need to put some more intelligence in here somehow.
    onStm :: Stm rep -> Set (Pat (LetDec rep), ArrayOp)
onStm (Let Pat (LetDec rep)
_ StmAux (ExpDec rep)
_ Match {}) = forall a. Monoid a => a
mempty
    onStm (Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux Exp rep
e) =
      case forall rep. Certs -> Exp rep -> Maybe ArrayOp
isArrayOp (Certs
cs forall a. Semigroup a => a -> a -> a
<> forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (ExpDec rep)
aux) Exp rep
e of
        Just ArrayOp
op -> forall a. a -> Set a
S.singleton (Pat (LetDec rep)
pat, ArrayOp
op)
        Maybe ArrayOp
Nothing -> forall s a. State s a -> s -> s
execState (forall (m :: * -> *) rep.
Monad m =>
Walker rep m -> Exp rep -> m ()
walkExpM (Certs
-> Walker rep (StateT (Set (Pat (LetDec rep), ArrayOp)) Identity)
walker (forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (ExpDec rep)
aux)) Exp rep
e) forall a. Monoid a => a
mempty
    onOp :: Certs -> OpC rep rep -> Set (Pat (LetDec rep), ArrayOp)
onOp Certs
more_cs OpC rep rep
op
      | Just SOAC rep
soac <- forall rep. HasSOAC rep => Op rep -> Maybe (SOAC rep)
asSOAC OpC rep rep
op =
          -- Copies are not safe to move out of nested ops (#1753).
          forall a. (a -> Bool) -> Set a -> Set a
S.filter (ArrayOp -> Bool
notCopy forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd) forall a b. (a -> b) -> a -> b
$
            forall w a. Writer w a -> w
execWriter forall a b. (a -> b) -> a -> b
$
              forall (m :: * -> *) frep trep.
Monad m =>
SOACMapper frep trep m -> SOAC frep -> m (SOAC trep)
mapSOACM
                forall rep (m :: * -> *). Monad m => SOACMapper rep rep m
identitySOACMapper {mapOnSOACLambda :: Lambda rep
-> WriterT (Set (Pat (LetDec rep), ArrayOp)) Identity (Lambda rep)
mapOnSOACLambda = Certs
-> Lambda rep
-> WriterT (Set (Pat (LetDec rep), ArrayOp)) Identity (Lambda rep)
onLambda Certs
more_cs}
                (SOAC rep
soac :: SOAC rep)
      | Bool
otherwise =
          forall a. Monoid a => a
mempty
    onLambda :: Certs
-> Lambda rep
-> WriterT (Set (Pat (LetDec rep), ArrayOp)) Identity (Lambda rep)
onLambda Certs
more_cs Lambda rep
lam = do
      forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell forall a b. (a -> b) -> a -> b
$ forall rep.
(Buildable rep, HasSOAC rep) =>
Certs -> Body rep -> Set (Pat (LetDec rep), ArrayOp)
arrayOps (Certs
cs forall a. Semigroup a => a -> a -> a
<> Certs
more_cs) forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam
      forall (f :: * -> *) a. Applicative f => a -> f a
pure Lambda rep
lam
    walker :: Certs
-> Walker rep (StateT (Set (Pat (LetDec rep), ArrayOp)) Identity)
walker Certs
more_cs =
      (forall rep (m :: * -> *). Monad m => Walker rep m
identityWalker @rep)
        { walkOnBody :: Scope rep
-> Body rep -> StateT (Set (Pat (LetDec rep), ArrayOp)) Identity ()
walkOnBody = forall a b. a -> b -> a
const forall a b. (a -> b) -> a -> b
$ forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Semigroup a => a -> a -> a
(<>) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep.
(Buildable rep, HasSOAC rep) =>
Certs -> Body rep -> Set (Pat (LetDec rep), ArrayOp)
arrayOps (Certs
cs forall a. Semigroup a => a -> a -> a
<> Certs
more_cs),
          walkOnOp :: OpC rep rep -> StateT (Set (Pat (LetDec rep), ArrayOp)) Identity ()
walkOnOp = forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Semigroup a => a -> a -> a
(<>) forall b c a. (b -> c) -> (a -> b) -> a -> c
. Certs -> OpC rep rep -> Set (Pat (LetDec rep), ArrayOp)
onOp Certs
more_cs
        }
    notCopy :: ArrayOp -> Bool
notCopy (ArrayCopy {}) = Bool
False
    notCopy ArrayOp
_ = Bool
True

replaceArrayOps ::
  forall rep.
  (Buildable rep, BuilderOps rep, HasSOAC rep) =>
  M.Map (Pat (LetDec rep)) ArrayOp ->
  Body rep ->
  Body rep
replaceArrayOps :: forall rep.
(Buildable rep, BuilderOps rep, HasSOAC rep) =>
Map (Pat (LetDec rep)) ArrayOp -> Body rep -> Body rep
replaceArrayOps Map (Pat (LetDec rep)) ArrayOp
substs (Body BodyDec rep
_ Stms rep
stms Result
res) =
  forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Stm rep -> Stm rep
onStm Stms rep
stms) Result
res
  where
    onStm :: Stm rep -> Stm rep
onStm (Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux Exp rep
e) =
      let (Certs
cs', Exp rep
e') =
            forall b a. b -> (a -> b) -> Maybe a -> b
maybe (forall a. Monoid a => a
mempty, forall frep trep. Mapper frep trep Identity -> Exp frep -> Exp trep
mapExp Mapper rep rep Identity
mapper Exp rep
e) forall rep. ArrayOp -> (Certs, Exp rep)
fromArrayOp forall a b. (a -> b) -> a -> b
$ forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Pat (LetDec rep)
pat Map (Pat (LetDec rep)) ArrayOp
substs
       in forall rep. Certs -> Stm rep -> Stm rep
certify Certs
cs' forall a b. (a -> b) -> a -> b
$ forall rep a.
Buildable rep =>
[Ident] -> StmAux a -> Exp rep -> Stm rep
mkLet' (forall dec. Typed dec => Pat dec -> [Ident]
patIdents Pat (LetDec rep)
pat) StmAux (ExpDec rep)
aux Exp rep
e'
    mapper :: Mapper rep rep Identity
mapper =
      (forall rep (m :: * -> *). Monad m => Mapper rep rep m
identityMapper @rep)
        { mapOnBody :: Scope rep -> Body rep -> Identity (Body rep)
mapOnBody = forall a b. a -> b -> a
const forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep.
(Buildable rep, BuilderOps rep, HasSOAC rep) =>
Map (Pat (LetDec rep)) ArrayOp -> Body rep -> Body rep
replaceArrayOps Map (Pat (LetDec rep)) ArrayOp
substs,
          mapOnOp :: Op rep -> Identity (Op rep)
mapOnOp = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. Op rep -> Op rep
onOp
        }
    onOp :: Op rep -> Op rep
onOp Op rep
op
      | Just (SOAC rep
soac :: SOAC rep) <- forall rep. HasSOAC rep => Op rep -> Maybe (SOAC rep)
asSOAC Op rep
op =
          forall rep. HasSOAC rep => SOAC rep -> Op rep
soacOp forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Identity a -> a
runIdentity forall a b. (a -> b) -> a -> b
$
            forall (m :: * -> *) frep trep.
Monad m =>
SOACMapper frep trep m -> SOAC frep -> m (SOAC trep)
mapSOACM forall rep (m :: * -> *). Monad m => SOACMapper rep rep m
identitySOACMapper {mapOnSOACLambda :: Lambda rep -> Identity (Lambda rep)
mapOnSOACLambda = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda rep -> Lambda rep
onLambda} SOAC rep
soac
      | Bool
otherwise =
          Op rep
op
    onLambda :: Lambda rep -> Lambda rep
onLambda Lambda rep
lam = Lambda rep
lam {lambdaBody :: Body rep
lambdaBody = forall rep.
(Buildable rep, BuilderOps rep, HasSOAC rep) =>
Map (Pat (LetDec rep)) ArrayOp -> Body rep -> Body rep
replaceArrayOps Map (Pat (LetDec rep)) ArrayOp
substs forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam}

-- Turn
--
--    map (\i -> ... xs[i] ...) (iota n)
--
-- into
--
--    map (\i x -> ... x ...) (iota n) xs
--
-- This is not because we want to encourage the map-iota pattern, but
-- it may be present in generated code.  This is an unfortunately
-- expensive simplification rule, since it requires multiple passes
-- over the entire lambda body.  It only handles the very simplest
-- case - if you find yourself planning to extend it to handle more
-- complex situations (rotate or whatnot), consider turning it into a
-- separate compiler pass instead.
simplifyMapIota ::
  forall rep.
  (Buildable rep, BuilderOps rep, HasSOAC rep) =>
  TopDownRuleOp rep
simplifyMapIota :: forall rep.
(Buildable rep, BuilderOps rep, HasSOAC rep) =>
TopDownRuleOp rep
simplifyMapIota TopDown rep
vtable Pat (LetDec rep)
screma_pat StmAux (ExpDec rep)
aux Op rep
op
  | Just (Screma SubExp
w [VName]
arrs (ScremaForm [Scan rep]
scan [Reduce rep]
reduce Lambda rep
map_lam) :: SOAC rep) <- forall rep. HasSOAC rep => Op rep -> Maybe (SOAC rep)
asSOAC Op rep
op,
    Just (Param Type
p, VName
_) <- forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (Param Type, VName) -> Bool
isIota (forall a b. [a] -> [b] -> [(a, b)]
zip (forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
map_lam) [VName]
arrs),
    [(Pat (LetDec rep), [SubExp], ArrayOp)]
indexings <-
      forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (VName
-> (Pat (LetDec rep), ArrayOp)
-> Maybe (Pat (LetDec rep), [SubExp], ArrayOp)
indexesWith (forall dec. Param dec -> VName
paramName Param Type
p)) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Set a -> [a]
S.toList forall a b. (a -> b) -> a -> b
$
        forall rep.
(Buildable rep, HasSOAC rep) =>
Certs -> Body rep -> Set (Pat (LetDec rep), ArrayOp)
arrayOps forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$
          forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
map_lam,
    Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(Pat (LetDec rep), [SubExp], ArrayOp)]
indexings = forall rep. RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$ do
      -- For each indexing with iota, add the corresponding array to
      -- the Screma, and construct a new lambda parameter.
      ([VName]
more_arrs, [Param Type]
more_params, [(Pat (LetDec rep), ArrayOp)]
replacements) <-
        forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. [Maybe a] -> [a]
catMaybes forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall {m :: * -> *} {a}.
MonadBuilder m =>
SubExp
-> (a, [SubExp], ArrayOp)
-> m (Maybe (VName, Param Type, (a, ArrayOp)))
mapOverArr SubExp
w) [(Pat (LetDec rep), [SubExp], ArrayOp)]
indexings
      let substs :: Map (Pat (LetDec rep)) ArrayOp
substs = forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(Pat (LetDec rep), ArrayOp)]
replacements
          map_lam' :: Lambda rep
map_lam' =
            Lambda rep
map_lam
              { lambdaParams :: [LParam rep]
lambdaParams = forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
map_lam forall a. Semigroup a => a -> a -> a
<> [Param Type]
more_params,
                lambdaBody :: Body rep
lambdaBody = forall rep.
(Buildable rep, BuilderOps rep, HasSOAC rep) =>
Map (Pat (LetDec rep)) ArrayOp -> Body rep -> Body rep
replaceArrayOps Map (Pat (LetDec rep)) ArrayOp
substs forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
map_lam
              }

      forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
screma_pat forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Op rep -> Exp rep
Op forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. HasSOAC rep => SOAC rep -> Op rep
soacOp forall a b. (a -> b) -> a -> b
$
        forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w ([VName]
arrs forall a. Semigroup a => a -> a -> a
<> [VName]
more_arrs) (forall rep.
[Scan rep] -> [Reduce rep] -> Lambda rep -> ScremaForm rep
ScremaForm [Scan rep]
scan [Reduce rep]
reduce Lambda rep
map_lam')
  where
    isIota :: (Param Type, VName) -> Bool
isIota (Param Type
_, VName
arr) = case forall rep. VName -> SymbolTable rep -> Maybe (BasicOp, Certs)
ST.lookupBasicOp VName
arr TopDown rep
vtable of
      Just (Iota SubExp
_ (Constant PrimValue
o) (Constant PrimValue
s) IntType
_, Certs
_) ->
        PrimValue -> Bool
zeroIsh PrimValue
o Bool -> Bool -> Bool
&& PrimValue -> Bool
oneIsh PrimValue
s
      Maybe (BasicOp, Certs)
_ -> Bool
False

    -- Find a 'DimFix i', optionally preceded by other DimFixes, and
    -- if so return those DimFixes.
    fixWith :: VName -> [DimIndex SubExp] -> Maybe [SubExp]
fixWith VName
i (DimFix SubExp
j : [DimIndex SubExp]
slice)
      | VName -> SubExp
Var VName
i forall a. Eq a => a -> a -> Bool
== SubExp
j = forall a. a -> Maybe a
Just []
      | Bool
otherwise = (SubExp
j :) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> [DimIndex SubExp] -> Maybe [SubExp]
fixWith VName
i [DimIndex SubExp]
slice
    fixWith VName
_ [DimIndex SubExp]
_ = forall a. Maybe a
Nothing

    indexesWith :: VName
-> (Pat (LetDec rep), ArrayOp)
-> Maybe (Pat (LetDec rep), [SubExp], ArrayOp)
indexesWith VName
v (Pat (LetDec rep)
pat, idx :: ArrayOp
idx@(ArrayIndexing Certs
cs VName
arr (Slice [DimIndex SubExp]
js)))
      | VName
arr forall rep. VName -> SymbolTable rep -> Bool
`ST.elem` TopDown rep
vtable,
        forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (forall rep. VName -> SymbolTable rep -> Bool
`ST.elem` TopDown rep
vtable) forall a b. (a -> b) -> a -> b
$ Certs -> [VName]
unCerts Certs
cs,
        Just [SubExp]
js' <- VName -> [DimIndex SubExp] -> Maybe [SubExp]
fixWith VName
v [DimIndex SubExp]
js,
        forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (forall rep. VName -> SymbolTable rep -> Bool
`ST.elem` TopDown rep
vtable) forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList forall a b. (a -> b) -> a -> b
$ forall a. FreeIn a => a -> Names
freeIn [SubExp]
js' =
          forall a. a -> Maybe a
Just (Pat (LetDec rep)
pat, [SubExp]
js', ArrayOp
idx)
    indexesWith VName
_ (Pat (LetDec rep), ArrayOp)
_ = forall a. Maybe a
Nothing

    properArr :: [SubExp] -> VName -> f VName
properArr [] VName
arr = forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
arr
    properArr [SubExp]
js VName
arr = do
      Type
arr_t <- forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
arr
      forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
arr) forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr forall a b. (a -> b) -> a -> b
$ Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
arr_t forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall d. d -> DimIndex d
DimFix [SubExp]
js

    mapOverArr :: SubExp
-> (a, [SubExp], ArrayOp)
-> m (Maybe (VName, Param Type, (a, ArrayOp)))
mapOverArr SubExp
w (a
pat, [SubExp]
js, ArrayIndexing Certs
cs VName
arr Slice SubExp
slice) = do
      VName
arr' <- forall {f :: * -> *}.
MonadBuilder f =>
[SubExp] -> VName -> f VName
properArr [SubExp]
js VName
arr
      Type
arr_t <- forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
arr'
      VName
arr'' <-
        if forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 Type
arr_t forall a. Eq a => a -> a -> Bool
== SubExp
w
          then forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
arr'
          else
            forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
arr forall a. [a] -> [a] -> [a]
++ [Char]
"_prefix") forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. BasicOp -> Exp rep
BasicOp forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Slice SubExp -> BasicOp
Index VName
arr' forall a b. (a -> b) -> a -> b
$
              Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
arr_t [forall d. d -> d -> d -> DimIndex d
DimSlice (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) SubExp
w (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1)]
      Param Type
arr_elem_param <- forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam (VName -> [Char]
baseString VName
arr forall a. [a] -> [a] -> [a]
++ [Char]
"_elem") (forall u. TypeBase Shape u -> TypeBase Shape u
rowType Type
arr_t)
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$
        forall a. a -> Maybe a
Just
          ( VName
arr'',
            Param Type
arr_elem_param,
            ( a
pat,
              Certs -> VName -> Slice SubExp -> ArrayOp
ArrayIndexing Certs
cs (forall dec. Param dec -> VName
paramName Param Type
arr_elem_param) (forall d. [DimIndex d] -> Slice d
Slice (forall a. Int -> [a] -> [a]
drop (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
js forall a. Num a => a -> a -> a
+ Int
1) (forall d. Slice d -> [DimIndex d]
unSlice Slice SubExp
slice)))
            )
          )
    mapOverArr SubExp
_ (a, [SubExp], ArrayOp)
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing
simplifyMapIota TopDown rep
_ Pat (LetDec rep)
_ StmAux (ExpDec rep)
_ Op rep
_ = forall rep. Rule rep
Skip

-- If a Screma's map function contains a transformation
-- (e.g. transpose) on a parameter, create a new parameter
-- corresponding to that transformation performed on the rows of the
-- full array.
moveTransformToInput :: TopDownRuleOp (Wise SOACS)
moveTransformToInput :: TopDownRuleOp (Wise SOACS)
moveTransformToInput SymbolTable (Wise SOACS)
vtable Pat (LetDec (Wise SOACS))
screma_pat StmAux (ExpDec (Wise SOACS))
aux soac :: Op (Wise SOACS)
soac@(Screma SubExp
w [VName]
arrs (ScremaForm [Scan (Wise SOACS)]
scan [Reduce (Wise SOACS)]
reduce Lambda (Wise SOACS)
map_lam))
  | [(Pat (VarWisdom, Type), ArrayOp)]
ops <- forall a. (a -> Bool) -> [a] -> [a]
filter (Pat (VarWisdom, Type), ArrayOp) -> Bool
arrayIsMapParam forall a b. (a -> b) -> a -> b
$ forall a. Set a -> [a]
S.toList forall a b. (a -> b) -> a -> b
$ forall rep.
(Buildable rep, HasSOAC rep) =>
Certs -> Body rep -> Set (Pat (LetDec rep), ArrayOp)
arrayOps forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> Body rep
lambdaBody Lambda (Wise SOACS)
map_lam,
    Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(Pat (VarWisdom, Type), ArrayOp)]
ops = forall rep. RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$ do
      ([VName]
more_arrs, [Param Type]
more_params, [(Pat (VarWisdom, Type), ArrayOp)]
replacements) <-
        forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. [Maybe a] -> [a]
catMaybes forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Pat (VarWisdom, Type), ArrayOp)
-> RuleM
     (Wise SOACS)
     (Maybe (VName, Param Type, (Pat (VarWisdom, Type), ArrayOp)))
mapOverArr [(Pat (VarWisdom, Type), ArrayOp)]
ops

      forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (forall (t :: * -> *) a. Foldable t => t a -> Bool
null [VName]
more_arrs) forall rep a. RuleM rep a
cannotSimplify

      let map_lam' :: Lambda (Wise SOACS)
map_lam' =
            Lambda (Wise SOACS)
map_lam
              { lambdaParams :: [LParam (Wise SOACS)]
lambdaParams = forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda (Wise SOACS)
map_lam forall a. Semigroup a => a -> a -> a
<> [Param Type]
more_params,
                lambdaBody :: Body (Wise SOACS)
lambdaBody = forall rep.
(Buildable rep, BuilderOps rep, HasSOAC rep) =>
Map (Pat (LetDec rep)) ArrayOp -> Body rep -> Body rep
replaceArrayOps (forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(Pat (VarWisdom, Type), ArrayOp)]
replacements) forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> Body rep
lambdaBody Lambda (Wise SOACS)
map_lam
              }

      forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec (Wise SOACS))
aux forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec (Wise SOACS))
screma_pat forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$
        forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w ([VName]
arrs forall a. Semigroup a => a -> a -> a
<> [VName]
more_arrs) (forall rep.
[Scan rep] -> [Reduce rep] -> Lambda rep -> ScremaForm rep
ScremaForm [Scan (Wise SOACS)]
scan [Reduce (Wise SOACS)]
reduce Lambda (Wise SOACS)
map_lam')
  where
    -- It is not safe to move the transform if the root array is being
    -- consumed by the Screma.  This is a bit too conservative - it's
    -- actually safe if we completely replace the original input, but
    -- this rule is not that precise.
    consumed :: Names
consumed = forall op. AliasedOp op => op -> Names
consumedInOp Op (Wise SOACS)
soac
    map_param_names :: [VName]
map_param_names = forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName (forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda (Wise SOACS)
map_lam)
    topLevelPat :: Pat (VarWisdom, Type) -> Bool
topLevelPat = (forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall rep. Stm rep -> Pat (LetDec rep)
stmPat (forall rep. Body rep -> Stms rep
bodyStms (forall rep. Lambda rep -> Body rep
lambdaBody Lambda (Wise SOACS)
map_lam)))
    onlyUsedOnce :: VName -> Bool
onlyUsedOnce VName
arr =
      case forall a. (a -> Bool) -> [a] -> [a]
filter ((VName
arr `nameIn`) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. FreeIn a => a -> Names
freeIn) forall a b. (a -> b) -> a -> b
$ forall rep. Stms rep -> [Stm rep]
stmsToList forall a b. (a -> b) -> a -> b
$ forall rep. Body rep -> Stms rep
bodyStms forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> Body rep
lambdaBody Lambda (Wise SOACS)
map_lam of
        Stm (Wise SOACS)
_ : Stm (Wise SOACS)
_ : [Stm (Wise SOACS)]
_ -> Bool
False
        [Stm (Wise SOACS)]
_ -> Bool
True

    -- It's not just about whether the array is a parameter;
    -- everything else must be map-invariant.
    arrayIsMapParam :: (Pat (VarWisdom, Type), ArrayOp) -> Bool
arrayIsMapParam (Pat (VarWisdom, Type)
pat', ArrayIndexing Certs
cs VName
arr Slice SubExp
slice) =
      VName
arr forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
map_param_names
        Bool -> Bool -> Bool
&& forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (forall rep. VName -> SymbolTable rep -> Bool
`ST.elem` SymbolTable (Wise SOACS)
vtable) (Names -> [VName]
namesToList forall a b. (a -> b) -> a -> b
$ forall a. FreeIn a => a -> Names
freeIn Certs
cs forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> Names
freeIn Slice SubExp
slice)
        Bool -> Bool -> Bool
&& Bool -> Bool
not (forall (t :: * -> *) a. Foldable t => t a -> Bool
null Slice SubExp
slice)
        Bool -> Bool -> Bool
&& (Bool -> Bool
not (forall (t :: * -> *) a. Foldable t => t a -> Bool
null forall a b. (a -> b) -> a -> b
$ forall d. Slice d -> [d]
sliceDims Slice SubExp
slice) Bool -> Bool -> Bool
|| (Pat (VarWisdom, Type) -> Bool
topLevelPat Pat (VarWisdom, Type)
pat' Bool -> Bool -> Bool
&& VName -> Bool
onlyUsedOnce VName
arr))
    arrayIsMapParam (Pat (VarWisdom, Type)
_, ArrayRearrange Certs
cs VName
arr [Int]
perm) =
      VName
arr forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
map_param_names
        Bool -> Bool -> Bool
&& forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (forall rep. VName -> SymbolTable rep -> Bool
`ST.elem` SymbolTable (Wise SOACS)
vtable) (Names -> [VName]
namesToList forall a b. (a -> b) -> a -> b
$ forall a. FreeIn a => a -> Names
freeIn Certs
cs)
        Bool -> Bool -> Bool
&& Bool -> Bool
not (forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Int]
perm)
    arrayIsMapParam (Pat (VarWisdom, Type)
_, ArrayRotate Certs
cs VName
arr [SubExp]
rots) =
      VName
arr forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
map_param_names
        Bool -> Bool -> Bool
&& forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (forall rep. VName -> SymbolTable rep -> Bool
`ST.elem` SymbolTable (Wise SOACS)
vtable) (Names -> [VName]
namesToList forall a b. (a -> b) -> a -> b
$ forall a. FreeIn a => a -> Names
freeIn Certs
cs forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> Names
freeIn [SubExp]
rots)
    arrayIsMapParam (Pat (VarWisdom, Type)
_, ArrayReshape Certs
cs VName
arr ReshapeKind
_ Shape
new_shape) =
      VName
arr forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
map_param_names
        Bool -> Bool -> Bool
&& forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (forall rep. VName -> SymbolTable rep -> Bool
`ST.elem` SymbolTable (Wise SOACS)
vtable) (Names -> [VName]
namesToList forall a b. (a -> b) -> a -> b
$ forall a. FreeIn a => a -> Names
freeIn Certs
cs forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> Names
freeIn Shape
new_shape)
    arrayIsMapParam (Pat (VarWisdom, Type)
_, ArrayCopy Certs
cs VName
arr) =
      VName
arr forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
map_param_names
        Bool -> Bool -> Bool
&& forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (forall rep. VName -> SymbolTable rep -> Bool
`ST.elem` SymbolTable (Wise SOACS)
vtable) (Names -> [VName]
namesToList forall a b. (a -> b) -> a -> b
$ forall a. FreeIn a => a -> Names
freeIn Certs
cs)
    arrayIsMapParam (Pat (VarWisdom, Type)
_, ArrayVar {}) =
      Bool
False

    mapOverArr :: (Pat (VarWisdom, Type), ArrayOp)
-> RuleM
     (Wise SOACS)
     (Maybe (VName, Param Type, (Pat (VarWisdom, Type), ArrayOp)))
mapOverArr (Pat (VarWisdom, Type)
pat, ArrayOp
op)
      | Just (VName
_, VName
arr) <- forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((forall a. Eq a => a -> a -> Bool
== ArrayOp -> VName
arrayOpArr ArrayOp
op) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
map_param_names [VName]
arrs),
        VName
arr VName -> Names -> Bool
`notNameIn` Names
consumed = do
          Type
arr_t <- forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
arr
          let whole_dim :: DimIndex SubExp
whole_dim = forall d. d -> d -> d -> DimIndex d
DimSlice (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 Type
arr_t) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1)
          VName
arr_transformed <- forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (ArrayOp -> Certs
arrayOpCerts ArrayOp
op) forall a b. (a -> b) -> a -> b
$
            forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
arr forall a. [a] -> [a] -> [a]
++ [Char]
"_transformed") forall a b. (a -> b) -> a -> b
$
              case ArrayOp
op of
                ArrayIndexing Certs
_ VName
_ (Slice [DimIndex SubExp]
slice) ->
                  forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr forall a b. (a -> b) -> a -> b
$ forall d. [DimIndex d] -> Slice d
Slice forall a b. (a -> b) -> a -> b
$ DimIndex SubExp
whole_dim forall a. a -> [a] -> [a]
: [DimIndex SubExp]
slice
                ArrayRearrange Certs
_ VName
_ [Int]
perm ->
                  forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Rearrange (Int
0 forall a. a -> [a] -> [a]
: forall a b. (a -> b) -> [a] -> [b]
map (forall a. Num a => a -> a -> a
+ Int
1) [Int]
perm) VName
arr
                ArrayRotate Certs
_ VName
_ [SubExp]
rots ->
                  forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ [SubExp] -> VName -> BasicOp
Rotate (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0 forall a. a -> [a] -> [a]
: [SubExp]
rots) VName
arr
                ArrayReshape Certs
_ VName
_ ReshapeKind
k Shape
new_shape ->
                  forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ ReshapeKind -> Shape -> VName -> BasicOp
Reshape ReshapeKind
k (forall d. [d] -> ShapeBase d
Shape [SubExp
w] forall a. Semigroup a => a -> a -> a
<> Shape
new_shape) VName
arr
                ArrayCopy {} ->
                  forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
arr
                ArrayVar {} ->
                  forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
arr
          Type
arr_transformed_t <- forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
arr_transformed
          VName
arr_transformed_row <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName forall a b. (a -> b) -> a -> b
$ VName -> [Char]
baseString VName
arr forall a. [a] -> [a] -> [a]
++ [Char]
"_transformed_row"
          forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$
            forall a. a -> Maybe a
Just
              ( VName
arr_transformed,
                forall dec. Attrs -> VName -> dec -> Param dec
Param forall a. Monoid a => a
mempty VName
arr_transformed_row (forall u. TypeBase Shape u -> TypeBase Shape u
rowType Type
arr_transformed_t),
                (Pat (VarWisdom, Type)
pat, Certs -> VName -> ArrayOp
ArrayVar forall a. Monoid a => a
mempty VName
arr_transformed_row)
              )
    mapOverArr (Pat (VarWisdom, Type), ArrayOp)
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing
moveTransformToInput SymbolTable (Wise SOACS)
_ Pat (LetDec (Wise SOACS))
_ StmAux (ExpDec (Wise SOACS))
_ Op (Wise SOACS)
_ =
  forall rep. Rule rep
Skip