{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Futhark.Representation.Kernels.Kernel
       ( HistOp(..)
       , histType
       , SegRedOp(..)
       , segRedResults
       , KernelBody(..)
       , aliasAnalyseKernelBody
       , consumedInKernelBody
       , ResultManifest(..)
       , KernelResult(..)
       , kernelResultSubExp
       , SplitOrdering(..)

       -- * Segmented operations
       , SegOp(..)
       , SegLevel(..)
       , SegVirt(..)
       , segLevel
       , segSpace
       , typeCheckSegOp
       , SegSpace(..)
       , scopeOfSegSpace
       , segSpaceDims

       -- ** Generic traversal
       , SegOpMapper(..)
       , identitySegOpMapper
       , mapSegOpM

       -- * Size operations
       , SizeOp(..)

       -- * Host operations
       , HostOp(..)
       , typeCheckHostOp

       -- * Reexports
       , module Futhark.Representation.Kernels.Sizes
       )
       where

import Control.Arrow (first)
import Control.Monad.State.Strict
import Control.Monad.Writer hiding (mapM_)
import Control.Monad.Identity hiding (mapM_)
import qualified Data.Map.Strict as M
import Data.List (intersperse)

import Futhark.Representation.AST
import qualified Futhark.Analysis.Alias as Alias
import qualified Futhark.Analysis.ScalExp as SE
import qualified Futhark.Analysis.SymbolTable as ST
import Futhark.Analysis.PrimExp.Convert
import qualified Futhark.Util.Pretty as PP
import Futhark.Util.Pretty
  ((</>), (<+>), ppr, commasep, Pretty, parens, text)
import Futhark.Transform.Substitute
import Futhark.Transform.Rename
import Futhark.Optimise.Simplify.Lore
import Futhark.Representation.Ranges
  (Ranges, removeLambdaRanges, removeStmRanges, mkBodyRanges)
import Futhark.Representation.AST.Attributes.Ranges
import Futhark.Representation.AST.Attributes.Aliases
import Futhark.Representation.Aliases
  (Aliases, removeLambdaAliases, removeStmAliases)
import Futhark.Representation.Kernels.Sizes
import qualified Futhark.TypeCheck as TC
import Futhark.Analysis.Metrics
import qualified Futhark.Analysis.Range as Range
import Futhark.Util (maybeNth)

-- | How an array is split into chunks.
data SplitOrdering = SplitContiguous
                   | SplitStrided SubExp
                   deriving (SplitOrdering -> SplitOrdering -> Bool
(SplitOrdering -> SplitOrdering -> Bool)
-> (SplitOrdering -> SplitOrdering -> Bool) -> Eq SplitOrdering
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SplitOrdering -> SplitOrdering -> Bool
$c/= :: SplitOrdering -> SplitOrdering -> Bool
== :: SplitOrdering -> SplitOrdering -> Bool
$c== :: SplitOrdering -> SplitOrdering -> Bool
Eq, Eq SplitOrdering
Eq SplitOrdering
-> (SplitOrdering -> SplitOrdering -> Ordering)
-> (SplitOrdering -> SplitOrdering -> Bool)
-> (SplitOrdering -> SplitOrdering -> Bool)
-> (SplitOrdering -> SplitOrdering -> Bool)
-> (SplitOrdering -> SplitOrdering -> Bool)
-> (SplitOrdering -> SplitOrdering -> SplitOrdering)
-> (SplitOrdering -> SplitOrdering -> SplitOrdering)
-> Ord SplitOrdering
SplitOrdering -> SplitOrdering -> Bool
SplitOrdering -> SplitOrdering -> Ordering
SplitOrdering -> SplitOrdering -> SplitOrdering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: SplitOrdering -> SplitOrdering -> SplitOrdering
$cmin :: SplitOrdering -> SplitOrdering -> SplitOrdering
max :: SplitOrdering -> SplitOrdering -> SplitOrdering
$cmax :: SplitOrdering -> SplitOrdering -> SplitOrdering
>= :: SplitOrdering -> SplitOrdering -> Bool
$c>= :: SplitOrdering -> SplitOrdering -> Bool
> :: SplitOrdering -> SplitOrdering -> Bool
$c> :: SplitOrdering -> SplitOrdering -> Bool
<= :: SplitOrdering -> SplitOrdering -> Bool
$c<= :: SplitOrdering -> SplitOrdering -> Bool
< :: SplitOrdering -> SplitOrdering -> Bool
$c< :: SplitOrdering -> SplitOrdering -> Bool
compare :: SplitOrdering -> SplitOrdering -> Ordering
$ccompare :: SplitOrdering -> SplitOrdering -> Ordering
$cp1Ord :: Eq SplitOrdering
Ord, Int -> SplitOrdering -> ShowS
[SplitOrdering] -> ShowS
SplitOrdering -> String
(Int -> SplitOrdering -> ShowS)
-> (SplitOrdering -> String)
-> ([SplitOrdering] -> ShowS)
-> Show SplitOrdering
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SplitOrdering] -> ShowS
$cshowList :: [SplitOrdering] -> ShowS
show :: SplitOrdering -> String
$cshow :: SplitOrdering -> String
showsPrec :: Int -> SplitOrdering -> ShowS
$cshowsPrec :: Int -> SplitOrdering -> ShowS
Show)

instance FreeIn SplitOrdering where
  freeIn' :: SplitOrdering -> FV
freeIn' SplitOrdering
SplitContiguous = FV
forall a. Monoid a => a
mempty
  freeIn' (SplitStrided SubExp
stride) = SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn' SubExp
stride

instance Substitute SplitOrdering where
  substituteNames :: Map VName VName -> SplitOrdering -> SplitOrdering
substituteNames Map VName VName
_ SplitOrdering
SplitContiguous =
    SplitOrdering
SplitContiguous
  substituteNames Map VName VName
subst (SplitStrided SubExp
stride) =
    SubExp -> SplitOrdering
SplitStrided (SubExp -> SplitOrdering) -> SubExp -> SplitOrdering
forall a b. (a -> b) -> a -> b
$ Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst SubExp
stride

instance Rename SplitOrdering where
  rename :: SplitOrdering -> RenameM SplitOrdering
rename SplitOrdering
SplitContiguous =
    SplitOrdering -> RenameM SplitOrdering
forall (f :: * -> *) a. Applicative f => a -> f a
pure SplitOrdering
SplitContiguous
  rename (SplitStrided SubExp
stride) =
    SubExp -> SplitOrdering
SplitStrided (SubExp -> SplitOrdering)
-> RenameM SubExp -> RenameM SplitOrdering
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> RenameM SubExp
forall a. Rename a => a -> RenameM a
rename SubExp
stride

data HistOp lore =
  HistOp { HistOp lore -> SubExp
histWidth :: SubExp
         , HistOp lore -> SubExp
histRaceFactor :: SubExp
         , HistOp lore -> [VName]
histDest :: [VName]
         , HistOp lore -> [SubExp]
histNeutral :: [SubExp]
         , HistOp lore -> Shape
histShape :: Shape
           -- ^ In case this operator is semantically a vectorised
           -- operator (corresponding to a perfect map nest in the
           -- SOACS representation), these are the logical
           -- "dimensions".  This is used to generate more efficient
           -- code.
         , HistOp lore -> Lambda lore
histOp :: Lambda lore
         }
  deriving (HistOp lore -> HistOp lore -> Bool
(HistOp lore -> HistOp lore -> Bool)
-> (HistOp lore -> HistOp lore -> Bool) -> Eq (HistOp lore)
forall lore. Annotations lore => HistOp lore -> HistOp lore -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: HistOp lore -> HistOp lore -> Bool
$c/= :: forall lore. Annotations lore => HistOp lore -> HistOp lore -> Bool
== :: HistOp lore -> HistOp lore -> Bool
$c== :: forall lore. Annotations lore => HistOp lore -> HistOp lore -> Bool
Eq, Eq (HistOp lore)
Eq (HistOp lore)
-> (HistOp lore -> HistOp lore -> Ordering)
-> (HistOp lore -> HistOp lore -> Bool)
-> (HistOp lore -> HistOp lore -> Bool)
-> (HistOp lore -> HistOp lore -> Bool)
-> (HistOp lore -> HistOp lore -> Bool)
-> (HistOp lore -> HistOp lore -> HistOp lore)
-> (HistOp lore -> HistOp lore -> HistOp lore)
-> Ord (HistOp lore)
HistOp lore -> HistOp lore -> Bool
HistOp lore -> HistOp lore -> Ordering
HistOp lore -> HistOp lore -> HistOp lore
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall lore. Annotations lore => Eq (HistOp lore)
forall lore. Annotations lore => HistOp lore -> HistOp lore -> Bool
forall lore.
Annotations lore =>
HistOp lore -> HistOp lore -> Ordering
forall lore.
Annotations lore =>
HistOp lore -> HistOp lore -> HistOp lore
min :: HistOp lore -> HistOp lore -> HistOp lore
$cmin :: forall lore.
Annotations lore =>
HistOp lore -> HistOp lore -> HistOp lore
max :: HistOp lore -> HistOp lore -> HistOp lore
$cmax :: forall lore.
Annotations lore =>
HistOp lore -> HistOp lore -> HistOp lore
>= :: HistOp lore -> HistOp lore -> Bool
$c>= :: forall lore. Annotations lore => HistOp lore -> HistOp lore -> Bool
> :: HistOp lore -> HistOp lore -> Bool
$c> :: forall lore. Annotations lore => HistOp lore -> HistOp lore -> Bool
<= :: HistOp lore -> HistOp lore -> Bool
$c<= :: forall lore. Annotations lore => HistOp lore -> HistOp lore -> Bool
< :: HistOp lore -> HistOp lore -> Bool
$c< :: forall lore. Annotations lore => HistOp lore -> HistOp lore -> Bool
compare :: HistOp lore -> HistOp lore -> Ordering
$ccompare :: forall lore.
Annotations lore =>
HistOp lore -> HistOp lore -> Ordering
$cp1Ord :: forall lore. Annotations lore => Eq (HistOp lore)
Ord, Int -> HistOp lore -> ShowS
[HistOp lore] -> ShowS
HistOp lore -> String
(Int -> HistOp lore -> ShowS)
-> (HistOp lore -> String)
-> ([HistOp lore] -> ShowS)
-> Show (HistOp lore)
forall lore. Annotations lore => Int -> HistOp lore -> ShowS
forall lore. Annotations lore => [HistOp lore] -> ShowS
forall lore. Annotations lore => HistOp lore -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [HistOp lore] -> ShowS
$cshowList :: forall lore. Annotations lore => [HistOp lore] -> ShowS
show :: HistOp lore -> String
$cshow :: forall lore. Annotations lore => HistOp lore -> String
showsPrec :: Int -> HistOp lore -> ShowS
$cshowsPrec :: forall lore. Annotations lore => Int -> HistOp lore -> ShowS
Show)

-- | The type of a histogram produced by a 'HistOp'.  This can be
-- different from the type of the 'HistDest's in case we are
-- dealing with a segmented histogram.
histType :: HistOp lore -> [Type]
histType :: HistOp lore -> [Type]
histType HistOp lore
op = (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map ((Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` HistOp lore -> SubExp
forall lore. HistOp lore -> SubExp
histWidth HistOp lore
op) (Type -> Type) -> (Type -> Type) -> Type -> Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
.
                        (Type -> Shape -> Type
`arrayOfShape` HistOp lore -> Shape
forall lore. HistOp lore -> Shape
histShape HistOp lore
op)) ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$
                   LambdaT lore -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType (LambdaT lore -> [Type]) -> LambdaT lore -> [Type]
forall a b. (a -> b) -> a -> b
$ HistOp lore -> LambdaT lore
forall lore. HistOp lore -> Lambda lore
histOp HistOp lore
op

data SegRedOp lore =
  SegRedOp { SegRedOp lore -> Commutativity
segRedComm :: Commutativity
           , SegRedOp lore -> Lambda lore
segRedLambda :: Lambda lore
           , SegRedOp lore -> [SubExp]
segRedNeutral :: [SubExp]
           , SegRedOp lore -> Shape
segRedShape :: Shape
             -- ^ In case this operator is semantically a vectorised
             -- operator (corresponding to a perfect map nest in the
             -- SOACS representation), these are the logical
             -- "dimensions".  This is used to generate more efficient
             -- code.
           }
  deriving (SegRedOp lore -> SegRedOp lore -> Bool
(SegRedOp lore -> SegRedOp lore -> Bool)
-> (SegRedOp lore -> SegRedOp lore -> Bool) -> Eq (SegRedOp lore)
forall lore.
Annotations lore =>
SegRedOp lore -> SegRedOp lore -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SegRedOp lore -> SegRedOp lore -> Bool
$c/= :: forall lore.
Annotations lore =>
SegRedOp lore -> SegRedOp lore -> Bool
== :: SegRedOp lore -> SegRedOp lore -> Bool
$c== :: forall lore.
Annotations lore =>
SegRedOp lore -> SegRedOp lore -> Bool
Eq, Eq (SegRedOp lore)
Eq (SegRedOp lore)
-> (SegRedOp lore -> SegRedOp lore -> Ordering)
-> (SegRedOp lore -> SegRedOp lore -> Bool)
-> (SegRedOp lore -> SegRedOp lore -> Bool)
-> (SegRedOp lore -> SegRedOp lore -> Bool)
-> (SegRedOp lore -> SegRedOp lore -> Bool)
-> (SegRedOp lore -> SegRedOp lore -> SegRedOp lore)
-> (SegRedOp lore -> SegRedOp lore -> SegRedOp lore)
-> Ord (SegRedOp lore)
SegRedOp lore -> SegRedOp lore -> Bool
SegRedOp lore -> SegRedOp lore -> Ordering
SegRedOp lore -> SegRedOp lore -> SegRedOp lore
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall lore. Annotations lore => Eq (SegRedOp lore)
forall lore.
Annotations lore =>
SegRedOp lore -> SegRedOp lore -> Bool
forall lore.
Annotations lore =>
SegRedOp lore -> SegRedOp lore -> Ordering
forall lore.
Annotations lore =>
SegRedOp lore -> SegRedOp lore -> SegRedOp lore
min :: SegRedOp lore -> SegRedOp lore -> SegRedOp lore
$cmin :: forall lore.
Annotations lore =>
SegRedOp lore -> SegRedOp lore -> SegRedOp lore
max :: SegRedOp lore -> SegRedOp lore -> SegRedOp lore
$cmax :: forall lore.
Annotations lore =>
SegRedOp lore -> SegRedOp lore -> SegRedOp lore
>= :: SegRedOp lore -> SegRedOp lore -> Bool
$c>= :: forall lore.
Annotations lore =>
SegRedOp lore -> SegRedOp lore -> Bool
> :: SegRedOp lore -> SegRedOp lore -> Bool
$c> :: forall lore.
Annotations lore =>
SegRedOp lore -> SegRedOp lore -> Bool
<= :: SegRedOp lore -> SegRedOp lore -> Bool
$c<= :: forall lore.
Annotations lore =>
SegRedOp lore -> SegRedOp lore -> Bool
< :: SegRedOp lore -> SegRedOp lore -> Bool
$c< :: forall lore.
Annotations lore =>
SegRedOp lore -> SegRedOp lore -> Bool
compare :: SegRedOp lore -> SegRedOp lore -> Ordering
$ccompare :: forall lore.
Annotations lore =>
SegRedOp lore -> SegRedOp lore -> Ordering
$cp1Ord :: forall lore. Annotations lore => Eq (SegRedOp lore)
Ord, Int -> SegRedOp lore -> ShowS
[SegRedOp lore] -> ShowS
SegRedOp lore -> String
(Int -> SegRedOp lore -> ShowS)
-> (SegRedOp lore -> String)
-> ([SegRedOp lore] -> ShowS)
-> Show (SegRedOp lore)
forall lore. Annotations lore => Int -> SegRedOp lore -> ShowS
forall lore. Annotations lore => [SegRedOp lore] -> ShowS
forall lore. Annotations lore => SegRedOp lore -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SegRedOp lore] -> ShowS
$cshowList :: forall lore. Annotations lore => [SegRedOp lore] -> ShowS
show :: SegRedOp lore -> String
$cshow :: forall lore. Annotations lore => SegRedOp lore -> String
showsPrec :: Int -> SegRedOp lore -> ShowS
$cshowsPrec :: forall lore. Annotations lore => Int -> SegRedOp lore -> ShowS
Show)

-- | How many reduction results are produced by these 'SegRedOp's?
segRedResults :: [SegRedOp lore] -> Int
segRedResults :: [SegRedOp lore] -> Int
segRedResults = [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Int] -> Int)
-> ([SegRedOp lore] -> [Int]) -> [SegRedOp lore] -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SegRedOp lore -> Int) -> [SegRedOp lore] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int)
-> (SegRedOp lore -> [SubExp]) -> SegRedOp lore -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegRedOp lore -> [SubExp]
forall lore. SegRedOp lore -> [SubExp]
segRedNeutral)
-- | The body of a 'Kernel'.
data KernelBody lore = KernelBody { KernelBody lore -> BodyAttr lore
kernelBodyLore :: BodyAttr lore
                                  , KernelBody lore -> Stms lore
kernelBodyStms :: Stms lore
                                  , KernelBody lore -> [KernelResult]
kernelBodyResult :: [KernelResult]
                                  }

deriving instance Annotations lore => Ord (KernelBody lore)
deriving instance Annotations lore => Show (KernelBody lore)
deriving instance Annotations lore => Eq (KernelBody lore)

-- | Metadata about whether there is a subtle point to this
-- 'KernelResult'.  This is used to protect things like tiling, which
-- might otherwise be removed by the simplifier because they're
-- semantically redundant.  This has no semantic effect and can be
-- ignored at code generation.
data ResultManifest
  = ResultNoSimplify
    -- ^ Don't simplify this one!
  | ResultMaySimplify
    -- ^ Go nuts.
  | ResultPrivate
    -- ^ The results produced are only used within the
    -- same physical thread later on, and can thus be
    -- kept in registers.
  deriving (ResultManifest -> ResultManifest -> Bool
(ResultManifest -> ResultManifest -> Bool)
-> (ResultManifest -> ResultManifest -> Bool) -> Eq ResultManifest
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ResultManifest -> ResultManifest -> Bool
$c/= :: ResultManifest -> ResultManifest -> Bool
== :: ResultManifest -> ResultManifest -> Bool
$c== :: ResultManifest -> ResultManifest -> Bool
Eq, Int -> ResultManifest -> ShowS
[ResultManifest] -> ShowS
ResultManifest -> String
(Int -> ResultManifest -> ShowS)
-> (ResultManifest -> String)
-> ([ResultManifest] -> ShowS)
-> Show ResultManifest
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ResultManifest] -> ShowS
$cshowList :: [ResultManifest] -> ShowS
show :: ResultManifest -> String
$cshow :: ResultManifest -> String
showsPrec :: Int -> ResultManifest -> ShowS
$cshowsPrec :: Int -> ResultManifest -> ShowS
Show, Eq ResultManifest
Eq ResultManifest
-> (ResultManifest -> ResultManifest -> Ordering)
-> (ResultManifest -> ResultManifest -> Bool)
-> (ResultManifest -> ResultManifest -> Bool)
-> (ResultManifest -> ResultManifest -> Bool)
-> (ResultManifest -> ResultManifest -> Bool)
-> (ResultManifest -> ResultManifest -> ResultManifest)
-> (ResultManifest -> ResultManifest -> ResultManifest)
-> Ord ResultManifest
ResultManifest -> ResultManifest -> Bool
ResultManifest -> ResultManifest -> Ordering
ResultManifest -> ResultManifest -> ResultManifest
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: ResultManifest -> ResultManifest -> ResultManifest
$cmin :: ResultManifest -> ResultManifest -> ResultManifest
max :: ResultManifest -> ResultManifest -> ResultManifest
$cmax :: ResultManifest -> ResultManifest -> ResultManifest
>= :: ResultManifest -> ResultManifest -> Bool
$c>= :: ResultManifest -> ResultManifest -> Bool
> :: ResultManifest -> ResultManifest -> Bool
$c> :: ResultManifest -> ResultManifest -> Bool
<= :: ResultManifest -> ResultManifest -> Bool
$c<= :: ResultManifest -> ResultManifest -> Bool
< :: ResultManifest -> ResultManifest -> Bool
$c< :: ResultManifest -> ResultManifest -> Bool
compare :: ResultManifest -> ResultManifest -> Ordering
$ccompare :: ResultManifest -> ResultManifest -> Ordering
$cp1Ord :: Eq ResultManifest
Ord)

data KernelResult = Returns ResultManifest SubExp
                    -- ^ Each "worker" in the kernel returns this.
                    -- Whether this is a result-per-thread or a
                    -- result-per-group depends on the 'SegLevel'.
                  | WriteReturns
                    [SubExp] -- Size of array.  Must match number of dims.
                    VName -- Which array
                    [([SubExp], SubExp)]
                    -- Arbitrary number of index/value pairs.
                  | ConcatReturns
                    SplitOrdering -- Permuted?
                    SubExp -- The final size.
                    SubExp -- Per-thread/group (max) chunk size.
                    VName -- Chunk by this worker.
                  | TileReturns
                    [(SubExp, SubExp)] -- Total/tile for each dimension
                    VName -- Tile written by this worker.
                    -- The TileReturns must not expect more than one
                    -- result to be written per physical thread.
                  deriving (KernelResult -> KernelResult -> Bool
(KernelResult -> KernelResult -> Bool)
-> (KernelResult -> KernelResult -> Bool) -> Eq KernelResult
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: KernelResult -> KernelResult -> Bool
$c/= :: KernelResult -> KernelResult -> Bool
== :: KernelResult -> KernelResult -> Bool
$c== :: KernelResult -> KernelResult -> Bool
Eq, Int -> KernelResult -> ShowS
[KernelResult] -> ShowS
KernelResult -> String
(Int -> KernelResult -> ShowS)
-> (KernelResult -> String)
-> ([KernelResult] -> ShowS)
-> Show KernelResult
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [KernelResult] -> ShowS
$cshowList :: [KernelResult] -> ShowS
show :: KernelResult -> String
$cshow :: KernelResult -> String
showsPrec :: Int -> KernelResult -> ShowS
$cshowsPrec :: Int -> KernelResult -> ShowS
Show, Eq KernelResult
Eq KernelResult
-> (KernelResult -> KernelResult -> Ordering)
-> (KernelResult -> KernelResult -> Bool)
-> (KernelResult -> KernelResult -> Bool)
-> (KernelResult -> KernelResult -> Bool)
-> (KernelResult -> KernelResult -> Bool)
-> (KernelResult -> KernelResult -> KernelResult)
-> (KernelResult -> KernelResult -> KernelResult)
-> Ord KernelResult
KernelResult -> KernelResult -> Bool
KernelResult -> KernelResult -> Ordering
KernelResult -> KernelResult -> KernelResult
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: KernelResult -> KernelResult -> KernelResult
$cmin :: KernelResult -> KernelResult -> KernelResult
max :: KernelResult -> KernelResult -> KernelResult
$cmax :: KernelResult -> KernelResult -> KernelResult
>= :: KernelResult -> KernelResult -> Bool
$c>= :: KernelResult -> KernelResult -> Bool
> :: KernelResult -> KernelResult -> Bool
$c> :: KernelResult -> KernelResult -> Bool
<= :: KernelResult -> KernelResult -> Bool
$c<= :: KernelResult -> KernelResult -> Bool
< :: KernelResult -> KernelResult -> Bool
$c< :: KernelResult -> KernelResult -> Bool
compare :: KernelResult -> KernelResult -> Ordering
$ccompare :: KernelResult -> KernelResult -> Ordering
$cp1Ord :: Eq KernelResult
Ord)

kernelResultSubExp :: KernelResult -> SubExp
kernelResultSubExp :: KernelResult -> SubExp
kernelResultSubExp (Returns ResultManifest
_ SubExp
se) = SubExp
se
kernelResultSubExp (WriteReturns [SubExp]
_ VName
arr [([SubExp], SubExp)]
_) = VName -> SubExp
Var VName
arr
kernelResultSubExp (ConcatReturns SplitOrdering
_ SubExp
_ SubExp
_ VName
v) = VName -> SubExp
Var VName
v
kernelResultSubExp (TileReturns [(SubExp, SubExp)]
_ VName
v) = VName -> SubExp
Var VName
v

instance FreeIn KernelResult where
  freeIn' :: KernelResult -> FV
freeIn' (Returns ResultManifest
_ SubExp
what) = SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn' SubExp
what
  freeIn' (WriteReturns [SubExp]
rws VName
arr [([SubExp], SubExp)]
res) = [SubExp] -> FV
forall a. FreeIn a => a -> FV
freeIn' [SubExp]
rws FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> VName -> FV
forall a. FreeIn a => a -> FV
freeIn' VName
arr FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> [([SubExp], SubExp)] -> FV
forall a. FreeIn a => a -> FV
freeIn' [([SubExp], SubExp)]
res
  freeIn' (ConcatReturns SplitOrdering
o SubExp
w SubExp
per_thread_elems VName
v) =
    SplitOrdering -> FV
forall a. FreeIn a => a -> FV
freeIn' SplitOrdering
o FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn' SubExp
w FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn' SubExp
per_thread_elems FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> VName -> FV
forall a. FreeIn a => a -> FV
freeIn' VName
v
  freeIn' (TileReturns [(SubExp, SubExp)]
dims VName
v) =
    [(SubExp, SubExp)] -> FV
forall a. FreeIn a => a -> FV
freeIn' [(SubExp, SubExp)]
dims FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> VName -> FV
forall a. FreeIn a => a -> FV
freeIn' VName
v

instance Attributes lore => FreeIn (KernelBody lore) where
  freeIn' :: KernelBody lore -> FV
freeIn' (KernelBody BodyAttr lore
attr Stms lore
stms [KernelResult]
res) =
    Names -> FV -> FV
fvBind Names
bound_in_stms (FV -> FV) -> FV -> FV
forall a b. (a -> b) -> a -> b
$ BodyAttr lore -> FV
forall a. FreeIn a => a -> FV
freeIn' BodyAttr lore
attr FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> Stms lore -> FV
forall a. FreeIn a => a -> FV
freeIn' Stms lore
stms FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> [KernelResult] -> FV
forall a. FreeIn a => a -> FV
freeIn' [KernelResult]
res
    where bound_in_stms :: Names
bound_in_stms = (Stm lore -> Names) -> Stms lore -> Names
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm lore -> Names
forall lore. Stm lore -> Names
boundByStm Stms lore
stms

instance Attributes lore => Substitute (KernelBody lore) where
  substituteNames :: Map VName VName -> KernelBody lore -> KernelBody lore
substituteNames Map VName VName
subst (KernelBody BodyAttr lore
attr Stms lore
stms [KernelResult]
res) =
    BodyAttr lore -> Stms lore -> [KernelResult] -> KernelBody lore
forall lore.
BodyAttr lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody
    (Map VName VName -> BodyAttr lore -> BodyAttr lore
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst BodyAttr lore
attr)
    (Map VName VName -> Stms lore -> Stms lore
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst Stms lore
stms)
    (Map VName VName -> [KernelResult] -> [KernelResult]
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst [KernelResult]
res)

instance Substitute KernelResult where
  substituteNames :: Map VName VName -> KernelResult -> KernelResult
substituteNames Map VName VName
subst (Returns ResultManifest
manifest SubExp
se) =
    ResultManifest -> SubExp -> KernelResult
Returns ResultManifest
manifest (Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst SubExp
se)
  substituteNames Map VName VName
subst (WriteReturns [SubExp]
rws VName
arr [([SubExp], SubExp)]
res) =
    [SubExp] -> VName -> [([SubExp], SubExp)] -> KernelResult
WriteReturns
    (Map VName VName -> [SubExp] -> [SubExp]
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst [SubExp]
rws) (Map VName VName -> VName -> VName
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst VName
arr)
    (Map VName VName -> [([SubExp], SubExp)] -> [([SubExp], SubExp)]
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst [([SubExp], SubExp)]
res)
  substituteNames Map VName VName
subst (ConcatReturns SplitOrdering
o SubExp
w SubExp
per_thread_elems VName
v) =
    SplitOrdering -> SubExp -> SubExp -> VName -> KernelResult
ConcatReturns
    (Map VName VName -> SplitOrdering -> SplitOrdering
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst SplitOrdering
o)
    (Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst SubExp
w)
    (Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst SubExp
per_thread_elems)
    (Map VName VName -> VName -> VName
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst VName
v)
  substituteNames Map VName VName
subst (TileReturns [(SubExp, SubExp)]
dims VName
v) =
    [(SubExp, SubExp)] -> VName -> KernelResult
TileReturns (Map VName VName -> [(SubExp, SubExp)] -> [(SubExp, SubExp)]
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst [(SubExp, SubExp)]
dims) (Map VName VName -> VName -> VName
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst VName
v)

instance Attributes lore => Rename (KernelBody lore) where
  rename :: KernelBody lore -> RenameM (KernelBody lore)
rename (KernelBody BodyAttr lore
attr Stms lore
stms [KernelResult]
res) = do
    BodyAttr lore
attr' <- BodyAttr lore -> RenameM (BodyAttr lore)
forall a. Rename a => a -> RenameM a
rename BodyAttr lore
attr
    Stms lore
-> (Stms lore -> RenameM (KernelBody lore))
-> RenameM (KernelBody lore)
forall lore a.
Renameable lore =>
Stms lore -> (Stms lore -> RenameM a) -> RenameM a
renamingStms Stms lore
stms ((Stms lore -> RenameM (KernelBody lore))
 -> RenameM (KernelBody lore))
-> (Stms lore -> RenameM (KernelBody lore))
-> RenameM (KernelBody lore)
forall a b. (a -> b) -> a -> b
$ \Stms lore
stms' ->
      BodyAttr lore -> Stms lore -> [KernelResult] -> KernelBody lore
forall lore.
BodyAttr lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody BodyAttr lore
attr' Stms lore
stms' ([KernelResult] -> KernelBody lore)
-> RenameM [KernelResult] -> RenameM (KernelBody lore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [KernelResult] -> RenameM [KernelResult]
forall a. Rename a => a -> RenameM a
rename [KernelResult]
res

instance Rename KernelResult where
  rename :: KernelResult -> RenameM KernelResult
rename = KernelResult -> RenameM KernelResult
forall a. Substitute a => a -> RenameM a
substituteRename

aliasAnalyseKernelBody :: (Attributes lore,
                           CanBeAliased (Op lore)) =>
                          KernelBody lore
                       -> KernelBody (Aliases lore)
aliasAnalyseKernelBody :: KernelBody lore -> KernelBody (Aliases lore)
aliasAnalyseKernelBody (KernelBody BodyAttr lore
attr Stms lore
stms [KernelResult]
res) =
  let Body BodyAttr (Aliases lore)
attr' Stms (Aliases lore)
stms' [SubExp]
_ = AliasTable -> Body lore -> BodyT (Aliases lore)
forall lore.
(Attributes lore, CanBeAliased (Op lore)) =>
AliasTable -> Body lore -> Body (Aliases lore)
Alias.analyseBody AliasTable
forall a. Monoid a => a
mempty (Body lore -> BodyT (Aliases lore))
-> Body lore -> BodyT (Aliases lore)
forall a b. (a -> b) -> a -> b
$ BodyAttr lore -> Stms lore -> [SubExp] -> Body lore
forall lore. BodyAttr lore -> Stms lore -> [SubExp] -> BodyT lore
Body BodyAttr lore
attr Stms lore
stms []
  in BodyAttr (Aliases lore)
-> Stms (Aliases lore)
-> [KernelResult]
-> KernelBody (Aliases lore)
forall lore.
BodyAttr lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody BodyAttr (Aliases lore)
attr' Stms (Aliases lore)
stms' [KernelResult]
res

removeKernelBodyAliases :: CanBeAliased (Op lore) =>
                           KernelBody (Aliases lore) -> KernelBody lore
removeKernelBodyAliases :: KernelBody (Aliases lore) -> KernelBody lore
removeKernelBodyAliases (KernelBody (_, attr) Stms (Aliases lore)
stms [KernelResult]
res) =
  BodyAttr lore -> Stms lore -> [KernelResult] -> KernelBody lore
forall lore.
BodyAttr lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody BodyAttr lore
attr ((Stm (Aliases lore) -> Stm lore)
-> Stms (Aliases lore) -> Stms lore
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Stm (Aliases lore) -> Stm lore
forall lore.
CanBeAliased (Op lore) =>
Stm (Aliases lore) -> Stm lore
removeStmAliases Stms (Aliases lore)
stms) [KernelResult]
res

addKernelBodyRanges :: (Attributes lore, CanBeRanged (Op lore)) =>
                       KernelBody lore -> Range.RangeM (KernelBody (Ranges lore))
addKernelBodyRanges :: KernelBody lore -> RangeM (KernelBody (Ranges lore))
addKernelBodyRanges (KernelBody BodyAttr lore
attr Stms lore
stms [KernelResult]
res) =
  Stms lore
-> (Stms (Ranges lore) -> RangeM (KernelBody (Ranges lore)))
-> RangeM (KernelBody (Ranges lore))
forall lore a.
(Attributes lore, CanBeRanged (Op lore)) =>
Stms lore -> (Stms (Ranges lore) -> RangeM a) -> RangeM a
Range.analyseStms Stms lore
stms ((Stms (Ranges lore) -> RangeM (KernelBody (Ranges lore)))
 -> RangeM (KernelBody (Ranges lore)))
-> (Stms (Ranges lore) -> RangeM (KernelBody (Ranges lore)))
-> RangeM (KernelBody (Ranges lore))
forall a b. (a -> b) -> a -> b
$ \Stms (Ranges lore)
stms' -> do
  let attr' :: ([Range], BodyAttr lore)
attr' = (Stms lore -> [SubExp] -> [Range]
forall lore. Stms lore -> [SubExp] -> [Range]
mkBodyRanges Stms lore
stms ([SubExp] -> [Range]) -> [SubExp] -> [Range]
forall a b. (a -> b) -> a -> b
$ (KernelResult -> SubExp) -> [KernelResult] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> SubExp
kernelResultSubExp [KernelResult]
res, BodyAttr lore
attr)
  KernelBody (Ranges lore) -> RangeM (KernelBody (Ranges lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (KernelBody (Ranges lore) -> RangeM (KernelBody (Ranges lore)))
-> KernelBody (Ranges lore) -> RangeM (KernelBody (Ranges lore))
forall a b. (a -> b) -> a -> b
$ BodyAttr (Ranges lore)
-> Stms (Ranges lore) -> [KernelResult] -> KernelBody (Ranges lore)
forall lore.
BodyAttr lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody ([Range], BodyAttr lore)
BodyAttr (Ranges lore)
attr' Stms (Ranges lore)
stms' [KernelResult]
res

removeKernelBodyRanges :: CanBeRanged (Op lore) =>
                          KernelBody (Ranges lore) -> KernelBody lore
removeKernelBodyRanges :: KernelBody (Ranges lore) -> KernelBody lore
removeKernelBodyRanges (KernelBody (_, attr) Stms (Ranges lore)
stms [KernelResult]
res) =
  BodyAttr lore -> Stms lore -> [KernelResult] -> KernelBody lore
forall lore.
BodyAttr lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody BodyAttr lore
attr ((Stm (Ranges lore) -> Stm lore) -> Stms (Ranges lore) -> Stms lore
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Stm (Ranges lore) -> Stm lore
forall lore. CanBeRanged (Op lore) => Stm (Ranges lore) -> Stm lore
removeStmRanges Stms (Ranges lore)
stms) [KernelResult]
res

removeKernelBodyWisdom :: CanBeWise (Op lore) =>
                          KernelBody (Wise lore) -> KernelBody lore
removeKernelBodyWisdom :: KernelBody (Wise lore) -> KernelBody lore
removeKernelBodyWisdom (KernelBody BodyAttr (Wise lore)
attr Stms (Wise lore)
stms [KernelResult]
res) =
  let Body BodyAttr lore
attr' Stms lore
stms' [SubExp]
_ = Body (Wise lore) -> BodyT lore
forall lore. CanBeWise (Op lore) => Body (Wise lore) -> Body lore
removeBodyWisdom (Body (Wise lore) -> BodyT lore) -> Body (Wise lore) -> BodyT lore
forall a b. (a -> b) -> a -> b
$ BodyAttr (Wise lore)
-> Stms (Wise lore) -> [SubExp] -> Body (Wise lore)
forall lore. BodyAttr lore -> Stms lore -> [SubExp] -> BodyT lore
Body BodyAttr (Wise lore)
attr Stms (Wise lore)
stms []
  in BodyAttr lore -> Stms lore -> [KernelResult] -> KernelBody lore
forall lore.
BodyAttr lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody BodyAttr lore
attr' Stms lore
stms' [KernelResult]
res

consumedInKernelBody :: Aliased lore =>
                        KernelBody lore -> Names
consumedInKernelBody :: KernelBody lore -> Names
consumedInKernelBody (KernelBody BodyAttr lore
attr Stms lore
stms [KernelResult]
res) =
  Body lore -> Names
forall lore. Aliased lore => Body lore -> Names
consumedInBody (BodyAttr lore -> Stms lore -> [SubExp] -> Body lore
forall lore. BodyAttr lore -> Stms lore -> [SubExp] -> BodyT lore
Body BodyAttr lore
attr Stms lore
stms []) Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ((KernelResult -> Names) -> [KernelResult] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> Names
consumedByReturn [KernelResult]
res)
  where consumedByReturn :: KernelResult -> Names
consumedByReturn (WriteReturns [SubExp]
_ VName
a [([SubExp], SubExp)]
_) = VName -> Names
oneName VName
a
        consumedByReturn KernelResult
_                    = Names
forall a. Monoid a => a
mempty

checkKernelBody :: TC.Checkable lore =>
                   [Type] -> KernelBody (Aliases lore) -> TC.TypeM lore ()
checkKernelBody :: [Type] -> KernelBody (Aliases lore) -> TypeM lore ()
checkKernelBody [Type]
ts (KernelBody (_, attr) Stms (Aliases lore)
stms [KernelResult]
kres) = do
  BodyAttr lore -> TypeM lore ()
forall lore. Checkable lore => BodyAttr lore -> TypeM lore ()
TC.checkBodyLore BodyAttr lore
attr
  Stms (Aliases lore) -> TypeM lore () -> TypeM lore ()
forall lore a.
Checkable lore =>
Stms (Aliases lore) -> TypeM lore a -> TypeM lore a
TC.checkStms Stms (Aliases lore)
stms (TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ do
    Bool -> TypeM lore () -> TypeM lore ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
ts Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [KernelResult] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [KernelResult]
kres) (TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
      ErrorCase lore -> TypeM lore ()
forall lore a. ErrorCase lore -> TypeM lore a
TC.bad (ErrorCase lore -> TypeM lore ())
-> ErrorCase lore -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ String -> ErrorCase lore
forall lore. String -> ErrorCase lore
TC.TypeError (String -> ErrorCase lore) -> String -> ErrorCase lore
forall a b. (a -> b) -> a -> b
$ String
"Kernel return type is " String -> ShowS
forall a. [a] -> [a] -> [a]
++ [Type] -> String
forall a. Pretty a => [a] -> String
prettyTuple [Type]
ts String -> ShowS
forall a. [a] -> [a] -> [a]
++
      String
", but body returns " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show ([KernelResult] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [KernelResult]
kres) String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" values."
    (KernelResult -> Type -> TypeM lore ())
-> [KernelResult] -> [Type] -> TypeM lore ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ KernelResult -> Type -> TypeM lore ()
forall lore.
Checkable lore =>
KernelResult -> Type -> TypeM lore ()
checkKernelResult [KernelResult]
kres [Type]
ts

  where checkKernelResult :: KernelResult -> Type -> TypeM lore ()
checkKernelResult (Returns ResultManifest
_ SubExp
what) Type
t =
          [Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [Type
t] SubExp
what
        checkKernelResult (WriteReturns [SubExp]
rws VName
arr [([SubExp], SubExp)]
res) Type
t = do
          (SubExp -> TypeM lore ()) -> [SubExp] -> TypeM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ([Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32]) [SubExp]
rws
          Type
arr_t <- VName -> TypeM lore Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
arr
          [([SubExp], SubExp)]
-> (([SubExp], SubExp) -> TypeM lore ()) -> TypeM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [([SubExp], SubExp)]
res ((([SubExp], SubExp) -> TypeM lore ()) -> TypeM lore ())
-> (([SubExp], SubExp) -> TypeM lore ()) -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ \([SubExp]
is, SubExp
e) -> do
            (SubExp -> TypeM lore ()) -> [SubExp] -> TypeM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ([Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32]) [SubExp]
is
            [Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [Type
t] SubExp
e
            Bool -> TypeM lore () -> TypeM lore ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Type
arr_t Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
== Type
t Type -> Shape -> Type
`arrayOfShape` [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
rws) (TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
              ErrorCase lore -> TypeM lore ()
forall lore a. ErrorCase lore -> TypeM lore a
TC.bad (ErrorCase lore -> TypeM lore ())
-> ErrorCase lore -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ String -> ErrorCase lore
forall lore. String -> ErrorCase lore
TC.TypeError (String -> ErrorCase lore) -> String -> ErrorCase lore
forall a b. (a -> b) -> a -> b
$ String
"WriteReturns returning " String -> ShowS
forall a. [a] -> [a] -> [a]
++
              SubExp -> String
forall a. Pretty a => a -> String
pretty SubExp
e String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" of type " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Type -> String
forall a. Pretty a => a -> String
pretty Type
t String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
", shape=" String -> ShowS
forall a. [a] -> [a] -> [a]
++ [SubExp] -> String
forall a. Pretty a => a -> String
pretty [SubExp]
rws String -> ShowS
forall a. [a] -> [a] -> [a]
++
              String
", but destination array has type " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Type -> String
forall a. Pretty a => a -> String
pretty Type
arr_t
          Names -> TypeM lore ()
forall lore. Checkable lore => Names -> TypeM lore ()
TC.consume (Names -> TypeM lore ()) -> TypeM lore Names -> TypeM lore ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> TypeM lore Names
forall lore. Checkable lore => VName -> TypeM lore Names
TC.lookupAliases VName
arr
        checkKernelResult (ConcatReturns SplitOrdering
o SubExp
w SubExp
per_thread_elems VName
v) Type
t = do
          case SplitOrdering
o of
            SplitOrdering
SplitContiguous     -> () -> TypeM lore ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
            SplitStrided SubExp
stride -> [Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32] SubExp
stride
          [Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32] SubExp
w
          [Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32] SubExp
per_thread_elems
          Type
vt <- VName -> TypeM lore Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
v
          Bool -> TypeM lore () -> TypeM lore ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Type
vt Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
== Type
t Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` Int -> Type -> SubExp
forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 Type
vt) (TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
            ErrorCase lore -> TypeM lore ()
forall lore a. ErrorCase lore -> TypeM lore a
TC.bad (ErrorCase lore -> TypeM lore ())
-> ErrorCase lore -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ String -> ErrorCase lore
forall lore. String -> ErrorCase lore
TC.TypeError (String -> ErrorCase lore) -> String -> ErrorCase lore
forall a b. (a -> b) -> a -> b
$ String
"Invalid type for ConcatReturns " String -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
pretty VName
v
        checkKernelResult (TileReturns [(SubExp, SubExp)]
dims VName
v) Type
t = do
          [(SubExp, SubExp)]
-> ((SubExp, SubExp) -> TypeM lore ()) -> TypeM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(SubExp, SubExp)]
dims (((SubExp, SubExp) -> TypeM lore ()) -> TypeM lore ())
-> ((SubExp, SubExp) -> TypeM lore ()) -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ \(SubExp
dim, SubExp
tile) -> do
            [Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32] SubExp
dim
            [Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32] SubExp
tile
          Type
vt <- VName -> TypeM lore Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
v
          Bool -> TypeM lore () -> TypeM lore ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Type
vt Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
== Type
t Type -> Shape -> Type
`arrayOfShape` [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape (((SubExp, SubExp) -> SubExp) -> [(SubExp, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(SubExp, SubExp)]
dims)) (TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
            ErrorCase lore -> TypeM lore ()
forall lore a. ErrorCase lore -> TypeM lore a
TC.bad (ErrorCase lore -> TypeM lore ())
-> ErrorCase lore -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ String -> ErrorCase lore
forall lore. String -> ErrorCase lore
TC.TypeError (String -> ErrorCase lore) -> String -> ErrorCase lore
forall a b. (a -> b) -> a -> b
$ String
"Invalid type for TileReturns " String -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
pretty VName
v

kernelBodyMetrics :: OpMetrics (Op lore) => KernelBody lore -> MetricsM ()
kernelBodyMetrics :: KernelBody lore -> MetricsM ()
kernelBodyMetrics = (Stm lore -> MetricsM ()) -> Seq (Stm lore) -> MetricsM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Stm lore -> MetricsM ()
forall lore. OpMetrics (Op lore) => Stm lore -> MetricsM ()
bindingMetrics (Seq (Stm lore) -> MetricsM ())
-> (KernelBody lore -> Seq (Stm lore))
-> KernelBody lore
-> MetricsM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelBody lore -> Seq (Stm lore)
forall lore. KernelBody lore -> Stms lore
kernelBodyStms

instance PrettyLore lore => Pretty (KernelBody lore) where
  ppr :: KernelBody lore -> Doc
ppr (KernelBody BodyAttr lore
_ Stms lore
stms [KernelResult]
res) =
    [Doc] -> Doc
PP.stack ((Stm lore -> Doc) -> [Stm lore] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map Stm lore -> Doc
forall a. Pretty a => a -> Doc
ppr (Stms lore -> [Stm lore]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms lore
stms)) Doc -> Doc -> Doc
</>
    String -> Doc
text String
"return" Doc -> Doc -> Doc
<+> Doc -> Doc
PP.braces ([Doc] -> Doc
PP.commasep ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ (KernelResult -> Doc) -> [KernelResult] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> Doc
forall a. Pretty a => a -> Doc
ppr [KernelResult]
res)

instance Pretty KernelResult where
  ppr :: KernelResult -> Doc
ppr (Returns ResultManifest
ResultNoSimplify SubExp
what) =
    String -> Doc
text String
"returns (manifest)" Doc -> Doc -> Doc
<+> SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
what
  ppr (Returns ResultManifest
ResultPrivate SubExp
what) =
    String -> Doc
text String
"returns (private)" Doc -> Doc -> Doc
<+> SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
what
  ppr (Returns ResultManifest
ResultMaySimplify SubExp
what) =
    String -> Doc
text String
"returns" Doc -> Doc -> Doc
<+> SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
what
  ppr (WriteReturns [SubExp]
rws VName
arr [([SubExp], SubExp)]
res) =
    VName -> Doc
forall a. Pretty a => a -> Doc
ppr VName
arr Doc -> Doc -> Doc
<+> String -> Doc
text String
"with" Doc -> Doc -> Doc
<+> [Doc] -> Doc
PP.apply ((([SubExp], SubExp) -> Doc) -> [([SubExp], SubExp)] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp], SubExp) -> Doc
ppRes [([SubExp], SubExp)]
res)
    where ppRes :: ([SubExp], SubExp) -> Doc
ppRes ([SubExp]
is, SubExp
e) =
            Doc -> Doc
PP.brackets ([Doc] -> Doc
PP.commasep ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ (SubExp -> SubExp -> Doc) -> [SubExp] -> [SubExp] -> [Doc]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith SubExp -> SubExp -> Doc
forall a a. (Pretty a, Pretty a) => a -> a -> Doc
f [SubExp]
is [SubExp]
rws) Doc -> Doc -> Doc
<+> String -> Doc
text String
"<-" Doc -> Doc -> Doc
<+> SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
e
          f :: a -> a -> Doc
f a
i a
rw = a -> Doc
forall a. Pretty a => a -> Doc
ppr a
i Doc -> Doc -> Doc
<+> String -> Doc
text String
"<" Doc -> Doc -> Doc
<+> a -> Doc
forall a. Pretty a => a -> Doc
ppr a
rw
  ppr (ConcatReturns SplitOrdering
o SubExp
w SubExp
per_thread_elems VName
v) =
    String -> Doc
text String
"concat" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
suff Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<>
    Doc -> Doc
parens ([Doc] -> Doc
commasep [SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
w, SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
per_thread_elems]) Doc -> Doc -> Doc
<+> VName -> Doc
forall a. Pretty a => a -> Doc
ppr VName
v
    where suff :: Doc
suff = case SplitOrdering
o of SplitOrdering
SplitContiguous     -> Doc
forall a. Monoid a => a
mempty
                           SplitStrided SubExp
stride -> String -> Doc
text String
"Strided" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens (SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
stride)
  ppr (TileReturns [(SubExp, SubExp)]
dims VName
v) =
    String -> Doc
text String
"tile" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<>
    Doc -> Doc
parens ([Doc] -> Doc
commasep ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ ((SubExp, SubExp) -> Doc) -> [(SubExp, SubExp)] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp, SubExp) -> Doc
forall a a. (Pretty a, Pretty a) => (a, a) -> Doc
onDim [(SubExp, SubExp)]
dims) Doc -> Doc -> Doc
<+> VName -> Doc
forall a. Pretty a => a -> Doc
ppr VName
v
    where onDim :: (a, a) -> Doc
onDim (a
dim, a
tile) = a -> Doc
forall a. Pretty a => a -> Doc
ppr a
dim Doc -> Doc -> Doc
<+> String -> Doc
text String
"/" Doc -> Doc -> Doc
<+> a -> Doc
forall a. Pretty a => a -> Doc
ppr a
tile

--- Segmented operations

-- | Do we need group-virtualisation when generating code for the
-- segmented operation?  In most cases, we do, but for some simple
-- kernels, we compute the full number of groups in advance, and then
-- virtualisation is an unnecessary (but generally very small)
-- overhead.  This only really matters for fairly trivial but very
-- wide @map@ kernels where each thread performs constant-time work on
-- scalars.
data SegVirt = SegVirt | SegNoVirt
             deriving (SegVirt -> SegVirt -> Bool
(SegVirt -> SegVirt -> Bool)
-> (SegVirt -> SegVirt -> Bool) -> Eq SegVirt
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SegVirt -> SegVirt -> Bool
$c/= :: SegVirt -> SegVirt -> Bool
== :: SegVirt -> SegVirt -> Bool
$c== :: SegVirt -> SegVirt -> Bool
Eq, Eq SegVirt
Eq SegVirt
-> (SegVirt -> SegVirt -> Ordering)
-> (SegVirt -> SegVirt -> Bool)
-> (SegVirt -> SegVirt -> Bool)
-> (SegVirt -> SegVirt -> Bool)
-> (SegVirt -> SegVirt -> Bool)
-> (SegVirt -> SegVirt -> SegVirt)
-> (SegVirt -> SegVirt -> SegVirt)
-> Ord SegVirt
SegVirt -> SegVirt -> Bool
SegVirt -> SegVirt -> Ordering
SegVirt -> SegVirt -> SegVirt
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: SegVirt -> SegVirt -> SegVirt
$cmin :: SegVirt -> SegVirt -> SegVirt
max :: SegVirt -> SegVirt -> SegVirt
$cmax :: SegVirt -> SegVirt -> SegVirt
>= :: SegVirt -> SegVirt -> Bool
$c>= :: SegVirt -> SegVirt -> Bool
> :: SegVirt -> SegVirt -> Bool
$c> :: SegVirt -> SegVirt -> Bool
<= :: SegVirt -> SegVirt -> Bool
$c<= :: SegVirt -> SegVirt -> Bool
< :: SegVirt -> SegVirt -> Bool
$c< :: SegVirt -> SegVirt -> Bool
compare :: SegVirt -> SegVirt -> Ordering
$ccompare :: SegVirt -> SegVirt -> Ordering
$cp1Ord :: Eq SegVirt
Ord, Int -> SegVirt -> ShowS
[SegVirt] -> ShowS
SegVirt -> String
(Int -> SegVirt -> ShowS)
-> (SegVirt -> String) -> ([SegVirt] -> ShowS) -> Show SegVirt
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SegVirt] -> ShowS
$cshowList :: [SegVirt] -> ShowS
show :: SegVirt -> String
$cshow :: SegVirt -> String
showsPrec :: Int -> SegVirt -> ShowS
$cshowsPrec :: Int -> SegVirt -> ShowS
Show)

-- | At which level the *body* of a 'SegOp' executes.
data SegLevel = SegThread { SegLevel -> Count NumGroups SubExp
segNumGroups :: Count NumGroups SubExp
                          , SegLevel -> Count GroupSize SubExp
segGroupSize :: Count GroupSize SubExp
                          , SegLevel -> SegVirt
segVirt :: SegVirt }
              | SegGroup { segNumGroups :: Count NumGroups SubExp
                         , segGroupSize :: Count GroupSize SubExp
                         , segVirt :: SegVirt }
              deriving (SegLevel -> SegLevel -> Bool
(SegLevel -> SegLevel -> Bool)
-> (SegLevel -> SegLevel -> Bool) -> Eq SegLevel
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SegLevel -> SegLevel -> Bool
$c/= :: SegLevel -> SegLevel -> Bool
== :: SegLevel -> SegLevel -> Bool
$c== :: SegLevel -> SegLevel -> Bool
Eq, Eq SegLevel
Eq SegLevel
-> (SegLevel -> SegLevel -> Ordering)
-> (SegLevel -> SegLevel -> Bool)
-> (SegLevel -> SegLevel -> Bool)
-> (SegLevel -> SegLevel -> Bool)
-> (SegLevel -> SegLevel -> Bool)
-> (SegLevel -> SegLevel -> SegLevel)
-> (SegLevel -> SegLevel -> SegLevel)
-> Ord SegLevel
SegLevel -> SegLevel -> Bool
SegLevel -> SegLevel -> Ordering
SegLevel -> SegLevel -> SegLevel
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: SegLevel -> SegLevel -> SegLevel
$cmin :: SegLevel -> SegLevel -> SegLevel
max :: SegLevel -> SegLevel -> SegLevel
$cmax :: SegLevel -> SegLevel -> SegLevel
>= :: SegLevel -> SegLevel -> Bool
$c>= :: SegLevel -> SegLevel -> Bool
> :: SegLevel -> SegLevel -> Bool
$c> :: SegLevel -> SegLevel -> Bool
<= :: SegLevel -> SegLevel -> Bool
$c<= :: SegLevel -> SegLevel -> Bool
< :: SegLevel -> SegLevel -> Bool
$c< :: SegLevel -> SegLevel -> Bool
compare :: SegLevel -> SegLevel -> Ordering
$ccompare :: SegLevel -> SegLevel -> Ordering
$cp1Ord :: Eq SegLevel
Ord, Int -> SegLevel -> ShowS
[SegLevel] -> ShowS
SegLevel -> String
(Int -> SegLevel -> ShowS)
-> (SegLevel -> String) -> ([SegLevel] -> ShowS) -> Show SegLevel
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SegLevel] -> ShowS
$cshowList :: [SegLevel] -> ShowS
show :: SegLevel -> String
$cshow :: SegLevel -> String
showsPrec :: Int -> SegLevel -> ShowS
$cshowsPrec :: Int -> SegLevel -> ShowS
Show)

-- | Index space of a 'SegOp'.
data SegSpace = SegSpace { SegSpace -> VName
segFlat :: VName
                         -- ^ Flat physical index corresponding to the
                         -- dimensions (at code generation used for a
                         -- thread ID or similar).
                         , SegSpace -> [(VName, SubExp)]
unSegSpace :: [(VName, SubExp)]
                         }
              deriving (SegSpace -> SegSpace -> Bool
(SegSpace -> SegSpace -> Bool)
-> (SegSpace -> SegSpace -> Bool) -> Eq SegSpace
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SegSpace -> SegSpace -> Bool
$c/= :: SegSpace -> SegSpace -> Bool
== :: SegSpace -> SegSpace -> Bool
$c== :: SegSpace -> SegSpace -> Bool
Eq, Eq SegSpace
Eq SegSpace
-> (SegSpace -> SegSpace -> Ordering)
-> (SegSpace -> SegSpace -> Bool)
-> (SegSpace -> SegSpace -> Bool)
-> (SegSpace -> SegSpace -> Bool)
-> (SegSpace -> SegSpace -> Bool)
-> (SegSpace -> SegSpace -> SegSpace)
-> (SegSpace -> SegSpace -> SegSpace)
-> Ord SegSpace
SegSpace -> SegSpace -> Bool
SegSpace -> SegSpace -> Ordering
SegSpace -> SegSpace -> SegSpace
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: SegSpace -> SegSpace -> SegSpace
$cmin :: SegSpace -> SegSpace -> SegSpace
max :: SegSpace -> SegSpace -> SegSpace
$cmax :: SegSpace -> SegSpace -> SegSpace
>= :: SegSpace -> SegSpace -> Bool
$c>= :: SegSpace -> SegSpace -> Bool
> :: SegSpace -> SegSpace -> Bool
$c> :: SegSpace -> SegSpace -> Bool
<= :: SegSpace -> SegSpace -> Bool
$c<= :: SegSpace -> SegSpace -> Bool
< :: SegSpace -> SegSpace -> Bool
$c< :: SegSpace -> SegSpace -> Bool
compare :: SegSpace -> SegSpace -> Ordering
$ccompare :: SegSpace -> SegSpace -> Ordering
$cp1Ord :: Eq SegSpace
Ord, Int -> SegSpace -> ShowS
[SegSpace] -> ShowS
SegSpace -> String
(Int -> SegSpace -> ShowS)
-> (SegSpace -> String) -> ([SegSpace] -> ShowS) -> Show SegSpace
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SegSpace] -> ShowS
$cshowList :: [SegSpace] -> ShowS
show :: SegSpace -> String
$cshow :: SegSpace -> String
showsPrec :: Int -> SegSpace -> ShowS
$cshowsPrec :: Int -> SegSpace -> ShowS
Show)


segSpaceDims :: SegSpace -> [SubExp]
segSpaceDims :: SegSpace -> [SubExp]
segSpaceDims (SegSpace VName
_ [(VName, SubExp)]
space) = ((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(VName, SubExp)]
space

scopeOfSegSpace :: SegSpace -> Scope lore
scopeOfSegSpace :: SegSpace -> Scope lore
scopeOfSegSpace (SegSpace VName
phys [(VName, SubExp)]
space) =
  [(VName, NameInfo lore)] -> Scope lore
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, NameInfo lore)] -> Scope lore)
-> [(VName, NameInfo lore)] -> Scope lore
forall a b. (a -> b) -> a -> b
$ [VName] -> [NameInfo lore] -> [(VName, NameInfo lore)]
forall a b. [a] -> [b] -> [(a, b)]
zip (VName
phys VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: ((VName, SubExp) -> VName) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst [(VName, SubExp)]
space) ([NameInfo lore] -> [(VName, NameInfo lore)])
-> [NameInfo lore] -> [(VName, NameInfo lore)]
forall a b. (a -> b) -> a -> b
$ NameInfo lore -> [NameInfo lore]
forall a. a -> [a]
repeat (NameInfo lore -> [NameInfo lore])
-> NameInfo lore -> [NameInfo lore]
forall a b. (a -> b) -> a -> b
$ IntType -> NameInfo lore
forall lore. IntType -> NameInfo lore
IndexInfo IntType
Int32

checkSegSpace :: TC.Checkable lore => SegSpace -> TC.TypeM lore ()
checkSegSpace :: SegSpace -> TypeM lore ()
checkSegSpace (SegSpace VName
_ [(VName, SubExp)]
dims) =
  ((VName, SubExp) -> TypeM lore ())
-> [(VName, SubExp)] -> TypeM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ([Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32] (SubExp -> TypeM lore ())
-> ((VName, SubExp) -> SubExp) -> (VName, SubExp) -> TypeM lore ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd) [(VName, SubExp)]
dims

data SegOp lore = SegMap SegLevel SegSpace [Type] (KernelBody lore)
                | SegRed SegLevel SegSpace [SegRedOp lore] [Type] (KernelBody lore)
                  -- ^ The KernelSpace must always have at least two dimensions,
                  -- implying that the result of a SegRed is always an array.
                | SegScan SegLevel SegSpace (Lambda lore) [SubExp] [Type] (KernelBody lore)
                | SegHist SegLevel SegSpace [HistOp lore] [Type] (KernelBody lore)
                deriving (SegOp lore -> SegOp lore -> Bool
(SegOp lore -> SegOp lore -> Bool)
-> (SegOp lore -> SegOp lore -> Bool) -> Eq (SegOp lore)
forall lore. Annotations lore => SegOp lore -> SegOp lore -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SegOp lore -> SegOp lore -> Bool
$c/= :: forall lore. Annotations lore => SegOp lore -> SegOp lore -> Bool
== :: SegOp lore -> SegOp lore -> Bool
$c== :: forall lore. Annotations lore => SegOp lore -> SegOp lore -> Bool
Eq, Eq (SegOp lore)
Eq (SegOp lore)
-> (SegOp lore -> SegOp lore -> Ordering)
-> (SegOp lore -> SegOp lore -> Bool)
-> (SegOp lore -> SegOp lore -> Bool)
-> (SegOp lore -> SegOp lore -> Bool)
-> (SegOp lore -> SegOp lore -> Bool)
-> (SegOp lore -> SegOp lore -> SegOp lore)
-> (SegOp lore -> SegOp lore -> SegOp lore)
-> Ord (SegOp lore)
SegOp lore -> SegOp lore -> Bool
SegOp lore -> SegOp lore -> Ordering
SegOp lore -> SegOp lore -> SegOp lore
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall lore. Annotations lore => Eq (SegOp lore)
forall lore. Annotations lore => SegOp lore -> SegOp lore -> Bool
forall lore.
Annotations lore =>
SegOp lore -> SegOp lore -> Ordering
forall lore.
Annotations lore =>
SegOp lore -> SegOp lore -> SegOp lore
min :: SegOp lore -> SegOp lore -> SegOp lore
$cmin :: forall lore.
Annotations lore =>
SegOp lore -> SegOp lore -> SegOp lore
max :: SegOp lore -> SegOp lore -> SegOp lore
$cmax :: forall lore.
Annotations lore =>
SegOp lore -> SegOp lore -> SegOp lore
>= :: SegOp lore -> SegOp lore -> Bool
$c>= :: forall lore. Annotations lore => SegOp lore -> SegOp lore -> Bool
> :: SegOp lore -> SegOp lore -> Bool
$c> :: forall lore. Annotations lore => SegOp lore -> SegOp lore -> Bool
<= :: SegOp lore -> SegOp lore -> Bool
$c<= :: forall lore. Annotations lore => SegOp lore -> SegOp lore -> Bool
< :: SegOp lore -> SegOp lore -> Bool
$c< :: forall lore. Annotations lore => SegOp lore -> SegOp lore -> Bool
compare :: SegOp lore -> SegOp lore -> Ordering
$ccompare :: forall lore.
Annotations lore =>
SegOp lore -> SegOp lore -> Ordering
$cp1Ord :: forall lore. Annotations lore => Eq (SegOp lore)
Ord, Int -> SegOp lore -> ShowS
[SegOp lore] -> ShowS
SegOp lore -> String
(Int -> SegOp lore -> ShowS)
-> (SegOp lore -> String)
-> ([SegOp lore] -> ShowS)
-> Show (SegOp lore)
forall lore. Annotations lore => Int -> SegOp lore -> ShowS
forall lore. Annotations lore => [SegOp lore] -> ShowS
forall lore. Annotations lore => SegOp lore -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SegOp lore] -> ShowS
$cshowList :: forall lore. Annotations lore => [SegOp lore] -> ShowS
show :: SegOp lore -> String
$cshow :: forall lore. Annotations lore => SegOp lore -> String
showsPrec :: Int -> SegOp lore -> ShowS
$cshowsPrec :: forall lore. Annotations lore => Int -> SegOp lore -> ShowS
Show)

segLevel :: SegOp lore -> SegLevel
segLevel :: SegOp lore -> SegLevel
segLevel (SegMap SegLevel
lvl SegSpace
_ [Type]
_ KernelBody lore
_) = SegLevel
lvl
segLevel (SegRed SegLevel
lvl SegSpace
_ [SegRedOp lore]
_ [Type]
_ KernelBody lore
_) = SegLevel
lvl
segLevel (SegScan SegLevel
lvl SegSpace
_ Lambda lore
_ [SubExp]
_ [Type]
_ KernelBody lore
_) = SegLevel
lvl
segLevel (SegHist SegLevel
lvl SegSpace
_ [HistOp lore]
_ [Type]
_ KernelBody lore
_) = SegLevel
lvl

segSpace :: SegOp lore -> SegSpace
segSpace :: SegOp lore -> SegSpace
segSpace (SegMap SegLevel
_ SegSpace
lvl [Type]
_ KernelBody lore
_) = SegSpace
lvl
segSpace (SegRed SegLevel
_ SegSpace
lvl [SegRedOp lore]
_ [Type]
_ KernelBody lore
_) = SegSpace
lvl
segSpace (SegScan SegLevel
_ SegSpace
lvl Lambda lore
_ [SubExp]
_ [Type]
_ KernelBody lore
_) = SegSpace
lvl
segSpace (SegHist SegLevel
_ SegSpace
lvl [HistOp lore]
_ [Type]
_ KernelBody lore
_) = SegSpace
lvl

segResultShape :: SegSpace -> Type -> KernelResult -> Type
segResultShape :: SegSpace -> Type -> KernelResult -> Type
segResultShape SegSpace
_ Type
t (WriteReturns [SubExp]
rws VName
_ [([SubExp], SubExp)]
_) =
  Type
t Type -> Shape -> Type
`arrayOfShape` [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
rws
segResultShape SegSpace
space Type
t (Returns ResultManifest
_ SubExp
_) =
  (SubExp -> Type -> Type) -> Type -> [SubExp] -> Type
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr ((Type -> SubExp -> Type) -> SubExp -> Type -> Type
forall a b c. (a -> b -> c) -> b -> a -> c
flip Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
arrayOfRow) Type
t ([SubExp] -> Type) -> [SubExp] -> Type
forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space
segResultShape SegSpace
_ Type
t (ConcatReturns SplitOrdering
_ SubExp
w SubExp
_ VName
_) =
  Type
t Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` SubExp
w
segResultShape SegSpace
_ Type
t (TileReturns [(SubExp, SubExp)]
dims VName
_) =
  Type
t Type -> Shape -> Type
`arrayOfShape` [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape (((SubExp, SubExp) -> SubExp) -> [(SubExp, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp, SubExp) -> SubExp
forall a b. (a, b) -> a
fst [(SubExp, SubExp)]
dims)

segOpType :: SegOp lore -> [Type]
segOpType :: SegOp lore -> [Type]
segOpType (SegMap SegLevel
_ SegSpace
space [Type]
ts KernelBody lore
kbody) =
  (Type -> KernelResult -> Type)
-> [Type] -> [KernelResult] -> [Type]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (SegSpace -> Type -> KernelResult -> Type
segResultShape SegSpace
space) [Type]
ts ([KernelResult] -> [Type]) -> [KernelResult] -> [Type]
forall a b. (a -> b) -> a -> b
$ KernelBody lore -> [KernelResult]
forall lore. KernelBody lore -> [KernelResult]
kernelBodyResult KernelBody lore
kbody
segOpType (SegRed SegLevel
_ SegSpace
space [SegRedOp lore]
reds [Type]
ts KernelBody lore
kbody) =
  [Type]
red_ts [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++
  (Type -> KernelResult -> Type)
-> [Type] -> [KernelResult] -> [Type]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (SegSpace -> Type -> KernelResult -> Type
segResultShape SegSpace
space) [Type]
map_ts
  (Int -> [KernelResult] -> [KernelResult]
forall a. Int -> [a] -> [a]
drop ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
red_ts) ([KernelResult] -> [KernelResult])
-> [KernelResult] -> [KernelResult]
forall a b. (a -> b) -> a -> b
$ KernelBody lore -> [KernelResult]
forall lore. KernelBody lore -> [KernelResult]
kernelBodyResult KernelBody lore
kbody)
  where map_ts :: [Type]
map_ts = Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
drop ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
red_ts) [Type]
ts
        segment_dims :: [SubExp]
segment_dims = [SubExp] -> [SubExp]
forall a. [a] -> [a]
init ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space
        red_ts :: [Type]
red_ts = do
          SegRedOp lore
op <- [SegRedOp lore]
reds
          let shape :: Shape
shape = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
segment_dims Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> SegRedOp lore -> Shape
forall lore. SegRedOp lore -> Shape
segRedShape SegRedOp lore
op
          (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> Shape -> Type
`arrayOfShape` Shape
shape) (LambdaT lore -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType (LambdaT lore -> [Type]) -> LambdaT lore -> [Type]
forall a b. (a -> b) -> a -> b
$ SegRedOp lore -> LambdaT lore
forall lore. SegRedOp lore -> Lambda lore
segRedLambda SegRedOp lore
op)
segOpType (SegScan SegLevel
_ SegSpace
space LambdaT lore
_ [SubExp]
nes [Type]
ts KernelBody lore
kbody) =
  (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> Shape -> Type
`arrayOfShape` [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
dims) [Type]
scan_ts [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++
  (Type -> KernelResult -> Type)
-> [Type] -> [KernelResult] -> [Type]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (SegSpace -> Type -> KernelResult -> Type
segResultShape SegSpace
space) [Type]
map_ts
  (Int -> [KernelResult] -> [KernelResult]
forall a. Int -> [a] -> [a]
drop ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
scan_ts) ([KernelResult] -> [KernelResult])
-> [KernelResult] -> [KernelResult]
forall a b. (a -> b) -> a -> b
$ KernelBody lore -> [KernelResult]
forall lore. KernelBody lore -> [KernelResult]
kernelBodyResult KernelBody lore
kbody)
  where dims :: [SubExp]
dims = SegSpace -> [SubExp]
segSpaceDims SegSpace
space
        ([Type]
scan_ts, [Type]
map_ts) = Int -> [Type] -> ([Type], [Type])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) [Type]
ts
segOpType (SegHist SegLevel
_ SegSpace
space [HistOp lore]
ops [Type]
_ KernelBody lore
_) = do
  HistOp lore
op <- [HistOp lore]
ops
  let shape :: Shape
shape = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape ([SubExp]
segment_dims [SubExp] -> [SubExp] -> [SubExp]
forall a. Semigroup a => a -> a -> a
<> [HistOp lore -> SubExp
forall lore. HistOp lore -> SubExp
histWidth HistOp lore
op]) Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> HistOp lore -> Shape
forall lore. HistOp lore -> Shape
histShape HistOp lore
op
  (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> Shape -> Type
`arrayOfShape` Shape
shape) (LambdaT lore -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType (LambdaT lore -> [Type]) -> LambdaT lore -> [Type]
forall a b. (a -> b) -> a -> b
$ HistOp lore -> LambdaT lore
forall lore. HistOp lore -> Lambda lore
histOp HistOp lore
op)
  where dims :: [SubExp]
dims = SegSpace -> [SubExp]
segSpaceDims SegSpace
space
        segment_dims :: [SubExp]
segment_dims = [SubExp] -> [SubExp]
forall a. [a] -> [a]
init [SubExp]
dims

instance TypedOp (SegOp lore) where
  opType :: SegOp lore -> m [ExtType]
opType = [ExtType] -> m [ExtType]
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([ExtType] -> m [ExtType])
-> (SegOp lore -> [ExtType]) -> SegOp lore -> m [ExtType]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Type] -> [ExtType]
forall u. [TypeBase Shape u] -> [TypeBase ExtShape u]
staticShapes ([Type] -> [ExtType])
-> (SegOp lore -> [Type]) -> SegOp lore -> [ExtType]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegOp lore -> [Type]
forall lore. SegOp lore -> [Type]
segOpType

instance (Attributes lore, Aliased lore) => AliasedOp (SegOp lore) where
  opAliases :: SegOp lore -> [Names]
opAliases = (Type -> Names) -> [Type] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map (Names -> Type -> Names
forall a b. a -> b -> a
const Names
forall a. Monoid a => a
mempty) ([Type] -> [Names])
-> (SegOp lore -> [Type]) -> SegOp lore -> [Names]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegOp lore -> [Type]
forall lore. SegOp lore -> [Type]
segOpType

  consumedInOp :: SegOp lore -> Names
consumedInOp (SegMap SegLevel
_ SegSpace
_ [Type]
_ KernelBody lore
kbody) =
    KernelBody lore -> Names
forall lore. Aliased lore => KernelBody lore -> Names
consumedInKernelBody KernelBody lore
kbody
  consumedInOp (SegRed SegLevel
_ SegSpace
_ [SegRedOp lore]
_ [Type]
_ KernelBody lore
kbody) =
    KernelBody lore -> Names
forall lore. Aliased lore => KernelBody lore -> Names
consumedInKernelBody KernelBody lore
kbody
  consumedInOp (SegScan SegLevel
_ SegSpace
_ Lambda lore
_ [SubExp]
_ [Type]
_ KernelBody lore
kbody) =
    KernelBody lore -> Names
forall lore. Aliased lore => KernelBody lore -> Names
consumedInKernelBody KernelBody lore
kbody
  consumedInOp (SegHist SegLevel
_ SegSpace
_ [HistOp lore]
ops [Type]
_ KernelBody lore
kbody) =
    [VName] -> Names
namesFromList ((HistOp lore -> [VName]) -> [HistOp lore] -> [VName]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap HistOp lore -> [VName]
forall lore. HistOp lore -> [VName]
histDest [HistOp lore]
ops) Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> KernelBody lore -> Names
forall lore. Aliased lore => KernelBody lore -> Names
consumedInKernelBody KernelBody lore
kbody

checkSegLevel :: Maybe SegLevel -> SegLevel -> TC.TypeM lore ()
checkSegLevel :: Maybe SegLevel -> SegLevel -> TypeM lore ()
checkSegLevel Maybe SegLevel
Nothing SegLevel
_ =
  () -> TypeM lore ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
checkSegLevel (Just SegThread{}) SegLevel
_ =
  ErrorCase lore -> TypeM lore ()
forall lore a. ErrorCase lore -> TypeM lore a
TC.bad (ErrorCase lore -> TypeM lore ())
-> ErrorCase lore -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ String -> ErrorCase lore
forall lore. String -> ErrorCase lore
TC.TypeError String
"SegOps cannot occur when already at thread level."
checkSegLevel (Just SegLevel
x) SegLevel
y
  | SegLevel
x SegLevel -> SegLevel -> Bool
forall a. Eq a => a -> a -> Bool
== SegLevel
y = ErrorCase lore -> TypeM lore ()
forall lore a. ErrorCase lore -> TypeM lore a
TC.bad (ErrorCase lore -> TypeM lore ())
-> ErrorCase lore -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ String -> ErrorCase lore
forall lore. String -> ErrorCase lore
TC.TypeError (String -> ErrorCase lore) -> String -> ErrorCase lore
forall a b. (a -> b) -> a -> b
$ String
"Already at at level " String -> ShowS
forall a. [a] -> [a] -> [a]
++ SegLevel -> String
forall a. Pretty a => a -> String
pretty SegLevel
x
  | SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
x Count NumGroups SubExp -> Count NumGroups SubExp -> Bool
forall a. Eq a => a -> a -> Bool
/= SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
y Bool -> Bool -> Bool
|| SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
x Count GroupSize SubExp -> Count GroupSize SubExp -> Bool
forall a. Eq a => a -> a -> Bool
/= SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
y =
      ErrorCase lore -> TypeM lore ()
forall lore a. ErrorCase lore -> TypeM lore a
TC.bad (ErrorCase lore -> TypeM lore ())
-> ErrorCase lore -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ String -> ErrorCase lore
forall lore. String -> ErrorCase lore
TC.TypeError String
"Physical layout for SegLevel does not match parent SegLevel."
  | Bool
otherwise =
      () -> TypeM lore ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

checkSegBasics :: TC.Checkable lore =>
                  Maybe SegLevel -> SegLevel -> SegSpace -> [Type] -> TC.TypeM lore ()
checkSegBasics :: Maybe SegLevel -> SegLevel -> SegSpace -> [Type] -> TypeM lore ()
checkSegBasics Maybe SegLevel
cur_lvl SegLevel
lvl SegSpace
space [Type]
ts = do
  Maybe SegLevel -> SegLevel -> TypeM lore ()
forall lore. Maybe SegLevel -> SegLevel -> TypeM lore ()
checkSegLevel Maybe SegLevel
cur_lvl SegLevel
lvl
  SegSpace -> TypeM lore ()
forall lore. Checkable lore => SegSpace -> TypeM lore ()
checkSegSpace SegSpace
space
  (Type -> TypeM lore ()) -> [Type] -> TypeM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Type -> TypeM lore ()
forall lore u. Checkable lore => TypeBase Shape u -> TypeM lore ()
TC.checkType [Type]
ts

typeCheckSegOp :: TC.Checkable lore =>
                  Maybe SegLevel -> SegOp (Aliases lore) -> TC.TypeM lore ()
typeCheckSegOp :: Maybe SegLevel -> SegOp (Aliases lore) -> TypeM lore ()
typeCheckSegOp Maybe SegLevel
cur_lvl (SegMap SegLevel
lvl SegSpace
space [Type]
ts KernelBody (Aliases lore)
kbody) =
  Maybe SegLevel
-> SegLevel
-> SegSpace
-> [(Lambda (Aliases lore), [SubExp], Shape)]
-> [Type]
-> KernelBody (Aliases lore)
-> TypeM lore ()
forall lore.
Checkable lore =>
Maybe SegLevel
-> SegLevel
-> SegSpace
-> [(Lambda (Aliases lore), [SubExp], Shape)]
-> [Type]
-> KernelBody (Aliases lore)
-> TypeM lore ()
checkScanRed Maybe SegLevel
cur_lvl SegLevel
lvl SegSpace
space [] [Type]
ts KernelBody (Aliases lore)
kbody

typeCheckSegOp Maybe SegLevel
cur_lvl (SegRed SegLevel
lvl SegSpace
space [SegRedOp (Aliases lore)]
reds [Type]
ts KernelBody (Aliases lore)
body) =
  Maybe SegLevel
-> SegLevel
-> SegSpace
-> [(Lambda (Aliases lore), [SubExp], Shape)]
-> [Type]
-> KernelBody (Aliases lore)
-> TypeM lore ()
forall lore.
Checkable lore =>
Maybe SegLevel
-> SegLevel
-> SegSpace
-> [(Lambda (Aliases lore), [SubExp], Shape)]
-> [Type]
-> KernelBody (Aliases lore)
-> TypeM lore ()
checkScanRed Maybe SegLevel
cur_lvl SegLevel
lvl SegSpace
space [(Lambda (Aliases lore), [SubExp], Shape)]
reds' [Type]
ts KernelBody (Aliases lore)
body
  where reds' :: [(Lambda (Aliases lore), [SubExp], Shape)]
reds' = [Lambda (Aliases lore)]
-> [[SubExp]]
-> [Shape]
-> [(Lambda (Aliases lore), [SubExp], Shape)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3
                ((SegRedOp (Aliases lore) -> Lambda (Aliases lore))
-> [SegRedOp (Aliases lore)] -> [Lambda (Aliases lore)]
forall a b. (a -> b) -> [a] -> [b]
map SegRedOp (Aliases lore) -> Lambda (Aliases lore)
forall lore. SegRedOp lore -> Lambda lore
segRedLambda [SegRedOp (Aliases lore)]
reds)
                ((SegRedOp (Aliases lore) -> [SubExp])
-> [SegRedOp (Aliases lore)] -> [[SubExp]]
forall a b. (a -> b) -> [a] -> [b]
map SegRedOp (Aliases lore) -> [SubExp]
forall lore. SegRedOp lore -> [SubExp]
segRedNeutral [SegRedOp (Aliases lore)]
reds)
                ((SegRedOp (Aliases lore) -> Shape)
-> [SegRedOp (Aliases lore)] -> [Shape]
forall a b. (a -> b) -> [a] -> [b]
map SegRedOp (Aliases lore) -> Shape
forall lore. SegRedOp lore -> Shape
segRedShape [SegRedOp (Aliases lore)]
reds)

typeCheckSegOp Maybe SegLevel
cur_lvl (SegScan SegLevel
lvl SegSpace
space Lambda (Aliases lore)
scan_op [SubExp]
nes [Type]
ts KernelBody (Aliases lore)
body) =
  Maybe SegLevel
-> SegLevel
-> SegSpace
-> [(Lambda (Aliases lore), [SubExp], Shape)]
-> [Type]
-> KernelBody (Aliases lore)
-> TypeM lore ()
forall lore.
Checkable lore =>
Maybe SegLevel
-> SegLevel
-> SegSpace
-> [(Lambda (Aliases lore), [SubExp], Shape)]
-> [Type]
-> KernelBody (Aliases lore)
-> TypeM lore ()
checkScanRed Maybe SegLevel
cur_lvl SegLevel
lvl SegSpace
space [(Lambda (Aliases lore)
scan_op, [SubExp]
nes, Shape
forall a. Monoid a => a
mempty)] [Type]
ts KernelBody (Aliases lore)
body

typeCheckSegOp Maybe SegLevel
cur_lvl (SegHist SegLevel
lvl SegSpace
space [HistOp (Aliases lore)]
ops [Type]
ts KernelBody (Aliases lore)
kbody) = do
  Maybe SegLevel -> SegLevel -> SegSpace -> [Type] -> TypeM lore ()
forall lore.
Checkable lore =>
Maybe SegLevel -> SegLevel -> SegSpace -> [Type] -> TypeM lore ()
checkSegBasics Maybe SegLevel
cur_lvl SegLevel
lvl SegSpace
space [Type]
ts

  Scope (Aliases lore) -> TypeM lore () -> TypeM lore ()
forall lore a.
Checkable lore =>
Scope (Aliases lore) -> TypeM lore a -> TypeM lore a
TC.binding (SegSpace -> Scope (Aliases lore)
forall lore. SegSpace -> Scope lore
scopeOfSegSpace SegSpace
space) (TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ do
    [[Type]]
nes_ts <- [HistOp (Aliases lore)]
-> (HistOp (Aliases lore) -> TypeM lore [Type])
-> TypeM lore [[Type]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [HistOp (Aliases lore)]
ops ((HistOp (Aliases lore) -> TypeM lore [Type])
 -> TypeM lore [[Type]])
-> (HistOp (Aliases lore) -> TypeM lore [Type])
-> TypeM lore [[Type]]
forall a b. (a -> b) -> a -> b
$ \(HistOp SubExp
dest_w SubExp
rf [VName]
dests [SubExp]
nes Shape
shape Lambda (Aliases lore)
op) -> do
      [Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32] SubExp
dest_w
      [Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32] SubExp
rf
      [Arg]
nes' <- (SubExp -> TypeM lore Arg) -> [SubExp] -> TypeM lore [Arg]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> TypeM lore Arg
forall lore. Checkable lore => SubExp -> TypeM lore Arg
TC.checkArg [SubExp]
nes
      (SubExp -> TypeM lore ()) -> [SubExp] -> TypeM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ([Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32]) ([SubExp] -> TypeM lore ()) -> [SubExp] -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape

      -- Operator type must match the type of neutral elements.
      let stripVecDims :: Type -> Type
stripVecDims = Int -> Type -> Type
forall shape u.
ArrayShape shape =>
Int -> TypeBase shape u -> TypeBase shape u
stripArray (Int -> Type -> Type) -> Int -> Type -> Type
forall a b. (a -> b) -> a -> b
$ Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank Shape
shape
      Lambda (Aliases lore) -> [Arg] -> TypeM lore ()
forall lore.
Checkable lore =>
Lambda (Aliases lore) -> [Arg] -> TypeM lore ()
TC.checkLambda Lambda (Aliases lore)
op ([Arg] -> TypeM lore ()) -> [Arg] -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ (Arg -> Arg) -> [Arg] -> [Arg]
forall a b. (a -> b) -> [a] -> [b]
map (Arg -> Arg
TC.noArgAliases (Arg -> Arg) -> (Arg -> Arg) -> Arg -> Arg
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Type -> Type) -> Arg -> Arg
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (b, d) (c, d)
first Type -> Type
stripVecDims) ([Arg] -> [Arg]) -> [Arg] -> [Arg]
forall a b. (a -> b) -> a -> b
$ [Arg]
nes' [Arg] -> [Arg] -> [Arg]
forall a. [a] -> [a] -> [a]
++ [Arg]
nes'
      let nes_t :: [Type]
nes_t = (Arg -> Type) -> [Arg] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Type
TC.argType [Arg]
nes'
      Bool -> TypeM lore () -> TypeM lore ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type]
nes_t [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== Lambda (Aliases lore) -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda (Aliases lore)
op) (TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
        ErrorCase lore -> TypeM lore ()
forall lore a. ErrorCase lore -> TypeM lore a
TC.bad (ErrorCase lore -> TypeM lore ())
-> ErrorCase lore -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ String -> ErrorCase lore
forall lore. String -> ErrorCase lore
TC.TypeError (String -> ErrorCase lore) -> String -> ErrorCase lore
forall a b. (a -> b) -> a -> b
$ String
"SegHist operator has return type " String -> ShowS
forall a. [a] -> [a] -> [a]
++
        [Type] -> String
forall a. Pretty a => [a] -> String
prettyTuple (Lambda (Aliases lore) -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda (Aliases lore)
op) String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" but neutral element has type " String -> ShowS
forall a. [a] -> [a] -> [a]
++
        [Type] -> String
forall a. Pretty a => [a] -> String
prettyTuple [Type]
nes_t

      -- Arrays must have proper type.
      let dest_shape :: Shape
dest_shape = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape ([SubExp]
segment_dims [SubExp] -> [SubExp] -> [SubExp]
forall a. Semigroup a => a -> a -> a
<> [SubExp
dest_w]) Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> Shape
shape
      [(Type, VName)]
-> ((Type, VName) -> TypeM lore ()) -> TypeM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Type] -> [VName] -> [(Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Type]
nes_t [VName]
dests) (((Type, VName) -> TypeM lore ()) -> TypeM lore ())
-> ((Type, VName) -> TypeM lore ()) -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ \(Type
t, VName
dest) -> do
        [Type] -> VName -> TypeM lore ()
forall lore. Checkable lore => [Type] -> VName -> TypeM lore ()
TC.requireI [Type
t Type -> Shape -> Type
`arrayOfShape` Shape
dest_shape] VName
dest
        Names -> TypeM lore ()
forall lore. Checkable lore => Names -> TypeM lore ()
TC.consume (Names -> TypeM lore ()) -> TypeM lore Names -> TypeM lore ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> TypeM lore Names
forall lore. Checkable lore => VName -> TypeM lore Names
TC.lookupAliases VName
dest

      [Type] -> TypeM lore [Type]
forall (m :: * -> *) a. Monad m => a -> m a
return ([Type] -> TypeM lore [Type]) -> [Type] -> TypeM lore [Type]
forall a b. (a -> b) -> a -> b
$ (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> Shape -> Type
`arrayOfShape` Shape
shape) [Type]
nes_t

    [Type] -> KernelBody (Aliases lore) -> TypeM lore ()
forall lore.
Checkable lore =>
[Type] -> KernelBody (Aliases lore) -> TypeM lore ()
checkKernelBody [Type]
ts KernelBody (Aliases lore)
kbody

    -- Return type of bucket function must be an index for each
    -- operation followed by the values to write.
    let bucket_ret_t :: [Type]
bucket_ret_t = Int -> Type -> [Type]
forall a. Int -> a -> [a]
replicate ([HistOp (Aliases lore)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [HistOp (Aliases lore)]
ops) (PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32) [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [[Type]] -> [Type]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Type]]
nes_ts
    Bool -> TypeM lore () -> TypeM lore ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type]
bucket_ret_t [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== [Type]
ts) (TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
      ErrorCase lore -> TypeM lore ()
forall lore a. ErrorCase lore -> TypeM lore a
TC.bad (ErrorCase lore -> TypeM lore ())
-> ErrorCase lore -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ String -> ErrorCase lore
forall lore. String -> ErrorCase lore
TC.TypeError (String -> ErrorCase lore) -> String -> ErrorCase lore
forall a b. (a -> b) -> a -> b
$ String
"SegHist body has return type " String -> ShowS
forall a. [a] -> [a] -> [a]
++
      [Type] -> String
forall a. Pretty a => [a] -> String
prettyTuple [Type]
ts String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" but should have type " String -> ShowS
forall a. [a] -> [a] -> [a]
++
      [Type] -> String
forall a. Pretty a => [a] -> String
prettyTuple [Type]
bucket_ret_t

  where segment_dims :: [SubExp]
segment_dims = [SubExp] -> [SubExp]
forall a. [a] -> [a]
init ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space

checkScanRed :: TC.Checkable lore =>
                Maybe SegLevel -> SegLevel
             -> SegSpace
             -> [(Lambda (Aliases lore), [SubExp], Shape)]
             -> [Type]
             -> KernelBody (Aliases lore)
             -> TC.TypeM lore ()
checkScanRed :: Maybe SegLevel
-> SegLevel
-> SegSpace
-> [(Lambda (Aliases lore), [SubExp], Shape)]
-> [Type]
-> KernelBody (Aliases lore)
-> TypeM lore ()
checkScanRed Maybe SegLevel
cur_lvl SegLevel
lvl SegSpace
space [(Lambda (Aliases lore), [SubExp], Shape)]
ops [Type]
ts KernelBody (Aliases lore)
kbody = do
  Maybe SegLevel -> SegLevel -> SegSpace -> [Type] -> TypeM lore ()
forall lore.
Checkable lore =>
Maybe SegLevel -> SegLevel -> SegSpace -> [Type] -> TypeM lore ()
checkSegBasics Maybe SegLevel
cur_lvl SegLevel
lvl SegSpace
space [Type]
ts

  Scope (Aliases lore) -> TypeM lore () -> TypeM lore ()
forall lore a.
Checkable lore =>
Scope (Aliases lore) -> TypeM lore a -> TypeM lore a
TC.binding (SegSpace -> Scope (Aliases lore)
forall lore. SegSpace -> Scope lore
scopeOfSegSpace SegSpace
space) (TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ do
    [[Type]]
ne_ts <- [(Lambda (Aliases lore), [SubExp], Shape)]
-> ((Lambda (Aliases lore), [SubExp], Shape) -> TypeM lore [Type])
-> TypeM lore [[Type]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(Lambda (Aliases lore), [SubExp], Shape)]
ops (((Lambda (Aliases lore), [SubExp], Shape) -> TypeM lore [Type])
 -> TypeM lore [[Type]])
-> ((Lambda (Aliases lore), [SubExp], Shape) -> TypeM lore [Type])
-> TypeM lore [[Type]]
forall a b. (a -> b) -> a -> b
$ \(Lambda (Aliases lore)
lam, [SubExp]
nes, Shape
shape) -> do
      (SubExp -> TypeM lore ()) -> [SubExp] -> TypeM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ([Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32]) ([SubExp] -> TypeM lore ()) -> [SubExp] -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape
      [Arg]
nes' <- (SubExp -> TypeM lore Arg) -> [SubExp] -> TypeM lore [Arg]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> TypeM lore Arg
forall lore. Checkable lore => SubExp -> TypeM lore Arg
TC.checkArg [SubExp]
nes

      -- Operator type must match the type of neutral elements.
      let stripVecDims :: Type -> Type
stripVecDims = Int -> Type -> Type
forall shape u.
ArrayShape shape =>
Int -> TypeBase shape u -> TypeBase shape u
stripArray (Int -> Type -> Type) -> Int -> Type -> Type
forall a b. (a -> b) -> a -> b
$ Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank Shape
shape
      Lambda (Aliases lore) -> [Arg] -> TypeM lore ()
forall lore.
Checkable lore =>
Lambda (Aliases lore) -> [Arg] -> TypeM lore ()
TC.checkLambda Lambda (Aliases lore)
lam ([Arg] -> TypeM lore ()) -> [Arg] -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ (Arg -> Arg) -> [Arg] -> [Arg]
forall a b. (a -> b) -> [a] -> [b]
map (Arg -> Arg
TC.noArgAliases (Arg -> Arg) -> (Arg -> Arg) -> Arg -> Arg
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Type -> Type) -> Arg -> Arg
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (b, d) (c, d)
first Type -> Type
stripVecDims) ([Arg] -> [Arg]) -> [Arg] -> [Arg]
forall a b. (a -> b) -> a -> b
$ [Arg]
nes' [Arg] -> [Arg] -> [Arg]
forall a. [a] -> [a] -> [a]
++ [Arg]
nes'
      let nes_t :: [Type]
nes_t = (Arg -> Type) -> [Arg] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Type
TC.argType [Arg]
nes'

      Bool -> TypeM lore () -> TypeM lore ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Lambda (Aliases lore) -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda (Aliases lore)
lam [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== [Type]
nes_t) (TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
        ErrorCase lore -> TypeM lore ()
forall lore a. ErrorCase lore -> TypeM lore a
TC.bad (ErrorCase lore -> TypeM lore ())
-> ErrorCase lore -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ String -> ErrorCase lore
forall lore. String -> ErrorCase lore
TC.TypeError String
"wrong type for operator or neutral elements."

      [Type] -> TypeM lore [Type]
forall (m :: * -> *) a. Monad m => a -> m a
return ([Type] -> TypeM lore [Type]) -> [Type] -> TypeM lore [Type]
forall a b. (a -> b) -> a -> b
$ (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> Shape -> Type
`arrayOfShape` Shape
shape) [Type]
nes_t

    let expecting :: [Type]
expecting = [[Type]] -> [Type]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Type]]
ne_ts
        got :: [Type]
got = Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
take ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
expecting) [Type]
ts
    Bool -> TypeM lore () -> TypeM lore ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type]
expecting [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== [Type]
got) (TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
      ErrorCase lore -> TypeM lore ()
forall lore a. ErrorCase lore -> TypeM lore a
TC.bad (ErrorCase lore -> TypeM lore ())
-> ErrorCase lore -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ String -> ErrorCase lore
forall lore. String -> ErrorCase lore
TC.TypeError (String -> ErrorCase lore) -> String -> ErrorCase lore
forall a b. (a -> b) -> a -> b
$
      String
"Wrong return for body (does not match neutral elements; expected " String -> ShowS
forall a. [a] -> [a] -> [a]
++
      [Type] -> String
forall a. Pretty a => a -> String
pretty [Type]
expecting String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"; found " String -> ShowS
forall a. [a] -> [a] -> [a]
++
      [Type] -> String
forall a. Pretty a => a -> String
pretty [Type]
got String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
")"

    [Type] -> KernelBody (Aliases lore) -> TypeM lore ()
forall lore.
Checkable lore =>
[Type] -> KernelBody (Aliases lore) -> TypeM lore ()
checkKernelBody [Type]
ts KernelBody (Aliases lore)
kbody

-- | Like 'Mapper', but just for 'SegOp's.
data SegOpMapper flore tlore m = SegOpMapper {
    SegOpMapper flore tlore m -> SubExp -> m SubExp
mapOnSegOpSubExp :: SubExp -> m SubExp
  , SegOpMapper flore tlore m -> Lambda flore -> m (Lambda tlore)
mapOnSegOpLambda :: Lambda flore -> m (Lambda tlore)
  , SegOpMapper flore tlore m
-> KernelBody flore -> m (KernelBody tlore)
mapOnSegOpBody :: KernelBody flore -> m (KernelBody tlore)
  , SegOpMapper flore tlore m -> VName -> m VName
mapOnSegOpVName :: VName -> m VName
  }

-- | A mapper that simply returns the 'SegOp' verbatim.
identitySegOpMapper :: Monad m => SegOpMapper lore lore m
identitySegOpMapper :: SegOpMapper lore lore m
identitySegOpMapper = SegOpMapper :: forall flore tlore (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda flore -> m (Lambda tlore))
-> (KernelBody flore -> m (KernelBody tlore))
-> (VName -> m VName)
-> SegOpMapper flore tlore m
SegOpMapper { mapOnSegOpSubExp :: SubExp -> m SubExp
mapOnSegOpSubExp = SubExp -> m SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return
                                  , mapOnSegOpLambda :: Lambda lore -> m (Lambda lore)
mapOnSegOpLambda = Lambda lore -> m (Lambda lore)
forall (m :: * -> *) a. Monad m => a -> m a
return
                                  , mapOnSegOpBody :: KernelBody lore -> m (KernelBody lore)
mapOnSegOpBody = KernelBody lore -> m (KernelBody lore)
forall (m :: * -> *) a. Monad m => a -> m a
return
                                  , mapOnSegOpVName :: VName -> m VName
mapOnSegOpVName = VName -> m VName
forall (m :: * -> *) a. Monad m => a -> m a
return
                                  }

mapOnSegSpace :: Monad f =>
                 SegOpMapper flore tlore f -> SegSpace -> f SegSpace
mapOnSegSpace :: SegOpMapper flore tlore f -> SegSpace -> f SegSpace
mapOnSegSpace SegOpMapper flore tlore f
tv (SegSpace VName
phys [(VName, SubExp)]
dims) =
  VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
phys ([(VName, SubExp)] -> SegSpace)
-> f [(VName, SubExp)] -> f SegSpace
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((VName, SubExp) -> f (VName, SubExp))
-> [(VName, SubExp)] -> f [(VName, SubExp)]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse ((SubExp -> f SubExp) -> (VName, SubExp) -> f (VName, SubExp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse ((SubExp -> f SubExp) -> (VName, SubExp) -> f (VName, SubExp))
-> (SubExp -> f SubExp) -> (VName, SubExp) -> f (VName, SubExp)
forall a b. (a -> b) -> a -> b
$ SegOpMapper flore tlore f -> SubExp -> f SubExp
forall flore tlore (m :: * -> *).
SegOpMapper flore tlore m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper flore tlore f
tv) [(VName, SubExp)]
dims

mapSegOpM :: (Applicative m, Monad m) =>
              SegOpMapper flore tlore m -> SegOp flore -> m (SegOp tlore)
mapSegOpM :: SegOpMapper flore tlore m -> SegOp flore -> m (SegOp tlore)
mapSegOpM SegOpMapper flore tlore m
tv (SegMap SegLevel
lvl SegSpace
space [Type]
ts KernelBody flore
body) =
  SegLevel -> SegSpace -> [Type] -> KernelBody tlore -> SegOp tlore
forall lore.
SegLevel -> SegSpace -> [Type] -> KernelBody lore -> SegOp lore
SegMap
  (SegLevel -> SegSpace -> [Type] -> KernelBody tlore -> SegOp tlore)
-> m SegLevel
-> m (SegSpace -> [Type] -> KernelBody tlore -> SegOp tlore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper flore tlore m -> SegLevel -> m SegLevel
forall (m :: * -> *) flore tlore.
Monad m =>
SegOpMapper flore tlore m -> SegLevel -> m SegLevel
mapOnSegLevel SegOpMapper flore tlore m
tv SegLevel
lvl
  m (SegSpace -> [Type] -> KernelBody tlore -> SegOp tlore)
-> m SegSpace -> m ([Type] -> KernelBody tlore -> SegOp tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper flore tlore m -> SegSpace -> m SegSpace
forall (f :: * -> *) flore tlore.
Monad f =>
SegOpMapper flore tlore f -> SegSpace -> f SegSpace
mapOnSegSpace SegOpMapper flore tlore m
tv SegSpace
space
  m ([Type] -> KernelBody tlore -> SegOp tlore)
-> m [Type] -> m (KernelBody tlore -> SegOp tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (Type -> m Type) -> [Type] -> m [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SegOpMapper flore tlore m -> Type -> m Type
forall (m :: * -> *) flore tlore.
Monad m =>
SegOpMapper flore tlore m -> Type -> m Type
mapOnSegOpType SegOpMapper flore tlore m
tv) [Type]
ts
  m (KernelBody tlore -> SegOp tlore)
-> m (KernelBody tlore) -> m (SegOp tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper flore tlore m
-> KernelBody flore -> m (KernelBody tlore)
forall flore tlore (m :: * -> *).
SegOpMapper flore tlore m
-> KernelBody flore -> m (KernelBody tlore)
mapOnSegOpBody SegOpMapper flore tlore m
tv KernelBody flore
body
mapSegOpM SegOpMapper flore tlore m
tv (SegRed SegLevel
lvl SegSpace
space [SegRedOp flore]
reds [Type]
ts KernelBody flore
lam) =
  SegLevel
-> SegSpace
-> [SegRedOp tlore]
-> [Type]
-> KernelBody tlore
-> SegOp tlore
forall lore.
SegLevel
-> SegSpace
-> [SegRedOp lore]
-> [Type]
-> KernelBody lore
-> SegOp lore
SegRed
  (SegLevel
 -> SegSpace
 -> [SegRedOp tlore]
 -> [Type]
 -> KernelBody tlore
 -> SegOp tlore)
-> m SegLevel
-> m (SegSpace
      -> [SegRedOp tlore] -> [Type] -> KernelBody tlore -> SegOp tlore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper flore tlore m -> SegLevel -> m SegLevel
forall (m :: * -> *) flore tlore.
Monad m =>
SegOpMapper flore tlore m -> SegLevel -> m SegLevel
mapOnSegLevel SegOpMapper flore tlore m
tv SegLevel
lvl
  m (SegSpace
   -> [SegRedOp tlore] -> [Type] -> KernelBody tlore -> SegOp tlore)
-> m SegSpace
-> m ([SegRedOp tlore]
      -> [Type] -> KernelBody tlore -> SegOp tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper flore tlore m -> SegSpace -> m SegSpace
forall (f :: * -> *) flore tlore.
Monad f =>
SegOpMapper flore tlore f -> SegSpace -> f SegSpace
mapOnSegSpace SegOpMapper flore tlore m
tv SegSpace
space
  m ([SegRedOp tlore] -> [Type] -> KernelBody tlore -> SegOp tlore)
-> m [SegRedOp tlore]
-> m ([Type] -> KernelBody tlore -> SegOp tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SegRedOp flore -> m (SegRedOp tlore))
-> [SegRedOp flore] -> m [SegRedOp tlore]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SegRedOp flore -> m (SegRedOp tlore)
onSegOp [SegRedOp flore]
reds
  m ([Type] -> KernelBody tlore -> SegOp tlore)
-> m [Type] -> m (KernelBody tlore -> SegOp tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (Type -> m Type) -> [Type] -> m [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((SubExp -> m SubExp) -> Type -> m Type
forall (m :: * -> *).
Monad m =>
(SubExp -> m SubExp) -> Type -> m Type
mapOnType ((SubExp -> m SubExp) -> Type -> m Type)
-> (SubExp -> m SubExp) -> Type -> m Type
forall a b. (a -> b) -> a -> b
$ SegOpMapper flore tlore m -> SubExp -> m SubExp
forall flore tlore (m :: * -> *).
SegOpMapper flore tlore m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper flore tlore m
tv) [Type]
ts
  m (KernelBody tlore -> SegOp tlore)
-> m (KernelBody tlore) -> m (SegOp tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper flore tlore m
-> KernelBody flore -> m (KernelBody tlore)
forall flore tlore (m :: * -> *).
SegOpMapper flore tlore m
-> KernelBody flore -> m (KernelBody tlore)
mapOnSegOpBody SegOpMapper flore tlore m
tv KernelBody flore
lam
  where onSegOp :: SegRedOp flore -> m (SegRedOp tlore)
onSegOp (SegRedOp Commutativity
comm Lambda flore
red_op [SubExp]
nes Shape
shape) =
          Commutativity
-> Lambda tlore -> [SubExp] -> Shape -> SegRedOp tlore
forall lore.
Commutativity -> Lambda lore -> [SubExp] -> Shape -> SegRedOp lore
SegRedOp Commutativity
comm
          (Lambda tlore -> [SubExp] -> Shape -> SegRedOp tlore)
-> m (Lambda tlore) -> m ([SubExp] -> Shape -> SegRedOp tlore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper flore tlore m -> Lambda flore -> m (Lambda tlore)
forall flore tlore (m :: * -> *).
SegOpMapper flore tlore m -> Lambda flore -> m (Lambda tlore)
mapOnSegOpLambda SegOpMapper flore tlore m
tv Lambda flore
red_op
          m ([SubExp] -> Shape -> SegRedOp tlore)
-> m [SubExp] -> m (Shape -> SegRedOp tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SubExp -> m SubExp) -> [SubExp] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SegOpMapper flore tlore m -> SubExp -> m SubExp
forall flore tlore (m :: * -> *).
SegOpMapper flore tlore m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper flore tlore m
tv) [SubExp]
nes
          m (Shape -> SegRedOp tlore) -> m Shape -> m (SegRedOp tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> Shape) -> m [SubExp] -> m Shape
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExp -> m SubExp) -> [SubExp] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SegOpMapper flore tlore m -> SubExp -> m SubExp
forall flore tlore (m :: * -> *).
SegOpMapper flore tlore m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper flore tlore m
tv) (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape))
mapSegOpM SegOpMapper flore tlore m
tv (SegScan SegLevel
lvl SegSpace
space Lambda flore
scan_op [SubExp]
nes [Type]
ts KernelBody flore
body) =
  SegLevel
-> SegSpace
-> Lambda tlore
-> [SubExp]
-> [Type]
-> KernelBody tlore
-> SegOp tlore
forall lore.
SegLevel
-> SegSpace
-> Lambda lore
-> [SubExp]
-> [Type]
-> KernelBody lore
-> SegOp lore
SegScan
  (SegLevel
 -> SegSpace
 -> Lambda tlore
 -> [SubExp]
 -> [Type]
 -> KernelBody tlore
 -> SegOp tlore)
-> m SegLevel
-> m (SegSpace
      -> Lambda tlore
      -> [SubExp]
      -> [Type]
      -> KernelBody tlore
      -> SegOp tlore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper flore tlore m -> SegLevel -> m SegLevel
forall (m :: * -> *) flore tlore.
Monad m =>
SegOpMapper flore tlore m -> SegLevel -> m SegLevel
mapOnSegLevel SegOpMapper flore tlore m
tv SegLevel
lvl
  m (SegSpace
   -> Lambda tlore
   -> [SubExp]
   -> [Type]
   -> KernelBody tlore
   -> SegOp tlore)
-> m SegSpace
-> m (Lambda tlore
      -> [SubExp] -> [Type] -> KernelBody tlore -> SegOp tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper flore tlore m -> SegSpace -> m SegSpace
forall (f :: * -> *) flore tlore.
Monad f =>
SegOpMapper flore tlore f -> SegSpace -> f SegSpace
mapOnSegSpace SegOpMapper flore tlore m
tv SegSpace
space
  m (Lambda tlore
   -> [SubExp] -> [Type] -> KernelBody tlore -> SegOp tlore)
-> m (Lambda tlore)
-> m ([SubExp] -> [Type] -> KernelBody tlore -> SegOp tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper flore tlore m -> Lambda flore -> m (Lambda tlore)
forall flore tlore (m :: * -> *).
SegOpMapper flore tlore m -> Lambda flore -> m (Lambda tlore)
mapOnSegOpLambda SegOpMapper flore tlore m
tv Lambda flore
scan_op
  m ([SubExp] -> [Type] -> KernelBody tlore -> SegOp tlore)
-> m [SubExp] -> m ([Type] -> KernelBody tlore -> SegOp tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SubExp -> m SubExp) -> [SubExp] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SegOpMapper flore tlore m -> SubExp -> m SubExp
forall flore tlore (m :: * -> *).
SegOpMapper flore tlore m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper flore tlore m
tv) [SubExp]
nes
  m ([Type] -> KernelBody tlore -> SegOp tlore)
-> m [Type] -> m (KernelBody tlore -> SegOp tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (Type -> m Type) -> [Type] -> m [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((SubExp -> m SubExp) -> Type -> m Type
forall (m :: * -> *).
Monad m =>
(SubExp -> m SubExp) -> Type -> m Type
mapOnType ((SubExp -> m SubExp) -> Type -> m Type)
-> (SubExp -> m SubExp) -> Type -> m Type
forall a b. (a -> b) -> a -> b
$ SegOpMapper flore tlore m -> SubExp -> m SubExp
forall flore tlore (m :: * -> *).
SegOpMapper flore tlore m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper flore tlore m
tv) [Type]
ts
  m (KernelBody tlore -> SegOp tlore)
-> m (KernelBody tlore) -> m (SegOp tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper flore tlore m
-> KernelBody flore -> m (KernelBody tlore)
forall flore tlore (m :: * -> *).
SegOpMapper flore tlore m
-> KernelBody flore -> m (KernelBody tlore)
mapOnSegOpBody SegOpMapper flore tlore m
tv KernelBody flore
body
mapSegOpM SegOpMapper flore tlore m
tv (SegHist SegLevel
lvl SegSpace
space [HistOp flore]
ops [Type]
ts KernelBody flore
body) =
  SegLevel
-> SegSpace
-> [HistOp tlore]
-> [Type]
-> KernelBody tlore
-> SegOp tlore
forall lore.
SegLevel
-> SegSpace
-> [HistOp lore]
-> [Type]
-> KernelBody lore
-> SegOp lore
SegHist
  (SegLevel
 -> SegSpace
 -> [HistOp tlore]
 -> [Type]
 -> KernelBody tlore
 -> SegOp tlore)
-> m SegLevel
-> m (SegSpace
      -> [HistOp tlore] -> [Type] -> KernelBody tlore -> SegOp tlore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper flore tlore m -> SegLevel -> m SegLevel
forall (m :: * -> *) flore tlore.
Monad m =>
SegOpMapper flore tlore m -> SegLevel -> m SegLevel
mapOnSegLevel SegOpMapper flore tlore m
tv SegLevel
lvl
  m (SegSpace
   -> [HistOp tlore] -> [Type] -> KernelBody tlore -> SegOp tlore)
-> m SegSpace
-> m ([HistOp tlore] -> [Type] -> KernelBody tlore -> SegOp tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper flore tlore m -> SegSpace -> m SegSpace
forall (f :: * -> *) flore tlore.
Monad f =>
SegOpMapper flore tlore f -> SegSpace -> f SegSpace
mapOnSegSpace SegOpMapper flore tlore m
tv SegSpace
space
  m ([HistOp tlore] -> [Type] -> KernelBody tlore -> SegOp tlore)
-> m [HistOp tlore]
-> m ([Type] -> KernelBody tlore -> SegOp tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (HistOp flore -> m (HistOp tlore))
-> [HistOp flore] -> m [HistOp tlore]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM HistOp flore -> m (HistOp tlore)
onHistOp [HistOp flore]
ops
  m ([Type] -> KernelBody tlore -> SegOp tlore)
-> m [Type] -> m (KernelBody tlore -> SegOp tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (Type -> m Type) -> [Type] -> m [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((SubExp -> m SubExp) -> Type -> m Type
forall (m :: * -> *).
Monad m =>
(SubExp -> m SubExp) -> Type -> m Type
mapOnType ((SubExp -> m SubExp) -> Type -> m Type)
-> (SubExp -> m SubExp) -> Type -> m Type
forall a b. (a -> b) -> a -> b
$ SegOpMapper flore tlore m -> SubExp -> m SubExp
forall flore tlore (m :: * -> *).
SegOpMapper flore tlore m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper flore tlore m
tv) [Type]
ts
  m (KernelBody tlore -> SegOp tlore)
-> m (KernelBody tlore) -> m (SegOp tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper flore tlore m
-> KernelBody flore -> m (KernelBody tlore)
forall flore tlore (m :: * -> *).
SegOpMapper flore tlore m
-> KernelBody flore -> m (KernelBody tlore)
mapOnSegOpBody SegOpMapper flore tlore m
tv KernelBody flore
body
  where onHistOp :: HistOp flore -> m (HistOp tlore)
onHistOp (HistOp SubExp
w SubExp
rf [VName]
arrs [SubExp]
nes Shape
shape Lambda flore
op) =
          SubExp
-> SubExp
-> [VName]
-> [SubExp]
-> Shape
-> Lambda tlore
-> HistOp tlore
forall lore.
SubExp
-> SubExp
-> [VName]
-> [SubExp]
-> Shape
-> Lambda lore
-> HistOp lore
HistOp (SubExp
 -> SubExp
 -> [VName]
 -> [SubExp]
 -> Shape
 -> Lambda tlore
 -> HistOp tlore)
-> m SubExp
-> m (SubExp
      -> [VName] -> [SubExp] -> Shape -> Lambda tlore -> HistOp tlore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper flore tlore m -> SubExp -> m SubExp
forall flore tlore (m :: * -> *).
SegOpMapper flore tlore m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper flore tlore m
tv SubExp
w
          m (SubExp
   -> [VName] -> [SubExp] -> Shape -> Lambda tlore -> HistOp tlore)
-> m SubExp
-> m ([VName] -> [SubExp] -> Shape -> Lambda tlore -> HistOp tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper flore tlore m -> SubExp -> m SubExp
forall flore tlore (m :: * -> *).
SegOpMapper flore tlore m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper flore tlore m
tv SubExp
rf
          m ([VName] -> [SubExp] -> Shape -> Lambda tlore -> HistOp tlore)
-> m [VName]
-> m ([SubExp] -> Shape -> Lambda tlore -> HistOp tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (VName -> m VName) -> [VName] -> m [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SegOpMapper flore tlore m -> VName -> m VName
forall flore tlore (m :: * -> *).
SegOpMapper flore tlore m -> VName -> m VName
mapOnSegOpVName SegOpMapper flore tlore m
tv) [VName]
arrs
          m ([SubExp] -> Shape -> Lambda tlore -> HistOp tlore)
-> m [SubExp] -> m (Shape -> Lambda tlore -> HistOp tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SubExp -> m SubExp) -> [SubExp] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SegOpMapper flore tlore m -> SubExp -> m SubExp
forall flore tlore (m :: * -> *).
SegOpMapper flore tlore m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper flore tlore m
tv) [SubExp]
nes
          m (Shape -> Lambda tlore -> HistOp tlore)
-> m Shape -> m (Lambda tlore -> HistOp tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> Shape) -> m [SubExp] -> m Shape
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExp -> m SubExp) -> [SubExp] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SegOpMapper flore tlore m -> SubExp -> m SubExp
forall flore tlore (m :: * -> *).
SegOpMapper flore tlore m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper flore tlore m
tv) (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape))
          m (Lambda tlore -> HistOp tlore)
-> m (Lambda tlore) -> m (HistOp tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper flore tlore m -> Lambda flore -> m (Lambda tlore)
forall flore tlore (m :: * -> *).
SegOpMapper flore tlore m -> Lambda flore -> m (Lambda tlore)
mapOnSegOpLambda SegOpMapper flore tlore m
tv Lambda flore
op

mapOnSegLevel :: Monad m =>
                 SegOpMapper flore tlore m -> SegLevel -> m SegLevel
mapOnSegLevel :: SegOpMapper flore tlore m -> SegLevel -> m SegLevel
mapOnSegLevel SegOpMapper flore tlore m
tv (SegThread Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegVirt
virt) =
  Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegThread
  (Count NumGroups SubExp
 -> Count GroupSize SubExp -> SegVirt -> SegLevel)
-> m (Count NumGroups SubExp)
-> m (Count GroupSize SubExp -> SegVirt -> SegLevel)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExp -> m SubExp)
-> Count NumGroups SubExp -> m (Count NumGroups SubExp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (SegOpMapper flore tlore m -> SubExp -> m SubExp
forall flore tlore (m :: * -> *).
SegOpMapper flore tlore m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper flore tlore m
tv) Count NumGroups SubExp
num_groups
  m (Count GroupSize SubExp -> SegVirt -> SegLevel)
-> m (Count GroupSize SubExp) -> m (SegVirt -> SegLevel)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SubExp -> m SubExp)
-> Count GroupSize SubExp -> m (Count GroupSize SubExp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (SegOpMapper flore tlore m -> SubExp -> m SubExp
forall flore tlore (m :: * -> *).
SegOpMapper flore tlore m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper flore tlore m
tv) Count GroupSize SubExp
group_size
  m (SegVirt -> SegLevel) -> m SegVirt -> m SegLevel
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegVirt -> m SegVirt
forall (f :: * -> *) a. Applicative f => a -> f a
pure SegVirt
virt
mapOnSegLevel SegOpMapper flore tlore m
tv (SegGroup Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegVirt
virt) =
  Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegGroup
  (Count NumGroups SubExp
 -> Count GroupSize SubExp -> SegVirt -> SegLevel)
-> m (Count NumGroups SubExp)
-> m (Count GroupSize SubExp -> SegVirt -> SegLevel)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExp -> m SubExp)
-> Count NumGroups SubExp -> m (Count NumGroups SubExp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (SegOpMapper flore tlore m -> SubExp -> m SubExp
forall flore tlore (m :: * -> *).
SegOpMapper flore tlore m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper flore tlore m
tv) Count NumGroups SubExp
num_groups
  m (Count GroupSize SubExp -> SegVirt -> SegLevel)
-> m (Count GroupSize SubExp) -> m (SegVirt -> SegLevel)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SubExp -> m SubExp)
-> Count GroupSize SubExp -> m (Count GroupSize SubExp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (SegOpMapper flore tlore m -> SubExp -> m SubExp
forall flore tlore (m :: * -> *).
SegOpMapper flore tlore m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper flore tlore m
tv) Count GroupSize SubExp
group_size
  m (SegVirt -> SegLevel) -> m SegVirt -> m SegLevel
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegVirt -> m SegVirt
forall (f :: * -> *) a. Applicative f => a -> f a
pure SegVirt
virt

mapOnSegOpType :: Monad m =>
                  SegOpMapper flore tlore m -> Type -> m Type
mapOnSegOpType :: SegOpMapper flore tlore m -> Type -> m Type
mapOnSegOpType SegOpMapper flore tlore m
_tv (Prim PrimType
pt) = Type -> m Type
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Type -> m Type) -> Type -> m Type
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
pt
mapOnSegOpType SegOpMapper flore tlore m
tv (Array PrimType
pt Shape
shape NoUniqueness
u) = PrimType -> Shape -> NoUniqueness -> Type
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
pt (Shape -> NoUniqueness -> Type)
-> m Shape -> m (NoUniqueness -> Type)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Shape -> m Shape
f Shape
shape m (NoUniqueness -> Type) -> m NoUniqueness -> m Type
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> NoUniqueness -> m NoUniqueness
forall (f :: * -> *) a. Applicative f => a -> f a
pure NoUniqueness
u
  where f :: Shape -> m Shape
f (Shape [SubExp]
dims) = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> Shape) -> m [SubExp] -> m Shape
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExp -> m SubExp) -> [SubExp] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SegOpMapper flore tlore m -> SubExp -> m SubExp
forall flore tlore (m :: * -> *).
SegOpMapper flore tlore m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper flore tlore m
tv) [SubExp]
dims
mapOnSegOpType SegOpMapper flore tlore m
_tv (Mem Space
s) = Type -> m Type
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Type -> m Type) -> Type -> m Type
forall a b. (a -> b) -> a -> b
$ Space -> Type
forall shape u. Space -> TypeBase shape u
Mem Space
s

instance Attributes lore => Substitute (SegOp lore) where
  substituteNames :: Map VName VName -> SegOp lore -> SegOp lore
substituteNames Map VName VName
subst = Identity (SegOp lore) -> SegOp lore
forall a. Identity a -> a
runIdentity (Identity (SegOp lore) -> SegOp lore)
-> (SegOp lore -> Identity (SegOp lore))
-> SegOp lore
-> SegOp lore
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegOpMapper lore lore Identity
-> SegOp lore -> Identity (SegOp lore)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
SegOpMapper flore tlore m -> SegOp flore -> m (SegOp tlore)
mapSegOpM SegOpMapper lore lore Identity
substitute
    where substitute :: SegOpMapper lore lore Identity
substitute =
            SegOpMapper :: forall flore tlore (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda flore -> m (Lambda tlore))
-> (KernelBody flore -> m (KernelBody tlore))
-> (VName -> m VName)
-> SegOpMapper flore tlore m
SegOpMapper { mapOnSegOpSubExp :: SubExp -> Identity SubExp
mapOnSegOpSubExp = SubExp -> Identity SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp -> Identity SubExp)
-> (SubExp -> SubExp) -> SubExp -> Identity SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst
                        , mapOnSegOpLambda :: Lambda lore -> Identity (Lambda lore)
mapOnSegOpLambda = Lambda lore -> Identity (Lambda lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda lore -> Identity (Lambda lore))
-> (Lambda lore -> Lambda lore)
-> Lambda lore
-> Identity (Lambda lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map VName VName -> Lambda lore -> Lambda lore
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst
                        , mapOnSegOpBody :: KernelBody lore -> Identity (KernelBody lore)
mapOnSegOpBody = KernelBody lore -> Identity (KernelBody lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (KernelBody lore -> Identity (KernelBody lore))
-> (KernelBody lore -> KernelBody lore)
-> KernelBody lore
-> Identity (KernelBody lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map VName VName -> KernelBody lore -> KernelBody lore
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst
                        , mapOnSegOpVName :: VName -> Identity VName
mapOnSegOpVName = VName -> Identity VName
forall (m :: * -> *) a. Monad m => a -> m a
return (VName -> Identity VName)
-> (VName -> VName) -> VName -> Identity VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map VName VName -> VName -> VName
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst
                        }

instance Attributes lore => Rename (SegOp lore) where
  rename :: SegOp lore -> RenameM (SegOp lore)
rename = SegOpMapper lore lore RenameM -> SegOp lore -> RenameM (SegOp lore)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
SegOpMapper flore tlore m -> SegOp flore -> m (SegOp tlore)
mapSegOpM SegOpMapper lore lore RenameM
renamer
    where renamer :: SegOpMapper lore lore RenameM
renamer = (SubExp -> RenameM SubExp)
-> (Lambda lore -> RenameM (Lambda lore))
-> (KernelBody lore -> RenameM (KernelBody lore))
-> (VName -> RenameM VName)
-> SegOpMapper lore lore RenameM
forall flore tlore (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda flore -> m (Lambda tlore))
-> (KernelBody flore -> m (KernelBody tlore))
-> (VName -> m VName)
-> SegOpMapper flore tlore m
SegOpMapper SubExp -> RenameM SubExp
forall a. Rename a => a -> RenameM a
rename Lambda lore -> RenameM (Lambda lore)
forall a. Rename a => a -> RenameM a
rename KernelBody lore -> RenameM (KernelBody lore)
forall a. Rename a => a -> RenameM a
rename VName -> RenameM VName
forall a. Rename a => a -> RenameM a
rename

instance (Attributes lore, FreeIn (LParamAttr lore)) =>
         FreeIn (SegOp lore) where
  freeIn' :: SegOp lore -> FV
freeIn' SegOp lore
e = (State FV (SegOp lore) -> FV -> FV)
-> FV -> State FV (SegOp lore) -> FV
forall a b c. (a -> b -> c) -> b -> a -> c
flip State FV (SegOp lore) -> FV -> FV
forall s a. State s a -> s -> s
execState FV
forall a. Monoid a => a
mempty (State FV (SegOp lore) -> FV) -> State FV (SegOp lore) -> FV
forall a b. (a -> b) -> a -> b
$ SegOpMapper lore lore (StateT FV Identity)
-> SegOp lore -> State FV (SegOp lore)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
SegOpMapper flore tlore m -> SegOp flore -> m (SegOp tlore)
mapSegOpM SegOpMapper lore lore (StateT FV Identity)
free SegOp lore
e
    where walk :: (b -> s) -> b -> m b
walk b -> s
f b
x = (s -> s) -> m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (s -> s -> s
forall a. Semigroup a => a -> a -> a
<>b -> s
f b
x) m () -> m b -> m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> b -> m b
forall (m :: * -> *) a. Monad m => a -> m a
return b
x
          free :: SegOpMapper lore lore (StateT FV Identity)
free = SegOpMapper :: forall flore tlore (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda flore -> m (Lambda tlore))
-> (KernelBody flore -> m (KernelBody tlore))
-> (VName -> m VName)
-> SegOpMapper flore tlore m
SegOpMapper { mapOnSegOpSubExp :: SubExp -> StateT FV Identity SubExp
mapOnSegOpSubExp = (SubExp -> FV) -> SubExp -> StateT FV Identity SubExp
forall (m :: * -> *) s b.
(MonadState s m, Semigroup s) =>
(b -> s) -> b -> m b
walk SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn'
                             , mapOnSegOpLambda :: Lambda lore -> StateT FV Identity (Lambda lore)
mapOnSegOpLambda = (Lambda lore -> FV)
-> Lambda lore -> StateT FV Identity (Lambda lore)
forall (m :: * -> *) s b.
(MonadState s m, Semigroup s) =>
(b -> s) -> b -> m b
walk Lambda lore -> FV
forall a. FreeIn a => a -> FV
freeIn'
                             , mapOnSegOpBody :: KernelBody lore -> StateT FV Identity (KernelBody lore)
mapOnSegOpBody = (KernelBody lore -> FV)
-> KernelBody lore -> StateT FV Identity (KernelBody lore)
forall (m :: * -> *) s b.
(MonadState s m, Semigroup s) =>
(b -> s) -> b -> m b
walk KernelBody lore -> FV
forall a. FreeIn a => a -> FV
freeIn'
                             , mapOnSegOpVName :: VName -> StateT FV Identity VName
mapOnSegOpVName = (VName -> FV) -> VName -> StateT FV Identity VName
forall (m :: * -> *) s b.
(MonadState s m, Semigroup s) =>
(b -> s) -> b -> m b
walk VName -> FV
forall a. FreeIn a => a -> FV
freeIn'
                             }

instance OpMetrics (Op lore) => OpMetrics (SegOp lore) where
  opMetrics :: SegOp lore -> MetricsM ()
opMetrics (SegMap SegLevel
_ SegSpace
_ [Type]
_ KernelBody lore
body) =
    Text -> MetricsM () -> MetricsM ()
inside Text
"SegMap" (MetricsM () -> MetricsM ()) -> MetricsM () -> MetricsM ()
forall a b. (a -> b) -> a -> b
$ KernelBody lore -> MetricsM ()
forall lore. OpMetrics (Op lore) => KernelBody lore -> MetricsM ()
kernelBodyMetrics KernelBody lore
body
  opMetrics (SegRed SegLevel
_ SegSpace
_ [SegRedOp lore]
reds [Type]
_ KernelBody lore
body) =
    Text -> MetricsM () -> MetricsM ()
inside Text
"SegRed" (MetricsM () -> MetricsM ()) -> MetricsM () -> MetricsM ()
forall a b. (a -> b) -> a -> b
$ do (SegRedOp lore -> MetricsM ()) -> [SegRedOp lore] -> MetricsM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Lambda lore -> MetricsM ()
forall lore. OpMetrics (Op lore) => Lambda lore -> MetricsM ()
lambdaMetrics (Lambda lore -> MetricsM ())
-> (SegRedOp lore -> Lambda lore) -> SegRedOp lore -> MetricsM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegRedOp lore -> Lambda lore
forall lore. SegRedOp lore -> Lambda lore
segRedLambda) [SegRedOp lore]
reds
                         KernelBody lore -> MetricsM ()
forall lore. OpMetrics (Op lore) => KernelBody lore -> MetricsM ()
kernelBodyMetrics KernelBody lore
body
  opMetrics (SegScan SegLevel
_ SegSpace
_ Lambda lore
scan_op [SubExp]
_ [Type]
_ KernelBody lore
body) =
    Text -> MetricsM () -> MetricsM ()
inside Text
"SegScan" (MetricsM () -> MetricsM ()) -> MetricsM () -> MetricsM ()
forall a b. (a -> b) -> a -> b
$ Lambda lore -> MetricsM ()
forall lore. OpMetrics (Op lore) => Lambda lore -> MetricsM ()
lambdaMetrics Lambda lore
scan_op MetricsM () -> MetricsM () -> MetricsM ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> KernelBody lore -> MetricsM ()
forall lore. OpMetrics (Op lore) => KernelBody lore -> MetricsM ()
kernelBodyMetrics KernelBody lore
body
  opMetrics (SegHist SegLevel
_ SegSpace
_ [HistOp lore]
ops [Type]
_ KernelBody lore
body) =
    Text -> MetricsM () -> MetricsM ()
inside Text
"SegHist" (MetricsM () -> MetricsM ()) -> MetricsM () -> MetricsM ()
forall a b. (a -> b) -> a -> b
$ do (HistOp lore -> MetricsM ()) -> [HistOp lore] -> MetricsM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Lambda lore -> MetricsM ()
forall lore. OpMetrics (Op lore) => Lambda lore -> MetricsM ()
lambdaMetrics (Lambda lore -> MetricsM ())
-> (HistOp lore -> Lambda lore) -> HistOp lore -> MetricsM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp lore -> Lambda lore
forall lore. HistOp lore -> Lambda lore
histOp) [HistOp lore]
ops
                          KernelBody lore -> MetricsM ()
forall lore. OpMetrics (Op lore) => KernelBody lore -> MetricsM ()
kernelBodyMetrics KernelBody lore
body

instance Pretty SegSpace where
  ppr :: SegSpace -> Doc
ppr (SegSpace VName
phys [(VName, SubExp)]
dims) = Doc -> Doc
parens ([Doc] -> Doc
commasep ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ do (VName
i,SubExp
d) <- [(VName, SubExp)]
dims
                                                   Doc -> [Doc]
forall (m :: * -> *) a. Monad m => a -> m a
return (Doc -> [Doc]) -> Doc -> [Doc]
forall a b. (a -> b) -> a -> b
$ VName -> Doc
forall a. Pretty a => a -> Doc
ppr VName
i Doc -> Doc -> Doc
<+> Doc
"<" Doc -> Doc -> Doc
<+> SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
d) Doc -> Doc -> Doc
<+>
                             Doc -> Doc
parens (String -> Doc
text String
"~" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> VName -> Doc
forall a. Pretty a => a -> Doc
ppr VName
phys)

instance PP.Pretty SegLevel where
  ppr :: SegLevel -> Doc
ppr SegThread{} = Doc
"thread"
  ppr SegGroup{} = Doc
"group"

ppSegLevel :: SegLevel -> PP.Doc
ppSegLevel :: SegLevel -> Doc
ppSegLevel SegLevel
lvl =
  Doc -> Doc
PP.parens (Doc -> Doc) -> Doc -> Doc
forall a b. (a -> b) -> a -> b
$
  String -> Doc
text String
"#groups=" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Count NumGroups SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr (SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.semi Doc -> Doc -> Doc
<+>
  String -> Doc
text String
"groupsize=" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Count GroupSize SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr (SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<>
  case SegLevel -> SegVirt
segVirt SegLevel
lvl of
    SegVirt
SegNoVirt -> Doc
forall a. Monoid a => a
mempty
    SegVirt
SegVirt -> Doc
PP.semi Doc -> Doc -> Doc
<+> String -> Doc
text String
"virtualise"

instance PrettyLore lore => PP.Pretty (SegOp lore) where
  ppr :: SegOp lore -> Doc
ppr (SegMap SegLevel
lvl SegSpace
space [Type]
ts KernelBody lore
body) =
    String -> Doc
text String
"segmap_" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> SegLevel -> Doc
forall a. Pretty a => a -> Doc
ppr SegLevel
lvl Doc -> Doc -> Doc
</>
    SegLevel -> Doc
ppSegLevel SegLevel
lvl Doc -> Doc -> Doc
</>
    Doc -> Doc
PP.align (SegSpace -> Doc
forall a. Pretty a => a -> Doc
ppr SegSpace
space) Doc -> Doc -> Doc
<+>
    Doc
PP.colon Doc -> Doc -> Doc
<+> [Type] -> Doc
forall a. Pretty a => [a] -> Doc
ppTuple' [Type]
ts Doc -> Doc -> Doc
<+> String -> String -> Doc -> Doc
PP.nestedBlock String
"{" String
"}" (KernelBody lore -> Doc
forall a. Pretty a => a -> Doc
ppr KernelBody lore
body)

  ppr (SegRed SegLevel
lvl SegSpace
space [SegRedOp lore]
reds [Type]
ts KernelBody lore
body) =
    String -> Doc
text String
"segred_" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> SegLevel -> Doc
forall a. Pretty a => a -> Doc
ppr SegLevel
lvl Doc -> Doc -> Doc
</>
    SegLevel -> Doc
ppSegLevel SegLevel
lvl Doc -> Doc -> Doc
</>
    Doc -> Doc
PP.parens (Doc -> Doc
PP.braces ([Doc] -> Doc
forall a. Monoid a => [a] -> a
mconcat ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ Doc -> [Doc] -> [Doc]
forall a. a -> [a] -> [a]
intersperse (Doc
PP.comma Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.line) ([Doc] -> [Doc]) -> [Doc] -> [Doc]
forall a b. (a -> b) -> a -> b
$ (SegRedOp lore -> Doc) -> [SegRedOp lore] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map SegRedOp lore -> Doc
forall lore. PrettyLore lore => SegRedOp lore -> Doc
ppOp [SegRedOp lore]
reds)) Doc -> Doc -> Doc
</>
    Doc -> Doc
PP.align (SegSpace -> Doc
forall a. Pretty a => a -> Doc
ppr SegSpace
space) Doc -> Doc -> Doc
<+> Doc
PP.colon Doc -> Doc -> Doc
<+> [Type] -> Doc
forall a. Pretty a => [a] -> Doc
ppTuple' [Type]
ts Doc -> Doc -> Doc
<+>
    String -> String -> Doc -> Doc
PP.nestedBlock String
"{" String
"}" (KernelBody lore -> Doc
forall a. Pretty a => a -> Doc
ppr KernelBody lore
body)
    where ppOp :: SegRedOp lore -> Doc
ppOp (SegRedOp Commutativity
comm Lambda lore
lam [SubExp]
nes Shape
shape) =
            Doc -> Doc
PP.braces ([Doc] -> Doc
PP.commasep ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ (SubExp -> Doc) -> [SubExp] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr [SubExp]
nes) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.comma Doc -> Doc -> Doc
</>
            Shape -> Doc
forall a. Pretty a => a -> Doc
ppr Shape
shape Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.comma Doc -> Doc -> Doc
</>
            Doc
comm' Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Lambda lore -> Doc
forall a. Pretty a => a -> Doc
ppr Lambda lore
lam
            where comm' :: Doc
comm' = case Commutativity
comm of Commutativity
Commutative -> String -> Doc
text String
"commutative "
                                       Commutativity
Noncommutative -> Doc
forall a. Monoid a => a
mempty

  ppr (SegScan SegLevel
lvl SegSpace
space Lambda lore
scan_op [SubExp]
nes [Type]
ts KernelBody lore
body) =
    String -> Doc
text String
"segscan_" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> SegLevel -> Doc
forall a. Pretty a => a -> Doc
ppr SegLevel
lvl Doc -> Doc -> Doc
</>
    SegLevel -> Doc
ppSegLevel SegLevel
lvl Doc -> Doc -> Doc
</>
    Doc -> Doc
PP.parens (Lambda lore -> Doc
forall a. Pretty a => a -> Doc
ppr Lambda lore
scan_op Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.comma Doc -> Doc -> Doc
</>
               Doc -> Doc
PP.braces ([Doc] -> Doc
PP.commasep ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ (SubExp -> Doc) -> [SubExp] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr [SubExp]
nes)) Doc -> Doc -> Doc
</>
    Doc -> Doc
PP.align (SegSpace -> Doc
forall a. Pretty a => a -> Doc
ppr SegSpace
space) Doc -> Doc -> Doc
<+> Doc
PP.colon Doc -> Doc -> Doc
<+> [Type] -> Doc
forall a. Pretty a => [a] -> Doc
ppTuple' [Type]
ts Doc -> Doc -> Doc
<+>
    String -> String -> Doc -> Doc
PP.nestedBlock String
"{" String
"}" (KernelBody lore -> Doc
forall a. Pretty a => a -> Doc
ppr KernelBody lore
body)

  ppr (SegHist SegLevel
lvl SegSpace
space [HistOp lore]
ops [Type]
ts KernelBody lore
body) =
    String -> Doc
text String
"seghist_" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> SegLevel -> Doc
forall a. Pretty a => a -> Doc
ppr SegLevel
lvl Doc -> Doc -> Doc
</>
    SegLevel -> Doc
ppSegLevel SegLevel
lvl Doc -> Doc -> Doc
</>
    Doc -> Doc
PP.parens (Doc -> Doc
PP.braces ([Doc] -> Doc
forall a. Monoid a => [a] -> a
mconcat ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ Doc -> [Doc] -> [Doc]
forall a. a -> [a] -> [a]
intersperse (Doc
PP.comma Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.line) ([Doc] -> [Doc]) -> [Doc] -> [Doc]
forall a b. (a -> b) -> a -> b
$ (HistOp lore -> Doc) -> [HistOp lore] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map HistOp lore -> Doc
forall lore. PrettyLore lore => HistOp lore -> Doc
ppOp [HistOp lore]
ops)) Doc -> Doc -> Doc
</>
    Doc -> Doc
PP.align (SegSpace -> Doc
forall a. Pretty a => a -> Doc
ppr SegSpace
space) Doc -> Doc -> Doc
<+> Doc
PP.colon Doc -> Doc -> Doc
<+> [Type] -> Doc
forall a. Pretty a => [a] -> Doc
ppTuple' [Type]
ts Doc -> Doc -> Doc
<+>
    String -> String -> Doc -> Doc
PP.nestedBlock String
"{" String
"}" (KernelBody lore -> Doc
forall a. Pretty a => a -> Doc
ppr KernelBody lore
body)
    where ppOp :: HistOp lore -> Doc
ppOp (HistOp SubExp
w SubExp
rf [VName]
dests [SubExp]
nes Shape
shape Lambda lore
op) =
            SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
w Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.comma Doc -> Doc -> Doc
<+> SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
rf Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.comma Doc -> Doc -> Doc
</>
            Doc -> Doc
PP.braces ([Doc] -> Doc
PP.commasep ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ (VName -> Doc) -> [VName] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Doc
forall a. Pretty a => a -> Doc
ppr [VName]
dests) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.comma Doc -> Doc -> Doc
</>
            Doc -> Doc
PP.braces ([Doc] -> Doc
PP.commasep ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ (SubExp -> Doc) -> [SubExp] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr [SubExp]
nes) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.comma Doc -> Doc -> Doc
</>
            Shape -> Doc
forall a. Pretty a => a -> Doc
ppr Shape
shape Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.comma Doc -> Doc -> Doc
</>
            Lambda lore -> Doc
forall a. Pretty a => a -> Doc
ppr Lambda lore
op

instance Attributes inner => RangedOp (SegOp inner) where
  opRanges :: SegOp inner -> [Range]
opRanges SegOp inner
op = Int -> Range -> [Range]
forall a. Int -> a -> [a]
replicate ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([Type] -> Int) -> [Type] -> Int
forall a b. (a -> b) -> a -> b
$ SegOp inner -> [Type]
forall lore. SegOp lore -> [Type]
segOpType SegOp inner
op) Range
unknownRange

instance (Attributes lore, CanBeRanged (Op lore)) => CanBeRanged (SegOp lore) where
  type OpWithRanges (SegOp lore) = SegOp (Ranges lore)

  removeOpRanges :: OpWithRanges (SegOp lore) -> SegOp lore
removeOpRanges = Identity (SegOp lore) -> SegOp lore
forall a. Identity a -> a
runIdentity (Identity (SegOp lore) -> SegOp lore)
-> (SegOp (Ranges lore) -> Identity (SegOp lore))
-> SegOp (Ranges lore)
-> SegOp lore
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegOpMapper (Ranges lore) lore Identity
-> SegOp (Ranges lore) -> Identity (SegOp lore)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
SegOpMapper flore tlore m -> SegOp flore -> m (SegOp tlore)
mapSegOpM SegOpMapper (Ranges lore) lore Identity
remove
    where remove :: SegOpMapper (Ranges lore) lore Identity
remove = (SubExp -> Identity SubExp)
-> (Lambda (Ranges lore) -> Identity (Lambda lore))
-> (KernelBody (Ranges lore) -> Identity (KernelBody lore))
-> (VName -> Identity VName)
-> SegOpMapper (Ranges lore) lore Identity
forall flore tlore (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda flore -> m (Lambda tlore))
-> (KernelBody flore -> m (KernelBody tlore))
-> (VName -> m VName)
-> SegOpMapper flore tlore m
SegOpMapper SubExp -> Identity SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda lore -> Identity (Lambda lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda lore -> Identity (Lambda lore))
-> (Lambda (Ranges lore) -> Lambda lore)
-> Lambda (Ranges lore)
-> Identity (Lambda lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda (Ranges lore) -> Lambda lore
forall lore.
CanBeRanged (Op lore) =>
Lambda (Ranges lore) -> Lambda lore
removeLambdaRanges)
                   (KernelBody lore -> Identity (KernelBody lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (KernelBody lore -> Identity (KernelBody lore))
-> (KernelBody (Ranges lore) -> KernelBody lore)
-> KernelBody (Ranges lore)
-> Identity (KernelBody lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelBody (Ranges lore) -> KernelBody lore
forall lore.
CanBeRanged (Op lore) =>
KernelBody (Ranges lore) -> KernelBody lore
removeKernelBodyRanges) VName -> Identity VName
forall (m :: * -> *) a. Monad m => a -> m a
return
  addOpRanges :: SegOp lore -> OpWithRanges (SegOp lore)
addOpRanges = RangeM (SegOp (Ranges lore)) -> SegOp (Ranges lore)
forall a. RangeM a -> a
Range.runRangeM (RangeM (SegOp (Ranges lore)) -> SegOp (Ranges lore))
-> (SegOp lore -> RangeM (SegOp (Ranges lore)))
-> SegOp lore
-> SegOp (Ranges lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegOpMapper lore (Ranges lore) (ReaderT RangeEnv Identity)
-> SegOp lore -> RangeM (SegOp (Ranges lore))
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
SegOpMapper flore tlore m -> SegOp flore -> m (SegOp tlore)
mapSegOpM SegOpMapper lore (Ranges lore) (ReaderT RangeEnv Identity)
add
    where add :: SegOpMapper lore (Ranges lore) (ReaderT RangeEnv Identity)
add = (SubExp -> ReaderT RangeEnv Identity SubExp)
-> (Lambda lore
    -> ReaderT RangeEnv Identity (Lambda (Ranges lore)))
-> (KernelBody lore
    -> ReaderT RangeEnv Identity (KernelBody (Ranges lore)))
-> (VName -> ReaderT RangeEnv Identity VName)
-> SegOpMapper lore (Ranges lore) (ReaderT RangeEnv Identity)
forall flore tlore (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda flore -> m (Lambda tlore))
-> (KernelBody flore -> m (KernelBody tlore))
-> (VName -> m VName)
-> SegOpMapper flore tlore m
SegOpMapper SubExp -> ReaderT RangeEnv Identity SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return Lambda lore -> ReaderT RangeEnv Identity (Lambda (Ranges lore))
forall lore.
(Attributes lore, CanBeRanged (Op lore)) =>
Lambda lore -> RangeM (Lambda (Ranges lore))
Range.analyseLambda
                KernelBody lore
-> ReaderT RangeEnv Identity (KernelBody (Ranges lore))
forall lore.
(Attributes lore, CanBeRanged (Op lore)) =>
KernelBody lore -> RangeM (KernelBody (Ranges lore))
addKernelBodyRanges VName -> ReaderT RangeEnv Identity VName
forall (m :: * -> *) a. Monad m => a -> m a
return

instance (Attributes lore,
          Attributes (Aliases lore),
          CanBeAliased (Op lore)) => CanBeAliased (SegOp lore) where
  type OpWithAliases (SegOp lore) = SegOp (Aliases lore)

  addOpAliases :: SegOp lore -> OpWithAliases (SegOp lore)
addOpAliases = Identity (SegOp (Aliases lore)) -> SegOp (Aliases lore)
forall a. Identity a -> a
runIdentity (Identity (SegOp (Aliases lore)) -> SegOp (Aliases lore))
-> (SegOp lore -> Identity (SegOp (Aliases lore)))
-> SegOp lore
-> SegOp (Aliases lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegOpMapper lore (Aliases lore) Identity
-> SegOp lore -> Identity (SegOp (Aliases lore))
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
SegOpMapper flore tlore m -> SegOp flore -> m (SegOp tlore)
mapSegOpM SegOpMapper lore (Aliases lore) Identity
alias
    where alias :: SegOpMapper lore (Aliases lore) Identity
alias = (SubExp -> Identity SubExp)
-> (Lambda lore -> Identity (Lambda (Aliases lore)))
-> (KernelBody lore -> Identity (KernelBody (Aliases lore)))
-> (VName -> Identity VName)
-> SegOpMapper lore (Aliases lore) Identity
forall flore tlore (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda flore -> m (Lambda tlore))
-> (KernelBody flore -> m (KernelBody tlore))
-> (VName -> m VName)
-> SegOpMapper flore tlore m
SegOpMapper SubExp -> Identity SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda (Aliases lore) -> Identity (Lambda (Aliases lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda (Aliases lore) -> Identity (Lambda (Aliases lore)))
-> (Lambda lore -> Lambda (Aliases lore))
-> Lambda lore
-> Identity (Lambda (Aliases lore))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda lore -> Lambda (Aliases lore)
forall lore.
(Attributes lore, CanBeAliased (Op lore)) =>
Lambda lore -> Lambda (Aliases lore)
Alias.analyseLambda)
                  (KernelBody (Aliases lore) -> Identity (KernelBody (Aliases lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (KernelBody (Aliases lore) -> Identity (KernelBody (Aliases lore)))
-> (KernelBody lore -> KernelBody (Aliases lore))
-> KernelBody lore
-> Identity (KernelBody (Aliases lore))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelBody lore -> KernelBody (Aliases lore)
forall lore.
(Attributes lore, CanBeAliased (Op lore)) =>
KernelBody lore -> KernelBody (Aliases lore)
aliasAnalyseKernelBody) VName -> Identity VName
forall (m :: * -> *) a. Monad m => a -> m a
return

  removeOpAliases :: OpWithAliases (SegOp lore) -> SegOp lore
removeOpAliases = Identity (SegOp lore) -> SegOp lore
forall a. Identity a -> a
runIdentity (Identity (SegOp lore) -> SegOp lore)
-> (SegOp (Aliases lore) -> Identity (SegOp lore))
-> SegOp (Aliases lore)
-> SegOp lore
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegOpMapper (Aliases lore) lore Identity
-> SegOp (Aliases lore) -> Identity (SegOp lore)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
SegOpMapper flore tlore m -> SegOp flore -> m (SegOp tlore)
mapSegOpM SegOpMapper (Aliases lore) lore Identity
remove
    where remove :: SegOpMapper (Aliases lore) lore Identity
remove = (SubExp -> Identity SubExp)
-> (Lambda (Aliases lore) -> Identity (Lambda lore))
-> (KernelBody (Aliases lore) -> Identity (KernelBody lore))
-> (VName -> Identity VName)
-> SegOpMapper (Aliases lore) lore Identity
forall flore tlore (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda flore -> m (Lambda tlore))
-> (KernelBody flore -> m (KernelBody tlore))
-> (VName -> m VName)
-> SegOpMapper flore tlore m
SegOpMapper SubExp -> Identity SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda lore -> Identity (Lambda lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda lore -> Identity (Lambda lore))
-> (Lambda (Aliases lore) -> Lambda lore)
-> Lambda (Aliases lore)
-> Identity (Lambda lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda (Aliases lore) -> Lambda lore
forall lore.
CanBeAliased (Op lore) =>
Lambda (Aliases lore) -> Lambda lore
removeLambdaAliases)
                   (KernelBody lore -> Identity (KernelBody lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (KernelBody lore -> Identity (KernelBody lore))
-> (KernelBody (Aliases lore) -> KernelBody lore)
-> KernelBody (Aliases lore)
-> Identity (KernelBody lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelBody (Aliases lore) -> KernelBody lore
forall lore.
CanBeAliased (Op lore) =>
KernelBody (Aliases lore) -> KernelBody lore
removeKernelBodyAliases) VName -> Identity VName
forall (m :: * -> *) a. Monad m => a -> m a
return

instance (CanBeWise (Op lore), Attributes lore) => CanBeWise (SegOp lore) where
  type OpWithWisdom (SegOp lore) = SegOp (Wise lore)

  removeOpWisdom :: OpWithWisdom (SegOp lore) -> SegOp lore
removeOpWisdom = Identity (SegOp lore) -> SegOp lore
forall a. Identity a -> a
runIdentity (Identity (SegOp lore) -> SegOp lore)
-> (SegOp (Wise lore) -> Identity (SegOp lore))
-> SegOp (Wise lore)
-> SegOp lore
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegOpMapper (Wise lore) lore Identity
-> SegOp (Wise lore) -> Identity (SegOp lore)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
SegOpMapper flore tlore m -> SegOp flore -> m (SegOp tlore)
mapSegOpM SegOpMapper (Wise lore) lore Identity
remove
    where remove :: SegOpMapper (Wise lore) lore Identity
remove = (SubExp -> Identity SubExp)
-> (Lambda (Wise lore) -> Identity (Lambda lore))
-> (KernelBody (Wise lore) -> Identity (KernelBody lore))
-> (VName -> Identity VName)
-> SegOpMapper (Wise lore) lore Identity
forall flore tlore (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda flore -> m (Lambda tlore))
-> (KernelBody flore -> m (KernelBody tlore))
-> (VName -> m VName)
-> SegOpMapper flore tlore m
SegOpMapper SubExp -> Identity SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return
                   (Lambda lore -> Identity (Lambda lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda lore -> Identity (Lambda lore))
-> (Lambda (Wise lore) -> Lambda lore)
-> Lambda (Wise lore)
-> Identity (Lambda lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda (Wise lore) -> Lambda lore
forall lore.
CanBeWise (Op lore) =>
Lambda (Wise lore) -> Lambda lore
removeLambdaWisdom)
                   (KernelBody lore -> Identity (KernelBody lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (KernelBody lore -> Identity (KernelBody lore))
-> (KernelBody (Wise lore) -> KernelBody lore)
-> KernelBody (Wise lore)
-> Identity (KernelBody lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelBody (Wise lore) -> KernelBody lore
forall lore.
CanBeWise (Op lore) =>
KernelBody (Wise lore) -> KernelBody lore
removeKernelBodyWisdom)
                   VName -> Identity VName
forall (m :: * -> *) a. Monad m => a -> m a
return

instance Attributes lore => ST.IndexOp (SegOp lore) where
  indexOp :: SymbolTable lore
-> Int -> SegOp lore -> [PrimExp VName] -> Maybe Indexed
indexOp SymbolTable lore
vtable Int
k (SegMap SegLevel
_ SegSpace
space [Type]
_ KernelBody lore
kbody) [PrimExp VName]
is = do
    Returns ResultManifest
ResultMaySimplify SubExp
se <- Int -> [KernelResult] -> Maybe KernelResult
forall int a. Integral int => int -> [a] -> Maybe a
maybeNth Int
k ([KernelResult] -> Maybe KernelResult)
-> [KernelResult] -> Maybe KernelResult
forall a b. (a -> b) -> a -> b
$ KernelBody lore -> [KernelResult]
forall lore. KernelBody lore -> [KernelResult]
kernelBodyResult KernelBody lore
kbody
    Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ [VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
gtids Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= [PrimExp VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PrimExp VName]
is
    let idx_table :: Map VName Indexed
idx_table = [(VName, Indexed)] -> Map VName Indexed
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, Indexed)] -> Map VName Indexed)
-> [(VName, Indexed)] -> Map VName Indexed
forall a b. (a -> b) -> a -> b
$ [VName] -> [Indexed] -> [(VName, Indexed)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
gtids ([Indexed] -> [(VName, Indexed)])
-> [Indexed] -> [(VName, Indexed)]
forall a b. (a -> b) -> a -> b
$ (PrimExp VName -> Indexed) -> [PrimExp VName] -> [Indexed]
forall a b. (a -> b) -> [a] -> [b]
map (Certificates -> PrimExp VName -> Indexed
ST.Indexed Certificates
forall a. Monoid a => a
mempty) [PrimExp VName]
is
        idx_table' :: Map VName Indexed
idx_table' = (Map VName Indexed -> Stm lore -> Map VName Indexed)
-> Map VName Indexed -> Seq (Stm lore) -> Map VName Indexed
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Map VName Indexed -> Stm lore -> Map VName Indexed
expandIndexedTable Map VName Indexed
idx_table (Seq (Stm lore) -> Map VName Indexed)
-> Seq (Stm lore) -> Map VName Indexed
forall a b. (a -> b) -> a -> b
$ KernelBody lore -> Seq (Stm lore)
forall lore. KernelBody lore -> Stms lore
kernelBodyStms KernelBody lore
kbody
    case SubExp
se of
      Var VName
v -> VName -> Map VName Indexed -> Maybe Indexed
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName Indexed
idx_table'
      SubExp
_ -> Maybe Indexed
forall a. Maybe a
Nothing

    where ([VName]
gtids, [SubExp]
_) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
          -- Indexes in excess of what is used to index through the
          -- segment dimensions.
          excess_is :: [PrimExp VName]
excess_is = Int -> [PrimExp VName] -> [PrimExp VName]
forall a. Int -> [a] -> [a]
drop ([VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
gtids) [PrimExp VName]
is

          expandIndexedTable :: Map VName Indexed -> Stm lore -> Map VName Indexed
expandIndexedTable Map VName Indexed
table Stm lore
stm
            | [VName
v] <- PatternT (LetAttr lore) -> [VName]
forall attr. PatternT attr -> [VName]
patternNames (PatternT (LetAttr lore) -> [VName])
-> PatternT (LetAttr lore) -> [VName]
forall a b. (a -> b) -> a -> b
$ Stm lore -> PatternT (LetAttr lore)
forall lore. Stm lore -> Pattern lore
stmPattern Stm lore
stm,
              Just (PrimExp VName
pe,Certificates
cs) <-
                  WriterT Certificates Maybe (PrimExp VName)
-> Maybe (PrimExp VName, Certificates)
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT Certificates Maybe (PrimExp VName)
 -> Maybe (PrimExp VName, Certificates))
-> WriterT Certificates Maybe (PrimExp VName)
-> Maybe (PrimExp VName, Certificates)
forall a b. (a -> b) -> a -> b
$ (VName -> WriterT Certificates Maybe (PrimExp VName))
-> Exp lore -> WriterT Certificates Maybe (PrimExp VName)
forall (m :: * -> *) lore v.
(MonadFail m, Annotations lore) =>
(VName -> m (PrimExp v)) -> Exp lore -> m (PrimExp v)
primExpFromExp (Map VName Indexed
-> VName -> WriterT Certificates Maybe (PrimExp VName)
asPrimExp Map VName Indexed
table) (Exp lore -> WriterT Certificates Maybe (PrimExp VName))
-> Exp lore -> WriterT Certificates Maybe (PrimExp VName)
forall a b. (a -> b) -> a -> b
$ Stm lore -> Exp lore
forall lore. Stm lore -> Exp lore
stmExp Stm lore
stm =
                VName -> Indexed -> Map VName Indexed -> Map VName Indexed
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v (Certificates -> PrimExp VName -> Indexed
ST.Indexed (Stm lore -> Certificates
forall lore. Stm lore -> Certificates
stmCerts Stm lore
stm Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> Certificates
cs) PrimExp VName
pe) Map VName Indexed
table

            | [VName
v] <- PatternT (LetAttr lore) -> [VName]
forall attr. PatternT attr -> [VName]
patternNames (PatternT (LetAttr lore) -> [VName])
-> PatternT (LetAttr lore) -> [VName]
forall a b. (a -> b) -> a -> b
$ Stm lore -> PatternT (LetAttr lore)
forall lore. Stm lore -> Pattern lore
stmPattern Stm lore
stm,
              BasicOp (Index VName
arr Slice SubExp
slice) <- Stm lore -> Exp lore
forall lore. Stm lore -> Exp lore
stmExp Stm lore
stm,
              [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (Slice SubExp -> [SubExp]
forall d. Slice d -> [d]
sliceDims Slice SubExp
slice) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [PrimExp VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PrimExp VName]
excess_is,
              VName
arr VName -> SymbolTable lore -> Bool
forall lore. VName -> SymbolTable lore -> Bool
`ST.elem` SymbolTable lore
vtable,
              Just ([DimIndex (PrimExp VName)]
slice', Certificates
cs) <- Map VName Indexed
-> Slice SubExp -> Maybe ([DimIndex (PrimExp VName)], Certificates)
asPrimExpSlice Map VName Indexed
table Slice SubExp
slice =
                let idx :: Indexed
idx = Certificates -> VName -> [PrimExp VName] -> Indexed
ST.IndexedArray (Stm lore -> Certificates
forall lore. Stm lore -> Certificates
stmCerts Stm lore
stm Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> Certificates
cs)
                          VName
arr ([DimIndex (PrimExp VName)] -> [PrimExp VName] -> [PrimExp VName]
forall d. Num d => Slice d -> [d] -> [d]
fixSlice [DimIndex (PrimExp VName)]
slice' [PrimExp VName]
excess_is)
                in VName -> Indexed -> Map VName Indexed -> Map VName Indexed
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v Indexed
idx Map VName Indexed
table

            | Bool
otherwise =
                Map VName Indexed
table

          asPrimExpSlice :: Map VName Indexed
-> Slice SubExp -> Maybe ([DimIndex (PrimExp VName)], Certificates)
asPrimExpSlice Map VName Indexed
table =
            WriterT Certificates Maybe [DimIndex (PrimExp VName)]
-> Maybe ([DimIndex (PrimExp VName)], Certificates)
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT Certificates Maybe [DimIndex (PrimExp VName)]
 -> Maybe ([DimIndex (PrimExp VName)], Certificates))
-> (Slice SubExp
    -> WriterT Certificates Maybe [DimIndex (PrimExp VName)])
-> Slice SubExp
-> Maybe ([DimIndex (PrimExp VName)], Certificates)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (DimIndex SubExp
 -> WriterT Certificates Maybe (DimIndex (PrimExp VName)))
-> Slice SubExp
-> WriterT Certificates Maybe [DimIndex (PrimExp VName)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((SubExp -> WriterT Certificates Maybe (PrimExp VName))
-> DimIndex SubExp
-> WriterT Certificates Maybe (DimIndex (PrimExp VName))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse ((VName -> WriterT Certificates Maybe (PrimExp VName))
-> SubExp -> WriterT Certificates Maybe (PrimExp VName)
forall (m :: * -> *) v.
Applicative m =>
(VName -> m (PrimExp v)) -> SubExp -> m (PrimExp v)
primExpFromSubExpM (Map VName Indexed
-> VName -> WriterT Certificates Maybe (PrimExp VName)
asPrimExp Map VName Indexed
table)))

          asPrimExp :: Map VName Indexed
-> VName -> WriterT Certificates Maybe (PrimExp VName)
asPrimExp Map VName Indexed
table VName
v
            | Just (ST.Indexed Certificates
cs PrimExp VName
e) <- VName -> Map VName Indexed -> Maybe Indexed
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName Indexed
table = Certificates -> WriterT Certificates Maybe ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell Certificates
cs WriterT Certificates Maybe ()
-> WriterT Certificates Maybe (PrimExp VName)
-> WriterT Certificates Maybe (PrimExp VName)
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> PrimExp VName -> WriterT Certificates Maybe (PrimExp VName)
forall (m :: * -> *) a. Monad m => a -> m a
return PrimExp VName
e
            | Just (Prim PrimType
pt) <- VName -> SymbolTable lore -> Maybe Type
forall lore.
Attributes lore =>
VName -> SymbolTable lore -> Maybe Type
ST.lookupType VName
v SymbolTable lore
vtable =
                PrimExp VName -> WriterT Certificates Maybe (PrimExp VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (PrimExp VName -> WriterT Certificates Maybe (PrimExp VName))
-> PrimExp VName -> WriterT Certificates Maybe (PrimExp VName)
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> PrimExp VName
forall v. v -> PrimType -> PrimExp v
LeafExp VName
v PrimType
pt
            | Bool
otherwise = Maybe (PrimExp VName) -> WriterT Certificates Maybe (PrimExp VName)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift Maybe (PrimExp VName)
forall a. Maybe a
Nothing

  indexOp SymbolTable lore
_ Int
_ SegOp lore
_ [PrimExp VName]
_ = Maybe Indexed
forall a. Maybe a
Nothing

instance Attributes lore => IsOp (SegOp lore) where
  cheapOp :: SegOp lore -> Bool
cheapOp SegOp lore
_ = Bool
False
  safeOp :: SegOp lore -> Bool
safeOp SegOp lore
_ = Bool
True

--- Host operations

-- | A simple size-level query or computation.
data SizeOp
  = SplitSpace SplitOrdering SubExp SubExp SubExp
    -- ^ @SplitSpace o w i elems_per_thread@.
    --
    -- Computes how to divide array elements to
    -- threads in a kernel.  Returns the number of
    -- elements in the chunk that the current thread
    -- should take.
    --
    -- @w@ is the length of the outer dimension in
    -- the array. @i@ is the current thread
    -- index. Each thread takes at most
    -- @elems_per_thread@ elements.
    --
    -- If the order @o@ is 'SplitContiguous', thread with index @i@
    -- should receive elements
    -- @i*elems_per_tread, i*elems_per_thread + 1,
    -- ..., i*elems_per_thread + (elems_per_thread-1)@.
    --
    -- If the order @o@ is @'SplitStrided' stride@,
    -- the thread will receive elements @i,
    -- i+stride, i+2*stride, ...,
    -- i+(elems_per_thread-1)*stride@.
  | GetSize Name SizeClass
    -- ^ Produce some runtime-configurable size.
  | GetSizeMax SizeClass
    -- ^ The maximum size of some class.
  | CmpSizeLe Name SizeClass SubExp
    -- ^ Compare size (likely a threshold) with some integer value.
  | CalcNumGroups SubExp Name SubExp
    -- ^ @CalcNumGroups w max_num_groups group_size@ calculates the
    -- number of GPU workgroups to use for an input of the given size.
    -- The @Name@ is a size name.  Note that @w@ is an i64 to avoid
    -- overflow issues.
  deriving (SizeOp -> SizeOp -> Bool
(SizeOp -> SizeOp -> Bool)
-> (SizeOp -> SizeOp -> Bool) -> Eq SizeOp
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SizeOp -> SizeOp -> Bool
$c/= :: SizeOp -> SizeOp -> Bool
== :: SizeOp -> SizeOp -> Bool
$c== :: SizeOp -> SizeOp -> Bool
Eq, Eq SizeOp
Eq SizeOp
-> (SizeOp -> SizeOp -> Ordering)
-> (SizeOp -> SizeOp -> Bool)
-> (SizeOp -> SizeOp -> Bool)
-> (SizeOp -> SizeOp -> Bool)
-> (SizeOp -> SizeOp -> Bool)
-> (SizeOp -> SizeOp -> SizeOp)
-> (SizeOp -> SizeOp -> SizeOp)
-> Ord SizeOp
SizeOp -> SizeOp -> Bool
SizeOp -> SizeOp -> Ordering
SizeOp -> SizeOp -> SizeOp
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: SizeOp -> SizeOp -> SizeOp
$cmin :: SizeOp -> SizeOp -> SizeOp
max :: SizeOp -> SizeOp -> SizeOp
$cmax :: SizeOp -> SizeOp -> SizeOp
>= :: SizeOp -> SizeOp -> Bool
$c>= :: SizeOp -> SizeOp -> Bool
> :: SizeOp -> SizeOp -> Bool
$c> :: SizeOp -> SizeOp -> Bool
<= :: SizeOp -> SizeOp -> Bool
$c<= :: SizeOp -> SizeOp -> Bool
< :: SizeOp -> SizeOp -> Bool
$c< :: SizeOp -> SizeOp -> Bool
compare :: SizeOp -> SizeOp -> Ordering
$ccompare :: SizeOp -> SizeOp -> Ordering
$cp1Ord :: Eq SizeOp
Ord, Int -> SizeOp -> ShowS
[SizeOp] -> ShowS
SizeOp -> String
(Int -> SizeOp -> ShowS)
-> (SizeOp -> String) -> ([SizeOp] -> ShowS) -> Show SizeOp
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SizeOp] -> ShowS
$cshowList :: [SizeOp] -> ShowS
show :: SizeOp -> String
$cshow :: SizeOp -> String
showsPrec :: Int -> SizeOp -> ShowS
$cshowsPrec :: Int -> SizeOp -> ShowS
Show)

instance Substitute SizeOp where
  substituteNames :: Map VName VName -> SizeOp -> SizeOp
substituteNames Map VName VName
subst (SplitSpace SplitOrdering
o SubExp
w SubExp
i SubExp
elems_per_thread) =
    SplitOrdering -> SubExp -> SubExp -> SubExp -> SizeOp
SplitSpace
    (Map VName VName -> SplitOrdering -> SplitOrdering
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst SplitOrdering
o)
    (Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst SubExp
w)
    (Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst SubExp
i)
    (Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst SubExp
elems_per_thread)
  substituteNames Map VName VName
substs (CmpSizeLe Name
name SizeClass
sclass SubExp
x) =
    Name -> SizeClass -> SubExp -> SizeOp
CmpSizeLe Name
name SizeClass
sclass (Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs SubExp
x)
  substituteNames Map VName VName
substs (CalcNumGroups SubExp
w Name
max_num_groups SubExp
group_size) =
    SubExp -> Name -> SubExp -> SizeOp
CalcNumGroups
    (Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs SubExp
w)
    Name
max_num_groups
    (Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs SubExp
group_size)
  substituteNames Map VName VName
_ SizeOp
op = SizeOp
op

instance Rename SizeOp where
  rename :: SizeOp -> RenameM SizeOp
rename (SplitSpace SplitOrdering
o SubExp
w SubExp
i SubExp
elems_per_thread) =
    SplitOrdering -> SubExp -> SubExp -> SubExp -> SizeOp
SplitSpace
    (SplitOrdering -> SubExp -> SubExp -> SubExp -> SizeOp)
-> RenameM SplitOrdering
-> RenameM (SubExp -> SubExp -> SubExp -> SizeOp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SplitOrdering -> RenameM SplitOrdering
forall a. Rename a => a -> RenameM a
rename SplitOrdering
o
    RenameM (SubExp -> SubExp -> SubExp -> SizeOp)
-> RenameM SubExp -> RenameM (SubExp -> SubExp -> SizeOp)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SubExp -> RenameM SubExp
forall a. Rename a => a -> RenameM a
rename SubExp
w
    RenameM (SubExp -> SubExp -> SizeOp)
-> RenameM SubExp -> RenameM (SubExp -> SizeOp)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SubExp -> RenameM SubExp
forall a. Rename a => a -> RenameM a
rename SubExp
i
    RenameM (SubExp -> SizeOp) -> RenameM SubExp -> RenameM SizeOp
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SubExp -> RenameM SubExp
forall a. Rename a => a -> RenameM a
rename SubExp
elems_per_thread
  rename (CmpSizeLe Name
name SizeClass
sclass SubExp
x) =
    Name -> SizeClass -> SubExp -> SizeOp
CmpSizeLe Name
name SizeClass
sclass (SubExp -> SizeOp) -> RenameM SubExp -> RenameM SizeOp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> RenameM SubExp
forall a. Rename a => a -> RenameM a
rename SubExp
x
  rename (CalcNumGroups SubExp
w Name
max_num_groups SubExp
group_size) =
    SubExp -> Name -> SubExp -> SizeOp
CalcNumGroups (SubExp -> Name -> SubExp -> SizeOp)
-> RenameM SubExp -> RenameM (Name -> SubExp -> SizeOp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> RenameM SubExp
forall a. Rename a => a -> RenameM a
rename SubExp
w RenameM (Name -> SubExp -> SizeOp)
-> RenameM Name -> RenameM (SubExp -> SizeOp)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Name -> RenameM Name
forall (f :: * -> *) a. Applicative f => a -> f a
pure Name
max_num_groups RenameM (SubExp -> SizeOp) -> RenameM SubExp -> RenameM SizeOp
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SubExp -> RenameM SubExp
forall a. Rename a => a -> RenameM a
rename SubExp
group_size
  rename SizeOp
x = SizeOp -> RenameM SizeOp
forall (f :: * -> *) a. Applicative f => a -> f a
pure SizeOp
x

instance IsOp SizeOp where
  safeOp :: SizeOp -> Bool
safeOp SizeOp
_ = Bool
True
  cheapOp :: SizeOp -> Bool
cheapOp SizeOp
_ = Bool
True

instance TypedOp SizeOp where
  opType :: SizeOp -> m [ExtType]
opType SplitSpace{} = [ExtType] -> m [ExtType]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [PrimType -> ExtType
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32]
  opType (GetSize Name
_ SizeClass
_) = [ExtType] -> m [ExtType]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [PrimType -> ExtType
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32]
  opType (GetSizeMax SizeClass
_) = [ExtType] -> m [ExtType]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [PrimType -> ExtType
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32]
  opType CmpSizeLe{} = [ExtType] -> m [ExtType]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [PrimType -> ExtType
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
Bool]
  opType CalcNumGroups{} = [ExtType] -> m [ExtType]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [PrimType -> ExtType
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32]

instance AliasedOp SizeOp where
  opAliases :: SizeOp -> [Names]
opAliases SizeOp
_ = [Names
forall a. Monoid a => a
mempty]
  consumedInOp :: SizeOp -> Names
consumedInOp SizeOp
_ = Names
forall a. Monoid a => a
mempty

instance RangedOp SizeOp where
  opRanges :: SizeOp -> [Range]
opRanges (SplitSpace SplitOrdering
_ SubExp
_ SubExp
_ SubExp
elems_per_thread) =
    [(KnownBound -> Maybe KnownBound
forall a. a -> Maybe a
Just (ScalExp -> KnownBound
ScalarBound ScalExp
0),
      KnownBound -> Maybe KnownBound
forall a. a -> Maybe a
Just (ScalExp -> KnownBound
ScalarBound (SubExp -> PrimType -> ScalExp
SE.subExpToScalExp SubExp
elems_per_thread PrimType
int32)))]
  opRanges SizeOp
_ = [Range
unknownRange]

instance FreeIn SizeOp where
  freeIn' :: SizeOp -> FV
freeIn' (SplitSpace SplitOrdering
o SubExp
w SubExp
i SubExp
elems_per_thread) =
    SplitOrdering -> FV
forall a. FreeIn a => a -> FV
freeIn' SplitOrdering
o FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> [SubExp] -> FV
forall a. FreeIn a => a -> FV
freeIn' [SubExp
w, SubExp
i, SubExp
elems_per_thread]
  freeIn' (CmpSizeLe Name
_ SizeClass
_ SubExp
x) = SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn' SubExp
x
  freeIn' (CalcNumGroups SubExp
w Name
_ SubExp
group_size) = SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn' SubExp
w FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn' SubExp
group_size
  freeIn' SizeOp
_ = FV
forall a. Monoid a => a
mempty

instance PP.Pretty SizeOp where
  ppr :: SizeOp -> Doc
ppr (SplitSpace SplitOrdering
o SubExp
w SubExp
i SubExp
elems_per_thread) =
    String -> Doc
text String
"splitSpace" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
suff Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<>
    Doc -> Doc
parens ([Doc] -> Doc
commasep [SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
w, SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
i, SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
elems_per_thread])
    where suff :: Doc
suff = case SplitOrdering
o of SplitOrdering
SplitContiguous     -> Doc
forall a. Monoid a => a
mempty
                           SplitStrided SubExp
stride -> String -> Doc
text String
"Strided" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens (SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
stride)

  ppr (GetSize Name
name SizeClass
size_class) =
    String -> Doc
text String
"get_size" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens ([Doc] -> Doc
commasep [Name -> Doc
forall a. Pretty a => a -> Doc
ppr Name
name, SizeClass -> Doc
forall a. Pretty a => a -> Doc
ppr SizeClass
size_class])

  ppr (GetSizeMax SizeClass
size_class) =
    String -> Doc
text String
"get_size_max" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens ([Doc] -> Doc
commasep [SizeClass -> Doc
forall a. Pretty a => a -> Doc
ppr SizeClass
size_class])

  ppr (CmpSizeLe Name
name SizeClass
size_class SubExp
x) =
    String -> Doc
text String
"get_size" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens ([Doc] -> Doc
commasep [Name -> Doc
forall a. Pretty a => a -> Doc
ppr Name
name, SizeClass -> Doc
forall a. Pretty a => a -> Doc
ppr SizeClass
size_class]) Doc -> Doc -> Doc
<+>
    String -> Doc
text String
"<=" Doc -> Doc -> Doc
<+> SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
x

  ppr (CalcNumGroups SubExp
w Name
max_num_groups SubExp
group_size) =
    String -> Doc
text String
"calc_num_groups" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens ([Doc] -> Doc
commasep [SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
w, Name -> Doc
forall a. Pretty a => a -> Doc
ppr Name
max_num_groups, SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
group_size])

instance OpMetrics SizeOp where
  opMetrics :: SizeOp -> MetricsM ()
opMetrics SplitSpace{} = Text -> MetricsM ()
seen Text
"SplitSpace"
  opMetrics GetSize{} = Text -> MetricsM ()
seen Text
"GetSize"
  opMetrics GetSizeMax{} = Text -> MetricsM ()
seen Text
"GetSizeMax"
  opMetrics CmpSizeLe{} = Text -> MetricsM ()
seen Text
"CmpSizeLe"
  opMetrics CalcNumGroups{} = Text -> MetricsM ()
seen Text
"CalcNumGroups"

typeCheckSizeOp :: TC.Checkable lore => SizeOp -> TC.TypeM lore ()
typeCheckSizeOp :: SizeOp -> TypeM lore ()
typeCheckSizeOp (SplitSpace SplitOrdering
o SubExp
w SubExp
i SubExp
elems_per_thread) = do
  case SplitOrdering
o of
    SplitOrdering
SplitContiguous     -> () -> TypeM lore ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    SplitStrided SubExp
stride -> [Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32] SubExp
stride
  (SubExp -> TypeM lore ()) -> [SubExp] -> TypeM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ([Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32]) [SubExp
w, SubExp
i, SubExp
elems_per_thread]
typeCheckSizeOp GetSize{} = () -> TypeM lore ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
typeCheckSizeOp GetSizeMax{} = () -> TypeM lore ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
typeCheckSizeOp (CmpSizeLe Name
_ SizeClass
_ SubExp
x) = [Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32] SubExp
x
typeCheckSizeOp (CalcNumGroups SubExp
w Name
_ SubExp
group_size) = do [Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
w
                                                    [Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32] SubExp
group_size

-- | A host-level operation; parameterised by what else it can do.
data HostOp lore op
  = SegOp (SegOp lore)
    -- ^ A segmented operation.
  | SizeOp SizeOp
  | OtherOp op
  deriving (HostOp lore op -> HostOp lore op -> Bool
(HostOp lore op -> HostOp lore op -> Bool)
-> (HostOp lore op -> HostOp lore op -> Bool)
-> Eq (HostOp lore op)
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall lore op.
(Annotations lore, Eq op) =>
HostOp lore op -> HostOp lore op -> Bool
/= :: HostOp lore op -> HostOp lore op -> Bool
$c/= :: forall lore op.
(Annotations lore, Eq op) =>
HostOp lore op -> HostOp lore op -> Bool
== :: HostOp lore op -> HostOp lore op -> Bool
$c== :: forall lore op.
(Annotations lore, Eq op) =>
HostOp lore op -> HostOp lore op -> Bool
Eq, Eq (HostOp lore op)
Eq (HostOp lore op)
-> (HostOp lore op -> HostOp lore op -> Ordering)
-> (HostOp lore op -> HostOp lore op -> Bool)
-> (HostOp lore op -> HostOp lore op -> Bool)
-> (HostOp lore op -> HostOp lore op -> Bool)
-> (HostOp lore op -> HostOp lore op -> Bool)
-> (HostOp lore op -> HostOp lore op -> HostOp lore op)
-> (HostOp lore op -> HostOp lore op -> HostOp lore op)
-> Ord (HostOp lore op)
HostOp lore op -> HostOp lore op -> Bool
HostOp lore op -> HostOp lore op -> Ordering
HostOp lore op -> HostOp lore op -> HostOp lore op
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall lore op. (Annotations lore, Ord op) => Eq (HostOp lore op)
forall lore op.
(Annotations lore, Ord op) =>
HostOp lore op -> HostOp lore op -> Bool
forall lore op.
(Annotations lore, Ord op) =>
HostOp lore op -> HostOp lore op -> Ordering
forall lore op.
(Annotations lore, Ord op) =>
HostOp lore op -> HostOp lore op -> HostOp lore op
min :: HostOp lore op -> HostOp lore op -> HostOp lore op
$cmin :: forall lore op.
(Annotations lore, Ord op) =>
HostOp lore op -> HostOp lore op -> HostOp lore op
max :: HostOp lore op -> HostOp lore op -> HostOp lore op
$cmax :: forall lore op.
(Annotations lore, Ord op) =>
HostOp lore op -> HostOp lore op -> HostOp lore op
>= :: HostOp lore op -> HostOp lore op -> Bool
$c>= :: forall lore op.
(Annotations lore, Ord op) =>
HostOp lore op -> HostOp lore op -> Bool
> :: HostOp lore op -> HostOp lore op -> Bool
$c> :: forall lore op.
(Annotations lore, Ord op) =>
HostOp lore op -> HostOp lore op -> Bool
<= :: HostOp lore op -> HostOp lore op -> Bool
$c<= :: forall lore op.
(Annotations lore, Ord op) =>
HostOp lore op -> HostOp lore op -> Bool
< :: HostOp lore op -> HostOp lore op -> Bool
$c< :: forall lore op.
(Annotations lore, Ord op) =>
HostOp lore op -> HostOp lore op -> Bool
compare :: HostOp lore op -> HostOp lore op -> Ordering
$ccompare :: forall lore op.
(Annotations lore, Ord op) =>
HostOp lore op -> HostOp lore op -> Ordering
$cp1Ord :: forall lore op. (Annotations lore, Ord op) => Eq (HostOp lore op)
Ord, Int -> HostOp lore op -> ShowS
[HostOp lore op] -> ShowS
HostOp lore op -> String
(Int -> HostOp lore op -> ShowS)
-> (HostOp lore op -> String)
-> ([HostOp lore op] -> ShowS)
-> Show (HostOp lore op)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall lore op.
(Annotations lore, Show op) =>
Int -> HostOp lore op -> ShowS
forall lore op.
(Annotations lore, Show op) =>
[HostOp lore op] -> ShowS
forall lore op.
(Annotations lore, Show op) =>
HostOp lore op -> String
showList :: [HostOp lore op] -> ShowS
$cshowList :: forall lore op.
(Annotations lore, Show op) =>
[HostOp lore op] -> ShowS
show :: HostOp lore op -> String
$cshow :: forall lore op.
(Annotations lore, Show op) =>
HostOp lore op -> String
showsPrec :: Int -> HostOp lore op -> ShowS
$cshowsPrec :: forall lore op.
(Annotations lore, Show op) =>
Int -> HostOp lore op -> ShowS
Show)

instance (Attributes lore, Substitute op) => Substitute (HostOp lore op) where
  substituteNames :: Map VName VName -> HostOp lore op -> HostOp lore op
substituteNames Map VName VName
substs (SegOp SegOp lore
op) =
    SegOp lore -> HostOp lore op
forall lore op. SegOp lore -> HostOp lore op
SegOp (SegOp lore -> HostOp lore op) -> SegOp lore -> HostOp lore op
forall a b. (a -> b) -> a -> b
$ Map VName VName -> SegOp lore -> SegOp lore
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs SegOp lore
op
  substituteNames Map VName VName
substs (OtherOp op
op) =
    op -> HostOp lore op
forall lore op. op -> HostOp lore op
OtherOp (op -> HostOp lore op) -> op -> HostOp lore op
forall a b. (a -> b) -> a -> b
$ Map VName VName -> op -> op
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs op
op
  substituteNames Map VName VName
substs (SizeOp SizeOp
op) =
    SizeOp -> HostOp lore op
forall lore op. SizeOp -> HostOp lore op
SizeOp (SizeOp -> HostOp lore op) -> SizeOp -> HostOp lore op
forall a b. (a -> b) -> a -> b
$ Map VName VName -> SizeOp -> SizeOp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs SizeOp
op

instance (Attributes lore, Rename op) => Rename (HostOp lore op) where
  rename :: HostOp lore op -> RenameM (HostOp lore op)
rename (SegOp SegOp lore
op) = SegOp lore -> HostOp lore op
forall lore op. SegOp lore -> HostOp lore op
SegOp (SegOp lore -> HostOp lore op)
-> RenameM (SegOp lore) -> RenameM (HostOp lore op)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOp lore -> RenameM (SegOp lore)
forall a. Rename a => a -> RenameM a
rename SegOp lore
op
  rename (OtherOp op
op) = op -> HostOp lore op
forall lore op. op -> HostOp lore op
OtherOp (op -> HostOp lore op) -> RenameM op -> RenameM (HostOp lore op)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> op -> RenameM op
forall a. Rename a => a -> RenameM a
rename op
op
  rename (SizeOp SizeOp
op) = SizeOp -> HostOp lore op
forall lore op. SizeOp -> HostOp lore op
SizeOp (SizeOp -> HostOp lore op)
-> RenameM SizeOp -> RenameM (HostOp lore op)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SizeOp -> RenameM SizeOp
forall a. Rename a => a -> RenameM a
rename SizeOp
op

instance (Attributes lore, IsOp op) => IsOp (HostOp lore op) where
  safeOp :: HostOp lore op -> Bool
safeOp (SegOp SegOp lore
op) = SegOp lore -> Bool
forall op. IsOp op => op -> Bool
safeOp SegOp lore
op
  safeOp (OtherOp op
op) = op -> Bool
forall op. IsOp op => op -> Bool
safeOp op
op
  safeOp (SizeOp SizeOp
op) = SizeOp -> Bool
forall op. IsOp op => op -> Bool
safeOp SizeOp
op

  cheapOp :: HostOp lore op -> Bool
cheapOp (SegOp SegOp lore
op) = SegOp lore -> Bool
forall op. IsOp op => op -> Bool
cheapOp SegOp lore
op
  cheapOp (OtherOp op
op) = op -> Bool
forall op. IsOp op => op -> Bool
cheapOp op
op
  cheapOp (SizeOp SizeOp
op) = SizeOp -> Bool
forall op. IsOp op => op -> Bool
cheapOp SizeOp
op

instance TypedOp op => TypedOp (HostOp lore op) where
  opType :: HostOp lore op -> m [ExtType]
opType (SegOp SegOp lore
op) = SegOp lore -> m [ExtType]
forall op t (m :: * -> *).
(TypedOp op, HasScope t m) =>
op -> m [ExtType]
opType SegOp lore
op
  opType (OtherOp op
op) = op -> m [ExtType]
forall op t (m :: * -> *).
(TypedOp op, HasScope t m) =>
op -> m [ExtType]
opType op
op
  opType (SizeOp SizeOp
op) = SizeOp -> m [ExtType]
forall op t (m :: * -> *).
(TypedOp op, HasScope t m) =>
op -> m [ExtType]
opType SizeOp
op

instance (Aliased lore, AliasedOp op, Attributes lore) => AliasedOp (HostOp lore op) where
  opAliases :: HostOp lore op -> [Names]
opAliases (SegOp SegOp lore
op) = SegOp lore -> [Names]
forall op. AliasedOp op => op -> [Names]
opAliases SegOp lore
op
  opAliases (OtherOp op
op) = op -> [Names]
forall op. AliasedOp op => op -> [Names]
opAliases op
op
  opAliases (SizeOp SizeOp
op) = SizeOp -> [Names]
forall op. AliasedOp op => op -> [Names]
opAliases SizeOp
op

  consumedInOp :: HostOp lore op -> Names
consumedInOp (SegOp SegOp lore
op) = SegOp lore -> Names
forall op. AliasedOp op => op -> Names
consumedInOp SegOp lore
op
  consumedInOp (OtherOp op
op) = op -> Names
forall op. AliasedOp op => op -> Names
consumedInOp op
op
  consumedInOp (SizeOp SizeOp
op) = SizeOp -> Names
forall op. AliasedOp op => op -> Names
consumedInOp SizeOp
op

instance (Attributes lore, RangedOp op) => RangedOp (HostOp lore op) where
  opRanges :: HostOp lore op -> [Range]
opRanges (SegOp SegOp lore
op) = SegOp lore -> [Range]
forall op. RangedOp op => op -> [Range]
opRanges SegOp lore
op
  opRanges (OtherOp op
op) = op -> [Range]
forall op. RangedOp op => op -> [Range]
opRanges op
op
  opRanges (SizeOp SizeOp
op) = SizeOp -> [Range]
forall op. RangedOp op => op -> [Range]
opRanges SizeOp
op

instance (Attributes lore, FreeIn op) => FreeIn (HostOp lore op) where
  freeIn' :: HostOp lore op -> FV
freeIn' (SegOp SegOp lore
op) = SegOp lore -> FV
forall a. FreeIn a => a -> FV
freeIn' SegOp lore
op
  freeIn' (OtherOp op
op) = op -> FV
forall a. FreeIn a => a -> FV
freeIn' op
op
  freeIn' (SizeOp SizeOp
op) = SizeOp -> FV
forall a. FreeIn a => a -> FV
freeIn' SizeOp
op

instance (CanBeAliased (Op lore), CanBeAliased op, Attributes lore) => CanBeAliased (HostOp lore op) where
  type OpWithAliases (HostOp lore op) = HostOp (Aliases lore) (OpWithAliases op)

  addOpAliases :: HostOp lore op -> OpWithAliases (HostOp lore op)
addOpAliases (SegOp SegOp lore
op) = SegOp (Aliases lore) -> HostOp (Aliases lore) (OpWithAliases op)
forall lore op. SegOp lore -> HostOp lore op
SegOp (SegOp (Aliases lore) -> HostOp (Aliases lore) (OpWithAliases op))
-> SegOp (Aliases lore) -> HostOp (Aliases lore) (OpWithAliases op)
forall a b. (a -> b) -> a -> b
$ SegOp lore -> OpWithAliases (SegOp lore)
forall op. CanBeAliased op => op -> OpWithAliases op
addOpAliases SegOp lore
op
  addOpAliases (OtherOp op
op) = OpWithAliases op -> HostOp (Aliases lore) (OpWithAliases op)
forall lore op. op -> HostOp lore op
OtherOp (OpWithAliases op -> HostOp (Aliases lore) (OpWithAliases op))
-> OpWithAliases op -> HostOp (Aliases lore) (OpWithAliases op)
forall a b. (a -> b) -> a -> b
$ op -> OpWithAliases op
forall op. CanBeAliased op => op -> OpWithAliases op
addOpAliases op
op
  addOpAliases (SizeOp SizeOp
op) = SizeOp -> HostOp (Aliases lore) (OpWithAliases op)
forall lore op. SizeOp -> HostOp lore op
SizeOp SizeOp
op

  removeOpAliases :: OpWithAliases (HostOp lore op) -> HostOp lore op
removeOpAliases (SegOp op) = SegOp lore -> HostOp lore op
forall lore op. SegOp lore -> HostOp lore op
SegOp (SegOp lore -> HostOp lore op) -> SegOp lore -> HostOp lore op
forall a b. (a -> b) -> a -> b
$ OpWithAliases (SegOp lore) -> SegOp lore
forall op. CanBeAliased op => OpWithAliases op -> op
removeOpAliases OpWithAliases (SegOp lore)
SegOp (Aliases lore)
op
  removeOpAliases (OtherOp op) = op -> HostOp lore op
forall lore op. op -> HostOp lore op
OtherOp (op -> HostOp lore op) -> op -> HostOp lore op
forall a b. (a -> b) -> a -> b
$ OpWithAliases op -> op
forall op. CanBeAliased op => OpWithAliases op -> op
removeOpAliases OpWithAliases op
op
  removeOpAliases (SizeOp op) = SizeOp -> HostOp lore op
forall lore op. SizeOp -> HostOp lore op
SizeOp SizeOp
op

instance (CanBeRanged (Op lore), CanBeRanged op, Attributes lore) => CanBeRanged (HostOp lore op) where
  type OpWithRanges (HostOp lore op) = HostOp (Ranges lore) (OpWithRanges op)

  addOpRanges :: HostOp lore op -> OpWithRanges (HostOp lore op)
addOpRanges (SegOp SegOp lore
op) = SegOp (Ranges lore) -> HostOp (Ranges lore) (OpWithRanges op)
forall lore op. SegOp lore -> HostOp lore op
SegOp (SegOp (Ranges lore) -> HostOp (Ranges lore) (OpWithRanges op))
-> SegOp (Ranges lore) -> HostOp (Ranges lore) (OpWithRanges op)
forall a b. (a -> b) -> a -> b
$ SegOp lore -> OpWithRanges (SegOp lore)
forall op. CanBeRanged op => op -> OpWithRanges op
addOpRanges SegOp lore
op
  addOpRanges (OtherOp op
op) = OpWithRanges op -> HostOp (Ranges lore) (OpWithRanges op)
forall lore op. op -> HostOp lore op
OtherOp (OpWithRanges op -> HostOp (Ranges lore) (OpWithRanges op))
-> OpWithRanges op -> HostOp (Ranges lore) (OpWithRanges op)
forall a b. (a -> b) -> a -> b
$ op -> OpWithRanges op
forall op. CanBeRanged op => op -> OpWithRanges op
addOpRanges op
op
  addOpRanges (SizeOp SizeOp
op) = SizeOp -> HostOp (Ranges lore) (OpWithRanges op)
forall lore op. SizeOp -> HostOp lore op
SizeOp SizeOp
op

  removeOpRanges :: OpWithRanges (HostOp lore op) -> HostOp lore op
removeOpRanges (SegOp op) = SegOp lore -> HostOp lore op
forall lore op. SegOp lore -> HostOp lore op
SegOp (SegOp lore -> HostOp lore op) -> SegOp lore -> HostOp lore op
forall a b. (a -> b) -> a -> b
$ OpWithRanges (SegOp lore) -> SegOp lore
forall op. CanBeRanged op => OpWithRanges op -> op
removeOpRanges OpWithRanges (SegOp lore)
SegOp (Ranges lore)
op
  removeOpRanges (OtherOp op) = op -> HostOp lore op
forall lore op. op -> HostOp lore op
OtherOp (op -> HostOp lore op) -> op -> HostOp lore op
forall a b. (a -> b) -> a -> b
$ OpWithRanges op -> op
forall op. CanBeRanged op => OpWithRanges op -> op
removeOpRanges OpWithRanges op
op
  removeOpRanges (SizeOp op) = SizeOp -> HostOp lore op
forall lore op. SizeOp -> HostOp lore op
SizeOp SizeOp
op

instance (CanBeWise (Op lore), CanBeWise op, Attributes lore) => CanBeWise (HostOp lore op) where
  type OpWithWisdom (HostOp lore op) = HostOp (Wise lore) (OpWithWisdom op)

  removeOpWisdom :: OpWithWisdom (HostOp lore op) -> HostOp lore op
removeOpWisdom (SegOp op) = SegOp lore -> HostOp lore op
forall lore op. SegOp lore -> HostOp lore op
SegOp (SegOp lore -> HostOp lore op) -> SegOp lore -> HostOp lore op
forall a b. (a -> b) -> a -> b
$ OpWithWisdom (SegOp lore) -> SegOp lore
forall op. CanBeWise op => OpWithWisdom op -> op
removeOpWisdom OpWithWisdom (SegOp lore)
SegOp (Wise lore)
op
  removeOpWisdom (OtherOp op) = op -> HostOp lore op
forall lore op. op -> HostOp lore op
OtherOp (op -> HostOp lore op) -> op -> HostOp lore op
forall a b. (a -> b) -> a -> b
$ OpWithWisdom op -> op
forall op. CanBeWise op => OpWithWisdom op -> op
removeOpWisdom OpWithWisdom op
op
  removeOpWisdom (SizeOp op) = SizeOp -> HostOp lore op
forall lore op. SizeOp -> HostOp lore op
SizeOp SizeOp
op

instance (Attributes lore, ST.IndexOp op) => ST.IndexOp (HostOp lore op) where
  indexOp :: SymbolTable lore
-> Int -> HostOp lore op -> [PrimExp VName] -> Maybe Indexed
indexOp SymbolTable lore
vtable Int
k (SegOp SegOp lore
op) [PrimExp VName]
is = SymbolTable lore
-> Int -> SegOp lore -> [PrimExp VName] -> Maybe Indexed
forall op lore.
(IndexOp op, Attributes lore, IndexOp (Op lore)) =>
SymbolTable lore -> Int -> op -> [PrimExp VName] -> Maybe Indexed
ST.indexOp SymbolTable lore
vtable Int
k SegOp lore
op [PrimExp VName]
is
  indexOp SymbolTable lore
vtable Int
k (OtherOp op
op) [PrimExp VName]
is = SymbolTable lore -> Int -> op -> [PrimExp VName] -> Maybe Indexed
forall op lore.
(IndexOp op, Attributes lore, IndexOp (Op lore)) =>
SymbolTable lore -> Int -> op -> [PrimExp VName] -> Maybe Indexed
ST.indexOp SymbolTable lore
vtable Int
k op
op [PrimExp VName]
is
  indexOp SymbolTable lore
_ Int
_ HostOp lore op
_ [PrimExp VName]
_ = Maybe Indexed
forall a. Maybe a
Nothing

instance (PrettyLore lore, PP.Pretty op) => PP.Pretty (HostOp lore op) where
  ppr :: HostOp lore op -> Doc
ppr (SegOp SegOp lore
op) = SegOp lore -> Doc
forall a. Pretty a => a -> Doc
ppr SegOp lore
op
  ppr (OtherOp op
op) = op -> Doc
forall a. Pretty a => a -> Doc
ppr op
op
  ppr (SizeOp SizeOp
op) = SizeOp -> Doc
forall a. Pretty a => a -> Doc
ppr SizeOp
op

instance (OpMetrics (Op lore), OpMetrics op) => OpMetrics (HostOp lore op) where
  opMetrics :: HostOp lore op -> MetricsM ()
opMetrics (SegOp SegOp lore
op) = SegOp lore -> MetricsM ()
forall op. OpMetrics op => op -> MetricsM ()
opMetrics SegOp lore
op
  opMetrics (OtherOp op
op) = op -> MetricsM ()
forall op. OpMetrics op => op -> MetricsM ()
opMetrics op
op
  opMetrics (SizeOp SizeOp
op) = SizeOp -> MetricsM ()
forall op. OpMetrics op => op -> MetricsM ()
opMetrics SizeOp
op

typeCheckHostOp :: TC.Checkable lore =>
                   (SegLevel -> OpWithAliases (Op lore) -> TC.TypeM lore ())
                -> Maybe SegLevel
                -> (op -> TC.TypeM lore ())
                -> HostOp (Aliases lore) op
                -> TC.TypeM lore ()
typeCheckHostOp :: (SegLevel -> OpWithAliases (Op lore) -> TypeM lore ())
-> Maybe SegLevel
-> (op -> TypeM lore ())
-> HostOp (Aliases lore) op
-> TypeM lore ()
typeCheckHostOp SegLevel -> OpWithAliases (Op lore) -> TypeM lore ()
checker Maybe SegLevel
lvl op -> TypeM lore ()
_ (SegOp SegOp (Aliases lore)
op) =
  (OpWithAliases (Op lore) -> TypeM lore ())
-> TypeM lore () -> TypeM lore ()
forall lore a.
(OpWithAliases (Op lore) -> TypeM lore ())
-> TypeM lore a -> TypeM lore a
TC.checkOpWith (SegLevel -> OpWithAliases (Op lore) -> TypeM lore ()
checker (SegLevel -> OpWithAliases (Op lore) -> TypeM lore ())
-> SegLevel -> OpWithAliases (Op lore) -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ SegOp (Aliases lore) -> SegLevel
forall lore. SegOp lore -> SegLevel
segLevel SegOp (Aliases lore)
op) (TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
  Maybe SegLevel -> SegOp (Aliases lore) -> TypeM lore ()
forall lore.
Checkable lore =>
Maybe SegLevel -> SegOp (Aliases lore) -> TypeM lore ()
typeCheckSegOp Maybe SegLevel
lvl SegOp (Aliases lore)
op
typeCheckHostOp SegLevel -> OpWithAliases (Op lore) -> TypeM lore ()
_ Maybe SegLevel
_ op -> TypeM lore ()
f (OtherOp op
op) = op -> TypeM lore ()
f op
op
typeCheckHostOp SegLevel -> OpWithAliases (Op lore) -> TypeM lore ()
_ Maybe SegLevel
_ op -> TypeM lore ()
_ (SizeOp SizeOp
op) = SizeOp -> TypeM lore ()
forall lore. Checkable lore => SizeOp -> TypeM lore ()
typeCheckSizeOp SizeOp
op