{-# LANGUAGE TypeFamilies #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
module Futhark.IR.SOACS.Simplify
( simplifySOACS,
simplifyLambda,
simplifyFun,
simplifyStms,
simplifyConsts,
simpleSOACS,
simplifySOAC,
soacRules,
HasSOAC (..),
simplifyKnownIterationSOAC,
removeReplicateMapping,
removeUnusedSOACInput,
liftIdentityMapping,
simplifyMapIota,
SOACS,
)
where
import Control.Monad
import Control.Monad.Identity
import Control.Monad.State
import Control.Monad.Writer
import Data.Either
import Data.Foldable
import Data.List (partition, transpose, unzip6, zip6)
import Data.List.NonEmpty (NonEmpty (..))
import Data.Map.Strict qualified as M
import Data.Maybe
import Data.Set qualified as S
import Futhark.Analysis.DataDependencies
import Futhark.Analysis.SymbolTable qualified as ST
import Futhark.Analysis.UsageTable qualified as UT
import Futhark.IR.Prop.Aliases
import Futhark.IR.SOACS
import Futhark.MonadFreshNames
import Futhark.Optimise.Simplify qualified as Simplify
import Futhark.Optimise.Simplify.Engine qualified as Engine
import Futhark.Optimise.Simplify.Rep
import Futhark.Optimise.Simplify.Rule
import Futhark.Optimise.Simplify.Rules
import Futhark.Optimise.Simplify.Rules.ClosedForm
import Futhark.Pass
import Futhark.Tools
import Futhark.Transform.Rename
import Futhark.Util
simpleSOACS :: Simplify.SimpleOps SOACS
simpleSOACS :: SimpleOps SOACS
simpleSOACS = SimplifyOp SOACS (Op (Wise SOACS)) -> SimpleOps SOACS
forall rep.
(SimplifiableRep rep, Buildable rep) =>
SimplifyOp rep (Op (Wise rep)) -> SimpleOps rep
Simplify.bindableSimpleOps SimplifyOp SOACS (Op (Wise SOACS))
SimplifyOp SOACS (SOAC (Wise SOACS))
forall rep. SimplifiableRep rep => SimplifyOp rep (SOAC (Wise rep))
simplifySOAC
simplifySOACS :: Prog SOACS -> PassM (Prog SOACS)
simplifySOACS :: Prog SOACS -> PassM (Prog SOACS)
simplifySOACS =
SimpleOps SOACS
-> RuleBook (Wise SOACS)
-> HoistBlockers SOACS
-> Prog SOACS
-> PassM (Prog SOACS)
forall rep.
SimplifiableRep rep =>
SimpleOps rep
-> RuleBook (Wise rep)
-> HoistBlockers rep
-> Prog rep
-> PassM (Prog rep)
Simplify.simplifyProg SimpleOps SOACS
simpleSOACS RuleBook (Wise SOACS)
soacRules HoistBlockers SOACS
forall {k} (rep :: k). HoistBlockers rep
Engine.noExtraHoistBlockers
simplifyFun ::
(MonadFreshNames m) =>
ST.SymbolTable (Wise SOACS) ->
FunDef SOACS ->
m (FunDef SOACS)
simplifyFun :: forall (m :: * -> *).
MonadFreshNames m =>
SymbolTable (Wise SOACS) -> FunDef SOACS -> m (FunDef SOACS)
simplifyFun =
SimpleOps SOACS
-> RuleBook (Wise SOACS)
-> HoistBlockers SOACS
-> SymbolTable (Wise SOACS)
-> FunDef SOACS
-> m (FunDef SOACS)
forall (m :: * -> *) rep.
(MonadFreshNames m, SimplifiableRep rep) =>
SimpleOps rep
-> RuleBook (Wise rep)
-> HoistBlockers rep
-> SymbolTable (Wise rep)
-> FunDef rep
-> m (FunDef rep)
Simplify.simplifyFun SimpleOps SOACS
simpleSOACS RuleBook (Wise SOACS)
soacRules HoistBlockers SOACS
forall {k} (rep :: k). HoistBlockers rep
Engine.noExtraHoistBlockers
simplifyLambda ::
(HasScope SOACS m, MonadFreshNames m) => Lambda SOACS -> m (Lambda SOACS)
simplifyLambda :: forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
Lambda SOACS -> m (Lambda SOACS)
simplifyLambda =
SimpleOps SOACS
-> RuleBook (Wise SOACS)
-> HoistBlockers SOACS
-> Lambda SOACS
-> m (Lambda SOACS)
forall (m :: * -> *) rep.
(MonadFreshNames m, HasScope rep m, SimplifiableRep rep) =>
SimpleOps rep
-> RuleBook (Wise rep)
-> HoistBlockers rep
-> Lambda rep
-> m (Lambda rep)
Simplify.simplifyLambda SimpleOps SOACS
simpleSOACS RuleBook (Wise SOACS)
soacRules HoistBlockers SOACS
forall {k} (rep :: k). HoistBlockers rep
Engine.noExtraHoistBlockers
simplifyStms ::
(HasScope SOACS m, MonadFreshNames m) => Stms SOACS -> m (Stms SOACS)
simplifyStms :: forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
Stms SOACS -> m (Stms SOACS)
simplifyStms Stms SOACS
stms = do
Scope SOACS
scope <- m (Scope SOACS)
forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope
SimpleOps SOACS
-> RuleBook (Wise SOACS)
-> HoistBlockers SOACS
-> Scope SOACS
-> Stms SOACS
-> m (Stms SOACS)
forall (m :: * -> *) rep.
(MonadFreshNames m, SimplifiableRep rep) =>
SimpleOps rep
-> RuleBook (Wise rep)
-> HoistBlockers rep
-> Scope rep
-> Stms rep
-> m (Stms rep)
Simplify.simplifyStms SimpleOps SOACS
simpleSOACS RuleBook (Wise SOACS)
soacRules HoistBlockers SOACS
forall {k} (rep :: k). HoistBlockers rep
Engine.noExtraHoistBlockers Scope SOACS
scope Stms SOACS
stms
simplifyConsts ::
(MonadFreshNames m) => Stms SOACS -> m (Stms SOACS)
simplifyConsts :: forall (m :: * -> *).
MonadFreshNames m =>
Stms SOACS -> m (Stms SOACS)
simplifyConsts =
SimpleOps SOACS
-> RuleBook (Wise SOACS)
-> HoistBlockers SOACS
-> Scope SOACS
-> Stms SOACS
-> m (Stms SOACS)
forall (m :: * -> *) rep.
(MonadFreshNames m, SimplifiableRep rep) =>
SimpleOps rep
-> RuleBook (Wise rep)
-> HoistBlockers rep
-> Scope rep
-> Stms rep
-> m (Stms rep)
Simplify.simplifyStms SimpleOps SOACS
simpleSOACS RuleBook (Wise SOACS)
soacRules HoistBlockers SOACS
forall {k} (rep :: k). HoistBlockers rep
Engine.noExtraHoistBlockers Scope SOACS
forall a. Monoid a => a
mempty
simplifySOAC ::
(Simplify.SimplifiableRep rep) =>
Simplify.SimplifyOp rep (SOAC (Wise rep))
simplifySOAC :: forall rep. SimplifiableRep rep => SimplifyOp rep (SOAC (Wise rep))
simplifySOAC (VJP Lambda (Wise rep)
lam [SubExp]
arr [SubExp]
vec) = do
(Lambda (Wise rep)
lam', Stms (Wise rep)
hoisted) <- Names
-> Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall rep.
SimplifiableRep rep =>
Names
-> Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
Engine.simplifyLambda Names
forall a. Monoid a => a
mempty Lambda (Wise rep)
lam
[SubExp]
arr' <- (SubExp -> SimpleM rep SubExp) -> [SubExp] -> SimpleM rep [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM SubExp -> SimpleM rep SubExp
forall rep. SimplifiableRep rep => SubExp -> SimpleM rep SubExp
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [SubExp]
arr
[SubExp]
vec' <- (SubExp -> SimpleM rep SubExp) -> [SubExp] -> SimpleM rep [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM SubExp -> SimpleM rep SubExp
forall rep. SimplifiableRep rep => SubExp -> SimpleM rep SubExp
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [SubExp]
vec
(SOAC (Wise rep), Stms (Wise rep))
-> SimpleM rep (SOAC (Wise rep), Stms (Wise rep))
forall a. a -> SimpleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Lambda (Wise rep) -> [SubExp] -> [SubExp] -> SOAC (Wise rep)
forall rep. Lambda rep -> [SubExp] -> [SubExp] -> SOAC rep
VJP Lambda (Wise rep)
lam' [SubExp]
arr' [SubExp]
vec', Stms (Wise rep)
hoisted)
simplifySOAC (JVP Lambda (Wise rep)
lam [SubExp]
arr [SubExp]
vec) = do
(Lambda (Wise rep)
lam', Stms (Wise rep)
hoisted) <- Names
-> Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall rep.
SimplifiableRep rep =>
Names
-> Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
Engine.simplifyLambda Names
forall a. Monoid a => a
mempty Lambda (Wise rep)
lam
[SubExp]
arr' <- (SubExp -> SimpleM rep SubExp) -> [SubExp] -> SimpleM rep [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM SubExp -> SimpleM rep SubExp
forall rep. SimplifiableRep rep => SubExp -> SimpleM rep SubExp
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [SubExp]
arr
[SubExp]
vec' <- (SubExp -> SimpleM rep SubExp) -> [SubExp] -> SimpleM rep [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM SubExp -> SimpleM rep SubExp
forall rep. SimplifiableRep rep => SubExp -> SimpleM rep SubExp
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [SubExp]
vec
(SOAC (Wise rep), Stms (Wise rep))
-> SimpleM rep (SOAC (Wise rep), Stms (Wise rep))
forall a. a -> SimpleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Lambda (Wise rep) -> [SubExp] -> [SubExp] -> SOAC (Wise rep)
forall rep. Lambda rep -> [SubExp] -> [SubExp] -> SOAC rep
JVP Lambda (Wise rep)
lam' [SubExp]
arr' [SubExp]
vec', Stms (Wise rep)
hoisted)
simplifySOAC (Stream SubExp
outerdim [VName]
arr [SubExp]
nes Lambda (Wise rep)
lam) = do
SubExp
outerdim' <- SubExp -> SimpleM rep SubExp
forall rep. SimplifiableRep rep => SubExp -> SimpleM rep SubExp
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify SubExp
outerdim
[SubExp]
nes' <- (SubExp -> SimpleM rep SubExp) -> [SubExp] -> SimpleM rep [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM SubExp -> SimpleM rep SubExp
forall rep. SimplifiableRep rep => SubExp -> SimpleM rep SubExp
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [SubExp]
nes
[VName]
arr' <- (VName -> SimpleM rep VName) -> [VName] -> SimpleM rep [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> SimpleM rep VName
forall rep. SimplifiableRep rep => VName -> SimpleM rep VName
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [VName]
arr
(Lambda (Wise rep)
lam', Stms (Wise rep)
lam_hoisted) <- SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall {k} (rep :: k) a. SimpleM rep a -> SimpleM rep a
Engine.enterLoop (SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep)))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall a b. (a -> b) -> a -> b
$ Names
-> Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall rep.
SimplifiableRep rep =>
Names
-> Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
Engine.simplifyLambda Names
forall a. Monoid a => a
mempty Lambda (Wise rep)
lam
(SOAC (Wise rep), Stms (Wise rep))
-> SimpleM rep (SOAC (Wise rep), Stms (Wise rep))
forall a. a -> SimpleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp
-> [VName] -> [SubExp] -> Lambda (Wise rep) -> SOAC (Wise rep)
forall rep. SubExp -> [VName] -> [SubExp] -> Lambda rep -> SOAC rep
Stream SubExp
outerdim' [VName]
arr' [SubExp]
nes' Lambda (Wise rep)
lam', Stms (Wise rep)
lam_hoisted)
simplifySOAC (Scatter SubExp
w [VName]
ivs Lambda (Wise rep)
lam [(Shape, Int, VName)]
as) = do
SubExp
w' <- SubExp -> SimpleM rep SubExp
forall rep. SimplifiableRep rep => SubExp -> SimpleM rep SubExp
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify SubExp
w
(Lambda (Wise rep)
lam', Stms (Wise rep)
hoisted) <- SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall {k} (rep :: k) a. SimpleM rep a -> SimpleM rep a
Engine.enterLoop (SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep)))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall a b. (a -> b) -> a -> b
$ Names
-> Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall rep.
SimplifiableRep rep =>
Names
-> Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
Engine.simplifyLambda Names
forall a. Monoid a => a
mempty Lambda (Wise rep)
lam
[VName]
ivs' <- (VName -> SimpleM rep VName) -> [VName] -> SimpleM rep [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> SimpleM rep VName
forall rep. SimplifiableRep rep => VName -> SimpleM rep VName
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [VName]
ivs
[(Shape, Int, VName)]
as' <- ((Shape, Int, VName) -> SimpleM rep (Shape, Int, VName))
-> [(Shape, Int, VName)] -> SimpleM rep [(Shape, Int, VName)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Shape, Int, VName) -> SimpleM rep (Shape, Int, VName)
forall rep.
SimplifiableRep rep =>
(Shape, Int, VName) -> SimpleM rep (Shape, Int, VName)
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [(Shape, Int, VName)]
as
(SOAC (Wise rep), Stms (Wise rep))
-> SimpleM rep (SOAC (Wise rep), Stms (Wise rep))
forall a. a -> SimpleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp
-> [VName]
-> Lambda (Wise rep)
-> [(Shape, Int, VName)]
-> SOAC (Wise rep)
forall rep.
SubExp
-> [VName] -> Lambda rep -> [(Shape, Int, VName)] -> SOAC rep
Scatter SubExp
w' [VName]
ivs' Lambda (Wise rep)
lam' [(Shape, Int, VName)]
as', Stms (Wise rep)
hoisted)
simplifySOAC (Hist SubExp
w [VName]
imgs [HistOp (Wise rep)]
ops Lambda (Wise rep)
bfun) = do
SubExp
w' <- SubExp -> SimpleM rep SubExp
forall rep. SimplifiableRep rep => SubExp -> SimpleM rep SubExp
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify SubExp
w
([HistOp (Wise rep)]
ops', [Stms (Wise rep)]
hoisted) <- ([(HistOp (Wise rep), Stms (Wise rep))]
-> ([HistOp (Wise rep)], [Stms (Wise rep)]))
-> SimpleM rep [(HistOp (Wise rep), Stms (Wise rep))]
-> SimpleM rep ([HistOp (Wise rep)], [Stms (Wise rep)])
forall a b. (a -> b) -> SimpleM rep a -> SimpleM rep b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [(HistOp (Wise rep), Stms (Wise rep))]
-> ([HistOp (Wise rep)], [Stms (Wise rep)])
forall a b. [(a, b)] -> ([a], [b])
unzip (SimpleM rep [(HistOp (Wise rep), Stms (Wise rep))]
-> SimpleM rep ([HistOp (Wise rep)], [Stms (Wise rep)]))
-> SimpleM rep [(HistOp (Wise rep), Stms (Wise rep))]
-> SimpleM rep ([HistOp (Wise rep)], [Stms (Wise rep)])
forall a b. (a -> b) -> a -> b
$
[HistOp (Wise rep)]
-> (HistOp (Wise rep)
-> SimpleM rep (HistOp (Wise rep), Stms (Wise rep)))
-> SimpleM rep [(HistOp (Wise rep), Stms (Wise rep))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [HistOp (Wise rep)]
ops ((HistOp (Wise rep)
-> SimpleM rep (HistOp (Wise rep), Stms (Wise rep)))
-> SimpleM rep [(HistOp (Wise rep), Stms (Wise rep))])
-> (HistOp (Wise rep)
-> SimpleM rep (HistOp (Wise rep), Stms (Wise rep)))
-> SimpleM rep [(HistOp (Wise rep), Stms (Wise rep))]
forall a b. (a -> b) -> a -> b
$ \(HistOp Shape
dests_w SubExp
rf [VName]
dests [SubExp]
nes Lambda (Wise rep)
op) -> do
Shape
dests_w' <- Shape -> SimpleM rep Shape
forall rep. SimplifiableRep rep => Shape -> SimpleM rep Shape
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify Shape
dests_w
SubExp
rf' <- SubExp -> SimpleM rep SubExp
forall rep. SimplifiableRep rep => SubExp -> SimpleM rep SubExp
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify SubExp
rf
[VName]
dests' <- [VName] -> SimpleM rep [VName]
forall rep. SimplifiableRep rep => [VName] -> SimpleM rep [VName]
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [VName]
dests
[SubExp]
nes' <- (SubExp -> SimpleM rep SubExp) -> [SubExp] -> SimpleM rep [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM SubExp -> SimpleM rep SubExp
forall rep. SimplifiableRep rep => SubExp -> SimpleM rep SubExp
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [SubExp]
nes
(Lambda (Wise rep)
op', Stms (Wise rep)
hoisted) <- SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall {k} (rep :: k) a. SimpleM rep a -> SimpleM rep a
Engine.enterLoop (SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep)))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall a b. (a -> b) -> a -> b
$ Names
-> Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall rep.
SimplifiableRep rep =>
Names
-> Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
Engine.simplifyLambda Names
forall a. Monoid a => a
mempty Lambda (Wise rep)
op
(HistOp (Wise rep), Stms (Wise rep))
-> SimpleM rep (HistOp (Wise rep), Stms (Wise rep))
forall a. a -> SimpleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Shape
-> SubExp
-> [VName]
-> [SubExp]
-> Lambda (Wise rep)
-> HistOp (Wise rep)
forall rep.
Shape -> SubExp -> [VName] -> [SubExp] -> Lambda rep -> HistOp rep
HistOp Shape
dests_w' SubExp
rf' [VName]
dests' [SubExp]
nes' Lambda (Wise rep)
op', Stms (Wise rep)
hoisted)
[VName]
imgs' <- (VName -> SimpleM rep VName) -> [VName] -> SimpleM rep [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> SimpleM rep VName
forall rep. SimplifiableRep rep => VName -> SimpleM rep VName
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [VName]
imgs
(Lambda (Wise rep)
bfun', Stms (Wise rep)
bfun_hoisted) <- SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall {k} (rep :: k) a. SimpleM rep a -> SimpleM rep a
Engine.enterLoop (SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep)))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall a b. (a -> b) -> a -> b
$ Names
-> Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall rep.
SimplifiableRep rep =>
Names
-> Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
Engine.simplifyLambda Names
forall a. Monoid a => a
mempty Lambda (Wise rep)
bfun
(SOAC (Wise rep), Stms (Wise rep))
-> SimpleM rep (SOAC (Wise rep), Stms (Wise rep))
forall a. a -> SimpleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp
-> [VName]
-> [HistOp (Wise rep)]
-> Lambda (Wise rep)
-> SOAC (Wise rep)
forall rep.
SubExp -> [VName] -> [HistOp rep] -> Lambda rep -> SOAC rep
Hist SubExp
w' [VName]
imgs' [HistOp (Wise rep)]
ops' Lambda (Wise rep)
bfun', [Stms (Wise rep)] -> Stms (Wise rep)
forall a. Monoid a => [a] -> a
mconcat [Stms (Wise rep)]
hoisted Stms (Wise rep) -> Stms (Wise rep) -> Stms (Wise rep)
forall a. Semigroup a => a -> a -> a
<> Stms (Wise rep)
bfun_hoisted)
simplifySOAC (Screma SubExp
w [VName]
arrs (ScremaForm [Scan (Wise rep)]
scans [Reduce (Wise rep)]
reds Lambda (Wise rep)
map_lam)) = do
([Scan (Wise rep)]
scans', [Stms (Wise rep)]
scans_hoisted) <- ([(Scan (Wise rep), Stms (Wise rep))]
-> ([Scan (Wise rep)], [Stms (Wise rep)]))
-> SimpleM rep [(Scan (Wise rep), Stms (Wise rep))]
-> SimpleM rep ([Scan (Wise rep)], [Stms (Wise rep)])
forall a b. (a -> b) -> SimpleM rep a -> SimpleM rep b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [(Scan (Wise rep), Stms (Wise rep))]
-> ([Scan (Wise rep)], [Stms (Wise rep)])
forall a b. [(a, b)] -> ([a], [b])
unzip (SimpleM rep [(Scan (Wise rep), Stms (Wise rep))]
-> SimpleM rep ([Scan (Wise rep)], [Stms (Wise rep)]))
-> SimpleM rep [(Scan (Wise rep), Stms (Wise rep))]
-> SimpleM rep ([Scan (Wise rep)], [Stms (Wise rep)])
forall a b. (a -> b) -> a -> b
$
[Scan (Wise rep)]
-> (Scan (Wise rep)
-> SimpleM rep (Scan (Wise rep), Stms (Wise rep)))
-> SimpleM rep [(Scan (Wise rep), Stms (Wise rep))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Scan (Wise rep)]
scans ((Scan (Wise rep)
-> SimpleM rep (Scan (Wise rep), Stms (Wise rep)))
-> SimpleM rep [(Scan (Wise rep), Stms (Wise rep))])
-> (Scan (Wise rep)
-> SimpleM rep (Scan (Wise rep), Stms (Wise rep)))
-> SimpleM rep [(Scan (Wise rep), Stms (Wise rep))]
forall a b. (a -> b) -> a -> b
$ \(Scan Lambda (Wise rep)
lam [SubExp]
nes) -> do
(Lambda (Wise rep)
lam', Stms (Wise rep)
hoisted) <- Names
-> Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall rep.
SimplifiableRep rep =>
Names
-> Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
Engine.simplifyLambda Names
forall a. Monoid a => a
mempty Lambda (Wise rep)
lam
[SubExp]
nes' <- [SubExp] -> SimpleM rep [SubExp]
forall rep. SimplifiableRep rep => [SubExp] -> SimpleM rep [SubExp]
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [SubExp]
nes
(Scan (Wise rep), Stms (Wise rep))
-> SimpleM rep (Scan (Wise rep), Stms (Wise rep))
forall a. a -> SimpleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Lambda (Wise rep) -> [SubExp] -> Scan (Wise rep)
forall rep. Lambda rep -> [SubExp] -> Scan rep
Scan Lambda (Wise rep)
lam' [SubExp]
nes', Stms (Wise rep)
hoisted)
([Reduce (Wise rep)]
reds', [Stms (Wise rep)]
reds_hoisted) <- ([(Reduce (Wise rep), Stms (Wise rep))]
-> ([Reduce (Wise rep)], [Stms (Wise rep)]))
-> SimpleM rep [(Reduce (Wise rep), Stms (Wise rep))]
-> SimpleM rep ([Reduce (Wise rep)], [Stms (Wise rep)])
forall a b. (a -> b) -> SimpleM rep a -> SimpleM rep b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [(Reduce (Wise rep), Stms (Wise rep))]
-> ([Reduce (Wise rep)], [Stms (Wise rep)])
forall a b. [(a, b)] -> ([a], [b])
unzip (SimpleM rep [(Reduce (Wise rep), Stms (Wise rep))]
-> SimpleM rep ([Reduce (Wise rep)], [Stms (Wise rep)]))
-> SimpleM rep [(Reduce (Wise rep), Stms (Wise rep))]
-> SimpleM rep ([Reduce (Wise rep)], [Stms (Wise rep)])
forall a b. (a -> b) -> a -> b
$
[Reduce (Wise rep)]
-> (Reduce (Wise rep)
-> SimpleM rep (Reduce (Wise rep), Stms (Wise rep)))
-> SimpleM rep [(Reduce (Wise rep), Stms (Wise rep))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Reduce (Wise rep)]
reds ((Reduce (Wise rep)
-> SimpleM rep (Reduce (Wise rep), Stms (Wise rep)))
-> SimpleM rep [(Reduce (Wise rep), Stms (Wise rep))])
-> (Reduce (Wise rep)
-> SimpleM rep (Reduce (Wise rep), Stms (Wise rep)))
-> SimpleM rep [(Reduce (Wise rep), Stms (Wise rep))]
forall a b. (a -> b) -> a -> b
$ \(Reduce Commutativity
comm Lambda (Wise rep)
lam [SubExp]
nes) -> do
(Lambda (Wise rep)
lam', Stms (Wise rep)
hoisted) <- Names
-> Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall rep.
SimplifiableRep rep =>
Names
-> Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
Engine.simplifyLambda Names
forall a. Monoid a => a
mempty Lambda (Wise rep)
lam
[SubExp]
nes' <- [SubExp] -> SimpleM rep [SubExp]
forall rep. SimplifiableRep rep => [SubExp] -> SimpleM rep [SubExp]
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [SubExp]
nes
(Reduce (Wise rep), Stms (Wise rep))
-> SimpleM rep (Reduce (Wise rep), Stms (Wise rep))
forall a. a -> SimpleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Commutativity -> Lambda (Wise rep) -> [SubExp] -> Reduce (Wise rep)
forall rep. Commutativity -> Lambda rep -> [SubExp] -> Reduce rep
Reduce Commutativity
comm Lambda (Wise rep)
lam' [SubExp]
nes', Stms (Wise rep)
hoisted)
(Lambda (Wise rep)
map_lam', Stms (Wise rep)
map_lam_hoisted) <- SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall {k} (rep :: k) a. SimpleM rep a -> SimpleM rep a
Engine.enterLoop (SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep)))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall a b. (a -> b) -> a -> b
$ Names
-> Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall rep.
SimplifiableRep rep =>
Names
-> Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
Engine.simplifyLambda Names
forall a. Monoid a => a
mempty Lambda (Wise rep)
map_lam
(,)
(SOAC (Wise rep)
-> Stms (Wise rep) -> (SOAC (Wise rep), Stms (Wise rep)))
-> SimpleM rep (SOAC (Wise rep))
-> SimpleM
rep (Stms (Wise rep) -> (SOAC (Wise rep), Stms (Wise rep)))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ( SubExp -> [VName] -> ScremaForm (Wise rep) -> SOAC (Wise rep)
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma
(SubExp -> [VName] -> ScremaForm (Wise rep) -> SOAC (Wise rep))
-> SimpleM rep SubExp
-> SimpleM
rep ([VName] -> ScremaForm (Wise rep) -> SOAC (Wise rep))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> SimpleM rep SubExp
forall rep. SimplifiableRep rep => SubExp -> SimpleM rep SubExp
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify SubExp
w
SimpleM rep ([VName] -> ScremaForm (Wise rep) -> SOAC (Wise rep))
-> SimpleM rep [VName]
-> SimpleM rep (ScremaForm (Wise rep) -> SOAC (Wise rep))
forall a b. SimpleM rep (a -> b) -> SimpleM rep a -> SimpleM rep b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [VName] -> SimpleM rep [VName]
forall rep. SimplifiableRep rep => [VName] -> SimpleM rep [VName]
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [VName]
arrs
SimpleM rep (ScremaForm (Wise rep) -> SOAC (Wise rep))
-> SimpleM rep (ScremaForm (Wise rep))
-> SimpleM rep (SOAC (Wise rep))
forall a b. SimpleM rep (a -> b) -> SimpleM rep a -> SimpleM rep b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ScremaForm (Wise rep) -> SimpleM rep (ScremaForm (Wise rep))
forall a. a -> SimpleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Scan (Wise rep)]
-> [Reduce (Wise rep)]
-> Lambda (Wise rep)
-> ScremaForm (Wise rep)
forall rep.
[Scan rep] -> [Reduce rep] -> Lambda rep -> ScremaForm rep
ScremaForm [Scan (Wise rep)]
scans' [Reduce (Wise rep)]
reds' Lambda (Wise rep)
map_lam')
)
SimpleM rep (Stms (Wise rep) -> (SOAC (Wise rep), Stms (Wise rep)))
-> SimpleM rep (Stms (Wise rep))
-> SimpleM rep (SOAC (Wise rep), Stms (Wise rep))
forall a b. SimpleM rep (a -> b) -> SimpleM rep a -> SimpleM rep b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Stms (Wise rep) -> SimpleM rep (Stms (Wise rep))
forall a. a -> SimpleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Stms (Wise rep)] -> Stms (Wise rep)
forall a. Monoid a => [a] -> a
mconcat [Stms (Wise rep)]
scans_hoisted Stms (Wise rep) -> Stms (Wise rep) -> Stms (Wise rep)
forall a. Semigroup a => a -> a -> a
<> [Stms (Wise rep)] -> Stms (Wise rep)
forall a. Monoid a => [a] -> a
mconcat [Stms (Wise rep)]
reds_hoisted Stms (Wise rep) -> Stms (Wise rep) -> Stms (Wise rep)
forall a. Semigroup a => a -> a -> a
<> Stms (Wise rep)
map_lam_hoisted)
instance BuilderOps (Wise SOACS)
instance TraverseOpStms (Wise SOACS) where
traverseOpStms :: forall (m :: * -> *).
Monad m =>
OpStmsTraverser m (Op (Wise SOACS)) (Wise SOACS)
traverseOpStms = (Map VName (NameInfo (Wise SOACS))
-> Stms (Wise SOACS) -> m (Stms (Wise SOACS)))
-> Op (Wise SOACS) -> m (Op (Wise SOACS))
OpStmsTraverser m (SOAC (Wise SOACS)) (Wise SOACS)
forall (m :: * -> *) rep.
Monad m =>
OpStmsTraverser m (SOAC rep) rep
traverseSOACStms
fixLambdaParams ::
(MonadBuilder m, Buildable (Rep m), BuilderOps (Rep m)) =>
Lambda (Rep m) ->
[Maybe SubExp] ->
m (Lambda (Rep m))
fixLambdaParams :: forall (m :: * -> *).
(MonadBuilder m, Buildable (Rep m), BuilderOps (Rep m)) =>
Lambda (Rep m) -> [Maybe SubExp] -> m (Lambda (Rep m))
fixLambdaParams Lambda (Rep m)
lam [Maybe SubExp]
fixes = do
Body (Rep m)
body <- Builder (Rep m) (Body (Rep m)) -> m (Body (Rep m))
forall rep (m :: * -> *) somerep.
(Buildable rep, MonadFreshNames m, HasScope somerep m,
SameScope somerep rep) =>
Builder rep (Body rep) -> m (Body rep)
runBodyBuilder (Builder (Rep m) (Body (Rep m)) -> m (Body (Rep m)))
-> Builder (Rep m) (Body (Rep m)) -> m (Body (Rep m))
forall a b. (a -> b) -> a -> b
$
Scope (Rep m)
-> Builder (Rep m) (Body (Rep m)) -> Builder (Rep m) (Body (Rep m))
forall a.
Scope (Rep m)
-> BuilderT (Rep m) (State VNameSource) a
-> BuilderT (Rep m) (State VNameSource) a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([Param (LParamInfo (Rep m))] -> Scope (Rep m)
forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams ([Param (LParamInfo (Rep m))] -> Scope (Rep m))
-> [Param (LParamInfo (Rep m))] -> Scope (Rep m)
forall a b. (a -> b) -> a -> b
$ Lambda (Rep m) -> [Param (LParamInfo (Rep m))]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda (Rep m)
lam) (Builder (Rep m) (Body (Rep m)) -> Builder (Rep m) (Body (Rep m)))
-> Builder (Rep m) (Body (Rep m)) -> Builder (Rep m) (Body (Rep m))
forall a b. (a -> b) -> a -> b
$ do
(Param (LParamInfo (Rep m))
-> Maybe SubExp -> BuilderT (Rep m) (State VNameSource) ())
-> [Param (LParamInfo (Rep m))]
-> [Maybe SubExp]
-> BuilderT (Rep m) (State VNameSource) ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ Param (LParamInfo (Rep m))
-> Maybe SubExp -> BuilderT (Rep m) (State VNameSource) ()
forall {m :: * -> *} {dec}.
MonadBuilder m =>
Param dec -> Maybe SubExp -> m ()
maybeFix (Lambda (Rep m) -> [Param (LParamInfo (Rep m))]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda (Rep m)
lam) [Maybe SubExp]
fixes'
Body (Rep m) -> Builder (Rep m) (Body (Rep m))
forall a. a -> BuilderT (Rep m) (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Body (Rep m) -> Builder (Rep m) (Body (Rep m)))
-> Body (Rep m) -> Builder (Rep m) (Body (Rep m))
forall a b. (a -> b) -> a -> b
$ Lambda (Rep m) -> Body (Rep m)
forall rep. Lambda rep -> Body rep
lambdaBody Lambda (Rep m)
lam
Lambda (Rep m) -> m (Lambda (Rep m))
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
Lambda (Rep m)
lam
{ lambdaBody :: Body (Rep m)
lambdaBody = Body (Rep m)
body,
lambdaParams :: [Param (LParamInfo (Rep m))]
lambdaParams =
((Param Type, Maybe SubExp) -> Param (LParamInfo (Rep m)))
-> [(Param Type, Maybe SubExp)] -> [Param (LParamInfo (Rep m))]
forall a b. (a -> b) -> [a] -> [b]
map (Param Type, Maybe SubExp) -> Param Type
(Param Type, Maybe SubExp) -> Param (LParamInfo (Rep m))
forall a b. (a, b) -> a
fst ([(Param Type, Maybe SubExp)] -> [Param (LParamInfo (Rep m))])
-> [(Param Type, Maybe SubExp)] -> [Param (LParamInfo (Rep m))]
forall a b. (a -> b) -> a -> b
$
((Param Type, Maybe SubExp) -> Bool)
-> [(Param Type, Maybe SubExp)] -> [(Param Type, Maybe SubExp)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Maybe SubExp -> Bool
forall a. Maybe a -> Bool
isNothing (Maybe SubExp -> Bool)
-> ((Param Type, Maybe SubExp) -> Maybe SubExp)
-> (Param Type, Maybe SubExp)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param Type, Maybe SubExp) -> Maybe SubExp
forall a b. (a, b) -> b
snd) ([(Param Type, Maybe SubExp)] -> [(Param Type, Maybe SubExp)])
-> [(Param Type, Maybe SubExp)] -> [(Param Type, Maybe SubExp)]
forall a b. (a -> b) -> a -> b
$
[Param Type] -> [Maybe SubExp] -> [(Param Type, Maybe SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Lambda (Rep m) -> [Param (LParamInfo (Rep m))]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda (Rep m)
lam) [Maybe SubExp]
fixes'
}
where
fixes' :: [Maybe SubExp]
fixes' = [Maybe SubExp]
fixes [Maybe SubExp] -> [Maybe SubExp] -> [Maybe SubExp]
forall a. [a] -> [a] -> [a]
++ Maybe SubExp -> [Maybe SubExp]
forall a. a -> [a]
repeat Maybe SubExp
forall a. Maybe a
Nothing
maybeFix :: Param dec -> Maybe SubExp -> m ()
maybeFix Param dec
p (Just SubExp
x) = [VName] -> Exp (Rep m) -> m ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
p] (Exp (Rep m) -> m ()) -> Exp (Rep m) -> m ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
x
maybeFix Param dec
_ Maybe SubExp
Nothing = () -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
removeLambdaResults :: [Bool] -> Lambda rep -> Lambda rep
removeLambdaResults :: forall rep. [Bool] -> Lambda rep -> Lambda rep
removeLambdaResults [Bool]
keep Lambda rep
lam =
Lambda rep
lam
{ lambdaBody :: Body rep
lambdaBody = Body rep
lam_body',
lambdaReturnType :: [Type]
lambdaReturnType = [Type]
ret
}
where
keep' :: [a] -> [a]
keep' :: forall a. [a] -> [a]
keep' = ((Bool, a) -> a) -> [(Bool, a)] -> [a]
forall a b. (a -> b) -> [a] -> [b]
map (Bool, a) -> a
forall a b. (a, b) -> b
snd ([(Bool, a)] -> [a]) -> ([a] -> [(Bool, a)]) -> [a] -> [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Bool, a) -> Bool) -> [(Bool, a)] -> [(Bool, a)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool, a) -> Bool
forall a b. (a, b) -> a
fst ([(Bool, a)] -> [(Bool, a)])
-> ([a] -> [(Bool, a)]) -> [a] -> [(Bool, a)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Bool] -> [a] -> [(Bool, a)]
forall a b. [a] -> [b] -> [(a, b)]
zip ([Bool]
keep [Bool] -> [Bool] -> [Bool]
forall a. [a] -> [a] -> [a]
++ Bool -> [Bool]
forall a. a -> [a]
repeat Bool
True)
lam_body :: Body rep
lam_body = Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam
lam_body' :: Body rep
lam_body' = Body rep
lam_body {bodyResult :: Result
bodyResult = Result -> Result
forall a. [a] -> [a]
keep' (Result -> Result) -> Result -> Result
forall a b. (a -> b) -> a -> b
$ Body rep -> Result
forall rep. Body rep -> Result
bodyResult Body rep
lam_body}
ret :: [Type]
ret = [Type] -> [Type]
forall a. [a] -> [a]
keep' ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda rep
lam
soacRules :: RuleBook (Wise SOACS)
soacRules :: RuleBook (Wise SOACS)
soacRules = RuleBook (Wise SOACS)
forall rep. (BuilderOps rep, TraverseOpStms rep) => RuleBook rep
standardRules RuleBook (Wise SOACS)
-> RuleBook (Wise SOACS) -> RuleBook (Wise SOACS)
forall a. Semigroup a => a -> a -> a
<> [TopDownRule (Wise SOACS)]
-> [BottomUpRule (Wise SOACS)] -> RuleBook (Wise SOACS)
forall m. [TopDownRule m] -> [BottomUpRule m] -> RuleBook m
ruleBook [TopDownRule (Wise SOACS)]
topDownRules [BottomUpRule (Wise SOACS)]
bottomUpRules
class HasSOAC rep where
asSOAC :: Op rep -> Maybe (SOAC rep)
soacOp :: SOAC rep -> Op rep
instance HasSOAC (Wise SOACS) where
asSOAC :: Op (Wise SOACS) -> Maybe (SOAC (Wise SOACS))
asSOAC = Op (Wise SOACS) -> Maybe (SOAC (Wise SOACS))
SOAC (Wise SOACS) -> Maybe (SOAC (Wise SOACS))
forall a. a -> Maybe a
Just
soacOp :: SOAC (Wise SOACS) -> Op (Wise SOACS)
soacOp = SOAC (Wise SOACS) -> Op (Wise SOACS)
SOAC (Wise SOACS) -> SOAC (Wise SOACS)
forall a. a -> a
id
topDownRules :: [TopDownRule (Wise SOACS)]
topDownRules :: [TopDownRule (Wise SOACS)]
topDownRules =
[ RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
-> TopDownRule (Wise SOACS)
forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
hoistCerts,
RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
-> TopDownRule (Wise SOACS)
forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
forall rep.
(Aliased rep, BuilderOps rep, HasSOAC rep) =>
TopDownRuleOp rep
removeReplicateMapping,
RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
-> TopDownRule (Wise SOACS)
forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
removeReplicateWrite,
RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
-> TopDownRule (Wise SOACS)
forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
forall rep.
(Aliased rep, Buildable rep, BuilderOps rep, HasSOAC rep) =>
TopDownRuleOp rep
removeUnusedSOACInput,
RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
-> TopDownRule (Wise SOACS)
forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
simplifyClosedFormReduce,
RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
-> TopDownRule (Wise SOACS)
forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
forall rep.
(Buildable rep, BuilderOps rep, HasSOAC rep) =>
TopDownRuleOp rep
simplifyKnownIterationSOAC,
RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
-> TopDownRule (Wise SOACS)
forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
forall rep.
(Buildable rep, BuilderOps rep, HasSOAC rep) =>
TopDownRuleOp rep
liftIdentityMapping,
RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
-> TopDownRule (Wise SOACS)
forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
removeDuplicateMapOutput,
RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
-> TopDownRule (Wise SOACS)
forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
fuseConcatScatter,
RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
-> TopDownRule (Wise SOACS)
forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
forall rep.
(Buildable rep, BuilderOps rep, HasSOAC rep) =>
TopDownRuleOp rep
simplifyMapIota,
RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
-> TopDownRule (Wise SOACS)
forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
moveTransformToInput
]
bottomUpRules :: [BottomUpRule (Wise SOACS)]
bottomUpRules :: [BottomUpRule (Wise SOACS)]
bottomUpRules =
[ RuleOp (Wise SOACS) (BottomUp (Wise SOACS))
-> BottomUpRule (Wise SOACS)
forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp RuleOp (Wise SOACS) (BottomUp (Wise SOACS))
removeDeadMapping,
RuleOp (Wise SOACS) (BottomUp (Wise SOACS))
-> BottomUpRule (Wise SOACS)
forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp RuleOp (Wise SOACS) (BottomUp (Wise SOACS))
removeDeadReduction,
RuleOp (Wise SOACS) (BottomUp (Wise SOACS))
-> BottomUpRule (Wise SOACS)
forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp RuleOp (Wise SOACS) (BottomUp (Wise SOACS))
removeDeadWrite,
RuleBasicOp (Wise SOACS) (BottomUp (Wise SOACS))
-> BottomUpRule (Wise SOACS)
forall rep a. RuleBasicOp rep a -> SimplificationRule rep a
RuleBasicOp RuleBasicOp (Wise SOACS) (BottomUp (Wise SOACS))
forall rep. BuilderOps rep => BottomUpRuleBasicOp rep
removeUnnecessaryCopy,
RuleOp (Wise SOACS) (BottomUp (Wise SOACS))
-> BottomUpRule (Wise SOACS)
forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp RuleOp (Wise SOACS) (BottomUp (Wise SOACS))
liftIdentityStreaming,
RuleOp (Wise SOACS) (BottomUp (Wise SOACS))
-> BottomUpRule (Wise SOACS)
forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp RuleOp (Wise SOACS) (BottomUp (Wise SOACS))
mapOpToOp
]
hoistCerts :: TopDownRuleOp (Wise SOACS)
hoistCerts :: RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
hoistCerts SymbolTable (Wise SOACS)
vtable Pat (LetDec (Wise SOACS))
pat StmAux (ExpDec (Wise SOACS))
aux Op (Wise SOACS)
soac
| (SOAC (Wise SOACS)
soac', Certs
hoisted) <- State Certs (SOAC (Wise SOACS))
-> Certs -> (SOAC (Wise SOACS), Certs)
forall s a. State s a -> s -> (a, s)
runState (SOACMapper (Wise SOACS) (Wise SOACS) (StateT Certs Identity)
-> SOAC (Wise SOACS) -> State Certs (SOAC (Wise SOACS))
forall (m :: * -> *) frep trep.
Monad m =>
SOACMapper frep trep m -> SOAC frep -> m (SOAC trep)
mapSOACM SOACMapper (Wise SOACS) (Wise SOACS) (StateT Certs Identity)
mapper Op (Wise SOACS)
SOAC (Wise SOACS)
soac) Certs
forall a. Monoid a => a
mempty,
Certs
hoisted Certs -> Certs -> Bool
forall a. Eq a => a -> a -> Bool
/= Certs
forall a. Monoid a => a
mempty =
RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM (Wise SOACS) () -> Rule (Wise SOACS))
-> RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ StmAux (ExpWisdom, ())
-> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpWisdom, ())
StmAux (ExpDec (Wise SOACS))
aux (RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ())
-> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$ Certs -> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall a. Certs -> RuleM (Wise SOACS) a -> RuleM (Wise SOACS) a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
hoisted (RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ())
-> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$ Pat (LetDec (Rep (RuleM (Wise SOACS))))
-> Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec (Rep (RuleM (Wise SOACS))))
Pat (LetDec (Wise SOACS))
pat (Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ())
-> Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$ Op (Wise SOACS) -> Exp (Wise SOACS)
forall rep. Op rep -> Exp rep
Op Op (Wise SOACS)
SOAC (Wise SOACS)
soac'
where
mapper :: SOACMapper (Wise SOACS) (Wise SOACS) (StateT Certs Identity)
mapper = SOACMapper Any Any (StateT Certs Identity)
forall rep (m :: * -> *). Monad m => SOACMapper rep rep m
identitySOACMapper {mapOnSOACLambda :: Lambda (Wise SOACS) -> StateT Certs Identity (Lambda (Wise SOACS))
mapOnSOACLambda = Lambda (Wise SOACS) -> StateT Certs Identity (Lambda (Wise SOACS))
onLambda}
onLambda :: Lambda (Wise SOACS) -> StateT Certs Identity (Lambda (Wise SOACS))
onLambda Lambda (Wise SOACS)
lam = do
Stms (Wise SOACS)
stms' <- (Stm (Wise SOACS) -> StateT Certs Identity (Stm (Wise SOACS)))
-> Stms (Wise SOACS) -> StateT Certs Identity (Stms (Wise SOACS))
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> Seq a -> m (Seq b)
mapM Stm (Wise SOACS) -> StateT Certs Identity (Stm (Wise SOACS))
onStm (Stms (Wise SOACS) -> StateT Certs Identity (Stms (Wise SOACS)))
-> Stms (Wise SOACS) -> StateT Certs Identity (Stms (Wise SOACS))
forall a b. (a -> b) -> a -> b
$ Body (Wise SOACS) -> Stms (Wise SOACS)
forall rep. Body rep -> Stms rep
bodyStms (Body (Wise SOACS) -> Stms (Wise SOACS))
-> Body (Wise SOACS) -> Stms (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ Lambda (Wise SOACS) -> Body (Wise SOACS)
forall rep. Lambda rep -> Body rep
lambdaBody Lambda (Wise SOACS)
lam
Lambda (Wise SOACS) -> StateT Certs Identity (Lambda (Wise SOACS))
forall a. a -> StateT Certs Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
Lambda (Wise SOACS)
lam
{ lambdaBody :: Body (Wise SOACS)
lambdaBody =
Stms (Wise SOACS) -> Result -> Body (Wise SOACS)
forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody Stms (Wise SOACS)
stms' (Result -> Body (Wise SOACS)) -> Result -> Body (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ Body (Wise SOACS) -> Result
forall rep. Body rep -> Result
bodyResult (Body (Wise SOACS) -> Result) -> Body (Wise SOACS) -> Result
forall a b. (a -> b) -> a -> b
$ Lambda (Wise SOACS) -> Body (Wise SOACS)
forall rep. Lambda rep -> Body rep
lambdaBody Lambda (Wise SOACS)
lam
}
onStm :: Stm (Wise SOACS) -> StateT Certs Identity (Stm (Wise SOACS))
onStm (Let Pat (LetDec (Wise SOACS))
se_pat StmAux (ExpDec (Wise SOACS))
se_aux (BasicOp (SubExp SubExp
se))) = do
let ([VName]
invariant, [VName]
variant) =
(VName -> Bool) -> [VName] -> ([VName], [VName])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition (VName -> SymbolTable (Wise SOACS) -> Bool
forall rep. VName -> SymbolTable rep -> Bool
`ST.elem` SymbolTable (Wise SOACS)
vtable) ([VName] -> ([VName], [VName])) -> [VName] -> ([VName], [VName])
forall a b. (a -> b) -> a -> b
$
Certs -> [VName]
unCerts (Certs -> [VName]) -> Certs -> [VName]
forall a b. (a -> b) -> a -> b
$
StmAux (ExpDec (Wise SOACS)) -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (ExpDec (Wise SOACS))
se_aux
se_aux' :: StmAux (ExpDec (Wise SOACS))
se_aux' = StmAux (ExpDec (Wise SOACS))
se_aux {stmAuxCerts :: Certs
stmAuxCerts = [VName] -> Certs
Certs [VName]
variant}
(Certs -> Certs) -> StateT Certs Identity ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ([VName] -> Certs
Certs [VName]
invariant <>)
Stm (Wise SOACS) -> StateT Certs Identity (Stm (Wise SOACS))
forall a. a -> StateT Certs Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stm (Wise SOACS) -> StateT Certs Identity (Stm (Wise SOACS)))
-> Stm (Wise SOACS) -> StateT Certs Identity (Stm (Wise SOACS))
forall a b. (a -> b) -> a -> b
$ Pat (LetDec (Wise SOACS))
-> StmAux (ExpDec (Wise SOACS))
-> Exp (Wise SOACS)
-> Stm (Wise SOACS)
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec (Wise SOACS))
se_pat StmAux (ExpDec (Wise SOACS))
se_aux' (Exp (Wise SOACS) -> Stm (Wise SOACS))
-> Exp (Wise SOACS) -> Stm (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Wise SOACS)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Wise SOACS)) -> BasicOp -> Exp (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
onStm Stm (Wise SOACS)
stm = Stm (Wise SOACS) -> StateT Certs Identity (Stm (Wise SOACS))
forall a. a -> StateT Certs Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Stm (Wise SOACS)
stm
hoistCerts SymbolTable (Wise SOACS)
_ Pat (LetDec (Wise SOACS))
_ StmAux (ExpDec (Wise SOACS))
_ Op (Wise SOACS)
_ =
Rule (Wise SOACS)
forall rep. Rule rep
Skip
liftIdentityMapping ::
forall rep.
(Buildable rep, BuilderOps rep, HasSOAC rep) =>
TopDownRuleOp rep
liftIdentityMapping :: forall rep.
(Buildable rep, BuilderOps rep, HasSOAC rep) =>
TopDownRuleOp rep
liftIdentityMapping TopDown rep
_ Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux OpC rep rep
op
| Just (Screma SubExp
w [VName]
arrs ScremaForm rep
form :: SOAC rep) <- OpC rep rep -> Maybe (SOAC rep)
forall rep. HasSOAC rep => Op rep -> Maybe (SOAC rep)
asSOAC OpC rep rep
op,
Just Lambda rep
fun <- ScremaForm rep -> Maybe (Lambda rep)
forall rep. ScremaForm rep -> Maybe (Lambda rep)
isMapSOAC ScremaForm rep
form = do
let inputMap :: Map VName VName
inputMap = [(VName, VName)] -> Map VName VName
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, VName)] -> Map VName VName)
-> [(VName, VName)] -> Map VName VName
forall a b. (a -> b) -> a -> b
$ [VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip ((Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName ([Param Type] -> [VName]) -> [Param Type] -> [VName]
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [LParam rep]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
fun) [VName]
arrs
free :: Names
free = Body rep -> Names
forall a. FreeIn a => a -> Names
freeIn (Body rep -> Names) -> Body rep -> Names
forall a b. (a -> b) -> a -> b
$ Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
fun
rettype :: [Type]
rettype = Lambda rep -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda rep
fun
ses :: Result
ses = Body rep -> Result
forall rep. Body rep -> Result
bodyResult (Body rep -> Result) -> Body rep -> Result
forall a b. (a -> b) -> a -> b
$ Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
fun
freeOrConst :: SubExp -> Bool
freeOrConst (Var VName
v) = VName
v VName -> Names -> Bool
`nameIn` Names
free
freeOrConst Constant {} = Bool
True
checkInvariance :: (PatElem (LetDec rep), SubExpRes, Type)
-> ([(Pat (LetDec rep), Exp rep)],
[(PatElem (LetDec rep), SubExp)], [Type])
-> ([(Pat (LetDec rep), Exp rep)],
[(PatElem (LetDec rep), SubExp)], [Type])
checkInvariance (PatElem (LetDec rep)
outId, SubExpRes Certs
_ (Var VName
v), Type
_) ([(Pat (LetDec rep), Exp rep)]
invariant, [(PatElem (LetDec rep), SubExp)]
mapresult, [Type]
rettype')
| Just VName
inp <- VName -> Map VName VName -> Maybe VName
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName VName
inputMap =
( ([PatElem (LetDec rep)] -> Pat (LetDec rep)
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (LetDec rep)
outId], VName -> Exp rep
e VName
inp) (Pat (LetDec rep), Exp rep)
-> [(Pat (LetDec rep), Exp rep)] -> [(Pat (LetDec rep), Exp rep)]
forall a. a -> [a] -> [a]
: [(Pat (LetDec rep), Exp rep)]
invariant,
[(PatElem (LetDec rep), SubExp)]
mapresult,
[Type]
rettype'
)
where
e :: VName -> Exp rep
e VName
inp = case PatElem (LetDec rep) -> Type
forall dec. Typed dec => PatElem dec -> Type
patElemType PatElem (LetDec rep)
outId of
Acc {} -> BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
inp
Type
_ -> BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (Shape -> SubExp -> BasicOp
Replicate Shape
forall a. Monoid a => a
mempty (VName -> SubExp
Var VName
inp))
checkInvariance (PatElem (LetDec rep)
outId, SubExpRes Certs
_ SubExp
e, Type
t) ([(Pat (LetDec rep), Exp rep)]
invariant, [(PatElem (LetDec rep), SubExp)]
mapresult, [Type]
rettype')
| SubExp -> Bool
freeOrConst SubExp
e =
( ([PatElem (LetDec rep)] -> Pat (LetDec rep)
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (LetDec rep)
outId], BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
w]) SubExp
e) (Pat (LetDec rep), Exp rep)
-> [(Pat (LetDec rep), Exp rep)] -> [(Pat (LetDec rep), Exp rep)]
forall a. a -> [a] -> [a]
: [(Pat (LetDec rep), Exp rep)]
invariant,
[(PatElem (LetDec rep), SubExp)]
mapresult,
[Type]
rettype'
)
| Bool
otherwise =
( [(Pat (LetDec rep), Exp rep)]
invariant,
(PatElem (LetDec rep)
outId, SubExp
e) (PatElem (LetDec rep), SubExp)
-> [(PatElem (LetDec rep), SubExp)]
-> [(PatElem (LetDec rep), SubExp)]
forall a. a -> [a] -> [a]
: [(PatElem (LetDec rep), SubExp)]
mapresult,
Type
t Type -> [Type] -> [Type]
forall a. a -> [a] -> [a]
: [Type]
rettype'
)
case ((PatElem (LetDec rep), SubExpRes, Type)
-> ([(Pat (LetDec rep), Exp rep)],
[(PatElem (LetDec rep), SubExp)], [Type])
-> ([(Pat (LetDec rep), Exp rep)],
[(PatElem (LetDec rep), SubExp)], [Type]))
-> ([(Pat (LetDec rep), Exp rep)],
[(PatElem (LetDec rep), SubExp)], [Type])
-> [(PatElem (LetDec rep), SubExpRes, Type)]
-> ([(Pat (LetDec rep), Exp rep)],
[(PatElem (LetDec rep), SubExp)], [Type])
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (PatElem (LetDec rep), SubExpRes, Type)
-> ([(Pat (LetDec rep), Exp rep)],
[(PatElem (LetDec rep), SubExp)], [Type])
-> ([(Pat (LetDec rep), Exp rep)],
[(PatElem (LetDec rep), SubExp)], [Type])
checkInvariance ([], [], []) ([(PatElem (LetDec rep), SubExpRes, Type)]
-> ([(Pat (LetDec rep), Exp rep)],
[(PatElem (LetDec rep), SubExp)], [Type]))
-> [(PatElem (LetDec rep), SubExpRes, Type)]
-> ([(Pat (LetDec rep), Exp rep)],
[(PatElem (LetDec rep), SubExp)], [Type])
forall a b. (a -> b) -> a -> b
$
[PatElem (LetDec rep)]
-> Result -> [Type] -> [(PatElem (LetDec rep), SubExpRes, Type)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 (Pat (LetDec rep) -> [PatElem (LetDec rep)]
forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec rep)
pat) Result
ses [Type]
rettype of
([], [(PatElem (LetDec rep), SubExp)]
_, [Type]
_) -> Rule rep
forall rep. Rule rep
Skip
([(Pat (LetDec rep), Exp rep)]
invariant, [(PatElem (LetDec rep), SubExp)]
mapresult, [Type]
rettype') -> RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
let ([PatElem (LetDec rep)]
pat', [SubExp]
ses') = [(PatElem (LetDec rep), SubExp)]
-> ([PatElem (LetDec rep)], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip [(PatElem (LetDec rep), SubExp)]
mapresult
fun' :: Lambda rep
fun' =
Lambda rep
fun
{ lambdaBody :: Body rep
lambdaBody = (Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
fun) {bodyResult :: Result
bodyResult = [SubExp] -> Result
subExpsRes [SubExp]
ses'},
lambdaReturnType :: [Type]
lambdaReturnType = [Type]
rettype'
}
((Pat (LetDec rep), Exp rep) -> RuleM rep ())
-> [(Pat (LetDec rep), Exp rep)] -> RuleM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ((Pat (LetDec rep) -> Exp rep -> RuleM rep ())
-> (Pat (LetDec rep), Exp rep) -> RuleM rep ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Pat (LetDec rep) -> Exp rep -> RuleM rep ()
Pat (LetDec (Rep (RuleM rep)))
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind) [(Pat (LetDec rep), Exp rep)]
invariant
StmAux (ExpDec rep) -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
[VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames ((PatElem (LetDec rep) -> VName)
-> [PatElem (LetDec rep)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName [PatElem (LetDec rep)]
pat') (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
Op (Rep (RuleM rep)) -> Exp (Rep (RuleM rep))
forall rep. Op rep -> Exp rep
Op (Op (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)))
-> Op (Rep (RuleM rep)) -> Exp (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$
SOAC rep -> OpC rep rep
forall rep. HasSOAC rep => SOAC rep -> Op rep
soacOp (SOAC rep -> OpC rep rep) -> SOAC rep -> OpC rep rep
forall a b. (a -> b) -> a -> b
$
SubExp -> [VName] -> ScremaForm rep -> SOAC rep
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName]
arrs (Lambda rep -> ScremaForm rep
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda rep
fun')
liftIdentityMapping TopDown rep
_ Pat (LetDec rep)
_ StmAux (ExpDec rep)
_ OpC rep rep
_ = Rule rep
forall rep. Rule rep
Skip
liftIdentityStreaming :: BottomUpRuleOp (Wise SOACS)
liftIdentityStreaming :: RuleOp (Wise SOACS) (BottomUp (Wise SOACS))
liftIdentityStreaming BottomUp (Wise SOACS)
_ (Pat [PatElem (LetDec (Wise SOACS))]
pes) StmAux (ExpDec (Wise SOACS))
aux (Stream SubExp
w [VName]
arrs [SubExp]
nes Lambda (Wise SOACS)
lam)
| ([(Type, PatElem (VarWisdom, Type), SubExpRes)]
variant_map, [(PatElem (VarWisdom, Type), VName)]
invariant_map) <-
[Either
(Type, PatElem (VarWisdom, Type), SubExpRes)
(PatElem (VarWisdom, Type), VName)]
-> ([(Type, PatElem (VarWisdom, Type), SubExpRes)],
[(PatElem (VarWisdom, Type), VName)])
forall a b. [Either a b] -> ([a], [b])
partitionEithers ([Either
(Type, PatElem (VarWisdom, Type), SubExpRes)
(PatElem (VarWisdom, Type), VName)]
-> ([(Type, PatElem (VarWisdom, Type), SubExpRes)],
[(PatElem (VarWisdom, Type), VName)]))
-> [Either
(Type, PatElem (VarWisdom, Type), SubExpRes)
(PatElem (VarWisdom, Type), VName)]
-> ([(Type, PatElem (VarWisdom, Type), SubExpRes)],
[(PatElem (VarWisdom, Type), VName)])
forall a b. (a -> b) -> a -> b
$ ((Type, PatElem (VarWisdom, Type), SubExpRes)
-> Either
(Type, PatElem (VarWisdom, Type), SubExpRes)
(PatElem (VarWisdom, Type), VName))
-> [(Type, PatElem (VarWisdom, Type), SubExpRes)]
-> [Either
(Type, PatElem (VarWisdom, Type), SubExpRes)
(PatElem (VarWisdom, Type), VName)]
forall a b. (a -> b) -> [a] -> [b]
map (Type, PatElem (VarWisdom, Type), SubExpRes)
-> Either
(Type, PatElem (VarWisdom, Type), SubExpRes)
(PatElem (VarWisdom, Type), VName)
isInvariantRes ([(Type, PatElem (VarWisdom, Type), SubExpRes)]
-> [Either
(Type, PatElem (VarWisdom, Type), SubExpRes)
(PatElem (VarWisdom, Type), VName)])
-> [(Type, PatElem (VarWisdom, Type), SubExpRes)]
-> [Either
(Type, PatElem (VarWisdom, Type), SubExpRes)
(PatElem (VarWisdom, Type), VName)]
forall a b. (a -> b) -> a -> b
$ [Type]
-> [PatElem (VarWisdom, Type)]
-> Result
-> [(Type, PatElem (VarWisdom, Type), SubExpRes)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Type]
map_ts [PatElem (VarWisdom, Type)]
map_pes Result
map_res,
Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [(PatElem (VarWisdom, Type), VName)] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(PatElem (VarWisdom, Type), VName)]
invariant_map = RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM (Wise SOACS) () -> Rule (Wise SOACS))
-> RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ do
[(PatElem (VarWisdom, Type), VName)]
-> ((PatElem (VarWisdom, Type), VName) -> RuleM (Wise SOACS) ())
-> RuleM (Wise SOACS) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(PatElem (VarWisdom, Type), VName)]
invariant_map (((PatElem (VarWisdom, Type), VName) -> RuleM (Wise SOACS) ())
-> RuleM (Wise SOACS) ())
-> ((PatElem (VarWisdom, Type), VName) -> RuleM (Wise SOACS) ())
-> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$ \(PatElem (VarWisdom, Type)
pe, VName
arr) ->
Pat (LetDec (Rep (RuleM (Wise SOACS))))
-> Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind ([PatElem (VarWisdom, Type)] -> Pat (VarWisdom, Type)
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (VarWisdom, Type)
pe]) (Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ())
-> Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (RuleM (Wise SOACS)))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (RuleM (Wise SOACS))))
-> BasicOp -> Exp (Rep (RuleM (Wise SOACS)))
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate Shape
forall a. Monoid a => a
mempty (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
arr
let ([Type]
variant_map_ts, [PatElem (VarWisdom, Type)]
variant_map_pes, Result
variant_map_res) = [(Type, PatElem (VarWisdom, Type), SubExpRes)]
-> ([Type], [PatElem (VarWisdom, Type)], Result)
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(Type, PatElem (VarWisdom, Type), SubExpRes)]
variant_map
lam' :: Lambda (Wise SOACS)
lam' =
Lambda (Wise SOACS)
lam
{ lambdaBody :: Body (Wise SOACS)
lambdaBody = (Lambda (Wise SOACS) -> Body (Wise SOACS)
forall rep. Lambda rep -> Body rep
lambdaBody Lambda (Wise SOACS)
lam) {bodyResult :: Result
bodyResult = Result
fold_res Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Result
variant_map_res},
lambdaReturnType :: [Type]
lambdaReturnType = [Type]
fold_ts [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [Type]
variant_map_ts
}
StmAux (ExpWisdom, ())
-> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpWisdom, ())
StmAux (ExpDec (Wise SOACS))
aux (RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ())
-> (SOAC (Wise SOACS) -> RuleM (Wise SOACS) ())
-> SOAC (Wise SOACS)
-> RuleM (Wise SOACS) ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pat (LetDec (Rep (RuleM (Wise SOACS))))
-> Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind ([PatElem (LetDec (Rep (RuleM (Wise SOACS))))]
-> Pat (LetDec (Rep (RuleM (Wise SOACS))))
forall dec. [PatElem dec] -> Pat dec
Pat ([PatElem (LetDec (Rep (RuleM (Wise SOACS))))]
-> Pat (LetDec (Rep (RuleM (Wise SOACS)))))
-> [PatElem (LetDec (Rep (RuleM (Wise SOACS))))]
-> Pat (LetDec (Rep (RuleM (Wise SOACS))))
forall a b. (a -> b) -> a -> b
$ [PatElem (VarWisdom, Type)]
fold_pes [PatElem (VarWisdom, Type)]
-> [PatElem (VarWisdom, Type)] -> [PatElem (VarWisdom, Type)]
forall a. [a] -> [a] -> [a]
++ [PatElem (VarWisdom, Type)]
variant_map_pes) (Exp (Wise SOACS) -> RuleM (Wise SOACS) ())
-> (SOAC (Wise SOACS) -> Exp (Wise SOACS))
-> SOAC (Wise SOACS)
-> RuleM (Wise SOACS) ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Op (Wise SOACS) -> Exp (Wise SOACS)
SOAC (Wise SOACS) -> Exp (Wise SOACS)
forall rep. Op rep -> Exp rep
Op (SOAC (Wise SOACS) -> RuleM (Wise SOACS) ())
-> SOAC (Wise SOACS) -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$
SubExp
-> [VName] -> [SubExp] -> Lambda (Wise SOACS) -> SOAC (Wise SOACS)
forall rep. SubExp -> [VName] -> [SubExp] -> Lambda rep -> SOAC rep
Stream SubExp
w [VName]
arrs [SubExp]
nes Lambda (Wise SOACS)
lam'
where
num_folds :: Int
num_folds = [SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes
([PatElem (VarWisdom, Type)]
fold_pes, [PatElem (VarWisdom, Type)]
map_pes) = Int
-> [PatElem (VarWisdom, Type)]
-> ([PatElem (VarWisdom, Type)], [PatElem (VarWisdom, Type)])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_folds [PatElem (VarWisdom, Type)]
[PatElem (LetDec (Wise SOACS))]
pes
([Type]
fold_ts, [Type]
map_ts) = Int -> [Type] -> ([Type], [Type])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_folds ([Type] -> ([Type], [Type])) -> [Type] -> ([Type], [Type])
forall a b. (a -> b) -> a -> b
$ Lambda (Wise SOACS) -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda (Wise SOACS)
lam
lam_res :: Result
lam_res = Body (Wise SOACS) -> Result
forall rep. Body rep -> Result
bodyResult (Body (Wise SOACS) -> Result) -> Body (Wise SOACS) -> Result
forall a b. (a -> b) -> a -> b
$ Lambda (Wise SOACS) -> Body (Wise SOACS)
forall rep. Lambda rep -> Body rep
lambdaBody Lambda (Wise SOACS)
lam
(Result
fold_res, Result
map_res) = Int -> Result -> (Result, Result)
forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_folds Result
lam_res
params_to_arrs :: [(VName, VName)]
params_to_arrs = [VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip ((Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName ([Param Type] -> [VName]) -> [Param Type] -> [VName]
forall a b. (a -> b) -> a -> b
$ Int -> [Param Type] -> [Param Type]
forall a. Int -> [a] -> [a]
drop (Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
num_folds) ([Param Type] -> [Param Type]) -> [Param Type] -> [Param Type]
forall a b. (a -> b) -> a -> b
$ Lambda (Wise SOACS) -> [LParam (Wise SOACS)]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda (Wise SOACS)
lam) [VName]
arrs
isInvariantRes :: (Type, PatElem (VarWisdom, Type), SubExpRes)
-> Either
(Type, PatElem (VarWisdom, Type), SubExpRes)
(PatElem (VarWisdom, Type), VName)
isInvariantRes (Type
_, PatElem (VarWisdom, Type)
pe, SubExpRes Certs
_ (Var VName
v))
| Just VName
arr <- VName -> [(VName, VName)] -> Maybe VName
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup VName
v [(VName, VName)]
params_to_arrs =
(PatElem (VarWisdom, Type), VName)
-> Either
(Type, PatElem (VarWisdom, Type), SubExpRes)
(PatElem (VarWisdom, Type), VName)
forall a b. b -> Either a b
Right (PatElem (VarWisdom, Type)
pe, VName
arr)
isInvariantRes (Type, PatElem (VarWisdom, Type), SubExpRes)
x =
(Type, PatElem (VarWisdom, Type), SubExpRes)
-> Either
(Type, PatElem (VarWisdom, Type), SubExpRes)
(PatElem (VarWisdom, Type), VName)
forall a b. a -> Either a b
Left (Type, PatElem (VarWisdom, Type), SubExpRes)
x
liftIdentityStreaming BottomUp (Wise SOACS)
_ Pat (LetDec (Wise SOACS))
_ StmAux (ExpDec (Wise SOACS))
_ Op (Wise SOACS)
_ = Rule (Wise SOACS)
forall rep. Rule rep
Skip
removeReplicateMapping ::
(Aliased rep, BuilderOps rep, HasSOAC rep) =>
TopDownRuleOp rep
removeReplicateMapping :: forall rep.
(Aliased rep, BuilderOps rep, HasSOAC rep) =>
TopDownRuleOp rep
removeReplicateMapping TopDown rep
vtable Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux Op rep
op
| Just (Screma SubExp
w [VName]
arrs ScremaForm rep
form) <- Op rep -> Maybe (SOAC rep)
forall rep. HasSOAC rep => Op rep -> Maybe (SOAC rep)
asSOAC Op rep
op,
Just Lambda rep
fun <- ScremaForm rep -> Maybe (Lambda rep)
forall rep. ScremaForm rep -> Maybe (Lambda rep)
isMapSOAC ScremaForm rep
form,
Just ([([VName], Certs, Exp rep)]
stms, Lambda rep
fun', [VName]
arrs') <- TopDown rep
-> Lambda rep
-> [VName]
-> Maybe ([([VName], Certs, Exp rep)], Lambda rep, [VName])
forall rep.
Aliased rep =>
SymbolTable rep
-> Lambda rep
-> [VName]
-> Maybe ([([VName], Certs, Exp rep)], Lambda rep, [VName])
removeReplicateInput TopDown rep
vtable Lambda rep
fun [VName]
arrs = RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
[([VName], Certs, Exp rep)]
-> (([VName], Certs, Exp rep) -> RuleM rep ()) -> RuleM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [([VName], Certs, Exp rep)]
stms ((([VName], Certs, Exp rep) -> RuleM rep ()) -> RuleM rep ())
-> (([VName], Certs, Exp rep) -> RuleM rep ()) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ \([VName]
vs, Certs
cs, Exp rep
e) -> Certs -> RuleM rep () -> RuleM rep ()
forall a. Certs -> RuleM rep a -> RuleM rep a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ [VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName]
vs Exp rep
Exp (Rep (RuleM rep))
e
StmAux (ExpDec rep) -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ Pat (LetDec (Rep (RuleM rep)))
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
Pat (LetDec (Rep (RuleM rep)))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ Op (Rep (RuleM rep)) -> Exp (Rep (RuleM rep))
forall rep. Op rep -> Exp rep
Op (Op (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)))
-> Op (Rep (RuleM rep)) -> Exp (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$ SOAC rep -> Op rep
forall rep. HasSOAC rep => SOAC rep -> Op rep
soacOp (SOAC rep -> Op rep) -> SOAC rep -> Op rep
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm rep -> SOAC rep
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName]
arrs' (ScremaForm rep -> SOAC rep) -> ScremaForm rep -> SOAC rep
forall a b. (a -> b) -> a -> b
$ Lambda rep -> ScremaForm rep
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda rep
fun'
removeReplicateMapping TopDown rep
_ Pat (LetDec rep)
_ StmAux (ExpDec rep)
_ Op rep
_ = Rule rep
forall rep. Rule rep
Skip
removeReplicateWrite :: TopDownRuleOp (Wise SOACS)
removeReplicateWrite :: RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
removeReplicateWrite SymbolTable (Wise SOACS)
vtable Pat (LetDec (Wise SOACS))
pat StmAux (ExpDec (Wise SOACS))
aux (Scatter SubExp
w [VName]
ivs Lambda (Wise SOACS)
lam [(Shape, Int, VName)]
as)
| Just ([([VName], Certs, Exp (Wise SOACS))]
stms, Lambda (Wise SOACS)
lam', [VName]
ivs') <- SymbolTable (Wise SOACS)
-> Lambda (Wise SOACS)
-> [VName]
-> Maybe
([([VName], Certs, Exp (Wise SOACS))], Lambda (Wise SOACS),
[VName])
forall rep.
Aliased rep =>
SymbolTable rep
-> Lambda rep
-> [VName]
-> Maybe ([([VName], Certs, Exp rep)], Lambda rep, [VName])
removeReplicateInput SymbolTable (Wise SOACS)
vtable Lambda (Wise SOACS)
lam [VName]
ivs = RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM (Wise SOACS) () -> Rule (Wise SOACS))
-> RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ do
[([VName], Certs, Exp (Wise SOACS))]
-> (([VName], Certs, Exp (Wise SOACS)) -> RuleM (Wise SOACS) ())
-> RuleM (Wise SOACS) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [([VName], Certs, Exp (Wise SOACS))]
stms ((([VName], Certs, Exp (Wise SOACS)) -> RuleM (Wise SOACS) ())
-> RuleM (Wise SOACS) ())
-> (([VName], Certs, Exp (Wise SOACS)) -> RuleM (Wise SOACS) ())
-> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$ \([VName]
vs, Certs
cs, Exp (Wise SOACS)
e) -> Certs -> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall a. Certs -> RuleM (Wise SOACS) a -> RuleM (Wise SOACS) a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs (RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ())
-> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$ [VName] -> Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName]
vs Exp (Rep (RuleM (Wise SOACS)))
Exp (Wise SOACS)
e
StmAux (ExpWisdom, ())
-> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpWisdom, ())
StmAux (ExpDec (Wise SOACS))
aux (RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ())
-> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$ Pat (LetDec (Rep (RuleM (Wise SOACS))))
-> Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec (Rep (RuleM (Wise SOACS))))
Pat (LetDec (Wise SOACS))
pat (Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ())
-> Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$ Op (Rep (RuleM (Wise SOACS))) -> Exp (Rep (RuleM (Wise SOACS)))
forall rep. Op rep -> Exp rep
Op (Op (Rep (RuleM (Wise SOACS))) -> Exp (Rep (RuleM (Wise SOACS))))
-> Op (Rep (RuleM (Wise SOACS))) -> Exp (Rep (RuleM (Wise SOACS)))
forall a b. (a -> b) -> a -> b
$ SubExp
-> [VName]
-> Lambda (Wise SOACS)
-> [(Shape, Int, VName)]
-> SOAC (Wise SOACS)
forall rep.
SubExp
-> [VName] -> Lambda rep -> [(Shape, Int, VName)] -> SOAC rep
Scatter SubExp
w [VName]
ivs' Lambda (Wise SOACS)
lam' [(Shape, Int, VName)]
as
removeReplicateWrite SymbolTable (Wise SOACS)
_ Pat (LetDec (Wise SOACS))
_ StmAux (ExpDec (Wise SOACS))
_ Op (Wise SOACS)
_ = Rule (Wise SOACS)
forall rep. Rule rep
Skip
removeReplicateInput ::
(Aliased rep) =>
ST.SymbolTable rep ->
Lambda rep ->
[VName] ->
Maybe
( [([VName], Certs, Exp rep)],
Lambda rep,
[VName]
)
removeReplicateInput :: forall rep.
Aliased rep =>
SymbolTable rep
-> Lambda rep
-> [VName]
-> Maybe ([([VName], Certs, Exp rep)], Lambda rep, [VName])
removeReplicateInput SymbolTable rep
vtable Lambda rep
fun [VName]
arrs
| Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [([VName], Certs, Exp rep)] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [([VName], Certs, Exp rep)]
parameterBnds = do
let ([Param (LParamInfo rep)]
arr_params', [VName]
arrs') = [(Param (LParamInfo rep), VName)]
-> ([Param (LParamInfo rep)], [VName])
forall a b. [(a, b)] -> ([a], [b])
unzip [(Param (LParamInfo rep), VName)]
params_and_arrs
fun' :: Lambda rep
fun' = Lambda rep
fun {lambdaParams :: [Param (LParamInfo rep)]
lambdaParams = [Param (LParamInfo rep)]
acc_params [Param (LParamInfo rep)]
-> [Param (LParamInfo rep)] -> [Param (LParamInfo rep)]
forall a. Semigroup a => a -> a -> a
<> [Param (LParamInfo rep)]
arr_params'}
([([VName], Certs, Exp rep)], Lambda rep, [VName])
-> Maybe ([([VName], Certs, Exp rep)], Lambda rep, [VName])
forall a. a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([([VName], Certs, Exp rep)]
parameterBnds, Lambda rep
fun', [VName]
arrs')
| Bool
otherwise = Maybe ([([VName], Certs, Exp rep)], Lambda rep, [VName])
forall a. Maybe a
Nothing
where
params :: [Param (LParamInfo rep)]
params = Lambda rep -> [Param (LParamInfo rep)]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
fun
([Param (LParamInfo rep)]
acc_params, [Param (LParamInfo rep)]
arr_params) =
Int
-> [Param (LParamInfo rep)]
-> ([Param (LParamInfo rep)], [Param (LParamInfo rep)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([Param (LParamInfo rep)] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Param (LParamInfo rep)]
params Int -> Int -> Int
forall a. Num a => a -> a -> a
- [VName] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
arrs) [Param (LParamInfo rep)]
params
([(Param (LParamInfo rep), VName)]
params_and_arrs, [([VName], Certs, Exp rep)]
parameterBnds) =
[Either (Param (LParamInfo rep), VName) ([VName], Certs, Exp rep)]
-> ([(Param (LParamInfo rep), VName)], [([VName], Certs, Exp rep)])
forall a b. [Either a b] -> ([a], [b])
partitionEithers ([Either (Param (LParamInfo rep), VName) ([VName], Certs, Exp rep)]
-> ([(Param (LParamInfo rep), VName)],
[([VName], Certs, Exp rep)]))
-> [Either
(Param (LParamInfo rep), VName) ([VName], Certs, Exp rep)]
-> ([(Param (LParamInfo rep), VName)], [([VName], Certs, Exp rep)])
forall a b. (a -> b) -> a -> b
$ (Param (LParamInfo rep)
-> VName
-> Either
(Param (LParamInfo rep), VName) ([VName], Certs, Exp rep))
-> [Param (LParamInfo rep)]
-> [VName]
-> [Either
(Param (LParamInfo rep), VName) ([VName], Certs, Exp rep)]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Param (LParamInfo rep)
-> VName
-> Either (Param (LParamInfo rep), VName) ([VName], Certs, Exp rep)
isReplicateAndNotConsumed [Param (LParamInfo rep)]
arr_params [VName]
arrs
isReplicateAndNotConsumed :: Param (LParamInfo rep)
-> VName
-> Either (Param (LParamInfo rep), VName) ([VName], Certs, Exp rep)
isReplicateAndNotConsumed Param (LParamInfo rep)
p VName
v
| Just (BasicOp (Replicate (Shape (SubExp
_ : [SubExp]
ds)) SubExp
e), Certs
v_cs) <-
VName -> SymbolTable rep -> Maybe (Exp rep, Certs)
forall rep. VName -> SymbolTable rep -> Maybe (Exp rep, Certs)
ST.lookupExp VName
v SymbolTable rep
vtable,
Param (LParamInfo rep) -> VName
forall dec. Param dec -> VName
paramName Param (LParamInfo rep)
p VName -> Names -> Bool
`notNameIn` Lambda rep -> Names
forall rep. Aliased rep => Lambda rep -> Names
consumedByLambda Lambda rep
fun =
([VName], Certs, Exp rep)
-> Either (Param (LParamInfo rep), VName) ([VName], Certs, Exp rep)
forall a b. b -> Either a b
Right
( [Param (LParamInfo rep) -> VName
forall dec. Param dec -> VName
paramName Param (LParamInfo rep)
p],
Certs
v_cs,
case [SubExp]
ds of
[] -> BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
e
[SubExp]
_ -> BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
ds) SubExp
e
)
| Bool
otherwise =
(Param (LParamInfo rep), VName)
-> Either (Param (LParamInfo rep), VName) ([VName], Certs, Exp rep)
forall a b. a -> Either a b
Left (Param (LParamInfo rep)
p, VName
v)
removeUnusedSOACInput ::
forall rep.
(Aliased rep, Buildable rep, BuilderOps rep, HasSOAC rep) =>
TopDownRuleOp rep
removeUnusedSOACInput :: forall rep.
(Aliased rep, Buildable rep, BuilderOps rep, HasSOAC rep) =>
TopDownRuleOp rep
removeUnusedSOACInput TopDown rep
_ Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux OpC rep rep
op
| Just (Screma SubExp
w [VName]
arrs ScremaForm rep
form :: SOAC rep) <- OpC rep rep -> Maybe (SOAC rep)
forall rep. HasSOAC rep => Op rep -> Maybe (SOAC rep)
asSOAC OpC rep rep
op,
ScremaForm [Scan rep]
scan [Reduce rep]
reduce Lambda rep
map_lam <- ScremaForm rep
form,
Just ([VName]
used_arrs, Lambda rep
map_lam') <- Lambda rep -> [VName] -> Maybe ([VName], Lambda rep)
forall {rep} {b}.
(FreeDec (ExpDec rep), FreeDec (BodyDec rep), FreeIn (OpC rep rep),
FreeIn (FParamInfo rep), FreeIn (LParamInfo rep),
FreeIn (LetDec rep), FreeIn (RetType rep),
FreeIn (BranchType rep)) =>
Lambda rep -> [b] -> Maybe ([b], Lambda rep)
remove Lambda rep
map_lam [VName]
arrs =
RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep)
-> (OpC rep rep -> RuleM rep ()) -> OpC rep rep -> Rule rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StmAux (ExpDec rep) -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux (RuleM rep () -> RuleM rep ())
-> (OpC rep rep -> RuleM rep ()) -> OpC rep rep -> RuleM rep ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pat (LetDec (Rep (RuleM rep)))
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
Pat (LetDec (Rep (RuleM rep)))
pat (Exp rep -> RuleM rep ())
-> (OpC rep rep -> Exp rep) -> OpC rep rep -> RuleM rep ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. OpC rep rep -> Exp rep
forall rep. Op rep -> Exp rep
Op (OpC rep rep -> Rule rep) -> OpC rep rep -> Rule rep
forall a b. (a -> b) -> a -> b
$
SOAC rep -> OpC rep rep
forall rep. HasSOAC rep => SOAC rep -> Op rep
soacOp (SubExp -> [VName] -> ScremaForm rep -> SOAC rep
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName]
used_arrs ([Scan rep] -> [Reduce rep] -> Lambda rep -> ScremaForm rep
forall rep.
[Scan rep] -> [Reduce rep] -> Lambda rep -> ScremaForm rep
ScremaForm [Scan rep]
scan [Reduce rep]
reduce Lambda rep
map_lam'))
| Just (Scatter SubExp
w [VName]
arrs Lambda rep
map_lam [(Shape, Int, VName)]
dests :: SOAC rep) <- OpC rep rep -> Maybe (SOAC rep)
forall rep. HasSOAC rep => Op rep -> Maybe (SOAC rep)
asSOAC OpC rep rep
op,
Just ([VName]
used_arrs, Lambda rep
map_lam') <- Lambda rep -> [VName] -> Maybe ([VName], Lambda rep)
forall {rep} {b}.
(FreeDec (ExpDec rep), FreeDec (BodyDec rep), FreeIn (OpC rep rep),
FreeIn (FParamInfo rep), FreeIn (LParamInfo rep),
FreeIn (LetDec rep), FreeIn (RetType rep),
FreeIn (BranchType rep)) =>
Lambda rep -> [b] -> Maybe ([b], Lambda rep)
remove Lambda rep
map_lam [VName]
arrs =
RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep)
-> (OpC rep rep -> RuleM rep ()) -> OpC rep rep -> Rule rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StmAux (ExpDec rep) -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux (RuleM rep () -> RuleM rep ())
-> (OpC rep rep -> RuleM rep ()) -> OpC rep rep -> RuleM rep ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pat (LetDec (Rep (RuleM rep)))
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
Pat (LetDec (Rep (RuleM rep)))
pat (Exp rep -> RuleM rep ())
-> (OpC rep rep -> Exp rep) -> OpC rep rep -> RuleM rep ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. OpC rep rep -> Exp rep
forall rep. Op rep -> Exp rep
Op (OpC rep rep -> Rule rep) -> OpC rep rep -> Rule rep
forall a b. (a -> b) -> a -> b
$
SOAC rep -> OpC rep rep
forall rep. HasSOAC rep => SOAC rep -> Op rep
soacOp (SubExp
-> [VName] -> Lambda rep -> [(Shape, Int, VName)] -> SOAC rep
forall rep.
SubExp
-> [VName] -> Lambda rep -> [(Shape, Int, VName)] -> SOAC rep
Scatter SubExp
w [VName]
used_arrs Lambda rep
map_lam' [(Shape, Int, VName)]
dests)
where
used_in_body :: Lambda rep -> Names
used_in_body Lambda rep
map_lam = Body rep -> Names
forall a. FreeIn a => a -> Names
freeIn (Body rep -> Names) -> Body rep -> Names
forall a b. (a -> b) -> a -> b
$ Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
map_lam
usedInput :: Lambda rep -> (Param dec, b) -> Bool
usedInput Lambda rep
map_lam (Param dec
param, b
_) = Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
param VName -> Names -> Bool
`nameIn` Lambda rep -> Names
forall {rep}.
(FreeDec (ExpDec rep), FreeDec (BodyDec rep), FreeIn (OpC rep rep),
FreeIn (FParamInfo rep), FreeIn (LParamInfo rep),
FreeIn (LetDec rep), FreeIn (RetType rep),
FreeIn (BranchType rep)) =>
Lambda rep -> Names
used_in_body Lambda rep
map_lam
remove :: Lambda rep -> [b] -> Maybe ([b], Lambda rep)
remove Lambda rep
map_lam [b]
arrs =
let ([(Param (LParamInfo rep), b)]
used, [(Param (LParamInfo rep), b)]
unused) = ((Param (LParamInfo rep), b) -> Bool)
-> [(Param (LParamInfo rep), b)]
-> ([(Param (LParamInfo rep), b)], [(Param (LParamInfo rep), b)])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition (Lambda rep -> (Param (LParamInfo rep), b) -> Bool
forall {rep} {dec} {b}.
(FreeDec (ExpDec rep), FreeDec (BodyDec rep), FreeIn (OpC rep rep),
FreeIn (FParamInfo rep), FreeIn (LParamInfo rep),
FreeIn (LetDec rep), FreeIn (RetType rep),
FreeIn (BranchType rep)) =>
Lambda rep -> (Param dec, b) -> Bool
usedInput Lambda rep
map_lam) ([Param (LParamInfo rep)] -> [b] -> [(Param (LParamInfo rep), b)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Lambda rep -> [Param (LParamInfo rep)]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
map_lam) [b]
arrs)
([Param (LParamInfo rep)]
used_params, [b]
used_arrs) = [(Param (LParamInfo rep), b)] -> ([Param (LParamInfo rep)], [b])
forall a b. [(a, b)] -> ([a], [b])
unzip [(Param (LParamInfo rep), b)]
used
map_lam' :: Lambda rep
map_lam' = Lambda rep
map_lam {lambdaParams :: [Param (LParamInfo rep)]
lambdaParams = [Param (LParamInfo rep)]
used_params}
in if [(Param (LParamInfo rep), b)] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(Param (LParamInfo rep), b)]
unused then Maybe ([b], Lambda rep)
forall a. Maybe a
Nothing else ([b], Lambda rep) -> Maybe ([b], Lambda rep)
forall a. a -> Maybe a
Just ([b]
used_arrs, Lambda rep
map_lam')
removeUnusedSOACInput TopDown rep
_ Pat (LetDec rep)
_ StmAux (ExpDec rep)
_ OpC rep rep
_ = Rule rep
forall rep. Rule rep
Skip
removeDeadMapping :: BottomUpRuleOp (Wise SOACS)
removeDeadMapping :: RuleOp (Wise SOACS) (BottomUp (Wise SOACS))
removeDeadMapping (SymbolTable (Wise SOACS)
_, UsageTable
used) (Pat [PatElem (LetDec (Wise SOACS))]
pes) StmAux (ExpDec (Wise SOACS))
aux (Screma SubExp
w [VName]
arrs (ScremaForm [Scan (Wise SOACS)]
scans [Reduce (Wise SOACS)]
reds Lambda (Wise SOACS)
lam))
| ([PatElem (VarWisdom, Type)]
nonmap_pes, [PatElem (VarWisdom, Type)]
map_pes) <- Int
-> [PatElem (VarWisdom, Type)]
-> ([PatElem (VarWisdom, Type)], [PatElem (VarWisdom, Type)])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_nonmap_res [PatElem (VarWisdom, Type)]
[PatElem (LetDec (Wise SOACS))]
pes,
Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [PatElem (VarWisdom, Type)] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [PatElem (VarWisdom, Type)]
map_pes =
let (Result
nonmap_res, Result
map_res) = Int -> Result -> (Result, Result)
forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_nonmap_res (Result -> (Result, Result)) -> Result -> (Result, Result)
forall a b. (a -> b) -> a -> b
$ Body (Wise SOACS) -> Result
forall rep. Body rep -> Result
bodyResult (Body (Wise SOACS) -> Result) -> Body (Wise SOACS) -> Result
forall a b. (a -> b) -> a -> b
$ Lambda (Wise SOACS) -> Body (Wise SOACS)
forall rep. Lambda rep -> Body rep
lambdaBody Lambda (Wise SOACS)
lam
([Type]
nonmap_ts, [Type]
map_ts) = Int -> [Type] -> ([Type], [Type])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_nonmap_res ([Type] -> ([Type], [Type])) -> [Type] -> ([Type], [Type])
forall a b. (a -> b) -> a -> b
$ Lambda (Wise SOACS) -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda (Wise SOACS)
lam
isUsed :: (PatElem (VarWisdom, Type), SubExpRes, Type) -> Bool
isUsed (PatElem (VarWisdom, Type)
bindee, SubExpRes
_, Type
_) = (VName -> UsageTable -> Bool
`UT.used` UsageTable
used) (VName -> Bool) -> VName -> Bool
forall a b. (a -> b) -> a -> b
$ PatElem (VarWisdom, Type) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (VarWisdom, Type)
bindee
([PatElem (VarWisdom, Type)]
map_pes', Result
map_res', [Type]
map_ts') =
[(PatElem (VarWisdom, Type), SubExpRes, Type)]
-> ([PatElem (VarWisdom, Type)], Result, [Type])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 ([(PatElem (VarWisdom, Type), SubExpRes, Type)]
-> ([PatElem (VarWisdom, Type)], Result, [Type]))
-> [(PatElem (VarWisdom, Type), SubExpRes, Type)]
-> ([PatElem (VarWisdom, Type)], Result, [Type])
forall a b. (a -> b) -> a -> b
$ ((PatElem (VarWisdom, Type), SubExpRes, Type) -> Bool)
-> [(PatElem (VarWisdom, Type), SubExpRes, Type)]
-> [(PatElem (VarWisdom, Type), SubExpRes, Type)]
forall a. (a -> Bool) -> [a] -> [a]
filter (PatElem (VarWisdom, Type), SubExpRes, Type) -> Bool
isUsed ([(PatElem (VarWisdom, Type), SubExpRes, Type)]
-> [(PatElem (VarWisdom, Type), SubExpRes, Type)])
-> [(PatElem (VarWisdom, Type), SubExpRes, Type)]
-> [(PatElem (VarWisdom, Type), SubExpRes, Type)]
forall a b. (a -> b) -> a -> b
$ [PatElem (VarWisdom, Type)]
-> Result
-> [Type]
-> [(PatElem (VarWisdom, Type), SubExpRes, Type)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [PatElem (VarWisdom, Type)]
map_pes Result
map_res [Type]
map_ts
lam' :: Lambda (Wise SOACS)
lam' =
Lambda (Wise SOACS)
lam
{ lambdaBody :: Body (Wise SOACS)
lambdaBody = (Lambda (Wise SOACS) -> Body (Wise SOACS)
forall rep. Lambda rep -> Body rep
lambdaBody Lambda (Wise SOACS)
lam) {bodyResult :: Result
bodyResult = Result
nonmap_res Result -> Result -> Result
forall a. Semigroup a => a -> a -> a
<> Result
map_res'},
lambdaReturnType :: [Type]
lambdaReturnType = [Type]
nonmap_ts [Type] -> [Type] -> [Type]
forall a. Semigroup a => a -> a -> a
<> [Type]
map_ts'
}
in if [PatElem (VarWisdom, Type)]
map_pes [PatElem (VarWisdom, Type)] -> [PatElem (VarWisdom, Type)] -> Bool
forall a. Eq a => a -> a -> Bool
/= [PatElem (VarWisdom, Type)]
map_pes'
then
RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM (Wise SOACS) () -> Rule (Wise SOACS))
-> (RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ())
-> RuleM (Wise SOACS) ()
-> Rule (Wise SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StmAux (ExpWisdom, ())
-> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpWisdom, ())
StmAux (ExpDec (Wise SOACS))
aux (RuleM (Wise SOACS) () -> Rule (Wise SOACS))
-> RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall a b. (a -> b) -> a -> b
$
Pat (LetDec (Rep (RuleM (Wise SOACS))))
-> Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind ([PatElem (LetDec (Rep (RuleM (Wise SOACS))))]
-> Pat (LetDec (Rep (RuleM (Wise SOACS))))
forall dec. [PatElem dec] -> Pat dec
Pat ([PatElem (LetDec (Rep (RuleM (Wise SOACS))))]
-> Pat (LetDec (Rep (RuleM (Wise SOACS)))))
-> [PatElem (LetDec (Rep (RuleM (Wise SOACS))))]
-> Pat (LetDec (Rep (RuleM (Wise SOACS))))
forall a b. (a -> b) -> a -> b
$ [PatElem (VarWisdom, Type)]
nonmap_pes [PatElem (VarWisdom, Type)]
-> [PatElem (VarWisdom, Type)] -> [PatElem (VarWisdom, Type)]
forall a. Semigroup a => a -> a -> a
<> [PatElem (VarWisdom, Type)]
map_pes') (Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ())
-> Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$
Op (Rep (RuleM (Wise SOACS))) -> Exp (Rep (RuleM (Wise SOACS)))
forall rep. Op rep -> Exp rep
Op (Op (Rep (RuleM (Wise SOACS))) -> Exp (Rep (RuleM (Wise SOACS))))
-> Op (Rep (RuleM (Wise SOACS))) -> Exp (Rep (RuleM (Wise SOACS)))
forall a b. (a -> b) -> a -> b
$
SubExp -> [VName] -> ScremaForm (Wise SOACS) -> SOAC (Wise SOACS)
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName]
arrs (ScremaForm (Wise SOACS) -> SOAC (Wise SOACS))
-> ScremaForm (Wise SOACS) -> SOAC (Wise SOACS)
forall a b. (a -> b) -> a -> b
$
[Scan (Wise SOACS)]
-> [Reduce (Wise SOACS)]
-> Lambda (Wise SOACS)
-> ScremaForm (Wise SOACS)
forall rep.
[Scan rep] -> [Reduce rep] -> Lambda rep -> ScremaForm rep
ScremaForm [Scan (Wise SOACS)]
scans [Reduce (Wise SOACS)]
reds Lambda (Wise SOACS)
lam'
else Rule (Wise SOACS)
forall rep. Rule rep
Skip
where
num_nonmap_res :: Int
num_nonmap_res = [Scan (Wise SOACS)] -> Int
forall rep. [Scan rep] -> Int
scanResults [Scan (Wise SOACS)]
scans Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [Reduce (Wise SOACS)] -> Int
forall rep. [Reduce rep] -> Int
redResults [Reduce (Wise SOACS)]
reds
removeDeadMapping BottomUp (Wise SOACS)
_ Pat (LetDec (Wise SOACS))
_ StmAux (ExpDec (Wise SOACS))
_ Op (Wise SOACS)
_ = Rule (Wise SOACS)
forall rep. Rule rep
Skip
removeDuplicateMapOutput :: TopDownRuleOp (Wise SOACS)
removeDuplicateMapOutput :: RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
removeDuplicateMapOutput SymbolTable (Wise SOACS)
_ (Pat [PatElem (LetDec (Wise SOACS))]
pes) StmAux (ExpDec (Wise SOACS))
aux (Screma SubExp
w [VName]
arrs ScremaForm (Wise SOACS)
form)
| Just Lambda (Wise SOACS)
fun <- ScremaForm (Wise SOACS) -> Maybe (Lambda (Wise SOACS))
forall rep. ScremaForm rep -> Maybe (Lambda rep)
isMapSOAC ScremaForm (Wise SOACS)
form =
let ses :: Result
ses = Body (Wise SOACS) -> Result
forall rep. Body rep -> Result
bodyResult (Body (Wise SOACS) -> Result) -> Body (Wise SOACS) -> Result
forall a b. (a -> b) -> a -> b
$ Lambda (Wise SOACS) -> Body (Wise SOACS)
forall rep. Lambda rep -> Body rep
lambdaBody Lambda (Wise SOACS)
fun
ts :: [Type]
ts = Lambda (Wise SOACS) -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda (Wise SOACS)
fun
ses_ts_pes :: [(SubExpRes, Type, PatElem (VarWisdom, Type))]
ses_ts_pes = Result
-> [Type]
-> [PatElem (VarWisdom, Type)]
-> [(SubExpRes, Type, PatElem (VarWisdom, Type))]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 Result
ses [Type]
ts [PatElem (VarWisdom, Type)]
[PatElem (LetDec (Wise SOACS))]
pes
([(SubExpRes, Type, PatElem (VarWisdom, Type))]
ses_ts_pes', [(PatElem (VarWisdom, Type), PatElem (VarWisdom, Type))]
copies) =
(([(SubExpRes, Type, PatElem (VarWisdom, Type))],
[(PatElem (VarWisdom, Type), PatElem (VarWisdom, Type))])
-> (SubExpRes, Type, PatElem (VarWisdom, Type))
-> ([(SubExpRes, Type, PatElem (VarWisdom, Type))],
[(PatElem (VarWisdom, Type), PatElem (VarWisdom, Type))]))
-> ([(SubExpRes, Type, PatElem (VarWisdom, Type))],
[(PatElem (VarWisdom, Type), PatElem (VarWisdom, Type))])
-> [(SubExpRes, Type, PatElem (VarWisdom, Type))]
-> ([(SubExpRes, Type, PatElem (VarWisdom, Type))],
[(PatElem (VarWisdom, Type), PatElem (VarWisdom, Type))])
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl ([(SubExpRes, Type, PatElem (VarWisdom, Type))],
[(PatElem (VarWisdom, Type), PatElem (VarWisdom, Type))])
-> (SubExpRes, Type, PatElem (VarWisdom, Type))
-> ([(SubExpRes, Type, PatElem (VarWisdom, Type))],
[(PatElem (VarWisdom, Type), PatElem (VarWisdom, Type))])
forall {b} {a}.
([(SubExpRes, b, a)], [(a, a)])
-> (SubExpRes, b, a) -> ([(SubExpRes, b, a)], [(a, a)])
checkForDuplicates ([(SubExpRes, Type, PatElem (VarWisdom, Type))]
forall a. Monoid a => a
mempty, [(PatElem (VarWisdom, Type), PatElem (VarWisdom, Type))]
forall a. Monoid a => a
mempty) [(SubExpRes, Type, PatElem (VarWisdom, Type))]
ses_ts_pes
in if [(PatElem (VarWisdom, Type), PatElem (VarWisdom, Type))] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(PatElem (VarWisdom, Type), PatElem (VarWisdom, Type))]
copies
then Rule (Wise SOACS)
forall rep. Rule rep
Skip
else RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM (Wise SOACS) () -> Rule (Wise SOACS))
-> RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ do
let (Result
ses', [Type]
ts', [PatElem (VarWisdom, Type)]
pes') = [(SubExpRes, Type, PatElem (VarWisdom, Type))]
-> (Result, [Type], [PatElem (VarWisdom, Type)])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(SubExpRes, Type, PatElem (VarWisdom, Type))]
ses_ts_pes'
fun' :: Lambda (Wise SOACS)
fun' =
Lambda (Wise SOACS)
fun
{ lambdaBody :: Body (Wise SOACS)
lambdaBody = (Lambda (Wise SOACS) -> Body (Wise SOACS)
forall rep. Lambda rep -> Body rep
lambdaBody Lambda (Wise SOACS)
fun) {bodyResult :: Result
bodyResult = Result
ses'},
lambdaReturnType :: [Type]
lambdaReturnType = [Type]
ts'
}
StmAux (ExpWisdom, ())
-> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpWisdom, ())
StmAux (ExpDec (Wise SOACS))
aux (RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ())
-> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$ Pat (LetDec (Rep (RuleM (Wise SOACS))))
-> Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind ([PatElem (VarWisdom, Type)] -> Pat (VarWisdom, Type)
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (VarWisdom, Type)]
pes') (Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ())
-> Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$ Op (Rep (RuleM (Wise SOACS))) -> Exp (Rep (RuleM (Wise SOACS)))
forall rep. Op rep -> Exp rep
Op (Op (Rep (RuleM (Wise SOACS))) -> Exp (Rep (RuleM (Wise SOACS))))
-> Op (Rep (RuleM (Wise SOACS))) -> Exp (Rep (RuleM (Wise SOACS)))
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm (Wise SOACS) -> SOAC (Wise SOACS)
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName]
arrs (ScremaForm (Wise SOACS) -> SOAC (Wise SOACS))
-> ScremaForm (Wise SOACS) -> SOAC (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ Lambda (Wise SOACS) -> ScremaForm (Wise SOACS)
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda (Wise SOACS)
fun'
[(PatElem (VarWisdom, Type), PatElem (VarWisdom, Type))]
-> ((PatElem (VarWisdom, Type), PatElem (VarWisdom, Type))
-> RuleM (Wise SOACS) ())
-> RuleM (Wise SOACS) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(PatElem (VarWisdom, Type), PatElem (VarWisdom, Type))]
copies (((PatElem (VarWisdom, Type), PatElem (VarWisdom, Type))
-> RuleM (Wise SOACS) ())
-> RuleM (Wise SOACS) ())
-> ((PatElem (VarWisdom, Type), PatElem (VarWisdom, Type))
-> RuleM (Wise SOACS) ())
-> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$ \(PatElem (VarWisdom, Type)
from, PatElem (VarWisdom, Type)
to) ->
Pat (LetDec (Rep (RuleM (Wise SOACS))))
-> Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind ([PatElem (VarWisdom, Type)] -> Pat (VarWisdom, Type)
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (VarWisdom, Type)
to]) (Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ())
-> Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (RuleM (Wise SOACS)))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (RuleM (Wise SOACS))))
-> BasicOp -> Exp (Rep (RuleM (Wise SOACS)))
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate Shape
forall a. Monoid a => a
mempty (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ PatElem (VarWisdom, Type) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (VarWisdom, Type)
from
where
checkForDuplicates :: ([(SubExpRes, b, a)], [(a, a)])
-> (SubExpRes, b, a) -> ([(SubExpRes, b, a)], [(a, a)])
checkForDuplicates ([(SubExpRes, b, a)]
ses_ts_pes', [(a, a)]
copies) (SubExpRes
se, b
t, a
pe)
| Just (SubExpRes
_, b
_, a
pe') <- ((SubExpRes, b, a) -> Bool)
-> [(SubExpRes, b, a)] -> Maybe (SubExpRes, b, a)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (\(SubExpRes
x, b
_, a
_) -> SubExpRes -> SubExp
resSubExp SubExpRes
x SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExpRes -> SubExp
resSubExp SubExpRes
se) [(SubExpRes, b, a)]
ses_ts_pes' =
([(SubExpRes, b, a)]
ses_ts_pes', (a
pe', a
pe) (a, a) -> [(a, a)] -> [(a, a)]
forall a. a -> [a] -> [a]
: [(a, a)]
copies)
| Bool
otherwise = ([(SubExpRes, b, a)]
ses_ts_pes' [(SubExpRes, b, a)] -> [(SubExpRes, b, a)] -> [(SubExpRes, b, a)]
forall a. [a] -> [a] -> [a]
++ [(SubExpRes
se, b
t, a
pe)], [(a, a)]
copies)
removeDuplicateMapOutput SymbolTable (Wise SOACS)
_ Pat (LetDec (Wise SOACS))
_ StmAux (ExpDec (Wise SOACS))
_ Op (Wise SOACS)
_ = Rule (Wise SOACS)
forall rep. Rule rep
Skip
mapOpToOp :: BottomUpRuleOp (Wise SOACS)
mapOpToOp :: RuleOp (Wise SOACS) (BottomUp (Wise SOACS))
mapOpToOp (SymbolTable (Wise SOACS)
_, UsageTable
used) Pat (LetDec (Wise SOACS))
pat StmAux (ExpDec (Wise SOACS))
aux1 Op (Wise SOACS)
e
| Just (PatElem (VarWisdom, Type)
map_pe, Certs
cs, SubExp
w, BasicOp (Reshape ReshapeKind
k Shape
newshape VName
reshape_arr), [Param Type
p], [VName
arr]) <-
Pat (VarWisdom, Type)
-> SOAC (Wise SOACS)
-> Maybe
(PatElem (VarWisdom, Type), Certs, SubExp, Exp (Wise SOACS),
[Param Type], [VName])
forall dec.
Pat dec
-> SOAC (Wise SOACS)
-> Maybe
(PatElem dec, Certs, SubExp, Exp (Wise SOACS), [Param Type],
[VName])
isMapWithOp Pat (VarWisdom, Type)
Pat (LetDec (Wise SOACS))
pat Op (Wise SOACS)
SOAC (Wise SOACS)
e,
Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
p VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
reshape_arr,
Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ VName -> UsageTable -> Bool
UT.isConsumed (PatElem (VarWisdom, Type) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (VarWisdom, Type)
map_pe) UsageTable
used = RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM (Wise SOACS) () -> Rule (Wise SOACS))
-> RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ do
Certs -> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall a. Certs -> RuleM (Wise SOACS) a -> RuleM (Wise SOACS) a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (StmAux (ExpWisdom, ()) -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (ExpWisdom, ())
StmAux (ExpDec (Wise SOACS))
aux1 Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> Certs
cs) (RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ())
-> (BasicOp -> RuleM (Wise SOACS) ())
-> BasicOp
-> RuleM (Wise SOACS) ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pat (LetDec (Rep (RuleM (Wise SOACS))))
-> Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec (Rep (RuleM (Wise SOACS))))
Pat (LetDec (Wise SOACS))
pat (Exp (Wise SOACS) -> RuleM (Wise SOACS) ())
-> (BasicOp -> Exp (Wise SOACS))
-> BasicOp
-> RuleM (Wise SOACS) ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp (Wise SOACS)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> RuleM (Wise SOACS) ())
-> BasicOp -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$
ReshapeKind -> Shape -> VName -> BasicOp
Reshape ReshapeKind
k ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
w] Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> Shape
newshape) VName
arr
| Just (PatElem (VarWisdom, Type)
_, Certs
cs, SubExp
_, BasicOp (Concat Int
d (VName
arr :| [VName]
arrs) SubExp
dw), [Param Type]
ps, VName
outer_arr : [VName]
outer_arrs) <-
Pat (VarWisdom, Type)
-> SOAC (Wise SOACS)
-> Maybe
(PatElem (VarWisdom, Type), Certs, SubExp, Exp (Wise SOACS),
[Param Type], [VName])
forall dec.
Pat dec
-> SOAC (Wise SOACS)
-> Maybe
(PatElem dec, Certs, SubExp, Exp (Wise SOACS), [Param Type],
[VName])
isMapWithOp Pat (VarWisdom, Type)
Pat (LetDec (Wise SOACS))
pat Op (Wise SOACS)
SOAC (Wise SOACS)
e,
(VName
arr VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
arrs) [VName] -> [VName] -> Bool
forall a. Eq a => a -> a -> Bool
== (Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName [Param Type]
ps =
RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM (Wise SOACS) () -> Rule (Wise SOACS))
-> (BasicOp -> RuleM (Wise SOACS) ())
-> BasicOp
-> Rule (Wise SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Certs -> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall a. Certs -> RuleM (Wise SOACS) a -> RuleM (Wise SOACS) a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (StmAux (ExpWisdom, ()) -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (ExpWisdom, ())
StmAux (ExpDec (Wise SOACS))
aux1 Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> Certs
cs) (RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ())
-> (BasicOp -> RuleM (Wise SOACS) ())
-> BasicOp
-> RuleM (Wise SOACS) ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pat (LetDec (Rep (RuleM (Wise SOACS))))
-> Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec (Rep (RuleM (Wise SOACS))))
Pat (LetDec (Wise SOACS))
pat (Exp (Wise SOACS) -> RuleM (Wise SOACS) ())
-> (BasicOp -> Exp (Wise SOACS))
-> BasicOp
-> RuleM (Wise SOACS) ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp (Wise SOACS)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Rule (Wise SOACS)) -> BasicOp -> Rule (Wise SOACS)
forall a b. (a -> b) -> a -> b
$
Int -> NonEmpty VName -> SubExp -> BasicOp
Concat (Int
d Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) (VName
outer_arr VName -> [VName] -> NonEmpty VName
forall a. a -> [a] -> NonEmpty a
:| [VName]
outer_arrs) SubExp
dw
| Just
(PatElem (VarWisdom, Type)
map_pe, Certs
cs, SubExp
_, BasicOp (Rearrange [Int]
perm VName
rearrange_arr), [Param Type
p], [VName
arr]) <-
Pat (VarWisdom, Type)
-> SOAC (Wise SOACS)
-> Maybe
(PatElem (VarWisdom, Type), Certs, SubExp, Exp (Wise SOACS),
[Param Type], [VName])
forall dec.
Pat dec
-> SOAC (Wise SOACS)
-> Maybe
(PatElem dec, Certs, SubExp, Exp (Wise SOACS), [Param Type],
[VName])
isMapWithOp Pat (VarWisdom, Type)
Pat (LetDec (Wise SOACS))
pat Op (Wise SOACS)
SOAC (Wise SOACS)
e,
Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
p VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
rearrange_arr,
Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ VName -> UsageTable -> Bool
UT.isConsumed (PatElem (VarWisdom, Type) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (VarWisdom, Type)
map_pe) UsageTable
used =
RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM (Wise SOACS) () -> Rule (Wise SOACS))
-> (BasicOp -> RuleM (Wise SOACS) ())
-> BasicOp
-> Rule (Wise SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Certs -> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall a. Certs -> RuleM (Wise SOACS) a -> RuleM (Wise SOACS) a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (StmAux (ExpWisdom, ()) -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (ExpWisdom, ())
StmAux (ExpDec (Wise SOACS))
aux1 Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> Certs
cs) (RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ())
-> (BasicOp -> RuleM (Wise SOACS) ())
-> BasicOp
-> RuleM (Wise SOACS) ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pat (LetDec (Rep (RuleM (Wise SOACS))))
-> Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec (Rep (RuleM (Wise SOACS))))
Pat (LetDec (Wise SOACS))
pat (Exp (Wise SOACS) -> RuleM (Wise SOACS) ())
-> (BasicOp -> Exp (Wise SOACS))
-> BasicOp
-> RuleM (Wise SOACS) ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp (Wise SOACS)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Rule (Wise SOACS)) -> BasicOp -> Rule (Wise SOACS)
forall a b. (a -> b) -> a -> b
$
[Int] -> VName -> BasicOp
Rearrange (Int
0 Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: (Int -> Int) -> [Int] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (Int
1 +) [Int]
perm) VName
arr
mapOpToOp BottomUp (Wise SOACS)
_ Pat (LetDec (Wise SOACS))
_ StmAux (ExpDec (Wise SOACS))
_ Op (Wise SOACS)
_ = Rule (Wise SOACS)
forall rep. Rule rep
Skip
isMapWithOp ::
Pat dec ->
SOAC (Wise SOACS) ->
Maybe
( PatElem dec,
Certs,
SubExp,
Exp (Wise SOACS),
[Param Type],
[VName]
)
isMapWithOp :: forall dec.
Pat dec
-> SOAC (Wise SOACS)
-> Maybe
(PatElem dec, Certs, SubExp, Exp (Wise SOACS), [Param Type],
[VName])
isMapWithOp Pat dec
pat SOAC (Wise SOACS)
e
| Pat [PatElem dec
map_pe] <- Pat dec
pat,
Screma SubExp
w [VName]
arrs ScremaForm (Wise SOACS)
form <- SOAC (Wise SOACS)
e,
Just Lambda (Wise SOACS)
map_lam <- ScremaForm (Wise SOACS) -> Maybe (Lambda (Wise SOACS))
forall rep. ScremaForm rep -> Maybe (Lambda rep)
isMapSOAC ScremaForm (Wise SOACS)
form,
[Let (Pat [PatElem (LetDec (Wise SOACS))
pe]) StmAux (ExpDec (Wise SOACS))
aux2 Exp (Wise SOACS)
e'] <- Stms (Wise SOACS) -> [Stm (Wise SOACS)]
forall rep. Stms rep -> [Stm rep]
stmsToList (Stms (Wise SOACS) -> [Stm (Wise SOACS)])
-> Stms (Wise SOACS) -> [Stm (Wise SOACS)]
forall a b. (a -> b) -> a -> b
$ Body (Wise SOACS) -> Stms (Wise SOACS)
forall rep. Body rep -> Stms rep
bodyStms (Body (Wise SOACS) -> Stms (Wise SOACS))
-> Body (Wise SOACS) -> Stms (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ Lambda (Wise SOACS) -> Body (Wise SOACS)
forall rep. Lambda rep -> Body rep
lambdaBody Lambda (Wise SOACS)
map_lam,
[SubExpRes Certs
_ (Var VName
r)] <- Body (Wise SOACS) -> Result
forall rep. Body rep -> Result
bodyResult (Body (Wise SOACS) -> Result) -> Body (Wise SOACS) -> Result
forall a b. (a -> b) -> a -> b
$ Lambda (Wise SOACS) -> Body (Wise SOACS)
forall rep. Lambda rep -> Body rep
lambdaBody Lambda (Wise SOACS)
map_lam,
VName
r VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== PatElem (VarWisdom, Type) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (VarWisdom, Type)
PatElem (LetDec (Wise SOACS))
pe =
(PatElem dec, Certs, SubExp, Exp (Wise SOACS), [Param Type],
[VName])
-> Maybe
(PatElem dec, Certs, SubExp, Exp (Wise SOACS), [Param Type],
[VName])
forall a. a -> Maybe a
Just (PatElem dec
map_pe, StmAux (ExpWisdom, ()) -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (ExpWisdom, ())
StmAux (ExpDec (Wise SOACS))
aux2, SubExp
w, Exp (Wise SOACS)
e', Lambda (Wise SOACS) -> [LParam (Wise SOACS)]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda (Wise SOACS)
map_lam, [VName]
arrs)
| Bool
otherwise = Maybe
(PatElem dec, Certs, SubExp, Exp (Wise SOACS), [Param Type],
[VName])
forall a. Maybe a
Nothing
removeDeadReduction :: BottomUpRuleOp (Wise SOACS)
removeDeadReduction :: RuleOp (Wise SOACS) (BottomUp (Wise SOACS))
removeDeadReduction (SymbolTable (Wise SOACS)
_, UsageTable
used) Pat (LetDec (Wise SOACS))
pat StmAux (ExpDec (Wise SOACS))
aux (Screma SubExp
w [VName]
arrs ScremaForm (Wise SOACS)
form) =
case ScremaForm (Wise SOACS)
-> Maybe ([Reduce (Wise SOACS)], Lambda (Wise SOACS))
forall rep. ScremaForm rep -> Maybe ([Reduce rep], Lambda rep)
isRedomapSOAC ScremaForm (Wise SOACS)
form of
Just ([Reduce Commutativity
comm Lambda (Wise SOACS)
redlam [SubExp]
rednes], Lambda (Wise SOACS)
maplam) ->
let mkOp :: Lambda (Wise SOACS)
-> [SubExp] -> Lambda (Wise SOACS) -> ScremaForm (Wise SOACS)
mkOp Lambda (Wise SOACS)
lam [SubExp]
nes' = [Reduce (Wise SOACS)]
-> Lambda (Wise SOACS) -> ScremaForm (Wise SOACS)
forall rep. [Reduce rep] -> Lambda rep -> ScremaForm rep
redomapSOAC [Commutativity
-> Lambda (Wise SOACS) -> [SubExp] -> Reduce (Wise SOACS)
forall rep. Commutativity -> Lambda rep -> [SubExp] -> Reduce rep
Reduce Commutativity
comm Lambda (Wise SOACS)
lam [SubExp]
nes']
in Lambda (Wise SOACS)
-> [SubExp]
-> Lambda (Wise SOACS)
-> (Lambda (Wise SOACS)
-> [SubExp] -> Lambda (Wise SOACS) -> ScremaForm (Wise SOACS))
-> Rule (Wise SOACS)
removeDeadReduction' Lambda (Wise SOACS)
redlam [SubExp]
rednes Lambda (Wise SOACS)
maplam Lambda (Wise SOACS)
-> [SubExp] -> Lambda (Wise SOACS) -> ScremaForm (Wise SOACS)
mkOp
Maybe ([Reduce (Wise SOACS)], Lambda (Wise SOACS))
_ ->
case ScremaForm (Wise SOACS)
-> Maybe ([Scan (Wise SOACS)], Lambda (Wise SOACS))
forall rep. ScremaForm rep -> Maybe ([Scan rep], Lambda rep)
isScanomapSOAC ScremaForm (Wise SOACS)
form of
Just ([Scan Lambda (Wise SOACS)
scanlam [SubExp]
nes], Lambda (Wise SOACS)
maplam) ->
let mkOp :: Lambda rep -> [SubExp] -> Lambda rep -> ScremaForm rep
mkOp Lambda rep
lam [SubExp]
nes' = [Scan rep] -> Lambda rep -> ScremaForm rep
forall rep. [Scan rep] -> Lambda rep -> ScremaForm rep
scanomapSOAC [Lambda rep -> [SubExp] -> Scan rep
forall rep. Lambda rep -> [SubExp] -> Scan rep
Scan Lambda rep
lam [SubExp]
nes']
in Lambda (Wise SOACS)
-> [SubExp]
-> Lambda (Wise SOACS)
-> (Lambda (Wise SOACS)
-> [SubExp] -> Lambda (Wise SOACS) -> ScremaForm (Wise SOACS))
-> Rule (Wise SOACS)
removeDeadReduction' Lambda (Wise SOACS)
scanlam [SubExp]
nes Lambda (Wise SOACS)
maplam Lambda (Wise SOACS)
-> [SubExp] -> Lambda (Wise SOACS) -> ScremaForm (Wise SOACS)
forall {rep}.
Lambda rep -> [SubExp] -> Lambda rep -> ScremaForm rep
mkOp
Maybe ([Scan (Wise SOACS)], Lambda (Wise SOACS))
_ -> Rule (Wise SOACS)
forall rep. Rule rep
Skip
where
removeDeadReduction' :: Lambda (Wise SOACS)
-> [SubExp]
-> Lambda (Wise SOACS)
-> (Lambda (Wise SOACS)
-> [SubExp] -> Lambda (Wise SOACS) -> ScremaForm (Wise SOACS))
-> Rule (Wise SOACS)
removeDeadReduction' Lambda (Wise SOACS)
redlam [SubExp]
nes Lambda (Wise SOACS)
maplam Lambda (Wise SOACS)
-> [SubExp] -> Lambda (Wise SOACS) -> ScremaForm (Wise SOACS)
mkOp
| Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> UsageTable -> Bool
`UT.used` UsageTable
used) ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$ Pat (VarWisdom, Type) -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat (VarWisdom, Type)
Pat (LetDec (Wise SOACS))
pat,
let ([PatElem (VarWisdom, Type)]
red_pes, [PatElem (VarWisdom, Type)]
map_pes) = Int
-> [PatElem (VarWisdom, Type)]
-> ([PatElem (VarWisdom, Type)], [PatElem (VarWisdom, Type)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) ([PatElem (VarWisdom, Type)]
-> ([PatElem (VarWisdom, Type)], [PatElem (VarWisdom, Type)]))
-> [PatElem (VarWisdom, Type)]
-> ([PatElem (VarWisdom, Type)], [PatElem (VarWisdom, Type)])
forall a b. (a -> b) -> a -> b
$ Pat (VarWisdom, Type) -> [PatElem (VarWisdom, Type)]
forall dec. Pat dec -> [PatElem dec]
patElems Pat (VarWisdom, Type)
Pat (LetDec (Wise SOACS))
pat,
let redlam_deps :: Dependencies
redlam_deps = Body (Wise SOACS) -> Dependencies
forall rep. ASTRep rep => Body rep -> Dependencies
dataDependencies (Body (Wise SOACS) -> Dependencies)
-> Body (Wise SOACS) -> Dependencies
forall a b. (a -> b) -> a -> b
$ Lambda (Wise SOACS) -> Body (Wise SOACS)
forall rep. Lambda rep -> Body rep
lambdaBody Lambda (Wise SOACS)
redlam,
let redlam_res :: Result
redlam_res = Body (Wise SOACS) -> Result
forall rep. Body rep -> Result
bodyResult (Body (Wise SOACS) -> Result) -> Body (Wise SOACS) -> Result
forall a b. (a -> b) -> a -> b
$ Lambda (Wise SOACS) -> Body (Wise SOACS)
forall rep. Lambda rep -> Body rep
lambdaBody Lambda (Wise SOACS)
redlam,
let redlam_params :: [LParam (Wise SOACS)]
redlam_params = Lambda (Wise SOACS) -> [LParam (Wise SOACS)]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda (Wise SOACS)
redlam,
let used_after :: [Param Type]
used_after =
((PatElem (VarWisdom, Type), Param Type) -> Param Type)
-> [(PatElem (VarWisdom, Type), Param Type)] -> [Param Type]
forall a b. (a -> b) -> [a] -> [b]
map (PatElem (VarWisdom, Type), Param Type) -> Param Type
forall a b. (a, b) -> b
snd ([(PatElem (VarWisdom, Type), Param Type)] -> [Param Type])
-> ([(PatElem (VarWisdom, Type), Param Type)]
-> [(PatElem (VarWisdom, Type), Param Type)])
-> [(PatElem (VarWisdom, Type), Param Type)]
-> [Param Type]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((PatElem (VarWisdom, Type), Param Type) -> Bool)
-> [(PatElem (VarWisdom, Type), Param Type)]
-> [(PatElem (VarWisdom, Type), Param Type)]
forall a. (a -> Bool) -> [a] -> [a]
filter ((VName -> UsageTable -> Bool
`UT.used` UsageTable
used) (VName -> Bool)
-> ((PatElem (VarWisdom, Type), Param Type) -> VName)
-> (PatElem (VarWisdom, Type), Param Type)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElem (VarWisdom, Type) -> VName
forall dec. PatElem dec -> VName
patElemName (PatElem (VarWisdom, Type) -> VName)
-> ((PatElem (VarWisdom, Type), Param Type)
-> PatElem (VarWisdom, Type))
-> (PatElem (VarWisdom, Type), Param Type)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PatElem (VarWisdom, Type), Param Type)
-> PatElem (VarWisdom, Type)
forall a b. (a, b) -> a
fst) ([(PatElem (VarWisdom, Type), Param Type)] -> [Param Type])
-> [(PatElem (VarWisdom, Type), Param Type)] -> [Param Type]
forall a b. (a -> b) -> a -> b
$
[PatElem (VarWisdom, Type)]
-> [Param Type] -> [(PatElem (VarWisdom, Type), Param Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem (VarWisdom, Type)]
red_pes [Param Type]
[LParam (Wise SOACS)]
redlam_params,
let necessary :: Names
necessary =
(Param Type -> Bool)
-> [(Param Type, SubExp)] -> Dependencies -> Names
forall dec.
(Param dec -> Bool)
-> [(Param dec, SubExp)] -> Dependencies -> Names
findNecessaryForReturned
(Param Type -> [Param Type] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Param Type]
used_after)
([Param Type] -> [SubExp] -> [(Param Type, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param Type]
[LParam (Wise SOACS)]
redlam_params ([SubExp] -> [(Param Type, SubExp)])
-> [SubExp] -> [(Param Type, SubExp)]
forall a b. (a -> b) -> a -> b
$ (SubExpRes -> SubExp) -> Result -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp (Result -> [SubExp]) -> Result -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Result
redlam_res Result -> Result -> Result
forall a. Semigroup a => a -> a -> a
<> Result
redlam_res)
Dependencies
redlam_deps,
let alive_mask :: [Bool]
alive_mask = (Param Type -> Bool) -> [Param Type] -> [Bool]
forall a b. (a -> b) -> [a] -> [b]
map ((VName -> Names -> Bool
`nameIn` Names
necessary) (VName -> Bool) -> (Param Type -> VName) -> Param Type -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param Type -> VName
forall dec. Param dec -> VName
paramName) [Param Type]
[LParam (Wise SOACS)]
redlam_params,
Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and (Int -> [Bool] -> [Bool]
forall a. Int -> [a] -> [a]
take ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) [Bool]
alive_mask) = RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM (Wise SOACS) () -> Rule (Wise SOACS))
-> RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ do
let fixDeadToNeutral :: Bool -> a -> Maybe a
fixDeadToNeutral Bool
lives a
ne = if Bool
lives then Maybe a
forall a. Maybe a
Nothing else a -> Maybe a
forall a. a -> Maybe a
Just a
ne
dead_fix :: [Maybe SubExp]
dead_fix = (Bool -> SubExp -> Maybe SubExp)
-> [Bool] -> [SubExp] -> [Maybe SubExp]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Bool -> SubExp -> Maybe SubExp
forall {a}. Bool -> a -> Maybe a
fixDeadToNeutral [Bool]
alive_mask [SubExp]
nes
([PatElem (VarWisdom, Type)]
used_red_pes, [Param Type]
_, [SubExp]
used_nes) =
[(PatElem (VarWisdom, Type), Param Type, SubExp)]
-> ([PatElem (VarWisdom, Type)], [Param Type], [SubExp])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 ([(PatElem (VarWisdom, Type), Param Type, SubExp)]
-> ([PatElem (VarWisdom, Type)], [Param Type], [SubExp]))
-> ([(PatElem (VarWisdom, Type), Param Type, SubExp)]
-> [(PatElem (VarWisdom, Type), Param Type, SubExp)])
-> [(PatElem (VarWisdom, Type), Param Type, SubExp)]
-> ([PatElem (VarWisdom, Type)], [Param Type], [SubExp])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((PatElem (VarWisdom, Type), Param Type, SubExp) -> Bool)
-> [(PatElem (VarWisdom, Type), Param Type, SubExp)]
-> [(PatElem (VarWisdom, Type), Param Type, SubExp)]
forall a. (a -> Bool) -> [a] -> [a]
filter (\(PatElem (VarWisdom, Type)
_, Param Type
x, SubExp
_) -> Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
x VName -> Names -> Bool
`nameIn` Names
necessary) ([(PatElem (VarWisdom, Type), Param Type, SubExp)]
-> ([PatElem (VarWisdom, Type)], [Param Type], [SubExp]))
-> [(PatElem (VarWisdom, Type), Param Type, SubExp)]
-> ([PatElem (VarWisdom, Type)], [Param Type], [SubExp])
forall a b. (a -> b) -> a -> b
$
[PatElem (VarWisdom, Type)]
-> [Param Type]
-> [SubExp]
-> [(PatElem (VarWisdom, Type), Param Type, SubExp)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [PatElem (VarWisdom, Type)]
red_pes [Param Type]
[LParam (Wise SOACS)]
redlam_params [SubExp]
nes
let maplam' :: Lambda (Wise SOACS)
maplam' = [Bool] -> Lambda (Wise SOACS) -> Lambda (Wise SOACS)
forall rep. [Bool] -> Lambda rep -> Lambda rep
removeLambdaResults (Int -> [Bool] -> [Bool]
forall a. Int -> [a] -> [a]
take ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) [Bool]
alive_mask) Lambda (Wise SOACS)
maplam
Lambda (Wise SOACS)
redlam' <- [Bool] -> Lambda (Wise SOACS) -> Lambda (Wise SOACS)
forall rep. [Bool] -> Lambda rep -> Lambda rep
removeLambdaResults (Int -> [Bool] -> [Bool]
forall a. Int -> [a] -> [a]
take ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) [Bool]
alive_mask) (Lambda (Wise SOACS) -> Lambda (Wise SOACS))
-> RuleM (Wise SOACS) (Lambda (Wise SOACS))
-> RuleM (Wise SOACS) (Lambda (Wise SOACS))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Lambda (Rep (RuleM (Wise SOACS)))
-> [Maybe SubExp]
-> RuleM (Wise SOACS) (Lambda (Rep (RuleM (Wise SOACS))))
forall (m :: * -> *).
(MonadBuilder m, Buildable (Rep m), BuilderOps (Rep m)) =>
Lambda (Rep m) -> [Maybe SubExp] -> m (Lambda (Rep m))
fixLambdaParams Lambda (Rep (RuleM (Wise SOACS)))
Lambda (Wise SOACS)
redlam ([Maybe SubExp]
dead_fix [Maybe SubExp] -> [Maybe SubExp] -> [Maybe SubExp]
forall a. [a] -> [a] -> [a]
++ [Maybe SubExp]
dead_fix)
StmAux (ExpWisdom, ())
-> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpWisdom, ())
StmAux (ExpDec (Wise SOACS))
aux (RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ())
-> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$
Pat (LetDec (Rep (RuleM (Wise SOACS))))
-> Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind ([PatElem (LetDec (Rep (RuleM (Wise SOACS))))]
-> Pat (LetDec (Rep (RuleM (Wise SOACS))))
forall dec. [PatElem dec] -> Pat dec
Pat ([PatElem (LetDec (Rep (RuleM (Wise SOACS))))]
-> Pat (LetDec (Rep (RuleM (Wise SOACS)))))
-> [PatElem (LetDec (Rep (RuleM (Wise SOACS))))]
-> Pat (LetDec (Rep (RuleM (Wise SOACS))))
forall a b. (a -> b) -> a -> b
$ [PatElem (VarWisdom, Type)]
used_red_pes [PatElem (VarWisdom, Type)]
-> [PatElem (VarWisdom, Type)] -> [PatElem (VarWisdom, Type)]
forall a. [a] -> [a] -> [a]
++ [PatElem (VarWisdom, Type)]
map_pes) (Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ())
-> Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$
Op (Rep (RuleM (Wise SOACS))) -> Exp (Rep (RuleM (Wise SOACS)))
forall rep. Op rep -> Exp rep
Op (Op (Rep (RuleM (Wise SOACS))) -> Exp (Rep (RuleM (Wise SOACS))))
-> Op (Rep (RuleM (Wise SOACS))) -> Exp (Rep (RuleM (Wise SOACS)))
forall a b. (a -> b) -> a -> b
$
SubExp -> [VName] -> ScremaForm (Wise SOACS) -> SOAC (Wise SOACS)
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName]
arrs (ScremaForm (Wise SOACS) -> SOAC (Wise SOACS))
-> ScremaForm (Wise SOACS) -> SOAC (Wise SOACS)
forall a b. (a -> b) -> a -> b
$
Lambda (Wise SOACS)
-> [SubExp] -> Lambda (Wise SOACS) -> ScremaForm (Wise SOACS)
mkOp Lambda (Wise SOACS)
redlam' [SubExp]
used_nes Lambda (Wise SOACS)
maplam'
removeDeadReduction' Lambda (Wise SOACS)
_ [SubExp]
_ Lambda (Wise SOACS)
_ Lambda (Wise SOACS)
-> [SubExp] -> Lambda (Wise SOACS) -> ScremaForm (Wise SOACS)
_ = Rule (Wise SOACS)
forall rep. Rule rep
Skip
removeDeadReduction BottomUp (Wise SOACS)
_ Pat (LetDec (Wise SOACS))
_ StmAux (ExpDec (Wise SOACS))
_ Op (Wise SOACS)
_ = Rule (Wise SOACS)
forall rep. Rule rep
Skip
removeDeadWrite :: BottomUpRuleOp (Wise SOACS)
removeDeadWrite :: RuleOp (Wise SOACS) (BottomUp (Wise SOACS))
removeDeadWrite (SymbolTable (Wise SOACS)
_, UsageTable
used) Pat (LetDec (Wise SOACS))
pat StmAux (ExpDec (Wise SOACS))
aux (Scatter SubExp
w [VName]
arrs Lambda (Wise SOACS)
fun [(Shape, Int, VName)]
dests) =
let ([Result]
i_ses, Result
v_ses) = [(Result, SubExpRes)] -> ([Result], Result)
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Result, SubExpRes)] -> ([Result], Result))
-> [(Result, SubExpRes)] -> ([Result], Result)
forall a b. (a -> b) -> a -> b
$ [(Shape, Int, VName)] -> Result -> [(Result, SubExpRes)]
forall array a. [(Shape, Int, array)] -> [a] -> [([a], a)]
groupScatterResults' [(Shape, Int, VName)]
dests (Result -> [(Result, SubExpRes)])
-> Result -> [(Result, SubExpRes)]
forall a b. (a -> b) -> a -> b
$ Body (Wise SOACS) -> Result
forall rep. Body rep -> Result
bodyResult (Body (Wise SOACS) -> Result) -> Body (Wise SOACS) -> Result
forall a b. (a -> b) -> a -> b
$ Lambda (Wise SOACS) -> Body (Wise SOACS)
forall rep. Lambda rep -> Body rep
lambdaBody Lambda (Wise SOACS)
fun
([[Type]]
i_ts, [Type]
v_ts) = [([Type], Type)] -> ([[Type]], [Type])
forall a b. [(a, b)] -> ([a], [b])
unzip ([([Type], Type)] -> ([[Type]], [Type]))
-> [([Type], Type)] -> ([[Type]], [Type])
forall a b. (a -> b) -> a -> b
$ [(Shape, Int, VName)] -> [Type] -> [([Type], Type)]
forall array a. [(Shape, Int, array)] -> [a] -> [([a], a)]
groupScatterResults' [(Shape, Int, VName)]
dests ([Type] -> [([Type], Type)]) -> [Type] -> [([Type], Type)]
forall a b. (a -> b) -> a -> b
$ Lambda (Wise SOACS) -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda (Wise SOACS)
fun
isUsed :: (PatElem (VarWisdom, Type), Result, SubExpRes, [Type], Type,
(Shape, Int, VName))
-> Bool
isUsed (PatElem (VarWisdom, Type)
bindee, Result
_, SubExpRes
_, [Type]
_, Type
_, (Shape, Int, VName)
_) = (VName -> UsageTable -> Bool
`UT.used` UsageTable
used) (VName -> Bool) -> VName -> Bool
forall a b. (a -> b) -> a -> b
$ PatElem (VarWisdom, Type) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (VarWisdom, Type)
bindee
([PatElem (VarWisdom, Type)]
pat', [Result]
i_ses', Result
v_ses', [[Type]]
i_ts', [Type]
v_ts', [(Shape, Int, VName)]
dests') =
[(PatElem (VarWisdom, Type), Result, SubExpRes, [Type], Type,
(Shape, Int, VName))]
-> ([PatElem (VarWisdom, Type)], [Result], Result, [[Type]],
[Type], [(Shape, Int, VName)])
forall a b c d e f.
[(a, b, c, d, e, f)] -> ([a], [b], [c], [d], [e], [f])
unzip6 ([(PatElem (VarWisdom, Type), Result, SubExpRes, [Type], Type,
(Shape, Int, VName))]
-> ([PatElem (VarWisdom, Type)], [Result], Result, [[Type]],
[Type], [(Shape, Int, VName)]))
-> [(PatElem (VarWisdom, Type), Result, SubExpRes, [Type], Type,
(Shape, Int, VName))]
-> ([PatElem (VarWisdom, Type)], [Result], Result, [[Type]],
[Type], [(Shape, Int, VName)])
forall a b. (a -> b) -> a -> b
$ ((PatElem (VarWisdom, Type), Result, SubExpRes, [Type], Type,
(Shape, Int, VName))
-> Bool)
-> [(PatElem (VarWisdom, Type), Result, SubExpRes, [Type], Type,
(Shape, Int, VName))]
-> [(PatElem (VarWisdom, Type), Result, SubExpRes, [Type], Type,
(Shape, Int, VName))]
forall a. (a -> Bool) -> [a] -> [a]
filter (PatElem (VarWisdom, Type), Result, SubExpRes, [Type], Type,
(Shape, Int, VName))
-> Bool
isUsed ([(PatElem (VarWisdom, Type), Result, SubExpRes, [Type], Type,
(Shape, Int, VName))]
-> [(PatElem (VarWisdom, Type), Result, SubExpRes, [Type], Type,
(Shape, Int, VName))])
-> [(PatElem (VarWisdom, Type), Result, SubExpRes, [Type], Type,
(Shape, Int, VName))]
-> [(PatElem (VarWisdom, Type), Result, SubExpRes, [Type], Type,
(Shape, Int, VName))]
forall a b. (a -> b) -> a -> b
$ [PatElem (VarWisdom, Type)]
-> [Result]
-> Result
-> [[Type]]
-> [Type]
-> [(Shape, Int, VName)]
-> [(PatElem (VarWisdom, Type), Result, SubExpRes, [Type], Type,
(Shape, Int, VName))]
forall a b c d e f.
[a] -> [b] -> [c] -> [d] -> [e] -> [f] -> [(a, b, c, d, e, f)]
zip6 (Pat (VarWisdom, Type) -> [PatElem (VarWisdom, Type)]
forall dec. Pat dec -> [PatElem dec]
patElems Pat (VarWisdom, Type)
Pat (LetDec (Wise SOACS))
pat) [Result]
i_ses Result
v_ses [[Type]]
i_ts [Type]
v_ts [(Shape, Int, VName)]
dests
fun' :: Lambda (Wise SOACS)
fun' =
Lambda (Wise SOACS)
fun
{ lambdaBody :: Body (Wise SOACS)
lambdaBody =
Stms (Wise SOACS) -> Result -> Body (Wise SOACS)
forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody (Body (Wise SOACS) -> Stms (Wise SOACS)
forall rep. Body rep -> Stms rep
bodyStms (Lambda (Wise SOACS) -> Body (Wise SOACS)
forall rep. Lambda rep -> Body rep
lambdaBody Lambda (Wise SOACS)
fun)) ([Result] -> Result
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [Result]
i_ses' Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Result
v_ses'),
lambdaReturnType :: [Type]
lambdaReturnType = [[Type]] -> [Type]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Type]]
i_ts' [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [Type]
v_ts'
}
in if Pat (VarWisdom, Type)
Pat (LetDec (Wise SOACS))
pat Pat (VarWisdom, Type) -> Pat (VarWisdom, Type) -> Bool
forall a. Eq a => a -> a -> Bool
/= [PatElem (VarWisdom, Type)] -> Pat (VarWisdom, Type)
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (VarWisdom, Type)]
pat'
then
RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM (Wise SOACS) () -> Rule (Wise SOACS))
-> (Exp (Wise SOACS) -> RuleM (Wise SOACS) ())
-> Exp (Wise SOACS)
-> Rule (Wise SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StmAux (ExpWisdom, ())
-> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpWisdom, ())
StmAux (ExpDec (Wise SOACS))
aux (RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ())
-> (Exp (Wise SOACS) -> RuleM (Wise SOACS) ())
-> Exp (Wise SOACS)
-> RuleM (Wise SOACS) ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pat (LetDec (Rep (RuleM (Wise SOACS))))
-> Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind ([PatElem (VarWisdom, Type)] -> Pat (VarWisdom, Type)
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (VarWisdom, Type)]
pat') (Exp (Wise SOACS) -> Rule (Wise SOACS))
-> Exp (Wise SOACS) -> Rule (Wise SOACS)
forall a b. (a -> b) -> a -> b
$
Op (Wise SOACS) -> Exp (Wise SOACS)
forall rep. Op rep -> Exp rep
Op (SubExp
-> [VName]
-> Lambda (Wise SOACS)
-> [(Shape, Int, VName)]
-> SOAC (Wise SOACS)
forall rep.
SubExp
-> [VName] -> Lambda rep -> [(Shape, Int, VName)] -> SOAC rep
Scatter SubExp
w [VName]
arrs Lambda (Wise SOACS)
fun' [(Shape, Int, VName)]
dests')
else Rule (Wise SOACS)
forall rep. Rule rep
Skip
removeDeadWrite BottomUp (Wise SOACS)
_ Pat (LetDec (Wise SOACS))
_ StmAux (ExpDec (Wise SOACS))
_ Op (Wise SOACS)
_ = Rule (Wise SOACS)
forall rep. Rule rep
Skip
fuseConcatScatter :: TopDownRuleOp (Wise SOACS)
fuseConcatScatter :: RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
fuseConcatScatter SymbolTable (Wise SOACS)
vtable Pat (LetDec (Wise SOACS))
pat StmAux (ExpDec (Wise SOACS))
_ (Scatter SubExp
_ [VName]
arrs Lambda (Wise SOACS)
fun [(Shape, Int, VName)]
dests)
| Just (ws :: [SubExp]
ws@(SubExp
w' : [SubExp]
_), [[VName]]
xss, [Certs]
css) <- [(SubExp, [VName], Certs)] -> ([SubExp], [[VName]], [Certs])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 ([(SubExp, [VName], Certs)] -> ([SubExp], [[VName]], [Certs]))
-> Maybe [(SubExp, [VName], Certs)]
-> Maybe ([SubExp], [[VName]], [Certs])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VName -> Maybe (SubExp, [VName], Certs))
-> [VName] -> Maybe [(SubExp, [VName], Certs)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> Maybe (SubExp, [VName], Certs)
isConcat [VName]
arrs,
[[VName]]
xivs <- [[VName]] -> [[VName]]
forall a. [[a]] -> [[a]]
transpose [[VName]]
xss,
(SubExp -> Bool) -> [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (SubExp
w' ==) [SubExp]
ws = RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM (Wise SOACS) () -> Rule (Wise SOACS))
-> RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ do
let r :: Int
r = [[VName]] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [[VName]]
xivs
[Lambda (Wise SOACS)]
fun2s <- Int
-> RuleM (Wise SOACS) (Lambda (Wise SOACS))
-> RuleM (Wise SOACS) [Lambda (Wise SOACS)]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM (Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) (Lambda (Wise SOACS) -> RuleM (Wise SOACS) (Lambda (Wise SOACS))
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda (Wise SOACS)
fun)
let ([Result]
fun_is, [Result]
fun_vs) =
[(Result, Result)] -> ([Result], [Result])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Result, Result)] -> ([Result], [Result]))
-> ([Lambda (Wise SOACS)] -> [(Result, Result)])
-> [Lambda (Wise SOACS)]
-> ([Result], [Result])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Lambda (Wise SOACS) -> (Result, Result))
-> [Lambda (Wise SOACS)] -> [(Result, Result)]
forall a b. (a -> b) -> [a] -> [b]
map ([(Shape, Int, VName)] -> Result -> (Result, Result)
forall array a. [(Shape, Int, array)] -> [a] -> ([a], [a])
splitScatterResults [(Shape, Int, VName)]
dests (Result -> (Result, Result))
-> (Lambda (Wise SOACS) -> Result)
-> Lambda (Wise SOACS)
-> (Result, Result)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body (Wise SOACS) -> Result
forall rep. Body rep -> Result
bodyResult (Body (Wise SOACS) -> Result)
-> (Lambda (Wise SOACS) -> Body (Wise SOACS))
-> Lambda (Wise SOACS)
-> Result
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda (Wise SOACS) -> Body (Wise SOACS)
forall rep. Lambda rep -> Body rep
lambdaBody) ([Lambda (Wise SOACS)] -> ([Result], [Result]))
-> [Lambda (Wise SOACS)] -> ([Result], [Result])
forall a b. (a -> b) -> a -> b
$
Lambda (Wise SOACS)
fun Lambda (Wise SOACS)
-> [Lambda (Wise SOACS)] -> [Lambda (Wise SOACS)]
forall a. a -> [a] -> [a]
: [Lambda (Wise SOACS)]
fun2s
([[Type]]
its, [[Type]]
vts) =
[([Type], [Type])] -> ([[Type]], [[Type]])
forall a b. [(a, b)] -> ([a], [b])
unzip ([([Type], [Type])] -> ([[Type]], [[Type]]))
-> ([Type] -> [([Type], [Type])]) -> [Type] -> ([[Type]], [[Type]])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> ([Type], [Type]) -> [([Type], [Type])]
forall a. Int -> a -> [a]
replicate Int
r (([Type], [Type]) -> [([Type], [Type])])
-> ([Type] -> ([Type], [Type])) -> [Type] -> [([Type], [Type])]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(Shape, Int, VName)] -> [Type] -> ([Type], [Type])
forall array a. [(Shape, Int, array)] -> [a] -> ([a], [a])
splitScatterResults [(Shape, Int, VName)]
dests ([Type] -> ([[Type]], [[Type]])) -> [Type] -> ([[Type]], [[Type]])
forall a b. (a -> b) -> a -> b
$ Lambda (Wise SOACS) -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda (Wise SOACS)
fun
new_stmts :: Stms (Wise SOACS)
new_stmts = [Stms (Wise SOACS)] -> Stms (Wise SOACS)
forall a. Monoid a => [a] -> a
mconcat ([Stms (Wise SOACS)] -> Stms (Wise SOACS))
-> [Stms (Wise SOACS)] -> Stms (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ (Lambda (Wise SOACS) -> Stms (Wise SOACS))
-> [Lambda (Wise SOACS)] -> [Stms (Wise SOACS)]
forall a b. (a -> b) -> [a] -> [b]
map (Body (Wise SOACS) -> Stms (Wise SOACS)
forall rep. Body rep -> Stms rep
bodyStms (Body (Wise SOACS) -> Stms (Wise SOACS))
-> (Lambda (Wise SOACS) -> Body (Wise SOACS))
-> Lambda (Wise SOACS)
-> Stms (Wise SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda (Wise SOACS) -> Body (Wise SOACS)
forall rep. Lambda rep -> Body rep
lambdaBody) (Lambda (Wise SOACS)
fun Lambda (Wise SOACS)
-> [Lambda (Wise SOACS)] -> [Lambda (Wise SOACS)]
forall a. a -> [a] -> [a]
: [Lambda (Wise SOACS)]
fun2s)
let fun' :: Lambda (Wise SOACS)
fun' =
Lambda
{ lambdaParams :: [LParam (Wise SOACS)]
lambdaParams = [[LParam (Wise SOACS)]] -> [LParam (Wise SOACS)]
forall a. Monoid a => [a] -> a
mconcat ([[LParam (Wise SOACS)]] -> [LParam (Wise SOACS)])
-> [[LParam (Wise SOACS)]] -> [LParam (Wise SOACS)]
forall a b. (a -> b) -> a -> b
$ (Lambda (Wise SOACS) -> [Param Type])
-> [Lambda (Wise SOACS)] -> [[Param Type]]
forall a b. (a -> b) -> [a] -> [b]
map Lambda (Wise SOACS) -> [Param Type]
Lambda (Wise SOACS) -> [LParam (Wise SOACS)]
forall rep. Lambda rep -> [LParam rep]
lambdaParams (Lambda (Wise SOACS)
fun Lambda (Wise SOACS)
-> [Lambda (Wise SOACS)] -> [Lambda (Wise SOACS)]
forall a. a -> [a] -> [a]
: [Lambda (Wise SOACS)]
fun2s),
lambdaBody :: Body (Wise SOACS)
lambdaBody = Stms (Wise SOACS) -> Result -> Body (Wise SOACS)
forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody Stms (Wise SOACS)
new_stmts (Result -> Body (Wise SOACS)) -> Result -> Body (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ [Result] -> Result
forall {a}. [[a]] -> [a]
mix [Result]
fun_is Result -> Result -> Result
forall a. Semigroup a => a -> a -> a
<> [Result] -> Result
forall {a}. [[a]] -> [a]
mix [Result]
fun_vs,
lambdaReturnType :: [Type]
lambdaReturnType = [[Type]] -> [Type]
forall {a}. [[a]] -> [a]
mix [[Type]]
its [Type] -> [Type] -> [Type]
forall a. Semigroup a => a -> a -> a
<> [[Type]] -> [Type]
forall {a}. [[a]] -> [a]
mix [[Type]]
vts
}
Certs -> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall a. Certs -> RuleM (Wise SOACS) a -> RuleM (Wise SOACS) a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying ([Certs] -> Certs
forall a. Monoid a => [a] -> a
mconcat [Certs]
css) (RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ())
-> (SOAC (Wise SOACS) -> RuleM (Wise SOACS) ())
-> SOAC (Wise SOACS)
-> RuleM (Wise SOACS) ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pat (LetDec (Rep (RuleM (Wise SOACS))))
-> Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec (Rep (RuleM (Wise SOACS))))
Pat (LetDec (Wise SOACS))
pat (Exp (Wise SOACS) -> RuleM (Wise SOACS) ())
-> (SOAC (Wise SOACS) -> Exp (Wise SOACS))
-> SOAC (Wise SOACS)
-> RuleM (Wise SOACS) ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Op (Wise SOACS) -> Exp (Wise SOACS)
SOAC (Wise SOACS) -> Exp (Wise SOACS)
forall rep. Op rep -> Exp rep
Op (SOAC (Wise SOACS) -> RuleM (Wise SOACS) ())
-> SOAC (Wise SOACS) -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$
SubExp
-> [VName]
-> Lambda (Wise SOACS)
-> [(Shape, Int, VName)]
-> SOAC (Wise SOACS)
forall rep.
SubExp
-> [VName] -> Lambda rep -> [(Shape, Int, VName)] -> SOAC rep
Scatter SubExp
w' ([[VName]] -> [VName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[VName]]
xivs) Lambda (Wise SOACS)
fun' ([(Shape, Int, VName)] -> SOAC (Wise SOACS))
-> [(Shape, Int, VName)] -> SOAC (Wise SOACS)
forall a b. (a -> b) -> a -> b
$
((Shape, Int, VName) -> (Shape, Int, VName))
-> [(Shape, Int, VName)] -> [(Shape, Int, VName)]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> (Shape, Int, VName) -> (Shape, Int, VName)
forall {b} {a} {c}. Num b => b -> (a, b, c) -> (a, b, c)
incWrites Int
r) [(Shape, Int, VName)]
dests
where
sizeOf :: VName -> Maybe SubExp
sizeOf :: VName -> Maybe SubExp
sizeOf VName
x = Int -> Type -> SubExp
forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 (Type -> SubExp)
-> (Entry (Wise SOACS) -> Type) -> Entry (Wise SOACS) -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Entry (Wise SOACS) -> Type
forall t. Typed t => t -> Type
typeOf (Entry (Wise SOACS) -> SubExp)
-> Maybe (Entry (Wise SOACS)) -> Maybe SubExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> SymbolTable (Wise SOACS) -> Maybe (Entry (Wise SOACS))
forall rep. VName -> SymbolTable rep -> Maybe (Entry rep)
ST.lookup VName
x SymbolTable (Wise SOACS)
vtable
mix :: [[a]] -> [a]
mix = [[a]] -> [a]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[a]] -> [a]) -> ([[a]] -> [[a]]) -> [[a]] -> [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [[a]] -> [[a]]
forall a. [[a]] -> [[a]]
transpose
incWrites :: b -> (a, b, c) -> (a, b, c)
incWrites b
r (a
w, b
n, c
a) = (a
w, b
n b -> b -> b
forall a. Num a => a -> a -> a
* b
r, c
a)
isConcat :: VName -> Maybe (SubExp, [VName], Certs)
isConcat VName
v = case VName
-> SymbolTable (Wise SOACS) -> Maybe (Exp (Wise SOACS), Certs)
forall rep. VName -> SymbolTable rep -> Maybe (Exp rep, Certs)
ST.lookupExp VName
v SymbolTable (Wise SOACS)
vtable of
Just (BasicOp (Concat Int
0 (VName
x :| [VName]
ys) SubExp
_), Certs
cs) -> do
SubExp
x_w <- VName -> Maybe SubExp
sizeOf VName
x
[SubExp]
y_ws <- (VName -> Maybe SubExp) -> [VName] -> Maybe [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> Maybe SubExp
sizeOf [VName]
ys
Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ (SubExp -> Bool) -> [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (SubExp
x_w ==) [SubExp]
y_ws
(SubExp, [VName], Certs) -> Maybe (SubExp, [VName], Certs)
forall a. a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp
x_w, VName
x VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
ys, Certs
cs)
Just (BasicOp (Reshape ReshapeKind
ReshapeCoerce Shape
_ VName
arr), Certs
cs) -> do
(SubExp
a, [VName]
b, Certs
cs') <- VName -> Maybe (SubExp, [VName], Certs)
isConcat VName
arr
(SubExp, [VName], Certs) -> Maybe (SubExp, [VName], Certs)
forall a. a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp
a, [VName]
b, Certs
cs Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> Certs
cs')
Maybe (Exp (Wise SOACS), Certs)
_ -> Maybe (SubExp, [VName], Certs)
forall a. Maybe a
Nothing
fuseConcatScatter SymbolTable (Wise SOACS)
_ Pat (LetDec (Wise SOACS))
_ StmAux (ExpDec (Wise SOACS))
_ Op (Wise SOACS)
_ = Rule (Wise SOACS)
forall rep. Rule rep
Skip
simplifyClosedFormReduce :: TopDownRuleOp (Wise SOACS)
simplifyClosedFormReduce :: RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
simplifyClosedFormReduce SymbolTable (Wise SOACS)
_ Pat (LetDec (Wise SOACS))
pat StmAux (ExpDec (Wise SOACS))
_ (Screma (Constant PrimValue
w) [VName]
_ ScremaForm (Wise SOACS)
form)
| Just [SubExp]
nes <- (Reduce (Wise SOACS) -> [SubExp])
-> [Reduce (Wise SOACS)] -> [SubExp]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Reduce (Wise SOACS) -> [SubExp]
forall rep. Reduce rep -> [SubExp]
redNeutral ([Reduce (Wise SOACS)] -> [SubExp])
-> (([Reduce (Wise SOACS)], Lambda (Wise SOACS))
-> [Reduce (Wise SOACS)])
-> ([Reduce (Wise SOACS)], Lambda (Wise SOACS))
-> [SubExp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([Reduce (Wise SOACS)], Lambda (Wise SOACS))
-> [Reduce (Wise SOACS)]
forall a b. (a, b) -> a
fst (([Reduce (Wise SOACS)], Lambda (Wise SOACS)) -> [SubExp])
-> Maybe ([Reduce (Wise SOACS)], Lambda (Wise SOACS))
-> Maybe [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ScremaForm (Wise SOACS)
-> Maybe ([Reduce (Wise SOACS)], Lambda (Wise SOACS))
forall rep. ScremaForm rep -> Maybe ([Reduce rep], Lambda rep)
isRedomapSOAC ScremaForm (Wise SOACS)
form,
PrimValue -> Bool
zeroIsh PrimValue
w =
RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM (Wise SOACS) () -> Rule (Wise SOACS))
-> (((VName, SubExp) -> RuleM (Wise SOACS) ())
-> RuleM (Wise SOACS) ())
-> ((VName, SubExp) -> RuleM (Wise SOACS) ())
-> Rule (Wise SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(VName, SubExp)]
-> ((VName, SubExp) -> RuleM (Wise SOACS) ())
-> RuleM (Wise SOACS) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pat (VarWisdom, Type) -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat (VarWisdom, Type)
Pat (LetDec (Wise SOACS))
pat) [SubExp]
nes) (((VName, SubExp) -> RuleM (Wise SOACS) ()) -> Rule (Wise SOACS))
-> ((VName, SubExp) -> RuleM (Wise SOACS) ()) -> Rule (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ \(VName
v, SubExp
ne) ->
[VName] -> Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
v] (Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ())
-> Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (RuleM (Wise SOACS)))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (RuleM (Wise SOACS))))
-> BasicOp -> Exp (Rep (RuleM (Wise SOACS)))
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
ne
simplifyClosedFormReduce SymbolTable (Wise SOACS)
vtable Pat (LetDec (Wise SOACS))
pat StmAux (ExpDec (Wise SOACS))
_ (Screma SubExp
_ [VName]
arrs ScremaForm (Wise SOACS)
form)
| Just [Reduce Commutativity
_ Lambda (Wise SOACS)
red_fun [SubExp]
nes] <- ScremaForm (Wise SOACS) -> Maybe [Reduce (Wise SOACS)]
forall rep. ScremaForm rep -> Maybe [Reduce rep]
isReduceSOAC ScremaForm (Wise SOACS)
form =
RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM (Wise SOACS) () -> Rule (Wise SOACS))
-> RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ VarLookup (Wise SOACS)
-> Pat (LetDec (Wise SOACS))
-> Lambda (Wise SOACS)
-> [SubExp]
-> [VName]
-> RuleM (Wise SOACS) ()
forall rep.
BuilderOps rep =>
VarLookup rep
-> Pat (LetDec rep)
-> Lambda rep
-> [SubExp]
-> [VName]
-> RuleM rep ()
foldClosedForm (VName
-> SymbolTable (Wise SOACS) -> Maybe (Exp (Wise SOACS), Certs)
forall rep. VName -> SymbolTable rep -> Maybe (Exp rep, Certs)
`ST.lookupExp` SymbolTable (Wise SOACS)
vtable) Pat (LetDec (Wise SOACS))
pat Lambda (Wise SOACS)
red_fun [SubExp]
nes [VName]
arrs
simplifyClosedFormReduce SymbolTable (Wise SOACS)
_ Pat (LetDec (Wise SOACS))
_ StmAux (ExpDec (Wise SOACS))
_ Op (Wise SOACS)
_ = Rule (Wise SOACS)
forall rep. Rule rep
Skip
simplifyKnownIterationSOAC ::
(Buildable rep, BuilderOps rep, HasSOAC rep) =>
TopDownRuleOp rep
simplifyKnownIterationSOAC :: forall rep.
(Buildable rep, BuilderOps rep, HasSOAC rep) =>
TopDownRuleOp rep
simplifyKnownIterationSOAC TopDown rep
_ Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ OpC rep rep
op
| Just (Screma (Constant PrimValue
k) [VName]
arrs (ScremaForm [Scan rep]
scans [Reduce rep]
reds Lambda rep
map_lam)) <- OpC rep rep -> Maybe (SOAC rep)
forall rep. HasSOAC rep => Op rep -> Maybe (SOAC rep)
asSOAC OpC rep rep
op,
PrimValue -> Bool
oneIsh PrimValue
k = RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
let (Reduce Commutativity
_ Lambda rep
red_lam [SubExp]
red_nes) = [Reduce rep] -> Reduce rep
forall rep. Buildable rep => [Reduce rep] -> Reduce rep
singleReduce [Reduce rep]
reds
(Scan Lambda rep
scan_lam [SubExp]
scan_nes) = [Scan rep] -> Scan rep
forall rep. Buildable rep => [Scan rep] -> Scan rep
singleScan [Scan rep]
scans
([PatElem (LetDec rep)]
scan_pes, [PatElem (LetDec rep)]
red_pes, [PatElem (LetDec rep)]
map_pes) =
Int
-> Int
-> [PatElem (LetDec rep)]
-> ([PatElem (LetDec rep)], [PatElem (LetDec rep)],
[PatElem (LetDec rep)])
forall a. Int -> Int -> [a] -> ([a], [a], [a])
splitAt3 ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
scan_nes) ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
red_nes) ([PatElem (LetDec rep)]
-> ([PatElem (LetDec rep)], [PatElem (LetDec rep)],
[PatElem (LetDec rep)]))
-> [PatElem (LetDec rep)]
-> ([PatElem (LetDec rep)], [PatElem (LetDec rep)],
[PatElem (LetDec rep)])
forall a b. (a -> b) -> a -> b
$
Pat (LetDec rep) -> [PatElem (LetDec rep)]
forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec rep)
pat
bindMapParam :: Param dec -> VName -> m ()
bindMapParam Param dec
p VName
a = do
Type
a_t <- VName -> m Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
a
[VName] -> Exp (Rep m) -> m ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
p] (Exp (Rep m) -> m ()) -> Exp (Rep m) -> m ()
forall a b. (a -> b) -> a -> b
$
BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$
VName -> Slice SubExp -> BasicOp
Index VName
a (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$
Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
a_t [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ Int64 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int64
0 :: Int64)]
bindArrayResult :: PatElem dec -> SubExpRes -> m ()
bindArrayResult PatElem dec
pe (SubExpRes Certs
cs SubExp
se) =
Certs -> m () -> m ()
forall a. Certs -> m a -> m a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs (m () -> m ()) -> (Exp (Rep m) -> m ()) -> Exp (Rep m) -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [VName] -> Exp (Rep m) -> m ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [PatElem dec -> VName
forall dec. PatElem dec -> VName
patElemName PatElem dec
pe] (Exp (Rep m) -> m ()) -> Exp (Rep m) -> m ()
forall a b. (a -> b) -> a -> b
$
BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$
[SubExp] -> Type -> BasicOp
ArrayLit [SubExp
se] (Type -> BasicOp) -> Type -> BasicOp
forall a b. (a -> b) -> a -> b
$
Type -> Type
forall u. TypeBase Shape u -> TypeBase Shape u
rowType (Type -> Type) -> Type -> Type
forall a b. (a -> b) -> a -> b
$
PatElem dec -> Type
forall dec. Typed dec => PatElem dec -> Type
patElemType PatElem dec
pe
bindResult :: PatElem dec -> SubExpRes -> m ()
bindResult PatElem dec
pe (SubExpRes Certs
cs SubExp
se) =
Certs -> m () -> m ()
forall a. Certs -> m a -> m a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ [VName] -> Exp (Rep m) -> m ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [PatElem dec -> VName
forall dec. PatElem dec -> VName
patElemName PatElem dec
pe] (Exp (Rep m) -> m ()) -> Exp (Rep m) -> m ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
(Param Type -> VName -> RuleM rep ())
-> [Param Type] -> [VName] -> RuleM rep ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ Param Type -> VName -> RuleM rep ()
forall {m :: * -> *} {dec}.
MonadBuilder m =>
Param dec -> VName -> m ()
bindMapParam (Lambda rep -> [LParam rep]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
map_lam) [VName]
arrs
(Result
to_scan, Result
to_red, Result
map_res) <-
Int -> Int -> Result -> (Result, Result, Result)
forall a. Int -> Int -> [a] -> ([a], [a], [a])
splitAt3 ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
scan_nes) ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
red_nes)
(Result -> (Result, Result, Result))
-> RuleM rep Result -> RuleM rep (Result, Result, Result)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Body (Rep (RuleM rep)) -> RuleM rep Result
forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind (Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
map_lam)
Result
scan_res <- Lambda (Rep (RuleM rep))
-> [RuleM rep (Exp (Rep (RuleM rep)))] -> RuleM rep Result
forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m Result
eLambda Lambda rep
Lambda (Rep (RuleM rep))
scan_lam ([RuleM rep (Exp (Rep (RuleM rep)))] -> RuleM rep Result)
-> [RuleM rep (Exp (Rep (RuleM rep)))] -> RuleM rep Result
forall a b. (a -> b) -> a -> b
$ (SubExp -> RuleM rep (Exp (Rep (RuleM rep))))
-> [SubExp] -> [RuleM rep (Exp (Rep (RuleM rep)))]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> RuleM rep (Exp (Rep (RuleM rep)))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp ([SubExp] -> [RuleM rep (Exp (Rep (RuleM rep)))])
-> [SubExp] -> [RuleM rep (Exp (Rep (RuleM rep)))]
forall a b. (a -> b) -> a -> b
$ [SubExp]
scan_nes [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ (SubExpRes -> SubExp) -> Result -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp Result
to_scan
Result
red_res <- Lambda (Rep (RuleM rep))
-> [RuleM rep (Exp (Rep (RuleM rep)))] -> RuleM rep Result
forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m Result
eLambda Lambda rep
Lambda (Rep (RuleM rep))
red_lam ([RuleM rep (Exp (Rep (RuleM rep)))] -> RuleM rep Result)
-> [RuleM rep (Exp (Rep (RuleM rep)))] -> RuleM rep Result
forall a b. (a -> b) -> a -> b
$ (SubExp -> RuleM rep (Exp (Rep (RuleM rep))))
-> [SubExp] -> [RuleM rep (Exp (Rep (RuleM rep)))]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> RuleM rep (Exp (Rep (RuleM rep)))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp ([SubExp] -> [RuleM rep (Exp (Rep (RuleM rep)))])
-> [SubExp] -> [RuleM rep (Exp (Rep (RuleM rep)))]
forall a b. (a -> b) -> a -> b
$ [SubExp]
red_nes [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ (SubExpRes -> SubExp) -> Result -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp Result
to_red
(PatElem (LetDec rep) -> SubExpRes -> RuleM rep ())
-> [PatElem (LetDec rep)] -> Result -> RuleM rep ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ PatElem (LetDec rep) -> SubExpRes -> RuleM rep ()
forall {m :: * -> *} {dec}.
(MonadBuilder m, Typed dec) =>
PatElem dec -> SubExpRes -> m ()
bindArrayResult [PatElem (LetDec rep)]
scan_pes Result
scan_res
(PatElem (LetDec rep) -> SubExpRes -> RuleM rep ())
-> [PatElem (LetDec rep)] -> Result -> RuleM rep ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ PatElem (LetDec rep) -> SubExpRes -> RuleM rep ()
forall {m :: * -> *} {dec}.
MonadBuilder m =>
PatElem dec -> SubExpRes -> m ()
bindResult [PatElem (LetDec rep)]
red_pes Result
red_res
(PatElem (LetDec rep) -> SubExpRes -> RuleM rep ())
-> [PatElem (LetDec rep)] -> Result -> RuleM rep ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ PatElem (LetDec rep) -> SubExpRes -> RuleM rep ()
forall {m :: * -> *} {dec}.
(MonadBuilder m, Typed dec) =>
PatElem dec -> SubExpRes -> m ()
bindArrayResult [PatElem (LetDec rep)]
map_pes Result
map_res
simplifyKnownIterationSOAC TopDown rep
_ Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ OpC rep rep
op
| Just (Stream (Constant PrimValue
k) [VName]
arrs [SubExp]
nes Lambda rep
fold_lam) <- OpC rep rep -> Maybe (SOAC rep)
forall rep. HasSOAC rep => Op rep -> Maybe (SOAC rep)
asSOAC OpC rep rep
op,
PrimValue -> Bool
oneIsh PrimValue
k = RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
let (Param Type
chunk_param, [Param Type]
acc_params, [Param Type]
slice_params) =
Int -> [Param Type] -> (Param Type, [Param Type], [Param Type])
forall dec.
Int -> [Param dec] -> (Param dec, [Param dec], [Param dec])
partitionChunkedFoldParameters ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) (Lambda rep -> [LParam rep]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
fold_lam)
[VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
chunk_param] (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
BasicOp -> Exp (Rep (RuleM rep))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (RuleM rep)))
-> BasicOp -> Exp (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$
SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$
IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1
[(Param Type, SubExp)]
-> ((Param Type, SubExp) -> RuleM rep ()) -> RuleM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param Type] -> [SubExp] -> [(Param Type, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param Type]
acc_params [SubExp]
nes) (((Param Type, SubExp) -> RuleM rep ()) -> RuleM rep ())
-> ((Param Type, SubExp) -> RuleM rep ()) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ \(Param Type
p, SubExp
ne) ->
[VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
p] (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (RuleM rep))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (RuleM rep)))
-> BasicOp -> Exp (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
ne
[(Param Type, VName)]
-> ((Param Type, VName) -> RuleM rep ()) -> RuleM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param Type] -> [VName] -> [(Param Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param Type]
slice_params [VName]
arrs) (((Param Type, VName) -> RuleM rep ()) -> RuleM rep ())
-> ((Param Type, VName) -> RuleM rep ()) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ \(Param Type
p, VName
arr) ->
[VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
p] (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (RuleM rep))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (RuleM rep)))
-> BasicOp -> Exp (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
arr
Result
res <- Body (Rep (RuleM rep)) -> RuleM rep Result
forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind (Body (Rep (RuleM rep)) -> RuleM rep Result)
-> Body (Rep (RuleM rep)) -> RuleM rep Result
forall a b. (a -> b) -> a -> b
$ Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
fold_lam
[(VName, SubExpRes)]
-> ((VName, SubExpRes) -> RuleM rep ()) -> RuleM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> Result -> [(VName, SubExpRes)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pat (LetDec rep) -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat (LetDec rep)
pat) Result
res) (((VName, SubExpRes) -> RuleM rep ()) -> RuleM rep ())
-> ((VName, SubExpRes) -> RuleM rep ()) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ \(VName
v, SubExpRes Certs
cs SubExp
se) ->
Certs -> RuleM rep () -> RuleM rep ()
forall a. Certs -> RuleM rep a -> RuleM rep a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ [VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
v] (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (RuleM rep))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (RuleM rep)))
-> BasicOp -> Exp (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
simplifyKnownIterationSOAC TopDown rep
_ Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux OpC rep rep
op
| Just (Screma (Constant (IntValue (Int64Value Int64
k))) [VName]
arrs (ScremaForm [] [] Lambda rep
map_lam)) <- OpC rep rep -> Maybe (SOAC rep)
forall rep. HasSOAC rep => Op rep -> Maybe (SOAC rep)
asSOAC OpC rep rep
op,
Attr
"unroll" Attr -> Attrs -> Bool
`inAttrs` StmAux (ExpDec rep) -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs StmAux (ExpDec rep)
aux = RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
[Result]
arrs_elems <- ([Result] -> [Result]) -> RuleM rep [Result] -> RuleM rep [Result]
forall a b. (a -> b) -> RuleM rep a -> RuleM rep b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [Result] -> [Result]
forall a. [[a]] -> [[a]]
transpose (RuleM rep [Result] -> RuleM rep [Result])
-> ((Int64 -> RuleM rep Result) -> RuleM rep [Result])
-> (Int64 -> RuleM rep Result)
-> RuleM rep [Result]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int64] -> (Int64 -> RuleM rep Result) -> RuleM rep [Result]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Int64
0 .. Int64
k Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
- Int64
1] ((Int64 -> RuleM rep Result) -> RuleM rep [Result])
-> (Int64 -> RuleM rep Result) -> RuleM rep [Result]
forall a b. (a -> b) -> a -> b
$ \Int64
i -> do
Lambda rep
map_lam' <- Lambda rep -> RuleM rep (Lambda rep)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda rep
map_lam
Lambda (Rep (RuleM rep))
-> [RuleM rep (Exp (Rep (RuleM rep)))] -> RuleM rep Result
forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m Result
eLambda Lambda rep
Lambda (Rep (RuleM rep))
map_lam' ([RuleM rep (Exp (Rep (RuleM rep)))] -> RuleM rep Result)
-> [RuleM rep (Exp (Rep (RuleM rep)))] -> RuleM rep Result
forall a b. (a -> b) -> a -> b
$ (VName -> RuleM rep (Exp rep)) -> [VName] -> [RuleM rep (Exp rep)]
forall a b. (a -> b) -> [a] -> [b]
map (VName
-> [RuleM rep (Exp (Rep (RuleM rep)))]
-> RuleM rep (Exp (Rep (RuleM rep)))
forall (m :: * -> *).
MonadBuilder m =>
VName -> [m (Exp (Rep m))] -> m (Exp (Rep m))
`eIndex` [SubExp -> RuleM rep (Exp (Rep (RuleM rep)))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (Int64 -> SubExp
forall v. IsValue v => v -> SubExp
constant Int64
i)]) [VName]
arrs
[(VName, Result, Type)]
-> ((VName, Result, Type) -> RuleM rep ()) -> RuleM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [Result] -> [Type] -> [(VName, Result, Type)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 (Pat (LetDec rep) -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat (LetDec rep)
pat) [Result]
arrs_elems (Lambda rep -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda rep
map_lam)) (((VName, Result, Type) -> RuleM rep ()) -> RuleM rep ())
-> ((VName, Result, Type) -> RuleM rep ()) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
\(VName
v, Result
arr_elems, Type
t) ->
Certs -> RuleM rep () -> RuleM rep ()
forall a. Certs -> RuleM rep a -> RuleM rep a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying ([Certs] -> Certs
forall a. Monoid a => [a] -> a
mconcat ((SubExpRes -> Certs) -> Result -> [Certs]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> Certs
resCerts Result
arr_elems)) (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
[VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
v] (Exp rep -> RuleM rep ())
-> (BasicOp -> Exp rep) -> BasicOp -> RuleM rep ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> RuleM rep ()) -> BasicOp -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
[SubExp] -> Type -> BasicOp
ArrayLit ((SubExpRes -> SubExp) -> Result -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp Result
arr_elems) Type
t
simplifyKnownIterationSOAC TopDown rep
_ Pat (LetDec rep)
_ StmAux (ExpDec rep)
_ OpC rep rep
_ = Rule rep
forall rep. Rule rep
Skip
data ArrayOp
= ArrayIndexing Certs VName (Slice SubExp)
| ArrayRearrange Certs VName [Int]
| ArrayReshape Certs VName ReshapeKind Shape
| ArrayCopy Certs VName
|
ArrayVar Certs VName
deriving (ArrayOp -> ArrayOp -> Bool
(ArrayOp -> ArrayOp -> Bool)
-> (ArrayOp -> ArrayOp -> Bool) -> Eq ArrayOp
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: ArrayOp -> ArrayOp -> Bool
== :: ArrayOp -> ArrayOp -> Bool
$c/= :: ArrayOp -> ArrayOp -> Bool
/= :: ArrayOp -> ArrayOp -> Bool
Eq, Eq ArrayOp
Eq ArrayOp
-> (ArrayOp -> ArrayOp -> Ordering)
-> (ArrayOp -> ArrayOp -> Bool)
-> (ArrayOp -> ArrayOp -> Bool)
-> (ArrayOp -> ArrayOp -> Bool)
-> (ArrayOp -> ArrayOp -> Bool)
-> (ArrayOp -> ArrayOp -> ArrayOp)
-> (ArrayOp -> ArrayOp -> ArrayOp)
-> Ord ArrayOp
ArrayOp -> ArrayOp -> Bool
ArrayOp -> ArrayOp -> Ordering
ArrayOp -> ArrayOp -> ArrayOp
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: ArrayOp -> ArrayOp -> Ordering
compare :: ArrayOp -> ArrayOp -> Ordering
$c< :: ArrayOp -> ArrayOp -> Bool
< :: ArrayOp -> ArrayOp -> Bool
$c<= :: ArrayOp -> ArrayOp -> Bool
<= :: ArrayOp -> ArrayOp -> Bool
$c> :: ArrayOp -> ArrayOp -> Bool
> :: ArrayOp -> ArrayOp -> Bool
$c>= :: ArrayOp -> ArrayOp -> Bool
>= :: ArrayOp -> ArrayOp -> Bool
$cmax :: ArrayOp -> ArrayOp -> ArrayOp
max :: ArrayOp -> ArrayOp -> ArrayOp
$cmin :: ArrayOp -> ArrayOp -> ArrayOp
min :: ArrayOp -> ArrayOp -> ArrayOp
Ord, Int -> ArrayOp -> ShowS
[ArrayOp] -> ShowS
ArrayOp -> [Char]
(Int -> ArrayOp -> ShowS)
-> (ArrayOp -> [Char]) -> ([ArrayOp] -> ShowS) -> Show ArrayOp
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ArrayOp -> ShowS
showsPrec :: Int -> ArrayOp -> ShowS
$cshow :: ArrayOp -> [Char]
show :: ArrayOp -> [Char]
$cshowList :: [ArrayOp] -> ShowS
showList :: [ArrayOp] -> ShowS
Show)
arrayOpArr :: ArrayOp -> VName
arrayOpArr :: ArrayOp -> VName
arrayOpArr (ArrayIndexing Certs
_ VName
arr Slice SubExp
_) = VName
arr
arrayOpArr (ArrayRearrange Certs
_ VName
arr [Int]
_) = VName
arr
arrayOpArr (ArrayReshape Certs
_ VName
arr ReshapeKind
_ Shape
_) = VName
arr
arrayOpArr (ArrayCopy Certs
_ VName
arr) = VName
arr
arrayOpArr (ArrayVar Certs
_ VName
arr) = VName
arr
arrayOpCerts :: ArrayOp -> Certs
arrayOpCerts :: ArrayOp -> Certs
arrayOpCerts (ArrayIndexing Certs
cs VName
_ Slice SubExp
_) = Certs
cs
arrayOpCerts (ArrayRearrange Certs
cs VName
_ [Int]
_) = Certs
cs
arrayOpCerts (ArrayReshape Certs
cs VName
_ ReshapeKind
_ Shape
_) = Certs
cs
arrayOpCerts (ArrayCopy Certs
cs VName
_) = Certs
cs
arrayOpCerts (ArrayVar Certs
cs VName
_) = Certs
cs
isArrayOp :: Certs -> Exp rep -> Maybe ArrayOp
isArrayOp :: forall rep. Certs -> Exp rep -> Maybe ArrayOp
isArrayOp Certs
cs (BasicOp (Index VName
arr Slice SubExp
slice)) =
ArrayOp -> Maybe ArrayOp
forall a. a -> Maybe a
Just (ArrayOp -> Maybe ArrayOp) -> ArrayOp -> Maybe ArrayOp
forall a b. (a -> b) -> a -> b
$ Certs -> VName -> Slice SubExp -> ArrayOp
ArrayIndexing Certs
cs VName
arr Slice SubExp
slice
isArrayOp Certs
cs (BasicOp (Rearrange [Int]
perm VName
arr)) =
ArrayOp -> Maybe ArrayOp
forall a. a -> Maybe a
Just (ArrayOp -> Maybe ArrayOp) -> ArrayOp -> Maybe ArrayOp
forall a b. (a -> b) -> a -> b
$ Certs -> VName -> [Int] -> ArrayOp
ArrayRearrange Certs
cs VName
arr [Int]
perm
isArrayOp Certs
cs (BasicOp (Reshape ReshapeKind
k Shape
new_shape VName
arr)) =
ArrayOp -> Maybe ArrayOp
forall a. a -> Maybe a
Just (ArrayOp -> Maybe ArrayOp) -> ArrayOp -> Maybe ArrayOp
forall a b. (a -> b) -> a -> b
$ Certs -> VName -> ReshapeKind -> Shape -> ArrayOp
ArrayReshape Certs
cs VName
arr ReshapeKind
k Shape
new_shape
isArrayOp Certs
cs (BasicOp (Replicate (Shape []) (Var VName
arr))) =
ArrayOp -> Maybe ArrayOp
forall a. a -> Maybe a
Just (ArrayOp -> Maybe ArrayOp) -> ArrayOp -> Maybe ArrayOp
forall a b. (a -> b) -> a -> b
$ Certs -> VName -> ArrayOp
ArrayCopy Certs
cs VName
arr
isArrayOp Certs
_ Exp rep
_ =
Maybe ArrayOp
forall a. Maybe a
Nothing
fromArrayOp :: ArrayOp -> (Certs, Exp rep)
fromArrayOp :: forall rep. ArrayOp -> (Certs, Exp rep)
fromArrayOp (ArrayIndexing Certs
cs VName
arr Slice SubExp
slice) = (Certs
cs, BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr Slice SubExp
slice)
fromArrayOp (ArrayRearrange Certs
cs VName
arr [Int]
perm) = (Certs
cs, BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Rearrange [Int]
perm VName
arr)
fromArrayOp (ArrayReshape Certs
cs VName
arr ReshapeKind
k Shape
new_shape) = (Certs
cs, BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ ReshapeKind -> Shape -> VName -> BasicOp
Reshape ReshapeKind
k Shape
new_shape VName
arr)
fromArrayOp (ArrayCopy Certs
cs VName
arr) = (Certs
cs, BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate Shape
forall a. Monoid a => a
mempty (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
arr)
fromArrayOp (ArrayVar Certs
cs VName
arr) = (Certs
cs, BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
arr)
arrayOps ::
forall rep.
(Buildable rep, HasSOAC rep) =>
Certs ->
Body rep ->
S.Set (Pat (LetDec rep), ArrayOp)
arrayOps :: forall rep.
(Buildable rep, HasSOAC rep) =>
Certs -> Body rep -> Set (Pat (LetDec rep), ArrayOp)
arrayOps Certs
cs = [Set (Pat (LetDec rep), ArrayOp)]
-> Set (Pat (LetDec rep), ArrayOp)
forall a. Monoid a => [a] -> a
mconcat ([Set (Pat (LetDec rep), ArrayOp)]
-> Set (Pat (LetDec rep), ArrayOp))
-> (Body rep -> [Set (Pat (LetDec rep), ArrayOp)])
-> Body rep
-> Set (Pat (LetDec rep), ArrayOp)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Stm rep -> Set (Pat (LetDec rep), ArrayOp))
-> [Stm rep] -> [Set (Pat (LetDec rep), ArrayOp)]
forall a b. (a -> b) -> [a] -> [b]
map Stm rep -> Set (Pat (LetDec rep), ArrayOp)
onStm ([Stm rep] -> [Set (Pat (LetDec rep), ArrayOp)])
-> (Body rep -> [Stm rep])
-> Body rep
-> [Set (Pat (LetDec rep), ArrayOp)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms rep -> [Stm rep]
forall rep. Stms rep -> [Stm rep]
stmsToList (Stms rep -> [Stm rep])
-> (Body rep -> Stms rep) -> Body rep -> [Stm rep]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body rep -> Stms rep
forall rep. Body rep -> Stms rep
bodyStms
where
onStm :: Stm rep -> Set (Pat (LetDec rep), ArrayOp)
onStm (Let Pat (LetDec rep)
_ StmAux (ExpDec rep)
_ Match {}) = Set (Pat (LetDec rep), ArrayOp)
forall a. Monoid a => a
mempty
onStm (Let Pat (LetDec rep)
_ StmAux (ExpDec rep)
_ Loop {}) = Set (Pat (LetDec rep), ArrayOp)
forall a. Monoid a => a
mempty
onStm (Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux Exp rep
e) =
case Certs -> Exp rep -> Maybe ArrayOp
forall rep. Certs -> Exp rep -> Maybe ArrayOp
isArrayOp (Certs
cs Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> StmAux (ExpDec rep) -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (ExpDec rep)
aux) Exp rep
e of
Just ArrayOp
op -> (Pat (LetDec rep), ArrayOp) -> Set (Pat (LetDec rep), ArrayOp)
forall a. a -> Set a
S.singleton (Pat (LetDec rep)
pat, ArrayOp
op)
Maybe ArrayOp
Nothing -> State (Set (Pat (LetDec rep), ArrayOp)) ()
-> Set (Pat (LetDec rep), ArrayOp)
-> Set (Pat (LetDec rep), ArrayOp)
forall s a. State s a -> s -> s
execState (Walker rep (StateT (Set (Pat (LetDec rep), ArrayOp)) Identity)
-> Exp rep -> State (Set (Pat (LetDec rep), ArrayOp)) ()
forall (m :: * -> *) rep.
Monad m =>
Walker rep m -> Exp rep -> m ()
walkExpM (Certs
-> Walker rep (StateT (Set (Pat (LetDec rep), ArrayOp)) Identity)
walker (StmAux (ExpDec rep) -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (ExpDec rep)
aux)) Exp rep
e) Set (Pat (LetDec rep), ArrayOp)
forall a. Monoid a => a
mempty
onOp :: Certs -> OpC rep rep -> Set (Pat (LetDec rep), ArrayOp)
onOp Certs
more_cs OpC rep rep
op
| Just SOAC rep
soac <- OpC rep rep -> Maybe (SOAC rep)
forall rep. HasSOAC rep => Op rep -> Maybe (SOAC rep)
asSOAC OpC rep rep
op =
((Pat (LetDec rep), ArrayOp) -> Bool)
-> Set (Pat (LetDec rep), ArrayOp)
-> Set (Pat (LetDec rep), ArrayOp)
forall a. (a -> Bool) -> Set a -> Set a
S.filter (ArrayOp -> Bool
notCopy (ArrayOp -> Bool)
-> ((Pat (LetDec rep), ArrayOp) -> ArrayOp)
-> (Pat (LetDec rep), ArrayOp)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Pat (LetDec rep), ArrayOp) -> ArrayOp
forall a b. (a, b) -> b
snd) (Set (Pat (LetDec rep), ArrayOp)
-> Set (Pat (LetDec rep), ArrayOp))
-> Set (Pat (LetDec rep), ArrayOp)
-> Set (Pat (LetDec rep), ArrayOp)
forall a b. (a -> b) -> a -> b
$
Writer (Set (Pat (LetDec rep), ArrayOp)) (SOAC rep)
-> Set (Pat (LetDec rep), ArrayOp)
forall w a. Writer w a -> w
execWriter (Writer (Set (Pat (LetDec rep), ArrayOp)) (SOAC rep)
-> Set (Pat (LetDec rep), ArrayOp))
-> Writer (Set (Pat (LetDec rep), ArrayOp)) (SOAC rep)
-> Set (Pat (LetDec rep), ArrayOp)
forall a b. (a -> b) -> a -> b
$
SOACMapper
rep rep (WriterT (Set (Pat (LetDec rep), ArrayOp)) Identity)
-> SOAC rep -> Writer (Set (Pat (LetDec rep), ArrayOp)) (SOAC rep)
forall (m :: * -> *) frep trep.
Monad m =>
SOACMapper frep trep m -> SOAC frep -> m (SOAC trep)
mapSOACM
SOACMapper
Any Any (WriterT (Set (Pat (LetDec rep), ArrayOp)) Identity)
forall rep (m :: * -> *). Monad m => SOACMapper rep rep m
identitySOACMapper {mapOnSOACLambda :: Lambda rep
-> WriterT (Set (Pat (LetDec rep), ArrayOp)) Identity (Lambda rep)
mapOnSOACLambda = Certs
-> Lambda rep
-> WriterT (Set (Pat (LetDec rep), ArrayOp)) Identity (Lambda rep)
onLambda Certs
more_cs}
(SOAC rep
soac :: SOAC rep)
| Bool
otherwise =
Set (Pat (LetDec rep), ArrayOp)
forall a. Monoid a => a
mempty
onLambda :: Certs
-> Lambda rep
-> WriterT (Set (Pat (LetDec rep), ArrayOp)) Identity (Lambda rep)
onLambda Certs
more_cs Lambda rep
lam = do
Set (Pat (LetDec rep), ArrayOp)
-> WriterT (Set (Pat (LetDec rep), ArrayOp)) Identity ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell (Set (Pat (LetDec rep), ArrayOp)
-> WriterT (Set (Pat (LetDec rep), ArrayOp)) Identity ())
-> Set (Pat (LetDec rep), ArrayOp)
-> WriterT (Set (Pat (LetDec rep), ArrayOp)) Identity ()
forall a b. (a -> b) -> a -> b
$ Certs -> Body rep -> Set (Pat (LetDec rep), ArrayOp)
forall rep.
(Buildable rep, HasSOAC rep) =>
Certs -> Body rep -> Set (Pat (LetDec rep), ArrayOp)
arrayOps (Certs
cs Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> Certs
more_cs) (Body rep -> Set (Pat (LetDec rep), ArrayOp))
-> Body rep -> Set (Pat (LetDec rep), ArrayOp)
forall a b. (a -> b) -> a -> b
$ Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam
Lambda rep
-> WriterT (Set (Pat (LetDec rep), ArrayOp)) Identity (Lambda rep)
forall a. a -> WriterT (Set (Pat (LetDec rep), ArrayOp)) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Lambda rep
lam
walker :: Certs
-> Walker rep (StateT (Set (Pat (LetDec rep), ArrayOp)) Identity)
walker Certs
more_cs =
(forall rep (m :: * -> *). Monad m => Walker rep m
identityWalker @rep)
{ walkOnBody :: Scope rep -> Body rep -> State (Set (Pat (LetDec rep), ArrayOp)) ()
walkOnBody = (Body rep -> State (Set (Pat (LetDec rep), ArrayOp)) ())
-> Scope rep
-> Body rep
-> State (Set (Pat (LetDec rep), ArrayOp)) ()
forall a b. a -> b -> a
const ((Body rep -> State (Set (Pat (LetDec rep), ArrayOp)) ())
-> Scope rep
-> Body rep
-> State (Set (Pat (LetDec rep), ArrayOp)) ())
-> (Body rep -> State (Set (Pat (LetDec rep), ArrayOp)) ())
-> Scope rep
-> Body rep
-> State (Set (Pat (LetDec rep), ArrayOp)) ()
forall a b. (a -> b) -> a -> b
$ (Set (Pat (LetDec rep), ArrayOp)
-> Set (Pat (LetDec rep), ArrayOp))
-> State (Set (Pat (LetDec rep), ArrayOp)) ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((Set (Pat (LetDec rep), ArrayOp)
-> Set (Pat (LetDec rep), ArrayOp))
-> State (Set (Pat (LetDec rep), ArrayOp)) ())
-> (Body rep
-> Set (Pat (LetDec rep), ArrayOp)
-> Set (Pat (LetDec rep), ArrayOp))
-> Body rep
-> State (Set (Pat (LetDec rep), ArrayOp)) ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Set (Pat (LetDec rep), ArrayOp)
-> Set (Pat (LetDec rep), ArrayOp)
-> Set (Pat (LetDec rep), ArrayOp)
forall a. Semigroup a => a -> a -> a
(<>) (Set (Pat (LetDec rep), ArrayOp)
-> Set (Pat (LetDec rep), ArrayOp)
-> Set (Pat (LetDec rep), ArrayOp))
-> (Body rep -> Set (Pat (LetDec rep), ArrayOp))
-> Body rep
-> Set (Pat (LetDec rep), ArrayOp)
-> Set (Pat (LetDec rep), ArrayOp)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Certs -> Body rep -> Set (Pat (LetDec rep), ArrayOp)
forall rep.
(Buildable rep, HasSOAC rep) =>
Certs -> Body rep -> Set (Pat (LetDec rep), ArrayOp)
arrayOps (Certs
cs Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> Certs
more_cs),
walkOnOp :: OpC rep rep -> State (Set (Pat (LetDec rep), ArrayOp)) ()
walkOnOp = (Set (Pat (LetDec rep), ArrayOp)
-> Set (Pat (LetDec rep), ArrayOp))
-> State (Set (Pat (LetDec rep), ArrayOp)) ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((Set (Pat (LetDec rep), ArrayOp)
-> Set (Pat (LetDec rep), ArrayOp))
-> State (Set (Pat (LetDec rep), ArrayOp)) ())
-> (OpC rep rep
-> Set (Pat (LetDec rep), ArrayOp)
-> Set (Pat (LetDec rep), ArrayOp))
-> OpC rep rep
-> State (Set (Pat (LetDec rep), ArrayOp)) ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Set (Pat (LetDec rep), ArrayOp)
-> Set (Pat (LetDec rep), ArrayOp)
-> Set (Pat (LetDec rep), ArrayOp)
forall a. Semigroup a => a -> a -> a
(<>) (Set (Pat (LetDec rep), ArrayOp)
-> Set (Pat (LetDec rep), ArrayOp)
-> Set (Pat (LetDec rep), ArrayOp))
-> (OpC rep rep -> Set (Pat (LetDec rep), ArrayOp))
-> OpC rep rep
-> Set (Pat (LetDec rep), ArrayOp)
-> Set (Pat (LetDec rep), ArrayOp)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Certs -> OpC rep rep -> Set (Pat (LetDec rep), ArrayOp)
onOp Certs
more_cs
}
notCopy :: ArrayOp -> Bool
notCopy (ArrayCopy {}) = Bool
False
notCopy ArrayOp
_ = Bool
True
replaceArrayOps ::
forall rep.
(Buildable rep, BuilderOps rep, HasSOAC rep) =>
M.Map (Pat (LetDec rep)) ArrayOp ->
Body rep ->
Body rep
replaceArrayOps :: forall rep.
(Buildable rep, BuilderOps rep, HasSOAC rep) =>
Map (Pat (LetDec rep)) ArrayOp -> Body rep -> Body rep
replaceArrayOps Map (Pat (LetDec rep)) ArrayOp
substs (Body BodyDec rep
_ Stms rep
stms Result
res) =
Stms rep -> Result -> Body rep
forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody ((Stm rep -> Stm rep) -> Stms rep -> Stms rep
forall a b. (a -> b) -> Seq a -> Seq b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Stm rep -> Stm rep
onStm Stms rep
stms) Result
res
where
onStm :: Stm rep -> Stm rep
onStm (Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux Exp rep
e) =
let (Certs
cs', Exp rep
e') =
(Certs, Exp rep)
-> (ArrayOp -> (Certs, Exp rep))
-> Maybe ArrayOp
-> (Certs, Exp rep)
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (Certs
forall a. Monoid a => a
mempty, Mapper rep rep Identity -> Exp rep -> Exp rep
forall frep trep. Mapper frep trep Identity -> Exp frep -> Exp trep
mapExp Mapper rep rep Identity
mapper Exp rep
e) ArrayOp -> (Certs, Exp rep)
forall rep. ArrayOp -> (Certs, Exp rep)
fromArrayOp (Maybe ArrayOp -> (Certs, Exp rep))
-> Maybe ArrayOp -> (Certs, Exp rep)
forall a b. (a -> b) -> a -> b
$ Pat (LetDec rep) -> Map (Pat (LetDec rep)) ArrayOp -> Maybe ArrayOp
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Pat (LetDec rep)
pat Map (Pat (LetDec rep)) ArrayOp
substs
in Certs -> Stm rep -> Stm rep
forall rep. Certs -> Stm rep -> Stm rep
certify Certs
cs' (Stm rep -> Stm rep) -> Stm rep -> Stm rep
forall a b. (a -> b) -> a -> b
$ [Ident] -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
forall rep a.
Buildable rep =>
[Ident] -> StmAux a -> Exp rep -> Stm rep
mkLet' (Pat (LetDec rep) -> [Ident]
forall dec. Typed dec => Pat dec -> [Ident]
patIdents Pat (LetDec rep)
pat) StmAux (ExpDec rep)
aux Exp rep
e'
mapper :: Mapper rep rep Identity
mapper =
(forall rep (m :: * -> *). Monad m => Mapper rep rep m
identityMapper @rep)
{ mapOnBody :: Scope rep -> Body rep -> Identity (Body rep)
mapOnBody = (Body rep -> Identity (Body rep))
-> Scope rep -> Body rep -> Identity (Body rep)
forall a b. a -> b -> a
const ((Body rep -> Identity (Body rep))
-> Scope rep -> Body rep -> Identity (Body rep))
-> (Body rep -> Identity (Body rep))
-> Scope rep
-> Body rep
-> Identity (Body rep)
forall a b. (a -> b) -> a -> b
$ Body rep -> Identity (Body rep)
forall a. a -> Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Body rep -> Identity (Body rep))
-> (Body rep -> Body rep) -> Body rep -> Identity (Body rep)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map (Pat (LetDec rep)) ArrayOp -> Body rep -> Body rep
forall rep.
(Buildable rep, BuilderOps rep, HasSOAC rep) =>
Map (Pat (LetDec rep)) ArrayOp -> Body rep -> Body rep
replaceArrayOps Map (Pat (LetDec rep)) ArrayOp
substs,
mapOnOp :: Op rep -> Identity (Op rep)
mapOnOp = Op rep -> Identity (Op rep)
forall a. a -> Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Op rep -> Identity (Op rep))
-> (Op rep -> Op rep) -> Op rep -> Identity (Op rep)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Op rep -> Op rep
onOp
}
onOp :: Op rep -> Op rep
onOp Op rep
op
| Just (SOAC rep
soac :: SOAC rep) <- Op rep -> Maybe (SOAC rep)
forall rep. HasSOAC rep => Op rep -> Maybe (SOAC rep)
asSOAC Op rep
op =
SOAC rep -> Op rep
forall rep. HasSOAC rep => SOAC rep -> Op rep
soacOp (SOAC rep -> Op rep)
-> (Identity (SOAC rep) -> SOAC rep)
-> Identity (SOAC rep)
-> Op rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Identity (SOAC rep) -> SOAC rep
forall a. Identity a -> a
runIdentity (Identity (SOAC rep) -> Op rep) -> Identity (SOAC rep) -> Op rep
forall a b. (a -> b) -> a -> b
$
SOACMapper rep rep Identity -> SOAC rep -> Identity (SOAC rep)
forall (m :: * -> *) frep trep.
Monad m =>
SOACMapper frep trep m -> SOAC frep -> m (SOAC trep)
mapSOACM SOACMapper Any Any Identity
forall rep (m :: * -> *). Monad m => SOACMapper rep rep m
identitySOACMapper {mapOnSOACLambda :: Lambda rep -> Identity (Lambda rep)
mapOnSOACLambda = Lambda rep -> Identity (Lambda rep)
forall a. a -> Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Lambda rep -> Identity (Lambda rep))
-> (Lambda rep -> Lambda rep)
-> Lambda rep
-> Identity (Lambda rep)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda rep -> Lambda rep
onLambda} SOAC rep
soac
| Bool
otherwise =
Op rep
op
onLambda :: Lambda rep -> Lambda rep
onLambda Lambda rep
lam = Lambda rep
lam {lambdaBody :: Body rep
lambdaBody = Map (Pat (LetDec rep)) ArrayOp -> Body rep -> Body rep
forall rep.
(Buildable rep, BuilderOps rep, HasSOAC rep) =>
Map (Pat (LetDec rep)) ArrayOp -> Body rep -> Body rep
replaceArrayOps Map (Pat (LetDec rep)) ArrayOp
substs (Body rep -> Body rep) -> Body rep -> Body rep
forall a b. (a -> b) -> a -> b
$ Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam}
simplifyMapIota ::
forall rep.
(Buildable rep, BuilderOps rep, HasSOAC rep) =>
TopDownRuleOp rep
simplifyMapIota :: forall rep.
(Buildable rep, BuilderOps rep, HasSOAC rep) =>
TopDownRuleOp rep
simplifyMapIota TopDown rep
vtable Pat (LetDec rep)
screma_pat StmAux (ExpDec rep)
aux Op rep
op
| Just (Screma SubExp
w [VName]
arrs (ScremaForm [Scan rep]
scan [Reduce rep]
reduce Lambda rep
map_lam) :: SOAC rep) <- Op rep -> Maybe (SOAC rep)
forall rep. HasSOAC rep => Op rep -> Maybe (SOAC rep)
asSOAC Op rep
op,
Just (Param Type
p, VName
_) <- ((Param Type, VName) -> Bool)
-> [(Param Type, VName)] -> Maybe (Param Type, VName)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (Param Type, VName) -> Bool
isIota ([Param Type] -> [VName] -> [(Param Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Lambda rep -> [LParam rep]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
map_lam) [VName]
arrs),
[(Pat (LetDec rep), [SubExp], ArrayOp)]
indexings <-
((Pat (LetDec rep), ArrayOp)
-> Maybe (Pat (LetDec rep), [SubExp], ArrayOp))
-> [(Pat (LetDec rep), ArrayOp)]
-> [(Pat (LetDec rep), [SubExp], ArrayOp)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (VName
-> (Pat (LetDec rep), ArrayOp)
-> Maybe (Pat (LetDec rep), [SubExp], ArrayOp)
indexesWith (Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
p)) ([(Pat (LetDec rep), ArrayOp)]
-> [(Pat (LetDec rep), [SubExp], ArrayOp)])
-> (Set (Pat (LetDec rep), ArrayOp)
-> [(Pat (LetDec rep), ArrayOp)])
-> Set (Pat (LetDec rep), ArrayOp)
-> [(Pat (LetDec rep), [SubExp], ArrayOp)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Set (Pat (LetDec rep), ArrayOp) -> [(Pat (LetDec rep), ArrayOp)]
forall a. Set a -> [a]
S.toList (Set (Pat (LetDec rep), ArrayOp)
-> [(Pat (LetDec rep), [SubExp], ArrayOp)])
-> Set (Pat (LetDec rep), ArrayOp)
-> [(Pat (LetDec rep), [SubExp], ArrayOp)]
forall a b. (a -> b) -> a -> b
$
Certs -> Body rep -> Set (Pat (LetDec rep), ArrayOp)
forall rep.
(Buildable rep, HasSOAC rep) =>
Certs -> Body rep -> Set (Pat (LetDec rep), ArrayOp)
arrayOps Certs
forall a. Monoid a => a
mempty (Body rep -> Set (Pat (LetDec rep), ArrayOp))
-> Body rep -> Set (Pat (LetDec rep), ArrayOp)
forall a b. (a -> b) -> a -> b
$
Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
map_lam,
Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [(Pat (LetDec rep), [SubExp], ArrayOp)] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(Pat (LetDec rep), [SubExp], ArrayOp)]
indexings = RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
([VName]
more_arrs, [Param Type]
more_params, [(Pat (LetDec rep), ArrayOp)]
replacements) <-
[(VName, Param Type, (Pat (LetDec rep), ArrayOp))]
-> ([VName], [Param Type], [(Pat (LetDec rep), ArrayOp)])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 ([(VName, Param Type, (Pat (LetDec rep), ArrayOp))]
-> ([VName], [Param Type], [(Pat (LetDec rep), ArrayOp)]))
-> ([Maybe (VName, Param Type, (Pat (LetDec rep), ArrayOp))]
-> [(VName, Param Type, (Pat (LetDec rep), ArrayOp))])
-> [Maybe (VName, Param Type, (Pat (LetDec rep), ArrayOp))]
-> ([VName], [Param Type], [(Pat (LetDec rep), ArrayOp)])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Maybe (VName, Param Type, (Pat (LetDec rep), ArrayOp))]
-> [(VName, Param Type, (Pat (LetDec rep), ArrayOp))]
forall a. [Maybe a] -> [a]
catMaybes ([Maybe (VName, Param Type, (Pat (LetDec rep), ArrayOp))]
-> ([VName], [Param Type], [(Pat (LetDec rep), ArrayOp)]))
-> RuleM
rep [Maybe (VName, Param Type, (Pat (LetDec rep), ArrayOp))]
-> RuleM rep ([VName], [Param Type], [(Pat (LetDec rep), ArrayOp)])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((Pat (LetDec rep), [SubExp], ArrayOp)
-> RuleM
rep (Maybe (VName, Param Type, (Pat (LetDec rep), ArrayOp))))
-> [(Pat (LetDec rep), [SubExp], ArrayOp)]
-> RuleM
rep [Maybe (VName, Param Type, (Pat (LetDec rep), ArrayOp))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (SubExp
-> (Pat (LetDec rep), [SubExp], ArrayOp)
-> RuleM
rep (Maybe (VName, Param Type, (Pat (LetDec rep), ArrayOp)))
forall {m :: * -> *} {a}.
MonadBuilder m =>
SubExp
-> (a, [SubExp], ArrayOp)
-> m (Maybe (VName, Param Type, (a, ArrayOp)))
mapOverArr SubExp
w) [(Pat (LetDec rep), [SubExp], ArrayOp)]
indexings
let substs :: Map (Pat (LetDec rep)) ArrayOp
substs = [(Pat (LetDec rep), ArrayOp)] -> Map (Pat (LetDec rep)) ArrayOp
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(Pat (LetDec rep), ArrayOp)]
replacements
map_lam' :: Lambda rep
map_lam' =
Lambda rep
map_lam
{ lambdaParams :: [LParam rep]
lambdaParams = Lambda rep -> [LParam rep]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
map_lam [Param Type] -> [Param Type] -> [Param Type]
forall a. Semigroup a => a -> a -> a
<> [Param Type]
more_params,
lambdaBody :: Body rep
lambdaBody = Map (Pat (LetDec rep)) ArrayOp -> Body rep -> Body rep
forall rep.
(Buildable rep, BuilderOps rep, HasSOAC rep) =>
Map (Pat (LetDec rep)) ArrayOp -> Body rep -> Body rep
replaceArrayOps Map (Pat (LetDec rep)) ArrayOp
substs (Body rep -> Body rep) -> Body rep -> Body rep
forall a b. (a -> b) -> a -> b
$ Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
map_lam
}
StmAux (ExpDec rep) -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux (RuleM rep () -> RuleM rep ())
-> (SOAC rep -> RuleM rep ()) -> SOAC rep -> RuleM rep ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pat (LetDec (Rep (RuleM rep)))
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
Pat (LetDec (Rep (RuleM rep)))
screma_pat (Exp rep -> RuleM rep ())
-> (SOAC rep -> Exp rep) -> SOAC rep -> RuleM rep ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Op rep -> Exp rep
forall rep. Op rep -> Exp rep
Op (Op rep -> Exp rep) -> (SOAC rep -> Op rep) -> SOAC rep -> Exp rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SOAC rep -> Op rep
forall rep. HasSOAC rep => SOAC rep -> Op rep
soacOp (SOAC rep -> RuleM rep ()) -> SOAC rep -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
SubExp -> [VName] -> ScremaForm rep -> SOAC rep
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w ([VName]
arrs [VName] -> [VName] -> [VName]
forall a. Semigroup a => a -> a -> a
<> [VName]
more_arrs) ([Scan rep] -> [Reduce rep] -> Lambda rep -> ScremaForm rep
forall rep.
[Scan rep] -> [Reduce rep] -> Lambda rep -> ScremaForm rep
ScremaForm [Scan rep]
scan [Reduce rep]
reduce Lambda rep
map_lam')
where
isIota :: (Param Type, VName) -> Bool
isIota (Param Type
_, VName
arr) = case VName -> TopDown rep -> Maybe (BasicOp, Certs)
forall rep. VName -> SymbolTable rep -> Maybe (BasicOp, Certs)
ST.lookupBasicOp VName
arr TopDown rep
vtable of
Just (Iota SubExp
_ (Constant PrimValue
o) (Constant PrimValue
s) IntType
_, Certs
_) ->
PrimValue -> Bool
zeroIsh PrimValue
o Bool -> Bool -> Bool
&& PrimValue -> Bool
oneIsh PrimValue
s
Maybe (BasicOp, Certs)
_ -> Bool
False
fixWith :: VName -> [DimIndex SubExp] -> Maybe [SubExp]
fixWith VName
i (DimFix SubExp
j : [DimIndex SubExp]
slice)
| VName -> SubExp
Var VName
i SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
j = [SubExp] -> Maybe [SubExp]
forall a. a -> Maybe a
Just []
| Bool
otherwise = (SubExp
j :) ([SubExp] -> [SubExp]) -> Maybe [SubExp] -> Maybe [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> [DimIndex SubExp] -> Maybe [SubExp]
fixWith VName
i [DimIndex SubExp]
slice
fixWith VName
_ [DimIndex SubExp]
_ = Maybe [SubExp]
forall a. Maybe a
Nothing
indexesWith :: VName
-> (Pat (LetDec rep), ArrayOp)
-> Maybe (Pat (LetDec rep), [SubExp], ArrayOp)
indexesWith VName
v (Pat (LetDec rep)
pat, idx :: ArrayOp
idx@(ArrayIndexing Certs
cs VName
arr (Slice [DimIndex SubExp]
js)))
| VName
arr VName -> TopDown rep -> Bool
forall rep. VName -> SymbolTable rep -> Bool
`ST.elem` TopDown rep
vtable,
(VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> TopDown rep -> Bool
forall rep. VName -> SymbolTable rep -> Bool
`ST.elem` TopDown rep
vtable) ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$ Certs -> [VName]
unCerts Certs
cs,
Just [SubExp]
js' <- VName -> [DimIndex SubExp] -> Maybe [SubExp]
fixWith VName
v [DimIndex SubExp]
js,
(VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> TopDown rep -> Bool
forall rep. VName -> SymbolTable rep -> Bool
`ST.elem` TopDown rep
vtable) ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Names
forall a. FreeIn a => a -> Names
freeIn [SubExp]
js' =
(Pat (LetDec rep), [SubExp], ArrayOp)
-> Maybe (Pat (LetDec rep), [SubExp], ArrayOp)
forall a. a -> Maybe a
Just (Pat (LetDec rep)
pat, [SubExp]
js', ArrayOp
idx)
indexesWith VName
_ (Pat (LetDec rep), ArrayOp)
_ = Maybe (Pat (LetDec rep), [SubExp], ArrayOp)
forall a. Maybe a
Nothing
properArr :: [SubExp] -> VName -> f VName
properArr [] VName
arr = VName -> f VName
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
arr
properArr [SubExp]
js VName
arr = do
Type
arr_t <- VName -> f Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
arr
[Char] -> Exp (Rep f) -> f VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
arr) (Exp (Rep f) -> f VName) -> Exp (Rep f) -> f VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep f)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep f)) -> BasicOp -> Exp (Rep f)
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
arr_t ([DimIndex SubExp] -> Slice SubExp)
-> [DimIndex SubExp] -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ (SubExp -> DimIndex SubExp) -> [SubExp] -> [DimIndex SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix [SubExp]
js
mapOverArr :: SubExp
-> (a, [SubExp], ArrayOp)
-> m (Maybe (VName, Param Type, (a, ArrayOp)))
mapOverArr SubExp
w (a
pat, [SubExp]
js, ArrayIndexing Certs
cs VName
arr Slice SubExp
slice) = do
VName
arr' <- [SubExp] -> VName -> m VName
forall {f :: * -> *}.
MonadBuilder f =>
[SubExp] -> VName -> f VName
properArr [SubExp]
js VName
arr
Type
arr_t <- VName -> m Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
arr'
VName
arr'' <-
if Int -> Type -> SubExp
forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 Type
arr_t SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
w
then VName -> m VName
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
arr'
else
Certs -> m VName -> m VName
forall a. Certs -> m a -> m a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs (m VName -> m VName)
-> (Slice SubExp -> m VName) -> Slice SubExp -> m VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Char] -> Exp (Rep m) -> m VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
arr [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
"_prefix") (Exp (Rep m) -> m VName)
-> (Slice SubExp -> Exp (Rep m)) -> Slice SubExp -> m VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep m))
-> (Slice SubExp -> BasicOp) -> Slice SubExp -> Exp (Rep m)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Slice SubExp -> BasicOp
Index VName
arr' (Slice SubExp -> m VName) -> Slice SubExp -> m VName
forall a b. (a -> b) -> a -> b
$
Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
arr_t [SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) SubExp
w (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1)]
Param Type
arr_elem_param <- [Char] -> Type -> m (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam (VName -> [Char]
baseString VName
arr [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
"_elem") (Type -> Type
forall u. TypeBase Shape u -> TypeBase Shape u
rowType Type
arr_t)
Maybe (VName, Param Type, (a, ArrayOp))
-> m (Maybe (VName, Param Type, (a, ArrayOp)))
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe (VName, Param Type, (a, ArrayOp))
-> m (Maybe (VName, Param Type, (a, ArrayOp))))
-> Maybe (VName, Param Type, (a, ArrayOp))
-> m (Maybe (VName, Param Type, (a, ArrayOp)))
forall a b. (a -> b) -> a -> b
$
(VName, Param Type, (a, ArrayOp))
-> Maybe (VName, Param Type, (a, ArrayOp))
forall a. a -> Maybe a
Just
( VName
arr'',
Param Type
arr_elem_param,
( a
pat,
Certs -> VName -> Slice SubExp -> ArrayOp
ArrayIndexing Certs
cs (Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
arr_elem_param) ([DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice (Int -> [DimIndex SubExp] -> [DimIndex SubExp]
forall a. Int -> [a] -> [a]
drop ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
js Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) (Slice SubExp -> [DimIndex SubExp]
forall d. Slice d -> [DimIndex d]
unSlice Slice SubExp
slice)))
)
)
mapOverArr SubExp
_ (a, [SubExp], ArrayOp)
_ = Maybe (VName, Param Type, (a, ArrayOp))
-> m (Maybe (VName, Param Type, (a, ArrayOp)))
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (VName, Param Type, (a, ArrayOp))
forall a. Maybe a
Nothing
simplifyMapIota TopDown rep
_ Pat (LetDec rep)
_ StmAux (ExpDec rep)
_ Op rep
_ = Rule rep
forall rep. Rule rep
Skip
moveTransformToInput :: TopDownRuleOp (Wise SOACS)
moveTransformToInput :: RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
moveTransformToInput SymbolTable (Wise SOACS)
vtable Pat (LetDec (Wise SOACS))
screma_pat StmAux (ExpDec (Wise SOACS))
aux soac :: Op (Wise SOACS)
soac@(Screma SubExp
w [VName]
arrs (ScremaForm [Scan (Wise SOACS)]
scan [Reduce (Wise SOACS)]
reduce Lambda (Wise SOACS)
map_lam))
| [(Pat (VarWisdom, Type), ArrayOp)]
ops <- ((Pat (VarWisdom, Type), ArrayOp) -> Bool)
-> [(Pat (VarWisdom, Type), ArrayOp)]
-> [(Pat (VarWisdom, Type), ArrayOp)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Pat (VarWisdom, Type), ArrayOp) -> Bool
arrayIsMapParam ([(Pat (VarWisdom, Type), ArrayOp)]
-> [(Pat (VarWisdom, Type), ArrayOp)])
-> [(Pat (VarWisdom, Type), ArrayOp)]
-> [(Pat (VarWisdom, Type), ArrayOp)]
forall a b. (a -> b) -> a -> b
$ Set (Pat (VarWisdom, Type), ArrayOp)
-> [(Pat (VarWisdom, Type), ArrayOp)]
forall a. Set a -> [a]
S.toList (Set (Pat (VarWisdom, Type), ArrayOp)
-> [(Pat (VarWisdom, Type), ArrayOp)])
-> Set (Pat (VarWisdom, Type), ArrayOp)
-> [(Pat (VarWisdom, Type), ArrayOp)]
forall a b. (a -> b) -> a -> b
$ Certs
-> Body (Wise SOACS) -> Set (Pat (LetDec (Wise SOACS)), ArrayOp)
forall rep.
(Buildable rep, HasSOAC rep) =>
Certs -> Body rep -> Set (Pat (LetDec rep), ArrayOp)
arrayOps Certs
forall a. Monoid a => a
mempty (Body (Wise SOACS) -> Set (Pat (LetDec (Wise SOACS)), ArrayOp))
-> Body (Wise SOACS) -> Set (Pat (LetDec (Wise SOACS)), ArrayOp)
forall a b. (a -> b) -> a -> b
$ Lambda (Wise SOACS) -> Body (Wise SOACS)
forall rep. Lambda rep -> Body rep
lambdaBody Lambda (Wise SOACS)
map_lam,
Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [(Pat (VarWisdom, Type), ArrayOp)] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(Pat (VarWisdom, Type), ArrayOp)]
ops = RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM (Wise SOACS) () -> Rule (Wise SOACS))
-> RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ do
([VName]
more_arrs, [Param Type]
more_params, [(Pat (VarWisdom, Type), ArrayOp)]
replacements) <-
[(VName, Param Type, (Pat (VarWisdom, Type), ArrayOp))]
-> ([VName], [Param Type], [(Pat (VarWisdom, Type), ArrayOp)])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 ([(VName, Param Type, (Pat (VarWisdom, Type), ArrayOp))]
-> ([VName], [Param Type], [(Pat (VarWisdom, Type), ArrayOp)]))
-> ([Maybe (VName, Param Type, (Pat (VarWisdom, Type), ArrayOp))]
-> [(VName, Param Type, (Pat (VarWisdom, Type), ArrayOp))])
-> [Maybe (VName, Param Type, (Pat (VarWisdom, Type), ArrayOp))]
-> ([VName], [Param Type], [(Pat (VarWisdom, Type), ArrayOp)])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Maybe (VName, Param Type, (Pat (VarWisdom, Type), ArrayOp))]
-> [(VName, Param Type, (Pat (VarWisdom, Type), ArrayOp))]
forall a. [Maybe a] -> [a]
catMaybes ([Maybe (VName, Param Type, (Pat (VarWisdom, Type), ArrayOp))]
-> ([VName], [Param Type], [(Pat (VarWisdom, Type), ArrayOp)]))
-> RuleM
(Wise SOACS)
[Maybe (VName, Param Type, (Pat (VarWisdom, Type), ArrayOp))]
-> RuleM
(Wise SOACS)
([VName], [Param Type], [(Pat (VarWisdom, Type), ArrayOp)])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((Pat (VarWisdom, Type), ArrayOp)
-> RuleM
(Wise SOACS)
(Maybe (VName, Param Type, (Pat (VarWisdom, Type), ArrayOp))))
-> [(Pat (VarWisdom, Type), ArrayOp)]
-> RuleM
(Wise SOACS)
[Maybe (VName, Param Type, (Pat (VarWisdom, Type), ArrayOp))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Pat (VarWisdom, Type), ArrayOp)
-> RuleM
(Wise SOACS)
(Maybe (VName, Param Type, (Pat (VarWisdom, Type), ArrayOp)))
mapOverArr [(Pat (VarWisdom, Type), ArrayOp)]
ops
Bool -> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ([VName] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [VName]
more_arrs) RuleM (Wise SOACS) ()
forall rep a. RuleM rep a
cannotSimplify
let map_lam' :: Lambda (Wise SOACS)
map_lam' =
Lambda (Wise SOACS)
map_lam
{ lambdaParams :: [LParam (Wise SOACS)]
lambdaParams = Lambda (Wise SOACS) -> [LParam (Wise SOACS)]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda (Wise SOACS)
map_lam [Param Type] -> [Param Type] -> [Param Type]
forall a. Semigroup a => a -> a -> a
<> [Param Type]
more_params,
lambdaBody :: Body (Wise SOACS)
lambdaBody = Map (Pat (LetDec (Wise SOACS))) ArrayOp
-> Body (Wise SOACS) -> Body (Wise SOACS)
forall rep.
(Buildable rep, BuilderOps rep, HasSOAC rep) =>
Map (Pat (LetDec rep)) ArrayOp -> Body rep -> Body rep
replaceArrayOps ([(Pat (VarWisdom, Type), ArrayOp)]
-> Map (Pat (VarWisdom, Type)) ArrayOp
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(Pat (VarWisdom, Type), ArrayOp)]
replacements) (Body (Wise SOACS) -> Body (Wise SOACS))
-> Body (Wise SOACS) -> Body (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ Lambda (Wise SOACS) -> Body (Wise SOACS)
forall rep. Lambda rep -> Body rep
lambdaBody Lambda (Wise SOACS)
map_lam
}
StmAux (ExpWisdom, ())
-> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpWisdom, ())
StmAux (ExpDec (Wise SOACS))
aux (RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ())
-> (SOAC (Wise SOACS) -> RuleM (Wise SOACS) ())
-> SOAC (Wise SOACS)
-> RuleM (Wise SOACS) ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pat (LetDec (Rep (RuleM (Wise SOACS))))
-> Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec (Rep (RuleM (Wise SOACS))))
Pat (LetDec (Wise SOACS))
screma_pat (Exp (Wise SOACS) -> RuleM (Wise SOACS) ())
-> (SOAC (Wise SOACS) -> Exp (Wise SOACS))
-> SOAC (Wise SOACS)
-> RuleM (Wise SOACS) ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Op (Wise SOACS) -> Exp (Wise SOACS)
SOAC (Wise SOACS) -> Exp (Wise SOACS)
forall rep. Op rep -> Exp rep
Op (SOAC (Wise SOACS) -> RuleM (Wise SOACS) ())
-> SOAC (Wise SOACS) -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$
SubExp -> [VName] -> ScremaForm (Wise SOACS) -> SOAC (Wise SOACS)
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w ([VName]
arrs [VName] -> [VName] -> [VName]
forall a. Semigroup a => a -> a -> a
<> [VName]
more_arrs) ([Scan (Wise SOACS)]
-> [Reduce (Wise SOACS)]
-> Lambda (Wise SOACS)
-> ScremaForm (Wise SOACS)
forall rep.
[Scan rep] -> [Reduce rep] -> Lambda rep -> ScremaForm rep
ScremaForm [Scan (Wise SOACS)]
scan [Reduce (Wise SOACS)]
reduce Lambda (Wise SOACS)
map_lam')
where
consumed :: Names
consumed = SOAC (Wise SOACS) -> Names
forall rep. Aliased rep => SOAC rep -> Names
forall (op :: * -> *) rep.
(AliasedOp op, Aliased rep) =>
op rep -> Names
consumedInOp Op (Wise SOACS)
SOAC (Wise SOACS)
soac
map_param_names :: [VName]
map_param_names = (Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName (Lambda (Wise SOACS) -> [LParam (Wise SOACS)]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda (Wise SOACS)
map_lam)
topLevelPat :: Pat (VarWisdom, Type) -> Bool
topLevelPat = (Pat (VarWisdom, Type) -> Seq (Pat (VarWisdom, Type)) -> Bool
forall a. Eq a => a -> Seq a -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` (Stm (Wise SOACS) -> Pat (VarWisdom, Type))
-> Stms (Wise SOACS) -> Seq (Pat (VarWisdom, Type))
forall a b. (a -> b) -> Seq a -> Seq b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Stm (Wise SOACS) -> Pat (VarWisdom, Type)
Stm (Wise SOACS) -> Pat (LetDec (Wise SOACS))
forall rep. Stm rep -> Pat (LetDec rep)
stmPat (Body (Wise SOACS) -> Stms (Wise SOACS)
forall rep. Body rep -> Stms rep
bodyStms (Lambda (Wise SOACS) -> Body (Wise SOACS)
forall rep. Lambda rep -> Body rep
lambdaBody Lambda (Wise SOACS)
map_lam)))
onlyUsedOnce :: VName -> Bool
onlyUsedOnce VName
arr =
case (Stm (Wise SOACS) -> Bool)
-> [Stm (Wise SOACS)] -> [Stm (Wise SOACS)]
forall a. (a -> Bool) -> [a] -> [a]
filter ((VName
arr `nameIn`) (Names -> Bool)
-> (Stm (Wise SOACS) -> Names) -> Stm (Wise SOACS) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm (Wise SOACS) -> Names
forall a. FreeIn a => a -> Names
freeIn) ([Stm (Wise SOACS)] -> [Stm (Wise SOACS)])
-> [Stm (Wise SOACS)] -> [Stm (Wise SOACS)]
forall a b. (a -> b) -> a -> b
$ Stms (Wise SOACS) -> [Stm (Wise SOACS)]
forall rep. Stms rep -> [Stm rep]
stmsToList (Stms (Wise SOACS) -> [Stm (Wise SOACS)])
-> Stms (Wise SOACS) -> [Stm (Wise SOACS)]
forall a b. (a -> b) -> a -> b
$ Body (Wise SOACS) -> Stms (Wise SOACS)
forall rep. Body rep -> Stms rep
bodyStms (Body (Wise SOACS) -> Stms (Wise SOACS))
-> Body (Wise SOACS) -> Stms (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ Lambda (Wise SOACS) -> Body (Wise SOACS)
forall rep. Lambda rep -> Body rep
lambdaBody Lambda (Wise SOACS)
map_lam of
Stm (Wise SOACS)
_ : Stm (Wise SOACS)
_ : [Stm (Wise SOACS)]
_ -> Bool
False
[Stm (Wise SOACS)]
_ -> Bool
True
arrayIsMapParam :: (Pat (VarWisdom, Type), ArrayOp) -> Bool
arrayIsMapParam (Pat (VarWisdom, Type)
pat', ArrayIndexing Certs
cs VName
arr Slice SubExp
slice) =
VName
arr VName -> [VName] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
map_param_names
Bool -> Bool -> Bool
&& (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> SymbolTable (Wise SOACS) -> Bool
forall rep. VName -> SymbolTable rep -> Bool
`ST.elem` SymbolTable (Wise SOACS)
vtable) (Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ Certs -> Names
forall a. FreeIn a => a -> Names
freeIn Certs
cs Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Slice SubExp -> Names
forall a. FreeIn a => a -> Names
freeIn Slice SubExp
slice)
Bool -> Bool -> Bool
&& Bool -> Bool
not (Slice SubExp -> Bool
forall a. Slice a -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Slice SubExp
slice)
Bool -> Bool -> Bool
&& (Bool -> Bool
not ([SubExp] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null ([SubExp] -> Bool) -> [SubExp] -> Bool
forall a b. (a -> b) -> a -> b
$ Slice SubExp -> [SubExp]
forall d. Slice d -> [d]
sliceDims Slice SubExp
slice) Bool -> Bool -> Bool
|| (Pat (VarWisdom, Type) -> Bool
topLevelPat Pat (VarWisdom, Type)
pat' Bool -> Bool -> Bool
&& VName -> Bool
onlyUsedOnce VName
arr))
arrayIsMapParam (Pat (VarWisdom, Type)
_, ArrayRearrange Certs
cs VName
arr [Int]
perm) =
VName
arr VName -> [VName] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
map_param_names
Bool -> Bool -> Bool
&& (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> SymbolTable (Wise SOACS) -> Bool
forall rep. VName -> SymbolTable rep -> Bool
`ST.elem` SymbolTable (Wise SOACS)
vtable) (Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ Certs -> Names
forall a. FreeIn a => a -> Names
freeIn Certs
cs)
Bool -> Bool -> Bool
&& Bool -> Bool
not ([Int] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Int]
perm)
arrayIsMapParam (Pat (VarWisdom, Type)
_, ArrayReshape Certs
cs VName
arr ReshapeKind
_ Shape
new_shape) =
VName
arr VName -> [VName] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
map_param_names
Bool -> Bool -> Bool
&& (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> SymbolTable (Wise SOACS) -> Bool
forall rep. VName -> SymbolTable rep -> Bool
`ST.elem` SymbolTable (Wise SOACS)
vtable) (Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ Certs -> Names
forall a. FreeIn a => a -> Names
freeIn Certs
cs Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Shape -> Names
forall a. FreeIn a => a -> Names
freeIn Shape
new_shape)
arrayIsMapParam (Pat (VarWisdom, Type)
_, ArrayCopy Certs
cs VName
arr) =
VName
arr VName -> [VName] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
map_param_names
Bool -> Bool -> Bool
&& (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> SymbolTable (Wise SOACS) -> Bool
forall rep. VName -> SymbolTable rep -> Bool
`ST.elem` SymbolTable (Wise SOACS)
vtable) (Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ Certs -> Names
forall a. FreeIn a => a -> Names
freeIn Certs
cs)
arrayIsMapParam (Pat (VarWisdom, Type)
_, ArrayVar {}) =
Bool
False
mapOverArr :: (Pat (VarWisdom, Type), ArrayOp)
-> RuleM
(Wise SOACS)
(Maybe (VName, Param Type, (Pat (VarWisdom, Type), ArrayOp)))
mapOverArr (Pat (VarWisdom, Type)
pat, ArrayOp
op)
| Just (VName
_, VName
arr) <- ((VName, VName) -> Bool)
-> [(VName, VName)] -> Maybe (VName, VName)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== ArrayOp -> VName
arrayOpArr ArrayOp
op) (VName -> Bool)
-> ((VName, VName) -> VName) -> (VName, VName) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, VName) -> VName
forall a b. (a, b) -> a
fst) ([VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
map_param_names [VName]
arrs),
VName
arr VName -> Names -> Bool
`notNameIn` Names
consumed = do
Type
arr_t <- VName -> RuleM (Wise SOACS) Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
arr
let whole_dim :: DimIndex SubExp
whole_dim = SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (Int -> Type -> SubExp
forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 Type
arr_t) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1)
VName
arr_transformed <- Certs -> RuleM (Wise SOACS) VName -> RuleM (Wise SOACS) VName
forall a. Certs -> RuleM (Wise SOACS) a -> RuleM (Wise SOACS) a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (ArrayOp -> Certs
arrayOpCerts ArrayOp
op) (RuleM (Wise SOACS) VName -> RuleM (Wise SOACS) VName)
-> RuleM (Wise SOACS) VName -> RuleM (Wise SOACS) VName
forall a b. (a -> b) -> a -> b
$
[Char]
-> Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
arr [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
"_transformed") (Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) VName)
-> Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) VName
forall a b. (a -> b) -> a -> b
$
case ArrayOp
op of
ArrayIndexing Certs
_ VName
_ (Slice [DimIndex SubExp]
slice) ->
BasicOp -> Exp (Rep (RuleM (Wise SOACS)))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (RuleM (Wise SOACS))))
-> BasicOp -> Exp (Rep (RuleM (Wise SOACS)))
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ [DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex SubExp] -> Slice SubExp)
-> [DimIndex SubExp] -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ DimIndex SubExp
whole_dim DimIndex SubExp -> [DimIndex SubExp] -> [DimIndex SubExp]
forall a. a -> [a] -> [a]
: [DimIndex SubExp]
slice
ArrayRearrange Certs
_ VName
_ [Int]
perm ->
BasicOp -> Exp (Rep (RuleM (Wise SOACS)))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (RuleM (Wise SOACS))))
-> BasicOp -> Exp (Rep (RuleM (Wise SOACS)))
forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Rearrange (Int
0 Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: (Int -> Int) -> [Int] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) [Int]
perm) VName
arr
ArrayReshape Certs
_ VName
_ ReshapeKind
k Shape
new_shape ->
BasicOp -> Exp (Rep (RuleM (Wise SOACS)))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (RuleM (Wise SOACS))))
-> BasicOp -> Exp (Rep (RuleM (Wise SOACS)))
forall a b. (a -> b) -> a -> b
$ ReshapeKind -> Shape -> VName -> BasicOp
Reshape ReshapeKind
k ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
w] Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> Shape
new_shape) VName
arr
ArrayCopy {} ->
BasicOp -> Exp (Rep (RuleM (Wise SOACS)))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (RuleM (Wise SOACS))))
-> BasicOp -> Exp (Rep (RuleM (Wise SOACS)))
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate Shape
forall a. Monoid a => a
mempty (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
arr
ArrayVar {} ->
BasicOp -> Exp (Rep (RuleM (Wise SOACS)))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (RuleM (Wise SOACS))))
-> BasicOp -> Exp (Rep (RuleM (Wise SOACS)))
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
arr
Type
arr_transformed_t <- VName -> RuleM (Wise SOACS) Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
arr_transformed
VName
arr_transformed_row <- [Char] -> RuleM (Wise SOACS) VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName ([Char] -> RuleM (Wise SOACS) VName)
-> [Char] -> RuleM (Wise SOACS) VName
forall a b. (a -> b) -> a -> b
$ VName -> [Char]
baseString VName
arr [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
"_transformed_row"
Maybe (VName, Param Type, (Pat (VarWisdom, Type), ArrayOp))
-> RuleM
(Wise SOACS)
(Maybe (VName, Param Type, (Pat (VarWisdom, Type), ArrayOp)))
forall a. a -> RuleM (Wise SOACS) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe (VName, Param Type, (Pat (VarWisdom, Type), ArrayOp))
-> RuleM
(Wise SOACS)
(Maybe (VName, Param Type, (Pat (VarWisdom, Type), ArrayOp))))
-> Maybe (VName, Param Type, (Pat (VarWisdom, Type), ArrayOp))
-> RuleM
(Wise SOACS)
(Maybe (VName, Param Type, (Pat (VarWisdom, Type), ArrayOp)))
forall a b. (a -> b) -> a -> b
$
(VName, Param Type, (Pat (VarWisdom, Type), ArrayOp))
-> Maybe (VName, Param Type, (Pat (VarWisdom, Type), ArrayOp))
forall a. a -> Maybe a
Just
( VName
arr_transformed,
Attrs -> VName -> Type -> Param Type
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
forall a. Monoid a => a
mempty VName
arr_transformed_row (Type -> Type
forall u. TypeBase Shape u -> TypeBase Shape u
rowType Type
arr_transformed_t),
(Pat (VarWisdom, Type)
pat, Certs -> VName -> ArrayOp
ArrayVar Certs
forall a. Monoid a => a
mempty VName
arr_transformed_row)
)
mapOverArr (Pat (VarWisdom, Type), ArrayOp)
_ = Maybe (VName, Param Type, (Pat (VarWisdom, Type), ArrayOp))
-> RuleM
(Wise SOACS)
(Maybe (VName, Param Type, (Pat (VarWisdom, Type), ArrayOp)))
forall a. a -> RuleM (Wise SOACS) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (VName, Param Type, (Pat (VarWisdom, Type), ArrayOp))
forall a. Maybe a
Nothing
moveTransformToInput SymbolTable (Wise SOACS)
_ Pat (LetDec (Wise SOACS))
_ StmAux (ExpDec (Wise SOACS))
_ Op (Wise SOACS)
_ =
Rule (Wise SOACS)
forall rep. Rule rep
Skip