{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
-- | Segmented operations.  These correspond to perfect @map@ nests on
-- top of /something/, except that the @map@s are conceptually only
-- over @iota@s (so there will be explicit indexing inside them).
module Futhark.IR.SegOp
  ( SegOp(..)
  , SegVirt(..)
  , segLevel
  , segSpace
  , typeCheckSegOp
  , SegSpace(..)
  , scopeOfSegSpace
  , segSpaceDims

    -- * Details
  , HistOp(..)
  , histType
  , SegBinOp(..)
  , segBinOpResults
  , segBinOpChunks
  , KernelBody(..)
  , aliasAnalyseKernelBody
  , consumedInKernelBody
  , ResultManifest(..)
  , KernelResult(..)
  , kernelResultSubExp
  , SplitOrdering(..)

    -- ** Generic traversal
  , SegOpMapper(..)
  , identitySegOpMapper
  , mapSegOpM

    -- * Simplification
  , simplifySegOp
  , HasSegOp(..)
  , segOpRules

    -- * Memory
  , segOpReturns
  )
where

import Control.Monad.State.Strict
import Control.Monad.Writer hiding (mapM_)
import Control.Monad.Identity hiding (mapM_)
import Data.Bifunctor (first)
import qualified Data.Map.Strict as M
import Data.Maybe
import Data.List
  (intersperse, foldl', partition, isPrefixOf, groupBy)

import Futhark.IR
import qualified Futhark.Analysis.Alias as Alias
import qualified Futhark.Analysis.SymbolTable as ST
import qualified Futhark.Analysis.UsageTable as UT
import Futhark.Analysis.PrimExp.Convert
import qualified Futhark.Util.Pretty as PP
import Futhark.Util.Pretty
  ((</>), (<+>), ppr, commasep, Pretty, parens, text)
import Futhark.Transform.Substitute
import Futhark.Transform.Rename
import Futhark.Optimise.Simplify.Lore
import Futhark.IR.Prop.Aliases
import Futhark.IR.Aliases
  (Aliases, removeLambdaAliases, removeStmAliases)
import Futhark.IR.Mem
import qualified Futhark.TypeCheck as TC
import Futhark.Analysis.Metrics
import Futhark.Util (maybeNth, chunks)
import Futhark.Optimise.Simplify.Rule
import qualified Futhark.Optimise.Simplify.Engine as Engine
import Futhark.Tools

-- | How an array is split into chunks.
data SplitOrdering = SplitContiguous
                   | SplitStrided SubExp
                   deriving (SplitOrdering -> SplitOrdering -> Bool
(SplitOrdering -> SplitOrdering -> Bool)
-> (SplitOrdering -> SplitOrdering -> Bool) -> Eq SplitOrdering
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SplitOrdering -> SplitOrdering -> Bool
$c/= :: SplitOrdering -> SplitOrdering -> Bool
== :: SplitOrdering -> SplitOrdering -> Bool
$c== :: SplitOrdering -> SplitOrdering -> Bool
Eq, Eq SplitOrdering
Eq SplitOrdering
-> (SplitOrdering -> SplitOrdering -> Ordering)
-> (SplitOrdering -> SplitOrdering -> Bool)
-> (SplitOrdering -> SplitOrdering -> Bool)
-> (SplitOrdering -> SplitOrdering -> Bool)
-> (SplitOrdering -> SplitOrdering -> Bool)
-> (SplitOrdering -> SplitOrdering -> SplitOrdering)
-> (SplitOrdering -> SplitOrdering -> SplitOrdering)
-> Ord SplitOrdering
SplitOrdering -> SplitOrdering -> Bool
SplitOrdering -> SplitOrdering -> Ordering
SplitOrdering -> SplitOrdering -> SplitOrdering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: SplitOrdering -> SplitOrdering -> SplitOrdering
$cmin :: SplitOrdering -> SplitOrdering -> SplitOrdering
max :: SplitOrdering -> SplitOrdering -> SplitOrdering
$cmax :: SplitOrdering -> SplitOrdering -> SplitOrdering
>= :: SplitOrdering -> SplitOrdering -> Bool
$c>= :: SplitOrdering -> SplitOrdering -> Bool
> :: SplitOrdering -> SplitOrdering -> Bool
$c> :: SplitOrdering -> SplitOrdering -> Bool
<= :: SplitOrdering -> SplitOrdering -> Bool
$c<= :: SplitOrdering -> SplitOrdering -> Bool
< :: SplitOrdering -> SplitOrdering -> Bool
$c< :: SplitOrdering -> SplitOrdering -> Bool
compare :: SplitOrdering -> SplitOrdering -> Ordering
$ccompare :: SplitOrdering -> SplitOrdering -> Ordering
$cp1Ord :: Eq SplitOrdering
Ord, Int -> SplitOrdering -> ShowS
[SplitOrdering] -> ShowS
SplitOrdering -> String
(Int -> SplitOrdering -> ShowS)
-> (SplitOrdering -> String)
-> ([SplitOrdering] -> ShowS)
-> Show SplitOrdering
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SplitOrdering] -> ShowS
$cshowList :: [SplitOrdering] -> ShowS
show :: SplitOrdering -> String
$cshow :: SplitOrdering -> String
showsPrec :: Int -> SplitOrdering -> ShowS
$cshowsPrec :: Int -> SplitOrdering -> ShowS
Show)

instance FreeIn SplitOrdering where
  freeIn' :: SplitOrdering -> FV
freeIn' SplitOrdering
SplitContiguous = FV
forall a. Monoid a => a
mempty
  freeIn' (SplitStrided SubExp
stride) = SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn' SubExp
stride

instance Substitute SplitOrdering where
  substituteNames :: Map VName VName -> SplitOrdering -> SplitOrdering
substituteNames Map VName VName
_ SplitOrdering
SplitContiguous =
    SplitOrdering
SplitContiguous
  substituteNames Map VName VName
subst (SplitStrided SubExp
stride) =
    SubExp -> SplitOrdering
SplitStrided (SubExp -> SplitOrdering) -> SubExp -> SplitOrdering
forall a b. (a -> b) -> a -> b
$ Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst SubExp
stride

instance Rename SplitOrdering where
  rename :: SplitOrdering -> RenameM SplitOrdering
rename SplitOrdering
SplitContiguous =
    SplitOrdering -> RenameM SplitOrdering
forall (f :: * -> *) a. Applicative f => a -> f a
pure SplitOrdering
SplitContiguous
  rename (SplitStrided SubExp
stride) =
    SubExp -> SplitOrdering
SplitStrided (SubExp -> SplitOrdering)
-> RenameM SubExp -> RenameM SplitOrdering
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> RenameM SubExp
forall a. Rename a => a -> RenameM a
rename SubExp
stride

-- | An operator for 'SegHist'.
data HistOp lore =
  HistOp { HistOp lore -> SubExp
histWidth :: SubExp
         , HistOp lore -> SubExp
histRaceFactor :: SubExp
         , HistOp lore -> [VName]
histDest :: [VName]
         , HistOp lore -> [SubExp]
histNeutral :: [SubExp]
         , HistOp lore -> Shape
histShape :: Shape
           -- ^ In case this operator is semantically a vectorised
           -- operator (corresponding to a perfect map nest in the
           -- SOACS representation), these are the logical
           -- "dimensions".  This is used to generate more efficient
           -- code.
         , HistOp lore -> Lambda lore
histOp :: Lambda lore
         }
  deriving (HistOp lore -> HistOp lore -> Bool
(HistOp lore -> HistOp lore -> Bool)
-> (HistOp lore -> HistOp lore -> Bool) -> Eq (HistOp lore)
forall lore. Decorations lore => HistOp lore -> HistOp lore -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: HistOp lore -> HistOp lore -> Bool
$c/= :: forall lore. Decorations lore => HistOp lore -> HistOp lore -> Bool
== :: HistOp lore -> HistOp lore -> Bool
$c== :: forall lore. Decorations lore => HistOp lore -> HistOp lore -> Bool
Eq, Eq (HistOp lore)
Eq (HistOp lore)
-> (HistOp lore -> HistOp lore -> Ordering)
-> (HistOp lore -> HistOp lore -> Bool)
-> (HistOp lore -> HistOp lore -> Bool)
-> (HistOp lore -> HistOp lore -> Bool)
-> (HistOp lore -> HistOp lore -> Bool)
-> (HistOp lore -> HistOp lore -> HistOp lore)
-> (HistOp lore -> HistOp lore -> HistOp lore)
-> Ord (HistOp lore)
HistOp lore -> HistOp lore -> Bool
HistOp lore -> HistOp lore -> Ordering
HistOp lore -> HistOp lore -> HistOp lore
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall lore. Decorations lore => Eq (HistOp lore)
forall lore. Decorations lore => HistOp lore -> HistOp lore -> Bool
forall lore.
Decorations lore =>
HistOp lore -> HistOp lore -> Ordering
forall lore.
Decorations lore =>
HistOp lore -> HistOp lore -> HistOp lore
min :: HistOp lore -> HistOp lore -> HistOp lore
$cmin :: forall lore.
Decorations lore =>
HistOp lore -> HistOp lore -> HistOp lore
max :: HistOp lore -> HistOp lore -> HistOp lore
$cmax :: forall lore.
Decorations lore =>
HistOp lore -> HistOp lore -> HistOp lore
>= :: HistOp lore -> HistOp lore -> Bool
$c>= :: forall lore. Decorations lore => HistOp lore -> HistOp lore -> Bool
> :: HistOp lore -> HistOp lore -> Bool
$c> :: forall lore. Decorations lore => HistOp lore -> HistOp lore -> Bool
<= :: HistOp lore -> HistOp lore -> Bool
$c<= :: forall lore. Decorations lore => HistOp lore -> HistOp lore -> Bool
< :: HistOp lore -> HistOp lore -> Bool
$c< :: forall lore. Decorations lore => HistOp lore -> HistOp lore -> Bool
compare :: HistOp lore -> HistOp lore -> Ordering
$ccompare :: forall lore.
Decorations lore =>
HistOp lore -> HistOp lore -> Ordering
$cp1Ord :: forall lore. Decorations lore => Eq (HistOp lore)
Ord, Int -> HistOp lore -> ShowS
[HistOp lore] -> ShowS
HistOp lore -> String
(Int -> HistOp lore -> ShowS)
-> (HistOp lore -> String)
-> ([HistOp lore] -> ShowS)
-> Show (HistOp lore)
forall lore. Decorations lore => Int -> HistOp lore -> ShowS
forall lore. Decorations lore => [HistOp lore] -> ShowS
forall lore. Decorations lore => HistOp lore -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [HistOp lore] -> ShowS
$cshowList :: forall lore. Decorations lore => [HistOp lore] -> ShowS
show :: HistOp lore -> String
$cshow :: forall lore. Decorations lore => HistOp lore -> String
showsPrec :: Int -> HistOp lore -> ShowS
$cshowsPrec :: forall lore. Decorations lore => Int -> HistOp lore -> ShowS
Show)

-- | The type of a histogram produced by a 'HistOp'.  This can be
-- different from the type of the 'histDest's in case we are
-- dealing with a segmented histogram.
histType :: HistOp lore -> [Type]
histType :: HistOp lore -> [Type]
histType HistOp lore
op = (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map ((Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` HistOp lore -> SubExp
forall lore. HistOp lore -> SubExp
histWidth HistOp lore
op) (Type -> Type) -> (Type -> Type) -> Type -> Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
.
                        (Type -> Shape -> Type
`arrayOfShape` HistOp lore -> Shape
forall lore. HistOp lore -> Shape
histShape HistOp lore
op)) ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$
                   LambdaT lore -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType (LambdaT lore -> [Type]) -> LambdaT lore -> [Type]
forall a b. (a -> b) -> a -> b
$ HistOp lore -> LambdaT lore
forall lore. HistOp lore -> Lambda lore
histOp HistOp lore
op

-- | An operator for 'SegScan' and 'SegRed'.
data SegBinOp lore =
  SegBinOp { SegBinOp lore -> Commutativity
segBinOpComm :: Commutativity
           , SegBinOp lore -> Lambda lore
segBinOpLambda :: Lambda lore
           , SegBinOp lore -> [SubExp]
segBinOpNeutral :: [SubExp]
           , SegBinOp lore -> Shape
segBinOpShape :: Shape
             -- ^ In case this operator is semantically a vectorised
             -- operator (corresponding to a perfect map nest in the
             -- SOACS representation), these are the logical
             -- "dimensions".  This is used to generate more efficient
             -- code.
           }
  deriving (SegBinOp lore -> SegBinOp lore -> Bool
(SegBinOp lore -> SegBinOp lore -> Bool)
-> (SegBinOp lore -> SegBinOp lore -> Bool) -> Eq (SegBinOp lore)
forall lore.
Decorations lore =>
SegBinOp lore -> SegBinOp lore -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SegBinOp lore -> SegBinOp lore -> Bool
$c/= :: forall lore.
Decorations lore =>
SegBinOp lore -> SegBinOp lore -> Bool
== :: SegBinOp lore -> SegBinOp lore -> Bool
$c== :: forall lore.
Decorations lore =>
SegBinOp lore -> SegBinOp lore -> Bool
Eq, Eq (SegBinOp lore)
Eq (SegBinOp lore)
-> (SegBinOp lore -> SegBinOp lore -> Ordering)
-> (SegBinOp lore -> SegBinOp lore -> Bool)
-> (SegBinOp lore -> SegBinOp lore -> Bool)
-> (SegBinOp lore -> SegBinOp lore -> Bool)
-> (SegBinOp lore -> SegBinOp lore -> Bool)
-> (SegBinOp lore -> SegBinOp lore -> SegBinOp lore)
-> (SegBinOp lore -> SegBinOp lore -> SegBinOp lore)
-> Ord (SegBinOp lore)
SegBinOp lore -> SegBinOp lore -> Bool
SegBinOp lore -> SegBinOp lore -> Ordering
SegBinOp lore -> SegBinOp lore -> SegBinOp lore
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall lore. Decorations lore => Eq (SegBinOp lore)
forall lore.
Decorations lore =>
SegBinOp lore -> SegBinOp lore -> Bool
forall lore.
Decorations lore =>
SegBinOp lore -> SegBinOp lore -> Ordering
forall lore.
Decorations lore =>
SegBinOp lore -> SegBinOp lore -> SegBinOp lore
min :: SegBinOp lore -> SegBinOp lore -> SegBinOp lore
$cmin :: forall lore.
Decorations lore =>
SegBinOp lore -> SegBinOp lore -> SegBinOp lore
max :: SegBinOp lore -> SegBinOp lore -> SegBinOp lore
$cmax :: forall lore.
Decorations lore =>
SegBinOp lore -> SegBinOp lore -> SegBinOp lore
>= :: SegBinOp lore -> SegBinOp lore -> Bool
$c>= :: forall lore.
Decorations lore =>
SegBinOp lore -> SegBinOp lore -> Bool
> :: SegBinOp lore -> SegBinOp lore -> Bool
$c> :: forall lore.
Decorations lore =>
SegBinOp lore -> SegBinOp lore -> Bool
<= :: SegBinOp lore -> SegBinOp lore -> Bool
$c<= :: forall lore.
Decorations lore =>
SegBinOp lore -> SegBinOp lore -> Bool
< :: SegBinOp lore -> SegBinOp lore -> Bool
$c< :: forall lore.
Decorations lore =>
SegBinOp lore -> SegBinOp lore -> Bool
compare :: SegBinOp lore -> SegBinOp lore -> Ordering
$ccompare :: forall lore.
Decorations lore =>
SegBinOp lore -> SegBinOp lore -> Ordering
$cp1Ord :: forall lore. Decorations lore => Eq (SegBinOp lore)
Ord, Int -> SegBinOp lore -> ShowS
[SegBinOp lore] -> ShowS
SegBinOp lore -> String
(Int -> SegBinOp lore -> ShowS)
-> (SegBinOp lore -> String)
-> ([SegBinOp lore] -> ShowS)
-> Show (SegBinOp lore)
forall lore. Decorations lore => Int -> SegBinOp lore -> ShowS
forall lore. Decorations lore => [SegBinOp lore] -> ShowS
forall lore. Decorations lore => SegBinOp lore -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SegBinOp lore] -> ShowS
$cshowList :: forall lore. Decorations lore => [SegBinOp lore] -> ShowS
show :: SegBinOp lore -> String
$cshow :: forall lore. Decorations lore => SegBinOp lore -> String
showsPrec :: Int -> SegBinOp lore -> ShowS
$cshowsPrec :: forall lore. Decorations lore => Int -> SegBinOp lore -> ShowS
Show)

-- | How many reduction results are produced by these 'SegBinOp's?
segBinOpResults :: [SegBinOp lore] -> Int
segBinOpResults :: [SegBinOp lore] -> Int
segBinOpResults = [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Int] -> Int)
-> ([SegBinOp lore] -> [Int]) -> [SegBinOp lore] -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SegBinOp lore -> Int) -> [SegBinOp lore] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int)
-> (SegBinOp lore -> [SubExp]) -> SegBinOp lore -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOp lore -> [SubExp]
forall lore. SegBinOp lore -> [SubExp]
segBinOpNeutral)

-- | Split some list into chunks equal to the number of values
-- returned by each 'SegBinOp'
segBinOpChunks :: [SegBinOp lore] -> [a] -> [[a]]
segBinOpChunks :: [SegBinOp lore] -> [a] -> [[a]]
segBinOpChunks = [Int] -> [a] -> [[a]]
forall a. [Int] -> [a] -> [[a]]
chunks ([Int] -> [a] -> [[a]])
-> ([SegBinOp lore] -> [Int]) -> [SegBinOp lore] -> [a] -> [[a]]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SegBinOp lore -> Int) -> [SegBinOp lore] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int)
-> (SegBinOp lore -> [SubExp]) -> SegBinOp lore -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOp lore -> [SubExp]
forall lore. SegBinOp lore -> [SubExp]
segBinOpNeutral)

-- | The body of a 'SegOp'.
data KernelBody lore = KernelBody { KernelBody lore -> BodyDec lore
kernelBodyLore :: BodyDec lore
                                  , KernelBody lore -> Stms lore
kernelBodyStms :: Stms lore
                                  , KernelBody lore -> [KernelResult]
kernelBodyResult :: [KernelResult]
                                  }

deriving instance Decorations lore => Ord (KernelBody lore)
deriving instance Decorations lore => Show (KernelBody lore)
deriving instance Decorations lore => Eq (KernelBody lore)

-- | Metadata about whether there is a subtle point to this
-- 'KernelResult'.  This is used to protect things like tiling, which
-- might otherwise be removed by the simplifier because they're
-- semantically redundant.  This has no semantic effect and can be
-- ignored at code generation.
data ResultManifest
  = ResultNoSimplify
    -- ^ Don't simplify this one!
  | ResultMaySimplify
    -- ^ Go nuts.
  | ResultPrivate
    -- ^ The results produced are only used within the
    -- same physical thread later on, and can thus be
    -- kept in registers.
  deriving (ResultManifest -> ResultManifest -> Bool
(ResultManifest -> ResultManifest -> Bool)
-> (ResultManifest -> ResultManifest -> Bool) -> Eq ResultManifest
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ResultManifest -> ResultManifest -> Bool
$c/= :: ResultManifest -> ResultManifest -> Bool
== :: ResultManifest -> ResultManifest -> Bool
$c== :: ResultManifest -> ResultManifest -> Bool
Eq, Int -> ResultManifest -> ShowS
[ResultManifest] -> ShowS
ResultManifest -> String
(Int -> ResultManifest -> ShowS)
-> (ResultManifest -> String)
-> ([ResultManifest] -> ShowS)
-> Show ResultManifest
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ResultManifest] -> ShowS
$cshowList :: [ResultManifest] -> ShowS
show :: ResultManifest -> String
$cshow :: ResultManifest -> String
showsPrec :: Int -> ResultManifest -> ShowS
$cshowsPrec :: Int -> ResultManifest -> ShowS
Show, Eq ResultManifest
Eq ResultManifest
-> (ResultManifest -> ResultManifest -> Ordering)
-> (ResultManifest -> ResultManifest -> Bool)
-> (ResultManifest -> ResultManifest -> Bool)
-> (ResultManifest -> ResultManifest -> Bool)
-> (ResultManifest -> ResultManifest -> Bool)
-> (ResultManifest -> ResultManifest -> ResultManifest)
-> (ResultManifest -> ResultManifest -> ResultManifest)
-> Ord ResultManifest
ResultManifest -> ResultManifest -> Bool
ResultManifest -> ResultManifest -> Ordering
ResultManifest -> ResultManifest -> ResultManifest
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: ResultManifest -> ResultManifest -> ResultManifest
$cmin :: ResultManifest -> ResultManifest -> ResultManifest
max :: ResultManifest -> ResultManifest -> ResultManifest
$cmax :: ResultManifest -> ResultManifest -> ResultManifest
>= :: ResultManifest -> ResultManifest -> Bool
$c>= :: ResultManifest -> ResultManifest -> Bool
> :: ResultManifest -> ResultManifest -> Bool
$c> :: ResultManifest -> ResultManifest -> Bool
<= :: ResultManifest -> ResultManifest -> Bool
$c<= :: ResultManifest -> ResultManifest -> Bool
< :: ResultManifest -> ResultManifest -> Bool
$c< :: ResultManifest -> ResultManifest -> Bool
compare :: ResultManifest -> ResultManifest -> Ordering
$ccompare :: ResultManifest -> ResultManifest -> Ordering
$cp1Ord :: Eq ResultManifest
Ord)

-- | A 'KernelBody' does not return an ordinary 'Result'.  Instead, it
-- returns a list of these.
data KernelResult = Returns ResultManifest SubExp
                    -- ^ Each "worker" in the kernel returns this.
                    -- Whether this is a result-per-thread or a
                    -- result-per-group depends on where the 'SegOp' occurs.
                  | WriteReturns
                    [SubExp] -- Size of array.  Must match number of dims.
                    VName -- Which array
                    [(Slice SubExp, SubExp)]
                    -- Arbitrary number of index/value pairs.
                  | ConcatReturns
                    SplitOrdering -- Permuted?
                    SubExp -- The final size.
                    SubExp -- Per-thread/group (max) chunk size.
                    VName -- Chunk by this worker.
                  | TileReturns
                    [(SubExp, SubExp)] -- Total/tile for each dimension
                    VName -- Tile written by this worker.
                    -- The TileReturns must not expect more than one
                    -- result to be written per physical thread.
                  deriving (KernelResult -> KernelResult -> Bool
(KernelResult -> KernelResult -> Bool)
-> (KernelResult -> KernelResult -> Bool) -> Eq KernelResult
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: KernelResult -> KernelResult -> Bool
$c/= :: KernelResult -> KernelResult -> Bool
== :: KernelResult -> KernelResult -> Bool
$c== :: KernelResult -> KernelResult -> Bool
Eq, Int -> KernelResult -> ShowS
[KernelResult] -> ShowS
KernelResult -> String
(Int -> KernelResult -> ShowS)
-> (KernelResult -> String)
-> ([KernelResult] -> ShowS)
-> Show KernelResult
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [KernelResult] -> ShowS
$cshowList :: [KernelResult] -> ShowS
show :: KernelResult -> String
$cshow :: KernelResult -> String
showsPrec :: Int -> KernelResult -> ShowS
$cshowsPrec :: Int -> KernelResult -> ShowS
Show, Eq KernelResult
Eq KernelResult
-> (KernelResult -> KernelResult -> Ordering)
-> (KernelResult -> KernelResult -> Bool)
-> (KernelResult -> KernelResult -> Bool)
-> (KernelResult -> KernelResult -> Bool)
-> (KernelResult -> KernelResult -> Bool)
-> (KernelResult -> KernelResult -> KernelResult)
-> (KernelResult -> KernelResult -> KernelResult)
-> Ord KernelResult
KernelResult -> KernelResult -> Bool
KernelResult -> KernelResult -> Ordering
KernelResult -> KernelResult -> KernelResult
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: KernelResult -> KernelResult -> KernelResult
$cmin :: KernelResult -> KernelResult -> KernelResult
max :: KernelResult -> KernelResult -> KernelResult
$cmax :: KernelResult -> KernelResult -> KernelResult
>= :: KernelResult -> KernelResult -> Bool
$c>= :: KernelResult -> KernelResult -> Bool
> :: KernelResult -> KernelResult -> Bool
$c> :: KernelResult -> KernelResult -> Bool
<= :: KernelResult -> KernelResult -> Bool
$c<= :: KernelResult -> KernelResult -> Bool
< :: KernelResult -> KernelResult -> Bool
$c< :: KernelResult -> KernelResult -> Bool
compare :: KernelResult -> KernelResult -> Ordering
$ccompare :: KernelResult -> KernelResult -> Ordering
$cp1Ord :: Eq KernelResult
Ord)

-- | Get the root t'SubExp' corresponding values for a 'KernelResult'.
kernelResultSubExp :: KernelResult -> SubExp
kernelResultSubExp :: KernelResult -> SubExp
kernelResultSubExp (Returns ResultManifest
_ SubExp
se) = SubExp
se
kernelResultSubExp (WriteReturns [SubExp]
_ VName
arr [(Slice SubExp, SubExp)]
_) = VName -> SubExp
Var VName
arr
kernelResultSubExp (ConcatReturns SplitOrdering
_ SubExp
_ SubExp
_ VName
v) = VName -> SubExp
Var VName
v
kernelResultSubExp (TileReturns [(SubExp, SubExp)]
_ VName
v) = VName -> SubExp
Var VName
v

instance FreeIn KernelResult where
  freeIn' :: KernelResult -> FV
freeIn' (Returns ResultManifest
_ SubExp
what) = SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn' SubExp
what
  freeIn' (WriteReturns [SubExp]
rws VName
arr [(Slice SubExp, SubExp)]
res) = [SubExp] -> FV
forall a. FreeIn a => a -> FV
freeIn' [SubExp]
rws FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> VName -> FV
forall a. FreeIn a => a -> FV
freeIn' VName
arr FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> [(Slice SubExp, SubExp)] -> FV
forall a. FreeIn a => a -> FV
freeIn' [(Slice SubExp, SubExp)]
res
  freeIn' (ConcatReturns SplitOrdering
o SubExp
w SubExp
per_thread_elems VName
v) =
    SplitOrdering -> FV
forall a. FreeIn a => a -> FV
freeIn' SplitOrdering
o FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn' SubExp
w FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn' SubExp
per_thread_elems FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> VName -> FV
forall a. FreeIn a => a -> FV
freeIn' VName
v
  freeIn' (TileReturns [(SubExp, SubExp)]
dims VName
v) =
    [(SubExp, SubExp)] -> FV
forall a. FreeIn a => a -> FV
freeIn' [(SubExp, SubExp)]
dims FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> VName -> FV
forall a. FreeIn a => a -> FV
freeIn' VName
v

instance ASTLore lore => FreeIn (KernelBody lore) where
  freeIn' :: KernelBody lore -> FV
freeIn' (KernelBody BodyDec lore
dec Stms lore
stms [KernelResult]
res) =
    Names -> FV -> FV
fvBind Names
bound_in_stms (FV -> FV) -> FV -> FV
forall a b. (a -> b) -> a -> b
$ BodyDec lore -> FV
forall a. FreeIn a => a -> FV
freeIn' BodyDec lore
dec FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> Stms lore -> FV
forall a. FreeIn a => a -> FV
freeIn' Stms lore
stms FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> [KernelResult] -> FV
forall a. FreeIn a => a -> FV
freeIn' [KernelResult]
res
    where bound_in_stms :: Names
bound_in_stms = (Stm lore -> Names) -> Stms lore -> Names
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm lore -> Names
forall lore. Stm lore -> Names
boundByStm Stms lore
stms

instance ASTLore lore => Substitute (KernelBody lore) where
  substituteNames :: Map VName VName -> KernelBody lore -> KernelBody lore
substituteNames Map VName VName
subst (KernelBody BodyDec lore
dec Stms lore
stms [KernelResult]
res) =
    BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
forall lore.
BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody
    (Map VName VName -> BodyDec lore -> BodyDec lore
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst BodyDec lore
dec)
    (Map VName VName -> Stms lore -> Stms lore
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst Stms lore
stms)
    (Map VName VName -> [KernelResult] -> [KernelResult]
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst [KernelResult]
res)

instance Substitute KernelResult where
  substituteNames :: Map VName VName -> KernelResult -> KernelResult
substituteNames Map VName VName
subst (Returns ResultManifest
manifest SubExp
se) =
    ResultManifest -> SubExp -> KernelResult
Returns ResultManifest
manifest (Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst SubExp
se)
  substituteNames Map VName VName
subst (WriteReturns [SubExp]
rws VName
arr [(Slice SubExp, SubExp)]
res) =
    [SubExp] -> VName -> [(Slice SubExp, SubExp)] -> KernelResult
WriteReturns
    (Map VName VName -> [SubExp] -> [SubExp]
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst [SubExp]
rws) (Map VName VName -> VName -> VName
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst VName
arr)
    (Map VName VName
-> [(Slice SubExp, SubExp)] -> [(Slice SubExp, SubExp)]
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst [(Slice SubExp, SubExp)]
res)
  substituteNames Map VName VName
subst (ConcatReturns SplitOrdering
o SubExp
w SubExp
per_thread_elems VName
v) =
    SplitOrdering -> SubExp -> SubExp -> VName -> KernelResult
ConcatReturns
    (Map VName VName -> SplitOrdering -> SplitOrdering
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst SplitOrdering
o)
    (Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst SubExp
w)
    (Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst SubExp
per_thread_elems)
    (Map VName VName -> VName -> VName
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst VName
v)
  substituteNames Map VName VName
subst (TileReturns [(SubExp, SubExp)]
dims VName
v) =
    [(SubExp, SubExp)] -> VName -> KernelResult
TileReturns (Map VName VName -> [(SubExp, SubExp)] -> [(SubExp, SubExp)]
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst [(SubExp, SubExp)]
dims) (Map VName VName -> VName -> VName
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst VName
v)

instance ASTLore lore => Rename (KernelBody lore) where
  rename :: KernelBody lore -> RenameM (KernelBody lore)
rename (KernelBody BodyDec lore
dec Stms lore
stms [KernelResult]
res) = do
    BodyDec lore
dec' <- BodyDec lore -> RenameM (BodyDec lore)
forall a. Rename a => a -> RenameM a
rename BodyDec lore
dec
    Stms lore
-> (Stms lore -> RenameM (KernelBody lore))
-> RenameM (KernelBody lore)
forall lore a.
Renameable lore =>
Stms lore -> (Stms lore -> RenameM a) -> RenameM a
renamingStms Stms lore
stms ((Stms lore -> RenameM (KernelBody lore))
 -> RenameM (KernelBody lore))
-> (Stms lore -> RenameM (KernelBody lore))
-> RenameM (KernelBody lore)
forall a b. (a -> b) -> a -> b
$ \Stms lore
stms' ->
      BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
forall lore.
BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody BodyDec lore
dec' Stms lore
stms' ([KernelResult] -> KernelBody lore)
-> RenameM [KernelResult] -> RenameM (KernelBody lore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [KernelResult] -> RenameM [KernelResult]
forall a. Rename a => a -> RenameM a
rename [KernelResult]
res

instance Rename KernelResult where
  rename :: KernelResult -> RenameM KernelResult
rename = KernelResult -> RenameM KernelResult
forall a. Substitute a => a -> RenameM a
substituteRename

-- | Perform alias analysis on a 'KernelBody'.
aliasAnalyseKernelBody :: (ASTLore lore,
                           CanBeAliased (Op lore)) =>
                          KernelBody lore
                       -> KernelBody (Aliases lore)
aliasAnalyseKernelBody :: KernelBody lore -> KernelBody (Aliases lore)
aliasAnalyseKernelBody (KernelBody BodyDec lore
dec Stms lore
stms [KernelResult]
res) =
  let Body BodyDec (Aliases lore)
dec' Stms (Aliases lore)
stms' [SubExp]
_ = AliasTable -> Body lore -> BodyT (Aliases lore)
forall lore.
(ASTLore lore, CanBeAliased (Op lore)) =>
AliasTable -> Body lore -> Body (Aliases lore)
Alias.analyseBody AliasTable
forall a. Monoid a => a
mempty (Body lore -> BodyT (Aliases lore))
-> Body lore -> BodyT (Aliases lore)
forall a b. (a -> b) -> a -> b
$ BodyDec lore -> Stms lore -> [SubExp] -> Body lore
forall lore. BodyDec lore -> Stms lore -> [SubExp] -> BodyT lore
Body BodyDec lore
dec Stms lore
stms []
  in BodyDec (Aliases lore)
-> Stms (Aliases lore)
-> [KernelResult]
-> KernelBody (Aliases lore)
forall lore.
BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody BodyDec (Aliases lore)
dec' Stms (Aliases lore)
stms' [KernelResult]
res

removeKernelBodyAliases :: CanBeAliased (Op lore) =>
                           KernelBody (Aliases lore) -> KernelBody lore
removeKernelBodyAliases :: KernelBody (Aliases lore) -> KernelBody lore
removeKernelBodyAliases (KernelBody (_, dec) Stms (Aliases lore)
stms [KernelResult]
res) =
  BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
forall lore.
BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody BodyDec lore
dec ((Stm (Aliases lore) -> Stm lore)
-> Stms (Aliases lore) -> Stms lore
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Stm (Aliases lore) -> Stm lore
forall lore.
CanBeAliased (Op lore) =>
Stm (Aliases lore) -> Stm lore
removeStmAliases Stms (Aliases lore)
stms) [KernelResult]
res

removeKernelBodyWisdom :: CanBeWise (Op lore) =>
                          KernelBody (Wise lore) -> KernelBody lore
removeKernelBodyWisdom :: KernelBody (Wise lore) -> KernelBody lore
removeKernelBodyWisdom (KernelBody BodyDec (Wise lore)
dec Stms (Wise lore)
stms [KernelResult]
res) =
  let Body BodyDec lore
dec' Stms lore
stms' [SubExp]
_ = Body (Wise lore) -> BodyT lore
forall lore. CanBeWise (Op lore) => Body (Wise lore) -> Body lore
removeBodyWisdom (Body (Wise lore) -> BodyT lore) -> Body (Wise lore) -> BodyT lore
forall a b. (a -> b) -> a -> b
$ BodyDec (Wise lore)
-> Stms (Wise lore) -> [SubExp] -> Body (Wise lore)
forall lore. BodyDec lore -> Stms lore -> [SubExp] -> BodyT lore
Body BodyDec (Wise lore)
dec Stms (Wise lore)
stms []
  in BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
forall lore.
BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody BodyDec lore
dec' Stms lore
stms' [KernelResult]
res

-- | The variables consumed in the kernel body.
consumedInKernelBody :: Aliased lore =>
                        KernelBody lore -> Names
consumedInKernelBody :: KernelBody lore -> Names
consumedInKernelBody (KernelBody BodyDec lore
dec Stms lore
stms [KernelResult]
res) =
  Body lore -> Names
forall lore. Aliased lore => Body lore -> Names
consumedInBody (BodyDec lore -> Stms lore -> [SubExp] -> Body lore
forall lore. BodyDec lore -> Stms lore -> [SubExp] -> BodyT lore
Body BodyDec lore
dec Stms lore
stms []) Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ((KernelResult -> Names) -> [KernelResult] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> Names
consumedByReturn [KernelResult]
res)
  where consumedByReturn :: KernelResult -> Names
consumedByReturn (WriteReturns [SubExp]
_ VName
a [(Slice SubExp, SubExp)]
_) = VName -> Names
oneName VName
a
        consumedByReturn KernelResult
_                    = Names
forall a. Monoid a => a
mempty

checkKernelBody :: TC.Checkable lore =>
                   [Type] -> KernelBody (Aliases lore) -> TC.TypeM lore ()
checkKernelBody :: [Type] -> KernelBody (Aliases lore) -> TypeM lore ()
checkKernelBody [Type]
ts (KernelBody (_, dec) Stms (Aliases lore)
stms [KernelResult]
kres) = do
  BodyDec lore -> TypeM lore ()
forall lore. Checkable lore => BodyDec lore -> TypeM lore ()
TC.checkBodyLore BodyDec lore
dec
  Stms (Aliases lore) -> TypeM lore () -> TypeM lore ()
forall lore a.
Checkable lore =>
Stms (Aliases lore) -> TypeM lore a -> TypeM lore a
TC.checkStms Stms (Aliases lore)
stms (TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ do
    Bool -> TypeM lore () -> TypeM lore ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
ts Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [KernelResult] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [KernelResult]
kres) (TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
      ErrorCase lore -> TypeM lore ()
forall lore a. ErrorCase lore -> TypeM lore a
TC.bad (ErrorCase lore -> TypeM lore ())
-> ErrorCase lore -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ String -> ErrorCase lore
forall lore. String -> ErrorCase lore
TC.TypeError (String -> ErrorCase lore) -> String -> ErrorCase lore
forall a b. (a -> b) -> a -> b
$ String
"Kernel return type is " String -> ShowS
forall a. [a] -> [a] -> [a]
++ [Type] -> String
forall a. Pretty a => [a] -> String
prettyTuple [Type]
ts String -> ShowS
forall a. [a] -> [a] -> [a]
++
      String
", but body returns " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show ([KernelResult] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [KernelResult]
kres) String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" values."
    (KernelResult -> Type -> TypeM lore ())
-> [KernelResult] -> [Type] -> TypeM lore ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ KernelResult -> Type -> TypeM lore ()
forall lore.
Checkable lore =>
KernelResult -> Type -> TypeM lore ()
checkKernelResult [KernelResult]
kres [Type]
ts

  where checkKernelResult :: KernelResult -> Type -> TypeM lore ()
checkKernelResult (Returns ResultManifest
_ SubExp
what) Type
t =
          [Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [Type
t] SubExp
what
        checkKernelResult (WriteReturns [SubExp]
rws VName
arr [(Slice SubExp, SubExp)]
res) Type
t = do
          (SubExp -> TypeM lore ()) -> [SubExp] -> TypeM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ([Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32]) [SubExp]
rws
          Type
arr_t <- VName -> TypeM lore Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
arr
          [(Slice SubExp, SubExp)]
-> ((Slice SubExp, SubExp) -> TypeM lore ()) -> TypeM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(Slice SubExp, SubExp)]
res (((Slice SubExp, SubExp) -> TypeM lore ()) -> TypeM lore ())
-> ((Slice SubExp, SubExp) -> TypeM lore ()) -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ \(Slice SubExp
slice, SubExp
e) -> do
            (DimIndex SubExp -> TypeM lore (DimIndex ()))
-> Slice SubExp -> TypeM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ((SubExp -> TypeM lore ())
-> DimIndex SubExp -> TypeM lore (DimIndex ())
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse ((SubExp -> TypeM lore ())
 -> DimIndex SubExp -> TypeM lore (DimIndex ()))
-> (SubExp -> TypeM lore ())
-> DimIndex SubExp
-> TypeM lore (DimIndex ())
forall a b. (a -> b) -> a -> b
$ [Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32]) Slice SubExp
slice
            [Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [Type
t] SubExp
e
            Bool -> TypeM lore () -> TypeM lore ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Type
arr_t Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
== Type
t Type -> Shape -> Type
`arrayOfShape` [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
rws) (TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
              ErrorCase lore -> TypeM lore ()
forall lore a. ErrorCase lore -> TypeM lore a
TC.bad (ErrorCase lore -> TypeM lore ())
-> ErrorCase lore -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ String -> ErrorCase lore
forall lore. String -> ErrorCase lore
TC.TypeError (String -> ErrorCase lore) -> String -> ErrorCase lore
forall a b. (a -> b) -> a -> b
$ String
"WriteReturns returning " String -> ShowS
forall a. [a] -> [a] -> [a]
++
              SubExp -> String
forall a. Pretty a => a -> String
pretty SubExp
e String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" of type " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Type -> String
forall a. Pretty a => a -> String
pretty Type
t String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
", shape=" String -> ShowS
forall a. [a] -> [a] -> [a]
++ [SubExp] -> String
forall a. Pretty a => a -> String
pretty [SubExp]
rws String -> ShowS
forall a. [a] -> [a] -> [a]
++
              String
", but destination array has type " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Type -> String
forall a. Pretty a => a -> String
pretty Type
arr_t
          Names -> TypeM lore ()
forall lore. Checkable lore => Names -> TypeM lore ()
TC.consume (Names -> TypeM lore ()) -> TypeM lore Names -> TypeM lore ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> TypeM lore Names
forall lore. Checkable lore => VName -> TypeM lore Names
TC.lookupAliases VName
arr
        checkKernelResult (ConcatReturns SplitOrdering
o SubExp
w SubExp
per_thread_elems VName
v) Type
t = do
          case SplitOrdering
o of
            SplitOrdering
SplitContiguous     -> () -> TypeM lore ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
            SplitStrided SubExp
stride -> [Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32] SubExp
stride
          [Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32] SubExp
w
          [Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32] SubExp
per_thread_elems
          Type
vt <- VName -> TypeM lore Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
v
          Bool -> TypeM lore () -> TypeM lore ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Type
vt Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
== Type
t Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` Int -> Type -> SubExp
forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 Type
vt) (TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
            ErrorCase lore -> TypeM lore ()
forall lore a. ErrorCase lore -> TypeM lore a
TC.bad (ErrorCase lore -> TypeM lore ())
-> ErrorCase lore -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ String -> ErrorCase lore
forall lore. String -> ErrorCase lore
TC.TypeError (String -> ErrorCase lore) -> String -> ErrorCase lore
forall a b. (a -> b) -> a -> b
$ String
"Invalid type for ConcatReturns " String -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
pretty VName
v
        checkKernelResult (TileReturns [(SubExp, SubExp)]
dims VName
v) Type
t = do
          [(SubExp, SubExp)]
-> ((SubExp, SubExp) -> TypeM lore ()) -> TypeM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(SubExp, SubExp)]
dims (((SubExp, SubExp) -> TypeM lore ()) -> TypeM lore ())
-> ((SubExp, SubExp) -> TypeM lore ()) -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ \(SubExp
dim, SubExp
tile) -> do
            [Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32] SubExp
dim
            [Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32] SubExp
tile
          Type
vt <- VName -> TypeM lore Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
v
          Bool -> TypeM lore () -> TypeM lore ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Type
vt Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
== Type
t Type -> Shape -> Type
`arrayOfShape` [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape (((SubExp, SubExp) -> SubExp) -> [(SubExp, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(SubExp, SubExp)]
dims)) (TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
            ErrorCase lore -> TypeM lore ()
forall lore a. ErrorCase lore -> TypeM lore a
TC.bad (ErrorCase lore -> TypeM lore ())
-> ErrorCase lore -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ String -> ErrorCase lore
forall lore. String -> ErrorCase lore
TC.TypeError (String -> ErrorCase lore) -> String -> ErrorCase lore
forall a b. (a -> b) -> a -> b
$ String
"Invalid type for TileReturns " String -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
pretty VName
v

kernelBodyMetrics :: OpMetrics (Op lore) => KernelBody lore -> MetricsM ()
kernelBodyMetrics :: KernelBody lore -> MetricsM ()
kernelBodyMetrics = (Stm lore -> MetricsM ()) -> Seq (Stm lore) -> MetricsM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Stm lore -> MetricsM ()
forall lore. OpMetrics (Op lore) => Stm lore -> MetricsM ()
stmMetrics (Seq (Stm lore) -> MetricsM ())
-> (KernelBody lore -> Seq (Stm lore))
-> KernelBody lore
-> MetricsM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelBody lore -> Seq (Stm lore)
forall lore. KernelBody lore -> Stms lore
kernelBodyStms

instance PrettyLore lore => Pretty (KernelBody lore) where
  ppr :: KernelBody lore -> Doc
ppr (KernelBody BodyDec lore
_ Stms lore
stms [KernelResult]
res) =
    [Doc] -> Doc
PP.stack ((Stm lore -> Doc) -> [Stm lore] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map Stm lore -> Doc
forall a. Pretty a => a -> Doc
ppr (Stms lore -> [Stm lore]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms lore
stms)) Doc -> Doc -> Doc
</>
    String -> Doc
text String
"return" Doc -> Doc -> Doc
<+> Doc -> Doc
PP.braces ([Doc] -> Doc
PP.commasep ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ (KernelResult -> Doc) -> [KernelResult] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> Doc
forall a. Pretty a => a -> Doc
ppr [KernelResult]
res)

instance Pretty KernelResult where
  ppr :: KernelResult -> Doc
ppr (Returns ResultManifest
ResultNoSimplify SubExp
what) =
    String -> Doc
text String
"returns (manifest)" Doc -> Doc -> Doc
<+> SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
what
  ppr (Returns ResultManifest
ResultPrivate SubExp
what) =
    String -> Doc
text String
"returns (private)" Doc -> Doc -> Doc
<+> SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
what
  ppr (Returns ResultManifest
ResultMaySimplify SubExp
what) =
    String -> Doc
text String
"returns" Doc -> Doc -> Doc
<+> SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
what
  ppr (WriteReturns [SubExp]
rws VName
arr [(Slice SubExp, SubExp)]
res) =
    VName -> Doc
forall a. Pretty a => a -> Doc
ppr VName
arr Doc -> Doc -> Doc
<+> String -> Doc
text String
"with" Doc -> Doc -> Doc
<+> [Doc] -> Doc
PP.apply (((Slice SubExp, SubExp) -> Doc)
-> [(Slice SubExp, SubExp)] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map (Slice SubExp, SubExp) -> Doc
ppRes [(Slice SubExp, SubExp)]
res)
    where ppRes :: (Slice SubExp, SubExp) -> Doc
ppRes (Slice SubExp
is, SubExp
e) =
            Doc -> Doc
PP.brackets ([Doc] -> Doc
PP.commasep ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ (DimIndex SubExp -> SubExp -> Doc)
-> Slice SubExp -> [SubExp] -> [Doc]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith DimIndex SubExp -> SubExp -> Doc
forall a a. (Pretty a, Pretty a) => a -> a -> Doc
f Slice SubExp
is [SubExp]
rws) Doc -> Doc -> Doc
<+> String -> Doc
text String
"<-" Doc -> Doc -> Doc
<+> SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
e
          f :: a -> a -> Doc
f a
i a
rw = a -> Doc
forall a. Pretty a => a -> Doc
ppr a
i Doc -> Doc -> Doc
<+> String -> Doc
text String
"<" Doc -> Doc -> Doc
<+> a -> Doc
forall a. Pretty a => a -> Doc
ppr a
rw
  ppr (ConcatReturns SplitOrdering
o SubExp
w SubExp
per_thread_elems VName
v) =
    String -> Doc
text String
"concat" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
suff Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<>
    Doc -> Doc
parens ([Doc] -> Doc
commasep [SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
w, SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
per_thread_elems]) Doc -> Doc -> Doc
<+> VName -> Doc
forall a. Pretty a => a -> Doc
ppr VName
v
    where suff :: Doc
suff = case SplitOrdering
o of SplitOrdering
SplitContiguous     -> Doc
forall a. Monoid a => a
mempty
                           SplitStrided SubExp
stride -> String -> Doc
text String
"Strided" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens (SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
stride)
  ppr (TileReturns [(SubExp, SubExp)]
dims VName
v) =
    String -> Doc
text String
"tile" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<>
    Doc -> Doc
parens ([Doc] -> Doc
commasep ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ ((SubExp, SubExp) -> Doc) -> [(SubExp, SubExp)] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp, SubExp) -> Doc
forall a a. (Pretty a, Pretty a) => (a, a) -> Doc
onDim [(SubExp, SubExp)]
dims) Doc -> Doc -> Doc
<+> VName -> Doc
forall a. Pretty a => a -> Doc
ppr VName
v
    where onDim :: (a, a) -> Doc
onDim (a
dim, a
tile) = a -> Doc
forall a. Pretty a => a -> Doc
ppr a
dim Doc -> Doc -> Doc
<+> String -> Doc
text String
"/" Doc -> Doc -> Doc
<+> a -> Doc
forall a. Pretty a => a -> Doc
ppr a
tile


-- | Do we need group-virtualisation when generating code for the
-- segmented operation?  In most cases, we do, but for some simple
-- kernels, we compute the full number of groups in advance, and then
-- virtualisation is an unnecessary (but generally very small)
-- overhead.  This only really matters for fairly trivial but very
-- wide @map@ kernels where each thread performs constant-time work on
-- scalars.
data SegVirt
  = SegVirt
  | SegNoVirt
  | SegNoVirtFull
    -- ^ Not only do we not need virtualisation, but we _guarantee_
    -- that all physical threads participate in the work.  This can
    -- save some checks in code generation.
  deriving (SegVirt -> SegVirt -> Bool
(SegVirt -> SegVirt -> Bool)
-> (SegVirt -> SegVirt -> Bool) -> Eq SegVirt
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SegVirt -> SegVirt -> Bool
$c/= :: SegVirt -> SegVirt -> Bool
== :: SegVirt -> SegVirt -> Bool
$c== :: SegVirt -> SegVirt -> Bool
Eq, Eq SegVirt
Eq SegVirt
-> (SegVirt -> SegVirt -> Ordering)
-> (SegVirt -> SegVirt -> Bool)
-> (SegVirt -> SegVirt -> Bool)
-> (SegVirt -> SegVirt -> Bool)
-> (SegVirt -> SegVirt -> Bool)
-> (SegVirt -> SegVirt -> SegVirt)
-> (SegVirt -> SegVirt -> SegVirt)
-> Ord SegVirt
SegVirt -> SegVirt -> Bool
SegVirt -> SegVirt -> Ordering
SegVirt -> SegVirt -> SegVirt
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: SegVirt -> SegVirt -> SegVirt
$cmin :: SegVirt -> SegVirt -> SegVirt
max :: SegVirt -> SegVirt -> SegVirt
$cmax :: SegVirt -> SegVirt -> SegVirt
>= :: SegVirt -> SegVirt -> Bool
$c>= :: SegVirt -> SegVirt -> Bool
> :: SegVirt -> SegVirt -> Bool
$c> :: SegVirt -> SegVirt -> Bool
<= :: SegVirt -> SegVirt -> Bool
$c<= :: SegVirt -> SegVirt -> Bool
< :: SegVirt -> SegVirt -> Bool
$c< :: SegVirt -> SegVirt -> Bool
compare :: SegVirt -> SegVirt -> Ordering
$ccompare :: SegVirt -> SegVirt -> Ordering
$cp1Ord :: Eq SegVirt
Ord, Int -> SegVirt -> ShowS
[SegVirt] -> ShowS
SegVirt -> String
(Int -> SegVirt -> ShowS)
-> (SegVirt -> String) -> ([SegVirt] -> ShowS) -> Show SegVirt
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SegVirt] -> ShowS
$cshowList :: [SegVirt] -> ShowS
show :: SegVirt -> String
$cshow :: SegVirt -> String
showsPrec :: Int -> SegVirt -> ShowS
$cshowsPrec :: Int -> SegVirt -> ShowS
Show)

-- | Index space of a 'SegOp'.
data SegSpace = SegSpace { SegSpace -> VName
segFlat :: VName
                         -- ^ Flat physical index corresponding to the
                         -- dimensions (at code generation used for a
                         -- thread ID or similar).
                         , SegSpace -> [(VName, SubExp)]
unSegSpace :: [(VName, SubExp)]
                         }
              deriving (SegSpace -> SegSpace -> Bool
(SegSpace -> SegSpace -> Bool)
-> (SegSpace -> SegSpace -> Bool) -> Eq SegSpace
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SegSpace -> SegSpace -> Bool
$c/= :: SegSpace -> SegSpace -> Bool
== :: SegSpace -> SegSpace -> Bool
$c== :: SegSpace -> SegSpace -> Bool
Eq, Eq SegSpace
Eq SegSpace
-> (SegSpace -> SegSpace -> Ordering)
-> (SegSpace -> SegSpace -> Bool)
-> (SegSpace -> SegSpace -> Bool)
-> (SegSpace -> SegSpace -> Bool)
-> (SegSpace -> SegSpace -> Bool)
-> (SegSpace -> SegSpace -> SegSpace)
-> (SegSpace -> SegSpace -> SegSpace)
-> Ord SegSpace
SegSpace -> SegSpace -> Bool
SegSpace -> SegSpace -> Ordering
SegSpace -> SegSpace -> SegSpace
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: SegSpace -> SegSpace -> SegSpace
$cmin :: SegSpace -> SegSpace -> SegSpace
max :: SegSpace -> SegSpace -> SegSpace
$cmax :: SegSpace -> SegSpace -> SegSpace
>= :: SegSpace -> SegSpace -> Bool
$c>= :: SegSpace -> SegSpace -> Bool
> :: SegSpace -> SegSpace -> Bool
$c> :: SegSpace -> SegSpace -> Bool
<= :: SegSpace -> SegSpace -> Bool
$c<= :: SegSpace -> SegSpace -> Bool
< :: SegSpace -> SegSpace -> Bool
$c< :: SegSpace -> SegSpace -> Bool
compare :: SegSpace -> SegSpace -> Ordering
$ccompare :: SegSpace -> SegSpace -> Ordering
$cp1Ord :: Eq SegSpace
Ord, Int -> SegSpace -> ShowS
[SegSpace] -> ShowS
SegSpace -> String
(Int -> SegSpace -> ShowS)
-> (SegSpace -> String) -> ([SegSpace] -> ShowS) -> Show SegSpace
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SegSpace] -> ShowS
$cshowList :: [SegSpace] -> ShowS
show :: SegSpace -> String
$cshow :: SegSpace -> String
showsPrec :: Int -> SegSpace -> ShowS
$cshowsPrec :: Int -> SegSpace -> ShowS
Show)


-- | The sizes spanned by the indexes of the 'SegSpace'.
segSpaceDims :: SegSpace -> [SubExp]
segSpaceDims :: SegSpace -> [SubExp]
segSpaceDims (SegSpace VName
_ [(VName, SubExp)]
space) = ((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(VName, SubExp)]
space

-- | A 'Scope' containing all the identifiers brought into scope by
-- this 'SegSpace'.
scopeOfSegSpace :: SegSpace -> Scope lore
scopeOfSegSpace :: SegSpace -> Scope lore
scopeOfSegSpace (SegSpace VName
phys [(VName, SubExp)]
space) =
  [(VName, NameInfo lore)] -> Scope lore
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, NameInfo lore)] -> Scope lore)
-> [(VName, NameInfo lore)] -> Scope lore
forall a b. (a -> b) -> a -> b
$ [VName] -> [NameInfo lore] -> [(VName, NameInfo lore)]
forall a b. [a] -> [b] -> [(a, b)]
zip (VName
phys VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: ((VName, SubExp) -> VName) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst [(VName, SubExp)]
space) ([NameInfo lore] -> [(VName, NameInfo lore)])
-> [NameInfo lore] -> [(VName, NameInfo lore)]
forall a b. (a -> b) -> a -> b
$ NameInfo lore -> [NameInfo lore]
forall a. a -> [a]
repeat (NameInfo lore -> [NameInfo lore])
-> NameInfo lore -> [NameInfo lore]
forall a b. (a -> b) -> a -> b
$ IntType -> NameInfo lore
forall lore. IntType -> NameInfo lore
IndexName IntType
Int32

checkSegSpace :: TC.Checkable lore => SegSpace -> TC.TypeM lore ()
checkSegSpace :: SegSpace -> TypeM lore ()
checkSegSpace (SegSpace VName
_ [(VName, SubExp)]
dims) =
  ((VName, SubExp) -> TypeM lore ())
-> [(VName, SubExp)] -> TypeM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ([Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32] (SubExp -> TypeM lore ())
-> ((VName, SubExp) -> SubExp) -> (VName, SubExp) -> TypeM lore ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd) [(VName, SubExp)]
dims

-- | A 'SegOp' is semantically a perfectly nested stack of maps, on
-- top of some bottommost computation (scalar computation, reduction,
-- scan, or histogram).  The 'SegSpace' encodes the original map
-- structure.
--
-- All 'SegOp's are parameterised by the representation of their body,
-- as well as a *level*.  The *level* is a representation-specific bit
-- of information.  For example, in GPU backends, it is used to
-- indicate whether the 'SegOp' is expected to run at the thread-level
-- or the group-level.
data SegOp lvl lore
  = SegMap lvl SegSpace [Type] (KernelBody lore)
  | SegRed lvl SegSpace [SegBinOp lore] [Type] (KernelBody lore)
    -- ^ The KernelSpace must always have at least two dimensions,
    -- implying that the result of a SegRed is always an array.
  | SegScan lvl SegSpace [SegBinOp lore] [Type] (KernelBody lore)
  | SegHist lvl SegSpace [HistOp lore] [Type] (KernelBody lore)
  deriving (SegOp lvl lore -> SegOp lvl lore -> Bool
(SegOp lvl lore -> SegOp lvl lore -> Bool)
-> (SegOp lvl lore -> SegOp lvl lore -> Bool)
-> Eq (SegOp lvl lore)
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall lvl lore.
(Decorations lore, Eq lvl) =>
SegOp lvl lore -> SegOp lvl lore -> Bool
/= :: SegOp lvl lore -> SegOp lvl lore -> Bool
$c/= :: forall lvl lore.
(Decorations lore, Eq lvl) =>
SegOp lvl lore -> SegOp lvl lore -> Bool
== :: SegOp lvl lore -> SegOp lvl lore -> Bool
$c== :: forall lvl lore.
(Decorations lore, Eq lvl) =>
SegOp lvl lore -> SegOp lvl lore -> Bool
Eq, Eq (SegOp lvl lore)
Eq (SegOp lvl lore)
-> (SegOp lvl lore -> SegOp lvl lore -> Ordering)
-> (SegOp lvl lore -> SegOp lvl lore -> Bool)
-> (SegOp lvl lore -> SegOp lvl lore -> Bool)
-> (SegOp lvl lore -> SegOp lvl lore -> Bool)
-> (SegOp lvl lore -> SegOp lvl lore -> Bool)
-> (SegOp lvl lore -> SegOp lvl lore -> SegOp lvl lore)
-> (SegOp lvl lore -> SegOp lvl lore -> SegOp lvl lore)
-> Ord (SegOp lvl lore)
SegOp lvl lore -> SegOp lvl lore -> Bool
SegOp lvl lore -> SegOp lvl lore -> Ordering
SegOp lvl lore -> SegOp lvl lore -> SegOp lvl lore
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall lvl lore. (Decorations lore, Ord lvl) => Eq (SegOp lvl lore)
forall lvl lore.
(Decorations lore, Ord lvl) =>
SegOp lvl lore -> SegOp lvl lore -> Bool
forall lvl lore.
(Decorations lore, Ord lvl) =>
SegOp lvl lore -> SegOp lvl lore -> Ordering
forall lvl lore.
(Decorations lore, Ord lvl) =>
SegOp lvl lore -> SegOp lvl lore -> SegOp lvl lore
min :: SegOp lvl lore -> SegOp lvl lore -> SegOp lvl lore
$cmin :: forall lvl lore.
(Decorations lore, Ord lvl) =>
SegOp lvl lore -> SegOp lvl lore -> SegOp lvl lore
max :: SegOp lvl lore -> SegOp lvl lore -> SegOp lvl lore
$cmax :: forall lvl lore.
(Decorations lore, Ord lvl) =>
SegOp lvl lore -> SegOp lvl lore -> SegOp lvl lore
>= :: SegOp lvl lore -> SegOp lvl lore -> Bool
$c>= :: forall lvl lore.
(Decorations lore, Ord lvl) =>
SegOp lvl lore -> SegOp lvl lore -> Bool
> :: SegOp lvl lore -> SegOp lvl lore -> Bool
$c> :: forall lvl lore.
(Decorations lore, Ord lvl) =>
SegOp lvl lore -> SegOp lvl lore -> Bool
<= :: SegOp lvl lore -> SegOp lvl lore -> Bool
$c<= :: forall lvl lore.
(Decorations lore, Ord lvl) =>
SegOp lvl lore -> SegOp lvl lore -> Bool
< :: SegOp lvl lore -> SegOp lvl lore -> Bool
$c< :: forall lvl lore.
(Decorations lore, Ord lvl) =>
SegOp lvl lore -> SegOp lvl lore -> Bool
compare :: SegOp lvl lore -> SegOp lvl lore -> Ordering
$ccompare :: forall lvl lore.
(Decorations lore, Ord lvl) =>
SegOp lvl lore -> SegOp lvl lore -> Ordering
$cp1Ord :: forall lvl lore. (Decorations lore, Ord lvl) => Eq (SegOp lvl lore)
Ord, Int -> SegOp lvl lore -> ShowS
[SegOp lvl lore] -> ShowS
SegOp lvl lore -> String
(Int -> SegOp lvl lore -> ShowS)
-> (SegOp lvl lore -> String)
-> ([SegOp lvl lore] -> ShowS)
-> Show (SegOp lvl lore)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall lvl lore.
(Decorations lore, Show lvl) =>
Int -> SegOp lvl lore -> ShowS
forall lvl lore.
(Decorations lore, Show lvl) =>
[SegOp lvl lore] -> ShowS
forall lvl lore.
(Decorations lore, Show lvl) =>
SegOp lvl lore -> String
showList :: [SegOp lvl lore] -> ShowS
$cshowList :: forall lvl lore.
(Decorations lore, Show lvl) =>
[SegOp lvl lore] -> ShowS
show :: SegOp lvl lore -> String
$cshow :: forall lvl lore.
(Decorations lore, Show lvl) =>
SegOp lvl lore -> String
showsPrec :: Int -> SegOp lvl lore -> ShowS
$cshowsPrec :: forall lvl lore.
(Decorations lore, Show lvl) =>
Int -> SegOp lvl lore -> ShowS
Show)

-- | The level of a 'SegOp'.
segLevel :: SegOp lvl lore -> lvl
segLevel :: SegOp lvl lore -> lvl
segLevel (SegMap lvl
lvl SegSpace
_ [Type]
_ KernelBody lore
_) = lvl
lvl
segLevel (SegRed lvl
lvl SegSpace
_ [SegBinOp lore]
_ [Type]
_ KernelBody lore
_) = lvl
lvl
segLevel (SegScan lvl
lvl SegSpace
_ [SegBinOp lore]
_ [Type]
_ KernelBody lore
_) = lvl
lvl
segLevel (SegHist lvl
lvl SegSpace
_ [HistOp lore]
_ [Type]
_ KernelBody lore
_) = lvl
lvl

-- | The space of a 'SegOp'.
segSpace :: SegOp lvl lore -> SegSpace
segSpace :: SegOp lvl lore -> SegSpace
segSpace (SegMap lvl
_ SegSpace
lvl [Type]
_ KernelBody lore
_) = SegSpace
lvl
segSpace (SegRed lvl
_ SegSpace
lvl [SegBinOp lore]
_ [Type]
_ KernelBody lore
_) = SegSpace
lvl
segSpace (SegScan lvl
_ SegSpace
lvl [SegBinOp lore]
_ [Type]
_ KernelBody lore
_) = SegSpace
lvl
segSpace (SegHist lvl
_ SegSpace
lvl [HistOp lore]
_ [Type]
_ KernelBody lore
_) = SegSpace
lvl

segResultShape :: SegSpace -> Type -> KernelResult -> Type
segResultShape :: SegSpace -> Type -> KernelResult -> Type
segResultShape SegSpace
_ Type
t (WriteReturns [SubExp]
rws VName
_ [(Slice SubExp, SubExp)]
_) =
  Type
t Type -> Shape -> Type
`arrayOfShape` [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
rws
segResultShape SegSpace
space Type
t (Returns ResultManifest
_ SubExp
_) =
  (SubExp -> Type -> Type) -> Type -> [SubExp] -> Type
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr ((Type -> SubExp -> Type) -> SubExp -> Type -> Type
forall a b c. (a -> b -> c) -> b -> a -> c
flip Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
arrayOfRow) Type
t ([SubExp] -> Type) -> [SubExp] -> Type
forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space
segResultShape SegSpace
_ Type
t (ConcatReturns SplitOrdering
_ SubExp
w SubExp
_ VName
_) =
  Type
t Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` SubExp
w
segResultShape SegSpace
_ Type
t (TileReturns [(SubExp, SubExp)]
dims VName
_) =
  Type
t Type -> Shape -> Type
`arrayOfShape` [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape (((SubExp, SubExp) -> SubExp) -> [(SubExp, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp, SubExp) -> SubExp
forall a b. (a, b) -> a
fst [(SubExp, SubExp)]
dims)

-- | The return type of a 'SegOp'.
segOpType :: SegOp lvl lore -> [Type]
segOpType :: SegOp lvl lore -> [Type]
segOpType (SegMap lvl
_ SegSpace
space [Type]
ts KernelBody lore
kbody) =
  (Type -> KernelResult -> Type)
-> [Type] -> [KernelResult] -> [Type]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (SegSpace -> Type -> KernelResult -> Type
segResultShape SegSpace
space) [Type]
ts ([KernelResult] -> [Type]) -> [KernelResult] -> [Type]
forall a b. (a -> b) -> a -> b
$ KernelBody lore -> [KernelResult]
forall lore. KernelBody lore -> [KernelResult]
kernelBodyResult KernelBody lore
kbody
segOpType (SegRed lvl
_ SegSpace
space [SegBinOp lore]
reds [Type]
ts KernelBody lore
kbody) =
  [Type]
red_ts [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++
  (Type -> KernelResult -> Type)
-> [Type] -> [KernelResult] -> [Type]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (SegSpace -> Type -> KernelResult -> Type
segResultShape SegSpace
space) [Type]
map_ts
  (Int -> [KernelResult] -> [KernelResult]
forall a. Int -> [a] -> [a]
drop ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
red_ts) ([KernelResult] -> [KernelResult])
-> [KernelResult] -> [KernelResult]
forall a b. (a -> b) -> a -> b
$ KernelBody lore -> [KernelResult]
forall lore. KernelBody lore -> [KernelResult]
kernelBodyResult KernelBody lore
kbody)
  where map_ts :: [Type]
map_ts = Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
drop ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
red_ts) [Type]
ts
        segment_dims :: [SubExp]
segment_dims = [SubExp] -> [SubExp]
forall a. [a] -> [a]
init ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space
        red_ts :: [Type]
red_ts = do
          SegBinOp lore
op <- [SegBinOp lore]
reds
          let shape :: Shape
shape = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
segment_dims Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> SegBinOp lore -> Shape
forall lore. SegBinOp lore -> Shape
segBinOpShape SegBinOp lore
op
          (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> Shape -> Type
`arrayOfShape` Shape
shape) (LambdaT lore -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType (LambdaT lore -> [Type]) -> LambdaT lore -> [Type]
forall a b. (a -> b) -> a -> b
$ SegBinOp lore -> LambdaT lore
forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda SegBinOp lore
op)
segOpType (SegScan lvl
_ SegSpace
space [SegBinOp lore]
scans [Type]
ts KernelBody lore
kbody) =
  [Type]
scan_ts [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++
  (Type -> KernelResult -> Type)
-> [Type] -> [KernelResult] -> [Type]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (SegSpace -> Type -> KernelResult -> Type
segResultShape SegSpace
space) [Type]
map_ts
  (Int -> [KernelResult] -> [KernelResult]
forall a. Int -> [a] -> [a]
drop ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
scan_ts) ([KernelResult] -> [KernelResult])
-> [KernelResult] -> [KernelResult]
forall a b. (a -> b) -> a -> b
$ KernelBody lore -> [KernelResult]
forall lore. KernelBody lore -> [KernelResult]
kernelBodyResult KernelBody lore
kbody)
  where map_ts :: [Type]
map_ts = Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
drop ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
scan_ts) [Type]
ts
        scan_ts :: [Type]
scan_ts = do
          SegBinOp lore
op <- [SegBinOp lore]
scans
          let shape :: Shape
shape = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape (SegSpace -> [SubExp]
segSpaceDims SegSpace
space) Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> SegBinOp lore -> Shape
forall lore. SegBinOp lore -> Shape
segBinOpShape SegBinOp lore
op
          (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> Shape -> Type
`arrayOfShape` Shape
shape) (LambdaT lore -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType (LambdaT lore -> [Type]) -> LambdaT lore -> [Type]
forall a b. (a -> b) -> a -> b
$ SegBinOp lore -> LambdaT lore
forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda SegBinOp lore
op)
segOpType (SegHist lvl
_ SegSpace
space [HistOp lore]
ops [Type]
_ KernelBody lore
_) = do
  HistOp lore
op <- [HistOp lore]
ops
  let shape :: Shape
shape = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape ([SubExp]
segment_dims [SubExp] -> [SubExp] -> [SubExp]
forall a. Semigroup a => a -> a -> a
<> [HistOp lore -> SubExp
forall lore. HistOp lore -> SubExp
histWidth HistOp lore
op]) Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> HistOp lore -> Shape
forall lore. HistOp lore -> Shape
histShape HistOp lore
op
  (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> Shape -> Type
`arrayOfShape` Shape
shape) (LambdaT lore -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType (LambdaT lore -> [Type]) -> LambdaT lore -> [Type]
forall a b. (a -> b) -> a -> b
$ HistOp lore -> LambdaT lore
forall lore. HistOp lore -> Lambda lore
histOp HistOp lore
op)
  where dims :: [SubExp]
dims = SegSpace -> [SubExp]
segSpaceDims SegSpace
space
        segment_dims :: [SubExp]
segment_dims = [SubExp] -> [SubExp]
forall a. [a] -> [a]
init [SubExp]
dims

instance TypedOp (SegOp lvl lore) where
  opType :: SegOp lvl lore -> m [ExtType]
opType = [ExtType] -> m [ExtType]
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([ExtType] -> m [ExtType])
-> (SegOp lvl lore -> [ExtType]) -> SegOp lvl lore -> m [ExtType]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Type] -> [ExtType]
forall u. [TypeBase Shape u] -> [TypeBase ExtShape u]
staticShapes ([Type] -> [ExtType])
-> (SegOp lvl lore -> [Type]) -> SegOp lvl lore -> [ExtType]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegOp lvl lore -> [Type]
forall lvl lore. SegOp lvl lore -> [Type]
segOpType

instance (ASTLore lore, Aliased lore, ASTConstraints lvl) =>
         AliasedOp (SegOp lvl lore) where
  opAliases :: SegOp lvl lore -> [Names]
opAliases = (Type -> Names) -> [Type] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map (Names -> Type -> Names
forall a b. a -> b -> a
const Names
forall a. Monoid a => a
mempty) ([Type] -> [Names])
-> (SegOp lvl lore -> [Type]) -> SegOp lvl lore -> [Names]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegOp lvl lore -> [Type]
forall lvl lore. SegOp lvl lore -> [Type]
segOpType

  consumedInOp :: SegOp lvl lore -> Names
consumedInOp (SegMap lvl
_ SegSpace
_ [Type]
_ KernelBody lore
kbody) =
    KernelBody lore -> Names
forall lore. Aliased lore => KernelBody lore -> Names
consumedInKernelBody KernelBody lore
kbody
  consumedInOp (SegRed lvl
_ SegSpace
_ [SegBinOp lore]
_ [Type]
_ KernelBody lore
kbody) =
    KernelBody lore -> Names
forall lore. Aliased lore => KernelBody lore -> Names
consumedInKernelBody KernelBody lore
kbody
  consumedInOp (SegScan lvl
_ SegSpace
_ [SegBinOp lore]
_ [Type]
_ KernelBody lore
kbody) =
    KernelBody lore -> Names
forall lore. Aliased lore => KernelBody lore -> Names
consumedInKernelBody KernelBody lore
kbody
  consumedInOp (SegHist lvl
_ SegSpace
_ [HistOp lore]
ops [Type]
_ KernelBody lore
kbody) =
    [VName] -> Names
namesFromList ((HistOp lore -> [VName]) -> [HistOp lore] -> [VName]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap HistOp lore -> [VName]
forall lore. HistOp lore -> [VName]
histDest [HistOp lore]
ops) Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> KernelBody lore -> Names
forall lore. Aliased lore => KernelBody lore -> Names
consumedInKernelBody KernelBody lore
kbody

-- | Type check a 'SegOp', given a checker for its level.
typeCheckSegOp :: TC.Checkable lore =>
                  (lvl -> TC.TypeM lore ())
               -> SegOp lvl (Aliases lore) -> TC.TypeM lore ()
typeCheckSegOp :: (lvl -> TypeM lore ()) -> SegOp lvl (Aliases lore) -> TypeM lore ()
typeCheckSegOp lvl -> TypeM lore ()
checkLvl (SegMap lvl
lvl SegSpace
space [Type]
ts KernelBody (Aliases lore)
kbody) = do
  lvl -> TypeM lore ()
checkLvl lvl
lvl
  SegSpace
-> [(Lambda (Aliases lore), [SubExp], Shape)]
-> [Type]
-> KernelBody (Aliases lore)
-> TypeM lore ()
forall lore.
Checkable lore =>
SegSpace
-> [(Lambda (Aliases lore), [SubExp], Shape)]
-> [Type]
-> KernelBody (Aliases lore)
-> TypeM lore ()
checkScanRed SegSpace
space [] [Type]
ts KernelBody (Aliases lore)
kbody

typeCheckSegOp lvl -> TypeM lore ()
checkLvl (SegRed lvl
lvl SegSpace
space [SegBinOp (Aliases lore)]
reds [Type]
ts KernelBody (Aliases lore)
body) = do
  lvl -> TypeM lore ()
checkLvl lvl
lvl
  SegSpace
-> [(Lambda (Aliases lore), [SubExp], Shape)]
-> [Type]
-> KernelBody (Aliases lore)
-> TypeM lore ()
forall lore.
Checkable lore =>
SegSpace
-> [(Lambda (Aliases lore), [SubExp], Shape)]
-> [Type]
-> KernelBody (Aliases lore)
-> TypeM lore ()
checkScanRed SegSpace
space [(Lambda (Aliases lore), [SubExp], Shape)]
reds' [Type]
ts KernelBody (Aliases lore)
body
  where reds' :: [(Lambda (Aliases lore), [SubExp], Shape)]
reds' = [Lambda (Aliases lore)]
-> [[SubExp]]
-> [Shape]
-> [(Lambda (Aliases lore), [SubExp], Shape)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3
                ((SegBinOp (Aliases lore) -> Lambda (Aliases lore))
-> [SegBinOp (Aliases lore)] -> [Lambda (Aliases lore)]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp (Aliases lore) -> Lambda (Aliases lore)
forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda [SegBinOp (Aliases lore)]
reds)
                ((SegBinOp (Aliases lore) -> [SubExp])
-> [SegBinOp (Aliases lore)] -> [[SubExp]]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp (Aliases lore) -> [SubExp]
forall lore. SegBinOp lore -> [SubExp]
segBinOpNeutral [SegBinOp (Aliases lore)]
reds)
                ((SegBinOp (Aliases lore) -> Shape)
-> [SegBinOp (Aliases lore)] -> [Shape]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp (Aliases lore) -> Shape
forall lore. SegBinOp lore -> Shape
segBinOpShape [SegBinOp (Aliases lore)]
reds)

typeCheckSegOp lvl -> TypeM lore ()
checkLvl (SegScan lvl
lvl SegSpace
space [SegBinOp (Aliases lore)]
scans [Type]
ts KernelBody (Aliases lore)
body) = do
  lvl -> TypeM lore ()
checkLvl lvl
lvl
  SegSpace
-> [(Lambda (Aliases lore), [SubExp], Shape)]
-> [Type]
-> KernelBody (Aliases lore)
-> TypeM lore ()
forall lore.
Checkable lore =>
SegSpace
-> [(Lambda (Aliases lore), [SubExp], Shape)]
-> [Type]
-> KernelBody (Aliases lore)
-> TypeM lore ()
checkScanRed SegSpace
space [(Lambda (Aliases lore), [SubExp], Shape)]
scans' [Type]
ts KernelBody (Aliases lore)
body
  where scans' :: [(Lambda (Aliases lore), [SubExp], Shape)]
scans' = [Lambda (Aliases lore)]
-> [[SubExp]]
-> [Shape]
-> [(Lambda (Aliases lore), [SubExp], Shape)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3
                 ((SegBinOp (Aliases lore) -> Lambda (Aliases lore))
-> [SegBinOp (Aliases lore)] -> [Lambda (Aliases lore)]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp (Aliases lore) -> Lambda (Aliases lore)
forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda [SegBinOp (Aliases lore)]
scans)
                 ((SegBinOp (Aliases lore) -> [SubExp])
-> [SegBinOp (Aliases lore)] -> [[SubExp]]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp (Aliases lore) -> [SubExp]
forall lore. SegBinOp lore -> [SubExp]
segBinOpNeutral [SegBinOp (Aliases lore)]
scans)
                 ((SegBinOp (Aliases lore) -> Shape)
-> [SegBinOp (Aliases lore)] -> [Shape]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp (Aliases lore) -> Shape
forall lore. SegBinOp lore -> Shape
segBinOpShape [SegBinOp (Aliases lore)]
scans)

typeCheckSegOp lvl -> TypeM lore ()
checkLvl (SegHist lvl
lvl SegSpace
space [HistOp (Aliases lore)]
ops [Type]
ts KernelBody (Aliases lore)
kbody) = do
  lvl -> TypeM lore ()
checkLvl lvl
lvl
  SegSpace -> TypeM lore ()
forall lore. Checkable lore => SegSpace -> TypeM lore ()
checkSegSpace SegSpace
space
  (Type -> TypeM lore ()) -> [Type] -> TypeM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Type -> TypeM lore ()
forall lore u. Checkable lore => TypeBase Shape u -> TypeM lore ()
TC.checkType [Type]
ts

  Scope (Aliases lore) -> TypeM lore () -> TypeM lore ()
forall lore a.
Checkable lore =>
Scope (Aliases lore) -> TypeM lore a -> TypeM lore a
TC.binding (SegSpace -> Scope (Aliases lore)
forall lore. SegSpace -> Scope lore
scopeOfSegSpace SegSpace
space) (TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ do
    [[Type]]
nes_ts <- [HistOp (Aliases lore)]
-> (HistOp (Aliases lore) -> TypeM lore [Type])
-> TypeM lore [[Type]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [HistOp (Aliases lore)]
ops ((HistOp (Aliases lore) -> TypeM lore [Type])
 -> TypeM lore [[Type]])
-> (HistOp (Aliases lore) -> TypeM lore [Type])
-> TypeM lore [[Type]]
forall a b. (a -> b) -> a -> b
$ \(HistOp SubExp
dest_w SubExp
rf [VName]
dests [SubExp]
nes Shape
shape Lambda (Aliases lore)
op) -> do
      [Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32] SubExp
dest_w
      [Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32] SubExp
rf
      [Arg]
nes' <- (SubExp -> TypeM lore Arg) -> [SubExp] -> TypeM lore [Arg]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> TypeM lore Arg
forall lore. Checkable lore => SubExp -> TypeM lore Arg
TC.checkArg [SubExp]
nes
      (SubExp -> TypeM lore ()) -> [SubExp] -> TypeM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ([Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32]) ([SubExp] -> TypeM lore ()) -> [SubExp] -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape

      -- Operator type must match the type of neutral elements.
      let stripVecDims :: Type -> Type
stripVecDims = Int -> Type -> Type
forall shape u.
ArrayShape shape =>
Int -> TypeBase shape u -> TypeBase shape u
stripArray (Int -> Type -> Type) -> Int -> Type -> Type
forall a b. (a -> b) -> a -> b
$ Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank Shape
shape
      Lambda (Aliases lore) -> [Arg] -> TypeM lore ()
forall lore.
Checkable lore =>
Lambda (Aliases lore) -> [Arg] -> TypeM lore ()
TC.checkLambda Lambda (Aliases lore)
op ([Arg] -> TypeM lore ()) -> [Arg] -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ (Arg -> Arg) -> [Arg] -> [Arg]
forall a b. (a -> b) -> [a] -> [b]
map (Arg -> Arg
TC.noArgAliases (Arg -> Arg) -> (Arg -> Arg) -> Arg -> Arg
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Type -> Type) -> Arg -> Arg
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first Type -> Type
stripVecDims) ([Arg] -> [Arg]) -> [Arg] -> [Arg]
forall a b. (a -> b) -> a -> b
$ [Arg]
nes' [Arg] -> [Arg] -> [Arg]
forall a. [a] -> [a] -> [a]
++ [Arg]
nes'
      let nes_t :: [Type]
nes_t = (Arg -> Type) -> [Arg] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Type
TC.argType [Arg]
nes'
      Bool -> TypeM lore () -> TypeM lore ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type]
nes_t [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== Lambda (Aliases lore) -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda (Aliases lore)
op) (TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
        ErrorCase lore -> TypeM lore ()
forall lore a. ErrorCase lore -> TypeM lore a
TC.bad (ErrorCase lore -> TypeM lore ())
-> ErrorCase lore -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ String -> ErrorCase lore
forall lore. String -> ErrorCase lore
TC.TypeError (String -> ErrorCase lore) -> String -> ErrorCase lore
forall a b. (a -> b) -> a -> b
$ String
"SegHist operator has return type " String -> ShowS
forall a. [a] -> [a] -> [a]
++
        [Type] -> String
forall a. Pretty a => [a] -> String
prettyTuple (Lambda (Aliases lore) -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda (Aliases lore)
op) String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" but neutral element has type " String -> ShowS
forall a. [a] -> [a] -> [a]
++
        [Type] -> String
forall a. Pretty a => [a] -> String
prettyTuple [Type]
nes_t

      -- Arrays must have proper type.
      let dest_shape :: Shape
dest_shape = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape ([SubExp]
segment_dims [SubExp] -> [SubExp] -> [SubExp]
forall a. Semigroup a => a -> a -> a
<> [SubExp
dest_w]) Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> Shape
shape
      [(Type, VName)]
-> ((Type, VName) -> TypeM lore ()) -> TypeM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Type] -> [VName] -> [(Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Type]
nes_t [VName]
dests) (((Type, VName) -> TypeM lore ()) -> TypeM lore ())
-> ((Type, VName) -> TypeM lore ()) -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ \(Type
t, VName
dest) -> do
        [Type] -> VName -> TypeM lore ()
forall lore. Checkable lore => [Type] -> VName -> TypeM lore ()
TC.requireI [Type
t Type -> Shape -> Type
`arrayOfShape` Shape
dest_shape] VName
dest
        Names -> TypeM lore ()
forall lore. Checkable lore => Names -> TypeM lore ()
TC.consume (Names -> TypeM lore ()) -> TypeM lore Names -> TypeM lore ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> TypeM lore Names
forall lore. Checkable lore => VName -> TypeM lore Names
TC.lookupAliases VName
dest

      [Type] -> TypeM lore [Type]
forall (m :: * -> *) a. Monad m => a -> m a
return ([Type] -> TypeM lore [Type]) -> [Type] -> TypeM lore [Type]
forall a b. (a -> b) -> a -> b
$ (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> Shape -> Type
`arrayOfShape` Shape
shape) [Type]
nes_t

    [Type] -> KernelBody (Aliases lore) -> TypeM lore ()
forall lore.
Checkable lore =>
[Type] -> KernelBody (Aliases lore) -> TypeM lore ()
checkKernelBody [Type]
ts KernelBody (Aliases lore)
kbody

    -- Return type of bucket function must be an index for each
    -- operation followed by the values to write.
    let bucket_ret_t :: [Type]
bucket_ret_t = Int -> Type -> [Type]
forall a. Int -> a -> [a]
replicate ([HistOp (Aliases lore)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [HistOp (Aliases lore)]
ops) (PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32) [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [[Type]] -> [Type]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Type]]
nes_ts
    Bool -> TypeM lore () -> TypeM lore ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type]
bucket_ret_t [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== [Type]
ts) (TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
      ErrorCase lore -> TypeM lore ()
forall lore a. ErrorCase lore -> TypeM lore a
TC.bad (ErrorCase lore -> TypeM lore ())
-> ErrorCase lore -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ String -> ErrorCase lore
forall lore. String -> ErrorCase lore
TC.TypeError (String -> ErrorCase lore) -> String -> ErrorCase lore
forall a b. (a -> b) -> a -> b
$ String
"SegHist body has return type " String -> ShowS
forall a. [a] -> [a] -> [a]
++
      [Type] -> String
forall a. Pretty a => [a] -> String
prettyTuple [Type]
ts String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" but should have type " String -> ShowS
forall a. [a] -> [a] -> [a]
++
      [Type] -> String
forall a. Pretty a => [a] -> String
prettyTuple [Type]
bucket_ret_t

  where segment_dims :: [SubExp]
segment_dims = [SubExp] -> [SubExp]
forall a. [a] -> [a]
init ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space

checkScanRed :: TC.Checkable lore =>
                SegSpace
             -> [(Lambda (Aliases lore), [SubExp], Shape)]
             -> [Type]
             -> KernelBody (Aliases lore)
             -> TC.TypeM lore ()
checkScanRed :: SegSpace
-> [(Lambda (Aliases lore), [SubExp], Shape)]
-> [Type]
-> KernelBody (Aliases lore)
-> TypeM lore ()
checkScanRed SegSpace
space [(Lambda (Aliases lore), [SubExp], Shape)]
ops [Type]
ts KernelBody (Aliases lore)
kbody = do
  SegSpace -> TypeM lore ()
forall lore. Checkable lore => SegSpace -> TypeM lore ()
checkSegSpace SegSpace
space
  (Type -> TypeM lore ()) -> [Type] -> TypeM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Type -> TypeM lore ()
forall lore u. Checkable lore => TypeBase Shape u -> TypeM lore ()
TC.checkType [Type]
ts

  Scope (Aliases lore) -> TypeM lore () -> TypeM lore ()
forall lore a.
Checkable lore =>
Scope (Aliases lore) -> TypeM lore a -> TypeM lore a
TC.binding (SegSpace -> Scope (Aliases lore)
forall lore. SegSpace -> Scope lore
scopeOfSegSpace SegSpace
space) (TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ do
    [[Type]]
ne_ts <- [(Lambda (Aliases lore), [SubExp], Shape)]
-> ((Lambda (Aliases lore), [SubExp], Shape) -> TypeM lore [Type])
-> TypeM lore [[Type]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(Lambda (Aliases lore), [SubExp], Shape)]
ops (((Lambda (Aliases lore), [SubExp], Shape) -> TypeM lore [Type])
 -> TypeM lore [[Type]])
-> ((Lambda (Aliases lore), [SubExp], Shape) -> TypeM lore [Type])
-> TypeM lore [[Type]]
forall a b. (a -> b) -> a -> b
$ \(Lambda (Aliases lore)
lam, [SubExp]
nes, Shape
shape) -> do
      (SubExp -> TypeM lore ()) -> [SubExp] -> TypeM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ([Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32]) ([SubExp] -> TypeM lore ()) -> [SubExp] -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape
      [Arg]
nes' <- (SubExp -> TypeM lore Arg) -> [SubExp] -> TypeM lore [Arg]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> TypeM lore Arg
forall lore. Checkable lore => SubExp -> TypeM lore Arg
TC.checkArg [SubExp]
nes

      -- Operator type must match the type of neutral elements.
      Lambda (Aliases lore) -> [Arg] -> TypeM lore ()
forall lore.
Checkable lore =>
Lambda (Aliases lore) -> [Arg] -> TypeM lore ()
TC.checkLambda Lambda (Aliases lore)
lam ([Arg] -> TypeM lore ()) -> [Arg] -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ (Arg -> Arg) -> [Arg] -> [Arg]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Arg
TC.noArgAliases ([Arg] -> [Arg]) -> [Arg] -> [Arg]
forall a b. (a -> b) -> a -> b
$ [Arg]
nes' [Arg] -> [Arg] -> [Arg]
forall a. [a] -> [a] -> [a]
++ [Arg]
nes'
      let nes_t :: [Type]
nes_t = (Arg -> Type) -> [Arg] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Type
TC.argType [Arg]
nes'

      Bool -> TypeM lore () -> TypeM lore ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Lambda (Aliases lore) -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda (Aliases lore)
lam [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== [Type]
nes_t) (TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
        ErrorCase lore -> TypeM lore ()
forall lore a. ErrorCase lore -> TypeM lore a
TC.bad (ErrorCase lore -> TypeM lore ())
-> ErrorCase lore -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ String -> ErrorCase lore
forall lore. String -> ErrorCase lore
TC.TypeError String
"wrong type for operator or neutral elements."

      [Type] -> TypeM lore [Type]
forall (m :: * -> *) a. Monad m => a -> m a
return ([Type] -> TypeM lore [Type]) -> [Type] -> TypeM lore [Type]
forall a b. (a -> b) -> a -> b
$ (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> Shape -> Type
`arrayOfShape` Shape
shape) [Type]
nes_t

    let expecting :: [Type]
expecting = [[Type]] -> [Type]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Type]]
ne_ts
        got :: [Type]
got = Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
take ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
expecting) [Type]
ts
    Bool -> TypeM lore () -> TypeM lore ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type]
expecting [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== [Type]
got) (TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
      ErrorCase lore -> TypeM lore ()
forall lore a. ErrorCase lore -> TypeM lore a
TC.bad (ErrorCase lore -> TypeM lore ())
-> ErrorCase lore -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ String -> ErrorCase lore
forall lore. String -> ErrorCase lore
TC.TypeError (String -> ErrorCase lore) -> String -> ErrorCase lore
forall a b. (a -> b) -> a -> b
$
      String
"Wrong return for body (does not match neutral elements; expected " String -> ShowS
forall a. [a] -> [a] -> [a]
++
      [Type] -> String
forall a. Pretty a => a -> String
pretty [Type]
expecting String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"; found " String -> ShowS
forall a. [a] -> [a] -> [a]
++
      [Type] -> String
forall a. Pretty a => a -> String
pretty [Type]
got String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
")"

    [Type] -> KernelBody (Aliases lore) -> TypeM lore ()
forall lore.
Checkable lore =>
[Type] -> KernelBody (Aliases lore) -> TypeM lore ()
checkKernelBody [Type]
ts KernelBody (Aliases lore)
kbody

-- | Like 'Mapper', but just for 'SegOp's.
data SegOpMapper lvl flore tlore m = SegOpMapper {
    SegOpMapper lvl flore tlore m -> SubExp -> m SubExp
mapOnSegOpSubExp :: SubExp -> m SubExp
  , SegOpMapper lvl flore tlore m -> Lambda flore -> m (Lambda tlore)
mapOnSegOpLambda :: Lambda flore -> m (Lambda tlore)
  , SegOpMapper lvl flore tlore m
-> KernelBody flore -> m (KernelBody tlore)
mapOnSegOpBody :: KernelBody flore -> m (KernelBody tlore)
  , SegOpMapper lvl flore tlore m -> VName -> m VName
mapOnSegOpVName :: VName -> m VName
  , SegOpMapper lvl flore tlore m -> lvl -> m lvl
mapOnSegOpLevel :: lvl -> m lvl
  }

-- | A mapper that simply returns the 'SegOp' verbatim.
identitySegOpMapper :: Monad m => SegOpMapper lvl lore lore m
identitySegOpMapper :: SegOpMapper lvl lore lore m
identitySegOpMapper = SegOpMapper :: forall lvl flore tlore (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda flore -> m (Lambda tlore))
-> (KernelBody flore -> m (KernelBody tlore))
-> (VName -> m VName)
-> (lvl -> m lvl)
-> SegOpMapper lvl flore tlore m
SegOpMapper { mapOnSegOpSubExp :: SubExp -> m SubExp
mapOnSegOpSubExp = SubExp -> m SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return
                                  , mapOnSegOpLambda :: Lambda lore -> m (Lambda lore)
mapOnSegOpLambda = Lambda lore -> m (Lambda lore)
forall (m :: * -> *) a. Monad m => a -> m a
return
                                  , mapOnSegOpBody :: KernelBody lore -> m (KernelBody lore)
mapOnSegOpBody = KernelBody lore -> m (KernelBody lore)
forall (m :: * -> *) a. Monad m => a -> m a
return
                                  , mapOnSegOpVName :: VName -> m VName
mapOnSegOpVName = VName -> m VName
forall (m :: * -> *) a. Monad m => a -> m a
return
                                  , mapOnSegOpLevel :: lvl -> m lvl
mapOnSegOpLevel = lvl -> m lvl
forall (m :: * -> *) a. Monad m => a -> m a
return
                                  }

mapOnSegSpace :: Monad f =>
                 SegOpMapper lvl flore tlore f -> SegSpace -> f SegSpace
mapOnSegSpace :: SegOpMapper lvl flore tlore f -> SegSpace -> f SegSpace
mapOnSegSpace SegOpMapper lvl flore tlore f
tv (SegSpace VName
phys [(VName, SubExp)]
dims) =
  VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
phys ([(VName, SubExp)] -> SegSpace)
-> f [(VName, SubExp)] -> f SegSpace
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((VName, SubExp) -> f (VName, SubExp))
-> [(VName, SubExp)] -> f [(VName, SubExp)]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse ((SubExp -> f SubExp) -> (VName, SubExp) -> f (VName, SubExp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse ((SubExp -> f SubExp) -> (VName, SubExp) -> f (VName, SubExp))
-> (SubExp -> f SubExp) -> (VName, SubExp) -> f (VName, SubExp)
forall a b. (a -> b) -> a -> b
$ SegOpMapper lvl flore tlore f -> SubExp -> f SubExp
forall lvl flore tlore (m :: * -> *).
SegOpMapper lvl flore tlore m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl flore tlore f
tv) [(VName, SubExp)]
dims

mapSegBinOp :: Monad m =>
               SegOpMapper lvl flore tlore m
            -> SegBinOp flore -> m (SegBinOp tlore)
mapSegBinOp :: SegOpMapper lvl flore tlore m
-> SegBinOp flore -> m (SegBinOp tlore)
mapSegBinOp SegOpMapper lvl flore tlore m
tv (SegBinOp Commutativity
comm Lambda flore
red_op [SubExp]
nes Shape
shape) =
  Commutativity
-> Lambda tlore -> [SubExp] -> Shape -> SegBinOp tlore
forall lore.
Commutativity -> Lambda lore -> [SubExp] -> Shape -> SegBinOp lore
SegBinOp Commutativity
comm
  (Lambda tlore -> [SubExp] -> Shape -> SegBinOp tlore)
-> m (Lambda tlore) -> m ([SubExp] -> Shape -> SegBinOp tlore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper lvl flore tlore m -> Lambda flore -> m (Lambda tlore)
forall lvl flore tlore (m :: * -> *).
SegOpMapper lvl flore tlore m -> Lambda flore -> m (Lambda tlore)
mapOnSegOpLambda SegOpMapper lvl flore tlore m
tv Lambda flore
red_op
  m ([SubExp] -> Shape -> SegBinOp tlore)
-> m [SubExp] -> m (Shape -> SegBinOp tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SubExp -> m SubExp) -> [SubExp] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SegOpMapper lvl flore tlore m -> SubExp -> m SubExp
forall lvl flore tlore (m :: * -> *).
SegOpMapper lvl flore tlore m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl flore tlore m
tv) [SubExp]
nes
  m (Shape -> SegBinOp tlore) -> m Shape -> m (SegBinOp tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> Shape) -> m [SubExp] -> m Shape
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExp -> m SubExp) -> [SubExp] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SegOpMapper lvl flore tlore m -> SubExp -> m SubExp
forall lvl flore tlore (m :: * -> *).
SegOpMapper lvl flore tlore m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl flore tlore m
tv) (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape))

-- | Apply a 'SegOpMapper' to the given 'SegOp'.
mapSegOpM :: (Applicative m, Monad m) =>
             SegOpMapper lvl flore tlore m
          -> SegOp lvl flore -> m (SegOp lvl tlore)
mapSegOpM :: SegOpMapper lvl flore tlore m
-> SegOp lvl flore -> m (SegOp lvl tlore)
mapSegOpM SegOpMapper lvl flore tlore m
tv (SegMap lvl
lvl SegSpace
space [Type]
ts KernelBody flore
body) =
  lvl -> SegSpace -> [Type] -> KernelBody tlore -> SegOp lvl tlore
forall lvl lore.
lvl -> SegSpace -> [Type] -> KernelBody lore -> SegOp lvl lore
SegMap
  (lvl -> SegSpace -> [Type] -> KernelBody tlore -> SegOp lvl tlore)
-> m lvl
-> m (SegSpace -> [Type] -> KernelBody tlore -> SegOp lvl tlore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper lvl flore tlore m -> lvl -> m lvl
forall lvl flore tlore (m :: * -> *).
SegOpMapper lvl flore tlore m -> lvl -> m lvl
mapOnSegOpLevel SegOpMapper lvl flore tlore m
tv lvl
lvl
  m (SegSpace -> [Type] -> KernelBody tlore -> SegOp lvl tlore)
-> m SegSpace -> m ([Type] -> KernelBody tlore -> SegOp lvl tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper lvl flore tlore m -> SegSpace -> m SegSpace
forall (f :: * -> *) lvl flore tlore.
Monad f =>
SegOpMapper lvl flore tlore f -> SegSpace -> f SegSpace
mapOnSegSpace SegOpMapper lvl flore tlore m
tv SegSpace
space
  m ([Type] -> KernelBody tlore -> SegOp lvl tlore)
-> m [Type] -> m (KernelBody tlore -> SegOp lvl tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (Type -> m Type) -> [Type] -> m [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SegOpMapper lvl flore tlore m -> Type -> m Type
forall (m :: * -> *) lvl flore tlore.
Monad m =>
SegOpMapper lvl flore tlore m -> Type -> m Type
mapOnSegOpType SegOpMapper lvl flore tlore m
tv) [Type]
ts
  m (KernelBody tlore -> SegOp lvl tlore)
-> m (KernelBody tlore) -> m (SegOp lvl tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper lvl flore tlore m
-> KernelBody flore -> m (KernelBody tlore)
forall lvl flore tlore (m :: * -> *).
SegOpMapper lvl flore tlore m
-> KernelBody flore -> m (KernelBody tlore)
mapOnSegOpBody SegOpMapper lvl flore tlore m
tv KernelBody flore
body
mapSegOpM SegOpMapper lvl flore tlore m
tv (SegRed lvl
lvl SegSpace
space [SegBinOp flore]
reds [Type]
ts KernelBody flore
lam) =
  lvl
-> SegSpace
-> [SegBinOp tlore]
-> [Type]
-> KernelBody tlore
-> SegOp lvl tlore
forall lvl lore.
lvl
-> SegSpace
-> [SegBinOp lore]
-> [Type]
-> KernelBody lore
-> SegOp lvl lore
SegRed
  (lvl
 -> SegSpace
 -> [SegBinOp tlore]
 -> [Type]
 -> KernelBody tlore
 -> SegOp lvl tlore)
-> m lvl
-> m (SegSpace
      -> [SegBinOp tlore]
      -> [Type]
      -> KernelBody tlore
      -> SegOp lvl tlore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper lvl flore tlore m -> lvl -> m lvl
forall lvl flore tlore (m :: * -> *).
SegOpMapper lvl flore tlore m -> lvl -> m lvl
mapOnSegOpLevel SegOpMapper lvl flore tlore m
tv lvl
lvl
  m (SegSpace
   -> [SegBinOp tlore]
   -> [Type]
   -> KernelBody tlore
   -> SegOp lvl tlore)
-> m SegSpace
-> m ([SegBinOp tlore]
      -> [Type] -> KernelBody tlore -> SegOp lvl tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper lvl flore tlore m -> SegSpace -> m SegSpace
forall (f :: * -> *) lvl flore tlore.
Monad f =>
SegOpMapper lvl flore tlore f -> SegSpace -> f SegSpace
mapOnSegSpace SegOpMapper lvl flore tlore m
tv SegSpace
space
  m ([SegBinOp tlore]
   -> [Type] -> KernelBody tlore -> SegOp lvl tlore)
-> m [SegBinOp tlore]
-> m ([Type] -> KernelBody tlore -> SegOp lvl tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SegBinOp flore -> m (SegBinOp tlore))
-> [SegBinOp flore] -> m [SegBinOp tlore]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SegOpMapper lvl flore tlore m
-> SegBinOp flore -> m (SegBinOp tlore)
forall (m :: * -> *) lvl flore tlore.
Monad m =>
SegOpMapper lvl flore tlore m
-> SegBinOp flore -> m (SegBinOp tlore)
mapSegBinOp SegOpMapper lvl flore tlore m
tv) [SegBinOp flore]
reds
  m ([Type] -> KernelBody tlore -> SegOp lvl tlore)
-> m [Type] -> m (KernelBody tlore -> SegOp lvl tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (Type -> m Type) -> [Type] -> m [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((SubExp -> m SubExp) -> Type -> m Type
forall (m :: * -> *).
Monad m =>
(SubExp -> m SubExp) -> Type -> m Type
mapOnType ((SubExp -> m SubExp) -> Type -> m Type)
-> (SubExp -> m SubExp) -> Type -> m Type
forall a b. (a -> b) -> a -> b
$ SegOpMapper lvl flore tlore m -> SubExp -> m SubExp
forall lvl flore tlore (m :: * -> *).
SegOpMapper lvl flore tlore m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl flore tlore m
tv) [Type]
ts
  m (KernelBody tlore -> SegOp lvl tlore)
-> m (KernelBody tlore) -> m (SegOp lvl tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper lvl flore tlore m
-> KernelBody flore -> m (KernelBody tlore)
forall lvl flore tlore (m :: * -> *).
SegOpMapper lvl flore tlore m
-> KernelBody flore -> m (KernelBody tlore)
mapOnSegOpBody SegOpMapper lvl flore tlore m
tv KernelBody flore
lam
mapSegOpM SegOpMapper lvl flore tlore m
tv (SegScan lvl
lvl SegSpace
space [SegBinOp flore]
scans [Type]
ts KernelBody flore
body) =
  lvl
-> SegSpace
-> [SegBinOp tlore]
-> [Type]
-> KernelBody tlore
-> SegOp lvl tlore
forall lvl lore.
lvl
-> SegSpace
-> [SegBinOp lore]
-> [Type]
-> KernelBody lore
-> SegOp lvl lore
SegScan
  (lvl
 -> SegSpace
 -> [SegBinOp tlore]
 -> [Type]
 -> KernelBody tlore
 -> SegOp lvl tlore)
-> m lvl
-> m (SegSpace
      -> [SegBinOp tlore]
      -> [Type]
      -> KernelBody tlore
      -> SegOp lvl tlore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper lvl flore tlore m -> lvl -> m lvl
forall lvl flore tlore (m :: * -> *).
SegOpMapper lvl flore tlore m -> lvl -> m lvl
mapOnSegOpLevel SegOpMapper lvl flore tlore m
tv lvl
lvl
  m (SegSpace
   -> [SegBinOp tlore]
   -> [Type]
   -> KernelBody tlore
   -> SegOp lvl tlore)
-> m SegSpace
-> m ([SegBinOp tlore]
      -> [Type] -> KernelBody tlore -> SegOp lvl tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper lvl flore tlore m -> SegSpace -> m SegSpace
forall (f :: * -> *) lvl flore tlore.
Monad f =>
SegOpMapper lvl flore tlore f -> SegSpace -> f SegSpace
mapOnSegSpace SegOpMapper lvl flore tlore m
tv SegSpace
space
  m ([SegBinOp tlore]
   -> [Type] -> KernelBody tlore -> SegOp lvl tlore)
-> m [SegBinOp tlore]
-> m ([Type] -> KernelBody tlore -> SegOp lvl tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SegBinOp flore -> m (SegBinOp tlore))
-> [SegBinOp flore] -> m [SegBinOp tlore]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SegOpMapper lvl flore tlore m
-> SegBinOp flore -> m (SegBinOp tlore)
forall (m :: * -> *) lvl flore tlore.
Monad m =>
SegOpMapper lvl flore tlore m
-> SegBinOp flore -> m (SegBinOp tlore)
mapSegBinOp SegOpMapper lvl flore tlore m
tv) [SegBinOp flore]
scans
  m ([Type] -> KernelBody tlore -> SegOp lvl tlore)
-> m [Type] -> m (KernelBody tlore -> SegOp lvl tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (Type -> m Type) -> [Type] -> m [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((SubExp -> m SubExp) -> Type -> m Type
forall (m :: * -> *).
Monad m =>
(SubExp -> m SubExp) -> Type -> m Type
mapOnType ((SubExp -> m SubExp) -> Type -> m Type)
-> (SubExp -> m SubExp) -> Type -> m Type
forall a b. (a -> b) -> a -> b
$ SegOpMapper lvl flore tlore m -> SubExp -> m SubExp
forall lvl flore tlore (m :: * -> *).
SegOpMapper lvl flore tlore m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl flore tlore m
tv) [Type]
ts
  m (KernelBody tlore -> SegOp lvl tlore)
-> m (KernelBody tlore) -> m (SegOp lvl tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper lvl flore tlore m
-> KernelBody flore -> m (KernelBody tlore)
forall lvl flore tlore (m :: * -> *).
SegOpMapper lvl flore tlore m
-> KernelBody flore -> m (KernelBody tlore)
mapOnSegOpBody SegOpMapper lvl flore tlore m
tv KernelBody flore
body
mapSegOpM SegOpMapper lvl flore tlore m
tv (SegHist lvl
lvl SegSpace
space [HistOp flore]
ops [Type]
ts KernelBody flore
body) =
  lvl
-> SegSpace
-> [HistOp tlore]
-> [Type]
-> KernelBody tlore
-> SegOp lvl tlore
forall lvl lore.
lvl
-> SegSpace
-> [HistOp lore]
-> [Type]
-> KernelBody lore
-> SegOp lvl lore
SegHist
  (lvl
 -> SegSpace
 -> [HistOp tlore]
 -> [Type]
 -> KernelBody tlore
 -> SegOp lvl tlore)
-> m lvl
-> m (SegSpace
      -> [HistOp tlore] -> [Type] -> KernelBody tlore -> SegOp lvl tlore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper lvl flore tlore m -> lvl -> m lvl
forall lvl flore tlore (m :: * -> *).
SegOpMapper lvl flore tlore m -> lvl -> m lvl
mapOnSegOpLevel SegOpMapper lvl flore tlore m
tv lvl
lvl
  m (SegSpace
   -> [HistOp tlore] -> [Type] -> KernelBody tlore -> SegOp lvl tlore)
-> m SegSpace
-> m ([HistOp tlore]
      -> [Type] -> KernelBody tlore -> SegOp lvl tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper lvl flore tlore m -> SegSpace -> m SegSpace
forall (f :: * -> *) lvl flore tlore.
Monad f =>
SegOpMapper lvl flore tlore f -> SegSpace -> f SegSpace
mapOnSegSpace SegOpMapper lvl flore tlore m
tv SegSpace
space
  m ([HistOp tlore] -> [Type] -> KernelBody tlore -> SegOp lvl tlore)
-> m [HistOp tlore]
-> m ([Type] -> KernelBody tlore -> SegOp lvl tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (HistOp flore -> m (HistOp tlore))
-> [HistOp flore] -> m [HistOp tlore]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM HistOp flore -> m (HistOp tlore)
onHistOp [HistOp flore]
ops
  m ([Type] -> KernelBody tlore -> SegOp lvl tlore)
-> m [Type] -> m (KernelBody tlore -> SegOp lvl tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (Type -> m Type) -> [Type] -> m [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((SubExp -> m SubExp) -> Type -> m Type
forall (m :: * -> *).
Monad m =>
(SubExp -> m SubExp) -> Type -> m Type
mapOnType ((SubExp -> m SubExp) -> Type -> m Type)
-> (SubExp -> m SubExp) -> Type -> m Type
forall a b. (a -> b) -> a -> b
$ SegOpMapper lvl flore tlore m -> SubExp -> m SubExp
forall lvl flore tlore (m :: * -> *).
SegOpMapper lvl flore tlore m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl flore tlore m
tv) [Type]
ts
  m (KernelBody tlore -> SegOp lvl tlore)
-> m (KernelBody tlore) -> m (SegOp lvl tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper lvl flore tlore m
-> KernelBody flore -> m (KernelBody tlore)
forall lvl flore tlore (m :: * -> *).
SegOpMapper lvl flore tlore m
-> KernelBody flore -> m (KernelBody tlore)
mapOnSegOpBody SegOpMapper lvl flore tlore m
tv KernelBody flore
body
  where onHistOp :: HistOp flore -> m (HistOp tlore)
onHistOp (HistOp SubExp
w SubExp
rf [VName]
arrs [SubExp]
nes Shape
shape Lambda flore
op) =
          SubExp
-> SubExp
-> [VName]
-> [SubExp]
-> Shape
-> Lambda tlore
-> HistOp tlore
forall lore.
SubExp
-> SubExp
-> [VName]
-> [SubExp]
-> Shape
-> Lambda lore
-> HistOp lore
HistOp (SubExp
 -> SubExp
 -> [VName]
 -> [SubExp]
 -> Shape
 -> Lambda tlore
 -> HistOp tlore)
-> m SubExp
-> m (SubExp
      -> [VName] -> [SubExp] -> Shape -> Lambda tlore -> HistOp tlore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper lvl flore tlore m -> SubExp -> m SubExp
forall lvl flore tlore (m :: * -> *).
SegOpMapper lvl flore tlore m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl flore tlore m
tv SubExp
w
          m (SubExp
   -> [VName] -> [SubExp] -> Shape -> Lambda tlore -> HistOp tlore)
-> m SubExp
-> m ([VName] -> [SubExp] -> Shape -> Lambda tlore -> HistOp tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper lvl flore tlore m -> SubExp -> m SubExp
forall lvl flore tlore (m :: * -> *).
SegOpMapper lvl flore tlore m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl flore tlore m
tv SubExp
rf
          m ([VName] -> [SubExp] -> Shape -> Lambda tlore -> HistOp tlore)
-> m [VName]
-> m ([SubExp] -> Shape -> Lambda tlore -> HistOp tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (VName -> m VName) -> [VName] -> m [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SegOpMapper lvl flore tlore m -> VName -> m VName
forall lvl flore tlore (m :: * -> *).
SegOpMapper lvl flore tlore m -> VName -> m VName
mapOnSegOpVName SegOpMapper lvl flore tlore m
tv) [VName]
arrs
          m ([SubExp] -> Shape -> Lambda tlore -> HistOp tlore)
-> m [SubExp] -> m (Shape -> Lambda tlore -> HistOp tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SubExp -> m SubExp) -> [SubExp] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SegOpMapper lvl flore tlore m -> SubExp -> m SubExp
forall lvl flore tlore (m :: * -> *).
SegOpMapper lvl flore tlore m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl flore tlore m
tv) [SubExp]
nes
          m (Shape -> Lambda tlore -> HistOp tlore)
-> m Shape -> m (Lambda tlore -> HistOp tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> Shape) -> m [SubExp] -> m Shape
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExp -> m SubExp) -> [SubExp] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SegOpMapper lvl flore tlore m -> SubExp -> m SubExp
forall lvl flore tlore (m :: * -> *).
SegOpMapper lvl flore tlore m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl flore tlore m
tv) (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape))
          m (Lambda tlore -> HistOp tlore)
-> m (Lambda tlore) -> m (HistOp tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper lvl flore tlore m -> Lambda flore -> m (Lambda tlore)
forall lvl flore tlore (m :: * -> *).
SegOpMapper lvl flore tlore m -> Lambda flore -> m (Lambda tlore)
mapOnSegOpLambda SegOpMapper lvl flore tlore m
tv Lambda flore
op

mapOnSegOpType :: Monad m =>
                  SegOpMapper lvl flore tlore m -> Type -> m Type
mapOnSegOpType :: SegOpMapper lvl flore tlore m -> Type -> m Type
mapOnSegOpType SegOpMapper lvl flore tlore m
_tv (Prim PrimType
pt) = Type -> m Type
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Type -> m Type) -> Type -> m Type
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
pt
mapOnSegOpType SegOpMapper lvl flore tlore m
tv (Array PrimType
pt Shape
shape NoUniqueness
u) = PrimType -> Shape -> NoUniqueness -> Type
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
pt (Shape -> NoUniqueness -> Type)
-> m Shape -> m (NoUniqueness -> Type)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Shape -> m Shape
f Shape
shape m (NoUniqueness -> Type) -> m NoUniqueness -> m Type
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> NoUniqueness -> m NoUniqueness
forall (f :: * -> *) a. Applicative f => a -> f a
pure NoUniqueness
u
  where f :: Shape -> m Shape
f (Shape [SubExp]
dims) = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> Shape) -> m [SubExp] -> m Shape
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExp -> m SubExp) -> [SubExp] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SegOpMapper lvl flore tlore m -> SubExp -> m SubExp
forall lvl flore tlore (m :: * -> *).
SegOpMapper lvl flore tlore m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper lvl flore tlore m
tv) [SubExp]
dims
mapOnSegOpType SegOpMapper lvl flore tlore m
_tv (Mem Space
s) = Type -> m Type
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Type -> m Type) -> Type -> m Type
forall a b. (a -> b) -> a -> b
$ Space -> Type
forall shape u. Space -> TypeBase shape u
Mem Space
s

instance (ASTLore lore, Substitute lvl) =>
         Substitute (SegOp lvl lore) where
  substituteNames :: Map VName VName -> SegOp lvl lore -> SegOp lvl lore
substituteNames Map VName VName
subst = Identity (SegOp lvl lore) -> SegOp lvl lore
forall a. Identity a -> a
runIdentity (Identity (SegOp lvl lore) -> SegOp lvl lore)
-> (SegOp lvl lore -> Identity (SegOp lvl lore))
-> SegOp lvl lore
-> SegOp lvl lore
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegOpMapper lvl lore lore Identity
-> SegOp lvl lore -> Identity (SegOp lvl lore)
forall (m :: * -> *) lvl flore tlore.
(Applicative m, Monad m) =>
SegOpMapper lvl flore tlore m
-> SegOp lvl flore -> m (SegOp lvl tlore)
mapSegOpM SegOpMapper lvl lore lore Identity
substitute
    where substitute :: SegOpMapper lvl lore lore Identity
substitute =
            SegOpMapper :: forall lvl flore tlore (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda flore -> m (Lambda tlore))
-> (KernelBody flore -> m (KernelBody tlore))
-> (VName -> m VName)
-> (lvl -> m lvl)
-> SegOpMapper lvl flore tlore m
SegOpMapper { mapOnSegOpSubExp :: SubExp -> Identity SubExp
mapOnSegOpSubExp = SubExp -> Identity SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp -> Identity SubExp)
-> (SubExp -> SubExp) -> SubExp -> Identity SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst
                        , mapOnSegOpLambda :: Lambda lore -> Identity (Lambda lore)
mapOnSegOpLambda = Lambda lore -> Identity (Lambda lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda lore -> Identity (Lambda lore))
-> (Lambda lore -> Lambda lore)
-> Lambda lore
-> Identity (Lambda lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map VName VName -> Lambda lore -> Lambda lore
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst
                        , mapOnSegOpBody :: KernelBody lore -> Identity (KernelBody lore)
mapOnSegOpBody = KernelBody lore -> Identity (KernelBody lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (KernelBody lore -> Identity (KernelBody lore))
-> (KernelBody lore -> KernelBody lore)
-> KernelBody lore
-> Identity (KernelBody lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map VName VName -> KernelBody lore -> KernelBody lore
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst
                        , mapOnSegOpVName :: VName -> Identity VName
mapOnSegOpVName = VName -> Identity VName
forall (m :: * -> *) a. Monad m => a -> m a
return (VName -> Identity VName)
-> (VName -> VName) -> VName -> Identity VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map VName VName -> VName -> VName
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst
                        , mapOnSegOpLevel :: lvl -> Identity lvl
mapOnSegOpLevel = lvl -> Identity lvl
forall (m :: * -> *) a. Monad m => a -> m a
return (lvl -> Identity lvl) -> (lvl -> lvl) -> lvl -> Identity lvl
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map VName VName -> lvl -> lvl
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst
                        }

instance (ASTLore lore, ASTConstraints lvl) =>
         Rename (SegOp lvl lore) where
  rename :: SegOp lvl lore -> RenameM (SegOp lvl lore)
rename = SegOpMapper lvl lore lore RenameM
-> SegOp lvl lore -> RenameM (SegOp lvl lore)
forall (m :: * -> *) lvl flore tlore.
(Applicative m, Monad m) =>
SegOpMapper lvl flore tlore m
-> SegOp lvl flore -> m (SegOp lvl tlore)
mapSegOpM SegOpMapper lvl lore lore RenameM
renamer
    where renamer :: SegOpMapper lvl lore lore RenameM
renamer = (SubExp -> RenameM SubExp)
-> (Lambda lore -> RenameM (Lambda lore))
-> (KernelBody lore -> RenameM (KernelBody lore))
-> (VName -> RenameM VName)
-> (lvl -> RenameM lvl)
-> SegOpMapper lvl lore lore RenameM
forall lvl flore tlore (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda flore -> m (Lambda tlore))
-> (KernelBody flore -> m (KernelBody tlore))
-> (VName -> m VName)
-> (lvl -> m lvl)
-> SegOpMapper lvl flore tlore m
SegOpMapper SubExp -> RenameM SubExp
forall a. Rename a => a -> RenameM a
rename Lambda lore -> RenameM (Lambda lore)
forall a. Rename a => a -> RenameM a
rename KernelBody lore -> RenameM (KernelBody lore)
forall a. Rename a => a -> RenameM a
rename VName -> RenameM VName
forall a. Rename a => a -> RenameM a
rename lvl -> RenameM lvl
forall a. Rename a => a -> RenameM a
rename

instance (ASTLore lore, FreeIn (LParamInfo lore), FreeIn lvl) =>
         FreeIn (SegOp lvl lore) where
  freeIn' :: SegOp lvl lore -> FV
freeIn' SegOp lvl lore
e = (State FV (SegOp lvl lore) -> FV -> FV)
-> FV -> State FV (SegOp lvl lore) -> FV
forall a b c. (a -> b -> c) -> b -> a -> c
flip State FV (SegOp lvl lore) -> FV -> FV
forall s a. State s a -> s -> s
execState FV
forall a. Monoid a => a
mempty (State FV (SegOp lvl lore) -> FV)
-> State FV (SegOp lvl lore) -> FV
forall a b. (a -> b) -> a -> b
$ SegOpMapper lvl lore lore (StateT FV Identity)
-> SegOp lvl lore -> State FV (SegOp lvl lore)
forall (m :: * -> *) lvl flore tlore.
(Applicative m, Monad m) =>
SegOpMapper lvl flore tlore m
-> SegOp lvl flore -> m (SegOp lvl tlore)
mapSegOpM SegOpMapper lvl lore lore (StateT FV Identity)
free SegOp lvl lore
e
    where walk :: (b -> s) -> b -> m b
walk b -> s
f b
x = (s -> s) -> m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (s -> s -> s
forall a. Semigroup a => a -> a -> a
<>b -> s
f b
x) m () -> m b -> m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> b -> m b
forall (m :: * -> *) a. Monad m => a -> m a
return b
x
          free :: SegOpMapper lvl lore lore (StateT FV Identity)
free = SegOpMapper :: forall lvl flore tlore (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda flore -> m (Lambda tlore))
-> (KernelBody flore -> m (KernelBody tlore))
-> (VName -> m VName)
-> (lvl -> m lvl)
-> SegOpMapper lvl flore tlore m
SegOpMapper { mapOnSegOpSubExp :: SubExp -> StateT FV Identity SubExp
mapOnSegOpSubExp = (SubExp -> FV) -> SubExp -> StateT FV Identity SubExp
forall (m :: * -> *) s b.
(MonadState s m, Semigroup s) =>
(b -> s) -> b -> m b
walk SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn'
                             , mapOnSegOpLambda :: Lambda lore -> StateT FV Identity (Lambda lore)
mapOnSegOpLambda = (Lambda lore -> FV)
-> Lambda lore -> StateT FV Identity (Lambda lore)
forall (m :: * -> *) s b.
(MonadState s m, Semigroup s) =>
(b -> s) -> b -> m b
walk Lambda lore -> FV
forall a. FreeIn a => a -> FV
freeIn'
                             , mapOnSegOpBody :: KernelBody lore -> StateT FV Identity (KernelBody lore)
mapOnSegOpBody = (KernelBody lore -> FV)
-> KernelBody lore -> StateT FV Identity (KernelBody lore)
forall (m :: * -> *) s b.
(MonadState s m, Semigroup s) =>
(b -> s) -> b -> m b
walk KernelBody lore -> FV
forall a. FreeIn a => a -> FV
freeIn'
                             , mapOnSegOpVName :: VName -> StateT FV Identity VName
mapOnSegOpVName = (VName -> FV) -> VName -> StateT FV Identity VName
forall (m :: * -> *) s b.
(MonadState s m, Semigroup s) =>
(b -> s) -> b -> m b
walk VName -> FV
forall a. FreeIn a => a -> FV
freeIn'
                             , mapOnSegOpLevel :: lvl -> StateT FV Identity lvl
mapOnSegOpLevel = (lvl -> FV) -> lvl -> StateT FV Identity lvl
forall (m :: * -> *) s b.
(MonadState s m, Semigroup s) =>
(b -> s) -> b -> m b
walk lvl -> FV
forall a. FreeIn a => a -> FV
freeIn'
                             }

instance OpMetrics (Op lore) => OpMetrics (SegOp lvl lore) where
  opMetrics :: SegOp lvl lore -> MetricsM ()
opMetrics (SegMap lvl
_ SegSpace
_ [Type]
_ KernelBody lore
body) =
    Text -> MetricsM () -> MetricsM ()
inside Text
"SegMap" (MetricsM () -> MetricsM ()) -> MetricsM () -> MetricsM ()
forall a b. (a -> b) -> a -> b
$ KernelBody lore -> MetricsM ()
forall lore. OpMetrics (Op lore) => KernelBody lore -> MetricsM ()
kernelBodyMetrics KernelBody lore
body
  opMetrics (SegRed lvl
_ SegSpace
_ [SegBinOp lore]
reds [Type]
_ KernelBody lore
body) =
    Text -> MetricsM () -> MetricsM ()
inside Text
"SegRed" (MetricsM () -> MetricsM ()) -> MetricsM () -> MetricsM ()
forall a b. (a -> b) -> a -> b
$ do (SegBinOp lore -> MetricsM ()) -> [SegBinOp lore] -> MetricsM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Lambda lore -> MetricsM ()
forall lore. OpMetrics (Op lore) => Lambda lore -> MetricsM ()
lambdaMetrics (Lambda lore -> MetricsM ())
-> (SegBinOp lore -> Lambda lore) -> SegBinOp lore -> MetricsM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOp lore -> Lambda lore
forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda) [SegBinOp lore]
reds
                         KernelBody lore -> MetricsM ()
forall lore. OpMetrics (Op lore) => KernelBody lore -> MetricsM ()
kernelBodyMetrics KernelBody lore
body
  opMetrics (SegScan lvl
_ SegSpace
_ [SegBinOp lore]
scans [Type]
_ KernelBody lore
body) =
    Text -> MetricsM () -> MetricsM ()
inside Text
"SegScan" (MetricsM () -> MetricsM ()) -> MetricsM () -> MetricsM ()
forall a b. (a -> b) -> a -> b
$ do (SegBinOp lore -> MetricsM ()) -> [SegBinOp lore] -> MetricsM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Lambda lore -> MetricsM ()
forall lore. OpMetrics (Op lore) => Lambda lore -> MetricsM ()
lambdaMetrics (Lambda lore -> MetricsM ())
-> (SegBinOp lore -> Lambda lore) -> SegBinOp lore -> MetricsM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOp lore -> Lambda lore
forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda) [SegBinOp lore]
scans
                          KernelBody lore -> MetricsM ()
forall lore. OpMetrics (Op lore) => KernelBody lore -> MetricsM ()
kernelBodyMetrics KernelBody lore
body
  opMetrics (SegHist lvl
_ SegSpace
_ [HistOp lore]
ops [Type]
_ KernelBody lore
body) =
    Text -> MetricsM () -> MetricsM ()
inside Text
"SegHist" (MetricsM () -> MetricsM ()) -> MetricsM () -> MetricsM ()
forall a b. (a -> b) -> a -> b
$ do (HistOp lore -> MetricsM ()) -> [HistOp lore] -> MetricsM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Lambda lore -> MetricsM ()
forall lore. OpMetrics (Op lore) => Lambda lore -> MetricsM ()
lambdaMetrics (Lambda lore -> MetricsM ())
-> (HistOp lore -> Lambda lore) -> HistOp lore -> MetricsM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp lore -> Lambda lore
forall lore. HistOp lore -> Lambda lore
histOp) [HistOp lore]
ops
                          KernelBody lore -> MetricsM ()
forall lore. OpMetrics (Op lore) => KernelBody lore -> MetricsM ()
kernelBodyMetrics KernelBody lore
body

instance Pretty SegSpace where
  ppr :: SegSpace -> Doc
ppr (SegSpace VName
phys [(VName, SubExp)]
dims) = Doc -> Doc
parens ([Doc] -> Doc
commasep ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ do (VName
i,SubExp
d) <- [(VName, SubExp)]
dims
                                                   Doc -> [Doc]
forall (m :: * -> *) a. Monad m => a -> m a
return (Doc -> [Doc]) -> Doc -> [Doc]
forall a b. (a -> b) -> a -> b
$ VName -> Doc
forall a. Pretty a => a -> Doc
ppr VName
i Doc -> Doc -> Doc
<+> Doc
"<" Doc -> Doc -> Doc
<+> SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
d) Doc -> Doc -> Doc
<+>
                             Doc -> Doc
parens (String -> Doc
text String
"~" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> VName -> Doc
forall a. Pretty a => a -> Doc
ppr VName
phys)

instance PrettyLore lore => Pretty (SegBinOp lore) where
  ppr :: SegBinOp lore -> Doc
ppr (SegBinOp Commutativity
comm Lambda lore
lam [SubExp]
nes Shape
shape) =
    Doc -> Doc
PP.braces ([Doc] -> Doc
PP.commasep ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ (SubExp -> Doc) -> [SubExp] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr [SubExp]
nes) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.comma Doc -> Doc -> Doc
</>
    Shape -> Doc
forall a. Pretty a => a -> Doc
ppr Shape
shape Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.comma Doc -> Doc -> Doc
</>
    Doc
comm' Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Lambda lore -> Doc
forall a. Pretty a => a -> Doc
ppr Lambda lore
lam
    where comm' :: Doc
comm' = case Commutativity
comm of Commutativity
Commutative -> String -> Doc
text String
"commutative "
                               Commutativity
Noncommutative -> Doc
forall a. Monoid a => a
mempty

instance (PrettyLore lore, PP.Pretty lvl) => PP.Pretty (SegOp lvl lore) where
  ppr :: SegOp lvl lore -> Doc
ppr (SegMap lvl
lvl SegSpace
space [Type]
ts KernelBody lore
body) =
    String -> Doc
text String
"segmap" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> lvl -> Doc
forall a. Pretty a => a -> Doc
ppr lvl
lvl Doc -> Doc -> Doc
</>
    Doc -> Doc
PP.align (SegSpace -> Doc
forall a. Pretty a => a -> Doc
ppr SegSpace
space) Doc -> Doc -> Doc
<+>
    Doc
PP.colon Doc -> Doc -> Doc
<+> [Type] -> Doc
forall a. Pretty a => [a] -> Doc
ppTuple' [Type]
ts Doc -> Doc -> Doc
<+> String -> String -> Doc -> Doc
PP.nestedBlock String
"{" String
"}" (KernelBody lore -> Doc
forall a. Pretty a => a -> Doc
ppr KernelBody lore
body)

  ppr (SegRed lvl
lvl SegSpace
space [SegBinOp lore]
reds [Type]
ts KernelBody lore
body) =
    String -> Doc
text String
"segred" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> lvl -> Doc
forall a. Pretty a => a -> Doc
ppr lvl
lvl Doc -> Doc -> Doc
</>
    Doc -> Doc
PP.parens (Doc -> Doc
PP.braces ([Doc] -> Doc
forall a. Monoid a => [a] -> a
mconcat ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ Doc -> [Doc] -> [Doc]
forall a. a -> [a] -> [a]
intersperse (Doc
PP.comma Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.line) ([Doc] -> [Doc]) -> [Doc] -> [Doc]
forall a b. (a -> b) -> a -> b
$ (SegBinOp lore -> Doc) -> [SegBinOp lore] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp lore -> Doc
forall a. Pretty a => a -> Doc
ppr [SegBinOp lore]
reds)) Doc -> Doc -> Doc
</>
    Doc -> Doc
PP.align (SegSpace -> Doc
forall a. Pretty a => a -> Doc
ppr SegSpace
space) Doc -> Doc -> Doc
<+> Doc
PP.colon Doc -> Doc -> Doc
<+> [Type] -> Doc
forall a. Pretty a => [a] -> Doc
ppTuple' [Type]
ts Doc -> Doc -> Doc
<+>
    String -> String -> Doc -> Doc
PP.nestedBlock String
"{" String
"}" (KernelBody lore -> Doc
forall a. Pretty a => a -> Doc
ppr KernelBody lore
body)

  ppr (SegScan lvl
lvl SegSpace
space [SegBinOp lore]
scans [Type]
ts KernelBody lore
body) =
    String -> Doc
text String
"segscan" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> lvl -> Doc
forall a. Pretty a => a -> Doc
ppr lvl
lvl Doc -> Doc -> Doc
</>
    Doc -> Doc
PP.parens (Doc -> Doc
PP.braces ([Doc] -> Doc
forall a. Monoid a => [a] -> a
mconcat ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ Doc -> [Doc] -> [Doc]
forall a. a -> [a] -> [a]
intersperse (Doc
PP.comma Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.line) ([Doc] -> [Doc]) -> [Doc] -> [Doc]
forall a b. (a -> b) -> a -> b
$ (SegBinOp lore -> Doc) -> [SegBinOp lore] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp lore -> Doc
forall a. Pretty a => a -> Doc
ppr [SegBinOp lore]
scans)) Doc -> Doc -> Doc
</>
    Doc -> Doc
PP.align (SegSpace -> Doc
forall a. Pretty a => a -> Doc
ppr SegSpace
space) Doc -> Doc -> Doc
<+> Doc
PP.colon Doc -> Doc -> Doc
<+> [Type] -> Doc
forall a. Pretty a => [a] -> Doc
ppTuple' [Type]
ts Doc -> Doc -> Doc
<+>
    String -> String -> Doc -> Doc
PP.nestedBlock String
"{" String
"}" (KernelBody lore -> Doc
forall a. Pretty a => a -> Doc
ppr KernelBody lore
body)

  ppr (SegHist lvl
lvl SegSpace
space [HistOp lore]
ops [Type]
ts KernelBody lore
body) =
    String -> Doc
text String
"seghist" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> lvl -> Doc
forall a. Pretty a => a -> Doc
ppr lvl
lvl Doc -> Doc -> Doc
</>
    lvl -> Doc
forall a. Pretty a => a -> Doc
ppr lvl
lvl Doc -> Doc -> Doc
</>
    Doc -> Doc
PP.parens (Doc -> Doc
PP.braces ([Doc] -> Doc
forall a. Monoid a => [a] -> a
mconcat ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ Doc -> [Doc] -> [Doc]
forall a. a -> [a] -> [a]
intersperse (Doc
PP.comma Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.line) ([Doc] -> [Doc]) -> [Doc] -> [Doc]
forall a b. (a -> b) -> a -> b
$ (HistOp lore -> Doc) -> [HistOp lore] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map HistOp lore -> Doc
forall lore. PrettyLore lore => HistOp lore -> Doc
ppOp [HistOp lore]
ops)) Doc -> Doc -> Doc
</>
    Doc -> Doc
PP.align (SegSpace -> Doc
forall a. Pretty a => a -> Doc
ppr SegSpace
space) Doc -> Doc -> Doc
<+> Doc
PP.colon Doc -> Doc -> Doc
<+> [Type] -> Doc
forall a. Pretty a => [a] -> Doc
ppTuple' [Type]
ts Doc -> Doc -> Doc
<+>
    String -> String -> Doc -> Doc
PP.nestedBlock String
"{" String
"}" (KernelBody lore -> Doc
forall a. Pretty a => a -> Doc
ppr KernelBody lore
body)
    where ppOp :: HistOp lore -> Doc
ppOp (HistOp SubExp
w SubExp
rf [VName]
dests [SubExp]
nes Shape
shape Lambda lore
op) =
            SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
w Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.comma Doc -> Doc -> Doc
<+> SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
rf Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.comma Doc -> Doc -> Doc
</>
            Doc -> Doc
PP.braces ([Doc] -> Doc
PP.commasep ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ (VName -> Doc) -> [VName] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Doc
forall a. Pretty a => a -> Doc
ppr [VName]
dests) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.comma Doc -> Doc -> Doc
</>
            Doc -> Doc
PP.braces ([Doc] -> Doc
PP.commasep ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ (SubExp -> Doc) -> [SubExp] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr [SubExp]
nes) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.comma Doc -> Doc -> Doc
</>
            Shape -> Doc
forall a. Pretty a => a -> Doc
ppr Shape
shape Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.comma Doc -> Doc -> Doc
</>
            Lambda lore -> Doc
forall a. Pretty a => a -> Doc
ppr Lambda lore
op

instance (ASTLore lore, ASTLore (Aliases lore),
          CanBeAliased (Op lore), ASTConstraints lvl) =>
         CanBeAliased (SegOp lvl lore) where
  type OpWithAliases (SegOp lvl lore) = SegOp lvl (Aliases lore)

  addOpAliases :: SegOp lvl lore -> OpWithAliases (SegOp lvl lore)
addOpAliases = Identity (SegOp lvl (Aliases lore)) -> SegOp lvl (Aliases lore)
forall a. Identity a -> a
runIdentity (Identity (SegOp lvl (Aliases lore)) -> SegOp lvl (Aliases lore))
-> (SegOp lvl lore -> Identity (SegOp lvl (Aliases lore)))
-> SegOp lvl lore
-> SegOp lvl (Aliases lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegOpMapper lvl lore (Aliases lore) Identity
-> SegOp lvl lore -> Identity (SegOp lvl (Aliases lore))
forall (m :: * -> *) lvl flore tlore.
(Applicative m, Monad m) =>
SegOpMapper lvl flore tlore m
-> SegOp lvl flore -> m (SegOp lvl tlore)
mapSegOpM SegOpMapper lvl lore (Aliases lore) Identity
forall lvl. SegOpMapper lvl lore (Aliases lore) Identity
alias
    where alias :: SegOpMapper lvl lore (Aliases lore) Identity
alias = (SubExp -> Identity SubExp)
-> (Lambda lore -> Identity (Lambda (Aliases lore)))
-> (KernelBody lore -> Identity (KernelBody (Aliases lore)))
-> (VName -> Identity VName)
-> (lvl -> Identity lvl)
-> SegOpMapper lvl lore (Aliases lore) Identity
forall lvl flore tlore (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda flore -> m (Lambda tlore))
-> (KernelBody flore -> m (KernelBody tlore))
-> (VName -> m VName)
-> (lvl -> m lvl)
-> SegOpMapper lvl flore tlore m
SegOpMapper SubExp -> Identity SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda (Aliases lore) -> Identity (Lambda (Aliases lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda (Aliases lore) -> Identity (Lambda (Aliases lore)))
-> (Lambda lore -> Lambda (Aliases lore))
-> Lambda lore
-> Identity (Lambda (Aliases lore))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda lore -> Lambda (Aliases lore)
forall lore.
(ASTLore lore, CanBeAliased (Op lore)) =>
Lambda lore -> Lambda (Aliases lore)
Alias.analyseLambda)
                  (KernelBody (Aliases lore) -> Identity (KernelBody (Aliases lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (KernelBody (Aliases lore) -> Identity (KernelBody (Aliases lore)))
-> (KernelBody lore -> KernelBody (Aliases lore))
-> KernelBody lore
-> Identity (KernelBody (Aliases lore))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelBody lore -> KernelBody (Aliases lore)
forall lore.
(ASTLore lore, CanBeAliased (Op lore)) =>
KernelBody lore -> KernelBody (Aliases lore)
aliasAnalyseKernelBody) VName -> Identity VName
forall (m :: * -> *) a. Monad m => a -> m a
return lvl -> Identity lvl
forall (m :: * -> *) a. Monad m => a -> m a
return

  removeOpAliases :: OpWithAliases (SegOp lvl lore) -> SegOp lvl lore
removeOpAliases = Identity (SegOp lvl lore) -> SegOp lvl lore
forall a. Identity a -> a
runIdentity (Identity (SegOp lvl lore) -> SegOp lvl lore)
-> (SegOp lvl (Aliases lore) -> Identity (SegOp lvl lore))
-> SegOp lvl (Aliases lore)
-> SegOp lvl lore
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegOpMapper lvl (Aliases lore) lore Identity
-> SegOp lvl (Aliases lore) -> Identity (SegOp lvl lore)
forall (m :: * -> *) lvl flore tlore.
(Applicative m, Monad m) =>
SegOpMapper lvl flore tlore m
-> SegOp lvl flore -> m (SegOp lvl tlore)
mapSegOpM SegOpMapper lvl (Aliases lore) lore Identity
forall lvl. SegOpMapper lvl (Aliases lore) lore Identity
remove
    where remove :: SegOpMapper lvl (Aliases lore) lore Identity
remove = (SubExp -> Identity SubExp)
-> (Lambda (Aliases lore) -> Identity (Lambda lore))
-> (KernelBody (Aliases lore) -> Identity (KernelBody lore))
-> (VName -> Identity VName)
-> (lvl -> Identity lvl)
-> SegOpMapper lvl (Aliases lore) lore Identity
forall lvl flore tlore (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda flore -> m (Lambda tlore))
-> (KernelBody flore -> m (KernelBody tlore))
-> (VName -> m VName)
-> (lvl -> m lvl)
-> SegOpMapper lvl flore tlore m
SegOpMapper SubExp -> Identity SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda lore -> Identity (Lambda lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda lore -> Identity (Lambda lore))
-> (Lambda (Aliases lore) -> Lambda lore)
-> Lambda (Aliases lore)
-> Identity (Lambda lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda (Aliases lore) -> Lambda lore
forall lore.
CanBeAliased (Op lore) =>
Lambda (Aliases lore) -> Lambda lore
removeLambdaAliases)
                   (KernelBody lore -> Identity (KernelBody lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (KernelBody lore -> Identity (KernelBody lore))
-> (KernelBody (Aliases lore) -> KernelBody lore)
-> KernelBody (Aliases lore)
-> Identity (KernelBody lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelBody (Aliases lore) -> KernelBody lore
forall lore.
CanBeAliased (Op lore) =>
KernelBody (Aliases lore) -> KernelBody lore
removeKernelBodyAliases) VName -> Identity VName
forall (m :: * -> *) a. Monad m => a -> m a
return lvl -> Identity lvl
forall (m :: * -> *) a. Monad m => a -> m a
return

instance (CanBeWise (Op lore), ASTLore lore, ASTConstraints lvl) =>
         CanBeWise (SegOp lvl lore) where
  type OpWithWisdom (SegOp lvl lore) = SegOp lvl (Wise lore)

  removeOpWisdom :: OpWithWisdom (SegOp lvl lore) -> SegOp lvl lore
removeOpWisdom = Identity (SegOp lvl lore) -> SegOp lvl lore
forall a. Identity a -> a
runIdentity (Identity (SegOp lvl lore) -> SegOp lvl lore)
-> (SegOp lvl (Wise lore) -> Identity (SegOp lvl lore))
-> SegOp lvl (Wise lore)
-> SegOp lvl lore
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegOpMapper lvl (Wise lore) lore Identity
-> SegOp lvl (Wise lore) -> Identity (SegOp lvl lore)
forall (m :: * -> *) lvl flore tlore.
(Applicative m, Monad m) =>
SegOpMapper lvl flore tlore m
-> SegOp lvl flore -> m (SegOp lvl tlore)
mapSegOpM SegOpMapper lvl (Wise lore) lore Identity
forall lvl. SegOpMapper lvl (Wise lore) lore Identity
remove
    where remove :: SegOpMapper lvl (Wise lore) lore Identity
remove = (SubExp -> Identity SubExp)
-> (Lambda (Wise lore) -> Identity (Lambda lore))
-> (KernelBody (Wise lore) -> Identity (KernelBody lore))
-> (VName -> Identity VName)
-> (lvl -> Identity lvl)
-> SegOpMapper lvl (Wise lore) lore Identity
forall lvl flore tlore (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda flore -> m (Lambda tlore))
-> (KernelBody flore -> m (KernelBody tlore))
-> (VName -> m VName)
-> (lvl -> m lvl)
-> SegOpMapper lvl flore tlore m
SegOpMapper SubExp -> Identity SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return
                   (Lambda lore -> Identity (Lambda lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda lore -> Identity (Lambda lore))
-> (Lambda (Wise lore) -> Lambda lore)
-> Lambda (Wise lore)
-> Identity (Lambda lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda (Wise lore) -> Lambda lore
forall lore.
CanBeWise (Op lore) =>
Lambda (Wise lore) -> Lambda lore
removeLambdaWisdom)
                   (KernelBody lore -> Identity (KernelBody lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (KernelBody lore -> Identity (KernelBody lore))
-> (KernelBody (Wise lore) -> KernelBody lore)
-> KernelBody (Wise lore)
-> Identity (KernelBody lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelBody (Wise lore) -> KernelBody lore
forall lore.
CanBeWise (Op lore) =>
KernelBody (Wise lore) -> KernelBody lore
removeKernelBodyWisdom)
                   VName -> Identity VName
forall (m :: * -> *) a. Monad m => a -> m a
return lvl -> Identity lvl
forall (m :: * -> *) a. Monad m => a -> m a
return

instance ASTLore lore => ST.IndexOp (SegOp lvl lore) where
  indexOp :: SymbolTable lore
-> Int -> SegOp lvl lore -> [PrimExp VName] -> Maybe Indexed
indexOp SymbolTable lore
vtable Int
k (SegMap lvl
_ SegSpace
space [Type]
_ KernelBody lore
kbody) [PrimExp VName]
is = do
    Returns ResultManifest
ResultMaySimplify SubExp
se <- Int -> [KernelResult] -> Maybe KernelResult
forall int a. Integral int => int -> [a] -> Maybe a
maybeNth Int
k ([KernelResult] -> Maybe KernelResult)
-> [KernelResult] -> Maybe KernelResult
forall a b. (a -> b) -> a -> b
$ KernelBody lore -> [KernelResult]
forall lore. KernelBody lore -> [KernelResult]
kernelBodyResult KernelBody lore
kbody
    Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ [VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
gtids Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= [PrimExp VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PrimExp VName]
is
    let idx_table :: Map VName Indexed
idx_table = [(VName, Indexed)] -> Map VName Indexed
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, Indexed)] -> Map VName Indexed)
-> [(VName, Indexed)] -> Map VName Indexed
forall a b. (a -> b) -> a -> b
$ [VName] -> [Indexed] -> [(VName, Indexed)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
gtids ([Indexed] -> [(VName, Indexed)])
-> [Indexed] -> [(VName, Indexed)]
forall a b. (a -> b) -> a -> b
$ (PrimExp VName -> Indexed) -> [PrimExp VName] -> [Indexed]
forall a b. (a -> b) -> [a] -> [b]
map (Certificates -> PrimExp VName -> Indexed
ST.Indexed Certificates
forall a. Monoid a => a
mempty) [PrimExp VName]
is
        idx_table' :: Map VName Indexed
idx_table' = (Map VName Indexed -> Stm lore -> Map VName Indexed)
-> Map VName Indexed -> Seq (Stm lore) -> Map VName Indexed
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Map VName Indexed -> Stm lore -> Map VName Indexed
expandIndexedTable Map VName Indexed
idx_table (Seq (Stm lore) -> Map VName Indexed)
-> Seq (Stm lore) -> Map VName Indexed
forall a b. (a -> b) -> a -> b
$ KernelBody lore -> Seq (Stm lore)
forall lore. KernelBody lore -> Stms lore
kernelBodyStms KernelBody lore
kbody
    case SubExp
se of
      Var VName
v -> VName -> Map VName Indexed -> Maybe Indexed
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName Indexed
idx_table'
      SubExp
_ -> Maybe Indexed
forall a. Maybe a
Nothing

    where ([VName]
gtids, [SubExp]
_) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
          -- Indexes in excess of what is used to index through the
          -- segment dimensions.
          excess_is :: [PrimExp VName]
excess_is = Int -> [PrimExp VName] -> [PrimExp VName]
forall a. Int -> [a] -> [a]
drop ([VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
gtids) [PrimExp VName]
is

          expandIndexedTable :: Map VName Indexed -> Stm lore -> Map VName Indexed
expandIndexedTable Map VName Indexed
table Stm lore
stm
            | [VName
v] <- PatternT (LetDec lore) -> [VName]
forall dec. PatternT dec -> [VName]
patternNames (PatternT (LetDec lore) -> [VName])
-> PatternT (LetDec lore) -> [VName]
forall a b. (a -> b) -> a -> b
$ Stm lore -> PatternT (LetDec lore)
forall lore. Stm lore -> Pattern lore
stmPattern Stm lore
stm,
              Just (PrimExp VName
pe,Certificates
cs) <-
                  WriterT Certificates Maybe (PrimExp VName)
-> Maybe (PrimExp VName, Certificates)
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT Certificates Maybe (PrimExp VName)
 -> Maybe (PrimExp VName, Certificates))
-> WriterT Certificates Maybe (PrimExp VName)
-> Maybe (PrimExp VName, Certificates)
forall a b. (a -> b) -> a -> b
$ (VName -> WriterT Certificates Maybe (PrimExp VName))
-> Exp lore -> WriterT Certificates Maybe (PrimExp VName)
forall (m :: * -> *) lore v.
(MonadFail m, Decorations lore) =>
(VName -> m (PrimExp v)) -> Exp lore -> m (PrimExp v)
primExpFromExp (Map VName Indexed
-> VName -> WriterT Certificates Maybe (PrimExp VName)
asPrimExp Map VName Indexed
table) (Exp lore -> WriterT Certificates Maybe (PrimExp VName))
-> Exp lore -> WriterT Certificates Maybe (PrimExp VName)
forall a b. (a -> b) -> a -> b
$ Stm lore -> Exp lore
forall lore. Stm lore -> Exp lore
stmExp Stm lore
stm =
                VName -> Indexed -> Map VName Indexed -> Map VName Indexed
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v (Certificates -> PrimExp VName -> Indexed
ST.Indexed (Stm lore -> Certificates
forall lore. Stm lore -> Certificates
stmCerts Stm lore
stm Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> Certificates
cs) PrimExp VName
pe) Map VName Indexed
table

            | [VName
v] <- PatternT (LetDec lore) -> [VName]
forall dec. PatternT dec -> [VName]
patternNames (PatternT (LetDec lore) -> [VName])
-> PatternT (LetDec lore) -> [VName]
forall a b. (a -> b) -> a -> b
$ Stm lore -> PatternT (LetDec lore)
forall lore. Stm lore -> Pattern lore
stmPattern Stm lore
stm,
              BasicOp (Index VName
arr Slice SubExp
slice) <- Stm lore -> Exp lore
forall lore. Stm lore -> Exp lore
stmExp Stm lore
stm,
              [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (Slice SubExp -> [SubExp]
forall d. Slice d -> [d]
sliceDims Slice SubExp
slice) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [PrimExp VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PrimExp VName]
excess_is,
              VName
arr VName -> SymbolTable lore -> Bool
forall lore. VName -> SymbolTable lore -> Bool
`ST.elem` SymbolTable lore
vtable,
              Just ([DimIndex (PrimExp VName)]
slice', Certificates
cs) <- Map VName Indexed
-> Slice SubExp -> Maybe ([DimIndex (PrimExp VName)], Certificates)
asPrimExpSlice Map VName Indexed
table Slice SubExp
slice =
                let idx :: Indexed
idx = Certificates -> VName -> [PrimExp VName] -> Indexed
ST.IndexedArray (Stm lore -> Certificates
forall lore. Stm lore -> Certificates
stmCerts Stm lore
stm Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> Certificates
cs)
                          VName
arr ([DimIndex (PrimExp VName)] -> [PrimExp VName] -> [PrimExp VName]
forall d. Num d => Slice d -> [d] -> [d]
fixSlice [DimIndex (PrimExp VName)]
slice' [PrimExp VName]
excess_is)
                in VName -> Indexed -> Map VName Indexed -> Map VName Indexed
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v Indexed
idx Map VName Indexed
table

            | Bool
otherwise =
                Map VName Indexed
table

          asPrimExpSlice :: Map VName Indexed
-> Slice SubExp -> Maybe ([DimIndex (PrimExp VName)], Certificates)
asPrimExpSlice Map VName Indexed
table =
            WriterT Certificates Maybe [DimIndex (PrimExp VName)]
-> Maybe ([DimIndex (PrimExp VName)], Certificates)
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT Certificates Maybe [DimIndex (PrimExp VName)]
 -> Maybe ([DimIndex (PrimExp VName)], Certificates))
-> (Slice SubExp
    -> WriterT Certificates Maybe [DimIndex (PrimExp VName)])
-> Slice SubExp
-> Maybe ([DimIndex (PrimExp VName)], Certificates)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (DimIndex SubExp
 -> WriterT Certificates Maybe (DimIndex (PrimExp VName)))
-> Slice SubExp
-> WriterT Certificates Maybe [DimIndex (PrimExp VName)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((SubExp -> WriterT Certificates Maybe (PrimExp VName))
-> DimIndex SubExp
-> WriterT Certificates Maybe (DimIndex (PrimExp VName))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse ((VName -> WriterT Certificates Maybe (PrimExp VName))
-> SubExp -> WriterT Certificates Maybe (PrimExp VName)
forall (m :: * -> *) v.
Applicative m =>
(VName -> m (PrimExp v)) -> SubExp -> m (PrimExp v)
primExpFromSubExpM (Map VName Indexed
-> VName -> WriterT Certificates Maybe (PrimExp VName)
asPrimExp Map VName Indexed
table)))

          asPrimExp :: Map VName Indexed
-> VName -> WriterT Certificates Maybe (PrimExp VName)
asPrimExp Map VName Indexed
table VName
v
            | Just (ST.Indexed Certificates
cs PrimExp VName
e) <- VName -> Map VName Indexed -> Maybe Indexed
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName Indexed
table = Certificates -> WriterT Certificates Maybe ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell Certificates
cs WriterT Certificates Maybe ()
-> WriterT Certificates Maybe (PrimExp VName)
-> WriterT Certificates Maybe (PrimExp VName)
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> PrimExp VName -> WriterT Certificates Maybe (PrimExp VName)
forall (m :: * -> *) a. Monad m => a -> m a
return PrimExp VName
e
            | Just (Prim PrimType
pt) <- VName -> SymbolTable lore -> Maybe Type
forall lore.
ASTLore lore =>
VName -> SymbolTable lore -> Maybe Type
ST.lookupType VName
v SymbolTable lore
vtable =
                PrimExp VName -> WriterT Certificates Maybe (PrimExp VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (PrimExp VName -> WriterT Certificates Maybe (PrimExp VName))
-> PrimExp VName -> WriterT Certificates Maybe (PrimExp VName)
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> PrimExp VName
forall v. v -> PrimType -> PrimExp v
LeafExp VName
v PrimType
pt
            | Bool
otherwise = Maybe (PrimExp VName) -> WriterT Certificates Maybe (PrimExp VName)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift Maybe (PrimExp VName)
forall a. Maybe a
Nothing

  indexOp SymbolTable lore
_ Int
_ SegOp lvl lore
_ [PrimExp VName]
_ = Maybe Indexed
forall a. Maybe a
Nothing

instance (ASTLore lore, ASTConstraints lvl) =>
         IsOp (SegOp lvl lore) where
  cheapOp :: SegOp lvl lore -> Bool
cheapOp SegOp lvl lore
_ = Bool
False
  safeOp :: SegOp lvl lore -> Bool
safeOp SegOp lvl lore
_ = Bool
True

--- Simplification

instance Engine.Simplifiable SplitOrdering where
  simplify :: SplitOrdering -> SimpleM lore SplitOrdering
simplify SplitOrdering
SplitContiguous =
    SplitOrdering -> SimpleM lore SplitOrdering
forall (m :: * -> *) a. Monad m => a -> m a
return SplitOrdering
SplitContiguous
  simplify (SplitStrided SubExp
stride) =
    SubExp -> SplitOrdering
SplitStrided (SubExp -> SplitOrdering)
-> SimpleM lore SubExp -> SimpleM lore SplitOrdering
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify SubExp
stride

instance Engine.Simplifiable SegSpace where
  simplify :: SegSpace -> SimpleM lore SegSpace
simplify (SegSpace VName
phys [(VName, SubExp)]
dims) =
    VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
phys ([(VName, SubExp)] -> SegSpace)
-> SimpleM lore [(VName, SubExp)] -> SimpleM lore SegSpace
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((VName, SubExp) -> SimpleM lore (VName, SubExp))
-> [(VName, SubExp)] -> SimpleM lore [(VName, SubExp)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((SubExp -> SimpleM lore SubExp)
-> (VName, SubExp) -> SimpleM lore (VName, SubExp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify) [(VName, SubExp)]
dims

instance Engine.Simplifiable KernelResult where
  simplify :: KernelResult -> SimpleM lore KernelResult
simplify (Returns ResultManifest
manifest SubExp
what) =
    ResultManifest -> SubExp -> KernelResult
Returns ResultManifest
manifest (SubExp -> KernelResult)
-> SimpleM lore SubExp -> SimpleM lore KernelResult
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify SubExp
what
  simplify (WriteReturns [SubExp]
ws VName
a [(Slice SubExp, SubExp)]
res) =
    [SubExp] -> VName -> [(Slice SubExp, SubExp)] -> KernelResult
WriteReturns ([SubExp] -> VName -> [(Slice SubExp, SubExp)] -> KernelResult)
-> SimpleM lore [SubExp]
-> SimpleM lore (VName -> [(Slice SubExp, SubExp)] -> KernelResult)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [SubExp] -> SimpleM lore [SubExp]
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify [SubExp]
ws SimpleM lore (VName -> [(Slice SubExp, SubExp)] -> KernelResult)
-> SimpleM lore VName
-> SimpleM lore ([(Slice SubExp, SubExp)] -> KernelResult)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> VName -> SimpleM lore VName
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify VName
a SimpleM lore ([(Slice SubExp, SubExp)] -> KernelResult)
-> SimpleM lore [(Slice SubExp, SubExp)]
-> SimpleM lore KernelResult
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [(Slice SubExp, SubExp)] -> SimpleM lore [(Slice SubExp, SubExp)]
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify [(Slice SubExp, SubExp)]
res
  simplify (ConcatReturns SplitOrdering
o SubExp
w SubExp
pte VName
what) =
    SplitOrdering -> SubExp -> SubExp -> VName -> KernelResult
ConcatReturns
    (SplitOrdering -> SubExp -> SubExp -> VName -> KernelResult)
-> SimpleM lore SplitOrdering
-> SimpleM lore (SubExp -> SubExp -> VName -> KernelResult)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SplitOrdering -> SimpleM lore SplitOrdering
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify SplitOrdering
o
    SimpleM lore (SubExp -> SubExp -> VName -> KernelResult)
-> SimpleM lore SubExp
-> SimpleM lore (SubExp -> VName -> KernelResult)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify SubExp
w
    SimpleM lore (SubExp -> VName -> KernelResult)
-> SimpleM lore SubExp -> SimpleM lore (VName -> KernelResult)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify SubExp
pte
    SimpleM lore (VName -> KernelResult)
-> SimpleM lore VName -> SimpleM lore KernelResult
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> VName -> SimpleM lore VName
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify VName
what
  simplify (TileReturns [(SubExp, SubExp)]
dims VName
what) =
    [(SubExp, SubExp)] -> VName -> KernelResult
TileReturns ([(SubExp, SubExp)] -> VName -> KernelResult)
-> SimpleM lore [(SubExp, SubExp)]
-> SimpleM lore (VName -> KernelResult)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [(SubExp, SubExp)] -> SimpleM lore [(SubExp, SubExp)]
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify [(SubExp, SubExp)]
dims SimpleM lore (VName -> KernelResult)
-> SimpleM lore VName -> SimpleM lore KernelResult
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> VName -> SimpleM lore VName
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify VName
what

mkWiseKernelBody :: (ASTLore lore, CanBeWise (Op lore)) =>
                    BodyDec lore -> Stms (Wise lore) -> [KernelResult] -> KernelBody (Wise lore)
mkWiseKernelBody :: BodyDec lore
-> Stms (Wise lore) -> [KernelResult] -> KernelBody (Wise lore)
mkWiseKernelBody BodyDec lore
dec Stms (Wise lore)
bnds [KernelResult]
res =
  let Body BodyDec (Wise lore)
dec' Stms (Wise lore)
_ [SubExp]
_ = BodyDec lore -> Stms (Wise lore) -> [SubExp] -> BodyT (Wise lore)
forall lore.
(ASTLore lore, CanBeWise (Op lore)) =>
BodyDec lore -> Stms (Wise lore) -> [SubExp] -> Body (Wise lore)
mkWiseBody BodyDec lore
dec Stms (Wise lore)
bnds [SubExp]
res_vs
  in BodyDec (Wise lore)
-> Stms (Wise lore) -> [KernelResult] -> KernelBody (Wise lore)
forall lore.
BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody BodyDec (Wise lore)
dec' Stms (Wise lore)
bnds [KernelResult]
res
  where res_vs :: [SubExp]
res_vs = (KernelResult -> SubExp) -> [KernelResult] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> SubExp
kernelResultSubExp [KernelResult]
res

mkKernelBodyM :: MonadBinder m =>
                 Stms (Lore m) -> [KernelResult]
              -> m (KernelBody (Lore m))
mkKernelBodyM :: Stms (Lore m) -> [KernelResult] -> m (KernelBody (Lore m))
mkKernelBodyM Stms (Lore m)
stms [KernelResult]
kres = do
  Body BodyDec (Lore m)
dec' Stms (Lore m)
_ [SubExp]
_ <- Stms (Lore m) -> [SubExp] -> m (BodyT (Lore m))
forall (m :: * -> *).
MonadBinder m =>
Stms (Lore m) -> [SubExp] -> m (Body (Lore m))
mkBodyM Stms (Lore m)
stms [SubExp]
res_ses
  KernelBody (Lore m) -> m (KernelBody (Lore m))
forall (m :: * -> *) a. Monad m => a -> m a
return (KernelBody (Lore m) -> m (KernelBody (Lore m)))
-> KernelBody (Lore m) -> m (KernelBody (Lore m))
forall a b. (a -> b) -> a -> b
$ BodyDec (Lore m)
-> Stms (Lore m) -> [KernelResult] -> KernelBody (Lore m)
forall lore.
BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody BodyDec (Lore m)
dec' Stms (Lore m)
stms [KernelResult]
kres
  where res_ses :: [SubExp]
res_ses = (KernelResult -> SubExp) -> [KernelResult] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> SubExp
kernelResultSubExp [KernelResult]
kres

simplifyKernelBody :: (Engine.SimplifiableLore lore, BodyDec lore ~ ()) =>
                      SegSpace -> KernelBody lore
                   -> Engine.SimpleM lore (KernelBody (Wise lore), Stms (Wise lore))
simplifyKernelBody :: SegSpace
-> KernelBody lore
-> SimpleM lore (KernelBody (Wise lore), Stms (Wise lore))
simplifyKernelBody SegSpace
space (KernelBody BodyDec lore
_ Stms lore
stms [KernelResult]
res) = do
  BlockPred (Wise lore)
par_blocker <- (Env lore -> BlockPred (Wise lore))
-> SimpleM lore (BlockPred (Wise lore))
forall lore a. (Env lore -> a) -> SimpleM lore a
Engine.asksEngineEnv ((Env lore -> BlockPred (Wise lore))
 -> SimpleM lore (BlockPred (Wise lore)))
-> (Env lore -> BlockPred (Wise lore))
-> SimpleM lore (BlockPred (Wise lore))
forall a b. (a -> b) -> a -> b
$ HoistBlockers lore -> BlockPred (Wise lore)
forall lore. HoistBlockers lore -> BlockPred (Wise lore)
Engine.blockHoistPar (HoistBlockers lore -> BlockPred (Wise lore))
-> (Env lore -> HoistBlockers lore)
-> Env lore
-> BlockPred (Wise lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env lore -> HoistBlockers lore
forall lore. Env lore -> HoistBlockers lore
Engine.envHoistBlockers

  ((Stms (Wise lore)
body_stms, [KernelResult]
body_res), Stms (Wise lore)
hoisted) <-
    (SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM
     lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore))
-> SimpleM
     lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore))
forall lore a.
(SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a
Engine.localVtable (SymbolTable (Wise lore)
-> SymbolTable (Wise lore) -> SymbolTable (Wise lore)
forall a. Semigroup a => a -> a -> a
<>SymbolTable (Wise lore)
scope_vtable) (SimpleM
   lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore))
 -> SimpleM
      lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore)))
-> SimpleM
     lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore))
-> SimpleM
     lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore))
forall a b. (a -> b) -> a -> b
$
    (SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM
     lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore))
-> SimpleM
     lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore))
forall lore a.
(SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a
Engine.localVtable (\SymbolTable (Wise lore)
vtable -> SymbolTable (Wise lore)
vtable { simplifyMemory :: Bool
ST.simplifyMemory = Bool
True }) (SimpleM
   lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore))
 -> SimpleM
      lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore)))
-> SimpleM
     lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore))
-> SimpleM
     lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore))
forall a b. (a -> b) -> a -> b
$
    BlockPred (Wise lore)
-> SimpleM lore (SimplifiedBody lore [KernelResult])
-> SimpleM
     lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore))
forall lore a.
SimplifiableLore lore =>
BlockPred (Wise lore)
-> SimpleM lore (SimplifiedBody lore a)
-> SimpleM lore ((Stms (Wise lore), a), Stms (Wise lore))
Engine.blockIf (Names -> BlockPred (Wise lore)
forall lore. ASTLore lore => Names -> BlockPred lore
Engine.hasFree Names
bound_here
                    BlockPred (Wise lore)
-> BlockPred (Wise lore) -> BlockPred (Wise lore)
forall lore. BlockPred lore -> BlockPred lore -> BlockPred lore
`Engine.orIf` BlockPred (Wise lore)
forall lore. BlockPred lore
Engine.isOp
                    BlockPred (Wise lore)
-> BlockPred (Wise lore) -> BlockPred (Wise lore)
forall lore. BlockPred lore -> BlockPred lore -> BlockPred lore
`Engine.orIf` BlockPred (Wise lore)
par_blocker
                    BlockPred (Wise lore)
-> BlockPred (Wise lore) -> BlockPred (Wise lore)
forall lore. BlockPred lore -> BlockPred lore -> BlockPred lore
`Engine.orIf` BlockPred (Wise lore)
forall lore. BlockPred lore
Engine.isConsumed) (SimpleM lore (SimplifiedBody lore [KernelResult])
 -> SimpleM
      lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore)))
-> SimpleM lore (SimplifiedBody lore [KernelResult])
-> SimpleM
     lore ((Stms (Wise lore), [KernelResult]), Stms (Wise lore))
forall a b. (a -> b) -> a -> b
$
    Stms lore
-> SimpleM lore (SimplifiedBody lore [KernelResult])
-> SimpleM lore (SimplifiedBody lore [KernelResult])
forall lore a.
SimplifiableLore lore =>
Stms lore
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
Engine.simplifyStms Stms lore
stms (SimpleM lore (SimplifiedBody lore [KernelResult])
 -> SimpleM lore (SimplifiedBody lore [KernelResult]))
-> SimpleM lore (SimplifiedBody lore [KernelResult])
-> SimpleM lore (SimplifiedBody lore [KernelResult])
forall a b. (a -> b) -> a -> b
$ do
    [KernelResult]
res' <- (SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore [KernelResult] -> SimpleM lore [KernelResult]
forall lore a.
(SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a
Engine.localVtable (Names -> SymbolTable (Wise lore) -> SymbolTable (Wise lore)
forall lore. Names -> SymbolTable lore -> SymbolTable lore
ST.hideCertified (Names -> SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> Names -> SymbolTable (Wise lore) -> SymbolTable (Wise lore)
forall a b. (a -> b) -> a -> b
$ [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ Map VName (NameInfo lore) -> [VName]
forall k a. Map k a -> [k]
M.keys (Map VName (NameInfo lore) -> [VName])
-> Map VName (NameInfo lore) -> [VName]
forall a b. (a -> b) -> a -> b
$ Stms lore -> Map VName (NameInfo lore)
forall lore a. Scoped lore a => a -> Scope lore
scopeOf Stms lore
stms) (SimpleM lore [KernelResult] -> SimpleM lore [KernelResult])
-> SimpleM lore [KernelResult] -> SimpleM lore [KernelResult]
forall a b. (a -> b) -> a -> b
$
            (KernelResult -> SimpleM lore KernelResult)
-> [KernelResult] -> SimpleM lore [KernelResult]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM KernelResult -> SimpleM lore KernelResult
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify [KernelResult]
res
    SimplifiedBody lore [KernelResult]
-> SimpleM lore (SimplifiedBody lore [KernelResult])
forall (m :: * -> *) a. Monad m => a -> m a
return (([KernelResult]
res', Names -> UsageTable
UT.usages (Names -> UsageTable) -> Names -> UsageTable
forall a b. (a -> b) -> a -> b
$ [KernelResult] -> Names
forall a. FreeIn a => a -> Names
freeIn [KernelResult]
res'), Stms (Wise lore)
forall a. Monoid a => a
mempty)

  (KernelBody (Wise lore), Stms (Wise lore))
-> SimpleM lore (KernelBody (Wise lore), Stms (Wise lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (BodyDec lore
-> Stms (Wise lore) -> [KernelResult] -> KernelBody (Wise lore)
forall lore.
(ASTLore lore, CanBeWise (Op lore)) =>
BodyDec lore
-> Stms (Wise lore) -> [KernelResult] -> KernelBody (Wise lore)
mkWiseKernelBody () Stms (Wise lore)
body_stms [KernelResult]
body_res, Stms (Wise lore)
hoisted)

  where scope_vtable :: SymbolTable (Wise lore)
scope_vtable = SegSpace -> SymbolTable (Wise lore)
forall lore. ASTLore lore => SegSpace -> SymbolTable lore
segSpaceSymbolTable SegSpace
space
        bound_here :: Names
bound_here = [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ Map VName (NameInfo Any) -> [VName]
forall k a. Map k a -> [k]
M.keys (Map VName (NameInfo Any) -> [VName])
-> Map VName (NameInfo Any) -> [VName]
forall a b. (a -> b) -> a -> b
$ SegSpace -> Map VName (NameInfo Any)
forall lore. SegSpace -> Scope lore
scopeOfSegSpace SegSpace
space

segSpaceSymbolTable :: ASTLore lore => SegSpace -> ST.SymbolTable lore
segSpaceSymbolTable :: SegSpace -> SymbolTable lore
segSpaceSymbolTable (SegSpace VName
flat [(VName, SubExp)]
gtids_and_dims) =
  (SymbolTable lore -> (VName, SubExp) -> SymbolTable lore)
-> SymbolTable lore -> [(VName, SubExp)] -> SymbolTable lore
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' SymbolTable lore -> (VName, SubExp) -> SymbolTable lore
forall lore.
ASTLore lore =>
SymbolTable lore -> (VName, SubExp) -> SymbolTable lore
f (Scope lore -> SymbolTable lore
forall lore. ASTLore lore => Scope lore -> SymbolTable lore
ST.fromScope (Scope lore -> SymbolTable lore) -> Scope lore -> SymbolTable lore
forall a b. (a -> b) -> a -> b
$ VName -> NameInfo lore -> Scope lore
forall k a. k -> a -> Map k a
M.singleton VName
flat (NameInfo lore -> Scope lore) -> NameInfo lore -> Scope lore
forall a b. (a -> b) -> a -> b
$ IntType -> NameInfo lore
forall lore. IntType -> NameInfo lore
IndexName IntType
Int32) [(VName, SubExp)]
gtids_and_dims
  where f :: SymbolTable lore -> (VName, SubExp) -> SymbolTable lore
f SymbolTable lore
vtable (VName
gtid, SubExp
dim) = VName -> IntType -> SubExp -> SymbolTable lore -> SymbolTable lore
forall lore.
ASTLore lore =>
VName -> IntType -> SubExp -> SymbolTable lore -> SymbolTable lore
ST.insertLoopVar VName
gtid IntType
Int32 SubExp
dim SymbolTable lore
vtable

simplifySegBinOp :: Engine.SimplifiableLore lore =>
                    SegBinOp lore
                 -> Engine.SimpleM lore (SegBinOp (Wise lore), Stms (Wise lore))
simplifySegBinOp :: SegBinOp lore
-> SimpleM lore (SegBinOp (Wise lore), Stms (Wise lore))
simplifySegBinOp (SegBinOp Commutativity
comm Lambda lore
lam [SubExp]
nes Shape
shape) = do
  (Lambda (Wise lore)
lam', Stms (Wise lore)
hoisted) <-
    (SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
forall lore a.
(SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a
Engine.localVtable (\SymbolTable (Wise lore)
vtable -> SymbolTable (Wise lore)
vtable { simplifyMemory :: Bool
ST.simplifyMemory = Bool
True }) (SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
 -> SimpleM lore (Lambda (Wise lore), Stms (Wise lore)))
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
forall a b. (a -> b) -> a -> b
$
    Lambda lore -> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
forall lore.
SimplifiableLore lore =>
Lambda lore -> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
Engine.simplifyLambda Lambda lore
lam
  Shape
shape' <- Shape -> SimpleM lore Shape
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify Shape
shape
  [SubExp]
nes' <- (SubExp -> SimpleM lore SubExp)
-> [SubExp] -> SimpleM lore [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify [SubExp]
nes
  (SegBinOp (Wise lore), Stms (Wise lore))
-> SimpleM lore (SegBinOp (Wise lore), Stms (Wise lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (Commutativity
-> Lambda (Wise lore) -> [SubExp] -> Shape -> SegBinOp (Wise lore)
forall lore.
Commutativity -> Lambda lore -> [SubExp] -> Shape -> SegBinOp lore
SegBinOp Commutativity
comm Lambda (Wise lore)
lam' [SubExp]
nes' Shape
shape', Stms (Wise lore)
hoisted)

-- | Simplify the given 'SegOp'.
simplifySegOp :: (Engine.SimplifiableLore lore,
                  BodyDec lore ~ (),
                  Engine.Simplifiable lvl) =>
                 SegOp lvl lore
              -> Engine.SimpleM lore (SegOp lvl (Wise lore), Stms (Wise lore))
simplifySegOp :: SegOp lvl lore
-> SimpleM lore (SegOp lvl (Wise lore), Stms (Wise lore))
simplifySegOp (SegMap lvl
lvl SegSpace
space [Type]
ts KernelBody lore
kbody) = do
  (lvl
lvl', SegSpace
space', [Type]
ts') <- (lvl, SegSpace, [Type]) -> SimpleM lore (lvl, SegSpace, [Type])
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify (lvl
lvl, SegSpace
space, [Type]
ts)
  (KernelBody (Wise lore)
kbody', Stms (Wise lore)
body_hoisted) <- SegSpace
-> KernelBody lore
-> SimpleM lore (KernelBody (Wise lore), Stms (Wise lore))
forall lore.
(SimplifiableLore lore, BodyDec lore ~ ()) =>
SegSpace
-> KernelBody lore
-> SimpleM lore (KernelBody (Wise lore), Stms (Wise lore))
simplifyKernelBody SegSpace
space KernelBody lore
kbody
  (SegOp lvl (Wise lore), Stms (Wise lore))
-> SimpleM lore (SegOp lvl (Wise lore), Stms (Wise lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (lvl
-> SegSpace
-> [Type]
-> KernelBody (Wise lore)
-> SegOp lvl (Wise lore)
forall lvl lore.
lvl -> SegSpace -> [Type] -> KernelBody lore -> SegOp lvl lore
SegMap lvl
lvl' SegSpace
space' [Type]
ts' KernelBody (Wise lore)
kbody',
          Stms (Wise lore)
body_hoisted)

simplifySegOp (SegRed lvl
lvl SegSpace
space [SegBinOp lore]
reds [Type]
ts KernelBody lore
kbody) = do
  (lvl
lvl', SegSpace
space', [Type]
ts') <- (lvl, SegSpace, [Type]) -> SimpleM lore (lvl, SegSpace, [Type])
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify (lvl
lvl, SegSpace
space, [Type]
ts)
  ([SegBinOp (Wise lore)]
reds', [Stms (Wise lore)]
reds_hoisted) <- (SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore ([SegBinOp (Wise lore)], [Stms (Wise lore)])
-> SimpleM lore ([SegBinOp (Wise lore)], [Stms (Wise lore)])
forall lore a.
(SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a
Engine.localVtable (SymbolTable (Wise lore)
-> SymbolTable (Wise lore) -> SymbolTable (Wise lore)
forall a. Semigroup a => a -> a -> a
<>SymbolTable (Wise lore)
scope_vtable) (SimpleM lore ([SegBinOp (Wise lore)], [Stms (Wise lore)])
 -> SimpleM lore ([SegBinOp (Wise lore)], [Stms (Wise lore)]))
-> SimpleM lore ([SegBinOp (Wise lore)], [Stms (Wise lore)])
-> SimpleM lore ([SegBinOp (Wise lore)], [Stms (Wise lore)])
forall a b. (a -> b) -> a -> b
$
    [(SegBinOp (Wise lore), Stms (Wise lore))]
-> ([SegBinOp (Wise lore)], [Stms (Wise lore)])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(SegBinOp (Wise lore), Stms (Wise lore))]
 -> ([SegBinOp (Wise lore)], [Stms (Wise lore)]))
-> SimpleM lore [(SegBinOp (Wise lore), Stms (Wise lore))]
-> SimpleM lore ([SegBinOp (Wise lore)], [Stms (Wise lore)])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SegBinOp lore
 -> SimpleM lore (SegBinOp (Wise lore), Stms (Wise lore)))
-> [SegBinOp lore]
-> SimpleM lore [(SegBinOp (Wise lore), Stms (Wise lore))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SegBinOp lore
-> SimpleM lore (SegBinOp (Wise lore), Stms (Wise lore))
forall lore.
SimplifiableLore lore =>
SegBinOp lore
-> SimpleM lore (SegBinOp (Wise lore), Stms (Wise lore))
simplifySegBinOp [SegBinOp lore]
reds
  (KernelBody (Wise lore)
kbody', Stms (Wise lore)
body_hoisted) <- SegSpace
-> KernelBody lore
-> SimpleM lore (KernelBody (Wise lore), Stms (Wise lore))
forall lore.
(SimplifiableLore lore, BodyDec lore ~ ()) =>
SegSpace
-> KernelBody lore
-> SimpleM lore (KernelBody (Wise lore), Stms (Wise lore))
simplifyKernelBody SegSpace
space KernelBody lore
kbody

  (SegOp lvl (Wise lore), Stms (Wise lore))
-> SimpleM lore (SegOp lvl (Wise lore), Stms (Wise lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (lvl
-> SegSpace
-> [SegBinOp (Wise lore)]
-> [Type]
-> KernelBody (Wise lore)
-> SegOp lvl (Wise lore)
forall lvl lore.
lvl
-> SegSpace
-> [SegBinOp lore]
-> [Type]
-> KernelBody lore
-> SegOp lvl lore
SegRed lvl
lvl' SegSpace
space' [SegBinOp (Wise lore)]
reds' [Type]
ts' KernelBody (Wise lore)
kbody',
          [Stms (Wise lore)] -> Stms (Wise lore)
forall a. Monoid a => [a] -> a
mconcat [Stms (Wise lore)]
reds_hoisted Stms (Wise lore) -> Stms (Wise lore) -> Stms (Wise lore)
forall a. Semigroup a => a -> a -> a
<> Stms (Wise lore)
body_hoisted)
  where scope :: Scope (Wise lore)
scope = SegSpace -> Scope (Wise lore)
forall lore. SegSpace -> Scope lore
scopeOfSegSpace SegSpace
space
        scope_vtable :: SymbolTable (Wise lore)
scope_vtable = Scope (Wise lore) -> SymbolTable (Wise lore)
forall lore. ASTLore lore => Scope lore -> SymbolTable lore
ST.fromScope Scope (Wise lore)
scope

simplifySegOp (SegScan lvl
lvl SegSpace
space [SegBinOp lore]
scans [Type]
ts KernelBody lore
kbody) = do
  (lvl
lvl', SegSpace
space', [Type]
ts') <- (lvl, SegSpace, [Type]) -> SimpleM lore (lvl, SegSpace, [Type])
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify (lvl
lvl, SegSpace
space, [Type]
ts)
  ([SegBinOp (Wise lore)]
scans', [Stms (Wise lore)]
scans_hoisted) <- (SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore ([SegBinOp (Wise lore)], [Stms (Wise lore)])
-> SimpleM lore ([SegBinOp (Wise lore)], [Stms (Wise lore)])
forall lore a.
(SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a
Engine.localVtable (SymbolTable (Wise lore)
-> SymbolTable (Wise lore) -> SymbolTable (Wise lore)
forall a. Semigroup a => a -> a -> a
<>SymbolTable (Wise lore)
scope_vtable) (SimpleM lore ([SegBinOp (Wise lore)], [Stms (Wise lore)])
 -> SimpleM lore ([SegBinOp (Wise lore)], [Stms (Wise lore)]))
-> SimpleM lore ([SegBinOp (Wise lore)], [Stms (Wise lore)])
-> SimpleM lore ([SegBinOp (Wise lore)], [Stms (Wise lore)])
forall a b. (a -> b) -> a -> b
$
    [(SegBinOp (Wise lore), Stms (Wise lore))]
-> ([SegBinOp (Wise lore)], [Stms (Wise lore)])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(SegBinOp (Wise lore), Stms (Wise lore))]
 -> ([SegBinOp (Wise lore)], [Stms (Wise lore)]))
-> SimpleM lore [(SegBinOp (Wise lore), Stms (Wise lore))]
-> SimpleM lore ([SegBinOp (Wise lore)], [Stms (Wise lore)])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SegBinOp lore
 -> SimpleM lore (SegBinOp (Wise lore), Stms (Wise lore)))
-> [SegBinOp lore]
-> SimpleM lore [(SegBinOp (Wise lore), Stms (Wise lore))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SegBinOp lore
-> SimpleM lore (SegBinOp (Wise lore), Stms (Wise lore))
forall lore.
SimplifiableLore lore =>
SegBinOp lore
-> SimpleM lore (SegBinOp (Wise lore), Stms (Wise lore))
simplifySegBinOp [SegBinOp lore]
scans
  (KernelBody (Wise lore)
kbody', Stms (Wise lore)
body_hoisted) <- SegSpace
-> KernelBody lore
-> SimpleM lore (KernelBody (Wise lore), Stms (Wise lore))
forall lore.
(SimplifiableLore lore, BodyDec lore ~ ()) =>
SegSpace
-> KernelBody lore
-> SimpleM lore (KernelBody (Wise lore), Stms (Wise lore))
simplifyKernelBody SegSpace
space KernelBody lore
kbody

  (SegOp lvl (Wise lore), Stms (Wise lore))
-> SimpleM lore (SegOp lvl (Wise lore), Stms (Wise lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (lvl
-> SegSpace
-> [SegBinOp (Wise lore)]
-> [Type]
-> KernelBody (Wise lore)
-> SegOp lvl (Wise lore)
forall lvl lore.
lvl
-> SegSpace
-> [SegBinOp lore]
-> [Type]
-> KernelBody lore
-> SegOp lvl lore
SegScan lvl
lvl' SegSpace
space' [SegBinOp (Wise lore)]
scans' [Type]
ts' KernelBody (Wise lore)
kbody',
          [Stms (Wise lore)] -> Stms (Wise lore)
forall a. Monoid a => [a] -> a
mconcat [Stms (Wise lore)]
scans_hoisted Stms (Wise lore) -> Stms (Wise lore) -> Stms (Wise lore)
forall a. Semigroup a => a -> a -> a
<> Stms (Wise lore)
body_hoisted)
  where scope :: Scope (Wise lore)
scope = SegSpace -> Scope (Wise lore)
forall lore. SegSpace -> Scope lore
scopeOfSegSpace SegSpace
space
        scope_vtable :: SymbolTable (Wise lore)
scope_vtable = Scope (Wise lore) -> SymbolTable (Wise lore)
forall lore. ASTLore lore => Scope lore -> SymbolTable lore
ST.fromScope Scope (Wise lore)
scope

simplifySegOp (SegHist lvl
lvl SegSpace
space [HistOp lore]
ops [Type]
ts KernelBody lore
kbody) = do
  (lvl
lvl', SegSpace
space', [Type]
ts') <- (lvl, SegSpace, [Type]) -> SimpleM lore (lvl, SegSpace, [Type])
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify (lvl
lvl, SegSpace
space, [Type]
ts)

  ([HistOp (Wise lore)]
ops', [Stms (Wise lore)]
ops_hoisted) <- ([(HistOp (Wise lore), Stms (Wise lore))]
 -> ([HistOp (Wise lore)], [Stms (Wise lore)]))
-> SimpleM lore [(HistOp (Wise lore), Stms (Wise lore))]
-> SimpleM lore ([HistOp (Wise lore)], [Stms (Wise lore)])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [(HistOp (Wise lore), Stms (Wise lore))]
-> ([HistOp (Wise lore)], [Stms (Wise lore)])
forall a b. [(a, b)] -> ([a], [b])
unzip (SimpleM lore [(HistOp (Wise lore), Stms (Wise lore))]
 -> SimpleM lore ([HistOp (Wise lore)], [Stms (Wise lore)]))
-> SimpleM lore [(HistOp (Wise lore), Stms (Wise lore))]
-> SimpleM lore ([HistOp (Wise lore)], [Stms (Wise lore)])
forall a b. (a -> b) -> a -> b
$ [HistOp lore]
-> (HistOp lore
    -> SimpleM lore (HistOp (Wise lore), Stms (Wise lore)))
-> SimpleM lore [(HistOp (Wise lore), Stms (Wise lore))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [HistOp lore]
ops ((HistOp lore
  -> SimpleM lore (HistOp (Wise lore), Stms (Wise lore)))
 -> SimpleM lore [(HistOp (Wise lore), Stms (Wise lore))])
-> (HistOp lore
    -> SimpleM lore (HistOp (Wise lore), Stms (Wise lore)))
-> SimpleM lore [(HistOp (Wise lore), Stms (Wise lore))]
forall a b. (a -> b) -> a -> b
$
    \(HistOp SubExp
w SubExp
rf [VName]
arrs [SubExp]
nes Shape
dims Lambda lore
lam) -> do
      SubExp
w' <- SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify SubExp
w
      SubExp
rf' <- SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify SubExp
rf
      [VName]
arrs' <- [VName] -> SimpleM lore [VName]
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify [VName]
arrs
      [SubExp]
nes' <- [SubExp] -> SimpleM lore [SubExp]
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify [SubExp]
nes
      Shape
dims' <- Shape -> SimpleM lore Shape
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify Shape
dims
      (Lambda (Wise lore)
lam', Stms (Wise lore)
op_hoisted) <-
        (SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
forall lore a.
(SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a
Engine.localVtable (SymbolTable (Wise lore)
-> SymbolTable (Wise lore) -> SymbolTable (Wise lore)
forall a. Semigroup a => a -> a -> a
<>SymbolTable (Wise lore)
scope_vtable) (SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
 -> SimpleM lore (Lambda (Wise lore), Stms (Wise lore)))
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
forall a b. (a -> b) -> a -> b
$
        (SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
forall lore a.
(SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a
Engine.localVtable (\SymbolTable (Wise lore)
vtable -> SymbolTable (Wise lore)
vtable { simplifyMemory :: Bool
ST.simplifyMemory = Bool
True }) (SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
 -> SimpleM lore (Lambda (Wise lore), Stms (Wise lore)))
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
forall a b. (a -> b) -> a -> b
$
        Lambda lore -> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
forall lore.
SimplifiableLore lore =>
Lambda lore -> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
Engine.simplifyLambda Lambda lore
lam
      (HistOp (Wise lore), Stms (Wise lore))
-> SimpleM lore (HistOp (Wise lore), Stms (Wise lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp
-> SubExp
-> [VName]
-> [SubExp]
-> Shape
-> Lambda (Wise lore)
-> HistOp (Wise lore)
forall lore.
SubExp
-> SubExp
-> [VName]
-> [SubExp]
-> Shape
-> Lambda lore
-> HistOp lore
HistOp SubExp
w' SubExp
rf' [VName]
arrs' [SubExp]
nes' Shape
dims' Lambda (Wise lore)
lam',
              Stms (Wise lore)
op_hoisted)

  (KernelBody (Wise lore)
kbody', Stms (Wise lore)
body_hoisted) <- SegSpace
-> KernelBody lore
-> SimpleM lore (KernelBody (Wise lore), Stms (Wise lore))
forall lore.
(SimplifiableLore lore, BodyDec lore ~ ()) =>
SegSpace
-> KernelBody lore
-> SimpleM lore (KernelBody (Wise lore), Stms (Wise lore))
simplifyKernelBody SegSpace
space KernelBody lore
kbody

  (SegOp lvl (Wise lore), Stms (Wise lore))
-> SimpleM lore (SegOp lvl (Wise lore), Stms (Wise lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (lvl
-> SegSpace
-> [HistOp (Wise lore)]
-> [Type]
-> KernelBody (Wise lore)
-> SegOp lvl (Wise lore)
forall lvl lore.
lvl
-> SegSpace
-> [HistOp lore]
-> [Type]
-> KernelBody lore
-> SegOp lvl lore
SegHist lvl
lvl' SegSpace
space' [HistOp (Wise lore)]
ops' [Type]
ts' KernelBody (Wise lore)
kbody',
          [Stms (Wise lore)] -> Stms (Wise lore)
forall a. Monoid a => [a] -> a
mconcat [Stms (Wise lore)]
ops_hoisted Stms (Wise lore) -> Stms (Wise lore) -> Stms (Wise lore)
forall a. Semigroup a => a -> a -> a
<> Stms (Wise lore)
body_hoisted)

  where scope :: Scope (Wise lore)
scope = SegSpace -> Scope (Wise lore)
forall lore. SegSpace -> Scope lore
scopeOfSegSpace SegSpace
space
        scope_vtable :: SymbolTable (Wise lore)
scope_vtable = Scope (Wise lore) -> SymbolTable (Wise lore)
forall lore. ASTLore lore => Scope lore -> SymbolTable lore
ST.fromScope Scope (Wise lore)
scope

-- | Does this lore contain 'SegOp's in its t'Op's?  A lore must be an
-- instance of this class for the simplification rules to work.
class HasSegOp lore where
  type SegOpLevel lore
  asSegOp :: Op lore -> Maybe (SegOp (SegOpLevel lore) lore)
  segOp :: SegOp (SegOpLevel lore) lore -> Op lore

-- | Simplification rules for simplifying 'SegOp's.
segOpRules :: (HasSegOp lore, BinderOps lore, Bindable lore) =>
              RuleBook lore
segOpRules :: RuleBook lore
segOpRules =
  [TopDownRule lore] -> [BottomUpRule lore] -> RuleBook lore
forall m. [TopDownRule m] -> [BottomUpRule m] -> RuleBook m
ruleBook [ RuleOp lore (TopDown lore) -> TopDownRule lore
forall lore a. RuleOp lore a -> SimplificationRule lore a
RuleOp RuleOp lore (TopDown lore)
forall lore.
(HasSegOp lore, BinderOps lore, Bindable lore) =>
TopDownRuleOp lore
segOpRuleTopDown ] [ RuleOp lore (BottomUp lore) -> BottomUpRule lore
forall lore a. RuleOp lore a -> SimplificationRule lore a
RuleOp RuleOp lore (BottomUp lore)
forall lore. (HasSegOp lore, BinderOps lore) => BottomUpRuleOp lore
segOpRuleBottomUp ]

segOpRuleTopDown :: (HasSegOp lore, BinderOps lore, Bindable lore) =>
                    TopDownRuleOp lore
segOpRuleTopDown :: TopDownRuleOp lore
segOpRuleTopDown TopDown lore
vtable Pattern lore
pat StmAux (ExpDec lore)
dec Op lore
op
  | Just SegOp (SegOpLevel lore) lore
op' <- Op lore -> Maybe (SegOp (SegOpLevel lore) lore)
forall lore.
HasSegOp lore =>
Op lore -> Maybe (SegOp (SegOpLevel lore) lore)
asSegOp Op lore
op =
      TopDown lore
-> Pattern lore
-> StmAux (ExpDec lore)
-> SegOp (SegOpLevel lore) lore
-> Rule lore
forall lore.
(HasSegOp lore, BinderOps lore, Bindable lore) =>
SymbolTable lore
-> Pattern lore
-> StmAux (ExpDec lore)
-> SegOp (SegOpLevel lore) lore
-> Rule lore
topDownSegOp TopDown lore
vtable Pattern lore
pat StmAux (ExpDec lore)
dec SegOp (SegOpLevel lore) lore
op'
  | Bool
otherwise =
      Rule lore
forall lore. Rule lore
Skip

segOpRuleBottomUp :: (HasSegOp lore, BinderOps lore) =>
                     BottomUpRuleOp lore
segOpRuleBottomUp :: BottomUpRuleOp lore
segOpRuleBottomUp BottomUp lore
vtable Pattern lore
pat StmAux (ExpDec lore)
dec Op lore
op
  | Just SegOp (SegOpLevel lore) lore
op' <- Op lore -> Maybe (SegOp (SegOpLevel lore) lore)
forall lore.
HasSegOp lore =>
Op lore -> Maybe (SegOp (SegOpLevel lore) lore)
asSegOp Op lore
op =
      BottomUp lore
-> Pattern lore
-> StmAux (ExpDec lore)
-> SegOp (SegOpLevel lore) lore
-> Rule lore
forall lore.
(HasSegOp lore, BinderOps lore) =>
(SymbolTable lore, UsageTable)
-> Pattern lore
-> StmAux (ExpDec lore)
-> SegOp (SegOpLevel lore) lore
-> Rule lore
bottomUpSegOp BottomUp lore
vtable Pattern lore
pat StmAux (ExpDec lore)
dec SegOp (SegOpLevel lore) lore
op'
  | Bool
otherwise =
      Rule lore
forall lore. Rule lore
Skip

topDownSegOp :: (HasSegOp lore, BinderOps lore, Bindable lore) =>
                ST.SymbolTable lore
             -> Pattern lore
             -> StmAux (ExpDec lore)
             -> SegOp (SegOpLevel lore) lore
             -> Rule lore

-- If a SegOp produces something invariant to the SegOp, turn it
-- into a replicate.
topDownSegOp :: SymbolTable lore
-> Pattern lore
-> StmAux (ExpDec lore)
-> SegOp (SegOpLevel lore) lore
-> Rule lore
topDownSegOp SymbolTable lore
vtable (Pattern [] [PatElemT (LetDec lore)]
kpes) StmAux (ExpDec lore)
dec (SegMap SegOpLevel lore
lvl SegSpace
space [Type]
ts (KernelBody BodyDec lore
_ Stms lore
kstms [KernelResult]
kres)) = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
  ([Type]
ts', [PatElemT (LetDec lore)]
kpes', [KernelResult]
kres') <-
    [(Type, PatElemT (LetDec lore), KernelResult)]
-> ([Type], [PatElemT (LetDec lore)], [KernelResult])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 ([(Type, PatElemT (LetDec lore), KernelResult)]
 -> ([Type], [PatElemT (LetDec lore)], [KernelResult]))
-> RuleM lore [(Type, PatElemT (LetDec lore), KernelResult)]
-> RuleM lore ([Type], [PatElemT (LetDec lore)], [KernelResult])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((Type, PatElemT (LetDec lore), KernelResult) -> RuleM lore Bool)
-> [(Type, PatElemT (LetDec lore), KernelResult)]
-> RuleM lore [(Type, PatElemT (LetDec lore), KernelResult)]
forall (m :: * -> *) a.
Applicative m =>
(a -> m Bool) -> [a] -> m [a]
filterM (Type, PatElemT (LetDec lore), KernelResult) -> RuleM lore Bool
checkForInvarianceResult ([Type]
-> [PatElemT (LetDec lore)]
-> [KernelResult]
-> [(Type, PatElemT (LetDec lore), KernelResult)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Type]
ts [PatElemT (LetDec lore)]
kpes [KernelResult]
kres)

  -- Check if we did anything at all.
  Bool -> RuleM lore () -> RuleM lore ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ([KernelResult]
kres [KernelResult] -> [KernelResult] -> Bool
forall a. Eq a => a -> a -> Bool
== [KernelResult]
kres')
    RuleM lore ()
forall lore a. RuleM lore a
cannotSimplify

  KernelBody lore
kbody <- Stms (Lore (RuleM lore))
-> [KernelResult] -> RuleM lore (KernelBody (Lore (RuleM lore)))
forall (m :: * -> *).
MonadBinder m =>
Stms (Lore m) -> [KernelResult] -> m (KernelBody (Lore m))
mkKernelBodyM Stms lore
Stms (Lore (RuleM lore))
kstms [KernelResult]
kres'
  Stm (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *). MonadBinder m => Stm (Lore m) -> m ()
addStm (Stm (Lore (RuleM lore)) -> RuleM lore ())
-> Stm (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let ([PatElemT (LetDec lore)]
-> [PatElemT (LetDec lore)] -> Pattern lore
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [PatElemT (LetDec lore)]
kpes') StmAux (ExpDec lore)
dec (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$ Op lore -> Exp lore
forall lore. Op lore -> ExpT lore
Op (Op lore -> Exp lore) -> Op lore -> Exp lore
forall a b. (a -> b) -> a -> b
$ SegOp (SegOpLevel lore) lore -> Op lore
forall lore.
HasSegOp lore =>
SegOp (SegOpLevel lore) lore -> Op lore
segOp (SegOp (SegOpLevel lore) lore -> Op lore)
-> SegOp (SegOpLevel lore) lore -> Op lore
forall a b. (a -> b) -> a -> b
$
    SegOpLevel lore
-> SegSpace
-> [Type]
-> KernelBody lore
-> SegOp (SegOpLevel lore) lore
forall lvl lore.
lvl -> SegSpace -> [Type] -> KernelBody lore -> SegOp lvl lore
SegMap SegOpLevel lore
lvl SegSpace
space [Type]
ts' KernelBody lore
kbody

  where isInvariant :: SubExp -> Bool
isInvariant Constant{} = Bool
True
        isInvariant (Var VName
v) = Maybe (Entry lore) -> Bool
forall a. Maybe a -> Bool
isJust (Maybe (Entry lore) -> Bool) -> Maybe (Entry lore) -> Bool
forall a b. (a -> b) -> a -> b
$ VName -> SymbolTable lore -> Maybe (Entry lore)
forall lore. VName -> SymbolTable lore -> Maybe (Entry lore)
ST.lookup VName
v SymbolTable lore
vtable

        checkForInvarianceResult :: (Type, PatElemT (LetDec lore), KernelResult) -> RuleM lore Bool
checkForInvarianceResult (Type
_, PatElemT (LetDec lore)
pe, Returns ResultManifest
rm SubExp
se)
          | ResultManifest
rm ResultManifest -> ResultManifest -> Bool
forall a. Eq a => a -> a -> Bool
== ResultManifest
ResultMaySimplify,
            SubExp -> Bool
isInvariant SubExp
se = do
              [VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [PatElemT (LetDec lore) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
pe] (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
                BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> Shape) -> [SubExp] -> Shape
forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space) SubExp
se
              Bool -> RuleM lore Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
        checkForInvarianceResult (Type, PatElemT (LetDec lore), KernelResult)
_ =
          Bool -> RuleM lore Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True

-- If a SegRed contains two reduction operations that have the same
-- vector shape, merge them together.  This saves on communication
-- overhead, but can in principle lead to more local memory usage.
topDownSegOp SymbolTable lore
_ (Pattern [] [PatElemT (LetDec lore)]
pes) StmAux (ExpDec lore)
_ (SegRed SegOpLevel lore
lvl SegSpace
space [SegBinOp lore]
ops [Type]
ts KernelBody lore
kbody)
  | [SegBinOp lore] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SegBinOp lore]
ops Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
1,
    [[(SegBinOp lore, [(PatElemT (LetDec lore), Type, KernelResult)])]]
op_groupings <- ((SegBinOp lore, [(PatElemT (LetDec lore), Type, KernelResult)])
 -> (SegBinOp lore, [(PatElemT (LetDec lore), Type, KernelResult)])
 -> Bool)
-> [(SegBinOp lore,
     [(PatElemT (LetDec lore), Type, KernelResult)])]
-> [[(SegBinOp lore,
      [(PatElemT (LetDec lore), Type, KernelResult)])]]
forall a. (a -> a -> Bool) -> [a] -> [[a]]
groupBy (SegBinOp lore, [(PatElemT (LetDec lore), Type, KernelResult)])
-> (SegBinOp lore, [(PatElemT (LetDec lore), Type, KernelResult)])
-> Bool
forall lore b lore b.
(SegBinOp lore, b) -> (SegBinOp lore, b) -> Bool
sameShape ([(SegBinOp lore, [(PatElemT (LetDec lore), Type, KernelResult)])]
 -> [[(SegBinOp lore,
       [(PatElemT (LetDec lore), Type, KernelResult)])]])
-> [(SegBinOp lore,
     [(PatElemT (LetDec lore), Type, KernelResult)])]
-> [[(SegBinOp lore,
      [(PatElemT (LetDec lore), Type, KernelResult)])]]
forall a b. (a -> b) -> a -> b
$ [SegBinOp lore]
-> [[(PatElemT (LetDec lore), Type, KernelResult)]]
-> [(SegBinOp lore,
     [(PatElemT (LetDec lore), Type, KernelResult)])]
forall a b. [a] -> [b] -> [(a, b)]
zip [SegBinOp lore]
ops ([[(PatElemT (LetDec lore), Type, KernelResult)]]
 -> [(SegBinOp lore,
      [(PatElemT (LetDec lore), Type, KernelResult)])])
-> [[(PatElemT (LetDec lore), Type, KernelResult)]]
-> [(SegBinOp lore,
     [(PatElemT (LetDec lore), Type, KernelResult)])]
forall a b. (a -> b) -> a -> b
$ [Int]
-> [(PatElemT (LetDec lore), Type, KernelResult)]
-> [[(PatElemT (LetDec lore), Type, KernelResult)]]
forall a. [Int] -> [a] -> [[a]]
chunks ((SegBinOp lore -> Int) -> [SegBinOp lore] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int)
-> (SegBinOp lore -> [SubExp]) -> SegBinOp lore -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOp lore -> [SubExp]
forall lore. SegBinOp lore -> [SubExp]
segBinOpNeutral) [SegBinOp lore]
ops) ([(PatElemT (LetDec lore), Type, KernelResult)]
 -> [[(PatElemT (LetDec lore), Type, KernelResult)]])
-> [(PatElemT (LetDec lore), Type, KernelResult)]
-> [[(PatElemT (LetDec lore), Type, KernelResult)]]
forall a b. (a -> b) -> a -> b
$
                    [PatElemT (LetDec lore)]
-> [Type]
-> [KernelResult]
-> [(PatElemT (LetDec lore), Type, KernelResult)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [PatElemT (LetDec lore)]
red_pes [Type]
red_ts [KernelResult]
red_res,
    ([(SegBinOp lore, [(PatElemT (LetDec lore), Type, KernelResult)])]
 -> Bool)
-> [[(SegBinOp lore,
      [(PatElemT (LetDec lore), Type, KernelResult)])]]
-> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any ((Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>Int
1) (Int -> Bool)
-> ([(SegBinOp lore,
      [(PatElemT (LetDec lore), Type, KernelResult)])]
    -> Int)
-> [(SegBinOp lore,
     [(PatElemT (LetDec lore), Type, KernelResult)])]
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(SegBinOp lore, [(PatElemT (LetDec lore), Type, KernelResult)])]
-> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length) [[(SegBinOp lore, [(PatElemT (LetDec lore), Type, KernelResult)])]]
op_groupings = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
      let ([SegBinOp lore]
ops', [[(PatElemT (LetDec lore), Type, KernelResult)]]
aux) = [(SegBinOp lore, [(PatElemT (LetDec lore), Type, KernelResult)])]
-> ([SegBinOp lore],
    [[(PatElemT (LetDec lore), Type, KernelResult)]])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(SegBinOp lore, [(PatElemT (LetDec lore), Type, KernelResult)])]
 -> ([SegBinOp lore],
     [[(PatElemT (LetDec lore), Type, KernelResult)]]))
-> [(SegBinOp lore,
     [(PatElemT (LetDec lore), Type, KernelResult)])]
-> ([SegBinOp lore],
    [[(PatElemT (LetDec lore), Type, KernelResult)]])
forall a b. (a -> b) -> a -> b
$ ([(SegBinOp lore, [(PatElemT (LetDec lore), Type, KernelResult)])]
 -> Maybe
      (SegBinOp lore, [(PatElemT (LetDec lore), Type, KernelResult)]))
-> [[(SegBinOp lore,
      [(PatElemT (LetDec lore), Type, KernelResult)])]]
-> [(SegBinOp lore,
     [(PatElemT (LetDec lore), Type, KernelResult)])]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe [(SegBinOp lore, [(PatElemT (LetDec lore), Type, KernelResult)])]
-> Maybe
     (SegBinOp lore, [(PatElemT (LetDec lore), Type, KernelResult)])
forall lore a.
Bindable lore =>
[(SegBinOp lore, [a])] -> Maybe (SegBinOp lore, [a])
combineOps [[(SegBinOp lore, [(PatElemT (LetDec lore), Type, KernelResult)])]]
op_groupings
          ([PatElemT (LetDec lore)]
red_pes', [Type]
red_ts', [KernelResult]
red_res') = [(PatElemT (LetDec lore), Type, KernelResult)]
-> ([PatElemT (LetDec lore)], [Type], [KernelResult])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 ([(PatElemT (LetDec lore), Type, KernelResult)]
 -> ([PatElemT (LetDec lore)], [Type], [KernelResult]))
-> [(PatElemT (LetDec lore), Type, KernelResult)]
-> ([PatElemT (LetDec lore)], [Type], [KernelResult])
forall a b. (a -> b) -> a -> b
$ [[(PatElemT (LetDec lore), Type, KernelResult)]]
-> [(PatElemT (LetDec lore), Type, KernelResult)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[(PatElemT (LetDec lore), Type, KernelResult)]]
aux
          pes' :: [PatElemT (LetDec lore)]
pes' = [PatElemT (LetDec lore)]
red_pes' [PatElemT (LetDec lore)]
-> [PatElemT (LetDec lore)] -> [PatElemT (LetDec lore)]
forall a. [a] -> [a] -> [a]
++ [PatElemT (LetDec lore)]
map_pes
          ts' :: [Type]
ts' = [Type]
red_ts' [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [Type]
map_ts
          kbody' :: KernelBody lore
kbody' = KernelBody lore
kbody { kernelBodyResult :: [KernelResult]
kernelBodyResult = [KernelResult]
red_res' [KernelResult] -> [KernelResult] -> [KernelResult]
forall a. [a] -> [a] -> [a]
++ [KernelResult]
map_res }
      Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind ([PatElemT (LetDec lore)]
-> [PatElemT (LetDec lore)] -> Pattern lore
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [PatElemT (LetDec lore)]
pes') (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ Op lore -> Exp lore
forall lore. Op lore -> ExpT lore
Op (Op lore -> Exp lore) -> Op lore -> Exp lore
forall a b. (a -> b) -> a -> b
$ SegOp (SegOpLevel lore) lore -> Op lore
forall lore.
HasSegOp lore =>
SegOp (SegOpLevel lore) lore -> Op lore
segOp (SegOp (SegOpLevel lore) lore -> Op lore)
-> SegOp (SegOpLevel lore) lore -> Op lore
forall a b. (a -> b) -> a -> b
$ SegOpLevel lore
-> SegSpace
-> [SegBinOp lore]
-> [Type]
-> KernelBody lore
-> SegOp (SegOpLevel lore) lore
forall lvl lore.
lvl
-> SegSpace
-> [SegBinOp lore]
-> [Type]
-> KernelBody lore
-> SegOp lvl lore
SegRed SegOpLevel lore
lvl SegSpace
space [SegBinOp lore]
ops' [Type]
ts' KernelBody lore
kbody'
  where ([PatElemT (LetDec lore)]
red_pes, [PatElemT (LetDec lore)]
map_pes) = Int
-> [PatElemT (LetDec lore)]
-> ([PatElemT (LetDec lore)], [PatElemT (LetDec lore)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SegBinOp lore] -> Int
forall lore. [SegBinOp lore] -> Int
segBinOpResults [SegBinOp lore]
ops) [PatElemT (LetDec lore)]
pes
        ([Type]
red_ts, [Type]
map_ts) = Int -> [Type] -> ([Type], [Type])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SegBinOp lore] -> Int
forall lore. [SegBinOp lore] -> Int
segBinOpResults [SegBinOp lore]
ops) [Type]
ts
        ([KernelResult]
red_res, [KernelResult]
map_res) = Int -> [KernelResult] -> ([KernelResult], [KernelResult])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SegBinOp lore] -> Int
forall lore. [SegBinOp lore] -> Int
segBinOpResults [SegBinOp lore]
ops) ([KernelResult] -> ([KernelResult], [KernelResult]))
-> [KernelResult] -> ([KernelResult], [KernelResult])
forall a b. (a -> b) -> a -> b
$ KernelBody lore -> [KernelResult]
forall lore. KernelBody lore -> [KernelResult]
kernelBodyResult KernelBody lore
kbody

        sameShape :: (SegBinOp lore, b) -> (SegBinOp lore, b) -> Bool
sameShape (SegBinOp lore
op1, b
_) (SegBinOp lore
op2, b
_) = SegBinOp lore -> Shape
forall lore. SegBinOp lore -> Shape
segBinOpShape SegBinOp lore
op1 Shape -> Shape -> Bool
forall a. Eq a => a -> a -> Bool
== SegBinOp lore -> Shape
forall lore. SegBinOp lore -> Shape
segBinOpShape SegBinOp lore
op2

        combineOps :: [(SegBinOp lore, [a])] -> Maybe (SegBinOp lore, [a])
combineOps [] = Maybe (SegBinOp lore, [a])
forall a. Maybe a
Nothing
        combineOps ((SegBinOp lore, [a])
x:[(SegBinOp lore, [a])]
xs) = (SegBinOp lore, [a]) -> Maybe (SegBinOp lore, [a])
forall a. a -> Maybe a
Just ((SegBinOp lore, [a]) -> Maybe (SegBinOp lore, [a]))
-> (SegBinOp lore, [a]) -> Maybe (SegBinOp lore, [a])
forall a b. (a -> b) -> a -> b
$ ((SegBinOp lore, [a])
 -> (SegBinOp lore, [a]) -> (SegBinOp lore, [a]))
-> (SegBinOp lore, [a])
-> [(SegBinOp lore, [a])]
-> (SegBinOp lore, [a])
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (SegBinOp lore, [a])
-> (SegBinOp lore, [a]) -> (SegBinOp lore, [a])
forall lore a.
Bindable lore =>
(SegBinOp lore, [a])
-> (SegBinOp lore, [a]) -> (SegBinOp lore, [a])
combine (SegBinOp lore, [a])
x [(SegBinOp lore, [a])]
xs

        combine :: (SegBinOp lore, [a])
-> (SegBinOp lore, [a]) -> (SegBinOp lore, [a])
combine (SegBinOp lore
op1, [a]
op1_aux) (SegBinOp lore
op2, [a]
op2_aux) =
          let lam1 :: Lambda lore
lam1 = SegBinOp lore -> Lambda lore
forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda SegBinOp lore
op1
              lam2 :: Lambda lore
lam2 = SegBinOp lore -> Lambda lore
forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda SegBinOp lore
op2
              ([Param (LParamInfo lore)]
op1_xparams, [Param (LParamInfo lore)]
op1_yparams) =
                Int
-> [Param (LParamInfo lore)]
-> ([Param (LParamInfo lore)], [Param (LParamInfo lore)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (SegBinOp lore -> [SubExp]
forall lore. SegBinOp lore -> [SubExp]
segBinOpNeutral SegBinOp lore
op1)) ([Param (LParamInfo lore)]
 -> ([Param (LParamInfo lore)], [Param (LParamInfo lore)]))
-> [Param (LParamInfo lore)]
-> ([Param (LParamInfo lore)], [Param (LParamInfo lore)])
forall a b. (a -> b) -> a -> b
$ Lambda lore -> [Param (LParamInfo lore)]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda lore
lam1
              ([Param (LParamInfo lore)]
op2_xparams, [Param (LParamInfo lore)]
op2_yparams) =
                Int
-> [Param (LParamInfo lore)]
-> ([Param (LParamInfo lore)], [Param (LParamInfo lore)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (SegBinOp lore -> [SubExp]
forall lore. SegBinOp lore -> [SubExp]
segBinOpNeutral SegBinOp lore
op2)) ([Param (LParamInfo lore)]
 -> ([Param (LParamInfo lore)], [Param (LParamInfo lore)]))
-> [Param (LParamInfo lore)]
-> ([Param (LParamInfo lore)], [Param (LParamInfo lore)])
forall a b. (a -> b) -> a -> b
$ Lambda lore -> [Param (LParamInfo lore)]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda lore
lam2
              lam :: Lambda lore
lam = Lambda :: forall lore. [LParam lore] -> BodyT lore -> [Type] -> LambdaT lore
Lambda { lambdaParams :: [Param (LParamInfo lore)]
lambdaParams = [Param (LParamInfo lore)]
op1_xparams [Param (LParamInfo lore)]
-> [Param (LParamInfo lore)] -> [Param (LParamInfo lore)]
forall a. [a] -> [a] -> [a]
++ [Param (LParamInfo lore)]
op2_xparams [Param (LParamInfo lore)]
-> [Param (LParamInfo lore)] -> [Param (LParamInfo lore)]
forall a. [a] -> [a] -> [a]
++
                                            [Param (LParamInfo lore)]
op1_yparams [Param (LParamInfo lore)]
-> [Param (LParamInfo lore)] -> [Param (LParamInfo lore)]
forall a. [a] -> [a] -> [a]
++ [Param (LParamInfo lore)]
op2_yparams
                           , lambdaReturnType :: [Type]
lambdaReturnType = Lambda lore -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda lore
lam1 [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ Lambda lore -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda lore
lam2
                           , lambdaBody :: BodyT lore
lambdaBody =
                               Stms lore -> [SubExp] -> BodyT lore
forall lore. Bindable lore => Stms lore -> [SubExp] -> Body lore
mkBody (BodyT lore -> Stms lore
forall lore. BodyT lore -> Stms lore
bodyStms (Lambda lore -> BodyT lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda lore
lam1) Stms lore -> Stms lore -> Stms lore
forall a. Semigroup a => a -> a -> a
<> BodyT lore -> Stms lore
forall lore. BodyT lore -> Stms lore
bodyStms (Lambda lore -> BodyT lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda lore
lam2)) ([SubExp] -> BodyT lore) -> [SubExp] -> BodyT lore
forall a b. (a -> b) -> a -> b
$
                               BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult (Lambda lore -> BodyT lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda lore
lam1) [SubExp] -> [SubExp] -> [SubExp]
forall a. Semigroup a => a -> a -> a
<> BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult (Lambda lore -> BodyT lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda lore
lam2)
                           }
          in (SegBinOp :: forall lore.
Commutativity -> Lambda lore -> [SubExp] -> Shape -> SegBinOp lore
SegBinOp { segBinOpComm :: Commutativity
segBinOpComm = SegBinOp lore -> Commutativity
forall lore. SegBinOp lore -> Commutativity
segBinOpComm SegBinOp lore
op1 Commutativity -> Commutativity -> Commutativity
forall a. Semigroup a => a -> a -> a
<> SegBinOp lore -> Commutativity
forall lore. SegBinOp lore -> Commutativity
segBinOpComm SegBinOp lore
op2
                       , segBinOpLambda :: Lambda lore
segBinOpLambda = Lambda lore
lam
                       , segBinOpNeutral :: [SubExp]
segBinOpNeutral = SegBinOp lore -> [SubExp]
forall lore. SegBinOp lore -> [SubExp]
segBinOpNeutral SegBinOp lore
op1 [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ SegBinOp lore -> [SubExp]
forall lore. SegBinOp lore -> [SubExp]
segBinOpNeutral SegBinOp lore
op2
                       , segBinOpShape :: Shape
segBinOpShape = SegBinOp lore -> Shape
forall lore. SegBinOp lore -> Shape
segBinOpShape SegBinOp lore
op1 -- Same as shape of op2 due to the grouping.
                       },
               [a]
op1_aux [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ [a]
op2_aux)
topDownSegOp SymbolTable lore
_ Pattern lore
_ StmAux (ExpDec lore)
_ SegOp (SegOpLevel lore) lore
_ = Rule lore
forall lore. Rule lore
Skip

bottomUpSegOp :: (HasSegOp lore, BinderOps lore) =>
                 (ST.SymbolTable lore, UT.UsageTable)
              -> Pattern lore
              -> StmAux (ExpDec lore)
              -> SegOp (SegOpLevel lore) lore
              -> Rule lore

-- Some SegOp results can be moved outside the SegOp, which can
-- simplify further analysis.
bottomUpSegOp :: (SymbolTable lore, UsageTable)
-> Pattern lore
-> StmAux (ExpDec lore)
-> SegOp (SegOpLevel lore) lore
-> Rule lore
bottomUpSegOp (SymbolTable lore
vtable, UsageTable
used) (Pattern [] [PatElemT (LetDec lore)]
kpes) StmAux (ExpDec lore)
dec (SegMap SegOpLevel lore
lvl SegSpace
space [Type]
kts (KernelBody BodyDec lore
_ Stms lore
kstms [KernelResult]
kres)) = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do

  -- Iterate through the bindings.  For each, we check whether it is
  -- in kres and can be moved outside.  If so, we remove it from kres
  -- and kpes and make it a binding outside.
  ([PatElemT (LetDec lore)]
kpes', [Type]
kts', [KernelResult]
kres', Stms lore
kstms') <- Scope lore
-> RuleM
     lore ([PatElemT (LetDec lore)], [Type], [KernelResult], Stms lore)
-> RuleM
     lore ([PatElemT (LetDec lore)], [Type], [KernelResult], Stms lore)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (SegSpace -> Scope lore
forall lore. SegSpace -> Scope lore
scopeOfSegSpace SegSpace
space) (RuleM
   lore ([PatElemT (LetDec lore)], [Type], [KernelResult], Stms lore)
 -> RuleM
      lore ([PatElemT (LetDec lore)], [Type], [KernelResult], Stms lore))
-> RuleM
     lore ([PatElemT (LetDec lore)], [Type], [KernelResult], Stms lore)
-> RuleM
     lore ([PatElemT (LetDec lore)], [Type], [KernelResult], Stms lore)
forall a b. (a -> b) -> a -> b
$
    (([PatElemT (LetDec lore)], [Type], [KernelResult], Stms lore)
 -> Stm lore
 -> RuleM
      lore ([PatElemT (LetDec lore)], [Type], [KernelResult], Stms lore))
-> ([PatElemT (LetDec lore)], [Type], [KernelResult], Stms lore)
-> Stms lore
-> RuleM
     lore ([PatElemT (LetDec lore)], [Type], [KernelResult], Stms lore)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM ([PatElemT (LetDec lore)], [Type], [KernelResult], Stms lore)
-> Stm lore
-> RuleM
     lore ([PatElemT (LetDec lore)], [Type], [KernelResult], Stms lore)
distribute ([PatElemT (LetDec lore)]
kpes, [Type]
kts, [KernelResult]
kres, Stms lore
forall a. Monoid a => a
mempty) Stms lore
kstms

  Bool -> RuleM lore () -> RuleM lore ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ([PatElemT (LetDec lore)]
kpes' [PatElemT (LetDec lore)] -> [PatElemT (LetDec lore)] -> Bool
forall a. Eq a => a -> a -> Bool
== [PatElemT (LetDec lore)]
kpes)
    RuleM lore ()
forall lore a. RuleM lore a
cannotSimplify

  KernelBody lore
kbody <- Scope lore
-> RuleM lore (KernelBody lore) -> RuleM lore (KernelBody lore)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (SegSpace -> Scope lore
forall lore. SegSpace -> Scope lore
scopeOfSegSpace SegSpace
space) (RuleM lore (KernelBody lore) -> RuleM lore (KernelBody lore))
-> RuleM lore (KernelBody lore) -> RuleM lore (KernelBody lore)
forall a b. (a -> b) -> a -> b
$
           Stms (Lore (RuleM lore))
-> [KernelResult] -> RuleM lore (KernelBody (Lore (RuleM lore)))
forall (m :: * -> *).
MonadBinder m =>
Stms (Lore m) -> [KernelResult] -> m (KernelBody (Lore m))
mkKernelBodyM Stms lore
Stms (Lore (RuleM lore))
kstms' [KernelResult]
kres'

  Stm (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *). MonadBinder m => Stm (Lore m) -> m ()
addStm (Stm (Lore (RuleM lore)) -> RuleM lore ())
-> Stm (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let ([PatElemT (LetDec lore)]
-> [PatElemT (LetDec lore)] -> Pattern lore
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [PatElemT (LetDec lore)]
kpes') StmAux (ExpDec lore)
dec (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$ Op lore -> Exp lore
forall lore. Op lore -> ExpT lore
Op (Op lore -> Exp lore) -> Op lore -> Exp lore
forall a b. (a -> b) -> a -> b
$ SegOp (SegOpLevel lore) lore -> Op lore
forall lore.
HasSegOp lore =>
SegOp (SegOpLevel lore) lore -> Op lore
segOp (SegOp (SegOpLevel lore) lore -> Op lore)
-> SegOp (SegOpLevel lore) lore -> Op lore
forall a b. (a -> b) -> a -> b
$
    SegOpLevel lore
-> SegSpace
-> [Type]
-> KernelBody lore
-> SegOp (SegOpLevel lore) lore
forall lvl lore.
lvl -> SegSpace -> [Type] -> KernelBody lore -> SegOp lvl lore
SegMap SegOpLevel lore
lvl SegSpace
space [Type]
kts' KernelBody lore
kbody
  where
    free_in_kstms :: Names
free_in_kstms = (Stm lore -> Names) -> Stms lore -> Names
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm lore -> Names
forall a. FreeIn a => a -> Names
freeIn Stms lore
kstms

    sliceWithGtidsFixed :: Stm lore -> Maybe (Slice SubExp, VName)
sliceWithGtidsFixed Stm lore
stm
      | Let Pattern lore
_ StmAux (ExpDec lore)
_ (BasicOp (Index VName
arr Slice SubExp
slice)) <- Stm lore
stm,
        Slice SubExp
space_slice <- ((VName, SubExp) -> DimIndex SubExp)
-> [(VName, SubExp)] -> Slice SubExp
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp)
-> ((VName, SubExp) -> SubExp)
-> (VName, SubExp)
-> DimIndex SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var (VName -> SubExp)
-> ((VName, SubExp) -> VName) -> (VName, SubExp) -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst) ([(VName, SubExp)] -> Slice SubExp)
-> [(VName, SubExp)] -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space,
        Slice SubExp
space_slice Slice SubExp -> Slice SubExp -> Bool
forall a. Eq a => [a] -> [a] -> Bool
`isPrefixOf` Slice SubExp
slice,
        Slice SubExp
remaining_slice <- Int -> Slice SubExp -> Slice SubExp
forall a. Int -> [a] -> [a]
drop (Slice SubExp -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Slice SubExp
space_slice) Slice SubExp
slice,
        (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Maybe (Entry lore) -> Bool
forall a. Maybe a -> Bool
isJust (Maybe (Entry lore) -> Bool)
-> (VName -> Maybe (Entry lore)) -> VName -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> SymbolTable lore -> Maybe (Entry lore))
-> SymbolTable lore -> VName -> Maybe (Entry lore)
forall a b c. (a -> b -> c) -> b -> a -> c
flip VName -> SymbolTable lore -> Maybe (Entry lore)
forall lore. VName -> SymbolTable lore -> Maybe (Entry lore)
ST.lookup SymbolTable lore
vtable) ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$
          VName -> Names
forall a. FreeIn a => a -> Names
freeIn VName
arr Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Slice SubExp -> Names
forall a. FreeIn a => a -> Names
freeIn Slice SubExp
remaining_slice =
          (Slice SubExp, VName) -> Maybe (Slice SubExp, VName)
forall a. a -> Maybe a
Just (Slice SubExp
remaining_slice, VName
arr)

      | Bool
otherwise =
          Maybe (Slice SubExp, VName)
forall a. Maybe a
Nothing

    distribute :: ([PatElemT (LetDec lore)], [Type], [KernelResult], Stms lore)
-> Stm lore
-> RuleM
     lore ([PatElemT (LetDec lore)], [Type], [KernelResult], Stms lore)
distribute ([PatElemT (LetDec lore)]
kpes', [Type]
kts', [KernelResult]
kres', Stms lore
kstms') Stm lore
stm
      | Let (Pattern [] [PatElemT (LetDec lore)
pe]) StmAux (ExpDec lore)
_ Exp lore
_ <- Stm lore
stm,
        Just (Slice SubExp
remaining_slice, VName
arr) <- Stm lore -> Maybe (Slice SubExp, VName)
sliceWithGtidsFixed Stm lore
stm,
        Just (PatElemT (LetDec lore)
kpe, [PatElemT (LetDec lore)]
kpes'', [Type]
kts'', [KernelResult]
kres'') <- [PatElemT (LetDec lore)]
-> [Type]
-> [KernelResult]
-> PatElemT (LetDec lore)
-> Maybe
     (PatElemT (LetDec lore), [PatElemT (LetDec lore)], [Type],
      [KernelResult])
forall a b dec.
[a]
-> [b]
-> [KernelResult]
-> PatElemT dec
-> Maybe (a, [a], [b], [KernelResult])
isResult [PatElemT (LetDec lore)]
kpes' [Type]
kts' [KernelResult]
kres' PatElemT (LetDec lore)
pe = do
          let outer_slice :: Slice SubExp
outer_slice = (SubExp -> DimIndex SubExp) -> [SubExp] -> Slice SubExp
forall a b. (a -> b) -> [a] -> [b]
map (\SubExp
d -> SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice
                                       (Int32 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int32
0::Int32))
                                       SubExp
d
                                       (Int32 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int32
1::Int32))) ([SubExp] -> Slice SubExp) -> [SubExp] -> Slice SubExp
forall a b. (a -> b) -> a -> b
$
                            SegSpace -> [SubExp]
segSpaceDims SegSpace
space
              index :: PatElemT (LetDec lore) -> RuleM lore ()
index PatElemT (LetDec lore)
kpe' = Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind ([PatElemT (LetDec lore)]
-> [PatElemT (LetDec lore)] -> Pattern lore
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [PatElemT (LetDec lore)
kpe']) (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$
                           Slice SubExp
outer_slice Slice SubExp -> Slice SubExp -> Slice SubExp
forall a. Semigroup a => a -> a -> a
<> Slice SubExp
remaining_slice
          if PatElemT (LetDec lore) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
kpe VName -> UsageTable -> Bool
`UT.isConsumed` UsageTable
used
            then do VName
precopy <- String -> RuleM lore VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> RuleM lore VName) -> String -> RuleM lore VName
forall a b. (a -> b) -> a -> b
$ VName -> String
baseString (PatElemT (LetDec lore) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
kpe) String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
"_precopy"
                    PatElemT (LetDec lore) -> RuleM lore ()
index PatElemT (LetDec lore)
kpe { patElemName :: VName
patElemName = VName
precopy }
                    Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind ([PatElemT (LetDec lore)]
-> [PatElemT (LetDec lore)] -> Pattern lore
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [PatElemT (LetDec lore)
kpe]) (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
precopy
            else PatElemT (LetDec lore) -> RuleM lore ()
index PatElemT (LetDec lore)
kpe
          ([PatElemT (LetDec lore)], [Type], [KernelResult], Stms lore)
-> RuleM
     lore ([PatElemT (LetDec lore)], [Type], [KernelResult], Stms lore)
forall (m :: * -> *) a. Monad m => a -> m a
return ([PatElemT (LetDec lore)]
kpes'', [Type]
kts'', [KernelResult]
kres'',
                  if PatElemT (LetDec lore) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
pe VName -> Names -> Bool
`nameIn` Names
free_in_kstms
                  then Stms lore
kstms' Stms lore -> Stms lore -> Stms lore
forall a. Semigroup a => a -> a -> a
<> Stm lore -> Stms lore
forall lore. Stm lore -> Stms lore
oneStm Stm lore
stm
                  else Stms lore
kstms')

    distribute ([PatElemT (LetDec lore)]
kpes', [Type]
kts', [KernelResult]
kres', Stms lore
kstms') Stm lore
stm =
      ([PatElemT (LetDec lore)], [Type], [KernelResult], Stms lore)
-> RuleM
     lore ([PatElemT (LetDec lore)], [Type], [KernelResult], Stms lore)
forall (m :: * -> *) a. Monad m => a -> m a
return ([PatElemT (LetDec lore)]
kpes', [Type]
kts', [KernelResult]
kres', Stms lore
kstms' Stms lore -> Stms lore -> Stms lore
forall a. Semigroup a => a -> a -> a
<> Stm lore -> Stms lore
forall lore. Stm lore -> Stms lore
oneStm Stm lore
stm)

    isResult :: [a]
-> [b]
-> [KernelResult]
-> PatElemT dec
-> Maybe (a, [a], [b], [KernelResult])
isResult [a]
kpes' [b]
kts' [KernelResult]
kres' PatElemT dec
pe =
      case ((a, b, KernelResult) -> Bool)
-> [(a, b, KernelResult)]
-> ([(a, b, KernelResult)], [(a, b, KernelResult)])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition (a, b, KernelResult) -> Bool
matches ([(a, b, KernelResult)]
 -> ([(a, b, KernelResult)], [(a, b, KernelResult)]))
-> [(a, b, KernelResult)]
-> ([(a, b, KernelResult)], [(a, b, KernelResult)])
forall a b. (a -> b) -> a -> b
$ [a] -> [b] -> [KernelResult] -> [(a, b, KernelResult)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [a]
kpes' [b]
kts' [KernelResult]
kres' of
        ([(a
kpe,b
_,KernelResult
_)], [(a, b, KernelResult)]
kpes_and_kres)
          | ([a]
kpes'', [b]
kts'', [KernelResult]
kres'') <- [(a, b, KernelResult)] -> ([a], [b], [KernelResult])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(a, b, KernelResult)]
kpes_and_kres ->
              (a, [a], [b], [KernelResult])
-> Maybe (a, [a], [b], [KernelResult])
forall a. a -> Maybe a
Just (a
kpe, [a]
kpes'', [b]
kts'', [KernelResult]
kres'')
        ([(a, b, KernelResult)], [(a, b, KernelResult)])
_ -> Maybe (a, [a], [b], [KernelResult])
forall a. Maybe a
Nothing
      where matches :: (a, b, KernelResult) -> Bool
matches (a
_, b
_, Returns ResultManifest
_ (Var VName
v)) = VName
v VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== PatElemT dec -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT dec
pe
            matches (a, b, KernelResult)
_ = Bool
False
bottomUpSegOp (SymbolTable lore, UsageTable)
_ Pattern lore
_ StmAux (ExpDec lore)
_ SegOp (SegOpLevel lore) lore
_ = Rule lore
forall lore. Rule lore
Skip

--- Memory

kernelBodyReturns :: (Mem lore, HasScope lore m, Monad m) =>
                     KernelBody lore -> [ExpReturns] -> m [ExpReturns]
kernelBodyReturns :: KernelBody lore -> [ExpReturns] -> m [ExpReturns]
kernelBodyReturns = (KernelResult -> ExpReturns -> m ExpReturns)
-> [KernelResult] -> [ExpReturns] -> m [ExpReturns]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM KernelResult -> ExpReturns -> m ExpReturns
forall (m :: * -> *) lore.
(Monad m, HasScope lore m, AllocOp (Op lore), Checkable lore,
 OpReturns lore, FParamInfo lore ~ FParamMem,
 LParamInfo lore ~ LParamMem, LetDec lore ~ LParamMem,
 RetType lore ~ RetTypeMem, BranchType lore ~ BranchTypeMem) =>
KernelResult -> ExpReturns -> m ExpReturns
correct ([KernelResult] -> [ExpReturns] -> m [ExpReturns])
-> (KernelBody lore -> [KernelResult])
-> KernelBody lore
-> [ExpReturns]
-> m [ExpReturns]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelBody lore -> [KernelResult]
forall lore. KernelBody lore -> [KernelResult]
kernelBodyResult
  where correct :: KernelResult -> ExpReturns -> m ExpReturns
correct (WriteReturns [SubExp]
_ VName
arr [(Slice SubExp, SubExp)]
_) ExpReturns
_ = VName -> m ExpReturns
forall lore (m :: * -> *).
(HasScope lore m, Monad m, Mem lore) =>
VName -> m ExpReturns
varReturns VName
arr
        correct KernelResult
_ ExpReturns
ret = ExpReturns -> m ExpReturns
forall (m :: * -> *) a. Monad m => a -> m a
return ExpReturns
ret

-- | Like 'segOpType', but for memory representations.
segOpReturns :: (Mem lore, Monad m, HasScope lore m) =>
                SegOp lvl lore -> m [ExpReturns]
segOpReturns :: SegOp lvl lore -> m [ExpReturns]
segOpReturns k :: SegOp lvl lore
k@(SegMap lvl
_ SegSpace
_ [Type]
_ KernelBody lore
kbody) =
  KernelBody lore -> [ExpReturns] -> m [ExpReturns]
forall lore (m :: * -> *).
(Mem lore, HasScope lore m, Monad m) =>
KernelBody lore -> [ExpReturns] -> m [ExpReturns]
kernelBodyReturns KernelBody lore
kbody ([ExpReturns] -> m [ExpReturns])
-> m [ExpReturns] -> m [ExpReturns]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ([ExtType] -> [ExpReturns]
extReturns ([ExtType] -> [ExpReturns]) -> m [ExtType] -> m [ExpReturns]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOp lvl lore -> m [ExtType]
forall op t (m :: * -> *).
(TypedOp op, HasScope t m) =>
op -> m [ExtType]
opType SegOp lvl lore
k)
segOpReturns k :: SegOp lvl lore
k@(SegRed lvl
_ SegSpace
_ [SegBinOp lore]
_ [Type]
_ KernelBody lore
kbody) =
  KernelBody lore -> [ExpReturns] -> m [ExpReturns]
forall lore (m :: * -> *).
(Mem lore, HasScope lore m, Monad m) =>
KernelBody lore -> [ExpReturns] -> m [ExpReturns]
kernelBodyReturns KernelBody lore
kbody ([ExpReturns] -> m [ExpReturns])
-> m [ExpReturns] -> m [ExpReturns]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ([ExtType] -> [ExpReturns]
extReturns ([ExtType] -> [ExpReturns]) -> m [ExtType] -> m [ExpReturns]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOp lvl lore -> m [ExtType]
forall op t (m :: * -> *).
(TypedOp op, HasScope t m) =>
op -> m [ExtType]
opType SegOp lvl lore
k)
segOpReturns k :: SegOp lvl lore
k@(SegScan lvl
_ SegSpace
_ [SegBinOp lore]
_ [Type]
_ KernelBody lore
kbody) =
  KernelBody lore -> [ExpReturns] -> m [ExpReturns]
forall lore (m :: * -> *).
(Mem lore, HasScope lore m, Monad m) =>
KernelBody lore -> [ExpReturns] -> m [ExpReturns]
kernelBodyReturns KernelBody lore
kbody ([ExpReturns] -> m [ExpReturns])
-> m [ExpReturns] -> m [ExpReturns]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ([ExtType] -> [ExpReturns]
extReturns ([ExtType] -> [ExpReturns]) -> m [ExtType] -> m [ExpReturns]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOp lvl lore -> m [ExtType]
forall op t (m :: * -> *).
(TypedOp op, HasScope t m) =>
op -> m [ExtType]
opType SegOp lvl lore
k)
segOpReturns (SegHist lvl
_ SegSpace
_ [HistOp lore]
ops [Type]
_ KernelBody lore
_) =
  [[ExpReturns]] -> [ExpReturns]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[ExpReturns]] -> [ExpReturns])
-> m [[ExpReturns]] -> m [ExpReturns]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (HistOp lore -> m [ExpReturns])
-> [HistOp lore] -> m [[ExpReturns]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((VName -> m ExpReturns) -> [VName] -> m [ExpReturns]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> m ExpReturns
forall lore (m :: * -> *).
(HasScope lore m, Monad m, Mem lore) =>
VName -> m ExpReturns
varReturns ([VName] -> m [ExpReturns])
-> (HistOp lore -> [VName]) -> HistOp lore -> m [ExpReturns]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp lore -> [VName]
forall lore. HistOp lore -> [VName]
histDest) [HistOp lore]
ops