{-# LANGUAGE TypeFamilies #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}

module Futhark.IR.SOACS.Simplify
  ( simplifySOACS,
    simplifyLambda,
    simplifyFun,
    simplifyStms,
    simplifyConsts,
    simpleSOACS,
    simplifySOAC,
    soacRules,
    HasSOAC (..),
    simplifyKnownIterationSOAC,
    removeReplicateMapping,
    removeUnusedSOACInput,
    liftIdentityMapping,
    simplifyMapIota,
    SOACS,
  )
where

import Control.Monad
import Control.Monad.Identity
import Control.Monad.State
import Control.Monad.Writer
import Data.Either
import Data.Foldable
import Data.List (partition, transpose, unzip6, zip6)
import Data.List.NonEmpty (NonEmpty (..))
import Data.Map.Strict qualified as M
import Data.Maybe
import Data.Set qualified as S
import Futhark.Analysis.DataDependencies
import Futhark.Analysis.SymbolTable qualified as ST
import Futhark.Analysis.UsageTable qualified as UT
import Futhark.IR.Prop.Aliases
import Futhark.IR.SOACS
import Futhark.MonadFreshNames
import Futhark.Optimise.Simplify qualified as Simplify
import Futhark.Optimise.Simplify.Engine qualified as Engine
import Futhark.Optimise.Simplify.Rep
import Futhark.Optimise.Simplify.Rule
import Futhark.Optimise.Simplify.Rules
import Futhark.Optimise.Simplify.Rules.ClosedForm
import Futhark.Pass
import Futhark.Tools
import Futhark.Transform.Rename
import Futhark.Util

simpleSOACS :: Simplify.SimpleOps SOACS
simpleSOACS :: SimpleOps SOACS
simpleSOACS = forall {k} (rep :: k).
(SimplifiableRep rep, Buildable rep) =>
SimplifyOp rep (Op (Wise rep)) -> SimpleOps rep
Simplify.bindableSimpleOps forall {k} (rep :: k).
SimplifiableRep rep =>
SimplifyOp rep (SOAC (Wise rep))
simplifySOAC

simplifySOACS :: Prog SOACS -> PassM (Prog SOACS)
simplifySOACS :: Prog SOACS -> PassM (Prog SOACS)
simplifySOACS =
  forall {k} (rep :: k).
SimplifiableRep rep =>
SimpleOps rep
-> RuleBook (Wise rep)
-> HoistBlockers rep
-> Prog rep
-> PassM (Prog rep)
Simplify.simplifyProg SimpleOps SOACS
simpleSOACS RuleBook (Wise SOACS)
soacRules forall {k} (rep :: k). HoistBlockers rep
Engine.noExtraHoistBlockers

simplifyFun ::
  MonadFreshNames m =>
  ST.SymbolTable (Wise SOACS) ->
  FunDef SOACS ->
  m (FunDef SOACS)
simplifyFun :: forall (m :: * -> *).
MonadFreshNames m =>
SymbolTable (Wise SOACS) -> FunDef SOACS -> m (FunDef SOACS)
simplifyFun =
  forall {k} (m :: * -> *) (rep :: k).
(MonadFreshNames m, SimplifiableRep rep) =>
SimpleOps rep
-> RuleBook (Wise rep)
-> HoistBlockers rep
-> SymbolTable (Wise rep)
-> FunDef rep
-> m (FunDef rep)
Simplify.simplifyFun SimpleOps SOACS
simpleSOACS RuleBook (Wise SOACS)
soacRules forall {k} (rep :: k). HoistBlockers rep
Engine.noExtraHoistBlockers

simplifyLambda ::
  (HasScope SOACS m, MonadFreshNames m) => Lambda SOACS -> m (Lambda SOACS)
simplifyLambda :: forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
Lambda SOACS -> m (Lambda SOACS)
simplifyLambda =
  forall {k} (m :: * -> *) (rep :: k).
(MonadFreshNames m, HasScope rep m, SimplifiableRep rep) =>
SimpleOps rep
-> RuleBook (Wise rep)
-> HoistBlockers rep
-> Lambda rep
-> m (Lambda rep)
Simplify.simplifyLambda SimpleOps SOACS
simpleSOACS RuleBook (Wise SOACS)
soacRules forall {k} (rep :: k). HoistBlockers rep
Engine.noExtraHoistBlockers

simplifyStms ::
  (HasScope SOACS m, MonadFreshNames m) => Stms SOACS -> m (Stms SOACS)
simplifyStms :: forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
Stms SOACS -> m (Stms SOACS)
simplifyStms Stms SOACS
stms = do
  Scope SOACS
scope <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
m (Scope rep)
askScope
  forall {k} (m :: * -> *) (rep :: k).
(MonadFreshNames m, SimplifiableRep rep) =>
SimpleOps rep
-> RuleBook (Wise rep)
-> HoistBlockers rep
-> Scope rep
-> Stms rep
-> m (Stms rep)
Simplify.simplifyStms SimpleOps SOACS
simpleSOACS RuleBook (Wise SOACS)
soacRules forall {k} (rep :: k). HoistBlockers rep
Engine.noExtraHoistBlockers Scope SOACS
scope Stms SOACS
stms

simplifyConsts ::
  MonadFreshNames m => Stms SOACS -> m (Stms SOACS)
simplifyConsts :: forall (m :: * -> *).
MonadFreshNames m =>
Stms SOACS -> m (Stms SOACS)
simplifyConsts =
  forall {k} (m :: * -> *) (rep :: k).
(MonadFreshNames m, SimplifiableRep rep) =>
SimpleOps rep
-> RuleBook (Wise rep)
-> HoistBlockers rep
-> Scope rep
-> Stms rep
-> m (Stms rep)
Simplify.simplifyStms SimpleOps SOACS
simpleSOACS RuleBook (Wise SOACS)
soacRules forall {k} (rep :: k). HoistBlockers rep
Engine.noExtraHoistBlockers forall a. Monoid a => a
mempty

simplifySOAC ::
  Simplify.SimplifiableRep rep =>
  Simplify.SimplifyOp rep (SOAC (Wise rep))
simplifySOAC :: forall {k} (rep :: k).
SimplifiableRep rep =>
SimplifyOp rep (SOAC (Wise rep))
simplifySOAC (VJP Lambda (Wise rep)
lam [SubExp]
arr [SubExp]
vec) = do
  (Lambda (Wise rep)
lam', Stms (Wise rep)
hoisted) <- forall {k} (rep :: k).
SimplifiableRep rep =>
Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
Engine.simplifyLambda Lambda (Wise rep)
lam
  [SubExp]
arr' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall e {k} (rep :: k).
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [SubExp]
arr
  [SubExp]
vec' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall e {k} (rep :: k).
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [SubExp]
vec
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall {k} (rep :: k).
Lambda rep -> [SubExp] -> [SubExp] -> SOAC rep
VJP Lambda (Wise rep)
lam' [SubExp]
arr' [SubExp]
vec', Stms (Wise rep)
hoisted)
simplifySOAC (JVP Lambda (Wise rep)
lam [SubExp]
arr [SubExp]
vec) = do
  (Lambda (Wise rep)
lam', Stms (Wise rep)
hoisted) <- forall {k} (rep :: k).
SimplifiableRep rep =>
Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
Engine.simplifyLambda Lambda (Wise rep)
lam
  [SubExp]
arr' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall e {k} (rep :: k).
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [SubExp]
arr
  [SubExp]
vec' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall e {k} (rep :: k).
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [SubExp]
vec
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall {k} (rep :: k).
Lambda rep -> [SubExp] -> [SubExp] -> SOAC rep
JVP Lambda (Wise rep)
lam' [SubExp]
arr' [SubExp]
vec', Stms (Wise rep)
hoisted)
simplifySOAC (Stream SubExp
outerdim [VName]
arr [SubExp]
nes Lambda (Wise rep)
lam) = do
  SubExp
outerdim' <- forall e {k} (rep :: k).
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify SubExp
outerdim
  [SubExp]
nes' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall e {k} (rep :: k).
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [SubExp]
nes
  [VName]
arr' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall e {k} (rep :: k).
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [VName]
arr
  (Lambda (Wise rep)
lam', Stms (Wise rep)
lam_hoisted) <- forall {k} (rep :: k) a. SimpleM rep a -> SimpleM rep a
Engine.enterLoop forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
SimplifiableRep rep =>
Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
Engine.simplifyLambda Lambda (Wise rep)
lam
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall {k} (rep :: k).
SubExp -> [VName] -> [SubExp] -> Lambda rep -> SOAC rep
Stream SubExp
outerdim' [VName]
arr' [SubExp]
nes' Lambda (Wise rep)
lam', Stms (Wise rep)
lam_hoisted)
simplifySOAC (Scatter SubExp
w [VName]
ivs Lambda (Wise rep)
lam [(Shape, Int, VName)]
as) = do
  SubExp
w' <- forall e {k} (rep :: k).
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify SubExp
w
  (Lambda (Wise rep)
lam', Stms (Wise rep)
hoisted) <- forall {k} (rep :: k) a. SimpleM rep a -> SimpleM rep a
Engine.enterLoop forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
SimplifiableRep rep =>
Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
Engine.simplifyLambda Lambda (Wise rep)
lam
  [VName]
ivs' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall e {k} (rep :: k).
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [VName]
ivs
  [(Shape, Int, VName)]
as' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall e {k} (rep :: k).
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [(Shape, Int, VName)]
as
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall {k} (rep :: k).
SubExp
-> [VName] -> Lambda rep -> [(Shape, Int, VName)] -> SOAC rep
Scatter SubExp
w' [VName]
ivs' Lambda (Wise rep)
lam' [(Shape, Int, VName)]
as', Stms (Wise rep)
hoisted)
simplifySOAC (Hist SubExp
w [VName]
imgs [HistOp (Wise rep)]
ops Lambda (Wise rep)
bfun) = do
  SubExp
w' <- forall e {k} (rep :: k).
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify SubExp
w
  ([HistOp (Wise rep)]
ops', [Stms (Wise rep)]
hoisted) <- forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$
    forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [HistOp (Wise rep)]
ops forall a b. (a -> b) -> a -> b
$ \(HistOp Shape
dests_w SubExp
rf [VName]
dests [SubExp]
nes Lambda (Wise rep)
op) -> do
      Shape
dests_w' <- forall e {k} (rep :: k).
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify Shape
dests_w
      SubExp
rf' <- forall e {k} (rep :: k).
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify SubExp
rf
      [VName]
dests' <- forall e {k} (rep :: k).
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [VName]
dests
      [SubExp]
nes' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall e {k} (rep :: k).
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [SubExp]
nes
      (Lambda (Wise rep)
op', Stms (Wise rep)
hoisted) <- forall {k} (rep :: k) a. SimpleM rep a -> SimpleM rep a
Engine.enterLoop forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
SimplifiableRep rep =>
Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
Engine.simplifyLambda Lambda (Wise rep)
op
      forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall {k} (rep :: k).
Shape -> SubExp -> [VName] -> [SubExp] -> Lambda rep -> HistOp rep
HistOp Shape
dests_w' SubExp
rf' [VName]
dests' [SubExp]
nes' Lambda (Wise rep)
op', Stms (Wise rep)
hoisted)
  [VName]
imgs' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall e {k} (rep :: k).
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [VName]
imgs
  (Lambda (Wise rep)
bfun', Stms (Wise rep)
bfun_hoisted) <- forall {k} (rep :: k) a. SimpleM rep a -> SimpleM rep a
Engine.enterLoop forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
SimplifiableRep rep =>
Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
Engine.simplifyLambda Lambda (Wise rep)
bfun
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall {k} (rep :: k).
SubExp -> [VName] -> [HistOp rep] -> Lambda rep -> SOAC rep
Hist SubExp
w' [VName]
imgs' [HistOp (Wise rep)]
ops' Lambda (Wise rep)
bfun', forall a. Monoid a => [a] -> a
mconcat [Stms (Wise rep)]
hoisted forall a. Semigroup a => a -> a -> a
<> Stms (Wise rep)
bfun_hoisted)
simplifySOAC (Screma SubExp
w [VName]
arrs (ScremaForm [Scan (Wise rep)]
scans [Reduce (Wise rep)]
reds Lambda (Wise rep)
map_lam)) = do
  ([Scan (Wise rep)]
scans', [Stms (Wise rep)]
scans_hoisted) <- forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$
    forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Scan (Wise rep)]
scans forall a b. (a -> b) -> a -> b
$ \(Scan Lambda (Wise rep)
lam [SubExp]
nes) -> do
      (Lambda (Wise rep)
lam', Stms (Wise rep)
hoisted) <- forall {k} (rep :: k).
SimplifiableRep rep =>
Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
Engine.simplifyLambda Lambda (Wise rep)
lam
      [SubExp]
nes' <- forall e {k} (rep :: k).
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [SubExp]
nes
      forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall {k} (rep :: k). Lambda rep -> [SubExp] -> Scan rep
Scan Lambda (Wise rep)
lam' [SubExp]
nes', Stms (Wise rep)
hoisted)

  ([Reduce (Wise rep)]
reds', [Stms (Wise rep)]
reds_hoisted) <- forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$
    forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Reduce (Wise rep)]
reds forall a b. (a -> b) -> a -> b
$ \(Reduce Commutativity
comm Lambda (Wise rep)
lam [SubExp]
nes) -> do
      (Lambda (Wise rep)
lam', Stms (Wise rep)
hoisted) <- forall {k} (rep :: k).
SimplifiableRep rep =>
Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
Engine.simplifyLambda Lambda (Wise rep)
lam
      [SubExp]
nes' <- forall e {k} (rep :: k).
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [SubExp]
nes
      forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall {k} (rep :: k).
Commutativity -> Lambda rep -> [SubExp] -> Reduce rep
Reduce Commutativity
comm Lambda (Wise rep)
lam' [SubExp]
nes', Stms (Wise rep)
hoisted)

  (Lambda (Wise rep)
map_lam', Stms (Wise rep)
map_lam_hoisted) <- forall {k} (rep :: k) a. SimpleM rep a -> SimpleM rep a
Engine.enterLoop forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
SimplifiableRep rep =>
Lambda (Wise rep)
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
Engine.simplifyLambda Lambda (Wise rep)
map_lam

  (,)
    forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ( forall {k} (rep :: k).
SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma
            forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall e {k} (rep :: k).
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify SubExp
w
            forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall e {k} (rep :: k).
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [VName]
arrs
            forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall {k} (rep :: k).
[Scan rep] -> [Reduce rep] -> Lambda rep -> ScremaForm rep
ScremaForm [Scan (Wise rep)]
scans' [Reduce (Wise rep)]
reds' Lambda (Wise rep)
map_lam')
        )
    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a. Monoid a => [a] -> a
mconcat [Stms (Wise rep)]
scans_hoisted forall a. Semigroup a => a -> a -> a
<> forall a. Monoid a => [a] -> a
mconcat [Stms (Wise rep)]
reds_hoisted forall a. Semigroup a => a -> a -> a
<> Stms (Wise rep)
map_lam_hoisted)

instance BuilderOps (Wise SOACS)

instance TraverseOpStms (Wise SOACS) where
  traverseOpStms :: forall (m :: * -> *).
Monad m =>
OpStmsTraverser m (Op (Wise SOACS)) (Wise SOACS)
traverseOpStms = forall {k} (m :: * -> *) (rep :: k).
Monad m =>
OpStmsTraverser m (SOAC rep) rep
traverseSOACStms

fixLambdaParams ::
  (MonadBuilder m, Buildable (Rep m), BuilderOps (Rep m)) =>
  Lambda (Rep m) ->
  [Maybe SubExp] ->
  m (Lambda (Rep m))
fixLambdaParams :: forall (m :: * -> *).
(MonadBuilder m, Buildable (Rep m), BuilderOps (Rep m)) =>
Lambda (Rep m) -> [Maybe SubExp] -> m (Lambda (Rep m))
fixLambdaParams Lambda (Rep m)
lam [Maybe SubExp]
fixes = do
  Body (Rep m)
body <- forall {k1} {k2} (rep :: k1) (m :: * -> *) (somerep :: k2).
(Buildable rep, MonadFreshNames m, HasScope somerep m,
 SameScope somerep rep) =>
Builder rep (Body rep) -> m (Body rep)
runBodyBuilder forall a b. (a -> b) -> a -> b
$
    forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall {k} (rep :: k) dec.
(LParamInfo rep ~ dec) =>
[Param dec] -> Scope rep
scopeOfLParams forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda (Rep m)
lam) forall a b. (a -> b) -> a -> b
$ do
      forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ forall {m :: * -> *} {dec}.
MonadBuilder m =>
Param dec -> Maybe SubExp -> m ()
maybeFix (forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda (Rep m)
lam) [Maybe SubExp]
fixes'
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda (Rep m)
lam
  forall (f :: * -> *) a. Applicative f => a -> f a
pure
    Lambda (Rep m)
lam
      { lambdaBody :: Body (Rep m)
lambdaBody = Body (Rep m)
body,
        lambdaParams :: [LParam (Rep m)]
lambdaParams =
          forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst forall a b. (a -> b) -> a -> b
$
            forall a. (a -> Bool) -> [a] -> [a]
filter (forall a. Maybe a -> Bool
isNothing forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd) forall a b. (a -> b) -> a -> b
$
              forall a b. [a] -> [b] -> [(a, b)]
zip (forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda (Rep m)
lam) [Maybe SubExp]
fixes'
      }
  where
    fixes' :: [Maybe SubExp]
fixes' = [Maybe SubExp]
fixes forall a. [a] -> [a] -> [a]
++ forall a. a -> [a]
repeat forall a. Maybe a
Nothing
    maybeFix :: Param dec -> Maybe SubExp -> m ()
maybeFix Param dec
p (Just SubExp
x) = forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [forall dec. Param dec -> VName
paramName Param dec
p] forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
x
    maybeFix Param dec
_ Maybe SubExp
Nothing = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

removeLambdaResults :: [Bool] -> Lambda rep -> Lambda rep
removeLambdaResults :: forall {k} (rep :: k). [Bool] -> Lambda rep -> Lambda rep
removeLambdaResults [Bool]
keep Lambda rep
lam =
  Lambda rep
lam
    { lambdaBody :: Body rep
lambdaBody = Body rep
lam_body',
      lambdaReturnType :: [Type]
lambdaReturnType = [Type]
ret
    }
  where
    keep' :: [a] -> [a]
    keep' :: forall a. [a] -> [a]
keep' = forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. (a -> Bool) -> [a] -> [a]
filter forall a b. (a, b) -> a
fst forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. [a] -> [b] -> [(a, b)]
zip ([Bool]
keep forall a. [a] -> [a] -> [a]
++ forall a. a -> [a]
repeat Bool
True)
    lam_body :: Body rep
lam_body = forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda rep
lam
    lam_body' :: Body rep
lam_body' = Body rep
lam_body {bodyResult :: Result
bodyResult = forall a. [a] -> [a]
keep' forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Body rep -> Result
bodyResult Body rep
lam_body}
    ret :: [Type]
ret = forall a. [a] -> [a]
keep' forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda rep
lam

soacRules :: RuleBook (Wise SOACS)
soacRules :: RuleBook (Wise SOACS)
soacRules = forall rep.
(BuilderOps rep, TraverseOpStms rep, Aliased rep) =>
RuleBook rep
standardRules forall a. Semigroup a => a -> a -> a
<> forall {k} (m :: k).
[TopDownRule m] -> [BottomUpRule m] -> RuleBook m
ruleBook [TopDownRule (Wise SOACS)]
topDownRules [BottomUpRule (Wise SOACS)]
bottomUpRules

-- | Does this rep contain 'SOAC's in its t'Op's?  A rep must be an
-- instance of this class for the simplification rules to work.
class HasSOAC rep where
  asSOAC :: Op rep -> Maybe (SOAC rep)
  soacOp :: SOAC rep -> Op rep

instance HasSOAC (Wise SOACS) where
  asSOAC :: Op (Wise SOACS) -> Maybe (SOAC (Wise SOACS))
asSOAC = forall a. a -> Maybe a
Just
  soacOp :: SOAC (Wise SOACS) -> Op (Wise SOACS)
soacOp = forall a. a -> a
id

topDownRules :: [TopDownRule (Wise SOACS)]
topDownRules :: [TopDownRule (Wise SOACS)]
topDownRules =
  [ forall {k} (rep :: k) a. RuleOp rep a -> SimplificationRule rep a
RuleOp TopDownRuleOp (Wise SOACS)
hoistCerts,
    forall {k} (rep :: k) a. RuleOp rep a -> SimplificationRule rep a
RuleOp forall rep.
(Aliased rep, BuilderOps rep, HasSOAC rep) =>
TopDownRuleOp rep
removeReplicateMapping,
    forall {k} (rep :: k) a. RuleOp rep a -> SimplificationRule rep a
RuleOp TopDownRuleOp (Wise SOACS)
removeReplicateWrite,
    forall {k} (rep :: k) a. RuleOp rep a -> SimplificationRule rep a
RuleOp forall rep.
(Aliased rep, Buildable rep, BuilderOps rep, HasSOAC rep) =>
TopDownRuleOp rep
removeUnusedSOACInput,
    forall {k} (rep :: k) a. RuleOp rep a -> SimplificationRule rep a
RuleOp TopDownRuleOp (Wise SOACS)
simplifyClosedFormReduce,
    forall {k} (rep :: k) a. RuleOp rep a -> SimplificationRule rep a
RuleOp forall rep.
(Buildable rep, BuilderOps rep, HasSOAC rep) =>
TopDownRuleOp rep
simplifyKnownIterationSOAC,
    forall {k} (rep :: k) a. RuleOp rep a -> SimplificationRule rep a
RuleOp forall rep.
(Buildable rep, BuilderOps rep, HasSOAC rep) =>
TopDownRuleOp rep
liftIdentityMapping,
    forall {k} (rep :: k) a. RuleOp rep a -> SimplificationRule rep a
RuleOp TopDownRuleOp (Wise SOACS)
removeDuplicateMapOutput,
    forall {k} (rep :: k) a. RuleOp rep a -> SimplificationRule rep a
RuleOp TopDownRuleOp (Wise SOACS)
fuseConcatScatter,
    forall {k} (rep :: k) a. RuleOp rep a -> SimplificationRule rep a
RuleOp forall rep.
(Buildable rep, BuilderOps rep, HasSOAC rep) =>
TopDownRuleOp rep
simplifyMapIota,
    forall {k} (rep :: k) a. RuleOp rep a -> SimplificationRule rep a
RuleOp TopDownRuleOp (Wise SOACS)
moveTransformToInput
  ]

bottomUpRules :: [BottomUpRule (Wise SOACS)]
bottomUpRules :: [BottomUpRule (Wise SOACS)]
bottomUpRules =
  [ forall {k} (rep :: k) a. RuleOp rep a -> SimplificationRule rep a
RuleOp BottomUpRuleOp (Wise SOACS)
removeDeadMapping,
    forall {k} (rep :: k) a. RuleOp rep a -> SimplificationRule rep a
RuleOp BottomUpRuleOp (Wise SOACS)
removeDeadReduction,
    forall {k} (rep :: k) a. RuleOp rep a -> SimplificationRule rep a
RuleOp BottomUpRuleOp (Wise SOACS)
removeDeadWrite,
    forall {k} (rep :: k) a.
RuleBasicOp rep a -> SimplificationRule rep a
RuleBasicOp forall rep. BuilderOps rep => BottomUpRuleBasicOp rep
removeUnnecessaryCopy,
    forall {k} (rep :: k) a. RuleOp rep a -> SimplificationRule rep a
RuleOp BottomUpRuleOp (Wise SOACS)
liftIdentityStreaming,
    forall {k} (rep :: k) a. RuleOp rep a -> SimplificationRule rep a
RuleOp BottomUpRuleOp (Wise SOACS)
mapOpToOp
  ]

-- Any certificates attached to a trivial Stm in the body might as
-- well be applied to the SOAC itself.
hoistCerts :: TopDownRuleOp (Wise SOACS)
hoistCerts :: TopDownRuleOp (Wise SOACS)
hoistCerts SymbolTable (Wise SOACS)
vtable Pat (LetDec (Wise SOACS))
pat StmAux (ExpDec (Wise SOACS))
aux Op (Wise SOACS)
soac
  | (SOAC (Wise SOACS)
soac', Certs
hoisted) <- forall s a. State s a -> s -> (a, s)
runState (forall {k1} {k2} (m :: * -> *) (frep :: k1) (trep :: k2).
Monad m =>
SOACMapper frep trep m -> SOAC frep -> m (SOAC trep)
mapSOACM SOACMapper (Wise SOACS) (Wise SOACS) (StateT Certs Identity)
mapper Op (Wise SOACS)
soac) forall a. Monoid a => a
mempty,
    Certs
hoisted forall a. Eq a => a -> a -> Bool
/= forall a. Monoid a => a
mempty =
      forall {k} (rep :: k). RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec (Wise SOACS))
aux forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
hoisted forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec (Wise SOACS))
pat forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Op rep -> Exp rep
Op SOAC (Wise SOACS)
soac'
  where
    mapper :: SOACMapper (Wise SOACS) (Wise SOACS) (StateT Certs Identity)
mapper = forall {k} (m :: * -> *) (rep :: k).
Monad m =>
SOACMapper rep rep m
identitySOACMapper {mapOnSOACLambda :: Lambda (Wise SOACS) -> StateT Certs Identity (Lambda (Wise SOACS))
mapOnSOACLambda = Lambda (Wise SOACS) -> StateT Certs Identity (Lambda (Wise SOACS))
onLambda}
    onLambda :: Lambda (Wise SOACS) -> StateT Certs Identity (Lambda (Wise SOACS))
onLambda Lambda (Wise SOACS)
lam = do
      Stms (Wise SOACS)
stms' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Stm (Wise SOACS) -> StateT Certs Identity (Stm (Wise SOACS))
onStm forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Body rep -> Stms rep
bodyStms forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda (Wise SOACS)
lam
      forall (f :: * -> *) a. Applicative f => a -> f a
pure
        Lambda (Wise SOACS)
lam
          { lambdaBody :: Body (Wise SOACS)
lambdaBody =
              forall {k} (rep :: k).
Buildable rep =>
Stms rep -> Result -> Body rep
mkBody Stms (Wise SOACS)
stms' forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Body rep -> Result
bodyResult forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda (Wise SOACS)
lam
          }
    onStm :: Stm (Wise SOACS) -> StateT Certs Identity (Stm (Wise SOACS))
onStm (Let Pat (LetDec (Wise SOACS))
se_pat StmAux (ExpDec (Wise SOACS))
se_aux (BasicOp (SubExp SubExp
se))) = do
      let ([VName]
invariant, [VName]
variant) =
            forall a. (a -> Bool) -> [a] -> ([a], [a])
partition (forall {k} (rep :: k). VName -> SymbolTable rep -> Bool
`ST.elem` SymbolTable (Wise SOACS)
vtable) forall a b. (a -> b) -> a -> b
$
              Certs -> [VName]
unCerts forall a b. (a -> b) -> a -> b
$
                forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (ExpDec (Wise SOACS))
se_aux
          se_aux' :: StmAux (ExpDec (Wise SOACS))
se_aux' = StmAux (ExpDec (Wise SOACS))
se_aux {stmAuxCerts :: Certs
stmAuxCerts = [VName] -> Certs
Certs [VName]
variant}
      forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ([VName] -> Certs
Certs [VName]
invariant <>)
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec (Wise SOACS))
se_pat StmAux (ExpDec (Wise SOACS))
se_aux' forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
    onStm Stm (Wise SOACS)
stm = forall (f :: * -> *) a. Applicative f => a -> f a
pure Stm (Wise SOACS)
stm
hoistCerts SymbolTable (Wise SOACS)
_ Pat (LetDec (Wise SOACS))
_ StmAux (ExpDec (Wise SOACS))
_ Op (Wise SOACS)
_ =
  forall {k} (rep :: k). Rule rep
Skip

liftIdentityMapping ::
  forall rep.
  (Buildable rep, BuilderOps rep, HasSOAC rep) =>
  TopDownRuleOp rep
liftIdentityMapping :: forall rep.
(Buildable rep, BuilderOps rep, HasSOAC rep) =>
TopDownRuleOp rep
liftIdentityMapping TopDown rep
_ Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux Op rep
op
  | Just (Screma SubExp
w [VName]
arrs ScremaForm rep
form :: SOAC rep) <- forall {k} (rep :: k). HasSOAC rep => Op rep -> Maybe (SOAC rep)
asSOAC Op rep
op,
    Just Lambda rep
fun <- forall {k} (rep :: k). ScremaForm rep -> Maybe (Lambda rep)
isMapSOAC ScremaForm rep
form = do
      let inputMap :: Map VName VName
inputMap = forall k a. Ord k => [(k, a)] -> Map k a
M.fromList forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip (forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda rep
fun) [VName]
arrs
          free :: Names
free = forall a. FreeIn a => a -> Names
freeIn forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda rep
fun
          rettype :: [Type]
rettype = forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda rep
fun
          ses :: Result
ses = forall {k} (rep :: k). Body rep -> Result
bodyResult forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda rep
fun

          freeOrConst :: SubExp -> Bool
freeOrConst (Var VName
v) = VName
v VName -> Names -> Bool
`nameIn` Names
free
          freeOrConst Constant {} = Bool
True

          checkInvariance :: (PatElem (LetDec rep), SubExpRes, Type)
-> ([(Pat (LetDec rep), Exp rep)],
    [(PatElem (LetDec rep), SubExp)], [Type])
-> ([(Pat (LetDec rep), Exp rep)],
    [(PatElem (LetDec rep), SubExp)], [Type])
checkInvariance (PatElem (LetDec rep)
outId, SubExpRes Certs
_ (Var VName
v), Type
_) ([(Pat (LetDec rep), Exp rep)]
invariant, [(PatElem (LetDec rep), SubExp)]
mapresult, [Type]
rettype')
            | Just VName
inp <- forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName VName
inputMap =
                ( (forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (LetDec rep)
outId], VName -> Exp rep
e VName
inp) forall a. a -> [a] -> [a]
: [(Pat (LetDec rep), Exp rep)]
invariant,
                  [(PatElem (LetDec rep), SubExp)]
mapresult,
                  [Type]
rettype'
                )
            where
              e :: VName -> Exp rep
e VName
inp = case forall dec. Typed dec => PatElem dec -> Type
patElemType PatElem (LetDec rep)
outId of
                Acc {} -> forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
inp
                Type
_ -> forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp (VName -> BasicOp
Copy VName
inp)
          checkInvariance (PatElem (LetDec rep)
outId, SubExpRes Certs
_ SubExp
e, Type
t) ([(Pat (LetDec rep), Exp rep)]
invariant, [(PatElem (LetDec rep), SubExp)]
mapresult, [Type]
rettype')
            | SubExp -> Bool
freeOrConst SubExp
e =
                ( (forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (LetDec rep)
outId], forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate (forall d. [d] -> ShapeBase d
Shape [SubExp
w]) SubExp
e) forall a. a -> [a] -> [a]
: [(Pat (LetDec rep), Exp rep)]
invariant,
                  [(PatElem (LetDec rep), SubExp)]
mapresult,
                  [Type]
rettype'
                )
            | Bool
otherwise =
                ( [(Pat (LetDec rep), Exp rep)]
invariant,
                  (PatElem (LetDec rep)
outId, SubExp
e) forall a. a -> [a] -> [a]
: [(PatElem (LetDec rep), SubExp)]
mapresult,
                  Type
t forall a. a -> [a] -> [a]
: [Type]
rettype'
                )

      case forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (PatElem (LetDec rep), SubExpRes, Type)
-> ([(Pat (LetDec rep), Exp rep)],
    [(PatElem (LetDec rep), SubExp)], [Type])
-> ([(Pat (LetDec rep), Exp rep)],
    [(PatElem (LetDec rep), SubExp)], [Type])
checkInvariance ([], [], []) forall a b. (a -> b) -> a -> b
$
        forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 (forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec rep)
pat) Result
ses [Type]
rettype of
        ([], [(PatElem (LetDec rep), SubExp)]
_, [Type]
_) -> forall {k} (rep :: k). Rule rep
Skip
        ([(Pat (LetDec rep), Exp rep)]
invariant, [(PatElem (LetDec rep), SubExp)]
mapresult, [Type]
rettype') -> forall {k} (rep :: k). RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$ do
          let ([PatElem (LetDec rep)]
pat', [SubExp]
ses') = forall a b. [(a, b)] -> ([a], [b])
unzip [(PatElem (LetDec rep), SubExp)]
mapresult
              fun' :: Lambda rep
fun' =
                Lambda rep
fun
                  { lambdaBody :: Body rep
lambdaBody = (forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda rep
fun) {bodyResult :: Result
bodyResult = [SubExp] -> Result
subExpsRes [SubExp]
ses'},
                    lambdaReturnType :: [Type]
lambdaReturnType = [Type]
rettype'
                  }
          forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind) [(Pat (LetDec rep), Exp rep)]
invariant
          forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux forall a b. (a -> b) -> a -> b
$
            forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames (forall a b. (a -> b) -> [a] -> [b]
map forall dec. PatElem dec -> VName
patElemName [PatElem (LetDec rep)]
pat') forall a b. (a -> b) -> a -> b
$
              forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$
                forall {k} (rep :: k). HasSOAC rep => SOAC rep -> Op rep
soacOp forall a b. (a -> b) -> a -> b
$
                  forall {k} (rep :: k).
SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName]
arrs (forall {k} (rep :: k). Lambda rep -> ScremaForm rep
mapSOAC Lambda rep
fun')
liftIdentityMapping TopDown rep
_ Pat (LetDec rep)
_ StmAux (ExpDec rep)
_ Op rep
_ = forall {k} (rep :: k). Rule rep
Skip

liftIdentityStreaming :: BottomUpRuleOp (Wise SOACS)
liftIdentityStreaming :: BottomUpRuleOp (Wise SOACS)
liftIdentityStreaming BottomUp (Wise SOACS)
_ (Pat [PatElem (LetDec (Wise SOACS))]
pes) StmAux (ExpDec (Wise SOACS))
aux (Stream SubExp
w [VName]
arrs [SubExp]
nes Lambda (Wise SOACS)
lam)
  | ([(Type, PatElem (VarWisdom, Type), SubExpRes)]
variant_map, [(PatElem (VarWisdom, Type), VName)]
invariant_map) <-
      forall a b. [Either a b] -> ([a], [b])
partitionEithers forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (Type, PatElem (VarWisdom, Type), SubExpRes)
-> Either
     (Type, PatElem (VarWisdom, Type), SubExpRes)
     (PatElem (VarWisdom, Type), VName)
isInvariantRes forall a b. (a -> b) -> a -> b
$ forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Type]
map_ts [PatElem (VarWisdom, Type)]
map_pes Result
map_res,
    Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(PatElem (VarWisdom, Type), VName)]
invariant_map = forall {k} (rep :: k). RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$ do
      forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(PatElem (VarWisdom, Type), VName)]
invariant_map forall a b. (a -> b) -> a -> b
$ \(PatElem (VarWisdom, Type)
pe, VName
arr) ->
        forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind (forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (VarWisdom, Type)
pe]) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
arr

      let ([Type]
variant_map_ts, [PatElem (VarWisdom, Type)]
variant_map_pes, Result
variant_map_res) = forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(Type, PatElem (VarWisdom, Type), SubExpRes)]
variant_map
          lam' :: Lambda (Wise SOACS)
lam' =
            Lambda (Wise SOACS)
lam
              { lambdaBody :: Body (Wise SOACS)
lambdaBody = (forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda (Wise SOACS)
lam) {bodyResult :: Result
bodyResult = Result
fold_res forall a. [a] -> [a] -> [a]
++ Result
variant_map_res},
                lambdaReturnType :: [Type]
lambdaReturnType = [Type]
fold_ts forall a. [a] -> [a] -> [a]
++ [Type]
variant_map_ts
              }

      forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec (Wise SOACS))
aux forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind (forall dec. [PatElem dec] -> Pat dec
Pat forall a b. (a -> b) -> a -> b
$ [PatElem (VarWisdom, Type)]
fold_pes forall a. [a] -> [a] -> [a]
++ [PatElem (VarWisdom, Type)]
variant_map_pes) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$
        forall {k} (rep :: k).
SubExp -> [VName] -> [SubExp] -> Lambda rep -> SOAC rep
Stream SubExp
w [VName]
arrs [SubExp]
nes Lambda (Wise SOACS)
lam'
  where
    num_folds :: Int
num_folds = forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes
    ([PatElem (VarWisdom, Type)]
fold_pes, [PatElem (VarWisdom, Type)]
map_pes) = forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_folds [PatElem (LetDec (Wise SOACS))]
pes
    ([Type]
fold_ts, [Type]
map_ts) = forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_folds forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda (Wise SOACS)
lam
    lam_res :: Result
lam_res = forall {k} (rep :: k). Body rep -> Result
bodyResult forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda (Wise SOACS)
lam
    (Result
fold_res, Result
map_res) = forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_folds Result
lam_res
    params_to_arrs :: [(VName, VName)]
params_to_arrs = forall a b. [a] -> [b] -> [(a, b)]
zip (forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName forall a b. (a -> b) -> a -> b
$ forall a. Int -> [a] -> [a]
drop (Int
1 forall a. Num a => a -> a -> a
+ Int
num_folds) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda (Wise SOACS)
lam) [VName]
arrs

    isInvariantRes :: (Type, PatElem (VarWisdom, Type), SubExpRes)
-> Either
     (Type, PatElem (VarWisdom, Type), SubExpRes)
     (PatElem (VarWisdom, Type), VName)
isInvariantRes (Type
_, PatElem (VarWisdom, Type)
pe, SubExpRes Certs
_ (Var VName
v))
      | Just VName
arr <- forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup VName
v [(VName, VName)]
params_to_arrs =
          forall a b. b -> Either a b
Right (PatElem (VarWisdom, Type)
pe, VName
arr)
    isInvariantRes (Type, PatElem (VarWisdom, Type), SubExpRes)
x =
      forall a b. a -> Either a b
Left (Type, PatElem (VarWisdom, Type), SubExpRes)
x
liftIdentityStreaming BottomUp (Wise SOACS)
_ Pat (LetDec (Wise SOACS))
_ StmAux (ExpDec (Wise SOACS))
_ Op (Wise SOACS)
_ = forall {k} (rep :: k). Rule rep
Skip

-- | Remove all arguments to the map that are simply replicates.
-- These can be turned into free variables instead.
removeReplicateMapping ::
  (Aliased rep, BuilderOps rep, HasSOAC rep) =>
  TopDownRuleOp rep
removeReplicateMapping :: forall rep.
(Aliased rep, BuilderOps rep, HasSOAC rep) =>
TopDownRuleOp rep
removeReplicateMapping TopDown rep
vtable Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux Op rep
op
  | Just (Screma SubExp
w [VName]
arrs ScremaForm rep
form) <- forall {k} (rep :: k). HasSOAC rep => Op rep -> Maybe (SOAC rep)
asSOAC Op rep
op,
    Just Lambda rep
fun <- forall {k} (rep :: k). ScremaForm rep -> Maybe (Lambda rep)
isMapSOAC ScremaForm rep
form,
    Just ([([VName], Certs, Exp rep)]
stms, Lambda rep
fun', [VName]
arrs') <- forall {k} (rep :: k).
Aliased rep =>
SymbolTable rep
-> Lambda rep
-> [VName]
-> Maybe ([([VName], Certs, Exp rep)], Lambda rep, [VName])
removeReplicateInput TopDown rep
vtable Lambda rep
fun [VName]
arrs = forall {k} (rep :: k). RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$ do
      forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [([VName], Certs, Exp rep)]
stms forall a b. (a -> b) -> a -> b
$ \([VName]
vs, Certs
cs, Exp rep
e) -> forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName]
vs Exp rep
e
      forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
pat forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). HasSOAC rep => SOAC rep -> Op rep
soacOp forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName]
arrs' forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> ScremaForm rep
mapSOAC Lambda rep
fun'
removeReplicateMapping TopDown rep
_ Pat (LetDec rep)
_ StmAux (ExpDec rep)
_ Op rep
_ = forall {k} (rep :: k). Rule rep
Skip

-- | Like 'removeReplicateMapping', but for 'Scatter'.
removeReplicateWrite :: TopDownRuleOp (Wise SOACS)
removeReplicateWrite :: TopDownRuleOp (Wise SOACS)
removeReplicateWrite SymbolTable (Wise SOACS)
vtable Pat (LetDec (Wise SOACS))
pat StmAux (ExpDec (Wise SOACS))
aux (Scatter SubExp
w [VName]
ivs Lambda (Wise SOACS)
lam [(Shape, Int, VName)]
as)
  | Just ([([VName], Certs, Exp (Wise SOACS))]
stms, Lambda (Wise SOACS)
lam', [VName]
ivs') <- forall {k} (rep :: k).
Aliased rep =>
SymbolTable rep
-> Lambda rep
-> [VName]
-> Maybe ([([VName], Certs, Exp rep)], Lambda rep, [VName])
removeReplicateInput SymbolTable (Wise SOACS)
vtable Lambda (Wise SOACS)
lam [VName]
ivs = forall {k} (rep :: k). RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$ do
      forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [([VName], Certs, Exp (Wise SOACS))]
stms forall a b. (a -> b) -> a -> b
$ \([VName]
vs, Certs
cs, Exp (Wise SOACS)
e) -> forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName]
vs Exp (Wise SOACS)
e
      forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec (Wise SOACS))
aux forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec (Wise SOACS))
pat forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
SubExp
-> [VName] -> Lambda rep -> [(Shape, Int, VName)] -> SOAC rep
Scatter SubExp
w [VName]
ivs' Lambda (Wise SOACS)
lam' [(Shape, Int, VName)]
as
removeReplicateWrite SymbolTable (Wise SOACS)
_ Pat (LetDec (Wise SOACS))
_ StmAux (ExpDec (Wise SOACS))
_ Op (Wise SOACS)
_ = forall {k} (rep :: k). Rule rep
Skip

removeReplicateInput ::
  Aliased rep =>
  ST.SymbolTable rep ->
  Lambda rep ->
  [VName] ->
  Maybe
    ( [([VName], Certs, Exp rep)],
      Lambda rep,
      [VName]
    )
removeReplicateInput :: forall {k} (rep :: k).
Aliased rep =>
SymbolTable rep
-> Lambda rep
-> [VName]
-> Maybe ([([VName], Certs, Exp rep)], Lambda rep, [VName])
removeReplicateInput SymbolTable rep
vtable Lambda rep
fun [VName]
arrs
  | Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> Bool
null [([VName], Certs, Exp rep)]
parameterBnds = do
      let ([Param (LParamInfo rep)]
arr_params', [VName]
arrs') = forall a b. [(a, b)] -> ([a], [b])
unzip [(Param (LParamInfo rep), VName)]
params_and_arrs
          fun' :: Lambda rep
fun' = Lambda rep
fun {lambdaParams :: [Param (LParamInfo rep)]
lambdaParams = [Param (LParamInfo rep)]
acc_params forall a. Semigroup a => a -> a -> a
<> [Param (LParamInfo rep)]
arr_params'}
      forall (f :: * -> *) a. Applicative f => a -> f a
pure ([([VName], Certs, Exp rep)]
parameterBnds, Lambda rep
fun', [VName]
arrs')
  | Bool
otherwise = forall a. Maybe a
Nothing
  where
    params :: [Param (LParamInfo rep)]
params = forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda rep
fun
    ([Param (LParamInfo rep)]
acc_params, [Param (LParamInfo rep)]
arr_params) =
      forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. Foldable t => t a -> Int
length [Param (LParamInfo rep)]
params forall a. Num a => a -> a -> a
- forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
arrs) [Param (LParamInfo rep)]
params
    ([(Param (LParamInfo rep), VName)]
params_and_arrs, [([VName], Certs, Exp rep)]
parameterBnds) =
      forall a b. [Either a b] -> ([a], [b])
partitionEithers forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Param (LParamInfo rep)
-> VName
-> Either (Param (LParamInfo rep), VName) ([VName], Certs, Exp rep)
isReplicateAndNotConsumed [Param (LParamInfo rep)]
arr_params [VName]
arrs

    isReplicateAndNotConsumed :: Param (LParamInfo rep)
-> VName
-> Either (Param (LParamInfo rep), VName) ([VName], Certs, Exp rep)
isReplicateAndNotConsumed Param (LParamInfo rep)
p VName
v
      | Just (BasicOp (Replicate (Shape (SubExp
_ : [SubExp]
ds)) SubExp
e), Certs
v_cs) <-
          forall {k} (rep :: k).
VName -> SymbolTable rep -> Maybe (Exp rep, Certs)
ST.lookupExp VName
v SymbolTable rep
vtable,
        forall dec. Param dec -> VName
paramName Param (LParamInfo rep)
p VName -> Names -> Bool
`notNameIn` forall {k} (rep :: k). Aliased rep => Lambda rep -> Names
consumedByLambda Lambda rep
fun =
          forall a b. b -> Either a b
Right
            ( [forall dec. Param dec -> VName
paramName Param (LParamInfo rep)
p],
              Certs
v_cs,
              case [SubExp]
ds of
                [] -> forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
e
                [SubExp]
_ -> forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate (forall d. [d] -> ShapeBase d
Shape [SubExp]
ds) SubExp
e
            )
      | Bool
otherwise =
          forall a b. a -> Either a b
Left (Param (LParamInfo rep)
p, VName
v)

-- | Remove inputs that are not used inside the SOAC.
removeUnusedSOACInput ::
  forall rep.
  (Aliased rep, Buildable rep, BuilderOps rep, HasSOAC rep) =>
  TopDownRuleOp rep
removeUnusedSOACInput :: forall rep.
(Aliased rep, Buildable rep, BuilderOps rep, HasSOAC rep) =>
TopDownRuleOp rep
removeUnusedSOACInput TopDown rep
_ Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux Op rep
op
  | Just (Screma SubExp
w [VName]
arrs ScremaForm rep
form :: SOAC rep) <- forall {k} (rep :: k). HasSOAC rep => Op rep -> Maybe (SOAC rep)
asSOAC Op rep
op,
    ScremaForm [Scan rep]
scan [Reduce rep]
reduce Lambda rep
map_lam <- ScremaForm rep
form,
    ([(Param Type, VName)]
used, [(Param Type, VName)]
unused) <- forall a. (a -> Bool) -> [a] -> ([a], [a])
partition (forall {k} {rep :: k} {dec} {b}.
(FreeDec (ExpDec rep), FreeDec (BodyDec rep),
 FreeIn (FParamInfo rep), FreeIn (LParamInfo rep),
 FreeIn (LetDec rep), FreeIn (RetType rep), FreeIn (BranchType rep),
 FreeIn (Op rep)) =>
Lambda rep -> (Param dec, b) -> Bool
usedInput Lambda rep
map_lam) (forall a b. [a] -> [b] -> [(a, b)]
zip (forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda rep
map_lam) [VName]
arrs),
    Bool -> Bool
not (forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(Param Type, VName)]
unused) = forall {k} (rep :: k). RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$ do
      let ([Param Type]
used_params, [VName]
used_arrs) = forall a b. [(a, b)] -> ([a], [b])
unzip [(Param Type, VName)]
used
          map_lam' :: Lambda rep
map_lam' = Lambda rep
map_lam {lambdaParams :: [LParam rep]
lambdaParams = [Param Type]
used_params}
      forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
pat forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). HasSOAC rep => SOAC rep -> Op rep
soacOp forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName]
used_arrs (forall {k} (rep :: k).
[Scan rep] -> [Reduce rep] -> Lambda rep -> ScremaForm rep
ScremaForm [Scan rep]
scan [Reduce rep]
reduce Lambda rep
map_lam')
  where
    used_in_body :: Lambda rep -> Names
used_in_body Lambda rep
map_lam = forall a. FreeIn a => a -> Names
freeIn forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda rep
map_lam
    usedInput :: Lambda rep -> (Param dec, b) -> Bool
usedInput Lambda rep
map_lam (Param dec
param, b
_) = forall dec. Param dec -> VName
paramName Param dec
param VName -> Names -> Bool
`nameIn` forall {k} {rep :: k}.
(FreeDec (ExpDec rep), FreeDec (BodyDec rep),
 FreeIn (FParamInfo rep), FreeIn (LParamInfo rep),
 FreeIn (LetDec rep), FreeIn (RetType rep), FreeIn (BranchType rep),
 FreeIn (Op rep)) =>
Lambda rep -> Names
used_in_body Lambda rep
map_lam
removeUnusedSOACInput TopDown rep
_ Pat (LetDec rep)
_ StmAux (ExpDec rep)
_ Op rep
_ = forall {k} (rep :: k). Rule rep
Skip

removeDeadMapping :: BottomUpRuleOp (Wise SOACS)
removeDeadMapping :: BottomUpRuleOp (Wise SOACS)
removeDeadMapping (SymbolTable (Wise SOACS)
_, UsageTable
used) (Pat [PatElem (LetDec (Wise SOACS))]
pes) StmAux (ExpDec (Wise SOACS))
aux (Screma SubExp
w [VName]
arrs (ScremaForm [Scan (Wise SOACS)]
scans [Reduce (Wise SOACS)]
reds Lambda (Wise SOACS)
lam))
  | ([PatElem (VarWisdom, Type)]
nonmap_pes, [PatElem (VarWisdom, Type)]
map_pes) <- forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_nonmap_res [PatElem (LetDec (Wise SOACS))]
pes,
    Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> Bool
null [PatElem (VarWisdom, Type)]
map_pes =
      let (Result
nonmap_res, Result
map_res) = forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_nonmap_res forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Body rep -> Result
bodyResult forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda (Wise SOACS)
lam
          ([Type]
nonmap_ts, [Type]
map_ts) = forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_nonmap_res forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda (Wise SOACS)
lam
          isUsed :: (PatElem (VarWisdom, Type), SubExpRes, Type) -> Bool
isUsed (PatElem (VarWisdom, Type)
bindee, SubExpRes
_, Type
_) = (VName -> UsageTable -> Bool
`UT.used` UsageTable
used) forall a b. (a -> b) -> a -> b
$ forall dec. PatElem dec -> VName
patElemName PatElem (VarWisdom, Type)
bindee
          ([PatElem (VarWisdom, Type)]
map_pes', Result
map_res', [Type]
map_ts') =
            forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 forall a b. (a -> b) -> a -> b
$ forall a. (a -> Bool) -> [a] -> [a]
filter (PatElem (VarWisdom, Type), SubExpRes, Type) -> Bool
isUsed forall a b. (a -> b) -> a -> b
$ forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [PatElem (VarWisdom, Type)]
map_pes Result
map_res [Type]
map_ts
          lam' :: Lambda (Wise SOACS)
lam' =
            Lambda (Wise SOACS)
lam
              { lambdaBody :: Body (Wise SOACS)
lambdaBody = (forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda (Wise SOACS)
lam) {bodyResult :: Result
bodyResult = Result
nonmap_res forall a. Semigroup a => a -> a -> a
<> Result
map_res'},
                lambdaReturnType :: [Type]
lambdaReturnType = [Type]
nonmap_ts forall a. Semigroup a => a -> a -> a
<> [Type]
map_ts'
              }
       in if [PatElem (VarWisdom, Type)]
map_pes forall a. Eq a => a -> a -> Bool
/= [PatElem (VarWisdom, Type)]
map_pes'
            then
              forall {k} (rep :: k). RuleM rep () -> Rule rep
Simplify forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec (Wise SOACS))
aux forall a b. (a -> b) -> a -> b
$
                forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind (forall dec. [PatElem dec] -> Pat dec
Pat forall a b. (a -> b) -> a -> b
$ [PatElem (VarWisdom, Type)]
nonmap_pes forall a. Semigroup a => a -> a -> a
<> [PatElem (VarWisdom, Type)]
map_pes') forall a b. (a -> b) -> a -> b
$
                  forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$
                    forall {k} (rep :: k).
SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName]
arrs forall a b. (a -> b) -> a -> b
$
                      forall {k} (rep :: k).
[Scan rep] -> [Reduce rep] -> Lambda rep -> ScremaForm rep
ScremaForm [Scan (Wise SOACS)]
scans [Reduce (Wise SOACS)]
reds Lambda (Wise SOACS)
lam'
            else forall {k} (rep :: k). Rule rep
Skip
  where
    num_nonmap_res :: Int
num_nonmap_res = forall {k} (rep :: k). [Scan rep] -> Int
scanResults [Scan (Wise SOACS)]
scans forall a. Num a => a -> a -> a
+ forall {k} (rep :: k). [Reduce rep] -> Int
redResults [Reduce (Wise SOACS)]
reds
removeDeadMapping BottomUp (Wise SOACS)
_ Pat (LetDec (Wise SOACS))
_ StmAux (ExpDec (Wise SOACS))
_ Op (Wise SOACS)
_ = forall {k} (rep :: k). Rule rep
Skip

removeDuplicateMapOutput :: TopDownRuleOp (Wise SOACS)
removeDuplicateMapOutput :: TopDownRuleOp (Wise SOACS)
removeDuplicateMapOutput SymbolTable (Wise SOACS)
_ (Pat [PatElem (LetDec (Wise SOACS))]
pes) StmAux (ExpDec (Wise SOACS))
aux (Screma SubExp
w [VName]
arrs ScremaForm (Wise SOACS)
form)
  | Just Lambda (Wise SOACS)
fun <- forall {k} (rep :: k). ScremaForm rep -> Maybe (Lambda rep)
isMapSOAC ScremaForm (Wise SOACS)
form =
      let ses :: Result
ses = forall {k} (rep :: k). Body rep -> Result
bodyResult forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda (Wise SOACS)
fun
          ts :: [Type]
ts = forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda (Wise SOACS)
fun
          ses_ts_pes :: [(SubExpRes, Type, PatElem (VarWisdom, Type))]
ses_ts_pes = forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 Result
ses [Type]
ts [PatElem (LetDec (Wise SOACS))]
pes
          ([(SubExpRes, Type, PatElem (VarWisdom, Type))]
ses_ts_pes', [(PatElem (VarWisdom, Type), PatElem (VarWisdom, Type))]
copies) =
            forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl forall {b} {a}.
([(SubExpRes, b, a)], [(a, a)])
-> (SubExpRes, b, a) -> ([(SubExpRes, b, a)], [(a, a)])
checkForDuplicates (forall a. Monoid a => a
mempty, forall a. Monoid a => a
mempty) [(SubExpRes, Type, PatElem (VarWisdom, Type))]
ses_ts_pes
       in if forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(PatElem (VarWisdom, Type), PatElem (VarWisdom, Type))]
copies
            then forall {k} (rep :: k). Rule rep
Skip
            else forall {k} (rep :: k). RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$ do
              let (Result
ses', [Type]
ts', [PatElem (VarWisdom, Type)]
pes') = forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(SubExpRes, Type, PatElem (VarWisdom, Type))]
ses_ts_pes'
                  fun' :: Lambda (Wise SOACS)
fun' =
                    Lambda (Wise SOACS)
fun
                      { lambdaBody :: Body (Wise SOACS)
lambdaBody = (forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda (Wise SOACS)
fun) {bodyResult :: Result
bodyResult = Result
ses'},
                        lambdaReturnType :: [Type]
lambdaReturnType = [Type]
ts'
                      }
              forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec (Wise SOACS))
aux forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind (forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (VarWisdom, Type)]
pes') forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName]
arrs forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> ScremaForm rep
mapSOAC Lambda (Wise SOACS)
fun'
              forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(PatElem (VarWisdom, Type), PatElem (VarWisdom, Type))]
copies forall a b. (a -> b) -> a -> b
$ \(PatElem (VarWisdom, Type)
from, PatElem (VarWisdom, Type)
to) ->
                forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind (forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (VarWisdom, Type)
to]) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy forall a b. (a -> b) -> a -> b
$ forall dec. PatElem dec -> VName
patElemName PatElem (VarWisdom, Type)
from
  where
    checkForDuplicates :: ([(SubExpRes, b, a)], [(a, a)])
-> (SubExpRes, b, a) -> ([(SubExpRes, b, a)], [(a, a)])
checkForDuplicates ([(SubExpRes, b, a)]
ses_ts_pes', [(a, a)]
copies) (SubExpRes
se, b
t, a
pe)
      | Just (SubExpRes
_, b
_, a
pe') <- forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (\(SubExpRes
x, b
_, a
_) -> SubExpRes -> SubExp
resSubExp SubExpRes
x forall a. Eq a => a -> a -> Bool
== SubExpRes -> SubExp
resSubExp SubExpRes
se) [(SubExpRes, b, a)]
ses_ts_pes' =
          -- This result has been returned before, producing the
          -- array pe'.
          ([(SubExpRes, b, a)]
ses_ts_pes', (a
pe', a
pe) forall a. a -> [a] -> [a]
: [(a, a)]
copies)
      | Bool
otherwise = ([(SubExpRes, b, a)]
ses_ts_pes' forall a. [a] -> [a] -> [a]
++ [(SubExpRes
se, b
t, a
pe)], [(a, a)]
copies)
removeDuplicateMapOutput SymbolTable (Wise SOACS)
_ Pat (LetDec (Wise SOACS))
_ StmAux (ExpDec (Wise SOACS))
_ Op (Wise SOACS)
_ = forall {k} (rep :: k). Rule rep
Skip

-- Mapping some operations becomes an extension of that operation.
mapOpToOp :: BottomUpRuleOp (Wise SOACS)
mapOpToOp :: BottomUpRuleOp (Wise SOACS)
mapOpToOp (SymbolTable (Wise SOACS)
_, UsageTable
used) Pat (LetDec (Wise SOACS))
pat StmAux (ExpDec (Wise SOACS))
aux1 Op (Wise SOACS)
e
  | Just (PatElem (VarWisdom, Type)
map_pe, Certs
cs, SubExp
w, BasicOp (Reshape ReshapeKind
k Shape
newshape VName
reshape_arr), [Param Type
p], [VName
arr]) <-
      forall dec.
Pat dec
-> SOAC (Wise SOACS)
-> Maybe
     (PatElem dec, Certs, SubExp, Exp (Wise SOACS), [Param Type],
      [VName])
isMapWithOp Pat (LetDec (Wise SOACS))
pat Op (Wise SOACS)
e,
    forall dec. Param dec -> VName
paramName Param Type
p forall a. Eq a => a -> a -> Bool
== VName
reshape_arr,
    Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ VName -> UsageTable -> Bool
UT.isConsumed (forall dec. PatElem dec -> VName
patElemName PatElem (VarWisdom, Type)
map_pe) UsageTable
used = forall {k} (rep :: k). RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$ do
      forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (ExpDec (Wise SOACS))
aux1 forall a. Semigroup a => a -> a -> a
<> Certs
cs) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec (Wise SOACS))
pat forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
        ReshapeKind -> Shape -> VName -> BasicOp
Reshape ReshapeKind
k (forall d. [d] -> ShapeBase d
Shape [SubExp
w] forall a. Semigroup a => a -> a -> a
<> Shape
newshape) VName
arr
  | Just (PatElem (VarWisdom, Type)
_, Certs
cs, SubExp
_, BasicOp (Concat Int
d (VName
arr :| [VName]
arrs) SubExp
dw), [Param Type]
ps, VName
outer_arr : [VName]
outer_arrs) <-
      forall dec.
Pat dec
-> SOAC (Wise SOACS)
-> Maybe
     (PatElem dec, Certs, SubExp, Exp (Wise SOACS), [Param Type],
      [VName])
isMapWithOp Pat (LetDec (Wise SOACS))
pat Op (Wise SOACS)
e,
    (VName
arr forall a. a -> [a] -> [a]
: [VName]
arrs) forall a. Eq a => a -> a -> Bool
== forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName [Param Type]
ps =
      forall {k} (rep :: k). RuleM rep () -> Rule rep
Simplify forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (ExpDec (Wise SOACS))
aux1 forall a. Semigroup a => a -> a -> a
<> Certs
cs) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec (Wise SOACS))
pat forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
        Int -> NonEmpty VName -> SubExp -> BasicOp
Concat (Int
d forall a. Num a => a -> a -> a
+ Int
1) (VName
outer_arr forall a. a -> [a] -> NonEmpty a
:| [VName]
outer_arrs) SubExp
dw
  | Just
      (PatElem (VarWisdom, Type)
map_pe, Certs
cs, SubExp
_, BasicOp (Rearrange [Int]
perm VName
rearrange_arr), [Param Type
p], [VName
arr]) <-
      forall dec.
Pat dec
-> SOAC (Wise SOACS)
-> Maybe
     (PatElem dec, Certs, SubExp, Exp (Wise SOACS), [Param Type],
      [VName])
isMapWithOp Pat (LetDec (Wise SOACS))
pat Op (Wise SOACS)
e,
    forall dec. Param dec -> VName
paramName Param Type
p forall a. Eq a => a -> a -> Bool
== VName
rearrange_arr,
    Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ VName -> UsageTable -> Bool
UT.isConsumed (forall dec. PatElem dec -> VName
patElemName PatElem (VarWisdom, Type)
map_pe) UsageTable
used =
      forall {k} (rep :: k). RuleM rep () -> Rule rep
Simplify forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (ExpDec (Wise SOACS))
aux1 forall a. Semigroup a => a -> a -> a
<> Certs
cs) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec (Wise SOACS))
pat forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
        [Int] -> VName -> BasicOp
Rearrange (Int
0 forall a. a -> [a] -> [a]
: forall a b. (a -> b) -> [a] -> [b]
map (Int
1 +) [Int]
perm) VName
arr
  | Just (PatElem (VarWisdom, Type)
map_pe, Certs
cs, SubExp
_, BasicOp (Rotate [SubExp]
rots VName
rotate_arr), [Param Type
p], [VName
arr]) <-
      forall dec.
Pat dec
-> SOAC (Wise SOACS)
-> Maybe
     (PatElem dec, Certs, SubExp, Exp (Wise SOACS), [Param Type],
      [VName])
isMapWithOp Pat (LetDec (Wise SOACS))
pat Op (Wise SOACS)
e,
    forall dec. Param dec -> VName
paramName Param Type
p forall a. Eq a => a -> a -> Bool
== VName
rotate_arr,
    Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ VName -> UsageTable -> Bool
UT.isConsumed (forall dec. PatElem dec -> VName
patElemName PatElem (VarWisdom, Type)
map_pe) UsageTable
used =
      forall {k} (rep :: k). RuleM rep () -> Rule rep
Simplify forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (ExpDec (Wise SOACS))
aux1 forall a. Semigroup a => a -> a -> a
<> Certs
cs) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec (Wise SOACS))
pat forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
        [SubExp] -> VName -> BasicOp
Rotate (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0 forall a. a -> [a] -> [a]
: [SubExp]
rots) VName
arr
mapOpToOp BottomUp (Wise SOACS)
_ Pat (LetDec (Wise SOACS))
_ StmAux (ExpDec (Wise SOACS))
_ Op (Wise SOACS)
_ = forall {k} (rep :: k). Rule rep
Skip

isMapWithOp ::
  Pat dec ->
  SOAC (Wise SOACS) ->
  Maybe
    ( PatElem dec,
      Certs,
      SubExp,
      Exp (Wise SOACS),
      [Param Type],
      [VName]
    )
isMapWithOp :: forall dec.
Pat dec
-> SOAC (Wise SOACS)
-> Maybe
     (PatElem dec, Certs, SubExp, Exp (Wise SOACS), [Param Type],
      [VName])
isMapWithOp Pat dec
pat SOAC (Wise SOACS)
e
  | Pat [PatElem dec
map_pe] <- Pat dec
pat,
    Screma SubExp
w [VName]
arrs ScremaForm (Wise SOACS)
form <- SOAC (Wise SOACS)
e,
    Just Lambda (Wise SOACS)
map_lam <- forall {k} (rep :: k). ScremaForm rep -> Maybe (Lambda rep)
isMapSOAC ScremaForm (Wise SOACS)
form,
    [Let (Pat [PatElem (LetDec (Wise SOACS))
pe]) StmAux (ExpDec (Wise SOACS))
aux2 Exp (Wise SOACS)
e'] <- forall {k} (rep :: k). Stms rep -> [Stm rep]
stmsToList forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Body rep -> Stms rep
bodyStms forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda (Wise SOACS)
map_lam,
    [SubExpRes Certs
_ (Var VName
r)] <- forall {k} (rep :: k). Body rep -> Result
bodyResult forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda (Wise SOACS)
map_lam,
    VName
r forall a. Eq a => a -> a -> Bool
== forall dec. PatElem dec -> VName
patElemName PatElem (LetDec (Wise SOACS))
pe =
      forall a. a -> Maybe a
Just (PatElem dec
map_pe, forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (ExpDec (Wise SOACS))
aux2, SubExp
w, Exp (Wise SOACS)
e', forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda (Wise SOACS)
map_lam, [VName]
arrs)
  | Bool
otherwise = forall a. Maybe a
Nothing

-- | Some of the results of a reduction (or really: Redomap) may be
-- dead.  We remove them here.  The trick is that we need to look at
-- the data dependencies to see that the "dead" result is not
-- actually used for computing one of the live ones.
removeDeadReduction :: BottomUpRuleOp (Wise SOACS)
removeDeadReduction :: BottomUpRuleOp (Wise SOACS)
removeDeadReduction (SymbolTable (Wise SOACS)
_, UsageTable
used) Pat (LetDec (Wise SOACS))
pat StmAux (ExpDec (Wise SOACS))
aux (Screma SubExp
w [VName]
arrs ScremaForm (Wise SOACS)
form)
  | Just ([Reduce Commutativity
comm Lambda (Wise SOACS)
redlam [SubExp]
nes], Lambda (Wise SOACS)
maplam) <- forall {k} (rep :: k).
ScremaForm rep -> Maybe ([Reduce rep], Lambda rep)
isRedomapSOAC ScremaForm (Wise SOACS)
form,
    Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> UsageTable -> Bool
`UT.used` UsageTable
used) forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [VName]
patNames Pat (LetDec (Wise SOACS))
pat, -- Quick/cheap check
    let ([PatElem (VarWisdom, Type)]
red_pes, [PatElem (VarWisdom, Type)]
map_pes) = forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec (Wise SOACS))
pat,
    let redlam_deps :: Dependencies
redlam_deps = forall {k} (rep :: k). ASTRep rep => Body rep -> Dependencies
dataDependencies forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda (Wise SOACS)
redlam,
    let redlam_res :: Result
redlam_res = forall {k} (rep :: k). Body rep -> Result
bodyResult forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda (Wise SOACS)
redlam,
    let redlam_params :: [LParam (Wise SOACS)]
redlam_params = forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda (Wise SOACS)
redlam,
    let used_after :: [Param Type]
used_after =
          forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. (a -> Bool) -> [a] -> [a]
filter ((VName -> UsageTable -> Bool
`UT.used` UsageTable
used) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. PatElem dec -> VName
patElemName forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) forall a b. (a -> b) -> a -> b
$
            forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem (VarWisdom, Type)]
red_pes [LParam (Wise SOACS)]
redlam_params,
    let necessary :: Names
necessary =
          forall dec.
(Param dec -> Bool)
-> [(Param dec, SubExp)] -> Dependencies -> Names
findNecessaryForReturned
            (forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Param Type]
used_after)
            (forall a b. [a] -> [b] -> [(a, b)]
zip [LParam (Wise SOACS)]
redlam_params forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp forall a b. (a -> b) -> a -> b
$ Result
redlam_res forall a. Semigroup a => a -> a -> a
<> Result
redlam_res)
            Dependencies
redlam_deps,
    let alive_mask :: [Bool]
alive_mask = forall a b. (a -> b) -> [a] -> [b]
map ((VName -> Names -> Bool
`nameIn` Names
necessary) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. Param dec -> VName
paramName) [LParam (Wise SOACS)]
redlam_params,
    Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (forall a. Eq a => a -> a -> Bool
== Bool
True) (forall a. Int -> [a] -> [a]
take (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) [Bool]
alive_mask) = forall {k} (rep :: k). RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$ do
      let fixDeadToNeutral :: Bool -> a -> Maybe a
fixDeadToNeutral Bool
lives a
ne = if Bool
lives then forall a. Maybe a
Nothing else forall a. a -> Maybe a
Just a
ne
          dead_fix :: [Maybe SubExp]
dead_fix = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall {a}. Bool -> a -> Maybe a
fixDeadToNeutral [Bool]
alive_mask [SubExp]
nes
          ([PatElem (VarWisdom, Type)]
used_red_pes, [Param Type]
_, [SubExp]
used_nes) =
            forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. (a -> Bool) -> [a] -> [a]
filter (\(PatElem (VarWisdom, Type)
_, Param Type
x, SubExp
_) -> forall dec. Param dec -> VName
paramName Param Type
x VName -> Names -> Bool
`nameIn` Names
necessary) forall a b. (a -> b) -> a -> b
$
              forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [PatElem (VarWisdom, Type)]
red_pes [LParam (Wise SOACS)]
redlam_params [SubExp]
nes

      let maplam' :: Lambda (Wise SOACS)
maplam' = forall {k} (rep :: k). [Bool] -> Lambda rep -> Lambda rep
removeLambdaResults (forall a. Int -> [a] -> [a]
take (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) [Bool]
alive_mask) Lambda (Wise SOACS)
maplam
      Lambda (Wise SOACS)
redlam' <- forall {k} (rep :: k). [Bool] -> Lambda rep -> Lambda rep
removeLambdaResults (forall a. Int -> [a] -> [a]
take (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) [Bool]
alive_mask) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *).
(MonadBuilder m, Buildable (Rep m), BuilderOps (Rep m)) =>
Lambda (Rep m) -> [Maybe SubExp] -> m (Lambda (Rep m))
fixLambdaParams Lambda (Wise SOACS)
redlam ([Maybe SubExp]
dead_fix forall a. [a] -> [a] -> [a]
++ [Maybe SubExp]
dead_fix)

      forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec (Wise SOACS))
aux forall a b. (a -> b) -> a -> b
$
        forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind (forall dec. [PatElem dec] -> Pat dec
Pat forall a b. (a -> b) -> a -> b
$ [PatElem (VarWisdom, Type)]
used_red_pes forall a. [a] -> [a] -> [a]
++ [PatElem (VarWisdom, Type)]
map_pes) forall a b. (a -> b) -> a -> b
$
          forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$
            forall {k} (rep :: k).
SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName]
arrs forall a b. (a -> b) -> a -> b
$
              forall {k} (rep :: k). [Reduce rep] -> Lambda rep -> ScremaForm rep
redomapSOAC [forall {k} (rep :: k).
Commutativity -> Lambda rep -> [SubExp] -> Reduce rep
Reduce Commutativity
comm Lambda (Wise SOACS)
redlam' [SubExp]
used_nes] Lambda (Wise SOACS)
maplam'
removeDeadReduction BottomUp (Wise SOACS)
_ Pat (LetDec (Wise SOACS))
_ StmAux (ExpDec (Wise SOACS))
_ Op (Wise SOACS)
_ = forall {k} (rep :: k). Rule rep
Skip

-- | If we are writing to an array that is never used, get rid of it.
removeDeadWrite :: BottomUpRuleOp (Wise SOACS)
removeDeadWrite :: BottomUpRuleOp (Wise SOACS)
removeDeadWrite (SymbolTable (Wise SOACS)
_, UsageTable
used) Pat (LetDec (Wise SOACS))
pat StmAux (ExpDec (Wise SOACS))
aux (Scatter SubExp
w [VName]
arrs Lambda (Wise SOACS)
fun [(Shape, Int, VName)]
dests) =
  let ([Result]
i_ses, Result
v_ses) = forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ forall array a. [(Shape, Int, array)] -> [a] -> [([a], a)]
groupScatterResults' [(Shape, Int, VName)]
dests forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Body rep -> Result
bodyResult forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda (Wise SOACS)
fun
      ([[Type]]
i_ts, [Type]
v_ts) = forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ forall array a. [(Shape, Int, array)] -> [a] -> [([a], a)]
groupScatterResults' [(Shape, Int, VName)]
dests forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda (Wise SOACS)
fun
      isUsed :: (PatElem (VarWisdom, Type), Result, SubExpRes, [Type], Type,
 (Shape, Int, VName))
-> Bool
isUsed (PatElem (VarWisdom, Type)
bindee, Result
_, SubExpRes
_, [Type]
_, Type
_, (Shape, Int, VName)
_) = (VName -> UsageTable -> Bool
`UT.used` UsageTable
used) forall a b. (a -> b) -> a -> b
$ forall dec. PatElem dec -> VName
patElemName PatElem (VarWisdom, Type)
bindee
      ([PatElem (VarWisdom, Type)]
pat', [Result]
i_ses', Result
v_ses', [[Type]]
i_ts', [Type]
v_ts', [(Shape, Int, VName)]
dests') =
        forall a b c d e f.
[(a, b, c, d, e, f)] -> ([a], [b], [c], [d], [e], [f])
unzip6 forall a b. (a -> b) -> a -> b
$ forall a. (a -> Bool) -> [a] -> [a]
filter (PatElem (VarWisdom, Type), Result, SubExpRes, [Type], Type,
 (Shape, Int, VName))
-> Bool
isUsed forall a b. (a -> b) -> a -> b
$ forall a b c d e f.
[a] -> [b] -> [c] -> [d] -> [e] -> [f] -> [(a, b, c, d, e, f)]
zip6 (forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec (Wise SOACS))
pat) [Result]
i_ses Result
v_ses [[Type]]
i_ts [Type]
v_ts [(Shape, Int, VName)]
dests
      fun' :: Lambda (Wise SOACS)
fun' =
        Lambda (Wise SOACS)
fun
          { lambdaBody :: Body (Wise SOACS)
lambdaBody = (forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda (Wise SOACS)
fun) {bodyResult :: Result
bodyResult = forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [Result]
i_ses' forall a. [a] -> [a] -> [a]
++ Result
v_ses'},
            lambdaReturnType :: [Type]
lambdaReturnType = forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Type]]
i_ts' forall a. [a] -> [a] -> [a]
++ [Type]
v_ts'
          }
   in if Pat (LetDec (Wise SOACS))
pat forall a. Eq a => a -> a -> Bool
/= forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (VarWisdom, Type)]
pat'
        then
          forall {k} (rep :: k). RuleM rep () -> Rule rep
Simplify forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec (Wise SOACS))
aux forall a b. (a -> b) -> a -> b
$
            forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind (forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (VarWisdom, Type)]
pat') forall a b. (a -> b) -> a -> b
$
              forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$
                forall {k} (rep :: k).
SubExp
-> [VName] -> Lambda rep -> [(Shape, Int, VName)] -> SOAC rep
Scatter SubExp
w [VName]
arrs Lambda (Wise SOACS)
fun' [(Shape, Int, VName)]
dests'
        else forall {k} (rep :: k). Rule rep
Skip
removeDeadWrite BottomUp (Wise SOACS)
_ Pat (LetDec (Wise SOACS))
_ StmAux (ExpDec (Wise SOACS))
_ Op (Wise SOACS)
_ = forall {k} (rep :: k). Rule rep
Skip

-- handles now concatenation of more than two arrays
fuseConcatScatter :: TopDownRuleOp (Wise SOACS)
fuseConcatScatter :: TopDownRuleOp (Wise SOACS)
fuseConcatScatter SymbolTable (Wise SOACS)
vtable Pat (LetDec (Wise SOACS))
pat StmAux (ExpDec (Wise SOACS))
_ (Scatter SubExp
_ [VName]
arrs Lambda (Wise SOACS)
fun [(Shape, Int, VName)]
dests)
  | Just (ws :: [SubExp]
ws@(SubExp
w' : [SubExp]
_), [[VName]]
xss, [Certs]
css) <- forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> Maybe (SubExp, [VName], Certs)
isConcat [VName]
arrs,
    [[VName]]
xivs <- forall a. [[a]] -> [[a]]
transpose [[VName]]
xss,
    forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (SubExp
w' ==) [SubExp]
ws = forall {k} (rep :: k). RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$ do
      let r :: Int
r = forall (t :: * -> *) a. Foldable t => t a -> Int
length [[VName]]
xivs
      [Lambda (Wise SOACS)]
fun2s <- forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM (Int
r forall a. Num a => a -> a -> a
- Int
1) (forall {k} (rep :: k) (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda (Wise SOACS)
fun)
      let ([Result]
fun_is, [Result]
fun_vs) =
            forall a b. [(a, b)] -> ([a], [b])
unzip forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map (forall array a. [(Shape, Int, array)] -> [a] -> ([a], [a])
splitScatterResults [(Shape, Int, VName)]
dests forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Body rep -> Result
bodyResult forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody) forall a b. (a -> b) -> a -> b
$
              Lambda (Wise SOACS)
fun forall a. a -> [a] -> [a]
: [Lambda (Wise SOACS)]
fun2s
          ([[Type]]
its, [[Type]]
vts) =
            forall a b. [(a, b)] -> ([a], [b])
unzip forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Int -> a -> [a]
replicate Int
r forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall array a. [(Shape, Int, array)] -> [a] -> ([a], [a])
splitScatterResults [(Shape, Int, VName)]
dests forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda (Wise SOACS)
fun
          new_stmts :: Stms (Wise SOACS)
new_stmts = forall a. Monoid a => [a] -> a
mconcat forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (forall {k} (rep :: k). Body rep -> Stms rep
bodyStms forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody) (Lambda (Wise SOACS)
fun forall a. a -> [a] -> [a]
: [Lambda (Wise SOACS)]
fun2s)
      let fun' :: Lambda (Wise SOACS)
fun' =
            Lambda
              { lambdaParams :: [LParam (Wise SOACS)]
lambdaParams = forall a. Monoid a => [a] -> a
mconcat forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams (Lambda (Wise SOACS)
fun forall a. a -> [a] -> [a]
: [Lambda (Wise SOACS)]
fun2s),
                lambdaBody :: Body (Wise SOACS)
lambdaBody = forall {k} (rep :: k).
Buildable rep =>
Stms rep -> Result -> Body rep
mkBody Stms (Wise SOACS)
new_stmts forall a b. (a -> b) -> a -> b
$ forall {a}. [[a]] -> [a]
mix [Result]
fun_is forall a. Semigroup a => a -> a -> a
<> forall {a}. [[a]] -> [a]
mix [Result]
fun_vs,
                lambdaReturnType :: [Type]
lambdaReturnType = forall {a}. [[a]] -> [a]
mix [[Type]]
its forall a. Semigroup a => a -> a -> a
<> forall {a}. [[a]] -> [a]
mix [[Type]]
vts
              }
      forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (forall a. Monoid a => [a] -> a
mconcat [Certs]
css) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec (Wise SOACS))
pat forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$
        forall {k} (rep :: k).
SubExp
-> [VName] -> Lambda rep -> [(Shape, Int, VName)] -> SOAC rep
Scatter SubExp
w' (forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[VName]]
xivs) Lambda (Wise SOACS)
fun' forall a b. (a -> b) -> a -> b
$
          forall a b. (a -> b) -> [a] -> [b]
map (forall {b} {a} {c}. Num b => b -> (a, b, c) -> (a, b, c)
incWrites Int
r) [(Shape, Int, VName)]
dests
  where
    sizeOf :: VName -> Maybe SubExp
    sizeOf :: VName -> Maybe SubExp
sizeOf VName
x = forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall t. Typed t => t -> Type
typeOf forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k).
VName -> SymbolTable rep -> Maybe (Entry rep)
ST.lookup VName
x SymbolTable (Wise SOACS)
vtable
    mix :: [[a]] -> [a]
mix = forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. [[a]] -> [[a]]
transpose
    incWrites :: b -> (a, b, c) -> (a, b, c)
incWrites b
r (a
w, b
n, c
a) = (a
w, b
n forall a. Num a => a -> a -> a
* b
r, c
a) -- ToDO: is it (n*r) or (n+r-1)??
    isConcat :: VName -> Maybe (SubExp, [VName], Certs)
isConcat VName
v = case forall {k} (rep :: k).
VName -> SymbolTable rep -> Maybe (Exp rep, Certs)
ST.lookupExp VName
v SymbolTable (Wise SOACS)
vtable of
      Just (BasicOp (Concat Int
0 (VName
x :| [VName]
ys) SubExp
_), Certs
cs) -> do
        SubExp
x_w <- VName -> Maybe SubExp
sizeOf VName
x
        [SubExp]
y_ws <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> Maybe SubExp
sizeOf [VName]
ys
        forall (f :: * -> *). Alternative f => Bool -> f ()
guard forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (SubExp
x_w ==) [SubExp]
y_ws
        forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp
x_w, VName
x forall a. a -> [a] -> [a]
: [VName]
ys, Certs
cs)
      Just (BasicOp (Reshape ReshapeKind
ReshapeCoerce Shape
_ VName
arr), Certs
cs) -> do
        (SubExp
a, [VName]
b, Certs
cs') <- VName -> Maybe (SubExp, [VName], Certs)
isConcat VName
arr
        forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp
a, [VName]
b, Certs
cs forall a. Semigroup a => a -> a -> a
<> Certs
cs')
      Maybe (Exp (Wise SOACS), Certs)
_ -> forall a. Maybe a
Nothing
fuseConcatScatter SymbolTable (Wise SOACS)
_ Pat (LetDec (Wise SOACS))
_ StmAux (ExpDec (Wise SOACS))
_ Op (Wise SOACS)
_ = forall {k} (rep :: k). Rule rep
Skip

simplifyClosedFormReduce :: TopDownRuleOp (Wise SOACS)
simplifyClosedFormReduce :: TopDownRuleOp (Wise SOACS)
simplifyClosedFormReduce SymbolTable (Wise SOACS)
_ Pat (LetDec (Wise SOACS))
pat StmAux (ExpDec (Wise SOACS))
_ (Screma (Constant PrimValue
w) [VName]
_ ScremaForm (Wise SOACS)
form)
  | Just [SubExp]
nes <- forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap forall {k} (rep :: k). Reduce rep -> [SubExp]
redNeutral forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k).
ScremaForm rep -> Maybe ([Reduce rep], Lambda rep)
isRedomapSOAC ScremaForm (Wise SOACS)
form,
    PrimValue -> Bool
zeroIsh PrimValue
w =
      forall {k} (rep :: k). RuleM rep () -> Rule rep
Simplify forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip (forall dec. Pat dec -> [VName]
patNames Pat (LetDec (Wise SOACS))
pat) [SubExp]
nes) forall a b. (a -> b) -> a -> b
$ \(VName
v, SubExp
ne) ->
        forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
v] forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
ne
simplifyClosedFormReduce SymbolTable (Wise SOACS)
vtable Pat (LetDec (Wise SOACS))
pat StmAux (ExpDec (Wise SOACS))
_ (Screma SubExp
_ [VName]
arrs ScremaForm (Wise SOACS)
form)
  | Just [Reduce Commutativity
_ Lambda (Wise SOACS)
red_fun [SubExp]
nes] <- forall {k} (rep :: k). ScremaForm rep -> Maybe [Reduce rep]
isReduceSOAC ScremaForm (Wise SOACS)
form =
      forall {k} (rep :: k). RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$ forall rep.
BuilderOps rep =>
VarLookup rep
-> Pat (LetDec rep)
-> Lambda rep
-> [SubExp]
-> [VName]
-> RuleM rep ()
foldClosedForm (forall {k} (rep :: k).
VName -> SymbolTable rep -> Maybe (Exp rep, Certs)
`ST.lookupExp` SymbolTable (Wise SOACS)
vtable) Pat (LetDec (Wise SOACS))
pat Lambda (Wise SOACS)
red_fun [SubExp]
nes [VName]
arrs
simplifyClosedFormReduce SymbolTable (Wise SOACS)
_ Pat (LetDec (Wise SOACS))
_ StmAux (ExpDec (Wise SOACS))
_ Op (Wise SOACS)
_ = forall {k} (rep :: k). Rule rep
Skip

-- For now we just remove singleton SOACs.
simplifyKnownIterationSOAC ::
  (Buildable rep, BuilderOps rep, HasSOAC rep) =>
  TopDownRuleOp rep
simplifyKnownIterationSOAC :: forall rep.
(Buildable rep, BuilderOps rep, HasSOAC rep) =>
TopDownRuleOp rep
simplifyKnownIterationSOAC TopDown rep
_ Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ Op rep
op
  | Just (Screma (Constant PrimValue
k) [VName]
arrs (ScremaForm [Scan rep]
scans [Reduce rep]
reds Lambda rep
map_lam)) <- forall {k} (rep :: k). HasSOAC rep => Op rep -> Maybe (SOAC rep)
asSOAC Op rep
op,
    PrimValue -> Bool
oneIsh PrimValue
k = forall {k} (rep :: k). RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$ do
      let (Reduce Commutativity
_ Lambda rep
red_lam [SubExp]
red_nes) = forall {k} (rep :: k). Buildable rep => [Reduce rep] -> Reduce rep
singleReduce [Reduce rep]
reds
          (Scan Lambda rep
scan_lam [SubExp]
scan_nes) = forall {k} (rep :: k). Buildable rep => [Scan rep] -> Scan rep
singleScan [Scan rep]
scans
          ([PatElem (LetDec rep)]
scan_pes, [PatElem (LetDec rep)]
red_pes, [PatElem (LetDec rep)]
map_pes) =
            forall a. Int -> Int -> [a] -> ([a], [a], [a])
splitAt3 (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
scan_nes) (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
red_nes) forall a b. (a -> b) -> a -> b
$
              forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec rep)
pat
          bindMapParam :: Param dec -> VName -> m ()
bindMapParam Param dec
p VName
a = do
            Type
a_t <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m Type
lookupType VName
a
            forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [forall dec. Param dec -> VName
paramName Param dec
p] forall a b. (a -> b) -> a -> b
$
              forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
                VName -> Slice SubExp -> BasicOp
Index VName
a forall a b. (a -> b) -> a -> b
$
                  Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
a_t [forall d. d -> DimIndex d
DimFix forall a b. (a -> b) -> a -> b
$ forall v. IsValue v => v -> SubExp
constant (Int64
0 :: Int64)]
          bindArrayResult :: PatElem dec -> SubExpRes -> m ()
bindArrayResult PatElem dec
pe (SubExpRes Certs
cs SubExp
se) =
            forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [forall dec. PatElem dec -> VName
patElemName PatElem dec
pe] forall a b. (a -> b) -> a -> b
$
              forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
                [SubExp] -> Type -> BasicOp
ArrayLit [SubExp
se] forall a b. (a -> b) -> a -> b
$
                  forall u. TypeBase Shape u -> TypeBase Shape u
rowType forall a b. (a -> b) -> a -> b
$
                    forall dec. Typed dec => PatElem dec -> Type
patElemType PatElem dec
pe
          bindResult :: PatElem dec -> SubExpRes -> m ()
bindResult PatElem dec
pe (SubExpRes Certs
cs SubExp
se) =
            forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [forall dec. PatElem dec -> VName
patElemName PatElem dec
pe] forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se

      forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ forall {m :: * -> *} {dec}.
MonadBuilder m =>
Param dec -> VName -> m ()
bindMapParam (forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda rep
map_lam) [VName]
arrs
      (Result
to_scan, Result
to_red, Result
map_res) <-
        forall a. Int -> Int -> [a] -> ([a], [a], [a])
splitAt3 (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
scan_nes) (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
red_nes)
          forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind (forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda rep
map_lam)
      Result
scan_res <- forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m Result
eLambda Lambda rep
scan_lam forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp forall a b. (a -> b) -> a -> b
$ [SubExp]
scan_nes forall a. [a] -> [a] -> [a]
++ forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp Result
to_scan
      Result
red_res <- forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m Result
eLambda Lambda rep
red_lam forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp forall a b. (a -> b) -> a -> b
$ [SubExp]
red_nes forall a. [a] -> [a] -> [a]
++ forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp Result
to_red

      forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ forall {m :: * -> *} {dec}.
(MonadBuilder m, Typed dec) =>
PatElem dec -> SubExpRes -> m ()
bindArrayResult [PatElem (LetDec rep)]
scan_pes Result
scan_res
      forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ forall {m :: * -> *} {dec}.
MonadBuilder m =>
PatElem dec -> SubExpRes -> m ()
bindResult [PatElem (LetDec rep)]
red_pes Result
red_res
      forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ forall {m :: * -> *} {dec}.
(MonadBuilder m, Typed dec) =>
PatElem dec -> SubExpRes -> m ()
bindArrayResult [PatElem (LetDec rep)]
map_pes Result
map_res
simplifyKnownIterationSOAC TopDown rep
_ Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ Op rep
op
  | Just (Stream (Constant PrimValue
k) [VName]
arrs [SubExp]
nes Lambda rep
fold_lam) <- forall {k} (rep :: k). HasSOAC rep => Op rep -> Maybe (SOAC rep)
asSOAC Op rep
op,
    PrimValue -> Bool
oneIsh PrimValue
k = forall {k} (rep :: k). RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$ do
      let (Param (LParamInfo rep)
chunk_param, [Param (LParamInfo rep)]
acc_params, [Param (LParamInfo rep)]
slice_params) =
            forall dec.
Int -> [Param dec] -> (Param dec, [Param dec], [Param dec])
partitionChunkedFoldParameters (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) (forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda rep
fold_lam)

      forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [forall dec. Param dec -> VName
paramName Param (LParamInfo rep)
chunk_param] forall a b. (a -> b) -> a -> b
$
        forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
          SubExp -> BasicOp
SubExp forall a b. (a -> b) -> a -> b
$
            IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1

      forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [Param (LParamInfo rep)]
acc_params [SubExp]
nes) forall a b. (a -> b) -> a -> b
$ \(Param (LParamInfo rep)
p, SubExp
ne) ->
        forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [forall dec. Param dec -> VName
paramName Param (LParamInfo rep)
p] forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
ne

      forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [Param (LParamInfo rep)]
slice_params [VName]
arrs) forall a b. (a -> b) -> a -> b
$ \(Param (LParamInfo rep)
p, VName
arr) ->
        forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [forall dec. Param dec -> VName
paramName Param (LParamInfo rep)
p] forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
arr

      Result
res <- forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda rep
fold_lam

      forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip (forall dec. Pat dec -> [VName]
patNames Pat (LetDec rep)
pat) Result
res) forall a b. (a -> b) -> a -> b
$ \(VName
v, SubExpRes Certs
cs SubExp
se) ->
        forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
v] forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
simplifyKnownIterationSOAC TopDown rep
_ Pat (LetDec rep)
_ StmAux (ExpDec rep)
_ Op rep
_ = forall {k} (rep :: k). Rule rep
Skip

data ArrayOp
  = ArrayIndexing Certs VName (Slice SubExp)
  | ArrayRearrange Certs VName [Int]
  | ArrayRotate Certs VName [SubExp]
  | ArrayReshape Certs VName ReshapeKind Shape
  | ArrayCopy Certs VName
  | -- | Never constructed.
    ArrayVar Certs VName
  deriving (ArrayOp -> ArrayOp -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ArrayOp -> ArrayOp -> Bool
$c/= :: ArrayOp -> ArrayOp -> Bool
== :: ArrayOp -> ArrayOp -> Bool
$c== :: ArrayOp -> ArrayOp -> Bool
Eq, Eq ArrayOp
ArrayOp -> ArrayOp -> Bool
ArrayOp -> ArrayOp -> Ordering
ArrayOp -> ArrayOp -> ArrayOp
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: ArrayOp -> ArrayOp -> ArrayOp
$cmin :: ArrayOp -> ArrayOp -> ArrayOp
max :: ArrayOp -> ArrayOp -> ArrayOp
$cmax :: ArrayOp -> ArrayOp -> ArrayOp
>= :: ArrayOp -> ArrayOp -> Bool
$c>= :: ArrayOp -> ArrayOp -> Bool
> :: ArrayOp -> ArrayOp -> Bool
$c> :: ArrayOp -> ArrayOp -> Bool
<= :: ArrayOp -> ArrayOp -> Bool
$c<= :: ArrayOp -> ArrayOp -> Bool
< :: ArrayOp -> ArrayOp -> Bool
$c< :: ArrayOp -> ArrayOp -> Bool
compare :: ArrayOp -> ArrayOp -> Ordering
$ccompare :: ArrayOp -> ArrayOp -> Ordering
Ord, Int -> ArrayOp -> ShowS
[ArrayOp] -> ShowS
ArrayOp -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [ArrayOp] -> ShowS
$cshowList :: [ArrayOp] -> ShowS
show :: ArrayOp -> [Char]
$cshow :: ArrayOp -> [Char]
showsPrec :: Int -> ArrayOp -> ShowS
$cshowsPrec :: Int -> ArrayOp -> ShowS
Show)

arrayOpArr :: ArrayOp -> VName
arrayOpArr :: ArrayOp -> VName
arrayOpArr (ArrayIndexing Certs
_ VName
arr Slice SubExp
_) = VName
arr
arrayOpArr (ArrayRearrange Certs
_ VName
arr [Int]
_) = VName
arr
arrayOpArr (ArrayRotate Certs
_ VName
arr [SubExp]
_) = VName
arr
arrayOpArr (ArrayReshape Certs
_ VName
arr ReshapeKind
_ Shape
_) = VName
arr
arrayOpArr (ArrayCopy Certs
_ VName
arr) = VName
arr
arrayOpArr (ArrayVar Certs
_ VName
arr) = VName
arr

arrayOpCerts :: ArrayOp -> Certs
arrayOpCerts :: ArrayOp -> Certs
arrayOpCerts (ArrayIndexing Certs
cs VName
_ Slice SubExp
_) = Certs
cs
arrayOpCerts (ArrayRearrange Certs
cs VName
_ [Int]
_) = Certs
cs
arrayOpCerts (ArrayRotate Certs
cs VName
_ [SubExp]
_) = Certs
cs
arrayOpCerts (ArrayReshape Certs
cs VName
_ ReshapeKind
_ Shape
_) = Certs
cs
arrayOpCerts (ArrayCopy Certs
cs VName
_) = Certs
cs
arrayOpCerts (ArrayVar Certs
cs VName
_) = Certs
cs

isArrayOp :: Certs -> Exp rep -> Maybe ArrayOp
isArrayOp :: forall {k} (rep :: k). Certs -> Exp rep -> Maybe ArrayOp
isArrayOp Certs
cs (BasicOp (Index VName
arr Slice SubExp
slice)) =
  forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ Certs -> VName -> Slice SubExp -> ArrayOp
ArrayIndexing Certs
cs VName
arr Slice SubExp
slice
isArrayOp Certs
cs (BasicOp (Rearrange [Int]
perm VName
arr)) =
  forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ Certs -> VName -> [Int] -> ArrayOp
ArrayRearrange Certs
cs VName
arr [Int]
perm
isArrayOp Certs
cs (BasicOp (Rotate [SubExp]
rots VName
arr)) =
  forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ Certs -> VName -> [SubExp] -> ArrayOp
ArrayRotate Certs
cs VName
arr [SubExp]
rots
isArrayOp Certs
cs (BasicOp (Reshape ReshapeKind
k Shape
new_shape VName
arr)) =
  forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ Certs -> VName -> ReshapeKind -> Shape -> ArrayOp
ArrayReshape Certs
cs VName
arr ReshapeKind
k Shape
new_shape
isArrayOp Certs
cs (BasicOp (Copy VName
arr)) =
  forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ Certs -> VName -> ArrayOp
ArrayCopy Certs
cs VName
arr
isArrayOp Certs
_ Exp rep
_ =
  forall a. Maybe a
Nothing

fromArrayOp :: ArrayOp -> (Certs, Exp rep)
fromArrayOp :: forall {k} (rep :: k). ArrayOp -> (Certs, Exp rep)
fromArrayOp (ArrayIndexing Certs
cs VName
arr Slice SubExp
slice) = (Certs
cs, forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr Slice SubExp
slice)
fromArrayOp (ArrayRearrange Certs
cs VName
arr [Int]
perm) = (Certs
cs, forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Rearrange [Int]
perm VName
arr)
fromArrayOp (ArrayRotate Certs
cs VName
arr [SubExp]
rots) = (Certs
cs, forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ [SubExp] -> VName -> BasicOp
Rotate [SubExp]
rots VName
arr)
fromArrayOp (ArrayReshape Certs
cs VName
arr ReshapeKind
k Shape
new_shape) = (Certs
cs, forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ ReshapeKind -> Shape -> VName -> BasicOp
Reshape ReshapeKind
k Shape
new_shape VName
arr)
fromArrayOp (ArrayCopy Certs
cs VName
arr) = (Certs
cs, forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
arr)
fromArrayOp (ArrayVar Certs
cs VName
arr) = (Certs
cs, forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
arr)

arrayOps ::
  forall rep.
  (Buildable rep, HasSOAC rep) =>
  Body rep ->
  S.Set (Pat (LetDec rep), ArrayOp)
arrayOps :: forall {k} (rep :: k).
(Buildable rep, HasSOAC rep) =>
Body rep -> Set (Pat (LetDec rep), ArrayOp)
arrayOps = forall a. Monoid a => [a] -> a
mconcat forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map Stm rep -> Set (Pat (LetDec rep), ArrayOp)
onStm forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Stms rep -> [Stm rep]
stmsToList forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Body rep -> Stms rep
bodyStms
  where
    onStm :: Stm rep -> Set (Pat (LetDec rep), ArrayOp)
onStm (Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux Exp rep
e) =
      case forall {k} (rep :: k). Certs -> Exp rep -> Maybe ArrayOp
isArrayOp (forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (ExpDec rep)
aux) Exp rep
e of
        Just ArrayOp
op -> forall a. a -> Set a
S.singleton (Pat (LetDec rep)
pat, ArrayOp
op)
        Maybe ArrayOp
Nothing -> forall s a. State s a -> s -> s
execState (forall {k} (m :: * -> *) (rep :: k).
Monad m =>
Walker rep m -> Exp rep -> m ()
walkExpM Walker rep (StateT (Set (Pat (LetDec rep), ArrayOp)) Identity)
walker Exp rep
e) forall a. Monoid a => a
mempty
    onOp :: Op rep -> Set (Pat (LetDec rep), ArrayOp)
onOp Op rep
op
      | Just SOAC rep
soac <- forall {k} (rep :: k). HasSOAC rep => Op rep -> Maybe (SOAC rep)
asSOAC Op rep
op =
          forall w a. Writer w a -> w
execWriter forall a b. (a -> b) -> a -> b
$
            forall {k1} {k2} (m :: * -> *) (frep :: k1) (trep :: k2).
Monad m =>
SOACMapper frep trep m -> SOAC frep -> m (SOAC trep)
mapSOACM
              forall {k} (m :: * -> *) (rep :: k).
Monad m =>
SOACMapper rep rep m
identitySOACMapper {mapOnSOACLambda :: Lambda rep
-> WriterT (Set (Pat (LetDec rep), ArrayOp)) Identity (Lambda rep)
mapOnSOACLambda = forall {k} {m :: * -> *} {rep :: k}.
(MonadWriter (Set (Pat (LetDec rep), ArrayOp)) m, Buildable rep,
 HasSOAC rep) =>
Lambda rep -> m (Lambda rep)
onLambda}
              (SOAC rep
soac :: SOAC rep)
      | Bool
otherwise =
          forall a. Monoid a => a
mempty
    onLambda :: Lambda rep -> m (Lambda rep)
onLambda Lambda rep
lam = do
      forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
(Buildable rep, HasSOAC rep) =>
Body rep -> Set (Pat (LetDec rep), ArrayOp)
arrayOps forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda rep
lam
      forall (f :: * -> *) a. Applicative f => a -> f a
pure Lambda rep
lam
    walker :: Walker rep (StateT (Set (Pat (LetDec rep), ArrayOp)) Identity)
walker =
      forall {k} (m :: * -> *) (rep :: k). Monad m => Walker rep m
identityWalker
        { walkOnBody :: Scope rep
-> Body rep -> StateT (Set (Pat (LetDec rep), ArrayOp)) Identity ()
walkOnBody = forall a b. a -> b -> a
const forall a b. (a -> b) -> a -> b
$ forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Semigroup a => a -> a -> a
(<>) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k).
(Buildable rep, HasSOAC rep) =>
Body rep -> Set (Pat (LetDec rep), ArrayOp)
arrayOps,
          walkOnOp :: Op rep -> StateT (Set (Pat (LetDec rep), ArrayOp)) Identity ()
walkOnOp = forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Semigroup a => a -> a -> a
(<>) forall b c a. (b -> c) -> (a -> b) -> a -> c
. Op rep -> Set (Pat (LetDec rep), ArrayOp)
onOp
        }

replaceArrayOps ::
  forall rep.
  (Buildable rep, BuilderOps rep, HasSOAC rep) =>
  M.Map (Pat (LetDec rep), ArrayOp) ArrayOp ->
  Body rep ->
  Body rep
replaceArrayOps :: forall rep.
(Buildable rep, BuilderOps rep, HasSOAC rep) =>
Map (Pat (LetDec rep), ArrayOp) ArrayOp -> Body rep -> Body rep
replaceArrayOps Map (Pat (LetDec rep), ArrayOp) ArrayOp
substs (Body BodyDec rep
_ Stms rep
stms Result
res) =
  forall {k} (rep :: k).
Buildable rep =>
Stms rep -> Result -> Body rep
mkBody (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Stm rep -> Stm rep
onStm Stms rep
stms) Result
res
  where
    onStm :: Stm rep -> Stm rep
onStm (Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux Exp rep
e) =
      let (Certs
cs', Exp rep
e') = Pat (LetDec rep) -> Certs -> Exp rep -> (Certs, Exp rep)
onExp Pat (LetDec rep)
pat (forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (ExpDec rep)
aux) Exp rep
e
       in forall {k} (rep :: k). Certs -> Stm rep -> Stm rep
certify Certs
cs' forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) a.
Buildable rep =>
[Ident] -> StmAux a -> Exp rep -> Stm rep
mkLet' (forall dec. Typed dec => Pat dec -> [Ident]
patIdents Pat (LetDec rep)
pat) StmAux (ExpDec rep)
aux Exp rep
e'
    onExp :: Pat (LetDec rep) -> Certs -> Exp rep -> (Certs, Exp rep)
onExp Pat (LetDec rep)
pat Certs
cs Exp rep
e
      | Just ArrayOp
op <- forall {k} (rep :: k). Certs -> Exp rep -> Maybe ArrayOp
isArrayOp Certs
cs Exp rep
e,
        Just ArrayOp
op' <- forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (Pat (LetDec rep)
pat, ArrayOp
op) Map (Pat (LetDec rep), ArrayOp) ArrayOp
substs =
          forall {k} (rep :: k). ArrayOp -> (Certs, Exp rep)
fromArrayOp ArrayOp
op'
    onExp Pat (LetDec rep)
_ Certs
cs Exp rep
e = (Certs
cs, forall {k1} {k2} (frep :: k1) (trep :: k2).
Mapper frep trep Identity -> Exp frep -> Exp trep
mapExp Mapper rep rep Identity
mapper Exp rep
e)
    mapper :: Mapper rep rep Identity
mapper =
      forall {k} (m :: * -> *) (rep :: k). Monad m => Mapper rep rep m
identityMapper
        { mapOnBody :: Scope rep -> Body rep -> Identity (Body rep)
mapOnBody = forall a b. a -> b -> a
const forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep.
(Buildable rep, BuilderOps rep, HasSOAC rep) =>
Map (Pat (LetDec rep), ArrayOp) ArrayOp -> Body rep -> Body rep
replaceArrayOps Map (Pat (LetDec rep), ArrayOp) ArrayOp
substs,
          mapOnOp :: Op rep -> Identity (Op rep)
mapOnOp = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. Op rep -> Op rep
onOp
        }
    onOp :: Op rep -> Op rep
onOp Op rep
op
      | Just (SOAC rep
soac :: SOAC rep) <- forall {k} (rep :: k). HasSOAC rep => Op rep -> Maybe (SOAC rep)
asSOAC Op rep
op =
          forall {k} (rep :: k). HasSOAC rep => SOAC rep -> Op rep
soacOp forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Identity a -> a
runIdentity forall a b. (a -> b) -> a -> b
$
            forall {k1} {k2} (m :: * -> *) (frep :: k1) (trep :: k2).
Monad m =>
SOACMapper frep trep m -> SOAC frep -> m (SOAC trep)
mapSOACM forall {k} (m :: * -> *) (rep :: k).
Monad m =>
SOACMapper rep rep m
identitySOACMapper {mapOnSOACLambda :: Lambda rep -> Identity (Lambda rep)
mapOnSOACLambda = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda rep -> Lambda rep
onLambda} SOAC rep
soac
      | Bool
otherwise =
          Op rep
op
    onLambda :: Lambda rep -> Lambda rep
onLambda Lambda rep
lam = Lambda rep
lam {lambdaBody :: Body rep
lambdaBody = forall rep.
(Buildable rep, BuilderOps rep, HasSOAC rep) =>
Map (Pat (LetDec rep), ArrayOp) ArrayOp -> Body rep -> Body rep
replaceArrayOps Map (Pat (LetDec rep), ArrayOp) ArrayOp
substs forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda rep
lam}

-- Turn
--
--    map (\i -> ... xs[i] ...) (iota n)
--
-- into
--
--    map (\i x -> ... x ...) (iota n) xs
--
-- This is not because we want to encourage the map-iota pattern, but
-- it may be present in generated code.  This is an unfortunately
-- expensive simplification rule, since it requires multiple passes
-- over the entire lambda body.  It only handles the very simplest
-- case - if you find yourself planning to extend it to handle more
-- complex situations (rotate or whatnot), consider turning it into a
-- separate compiler pass instead.
simplifyMapIota ::
  forall rep.
  (Buildable rep, BuilderOps rep, HasSOAC rep) =>
  TopDownRuleOp rep
simplifyMapIota :: forall rep.
(Buildable rep, BuilderOps rep, HasSOAC rep) =>
TopDownRuleOp rep
simplifyMapIota TopDown rep
vtable Pat (LetDec rep)
screma_pat StmAux (ExpDec rep)
aux Op rep
op
  | Just (Screma SubExp
w [VName]
arrs (ScremaForm [Scan rep]
scan [Reduce rep]
reduce Lambda rep
map_lam) :: SOAC rep) <- forall {k} (rep :: k). HasSOAC rep => Op rep -> Maybe (SOAC rep)
asSOAC Op rep
op,
    Just (Param Type
p, VName
_) <- forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (Param Type, VName) -> Bool
isIota (forall a b. [a] -> [b] -> [(a, b)]
zip (forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda rep
map_lam) [VName]
arrs),
    [(Pat (LetDec rep), [SubExp], ArrayOp)]
indexings <-
      forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (VName
-> (Pat (LetDec rep), ArrayOp)
-> Maybe (Pat (LetDec rep), [SubExp], ArrayOp)
indexesWith (forall dec. Param dec -> VName
paramName Param Type
p)) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Set a -> [a]
S.toList forall a b. (a -> b) -> a -> b
$
        forall {k} (rep :: k).
(Buildable rep, HasSOAC rep) =>
Body rep -> Set (Pat (LetDec rep), ArrayOp)
arrayOps forall a b. (a -> b) -> a -> b
$
          forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda rep
map_lam,
    Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(Pat (LetDec rep), [SubExp], ArrayOp)]
indexings = forall {k} (rep :: k). RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$ do
      -- For each indexing with iota, add the corresponding array to
      -- the Screma, and construct a new lambda parameter.
      ([VName]
more_arrs, [Param Type]
more_params, [((Pat (LetDec rep), ArrayOp), ArrayOp)]
replacements) <-
        forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. [Maybe a] -> [a]
catMaybes forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall {m :: * -> *} {a}.
MonadBuilder m =>
SubExp
-> (a, [SubExp], ArrayOp)
-> m (Maybe (VName, Param Type, ((a, ArrayOp), ArrayOp)))
mapOverArr SubExp
w) [(Pat (LetDec rep), [SubExp], ArrayOp)]
indexings
      let substs :: Map (Pat (LetDec rep), ArrayOp) ArrayOp
substs = forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [((Pat (LetDec rep), ArrayOp), ArrayOp)]
replacements
          map_lam' :: Lambda rep
map_lam' =
            Lambda rep
map_lam
              { lambdaParams :: [LParam rep]
lambdaParams = forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda rep
map_lam forall a. Semigroup a => a -> a -> a
<> [Param Type]
more_params,
                lambdaBody :: Body rep
lambdaBody = forall rep.
(Buildable rep, BuilderOps rep, HasSOAC rep) =>
Map (Pat (LetDec rep), ArrayOp) ArrayOp -> Body rep -> Body rep
replaceArrayOps Map (Pat (LetDec rep), ArrayOp) ArrayOp
substs forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda rep
map_lam
              }

      forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
screma_pat forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Op rep -> Exp rep
Op forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). HasSOAC rep => SOAC rep -> Op rep
soacOp forall a b. (a -> b) -> a -> b
$
        forall {k} (rep :: k).
SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w ([VName]
arrs forall a. Semigroup a => a -> a -> a
<> [VName]
more_arrs) (forall {k} (rep :: k).
[Scan rep] -> [Reduce rep] -> Lambda rep -> ScremaForm rep
ScremaForm [Scan rep]
scan [Reduce rep]
reduce Lambda rep
map_lam')
  where
    isIota :: (Param Type, VName) -> Bool
isIota (Param Type
_, VName
arr) = case forall {k} (rep :: k).
VName -> SymbolTable rep -> Maybe (BasicOp, Certs)
ST.lookupBasicOp VName
arr TopDown rep
vtable of
      Just (Iota SubExp
_ (Constant PrimValue
o) (Constant PrimValue
s) IntType
_, Certs
_) ->
        PrimValue -> Bool
zeroIsh PrimValue
o Bool -> Bool -> Bool
&& PrimValue -> Bool
oneIsh PrimValue
s
      Maybe (BasicOp, Certs)
_ -> Bool
False

    -- Find a 'DimFix i', optionally preceded by other DimFixes, and
    -- if so return those DimFixes.
    fixWith :: VName -> [DimIndex SubExp] -> Maybe [SubExp]
fixWith VName
i (DimFix SubExp
j : [DimIndex SubExp]
slice)
      | VName -> SubExp
Var VName
i forall a. Eq a => a -> a -> Bool
== SubExp
j = forall a. a -> Maybe a
Just []
      | Bool
otherwise = (SubExp
j :) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> [DimIndex SubExp] -> Maybe [SubExp]
fixWith VName
i [DimIndex SubExp]
slice
    fixWith VName
_ [DimIndex SubExp]
_ = forall a. Maybe a
Nothing

    indexesWith :: VName
-> (Pat (LetDec rep), ArrayOp)
-> Maybe (Pat (LetDec rep), [SubExp], ArrayOp)
indexesWith VName
v (Pat (LetDec rep)
pat, idx :: ArrayOp
idx@(ArrayIndexing Certs
cs VName
arr (Slice [DimIndex SubExp]
js)))
      | VName
arr forall {k} (rep :: k). VName -> SymbolTable rep -> Bool
`ST.elem` TopDown rep
vtable,
        forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (forall {k} (rep :: k). VName -> SymbolTable rep -> Bool
`ST.elem` TopDown rep
vtable) forall a b. (a -> b) -> a -> b
$ Certs -> [VName]
unCerts Certs
cs,
        Just [SubExp]
js' <- VName -> [DimIndex SubExp] -> Maybe [SubExp]
fixWith VName
v [DimIndex SubExp]
js,
        forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (forall {k} (rep :: k). VName -> SymbolTable rep -> Bool
`ST.elem` TopDown rep
vtable) forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList forall a b. (a -> b) -> a -> b
$ forall a. FreeIn a => a -> Names
freeIn [SubExp]
js' =
          forall a. a -> Maybe a
Just (Pat (LetDec rep)
pat, [SubExp]
js', ArrayOp
idx)
    indexesWith VName
_ (Pat (LetDec rep), ArrayOp)
_ = forall a. Maybe a
Nothing

    properArr :: [SubExp] -> VName -> f VName
properArr [] VName
arr = forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
arr
    properArr [SubExp]
js VName
arr = do
      Type
arr_t <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m Type
lookupType VName
arr
      forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
arr) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr forall a b. (a -> b) -> a -> b
$ Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
arr_t forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall d. d -> DimIndex d
DimFix [SubExp]
js

    mapOverArr :: SubExp
-> (a, [SubExp], ArrayOp)
-> m (Maybe (VName, Param Type, ((a, ArrayOp), ArrayOp)))
mapOverArr SubExp
w (a
pat, [SubExp]
js, ArrayIndexing Certs
cs VName
arr Slice SubExp
slice) = do
      VName
arr' <- forall {f :: * -> *}.
MonadBuilder f =>
[SubExp] -> VName -> f VName
properArr [SubExp]
js VName
arr
      Type
arr_t <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m Type
lookupType VName
arr'
      VName
arr'' <-
        if forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 Type
arr_t forall a. Eq a => a -> a -> Bool
== SubExp
w
          then forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
arr'
          else
            forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
arr forall a. [a] -> [a] -> [a]
++ [Char]
"_prefix") forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Slice SubExp -> BasicOp
Index VName
arr' forall a b. (a -> b) -> a -> b
$
              Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
arr_t [forall d. d -> d -> d -> DimIndex d
DimSlice (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) SubExp
w (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1)]
      Param Type
arr_elem_param <- forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam (VName -> [Char]
baseString VName
arr forall a. [a] -> [a] -> [a]
++ [Char]
"_elem") (forall u. TypeBase Shape u -> TypeBase Shape u
rowType Type
arr_t)
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$
        forall a. a -> Maybe a
Just
          ( VName
arr'',
            Param Type
arr_elem_param,
            ( (a
pat, Certs -> VName -> Slice SubExp -> ArrayOp
ArrayIndexing Certs
cs VName
arr Slice SubExp
slice),
              Certs -> VName -> Slice SubExp -> ArrayOp
ArrayIndexing Certs
cs (forall dec. Param dec -> VName
paramName Param Type
arr_elem_param) (forall d. [DimIndex d] -> Slice d
Slice (forall a. Int -> [a] -> [a]
drop (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
js forall a. Num a => a -> a -> a
+ Int
1) (forall d. Slice d -> [DimIndex d]
unSlice Slice SubExp
slice)))
            )
          )
    mapOverArr SubExp
_ (a, [SubExp], ArrayOp)
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing
simplifyMapIota TopDown rep
_ Pat (LetDec rep)
_ StmAux (ExpDec rep)
_ Op rep
_ = forall {k} (rep :: k). Rule rep
Skip

-- If a Screma's map function contains a transformation
-- (e.g. transpose) on a parameter, create a new parameter
-- corresponding to that transformation performed on the rows of the
-- full array.
moveTransformToInput :: TopDownRuleOp (Wise SOACS)
moveTransformToInput :: TopDownRuleOp (Wise SOACS)
moveTransformToInput SymbolTable (Wise SOACS)
vtable Pat (LetDec (Wise SOACS))
screma_pat StmAux (ExpDec (Wise SOACS))
aux soac :: Op (Wise SOACS)
soac@(Screma SubExp
w [VName]
arrs (ScremaForm [Scan (Wise SOACS)]
scan [Reduce (Wise SOACS)]
reduce Lambda (Wise SOACS)
map_lam))
  | [(Pat (VarWisdom, Type), ArrayOp)]
ops <- forall a. (a -> Bool) -> [a] -> [a]
filter (Pat (VarWisdom, Type), ArrayOp) -> Bool
arrayIsMapParam forall a b. (a -> b) -> a -> b
$ forall a. Set a -> [a]
S.toList forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
(Buildable rep, HasSOAC rep) =>
Body rep -> Set (Pat (LetDec rep), ArrayOp)
arrayOps forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda (Wise SOACS)
map_lam,
    Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(Pat (VarWisdom, Type), ArrayOp)]
ops = forall {k} (rep :: k). RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$ do
      ([VName]
more_arrs, [Param Type]
more_params, [((Pat (VarWisdom, Type), ArrayOp), ArrayOp)]
replacements) <-
        forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. [Maybe a] -> [a]
catMaybes forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Pat (VarWisdom, Type), ArrayOp)
-> RuleM
     (Wise SOACS)
     (Maybe
        (VName, Param Type, ((Pat (VarWisdom, Type), ArrayOp), ArrayOp)))
mapOverArr [(Pat (VarWisdom, Type), ArrayOp)]
ops

      forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (forall (t :: * -> *) a. Foldable t => t a -> Bool
null [VName]
more_arrs) forall {k} (rep :: k) a. RuleM rep a
cannotSimplify

      let map_lam' :: Lambda (Wise SOACS)
map_lam' =
            Lambda (Wise SOACS)
map_lam
              { lambdaParams :: [LParam (Wise SOACS)]
lambdaParams = forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda (Wise SOACS)
map_lam forall a. Semigroup a => a -> a -> a
<> [Param Type]
more_params,
                lambdaBody :: Body (Wise SOACS)
lambdaBody = forall rep.
(Buildable rep, BuilderOps rep, HasSOAC rep) =>
Map (Pat (LetDec rep), ArrayOp) ArrayOp -> Body rep -> Body rep
replaceArrayOps (forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [((Pat (VarWisdom, Type), ArrayOp), ArrayOp)]
replacements) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda (Wise SOACS)
map_lam
              }

      forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec (Wise SOACS))
aux forall a b. (a -> b) -> a -> b
$
        forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec (Wise SOACS))
screma_pat forall a b. (a -> b) -> a -> b
$
          forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$
            forall {k} (rep :: k).
SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w ([VName]
arrs forall a. Semigroup a => a -> a -> a
<> [VName]
more_arrs) (forall {k} (rep :: k).
[Scan rep] -> [Reduce rep] -> Lambda rep -> ScremaForm rep
ScremaForm [Scan (Wise SOACS)]
scan [Reduce (Wise SOACS)]
reduce Lambda (Wise SOACS)
map_lam')
  where
    -- It is not safe to move the transform if the root array is being
    -- consumed by the Screma.  This is a bit too conservative - it's
    -- actually safe if we completely replace the original input, but
    -- this rule is not that precise.
    consumed :: Names
consumed = forall op. AliasedOp op => op -> Names
consumedInOp Op (Wise SOACS)
soac
    map_param_names :: [VName]
map_param_names = forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName (forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda (Wise SOACS)
map_lam)
    topLevelPat :: Pat (VarWisdom, Type) -> Bool
topLevelPat = (forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall {k} (rep :: k). Stm rep -> Pat (LetDec rep)
stmPat (forall {k} (rep :: k). Body rep -> Stms rep
bodyStms (forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda (Wise SOACS)
map_lam)))
    onlyUsedOnce :: VName -> Bool
onlyUsedOnce VName
arr =
      case forall a. (a -> Bool) -> [a] -> [a]
filter ((VName
arr `nameIn`) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. FreeIn a => a -> Names
freeIn) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Stms rep -> [Stm rep]
stmsToList forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Body rep -> Stms rep
bodyStms forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda (Wise SOACS)
map_lam of
        Stm (Wise SOACS)
_ : Stm (Wise SOACS)
_ : [Stm (Wise SOACS)]
_ -> Bool
False
        [Stm (Wise SOACS)]
_ -> Bool
True

    -- It's not just about whether the array is a parameter;
    -- everything else must be map-invariant.
    arrayIsMapParam :: (Pat (VarWisdom, Type), ArrayOp) -> Bool
arrayIsMapParam (Pat (VarWisdom, Type)
pat', ArrayIndexing Certs
cs VName
arr Slice SubExp
slice) =
      VName
arr forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
map_param_names
        Bool -> Bool -> Bool
&& forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (forall {k} (rep :: k). VName -> SymbolTable rep -> Bool
`ST.elem` SymbolTable (Wise SOACS)
vtable) (Names -> [VName]
namesToList forall a b. (a -> b) -> a -> b
$ forall a. FreeIn a => a -> Names
freeIn Certs
cs forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> Names
freeIn Slice SubExp
slice)
        Bool -> Bool -> Bool
&& Bool -> Bool
not (forall (t :: * -> *) a. Foldable t => t a -> Bool
null Slice SubExp
slice)
        Bool -> Bool -> Bool
&& (Bool -> Bool
not (forall (t :: * -> *) a. Foldable t => t a -> Bool
null forall a b. (a -> b) -> a -> b
$ forall d. Slice d -> [d]
sliceDims Slice SubExp
slice) Bool -> Bool -> Bool
|| (Pat (VarWisdom, Type) -> Bool
topLevelPat Pat (VarWisdom, Type)
pat' Bool -> Bool -> Bool
&& VName -> Bool
onlyUsedOnce VName
arr))
    arrayIsMapParam (Pat (VarWisdom, Type)
_, ArrayRearrange Certs
cs VName
arr [Int]
perm) =
      VName
arr forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
map_param_names
        Bool -> Bool -> Bool
&& forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (forall {k} (rep :: k). VName -> SymbolTable rep -> Bool
`ST.elem` SymbolTable (Wise SOACS)
vtable) (Names -> [VName]
namesToList forall a b. (a -> b) -> a -> b
$ forall a. FreeIn a => a -> Names
freeIn Certs
cs)
        Bool -> Bool -> Bool
&& Bool -> Bool
not (forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Int]
perm)
    arrayIsMapParam (Pat (VarWisdom, Type)
_, ArrayRotate Certs
cs VName
arr [SubExp]
rots) =
      VName
arr forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
map_param_names
        Bool -> Bool -> Bool
&& forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (forall {k} (rep :: k). VName -> SymbolTable rep -> Bool
`ST.elem` SymbolTable (Wise SOACS)
vtable) (Names -> [VName]
namesToList forall a b. (a -> b) -> a -> b
$ forall a. FreeIn a => a -> Names
freeIn Certs
cs forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> Names
freeIn [SubExp]
rots)
    arrayIsMapParam (Pat (VarWisdom, Type)
_, ArrayReshape Certs
cs VName
arr ReshapeKind
_ Shape
new_shape) =
      VName
arr forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
map_param_names
        Bool -> Bool -> Bool
&& forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (forall {k} (rep :: k). VName -> SymbolTable rep -> Bool
`ST.elem` SymbolTable (Wise SOACS)
vtable) (Names -> [VName]
namesToList forall a b. (a -> b) -> a -> b
$ forall a. FreeIn a => a -> Names
freeIn Certs
cs forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> Names
freeIn Shape
new_shape)
    arrayIsMapParam (Pat (VarWisdom, Type)
_, ArrayCopy Certs
cs VName
arr) =
      VName
arr forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
map_param_names
        Bool -> Bool -> Bool
&& forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (forall {k} (rep :: k). VName -> SymbolTable rep -> Bool
`ST.elem` SymbolTable (Wise SOACS)
vtable) (Names -> [VName]
namesToList forall a b. (a -> b) -> a -> b
$ forall a. FreeIn a => a -> Names
freeIn Certs
cs)
    arrayIsMapParam (Pat (VarWisdom, Type)
_, ArrayVar {}) =
      Bool
False

    mapOverArr :: (Pat (VarWisdom, Type), ArrayOp)
-> RuleM
     (Wise SOACS)
     (Maybe
        (VName, Param Type, ((Pat (VarWisdom, Type), ArrayOp), ArrayOp)))
mapOverArr (Pat (VarWisdom, Type)
pat, ArrayOp
op)
      | Just (VName
_, VName
arr) <- forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((forall a. Eq a => a -> a -> Bool
== ArrayOp -> VName
arrayOpArr ArrayOp
op) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
map_param_names [VName]
arrs),
        VName
arr VName -> Names -> Bool
`notNameIn` Names
consumed = do
          Type
arr_t <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m Type
lookupType VName
arr
          let whole_dim :: DimIndex SubExp
whole_dim = forall d. d -> d -> d -> DimIndex d
DimSlice (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 Type
arr_t) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1)
          VName
arr_transformed <- forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (ArrayOp -> Certs
arrayOpCerts ArrayOp
op) forall a b. (a -> b) -> a -> b
$
            forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
arr forall a. [a] -> [a] -> [a]
++ [Char]
"_transformed") forall a b. (a -> b) -> a -> b
$
              case ArrayOp
op of
                ArrayIndexing Certs
_ VName
_ (Slice [DimIndex SubExp]
slice) ->
                  forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr forall a b. (a -> b) -> a -> b
$ forall d. [DimIndex d] -> Slice d
Slice forall a b. (a -> b) -> a -> b
$ DimIndex SubExp
whole_dim forall a. a -> [a] -> [a]
: [DimIndex SubExp]
slice
                ArrayRearrange Certs
_ VName
_ [Int]
perm ->
                  forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Rearrange (Int
0 forall a. a -> [a] -> [a]
: forall a b. (a -> b) -> [a] -> [b]
map (forall a. Num a => a -> a -> a
+ Int
1) [Int]
perm) VName
arr
                ArrayRotate Certs
_ VName
_ [SubExp]
rots ->
                  forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ [SubExp] -> VName -> BasicOp
Rotate (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0 forall a. a -> [a] -> [a]
: [SubExp]
rots) VName
arr
                ArrayReshape Certs
_ VName
_ ReshapeKind
k Shape
new_shape ->
                  forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ ReshapeKind -> Shape -> VName -> BasicOp
Reshape ReshapeKind
k (forall d. [d] -> ShapeBase d
Shape [SubExp
w] forall a. Semigroup a => a -> a -> a
<> Shape
new_shape) VName
arr
                ArrayCopy {} ->
                  forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
arr
                ArrayVar {} ->
                  forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
arr
          Type
arr_transformed_t <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m Type
lookupType VName
arr_transformed
          VName
arr_transformed_row <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName forall a b. (a -> b) -> a -> b
$ VName -> [Char]
baseString VName
arr forall a. [a] -> [a] -> [a]
++ [Char]
"_transformed_row"
          forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$
            forall a. a -> Maybe a
Just
              ( VName
arr_transformed,
                forall dec. Attrs -> VName -> dec -> Param dec
Param forall a. Monoid a => a
mempty VName
arr_transformed_row (forall u. TypeBase Shape u -> TypeBase Shape u
rowType Type
arr_transformed_t),
                ((Pat (VarWisdom, Type)
pat, ArrayOp
op), Certs -> VName -> ArrayOp
ArrayVar forall a. Monoid a => a
mempty VName
arr_transformed_row)
              )
    mapOverArr (Pat (VarWisdom, Type), ArrayOp)
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing
moveTransformToInput SymbolTable (Wise SOACS)
_ Pat (LetDec (Wise SOACS))
_ StmAux (ExpDec (Wise SOACS))
_ Op (Wise SOACS)
_ =
  forall {k} (rep :: k). Rule rep
Skip