{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}

module Futhark.IR.SOACS.Simplify
  ( simplifySOACS,
    simplifyLambda,
    simplifyFun,
    simplifyStms,
    simplifyConsts,
    simpleSOACS,
    simplifySOAC,
    soacRules,
    HasSOAC (..),
    simplifyKnownIterationSOAC,
    removeReplicateMapping,
    liftIdentityMapping,
    SOACS,
  )
where

import Control.Monad
import Control.Monad.Identity
import Control.Monad.State
import Control.Monad.Writer
import Data.Either
import Data.Foldable
import Data.List (partition, transpose, unzip6, zip6)
import qualified Data.Map.Strict as M
import Data.Maybe
import qualified Data.Set as S
import Futhark.Analysis.DataDependencies
import qualified Futhark.Analysis.SymbolTable as ST
import qualified Futhark.Analysis.UsageTable as UT
import qualified Futhark.IR as AST
import Futhark.IR.Prop.Aliases
import Futhark.IR.SOACS
import Futhark.MonadFreshNames
import qualified Futhark.Optimise.Simplify as Simplify
import qualified Futhark.Optimise.Simplify.Engine as Engine
import Futhark.Optimise.Simplify.Rep
import Futhark.Optimise.Simplify.Rule
import Futhark.Optimise.Simplify.Rules
import Futhark.Optimise.Simplify.Rules.ClosedForm
import Futhark.Pass
import Futhark.Tools
import Futhark.Transform.Rename
import Futhark.Util

simpleSOACS :: Simplify.SimpleOps SOACS
simpleSOACS :: SimpleOps SOACS
simpleSOACS = SimplifyOp SOACS (Op SOACS) -> SimpleOps SOACS
forall rep.
(SimplifiableRep rep, Buildable rep) =>
SimplifyOp rep (Op rep) -> SimpleOps rep
Simplify.bindableSimpleOps SimplifyOp SOACS (Op SOACS)
forall rep. SimplifiableRep rep => SimplifyOp rep (SOAC rep)
simplifySOAC

simplifySOACS :: Prog SOACS -> PassM (Prog SOACS)
simplifySOACS :: Prog SOACS -> PassM (Prog SOACS)
simplifySOACS =
  SimpleOps SOACS
-> RuleBook (Wise SOACS)
-> HoistBlockers SOACS
-> Prog SOACS
-> PassM (Prog SOACS)
forall rep.
SimplifiableRep rep =>
SimpleOps rep
-> RuleBook (Wise rep)
-> HoistBlockers rep
-> Prog rep
-> PassM (Prog rep)
Simplify.simplifyProg SimpleOps SOACS
simpleSOACS RuleBook (Wise SOACS)
soacRules HoistBlockers SOACS
forall rep. HoistBlockers rep
Engine.noExtraHoistBlockers

simplifyFun ::
  MonadFreshNames m =>
  ST.SymbolTable (Wise SOACS) ->
  FunDef SOACS ->
  m (FunDef SOACS)
simplifyFun :: SymbolTable (Wise SOACS) -> FunDef SOACS -> m (FunDef SOACS)
simplifyFun =
  SimpleOps SOACS
-> RuleBook (Wise SOACS)
-> HoistBlockers SOACS
-> SymbolTable (Wise SOACS)
-> FunDef SOACS
-> m (FunDef SOACS)
forall (m :: * -> *) rep.
(MonadFreshNames m, SimplifiableRep rep) =>
SimpleOps rep
-> RuleBook (Wise rep)
-> HoistBlockers rep
-> SymbolTable (Wise rep)
-> FunDef rep
-> m (FunDef rep)
Simplify.simplifyFun SimpleOps SOACS
simpleSOACS RuleBook (Wise SOACS)
soacRules HoistBlockers SOACS
forall rep. HoistBlockers rep
Engine.noExtraHoistBlockers

simplifyLambda ::
  (HasScope SOACS m, MonadFreshNames m) =>
  Lambda ->
  m Lambda
simplifyLambda :: Lambda -> m Lambda
simplifyLambda =
  SimpleOps SOACS
-> RuleBook (Wise SOACS)
-> HoistBlockers SOACS
-> Lambda
-> m Lambda
forall (m :: * -> *) rep.
(MonadFreshNames m, HasScope rep m, SimplifiableRep rep) =>
SimpleOps rep
-> RuleBook (Wise rep)
-> HoistBlockers rep
-> Lambda rep
-> m (Lambda rep)
Simplify.simplifyLambda SimpleOps SOACS
simpleSOACS RuleBook (Wise SOACS)
soacRules HoistBlockers SOACS
forall rep. HoistBlockers rep
Engine.noExtraHoistBlockers

simplifyStms ::
  (HasScope SOACS m, MonadFreshNames m) =>
  Stms SOACS ->
  m (ST.SymbolTable (Wise SOACS), Stms SOACS)
simplifyStms :: Stms SOACS -> m (SymbolTable (Wise SOACS), Stms SOACS)
simplifyStms Stms SOACS
stms = do
  Scope SOACS
scope <- m (Scope SOACS)
forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope
  SimpleOps SOACS
-> RuleBook (Wise SOACS)
-> HoistBlockers SOACS
-> Scope SOACS
-> Stms SOACS
-> m (SymbolTable (Wise SOACS), Stms SOACS)
forall (m :: * -> *) rep.
(MonadFreshNames m, SimplifiableRep rep) =>
SimpleOps rep
-> RuleBook (Wise rep)
-> HoistBlockers rep
-> Scope rep
-> Stms rep
-> m (SymbolTable (Wise rep), Stms rep)
Simplify.simplifyStms
    SimpleOps SOACS
simpleSOACS
    RuleBook (Wise SOACS)
soacRules
    HoistBlockers SOACS
forall rep. HoistBlockers rep
Engine.noExtraHoistBlockers
    Scope SOACS
scope
    Stms SOACS
stms

simplifyConsts ::
  MonadFreshNames m =>
  Stms SOACS ->
  m (ST.SymbolTable (Wise SOACS), Stms SOACS)
simplifyConsts :: Stms SOACS -> m (SymbolTable (Wise SOACS), Stms SOACS)
simplifyConsts =
  SimpleOps SOACS
-> RuleBook (Wise SOACS)
-> HoistBlockers SOACS
-> Scope SOACS
-> Stms SOACS
-> m (SymbolTable (Wise SOACS), Stms SOACS)
forall (m :: * -> *) rep.
(MonadFreshNames m, SimplifiableRep rep) =>
SimpleOps rep
-> RuleBook (Wise rep)
-> HoistBlockers rep
-> Scope rep
-> Stms rep
-> m (SymbolTable (Wise rep), Stms rep)
Simplify.simplifyStms SimpleOps SOACS
simpleSOACS RuleBook (Wise SOACS)
soacRules HoistBlockers SOACS
forall rep. HoistBlockers rep
Engine.noExtraHoistBlockers Scope SOACS
forall a. Monoid a => a
mempty

simplifySOAC ::
  Simplify.SimplifiableRep rep =>
  Simplify.SimplifyOp rep (SOAC rep)
simplifySOAC :: SimplifyOp rep (SOAC rep)
simplifySOAC (Stream SubExp
outerdim [VName]
arr StreamForm rep
form [SubExp]
nes Lambda rep
lam) = do
  SubExp
outerdim' <- SubExp -> SimpleM rep SubExp
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify SubExp
outerdim
  (StreamForm (Wise rep)
form', Stms (Wise rep)
form_hoisted) <- StreamForm rep
-> SimpleM rep (StreamForm (Wise rep), Stms (Wise rep))
forall rep.
(ASTRep rep, Simplifiable (LetDec rep),
 Simplifiable (FParamInfo rep), Simplifiable (LParamInfo rep),
 Simplifiable (RetType rep), Simplifiable (BranchType rep),
 CanBeWise (Op rep), IndexOp (OpWithWisdom (Op rep)),
 BuilderOps (Wise rep)) =>
StreamForm rep
-> SimpleM rep (StreamForm (Wise rep), Stms (Wise rep))
simplifyStreamForm StreamForm rep
form
  [SubExp]
nes' <- (SubExp -> SimpleM rep SubExp) -> [SubExp] -> SimpleM rep [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> SimpleM rep SubExp
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [SubExp]
nes
  [VName]
arr' <- (VName -> SimpleM rep VName) -> [VName] -> SimpleM rep [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> SimpleM rep VName
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [VName]
arr
  (Lambda (Wise rep)
lam', Stms (Wise rep)
lam_hoisted) <- Lambda rep -> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall rep.
SimplifiableRep rep =>
Lambda rep -> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
Engine.simplifyLambda Lambda rep
lam
  (SOAC (Wise rep), Stms (Wise rep))
-> SimpleM rep (SOAC (Wise rep), Stms (Wise rep))
forall (m :: * -> *) a. Monad m => a -> m a
return
    ( SubExp
-> [VName]
-> StreamForm (Wise rep)
-> [SubExp]
-> Lambda (Wise rep)
-> SOAC (Wise rep)
forall rep.
SubExp
-> [VName] -> StreamForm rep -> [SubExp] -> Lambda rep -> SOAC rep
Stream SubExp
outerdim' [VName]
arr' StreamForm (Wise rep)
form' [SubExp]
nes' Lambda (Wise rep)
lam',
      Stms (Wise rep)
form_hoisted Stms (Wise rep) -> Stms (Wise rep) -> Stms (Wise rep)
forall a. Semigroup a => a -> a -> a
<> Stms (Wise rep)
lam_hoisted
    )
  where
    simplifyStreamForm :: StreamForm rep
-> SimpleM rep (StreamForm (Wise rep), Stms (Wise rep))
simplifyStreamForm (Parallel StreamOrd
o Commutativity
comm Lambda rep
lam0) = do
      (Lambda (Wise rep)
lam0', Stms (Wise rep)
hoisted) <- Lambda rep -> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall rep.
SimplifiableRep rep =>
Lambda rep -> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
Engine.simplifyLambda Lambda rep
lam0
      (StreamForm (Wise rep), Stms (Wise rep))
-> SimpleM rep (StreamForm (Wise rep), Stms (Wise rep))
forall (m :: * -> *) a. Monad m => a -> m a
return (StreamOrd
-> Commutativity -> Lambda (Wise rep) -> StreamForm (Wise rep)
forall rep.
StreamOrd -> Commutativity -> Lambda rep -> StreamForm rep
Parallel StreamOrd
o Commutativity
comm Lambda (Wise rep)
lam0', Stms (Wise rep)
hoisted)
    simplifyStreamForm StreamForm rep
Sequential =
      (StreamForm (Wise rep), Stms (Wise rep))
-> SimpleM rep (StreamForm (Wise rep), Stms (Wise rep))
forall (m :: * -> *) a. Monad m => a -> m a
return (StreamForm (Wise rep)
forall rep. StreamForm rep
Sequential, Stms (Wise rep)
forall a. Monoid a => a
mempty)
simplifySOAC (Scatter SubExp
len Lambda rep
lam [VName]
ivs [(Shape, Int, VName)]
as) = do
  SubExp
len' <- SubExp -> SimpleM rep SubExp
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify SubExp
len
  (Lambda (Wise rep)
lam', Stms (Wise rep)
hoisted) <- Lambda rep -> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall rep.
SimplifiableRep rep =>
Lambda rep -> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
Engine.simplifyLambda Lambda rep
lam
  [VName]
ivs' <- (VName -> SimpleM rep VName) -> [VName] -> SimpleM rep [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> SimpleM rep VName
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [VName]
ivs
  [(Shape, Int, VName)]
as' <- ((Shape, Int, VName) -> SimpleM rep (Shape, Int, VName))
-> [(Shape, Int, VName)] -> SimpleM rep [(Shape, Int, VName)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Shape, Int, VName) -> SimpleM rep (Shape, Int, VName)
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [(Shape, Int, VName)]
as
  (SOAC (Wise rep), Stms (Wise rep))
-> SimpleM rep (SOAC (Wise rep), Stms (Wise rep))
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp
-> Lambda (Wise rep)
-> [VName]
-> [(Shape, Int, VName)]
-> SOAC (Wise rep)
forall rep.
SubExp
-> Lambda rep -> [VName] -> [(Shape, Int, VName)] -> SOAC rep
Scatter SubExp
len' Lambda (Wise rep)
lam' [VName]
ivs' [(Shape, Int, VName)]
as', Stms (Wise rep)
hoisted)
simplifySOAC (Hist SubExp
w [HistOp rep]
ops Lambda rep
bfun [VName]
imgs) = do
  SubExp
w' <- SubExp -> SimpleM rep SubExp
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify SubExp
w
  ([HistOp (Wise rep)]
ops', [Stms (Wise rep)]
hoisted) <- ([(HistOp (Wise rep), Stms (Wise rep))]
 -> ([HistOp (Wise rep)], [Stms (Wise rep)]))
-> SimpleM rep [(HistOp (Wise rep), Stms (Wise rep))]
-> SimpleM rep ([HistOp (Wise rep)], [Stms (Wise rep)])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [(HistOp (Wise rep), Stms (Wise rep))]
-> ([HistOp (Wise rep)], [Stms (Wise rep)])
forall a b. [(a, b)] -> ([a], [b])
unzip (SimpleM rep [(HistOp (Wise rep), Stms (Wise rep))]
 -> SimpleM rep ([HistOp (Wise rep)], [Stms (Wise rep)]))
-> SimpleM rep [(HistOp (Wise rep), Stms (Wise rep))]
-> SimpleM rep ([HistOp (Wise rep)], [Stms (Wise rep)])
forall a b. (a -> b) -> a -> b
$
    [HistOp rep]
-> (HistOp rep -> SimpleM rep (HistOp (Wise rep), Stms (Wise rep)))
-> SimpleM rep [(HistOp (Wise rep), Stms (Wise rep))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [HistOp rep]
ops ((HistOp rep -> SimpleM rep (HistOp (Wise rep), Stms (Wise rep)))
 -> SimpleM rep [(HistOp (Wise rep), Stms (Wise rep))])
-> (HistOp rep -> SimpleM rep (HistOp (Wise rep), Stms (Wise rep)))
-> SimpleM rep [(HistOp (Wise rep), Stms (Wise rep))]
forall a b. (a -> b) -> a -> b
$ \(HistOp SubExp
dests_w SubExp
rf [VName]
dests [SubExp]
nes Lambda rep
op) -> do
      SubExp
dests_w' <- SubExp -> SimpleM rep SubExp
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify SubExp
dests_w
      SubExp
rf' <- SubExp -> SimpleM rep SubExp
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify SubExp
rf
      [VName]
dests' <- [VName] -> SimpleM rep [VName]
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [VName]
dests
      [SubExp]
nes' <- (SubExp -> SimpleM rep SubExp) -> [SubExp] -> SimpleM rep [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> SimpleM rep SubExp
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [SubExp]
nes
      (Lambda (Wise rep)
op', Stms (Wise rep)
hoisted) <- Lambda rep -> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall rep.
SimplifiableRep rep =>
Lambda rep -> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
Engine.simplifyLambda Lambda rep
op
      (HistOp (Wise rep), Stms (Wise rep))
-> SimpleM rep (HistOp (Wise rep), Stms (Wise rep))
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp
-> SubExp
-> [VName]
-> [SubExp]
-> Lambda (Wise rep)
-> HistOp (Wise rep)
forall rep.
SubExp -> SubExp -> [VName] -> [SubExp] -> Lambda rep -> HistOp rep
HistOp SubExp
dests_w' SubExp
rf' [VName]
dests' [SubExp]
nes' Lambda (Wise rep)
op', Stms (Wise rep)
hoisted)
  [VName]
imgs' <- (VName -> SimpleM rep VName) -> [VName] -> SimpleM rep [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> SimpleM rep VName
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [VName]
imgs
  (Lambda (Wise rep)
bfun', Stms (Wise rep)
bfun_hoisted) <- Lambda rep -> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall rep.
SimplifiableRep rep =>
Lambda rep -> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
Engine.simplifyLambda Lambda rep
bfun
  (SOAC (Wise rep), Stms (Wise rep))
-> SimpleM rep (SOAC (Wise rep), Stms (Wise rep))
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp
-> [HistOp (Wise rep)]
-> Lambda (Wise rep)
-> [VName]
-> SOAC (Wise rep)
forall rep.
SubExp -> [HistOp rep] -> Lambda rep -> [VName] -> SOAC rep
Hist SubExp
w' [HistOp (Wise rep)]
ops' Lambda (Wise rep)
bfun' [VName]
imgs', [Stms (Wise rep)] -> Stms (Wise rep)
forall a. Monoid a => [a] -> a
mconcat [Stms (Wise rep)]
hoisted Stms (Wise rep) -> Stms (Wise rep) -> Stms (Wise rep)
forall a. Semigroup a => a -> a -> a
<> Stms (Wise rep)
bfun_hoisted)
simplifySOAC (Screma SubExp
w [VName]
arrs (ScremaForm [Scan rep]
scans [Reduce rep]
reds Lambda rep
map_lam)) = do
  ([Scan (Wise rep)]
scans', [Stms (Wise rep)]
scans_hoisted) <- ([(Scan (Wise rep), Stms (Wise rep))]
 -> ([Scan (Wise rep)], [Stms (Wise rep)]))
-> SimpleM rep [(Scan (Wise rep), Stms (Wise rep))]
-> SimpleM rep ([Scan (Wise rep)], [Stms (Wise rep)])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [(Scan (Wise rep), Stms (Wise rep))]
-> ([Scan (Wise rep)], [Stms (Wise rep)])
forall a b. [(a, b)] -> ([a], [b])
unzip (SimpleM rep [(Scan (Wise rep), Stms (Wise rep))]
 -> SimpleM rep ([Scan (Wise rep)], [Stms (Wise rep)]))
-> SimpleM rep [(Scan (Wise rep), Stms (Wise rep))]
-> SimpleM rep ([Scan (Wise rep)], [Stms (Wise rep)])
forall a b. (a -> b) -> a -> b
$
    [Scan rep]
-> (Scan rep -> SimpleM rep (Scan (Wise rep), Stms (Wise rep)))
-> SimpleM rep [(Scan (Wise rep), Stms (Wise rep))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Scan rep]
scans ((Scan rep -> SimpleM rep (Scan (Wise rep), Stms (Wise rep)))
 -> SimpleM rep [(Scan (Wise rep), Stms (Wise rep))])
-> (Scan rep -> SimpleM rep (Scan (Wise rep), Stms (Wise rep)))
-> SimpleM rep [(Scan (Wise rep), Stms (Wise rep))]
forall a b. (a -> b) -> a -> b
$ \(Scan Lambda rep
lam [SubExp]
nes) -> do
      (Lambda (Wise rep)
lam', Stms (Wise rep)
hoisted) <- Lambda rep -> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall rep.
SimplifiableRep rep =>
Lambda rep -> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
Engine.simplifyLambda Lambda rep
lam
      [SubExp]
nes' <- [SubExp] -> SimpleM rep [SubExp]
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [SubExp]
nes
      (Scan (Wise rep), Stms (Wise rep))
-> SimpleM rep (Scan (Wise rep), Stms (Wise rep))
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda (Wise rep) -> [SubExp] -> Scan (Wise rep)
forall rep. Lambda rep -> [SubExp] -> Scan rep
Scan Lambda (Wise rep)
lam' [SubExp]
nes', Stms (Wise rep)
hoisted)

  ([Reduce (Wise rep)]
reds', [Stms (Wise rep)]
reds_hoisted) <- ([(Reduce (Wise rep), Stms (Wise rep))]
 -> ([Reduce (Wise rep)], [Stms (Wise rep)]))
-> SimpleM rep [(Reduce (Wise rep), Stms (Wise rep))]
-> SimpleM rep ([Reduce (Wise rep)], [Stms (Wise rep)])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [(Reduce (Wise rep), Stms (Wise rep))]
-> ([Reduce (Wise rep)], [Stms (Wise rep)])
forall a b. [(a, b)] -> ([a], [b])
unzip (SimpleM rep [(Reduce (Wise rep), Stms (Wise rep))]
 -> SimpleM rep ([Reduce (Wise rep)], [Stms (Wise rep)]))
-> SimpleM rep [(Reduce (Wise rep), Stms (Wise rep))]
-> SimpleM rep ([Reduce (Wise rep)], [Stms (Wise rep)])
forall a b. (a -> b) -> a -> b
$
    [Reduce rep]
-> (Reduce rep -> SimpleM rep (Reduce (Wise rep), Stms (Wise rep)))
-> SimpleM rep [(Reduce (Wise rep), Stms (Wise rep))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Reduce rep]
reds ((Reduce rep -> SimpleM rep (Reduce (Wise rep), Stms (Wise rep)))
 -> SimpleM rep [(Reduce (Wise rep), Stms (Wise rep))])
-> (Reduce rep -> SimpleM rep (Reduce (Wise rep), Stms (Wise rep)))
-> SimpleM rep [(Reduce (Wise rep), Stms (Wise rep))]
forall a b. (a -> b) -> a -> b
$ \(Reduce Commutativity
comm Lambda rep
lam [SubExp]
nes) -> do
      (Lambda (Wise rep)
lam', Stms (Wise rep)
hoisted) <- Lambda rep -> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall rep.
SimplifiableRep rep =>
Lambda rep -> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
Engine.simplifyLambda Lambda rep
lam
      [SubExp]
nes' <- [SubExp] -> SimpleM rep [SubExp]
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [SubExp]
nes
      (Reduce (Wise rep), Stms (Wise rep))
-> SimpleM rep (Reduce (Wise rep), Stms (Wise rep))
forall (m :: * -> *) a. Monad m => a -> m a
return (Commutativity -> Lambda (Wise rep) -> [SubExp] -> Reduce (Wise rep)
forall rep. Commutativity -> Lambda rep -> [SubExp] -> Reduce rep
Reduce Commutativity
comm Lambda (Wise rep)
lam' [SubExp]
nes', Stms (Wise rep)
hoisted)

  (Lambda (Wise rep)
map_lam', Stms (Wise rep)
map_lam_hoisted) <- Lambda rep -> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall rep.
SimplifiableRep rep =>
Lambda rep -> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
Engine.simplifyLambda Lambda rep
map_lam

  (,)
    (SOAC (Wise rep)
 -> Stms (Wise rep) -> (SOAC (Wise rep), Stms (Wise rep)))
-> SimpleM rep (SOAC (Wise rep))
-> SimpleM
     rep (Stms (Wise rep) -> (SOAC (Wise rep), Stms (Wise rep)))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ( SubExp -> [VName] -> ScremaForm (Wise rep) -> SOAC (Wise rep)
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma (SubExp -> [VName] -> ScremaForm (Wise rep) -> SOAC (Wise rep))
-> SimpleM rep SubExp
-> SimpleM
     rep ([VName] -> ScremaForm (Wise rep) -> SOAC (Wise rep))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> SimpleM rep SubExp
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify SubExp
w
            SimpleM rep ([VName] -> ScremaForm (Wise rep) -> SOAC (Wise rep))
-> SimpleM rep [VName]
-> SimpleM rep (ScremaForm (Wise rep) -> SOAC (Wise rep))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [VName] -> SimpleM rep [VName]
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [VName]
arrs
            SimpleM rep (ScremaForm (Wise rep) -> SOAC (Wise rep))
-> SimpleM rep (ScremaForm (Wise rep))
-> SimpleM rep (SOAC (Wise rep))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ScremaForm (Wise rep) -> SimpleM rep (ScremaForm (Wise rep))
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Scan (Wise rep)]
-> [Reduce (Wise rep)]
-> Lambda (Wise rep)
-> ScremaForm (Wise rep)
forall rep.
[Scan rep] -> [Reduce rep] -> Lambda rep -> ScremaForm rep
ScremaForm [Scan (Wise rep)]
scans' [Reduce (Wise rep)]
reds' Lambda (Wise rep)
map_lam')
        )
    SimpleM rep (Stms (Wise rep) -> (SOAC (Wise rep), Stms (Wise rep)))
-> SimpleM rep (Stms (Wise rep))
-> SimpleM rep (SOAC (Wise rep), Stms (Wise rep))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Stms (Wise rep) -> SimpleM rep (Stms (Wise rep))
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Stms (Wise rep)] -> Stms (Wise rep)
forall a. Monoid a => [a] -> a
mconcat [Stms (Wise rep)]
scans_hoisted Stms (Wise rep) -> Stms (Wise rep) -> Stms (Wise rep)
forall a. Semigroup a => a -> a -> a
<> [Stms (Wise rep)] -> Stms (Wise rep)
forall a. Monoid a => [a] -> a
mconcat [Stms (Wise rep)]
reds_hoisted Stms (Wise rep) -> Stms (Wise rep) -> Stms (Wise rep)
forall a. Semigroup a => a -> a -> a
<> Stms (Wise rep)
map_lam_hoisted)

instance BuilderOps (Wise SOACS)

fixLambdaParams ::
  (MonadBuilder m, Buildable (Rep m), BuilderOps (Rep m)) =>
  AST.Lambda (Rep m) ->
  [Maybe SubExp] ->
  m (AST.Lambda (Rep m))
fixLambdaParams :: Lambda (Rep m) -> [Maybe SubExp] -> m (Lambda (Rep m))
fixLambdaParams Lambda (Rep m)
lam [Maybe SubExp]
fixes = do
  Body (Rep m)
body <- Builder (Rep m) (Body (Rep m)) -> m (Body (Rep m))
forall rep (m :: * -> *) somerep.
(Buildable rep, MonadFreshNames m, HasScope somerep m,
 SameScope somerep rep) =>
Builder rep (Body rep) -> m (Body rep)
runBodyBuilder (Builder (Rep m) (Body (Rep m)) -> m (Body (Rep m)))
-> Builder (Rep m) (Body (Rep m)) -> m (Body (Rep m))
forall a b. (a -> b) -> a -> b
$
    Scope (Rep m)
-> Builder (Rep m) (Body (Rep m)) -> Builder (Rep m) (Body (Rep m))
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([Param Type] -> Scope (Rep m)
forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams ([Param Type] -> Scope (Rep m)) -> [Param Type] -> Scope (Rep m)
forall a b. (a -> b) -> a -> b
$ Lambda (Rep m) -> [LParam (Rep m)]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams Lambda (Rep m)
lam) (Builder (Rep m) (Body (Rep m)) -> Builder (Rep m) (Body (Rep m)))
-> Builder (Rep m) (Body (Rep m)) -> Builder (Rep m) (Body (Rep m))
forall a b. (a -> b) -> a -> b
$ do
      (Param Type
 -> Maybe SubExp -> BuilderT (Rep m) (State VNameSource) ())
-> [Param Type]
-> [Maybe SubExp]
-> BuilderT (Rep m) (State VNameSource) ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ Param Type
-> Maybe SubExp -> BuilderT (Rep m) (State VNameSource) ()
forall (m :: * -> *) dec.
MonadBuilder m =>
Param dec -> Maybe SubExp -> m ()
maybeFix (Lambda (Rep m) -> [LParam (Rep m)]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams Lambda (Rep m)
lam) [Maybe SubExp]
fixes'
      Body (Rep m) -> Builder (Rep m) (Body (Rep m))
forall (m :: * -> *) a. Monad m => a -> m a
return (Body (Rep m) -> Builder (Rep m) (Body (Rep m)))
-> Body (Rep m) -> Builder (Rep m) (Body (Rep m))
forall a b. (a -> b) -> a -> b
$ Lambda (Rep m) -> Body (Rep m)
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda (Rep m)
lam
  Lambda (Rep m) -> m (Lambda (Rep m))
forall (m :: * -> *) a. Monad m => a -> m a
return
    Lambda (Rep m)
lam
      { lambdaBody :: Body (Rep m)
lambdaBody = Body (Rep m)
body,
        lambdaParams :: [LParam (Rep m)]
lambdaParams =
          ((LParam (Rep m), Maybe SubExp) -> LParam (Rep m))
-> [(LParam (Rep m), Maybe SubExp)] -> [LParam (Rep m)]
forall a b. (a -> b) -> [a] -> [b]
map (LParam (Rep m), Maybe SubExp) -> LParam (Rep m)
forall a b. (a, b) -> a
fst ([(LParam (Rep m), Maybe SubExp)] -> [LParam (Rep m)])
-> [(LParam (Rep m), Maybe SubExp)] -> [LParam (Rep m)]
forall a b. (a -> b) -> a -> b
$
            ((LParam (Rep m), Maybe SubExp) -> Bool)
-> [(LParam (Rep m), Maybe SubExp)]
-> [(LParam (Rep m), Maybe SubExp)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Maybe SubExp -> Bool
forall a. Maybe a -> Bool
isNothing (Maybe SubExp -> Bool)
-> ((LParam (Rep m), Maybe SubExp) -> Maybe SubExp)
-> (LParam (Rep m), Maybe SubExp)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (LParam (Rep m), Maybe SubExp) -> Maybe SubExp
forall a b. (a, b) -> b
snd) ([(LParam (Rep m), Maybe SubExp)]
 -> [(LParam (Rep m), Maybe SubExp)])
-> [(LParam (Rep m), Maybe SubExp)]
-> [(LParam (Rep m), Maybe SubExp)]
forall a b. (a -> b) -> a -> b
$
              [LParam (Rep m)]
-> [Maybe SubExp] -> [(LParam (Rep m), Maybe SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Lambda (Rep m) -> [LParam (Rep m)]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams Lambda (Rep m)
lam) [Maybe SubExp]
fixes'
      }
  where
    fixes' :: [Maybe SubExp]
fixes' = [Maybe SubExp]
fixes [Maybe SubExp] -> [Maybe SubExp] -> [Maybe SubExp]
forall a. [a] -> [a] -> [a]
++ Maybe SubExp -> [Maybe SubExp]
forall a. a -> [a]
repeat Maybe SubExp
forall a. Maybe a
Nothing
    maybeFix :: Param dec -> Maybe SubExp -> m ()
maybeFix Param dec
p (Just SubExp
x) = [VName] -> Exp (Rep m) -> m ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
p] (Exp (Rep m) -> m ()) -> Exp (Rep m) -> m ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep m)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
x
    maybeFix Param dec
_ Maybe SubExp
Nothing = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

removeLambdaResults :: [Bool] -> AST.Lambda rep -> AST.Lambda rep
removeLambdaResults :: [Bool] -> Lambda rep -> Lambda rep
removeLambdaResults [Bool]
keep Lambda rep
lam =
  Lambda rep
lam
    { lambdaBody :: BodyT rep
lambdaBody = BodyT rep
lam_body',
      lambdaReturnType :: [Type]
lambdaReturnType = [Type]
ret
    }
  where
    keep' :: [a] -> [a]
    keep' :: [a] -> [a]
keep' = ((Bool, a) -> a) -> [(Bool, a)] -> [a]
forall a b. (a -> b) -> [a] -> [b]
map (Bool, a) -> a
forall a b. (a, b) -> b
snd ([(Bool, a)] -> [a]) -> ([a] -> [(Bool, a)]) -> [a] -> [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Bool, a) -> Bool) -> [(Bool, a)] -> [(Bool, a)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool, a) -> Bool
forall a b. (a, b) -> a
fst ([(Bool, a)] -> [(Bool, a)])
-> ([a] -> [(Bool, a)]) -> [a] -> [(Bool, a)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Bool] -> [a] -> [(Bool, a)]
forall a b. [a] -> [b] -> [(a, b)]
zip ([Bool]
keep [Bool] -> [Bool] -> [Bool]
forall a. [a] -> [a] -> [a]
++ Bool -> [Bool]
forall a. a -> [a]
repeat Bool
True)
    lam_body :: BodyT rep
lam_body = Lambda rep -> BodyT rep
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda rep
lam
    lam_body' :: BodyT rep
lam_body' = BodyT rep
lam_body {bodyResult :: Result
bodyResult = Result -> Result
forall a. [a] -> [a]
keep' (Result -> Result) -> Result -> Result
forall a b. (a -> b) -> a -> b
$ BodyT rep -> Result
forall rep. BodyT rep -> Result
bodyResult BodyT rep
lam_body}
    ret :: [Type]
ret = [Type] -> [Type]
forall a. [a] -> [a]
keep' ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType Lambda rep
lam

soacRules :: RuleBook (Wise SOACS)
soacRules :: RuleBook (Wise SOACS)
soacRules = RuleBook (Wise SOACS)
forall rep. (BuilderOps rep, Aliased rep) => RuleBook rep
standardRules RuleBook (Wise SOACS)
-> RuleBook (Wise SOACS) -> RuleBook (Wise SOACS)
forall a. Semigroup a => a -> a -> a
<> [TopDownRule (Wise SOACS)]
-> [BottomUpRule (Wise SOACS)] -> RuleBook (Wise SOACS)
forall m. [TopDownRule m] -> [BottomUpRule m] -> RuleBook m
ruleBook [TopDownRule (Wise SOACS)]
topDownRules [BottomUpRule (Wise SOACS)]
bottomUpRules

-- | Does this rep contain 'SOAC's in its t'Op's?  A rep must be an
-- instance of this class for the simplification rules to work.
class HasSOAC rep where
  asSOAC :: Op rep -> Maybe (SOAC rep)
  soacOp :: SOAC rep -> Op rep

instance HasSOAC (Wise SOACS) where
  asSOAC :: Op (Wise SOACS) -> Maybe (SOAC (Wise SOACS))
asSOAC = Op (Wise SOACS) -> Maybe (SOAC (Wise SOACS))
forall a. a -> Maybe a
Just
  soacOp :: SOAC (Wise SOACS) -> Op (Wise SOACS)
soacOp = SOAC (Wise SOACS) -> Op (Wise SOACS)
forall a. a -> a
id

topDownRules :: [TopDownRule (Wise SOACS)]
topDownRules :: [TopDownRule (Wise SOACS)]
topDownRules =
  [ RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
-> TopDownRule (Wise SOACS)
forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
hoistCerts,
    RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
-> TopDownRule (Wise SOACS)
forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
forall rep.
(Buildable rep, SimplifiableRep rep, HasSOAC (Wise rep)) =>
TopDownRuleOp (Wise rep)
removeReplicateMapping,
    RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
-> TopDownRule (Wise SOACS)
forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
removeReplicateWrite,
    RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
-> TopDownRule (Wise SOACS)
forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
removeUnusedSOACInput,
    RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
-> TopDownRule (Wise SOACS)
forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
simplifyClosedFormReduce,
    RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
-> TopDownRule (Wise SOACS)
forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
forall rep.
(Buildable rep, SimplifiableRep rep, HasSOAC (Wise rep)) =>
TopDownRuleOp (Wise rep)
simplifyKnownIterationSOAC,
    RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
-> TopDownRule (Wise SOACS)
forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
forall rep.
(Buildable rep, SimplifiableRep rep, HasSOAC (Wise rep)) =>
TopDownRuleOp (Wise rep)
liftIdentityMapping,
    RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
-> TopDownRule (Wise SOACS)
forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
removeDuplicateMapOutput,
    RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
-> TopDownRule (Wise SOACS)
forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
fuseConcatScatter,
    RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
-> TopDownRule (Wise SOACS)
forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
simplifyMapIota,
    RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
-> TopDownRule (Wise SOACS)
forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
moveTransformToInput
  ]

bottomUpRules :: [BottomUpRule (Wise SOACS)]
bottomUpRules :: [BottomUpRule (Wise SOACS)]
bottomUpRules =
  [ RuleOp (Wise SOACS) (BottomUp (Wise SOACS))
-> BottomUpRule (Wise SOACS)
forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp RuleOp (Wise SOACS) (BottomUp (Wise SOACS))
removeDeadMapping,
    RuleOp (Wise SOACS) (BottomUp (Wise SOACS))
-> BottomUpRule (Wise SOACS)
forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp RuleOp (Wise SOACS) (BottomUp (Wise SOACS))
removeDeadReduction,
    RuleOp (Wise SOACS) (BottomUp (Wise SOACS))
-> BottomUpRule (Wise SOACS)
forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp RuleOp (Wise SOACS) (BottomUp (Wise SOACS))
removeDeadWrite,
    RuleBasicOp (Wise SOACS) (BottomUp (Wise SOACS))
-> BottomUpRule (Wise SOACS)
forall rep a. RuleBasicOp rep a -> SimplificationRule rep a
RuleBasicOp RuleBasicOp (Wise SOACS) (BottomUp (Wise SOACS))
forall rep.
(BuilderOps rep, Aliased rep) =>
BottomUpRuleBasicOp rep
removeUnnecessaryCopy,
    RuleOp (Wise SOACS) (BottomUp (Wise SOACS))
-> BottomUpRule (Wise SOACS)
forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp RuleOp (Wise SOACS) (BottomUp (Wise SOACS))
liftIdentityStreaming,
    RuleOp (Wise SOACS) (BottomUp (Wise SOACS))
-> BottomUpRule (Wise SOACS)
forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp RuleOp (Wise SOACS) (BottomUp (Wise SOACS))
mapOpToOp
  ]

-- Any certificates attached to a trivial Stm in the body might as
-- well be applied to the SOAC itself.
hoistCerts :: TopDownRuleOp (Wise SOACS)
hoistCerts :: RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
hoistCerts SymbolTable (Wise SOACS)
vtable Pat (Wise SOACS)
pat StmAux (ExpDec (Wise SOACS))
aux Op (Wise SOACS)
soac
  | (SOAC (Wise SOACS)
soac', Certs
hoisted) <- State Certs (SOAC (Wise SOACS))
-> Certs -> (SOAC (Wise SOACS), Certs)
forall s a. State s a -> s -> (a, s)
runState (SOACMapper (Wise SOACS) (Wise SOACS) (StateT Certs Identity)
-> SOAC (Wise SOACS) -> State Certs (SOAC (Wise SOACS))
forall (m :: * -> *) frep trep.
(Applicative m, Monad m) =>
SOACMapper frep trep m -> SOAC frep -> m (SOAC trep)
mapSOACM SOACMapper (Wise SOACS) (Wise SOACS) (StateT Certs Identity)
mapper Op (Wise SOACS)
SOAC (Wise SOACS)
soac) Certs
forall a. Monoid a => a
mempty,
    Certs
hoisted Certs -> Certs -> Bool
forall a. Eq a => a -> a -> Bool
/= Certs
forall a. Monoid a => a
mempty =
    RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM (Wise SOACS) () -> Rule (Wise SOACS))
-> RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ StmAux (ExpWisdom, ())
-> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpWisdom, ())
StmAux (ExpDec (Wise SOACS))
aux (RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ())
-> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$ Certs -> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
hoisted (RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ())
-> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$ Pat (Rep (RuleM (Wise SOACS)))
-> Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m ()
letBind Pat (Rep (RuleM (Wise SOACS)))
Pat (Wise SOACS)
pat (Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ())
-> Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$ Op (Wise SOACS) -> Exp (Wise SOACS)
forall rep. Op rep -> ExpT rep
Op Op (Wise SOACS)
SOAC (Wise SOACS)
soac'
  where
    mapper :: SOACMapper (Wise SOACS) (Wise SOACS) (StateT Certs Identity)
mapper = SOACMapper Any Any (StateT Certs Identity)
forall (m :: * -> *) rep. Monad m => SOACMapper rep rep m
identitySOACMapper {mapOnSOACLambda :: Lambda (Wise SOACS) -> StateT Certs Identity (Lambda (Wise SOACS))
mapOnSOACLambda = Lambda (Wise SOACS) -> StateT Certs Identity (Lambda (Wise SOACS))
onLambda}
    onLambda :: Lambda (Wise SOACS) -> StateT Certs Identity (Lambda (Wise SOACS))
onLambda Lambda (Wise SOACS)
lam = do
      Stms (Wise SOACS)
stms' <- (Stm (Wise SOACS) -> StateT Certs Identity (Stm (Wise SOACS)))
-> Stms (Wise SOACS) -> StateT Certs Identity (Stms (Wise SOACS))
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Stm (Wise SOACS) -> StateT Certs Identity (Stm (Wise SOACS))
onStm (Stms (Wise SOACS) -> StateT Certs Identity (Stms (Wise SOACS)))
-> Stms (Wise SOACS) -> StateT Certs Identity (Stms (Wise SOACS))
forall a b. (a -> b) -> a -> b
$ Body (Wise SOACS) -> Stms (Wise SOACS)
forall rep. BodyT rep -> Stms rep
bodyStms (Body (Wise SOACS) -> Stms (Wise SOACS))
-> Body (Wise SOACS) -> Stms (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ Lambda (Wise SOACS) -> Body (Wise SOACS)
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda (Wise SOACS)
lam
      Lambda (Wise SOACS) -> StateT Certs Identity (Lambda (Wise SOACS))
forall (m :: * -> *) a. Monad m => a -> m a
return
        Lambda (Wise SOACS)
lam
          { lambdaBody :: Body (Wise SOACS)
lambdaBody =
              Stms (Wise SOACS) -> Result -> Body (Wise SOACS)
forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody Stms (Wise SOACS)
stms' (Result -> Body (Wise SOACS)) -> Result -> Body (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ Body (Wise SOACS) -> Result
forall rep. BodyT rep -> Result
bodyResult (Body (Wise SOACS) -> Result) -> Body (Wise SOACS) -> Result
forall a b. (a -> b) -> a -> b
$ Lambda (Wise SOACS) -> Body (Wise SOACS)
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda (Wise SOACS)
lam
          }
    onStm :: Stm (Wise SOACS) -> StateT Certs Identity (Stm (Wise SOACS))
onStm (Let Pat (Wise SOACS)
se_pat StmAux (ExpDec (Wise SOACS))
se_aux (BasicOp (SubExp SubExp
se))) = do
      let ([VName]
invariant, [VName]
variant) =
            (VName -> Bool) -> [VName] -> ([VName], [VName])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition (VName -> SymbolTable (Wise SOACS) -> Bool
forall rep. VName -> SymbolTable rep -> Bool
`ST.elem` SymbolTable (Wise SOACS)
vtable) ([VName] -> ([VName], [VName])) -> [VName] -> ([VName], [VName])
forall a b. (a -> b) -> a -> b
$
              Certs -> [VName]
unCerts (Certs -> [VName]) -> Certs -> [VName]
forall a b. (a -> b) -> a -> b
$ StmAux (ExpWisdom, ()) -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (ExpWisdom, ())
StmAux (ExpDec (Wise SOACS))
se_aux
          se_aux' :: StmAux (ExpWisdom, ())
se_aux' = StmAux (ExpWisdom, ())
StmAux (ExpDec (Wise SOACS))
se_aux {stmAuxCerts :: Certs
stmAuxCerts = [VName] -> Certs
Certs [VName]
variant}
      (Certs -> Certs) -> StateT Certs Identity ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ([VName] -> Certs
Certs [VName]
invariant Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<>)
      Stm (Wise SOACS) -> StateT Certs Identity (Stm (Wise SOACS))
forall (m :: * -> *) a. Monad m => a -> m a
return (Stm (Wise SOACS) -> StateT Certs Identity (Stm (Wise SOACS)))
-> Stm (Wise SOACS) -> StateT Certs Identity (Stm (Wise SOACS))
forall a b. (a -> b) -> a -> b
$ Pat (Wise SOACS)
-> StmAux (ExpDec (Wise SOACS))
-> Exp (Wise SOACS)
-> Stm (Wise SOACS)
forall rep. Pat rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (Wise SOACS)
se_pat StmAux (ExpWisdom, ())
StmAux (ExpDec (Wise SOACS))
se_aux' (Exp (Wise SOACS) -> Stm (Wise SOACS))
-> Exp (Wise SOACS) -> Stm (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Wise SOACS)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp (Wise SOACS)) -> BasicOp -> Exp (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
    onStm Stm (Wise SOACS)
stm = Stm (Wise SOACS) -> StateT Certs Identity (Stm (Wise SOACS))
forall (m :: * -> *) a. Monad m => a -> m a
return Stm (Wise SOACS)
stm
hoistCerts SymbolTable (Wise SOACS)
_ Pat (Wise SOACS)
_ StmAux (ExpDec (Wise SOACS))
_ Op (Wise SOACS)
_ =
  Rule (Wise SOACS)
forall rep. Rule rep
Skip

liftIdentityMapping ::
  forall rep.
  (Buildable rep, Simplify.SimplifiableRep rep, HasSOAC (Wise rep)) =>
  TopDownRuleOp (Wise rep)
liftIdentityMapping :: TopDownRuleOp (Wise rep)
liftIdentityMapping TopDown (Wise rep)
_ Pat (Wise rep)
pat StmAux (ExpDec (Wise rep))
aux Op (Wise rep)
op
  | Just (Screma SubExp
w [VName]
arrs ScremaForm (Wise rep)
form :: SOAC (Wise rep)) <- Op (Wise rep) -> Maybe (SOAC (Wise rep))
forall rep. HasSOAC rep => Op rep -> Maybe (SOAC rep)
asSOAC Op (Wise rep)
op,
    Just Lambda (Wise rep)
fun <- ScremaForm (Wise rep) -> Maybe (Lambda (Wise rep))
forall rep. ScremaForm rep -> Maybe (Lambda rep)
isMapSOAC ScremaForm (Wise rep)
form = do
    let inputMap :: Map VName VName
inputMap = [(VName, VName)] -> Map VName VName
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, VName)] -> Map VName VName)
-> [(VName, VName)] -> Map VName VName
forall a b. (a -> b) -> a -> b
$ [VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip ((Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName ([Param Type] -> [VName]) -> [Param Type] -> [VName]
forall a b. (a -> b) -> a -> b
$ Lambda (Wise rep) -> [LParam (Wise rep)]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams Lambda (Wise rep)
fun) [VName]
arrs
        free :: Names
free = BodyT (Wise rep) -> Names
forall a. FreeIn a => a -> Names
freeIn (BodyT (Wise rep) -> Names) -> BodyT (Wise rep) -> Names
forall a b. (a -> b) -> a -> b
$ Lambda (Wise rep) -> BodyT (Wise rep)
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda (Wise rep)
fun
        rettype :: [Type]
rettype = Lambda (Wise rep) -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType Lambda (Wise rep)
fun
        ses :: Result
ses = BodyT (Wise rep) -> Result
forall rep. BodyT rep -> Result
bodyResult (BodyT (Wise rep) -> Result) -> BodyT (Wise rep) -> Result
forall a b. (a -> b) -> a -> b
$ Lambda (Wise rep) -> BodyT (Wise rep)
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda (Wise rep)
fun

        freeOrConst :: SubExp -> Bool
freeOrConst (Var VName
v) = VName
v VName -> Names -> Bool
`nameIn` Names
free
        freeOrConst Constant {} = Bool
True

        checkInvariance :: (PatElemT (VarWisdom, LetDec rep), SubExpRes, Type)
-> ([(PatT (VarWisdom, LetDec rep), ExpT (Wise rep))],
    [(PatElemT (VarWisdom, LetDec rep), SubExp)], [Type])
-> ([(PatT (VarWisdom, LetDec rep), ExpT (Wise rep))],
    [(PatElemT (VarWisdom, LetDec rep), SubExp)], [Type])
checkInvariance (PatElemT (VarWisdom, LetDec rep)
outId, SubExpRes Certs
_ (Var VName
v), Type
_) ([(PatT (VarWisdom, LetDec rep), ExpT (Wise rep))]
invariant, [(PatElemT (VarWisdom, LetDec rep), SubExp)]
mapresult, [Type]
rettype')
          | Just VName
inp <- VName -> Map VName VName -> Maybe VName
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName VName
inputMap =
            ( ([PatElemT (VarWisdom, LetDec rep)] -> PatT (VarWisdom, LetDec rep)
forall dec. [PatElemT dec] -> PatT dec
Pat [PatElemT (VarWisdom, LetDec rep)
outId], VName -> ExpT (Wise rep)
e VName
inp) (PatT (VarWisdom, LetDec rep), ExpT (Wise rep))
-> [(PatT (VarWisdom, LetDec rep), ExpT (Wise rep))]
-> [(PatT (VarWisdom, LetDec rep), ExpT (Wise rep))]
forall a. a -> [a] -> [a]
: [(PatT (VarWisdom, LetDec rep), ExpT (Wise rep))]
invariant,
              [(PatElemT (VarWisdom, LetDec rep), SubExp)]
mapresult,
              [Type]
rettype'
            )
          where
            e :: VName -> ExpT (Wise rep)
e VName
inp = case PatElemT (VarWisdom, LetDec rep) -> Type
forall dec. Typed dec => PatElemT dec -> Type
patElemType PatElemT (VarWisdom, LetDec rep)
outId of
              Acc {} -> BasicOp -> ExpT (Wise rep)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT (Wise rep)) -> BasicOp -> ExpT (Wise rep)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
inp
              Type
_ -> BasicOp -> ExpT (Wise rep)
forall rep. BasicOp -> ExpT rep
BasicOp (VName -> BasicOp
Copy VName
inp)
        checkInvariance (PatElemT (VarWisdom, LetDec rep)
outId, SubExpRes Certs
_ SubExp
e, Type
t) ([(PatT (VarWisdom, LetDec rep), ExpT (Wise rep))]
invariant, [(PatElemT (VarWisdom, LetDec rep), SubExp)]
mapresult, [Type]
rettype')
          | SubExp -> Bool
freeOrConst SubExp
e =
            ( ([PatElemT (VarWisdom, LetDec rep)] -> PatT (VarWisdom, LetDec rep)
forall dec. [PatElemT dec] -> PatT dec
Pat [PatElemT (VarWisdom, LetDec rep)
outId], BasicOp -> ExpT (Wise rep)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT (Wise rep)) -> BasicOp -> ExpT (Wise rep)
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
w]) SubExp
e) (PatT (VarWisdom, LetDec rep), ExpT (Wise rep))
-> [(PatT (VarWisdom, LetDec rep), ExpT (Wise rep))]
-> [(PatT (VarWisdom, LetDec rep), ExpT (Wise rep))]
forall a. a -> [a] -> [a]
: [(PatT (VarWisdom, LetDec rep), ExpT (Wise rep))]
invariant,
              [(PatElemT (VarWisdom, LetDec rep), SubExp)]
mapresult,
              [Type]
rettype'
            )
          | Bool
otherwise =
            ( [(PatT (VarWisdom, LetDec rep), ExpT (Wise rep))]
invariant,
              (PatElemT (VarWisdom, LetDec rep)
outId, SubExp
e) (PatElemT (VarWisdom, LetDec rep), SubExp)
-> [(PatElemT (VarWisdom, LetDec rep), SubExp)]
-> [(PatElemT (VarWisdom, LetDec rep), SubExp)]
forall a. a -> [a] -> [a]
: [(PatElemT (VarWisdom, LetDec rep), SubExp)]
mapresult,
              Type
t Type -> [Type] -> [Type]
forall a. a -> [a] -> [a]
: [Type]
rettype'
            )

    case ((PatElemT (VarWisdom, LetDec rep), SubExpRes, Type)
 -> ([(PatT (VarWisdom, LetDec rep), ExpT (Wise rep))],
     [(PatElemT (VarWisdom, LetDec rep), SubExp)], [Type])
 -> ([(PatT (VarWisdom, LetDec rep), ExpT (Wise rep))],
     [(PatElemT (VarWisdom, LetDec rep), SubExp)], [Type]))
-> ([(PatT (VarWisdom, LetDec rep), ExpT (Wise rep))],
    [(PatElemT (VarWisdom, LetDec rep), SubExp)], [Type])
-> [(PatElemT (VarWisdom, LetDec rep), SubExpRes, Type)]
-> ([(PatT (VarWisdom, LetDec rep), ExpT (Wise rep))],
    [(PatElemT (VarWisdom, LetDec rep), SubExp)], [Type])
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (PatElemT (VarWisdom, LetDec rep), SubExpRes, Type)
-> ([(PatT (VarWisdom, LetDec rep), ExpT (Wise rep))],
    [(PatElemT (VarWisdom, LetDec rep), SubExp)], [Type])
-> ([(PatT (VarWisdom, LetDec rep), ExpT (Wise rep))],
    [(PatElemT (VarWisdom, LetDec rep), SubExp)], [Type])
checkInvariance ([], [], []) ([(PatElemT (VarWisdom, LetDec rep), SubExpRes, Type)]
 -> ([(PatT (VarWisdom, LetDec rep), ExpT (Wise rep))],
     [(PatElemT (VarWisdom, LetDec rep), SubExp)], [Type]))
-> [(PatElemT (VarWisdom, LetDec rep), SubExpRes, Type)]
-> ([(PatT (VarWisdom, LetDec rep), ExpT (Wise rep))],
    [(PatElemT (VarWisdom, LetDec rep), SubExp)], [Type])
forall a b. (a -> b) -> a -> b
$
      [PatElemT (VarWisdom, LetDec rep)]
-> Result
-> [Type]
-> [(PatElemT (VarWisdom, LetDec rep), SubExpRes, Type)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 (PatT (VarWisdom, LetDec rep) -> [PatElemT (VarWisdom, LetDec rep)]
forall dec. PatT dec -> [PatElemT dec]
patElems PatT (VarWisdom, LetDec rep)
Pat (Wise rep)
pat) Result
ses [Type]
rettype of
      ([], [(PatElemT (VarWisdom, LetDec rep), SubExp)]
_, [Type]
_) -> Rule (Wise rep)
forall rep. Rule rep
Skip
      ([(PatT (VarWisdom, LetDec rep), ExpT (Wise rep))]
invariant, [(PatElemT (VarWisdom, LetDec rep), SubExp)]
mapresult, [Type]
rettype') -> RuleM (Wise rep) () -> Rule (Wise rep)
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM (Wise rep) () -> Rule (Wise rep))
-> RuleM (Wise rep) () -> Rule (Wise rep)
forall a b. (a -> b) -> a -> b
$ do
        let ([PatElemT (VarWisdom, LetDec rep)]
pat', [SubExp]
ses') = [(PatElemT (VarWisdom, LetDec rep), SubExp)]
-> ([PatElemT (VarWisdom, LetDec rep)], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip [(PatElemT (VarWisdom, LetDec rep), SubExp)]
mapresult
            fun' :: Lambda (Wise rep)
fun' =
              Lambda (Wise rep)
fun
                { lambdaBody :: BodyT (Wise rep)
lambdaBody = (Lambda (Wise rep) -> BodyT (Wise rep)
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda (Wise rep)
fun) {bodyResult :: Result
bodyResult = [SubExp] -> Result
subExpsRes [SubExp]
ses'},
                  lambdaReturnType :: [Type]
lambdaReturnType = [Type]
rettype'
                }
        ((PatT (VarWisdom, LetDec rep), ExpT (Wise rep))
 -> RuleM (Wise rep) ())
-> [(PatT (VarWisdom, LetDec rep), ExpT (Wise rep))]
-> RuleM (Wise rep) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ((PatT (VarWisdom, LetDec rep)
 -> ExpT (Wise rep) -> RuleM (Wise rep) ())
-> (PatT (VarWisdom, LetDec rep), ExpT (Wise rep))
-> RuleM (Wise rep) ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry PatT (VarWisdom, LetDec rep)
-> ExpT (Wise rep) -> RuleM (Wise rep) ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m ()
letBind) [(PatT (VarWisdom, LetDec rep), ExpT (Wise rep))]
invariant
        StmAux (ExpWisdom, ExpDec rep)
-> RuleM (Wise rep) () -> RuleM (Wise rep) ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpWisdom, ExpDec rep)
StmAux (ExpDec (Wise rep))
aux (RuleM (Wise rep) () -> RuleM (Wise rep) ())
-> RuleM (Wise rep) () -> RuleM (Wise rep) ()
forall a b. (a -> b) -> a -> b
$
          [VName] -> Exp (Rep (RuleM (Wise rep))) -> RuleM (Wise rep) ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames ((PatElemT (VarWisdom, LetDec rep) -> VName)
-> [PatElemT (VarWisdom, LetDec rep)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map PatElemT (VarWisdom, LetDec rep) -> VName
forall dec. PatElemT dec -> VName
patElemName [PatElemT (VarWisdom, LetDec rep)]
pat') (Exp (Rep (RuleM (Wise rep))) -> RuleM (Wise rep) ())
-> Exp (Rep (RuleM (Wise rep))) -> RuleM (Wise rep) ()
forall a b. (a -> b) -> a -> b
$ Op (Wise rep) -> ExpT (Wise rep)
forall rep. Op rep -> ExpT rep
Op (Op (Wise rep) -> ExpT (Wise rep))
-> Op (Wise rep) -> ExpT (Wise rep)
forall a b. (a -> b) -> a -> b
$ SOAC (Wise rep) -> Op (Wise rep)
forall rep. HasSOAC rep => SOAC rep -> Op rep
soacOp (SOAC (Wise rep) -> Op (Wise rep))
-> SOAC (Wise rep) -> Op (Wise rep)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm (Wise rep) -> SOAC (Wise rep)
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName]
arrs (Lambda (Wise rep) -> ScremaForm (Wise rep)
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda (Wise rep)
fun')
liftIdentityMapping TopDown (Wise rep)
_ Pat (Wise rep)
_ StmAux (ExpDec (Wise rep))
_ Op (Wise rep)
_ = Rule (Wise rep)
forall rep. Rule rep
Skip

liftIdentityStreaming :: BottomUpRuleOp (Wise SOACS)
liftIdentityStreaming :: RuleOp (Wise SOACS) (BottomUp (Wise SOACS))
liftIdentityStreaming BottomUp (Wise SOACS)
_ (Pat [PatElemT (LetDec (Wise SOACS))]
pes) StmAux (ExpDec (Wise SOACS))
aux (Stream w arrs form nes lam)
  | ([(Type, PatElemT (VarWisdom, Type), SubExpRes)]
variant_map, [(PatElemT (VarWisdom, Type), VName)]
invariant_map) <-
      [Either
   (Type, PatElemT (VarWisdom, Type), SubExpRes)
   (PatElemT (VarWisdom, Type), VName)]
-> ([(Type, PatElemT (VarWisdom, Type), SubExpRes)],
    [(PatElemT (VarWisdom, Type), VName)])
forall a b. [Either a b] -> ([a], [b])
partitionEithers ([Either
    (Type, PatElemT (VarWisdom, Type), SubExpRes)
    (PatElemT (VarWisdom, Type), VName)]
 -> ([(Type, PatElemT (VarWisdom, Type), SubExpRes)],
     [(PatElemT (VarWisdom, Type), VName)]))
-> [Either
      (Type, PatElemT (VarWisdom, Type), SubExpRes)
      (PatElemT (VarWisdom, Type), VName)]
-> ([(Type, PatElemT (VarWisdom, Type), SubExpRes)],
    [(PatElemT (VarWisdom, Type), VName)])
forall a b. (a -> b) -> a -> b
$ ((Type, PatElemT (VarWisdom, Type), SubExpRes)
 -> Either
      (Type, PatElemT (VarWisdom, Type), SubExpRes)
      (PatElemT (VarWisdom, Type), VName))
-> [(Type, PatElemT (VarWisdom, Type), SubExpRes)]
-> [Either
      (Type, PatElemT (VarWisdom, Type), SubExpRes)
      (PatElemT (VarWisdom, Type), VName)]
forall a b. (a -> b) -> [a] -> [b]
map (Type, PatElemT (VarWisdom, Type), SubExpRes)
-> Either
     (Type, PatElemT (VarWisdom, Type), SubExpRes)
     (PatElemT (VarWisdom, Type), VName)
isInvariantRes ([(Type, PatElemT (VarWisdom, Type), SubExpRes)]
 -> [Either
       (Type, PatElemT (VarWisdom, Type), SubExpRes)
       (PatElemT (VarWisdom, Type), VName)])
-> [(Type, PatElemT (VarWisdom, Type), SubExpRes)]
-> [Either
      (Type, PatElemT (VarWisdom, Type), SubExpRes)
      (PatElemT (VarWisdom, Type), VName)]
forall a b. (a -> b) -> a -> b
$ [Type]
-> [PatElemT (VarWisdom, Type)]
-> Result
-> [(Type, PatElemT (VarWisdom, Type), SubExpRes)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Type]
map_ts [PatElemT (VarWisdom, Type)]
map_pes Result
map_res,
    Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [(PatElemT (VarWisdom, Type), VName)] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(PatElemT (VarWisdom, Type), VName)]
invariant_map = RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM (Wise SOACS) () -> Rule (Wise SOACS))
-> RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ do
    [(PatElemT (VarWisdom, Type), VName)]
-> ((PatElemT (VarWisdom, Type), VName) -> RuleM (Wise SOACS) ())
-> RuleM (Wise SOACS) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(PatElemT (VarWisdom, Type), VName)]
invariant_map (((PatElemT (VarWisdom, Type), VName) -> RuleM (Wise SOACS) ())
 -> RuleM (Wise SOACS) ())
-> ((PatElemT (VarWisdom, Type), VName) -> RuleM (Wise SOACS) ())
-> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$ \(PatElemT (VarWisdom, Type)
pe, VName
arr) ->
      Pat (Rep (RuleM (Wise SOACS)))
-> Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m ()
letBind ([PatElemT (VarWisdom, Type)] -> PatT (VarWisdom, Type)
forall dec. [PatElemT dec] -> PatT dec
Pat [PatElemT (VarWisdom, Type)
pe]) (Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ())
-> Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Wise SOACS)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp (Wise SOACS)) -> BasicOp -> Exp (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
arr

    let ([Type]
variant_map_ts, [PatElemT (VarWisdom, Type)]
variant_map_pes, Result
variant_map_res) = [(Type, PatElemT (VarWisdom, Type), SubExpRes)]
-> ([Type], [PatElemT (VarWisdom, Type)], Result)
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(Type, PatElemT (VarWisdom, Type), SubExpRes)]
variant_map
        lam' :: Lambda (Wise SOACS)
lam' =
          Lambda (Wise SOACS)
lam
            { lambdaBody :: Body (Wise SOACS)
lambdaBody = (Lambda (Wise SOACS) -> Body (Wise SOACS)
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda (Wise SOACS)
lam) {bodyResult :: Result
bodyResult = Result
fold_res Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Result
variant_map_res},
              lambdaReturnType :: [Type]
lambdaReturnType = [Type]
fold_ts [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [Type]
variant_map_ts
            }

    StmAux (ExpWisdom, ())
-> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpWisdom, ())
StmAux (ExpDec (Wise SOACS))
aux (RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ())
-> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$
      Pat (Rep (RuleM (Wise SOACS)))
-> Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m ()
letBind ([PatElemT (VarWisdom, Type)] -> PatT (VarWisdom, Type)
forall dec. [PatElemT dec] -> PatT dec
Pat ([PatElemT (VarWisdom, Type)] -> PatT (VarWisdom, Type))
-> [PatElemT (VarWisdom, Type)] -> PatT (VarWisdom, Type)
forall a b. (a -> b) -> a -> b
$ [PatElemT (VarWisdom, Type)]
fold_pes [PatElemT (VarWisdom, Type)]
-> [PatElemT (VarWisdom, Type)] -> [PatElemT (VarWisdom, Type)]
forall a. [a] -> [a] -> [a]
++ [PatElemT (VarWisdom, Type)]
variant_map_pes) (Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ())
-> Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$
        Op (Wise SOACS) -> Exp (Wise SOACS)
forall rep. Op rep -> ExpT rep
Op (Op (Wise SOACS) -> Exp (Wise SOACS))
-> Op (Wise SOACS) -> Exp (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ SubExp
-> [VName]
-> StreamForm (Wise SOACS)
-> [SubExp]
-> Lambda (Wise SOACS)
-> SOAC (Wise SOACS)
forall rep.
SubExp
-> [VName] -> StreamForm rep -> [SubExp] -> Lambda rep -> SOAC rep
Stream SubExp
w [VName]
arrs StreamForm (Wise SOACS)
form [SubExp]
nes Lambda (Wise SOACS)
lam'
  where
    num_folds :: Int
num_folds = [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes
    ([PatElemT (VarWisdom, Type)]
fold_pes, [PatElemT (VarWisdom, Type)]
map_pes) = Int
-> [PatElemT (VarWisdom, Type)]
-> ([PatElemT (VarWisdom, Type)], [PatElemT (VarWisdom, Type)])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_folds [PatElemT (VarWisdom, Type)]
[PatElemT (LetDec (Wise SOACS))]
pes
    ([Type]
fold_ts, [Type]
map_ts) = Int -> [Type] -> ([Type], [Type])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_folds ([Type] -> ([Type], [Type])) -> [Type] -> ([Type], [Type])
forall a b. (a -> b) -> a -> b
$ Lambda (Wise SOACS) -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType Lambda (Wise SOACS)
lam
    lam_res :: Result
lam_res = Body (Wise SOACS) -> Result
forall rep. BodyT rep -> Result
bodyResult (Body (Wise SOACS) -> Result) -> Body (Wise SOACS) -> Result
forall a b. (a -> b) -> a -> b
$ Lambda (Wise SOACS) -> Body (Wise SOACS)
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda (Wise SOACS)
lam
    (Result
fold_res, Result
map_res) = Int -> Result -> (Result, Result)
forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_folds Result
lam_res
    params_to_arrs :: [(VName, VName)]
params_to_arrs = [VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip ((Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName ([Param Type] -> [VName]) -> [Param Type] -> [VName]
forall a b. (a -> b) -> a -> b
$ Int -> [Param Type] -> [Param Type]
forall a. Int -> [a] -> [a]
drop (Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
num_folds) ([Param Type] -> [Param Type]) -> [Param Type] -> [Param Type]
forall a b. (a -> b) -> a -> b
$ Lambda (Wise SOACS) -> [LParam (Wise SOACS)]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams Lambda (Wise SOACS)
lam) [VName]
arrs

    isInvariantRes :: (Type, PatElemT (VarWisdom, Type), SubExpRes)
-> Either
     (Type, PatElemT (VarWisdom, Type), SubExpRes)
     (PatElemT (VarWisdom, Type), VName)
isInvariantRes (Type
_, PatElemT (VarWisdom, Type)
pe, SubExpRes Certs
_ (Var VName
v))
      | Just VName
arr <- VName -> [(VName, VName)] -> Maybe VName
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup VName
v [(VName, VName)]
params_to_arrs =
        (PatElemT (VarWisdom, Type), VName)
-> Either
     (Type, PatElemT (VarWisdom, Type), SubExpRes)
     (PatElemT (VarWisdom, Type), VName)
forall a b. b -> Either a b
Right (PatElemT (VarWisdom, Type)
pe, VName
arr)
    isInvariantRes (Type, PatElemT (VarWisdom, Type), SubExpRes)
x =
      (Type, PatElemT (VarWisdom, Type), SubExpRes)
-> Either
     (Type, PatElemT (VarWisdom, Type), SubExpRes)
     (PatElemT (VarWisdom, Type), VName)
forall a b. a -> Either a b
Left (Type, PatElemT (VarWisdom, Type), SubExpRes)
x
liftIdentityStreaming BottomUp (Wise SOACS)
_ Pat (Wise SOACS)
_ StmAux (ExpDec (Wise SOACS))
_ Op (Wise SOACS)
_ = Rule (Wise SOACS)
forall rep. Rule rep
Skip

-- | Remove all arguments to the map that are simply replicates.
-- These can be turned into free variables instead.
removeReplicateMapping ::
  (Buildable rep, Simplify.SimplifiableRep rep, HasSOAC (Wise rep)) =>
  TopDownRuleOp (Wise rep)
removeReplicateMapping :: TopDownRuleOp (Wise rep)
removeReplicateMapping TopDown (Wise rep)
vtable Pat (Wise rep)
pat StmAux (ExpDec (Wise rep))
aux Op (Wise rep)
op
  | Just (Screma SubExp
w [VName]
arrs ScremaForm (Wise rep)
form) <- Op (Wise rep) -> Maybe (SOAC (Wise rep))
forall rep. HasSOAC rep => Op rep -> Maybe (SOAC rep)
asSOAC Op (Wise rep)
op,
    Just Lambda (Wise rep)
fun <- ScremaForm (Wise rep) -> Maybe (Lambda (Wise rep))
forall rep. ScremaForm rep -> Maybe (Lambda rep)
isMapSOAC ScremaForm (Wise rep)
form,
    Just ([([VName], Certs, Exp (Wise rep))]
stms, Lambda (Wise rep)
fun', [VName]
arrs') <- TopDown (Wise rep)
-> Lambda (Wise rep)
-> [VName]
-> Maybe
     ([([VName], Certs, Exp (Wise rep))], Lambda (Wise rep), [VName])
forall rep.
Aliased rep =>
SymbolTable rep
-> Lambda rep
-> [VName]
-> Maybe ([([VName], Certs, Exp rep)], Lambda rep, [VName])
removeReplicateInput TopDown (Wise rep)
vtable Lambda (Wise rep)
fun [VName]
arrs = RuleM (Wise rep) () -> Rule (Wise rep)
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM (Wise rep) () -> Rule (Wise rep))
-> RuleM (Wise rep) () -> Rule (Wise rep)
forall a b. (a -> b) -> a -> b
$ do
    [([VName], Certs, Exp (Wise rep))]
-> (([VName], Certs, Exp (Wise rep)) -> RuleM (Wise rep) ())
-> RuleM (Wise rep) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [([VName], Certs, Exp (Wise rep))]
stms ((([VName], Certs, Exp (Wise rep)) -> RuleM (Wise rep) ())
 -> RuleM (Wise rep) ())
-> (([VName], Certs, Exp (Wise rep)) -> RuleM (Wise rep) ())
-> RuleM (Wise rep) ()
forall a b. (a -> b) -> a -> b
$ \([VName]
vs, Certs
cs, Exp (Wise rep)
e) -> Certs -> RuleM (Wise rep) () -> RuleM (Wise rep) ()
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs (RuleM (Wise rep) () -> RuleM (Wise rep) ())
-> RuleM (Wise rep) () -> RuleM (Wise rep) ()
forall a b. (a -> b) -> a -> b
$ [VName] -> Exp (Rep (RuleM (Wise rep))) -> RuleM (Wise rep) ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName]
vs Exp (Rep (RuleM (Wise rep)))
Exp (Wise rep)
e
    StmAux (ExpWisdom, ExpDec rep)
-> RuleM (Wise rep) () -> RuleM (Wise rep) ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpWisdom, ExpDec rep)
StmAux (ExpDec (Wise rep))
aux (RuleM (Wise rep) () -> RuleM (Wise rep) ())
-> RuleM (Wise rep) () -> RuleM (Wise rep) ()
forall a b. (a -> b) -> a -> b
$ Pat (Rep (RuleM (Wise rep)))
-> Exp (Rep (RuleM (Wise rep))) -> RuleM (Wise rep) ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m ()
letBind Pat (Rep (RuleM (Wise rep)))
Pat (Wise rep)
pat (Exp (Rep (RuleM (Wise rep))) -> RuleM (Wise rep) ())
-> Exp (Rep (RuleM (Wise rep))) -> RuleM (Wise rep) ()
forall a b. (a -> b) -> a -> b
$ Op (Wise rep) -> Exp (Wise rep)
forall rep. Op rep -> ExpT rep
Op (Op (Wise rep) -> Exp (Wise rep))
-> Op (Wise rep) -> Exp (Wise rep)
forall a b. (a -> b) -> a -> b
$ SOAC (Wise rep) -> Op (Wise rep)
forall rep. HasSOAC rep => SOAC rep -> Op rep
soacOp (SOAC (Wise rep) -> Op (Wise rep))
-> SOAC (Wise rep) -> Op (Wise rep)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm (Wise rep) -> SOAC (Wise rep)
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName]
arrs' (ScremaForm (Wise rep) -> SOAC (Wise rep))
-> ScremaForm (Wise rep) -> SOAC (Wise rep)
forall a b. (a -> b) -> a -> b
$ Lambda (Wise rep) -> ScremaForm (Wise rep)
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda (Wise rep)
fun'
removeReplicateMapping TopDown (Wise rep)
_ Pat (Wise rep)
_ StmAux (ExpDec (Wise rep))
_ Op (Wise rep)
_ = Rule (Wise rep)
forall rep. Rule rep
Skip

-- | Like 'removeReplicateMapping', but for 'Scatter'.
removeReplicateWrite :: TopDownRuleOp (Wise SOACS)
removeReplicateWrite :: RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
removeReplicateWrite SymbolTable (Wise SOACS)
vtable Pat (Wise SOACS)
pat StmAux (ExpDec (Wise SOACS))
aux (Scatter len lam ivs as)
  | Just ([([VName], Certs, Exp (Wise SOACS))]
stms, Lambda (Wise SOACS)
lam', [VName]
ivs') <- SymbolTable (Wise SOACS)
-> Lambda (Wise SOACS)
-> [VName]
-> Maybe
     ([([VName], Certs, Exp (Wise SOACS))], Lambda (Wise SOACS),
      [VName])
forall rep.
Aliased rep =>
SymbolTable rep
-> Lambda rep
-> [VName]
-> Maybe ([([VName], Certs, Exp rep)], Lambda rep, [VName])
removeReplicateInput SymbolTable (Wise SOACS)
vtable Lambda (Wise SOACS)
lam [VName]
ivs = RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM (Wise SOACS) () -> Rule (Wise SOACS))
-> RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ do
    [([VName], Certs, Exp (Wise SOACS))]
-> (([VName], Certs, Exp (Wise SOACS)) -> RuleM (Wise SOACS) ())
-> RuleM (Wise SOACS) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [([VName], Certs, Exp (Wise SOACS))]
stms ((([VName], Certs, Exp (Wise SOACS)) -> RuleM (Wise SOACS) ())
 -> RuleM (Wise SOACS) ())
-> (([VName], Certs, Exp (Wise SOACS)) -> RuleM (Wise SOACS) ())
-> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$ \([VName]
vs, Certs
cs, Exp (Wise SOACS)
e) -> Certs -> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs (RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ())
-> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$ [VName] -> Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName]
vs Exp (Rep (RuleM (Wise SOACS)))
Exp (Wise SOACS)
e
    StmAux (ExpWisdom, ())
-> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpWisdom, ())
StmAux (ExpDec (Wise SOACS))
aux (RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ())
-> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$ Pat (Rep (RuleM (Wise SOACS)))
-> Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m ()
letBind Pat (Rep (RuleM (Wise SOACS)))
Pat (Wise SOACS)
pat (Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ())
-> Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$ Op (Wise SOACS) -> Exp (Wise SOACS)
forall rep. Op rep -> ExpT rep
Op (Op (Wise SOACS) -> Exp (Wise SOACS))
-> Op (Wise SOACS) -> Exp (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ SubExp
-> Lambda (Wise SOACS)
-> [VName]
-> [(Shape, Int, VName)]
-> SOAC (Wise SOACS)
forall rep.
SubExp
-> Lambda rep -> [VName] -> [(Shape, Int, VName)] -> SOAC rep
Scatter SubExp
len Lambda (Wise SOACS)
lam' [VName]
ivs' [(Shape, Int, VName)]
as
removeReplicateWrite SymbolTable (Wise SOACS)
_ Pat (Wise SOACS)
_ StmAux (ExpDec (Wise SOACS))
_ Op (Wise SOACS)
_ = Rule (Wise SOACS)
forall rep. Rule rep
Skip

removeReplicateInput ::
  Aliased rep =>
  ST.SymbolTable rep ->
  AST.Lambda rep ->
  [VName] ->
  Maybe
    ( [([VName], Certs, AST.Exp rep)],
      AST.Lambda rep,
      [VName]
    )
removeReplicateInput :: SymbolTable rep
-> Lambda rep
-> [VName]
-> Maybe ([([VName], Certs, Exp rep)], Lambda rep, [VName])
removeReplicateInput SymbolTable rep
vtable Lambda rep
fun [VName]
arrs
  | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [([VName], Certs, Exp rep)] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [([VName], Certs, Exp rep)]
parameterBnds = do
    let ([Param (LParamInfo rep)]
arr_params', [VName]
arrs') = [(Param (LParamInfo rep), VName)]
-> ([Param (LParamInfo rep)], [VName])
forall a b. [(a, b)] -> ([a], [b])
unzip [(Param (LParamInfo rep), VName)]
params_and_arrs
        fun' :: Lambda rep
fun' = Lambda rep
fun {lambdaParams :: [Param (LParamInfo rep)]
lambdaParams = [Param (LParamInfo rep)]
acc_params [Param (LParamInfo rep)]
-> [Param (LParamInfo rep)] -> [Param (LParamInfo rep)]
forall a. Semigroup a => a -> a -> a
<> [Param (LParamInfo rep)]
arr_params'}
    ([([VName], Certs, Exp rep)], Lambda rep, [VName])
-> Maybe ([([VName], Certs, Exp rep)], Lambda rep, [VName])
forall (m :: * -> *) a. Monad m => a -> m a
return ([([VName], Certs, Exp rep)]
parameterBnds, Lambda rep
fun', [VName]
arrs')
  | Bool
otherwise = Maybe ([([VName], Certs, Exp rep)], Lambda rep, [VName])
forall a. Maybe a
Nothing
  where
    params :: [Param (LParamInfo rep)]
params = Lambda rep -> [Param (LParamInfo rep)]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams Lambda rep
fun
    ([Param (LParamInfo rep)]
acc_params, [Param (LParamInfo rep)]
arr_params) =
      Int
-> [Param (LParamInfo rep)]
-> ([Param (LParamInfo rep)], [Param (LParamInfo rep)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([Param (LParamInfo rep)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Param (LParamInfo rep)]
params Int -> Int -> Int
forall a. Num a => a -> a -> a
- [VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
arrs) [Param (LParamInfo rep)]
params
    ([(Param (LParamInfo rep), VName)]
params_and_arrs, [([VName], Certs, Exp rep)]
parameterBnds) =
      [Either (Param (LParamInfo rep), VName) ([VName], Certs, Exp rep)]
-> ([(Param (LParamInfo rep), VName)], [([VName], Certs, Exp rep)])
forall a b. [Either a b] -> ([a], [b])
partitionEithers ([Either (Param (LParamInfo rep), VName) ([VName], Certs, Exp rep)]
 -> ([(Param (LParamInfo rep), VName)],
     [([VName], Certs, Exp rep)]))
-> [Either
      (Param (LParamInfo rep), VName) ([VName], Certs, Exp rep)]
-> ([(Param (LParamInfo rep), VName)], [([VName], Certs, Exp rep)])
forall a b. (a -> b) -> a -> b
$ (Param (LParamInfo rep)
 -> VName
 -> Either
      (Param (LParamInfo rep), VName) ([VName], Certs, Exp rep))
-> [Param (LParamInfo rep)]
-> [VName]
-> [Either
      (Param (LParamInfo rep), VName) ([VName], Certs, Exp rep)]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Param (LParamInfo rep)
-> VName
-> Either (Param (LParamInfo rep), VName) ([VName], Certs, Exp rep)
isReplicateAndNotConsumed [Param (LParamInfo rep)]
arr_params [VName]
arrs

    isReplicateAndNotConsumed :: Param (LParamInfo rep)
-> VName
-> Either (Param (LParamInfo rep), VName) ([VName], Certs, Exp rep)
isReplicateAndNotConsumed Param (LParamInfo rep)
p VName
v
      | Just (BasicOp (Replicate (Shape (SubExp
_ : [SubExp]
ds)) SubExp
e), Certs
v_cs) <-
          VName -> SymbolTable rep -> Maybe (Exp rep, Certs)
forall rep. VName -> SymbolTable rep -> Maybe (Exp rep, Certs)
ST.lookupExp VName
v SymbolTable rep
vtable,
        Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Param (LParamInfo rep) -> VName
forall dec. Param dec -> VName
paramName Param (LParamInfo rep)
p VName -> Names -> Bool
`nameIn` Lambda rep -> Names
forall rep. Aliased rep => Lambda rep -> Names
consumedByLambda Lambda rep
fun =
        ([VName], Certs, Exp rep)
-> Either (Param (LParamInfo rep), VName) ([VName], Certs, Exp rep)
forall a b. b -> Either a b
Right
          ( [Param (LParamInfo rep) -> VName
forall dec. Param dec -> VName
paramName Param (LParamInfo rep)
p],
            Certs
v_cs,
            case [SubExp]
ds of
              [] -> BasicOp -> Exp rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
e
              [SubExp]
_ -> BasicOp -> Exp rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
ds) SubExp
e
          )
      | Bool
otherwise =
        (Param (LParamInfo rep), VName)
-> Either (Param (LParamInfo rep), VName) ([VName], Certs, Exp rep)
forall a b. a -> Either a b
Left (Param (LParamInfo rep)
p, VName
v)

-- | Remove inputs that are not used inside the SOAC.
removeUnusedSOACInput :: TopDownRuleOp (Wise SOACS)
removeUnusedSOACInput :: RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
removeUnusedSOACInput SymbolTable (Wise SOACS)
_ Pat (Wise SOACS)
pat StmAux (ExpDec (Wise SOACS))
aux (Screma w arrs (ScremaForm scan reduce map_lam))
  | ([(Param Type, VName)]
used, [(Param Type, VName)]
unused) <- ((Param Type, VName) -> Bool)
-> [(Param Type, VName)]
-> ([(Param Type, VName)], [(Param Type, VName)])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition (Param Type, VName) -> Bool
usedInput [(Param Type, VName)]
params_and_arrs,
    Bool -> Bool
not ([(Param Type, VName)] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(Param Type, VName)]
unused) = RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM (Wise SOACS) () -> Rule (Wise SOACS))
-> RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ do
    let ([Param Type]
used_params, [VName]
used_arrs) = [(Param Type, VName)] -> ([Param Type], [VName])
forall a b. [(a, b)] -> ([a], [b])
unzip [(Param Type, VName)]
used
        map_lam' :: Lambda (Wise SOACS)
map_lam' = Lambda (Wise SOACS)
map_lam {lambdaParams :: [LParam (Wise SOACS)]
lambdaParams = [Param Type]
[LParam (Wise SOACS)]
used_params}
    StmAux (ExpWisdom, ())
-> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpWisdom, ())
StmAux (ExpDec (Wise SOACS))
aux (RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ())
-> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$ Pat (Rep (RuleM (Wise SOACS)))
-> Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m ()
letBind Pat (Rep (RuleM (Wise SOACS)))
Pat (Wise SOACS)
pat (Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ())
-> Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$ Op (Wise SOACS) -> Exp (Wise SOACS)
forall rep. Op rep -> ExpT rep
Op (Op (Wise SOACS) -> Exp (Wise SOACS))
-> Op (Wise SOACS) -> Exp (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm (Wise SOACS) -> SOAC (Wise SOACS)
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName]
used_arrs ([Scan (Wise SOACS)]
-> [Reduce (Wise SOACS)]
-> Lambda (Wise SOACS)
-> ScremaForm (Wise SOACS)
forall rep.
[Scan rep] -> [Reduce rep] -> Lambda rep -> ScremaForm rep
ScremaForm [Scan (Wise SOACS)]
scan [Reduce (Wise SOACS)]
reduce Lambda (Wise SOACS)
map_lam')
  where
    params_and_arrs :: [(Param Type, VName)]
params_and_arrs = [Param Type] -> [VName] -> [(Param Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Lambda (Wise SOACS) -> [LParam (Wise SOACS)]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams Lambda (Wise SOACS)
map_lam) [VName]
arrs
    used_in_body :: Names
used_in_body = Body (Wise SOACS) -> Names
forall a. FreeIn a => a -> Names
freeIn (Body (Wise SOACS) -> Names) -> Body (Wise SOACS) -> Names
forall a b. (a -> b) -> a -> b
$ Lambda (Wise SOACS) -> Body (Wise SOACS)
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda (Wise SOACS)
map_lam
    usedInput :: (Param Type, VName) -> Bool
usedInput (Param Type
param, VName
_) = Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
param VName -> Names -> Bool
`nameIn` Names
used_in_body
removeUnusedSOACInput SymbolTable (Wise SOACS)
_ Pat (Wise SOACS)
_ StmAux (ExpDec (Wise SOACS))
_ Op (Wise SOACS)
_ = Rule (Wise SOACS)
forall rep. Rule rep
Skip

removeDeadMapping :: BottomUpRuleOp (Wise SOACS)
removeDeadMapping :: RuleOp (Wise SOACS) (BottomUp (Wise SOACS))
removeDeadMapping (SymbolTable (Wise SOACS)
_, UsageTable
used) Pat (Wise SOACS)
pat StmAux (ExpDec (Wise SOACS))
aux (Screma w arrs form)
  | Just Lambda (Wise SOACS)
fun <- ScremaForm (Wise SOACS) -> Maybe (Lambda (Wise SOACS))
forall rep. ScremaForm rep -> Maybe (Lambda rep)
isMapSOAC ScremaForm (Wise SOACS)
form =
    let ses :: Result
ses = Body (Wise SOACS) -> Result
forall rep. BodyT rep -> Result
bodyResult (Body (Wise SOACS) -> Result) -> Body (Wise SOACS) -> Result
forall a b. (a -> b) -> a -> b
$ Lambda (Wise SOACS) -> Body (Wise SOACS)
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda (Wise SOACS)
fun
        isUsed :: (PatElemT (VarWisdom, Type), SubExpRes, Type) -> Bool
isUsed (PatElemT (VarWisdom, Type)
bindee, SubExpRes
_, Type
_) = (VName -> UsageTable -> Bool
`UT.used` UsageTable
used) (VName -> Bool) -> VName -> Bool
forall a b. (a -> b) -> a -> b
$ PatElemT (VarWisdom, Type) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (VarWisdom, Type)
bindee
        ([PatElemT (VarWisdom, Type)]
pat', Result
ses', [Type]
ts') =
          [(PatElemT (VarWisdom, Type), SubExpRes, Type)]
-> ([PatElemT (VarWisdom, Type)], Result, [Type])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 ([(PatElemT (VarWisdom, Type), SubExpRes, Type)]
 -> ([PatElemT (VarWisdom, Type)], Result, [Type]))
-> [(PatElemT (VarWisdom, Type), SubExpRes, Type)]
-> ([PatElemT (VarWisdom, Type)], Result, [Type])
forall a b. (a -> b) -> a -> b
$
            ((PatElemT (VarWisdom, Type), SubExpRes, Type) -> Bool)
-> [(PatElemT (VarWisdom, Type), SubExpRes, Type)]
-> [(PatElemT (VarWisdom, Type), SubExpRes, Type)]
forall a. (a -> Bool) -> [a] -> [a]
filter (PatElemT (VarWisdom, Type), SubExpRes, Type) -> Bool
isUsed ([(PatElemT (VarWisdom, Type), SubExpRes, Type)]
 -> [(PatElemT (VarWisdom, Type), SubExpRes, Type)])
-> [(PatElemT (VarWisdom, Type), SubExpRes, Type)]
-> [(PatElemT (VarWisdom, Type), SubExpRes, Type)]
forall a b. (a -> b) -> a -> b
$
              [PatElemT (VarWisdom, Type)]
-> Result
-> [Type]
-> [(PatElemT (VarWisdom, Type), SubExpRes, Type)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 (PatT (VarWisdom, Type) -> [PatElemT (VarWisdom, Type)]
forall dec. PatT dec -> [PatElemT dec]
patElems PatT (VarWisdom, Type)
Pat (Wise SOACS)
pat) Result
ses ([Type] -> [(PatElemT (VarWisdom, Type), SubExpRes, Type)])
-> [Type] -> [(PatElemT (VarWisdom, Type), SubExpRes, Type)]
forall a b. (a -> b) -> a -> b
$ Lambda (Wise SOACS) -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType Lambda (Wise SOACS)
fun
        fun' :: Lambda (Wise SOACS)
fun' =
          Lambda (Wise SOACS)
fun
            { lambdaBody :: Body (Wise SOACS)
lambdaBody = (Lambda (Wise SOACS) -> Body (Wise SOACS)
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda (Wise SOACS)
fun) {bodyResult :: Result
bodyResult = Result
ses'},
              lambdaReturnType :: [Type]
lambdaReturnType = [Type]
ts'
            }
     in if PatT (VarWisdom, Type)
Pat (Wise SOACS)
pat PatT (VarWisdom, Type) -> PatT (VarWisdom, Type) -> Bool
forall a. Eq a => a -> a -> Bool
/= [PatElemT (VarWisdom, Type)] -> PatT (VarWisdom, Type)
forall dec. [PatElemT dec] -> PatT dec
Pat [PatElemT (VarWisdom, Type)]
pat'
          then
            RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM (Wise SOACS) () -> Rule (Wise SOACS))
-> (RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ())
-> RuleM (Wise SOACS) ()
-> Rule (Wise SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StmAux (ExpWisdom, ())
-> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpWisdom, ())
StmAux (ExpDec (Wise SOACS))
aux (RuleM (Wise SOACS) () -> Rule (Wise SOACS))
-> RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall a b. (a -> b) -> a -> b
$
              Pat (Rep (RuleM (Wise SOACS)))
-> Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m ()
letBind ([PatElemT (VarWisdom, Type)] -> PatT (VarWisdom, Type)
forall dec. [PatElemT dec] -> PatT dec
Pat [PatElemT (VarWisdom, Type)]
pat') (Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ())
-> Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$ Op (Wise SOACS) -> Exp (Wise SOACS)
forall rep. Op rep -> ExpT rep
Op (Op (Wise SOACS) -> Exp (Wise SOACS))
-> Op (Wise SOACS) -> Exp (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm (Wise SOACS) -> SOAC (Wise SOACS)
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName]
arrs (ScremaForm (Wise SOACS) -> SOAC (Wise SOACS))
-> ScremaForm (Wise SOACS) -> SOAC (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ Lambda (Wise SOACS) -> ScremaForm (Wise SOACS)
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda (Wise SOACS)
fun'
          else Rule (Wise SOACS)
forall rep. Rule rep
Skip
removeDeadMapping BottomUp (Wise SOACS)
_ Pat (Wise SOACS)
_ StmAux (ExpDec (Wise SOACS))
_ Op (Wise SOACS)
_ = Rule (Wise SOACS)
forall rep. Rule rep
Skip

removeDuplicateMapOutput :: TopDownRuleOp (Wise SOACS)
removeDuplicateMapOutput :: RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
removeDuplicateMapOutput SymbolTable (Wise SOACS)
_ (Pat [PatElemT (LetDec (Wise SOACS))]
pes) StmAux (ExpDec (Wise SOACS))
aux (Screma w arrs form)
  | Just Lambda (Wise SOACS)
fun <- ScremaForm (Wise SOACS) -> Maybe (Lambda (Wise SOACS))
forall rep. ScremaForm rep -> Maybe (Lambda rep)
isMapSOAC ScremaForm (Wise SOACS)
form =
    let ses :: Result
ses = Body (Wise SOACS) -> Result
forall rep. BodyT rep -> Result
bodyResult (Body (Wise SOACS) -> Result) -> Body (Wise SOACS) -> Result
forall a b. (a -> b) -> a -> b
$ Lambda (Wise SOACS) -> Body (Wise SOACS)
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda (Wise SOACS)
fun
        ts :: [Type]
ts = Lambda (Wise SOACS) -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType Lambda (Wise SOACS)
fun
        ses_ts_pes :: [(SubExpRes, Type, PatElemT (VarWisdom, Type))]
ses_ts_pes = Result
-> [Type]
-> [PatElemT (VarWisdom, Type)]
-> [(SubExpRes, Type, PatElemT (VarWisdom, Type))]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 Result
ses [Type]
ts [PatElemT (VarWisdom, Type)]
[PatElemT (LetDec (Wise SOACS))]
pes
        ([(SubExpRes, Type, PatElemT (VarWisdom, Type))]
ses_ts_pes', [(PatElemT (VarWisdom, Type), PatElemT (VarWisdom, Type))]
copies) =
          (([(SubExpRes, Type, PatElemT (VarWisdom, Type))],
  [(PatElemT (VarWisdom, Type), PatElemT (VarWisdom, Type))])
 -> (SubExpRes, Type, PatElemT (VarWisdom, Type))
 -> ([(SubExpRes, Type, PatElemT (VarWisdom, Type))],
     [(PatElemT (VarWisdom, Type), PatElemT (VarWisdom, Type))]))
-> ([(SubExpRes, Type, PatElemT (VarWisdom, Type))],
    [(PatElemT (VarWisdom, Type), PatElemT (VarWisdom, Type))])
-> [(SubExpRes, Type, PatElemT (VarWisdom, Type))]
-> ([(SubExpRes, Type, PatElemT (VarWisdom, Type))],
    [(PatElemT (VarWisdom, Type), PatElemT (VarWisdom, Type))])
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl ([(SubExpRes, Type, PatElemT (VarWisdom, Type))],
 [(PatElemT (VarWisdom, Type), PatElemT (VarWisdom, Type))])
-> (SubExpRes, Type, PatElemT (VarWisdom, Type))
-> ([(SubExpRes, Type, PatElemT (VarWisdom, Type))],
    [(PatElemT (VarWisdom, Type), PatElemT (VarWisdom, Type))])
forall b a.
([(SubExpRes, b, a)], [(a, a)])
-> (SubExpRes, b, a) -> ([(SubExpRes, b, a)], [(a, a)])
checkForDuplicates ([(SubExpRes, Type, PatElemT (VarWisdom, Type))]
forall a. Monoid a => a
mempty, [(PatElemT (VarWisdom, Type), PatElemT (VarWisdom, Type))]
forall a. Monoid a => a
mempty) [(SubExpRes, Type, PatElemT (VarWisdom, Type))]
ses_ts_pes
     in if [(PatElemT (VarWisdom, Type), PatElemT (VarWisdom, Type))] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(PatElemT (VarWisdom, Type), PatElemT (VarWisdom, Type))]
copies
          then Rule (Wise SOACS)
forall rep. Rule rep
Skip
          else RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM (Wise SOACS) () -> Rule (Wise SOACS))
-> RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ do
            let (Result
ses', [Type]
ts', [PatElemT (VarWisdom, Type)]
pes') = [(SubExpRes, Type, PatElemT (VarWisdom, Type))]
-> (Result, [Type], [PatElemT (VarWisdom, Type)])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(SubExpRes, Type, PatElemT (VarWisdom, Type))]
ses_ts_pes'
                fun' :: Lambda (Wise SOACS)
fun' =
                  Lambda (Wise SOACS)
fun
                    { lambdaBody :: Body (Wise SOACS)
lambdaBody = (Lambda (Wise SOACS) -> Body (Wise SOACS)
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda (Wise SOACS)
fun) {bodyResult :: Result
bodyResult = Result
ses'},
                      lambdaReturnType :: [Type]
lambdaReturnType = [Type]
ts'
                    }
            StmAux (ExpWisdom, ())
-> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpWisdom, ())
StmAux (ExpDec (Wise SOACS))
aux (RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ())
-> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$ Pat (Rep (RuleM (Wise SOACS)))
-> Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m ()
letBind ([PatElemT (VarWisdom, Type)] -> PatT (VarWisdom, Type)
forall dec. [PatElemT dec] -> PatT dec
Pat [PatElemT (VarWisdom, Type)]
pes') (Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ())
-> Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$ Op (Wise SOACS) -> Exp (Wise SOACS)
forall rep. Op rep -> ExpT rep
Op (Op (Wise SOACS) -> Exp (Wise SOACS))
-> Op (Wise SOACS) -> Exp (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm (Wise SOACS) -> SOAC (Wise SOACS)
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName]
arrs (ScremaForm (Wise SOACS) -> SOAC (Wise SOACS))
-> ScremaForm (Wise SOACS) -> SOAC (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ Lambda (Wise SOACS) -> ScremaForm (Wise SOACS)
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda (Wise SOACS)
fun'
            [(PatElemT (VarWisdom, Type), PatElemT (VarWisdom, Type))]
-> ((PatElemT (VarWisdom, Type), PatElemT (VarWisdom, Type))
    -> RuleM (Wise SOACS) ())
-> RuleM (Wise SOACS) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(PatElemT (VarWisdom, Type), PatElemT (VarWisdom, Type))]
copies (((PatElemT (VarWisdom, Type), PatElemT (VarWisdom, Type))
  -> RuleM (Wise SOACS) ())
 -> RuleM (Wise SOACS) ())
-> ((PatElemT (VarWisdom, Type), PatElemT (VarWisdom, Type))
    -> RuleM (Wise SOACS) ())
-> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$ \(PatElemT (VarWisdom, Type)
from, PatElemT (VarWisdom, Type)
to) ->
              Pat (Rep (RuleM (Wise SOACS)))
-> Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m ()
letBind ([PatElemT (VarWisdom, Type)] -> PatT (VarWisdom, Type)
forall dec. [PatElemT dec] -> PatT dec
Pat [PatElemT (VarWisdom, Type)
to]) (Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ())
-> Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Wise SOACS)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp (Wise SOACS)) -> BasicOp -> Exp (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy (VName -> BasicOp) -> VName -> BasicOp
forall a b. (a -> b) -> a -> b
$ PatElemT (VarWisdom, Type) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (VarWisdom, Type)
from
  where
    checkForDuplicates :: ([(SubExpRes, b, a)], [(a, a)])
-> (SubExpRes, b, a) -> ([(SubExpRes, b, a)], [(a, a)])
checkForDuplicates ([(SubExpRes, b, a)]
ses_ts_pes', [(a, a)]
copies) (SubExpRes
se, b
t, a
pe)
      | Just (SubExpRes
_, b
_, a
pe') <- ((SubExpRes, b, a) -> Bool)
-> [(SubExpRes, b, a)] -> Maybe (SubExpRes, b, a)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (\(SubExpRes
x, b
_, a
_) -> SubExpRes -> SubExp
resSubExp SubExpRes
x SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExpRes -> SubExp
resSubExp SubExpRes
se) [(SubExpRes, b, a)]
ses_ts_pes' =
        -- This result has been returned before, producing the
        -- array pe'.
        ([(SubExpRes, b, a)]
ses_ts_pes', (a
pe', a
pe) (a, a) -> [(a, a)] -> [(a, a)]
forall a. a -> [a] -> [a]
: [(a, a)]
copies)
      | Bool
otherwise = ([(SubExpRes, b, a)]
ses_ts_pes' [(SubExpRes, b, a)] -> [(SubExpRes, b, a)] -> [(SubExpRes, b, a)]
forall a. [a] -> [a] -> [a]
++ [(SubExpRes
se, b
t, a
pe)], [(a, a)]
copies)
removeDuplicateMapOutput SymbolTable (Wise SOACS)
_ Pat (Wise SOACS)
_ StmAux (ExpDec (Wise SOACS))
_ Op (Wise SOACS)
_ = Rule (Wise SOACS)
forall rep. Rule rep
Skip

-- Mapping some operations becomes an extension of that operation.
mapOpToOp :: BottomUpRuleOp (Wise SOACS)
mapOpToOp :: RuleOp (Wise SOACS) (BottomUp (Wise SOACS))
mapOpToOp (SymbolTable (Wise SOACS)
_, UsageTable
used) Pat (Wise SOACS)
pat StmAux (ExpDec (Wise SOACS))
aux1 Op (Wise SOACS)
e
  | Just (PatElemT (VarWisdom, Type)
map_pe, Certs
cs, SubExp
w, BasicOp (Reshape ShapeChange SubExp
newshape VName
reshape_arr), [Param Type
p], [VName
arr]) <-
      PatT (VarWisdom, Type)
-> SOAC (Wise SOACS)
-> Maybe
     (PatElemT (VarWisdom, Type), Certs, SubExp, Exp (Wise SOACS),
      [Param Type], [VName])
forall dec.
PatT dec
-> SOAC (Wise SOACS)
-> Maybe
     (PatElemT dec, Certs, SubExp, Exp (Wise SOACS), [Param Type],
      [VName])
isMapWithOp PatT (VarWisdom, Type)
Pat (Wise SOACS)
pat Op (Wise SOACS)
SOAC (Wise SOACS)
e,
    Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
p VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
reshape_arr,
    Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ VName -> UsageTable -> Bool
UT.isConsumed (PatElemT (VarWisdom, Type) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (VarWisdom, Type)
map_pe) UsageTable
used = RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM (Wise SOACS) () -> Rule (Wise SOACS))
-> RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ do
    let redim :: DimChange SubExp
redim
          | Maybe [SubExp] -> Bool
forall a. Maybe a -> Bool
isJust (Maybe [SubExp] -> Bool) -> Maybe [SubExp] -> Bool
forall a b. (a -> b) -> a -> b
$ ShapeChange SubExp -> Maybe [SubExp]
forall d. ShapeChange d -> Maybe [d]
shapeCoercion ShapeChange SubExp
newshape = SubExp -> DimChange SubExp
forall d. d -> DimChange d
DimCoercion SubExp
w
          | Bool
otherwise = SubExp -> DimChange SubExp
forall d. d -> DimChange d
DimNew SubExp
w
    Certs -> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (StmAux (ExpWisdom, ()) -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (ExpWisdom, ())
StmAux (ExpDec (Wise SOACS))
aux1 Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> Certs
cs) (RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ())
-> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$
      Pat (Rep (RuleM (Wise SOACS)))
-> Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m ()
letBind Pat (Rep (RuleM (Wise SOACS)))
Pat (Wise SOACS)
pat (Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ())
-> Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$
        BasicOp -> Exp (Wise SOACS)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp (Wise SOACS)) -> BasicOp -> Exp (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ ShapeChange SubExp -> VName -> BasicOp
Reshape (DimChange SubExp
redim DimChange SubExp -> ShapeChange SubExp -> ShapeChange SubExp
forall a. a -> [a] -> [a]
: ShapeChange SubExp
newshape) VName
arr
  | Just
      ( PatElemT (VarWisdom, Type)
_,
        Certs
cs,
        SubExp
_,
        BasicOp (Concat Int
d VName
arr [VName]
arrs SubExp
dw),
        [Param Type]
ps,
        VName
outer_arr : [VName]
outer_arrs
        ) <-
      PatT (VarWisdom, Type)
-> SOAC (Wise SOACS)
-> Maybe
     (PatElemT (VarWisdom, Type), Certs, SubExp, Exp (Wise SOACS),
      [Param Type], [VName])
forall dec.
PatT dec
-> SOAC (Wise SOACS)
-> Maybe
     (PatElemT dec, Certs, SubExp, Exp (Wise SOACS), [Param Type],
      [VName])
isMapWithOp PatT (VarWisdom, Type)
Pat (Wise SOACS)
pat Op (Wise SOACS)
SOAC (Wise SOACS)
e,
    (VName
arr VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
arrs) [VName] -> [VName] -> Bool
forall a. Eq a => a -> a -> Bool
== (Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName [Param Type]
ps =
    RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM (Wise SOACS) () -> Rule (Wise SOACS))
-> RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall a b. (a -> b) -> a -> b
$
      Certs -> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (StmAux (ExpWisdom, ()) -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (ExpWisdom, ())
StmAux (ExpDec (Wise SOACS))
aux1 Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> Certs
cs) (RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ())
-> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$
        Pat (Rep (RuleM (Wise SOACS)))
-> Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m ()
letBind Pat (Rep (RuleM (Wise SOACS)))
Pat (Wise SOACS)
pat (Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ())
-> Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$
          BasicOp -> Exp (Wise SOACS)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp (Wise SOACS)) -> BasicOp -> Exp (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ Int -> VName -> [VName] -> SubExp -> BasicOp
Concat (Int
d Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) VName
outer_arr [VName]
outer_arrs SubExp
dw
  | Just
      ( PatElemT (VarWisdom, Type)
map_pe,
        Certs
cs,
        SubExp
_,
        BasicOp (Rearrange [Int]
perm VName
rearrange_arr),
        [Param Type
p],
        [VName
arr]
        ) <-
      PatT (VarWisdom, Type)
-> SOAC (Wise SOACS)
-> Maybe
     (PatElemT (VarWisdom, Type), Certs, SubExp, Exp (Wise SOACS),
      [Param Type], [VName])
forall dec.
PatT dec
-> SOAC (Wise SOACS)
-> Maybe
     (PatElemT dec, Certs, SubExp, Exp (Wise SOACS), [Param Type],
      [VName])
isMapWithOp PatT (VarWisdom, Type)
Pat (Wise SOACS)
pat Op (Wise SOACS)
SOAC (Wise SOACS)
e,
    Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
p VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
rearrange_arr,
    Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ VName -> UsageTable -> Bool
UT.isConsumed (PatElemT (VarWisdom, Type) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (VarWisdom, Type)
map_pe) UsageTable
used =
    RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM (Wise SOACS) () -> Rule (Wise SOACS))
-> RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall a b. (a -> b) -> a -> b
$
      Certs -> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (StmAux (ExpWisdom, ()) -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (ExpWisdom, ())
StmAux (ExpDec (Wise SOACS))
aux1 Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> Certs
cs) (RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ())
-> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$
        Pat (Rep (RuleM (Wise SOACS)))
-> Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m ()
letBind Pat (Rep (RuleM (Wise SOACS)))
Pat (Wise SOACS)
pat (Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ())
-> Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$
          BasicOp -> Exp (Wise SOACS)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp (Wise SOACS)) -> BasicOp -> Exp (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Rearrange (Int
0 Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: (Int -> Int) -> [Int] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+) [Int]
perm) VName
arr
  | Just (PatElemT (VarWisdom, Type)
map_pe, Certs
cs, SubExp
_, BasicOp (Rotate [SubExp]
rots VName
rotate_arr), [Param Type
p], [VName
arr]) <-
      PatT (VarWisdom, Type)
-> SOAC (Wise SOACS)
-> Maybe
     (PatElemT (VarWisdom, Type), Certs, SubExp, Exp (Wise SOACS),
      [Param Type], [VName])
forall dec.
PatT dec
-> SOAC (Wise SOACS)
-> Maybe
     (PatElemT dec, Certs, SubExp, Exp (Wise SOACS), [Param Type],
      [VName])
isMapWithOp PatT (VarWisdom, Type)
Pat (Wise SOACS)
pat Op (Wise SOACS)
SOAC (Wise SOACS)
e,
    Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
p VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
rotate_arr,
    Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ VName -> UsageTable -> Bool
UT.isConsumed (PatElemT (VarWisdom, Type) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (VarWisdom, Type)
map_pe) UsageTable
used =
    RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM (Wise SOACS) () -> Rule (Wise SOACS))
-> RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall a b. (a -> b) -> a -> b
$
      Certs -> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (StmAux (ExpWisdom, ()) -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (ExpWisdom, ())
StmAux (ExpDec (Wise SOACS))
aux1 Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> Certs
cs) (RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ())
-> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$
        Pat (Rep (RuleM (Wise SOACS)))
-> Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m ()
letBind Pat (Rep (RuleM (Wise SOACS)))
Pat (Wise SOACS)
pat (Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ())
-> Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$
          BasicOp -> Exp (Wise SOACS)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp (Wise SOACS)) -> BasicOp -> Exp (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ [SubExp] -> VName -> BasicOp
Rotate (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0 SubExp -> [SubExp] -> [SubExp]
forall a. a -> [a] -> [a]
: [SubExp]
rots) VName
arr
mapOpToOp BottomUp (Wise SOACS)
_ Pat (Wise SOACS)
_ StmAux (ExpDec (Wise SOACS))
_ Op (Wise SOACS)
_ = Rule (Wise SOACS)
forall rep. Rule rep
Skip

isMapWithOp ::
  PatT dec ->
  SOAC (Wise SOACS) ->
  Maybe
    ( PatElemT dec,
      Certs,
      SubExp,
      AST.Exp (Wise SOACS),
      [Param Type],
      [VName]
    )
isMapWithOp :: PatT dec
-> SOAC (Wise SOACS)
-> Maybe
     (PatElemT dec, Certs, SubExp, Exp (Wise SOACS), [Param Type],
      [VName])
isMapWithOp PatT dec
pat SOAC (Wise SOACS)
e
  | Pat [PatElemT dec
map_pe] <- PatT dec
pat,
    Screma SubExp
w [VName]
arrs ScremaForm (Wise SOACS)
form <- SOAC (Wise SOACS)
e,
    Just Lambda (Wise SOACS)
map_lam <- ScremaForm (Wise SOACS) -> Maybe (Lambda (Wise SOACS))
forall rep. ScremaForm rep -> Maybe (Lambda rep)
isMapSOAC ScremaForm (Wise SOACS)
form,
    [Let (Pat [PatElemT (LetDec (Wise SOACS))
pe]) StmAux (ExpDec (Wise SOACS))
aux2 Exp (Wise SOACS)
e'] <-
      Stms (Wise SOACS) -> [Stm (Wise SOACS)]
forall rep. Stms rep -> [Stm rep]
stmsToList (Stms (Wise SOACS) -> [Stm (Wise SOACS)])
-> Stms (Wise SOACS) -> [Stm (Wise SOACS)]
forall a b. (a -> b) -> a -> b
$ Body (Wise SOACS) -> Stms (Wise SOACS)
forall rep. BodyT rep -> Stms rep
bodyStms (Body (Wise SOACS) -> Stms (Wise SOACS))
-> Body (Wise SOACS) -> Stms (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ Lambda (Wise SOACS) -> Body (Wise SOACS)
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda (Wise SOACS)
map_lam,
    [SubExpRes Certs
_ (Var VName
r)] <- Body (Wise SOACS) -> Result
forall rep. BodyT rep -> Result
bodyResult (Body (Wise SOACS) -> Result) -> Body (Wise SOACS) -> Result
forall a b. (a -> b) -> a -> b
$ Lambda (Wise SOACS) -> Body (Wise SOACS)
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda (Wise SOACS)
map_lam,
    VName
r VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== PatElemT (VarWisdom, Type) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (VarWisdom, Type)
PatElemT (LetDec (Wise SOACS))
pe =
    (PatElemT dec, Certs, SubExp, Exp (Wise SOACS), [Param Type],
 [VName])
-> Maybe
     (PatElemT dec, Certs, SubExp, Exp (Wise SOACS), [Param Type],
      [VName])
forall a. a -> Maybe a
Just (PatElemT dec
map_pe, StmAux (ExpWisdom, ()) -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (ExpWisdom, ())
StmAux (ExpDec (Wise SOACS))
aux2, SubExp
w, Exp (Wise SOACS)
e', Lambda (Wise SOACS) -> [LParam (Wise SOACS)]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams Lambda (Wise SOACS)
map_lam, [VName]
arrs)
  | Bool
otherwise = Maybe
  (PatElemT dec, Certs, SubExp, Exp (Wise SOACS), [Param Type],
   [VName])
forall a. Maybe a
Nothing

-- | Some of the results of a reduction (or really: Redomap) may be
-- dead.  We remove them here.  The trick is that we need to look at
-- the data dependencies to see that the "dead" result is not
-- actually used for computing one of the live ones.
removeDeadReduction :: BottomUpRuleOp (Wise SOACS)
removeDeadReduction :: RuleOp (Wise SOACS) (BottomUp (Wise SOACS))
removeDeadReduction (SymbolTable (Wise SOACS)
_, UsageTable
used) Pat (Wise SOACS)
pat StmAux (ExpDec (Wise SOACS))
aux (Screma w arrs form)
  | Just ([Reduce Commutativity
comm Lambda (Wise SOACS)
redlam [SubExp]
nes], Lambda (Wise SOACS)
maplam) <- ScremaForm (Wise SOACS)
-> Maybe ([Reduce (Wise SOACS)], Lambda (Wise SOACS))
forall rep. ScremaForm rep -> Maybe ([Reduce rep], Lambda rep)
isRedomapSOAC ScremaForm (Wise SOACS)
form,
    Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> UsageTable -> Bool
`UT.used` UsageTable
used) ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$ PatT (VarWisdom, Type) -> [VName]
forall dec. PatT dec -> [VName]
patNames PatT (VarWisdom, Type)
Pat (Wise SOACS)
pat, -- Quick/cheap check
    let ([PatElemT (VarWisdom, Type)]
red_pes, [PatElemT (VarWisdom, Type)]
map_pes) = Int
-> [PatElemT (VarWisdom, Type)]
-> ([PatElemT (VarWisdom, Type)], [PatElemT (VarWisdom, Type)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) ([PatElemT (VarWisdom, Type)]
 -> ([PatElemT (VarWisdom, Type)], [PatElemT (VarWisdom, Type)]))
-> [PatElemT (VarWisdom, Type)]
-> ([PatElemT (VarWisdom, Type)], [PatElemT (VarWisdom, Type)])
forall a b. (a -> b) -> a -> b
$ PatT (VarWisdom, Type) -> [PatElemT (VarWisdom, Type)]
forall dec. PatT dec -> [PatElemT dec]
patElems PatT (VarWisdom, Type)
Pat (Wise SOACS)
pat,
    let redlam_deps :: Dependencies
redlam_deps = Body (Wise SOACS) -> Dependencies
forall rep. ASTRep rep => Body rep -> Dependencies
dataDependencies (Body (Wise SOACS) -> Dependencies)
-> Body (Wise SOACS) -> Dependencies
forall a b. (a -> b) -> a -> b
$ Lambda (Wise SOACS) -> Body (Wise SOACS)
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda (Wise SOACS)
redlam,
    let redlam_res :: Result
redlam_res = Body (Wise SOACS) -> Result
forall rep. BodyT rep -> Result
bodyResult (Body (Wise SOACS) -> Result) -> Body (Wise SOACS) -> Result
forall a b. (a -> b) -> a -> b
$ Lambda (Wise SOACS) -> Body (Wise SOACS)
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda (Wise SOACS)
redlam,
    let redlam_params :: [LParam (Wise SOACS)]
redlam_params = Lambda (Wise SOACS) -> [LParam (Wise SOACS)]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams Lambda (Wise SOACS)
redlam,
    let used_after :: [Param Type]
used_after =
          ((PatElemT (VarWisdom, Type), Param Type) -> Param Type)
-> [(PatElemT (VarWisdom, Type), Param Type)] -> [Param Type]
forall a b. (a -> b) -> [a] -> [b]
map (PatElemT (VarWisdom, Type), Param Type) -> Param Type
forall a b. (a, b) -> b
snd ([(PatElemT (VarWisdom, Type), Param Type)] -> [Param Type])
-> [(PatElemT (VarWisdom, Type), Param Type)] -> [Param Type]
forall a b. (a -> b) -> a -> b
$
            ((PatElemT (VarWisdom, Type), Param Type) -> Bool)
-> [(PatElemT (VarWisdom, Type), Param Type)]
-> [(PatElemT (VarWisdom, Type), Param Type)]
forall a. (a -> Bool) -> [a] -> [a]
filter ((VName -> UsageTable -> Bool
`UT.used` UsageTable
used) (VName -> Bool)
-> ((PatElemT (VarWisdom, Type), Param Type) -> VName)
-> (PatElemT (VarWisdom, Type), Param Type)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElemT (VarWisdom, Type) -> VName
forall dec. PatElemT dec -> VName
patElemName (PatElemT (VarWisdom, Type) -> VName)
-> ((PatElemT (VarWisdom, Type), Param Type)
    -> PatElemT (VarWisdom, Type))
-> (PatElemT (VarWisdom, Type), Param Type)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PatElemT (VarWisdom, Type), Param Type)
-> PatElemT (VarWisdom, Type)
forall a b. (a, b) -> a
fst) ([(PatElemT (VarWisdom, Type), Param Type)]
 -> [(PatElemT (VarWisdom, Type), Param Type)])
-> [(PatElemT (VarWisdom, Type), Param Type)]
-> [(PatElemT (VarWisdom, Type), Param Type)]
forall a b. (a -> b) -> a -> b
$
              [PatElemT (VarWisdom, Type)]
-> [Param Type] -> [(PatElemT (VarWisdom, Type), Param Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElemT (VarWisdom, Type)]
red_pes [Param Type]
[LParam (Wise SOACS)]
redlam_params,
    let necessary :: Names
necessary =
          (Param Type -> Bool)
-> [(Param Type, SubExp)] -> Dependencies -> Names
forall dec.
(Param dec -> Bool)
-> [(Param dec, SubExp)] -> Dependencies -> Names
findNecessaryForReturned
            (Param Type -> [Param Type] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Param Type]
used_after)
            ([Param Type] -> [SubExp] -> [(Param Type, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param Type]
[LParam (Wise SOACS)]
redlam_params ([SubExp] -> [(Param Type, SubExp)])
-> [SubExp] -> [(Param Type, SubExp)]
forall a b. (a -> b) -> a -> b
$ (SubExpRes -> SubExp) -> Result -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp (Result -> [SubExp]) -> Result -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Result
redlam_res Result -> Result -> Result
forall a. Semigroup a => a -> a -> a
<> Result
redlam_res)
            Dependencies
redlam_deps,
    let alive_mask :: [Bool]
alive_mask = (Param Type -> Bool) -> [Param Type] -> [Bool]
forall a b. (a -> b) -> [a] -> [b]
map ((VName -> Names -> Bool
`nameIn` Names
necessary) (VName -> Bool) -> (Param Type -> VName) -> Param Type -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param Type -> VName
forall dec. Param dec -> VName
paramName) [Param Type]
[LParam (Wise SOACS)]
redlam_params,
    Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (Bool -> Bool) -> [Bool] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Bool -> Bool -> Bool
forall a. Eq a => a -> a -> Bool
== Bool
True) [Bool]
alive_mask = RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM (Wise SOACS) () -> Rule (Wise SOACS))
-> RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ do
    let fixDeadToNeutral :: Bool -> a -> Maybe a
fixDeadToNeutral Bool
lives a
ne = if Bool
lives then Maybe a
forall a. Maybe a
Nothing else a -> Maybe a
forall a. a -> Maybe a
Just a
ne
        dead_fix :: [Maybe SubExp]
dead_fix = (Bool -> SubExp -> Maybe SubExp)
-> [Bool] -> [SubExp] -> [Maybe SubExp]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Bool -> SubExp -> Maybe SubExp
forall a. Bool -> a -> Maybe a
fixDeadToNeutral [Bool]
alive_mask [SubExp]
nes
        ([PatElemT (VarWisdom, Type)]
used_red_pes, [Param Type]
_, [SubExp]
used_nes) =
          [(PatElemT (VarWisdom, Type), Param Type, SubExp)]
-> ([PatElemT (VarWisdom, Type)], [Param Type], [SubExp])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 ([(PatElemT (VarWisdom, Type), Param Type, SubExp)]
 -> ([PatElemT (VarWisdom, Type)], [Param Type], [SubExp]))
-> [(PatElemT (VarWisdom, Type), Param Type, SubExp)]
-> ([PatElemT (VarWisdom, Type)], [Param Type], [SubExp])
forall a b. (a -> b) -> a -> b
$
            ((PatElemT (VarWisdom, Type), Param Type, SubExp) -> Bool)
-> [(PatElemT (VarWisdom, Type), Param Type, SubExp)]
-> [(PatElemT (VarWisdom, Type), Param Type, SubExp)]
forall a. (a -> Bool) -> [a] -> [a]
filter (\(PatElemT (VarWisdom, Type)
_, Param Type
x, SubExp
_) -> Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
x VName -> Names -> Bool
`nameIn` Names
necessary) ([(PatElemT (VarWisdom, Type), Param Type, SubExp)]
 -> [(PatElemT (VarWisdom, Type), Param Type, SubExp)])
-> [(PatElemT (VarWisdom, Type), Param Type, SubExp)]
-> [(PatElemT (VarWisdom, Type), Param Type, SubExp)]
forall a b. (a -> b) -> a -> b
$
              [PatElemT (VarWisdom, Type)]
-> [Param Type]
-> [SubExp]
-> [(PatElemT (VarWisdom, Type), Param Type, SubExp)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [PatElemT (VarWisdom, Type)]
red_pes [Param Type]
[LParam (Wise SOACS)]
redlam_params [SubExp]
nes

    let maplam' :: Lambda (Wise SOACS)
maplam' = [Bool] -> Lambda (Wise SOACS) -> Lambda (Wise SOACS)
forall rep. [Bool] -> Lambda rep -> Lambda rep
removeLambdaResults (Int -> [Bool] -> [Bool]
forall a. Int -> [a] -> [a]
take ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) [Bool]
alive_mask) Lambda (Wise SOACS)
maplam
    Lambda (Wise SOACS)
redlam' <- [Bool] -> Lambda (Wise SOACS) -> Lambda (Wise SOACS)
forall rep. [Bool] -> Lambda rep -> Lambda rep
removeLambdaResults (Int -> [Bool] -> [Bool]
forall a. Int -> [a] -> [a]
take ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) [Bool]
alive_mask) (Lambda (Wise SOACS) -> Lambda (Wise SOACS))
-> RuleM (Wise SOACS) (Lambda (Wise SOACS))
-> RuleM (Wise SOACS) (Lambda (Wise SOACS))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Lambda (Rep (RuleM (Wise SOACS)))
-> [Maybe SubExp]
-> RuleM (Wise SOACS) (Lambda (Rep (RuleM (Wise SOACS))))
forall (m :: * -> *).
(MonadBuilder m, Buildable (Rep m), BuilderOps (Rep m)) =>
Lambda (Rep m) -> [Maybe SubExp] -> m (Lambda (Rep m))
fixLambdaParams Lambda (Rep (RuleM (Wise SOACS)))
Lambda (Wise SOACS)
redlam ([Maybe SubExp]
dead_fix [Maybe SubExp] -> [Maybe SubExp] -> [Maybe SubExp]
forall a. [a] -> [a] -> [a]
++ [Maybe SubExp]
dead_fix)

    StmAux (ExpWisdom, ())
-> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpWisdom, ())
StmAux (ExpDec (Wise SOACS))
aux (RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ())
-> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$
      Pat (Rep (RuleM (Wise SOACS)))
-> Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m ()
letBind ([PatElemT (VarWisdom, Type)] -> PatT (VarWisdom, Type)
forall dec. [PatElemT dec] -> PatT dec
Pat ([PatElemT (VarWisdom, Type)] -> PatT (VarWisdom, Type))
-> [PatElemT (VarWisdom, Type)] -> PatT (VarWisdom, Type)
forall a b. (a -> b) -> a -> b
$ [PatElemT (VarWisdom, Type)]
used_red_pes [PatElemT (VarWisdom, Type)]
-> [PatElemT (VarWisdom, Type)] -> [PatElemT (VarWisdom, Type)]
forall a. [a] -> [a] -> [a]
++ [PatElemT (VarWisdom, Type)]
map_pes) (Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ())
-> Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$
        Op (Wise SOACS) -> Exp (Wise SOACS)
forall rep. Op rep -> ExpT rep
Op (Op (Wise SOACS) -> Exp (Wise SOACS))
-> Op (Wise SOACS) -> Exp (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm (Wise SOACS) -> SOAC (Wise SOACS)
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName]
arrs (ScremaForm (Wise SOACS) -> SOAC (Wise SOACS))
-> ScremaForm (Wise SOACS) -> SOAC (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ [Reduce (Wise SOACS)]
-> Lambda (Wise SOACS) -> ScremaForm (Wise SOACS)
forall rep. [Reduce rep] -> Lambda rep -> ScremaForm rep
redomapSOAC [Commutativity
-> Lambda (Wise SOACS) -> [SubExp] -> Reduce (Wise SOACS)
forall rep. Commutativity -> Lambda rep -> [SubExp] -> Reduce rep
Reduce Commutativity
comm Lambda (Wise SOACS)
redlam' [SubExp]
used_nes] Lambda (Wise SOACS)
maplam'
removeDeadReduction BottomUp (Wise SOACS)
_ Pat (Wise SOACS)
_ StmAux (ExpDec (Wise SOACS))
_ Op (Wise SOACS)
_ = Rule (Wise SOACS)
forall rep. Rule rep
Skip

-- | If we are writing to an array that is never used, get rid of it.
removeDeadWrite :: BottomUpRuleOp (Wise SOACS)
removeDeadWrite :: RuleOp (Wise SOACS) (BottomUp (Wise SOACS))
removeDeadWrite (SymbolTable (Wise SOACS)
_, UsageTable
used) Pat (Wise SOACS)
pat StmAux (ExpDec (Wise SOACS))
aux (Scatter w fun arrs dests) =
  let ([Result]
i_ses, Result
v_ses) = [(Result, SubExpRes)] -> ([Result], Result)
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Result, SubExpRes)] -> ([Result], Result))
-> [(Result, SubExpRes)] -> ([Result], Result)
forall a b. (a -> b) -> a -> b
$ [(Shape, Int, VName)] -> Result -> [(Result, SubExpRes)]
forall array a. [(Shape, Int, array)] -> [a] -> [([a], a)]
groupScatterResults' [(Shape, Int, VName)]
dests (Result -> [(Result, SubExpRes)])
-> Result -> [(Result, SubExpRes)]
forall a b. (a -> b) -> a -> b
$ Body (Wise SOACS) -> Result
forall rep. BodyT rep -> Result
bodyResult (Body (Wise SOACS) -> Result) -> Body (Wise SOACS) -> Result
forall a b. (a -> b) -> a -> b
$ Lambda (Wise SOACS) -> Body (Wise SOACS)
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda (Wise SOACS)
fun
      ([[Type]]
i_ts, [Type]
v_ts) = [([Type], Type)] -> ([[Type]], [Type])
forall a b. [(a, b)] -> ([a], [b])
unzip ([([Type], Type)] -> ([[Type]], [Type]))
-> [([Type], Type)] -> ([[Type]], [Type])
forall a b. (a -> b) -> a -> b
$ [(Shape, Int, VName)] -> [Type] -> [([Type], Type)]
forall array a. [(Shape, Int, array)] -> [a] -> [([a], a)]
groupScatterResults' [(Shape, Int, VName)]
dests ([Type] -> [([Type], Type)]) -> [Type] -> [([Type], Type)]
forall a b. (a -> b) -> a -> b
$ Lambda (Wise SOACS) -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType Lambda (Wise SOACS)
fun
      isUsed :: (PatElemT (VarWisdom, Type), Result, SubExpRes, [Type], Type,
 (Shape, Int, VName))
-> Bool
isUsed (PatElemT (VarWisdom, Type)
bindee, Result
_, SubExpRes
_, [Type]
_, Type
_, (Shape, Int, VName)
_) = (VName -> UsageTable -> Bool
`UT.used` UsageTable
used) (VName -> Bool) -> VName -> Bool
forall a b. (a -> b) -> a -> b
$ PatElemT (VarWisdom, Type) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (VarWisdom, Type)
bindee
      ([PatElemT (VarWisdom, Type)]
pat', [Result]
i_ses', Result
v_ses', [[Type]]
i_ts', [Type]
v_ts', [(Shape, Int, VName)]
dests') =
        [(PatElemT (VarWisdom, Type), Result, SubExpRes, [Type], Type,
  (Shape, Int, VName))]
-> ([PatElemT (VarWisdom, Type)], [Result], Result, [[Type]],
    [Type], [(Shape, Int, VName)])
forall a b c d e f.
[(a, b, c, d, e, f)] -> ([a], [b], [c], [d], [e], [f])
unzip6 ([(PatElemT (VarWisdom, Type), Result, SubExpRes, [Type], Type,
   (Shape, Int, VName))]
 -> ([PatElemT (VarWisdom, Type)], [Result], Result, [[Type]],
     [Type], [(Shape, Int, VName)]))
-> [(PatElemT (VarWisdom, Type), Result, SubExpRes, [Type], Type,
     (Shape, Int, VName))]
-> ([PatElemT (VarWisdom, Type)], [Result], Result, [[Type]],
    [Type], [(Shape, Int, VName)])
forall a b. (a -> b) -> a -> b
$
          ((PatElemT (VarWisdom, Type), Result, SubExpRes, [Type], Type,
  (Shape, Int, VName))
 -> Bool)
-> [(PatElemT (VarWisdom, Type), Result, SubExpRes, [Type], Type,
     (Shape, Int, VName))]
-> [(PatElemT (VarWisdom, Type), Result, SubExpRes, [Type], Type,
     (Shape, Int, VName))]
forall a. (a -> Bool) -> [a] -> [a]
filter (PatElemT (VarWisdom, Type), Result, SubExpRes, [Type], Type,
 (Shape, Int, VName))
-> Bool
isUsed ([(PatElemT (VarWisdom, Type), Result, SubExpRes, [Type], Type,
   (Shape, Int, VName))]
 -> [(PatElemT (VarWisdom, Type), Result, SubExpRes, [Type], Type,
      (Shape, Int, VName))])
-> [(PatElemT (VarWisdom, Type), Result, SubExpRes, [Type], Type,
     (Shape, Int, VName))]
-> [(PatElemT (VarWisdom, Type), Result, SubExpRes, [Type], Type,
     (Shape, Int, VName))]
forall a b. (a -> b) -> a -> b
$
            [PatElemT (VarWisdom, Type)]
-> [Result]
-> Result
-> [[Type]]
-> [Type]
-> [(Shape, Int, VName)]
-> [(PatElemT (VarWisdom, Type), Result, SubExpRes, [Type], Type,
     (Shape, Int, VName))]
forall a b c d e f.
[a] -> [b] -> [c] -> [d] -> [e] -> [f] -> [(a, b, c, d, e, f)]
zip6 (PatT (VarWisdom, Type) -> [PatElemT (VarWisdom, Type)]
forall dec. PatT dec -> [PatElemT dec]
patElems PatT (VarWisdom, Type)
Pat (Wise SOACS)
pat) [Result]
i_ses Result
v_ses [[Type]]
i_ts [Type]
v_ts [(Shape, Int, VName)]
dests
      fun' :: Lambda (Wise SOACS)
fun' =
        Lambda (Wise SOACS)
fun
          { lambdaBody :: Body (Wise SOACS)
lambdaBody = (Lambda (Wise SOACS) -> Body (Wise SOACS)
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda (Wise SOACS)
fun) {bodyResult :: Result
bodyResult = [Result] -> Result
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [Result]
i_ses' Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Result
v_ses'},
            lambdaReturnType :: [Type]
lambdaReturnType = [[Type]] -> [Type]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Type]]
i_ts' [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [Type]
v_ts'
          }
   in if PatT (VarWisdom, Type)
Pat (Wise SOACS)
pat PatT (VarWisdom, Type) -> PatT (VarWisdom, Type) -> Bool
forall a. Eq a => a -> a -> Bool
/= [PatElemT (VarWisdom, Type)] -> PatT (VarWisdom, Type)
forall dec. [PatElemT dec] -> PatT dec
Pat [PatElemT (VarWisdom, Type)]
pat'
        then
          RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM (Wise SOACS) () -> Rule (Wise SOACS))
-> RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall a b. (a -> b) -> a -> b
$
            StmAux (ExpWisdom, ())
-> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpWisdom, ())
StmAux (ExpDec (Wise SOACS))
aux (RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ())
-> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$
              Pat (Rep (RuleM (Wise SOACS)))
-> Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m ()
letBind ([PatElemT (VarWisdom, Type)] -> PatT (VarWisdom, Type)
forall dec. [PatElemT dec] -> PatT dec
Pat [PatElemT (VarWisdom, Type)]
pat') (Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ())
-> Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$ Op (Wise SOACS) -> Exp (Wise SOACS)
forall rep. Op rep -> ExpT rep
Op (Op (Wise SOACS) -> Exp (Wise SOACS))
-> Op (Wise SOACS) -> Exp (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ SubExp
-> Lambda (Wise SOACS)
-> [VName]
-> [(Shape, Int, VName)]
-> SOAC (Wise SOACS)
forall rep.
SubExp
-> Lambda rep -> [VName] -> [(Shape, Int, VName)] -> SOAC rep
Scatter SubExp
w Lambda (Wise SOACS)
fun' [VName]
arrs [(Shape, Int, VName)]
dests'
        else Rule (Wise SOACS)
forall rep. Rule rep
Skip
removeDeadWrite BottomUp (Wise SOACS)
_ Pat (Wise SOACS)
_ StmAux (ExpDec (Wise SOACS))
_ Op (Wise SOACS)
_ = Rule (Wise SOACS)
forall rep. Rule rep
Skip

-- handles now concatenation of more than two arrays
fuseConcatScatter :: TopDownRuleOp (Wise SOACS)
fuseConcatScatter :: RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
fuseConcatScatter SymbolTable (Wise SOACS)
vtable Pat (Wise SOACS)
pat StmAux (ExpDec (Wise SOACS))
_ (Scatter _ fun arrs dests)
  | Just (ws :: [SubExp]
ws@(SubExp
w' : [SubExp]
_), [[VName]]
xss, [Certs]
css) <- [(SubExp, [VName], Certs)] -> ([SubExp], [[VName]], [Certs])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 ([(SubExp, [VName], Certs)] -> ([SubExp], [[VName]], [Certs]))
-> Maybe [(SubExp, [VName], Certs)]
-> Maybe ([SubExp], [[VName]], [Certs])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VName -> Maybe (SubExp, [VName], Certs))
-> [VName] -> Maybe [(SubExp, [VName], Certs)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> Maybe (SubExp, [VName], Certs)
isConcat [VName]
arrs,
    [[VName]]
xivs <- [[VName]] -> [[VName]]
forall a. [[a]] -> [[a]]
transpose [[VName]]
xss,
    (SubExp -> Bool) -> [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (SubExp
w' SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
==) [SubExp]
ws = RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM (Wise SOACS) () -> Rule (Wise SOACS))
-> RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ do
    let r :: Int
r = [[VName]] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [[VName]]
xivs
    [Lambda (Wise SOACS)]
fun2s <- (Int -> RuleM (Wise SOACS) (Lambda (Wise SOACS)))
-> [Int] -> RuleM (Wise SOACS) [Lambda (Wise SOACS)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\Int
_ -> Lambda (Wise SOACS) -> RuleM (Wise SOACS) (Lambda (Wise SOACS))
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda (Wise SOACS)
fun) [Int
1 .. Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]
    let ([Result]
fun_is, [Result]
fun_vs) =
          [(Result, Result)] -> ([Result], [Result])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Result, Result)] -> ([Result], [Result]))
-> [(Result, Result)] -> ([Result], [Result])
forall a b. (a -> b) -> a -> b
$
            (Lambda (Wise SOACS) -> (Result, Result))
-> [Lambda (Wise SOACS)] -> [(Result, Result)]
forall a b. (a -> b) -> [a] -> [b]
map
              ( [(Shape, Int, VName)] -> Result -> (Result, Result)
forall array a. [(Shape, Int, array)] -> [a] -> ([a], [a])
splitScatterResults [(Shape, Int, VName)]
dests
                  (Result -> (Result, Result))
-> (Lambda (Wise SOACS) -> Result)
-> Lambda (Wise SOACS)
-> (Result, Result)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body (Wise SOACS) -> Result
forall rep. BodyT rep -> Result
bodyResult
                  (Body (Wise SOACS) -> Result)
-> (Lambda (Wise SOACS) -> Body (Wise SOACS))
-> Lambda (Wise SOACS)
-> Result
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda (Wise SOACS) -> Body (Wise SOACS)
forall rep. LambdaT rep -> BodyT rep
lambdaBody
              )
              (Lambda (Wise SOACS)
fun Lambda (Wise SOACS)
-> [Lambda (Wise SOACS)] -> [Lambda (Wise SOACS)]
forall a. a -> [a] -> [a]
: [Lambda (Wise SOACS)]
fun2s)
        ([[Type]]
its, [[Type]]
vts) =
          [([Type], [Type])] -> ([[Type]], [[Type]])
forall a b. [(a, b)] -> ([a], [b])
unzip ([([Type], [Type])] -> ([[Type]], [[Type]]))
-> [([Type], [Type])] -> ([[Type]], [[Type]])
forall a b. (a -> b) -> a -> b
$
            Int -> ([Type], [Type]) -> [([Type], [Type])]
forall a. Int -> a -> [a]
replicate Int
r (([Type], [Type]) -> [([Type], [Type])])
-> ([Type], [Type]) -> [([Type], [Type])]
forall a b. (a -> b) -> a -> b
$
              [(Shape, Int, VName)] -> [Type] -> ([Type], [Type])
forall array a. [(Shape, Int, array)] -> [a] -> ([a], [a])
splitScatterResults [(Shape, Int, VName)]
dests ([Type] -> ([Type], [Type])) -> [Type] -> ([Type], [Type])
forall a b. (a -> b) -> a -> b
$ Lambda (Wise SOACS) -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType Lambda (Wise SOACS)
fun
        new_stmts :: Stms (Wise SOACS)
new_stmts = [Stms (Wise SOACS)] -> Stms (Wise SOACS)
forall a. Monoid a => [a] -> a
mconcat ([Stms (Wise SOACS)] -> Stms (Wise SOACS))
-> [Stms (Wise SOACS)] -> Stms (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ (Lambda (Wise SOACS) -> Stms (Wise SOACS))
-> [Lambda (Wise SOACS)] -> [Stms (Wise SOACS)]
forall a b. (a -> b) -> [a] -> [b]
map (Body (Wise SOACS) -> Stms (Wise SOACS)
forall rep. BodyT rep -> Stms rep
bodyStms (Body (Wise SOACS) -> Stms (Wise SOACS))
-> (Lambda (Wise SOACS) -> Body (Wise SOACS))
-> Lambda (Wise SOACS)
-> Stms (Wise SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda (Wise SOACS) -> Body (Wise SOACS)
forall rep. LambdaT rep -> BodyT rep
lambdaBody) (Lambda (Wise SOACS)
fun Lambda (Wise SOACS)
-> [Lambda (Wise SOACS)] -> [Lambda (Wise SOACS)]
forall a. a -> [a] -> [a]
: [Lambda (Wise SOACS)]
fun2s)
    let fun' :: Lambda (Wise SOACS)
fun' =
          Lambda :: forall rep. [LParam rep] -> BodyT rep -> [Type] -> LambdaT rep
Lambda
            { lambdaParams :: [LParam (Wise SOACS)]
lambdaParams = [[Param Type]] -> [Param Type]
forall a. Monoid a => [a] -> a
mconcat ([[Param Type]] -> [Param Type]) -> [[Param Type]] -> [Param Type]
forall a b. (a -> b) -> a -> b
$ (Lambda (Wise SOACS) -> [Param Type])
-> [Lambda (Wise SOACS)] -> [[Param Type]]
forall a b. (a -> b) -> [a] -> [b]
map Lambda (Wise SOACS) -> [Param Type]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams (Lambda (Wise SOACS)
fun Lambda (Wise SOACS)
-> [Lambda (Wise SOACS)] -> [Lambda (Wise SOACS)]
forall a. a -> [a] -> [a]
: [Lambda (Wise SOACS)]
fun2s),
              lambdaBody :: Body (Wise SOACS)
lambdaBody =
                Stms (Wise SOACS) -> Result -> Body (Wise SOACS)
forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody Stms (Wise SOACS)
new_stmts (Result -> Body (Wise SOACS)) -> Result -> Body (Wise SOACS)
forall a b. (a -> b) -> a -> b
$
                  [Result] -> Result
forall a. [[a]] -> [a]
mix [Result]
fun_is Result -> Result -> Result
forall a. Semigroup a => a -> a -> a
<> [Result] -> Result
forall a. [[a]] -> [a]
mix [Result]
fun_vs,
              lambdaReturnType :: [Type]
lambdaReturnType = [[Type]] -> [Type]
forall a. [[a]] -> [a]
mix [[Type]]
its [Type] -> [Type] -> [Type]
forall a. Semigroup a => a -> a -> a
<> [[Type]] -> [Type]
forall a. [[a]] -> [a]
mix [[Type]]
vts
            }
    Certs -> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying ([Certs] -> Certs
forall a. Monoid a => [a] -> a
mconcat [Certs]
css) (RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ())
-> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$
      Pat (Rep (RuleM (Wise SOACS)))
-> Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m ()
letBind Pat (Rep (RuleM (Wise SOACS)))
Pat (Wise SOACS)
pat (Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ())
-> Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$ Op (Wise SOACS) -> Exp (Wise SOACS)
forall rep. Op rep -> ExpT rep
Op (Op (Wise SOACS) -> Exp (Wise SOACS))
-> Op (Wise SOACS) -> Exp (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ SubExp
-> Lambda (Wise SOACS)
-> [VName]
-> [(Shape, Int, VName)]
-> SOAC (Wise SOACS)
forall rep.
SubExp
-> Lambda rep -> [VName] -> [(Shape, Int, VName)] -> SOAC rep
Scatter SubExp
w' Lambda (Wise SOACS)
fun' ([[VName]] -> [VName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[VName]]
xivs) ([(Shape, Int, VName)] -> SOAC (Wise SOACS))
-> [(Shape, Int, VName)] -> SOAC (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ ((Shape, Int, VName) -> (Shape, Int, VName))
-> [(Shape, Int, VName)] -> [(Shape, Int, VName)]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> (Shape, Int, VName) -> (Shape, Int, VName)
forall b a c. Num b => b -> (a, b, c) -> (a, b, c)
incWrites Int
r) [(Shape, Int, VName)]
dests
  where
    sizeOf :: VName -> Maybe SubExp
    sizeOf :: VName -> Maybe SubExp
sizeOf VName
x = Int -> Type -> SubExp
forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 (Type -> SubExp)
-> (Entry (Wise SOACS) -> Type) -> Entry (Wise SOACS) -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Entry (Wise SOACS) -> Type
forall t. Typed t => t -> Type
typeOf (Entry (Wise SOACS) -> SubExp)
-> Maybe (Entry (Wise SOACS)) -> Maybe SubExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> SymbolTable (Wise SOACS) -> Maybe (Entry (Wise SOACS))
forall rep. VName -> SymbolTable rep -> Maybe (Entry rep)
ST.lookup VName
x SymbolTable (Wise SOACS)
vtable
    mix :: [[a]] -> [a]
mix = [[a]] -> [a]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[a]] -> [a]) -> ([[a]] -> [[a]]) -> [[a]] -> [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [[a]] -> [[a]]
forall a. [[a]] -> [[a]]
transpose
    incWrites :: b -> (a, b, c) -> (a, b, c)
incWrites b
r (a
w, b
n, c
a) = (a
w, b
n b -> b -> b
forall a. Num a => a -> a -> a
* b
r, c
a) -- ToDO: is it (n*r) or (n+r-1)??
    isConcat :: VName -> Maybe (SubExp, [VName], Certs)
isConcat VName
v = case VName
-> SymbolTable (Wise SOACS) -> Maybe (Exp (Wise SOACS), Certs)
forall rep. VName -> SymbolTable rep -> Maybe (Exp rep, Certs)
ST.lookupExp VName
v SymbolTable (Wise SOACS)
vtable of
      Just (BasicOp (Concat Int
0 VName
x [VName]
ys SubExp
_), Certs
cs) -> do
        SubExp
x_w <- VName -> Maybe SubExp
sizeOf VName
x
        [SubExp]
y_ws <- (VName -> Maybe SubExp) -> [VName] -> Maybe [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> Maybe SubExp
sizeOf [VName]
ys
        Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ (SubExp -> Bool) -> [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (SubExp
x_w SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
==) [SubExp]
y_ws
        (SubExp, [VName], Certs) -> Maybe (SubExp, [VName], Certs)
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp
x_w, VName
x VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
ys, Certs
cs)
      Just (BasicOp (Reshape ShapeChange SubExp
reshape VName
arr), Certs
cs) -> do
        Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ Maybe [SubExp] -> Bool
forall a. Maybe a -> Bool
isJust (Maybe [SubExp] -> Bool) -> Maybe [SubExp] -> Bool
forall a b. (a -> b) -> a -> b
$ ShapeChange SubExp -> Maybe [SubExp]
forall d. ShapeChange d -> Maybe [d]
shapeCoercion ShapeChange SubExp
reshape
        (SubExp
a, [VName]
b, Certs
cs') <- VName -> Maybe (SubExp, [VName], Certs)
isConcat VName
arr
        (SubExp, [VName], Certs) -> Maybe (SubExp, [VName], Certs)
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp
a, [VName]
b, Certs
cs Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> Certs
cs')
      Maybe (Exp (Wise SOACS), Certs)
_ -> Maybe (SubExp, [VName], Certs)
forall a. Maybe a
Nothing
fuseConcatScatter SymbolTable (Wise SOACS)
_ Pat (Wise SOACS)
_ StmAux (ExpDec (Wise SOACS))
_ Op (Wise SOACS)
_ = Rule (Wise SOACS)
forall rep. Rule rep
Skip

simplifyClosedFormReduce :: TopDownRuleOp (Wise SOACS)
simplifyClosedFormReduce :: RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
simplifyClosedFormReduce SymbolTable (Wise SOACS)
_ Pat (Wise SOACS)
pat StmAux (ExpDec (Wise SOACS))
_ (Screma (Constant w) _ form)
  | Just [SubExp]
nes <- (Reduce (Wise SOACS) -> [SubExp])
-> [Reduce (Wise SOACS)] -> [SubExp]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Reduce (Wise SOACS) -> [SubExp]
forall rep. Reduce rep -> [SubExp]
redNeutral ([Reduce (Wise SOACS)] -> [SubExp])
-> (([Reduce (Wise SOACS)], Lambda (Wise SOACS))
    -> [Reduce (Wise SOACS)])
-> ([Reduce (Wise SOACS)], Lambda (Wise SOACS))
-> [SubExp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([Reduce (Wise SOACS)], Lambda (Wise SOACS))
-> [Reduce (Wise SOACS)]
forall a b. (a, b) -> a
fst (([Reduce (Wise SOACS)], Lambda (Wise SOACS)) -> [SubExp])
-> Maybe ([Reduce (Wise SOACS)], Lambda (Wise SOACS))
-> Maybe [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ScremaForm (Wise SOACS)
-> Maybe ([Reduce (Wise SOACS)], Lambda (Wise SOACS))
forall rep. ScremaForm rep -> Maybe ([Reduce rep], Lambda rep)
isRedomapSOAC ScremaForm (Wise SOACS)
form,
    PrimValue -> Bool
zeroIsh PrimValue
w =
    RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM (Wise SOACS) () -> Rule (Wise SOACS))
-> RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall a b. (a -> b) -> a -> b
$
      [(VName, SubExp)]
-> ((VName, SubExp) -> RuleM (Wise SOACS) ())
-> RuleM (Wise SOACS) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (PatT (VarWisdom, Type) -> [VName]
forall dec. PatT dec -> [VName]
patNames PatT (VarWisdom, Type)
Pat (Wise SOACS)
pat) [SubExp]
nes) (((VName, SubExp) -> RuleM (Wise SOACS) ())
 -> RuleM (Wise SOACS) ())
-> ((VName, SubExp) -> RuleM (Wise SOACS) ())
-> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$ \(VName
v, SubExp
ne) ->
        [VName] -> Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
v] (Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ())
-> Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Wise SOACS)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp (Wise SOACS)) -> BasicOp -> Exp (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
ne
simplifyClosedFormReduce SymbolTable (Wise SOACS)
vtable Pat (Wise SOACS)
pat StmAux (ExpDec (Wise SOACS))
_ (Screma _ arrs form)
  | Just [Reduce Commutativity
_ Lambda (Wise SOACS)
red_fun [SubExp]
nes] <- ScremaForm (Wise SOACS) -> Maybe [Reduce (Wise SOACS)]
forall rep. ScremaForm rep -> Maybe [Reduce rep]
isReduceSOAC ScremaForm (Wise SOACS)
form =
    RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM (Wise SOACS) () -> Rule (Wise SOACS))
-> RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ VarLookup (Wise SOACS)
-> Pat (Wise SOACS)
-> Lambda (Wise SOACS)
-> [SubExp]
-> [VName]
-> RuleM (Wise SOACS) ()
forall rep.
(ASTRep rep, BuilderOps rep) =>
VarLookup rep
-> Pat rep -> Lambda rep -> [SubExp] -> [VName] -> RuleM rep ()
foldClosedForm (VName
-> SymbolTable (Wise SOACS) -> Maybe (Exp (Wise SOACS), Certs)
forall rep. VName -> SymbolTable rep -> Maybe (Exp rep, Certs)
`ST.lookupExp` SymbolTable (Wise SOACS)
vtable) Pat (Wise SOACS)
pat Lambda (Wise SOACS)
red_fun [SubExp]
nes [VName]
arrs
simplifyClosedFormReduce SymbolTable (Wise SOACS)
_ Pat (Wise SOACS)
_ StmAux (ExpDec (Wise SOACS))
_ Op (Wise SOACS)
_ = Rule (Wise SOACS)
forall rep. Rule rep
Skip

-- For now we just remove singleton SOACs.
simplifyKnownIterationSOAC ::
  (Buildable rep, Simplify.SimplifiableRep rep, HasSOAC (Wise rep)) =>
  TopDownRuleOp (Wise rep)
simplifyKnownIterationSOAC :: TopDownRuleOp (Wise rep)
simplifyKnownIterationSOAC TopDown (Wise rep)
_ Pat (Wise rep)
pat StmAux (ExpDec (Wise rep))
_ Op (Wise rep)
op
  | Just (Screma (Constant PrimValue
k) [VName]
arrs (ScremaForm [Scan (Wise rep)]
scans [Reduce (Wise rep)]
reds Lambda (Wise rep)
map_lam)) <- Op (Wise rep) -> Maybe (SOAC (Wise rep))
forall rep. HasSOAC rep => Op rep -> Maybe (SOAC rep)
asSOAC Op (Wise rep)
op,
    PrimValue -> Bool
oneIsh PrimValue
k = RuleM (Wise rep) () -> Rule (Wise rep)
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM (Wise rep) () -> Rule (Wise rep))
-> RuleM (Wise rep) () -> Rule (Wise rep)
forall a b. (a -> b) -> a -> b
$ do
    let (Reduce Commutativity
_ Lambda (Wise rep)
red_lam [SubExp]
red_nes) = [Reduce (Wise rep)] -> Reduce (Wise rep)
forall rep. Buildable rep => [Reduce rep] -> Reduce rep
singleReduce [Reduce (Wise rep)]
reds
        (Scan Lambda (Wise rep)
scan_lam [SubExp]
scan_nes) = [Scan (Wise rep)] -> Scan (Wise rep)
forall rep. Buildable rep => [Scan rep] -> Scan rep
singleScan [Scan (Wise rep)]
scans
        ([PatElemT (VarWisdom, LetDec rep)]
scan_pes, [PatElemT (VarWisdom, LetDec rep)]
red_pes, [PatElemT (VarWisdom, LetDec rep)]
map_pes) =
          Int
-> Int
-> [PatElemT (VarWisdom, LetDec rep)]
-> ([PatElemT (VarWisdom, LetDec rep)],
    [PatElemT (VarWisdom, LetDec rep)],
    [PatElemT (VarWisdom, LetDec rep)])
forall a. Int -> Int -> [a] -> ([a], [a], [a])
splitAt3 ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
scan_nes) ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
red_nes) ([PatElemT (VarWisdom, LetDec rep)]
 -> ([PatElemT (VarWisdom, LetDec rep)],
     [PatElemT (VarWisdom, LetDec rep)],
     [PatElemT (VarWisdom, LetDec rep)]))
-> [PatElemT (VarWisdom, LetDec rep)]
-> ([PatElemT (VarWisdom, LetDec rep)],
    [PatElemT (VarWisdom, LetDec rep)],
    [PatElemT (VarWisdom, LetDec rep)])
forall a b. (a -> b) -> a -> b
$
            PatT (VarWisdom, LetDec rep) -> [PatElemT (VarWisdom, LetDec rep)]
forall dec. PatT dec -> [PatElemT dec]
patElems PatT (VarWisdom, LetDec rep)
Pat (Wise rep)
pat
        bindMapParam :: Param dec -> VName -> m ()
bindMapParam Param dec
p VName
a = do
          Type
a_t <- VName -> m Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
a
          [VName] -> Exp (Rep m) -> m ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
p] (Exp (Rep m) -> m ()) -> Exp (Rep m) -> m ()
forall a b. (a -> b) -> a -> b
$
            BasicOp -> Exp (Rep m)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
a (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
a_t [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ Int64 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int64
0 :: Int64)]
        bindArrayResult :: PatElemT dec -> SubExpRes -> m ()
bindArrayResult PatElemT dec
pe (SubExpRes Certs
cs SubExp
se) =
          Certs -> m () -> m ()
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs (m () -> m ()) -> (ExpT (Rep m) -> m ()) -> ExpT (Rep m) -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [VName] -> ExpT (Rep m) -> m ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [PatElemT dec -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT dec
pe] (ExpT (Rep m) -> m ()) -> ExpT (Rep m) -> m ()
forall a b. (a -> b) -> a -> b
$
            BasicOp -> ExpT (Rep m)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT (Rep m)) -> BasicOp -> ExpT (Rep m)
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Type -> BasicOp
ArrayLit [SubExp
se] (Type -> BasicOp) -> Type -> BasicOp
forall a b. (a -> b) -> a -> b
$ Type -> Type
forall shape u.
ArrayShape shape =>
TypeBase shape u -> TypeBase shape u
rowType (Type -> Type) -> Type -> Type
forall a b. (a -> b) -> a -> b
$ PatElemT dec -> Type
forall dec. Typed dec => PatElemT dec -> Type
patElemType PatElemT dec
pe
        bindResult :: PatElemT dec -> SubExpRes -> m ()
bindResult PatElemT dec
pe (SubExpRes Certs
cs SubExp
se) =
          Certs -> m () -> m ()
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ [VName] -> Exp (Rep m) -> m ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [PatElemT dec -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT dec
pe] (Exp (Rep m) -> m ()) -> Exp (Rep m) -> m ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep m)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se

    (Param Type -> VName -> RuleM (Wise rep) ())
-> [Param Type] -> [VName] -> RuleM (Wise rep) ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ Param Type -> VName -> RuleM (Wise rep) ()
forall (m :: * -> *) dec.
MonadBuilder m =>
Param dec -> VName -> m ()
bindMapParam (Lambda (Wise rep) -> [LParam (Wise rep)]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams Lambda (Wise rep)
map_lam) [VName]
arrs
    (Result
to_scan, Result
to_red, Result
map_res) <-
      Int -> Int -> Result -> (Result, Result, Result)
forall a. Int -> Int -> [a] -> ([a], [a], [a])
splitAt3 ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
scan_nes) ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
red_nes)
        (Result -> (Result, Result, Result))
-> RuleM (Wise rep) Result
-> RuleM (Wise rep) (Result, Result, Result)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Body (Rep (RuleM (Wise rep))) -> RuleM (Wise rep) Result
forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind (Lambda (Wise rep) -> BodyT (Wise rep)
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda (Wise rep)
map_lam)
    Result
scan_res <- Lambda (Rep (RuleM (Wise rep)))
-> [RuleM (Wise rep) (Exp (Rep (RuleM (Wise rep))))]
-> RuleM (Wise rep) Result
forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m Result
eLambda Lambda (Rep (RuleM (Wise rep)))
Lambda (Wise rep)
scan_lam ([RuleM (Wise rep) (Exp (Rep (RuleM (Wise rep))))]
 -> RuleM (Wise rep) Result)
-> [RuleM (Wise rep) (Exp (Rep (RuleM (Wise rep))))]
-> RuleM (Wise rep) Result
forall a b. (a -> b) -> a -> b
$ (SubExp -> RuleM (Wise rep) (Exp (Wise rep)))
-> [SubExp] -> [RuleM (Wise rep) (Exp (Wise rep))]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> RuleM (Wise rep) (Exp (Wise rep))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp ([SubExp] -> [RuleM (Wise rep) (Exp (Wise rep))])
-> [SubExp] -> [RuleM (Wise rep) (Exp (Wise rep))]
forall a b. (a -> b) -> a -> b
$ [SubExp]
scan_nes [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ (SubExpRes -> SubExp) -> Result -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp Result
to_scan
    Result
red_res <- Lambda (Rep (RuleM (Wise rep)))
-> [RuleM (Wise rep) (Exp (Rep (RuleM (Wise rep))))]
-> RuleM (Wise rep) Result
forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m Result
eLambda Lambda (Rep (RuleM (Wise rep)))
Lambda (Wise rep)
red_lam ([RuleM (Wise rep) (Exp (Rep (RuleM (Wise rep))))]
 -> RuleM (Wise rep) Result)
-> [RuleM (Wise rep) (Exp (Rep (RuleM (Wise rep))))]
-> RuleM (Wise rep) Result
forall a b. (a -> b) -> a -> b
$ (SubExp -> RuleM (Wise rep) (Exp (Wise rep)))
-> [SubExp] -> [RuleM (Wise rep) (Exp (Wise rep))]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> RuleM (Wise rep) (Exp (Wise rep))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp ([SubExp] -> [RuleM (Wise rep) (Exp (Wise rep))])
-> [SubExp] -> [RuleM (Wise rep) (Exp (Wise rep))]
forall a b. (a -> b) -> a -> b
$ [SubExp]
red_nes [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ (SubExpRes -> SubExp) -> Result -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp Result
to_red

    (PatElemT (VarWisdom, LetDec rep)
 -> SubExpRes -> RuleM (Wise rep) ())
-> [PatElemT (VarWisdom, LetDec rep)]
-> Result
-> RuleM (Wise rep) ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ PatElemT (VarWisdom, LetDec rep)
-> SubExpRes -> RuleM (Wise rep) ()
forall (m :: * -> *) dec.
(MonadBuilder m, Typed dec) =>
PatElemT dec -> SubExpRes -> m ()
bindArrayResult [PatElemT (VarWisdom, LetDec rep)]
scan_pes Result
scan_res
    (PatElemT (VarWisdom, LetDec rep)
 -> SubExpRes -> RuleM (Wise rep) ())
-> [PatElemT (VarWisdom, LetDec rep)]
-> Result
-> RuleM (Wise rep) ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ PatElemT (VarWisdom, LetDec rep)
-> SubExpRes -> RuleM (Wise rep) ()
forall (m :: * -> *) dec.
MonadBuilder m =>
PatElemT dec -> SubExpRes -> m ()
bindResult [PatElemT (VarWisdom, LetDec rep)]
red_pes Result
red_res
    (PatElemT (VarWisdom, LetDec rep)
 -> SubExpRes -> RuleM (Wise rep) ())
-> [PatElemT (VarWisdom, LetDec rep)]
-> Result
-> RuleM (Wise rep) ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ PatElemT (VarWisdom, LetDec rep)
-> SubExpRes -> RuleM (Wise rep) ()
forall (m :: * -> *) dec.
(MonadBuilder m, Typed dec) =>
PatElemT dec -> SubExpRes -> m ()
bindArrayResult [PatElemT (VarWisdom, LetDec rep)]
map_pes Result
map_res
simplifyKnownIterationSOAC TopDown (Wise rep)
_ Pat (Wise rep)
pat StmAux (ExpDec (Wise rep))
_ Op (Wise rep)
op
  | Just (Stream (Constant PrimValue
k) [VName]
arrs StreamForm (Wise rep)
_ [SubExp]
nes Lambda (Wise rep)
fold_lam) <- Op (Wise rep) -> Maybe (SOAC (Wise rep))
forall rep. HasSOAC rep => Op rep -> Maybe (SOAC rep)
asSOAC Op (Wise rep)
op,
    PrimValue -> Bool
oneIsh PrimValue
k = RuleM (Wise rep) () -> Rule (Wise rep)
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM (Wise rep) () -> Rule (Wise rep))
-> RuleM (Wise rep) () -> Rule (Wise rep)
forall a b. (a -> b) -> a -> b
$ do
    let (Param Type
chunk_param, [Param Type]
acc_params, [Param Type]
slice_params) =
          Int -> [Param Type] -> (Param Type, [Param Type], [Param Type])
forall dec.
Int -> [Param dec] -> (Param dec, [Param dec], [Param dec])
partitionChunkedFoldParameters ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) (Lambda (Wise rep) -> [LParam (Wise rep)]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams Lambda (Wise rep)
fold_lam)

    [VName] -> Exp (Rep (RuleM (Wise rep))) -> RuleM (Wise rep) ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
chunk_param] (Exp (Rep (RuleM (Wise rep))) -> RuleM (Wise rep) ())
-> Exp (Rep (RuleM (Wise rep))) -> RuleM (Wise rep) ()
forall a b. (a -> b) -> a -> b
$
      BasicOp -> Exp (Wise rep)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp (Wise rep)) -> BasicOp -> Exp (Wise rep)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1

    [(Param Type, SubExp)]
-> ((Param Type, SubExp) -> RuleM (Wise rep) ())
-> RuleM (Wise rep) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param Type] -> [SubExp] -> [(Param Type, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param Type]
acc_params [SubExp]
nes) (((Param Type, SubExp) -> RuleM (Wise rep) ())
 -> RuleM (Wise rep) ())
-> ((Param Type, SubExp) -> RuleM (Wise rep) ())
-> RuleM (Wise rep) ()
forall a b. (a -> b) -> a -> b
$ \(Param Type
p, SubExp
ne) ->
      [VName] -> Exp (Rep (RuleM (Wise rep))) -> RuleM (Wise rep) ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
p] (Exp (Rep (RuleM (Wise rep))) -> RuleM (Wise rep) ())
-> Exp (Rep (RuleM (Wise rep))) -> RuleM (Wise rep) ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Wise rep)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp (Wise rep)) -> BasicOp -> Exp (Wise rep)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
ne

    [(Param Type, VName)]
-> ((Param Type, VName) -> RuleM (Wise rep) ())
-> RuleM (Wise rep) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param Type] -> [VName] -> [(Param Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param Type]
slice_params [VName]
arrs) (((Param Type, VName) -> RuleM (Wise rep) ())
 -> RuleM (Wise rep) ())
-> ((Param Type, VName) -> RuleM (Wise rep) ())
-> RuleM (Wise rep) ()
forall a b. (a -> b) -> a -> b
$ \(Param Type
p, VName
arr) ->
      [VName] -> Exp (Rep (RuleM (Wise rep))) -> RuleM (Wise rep) ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
p] (Exp (Rep (RuleM (Wise rep))) -> RuleM (Wise rep) ())
-> Exp (Rep (RuleM (Wise rep))) -> RuleM (Wise rep) ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Wise rep)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp (Wise rep)) -> BasicOp -> Exp (Wise rep)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
arr

    Result
res <- Body (Rep (RuleM (Wise rep))) -> RuleM (Wise rep) Result
forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind (Body (Rep (RuleM (Wise rep))) -> RuleM (Wise rep) Result)
-> Body (Rep (RuleM (Wise rep))) -> RuleM (Wise rep) Result
forall a b. (a -> b) -> a -> b
$ Lambda (Wise rep) -> BodyT (Wise rep)
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda (Wise rep)
fold_lam

    [(VName, SubExpRes)]
-> ((VName, SubExpRes) -> RuleM (Wise rep) ())
-> RuleM (Wise rep) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> Result -> [(VName, SubExpRes)]
forall a b. [a] -> [b] -> [(a, b)]
zip (PatT (VarWisdom, LetDec rep) -> [VName]
forall dec. PatT dec -> [VName]
patNames PatT (VarWisdom, LetDec rep)
Pat (Wise rep)
pat) Result
res) (((VName, SubExpRes) -> RuleM (Wise rep) ())
 -> RuleM (Wise rep) ())
-> ((VName, SubExpRes) -> RuleM (Wise rep) ())
-> RuleM (Wise rep) ()
forall a b. (a -> b) -> a -> b
$ \(VName
v, SubExpRes Certs
cs SubExp
se) ->
      Certs -> RuleM (Wise rep) () -> RuleM (Wise rep) ()
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs (RuleM (Wise rep) () -> RuleM (Wise rep) ())
-> RuleM (Wise rep) () -> RuleM (Wise rep) ()
forall a b. (a -> b) -> a -> b
$ [VName] -> Exp (Rep (RuleM (Wise rep))) -> RuleM (Wise rep) ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
v] (Exp (Rep (RuleM (Wise rep))) -> RuleM (Wise rep) ())
-> Exp (Rep (RuleM (Wise rep))) -> RuleM (Wise rep) ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Wise rep)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp (Wise rep)) -> BasicOp -> Exp (Wise rep)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
simplifyKnownIterationSOAC TopDown (Wise rep)
_ Pat (Wise rep)
_ StmAux (ExpDec (Wise rep))
_ Op (Wise rep)
_ = Rule (Wise rep)
forall rep. Rule rep
Skip

data ArrayOp
  = ArrayIndexing Certs VName (Slice SubExp)
  | ArrayRearrange Certs VName [Int]
  | ArrayRotate Certs VName [SubExp]
  | ArrayCopy Certs VName
  | -- | Never constructed.
    ArrayVar Certs VName
  deriving (ArrayOp -> ArrayOp -> Bool
(ArrayOp -> ArrayOp -> Bool)
-> (ArrayOp -> ArrayOp -> Bool) -> Eq ArrayOp
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ArrayOp -> ArrayOp -> Bool
$c/= :: ArrayOp -> ArrayOp -> Bool
== :: ArrayOp -> ArrayOp -> Bool
$c== :: ArrayOp -> ArrayOp -> Bool
Eq, Eq ArrayOp
Eq ArrayOp
-> (ArrayOp -> ArrayOp -> Ordering)
-> (ArrayOp -> ArrayOp -> Bool)
-> (ArrayOp -> ArrayOp -> Bool)
-> (ArrayOp -> ArrayOp -> Bool)
-> (ArrayOp -> ArrayOp -> Bool)
-> (ArrayOp -> ArrayOp -> ArrayOp)
-> (ArrayOp -> ArrayOp -> ArrayOp)
-> Ord ArrayOp
ArrayOp -> ArrayOp -> Bool
ArrayOp -> ArrayOp -> Ordering
ArrayOp -> ArrayOp -> ArrayOp
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: ArrayOp -> ArrayOp -> ArrayOp
$cmin :: ArrayOp -> ArrayOp -> ArrayOp
max :: ArrayOp -> ArrayOp -> ArrayOp
$cmax :: ArrayOp -> ArrayOp -> ArrayOp
>= :: ArrayOp -> ArrayOp -> Bool
$c>= :: ArrayOp -> ArrayOp -> Bool
> :: ArrayOp -> ArrayOp -> Bool
$c> :: ArrayOp -> ArrayOp -> Bool
<= :: ArrayOp -> ArrayOp -> Bool
$c<= :: ArrayOp -> ArrayOp -> Bool
< :: ArrayOp -> ArrayOp -> Bool
$c< :: ArrayOp -> ArrayOp -> Bool
compare :: ArrayOp -> ArrayOp -> Ordering
$ccompare :: ArrayOp -> ArrayOp -> Ordering
$cp1Ord :: Eq ArrayOp
Ord, Int -> ArrayOp -> ShowS
[ArrayOp] -> ShowS
ArrayOp -> String
(Int -> ArrayOp -> ShowS)
-> (ArrayOp -> String) -> ([ArrayOp] -> ShowS) -> Show ArrayOp
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ArrayOp] -> ShowS
$cshowList :: [ArrayOp] -> ShowS
show :: ArrayOp -> String
$cshow :: ArrayOp -> String
showsPrec :: Int -> ArrayOp -> ShowS
$cshowsPrec :: Int -> ArrayOp -> ShowS
Show)

arrayOpArr :: ArrayOp -> VName
arrayOpArr :: ArrayOp -> VName
arrayOpArr (ArrayIndexing Certs
_ VName
arr Slice SubExp
_) = VName
arr
arrayOpArr (ArrayRearrange Certs
_ VName
arr [Int]
_) = VName
arr
arrayOpArr (ArrayRotate Certs
_ VName
arr [SubExp]
_) = VName
arr
arrayOpArr (ArrayCopy Certs
_ VName
arr) = VName
arr
arrayOpArr (ArrayVar Certs
_ VName
arr) = VName
arr

arrayOpCerts :: ArrayOp -> Certs
arrayOpCerts :: ArrayOp -> Certs
arrayOpCerts (ArrayIndexing Certs
cs VName
_ Slice SubExp
_) = Certs
cs
arrayOpCerts (ArrayRearrange Certs
cs VName
_ [Int]
_) = Certs
cs
arrayOpCerts (ArrayRotate Certs
cs VName
_ [SubExp]
_) = Certs
cs
arrayOpCerts (ArrayCopy Certs
cs VName
_) = Certs
cs
arrayOpCerts (ArrayVar Certs
cs VName
_) = Certs
cs

isArrayOp :: Certs -> AST.Exp (Wise SOACS) -> Maybe ArrayOp
isArrayOp :: Certs -> Exp (Wise SOACS) -> Maybe ArrayOp
isArrayOp Certs
cs (BasicOp (Index VName
arr Slice SubExp
slice)) =
  ArrayOp -> Maybe ArrayOp
forall a. a -> Maybe a
Just (ArrayOp -> Maybe ArrayOp) -> ArrayOp -> Maybe ArrayOp
forall a b. (a -> b) -> a -> b
$ Certs -> VName -> Slice SubExp -> ArrayOp
ArrayIndexing Certs
cs VName
arr Slice SubExp
slice
isArrayOp Certs
cs (BasicOp (Rearrange [Int]
perm VName
arr)) =
  ArrayOp -> Maybe ArrayOp
forall a. a -> Maybe a
Just (ArrayOp -> Maybe ArrayOp) -> ArrayOp -> Maybe ArrayOp
forall a b. (a -> b) -> a -> b
$ Certs -> VName -> [Int] -> ArrayOp
ArrayRearrange Certs
cs VName
arr [Int]
perm
isArrayOp Certs
cs (BasicOp (Rotate [SubExp]
rots VName
arr)) =
  ArrayOp -> Maybe ArrayOp
forall a. a -> Maybe a
Just (ArrayOp -> Maybe ArrayOp) -> ArrayOp -> Maybe ArrayOp
forall a b. (a -> b) -> a -> b
$ Certs -> VName -> [SubExp] -> ArrayOp
ArrayRotate Certs
cs VName
arr [SubExp]
rots
isArrayOp Certs
cs (BasicOp (Copy VName
arr)) =
  ArrayOp -> Maybe ArrayOp
forall a. a -> Maybe a
Just (ArrayOp -> Maybe ArrayOp) -> ArrayOp -> Maybe ArrayOp
forall a b. (a -> b) -> a -> b
$ Certs -> VName -> ArrayOp
ArrayCopy Certs
cs VName
arr
isArrayOp Certs
_ Exp (Wise SOACS)
_ =
  Maybe ArrayOp
forall a. Maybe a
Nothing

fromArrayOp :: ArrayOp -> (Certs, AST.Exp (Wise SOACS))
fromArrayOp :: ArrayOp -> (Certs, Exp (Wise SOACS))
fromArrayOp (ArrayIndexing Certs
cs VName
arr Slice SubExp
slice) = (Certs
cs, BasicOp -> Exp (Wise SOACS)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp (Wise SOACS)) -> BasicOp -> Exp (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr Slice SubExp
slice)
fromArrayOp (ArrayRearrange Certs
cs VName
arr [Int]
perm) = (Certs
cs, BasicOp -> Exp (Wise SOACS)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp (Wise SOACS)) -> BasicOp -> Exp (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Rearrange [Int]
perm VName
arr)
fromArrayOp (ArrayRotate Certs
cs VName
arr [SubExp]
rots) = (Certs
cs, BasicOp -> Exp (Wise SOACS)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp (Wise SOACS)) -> BasicOp -> Exp (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ [SubExp] -> VName -> BasicOp
Rotate [SubExp]
rots VName
arr)
fromArrayOp (ArrayCopy Certs
cs VName
arr) = (Certs
cs, BasicOp -> Exp (Wise SOACS)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp (Wise SOACS)) -> BasicOp -> Exp (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
arr)
fromArrayOp (ArrayVar Certs
cs VName
arr) = (Certs
cs, BasicOp -> Exp (Wise SOACS)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp (Wise SOACS)) -> BasicOp -> Exp (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
arr)

arrayOps :: AST.Body (Wise SOACS) -> S.Set (AST.Pat (Wise SOACS), ArrayOp)
arrayOps :: Body (Wise SOACS) -> Set (Pat (Wise SOACS), ArrayOp)
arrayOps = [Set (PatT (VarWisdom, Type), ArrayOp)]
-> Set (PatT (VarWisdom, Type), ArrayOp)
forall a. Monoid a => [a] -> a
mconcat ([Set (PatT (VarWisdom, Type), ArrayOp)]
 -> Set (PatT (VarWisdom, Type), ArrayOp))
-> (Body (Wise SOACS) -> [Set (PatT (VarWisdom, Type), ArrayOp)])
-> Body (Wise SOACS)
-> Set (PatT (VarWisdom, Type), ArrayOp)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Stm (Wise SOACS) -> Set (PatT (VarWisdom, Type), ArrayOp))
-> [Stm (Wise SOACS)] -> [Set (PatT (VarWisdom, Type), ArrayOp)]
forall a b. (a -> b) -> [a] -> [b]
map Stm (Wise SOACS) -> Set (PatT (VarWisdom, Type), ArrayOp)
onStm ([Stm (Wise SOACS)] -> [Set (PatT (VarWisdom, Type), ArrayOp)])
-> (Body (Wise SOACS) -> [Stm (Wise SOACS)])
-> Body (Wise SOACS)
-> [Set (PatT (VarWisdom, Type), ArrayOp)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms (Wise SOACS) -> [Stm (Wise SOACS)]
forall rep. Stms rep -> [Stm rep]
stmsToList (Stms (Wise SOACS) -> [Stm (Wise SOACS)])
-> (Body (Wise SOACS) -> Stms (Wise SOACS))
-> Body (Wise SOACS)
-> [Stm (Wise SOACS)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body (Wise SOACS) -> Stms (Wise SOACS)
forall rep. BodyT rep -> Stms rep
bodyStms
  where
    onStm :: Stm (Wise SOACS) -> Set (PatT (VarWisdom, Type), ArrayOp)
onStm (Let Pat (Wise SOACS)
pat StmAux (ExpDec (Wise SOACS))
aux Exp (Wise SOACS)
e) =
      case Certs -> Exp (Wise SOACS) -> Maybe ArrayOp
isArrayOp (StmAux (ExpWisdom, ()) -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (ExpWisdom, ())
StmAux (ExpDec (Wise SOACS))
aux) Exp (Wise SOACS)
e of
        Just ArrayOp
op -> (PatT (VarWisdom, Type), ArrayOp)
-> Set (PatT (VarWisdom, Type), ArrayOp)
forall a. a -> Set a
S.singleton (PatT (VarWisdom, Type)
Pat (Wise SOACS)
pat, ArrayOp
op)
        Maybe ArrayOp
Nothing -> State (Set (PatT (VarWisdom, Type), ArrayOp)) ()
-> Set (PatT (VarWisdom, Type), ArrayOp)
-> Set (PatT (VarWisdom, Type), ArrayOp)
forall s a. State s a -> s -> s
execState (Walker
  (Wise SOACS)
  (StateT (Set (PatT (VarWisdom, Type), ArrayOp)) Identity)
-> Exp (Wise SOACS)
-> State (Set (PatT (VarWisdom, Type), ArrayOp)) ()
forall (m :: * -> *) rep.
Monad m =>
Walker rep m -> Exp rep -> m ()
walkExpM Walker
  (Wise SOACS)
  (StateT (Set (PatT (VarWisdom, Type), ArrayOp)) Identity)
walker Exp (Wise SOACS)
e) Set (PatT (VarWisdom, Type), ArrayOp)
forall a. Monoid a => a
mempty
    onOp :: SOAC (Wise SOACS) -> Set (PatT (VarWisdom, Type), ArrayOp)
onOp = Writer (Set (PatT (VarWisdom, Type), ArrayOp)) (SOAC (Wise SOACS))
-> Set (PatT (VarWisdom, Type), ArrayOp)
forall w a. Writer w a -> w
execWriter (Writer (Set (PatT (VarWisdom, Type), ArrayOp)) (SOAC (Wise SOACS))
 -> Set (PatT (VarWisdom, Type), ArrayOp))
-> (SOAC (Wise SOACS)
    -> Writer
         (Set (PatT (VarWisdom, Type), ArrayOp)) (SOAC (Wise SOACS)))
-> SOAC (Wise SOACS)
-> Set (PatT (VarWisdom, Type), ArrayOp)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SOACMapper
  (Wise SOACS)
  (Wise SOACS)
  (WriterT (Set (PatT (VarWisdom, Type), ArrayOp)) Identity)
-> SOAC (Wise SOACS)
-> Writer
     (Set (PatT (VarWisdom, Type), ArrayOp)) (SOAC (Wise SOACS))
forall (m :: * -> *) frep trep.
(Applicative m, Monad m) =>
SOACMapper frep trep m -> SOAC frep -> m (SOAC trep)
mapSOACM SOACMapper
  Any Any (WriterT (Set (PatT (VarWisdom, Type), ArrayOp)) Identity)
forall (m :: * -> *) rep. Monad m => SOACMapper rep rep m
identitySOACMapper {mapOnSOACLambda :: Lambda (Wise SOACS)
-> WriterT
     (Set (PatT (VarWisdom, Type), ArrayOp))
     Identity
     (Lambda (Wise SOACS))
mapOnSOACLambda = Lambda (Wise SOACS)
-> WriterT
     (Set (PatT (VarWisdom, Type), ArrayOp))
     Identity
     (Lambda (Wise SOACS))
forall (m :: * -> *).
MonadWriter (Set (PatT (VarWisdom, Type), ArrayOp)) m =>
Lambda (Wise SOACS) -> m (Lambda (Wise SOACS))
onLambda}
    onLambda :: Lambda (Wise SOACS) -> m (Lambda (Wise SOACS))
onLambda Lambda (Wise SOACS)
lam = do
      Set (PatT (VarWisdom, Type), ArrayOp) -> m ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell (Set (PatT (VarWisdom, Type), ArrayOp) -> m ())
-> Set (PatT (VarWisdom, Type), ArrayOp) -> m ()
forall a b. (a -> b) -> a -> b
$ Body (Wise SOACS) -> Set (Pat (Wise SOACS), ArrayOp)
arrayOps (Body (Wise SOACS) -> Set (Pat (Wise SOACS), ArrayOp))
-> Body (Wise SOACS) -> Set (Pat (Wise SOACS), ArrayOp)
forall a b. (a -> b) -> a -> b
$ Lambda (Wise SOACS) -> Body (Wise SOACS)
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda (Wise SOACS)
lam
      Lambda (Wise SOACS) -> m (Lambda (Wise SOACS))
forall (m :: * -> *) a. Monad m => a -> m a
return Lambda (Wise SOACS)
lam
    walker :: Walker
  (Wise SOACS)
  (StateT (Set (PatT (VarWisdom, Type), ArrayOp)) Identity)
walker =
      Walker
  (Wise SOACS)
  (StateT (Set (PatT (VarWisdom, Type), ArrayOp)) Identity)
forall (m :: * -> *) rep. Monad m => Walker rep m
identityWalker
        { walkOnBody :: Scope (Wise SOACS)
-> Body (Wise SOACS)
-> State (Set (PatT (VarWisdom, Type), ArrayOp)) ()
walkOnBody = (Body (Wise SOACS)
 -> State (Set (PatT (VarWisdom, Type), ArrayOp)) ())
-> Scope (Wise SOACS)
-> Body (Wise SOACS)
-> State (Set (PatT (VarWisdom, Type), ArrayOp)) ()
forall a b. a -> b -> a
const ((Body (Wise SOACS)
  -> State (Set (PatT (VarWisdom, Type), ArrayOp)) ())
 -> Scope (Wise SOACS)
 -> Body (Wise SOACS)
 -> State (Set (PatT (VarWisdom, Type), ArrayOp)) ())
-> (Body (Wise SOACS)
    -> State (Set (PatT (VarWisdom, Type), ArrayOp)) ())
-> Scope (Wise SOACS)
-> Body (Wise SOACS)
-> State (Set (PatT (VarWisdom, Type), ArrayOp)) ()
forall a b. (a -> b) -> a -> b
$ (Set (PatT (VarWisdom, Type), ArrayOp)
 -> Set (PatT (VarWisdom, Type), ArrayOp))
-> State (Set (PatT (VarWisdom, Type), ArrayOp)) ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((Set (PatT (VarWisdom, Type), ArrayOp)
  -> Set (PatT (VarWisdom, Type), ArrayOp))
 -> State (Set (PatT (VarWisdom, Type), ArrayOp)) ())
-> (Body (Wise SOACS)
    -> Set (PatT (VarWisdom, Type), ArrayOp)
    -> Set (PatT (VarWisdom, Type), ArrayOp))
-> Body (Wise SOACS)
-> State (Set (PatT (VarWisdom, Type), ArrayOp)) ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Set (PatT (VarWisdom, Type), ArrayOp)
-> Set (PatT (VarWisdom, Type), ArrayOp)
-> Set (PatT (VarWisdom, Type), ArrayOp)
forall a. Semigroup a => a -> a -> a
(<>) (Set (PatT (VarWisdom, Type), ArrayOp)
 -> Set (PatT (VarWisdom, Type), ArrayOp)
 -> Set (PatT (VarWisdom, Type), ArrayOp))
-> (Body (Wise SOACS) -> Set (PatT (VarWisdom, Type), ArrayOp))
-> Body (Wise SOACS)
-> Set (PatT (VarWisdom, Type), ArrayOp)
-> Set (PatT (VarWisdom, Type), ArrayOp)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body (Wise SOACS) -> Set (PatT (VarWisdom, Type), ArrayOp)
Body (Wise SOACS) -> Set (Pat (Wise SOACS), ArrayOp)
arrayOps,
          walkOnOp :: Op (Wise SOACS) -> State (Set (PatT (VarWisdom, Type), ArrayOp)) ()
walkOnOp = (Set (PatT (VarWisdom, Type), ArrayOp)
 -> Set (PatT (VarWisdom, Type), ArrayOp))
-> State (Set (PatT (VarWisdom, Type), ArrayOp)) ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((Set (PatT (VarWisdom, Type), ArrayOp)
  -> Set (PatT (VarWisdom, Type), ArrayOp))
 -> State (Set (PatT (VarWisdom, Type), ArrayOp)) ())
-> (SOAC (Wise SOACS)
    -> Set (PatT (VarWisdom, Type), ArrayOp)
    -> Set (PatT (VarWisdom, Type), ArrayOp))
-> SOAC (Wise SOACS)
-> State (Set (PatT (VarWisdom, Type), ArrayOp)) ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Set (PatT (VarWisdom, Type), ArrayOp)
-> Set (PatT (VarWisdom, Type), ArrayOp)
-> Set (PatT (VarWisdom, Type), ArrayOp)
forall a. Semigroup a => a -> a -> a
(<>) (Set (PatT (VarWisdom, Type), ArrayOp)
 -> Set (PatT (VarWisdom, Type), ArrayOp)
 -> Set (PatT (VarWisdom, Type), ArrayOp))
-> (SOAC (Wise SOACS) -> Set (PatT (VarWisdom, Type), ArrayOp))
-> SOAC (Wise SOACS)
-> Set (PatT (VarWisdom, Type), ArrayOp)
-> Set (PatT (VarWisdom, Type), ArrayOp)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SOAC (Wise SOACS) -> Set (PatT (VarWisdom, Type), ArrayOp)
onOp
        }

replaceArrayOps ::
  M.Map ArrayOp ArrayOp ->
  AST.Body (Wise SOACS) ->
  AST.Body (Wise SOACS)
replaceArrayOps :: Map ArrayOp ArrayOp -> Body (Wise SOACS) -> Body (Wise SOACS)
replaceArrayOps Map ArrayOp ArrayOp
substs (Body BodyDec (Wise SOACS)
_ Stms (Wise SOACS)
stms Result
res) =
  Stms (Wise SOACS) -> Result -> Body (Wise SOACS)
forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody ((Stm (Wise SOACS) -> Stm (Wise SOACS))
-> Stms (Wise SOACS) -> Stms (Wise SOACS)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Stm (Wise SOACS) -> Stm (Wise SOACS)
onStm Stms (Wise SOACS)
stms) Result
res
  where
    onStm :: Stm (Wise SOACS) -> Stm (Wise SOACS)
onStm (Let Pat (Wise SOACS)
pat StmAux (ExpDec (Wise SOACS))
aux Exp (Wise SOACS)
e) =
      let (Certs
cs', Exp (Wise SOACS)
e') = Certs -> Exp (Wise SOACS) -> (Certs, Exp (Wise SOACS))
onExp (StmAux (ExpWisdom, ()) -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (ExpWisdom, ())
StmAux (ExpDec (Wise SOACS))
aux) Exp (Wise SOACS)
e
       in Certs -> Stm (Wise SOACS) -> Stm (Wise SOACS)
forall rep. Certs -> Stm rep -> Stm rep
certify Certs
cs' (Stm (Wise SOACS) -> Stm (Wise SOACS))
-> Stm (Wise SOACS) -> Stm (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ [Ident]
-> StmAux (ExpWisdom, ()) -> Exp (Wise SOACS) -> Stm (Wise SOACS)
forall rep a.
Buildable rep =>
[Ident] -> StmAux a -> Exp rep -> Stm rep
mkLet' (PatT (VarWisdom, Type) -> [Ident]
forall dec. Typed dec => PatT dec -> [Ident]
patIdents PatT (VarWisdom, Type)
Pat (Wise SOACS)
pat) StmAux (ExpWisdom, ())
StmAux (ExpDec (Wise SOACS))
aux Exp (Wise SOACS)
e'
    onExp :: Certs -> Exp (Wise SOACS) -> (Certs, Exp (Wise SOACS))
onExp Certs
cs Exp (Wise SOACS)
e
      | Just ArrayOp
op <- Certs -> Exp (Wise SOACS) -> Maybe ArrayOp
isArrayOp Certs
cs Exp (Wise SOACS)
e,
        Just ArrayOp
op' <- ArrayOp -> Map ArrayOp ArrayOp -> Maybe ArrayOp
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup ArrayOp
op Map ArrayOp ArrayOp
substs =
        ArrayOp -> (Certs, Exp (Wise SOACS))
fromArrayOp ArrayOp
op'
    onExp Certs
cs Exp (Wise SOACS)
e = (Certs
cs, Mapper (Wise SOACS) (Wise SOACS) Identity
-> Exp (Wise SOACS) -> Exp (Wise SOACS)
forall frep trep. Mapper frep trep Identity -> Exp frep -> Exp trep
mapExp Mapper (Wise SOACS) (Wise SOACS) Identity
mapper Exp (Wise SOACS)
e)
    mapper :: Mapper (Wise SOACS) (Wise SOACS) Identity
mapper =
      Mapper (Wise SOACS) (Wise SOACS) Identity
forall (m :: * -> *) rep. Monad m => Mapper rep rep m
identityMapper
        { mapOnBody :: Scope (Wise SOACS)
-> Body (Wise SOACS) -> Identity (Body (Wise SOACS))
mapOnBody = (Body (Wise SOACS) -> Identity (Body (Wise SOACS)))
-> Scope (Wise SOACS)
-> Body (Wise SOACS)
-> Identity (Body (Wise SOACS))
forall a b. a -> b -> a
const ((Body (Wise SOACS) -> Identity (Body (Wise SOACS)))
 -> Scope (Wise SOACS)
 -> Body (Wise SOACS)
 -> Identity (Body (Wise SOACS)))
-> (Body (Wise SOACS) -> Identity (Body (Wise SOACS)))
-> Scope (Wise SOACS)
-> Body (Wise SOACS)
-> Identity (Body (Wise SOACS))
forall a b. (a -> b) -> a -> b
$ Body (Wise SOACS) -> Identity (Body (Wise SOACS))
forall (m :: * -> *) a. Monad m => a -> m a
return (Body (Wise SOACS) -> Identity (Body (Wise SOACS)))
-> (Body (Wise SOACS) -> Body (Wise SOACS))
-> Body (Wise SOACS)
-> Identity (Body (Wise SOACS))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map ArrayOp ArrayOp -> Body (Wise SOACS) -> Body (Wise SOACS)
replaceArrayOps Map ArrayOp ArrayOp
substs,
          mapOnOp :: Op (Wise SOACS) -> Identity (Op (Wise SOACS))
mapOnOp = SOAC (Wise SOACS) -> Identity (SOAC (Wise SOACS))
forall (m :: * -> *) a. Monad m => a -> m a
return (SOAC (Wise SOACS) -> Identity (SOAC (Wise SOACS)))
-> (SOAC (Wise SOACS) -> SOAC (Wise SOACS))
-> SOAC (Wise SOACS)
-> Identity (SOAC (Wise SOACS))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SOAC (Wise SOACS) -> SOAC (Wise SOACS)
onOp
        }
    onOp :: SOAC (Wise SOACS) -> SOAC (Wise SOACS)
onOp = Identity (SOAC (Wise SOACS)) -> SOAC (Wise SOACS)
forall a. Identity a -> a
runIdentity (Identity (SOAC (Wise SOACS)) -> SOAC (Wise SOACS))
-> (SOAC (Wise SOACS) -> Identity (SOAC (Wise SOACS)))
-> SOAC (Wise SOACS)
-> SOAC (Wise SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SOACMapper (Wise SOACS) (Wise SOACS) Identity
-> SOAC (Wise SOACS) -> Identity (SOAC (Wise SOACS))
forall (m :: * -> *) frep trep.
(Applicative m, Monad m) =>
SOACMapper frep trep m -> SOAC frep -> m (SOAC trep)
mapSOACM SOACMapper Any Any Identity
forall (m :: * -> *) rep. Monad m => SOACMapper rep rep m
identitySOACMapper {mapOnSOACLambda :: Lambda (Wise SOACS) -> Identity (Lambda (Wise SOACS))
mapOnSOACLambda = Lambda (Wise SOACS) -> Identity (Lambda (Wise SOACS))
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda (Wise SOACS) -> Identity (Lambda (Wise SOACS)))
-> (Lambda (Wise SOACS) -> Lambda (Wise SOACS))
-> Lambda (Wise SOACS)
-> Identity (Lambda (Wise SOACS))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda (Wise SOACS) -> Lambda (Wise SOACS)
onLambda}
    onLambda :: Lambda (Wise SOACS) -> Lambda (Wise SOACS)
onLambda Lambda (Wise SOACS)
lam = Lambda (Wise SOACS)
lam {lambdaBody :: Body (Wise SOACS)
lambdaBody = Map ArrayOp ArrayOp -> Body (Wise SOACS) -> Body (Wise SOACS)
replaceArrayOps Map ArrayOp ArrayOp
substs (Body (Wise SOACS) -> Body (Wise SOACS))
-> Body (Wise SOACS) -> Body (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ Lambda (Wise SOACS) -> Body (Wise SOACS)
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda (Wise SOACS)
lam}

-- Turn
--
--    map (\i -> ... xs[i] ...) (iota n)
--
-- into
--
--    map (\i x -> ... x ...) (iota n) xs
--
-- This is not because we want to encourage the map-iota pattern, but
-- it may be present in generated code.  This is an unfortunately
-- expensive simplification rule, since it requires multiple passes
-- over the entire lambda body.  It only handles the very simplest
-- case - if you find yourself planning to extend it to handle more
-- complex situations (rotate or whatnot), consider turning it into a
-- separate compiler pass instead.
simplifyMapIota :: TopDownRuleOp (Wise SOACS)
simplifyMapIota :: RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
simplifyMapIota SymbolTable (Wise SOACS)
vtable Pat (Wise SOACS)
pat StmAux (ExpDec (Wise SOACS))
aux (Screma w arrs (ScremaForm scan reduce map_lam))
  | Just (Param Type
p, VName
_) <- ((Param Type, VName) -> Bool)
-> [(Param Type, VName)] -> Maybe (Param Type, VName)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (Param Type, VName) -> Bool
isIota ([Param Type] -> [VName] -> [(Param Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Lambda (Wise SOACS) -> [LParam (Wise SOACS)]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams Lambda (Wise SOACS)
map_lam) [VName]
arrs),
    [ArrayOp]
indexings <-
      (ArrayOp -> Bool) -> [ArrayOp] -> [ArrayOp]
forall a. (a -> Bool) -> [a] -> [a]
filter (VName -> ArrayOp -> Bool
indexesWith (Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
p)) ([ArrayOp] -> [ArrayOp]) -> [ArrayOp] -> [ArrayOp]
forall a b. (a -> b) -> a -> b
$
        ((PatT (VarWisdom, Type), ArrayOp) -> ArrayOp)
-> [(PatT (VarWisdom, Type), ArrayOp)] -> [ArrayOp]
forall a b. (a -> b) -> [a] -> [b]
map (PatT (VarWisdom, Type), ArrayOp) -> ArrayOp
forall a b. (a, b) -> b
snd ([(PatT (VarWisdom, Type), ArrayOp)] -> [ArrayOp])
-> [(PatT (VarWisdom, Type), ArrayOp)] -> [ArrayOp]
forall a b. (a -> b) -> a -> b
$
          Set (PatT (VarWisdom, Type), ArrayOp)
-> [(PatT (VarWisdom, Type), ArrayOp)]
forall a. Set a -> [a]
S.toList (Set (PatT (VarWisdom, Type), ArrayOp)
 -> [(PatT (VarWisdom, Type), ArrayOp)])
-> Set (PatT (VarWisdom, Type), ArrayOp)
-> [(PatT (VarWisdom, Type), ArrayOp)]
forall a b. (a -> b) -> a -> b
$
            Body (Wise SOACS) -> Set (Pat (Wise SOACS), ArrayOp)
arrayOps (Body (Wise SOACS) -> Set (Pat (Wise SOACS), ArrayOp))
-> Body (Wise SOACS) -> Set (Pat (Wise SOACS), ArrayOp)
forall a b. (a -> b) -> a -> b
$ Lambda (Wise SOACS) -> Body (Wise SOACS)
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda (Wise SOACS)
map_lam,
    Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [ArrayOp] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [ArrayOp]
indexings = RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM (Wise SOACS) () -> Rule (Wise SOACS))
-> RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ do
    -- For each indexing with iota, add the corresponding array to
    -- the Screma, and construct a new lambda parameter.
    ([VName]
more_arrs, [Param Type]
more_params, [ArrayOp]
replacements) <-
      [(VName, Param Type, ArrayOp)]
-> ([VName], [Param Type], [ArrayOp])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 ([(VName, Param Type, ArrayOp)]
 -> ([VName], [Param Type], [ArrayOp]))
-> ([Maybe (VName, Param Type, ArrayOp)]
    -> [(VName, Param Type, ArrayOp)])
-> [Maybe (VName, Param Type, ArrayOp)]
-> ([VName], [Param Type], [ArrayOp])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Maybe (VName, Param Type, ArrayOp)]
-> [(VName, Param Type, ArrayOp)]
forall a. [Maybe a] -> [a]
catMaybes ([Maybe (VName, Param Type, ArrayOp)]
 -> ([VName], [Param Type], [ArrayOp]))
-> RuleM (Wise SOACS) [Maybe (VName, Param Type, ArrayOp)]
-> RuleM (Wise SOACS) ([VName], [Param Type], [ArrayOp])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (ArrayOp
 -> RuleM (Wise SOACS) (Maybe (VName, Param Type, ArrayOp)))
-> [ArrayOp]
-> RuleM (Wise SOACS) [Maybe (VName, Param Type, ArrayOp)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ArrayOp -> RuleM (Wise SOACS) (Maybe (VName, Param Type, ArrayOp))
mapOverArr [ArrayOp]
indexings
    let substs :: Map ArrayOp ArrayOp
substs = [(ArrayOp, ArrayOp)] -> Map ArrayOp ArrayOp
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(ArrayOp, ArrayOp)] -> Map ArrayOp ArrayOp)
-> [(ArrayOp, ArrayOp)] -> Map ArrayOp ArrayOp
forall a b. (a -> b) -> a -> b
$ [ArrayOp] -> [ArrayOp] -> [(ArrayOp, ArrayOp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [ArrayOp]
indexings [ArrayOp]
replacements
        map_lam' :: Lambda (Wise SOACS)
map_lam' =
          Lambda (Wise SOACS)
map_lam
            { lambdaParams :: [LParam (Wise SOACS)]
lambdaParams = Lambda (Wise SOACS) -> [LParam (Wise SOACS)]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams Lambda (Wise SOACS)
map_lam [Param Type] -> [Param Type] -> [Param Type]
forall a. Semigroup a => a -> a -> a
<> [Param Type]
more_params,
              lambdaBody :: Body (Wise SOACS)
lambdaBody =
                Map ArrayOp ArrayOp -> Body (Wise SOACS) -> Body (Wise SOACS)
replaceArrayOps Map ArrayOp ArrayOp
substs (Body (Wise SOACS) -> Body (Wise SOACS))
-> Body (Wise SOACS) -> Body (Wise SOACS)
forall a b. (a -> b) -> a -> b
$
                  Lambda (Wise SOACS) -> Body (Wise SOACS)
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda (Wise SOACS)
map_lam
            }

    StmAux (ExpWisdom, ())
-> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpWisdom, ())
StmAux (ExpDec (Wise SOACS))
aux (RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ())
-> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$
      Pat (Rep (RuleM (Wise SOACS)))
-> Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m ()
letBind Pat (Rep (RuleM (Wise SOACS)))
Pat (Wise SOACS)
pat (Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ())
-> Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$ Op (Wise SOACS) -> Exp (Wise SOACS)
forall rep. Op rep -> ExpT rep
Op (Op (Wise SOACS) -> Exp (Wise SOACS))
-> Op (Wise SOACS) -> Exp (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm (Wise SOACS) -> SOAC (Wise SOACS)
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w ([VName]
arrs [VName] -> [VName] -> [VName]
forall a. Semigroup a => a -> a -> a
<> [VName]
more_arrs) ([Scan (Wise SOACS)]
-> [Reduce (Wise SOACS)]
-> Lambda (Wise SOACS)
-> ScremaForm (Wise SOACS)
forall rep.
[Scan rep] -> [Reduce rep] -> Lambda rep -> ScremaForm rep
ScremaForm [Scan (Wise SOACS)]
scan [Reduce (Wise SOACS)]
reduce Lambda (Wise SOACS)
map_lam')
  where
    isIota :: (Param Type, VName) -> Bool
isIota (Param Type
_, VName
arr) = case VName -> SymbolTable (Wise SOACS) -> Maybe (BasicOp, Certs)
forall rep. VName -> SymbolTable rep -> Maybe (BasicOp, Certs)
ST.lookupBasicOp VName
arr SymbolTable (Wise SOACS)
vtable of
      Just (Iota SubExp
_ (Constant PrimValue
o) (Constant PrimValue
s) IntType
_, Certs
_) ->
        PrimValue -> Bool
zeroIsh PrimValue
o Bool -> Bool -> Bool
&& PrimValue -> Bool
oneIsh PrimValue
s
      Maybe (BasicOp, Certs)
_ -> Bool
False

    indexesWith :: VName -> ArrayOp -> Bool
indexesWith VName
v (ArrayIndexing Certs
cs VName
arr (Slice (DimFix (Var VName
i) : [DimIndex SubExp]
_)))
      | VName
arr VName -> SymbolTable (Wise SOACS) -> Bool
forall rep. VName -> SymbolTable rep -> Bool
`ST.elem` SymbolTable (Wise SOACS)
vtable,
        (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> SymbolTable (Wise SOACS) -> Bool
forall rep. VName -> SymbolTable rep -> Bool
`ST.elem` SymbolTable (Wise SOACS)
vtable) ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$ Certs -> [VName]
unCerts Certs
cs =
        VName
i VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
v
    indexesWith VName
_ ArrayOp
_ = Bool
False

    mapOverArr :: ArrayOp -> RuleM (Wise SOACS) (Maybe (VName, Param Type, ArrayOp))
mapOverArr (ArrayIndexing Certs
cs VName
arr Slice SubExp
slice) = do
      VName
arr_elem <- String -> RuleM (Wise SOACS) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> RuleM (Wise SOACS) VName)
-> String -> RuleM (Wise SOACS) VName
forall a b. (a -> b) -> a -> b
$ VName -> String
baseString VName
arr String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"_elem"
      Type
arr_t <- VName -> RuleM (Wise SOACS) Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
arr
      VName
arr' <-
        if Int -> Type -> SubExp
forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 Type
arr_t SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
w
          then VName -> RuleM (Wise SOACS) VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
arr
          else
            Certs -> RuleM (Wise SOACS) VName -> RuleM (Wise SOACS) VName
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs (RuleM (Wise SOACS) VName -> RuleM (Wise SOACS) VName)
-> (Slice SubExp -> RuleM (Wise SOACS) VName)
-> Slice SubExp
-> RuleM (Wise SOACS) VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String
-> Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
arr String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"_prefix") (Exp (Wise SOACS) -> RuleM (Wise SOACS) VName)
-> (Slice SubExp -> Exp (Wise SOACS))
-> Slice SubExp
-> RuleM (Wise SOACS) VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp (Wise SOACS)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp (Wise SOACS))
-> (Slice SubExp -> BasicOp) -> Slice SubExp -> Exp (Wise SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Slice SubExp -> BasicOp
Index VName
arr (Slice SubExp -> RuleM (Wise SOACS) VName)
-> Slice SubExp -> RuleM (Wise SOACS) VName
forall a b. (a -> b) -> a -> b
$
              Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
arr_t [SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) SubExp
w (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1)]
      Maybe (VName, Param Type, ArrayOp)
-> RuleM (Wise SOACS) (Maybe (VName, Param Type, ArrayOp))
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (VName, Param Type, ArrayOp)
 -> RuleM (Wise SOACS) (Maybe (VName, Param Type, ArrayOp)))
-> Maybe (VName, Param Type, ArrayOp)
-> RuleM (Wise SOACS) (Maybe (VName, Param Type, ArrayOp))
forall a b. (a -> b) -> a -> b
$
        (VName, Param Type, ArrayOp) -> Maybe (VName, Param Type, ArrayOp)
forall a. a -> Maybe a
Just
          ( VName
arr',
            VName -> Type -> Param Type
forall dec. VName -> dec -> Param dec
Param VName
arr_elem (Type -> Type
forall shape u.
ArrayShape shape =>
TypeBase shape u -> TypeBase shape u
rowType Type
arr_t),
            Certs -> VName -> Slice SubExp -> ArrayOp
ArrayIndexing Certs
cs VName
arr_elem ([DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice (Int -> [DimIndex SubExp] -> [DimIndex SubExp]
forall a. Int -> [a] -> [a]
drop Int
1 (Slice SubExp -> [DimIndex SubExp]
forall d. Slice d -> [DimIndex d]
unSlice Slice SubExp
slice)))
          )
    mapOverArr ArrayOp
_ = Maybe (VName, Param Type, ArrayOp)
-> RuleM (Wise SOACS) (Maybe (VName, Param Type, ArrayOp))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (VName, Param Type, ArrayOp)
forall a. Maybe a
Nothing
simplifyMapIota SymbolTable (Wise SOACS)
_ Pat (Wise SOACS)
_ StmAux (ExpDec (Wise SOACS))
_ Op (Wise SOACS)
_ = Rule (Wise SOACS)
forall rep. Rule rep
Skip

-- If a Screma's map function contains a transformation
-- (e.g. transpose) on a parameter, create a new parameter
-- corresponding to that transformation performed on the rows of the
-- full array.
moveTransformToInput :: TopDownRuleOp (Wise SOACS)
moveTransformToInput :: RuleOp (Wise SOACS) (SymbolTable (Wise SOACS))
moveTransformToInput SymbolTable (Wise SOACS)
vtable Pat (Wise SOACS)
pat StmAux (ExpDec (Wise SOACS))
aux (Screma w arrs (ScremaForm scan reduce map_lam))
  | [ArrayOp]
ops <- ((PatT (VarWisdom, Type), ArrayOp) -> ArrayOp)
-> [(PatT (VarWisdom, Type), ArrayOp)] -> [ArrayOp]
forall a b. (a -> b) -> [a] -> [b]
map (PatT (VarWisdom, Type), ArrayOp) -> ArrayOp
forall a b. (a, b) -> b
snd ([(PatT (VarWisdom, Type), ArrayOp)] -> [ArrayOp])
-> [(PatT (VarWisdom, Type), ArrayOp)] -> [ArrayOp]
forall a b. (a -> b) -> a -> b
$ ((PatT (VarWisdom, Type), ArrayOp) -> Bool)
-> [(PatT (VarWisdom, Type), ArrayOp)]
-> [(PatT (VarWisdom, Type), ArrayOp)]
forall a. (a -> Bool) -> [a] -> [a]
filter (PatT (VarWisdom, Type), ArrayOp) -> Bool
arrayIsMapParam ([(PatT (VarWisdom, Type), ArrayOp)]
 -> [(PatT (VarWisdom, Type), ArrayOp)])
-> [(PatT (VarWisdom, Type), ArrayOp)]
-> [(PatT (VarWisdom, Type), ArrayOp)]
forall a b. (a -> b) -> a -> b
$ Set (PatT (VarWisdom, Type), ArrayOp)
-> [(PatT (VarWisdom, Type), ArrayOp)]
forall a. Set a -> [a]
S.toList (Set (PatT (VarWisdom, Type), ArrayOp)
 -> [(PatT (VarWisdom, Type), ArrayOp)])
-> Set (PatT (VarWisdom, Type), ArrayOp)
-> [(PatT (VarWisdom, Type), ArrayOp)]
forall a b. (a -> b) -> a -> b
$ Body (Wise SOACS) -> Set (Pat (Wise SOACS), ArrayOp)
arrayOps (Body (Wise SOACS) -> Set (Pat (Wise SOACS), ArrayOp))
-> Body (Wise SOACS) -> Set (Pat (Wise SOACS), ArrayOp)
forall a b. (a -> b) -> a -> b
$ Lambda (Wise SOACS) -> Body (Wise SOACS)
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda (Wise SOACS)
map_lam,
    Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [ArrayOp] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [ArrayOp]
ops = RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM (Wise SOACS) () -> Rule (Wise SOACS))
-> RuleM (Wise SOACS) () -> Rule (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ do
    ([VName]
more_arrs, [Param Type]
more_params, [ArrayOp]
replacements) <-
      [(VName, Param Type, ArrayOp)]
-> ([VName], [Param Type], [ArrayOp])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 ([(VName, Param Type, ArrayOp)]
 -> ([VName], [Param Type], [ArrayOp]))
-> ([Maybe (VName, Param Type, ArrayOp)]
    -> [(VName, Param Type, ArrayOp)])
-> [Maybe (VName, Param Type, ArrayOp)]
-> ([VName], [Param Type], [ArrayOp])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Maybe (VName, Param Type, ArrayOp)]
-> [(VName, Param Type, ArrayOp)]
forall a. [Maybe a] -> [a]
catMaybes ([Maybe (VName, Param Type, ArrayOp)]
 -> ([VName], [Param Type], [ArrayOp]))
-> RuleM (Wise SOACS) [Maybe (VName, Param Type, ArrayOp)]
-> RuleM (Wise SOACS) ([VName], [Param Type], [ArrayOp])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (ArrayOp
 -> RuleM (Wise SOACS) (Maybe (VName, Param Type, ArrayOp)))
-> [ArrayOp]
-> RuleM (Wise SOACS) [Maybe (VName, Param Type, ArrayOp)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ArrayOp -> RuleM (Wise SOACS) (Maybe (VName, Param Type, ArrayOp))
mapOverArr [ArrayOp]
ops

    Bool -> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ([VName] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [VName]
more_arrs) RuleM (Wise SOACS) ()
forall rep a. RuleM rep a
cannotSimplify

    let substs :: Map ArrayOp ArrayOp
substs = [(ArrayOp, ArrayOp)] -> Map ArrayOp ArrayOp
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(ArrayOp, ArrayOp)] -> Map ArrayOp ArrayOp)
-> [(ArrayOp, ArrayOp)] -> Map ArrayOp ArrayOp
forall a b. (a -> b) -> a -> b
$ [ArrayOp] -> [ArrayOp] -> [(ArrayOp, ArrayOp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [ArrayOp]
ops [ArrayOp]
replacements
        map_lam' :: Lambda (Wise SOACS)
map_lam' =
          Lambda (Wise SOACS)
map_lam
            { lambdaParams :: [LParam (Wise SOACS)]
lambdaParams = Lambda (Wise SOACS) -> [LParam (Wise SOACS)]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams Lambda (Wise SOACS)
map_lam [Param Type] -> [Param Type] -> [Param Type]
forall a. Semigroup a => a -> a -> a
<> [Param Type]
more_params,
              lambdaBody :: Body (Wise SOACS)
lambdaBody =
                Map ArrayOp ArrayOp -> Body (Wise SOACS) -> Body (Wise SOACS)
replaceArrayOps Map ArrayOp ArrayOp
substs (Body (Wise SOACS) -> Body (Wise SOACS))
-> Body (Wise SOACS) -> Body (Wise SOACS)
forall a b. (a -> b) -> a -> b
$
                  Lambda (Wise SOACS) -> Body (Wise SOACS)
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda (Wise SOACS)
map_lam
            }

    StmAux (ExpWisdom, ())
-> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpWisdom, ())
StmAux (ExpDec (Wise SOACS))
aux (RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ())
-> RuleM (Wise SOACS) () -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$
      Pat (Rep (RuleM (Wise SOACS)))
-> Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m ()
letBind Pat (Rep (RuleM (Wise SOACS)))
Pat (Wise SOACS)
pat (Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ())
-> Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) ()
forall a b. (a -> b) -> a -> b
$ Op (Wise SOACS) -> Exp (Wise SOACS)
forall rep. Op rep -> ExpT rep
Op (Op (Wise SOACS) -> Exp (Wise SOACS))
-> Op (Wise SOACS) -> Exp (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm (Wise SOACS) -> SOAC (Wise SOACS)
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w ([VName]
arrs [VName] -> [VName] -> [VName]
forall a. Semigroup a => a -> a -> a
<> [VName]
more_arrs) ([Scan (Wise SOACS)]
-> [Reduce (Wise SOACS)]
-> Lambda (Wise SOACS)
-> ScremaForm (Wise SOACS)
forall rep.
[Scan rep] -> [Reduce rep] -> Lambda rep -> ScremaForm rep
ScremaForm [Scan (Wise SOACS)]
scan [Reduce (Wise SOACS)]
reduce Lambda (Wise SOACS)
map_lam')
  where
    map_param_names :: [VName]
map_param_names = (Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName (Lambda (Wise SOACS) -> [LParam (Wise SOACS)]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams Lambda (Wise SOACS)
map_lam)
    topLevelPat :: PatT (VarWisdom, Type) -> Bool
topLevelPat = (PatT (VarWisdom, Type) -> Seq (PatT (VarWisdom, Type)) -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` (Stm (Wise SOACS) -> PatT (VarWisdom, Type))
-> Stms (Wise SOACS) -> Seq (PatT (VarWisdom, Type))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Stm (Wise SOACS) -> PatT (VarWisdom, Type)
forall rep. Stm rep -> Pat rep
stmPat (Body (Wise SOACS) -> Stms (Wise SOACS)
forall rep. BodyT rep -> Stms rep
bodyStms (Lambda (Wise SOACS) -> Body (Wise SOACS)
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda (Wise SOACS)
map_lam)))
    onlyUsedOnce :: VName -> Bool
onlyUsedOnce VName
arr =
      case (Stm (Wise SOACS) -> Bool)
-> [Stm (Wise SOACS)] -> [Stm (Wise SOACS)]
forall a. (a -> Bool) -> [a] -> [a]
filter ((VName
arr VName -> Names -> Bool
`nameIn`) (Names -> Bool)
-> (Stm (Wise SOACS) -> Names) -> Stm (Wise SOACS) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm (Wise SOACS) -> Names
forall a. FreeIn a => a -> Names
freeIn) ([Stm (Wise SOACS)] -> [Stm (Wise SOACS)])
-> [Stm (Wise SOACS)] -> [Stm (Wise SOACS)]
forall a b. (a -> b) -> a -> b
$ Stms (Wise SOACS) -> [Stm (Wise SOACS)]
forall rep. Stms rep -> [Stm rep]
stmsToList (Stms (Wise SOACS) -> [Stm (Wise SOACS)])
-> Stms (Wise SOACS) -> [Stm (Wise SOACS)]
forall a b. (a -> b) -> a -> b
$ Body (Wise SOACS) -> Stms (Wise SOACS)
forall rep. BodyT rep -> Stms rep
bodyStms (Body (Wise SOACS) -> Stms (Wise SOACS))
-> Body (Wise SOACS) -> Stms (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ Lambda (Wise SOACS) -> Body (Wise SOACS)
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda (Wise SOACS)
map_lam of
        Stm (Wise SOACS)
_ : Stm (Wise SOACS)
_ : [Stm (Wise SOACS)]
_ -> Bool
False
        [Stm (Wise SOACS)]
_ -> Bool
True

    -- It's not just about whether the array is a parameter;
    -- everything else must be map-invariant.
    arrayIsMapParam :: (PatT (VarWisdom, Type), ArrayOp) -> Bool
arrayIsMapParam (PatT (VarWisdom, Type)
pat', ArrayIndexing Certs
cs VName
arr Slice SubExp
slice) =
      VName
arr VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
map_param_names
        Bool -> Bool -> Bool
&& (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> SymbolTable (Wise SOACS) -> Bool
forall rep. VName -> SymbolTable rep -> Bool
`ST.elem` SymbolTable (Wise SOACS)
vtable) (Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ Certs -> Names
forall a. FreeIn a => a -> Names
freeIn Certs
cs Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Slice SubExp -> Names
forall a. FreeIn a => a -> Names
freeIn Slice SubExp
slice)
        Bool -> Bool -> Bool
&& Bool -> Bool
not (Slice SubExp -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Slice SubExp
slice)
        Bool -> Bool -> Bool
&& (Bool -> Bool
not ([SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null ([SubExp] -> Bool) -> [SubExp] -> Bool
forall a b. (a -> b) -> a -> b
$ Slice SubExp -> [SubExp]
forall d. Slice d -> [d]
sliceDims Slice SubExp
slice) Bool -> Bool -> Bool
|| (PatT (VarWisdom, Type) -> Bool
topLevelPat PatT (VarWisdom, Type)
pat' Bool -> Bool -> Bool
&& VName -> Bool
onlyUsedOnce VName
arr))
    arrayIsMapParam (PatT (VarWisdom, Type)
_, ArrayRearrange Certs
cs VName
arr [Int]
perm) =
      VName
arr VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
map_param_names
        Bool -> Bool -> Bool
&& (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> SymbolTable (Wise SOACS) -> Bool
forall rep. VName -> SymbolTable rep -> Bool
`ST.elem` SymbolTable (Wise SOACS)
vtable) (Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ Certs -> Names
forall a. FreeIn a => a -> Names
freeIn Certs
cs)
        Bool -> Bool -> Bool
&& Bool -> Bool
not ([Int] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Int]
perm)
    arrayIsMapParam (PatT (VarWisdom, Type)
_, ArrayRotate Certs
cs VName
arr [SubExp]
rots) =
      VName
arr VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
map_param_names
        Bool -> Bool -> Bool
&& (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> SymbolTable (Wise SOACS) -> Bool
forall rep. VName -> SymbolTable rep -> Bool
`ST.elem` SymbolTable (Wise SOACS)
vtable) (Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ Certs -> Names
forall a. FreeIn a => a -> Names
freeIn Certs
cs Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> [SubExp] -> Names
forall a. FreeIn a => a -> Names
freeIn [SubExp]
rots)
    arrayIsMapParam (PatT (VarWisdom, Type)
_, ArrayCopy Certs
cs VName
arr) =
      VName
arr VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
map_param_names
        Bool -> Bool -> Bool
&& (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> SymbolTable (Wise SOACS) -> Bool
forall rep. VName -> SymbolTable rep -> Bool
`ST.elem` SymbolTable (Wise SOACS)
vtable) (Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ Certs -> Names
forall a. FreeIn a => a -> Names
freeIn Certs
cs)
    arrayIsMapParam (PatT (VarWisdom, Type)
_, ArrayVar {}) =
      Bool
False

    mapOverArr :: ArrayOp -> RuleM (Wise SOACS) (Maybe (VName, Param Type, ArrayOp))
mapOverArr ArrayOp
op
      | Just (VName
_, VName
arr) <- ((VName, VName) -> Bool)
-> [(VName, VName)] -> Maybe (VName, VName)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== ArrayOp -> VName
arrayOpArr ArrayOp
op) (VName -> Bool)
-> ((VName, VName) -> VName) -> (VName, VName) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, VName) -> VName
forall a b. (a, b) -> a
fst) ([VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
map_param_names [VName]
arrs) = do
        Type
arr_t <- VName -> RuleM (Wise SOACS) Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
arr
        let whole_dim :: DimIndex SubExp
whole_dim = SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (Int -> Type -> SubExp
forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 Type
arr_t) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1)
        VName
arr_transformed <- Certs -> RuleM (Wise SOACS) VName -> RuleM (Wise SOACS) VName
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (ArrayOp -> Certs
arrayOpCerts ArrayOp
op) (RuleM (Wise SOACS) VName -> RuleM (Wise SOACS) VName)
-> RuleM (Wise SOACS) VName -> RuleM (Wise SOACS) VName
forall a b. (a -> b) -> a -> b
$
          String
-> Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
arr String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"_transformed") (Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) VName)
-> Exp (Rep (RuleM (Wise SOACS))) -> RuleM (Wise SOACS) VName
forall a b. (a -> b) -> a -> b
$
            case ArrayOp
op of
              ArrayIndexing Certs
_ VName
_ (Slice [DimIndex SubExp]
slice) ->
                BasicOp -> Exp (Wise SOACS)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp (Wise SOACS)) -> BasicOp -> Exp (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ [DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex SubExp] -> Slice SubExp)
-> [DimIndex SubExp] -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ DimIndex SubExp
whole_dim DimIndex SubExp -> [DimIndex SubExp] -> [DimIndex SubExp]
forall a. a -> [a] -> [a]
: [DimIndex SubExp]
slice
              ArrayRearrange Certs
_ VName
_ [Int]
perm ->
                BasicOp -> Exp (Wise SOACS)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp (Wise SOACS)) -> BasicOp -> Exp (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Rearrange (Int
0 Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: (Int -> Int) -> [Int] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) [Int]
perm) VName
arr
              ArrayRotate Certs
_ VName
_ [SubExp]
rots ->
                BasicOp -> Exp (Wise SOACS)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp (Wise SOACS)) -> BasicOp -> Exp (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ [SubExp] -> VName -> BasicOp
Rotate (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0 SubExp -> [SubExp] -> [SubExp]
forall a. a -> [a] -> [a]
: [SubExp]
rots) VName
arr
              ArrayCopy {} ->
                BasicOp -> Exp (Wise SOACS)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp (Wise SOACS)) -> BasicOp -> Exp (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
arr
              ArrayVar {} ->
                BasicOp -> Exp (Wise SOACS)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp (Wise SOACS)) -> BasicOp -> Exp (Wise SOACS)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
arr
        Type
arr_transformed_t <- VName -> RuleM (Wise SOACS) Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
arr_transformed
        VName
arr_transformed_row <- String -> RuleM (Wise SOACS) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> RuleM (Wise SOACS) VName)
-> String -> RuleM (Wise SOACS) VName
forall a b. (a -> b) -> a -> b
$ VName -> String
baseString VName
arr String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"_transformed_row"
        Maybe (VName, Param Type, ArrayOp)
-> RuleM (Wise SOACS) (Maybe (VName, Param Type, ArrayOp))
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (VName, Param Type, ArrayOp)
 -> RuleM (Wise SOACS) (Maybe (VName, Param Type, ArrayOp)))
-> Maybe (VName, Param Type, ArrayOp)
-> RuleM (Wise SOACS) (Maybe (VName, Param Type, ArrayOp))
forall a b. (a -> b) -> a -> b
$
          (VName, Param Type, ArrayOp) -> Maybe (VName, Param Type, ArrayOp)
forall a. a -> Maybe a
Just
            ( VName
arr_transformed,
              VName -> Type -> Param Type
forall dec. VName -> dec -> Param dec
Param VName
arr_transformed_row (Type -> Type
forall shape u.
ArrayShape shape =>
TypeBase shape u -> TypeBase shape u
rowType Type
arr_transformed_t),
              Certs -> VName -> ArrayOp
ArrayVar Certs
forall a. Monoid a => a
mempty VName
arr_transformed_row
            )
    mapOverArr ArrayOp
_ = Maybe (VName, Param Type, ArrayOp)
-> RuleM (Wise SOACS) (Maybe (VName, Param Type, ArrayOp))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (VName, Param Type, ArrayOp)
forall a. Maybe a
Nothing
moveTransformToInput SymbolTable (Wise SOACS)
_ Pat (Wise SOACS)
_ StmAux (ExpDec (Wise SOACS))
_ Op (Wise SOACS)
_ =
  Rule (Wise SOACS)
forall rep. Rule rep
Skip