{-# OPTIONS_GHC -fno-warn-orphans #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ConstraintKinds #-}
module Futhark.Representation.Kernels.Simplify
       ( simplifyKernels
       , simplifyLambda

       -- * Building blocks
       , simplifyKernelOp
       )
where

import Control.Monad
import Data.Foldable
import Data.List (isPrefixOf, groupBy, partition)
import Data.Maybe
import qualified Data.Map.Strict as M

import Futhark.Representation.Kernels
import qualified Futhark.Optimise.Simplify.Engine as Engine
import Futhark.Optimise.Simplify.Rules
import Futhark.Optimise.Simplify.Lore
import Futhark.MonadFreshNames
import Futhark.Tools
import Futhark.Pass
import Futhark.Representation.SOACS.Simplify (simplifySOAC)
import qualified Futhark.Optimise.Simplify as Simplify
import Futhark.Optimise.Simplify.Rule
import qualified Futhark.Analysis.SymbolTable as ST
import qualified Futhark.Analysis.UsageTable as UT
import Futhark.Util (chunks)
import qualified Futhark.Transform.FirstOrderTransform as FOT

simpleKernels :: Simplify.SimpleOps Kernels
simpleKernels :: SimpleOps Kernels
simpleKernels = SimplifyOp Kernels (Op Kernels) -> SimpleOps Kernels
forall lore.
(SimplifiableLore lore, Bindable lore) =>
SimplifyOp lore (Op lore) -> SimpleOps lore
Simplify.bindableSimpleOps (SimplifyOp Kernels (Op Kernels) -> SimpleOps Kernels)
-> SimplifyOp Kernels (Op Kernels) -> SimpleOps Kernels
forall a b. (a -> b) -> a -> b
$ SimplifyOp Kernels (SOAC Kernels)
-> HostOp Kernels (SOAC Kernels)
-> SimpleM
     Kernels
     (HostOp (Wise Kernels) (OpWithWisdom (SOAC Kernels)),
      Stms (Wise Kernels))
forall lore op.
(SimplifiableLore lore, BodyAttr lore ~ ()) =>
SimplifyOp lore op
-> HostOp lore op
-> SimpleM
     lore (HostOp (Wise lore) (OpWithWisdom op), Stms (Wise lore))
simplifyKernelOp SimplifyOp Kernels (SOAC Kernels)
forall lore. SimplifiableLore lore => SimplifyOp lore (SOAC lore)
simplifySOAC

simplifyKernels :: Prog Kernels -> PassM (Prog Kernels)
simplifyKernels :: Prog Kernels -> PassM (Prog Kernels)
simplifyKernels =
  SimpleOps Kernels
-> RuleBook (Wise Kernels)
-> HoistBlockers Kernels
-> Prog Kernels
-> PassM (Prog Kernels)
forall lore.
SimplifiableLore lore =>
SimpleOps lore
-> RuleBook (Wise lore)
-> HoistBlockers lore
-> Prog lore
-> PassM (Prog lore)
Simplify.simplifyProg SimpleOps Kernels
simpleKernels RuleBook (Wise Kernels)
kernelRules HoistBlockers Kernels
forall lore. HoistBlockers lore
Simplify.noExtraHoistBlockers

simplifyLambda :: (HasScope Kernels m, MonadFreshNames m) =>
                  Lambda Kernels -> [Maybe VName] -> m (Lambda Kernels)
simplifyLambda :: Lambda Kernels -> [Maybe VName] -> m (Lambda Kernels)
simplifyLambda =
  SimpleOps Kernels
-> RuleBook (Wise Kernels)
-> HoistBlockers Kernels
-> Lambda Kernels
-> [Maybe VName]
-> m (Lambda Kernels)
forall (m :: * -> *) lore.
(MonadFreshNames m, HasScope lore m, SimplifiableLore lore) =>
SimpleOps lore
-> RuleBook (Wise lore)
-> HoistBlockers lore
-> Lambda lore
-> [Maybe VName]
-> m (Lambda lore)
Simplify.simplifyLambda SimpleOps Kernels
simpleKernels RuleBook (Wise Kernels)
kernelRules HoistBlockers Kernels
forall lore. HoistBlockers lore
Engine.noExtraHoistBlockers

simplifyKernelOp :: (Engine.SimplifiableLore lore,
                     BodyAttr lore ~ ()) =>
                    Simplify.SimplifyOp lore op
                 -> HostOp lore op
                 -> Engine.SimpleM lore (HostOp (Wise lore) (OpWithWisdom op), Stms (Wise lore))

simplifyKernelOp :: SimplifyOp lore op
-> HostOp lore op
-> SimpleM
     lore (HostOp (Wise lore) (OpWithWisdom op), Stms (Wise lore))
simplifyKernelOp SimplifyOp lore op
f (OtherOp op
op) = do
  (OpWithWisdom op
op', Stms (Wise lore)
stms) <- SimplifyOp lore op
f op
op
  (HostOp (Wise lore) (OpWithWisdom op), Stms (Wise lore))
-> SimpleM
     lore (HostOp (Wise lore) (OpWithWisdom op), Stms (Wise lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (OpWithWisdom op -> HostOp (Wise lore) (OpWithWisdom op)
forall lore op. op -> HostOp lore op
OtherOp OpWithWisdom op
op', Stms (Wise lore)
stms)

simplifyKernelOp SimplifyOp lore op
_ (SegOp (SegMap SegLevel
lvl SegSpace
space [Type]
ts KernelBody lore
kbody)) = do
  (SegLevel
lvl', SegSpace
space', [Type]
ts') <- (SegLevel, SegSpace, [Type])
-> SimpleM lore (SegLevel, SegSpace, [Type])
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify (SegLevel
lvl, SegSpace
space, [Type]
ts)
  (KernelBody (Wise lore)
kbody', Stms (Wise lore)
body_hoisted) <- SegSpace
-> KernelBody lore
-> SimpleM lore (KernelBody (Wise lore), Stms (Wise lore))
forall lore.
(SimplifiableLore lore, BodyAttr lore ~ ()) =>
SegSpace
-> KernelBody lore
-> SimpleM lore (KernelBody (Wise lore), Stms (Wise lore))
simplifyKernelBody SegSpace
space KernelBody lore
kbody
  (HostOp (Wise lore) (OpWithWisdom op), Stms (Wise lore))
-> SimpleM
     lore (HostOp (Wise lore) (OpWithWisdom op), Stms (Wise lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (SegOp (Wise lore) -> HostOp (Wise lore) (OpWithWisdom op)
forall lore op. SegOp lore -> HostOp lore op
SegOp (SegOp (Wise lore) -> HostOp (Wise lore) (OpWithWisdom op))
-> SegOp (Wise lore) -> HostOp (Wise lore) (OpWithWisdom op)
forall a b. (a -> b) -> a -> b
$ SegLevel
-> SegSpace
-> [Type]
-> KernelBody (Wise lore)
-> SegOp (Wise lore)
forall lore.
SegLevel -> SegSpace -> [Type] -> KernelBody lore -> SegOp lore
SegMap SegLevel
lvl' SegSpace
space' [Type]
ts' KernelBody (Wise lore)
kbody',
          Stms (Wise lore)
body_hoisted)

simplifyKernelOp SimplifyOp lore op
_ (SegOp (SegRed SegLevel
lvl SegSpace
space [SegRedOp lore]
reds [Type]
ts KernelBody lore
kbody)) = do
  (SegLevel
lvl', SegSpace
space', [Type]
ts') <- (SegLevel, SegSpace, [Type])
-> SimpleM lore (SegLevel, SegSpace, [Type])
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify (SegLevel
lvl, SegSpace
space, [Type]
ts)
  ([SegRedOp (Wise lore)]
reds', [Stms (Wise lore)]
reds_hoisted) <- ([(SegRedOp (Wise lore), Stms (Wise lore))]
 -> ([SegRedOp (Wise lore)], [Stms (Wise lore)]))
-> SimpleM lore [(SegRedOp (Wise lore), Stms (Wise lore))]
-> SimpleM lore ([SegRedOp (Wise lore)], [Stms (Wise lore)])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [(SegRedOp (Wise lore), Stms (Wise lore))]
-> ([SegRedOp (Wise lore)], [Stms (Wise lore)])
forall a b. [(a, b)] -> ([a], [b])
unzip (SimpleM lore [(SegRedOp (Wise lore), Stms (Wise lore))]
 -> SimpleM lore ([SegRedOp (Wise lore)], [Stms (Wise lore)]))
-> SimpleM lore [(SegRedOp (Wise lore), Stms (Wise lore))]
-> SimpleM lore ([SegRedOp (Wise lore)], [Stms (Wise lore)])
forall a b. (a -> b) -> a -> b
$ [SegRedOp lore]
-> (SegRedOp lore
    -> SimpleM lore (SegRedOp (Wise lore), Stms (Wise lore)))
-> SimpleM lore [(SegRedOp (Wise lore), Stms (Wise lore))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [SegRedOp lore]
reds ((SegRedOp lore
  -> SimpleM lore (SegRedOp (Wise lore), Stms (Wise lore)))
 -> SimpleM lore [(SegRedOp (Wise lore), Stms (Wise lore))])
-> (SegRedOp lore
    -> SimpleM lore (SegRedOp (Wise lore), Stms (Wise lore)))
-> SimpleM lore [(SegRedOp (Wise lore), Stms (Wise lore))]
forall a b. (a -> b) -> a -> b
$ \(SegRedOp Commutativity
comm Lambda lore
lam [SubExp]
nes Shape
shape) -> do
    (Lambda (Wise lore)
lam', Stms (Wise lore)
hoisted) <-
      (SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
forall lore a.
(SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a
Engine.localVtable (SymbolTable (Wise lore)
-> SymbolTable (Wise lore) -> SymbolTable (Wise lore)
forall a. Semigroup a => a -> a -> a
<>SymbolTable (Wise lore)
scope_vtable) (SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
 -> SimpleM lore (Lambda (Wise lore), Stms (Wise lore)))
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
forall a b. (a -> b) -> a -> b
$
      (SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
forall lore a.
(SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a
Engine.localVtable (\SymbolTable (Wise lore)
vtable -> SymbolTable (Wise lore)
vtable { simplifyMemory :: Bool
ST.simplifyMemory = Bool
True }) (SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
 -> SimpleM lore (Lambda (Wise lore), Stms (Wise lore)))
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
forall a b. (a -> b) -> a -> b
$
      Lambda lore
-> [Maybe VName]
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
forall lore.
SimplifiableLore lore =>
Lambda lore
-> [Maybe VName]
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
Engine.simplifyLambda Lambda lore
lam ([Maybe VName]
 -> SimpleM lore (Lambda (Wise lore), Stms (Wise lore)))
-> [Maybe VName]
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
forall a b. (a -> b) -> a -> b
$ Int -> Maybe VName -> [Maybe VName]
forall a. Int -> a -> [a]
replicate ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
2) Maybe VName
forall a. Maybe a
Nothing
    Shape
shape' <- Shape -> SimpleM lore Shape
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify Shape
shape
    [SubExp]
nes' <- (SubExp -> SimpleM lore SubExp)
-> [SubExp] -> SimpleM lore [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify [SubExp]
nes
    (SegRedOp (Wise lore), Stms (Wise lore))
-> SimpleM lore (SegRedOp (Wise lore), Stms (Wise lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (Commutativity
-> Lambda (Wise lore) -> [SubExp] -> Shape -> SegRedOp (Wise lore)
forall lore.
Commutativity -> Lambda lore -> [SubExp] -> Shape -> SegRedOp lore
SegRedOp Commutativity
comm Lambda (Wise lore)
lam' [SubExp]
nes' Shape
shape', Stms (Wise lore)
hoisted)

  (KernelBody (Wise lore)
kbody', Stms (Wise lore)
body_hoisted) <- SegSpace
-> KernelBody lore
-> SimpleM lore (KernelBody (Wise lore), Stms (Wise lore))
forall lore.
(SimplifiableLore lore, BodyAttr lore ~ ()) =>
SegSpace
-> KernelBody lore
-> SimpleM lore (KernelBody (Wise lore), Stms (Wise lore))
simplifyKernelBody SegSpace
space KernelBody lore
kbody

  (HostOp (Wise lore) (OpWithWisdom op), Stms (Wise lore))
-> SimpleM
     lore (HostOp (Wise lore) (OpWithWisdom op), Stms (Wise lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (SegOp (Wise lore) -> HostOp (Wise lore) (OpWithWisdom op)
forall lore op. SegOp lore -> HostOp lore op
SegOp (SegOp (Wise lore) -> HostOp (Wise lore) (OpWithWisdom op))
-> SegOp (Wise lore) -> HostOp (Wise lore) (OpWithWisdom op)
forall a b. (a -> b) -> a -> b
$ SegLevel
-> SegSpace
-> [SegRedOp (Wise lore)]
-> [Type]
-> KernelBody (Wise lore)
-> SegOp (Wise lore)
forall lore.
SegLevel
-> SegSpace
-> [SegRedOp lore]
-> [Type]
-> KernelBody lore
-> SegOp lore
SegRed SegLevel
lvl' SegSpace
space' [SegRedOp (Wise lore)]
reds' [Type]
ts' KernelBody (Wise lore)
kbody',
          [Stms (Wise lore)] -> Stms (Wise lore)
forall a. Monoid a => [a] -> a
mconcat [Stms (Wise lore)]
reds_hoisted Stms (Wise lore) -> Stms (Wise lore) -> Stms (Wise lore)
forall a. Semigroup a => a -> a -> a
<> Stms (Wise lore)
body_hoisted)
  where scope :: Scope (Wise lore)
scope = SegSpace -> Scope (Wise lore)
forall lore. SegSpace -> Scope lore
scopeOfSegSpace SegSpace
space
        scope_vtable :: SymbolTable (Wise lore)
scope_vtable = Scope (Wise lore) -> SymbolTable (Wise lore)
forall lore. Attributes lore => Scope lore -> SymbolTable lore
ST.fromScope Scope (Wise lore)
scope

simplifyKernelOp SimplifyOp lore op
_ (SegOp (SegScan SegLevel
lvl SegSpace
space Lambda lore
scan_op [SubExp]
nes [Type]
ts KernelBody lore
kbody)) = do
  SegLevel
lvl' <- SegLevel -> SimpleM lore SegLevel
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify SegLevel
lvl
  (SegSpace
space', Lambda (Wise lore)
scan_op', [SubExp]
nes', [Type]
ts', KernelBody (Wise lore)
kbody', Stms (Wise lore)
hoisted) <-
    SegSpace
-> Lambda lore
-> [SubExp]
-> [Type]
-> KernelBody lore
-> SimpleM
     lore
     (SegSpace, Lambda (Wise lore), [SubExp], [Type],
      KernelBody (Wise lore), Stms (Wise lore))
forall lore.
(SimplifiableLore lore, BodyAttr lore ~ ()) =>
SegSpace
-> Lambda lore
-> [SubExp]
-> [Type]
-> KernelBody lore
-> SimpleM
     lore
     (SegSpace, Lambda (Wise lore), [SubExp], [Type],
      KernelBody (Wise lore), Stms (Wise lore))
simplifyRedOrScan SegSpace
space Lambda lore
scan_op [SubExp]
nes [Type]
ts KernelBody lore
kbody

  (HostOp (Wise lore) (OpWithWisdom op), Stms (Wise lore))
-> SimpleM
     lore (HostOp (Wise lore) (OpWithWisdom op), Stms (Wise lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (SegOp (Wise lore) -> HostOp (Wise lore) (OpWithWisdom op)
forall lore op. SegOp lore -> HostOp lore op
SegOp (SegOp (Wise lore) -> HostOp (Wise lore) (OpWithWisdom op))
-> SegOp (Wise lore) -> HostOp (Wise lore) (OpWithWisdom op)
forall a b. (a -> b) -> a -> b
$ SegLevel
-> SegSpace
-> Lambda (Wise lore)
-> [SubExp]
-> [Type]
-> KernelBody (Wise lore)
-> SegOp (Wise lore)
forall lore.
SegLevel
-> SegSpace
-> Lambda lore
-> [SubExp]
-> [Type]
-> KernelBody lore
-> SegOp lore
SegScan SegLevel
lvl' SegSpace
space' Lambda (Wise lore)
scan_op' [SubExp]
nes' [Type]
ts' KernelBody (Wise lore)
kbody',
          Stms (Wise lore)
hoisted)

simplifyKernelOp SimplifyOp lore op
_ (SegOp (SegHist SegLevel
lvl SegSpace
space [HistOp lore]
ops [Type]
ts KernelBody lore
kbody)) = do
  (SegLevel
lvl', SegSpace
space', [Type]
ts') <- (SegLevel, SegSpace, [Type])
-> SimpleM lore (SegLevel, SegSpace, [Type])
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify (SegLevel
lvl, SegSpace
space, [Type]
ts)

  ([HistOp (Wise lore)]
ops', [Stms (Wise lore)]
ops_hoisted) <- ([(HistOp (Wise lore), Stms (Wise lore))]
 -> ([HistOp (Wise lore)], [Stms (Wise lore)]))
-> SimpleM lore [(HistOp (Wise lore), Stms (Wise lore))]
-> SimpleM lore ([HistOp (Wise lore)], [Stms (Wise lore)])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [(HistOp (Wise lore), Stms (Wise lore))]
-> ([HistOp (Wise lore)], [Stms (Wise lore)])
forall a b. [(a, b)] -> ([a], [b])
unzip (SimpleM lore [(HistOp (Wise lore), Stms (Wise lore))]
 -> SimpleM lore ([HistOp (Wise lore)], [Stms (Wise lore)]))
-> SimpleM lore [(HistOp (Wise lore), Stms (Wise lore))]
-> SimpleM lore ([HistOp (Wise lore)], [Stms (Wise lore)])
forall a b. (a -> b) -> a -> b
$ [HistOp lore]
-> (HistOp lore
    -> SimpleM lore (HistOp (Wise lore), Stms (Wise lore)))
-> SimpleM lore [(HistOp (Wise lore), Stms (Wise lore))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [HistOp lore]
ops ((HistOp lore
  -> SimpleM lore (HistOp (Wise lore), Stms (Wise lore)))
 -> SimpleM lore [(HistOp (Wise lore), Stms (Wise lore))])
-> (HistOp lore
    -> SimpleM lore (HistOp (Wise lore), Stms (Wise lore)))
-> SimpleM lore [(HistOp (Wise lore), Stms (Wise lore))]
forall a b. (a -> b) -> a -> b
$
    \(HistOp SubExp
w SubExp
rf [VName]
arrs [SubExp]
nes Shape
dims Lambda lore
lam) -> do
      SubExp
w' <- SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify SubExp
w
      SubExp
rf' <- SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify SubExp
rf
      [VName]
arrs' <- [VName] -> SimpleM lore [VName]
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify [VName]
arrs
      [SubExp]
nes' <- [SubExp] -> SimpleM lore [SubExp]
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify [SubExp]
nes
      Shape
dims' <- Shape -> SimpleM lore Shape
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify Shape
dims
      (Lambda (Wise lore)
lam', Stms (Wise lore)
op_hoisted) <-
        (SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
forall lore a.
(SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a
Engine.localVtable (SymbolTable (Wise lore)
-> SymbolTable (Wise lore) -> SymbolTable (Wise lore)
forall a. Semigroup a => a -> a -> a
<>SymbolTable (Wise lore)
scope_vtable) (SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
 -> SimpleM lore (Lambda (Wise lore), Stms (Wise lore)))
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
forall a b. (a -> b) -> a -> b
$
        (SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
forall lore a.
(SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a
Engine.localVtable (\SymbolTable (Wise lore)
vtable -> SymbolTable (Wise lore)
vtable { simplifyMemory :: Bool
ST.simplifyMemory = Bool
True }) (SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
 -> SimpleM lore (Lambda (Wise lore), Stms (Wise lore)))
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
forall a b. (a -> b) -> a -> b
$
        Lambda lore
-> [Maybe VName]
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
forall lore.
SimplifiableLore lore =>
Lambda lore
-> [Maybe VName]
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
Engine.simplifyLambda Lambda lore
lam ([Maybe VName]
 -> SimpleM lore (Lambda (Wise lore), Stms (Wise lore)))
-> [Maybe VName]
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
forall a b. (a -> b) -> a -> b
$
        Int -> Maybe VName -> [Maybe VName]
forall a. Int -> a -> [a]
replicate ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
2) Maybe VName
forall a. Maybe a
Nothing
      (HistOp (Wise lore), Stms (Wise lore))
-> SimpleM lore (HistOp (Wise lore), Stms (Wise lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp
-> SubExp
-> [VName]
-> [SubExp]
-> Shape
-> Lambda (Wise lore)
-> HistOp (Wise lore)
forall lore.
SubExp
-> SubExp
-> [VName]
-> [SubExp]
-> Shape
-> Lambda lore
-> HistOp lore
HistOp SubExp
w' SubExp
rf' [VName]
arrs' [SubExp]
nes' Shape
dims' Lambda (Wise lore)
lam',
              Stms (Wise lore)
op_hoisted)

  (KernelBody (Wise lore)
kbody', Stms (Wise lore)
body_hoisted) <- SegSpace
-> KernelBody lore
-> SimpleM lore (KernelBody (Wise lore), Stms (Wise lore))
forall lore.
(SimplifiableLore lore, BodyAttr lore ~ ()) =>
SegSpace
-> KernelBody lore
-> SimpleM lore (KernelBody (Wise lore), Stms (Wise lore))
simplifyKernelBody SegSpace
space KernelBody lore
kbody

  (HostOp (Wise lore) (OpWithWisdom op), Stms (Wise lore))
-> SimpleM
     lore (HostOp (Wise lore) (OpWithWisdom op), Stms (Wise lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (SegOp (Wise lore) -> HostOp (Wise lore) (OpWithWisdom op)
forall lore op. SegOp lore -> HostOp lore op
SegOp (SegOp (Wise lore) -> HostOp (Wise lore) (OpWithWisdom op))
-> SegOp (Wise lore) -> HostOp (Wise lore) (OpWithWisdom op)
forall a b. (a -> b) -> a -> b
$ SegLevel
-> SegSpace
-> [HistOp (Wise lore)]
-> [Type]
-> KernelBody (Wise lore)
-> SegOp (Wise lore)
forall lore.
SegLevel
-> SegSpace
-> [HistOp lore]
-> [Type]
-> KernelBody lore
-> SegOp lore
SegHist SegLevel
lvl' SegSpace
space' [HistOp (Wise lore)]
ops' [Type]
ts' KernelBody (Wise lore)
kbody',
          [Stms (Wise lore)] -> Stms (Wise lore)
forall a. Monoid a => [a] -> a
mconcat [Stms (Wise lore)]
ops_hoisted Stms (Wise lore) -> Stms (Wise lore) -> Stms (Wise lore)
forall a. Semigroup a => a -> a -> a
<> Stms (Wise lore)
body_hoisted)

  where scope :: Scope (Wise lore)
scope = SegSpace -> Scope (Wise lore)
forall lore. SegSpace -> Scope lore
scopeOfSegSpace SegSpace
space
        scope_vtable :: SymbolTable (Wise lore)
scope_vtable = Scope (Wise lore) -> SymbolTable (Wise lore)
forall lore. Attributes lore => Scope lore -> SymbolTable lore
ST.fromScope Scope (Wise lore)
scope

simplifyKernelOp SimplifyOp lore op
_ (SizeOp (SplitSpace SplitOrdering
o SubExp
w SubExp
i SubExp
elems_per_thread)) =
  (,) (HostOp (Wise lore) (OpWithWisdom op)
 -> Stms (Wise lore)
 -> (HostOp (Wise lore) (OpWithWisdom op), Stms (Wise lore)))
-> SimpleM lore (HostOp (Wise lore) (OpWithWisdom op))
-> SimpleM
     lore
     (Stms (Wise lore)
      -> (HostOp (Wise lore) (OpWithWisdom op), Stms (Wise lore)))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SizeOp -> HostOp (Wise lore) (OpWithWisdom op)
forall lore op. SizeOp -> HostOp lore op
SizeOp (SizeOp -> HostOp (Wise lore) (OpWithWisdom op))
-> SimpleM lore SizeOp
-> SimpleM lore (HostOp (Wise lore) (OpWithWisdom op))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
           (SplitOrdering -> SubExp -> SubExp -> SubExp -> SizeOp
SplitSpace (SplitOrdering -> SubExp -> SubExp -> SubExp -> SizeOp)
-> SimpleM lore SplitOrdering
-> SimpleM lore (SubExp -> SubExp -> SubExp -> SizeOp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SplitOrdering -> SimpleM lore SplitOrdering
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify SplitOrdering
o SimpleM lore (SubExp -> SubExp -> SubExp -> SizeOp)
-> SimpleM lore SubExp -> SimpleM lore (SubExp -> SubExp -> SizeOp)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify SubExp
w
            SimpleM lore (SubExp -> SubExp -> SizeOp)
-> SimpleM lore SubExp -> SimpleM lore (SubExp -> SizeOp)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify SubExp
i SimpleM lore (SubExp -> SizeOp)
-> SimpleM lore SubExp -> SimpleM lore SizeOp
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify SubExp
elems_per_thread))
      SimpleM
  lore
  (Stms (Wise lore)
   -> (HostOp (Wise lore) (OpWithWisdom op), Stms (Wise lore)))
-> SimpleM lore (Stms (Wise lore))
-> SimpleM
     lore (HostOp (Wise lore) (OpWithWisdom op), Stms (Wise lore))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Stms (Wise lore) -> SimpleM lore (Stms (Wise lore))
forall (f :: * -> *) a. Applicative f => a -> f a
pure Stms (Wise lore)
forall a. Monoid a => a
mempty
simplifyKernelOp SimplifyOp lore op
_ (SizeOp (GetSize Name
key SizeClass
size_class)) =
  (HostOp (Wise lore) (OpWithWisdom op), Stms (Wise lore))
-> SimpleM
     lore (HostOp (Wise lore) (OpWithWisdom op), Stms (Wise lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (SizeOp -> HostOp (Wise lore) (OpWithWisdom op)
forall lore op. SizeOp -> HostOp lore op
SizeOp (SizeOp -> HostOp (Wise lore) (OpWithWisdom op))
-> SizeOp -> HostOp (Wise lore) (OpWithWisdom op)
forall a b. (a -> b) -> a -> b
$ Name -> SizeClass -> SizeOp
GetSize Name
key SizeClass
size_class, Stms (Wise lore)
forall a. Monoid a => a
mempty)
simplifyKernelOp SimplifyOp lore op
_ (SizeOp (GetSizeMax SizeClass
size_class)) =
  (HostOp (Wise lore) (OpWithWisdom op), Stms (Wise lore))
-> SimpleM
     lore (HostOp (Wise lore) (OpWithWisdom op), Stms (Wise lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (SizeOp -> HostOp (Wise lore) (OpWithWisdom op)
forall lore op. SizeOp -> HostOp lore op
SizeOp (SizeOp -> HostOp (Wise lore) (OpWithWisdom op))
-> SizeOp -> HostOp (Wise lore) (OpWithWisdom op)
forall a b. (a -> b) -> a -> b
$ SizeClass -> SizeOp
GetSizeMax SizeClass
size_class, Stms (Wise lore)
forall a. Monoid a => a
mempty)
simplifyKernelOp SimplifyOp lore op
_ (SizeOp (CmpSizeLe Name
key SizeClass
size_class SubExp
x)) = do
  SubExp
x' <- SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify SubExp
x
  (HostOp (Wise lore) (OpWithWisdom op), Stms (Wise lore))
-> SimpleM
     lore (HostOp (Wise lore) (OpWithWisdom op), Stms (Wise lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (SizeOp -> HostOp (Wise lore) (OpWithWisdom op)
forall lore op. SizeOp -> HostOp lore op
SizeOp (SizeOp -> HostOp (Wise lore) (OpWithWisdom op))
-> SizeOp -> HostOp (Wise lore) (OpWithWisdom op)
forall a b. (a -> b) -> a -> b
$ Name -> SizeClass -> SubExp -> SizeOp
CmpSizeLe Name
key SizeClass
size_class SubExp
x', Stms (Wise lore)
forall a. Monoid a => a
mempty)
simplifyKernelOp SimplifyOp lore op
_ (SizeOp (CalcNumGroups SubExp
w Name
max_num_groups SubExp
group_size)) = do
  SubExp
w' <- SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify SubExp
w
  (HostOp (Wise lore) (OpWithWisdom op), Stms (Wise lore))
-> SimpleM
     lore (HostOp (Wise lore) (OpWithWisdom op), Stms (Wise lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (SizeOp -> HostOp (Wise lore) (OpWithWisdom op)
forall lore op. SizeOp -> HostOp lore op
SizeOp (SizeOp -> HostOp (Wise lore) (OpWithWisdom op))
-> SizeOp -> HostOp (Wise lore) (OpWithWisdom op)
forall a b. (a -> b) -> a -> b
$ SubExp -> Name -> SubExp -> SizeOp
CalcNumGroups SubExp
w' Name
max_num_groups SubExp
group_size, Stms (Wise lore)
forall a. Monoid a => a
mempty)

simplifyRedOrScan :: (Engine.SimplifiableLore lore, BodyAttr lore ~ ()) =>
                     SegSpace
                  -> Lambda lore -> [SubExp] -> [Type]
                  -> KernelBody lore
                  -> Simplify.SimpleM lore
                  (SegSpace, Lambda (Wise lore), [SubExp], [Type], KernelBody (Wise lore),
                   Stms (Wise lore))
simplifyRedOrScan :: SegSpace
-> Lambda lore
-> [SubExp]
-> [Type]
-> KernelBody lore
-> SimpleM
     lore
     (SegSpace, Lambda (Wise lore), [SubExp], [Type],
      KernelBody (Wise lore), Stms (Wise lore))
simplifyRedOrScan SegSpace
space Lambda lore
scan_op [SubExp]
nes [Type]
ts KernelBody lore
kbody = do
  SegSpace
space' <- SegSpace -> SimpleM lore SegSpace
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify SegSpace
space
  [SubExp]
nes' <- (SubExp -> SimpleM lore SubExp)
-> [SubExp] -> SimpleM lore [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify [SubExp]
nes
  [Type]
ts' <- (Type -> SimpleM lore Type) -> [Type] -> SimpleM lore [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Type -> SimpleM lore Type
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify [Type]
ts

  (Lambda (Wise lore)
scan_op', Stms (Wise lore)
scan_op_hoisted) <-
    (SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
forall lore a.
(SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a
Engine.localVtable (SymbolTable (Wise lore)
-> SymbolTable (Wise lore) -> SymbolTable (Wise lore)
forall a. Semigroup a => a -> a -> a
<>SymbolTable (Wise lore)
scope_vtable) (SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
 -> SimpleM lore (Lambda (Wise lore), Stms (Wise lore)))
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
forall a b. (a -> b) -> a -> b
$
    (SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
forall lore a.
(SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a
Engine.localVtable (\SymbolTable (Wise lore)
vtable -> SymbolTable (Wise lore)
vtable { simplifyMemory :: Bool
ST.simplifyMemory = Bool
True }) (SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
 -> SimpleM lore (Lambda (Wise lore), Stms (Wise lore)))
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
forall a b. (a -> b) -> a -> b
$
    Lambda lore
-> [Maybe VName]
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
forall lore.
SimplifiableLore lore =>
Lambda lore
-> [Maybe VName]
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
Engine.simplifyLambda Lambda lore
scan_op ([Maybe VName]
 -> SimpleM lore (Lambda (Wise lore), Stms (Wise lore)))
-> [Maybe VName]
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
forall a b. (a -> b) -> a -> b
$ Int -> Maybe VName -> [Maybe VName]
forall a. Int -> a -> [a]
replicate ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
2) Maybe VName
forall a. Maybe a
Nothing

  (KernelBody (Wise lore)
kbody', Stms (Wise lore)
body_hoisted) <- SegSpace
-> KernelBody lore
-> SimpleM lore (KernelBody (Wise lore), Stms (Wise lore))
forall lore.
(SimplifiableLore lore, BodyAttr lore ~ ()) =>
SegSpace
-> KernelBody lore
-> SimpleM lore (KernelBody (Wise lore), Stms (Wise lore))
simplifyKernelBody SegSpace
space KernelBody lore
kbody

  (SegSpace, Lambda (Wise lore), [SubExp], [Type],
 KernelBody (Wise lore), Stms (Wise lore))
-> SimpleM
     lore
     (SegSpace, Lambda (Wise lore), [SubExp], [Type],
      KernelBody (Wise lore), Stms (Wise lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (SegSpace
space', Lambda (Wise lore)
scan_op', [SubExp]
nes', [Type]
ts', KernelBody (Wise lore)
kbody',
          Stms (Wise lore)
scan_op_hoisted Stms (Wise lore) -> Stms (Wise lore) -> Stms (Wise lore)
forall a. Semigroup a => a -> a -> a
<> Stms (Wise lore)
body_hoisted)

  where scope :: Scope (Wise lore)
scope = SegSpace -> Scope (Wise lore)
forall lore. SegSpace -> Scope lore
scopeOfSegSpace SegSpace
space
        scope_vtable :: SymbolTable (Wise lore)
scope_vtable = Scope (Wise lore) -> SymbolTable (Wise lore)
forall lore. Attributes lore => Scope lore -> SymbolTable lore
ST.fromScope Scope (Wise lore)
scope

simplifyKernelBody :: (Engine.SimplifiableLore lore, BodyAttr lore ~ ()) =>
                      SegSpace -> KernelBody lore
                   -> Engine.SimpleM lore (KernelBody (Wise lore), Stms (Wise lore))
simplifyKernelBody :: SegSpace
-> KernelBody lore
-> SimpleM lore (KernelBody (Wise lore), Stms (Wise lore))
simplifyKernelBody SegSpace
space (KernelBody BodyAttr lore
_ Stms lore
stms [KernelResult]
res) = do
  BlockPred (Wise lore)
par_blocker <- (Env lore -> BlockPred (Wise lore))
-> SimpleM lore (BlockPred (Wise lore))
forall lore a. (Env lore -> a) -> SimpleM lore a
Engine.asksEngineEnv ((Env lore -> BlockPred (Wise lore))
 -> SimpleM lore (BlockPred (Wise lore)))
-> (Env lore -> BlockPred (Wise lore))
-> SimpleM lore (BlockPred (Wise lore))
forall a b. (a -> b) -> a -> b
$ HoistBlockers lore -> BlockPred (Wise lore)
forall lore. HoistBlockers lore -> BlockPred (Wise lore)
Engine.blockHoistPar (HoistBlockers lore -> BlockPred (Wise lore))
-> (Env lore -> HoistBlockers lore)
-> Env lore
-> BlockPred (Wise lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env lore -> HoistBlockers lore
forall lore. Env lore -> HoistBlockers lore
Engine.envHoistBlockers

  ((Stms (Wise lore)
body_stms, [KernelResult]
body_res), Stms (Wise lore)
hoisted) <-
    (SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM
     lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore))
-> SimpleM
     lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore))
forall lore a.
(SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a
Engine.localVtable (SymbolTable (Wise lore)
-> SymbolTable (Wise lore) -> SymbolTable (Wise lore)
forall a. Semigroup a => a -> a -> a
<>SymbolTable (Wise lore)
scope_vtable) (SimpleM
   lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore))
 -> SimpleM
      lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore)))
-> SimpleM
     lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore))
-> SimpleM
     lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore))
forall a b. (a -> b) -> a -> b
$
    (SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM
     lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore))
-> SimpleM
     lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore))
forall lore a.
(SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a
Engine.localVtable (\SymbolTable (Wise lore)
vtable -> SymbolTable (Wise lore)
vtable { simplifyMemory :: Bool
ST.simplifyMemory = Bool
True }) (SimpleM
   lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore))
 -> SimpleM
      lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore)))
-> SimpleM
     lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore))
-> SimpleM
     lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore))
forall a b. (a -> b) -> a -> b
$
    BlockPred (Wise lore)
-> SimpleM lore (SimplifiedBody lore [KernelResult])
-> SimpleM
     lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore))
forall lore a.
SimplifiableLore lore =>
BlockPred (Wise lore)
-> SimpleM lore (SimplifiedBody lore a)
-> SimpleM lore ((Stms (Wise lore), a), Stms (Wise lore))
Engine.blockIf (Names -> BlockPred (Wise lore)
forall lore. Attributes lore => Names -> BlockPred lore
Engine.hasFree Names
bound_here
                    BlockPred (Wise lore)
-> BlockPred (Wise lore) -> BlockPred (Wise lore)
forall lore. BlockPred lore -> BlockPred lore -> BlockPred lore
`Engine.orIf` BlockPred (Wise lore)
forall lore. BlockPred lore
Engine.isOp
                    BlockPred (Wise lore)
-> BlockPred (Wise lore) -> BlockPred (Wise lore)
forall lore. BlockPred lore -> BlockPred lore -> BlockPred lore
`Engine.orIf` BlockPred (Wise lore)
par_blocker
                    BlockPred (Wise lore)
-> BlockPred (Wise lore) -> BlockPred (Wise lore)
forall lore. BlockPred lore -> BlockPred lore -> BlockPred lore
`Engine.orIf` BlockPred (Wise lore)
forall lore. BlockPred lore
Engine.isConsumed) (SimpleM lore (SimplifiedBody lore [KernelResult])
 -> SimpleM
      lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore)))
-> SimpleM lore (SimplifiedBody lore [KernelResult])
-> SimpleM
     lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore))
forall a b. (a -> b) -> a -> b
$
    Stms lore
-> SimpleM lore (SimplifiedBody lore [KernelResult])
-> SimpleM lore (SimplifiedBody lore [KernelResult])
forall lore a.
SimplifiableLore lore =>
Stms lore
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
Engine.simplifyStms Stms lore
stms (SimpleM lore (SimplifiedBody lore [KernelResult])
 -> SimpleM lore (SimplifiedBody lore [KernelResult]))
-> SimpleM lore (SimplifiedBody lore [KernelResult])
-> SimpleM lore (SimplifiedBody lore [KernelResult])
forall a b. (a -> b) -> a -> b
$ do
    [KernelResult]
res' <- (SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore [KernelResult] -> SimpleM lore [KernelResult]
forall lore a.
(SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a
Engine.localVtable (Names -> SymbolTable (Wise lore) -> SymbolTable (Wise lore)
forall lore. Names -> SymbolTable lore -> SymbolTable lore
ST.hideCertified (Names -> SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> Names -> SymbolTable (Wise lore) -> SymbolTable (Wise lore)
forall a b. (a -> b) -> a -> b
$ [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ Map VName (NameInfo lore) -> [VName]
forall k a. Map k a -> [k]
M.keys (Map VName (NameInfo lore) -> [VName])
-> Map VName (NameInfo lore) -> [VName]
forall a b. (a -> b) -> a -> b
$ Stms lore -> Map VName (NameInfo lore)
forall lore a. Scoped lore a => a -> Scope lore
scopeOf Stms lore
stms) (SimpleM lore [KernelResult] -> SimpleM lore [KernelResult])
-> SimpleM lore [KernelResult] -> SimpleM lore [KernelResult]
forall a b. (a -> b) -> a -> b
$
            (KernelResult -> SimpleM lore KernelResult)
-> [KernelResult] -> SimpleM lore [KernelResult]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM KernelResult -> SimpleM lore KernelResult
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify [KernelResult]
res
    SimplifiedBody lore [KernelResult]
-> SimpleM lore (SimplifiedBody lore [KernelResult])
forall (m :: * -> *) a. Monad m => a -> m a
return (([KernelResult]
res', Names -> UsageTable
UT.usages (Names -> UsageTable) -> Names -> UsageTable
forall a b. (a -> b) -> a -> b
$ [KernelResult] -> Names
forall a. FreeIn a => a -> Names
freeIn [KernelResult]
res'), Stms (Wise lore)
forall a. Monoid a => a
mempty)

  (KernelBody (Wise lore), Stms (Wise lore))
-> SimpleM lore (KernelBody (Wise lore), Stms (Wise lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (BodyAttr lore
-> Stms (Wise lore) -> [KernelResult] -> KernelBody (Wise lore)
forall lore.
(Attributes lore, CanBeWise (Op lore)) =>
BodyAttr lore
-> Stms (Wise lore) -> [KernelResult] -> KernelBody (Wise lore)
mkWiseKernelBody () Stms (Wise lore)
body_stms [KernelResult]
body_res,
          Stms (Wise lore)
hoisted)

  where scope_vtable :: SymbolTable (Wise lore)
scope_vtable = Scope (Wise lore) -> SymbolTable (Wise lore)
forall lore. Attributes lore => Scope lore -> SymbolTable lore
ST.fromScope (Scope (Wise lore) -> SymbolTable (Wise lore))
-> Scope (Wise lore) -> SymbolTable (Wise lore)
forall a b. (a -> b) -> a -> b
$ SegSpace -> Scope (Wise lore)
forall lore. SegSpace -> Scope lore
scopeOfSegSpace SegSpace
space
        bound_here :: Names
bound_here = [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ Map VName (NameInfo Any) -> [VName]
forall k a. Map k a -> [k]
M.keys (Map VName (NameInfo Any) -> [VName])
-> Map VName (NameInfo Any) -> [VName]
forall a b. (a -> b) -> a -> b
$ SegSpace -> Map VName (NameInfo Any)
forall lore. SegSpace -> Scope lore
scopeOfSegSpace SegSpace
space

mkWiseKernelBody :: (Attributes lore, CanBeWise (Op lore)) =>
                    BodyAttr lore -> Stms (Wise lore) -> [KernelResult] -> KernelBody (Wise lore)
mkWiseKernelBody :: BodyAttr lore
-> Stms (Wise lore) -> [KernelResult] -> KernelBody (Wise lore)
mkWiseKernelBody BodyAttr lore
attr Stms (Wise lore)
bnds [KernelResult]
res =
  let Body BodyAttr (Wise lore)
attr' Stms (Wise lore)
_ [SubExp]
_ = BodyAttr lore -> Stms (Wise lore) -> [SubExp] -> BodyT (Wise lore)
forall lore.
(Attributes lore, CanBeWise (Op lore)) =>
BodyAttr lore -> Stms (Wise lore) -> [SubExp] -> Body (Wise lore)
mkWiseBody BodyAttr lore
attr Stms (Wise lore)
bnds [SubExp]
res_vs
  in BodyAttr (Wise lore)
-> Stms (Wise lore) -> [KernelResult] -> KernelBody (Wise lore)
forall lore.
BodyAttr lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody BodyAttr (Wise lore)
attr' Stms (Wise lore)
bnds [KernelResult]
res
  where res_vs :: [SubExp]
res_vs = (KernelResult -> SubExp) -> [KernelResult] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> SubExp
kernelResultSubExp [KernelResult]
res

instance Engine.Simplifiable SplitOrdering where
  simplify :: SplitOrdering -> SimpleM lore SplitOrdering
simplify SplitOrdering
SplitContiguous =
    SplitOrdering -> SimpleM lore SplitOrdering
forall (m :: * -> *) a. Monad m => a -> m a
return SplitOrdering
SplitContiguous
  simplify (SplitStrided SubExp
stride) =
    SubExp -> SplitOrdering
SplitStrided (SubExp -> SplitOrdering)
-> SimpleM lore SubExp -> SimpleM lore SplitOrdering
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify SubExp
stride

instance Engine.Simplifiable SegLevel where
  simplify :: SegLevel -> SimpleM lore SegLevel
simplify (SegThread Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegVirt
virt) =
    Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegThread (Count NumGroups SubExp
 -> Count GroupSize SubExp -> SegVirt -> SegLevel)
-> SimpleM lore (Count NumGroups SubExp)
-> SimpleM lore (Count GroupSize SubExp -> SegVirt -> SegLevel)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExp -> SimpleM lore SubExp)
-> Count NumGroups SubExp -> SimpleM lore (Count NumGroups SubExp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify Count NumGroups SubExp
num_groups SimpleM lore (Count GroupSize SubExp -> SegVirt -> SegLevel)
-> SimpleM lore (Count GroupSize SubExp)
-> SimpleM lore (SegVirt -> SegLevel)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*>
    (SubExp -> SimpleM lore SubExp)
-> Count GroupSize SubExp -> SimpleM lore (Count GroupSize SubExp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify Count GroupSize SubExp
group_size SimpleM lore (SegVirt -> SegLevel)
-> SimpleM lore SegVirt -> SimpleM lore SegLevel
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegVirt -> SimpleM lore SegVirt
forall (f :: * -> *) a. Applicative f => a -> f a
pure SegVirt
virt
  simplify (SegGroup Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegVirt
virt) =
    Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegGroup (Count NumGroups SubExp
 -> Count GroupSize SubExp -> SegVirt -> SegLevel)
-> SimpleM lore (Count NumGroups SubExp)
-> SimpleM lore (Count GroupSize SubExp -> SegVirt -> SegLevel)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExp -> SimpleM lore SubExp)
-> Count NumGroups SubExp -> SimpleM lore (Count NumGroups SubExp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify Count NumGroups SubExp
num_groups SimpleM lore (Count GroupSize SubExp -> SegVirt -> SegLevel)
-> SimpleM lore (Count GroupSize SubExp)
-> SimpleM lore (SegVirt -> SegLevel)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*>
    (SubExp -> SimpleM lore SubExp)
-> Count GroupSize SubExp -> SimpleM lore (Count GroupSize SubExp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify Count GroupSize SubExp
group_size SimpleM lore (SegVirt -> SegLevel)
-> SimpleM lore SegVirt -> SimpleM lore SegLevel
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegVirt -> SimpleM lore SegVirt
forall (f :: * -> *) a. Applicative f => a -> f a
pure SegVirt
virt

instance Engine.Simplifiable SegSpace where
  simplify :: SegSpace -> SimpleM lore SegSpace
simplify (SegSpace VName
phys [(VName, SubExp)]
dims) =
    VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
phys ([(VName, SubExp)] -> SegSpace)
-> SimpleM lore [(VName, SubExp)] -> SimpleM lore SegSpace
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((VName, SubExp) -> SimpleM lore (VName, SubExp))
-> [(VName, SubExp)] -> SimpleM lore [(VName, SubExp)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((SubExp -> SimpleM lore SubExp)
-> (VName, SubExp) -> SimpleM lore (VName, SubExp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify) [(VName, SubExp)]
dims

instance Engine.Simplifiable KernelResult where
  simplify :: KernelResult -> SimpleM lore KernelResult
simplify (Returns ResultManifest
manifest SubExp
what) =
    ResultManifest -> SubExp -> KernelResult
Returns ResultManifest
manifest (SubExp -> KernelResult)
-> SimpleM lore SubExp -> SimpleM lore KernelResult
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify SubExp
what
  simplify (WriteReturns [SubExp]
ws VName
a [([SubExp], SubExp)]
res) =
    [SubExp] -> VName -> [([SubExp], SubExp)] -> KernelResult
WriteReturns ([SubExp] -> VName -> [([SubExp], SubExp)] -> KernelResult)
-> SimpleM lore [SubExp]
-> SimpleM lore (VName -> [([SubExp], SubExp)] -> KernelResult)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [SubExp] -> SimpleM lore [SubExp]
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify [SubExp]
ws SimpleM lore (VName -> [([SubExp], SubExp)] -> KernelResult)
-> SimpleM lore VName
-> SimpleM lore ([([SubExp], SubExp)] -> KernelResult)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> VName -> SimpleM lore VName
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify VName
a SimpleM lore ([([SubExp], SubExp)] -> KernelResult)
-> SimpleM lore [([SubExp], SubExp)] -> SimpleM lore KernelResult
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [([SubExp], SubExp)] -> SimpleM lore [([SubExp], SubExp)]
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify [([SubExp], SubExp)]
res
  simplify (ConcatReturns SplitOrdering
o SubExp
w SubExp
pte VName
what) =
    SplitOrdering -> SubExp -> SubExp -> VName -> KernelResult
ConcatReturns
    (SplitOrdering -> SubExp -> SubExp -> VName -> KernelResult)
-> SimpleM lore SplitOrdering
-> SimpleM lore (SubExp -> SubExp -> VName -> KernelResult)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SplitOrdering -> SimpleM lore SplitOrdering
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify SplitOrdering
o
    SimpleM lore (SubExp -> SubExp -> VName -> KernelResult)
-> SimpleM lore SubExp
-> SimpleM lore (SubExp -> VName -> KernelResult)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify SubExp
w
    SimpleM lore (SubExp -> VName -> KernelResult)
-> SimpleM lore SubExp -> SimpleM lore (VName -> KernelResult)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify SubExp
pte
    SimpleM lore (VName -> KernelResult)
-> SimpleM lore VName -> SimpleM lore KernelResult
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> VName -> SimpleM lore VName
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify VName
what
  simplify (TileReturns [(SubExp, SubExp)]
dims VName
what) =
    [(SubExp, SubExp)] -> VName -> KernelResult
TileReturns ([(SubExp, SubExp)] -> VName -> KernelResult)
-> SimpleM lore [(SubExp, SubExp)]
-> SimpleM lore (VName -> KernelResult)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [(SubExp, SubExp)] -> SimpleM lore [(SubExp, SubExp)]
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify [(SubExp, SubExp)]
dims SimpleM lore (VName -> KernelResult)
-> SimpleM lore VName -> SimpleM lore KernelResult
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> VName -> SimpleM lore VName
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify VName
what

instance BinderOps (Wise Kernels) where
  mkExpAttrB :: Pattern (Wise Kernels)
-> Exp (Wise Kernels) -> m (ExpAttr (Wise Kernels))
mkExpAttrB = Pattern (Wise Kernels)
-> Exp (Wise Kernels) -> m (ExpAttr (Wise Kernels))
forall (m :: * -> *).
(MonadBinder m, Bindable (Lore m)) =>
Pattern (Lore m) -> Exp (Lore m) -> m (ExpAttr (Lore m))
bindableMkExpAttrB
  mkBodyB :: Stms (Wise Kernels) -> [SubExp] -> m (Body (Wise Kernels))
mkBodyB = Stms (Wise Kernels) -> [SubExp] -> m (Body (Wise Kernels))
forall (m :: * -> *).
(MonadBinder m, Bindable (Lore m)) =>
Stms (Lore m) -> [SubExp] -> m (Body (Lore m))
bindableMkBodyB
  mkLetNamesB :: [VName] -> Exp (Wise Kernels) -> m (Stm (Wise Kernels))
mkLetNamesB = [VName] -> Exp (Wise Kernels) -> m (Stm (Wise Kernels))
forall (m :: * -> *).
(MonadBinder m, Bindable (Lore m)) =>
[VName] -> Exp (Lore m) -> m (Stm (Lore m))
bindableMkLetNamesB

kernelRules :: RuleBook (Wise Kernels)
kernelRules :: RuleBook (Wise Kernels)
kernelRules = RuleBook (Wise Kernels)
forall lore. (BinderOps lore, Aliased lore) => RuleBook lore
standardRules RuleBook (Wise Kernels)
-> RuleBook (Wise Kernels) -> RuleBook (Wise Kernels)
forall a. Semigroup a => a -> a -> a
<>
              [TopDownRule (Wise Kernels)]
-> [BottomUpRule (Wise Kernels)] -> RuleBook (Wise Kernels)
forall m. [TopDownRule m] -> [BottomUpRule m] -> RuleBook m
ruleBook [ RuleOp (Wise Kernels) (TopDown (Wise Kernels))
-> TopDownRule (Wise Kernels)
forall lore a. RuleOp lore a -> SimplificationRule lore a
RuleOp RuleOp (Wise Kernels) (TopDown (Wise Kernels))
removeInvariantKernelResults
                       , RuleOp (Wise Kernels) (TopDown (Wise Kernels))
-> TopDownRule (Wise Kernels)
forall lore a. RuleOp lore a -> SimplificationRule lore a
RuleOp RuleOp (Wise Kernels) (TopDown (Wise Kernels))
mergeSegRedOps
                       , RuleOp (Wise Kernels) (TopDown (Wise Kernels))
-> TopDownRule (Wise Kernels)
forall lore a. RuleOp lore a -> SimplificationRule lore a
RuleOp RuleOp (Wise Kernels) (TopDown (Wise Kernels))
redomapIotaToLoop
                       ]
                       [ RuleOp (Wise Kernels) (BottomUp (Wise Kernels))
-> BottomUpRule (Wise Kernels)
forall lore a. RuleOp lore a -> SimplificationRule lore a
RuleOp RuleOp (Wise Kernels) (BottomUp (Wise Kernels))
distributeKernelResults
                       , RuleBasicOp (Wise Kernels) (BottomUp (Wise Kernels))
-> BottomUpRule (Wise Kernels)
forall lore a. RuleBasicOp lore a -> SimplificationRule lore a
RuleBasicOp RuleBasicOp (Wise Kernels) (BottomUp (Wise Kernels))
forall lore. BinderOps lore => BottomUpRuleBasicOp lore
removeUnnecessaryCopy
                       ]

-- If a kernel produces something invariant to the kernel, turn it
-- into a replicate.
removeInvariantKernelResults :: TopDownRuleOp (Wise Kernels)
removeInvariantKernelResults :: RuleOp (Wise Kernels) (TopDown (Wise Kernels))
removeInvariantKernelResults TopDown (Wise Kernels)
vtable (Pattern [] [PatElemT (LetAttr (Wise Kernels))]
kpes) StmAux (ExpAttr (Wise Kernels))
attr
                             (SegOp (SegMap lvl space ts (KernelBody _ kstms kres))) = RuleM (Wise Kernels) () -> Rule (Wise Kernels)
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM (Wise Kernels) () -> Rule (Wise Kernels))
-> RuleM (Wise Kernels) () -> Rule (Wise Kernels)
forall a b. (a -> b) -> a -> b
$ do

  ([Type]
ts', [PatElemT (VarWisdom, Type)]
kpes', [KernelResult]
kres') <-
    [(Type, PatElemT (VarWisdom, Type), KernelResult)]
-> ([Type], [PatElemT (VarWisdom, Type)], [KernelResult])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 ([(Type, PatElemT (VarWisdom, Type), KernelResult)]
 -> ([Type], [PatElemT (VarWisdom, Type)], [KernelResult]))
-> RuleM
     (Wise Kernels) [(Type, PatElemT (VarWisdom, Type), KernelResult)]
-> RuleM
     (Wise Kernels)
     ([Type], [PatElemT (VarWisdom, Type)], [KernelResult])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((Type, PatElemT (VarWisdom, Type), KernelResult)
 -> RuleM (Wise Kernels) Bool)
-> [(Type, PatElemT (VarWisdom, Type), KernelResult)]
-> RuleM
     (Wise Kernels) [(Type, PatElemT (VarWisdom, Type), KernelResult)]
forall (m :: * -> *) a.
Applicative m =>
(a -> m Bool) -> [a] -> m [a]
filterM (Type, PatElemT (VarWisdom, Type), KernelResult)
-> RuleM (Wise Kernels) Bool
checkForInvarianceResult ([Type]
-> [PatElemT (VarWisdom, Type)]
-> [KernelResult]
-> [(Type, PatElemT (VarWisdom, Type), KernelResult)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Type]
ts [PatElemT (VarWisdom, Type)]
[PatElemT (LetAttr (Wise Kernels))]
kpes [KernelResult]
kres)

  -- Check if we did anything at all.
  Bool -> RuleM (Wise Kernels) () -> RuleM (Wise Kernels) ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ([KernelResult]
kres [KernelResult] -> [KernelResult] -> Bool
forall a. Eq a => a -> a -> Bool
== [KernelResult]
kres')
    RuleM (Wise Kernels) ()
forall lore a. RuleM lore a
cannotSimplify

  Stm (Lore (RuleM (Wise Kernels))) -> RuleM (Wise Kernels) ()
forall (m :: * -> *). MonadBinder m => Stm (Lore m) -> m ()
addStm (Stm (Lore (RuleM (Wise Kernels))) -> RuleM (Wise Kernels) ())
-> Stm (Lore (RuleM (Wise Kernels))) -> RuleM (Wise Kernels) ()
forall a b. (a -> b) -> a -> b
$ Pattern (Wise Kernels)
-> StmAux (ExpAttr (Wise Kernels))
-> Exp (Wise Kernels)
-> Stm (Wise Kernels)
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let ([PatElemT (VarWisdom, Type)]
-> [PatElemT (VarWisdom, Type)] -> PatternT (VarWisdom, Type)
forall attr. [PatElemT attr] -> [PatElemT attr] -> PatternT attr
Pattern [] [PatElemT (VarWisdom, Type)]
kpes') StmAux (ExpAttr (Wise Kernels))
attr (Exp (Wise Kernels) -> Stm (Wise Kernels))
-> Exp (Wise Kernels) -> Stm (Wise Kernels)
forall a b. (a -> b) -> a -> b
$ Op (Wise Kernels) -> Exp (Wise Kernels)
forall lore. Op lore -> ExpT lore
Op (Op (Wise Kernels) -> Exp (Wise Kernels))
-> Op (Wise Kernels) -> Exp (Wise Kernels)
forall a b. (a -> b) -> a -> b
$ SegOp (Wise Kernels) -> HostOp (Wise Kernels) (SOAC (Wise Kernels))
forall lore op. SegOp lore -> HostOp lore op
SegOp (SegOp (Wise Kernels)
 -> HostOp (Wise Kernels) (SOAC (Wise Kernels)))
-> SegOp (Wise Kernels)
-> HostOp (Wise Kernels) (SOAC (Wise Kernels))
forall a b. (a -> b) -> a -> b
$ SegLevel
-> SegSpace
-> [Type]
-> KernelBody (Wise Kernels)
-> SegOp (Wise Kernels)
forall lore.
SegLevel -> SegSpace -> [Type] -> KernelBody lore -> SegOp lore
SegMap SegLevel
lvl SegSpace
space [Type]
ts' (KernelBody (Wise Kernels) -> SegOp (Wise Kernels))
-> KernelBody (Wise Kernels) -> SegOp (Wise Kernels)
forall a b. (a -> b) -> a -> b
$
    BodyAttr Kernels
-> Stms (Wise Kernels)
-> [KernelResult]
-> KernelBody (Wise Kernels)
forall lore.
(Attributes lore, CanBeWise (Op lore)) =>
BodyAttr lore
-> Stms (Wise lore) -> [KernelResult] -> KernelBody (Wise lore)
mkWiseKernelBody () Stms (Wise Kernels)
kstms [KernelResult]
kres'
  where isInvariant :: SubExp -> Bool
isInvariant Constant{} = Bool
True
        isInvariant (Var VName
v) = Maybe (Entry (Wise Kernels)) -> Bool
forall a. Maybe a -> Bool
isJust (Maybe (Entry (Wise Kernels)) -> Bool)
-> Maybe (Entry (Wise Kernels)) -> Bool
forall a b. (a -> b) -> a -> b
$ VName -> TopDown (Wise Kernels) -> Maybe (Entry (Wise Kernels))
forall lore. VName -> SymbolTable lore -> Maybe (Entry lore)
ST.lookup VName
v TopDown (Wise Kernels)
vtable

        checkForInvarianceResult :: (Type, PatElemT (VarWisdom, Type), KernelResult)
-> RuleM (Wise Kernels) Bool
checkForInvarianceResult (Type
_, PatElemT (VarWisdom, Type)
pe, Returns ResultManifest
rm SubExp
se)
          | ResultManifest
rm ResultManifest -> ResultManifest -> Bool
forall a. Eq a => a -> a -> Bool
== ResultManifest
ResultMaySimplify,
            SubExp -> Bool
isInvariant SubExp
se = do
              [VName]
-> Exp (Lore (RuleM (Wise Kernels))) -> RuleM (Wise Kernels) ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames_ [PatElemT (VarWisdom, Type) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (VarWisdom, Type)
pe] (Exp (Lore (RuleM (Wise Kernels))) -> RuleM (Wise Kernels) ())
-> Exp (Lore (RuleM (Wise Kernels))) -> RuleM (Wise Kernels) ()
forall a b. (a -> b) -> a -> b
$
                BasicOp (Wise Kernels) -> Exp (Wise Kernels)
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp (Wise Kernels) -> Exp (Wise Kernels))
-> BasicOp (Wise Kernels) -> Exp (Wise Kernels)
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp (Wise Kernels)
forall lore. Shape -> SubExp -> BasicOp lore
Replicate ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> Shape) -> [SubExp] -> Shape
forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space) SubExp
se
              Bool -> RuleM (Wise Kernels) Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
        checkForInvarianceResult (Type, PatElemT (VarWisdom, Type), KernelResult)
_ =
          Bool -> RuleM (Wise Kernels) Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
removeInvariantKernelResults TopDown (Wise Kernels)
_ Pattern (Wise Kernels)
_ StmAux (ExpAttr (Wise Kernels))
_ Op (Wise Kernels)
_ = Rule (Wise Kernels)
forall lore. Rule lore
Skip

-- Some kernel results can be moved outside the kernel, which can
-- simplify further analysis.
distributeKernelResults :: BottomUpRuleOp (Wise Kernels)
distributeKernelResults :: RuleOp (Wise Kernels) (BottomUp (Wise Kernels))
distributeKernelResults (TopDown (Wise Kernels)
vtable, UsageTable
used)
  (Pattern [] [PatElemT (LetAttr (Wise Kernels))]
kpes) StmAux (ExpAttr (Wise Kernels))
attr (SegOp (SegMap lvl space kts (KernelBody _ kstms kres))) = RuleM (Wise Kernels) () -> Rule (Wise Kernels)
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM (Wise Kernels) () -> Rule (Wise Kernels))
-> RuleM (Wise Kernels) () -> Rule (Wise Kernels)
forall a b. (a -> b) -> a -> b
$ do
  -- Iterate through the bindings.  For each, we check whether it is
  -- in kres and can be moved outside.  If so, we remove it from kres
  -- and kpes and make it a binding outside.
  ([PatElemT (VarWisdom, Type)]
kpes', [Type]
kts', [KernelResult]
kres', Stms (Wise Kernels)
kstms') <- Scope (Wise Kernels)
-> RuleM
     (Wise Kernels)
     ([PatElemT (VarWisdom, Type)], [Type], [KernelResult],
      Stms (Wise Kernels))
-> RuleM
     (Wise Kernels)
     ([PatElemT (VarWisdom, Type)], [Type], [KernelResult],
      Stms (Wise Kernels))
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (SegSpace -> Scope (Wise Kernels)
forall lore. SegSpace -> Scope lore
scopeOfSegSpace SegSpace
space) (RuleM
   (Wise Kernels)
   ([PatElemT (VarWisdom, Type)], [Type], [KernelResult],
    Stms (Wise Kernels))
 -> RuleM
      (Wise Kernels)
      ([PatElemT (VarWisdom, Type)], [Type], [KernelResult],
       Stms (Wise Kernels)))
-> RuleM
     (Wise Kernels)
     ([PatElemT (VarWisdom, Type)], [Type], [KernelResult],
      Stms (Wise Kernels))
-> RuleM
     (Wise Kernels)
     ([PatElemT (VarWisdom, Type)], [Type], [KernelResult],
      Stms (Wise Kernels))
forall a b. (a -> b) -> a -> b
$
    (([PatElemT (VarWisdom, Type)], [Type], [KernelResult],
  Stms (Wise Kernels))
 -> Stm (Wise Kernels)
 -> RuleM
      (Wise Kernels)
      ([PatElemT (VarWisdom, Type)], [Type], [KernelResult],
       Stms (Wise Kernels)))
-> ([PatElemT (VarWisdom, Type)], [Type], [KernelResult],
    Stms (Wise Kernels))
-> Stms (Wise Kernels)
-> RuleM
     (Wise Kernels)
     ([PatElemT (VarWisdom, Type)], [Type], [KernelResult],
      Stms (Wise Kernels))
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM ([PatElemT (VarWisdom, Type)], [Type], [KernelResult],
 Stms (Wise Kernels))
-> Stm (Wise Kernels)
-> RuleM
     (Wise Kernels)
     ([PatElemT (VarWisdom, Type)], [Type], [KernelResult],
      Stms (Wise Kernels))
distribute ([PatElemT (VarWisdom, Type)]
[PatElemT (LetAttr (Wise Kernels))]
kpes, [Type]
kts, [KernelResult]
kres, Stms (Wise Kernels)
forall a. Monoid a => a
mempty) Stms (Wise Kernels)
kstms

  Bool -> RuleM (Wise Kernels) () -> RuleM (Wise Kernels) ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ([PatElemT (VarWisdom, Type)]
kpes' [PatElemT (VarWisdom, Type)]
-> [PatElemT (VarWisdom, Type)] -> Bool
forall a. Eq a => a -> a -> Bool
== [PatElemT (VarWisdom, Type)]
[PatElemT (LetAttr (Wise Kernels))]
kpes)
    RuleM (Wise Kernels) ()
forall lore a. RuleM lore a
cannotSimplify

  Stm (Lore (RuleM (Wise Kernels))) -> RuleM (Wise Kernels) ()
forall (m :: * -> *). MonadBinder m => Stm (Lore m) -> m ()
addStm (Stm (Lore (RuleM (Wise Kernels))) -> RuleM (Wise Kernels) ())
-> Stm (Lore (RuleM (Wise Kernels))) -> RuleM (Wise Kernels) ()
forall a b. (a -> b) -> a -> b
$ Pattern (Wise Kernels)
-> StmAux (ExpAttr (Wise Kernels))
-> Exp (Wise Kernels)
-> Stm (Wise Kernels)
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let ([PatElemT (VarWisdom, Type)]
-> [PatElemT (VarWisdom, Type)] -> PatternT (VarWisdom, Type)
forall attr. [PatElemT attr] -> [PatElemT attr] -> PatternT attr
Pattern [] [PatElemT (VarWisdom, Type)]
kpes') StmAux (ExpAttr (Wise Kernels))
attr (Exp (Wise Kernels) -> Stm (Wise Kernels))
-> Exp (Wise Kernels) -> Stm (Wise Kernels)
forall a b. (a -> b) -> a -> b
$ Op (Wise Kernels) -> Exp (Wise Kernels)
forall lore. Op lore -> ExpT lore
Op (Op (Wise Kernels) -> Exp (Wise Kernels))
-> Op (Wise Kernels) -> Exp (Wise Kernels)
forall a b. (a -> b) -> a -> b
$ SegOp (Wise Kernels) -> HostOp (Wise Kernels) (SOAC (Wise Kernels))
forall lore op. SegOp lore -> HostOp lore op
SegOp (SegOp (Wise Kernels)
 -> HostOp (Wise Kernels) (SOAC (Wise Kernels)))
-> SegOp (Wise Kernels)
-> HostOp (Wise Kernels) (SOAC (Wise Kernels))
forall a b. (a -> b) -> a -> b
$
    SegLevel
-> SegSpace
-> [Type]
-> KernelBody (Wise Kernels)
-> SegOp (Wise Kernels)
forall lore.
SegLevel -> SegSpace -> [Type] -> KernelBody lore -> SegOp lore
SegMap SegLevel
lvl SegSpace
space [Type]
kts' (KernelBody (Wise Kernels) -> SegOp (Wise Kernels))
-> KernelBody (Wise Kernels) -> SegOp (Wise Kernels)
forall a b. (a -> b) -> a -> b
$ BodyAttr Kernels
-> Stms (Wise Kernels)
-> [KernelResult]
-> KernelBody (Wise Kernels)
forall lore.
(Attributes lore, CanBeWise (Op lore)) =>
BodyAttr lore
-> Stms (Wise lore) -> [KernelResult] -> KernelBody (Wise lore)
mkWiseKernelBody () Stms (Wise Kernels)
kstms' [KernelResult]
kres'
  where
    free_in_kstms :: Names
free_in_kstms = (Stm (Wise Kernels) -> Names) -> Stms (Wise Kernels) -> Names
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm (Wise Kernels) -> Names
forall a. FreeIn a => a -> Names
freeIn Stms (Wise Kernels)
kstms

    sliceWithGtidsFixed :: Stm (Wise Kernels) -> Maybe ([DimIndex SubExp], VName)
sliceWithGtidsFixed Stm (Wise Kernels)
stm
      | Let Pattern (Wise Kernels)
_ StmAux (ExpAttr (Wise Kernels))
_ (BasicOp (Index VName
arr [DimIndex SubExp]
slice)) <- Stm (Wise Kernels)
stm,
        [DimIndex SubExp]
space_slice <- ((VName, SubExp) -> DimIndex SubExp)
-> [(VName, SubExp)] -> [DimIndex SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp)
-> ((VName, SubExp) -> SubExp)
-> (VName, SubExp)
-> DimIndex SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var (VName -> SubExp)
-> ((VName, SubExp) -> VName) -> (VName, SubExp) -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst) ([(VName, SubExp)] -> [DimIndex SubExp])
-> [(VName, SubExp)] -> [DimIndex SubExp]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space,
        [DimIndex SubExp]
space_slice [DimIndex SubExp] -> [DimIndex SubExp] -> Bool
forall a. Eq a => [a] -> [a] -> Bool
`isPrefixOf` [DimIndex SubExp]
slice,
        [DimIndex SubExp]
remaining_slice <- Int -> [DimIndex SubExp] -> [DimIndex SubExp]
forall a. Int -> [a] -> [a]
drop ([DimIndex SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimIndex SubExp]
space_slice) [DimIndex SubExp]
slice,
        (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Maybe (Entry (Wise Kernels)) -> Bool
forall a. Maybe a -> Bool
isJust (Maybe (Entry (Wise Kernels)) -> Bool)
-> (VName -> Maybe (Entry (Wise Kernels))) -> VName -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> TopDown (Wise Kernels) -> Maybe (Entry (Wise Kernels)))
-> TopDown (Wise Kernels) -> VName -> Maybe (Entry (Wise Kernels))
forall a b c. (a -> b -> c) -> b -> a -> c
flip VName -> TopDown (Wise Kernels) -> Maybe (Entry (Wise Kernels))
forall lore. VName -> SymbolTable lore -> Maybe (Entry lore)
ST.lookup TopDown (Wise Kernels)
vtable) ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$
          VName -> Names
forall a. FreeIn a => a -> Names
freeIn VName
arr Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> [DimIndex SubExp] -> Names
forall a. FreeIn a => a -> Names
freeIn [DimIndex SubExp]
remaining_slice =
          ([DimIndex SubExp], VName) -> Maybe ([DimIndex SubExp], VName)
forall a. a -> Maybe a
Just ([DimIndex SubExp]
remaining_slice, VName
arr)

      | Bool
otherwise =
          Maybe ([DimIndex SubExp], VName)
forall a. Maybe a
Nothing

    distribute :: ([PatElemT (VarWisdom, Type)], [Type], [KernelResult],
 Stms (Wise Kernels))
-> Stm (Wise Kernels)
-> RuleM
     (Wise Kernels)
     ([PatElemT (VarWisdom, Type)], [Type], [KernelResult],
      Stms (Wise Kernels))
distribute ([PatElemT (VarWisdom, Type)]
kpes', [Type]
kts', [KernelResult]
kres', Stms (Wise Kernels)
kstms') Stm (Wise Kernels)
stm
      | Let (Pattern [] [PatElemT (LetAttr (Wise Kernels))
pe]) StmAux (ExpAttr (Wise Kernels))
_ Exp (Wise Kernels)
_ <- Stm (Wise Kernels)
stm,
        Just ([DimIndex SubExp]
remaining_slice, VName
arr) <- Stm (Wise Kernels) -> Maybe ([DimIndex SubExp], VName)
sliceWithGtidsFixed Stm (Wise Kernels)
stm,
        Just (PatElemT (VarWisdom, Type)
kpe, [PatElemT (VarWisdom, Type)]
kpes'', [Type]
kts'', [KernelResult]
kres'') <- [PatElemT (VarWisdom, Type)]
-> [Type]
-> [KernelResult]
-> PatElemT (VarWisdom, Type)
-> Maybe
     (PatElemT (VarWisdom, Type), [PatElemT (VarWisdom, Type)], [Type],
      [KernelResult])
forall a b attr.
[a]
-> [b]
-> [KernelResult]
-> PatElemT attr
-> Maybe (a, [a], [b], [KernelResult])
isResult [PatElemT (VarWisdom, Type)]
kpes' [Type]
kts' [KernelResult]
kres' PatElemT (VarWisdom, Type)
PatElemT (LetAttr (Wise Kernels))
pe = do
          let outer_slice :: [DimIndex SubExp]
outer_slice = (SubExp -> DimIndex SubExp) -> [SubExp] -> [DimIndex SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (\SubExp
d -> SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice
                                       (Int32 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int32
0::Int32))
                                       SubExp
d
                                       (Int32 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int32
1::Int32))) ([SubExp] -> [DimIndex SubExp]) -> [SubExp] -> [DimIndex SubExp]
forall a b. (a -> b) -> a -> b
$
                            SegSpace -> [SubExp]
segSpaceDims SegSpace
space
              index :: PatElemT (VarWisdom, Type) -> RuleM (Wise Kernels) ()
index PatElemT (VarWisdom, Type)
kpe' = Pattern (Lore (RuleM (Wise Kernels)))
-> Exp (Lore (RuleM (Wise Kernels))) -> RuleM (Wise Kernels) ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind_ ([PatElemT (VarWisdom, Type)]
-> [PatElemT (VarWisdom, Type)] -> PatternT (VarWisdom, Type)
forall attr. [PatElemT attr] -> [PatElemT attr] -> PatternT attr
Pattern [] [PatElemT (VarWisdom, Type)
kpe']) (Exp (Lore (RuleM (Wise Kernels))) -> RuleM (Wise Kernels) ())
-> Exp (Lore (RuleM (Wise Kernels))) -> RuleM (Wise Kernels) ()
forall a b. (a -> b) -> a -> b
$ BasicOp (Wise Kernels) -> Exp (Wise Kernels)
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp (Wise Kernels) -> Exp (Wise Kernels))
-> BasicOp (Wise Kernels) -> Exp (Wise Kernels)
forall a b. (a -> b) -> a -> b
$ VName -> [DimIndex SubExp] -> BasicOp (Wise Kernels)
forall lore. VName -> [DimIndex SubExp] -> BasicOp lore
Index VName
arr ([DimIndex SubExp] -> BasicOp (Wise Kernels))
-> [DimIndex SubExp] -> BasicOp (Wise Kernels)
forall a b. (a -> b) -> a -> b
$
                           [DimIndex SubExp]
outer_slice [DimIndex SubExp] -> [DimIndex SubExp] -> [DimIndex SubExp]
forall a. Semigroup a => a -> a -> a
<> [DimIndex SubExp]
remaining_slice
          if PatElemT (VarWisdom, Type) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (VarWisdom, Type)
kpe VName -> UsageTable -> Bool
`UT.isConsumed` UsageTable
used
            then do VName
precopy <- String -> RuleM (Wise Kernels) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> RuleM (Wise Kernels) VName)
-> String -> RuleM (Wise Kernels) VName
forall a b. (a -> b) -> a -> b
$ VName -> String
baseString (PatElemT (VarWisdom, Type) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (VarWisdom, Type)
kpe) String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_precopy"
                    PatElemT (VarWisdom, Type) -> RuleM (Wise Kernels) ()
index PatElemT (VarWisdom, Type)
kpe { patElemName :: VName
patElemName = VName
precopy }
                    Pattern (Lore (RuleM (Wise Kernels)))
-> Exp (Lore (RuleM (Wise Kernels))) -> RuleM (Wise Kernels) ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind_ ([PatElemT (VarWisdom, Type)]
-> [PatElemT (VarWisdom, Type)] -> PatternT (VarWisdom, Type)
forall attr. [PatElemT attr] -> [PatElemT attr] -> PatternT attr
Pattern [] [PatElemT (VarWisdom, Type)
kpe]) (Exp (Lore (RuleM (Wise Kernels))) -> RuleM (Wise Kernels) ())
-> Exp (Lore (RuleM (Wise Kernels))) -> RuleM (Wise Kernels) ()
forall a b. (a -> b) -> a -> b
$ BasicOp (Wise Kernels) -> Exp (Wise Kernels)
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp (Wise Kernels) -> Exp (Wise Kernels))
-> BasicOp (Wise Kernels) -> Exp (Wise Kernels)
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp (Wise Kernels)
forall lore. VName -> BasicOp lore
Copy VName
precopy
            else PatElemT (VarWisdom, Type) -> RuleM (Wise Kernels) ()
index PatElemT (VarWisdom, Type)
kpe
          ([PatElemT (VarWisdom, Type)], [Type], [KernelResult],
 Stms (Wise Kernels))
-> RuleM
     (Wise Kernels)
     ([PatElemT (VarWisdom, Type)], [Type], [KernelResult],
      Stms (Wise Kernels))
forall (m :: * -> *) a. Monad m => a -> m a
return ([PatElemT (VarWisdom, Type)]
kpes'', [Type]
kts'', [KernelResult]
kres'',
                  if PatElemT (VarWisdom, Type) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (VarWisdom, Type)
PatElemT (LetAttr (Wise Kernels))
pe VName -> Names -> Bool
`nameIn` Names
free_in_kstms
                  then Stms (Wise Kernels)
kstms' Stms (Wise Kernels) -> Stms (Wise Kernels) -> Stms (Wise Kernels)
forall a. Semigroup a => a -> a -> a
<> Stm (Wise Kernels) -> Stms (Wise Kernels)
forall lore. Stm lore -> Stms lore
oneStm Stm (Wise Kernels)
stm
                  else Stms (Wise Kernels)
kstms')

    distribute ([PatElemT (VarWisdom, Type)]
kpes', [Type]
kts', [KernelResult]
kres', Stms (Wise Kernels)
kstms') Stm (Wise Kernels)
stm =
      ([PatElemT (VarWisdom, Type)], [Type], [KernelResult],
 Stms (Wise Kernels))
-> RuleM
     (Wise Kernels)
     ([PatElemT (VarWisdom, Type)], [Type], [KernelResult],
      Stms (Wise Kernels))
forall (m :: * -> *) a. Monad m => a -> m a
return ([PatElemT (VarWisdom, Type)]
kpes', [Type]
kts', [KernelResult]
kres', Stms (Wise Kernels)
kstms' Stms (Wise Kernels) -> Stms (Wise Kernels) -> Stms (Wise Kernels)
forall a. Semigroup a => a -> a -> a
<> Stm (Wise Kernels) -> Stms (Wise Kernels)
forall lore. Stm lore -> Stms lore
oneStm Stm (Wise Kernels)
stm)

    isResult :: [a]
-> [b]
-> [KernelResult]
-> PatElemT attr
-> Maybe (a, [a], [b], [KernelResult])
isResult [a]
kpes' [b]
kts' [KernelResult]
kres' PatElemT attr
pe =
      case ((a, b, KernelResult) -> Bool)
-> [(a, b, KernelResult)]
-> ([(a, b, KernelResult)], [(a, b, KernelResult)])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition (a, b, KernelResult) -> Bool
matches ([(a, b, KernelResult)]
 -> ([(a, b, KernelResult)], [(a, b, KernelResult)]))
-> [(a, b, KernelResult)]
-> ([(a, b, KernelResult)], [(a, b, KernelResult)])
forall a b. (a -> b) -> a -> b
$ [a] -> [b] -> [KernelResult] -> [(a, b, KernelResult)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [a]
kpes' [b]
kts' [KernelResult]
kres' of
        ([(a
kpe,b
_,KernelResult
_)], [(a, b, KernelResult)]
kpes_and_kres)
          | ([a]
kpes'', [b]
kts'', [KernelResult]
kres'') <- [(a, b, KernelResult)] -> ([a], [b], [KernelResult])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(a, b, KernelResult)]
kpes_and_kres ->
              (a, [a], [b], [KernelResult])
-> Maybe (a, [a], [b], [KernelResult])
forall a. a -> Maybe a
Just (a
kpe, [a]
kpes'', [b]
kts'', [KernelResult]
kres'')
        ([(a, b, KernelResult)], [(a, b, KernelResult)])
_ -> Maybe (a, [a], [b], [KernelResult])
forall a. Maybe a
Nothing
      where matches :: (a, b, KernelResult) -> Bool
matches (a
_, b
_, Returns ResultManifest
_ (Var VName
v)) = VName
v VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== PatElemT attr -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT attr
pe
            matches (a, b, KernelResult)
_ = Bool
False
distributeKernelResults BottomUp (Wise Kernels)
_ Pattern (Wise Kernels)
_ StmAux (ExpAttr (Wise Kernels))
_ Op (Wise Kernels)
_ = Rule (Wise Kernels)
forall lore. Rule lore
Skip

-- If a SegRed contains two reduction operations that have the same
-- vector shape, merge them together.  This saves on communication
-- overhead, but can in principle lead to more local memory usage.
mergeSegRedOps :: TopDownRuleOp (Wise Kernels)
mergeSegRedOps :: RuleOp (Wise Kernels) (TopDown (Wise Kernels))
mergeSegRedOps TopDown (Wise Kernels)
_ (Pattern [] [PatElemT (LetAttr (Wise Kernels))]
pes) StmAux (ExpAttr (Wise Kernels))
_ (SegOp (SegRed lvl space ops ts kbody))
  | [SegRedOp (Wise Kernels)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SegRedOp (Wise Kernels)]
ops Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
1,
    [[(SegRedOp (Wise Kernels),
   [(PatElemT (VarWisdom, Type), Type, KernelResult)])]]
op_groupings <- ((SegRedOp (Wise Kernels),
  [(PatElemT (VarWisdom, Type), Type, KernelResult)])
 -> (SegRedOp (Wise Kernels),
     [(PatElemT (VarWisdom, Type), Type, KernelResult)])
 -> Bool)
-> [(SegRedOp (Wise Kernels),
     [(PatElemT (VarWisdom, Type), Type, KernelResult)])]
-> [[(SegRedOp (Wise Kernels),
      [(PatElemT (VarWisdom, Type), Type, KernelResult)])]]
forall a. (a -> a -> Bool) -> [a] -> [[a]]
groupBy (SegRedOp (Wise Kernels),
 [(PatElemT (VarWisdom, Type), Type, KernelResult)])
-> (SegRedOp (Wise Kernels),
    [(PatElemT (VarWisdom, Type), Type, KernelResult)])
-> Bool
forall lore b lore b.
(SegRedOp lore, b) -> (SegRedOp lore, b) -> Bool
sameShape ([(SegRedOp (Wise Kernels),
   [(PatElemT (VarWisdom, Type), Type, KernelResult)])]
 -> [[(SegRedOp (Wise Kernels),
       [(PatElemT (VarWisdom, Type), Type, KernelResult)])]])
-> [(SegRedOp (Wise Kernels),
     [(PatElemT (VarWisdom, Type), Type, KernelResult)])]
-> [[(SegRedOp (Wise Kernels),
      [(PatElemT (VarWisdom, Type), Type, KernelResult)])]]
forall a b. (a -> b) -> a -> b
$ [SegRedOp (Wise Kernels)]
-> [[(PatElemT (VarWisdom, Type), Type, KernelResult)]]
-> [(SegRedOp (Wise Kernels),
     [(PatElemT (VarWisdom, Type), Type, KernelResult)])]
forall a b. [a] -> [b] -> [(a, b)]
zip [SegRedOp (Wise Kernels)]
ops ([[(PatElemT (VarWisdom, Type), Type, KernelResult)]]
 -> [(SegRedOp (Wise Kernels),
      [(PatElemT (VarWisdom, Type), Type, KernelResult)])])
-> [[(PatElemT (VarWisdom, Type), Type, KernelResult)]]
-> [(SegRedOp (Wise Kernels),
     [(PatElemT (VarWisdom, Type), Type, KernelResult)])]
forall a b. (a -> b) -> a -> b
$ [Int]
-> [(PatElemT (VarWisdom, Type), Type, KernelResult)]
-> [[(PatElemT (VarWisdom, Type), Type, KernelResult)]]
forall a. [Int] -> [a] -> [[a]]
chunks ((SegRedOp (Wise Kernels) -> Int)
-> [SegRedOp (Wise Kernels)] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int)
-> (SegRedOp (Wise Kernels) -> [SubExp])
-> SegRedOp (Wise Kernels)
-> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegRedOp (Wise Kernels) -> [SubExp]
forall lore. SegRedOp lore -> [SubExp]
segRedNeutral) [SegRedOp (Wise Kernels)]
ops) ([(PatElemT (VarWisdom, Type), Type, KernelResult)]
 -> [[(PatElemT (VarWisdom, Type), Type, KernelResult)]])
-> [(PatElemT (VarWisdom, Type), Type, KernelResult)]
-> [[(PatElemT (VarWisdom, Type), Type, KernelResult)]]
forall a b. (a -> b) -> a -> b
$
                    [PatElemT (VarWisdom, Type)]
-> [Type]
-> [KernelResult]
-> [(PatElemT (VarWisdom, Type), Type, KernelResult)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [PatElemT (VarWisdom, Type)]
red_pes [Type]
red_ts [KernelResult]
red_res,
    ([(SegRedOp (Wise Kernels),
   [(PatElemT (VarWisdom, Type), Type, KernelResult)])]
 -> Bool)
-> [[(SegRedOp (Wise Kernels),
      [(PatElemT (VarWisdom, Type), Type, KernelResult)])]]
-> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any ((Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>Int
1) (Int -> Bool)
-> ([(SegRedOp (Wise Kernels),
      [(PatElemT (VarWisdom, Type), Type, KernelResult)])]
    -> Int)
-> [(SegRedOp (Wise Kernels),
     [(PatElemT (VarWisdom, Type), Type, KernelResult)])]
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(SegRedOp (Wise Kernels),
  [(PatElemT (VarWisdom, Type), Type, KernelResult)])]
-> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length) [[(SegRedOp (Wise Kernels),
   [(PatElemT (VarWisdom, Type), Type, KernelResult)])]]
op_groupings = RuleM (Wise Kernels) () -> Rule (Wise Kernels)
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM (Wise Kernels) () -> Rule (Wise Kernels))
-> RuleM (Wise Kernels) () -> Rule (Wise Kernels)
forall a b. (a -> b) -> a -> b
$ do
      let ([SegRedOp (Wise Kernels)]
ops', [[(PatElemT (VarWisdom, Type), Type, KernelResult)]]
aux) = [(SegRedOp (Wise Kernels),
  [(PatElemT (VarWisdom, Type), Type, KernelResult)])]
-> ([SegRedOp (Wise Kernels)],
    [[(PatElemT (VarWisdom, Type), Type, KernelResult)]])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(SegRedOp (Wise Kernels),
   [(PatElemT (VarWisdom, Type), Type, KernelResult)])]
 -> ([SegRedOp (Wise Kernels)],
     [[(PatElemT (VarWisdom, Type), Type, KernelResult)]]))
-> [(SegRedOp (Wise Kernels),
     [(PatElemT (VarWisdom, Type), Type, KernelResult)])]
-> ([SegRedOp (Wise Kernels)],
    [[(PatElemT (VarWisdom, Type), Type, KernelResult)]])
forall a b. (a -> b) -> a -> b
$ ([(SegRedOp (Wise Kernels),
   [(PatElemT (VarWisdom, Type), Type, KernelResult)])]
 -> Maybe
      (SegRedOp (Wise Kernels),
       [(PatElemT (VarWisdom, Type), Type, KernelResult)]))
-> [[(SegRedOp (Wise Kernels),
      [(PatElemT (VarWisdom, Type), Type, KernelResult)])]]
-> [(SegRedOp (Wise Kernels),
     [(PatElemT (VarWisdom, Type), Type, KernelResult)])]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe [(SegRedOp (Wise Kernels),
  [(PatElemT (VarWisdom, Type), Type, KernelResult)])]
-> Maybe
     (SegRedOp (Wise Kernels),
      [(PatElemT (VarWisdom, Type), Type, KernelResult)])
forall a.
[(SegRedOp (Wise Kernels), [a])]
-> Maybe (SegRedOp (Wise Kernels), [a])
combineOps [[(SegRedOp (Wise Kernels),
   [(PatElemT (VarWisdom, Type), Type, KernelResult)])]]
op_groupings
          ([PatElemT (VarWisdom, Type)]
red_pes', [Type]
red_ts', [KernelResult]
red_res') = [(PatElemT (VarWisdom, Type), Type, KernelResult)]
-> ([PatElemT (VarWisdom, Type)], [Type], [KernelResult])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 ([(PatElemT (VarWisdom, Type), Type, KernelResult)]
 -> ([PatElemT (VarWisdom, Type)], [Type], [KernelResult]))
-> [(PatElemT (VarWisdom, Type), Type, KernelResult)]
-> ([PatElemT (VarWisdom, Type)], [Type], [KernelResult])
forall a b. (a -> b) -> a -> b
$ [[(PatElemT (VarWisdom, Type), Type, KernelResult)]]
-> [(PatElemT (VarWisdom, Type), Type, KernelResult)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[(PatElemT (VarWisdom, Type), Type, KernelResult)]]
aux
          pes' :: [PatElemT (VarWisdom, Type)]
pes' = [PatElemT (VarWisdom, Type)]
red_pes' [PatElemT (VarWisdom, Type)]
-> [PatElemT (VarWisdom, Type)] -> [PatElemT (VarWisdom, Type)]
forall a. [a] -> [a] -> [a]
++ [PatElemT (VarWisdom, Type)]
map_pes
          ts' :: [Type]
ts' = [Type]
red_ts' [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [Type]
map_ts
          kbody' :: KernelBody (Wise Kernels)
kbody' = KernelBody (Wise Kernels)
kbody { kernelBodyResult :: [KernelResult]
kernelBodyResult = [KernelResult]
red_res' [KernelResult] -> [KernelResult] -> [KernelResult]
forall a. [a] -> [a] -> [a]
++ [KernelResult]
map_res }
      Pattern (Lore (RuleM (Wise Kernels)))
-> Exp (Lore (RuleM (Wise Kernels))) -> RuleM (Wise Kernels) ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind_ ([PatElemT (VarWisdom, Type)]
-> [PatElemT (VarWisdom, Type)] -> PatternT (VarWisdom, Type)
forall attr. [PatElemT attr] -> [PatElemT attr] -> PatternT attr
Pattern [] [PatElemT (VarWisdom, Type)]
pes') (Exp (Lore (RuleM (Wise Kernels))) -> RuleM (Wise Kernels) ())
-> Exp (Lore (RuleM (Wise Kernels))) -> RuleM (Wise Kernels) ()
forall a b. (a -> b) -> a -> b
$ Op (Wise Kernels) -> Exp (Wise Kernels)
forall lore. Op lore -> ExpT lore
Op (Op (Wise Kernels) -> Exp (Wise Kernels))
-> Op (Wise Kernels) -> Exp (Wise Kernels)
forall a b. (a -> b) -> a -> b
$ SegOp (Wise Kernels) -> HostOp (Wise Kernels) (SOAC (Wise Kernels))
forall lore op. SegOp lore -> HostOp lore op
SegOp (SegOp (Wise Kernels)
 -> HostOp (Wise Kernels) (SOAC (Wise Kernels)))
-> SegOp (Wise Kernels)
-> HostOp (Wise Kernels) (SOAC (Wise Kernels))
forall a b. (a -> b) -> a -> b
$ SegLevel
-> SegSpace
-> [SegRedOp (Wise Kernels)]
-> [Type]
-> KernelBody (Wise Kernels)
-> SegOp (Wise Kernels)
forall lore.
SegLevel
-> SegSpace
-> [SegRedOp lore]
-> [Type]
-> KernelBody lore
-> SegOp lore
SegRed SegLevel
lvl SegSpace
space [SegRedOp (Wise Kernels)]
ops' [Type]
ts' KernelBody (Wise Kernels)
kbody'
  where ([PatElemT (VarWisdom, Type)]
red_pes, [PatElemT (VarWisdom, Type)]
map_pes) = Int
-> [PatElemT (VarWisdom, Type)]
-> ([PatElemT (VarWisdom, Type)], [PatElemT (VarWisdom, Type)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SegRedOp (Wise Kernels)] -> Int
forall lore. [SegRedOp lore] -> Int
segRedResults [SegRedOp (Wise Kernels)]
ops) [PatElemT (VarWisdom, Type)]
[PatElemT (LetAttr (Wise Kernels))]
pes
        ([Type]
red_ts, [Type]
map_ts) = Int -> [Type] -> ([Type], [Type])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SegRedOp (Wise Kernels)] -> Int
forall lore. [SegRedOp lore] -> Int
segRedResults [SegRedOp (Wise Kernels)]
ops) [Type]
ts
        ([KernelResult]
red_res, [KernelResult]
map_res) = Int -> [KernelResult] -> ([KernelResult], [KernelResult])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SegRedOp (Wise Kernels)] -> Int
forall lore. [SegRedOp lore] -> Int
segRedResults [SegRedOp (Wise Kernels)]
ops) ([KernelResult] -> ([KernelResult], [KernelResult]))
-> [KernelResult] -> ([KernelResult], [KernelResult])
forall a b. (a -> b) -> a -> b
$ KernelBody (Wise Kernels) -> [KernelResult]
forall lore. KernelBody lore -> [KernelResult]
kernelBodyResult KernelBody (Wise Kernels)
kbody

        sameShape :: (SegRedOp lore, b) -> (SegRedOp lore, b) -> Bool
sameShape (SegRedOp lore
op1, b
_) (SegRedOp lore
op2, b
_) = SegRedOp lore -> Shape
forall lore. SegRedOp lore -> Shape
segRedShape SegRedOp lore
op1 Shape -> Shape -> Bool
forall a. Eq a => a -> a -> Bool
== SegRedOp lore -> Shape
forall lore. SegRedOp lore -> Shape
segRedShape SegRedOp lore
op2

        combineOps :: [(SegRedOp (Wise Kernels), [a])]
                   -> Maybe (SegRedOp (Wise Kernels), [a])
        combineOps :: [(SegRedOp (Wise Kernels), [a])]
-> Maybe (SegRedOp (Wise Kernels), [a])
combineOps [] = Maybe (SegRedOp (Wise Kernels), [a])
forall a. Maybe a
Nothing
        combineOps ((SegRedOp (Wise Kernels), [a])
x:[(SegRedOp (Wise Kernels), [a])]
xs) = (SegRedOp (Wise Kernels), [a])
-> Maybe (SegRedOp (Wise Kernels), [a])
forall a. a -> Maybe a
Just ((SegRedOp (Wise Kernels), [a])
 -> Maybe (SegRedOp (Wise Kernels), [a]))
-> (SegRedOp (Wise Kernels), [a])
-> Maybe (SegRedOp (Wise Kernels), [a])
forall a b. (a -> b) -> a -> b
$ ((SegRedOp (Wise Kernels), [a])
 -> (SegRedOp (Wise Kernels), [a])
 -> (SegRedOp (Wise Kernels), [a]))
-> (SegRedOp (Wise Kernels), [a])
-> [(SegRedOp (Wise Kernels), [a])]
-> (SegRedOp (Wise Kernels), [a])
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (SegRedOp (Wise Kernels), [a])
-> (SegRedOp (Wise Kernels), [a]) -> (SegRedOp (Wise Kernels), [a])
forall lore a.
Bindable lore =>
(SegRedOp lore, [a])
-> (SegRedOp lore, [a]) -> (SegRedOp lore, [a])
combine (SegRedOp (Wise Kernels), [a])
x [(SegRedOp (Wise Kernels), [a])]
xs

        combine :: (SegRedOp lore, [a])
-> (SegRedOp lore, [a]) -> (SegRedOp lore, [a])
combine (SegRedOp lore
op1, [a]
op1_aux) (SegRedOp lore
op2, [a]
op2_aux) =
          let lam1 :: Lambda lore
lam1 = SegRedOp lore -> Lambda lore
forall lore. SegRedOp lore -> Lambda lore
segRedLambda SegRedOp lore
op1
              lam2 :: Lambda lore
lam2 = SegRedOp lore -> Lambda lore
forall lore. SegRedOp lore -> Lambda lore
segRedLambda SegRedOp lore
op2
              ([Param (LParamAttr lore)]
op1_xparams, [Param (LParamAttr lore)]
op1_yparams) =
                Int
-> [Param (LParamAttr lore)]
-> ([Param (LParamAttr lore)], [Param (LParamAttr lore)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (SegRedOp lore -> [SubExp]
forall lore. SegRedOp lore -> [SubExp]
segRedNeutral SegRedOp lore
op1)) ([Param (LParamAttr lore)]
 -> ([Param (LParamAttr lore)], [Param (LParamAttr lore)]))
-> [Param (LParamAttr lore)]
-> ([Param (LParamAttr lore)], [Param (LParamAttr lore)])
forall a b. (a -> b) -> a -> b
$ Lambda lore -> [Param (LParamAttr lore)]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda lore
lam1
              ([Param (LParamAttr lore)]
op2_xparams, [Param (LParamAttr lore)]
op2_yparams) =
                Int
-> [Param (LParamAttr lore)]
-> ([Param (LParamAttr lore)], [Param (LParamAttr lore)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (SegRedOp lore -> [SubExp]
forall lore. SegRedOp lore -> [SubExp]
segRedNeutral SegRedOp lore
op2)) ([Param (LParamAttr lore)]
 -> ([Param (LParamAttr lore)], [Param (LParamAttr lore)]))
-> [Param (LParamAttr lore)]
-> ([Param (LParamAttr lore)], [Param (LParamAttr lore)])
forall a b. (a -> b) -> a -> b
$ Lambda lore -> [Param (LParamAttr lore)]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda lore
lam2
              lam :: Lambda lore
lam = Lambda :: forall lore. [LParam lore] -> BodyT lore -> [Type] -> LambdaT lore
Lambda { lambdaParams :: [Param (LParamAttr lore)]
lambdaParams = [Param (LParamAttr lore)]
op1_xparams [Param (LParamAttr lore)]
-> [Param (LParamAttr lore)] -> [Param (LParamAttr lore)]
forall a. [a] -> [a] -> [a]
++ [Param (LParamAttr lore)]
op2_xparams [Param (LParamAttr lore)]
-> [Param (LParamAttr lore)] -> [Param (LParamAttr lore)]
forall a. [a] -> [a] -> [a]
++
                                            [Param (LParamAttr lore)]
op1_yparams [Param (LParamAttr lore)]
-> [Param (LParamAttr lore)] -> [Param (LParamAttr lore)]
forall a. [a] -> [a] -> [a]
++ [Param (LParamAttr lore)]
op2_yparams
                           , lambdaReturnType :: [Type]
lambdaReturnType = Lambda lore -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda lore
lam1 [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ Lambda lore -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda lore
lam2
                           , lambdaBody :: BodyT lore
lambdaBody =
                               Stms lore -> [SubExp] -> BodyT lore
forall lore. Bindable lore => Stms lore -> [SubExp] -> Body lore
mkBody (BodyT lore -> Stms lore
forall lore. BodyT lore -> Stms lore
bodyStms (Lambda lore -> BodyT lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda lore
lam1) Stms lore -> Stms lore -> Stms lore
forall a. Semigroup a => a -> a -> a
<> BodyT lore -> Stms lore
forall lore. BodyT lore -> Stms lore
bodyStms (Lambda lore -> BodyT lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda lore
lam2)) ([SubExp] -> BodyT lore) -> [SubExp] -> BodyT lore
forall a b. (a -> b) -> a -> b
$
                               BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult (Lambda lore -> BodyT lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda lore
lam1) [SubExp] -> [SubExp] -> [SubExp]
forall a. Semigroup a => a -> a -> a
<> BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult (Lambda lore -> BodyT lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda lore
lam2)
                           }
          in (SegRedOp :: forall lore.
Commutativity -> Lambda lore -> [SubExp] -> Shape -> SegRedOp lore
SegRedOp { segRedComm :: Commutativity
segRedComm = SegRedOp lore -> Commutativity
forall lore. SegRedOp lore -> Commutativity
segRedComm SegRedOp lore
op1 Commutativity -> Commutativity -> Commutativity
forall a. Semigroup a => a -> a -> a
<> SegRedOp lore -> Commutativity
forall lore. SegRedOp lore -> Commutativity
segRedComm SegRedOp lore
op2
                       , segRedLambda :: Lambda lore
segRedLambda = Lambda lore
lam
                       , segRedNeutral :: [SubExp]
segRedNeutral = SegRedOp lore -> [SubExp]
forall lore. SegRedOp lore -> [SubExp]
segRedNeutral SegRedOp lore
op1 [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ SegRedOp lore -> [SubExp]
forall lore. SegRedOp lore -> [SubExp]
segRedNeutral SegRedOp lore
op2
                       , segRedShape :: Shape
segRedShape = SegRedOp lore -> Shape
forall lore. SegRedOp lore -> Shape
segRedShape SegRedOp lore
op1 -- Same as shape of op2 due to the grouping.
                       },
               [a]
op1_aux [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ [a]
op2_aux)
mergeSegRedOps TopDown (Wise Kernels)
_ Pattern (Wise Kernels)
_ StmAux (ExpAttr (Wise Kernels))
_ Op (Wise Kernels)
_ = Rule (Wise Kernels)
forall lore. Rule lore
Skip

-- We turn reductions over (solely) iotas into do-loops, because there
-- is no useful structure here anyway.  This is mostly a hack to work
-- around the fact that loop tiling would otherwise pointlessly tile
-- them.
redomapIotaToLoop :: TopDownRuleOp (Wise Kernels)
redomapIotaToLoop :: RuleOp (Wise Kernels) (TopDown (Wise Kernels))
redomapIotaToLoop TopDown (Wise Kernels)
vtable Pattern (Wise Kernels)
pat StmAux (ExpAttr (Wise Kernels))
aux (OtherOp soac@(Screma _ form [arr]))
  | Just ([Reduce (Wise Kernels)], Lambda (Wise Kernels))
_ <- ScremaForm (Wise Kernels)
-> Maybe ([Reduce (Wise Kernels)], Lambda (Wise Kernels))
forall lore. ScremaForm lore -> Maybe ([Reduce lore], Lambda lore)
isRedomapSOAC ScremaForm (Wise Kernels)
form,
    Just (Iota{}, Certificates
_) <- VName
-> TopDown (Wise Kernels)
-> Maybe (BasicOp (Wise Kernels), Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (BasicOp lore, Certificates)
ST.lookupBasicOp VName
arr TopDown (Wise Kernels)
vtable =
      RuleM (Wise Kernels) () -> Rule (Wise Kernels)
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM (Wise Kernels) () -> Rule (Wise Kernels))
-> RuleM (Wise Kernels) () -> Rule (Wise Kernels)
forall a b. (a -> b) -> a -> b
$ Certificates -> RuleM (Wise Kernels) () -> RuleM (Wise Kernels) ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (StmAux (ExpWisdom, ()) -> Certificates
forall attr. StmAux attr -> Certificates
stmAuxCerts StmAux (ExpWisdom, ())
StmAux (ExpAttr (Wise Kernels))
aux) (RuleM (Wise Kernels) () -> RuleM (Wise Kernels) ())
-> RuleM (Wise Kernels) () -> RuleM (Wise Kernels) ()
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (RuleM (Wise Kernels)))
-> SOAC (Lore (RuleM (Wise Kernels))) -> RuleM (Wise Kernels) ()
forall (m :: * -> *).
Transformer m =>
Pattern (Lore m) -> SOAC (Lore m) -> m ()
FOT.transformSOAC Pattern (Lore (RuleM (Wise Kernels)))
Pattern (Wise Kernels)
pat SOAC (Lore (RuleM (Wise Kernels)))
SOAC (Wise Kernels)
soac
redomapIotaToLoop TopDown (Wise Kernels)
_ Pattern (Wise Kernels)
_ StmAux (ExpAttr (Wise Kernels))
_ Op (Wise Kernels)
_ =
  Rule (Wise Kernels)
forall lore. Rule lore
Skip