{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Futhark.Representation.Kernels.Kernel
( HistOp(..)
, histType
, SegRedOp(..)
, segRedResults
, KernelBody(..)
, aliasAnalyseKernelBody
, consumedInKernelBody
, ResultManifest(..)
, KernelResult(..)
, kernelResultSubExp
, SplitOrdering(..)
, SegOp(..)
, SegLevel(..)
, SegVirt(..)
, segLevel
, segSpace
, typeCheckSegOp
, SegSpace(..)
, scopeOfSegSpace
, segSpaceDims
, SegOpMapper(..)
, identitySegOpMapper
, mapSegOpM
, SizeOp(..)
, HostOp(..)
, typeCheckHostOp
, module Futhark.Representation.Kernels.Sizes
)
where
import Control.Arrow (first)
import Control.Monad.State.Strict
import Control.Monad.Writer hiding (mapM_)
import Control.Monad.Identity hiding (mapM_)
import qualified Data.Map.Strict as M
import Data.List (intersperse)
import Futhark.Representation.AST
import qualified Futhark.Analysis.Alias as Alias
import qualified Futhark.Analysis.ScalExp as SE
import qualified Futhark.Analysis.SymbolTable as ST
import Futhark.Analysis.PrimExp.Convert
import qualified Futhark.Util.Pretty as PP
import Futhark.Util.Pretty
((</>), (<+>), ppr, commasep, Pretty, parens, text)
import Futhark.Transform.Substitute
import Futhark.Transform.Rename
import Futhark.Optimise.Simplify.Lore
import Futhark.Representation.Ranges
(Ranges, removeLambdaRanges, removeStmRanges, mkBodyRanges)
import Futhark.Representation.AST.Attributes.Ranges
import Futhark.Representation.AST.Attributes.Aliases
import Futhark.Representation.Aliases
(Aliases, removeLambdaAliases, removeStmAliases)
import Futhark.Representation.Kernels.Sizes
import qualified Futhark.TypeCheck as TC
import Futhark.Analysis.Metrics
import qualified Futhark.Analysis.Range as Range
import Futhark.Util (maybeNth)
data SplitOrdering = SplitContiguous
| SplitStrided SubExp
deriving (SplitOrdering -> SplitOrdering -> Bool
(SplitOrdering -> SplitOrdering -> Bool)
-> (SplitOrdering -> SplitOrdering -> Bool) -> Eq SplitOrdering
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SplitOrdering -> SplitOrdering -> Bool
$c/= :: SplitOrdering -> SplitOrdering -> Bool
== :: SplitOrdering -> SplitOrdering -> Bool
$c== :: SplitOrdering -> SplitOrdering -> Bool
Eq, Eq SplitOrdering
Eq SplitOrdering
-> (SplitOrdering -> SplitOrdering -> Ordering)
-> (SplitOrdering -> SplitOrdering -> Bool)
-> (SplitOrdering -> SplitOrdering -> Bool)
-> (SplitOrdering -> SplitOrdering -> Bool)
-> (SplitOrdering -> SplitOrdering -> Bool)
-> (SplitOrdering -> SplitOrdering -> SplitOrdering)
-> (SplitOrdering -> SplitOrdering -> SplitOrdering)
-> Ord SplitOrdering
SplitOrdering -> SplitOrdering -> Bool
SplitOrdering -> SplitOrdering -> Ordering
SplitOrdering -> SplitOrdering -> SplitOrdering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: SplitOrdering -> SplitOrdering -> SplitOrdering
$cmin :: SplitOrdering -> SplitOrdering -> SplitOrdering
max :: SplitOrdering -> SplitOrdering -> SplitOrdering
$cmax :: SplitOrdering -> SplitOrdering -> SplitOrdering
>= :: SplitOrdering -> SplitOrdering -> Bool
$c>= :: SplitOrdering -> SplitOrdering -> Bool
> :: SplitOrdering -> SplitOrdering -> Bool
$c> :: SplitOrdering -> SplitOrdering -> Bool
<= :: SplitOrdering -> SplitOrdering -> Bool
$c<= :: SplitOrdering -> SplitOrdering -> Bool
< :: SplitOrdering -> SplitOrdering -> Bool
$c< :: SplitOrdering -> SplitOrdering -> Bool
compare :: SplitOrdering -> SplitOrdering -> Ordering
$ccompare :: SplitOrdering -> SplitOrdering -> Ordering
$cp1Ord :: Eq SplitOrdering
Ord, Int -> SplitOrdering -> ShowS
[SplitOrdering] -> ShowS
SplitOrdering -> String
(Int -> SplitOrdering -> ShowS)
-> (SplitOrdering -> String)
-> ([SplitOrdering] -> ShowS)
-> Show SplitOrdering
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SplitOrdering] -> ShowS
$cshowList :: [SplitOrdering] -> ShowS
show :: SplitOrdering -> String
$cshow :: SplitOrdering -> String
showsPrec :: Int -> SplitOrdering -> ShowS
$cshowsPrec :: Int -> SplitOrdering -> ShowS
Show)
instance FreeIn SplitOrdering where
freeIn' :: SplitOrdering -> FV
freeIn' SplitOrdering
SplitContiguous = FV
forall a. Monoid a => a
mempty
freeIn' (SplitStrided SubExp
stride) = SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn' SubExp
stride
instance Substitute SplitOrdering where
substituteNames :: Map VName VName -> SplitOrdering -> SplitOrdering
substituteNames Map VName VName
_ SplitOrdering
SplitContiguous =
SplitOrdering
SplitContiguous
substituteNames Map VName VName
subst (SplitStrided SubExp
stride) =
SubExp -> SplitOrdering
SplitStrided (SubExp -> SplitOrdering) -> SubExp -> SplitOrdering
forall a b. (a -> b) -> a -> b
$ Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst SubExp
stride
instance Rename SplitOrdering where
rename :: SplitOrdering -> RenameM SplitOrdering
rename SplitOrdering
SplitContiguous =
SplitOrdering -> RenameM SplitOrdering
forall (f :: * -> *) a. Applicative f => a -> f a
pure SplitOrdering
SplitContiguous
rename (SplitStrided SubExp
stride) =
SubExp -> SplitOrdering
SplitStrided (SubExp -> SplitOrdering)
-> RenameM SubExp -> RenameM SplitOrdering
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> RenameM SubExp
forall a. Rename a => a -> RenameM a
rename SubExp
stride
data HistOp lore =
HistOp { HistOp lore -> SubExp
histWidth :: SubExp
, HistOp lore -> SubExp
histRaceFactor :: SubExp
, HistOp lore -> [VName]
histDest :: [VName]
, HistOp lore -> [SubExp]
histNeutral :: [SubExp]
, HistOp lore -> Shape
histShape :: Shape
, HistOp lore -> Lambda lore
histOp :: Lambda lore
}
deriving (HistOp lore -> HistOp lore -> Bool
(HistOp lore -> HistOp lore -> Bool)
-> (HistOp lore -> HistOp lore -> Bool) -> Eq (HistOp lore)
forall lore. Annotations lore => HistOp lore -> HistOp lore -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: HistOp lore -> HistOp lore -> Bool
$c/= :: forall lore. Annotations lore => HistOp lore -> HistOp lore -> Bool
== :: HistOp lore -> HistOp lore -> Bool
$c== :: forall lore. Annotations lore => HistOp lore -> HistOp lore -> Bool
Eq, Eq (HistOp lore)
Eq (HistOp lore)
-> (HistOp lore -> HistOp lore -> Ordering)
-> (HistOp lore -> HistOp lore -> Bool)
-> (HistOp lore -> HistOp lore -> Bool)
-> (HistOp lore -> HistOp lore -> Bool)
-> (HistOp lore -> HistOp lore -> Bool)
-> (HistOp lore -> HistOp lore -> HistOp lore)
-> (HistOp lore -> HistOp lore -> HistOp lore)
-> Ord (HistOp lore)
HistOp lore -> HistOp lore -> Bool
HistOp lore -> HistOp lore -> Ordering
HistOp lore -> HistOp lore -> HistOp lore
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall lore. Annotations lore => Eq (HistOp lore)
forall lore. Annotations lore => HistOp lore -> HistOp lore -> Bool
forall lore.
Annotations lore =>
HistOp lore -> HistOp lore -> Ordering
forall lore.
Annotations lore =>
HistOp lore -> HistOp lore -> HistOp lore
min :: HistOp lore -> HistOp lore -> HistOp lore
$cmin :: forall lore.
Annotations lore =>
HistOp lore -> HistOp lore -> HistOp lore
max :: HistOp lore -> HistOp lore -> HistOp lore
$cmax :: forall lore.
Annotations lore =>
HistOp lore -> HistOp lore -> HistOp lore
>= :: HistOp lore -> HistOp lore -> Bool
$c>= :: forall lore. Annotations lore => HistOp lore -> HistOp lore -> Bool
> :: HistOp lore -> HistOp lore -> Bool
$c> :: forall lore. Annotations lore => HistOp lore -> HistOp lore -> Bool
<= :: HistOp lore -> HistOp lore -> Bool
$c<= :: forall lore. Annotations lore => HistOp lore -> HistOp lore -> Bool
< :: HistOp lore -> HistOp lore -> Bool
$c< :: forall lore. Annotations lore => HistOp lore -> HistOp lore -> Bool
compare :: HistOp lore -> HistOp lore -> Ordering
$ccompare :: forall lore.
Annotations lore =>
HistOp lore -> HistOp lore -> Ordering
$cp1Ord :: forall lore. Annotations lore => Eq (HistOp lore)
Ord, Int -> HistOp lore -> ShowS
[HistOp lore] -> ShowS
HistOp lore -> String
(Int -> HistOp lore -> ShowS)
-> (HistOp lore -> String)
-> ([HistOp lore] -> ShowS)
-> Show (HistOp lore)
forall lore. Annotations lore => Int -> HistOp lore -> ShowS
forall lore. Annotations lore => [HistOp lore] -> ShowS
forall lore. Annotations lore => HistOp lore -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [HistOp lore] -> ShowS
$cshowList :: forall lore. Annotations lore => [HistOp lore] -> ShowS
show :: HistOp lore -> String
$cshow :: forall lore. Annotations lore => HistOp lore -> String
showsPrec :: Int -> HistOp lore -> ShowS
$cshowsPrec :: forall lore. Annotations lore => Int -> HistOp lore -> ShowS
Show)
histType :: HistOp lore -> [Type]
histType :: HistOp lore -> [Type]
histType HistOp lore
op = (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map ((Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` HistOp lore -> SubExp
forall lore. HistOp lore -> SubExp
histWidth HistOp lore
op) (Type -> Type) -> (Type -> Type) -> Type -> Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
.
(Type -> Shape -> Type
`arrayOfShape` HistOp lore -> Shape
forall lore. HistOp lore -> Shape
histShape HistOp lore
op)) ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$
LambdaT lore -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType (LambdaT lore -> [Type]) -> LambdaT lore -> [Type]
forall a b. (a -> b) -> a -> b
$ HistOp lore -> LambdaT lore
forall lore. HistOp lore -> Lambda lore
histOp HistOp lore
op
data SegRedOp lore =
SegRedOp { SegRedOp lore -> Commutativity
segRedComm :: Commutativity
, SegRedOp lore -> Lambda lore
segRedLambda :: Lambda lore
, SegRedOp lore -> [SubExp]
segRedNeutral :: [SubExp]
, SegRedOp lore -> Shape
segRedShape :: Shape
}
deriving (SegRedOp lore -> SegRedOp lore -> Bool
(SegRedOp lore -> SegRedOp lore -> Bool)
-> (SegRedOp lore -> SegRedOp lore -> Bool) -> Eq (SegRedOp lore)
forall lore.
Annotations lore =>
SegRedOp lore -> SegRedOp lore -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SegRedOp lore -> SegRedOp lore -> Bool
$c/= :: forall lore.
Annotations lore =>
SegRedOp lore -> SegRedOp lore -> Bool
== :: SegRedOp lore -> SegRedOp lore -> Bool
$c== :: forall lore.
Annotations lore =>
SegRedOp lore -> SegRedOp lore -> Bool
Eq, Eq (SegRedOp lore)
Eq (SegRedOp lore)
-> (SegRedOp lore -> SegRedOp lore -> Ordering)
-> (SegRedOp lore -> SegRedOp lore -> Bool)
-> (SegRedOp lore -> SegRedOp lore -> Bool)
-> (SegRedOp lore -> SegRedOp lore -> Bool)
-> (SegRedOp lore -> SegRedOp lore -> Bool)
-> (SegRedOp lore -> SegRedOp lore -> SegRedOp lore)
-> (SegRedOp lore -> SegRedOp lore -> SegRedOp lore)
-> Ord (SegRedOp lore)
SegRedOp lore -> SegRedOp lore -> Bool
SegRedOp lore -> SegRedOp lore -> Ordering
SegRedOp lore -> SegRedOp lore -> SegRedOp lore
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall lore. Annotations lore => Eq (SegRedOp lore)
forall lore.
Annotations lore =>
SegRedOp lore -> SegRedOp lore -> Bool
forall lore.
Annotations lore =>
SegRedOp lore -> SegRedOp lore -> Ordering
forall lore.
Annotations lore =>
SegRedOp lore -> SegRedOp lore -> SegRedOp lore
min :: SegRedOp lore -> SegRedOp lore -> SegRedOp lore
$cmin :: forall lore.
Annotations lore =>
SegRedOp lore -> SegRedOp lore -> SegRedOp lore
max :: SegRedOp lore -> SegRedOp lore -> SegRedOp lore
$cmax :: forall lore.
Annotations lore =>
SegRedOp lore -> SegRedOp lore -> SegRedOp lore
>= :: SegRedOp lore -> SegRedOp lore -> Bool
$c>= :: forall lore.
Annotations lore =>
SegRedOp lore -> SegRedOp lore -> Bool
> :: SegRedOp lore -> SegRedOp lore -> Bool
$c> :: forall lore.
Annotations lore =>
SegRedOp lore -> SegRedOp lore -> Bool
<= :: SegRedOp lore -> SegRedOp lore -> Bool
$c<= :: forall lore.
Annotations lore =>
SegRedOp lore -> SegRedOp lore -> Bool
< :: SegRedOp lore -> SegRedOp lore -> Bool
$c< :: forall lore.
Annotations lore =>
SegRedOp lore -> SegRedOp lore -> Bool
compare :: SegRedOp lore -> SegRedOp lore -> Ordering
$ccompare :: forall lore.
Annotations lore =>
SegRedOp lore -> SegRedOp lore -> Ordering
$cp1Ord :: forall lore. Annotations lore => Eq (SegRedOp lore)
Ord, Int -> SegRedOp lore -> ShowS
[SegRedOp lore] -> ShowS
SegRedOp lore -> String
(Int -> SegRedOp lore -> ShowS)
-> (SegRedOp lore -> String)
-> ([SegRedOp lore] -> ShowS)
-> Show (SegRedOp lore)
forall lore. Annotations lore => Int -> SegRedOp lore -> ShowS
forall lore. Annotations lore => [SegRedOp lore] -> ShowS
forall lore. Annotations lore => SegRedOp lore -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SegRedOp lore] -> ShowS
$cshowList :: forall lore. Annotations lore => [SegRedOp lore] -> ShowS
show :: SegRedOp lore -> String
$cshow :: forall lore. Annotations lore => SegRedOp lore -> String
showsPrec :: Int -> SegRedOp lore -> ShowS
$cshowsPrec :: forall lore. Annotations lore => Int -> SegRedOp lore -> ShowS
Show)
segRedResults :: [SegRedOp lore] -> Int
segRedResults :: [SegRedOp lore] -> Int
segRedResults = [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Int] -> Int)
-> ([SegRedOp lore] -> [Int]) -> [SegRedOp lore] -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SegRedOp lore -> Int) -> [SegRedOp lore] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int)
-> (SegRedOp lore -> [SubExp]) -> SegRedOp lore -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegRedOp lore -> [SubExp]
forall lore. SegRedOp lore -> [SubExp]
segRedNeutral)
data KernelBody lore = KernelBody { KernelBody lore -> BodyAttr lore
kernelBodyLore :: BodyAttr lore
, KernelBody lore -> Stms lore
kernelBodyStms :: Stms lore
, KernelBody lore -> [KernelResult]
kernelBodyResult :: [KernelResult]
}
deriving instance Annotations lore => Ord (KernelBody lore)
deriving instance Annotations lore => Show (KernelBody lore)
deriving instance Annotations lore => Eq (KernelBody lore)
data ResultManifest
= ResultNoSimplify
| ResultMaySimplify
| ResultPrivate
deriving (ResultManifest -> ResultManifest -> Bool
(ResultManifest -> ResultManifest -> Bool)
-> (ResultManifest -> ResultManifest -> Bool) -> Eq ResultManifest
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ResultManifest -> ResultManifest -> Bool
$c/= :: ResultManifest -> ResultManifest -> Bool
== :: ResultManifest -> ResultManifest -> Bool
$c== :: ResultManifest -> ResultManifest -> Bool
Eq, Int -> ResultManifest -> ShowS
[ResultManifest] -> ShowS
ResultManifest -> String
(Int -> ResultManifest -> ShowS)
-> (ResultManifest -> String)
-> ([ResultManifest] -> ShowS)
-> Show ResultManifest
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ResultManifest] -> ShowS
$cshowList :: [ResultManifest] -> ShowS
show :: ResultManifest -> String
$cshow :: ResultManifest -> String
showsPrec :: Int -> ResultManifest -> ShowS
$cshowsPrec :: Int -> ResultManifest -> ShowS
Show, Eq ResultManifest
Eq ResultManifest
-> (ResultManifest -> ResultManifest -> Ordering)
-> (ResultManifest -> ResultManifest -> Bool)
-> (ResultManifest -> ResultManifest -> Bool)
-> (ResultManifest -> ResultManifest -> Bool)
-> (ResultManifest -> ResultManifest -> Bool)
-> (ResultManifest -> ResultManifest -> ResultManifest)
-> (ResultManifest -> ResultManifest -> ResultManifest)
-> Ord ResultManifest
ResultManifest -> ResultManifest -> Bool
ResultManifest -> ResultManifest -> Ordering
ResultManifest -> ResultManifest -> ResultManifest
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: ResultManifest -> ResultManifest -> ResultManifest
$cmin :: ResultManifest -> ResultManifest -> ResultManifest
max :: ResultManifest -> ResultManifest -> ResultManifest
$cmax :: ResultManifest -> ResultManifest -> ResultManifest
>= :: ResultManifest -> ResultManifest -> Bool
$c>= :: ResultManifest -> ResultManifest -> Bool
> :: ResultManifest -> ResultManifest -> Bool
$c> :: ResultManifest -> ResultManifest -> Bool
<= :: ResultManifest -> ResultManifest -> Bool
$c<= :: ResultManifest -> ResultManifest -> Bool
< :: ResultManifest -> ResultManifest -> Bool
$c< :: ResultManifest -> ResultManifest -> Bool
compare :: ResultManifest -> ResultManifest -> Ordering
$ccompare :: ResultManifest -> ResultManifest -> Ordering
$cp1Ord :: Eq ResultManifest
Ord)
data KernelResult = Returns ResultManifest SubExp
| WriteReturns
[SubExp]
VName
[([SubExp], SubExp)]
| ConcatReturns
SplitOrdering
SubExp
SubExp
VName
| TileReturns
[(SubExp, SubExp)]
VName
deriving (KernelResult -> KernelResult -> Bool
(KernelResult -> KernelResult -> Bool)
-> (KernelResult -> KernelResult -> Bool) -> Eq KernelResult
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: KernelResult -> KernelResult -> Bool
$c/= :: KernelResult -> KernelResult -> Bool
== :: KernelResult -> KernelResult -> Bool
$c== :: KernelResult -> KernelResult -> Bool
Eq, Int -> KernelResult -> ShowS
[KernelResult] -> ShowS
KernelResult -> String
(Int -> KernelResult -> ShowS)
-> (KernelResult -> String)
-> ([KernelResult] -> ShowS)
-> Show KernelResult
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [KernelResult] -> ShowS
$cshowList :: [KernelResult] -> ShowS
show :: KernelResult -> String
$cshow :: KernelResult -> String
showsPrec :: Int -> KernelResult -> ShowS
$cshowsPrec :: Int -> KernelResult -> ShowS
Show, Eq KernelResult
Eq KernelResult
-> (KernelResult -> KernelResult -> Ordering)
-> (KernelResult -> KernelResult -> Bool)
-> (KernelResult -> KernelResult -> Bool)
-> (KernelResult -> KernelResult -> Bool)
-> (KernelResult -> KernelResult -> Bool)
-> (KernelResult -> KernelResult -> KernelResult)
-> (KernelResult -> KernelResult -> KernelResult)
-> Ord KernelResult
KernelResult -> KernelResult -> Bool
KernelResult -> KernelResult -> Ordering
KernelResult -> KernelResult -> KernelResult
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: KernelResult -> KernelResult -> KernelResult
$cmin :: KernelResult -> KernelResult -> KernelResult
max :: KernelResult -> KernelResult -> KernelResult
$cmax :: KernelResult -> KernelResult -> KernelResult
>= :: KernelResult -> KernelResult -> Bool
$c>= :: KernelResult -> KernelResult -> Bool
> :: KernelResult -> KernelResult -> Bool
$c> :: KernelResult -> KernelResult -> Bool
<= :: KernelResult -> KernelResult -> Bool
$c<= :: KernelResult -> KernelResult -> Bool
< :: KernelResult -> KernelResult -> Bool
$c< :: KernelResult -> KernelResult -> Bool
compare :: KernelResult -> KernelResult -> Ordering
$ccompare :: KernelResult -> KernelResult -> Ordering
$cp1Ord :: Eq KernelResult
Ord)
kernelResultSubExp :: KernelResult -> SubExp
kernelResultSubExp :: KernelResult -> SubExp
kernelResultSubExp (Returns ResultManifest
_ SubExp
se) = SubExp
se
kernelResultSubExp (WriteReturns [SubExp]
_ VName
arr [([SubExp], SubExp)]
_) = VName -> SubExp
Var VName
arr
kernelResultSubExp (ConcatReturns SplitOrdering
_ SubExp
_ SubExp
_ VName
v) = VName -> SubExp
Var VName
v
kernelResultSubExp (TileReturns [(SubExp, SubExp)]
_ VName
v) = VName -> SubExp
Var VName
v
instance FreeIn KernelResult where
freeIn' :: KernelResult -> FV
freeIn' (Returns ResultManifest
_ SubExp
what) = SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn' SubExp
what
freeIn' (WriteReturns [SubExp]
rws VName
arr [([SubExp], SubExp)]
res) = [SubExp] -> FV
forall a. FreeIn a => a -> FV
freeIn' [SubExp]
rws FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> VName -> FV
forall a. FreeIn a => a -> FV
freeIn' VName
arr FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> [([SubExp], SubExp)] -> FV
forall a. FreeIn a => a -> FV
freeIn' [([SubExp], SubExp)]
res
freeIn' (ConcatReturns SplitOrdering
o SubExp
w SubExp
per_thread_elems VName
v) =
SplitOrdering -> FV
forall a. FreeIn a => a -> FV
freeIn' SplitOrdering
o FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn' SubExp
w FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn' SubExp
per_thread_elems FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> VName -> FV
forall a. FreeIn a => a -> FV
freeIn' VName
v
freeIn' (TileReturns [(SubExp, SubExp)]
dims VName
v) =
[(SubExp, SubExp)] -> FV
forall a. FreeIn a => a -> FV
freeIn' [(SubExp, SubExp)]
dims FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> VName -> FV
forall a. FreeIn a => a -> FV
freeIn' VName
v
instance Attributes lore => FreeIn (KernelBody lore) where
freeIn' :: KernelBody lore -> FV
freeIn' (KernelBody BodyAttr lore
attr Stms lore
stms [KernelResult]
res) =
Names -> FV -> FV
fvBind Names
bound_in_stms (FV -> FV) -> FV -> FV
forall a b. (a -> b) -> a -> b
$ BodyAttr lore -> FV
forall a. FreeIn a => a -> FV
freeIn' BodyAttr lore
attr FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> Stms lore -> FV
forall a. FreeIn a => a -> FV
freeIn' Stms lore
stms FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> [KernelResult] -> FV
forall a. FreeIn a => a -> FV
freeIn' [KernelResult]
res
where bound_in_stms :: Names
bound_in_stms = (Stm lore -> Names) -> Stms lore -> Names
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm lore -> Names
forall lore. Stm lore -> Names
boundByStm Stms lore
stms
instance Attributes lore => Substitute (KernelBody lore) where
substituteNames :: Map VName VName -> KernelBody lore -> KernelBody lore
substituteNames Map VName VName
subst (KernelBody BodyAttr lore
attr Stms lore
stms [KernelResult]
res) =
BodyAttr lore -> Stms lore -> [KernelResult] -> KernelBody lore
forall lore.
BodyAttr lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody
(Map VName VName -> BodyAttr lore -> BodyAttr lore
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst BodyAttr lore
attr)
(Map VName VName -> Stms lore -> Stms lore
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst Stms lore
stms)
(Map VName VName -> [KernelResult] -> [KernelResult]
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst [KernelResult]
res)
instance Substitute KernelResult where
substituteNames :: Map VName VName -> KernelResult -> KernelResult
substituteNames Map VName VName
subst (Returns ResultManifest
manifest SubExp
se) =
ResultManifest -> SubExp -> KernelResult
Returns ResultManifest
manifest (Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst SubExp
se)
substituteNames Map VName VName
subst (WriteReturns [SubExp]
rws VName
arr [([SubExp], SubExp)]
res) =
[SubExp] -> VName -> [([SubExp], SubExp)] -> KernelResult
WriteReturns
(Map VName VName -> [SubExp] -> [SubExp]
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst [SubExp]
rws) (Map VName VName -> VName -> VName
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst VName
arr)
(Map VName VName -> [([SubExp], SubExp)] -> [([SubExp], SubExp)]
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst [([SubExp], SubExp)]
res)
substituteNames Map VName VName
subst (ConcatReturns SplitOrdering
o SubExp
w SubExp
per_thread_elems VName
v) =
SplitOrdering -> SubExp -> SubExp -> VName -> KernelResult
ConcatReturns
(Map VName VName -> SplitOrdering -> SplitOrdering
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst SplitOrdering
o)
(Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst SubExp
w)
(Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst SubExp
per_thread_elems)
(Map VName VName -> VName -> VName
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst VName
v)
substituteNames Map VName VName
subst (TileReturns [(SubExp, SubExp)]
dims VName
v) =
[(SubExp, SubExp)] -> VName -> KernelResult
TileReturns (Map VName VName -> [(SubExp, SubExp)] -> [(SubExp, SubExp)]
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst [(SubExp, SubExp)]
dims) (Map VName VName -> VName -> VName
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst VName
v)
instance Attributes lore => Rename (KernelBody lore) where
rename :: KernelBody lore -> RenameM (KernelBody lore)
rename (KernelBody BodyAttr lore
attr Stms lore
stms [KernelResult]
res) = do
BodyAttr lore
attr' <- BodyAttr lore -> RenameM (BodyAttr lore)
forall a. Rename a => a -> RenameM a
rename BodyAttr lore
attr
Stms lore
-> (Stms lore -> RenameM (KernelBody lore))
-> RenameM (KernelBody lore)
forall lore a.
Renameable lore =>
Stms lore -> (Stms lore -> RenameM a) -> RenameM a
renamingStms Stms lore
stms ((Stms lore -> RenameM (KernelBody lore))
-> RenameM (KernelBody lore))
-> (Stms lore -> RenameM (KernelBody lore))
-> RenameM (KernelBody lore)
forall a b. (a -> b) -> a -> b
$ \Stms lore
stms' ->
BodyAttr lore -> Stms lore -> [KernelResult] -> KernelBody lore
forall lore.
BodyAttr lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody BodyAttr lore
attr' Stms lore
stms' ([KernelResult] -> KernelBody lore)
-> RenameM [KernelResult] -> RenameM (KernelBody lore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [KernelResult] -> RenameM [KernelResult]
forall a. Rename a => a -> RenameM a
rename [KernelResult]
res
instance Rename KernelResult where
rename :: KernelResult -> RenameM KernelResult
rename = KernelResult -> RenameM KernelResult
forall a. Substitute a => a -> RenameM a
substituteRename
aliasAnalyseKernelBody :: (Attributes lore,
CanBeAliased (Op lore)) =>
KernelBody lore
-> KernelBody (Aliases lore)
aliasAnalyseKernelBody :: KernelBody lore -> KernelBody (Aliases lore)
aliasAnalyseKernelBody (KernelBody BodyAttr lore
attr Stms lore
stms [KernelResult]
res) =
let Body BodyAttr (Aliases lore)
attr' Stms (Aliases lore)
stms' [SubExp]
_ = AliasTable -> Body lore -> BodyT (Aliases lore)
forall lore.
(Attributes lore, CanBeAliased (Op lore)) =>
AliasTable -> Body lore -> Body (Aliases lore)
Alias.analyseBody AliasTable
forall a. Monoid a => a
mempty (Body lore -> BodyT (Aliases lore))
-> Body lore -> BodyT (Aliases lore)
forall a b. (a -> b) -> a -> b
$ BodyAttr lore -> Stms lore -> [SubExp] -> Body lore
forall lore. BodyAttr lore -> Stms lore -> [SubExp] -> BodyT lore
Body BodyAttr lore
attr Stms lore
stms []
in BodyAttr (Aliases lore)
-> Stms (Aliases lore)
-> [KernelResult]
-> KernelBody (Aliases lore)
forall lore.
BodyAttr lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody BodyAttr (Aliases lore)
attr' Stms (Aliases lore)
stms' [KernelResult]
res
removeKernelBodyAliases :: CanBeAliased (Op lore) =>
KernelBody (Aliases lore) -> KernelBody lore
removeKernelBodyAliases :: KernelBody (Aliases lore) -> KernelBody lore
removeKernelBodyAliases (KernelBody (_, attr) Stms (Aliases lore)
stms [KernelResult]
res) =
BodyAttr lore -> Stms lore -> [KernelResult] -> KernelBody lore
forall lore.
BodyAttr lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody BodyAttr lore
attr ((Stm (Aliases lore) -> Stm lore)
-> Stms (Aliases lore) -> Stms lore
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Stm (Aliases lore) -> Stm lore
forall lore.
CanBeAliased (Op lore) =>
Stm (Aliases lore) -> Stm lore
removeStmAliases Stms (Aliases lore)
stms) [KernelResult]
res
addKernelBodyRanges :: (Attributes lore, CanBeRanged (Op lore)) =>
KernelBody lore -> Range.RangeM (KernelBody (Ranges lore))
addKernelBodyRanges :: KernelBody lore -> RangeM (KernelBody (Ranges lore))
addKernelBodyRanges (KernelBody BodyAttr lore
attr Stms lore
stms [KernelResult]
res) =
Stms lore
-> (Stms (Ranges lore) -> RangeM (KernelBody (Ranges lore)))
-> RangeM (KernelBody (Ranges lore))
forall lore a.
(Attributes lore, CanBeRanged (Op lore)) =>
Stms lore -> (Stms (Ranges lore) -> RangeM a) -> RangeM a
Range.analyseStms Stms lore
stms ((Stms (Ranges lore) -> RangeM (KernelBody (Ranges lore)))
-> RangeM (KernelBody (Ranges lore)))
-> (Stms (Ranges lore) -> RangeM (KernelBody (Ranges lore)))
-> RangeM (KernelBody (Ranges lore))
forall a b. (a -> b) -> a -> b
$ \Stms (Ranges lore)
stms' -> do
let attr' :: ([Range], BodyAttr lore)
attr' = (Stms lore -> [SubExp] -> [Range]
forall lore. Stms lore -> [SubExp] -> [Range]
mkBodyRanges Stms lore
stms ([SubExp] -> [Range]) -> [SubExp] -> [Range]
forall a b. (a -> b) -> a -> b
$ (KernelResult -> SubExp) -> [KernelResult] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> SubExp
kernelResultSubExp [KernelResult]
res, BodyAttr lore
attr)
KernelBody (Ranges lore) -> RangeM (KernelBody (Ranges lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (KernelBody (Ranges lore) -> RangeM (KernelBody (Ranges lore)))
-> KernelBody (Ranges lore) -> RangeM (KernelBody (Ranges lore))
forall a b. (a -> b) -> a -> b
$ BodyAttr (Ranges lore)
-> Stms (Ranges lore) -> [KernelResult] -> KernelBody (Ranges lore)
forall lore.
BodyAttr lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody ([Range], BodyAttr lore)
BodyAttr (Ranges lore)
attr' Stms (Ranges lore)
stms' [KernelResult]
res
removeKernelBodyRanges :: CanBeRanged (Op lore) =>
KernelBody (Ranges lore) -> KernelBody lore
removeKernelBodyRanges :: KernelBody (Ranges lore) -> KernelBody lore
removeKernelBodyRanges (KernelBody (_, attr) Stms (Ranges lore)
stms [KernelResult]
res) =
BodyAttr lore -> Stms lore -> [KernelResult] -> KernelBody lore
forall lore.
BodyAttr lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody BodyAttr lore
attr ((Stm (Ranges lore) -> Stm lore) -> Stms (Ranges lore) -> Stms lore
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Stm (Ranges lore) -> Stm lore
forall lore. CanBeRanged (Op lore) => Stm (Ranges lore) -> Stm lore
removeStmRanges Stms (Ranges lore)
stms) [KernelResult]
res
removeKernelBodyWisdom :: CanBeWise (Op lore) =>
KernelBody (Wise lore) -> KernelBody lore
removeKernelBodyWisdom :: KernelBody (Wise lore) -> KernelBody lore
removeKernelBodyWisdom (KernelBody BodyAttr (Wise lore)
attr Stms (Wise lore)
stms [KernelResult]
res) =
let Body BodyAttr lore
attr' Stms lore
stms' [SubExp]
_ = Body (Wise lore) -> BodyT lore
forall lore. CanBeWise (Op lore) => Body (Wise lore) -> Body lore
removeBodyWisdom (Body (Wise lore) -> BodyT lore) -> Body (Wise lore) -> BodyT lore
forall a b. (a -> b) -> a -> b
$ BodyAttr (Wise lore)
-> Stms (Wise lore) -> [SubExp] -> Body (Wise lore)
forall lore. BodyAttr lore -> Stms lore -> [SubExp] -> BodyT lore
Body BodyAttr (Wise lore)
attr Stms (Wise lore)
stms []
in BodyAttr lore -> Stms lore -> [KernelResult] -> KernelBody lore
forall lore.
BodyAttr lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody BodyAttr lore
attr' Stms lore
stms' [KernelResult]
res
consumedInKernelBody :: Aliased lore =>
KernelBody lore -> Names
consumedInKernelBody :: KernelBody lore -> Names
consumedInKernelBody (KernelBody BodyAttr lore
attr Stms lore
stms [KernelResult]
res) =
Body lore -> Names
forall lore. Aliased lore => Body lore -> Names
consumedInBody (BodyAttr lore -> Stms lore -> [SubExp] -> Body lore
forall lore. BodyAttr lore -> Stms lore -> [SubExp] -> BodyT lore
Body BodyAttr lore
attr Stms lore
stms []) Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ((KernelResult -> Names) -> [KernelResult] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> Names
consumedByReturn [KernelResult]
res)
where consumedByReturn :: KernelResult -> Names
consumedByReturn (WriteReturns [SubExp]
_ VName
a [([SubExp], SubExp)]
_) = VName -> Names
oneName VName
a
consumedByReturn KernelResult
_ = Names
forall a. Monoid a => a
mempty
checkKernelBody :: TC.Checkable lore =>
[Type] -> KernelBody (Aliases lore) -> TC.TypeM lore ()
checkKernelBody :: [Type] -> KernelBody (Aliases lore) -> TypeM lore ()
checkKernelBody [Type]
ts (KernelBody (_, attr) Stms (Aliases lore)
stms [KernelResult]
kres) = do
BodyAttr lore -> TypeM lore ()
forall lore. Checkable lore => BodyAttr lore -> TypeM lore ()
TC.checkBodyLore BodyAttr lore
attr
Stms (Aliases lore) -> TypeM lore () -> TypeM lore ()
forall lore a.
Checkable lore =>
Stms (Aliases lore) -> TypeM lore a -> TypeM lore a
TC.checkStms Stms (Aliases lore)
stms (TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ do
Bool -> TypeM lore () -> TypeM lore ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
ts Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [KernelResult] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [KernelResult]
kres) (TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
ErrorCase lore -> TypeM lore ()
forall lore a. ErrorCase lore -> TypeM lore a
TC.bad (ErrorCase lore -> TypeM lore ())
-> ErrorCase lore -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ String -> ErrorCase lore
forall lore. String -> ErrorCase lore
TC.TypeError (String -> ErrorCase lore) -> String -> ErrorCase lore
forall a b. (a -> b) -> a -> b
$ String
"Kernel return type is " String -> ShowS
forall a. [a] -> [a] -> [a]
++ [Type] -> String
forall a. Pretty a => [a] -> String
prettyTuple [Type]
ts String -> ShowS
forall a. [a] -> [a] -> [a]
++
String
", but body returns " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show ([KernelResult] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [KernelResult]
kres) String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" values."
(KernelResult -> Type -> TypeM lore ())
-> [KernelResult] -> [Type] -> TypeM lore ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ KernelResult -> Type -> TypeM lore ()
forall lore.
Checkable lore =>
KernelResult -> Type -> TypeM lore ()
checkKernelResult [KernelResult]
kres [Type]
ts
where checkKernelResult :: KernelResult -> Type -> TypeM lore ()
checkKernelResult (Returns ResultManifest
_ SubExp
what) Type
t =
[Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [Type
t] SubExp
what
checkKernelResult (WriteReturns [SubExp]
rws VName
arr [([SubExp], SubExp)]
res) Type
t = do
(SubExp -> TypeM lore ()) -> [SubExp] -> TypeM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ([Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32]) [SubExp]
rws
Type
arr_t <- VName -> TypeM lore Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
arr
[([SubExp], SubExp)]
-> (([SubExp], SubExp) -> TypeM lore ()) -> TypeM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [([SubExp], SubExp)]
res ((([SubExp], SubExp) -> TypeM lore ()) -> TypeM lore ())
-> (([SubExp], SubExp) -> TypeM lore ()) -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ \([SubExp]
is, SubExp
e) -> do
(SubExp -> TypeM lore ()) -> [SubExp] -> TypeM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ([Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32]) [SubExp]
is
[Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [Type
t] SubExp
e
Bool -> TypeM lore () -> TypeM lore ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Type
arr_t Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
== Type
t Type -> Shape -> Type
`arrayOfShape` [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
rws) (TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
ErrorCase lore -> TypeM lore ()
forall lore a. ErrorCase lore -> TypeM lore a
TC.bad (ErrorCase lore -> TypeM lore ())
-> ErrorCase lore -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ String -> ErrorCase lore
forall lore. String -> ErrorCase lore
TC.TypeError (String -> ErrorCase lore) -> String -> ErrorCase lore
forall a b. (a -> b) -> a -> b
$ String
"WriteReturns returning " String -> ShowS
forall a. [a] -> [a] -> [a]
++
SubExp -> String
forall a. Pretty a => a -> String
pretty SubExp
e String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" of type " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Type -> String
forall a. Pretty a => a -> String
pretty Type
t String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
", shape=" String -> ShowS
forall a. [a] -> [a] -> [a]
++ [SubExp] -> String
forall a. Pretty a => a -> String
pretty [SubExp]
rws String -> ShowS
forall a. [a] -> [a] -> [a]
++
String
", but destination array has type " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Type -> String
forall a. Pretty a => a -> String
pretty Type
arr_t
Names -> TypeM lore ()
forall lore. Checkable lore => Names -> TypeM lore ()
TC.consume (Names -> TypeM lore ()) -> TypeM lore Names -> TypeM lore ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> TypeM lore Names
forall lore. Checkable lore => VName -> TypeM lore Names
TC.lookupAliases VName
arr
checkKernelResult (ConcatReturns SplitOrdering
o SubExp
w SubExp
per_thread_elems VName
v) Type
t = do
case SplitOrdering
o of
SplitOrdering
SplitContiguous -> () -> TypeM lore ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
SplitStrided SubExp
stride -> [Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32] SubExp
stride
[Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32] SubExp
w
[Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32] SubExp
per_thread_elems
Type
vt <- VName -> TypeM lore Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
v
Bool -> TypeM lore () -> TypeM lore ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Type
vt Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
== Type
t Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` Int -> Type -> SubExp
forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 Type
vt) (TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
ErrorCase lore -> TypeM lore ()
forall lore a. ErrorCase lore -> TypeM lore a
TC.bad (ErrorCase lore -> TypeM lore ())
-> ErrorCase lore -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ String -> ErrorCase lore
forall lore. String -> ErrorCase lore
TC.TypeError (String -> ErrorCase lore) -> String -> ErrorCase lore
forall a b. (a -> b) -> a -> b
$ String
"Invalid type for ConcatReturns " String -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
pretty VName
v
checkKernelResult (TileReturns [(SubExp, SubExp)]
dims VName
v) Type
t = do
[(SubExp, SubExp)]
-> ((SubExp, SubExp) -> TypeM lore ()) -> TypeM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(SubExp, SubExp)]
dims (((SubExp, SubExp) -> TypeM lore ()) -> TypeM lore ())
-> ((SubExp, SubExp) -> TypeM lore ()) -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ \(SubExp
dim, SubExp
tile) -> do
[Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32] SubExp
dim
[Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32] SubExp
tile
Type
vt <- VName -> TypeM lore Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
v
Bool -> TypeM lore () -> TypeM lore ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Type
vt Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
== Type
t Type -> Shape -> Type
`arrayOfShape` [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape (((SubExp, SubExp) -> SubExp) -> [(SubExp, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(SubExp, SubExp)]
dims)) (TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
ErrorCase lore -> TypeM lore ()
forall lore a. ErrorCase lore -> TypeM lore a
TC.bad (ErrorCase lore -> TypeM lore ())
-> ErrorCase lore -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ String -> ErrorCase lore
forall lore. String -> ErrorCase lore
TC.TypeError (String -> ErrorCase lore) -> String -> ErrorCase lore
forall a b. (a -> b) -> a -> b
$ String
"Invalid type for TileReturns " String -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
pretty VName
v
kernelBodyMetrics :: OpMetrics (Op lore) => KernelBody lore -> MetricsM ()
kernelBodyMetrics :: KernelBody lore -> MetricsM ()
kernelBodyMetrics = (Stm lore -> MetricsM ()) -> Seq (Stm lore) -> MetricsM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Stm lore -> MetricsM ()
forall lore. OpMetrics (Op lore) => Stm lore -> MetricsM ()
bindingMetrics (Seq (Stm lore) -> MetricsM ())
-> (KernelBody lore -> Seq (Stm lore))
-> KernelBody lore
-> MetricsM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelBody lore -> Seq (Stm lore)
forall lore. KernelBody lore -> Stms lore
kernelBodyStms
instance PrettyLore lore => Pretty (KernelBody lore) where
ppr :: KernelBody lore -> Doc
ppr (KernelBody BodyAttr lore
_ Stms lore
stms [KernelResult]
res) =
[Doc] -> Doc
PP.stack ((Stm lore -> Doc) -> [Stm lore] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map Stm lore -> Doc
forall a. Pretty a => a -> Doc
ppr (Stms lore -> [Stm lore]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms lore
stms)) Doc -> Doc -> Doc
</>
String -> Doc
text String
"return" Doc -> Doc -> Doc
<+> Doc -> Doc
PP.braces ([Doc] -> Doc
PP.commasep ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ (KernelResult -> Doc) -> [KernelResult] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> Doc
forall a. Pretty a => a -> Doc
ppr [KernelResult]
res)
instance Pretty KernelResult where
ppr :: KernelResult -> Doc
ppr (Returns ResultManifest
ResultNoSimplify SubExp
what) =
String -> Doc
text String
"returns (manifest)" Doc -> Doc -> Doc
<+> SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
what
ppr (Returns ResultManifest
ResultPrivate SubExp
what) =
String -> Doc
text String
"returns (private)" Doc -> Doc -> Doc
<+> SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
what
ppr (Returns ResultManifest
ResultMaySimplify SubExp
what) =
String -> Doc
text String
"returns" Doc -> Doc -> Doc
<+> SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
what
ppr (WriteReturns [SubExp]
rws VName
arr [([SubExp], SubExp)]
res) =
VName -> Doc
forall a. Pretty a => a -> Doc
ppr VName
arr Doc -> Doc -> Doc
<+> String -> Doc
text String
"with" Doc -> Doc -> Doc
<+> [Doc] -> Doc
PP.apply ((([SubExp], SubExp) -> Doc) -> [([SubExp], SubExp)] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp], SubExp) -> Doc
ppRes [([SubExp], SubExp)]
res)
where ppRes :: ([SubExp], SubExp) -> Doc
ppRes ([SubExp]
is, SubExp
e) =
Doc -> Doc
PP.brackets ([Doc] -> Doc
PP.commasep ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ (SubExp -> SubExp -> Doc) -> [SubExp] -> [SubExp] -> [Doc]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith SubExp -> SubExp -> Doc
forall a a. (Pretty a, Pretty a) => a -> a -> Doc
f [SubExp]
is [SubExp]
rws) Doc -> Doc -> Doc
<+> String -> Doc
text String
"<-" Doc -> Doc -> Doc
<+> SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
e
f :: a -> a -> Doc
f a
i a
rw = a -> Doc
forall a. Pretty a => a -> Doc
ppr a
i Doc -> Doc -> Doc
<+> String -> Doc
text String
"<" Doc -> Doc -> Doc
<+> a -> Doc
forall a. Pretty a => a -> Doc
ppr a
rw
ppr (ConcatReturns SplitOrdering
o SubExp
w SubExp
per_thread_elems VName
v) =
String -> Doc
text String
"concat" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
suff Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<>
Doc -> Doc
parens ([Doc] -> Doc
commasep [SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
w, SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
per_thread_elems]) Doc -> Doc -> Doc
<+> VName -> Doc
forall a. Pretty a => a -> Doc
ppr VName
v
where suff :: Doc
suff = case SplitOrdering
o of SplitOrdering
SplitContiguous -> Doc
forall a. Monoid a => a
mempty
SplitStrided SubExp
stride -> String -> Doc
text String
"Strided" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens (SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
stride)
ppr (TileReturns [(SubExp, SubExp)]
dims VName
v) =
String -> Doc
text String
"tile" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<>
Doc -> Doc
parens ([Doc] -> Doc
commasep ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ ((SubExp, SubExp) -> Doc) -> [(SubExp, SubExp)] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp, SubExp) -> Doc
forall a a. (Pretty a, Pretty a) => (a, a) -> Doc
onDim [(SubExp, SubExp)]
dims) Doc -> Doc -> Doc
<+> VName -> Doc
forall a. Pretty a => a -> Doc
ppr VName
v
where onDim :: (a, a) -> Doc
onDim (a
dim, a
tile) = a -> Doc
forall a. Pretty a => a -> Doc
ppr a
dim Doc -> Doc -> Doc
<+> String -> Doc
text String
"/" Doc -> Doc -> Doc
<+> a -> Doc
forall a. Pretty a => a -> Doc
ppr a
tile
data SegVirt = SegVirt | SegNoVirt
deriving (SegVirt -> SegVirt -> Bool
(SegVirt -> SegVirt -> Bool)
-> (SegVirt -> SegVirt -> Bool) -> Eq SegVirt
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SegVirt -> SegVirt -> Bool
$c/= :: SegVirt -> SegVirt -> Bool
== :: SegVirt -> SegVirt -> Bool
$c== :: SegVirt -> SegVirt -> Bool
Eq, Eq SegVirt
Eq SegVirt
-> (SegVirt -> SegVirt -> Ordering)
-> (SegVirt -> SegVirt -> Bool)
-> (SegVirt -> SegVirt -> Bool)
-> (SegVirt -> SegVirt -> Bool)
-> (SegVirt -> SegVirt -> Bool)
-> (SegVirt -> SegVirt -> SegVirt)
-> (SegVirt -> SegVirt -> SegVirt)
-> Ord SegVirt
SegVirt -> SegVirt -> Bool
SegVirt -> SegVirt -> Ordering
SegVirt -> SegVirt -> SegVirt
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: SegVirt -> SegVirt -> SegVirt
$cmin :: SegVirt -> SegVirt -> SegVirt
max :: SegVirt -> SegVirt -> SegVirt
$cmax :: SegVirt -> SegVirt -> SegVirt
>= :: SegVirt -> SegVirt -> Bool
$c>= :: SegVirt -> SegVirt -> Bool
> :: SegVirt -> SegVirt -> Bool
$c> :: SegVirt -> SegVirt -> Bool
<= :: SegVirt -> SegVirt -> Bool
$c<= :: SegVirt -> SegVirt -> Bool
< :: SegVirt -> SegVirt -> Bool
$c< :: SegVirt -> SegVirt -> Bool
compare :: SegVirt -> SegVirt -> Ordering
$ccompare :: SegVirt -> SegVirt -> Ordering
$cp1Ord :: Eq SegVirt
Ord, Int -> SegVirt -> ShowS
[SegVirt] -> ShowS
SegVirt -> String
(Int -> SegVirt -> ShowS)
-> (SegVirt -> String) -> ([SegVirt] -> ShowS) -> Show SegVirt
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SegVirt] -> ShowS
$cshowList :: [SegVirt] -> ShowS
show :: SegVirt -> String
$cshow :: SegVirt -> String
showsPrec :: Int -> SegVirt -> ShowS
$cshowsPrec :: Int -> SegVirt -> ShowS
Show)
data SegLevel = SegThread { SegLevel -> Count NumGroups SubExp
segNumGroups :: Count NumGroups SubExp
, SegLevel -> Count GroupSize SubExp
segGroupSize :: Count GroupSize SubExp
, SegLevel -> SegVirt
segVirt :: SegVirt }
| SegGroup { segNumGroups :: Count NumGroups SubExp
, segGroupSize :: Count GroupSize SubExp
, segVirt :: SegVirt }
deriving (SegLevel -> SegLevel -> Bool
(SegLevel -> SegLevel -> Bool)
-> (SegLevel -> SegLevel -> Bool) -> Eq SegLevel
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SegLevel -> SegLevel -> Bool
$c/= :: SegLevel -> SegLevel -> Bool
== :: SegLevel -> SegLevel -> Bool
$c== :: SegLevel -> SegLevel -> Bool
Eq, Eq SegLevel
Eq SegLevel
-> (SegLevel -> SegLevel -> Ordering)
-> (SegLevel -> SegLevel -> Bool)
-> (SegLevel -> SegLevel -> Bool)
-> (SegLevel -> SegLevel -> Bool)
-> (SegLevel -> SegLevel -> Bool)
-> (SegLevel -> SegLevel -> SegLevel)
-> (SegLevel -> SegLevel -> SegLevel)
-> Ord SegLevel
SegLevel -> SegLevel -> Bool
SegLevel -> SegLevel -> Ordering
SegLevel -> SegLevel -> SegLevel
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: SegLevel -> SegLevel -> SegLevel
$cmin :: SegLevel -> SegLevel -> SegLevel
max :: SegLevel -> SegLevel -> SegLevel
$cmax :: SegLevel -> SegLevel -> SegLevel
>= :: SegLevel -> SegLevel -> Bool
$c>= :: SegLevel -> SegLevel -> Bool
> :: SegLevel -> SegLevel -> Bool
$c> :: SegLevel -> SegLevel -> Bool
<= :: SegLevel -> SegLevel -> Bool
$c<= :: SegLevel -> SegLevel -> Bool
< :: SegLevel -> SegLevel -> Bool
$c< :: SegLevel -> SegLevel -> Bool
compare :: SegLevel -> SegLevel -> Ordering
$ccompare :: SegLevel -> SegLevel -> Ordering
$cp1Ord :: Eq SegLevel
Ord, Int -> SegLevel -> ShowS
[SegLevel] -> ShowS
SegLevel -> String
(Int -> SegLevel -> ShowS)
-> (SegLevel -> String) -> ([SegLevel] -> ShowS) -> Show SegLevel
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SegLevel] -> ShowS
$cshowList :: [SegLevel] -> ShowS
show :: SegLevel -> String
$cshow :: SegLevel -> String
showsPrec :: Int -> SegLevel -> ShowS
$cshowsPrec :: Int -> SegLevel -> ShowS
Show)
data SegSpace = SegSpace { SegSpace -> VName
segFlat :: VName
, SegSpace -> [(VName, SubExp)]
unSegSpace :: [(VName, SubExp)]
}
deriving (SegSpace -> SegSpace -> Bool
(SegSpace -> SegSpace -> Bool)
-> (SegSpace -> SegSpace -> Bool) -> Eq SegSpace
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SegSpace -> SegSpace -> Bool
$c/= :: SegSpace -> SegSpace -> Bool
== :: SegSpace -> SegSpace -> Bool
$c== :: SegSpace -> SegSpace -> Bool
Eq, Eq SegSpace
Eq SegSpace
-> (SegSpace -> SegSpace -> Ordering)
-> (SegSpace -> SegSpace -> Bool)
-> (SegSpace -> SegSpace -> Bool)
-> (SegSpace -> SegSpace -> Bool)
-> (SegSpace -> SegSpace -> Bool)
-> (SegSpace -> SegSpace -> SegSpace)
-> (SegSpace -> SegSpace -> SegSpace)
-> Ord SegSpace
SegSpace -> SegSpace -> Bool
SegSpace -> SegSpace -> Ordering
SegSpace -> SegSpace -> SegSpace
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: SegSpace -> SegSpace -> SegSpace
$cmin :: SegSpace -> SegSpace -> SegSpace
max :: SegSpace -> SegSpace -> SegSpace
$cmax :: SegSpace -> SegSpace -> SegSpace
>= :: SegSpace -> SegSpace -> Bool
$c>= :: SegSpace -> SegSpace -> Bool
> :: SegSpace -> SegSpace -> Bool
$c> :: SegSpace -> SegSpace -> Bool
<= :: SegSpace -> SegSpace -> Bool
$c<= :: SegSpace -> SegSpace -> Bool
< :: SegSpace -> SegSpace -> Bool
$c< :: SegSpace -> SegSpace -> Bool
compare :: SegSpace -> SegSpace -> Ordering
$ccompare :: SegSpace -> SegSpace -> Ordering
$cp1Ord :: Eq SegSpace
Ord, Int -> SegSpace -> ShowS
[SegSpace] -> ShowS
SegSpace -> String
(Int -> SegSpace -> ShowS)
-> (SegSpace -> String) -> ([SegSpace] -> ShowS) -> Show SegSpace
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SegSpace] -> ShowS
$cshowList :: [SegSpace] -> ShowS
show :: SegSpace -> String
$cshow :: SegSpace -> String
showsPrec :: Int -> SegSpace -> ShowS
$cshowsPrec :: Int -> SegSpace -> ShowS
Show)
segSpaceDims :: SegSpace -> [SubExp]
segSpaceDims :: SegSpace -> [SubExp]
segSpaceDims (SegSpace VName
_ [(VName, SubExp)]
space) = ((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(VName, SubExp)]
space
scopeOfSegSpace :: SegSpace -> Scope lore
scopeOfSegSpace :: SegSpace -> Scope lore
scopeOfSegSpace (SegSpace VName
phys [(VName, SubExp)]
space) =
[(VName, NameInfo lore)] -> Scope lore
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, NameInfo lore)] -> Scope lore)
-> [(VName, NameInfo lore)] -> Scope lore
forall a b. (a -> b) -> a -> b
$ [VName] -> [NameInfo lore] -> [(VName, NameInfo lore)]
forall a b. [a] -> [b] -> [(a, b)]
zip (VName
phys VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: ((VName, SubExp) -> VName) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst [(VName, SubExp)]
space) ([NameInfo lore] -> [(VName, NameInfo lore)])
-> [NameInfo lore] -> [(VName, NameInfo lore)]
forall a b. (a -> b) -> a -> b
$ NameInfo lore -> [NameInfo lore]
forall a. a -> [a]
repeat (NameInfo lore -> [NameInfo lore])
-> NameInfo lore -> [NameInfo lore]
forall a b. (a -> b) -> a -> b
$ IntType -> NameInfo lore
forall lore. IntType -> NameInfo lore
IndexInfo IntType
Int32
checkSegSpace :: TC.Checkable lore => SegSpace -> TC.TypeM lore ()
checkSegSpace :: SegSpace -> TypeM lore ()
checkSegSpace (SegSpace VName
_ [(VName, SubExp)]
dims) =
((VName, SubExp) -> TypeM lore ())
-> [(VName, SubExp)] -> TypeM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ([Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32] (SubExp -> TypeM lore ())
-> ((VName, SubExp) -> SubExp) -> (VName, SubExp) -> TypeM lore ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd) [(VName, SubExp)]
dims
data SegOp lore = SegMap SegLevel SegSpace [Type] (KernelBody lore)
| SegRed SegLevel SegSpace [SegRedOp lore] [Type] (KernelBody lore)
| SegScan SegLevel SegSpace (Lambda lore) [SubExp] [Type] (KernelBody lore)
| SegHist SegLevel SegSpace [HistOp lore] [Type] (KernelBody lore)
deriving (SegOp lore -> SegOp lore -> Bool
(SegOp lore -> SegOp lore -> Bool)
-> (SegOp lore -> SegOp lore -> Bool) -> Eq (SegOp lore)
forall lore. Annotations lore => SegOp lore -> SegOp lore -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SegOp lore -> SegOp lore -> Bool
$c/= :: forall lore. Annotations lore => SegOp lore -> SegOp lore -> Bool
== :: SegOp lore -> SegOp lore -> Bool
$c== :: forall lore. Annotations lore => SegOp lore -> SegOp lore -> Bool
Eq, Eq (SegOp lore)
Eq (SegOp lore)
-> (SegOp lore -> SegOp lore -> Ordering)
-> (SegOp lore -> SegOp lore -> Bool)
-> (SegOp lore -> SegOp lore -> Bool)
-> (SegOp lore -> SegOp lore -> Bool)
-> (SegOp lore -> SegOp lore -> Bool)
-> (SegOp lore -> SegOp lore -> SegOp lore)
-> (SegOp lore -> SegOp lore -> SegOp lore)
-> Ord (SegOp lore)
SegOp lore -> SegOp lore -> Bool
SegOp lore -> SegOp lore -> Ordering
SegOp lore -> SegOp lore -> SegOp lore
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall lore. Annotations lore => Eq (SegOp lore)
forall lore. Annotations lore => SegOp lore -> SegOp lore -> Bool
forall lore.
Annotations lore =>
SegOp lore -> SegOp lore -> Ordering
forall lore.
Annotations lore =>
SegOp lore -> SegOp lore -> SegOp lore
min :: SegOp lore -> SegOp lore -> SegOp lore
$cmin :: forall lore.
Annotations lore =>
SegOp lore -> SegOp lore -> SegOp lore
max :: SegOp lore -> SegOp lore -> SegOp lore
$cmax :: forall lore.
Annotations lore =>
SegOp lore -> SegOp lore -> SegOp lore
>= :: SegOp lore -> SegOp lore -> Bool
$c>= :: forall lore. Annotations lore => SegOp lore -> SegOp lore -> Bool
> :: SegOp lore -> SegOp lore -> Bool
$c> :: forall lore. Annotations lore => SegOp lore -> SegOp lore -> Bool
<= :: SegOp lore -> SegOp lore -> Bool
$c<= :: forall lore. Annotations lore => SegOp lore -> SegOp lore -> Bool
< :: SegOp lore -> SegOp lore -> Bool
$c< :: forall lore. Annotations lore => SegOp lore -> SegOp lore -> Bool
compare :: SegOp lore -> SegOp lore -> Ordering
$ccompare :: forall lore.
Annotations lore =>
SegOp lore -> SegOp lore -> Ordering
$cp1Ord :: forall lore. Annotations lore => Eq (SegOp lore)
Ord, Int -> SegOp lore -> ShowS
[SegOp lore] -> ShowS
SegOp lore -> String
(Int -> SegOp lore -> ShowS)
-> (SegOp lore -> String)
-> ([SegOp lore] -> ShowS)
-> Show (SegOp lore)
forall lore. Annotations lore => Int -> SegOp lore -> ShowS
forall lore. Annotations lore => [SegOp lore] -> ShowS
forall lore. Annotations lore => SegOp lore -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SegOp lore] -> ShowS
$cshowList :: forall lore. Annotations lore => [SegOp lore] -> ShowS
show :: SegOp lore -> String
$cshow :: forall lore. Annotations lore => SegOp lore -> String
showsPrec :: Int -> SegOp lore -> ShowS
$cshowsPrec :: forall lore. Annotations lore => Int -> SegOp lore -> ShowS
Show)
segLevel :: SegOp lore -> SegLevel
segLevel :: SegOp lore -> SegLevel
segLevel (SegMap SegLevel
lvl SegSpace
_ [Type]
_ KernelBody lore
_) = SegLevel
lvl
segLevel (SegRed SegLevel
lvl SegSpace
_ [SegRedOp lore]
_ [Type]
_ KernelBody lore
_) = SegLevel
lvl
segLevel (SegScan SegLevel
lvl SegSpace
_ Lambda lore
_ [SubExp]
_ [Type]
_ KernelBody lore
_) = SegLevel
lvl
segLevel (SegHist SegLevel
lvl SegSpace
_ [HistOp lore]
_ [Type]
_ KernelBody lore
_) = SegLevel
lvl
segSpace :: SegOp lore -> SegSpace
segSpace :: SegOp lore -> SegSpace
segSpace (SegMap SegLevel
_ SegSpace
lvl [Type]
_ KernelBody lore
_) = SegSpace
lvl
segSpace (SegRed SegLevel
_ SegSpace
lvl [SegRedOp lore]
_ [Type]
_ KernelBody lore
_) = SegSpace
lvl
segSpace (SegScan SegLevel
_ SegSpace
lvl Lambda lore
_ [SubExp]
_ [Type]
_ KernelBody lore
_) = SegSpace
lvl
segSpace (SegHist SegLevel
_ SegSpace
lvl [HistOp lore]
_ [Type]
_ KernelBody lore
_) = SegSpace
lvl
segResultShape :: SegSpace -> Type -> KernelResult -> Type
segResultShape :: SegSpace -> Type -> KernelResult -> Type
segResultShape SegSpace
_ Type
t (WriteReturns [SubExp]
rws VName
_ [([SubExp], SubExp)]
_) =
Type
t Type -> Shape -> Type
`arrayOfShape` [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
rws
segResultShape SegSpace
space Type
t (Returns ResultManifest
_ SubExp
_) =
(SubExp -> Type -> Type) -> Type -> [SubExp] -> Type
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr ((Type -> SubExp -> Type) -> SubExp -> Type -> Type
forall a b c. (a -> b -> c) -> b -> a -> c
flip Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
arrayOfRow) Type
t ([SubExp] -> Type) -> [SubExp] -> Type
forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space
segResultShape SegSpace
_ Type
t (ConcatReturns SplitOrdering
_ SubExp
w SubExp
_ VName
_) =
Type
t Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` SubExp
w
segResultShape SegSpace
_ Type
t (TileReturns [(SubExp, SubExp)]
dims VName
_) =
Type
t Type -> Shape -> Type
`arrayOfShape` [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape (((SubExp, SubExp) -> SubExp) -> [(SubExp, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp, SubExp) -> SubExp
forall a b. (a, b) -> a
fst [(SubExp, SubExp)]
dims)
segOpType :: SegOp lore -> [Type]
segOpType :: SegOp lore -> [Type]
segOpType (SegMap SegLevel
_ SegSpace
space [Type]
ts KernelBody lore
kbody) =
(Type -> KernelResult -> Type)
-> [Type] -> [KernelResult] -> [Type]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (SegSpace -> Type -> KernelResult -> Type
segResultShape SegSpace
space) [Type]
ts ([KernelResult] -> [Type]) -> [KernelResult] -> [Type]
forall a b. (a -> b) -> a -> b
$ KernelBody lore -> [KernelResult]
forall lore. KernelBody lore -> [KernelResult]
kernelBodyResult KernelBody lore
kbody
segOpType (SegRed SegLevel
_ SegSpace
space [SegRedOp lore]
reds [Type]
ts KernelBody lore
kbody) =
[Type]
red_ts [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++
(Type -> KernelResult -> Type)
-> [Type] -> [KernelResult] -> [Type]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (SegSpace -> Type -> KernelResult -> Type
segResultShape SegSpace
space) [Type]
map_ts
(Int -> [KernelResult] -> [KernelResult]
forall a. Int -> [a] -> [a]
drop ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
red_ts) ([KernelResult] -> [KernelResult])
-> [KernelResult] -> [KernelResult]
forall a b. (a -> b) -> a -> b
$ KernelBody lore -> [KernelResult]
forall lore. KernelBody lore -> [KernelResult]
kernelBodyResult KernelBody lore
kbody)
where map_ts :: [Type]
map_ts = Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
drop ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
red_ts) [Type]
ts
segment_dims :: [SubExp]
segment_dims = [SubExp] -> [SubExp]
forall a. [a] -> [a]
init ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space
red_ts :: [Type]
red_ts = do
SegRedOp lore
op <- [SegRedOp lore]
reds
let shape :: Shape
shape = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
segment_dims Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> SegRedOp lore -> Shape
forall lore. SegRedOp lore -> Shape
segRedShape SegRedOp lore
op
(Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> Shape -> Type
`arrayOfShape` Shape
shape) (LambdaT lore -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType (LambdaT lore -> [Type]) -> LambdaT lore -> [Type]
forall a b. (a -> b) -> a -> b
$ SegRedOp lore -> LambdaT lore
forall lore. SegRedOp lore -> Lambda lore
segRedLambda SegRedOp lore
op)
segOpType (SegScan SegLevel
_ SegSpace
space LambdaT lore
_ [SubExp]
nes [Type]
ts KernelBody lore
kbody) =
(Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> Shape -> Type
`arrayOfShape` [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
dims) [Type]
scan_ts [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++
(Type -> KernelResult -> Type)
-> [Type] -> [KernelResult] -> [Type]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (SegSpace -> Type -> KernelResult -> Type
segResultShape SegSpace
space) [Type]
map_ts
(Int -> [KernelResult] -> [KernelResult]
forall a. Int -> [a] -> [a]
drop ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
scan_ts) ([KernelResult] -> [KernelResult])
-> [KernelResult] -> [KernelResult]
forall a b. (a -> b) -> a -> b
$ KernelBody lore -> [KernelResult]
forall lore. KernelBody lore -> [KernelResult]
kernelBodyResult KernelBody lore
kbody)
where dims :: [SubExp]
dims = SegSpace -> [SubExp]
segSpaceDims SegSpace
space
([Type]
scan_ts, [Type]
map_ts) = Int -> [Type] -> ([Type], [Type])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) [Type]
ts
segOpType (SegHist SegLevel
_ SegSpace
space [HistOp lore]
ops [Type]
_ KernelBody lore
_) = do
HistOp lore
op <- [HistOp lore]
ops
let shape :: Shape
shape = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape ([SubExp]
segment_dims [SubExp] -> [SubExp] -> [SubExp]
forall a. Semigroup a => a -> a -> a
<> [HistOp lore -> SubExp
forall lore. HistOp lore -> SubExp
histWidth HistOp lore
op]) Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> HistOp lore -> Shape
forall lore. HistOp lore -> Shape
histShape HistOp lore
op
(Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> Shape -> Type
`arrayOfShape` Shape
shape) (LambdaT lore -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType (LambdaT lore -> [Type]) -> LambdaT lore -> [Type]
forall a b. (a -> b) -> a -> b
$ HistOp lore -> LambdaT lore
forall lore. HistOp lore -> Lambda lore
histOp HistOp lore
op)
where dims :: [SubExp]
dims = SegSpace -> [SubExp]
segSpaceDims SegSpace
space
segment_dims :: [SubExp]
segment_dims = [SubExp] -> [SubExp]
forall a. [a] -> [a]
init [SubExp]
dims
instance TypedOp (SegOp lore) where
opType :: SegOp lore -> m [ExtType]
opType = [ExtType] -> m [ExtType]
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([ExtType] -> m [ExtType])
-> (SegOp lore -> [ExtType]) -> SegOp lore -> m [ExtType]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Type] -> [ExtType]
forall u. [TypeBase Shape u] -> [TypeBase ExtShape u]
staticShapes ([Type] -> [ExtType])
-> (SegOp lore -> [Type]) -> SegOp lore -> [ExtType]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegOp lore -> [Type]
forall lore. SegOp lore -> [Type]
segOpType
instance (Attributes lore, Aliased lore) => AliasedOp (SegOp lore) where
opAliases :: SegOp lore -> [Names]
opAliases = (Type -> Names) -> [Type] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map (Names -> Type -> Names
forall a b. a -> b -> a
const Names
forall a. Monoid a => a
mempty) ([Type] -> [Names])
-> (SegOp lore -> [Type]) -> SegOp lore -> [Names]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegOp lore -> [Type]
forall lore. SegOp lore -> [Type]
segOpType
consumedInOp :: SegOp lore -> Names
consumedInOp (SegMap SegLevel
_ SegSpace
_ [Type]
_ KernelBody lore
kbody) =
KernelBody lore -> Names
forall lore. Aliased lore => KernelBody lore -> Names
consumedInKernelBody KernelBody lore
kbody
consumedInOp (SegRed SegLevel
_ SegSpace
_ [SegRedOp lore]
_ [Type]
_ KernelBody lore
kbody) =
KernelBody lore -> Names
forall lore. Aliased lore => KernelBody lore -> Names
consumedInKernelBody KernelBody lore
kbody
consumedInOp (SegScan SegLevel
_ SegSpace
_ Lambda lore
_ [SubExp]
_ [Type]
_ KernelBody lore
kbody) =
KernelBody lore -> Names
forall lore. Aliased lore => KernelBody lore -> Names
consumedInKernelBody KernelBody lore
kbody
consumedInOp (SegHist SegLevel
_ SegSpace
_ [HistOp lore]
ops [Type]
_ KernelBody lore
kbody) =
[VName] -> Names
namesFromList ((HistOp lore -> [VName]) -> [HistOp lore] -> [VName]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap HistOp lore -> [VName]
forall lore. HistOp lore -> [VName]
histDest [HistOp lore]
ops) Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> KernelBody lore -> Names
forall lore. Aliased lore => KernelBody lore -> Names
consumedInKernelBody KernelBody lore
kbody
checkSegLevel :: Maybe SegLevel -> SegLevel -> TC.TypeM lore ()
checkSegLevel :: Maybe SegLevel -> SegLevel -> TypeM lore ()
checkSegLevel Maybe SegLevel
Nothing SegLevel
_ =
() -> TypeM lore ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
checkSegLevel (Just SegThread{}) SegLevel
_ =
ErrorCase lore -> TypeM lore ()
forall lore a. ErrorCase lore -> TypeM lore a
TC.bad (ErrorCase lore -> TypeM lore ())
-> ErrorCase lore -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ String -> ErrorCase lore
forall lore. String -> ErrorCase lore
TC.TypeError String
"SegOps cannot occur when already at thread level."
checkSegLevel (Just SegLevel
x) SegLevel
y
| SegLevel
x SegLevel -> SegLevel -> Bool
forall a. Eq a => a -> a -> Bool
== SegLevel
y = ErrorCase lore -> TypeM lore ()
forall lore a. ErrorCase lore -> TypeM lore a
TC.bad (ErrorCase lore -> TypeM lore ())
-> ErrorCase lore -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ String -> ErrorCase lore
forall lore. String -> ErrorCase lore
TC.TypeError (String -> ErrorCase lore) -> String -> ErrorCase lore
forall a b. (a -> b) -> a -> b
$ String
"Already at at level " String -> ShowS
forall a. [a] -> [a] -> [a]
++ SegLevel -> String
forall a. Pretty a => a -> String
pretty SegLevel
x
| SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
x Count NumGroups SubExp -> Count NumGroups SubExp -> Bool
forall a. Eq a => a -> a -> Bool
/= SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
y Bool -> Bool -> Bool
|| SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
x Count GroupSize SubExp -> Count GroupSize SubExp -> Bool
forall a. Eq a => a -> a -> Bool
/= SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
y =
ErrorCase lore -> TypeM lore ()
forall lore a. ErrorCase lore -> TypeM lore a
TC.bad (ErrorCase lore -> TypeM lore ())
-> ErrorCase lore -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ String -> ErrorCase lore
forall lore. String -> ErrorCase lore
TC.TypeError String
"Physical layout for SegLevel does not match parent SegLevel."
| Bool
otherwise =
() -> TypeM lore ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
checkSegBasics :: TC.Checkable lore =>
Maybe SegLevel -> SegLevel -> SegSpace -> [Type] -> TC.TypeM lore ()
checkSegBasics :: Maybe SegLevel -> SegLevel -> SegSpace -> [Type] -> TypeM lore ()
checkSegBasics Maybe SegLevel
cur_lvl SegLevel
lvl SegSpace
space [Type]
ts = do
Maybe SegLevel -> SegLevel -> TypeM lore ()
forall lore. Maybe SegLevel -> SegLevel -> TypeM lore ()
checkSegLevel Maybe SegLevel
cur_lvl SegLevel
lvl
SegSpace -> TypeM lore ()
forall lore. Checkable lore => SegSpace -> TypeM lore ()
checkSegSpace SegSpace
space
(Type -> TypeM lore ()) -> [Type] -> TypeM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Type -> TypeM lore ()
forall lore u. Checkable lore => TypeBase Shape u -> TypeM lore ()
TC.checkType [Type]
ts
typeCheckSegOp :: TC.Checkable lore =>
Maybe SegLevel -> SegOp (Aliases lore) -> TC.TypeM lore ()
typeCheckSegOp :: Maybe SegLevel -> SegOp (Aliases lore) -> TypeM lore ()
typeCheckSegOp Maybe SegLevel
cur_lvl (SegMap SegLevel
lvl SegSpace
space [Type]
ts KernelBody (Aliases lore)
kbody) =
Maybe SegLevel
-> SegLevel
-> SegSpace
-> [(Lambda (Aliases lore), [SubExp], Shape)]
-> [Type]
-> KernelBody (Aliases lore)
-> TypeM lore ()
forall lore.
Checkable lore =>
Maybe SegLevel
-> SegLevel
-> SegSpace
-> [(Lambda (Aliases lore), [SubExp], Shape)]
-> [Type]
-> KernelBody (Aliases lore)
-> TypeM lore ()
checkScanRed Maybe SegLevel
cur_lvl SegLevel
lvl SegSpace
space [] [Type]
ts KernelBody (Aliases lore)
kbody
typeCheckSegOp Maybe SegLevel
cur_lvl (SegRed SegLevel
lvl SegSpace
space [SegRedOp (Aliases lore)]
reds [Type]
ts KernelBody (Aliases lore)
body) =
Maybe SegLevel
-> SegLevel
-> SegSpace
-> [(Lambda (Aliases lore), [SubExp], Shape)]
-> [Type]
-> KernelBody (Aliases lore)
-> TypeM lore ()
forall lore.
Checkable lore =>
Maybe SegLevel
-> SegLevel
-> SegSpace
-> [(Lambda (Aliases lore), [SubExp], Shape)]
-> [Type]
-> KernelBody (Aliases lore)
-> TypeM lore ()
checkScanRed Maybe SegLevel
cur_lvl SegLevel
lvl SegSpace
space [(Lambda (Aliases lore), [SubExp], Shape)]
reds' [Type]
ts KernelBody (Aliases lore)
body
where reds' :: [(Lambda (Aliases lore), [SubExp], Shape)]
reds' = [Lambda (Aliases lore)]
-> [[SubExp]]
-> [Shape]
-> [(Lambda (Aliases lore), [SubExp], Shape)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3
((SegRedOp (Aliases lore) -> Lambda (Aliases lore))
-> [SegRedOp (Aliases lore)] -> [Lambda (Aliases lore)]
forall a b. (a -> b) -> [a] -> [b]
map SegRedOp (Aliases lore) -> Lambda (Aliases lore)
forall lore. SegRedOp lore -> Lambda lore
segRedLambda [SegRedOp (Aliases lore)]
reds)
((SegRedOp (Aliases lore) -> [SubExp])
-> [SegRedOp (Aliases lore)] -> [[SubExp]]
forall a b. (a -> b) -> [a] -> [b]
map SegRedOp (Aliases lore) -> [SubExp]
forall lore. SegRedOp lore -> [SubExp]
segRedNeutral [SegRedOp (Aliases lore)]
reds)
((SegRedOp (Aliases lore) -> Shape)
-> [SegRedOp (Aliases lore)] -> [Shape]
forall a b. (a -> b) -> [a] -> [b]
map SegRedOp (Aliases lore) -> Shape
forall lore. SegRedOp lore -> Shape
segRedShape [SegRedOp (Aliases lore)]
reds)
typeCheckSegOp Maybe SegLevel
cur_lvl (SegScan SegLevel
lvl SegSpace
space Lambda (Aliases lore)
scan_op [SubExp]
nes [Type]
ts KernelBody (Aliases lore)
body) =
Maybe SegLevel
-> SegLevel
-> SegSpace
-> [(Lambda (Aliases lore), [SubExp], Shape)]
-> [Type]
-> KernelBody (Aliases lore)
-> TypeM lore ()
forall lore.
Checkable lore =>
Maybe SegLevel
-> SegLevel
-> SegSpace
-> [(Lambda (Aliases lore), [SubExp], Shape)]
-> [Type]
-> KernelBody (Aliases lore)
-> TypeM lore ()
checkScanRed Maybe SegLevel
cur_lvl SegLevel
lvl SegSpace
space [(Lambda (Aliases lore)
scan_op, [SubExp]
nes, Shape
forall a. Monoid a => a
mempty)] [Type]
ts KernelBody (Aliases lore)
body
typeCheckSegOp Maybe SegLevel
cur_lvl (SegHist SegLevel
lvl SegSpace
space [HistOp (Aliases lore)]
ops [Type]
ts KernelBody (Aliases lore)
kbody) = do
Maybe SegLevel -> SegLevel -> SegSpace -> [Type] -> TypeM lore ()
forall lore.
Checkable lore =>
Maybe SegLevel -> SegLevel -> SegSpace -> [Type] -> TypeM lore ()
checkSegBasics Maybe SegLevel
cur_lvl SegLevel
lvl SegSpace
space [Type]
ts
Scope (Aliases lore) -> TypeM lore () -> TypeM lore ()
forall lore a.
Checkable lore =>
Scope (Aliases lore) -> TypeM lore a -> TypeM lore a
TC.binding (SegSpace -> Scope (Aliases lore)
forall lore. SegSpace -> Scope lore
scopeOfSegSpace SegSpace
space) (TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ do
[[Type]]
nes_ts <- [HistOp (Aliases lore)]
-> (HistOp (Aliases lore) -> TypeM lore [Type])
-> TypeM lore [[Type]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [HistOp (Aliases lore)]
ops ((HistOp (Aliases lore) -> TypeM lore [Type])
-> TypeM lore [[Type]])
-> (HistOp (Aliases lore) -> TypeM lore [Type])
-> TypeM lore [[Type]]
forall a b. (a -> b) -> a -> b
$ \(HistOp SubExp
dest_w SubExp
rf [VName]
dests [SubExp]
nes Shape
shape Lambda (Aliases lore)
op) -> do
[Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32] SubExp
dest_w
[Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32] SubExp
rf
[Arg]
nes' <- (SubExp -> TypeM lore Arg) -> [SubExp] -> TypeM lore [Arg]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> TypeM lore Arg
forall lore. Checkable lore => SubExp -> TypeM lore Arg
TC.checkArg [SubExp]
nes
(SubExp -> TypeM lore ()) -> [SubExp] -> TypeM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ([Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32]) ([SubExp] -> TypeM lore ()) -> [SubExp] -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape
let stripVecDims :: Type -> Type
stripVecDims = Int -> Type -> Type
forall shape u.
ArrayShape shape =>
Int -> TypeBase shape u -> TypeBase shape u
stripArray (Int -> Type -> Type) -> Int -> Type -> Type
forall a b. (a -> b) -> a -> b
$ Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank Shape
shape
Lambda (Aliases lore) -> [Arg] -> TypeM lore ()
forall lore.
Checkable lore =>
Lambda (Aliases lore) -> [Arg] -> TypeM lore ()
TC.checkLambda Lambda (Aliases lore)
op ([Arg] -> TypeM lore ()) -> [Arg] -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ (Arg -> Arg) -> [Arg] -> [Arg]
forall a b. (a -> b) -> [a] -> [b]
map (Arg -> Arg
TC.noArgAliases (Arg -> Arg) -> (Arg -> Arg) -> Arg -> Arg
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Type -> Type) -> Arg -> Arg
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (b, d) (c, d)
first Type -> Type
stripVecDims) ([Arg] -> [Arg]) -> [Arg] -> [Arg]
forall a b. (a -> b) -> a -> b
$ [Arg]
nes' [Arg] -> [Arg] -> [Arg]
forall a. [a] -> [a] -> [a]
++ [Arg]
nes'
let nes_t :: [Type]
nes_t = (Arg -> Type) -> [Arg] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Type
TC.argType [Arg]
nes'
Bool -> TypeM lore () -> TypeM lore ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type]
nes_t [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== Lambda (Aliases lore) -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda (Aliases lore)
op) (TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
ErrorCase lore -> TypeM lore ()
forall lore a. ErrorCase lore -> TypeM lore a
TC.bad (ErrorCase lore -> TypeM lore ())
-> ErrorCase lore -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ String -> ErrorCase lore
forall lore. String -> ErrorCase lore
TC.TypeError (String -> ErrorCase lore) -> String -> ErrorCase lore
forall a b. (a -> b) -> a -> b
$ String
"SegHist operator has return type " String -> ShowS
forall a. [a] -> [a] -> [a]
++
[Type] -> String
forall a. Pretty a => [a] -> String
prettyTuple (Lambda (Aliases lore) -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda (Aliases lore)
op) String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" but neutral element has type " String -> ShowS
forall a. [a] -> [a] -> [a]
++
[Type] -> String
forall a. Pretty a => [a] -> String
prettyTuple [Type]
nes_t
let dest_shape :: Shape
dest_shape = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape ([SubExp]
segment_dims [SubExp] -> [SubExp] -> [SubExp]
forall a. Semigroup a => a -> a -> a
<> [SubExp
dest_w]) Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> Shape
shape
[(Type, VName)]
-> ((Type, VName) -> TypeM lore ()) -> TypeM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Type] -> [VName] -> [(Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Type]
nes_t [VName]
dests) (((Type, VName) -> TypeM lore ()) -> TypeM lore ())
-> ((Type, VName) -> TypeM lore ()) -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ \(Type
t, VName
dest) -> do
[Type] -> VName -> TypeM lore ()
forall lore. Checkable lore => [Type] -> VName -> TypeM lore ()
TC.requireI [Type
t Type -> Shape -> Type
`arrayOfShape` Shape
dest_shape] VName
dest
Names -> TypeM lore ()
forall lore. Checkable lore => Names -> TypeM lore ()
TC.consume (Names -> TypeM lore ()) -> TypeM lore Names -> TypeM lore ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> TypeM lore Names
forall lore. Checkable lore => VName -> TypeM lore Names
TC.lookupAliases VName
dest
[Type] -> TypeM lore [Type]
forall (m :: * -> *) a. Monad m => a -> m a
return ([Type] -> TypeM lore [Type]) -> [Type] -> TypeM lore [Type]
forall a b. (a -> b) -> a -> b
$ (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> Shape -> Type
`arrayOfShape` Shape
shape) [Type]
nes_t
[Type] -> KernelBody (Aliases lore) -> TypeM lore ()
forall lore.
Checkable lore =>
[Type] -> KernelBody (Aliases lore) -> TypeM lore ()
checkKernelBody [Type]
ts KernelBody (Aliases lore)
kbody
let bucket_ret_t :: [Type]
bucket_ret_t = Int -> Type -> [Type]
forall a. Int -> a -> [a]
replicate ([HistOp (Aliases lore)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [HistOp (Aliases lore)]
ops) (PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32) [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [[Type]] -> [Type]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Type]]
nes_ts
Bool -> TypeM lore () -> TypeM lore ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type]
bucket_ret_t [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== [Type]
ts) (TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
ErrorCase lore -> TypeM lore ()
forall lore a. ErrorCase lore -> TypeM lore a
TC.bad (ErrorCase lore -> TypeM lore ())
-> ErrorCase lore -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ String -> ErrorCase lore
forall lore. String -> ErrorCase lore
TC.TypeError (String -> ErrorCase lore) -> String -> ErrorCase lore
forall a b. (a -> b) -> a -> b
$ String
"SegHist body has return type " String -> ShowS
forall a. [a] -> [a] -> [a]
++
[Type] -> String
forall a. Pretty a => [a] -> String
prettyTuple [Type]
ts String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" but should have type " String -> ShowS
forall a. [a] -> [a] -> [a]
++
[Type] -> String
forall a. Pretty a => [a] -> String
prettyTuple [Type]
bucket_ret_t
where segment_dims :: [SubExp]
segment_dims = [SubExp] -> [SubExp]
forall a. [a] -> [a]
init ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space
checkScanRed :: TC.Checkable lore =>
Maybe SegLevel -> SegLevel
-> SegSpace
-> [(Lambda (Aliases lore), [SubExp], Shape)]
-> [Type]
-> KernelBody (Aliases lore)
-> TC.TypeM lore ()
checkScanRed :: Maybe SegLevel
-> SegLevel
-> SegSpace
-> [(Lambda (Aliases lore), [SubExp], Shape)]
-> [Type]
-> KernelBody (Aliases lore)
-> TypeM lore ()
checkScanRed Maybe SegLevel
cur_lvl SegLevel
lvl SegSpace
space [(Lambda (Aliases lore), [SubExp], Shape)]
ops [Type]
ts KernelBody (Aliases lore)
kbody = do
Maybe SegLevel -> SegLevel -> SegSpace -> [Type] -> TypeM lore ()
forall lore.
Checkable lore =>
Maybe SegLevel -> SegLevel -> SegSpace -> [Type] -> TypeM lore ()
checkSegBasics Maybe SegLevel
cur_lvl SegLevel
lvl SegSpace
space [Type]
ts
Scope (Aliases lore) -> TypeM lore () -> TypeM lore ()
forall lore a.
Checkable lore =>
Scope (Aliases lore) -> TypeM lore a -> TypeM lore a
TC.binding (SegSpace -> Scope (Aliases lore)
forall lore. SegSpace -> Scope lore
scopeOfSegSpace SegSpace
space) (TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ do
[[Type]]
ne_ts <- [(Lambda (Aliases lore), [SubExp], Shape)]
-> ((Lambda (Aliases lore), [SubExp], Shape) -> TypeM lore [Type])
-> TypeM lore [[Type]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(Lambda (Aliases lore), [SubExp], Shape)]
ops (((Lambda (Aliases lore), [SubExp], Shape) -> TypeM lore [Type])
-> TypeM lore [[Type]])
-> ((Lambda (Aliases lore), [SubExp], Shape) -> TypeM lore [Type])
-> TypeM lore [[Type]]
forall a b. (a -> b) -> a -> b
$ \(Lambda (Aliases lore)
lam, [SubExp]
nes, Shape
shape) -> do
(SubExp -> TypeM lore ()) -> [SubExp] -> TypeM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ([Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32]) ([SubExp] -> TypeM lore ()) -> [SubExp] -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape
[Arg]
nes' <- (SubExp -> TypeM lore Arg) -> [SubExp] -> TypeM lore [Arg]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> TypeM lore Arg
forall lore. Checkable lore => SubExp -> TypeM lore Arg
TC.checkArg [SubExp]
nes
let stripVecDims :: Type -> Type
stripVecDims = Int -> Type -> Type
forall shape u.
ArrayShape shape =>
Int -> TypeBase shape u -> TypeBase shape u
stripArray (Int -> Type -> Type) -> Int -> Type -> Type
forall a b. (a -> b) -> a -> b
$ Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank Shape
shape
Lambda (Aliases lore) -> [Arg] -> TypeM lore ()
forall lore.
Checkable lore =>
Lambda (Aliases lore) -> [Arg] -> TypeM lore ()
TC.checkLambda Lambda (Aliases lore)
lam ([Arg] -> TypeM lore ()) -> [Arg] -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ (Arg -> Arg) -> [Arg] -> [Arg]
forall a b. (a -> b) -> [a] -> [b]
map (Arg -> Arg
TC.noArgAliases (Arg -> Arg) -> (Arg -> Arg) -> Arg -> Arg
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Type -> Type) -> Arg -> Arg
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (b, d) (c, d)
first Type -> Type
stripVecDims) ([Arg] -> [Arg]) -> [Arg] -> [Arg]
forall a b. (a -> b) -> a -> b
$ [Arg]
nes' [Arg] -> [Arg] -> [Arg]
forall a. [a] -> [a] -> [a]
++ [Arg]
nes'
let nes_t :: [Type]
nes_t = (Arg -> Type) -> [Arg] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Type
TC.argType [Arg]
nes'
Bool -> TypeM lore () -> TypeM lore ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Lambda (Aliases lore) -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda (Aliases lore)
lam [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== [Type]
nes_t) (TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
ErrorCase lore -> TypeM lore ()
forall lore a. ErrorCase lore -> TypeM lore a
TC.bad (ErrorCase lore -> TypeM lore ())
-> ErrorCase lore -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ String -> ErrorCase lore
forall lore. String -> ErrorCase lore
TC.TypeError String
"wrong type for operator or neutral elements."
[Type] -> TypeM lore [Type]
forall (m :: * -> *) a. Monad m => a -> m a
return ([Type] -> TypeM lore [Type]) -> [Type] -> TypeM lore [Type]
forall a b. (a -> b) -> a -> b
$ (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> Shape -> Type
`arrayOfShape` Shape
shape) [Type]
nes_t
let expecting :: [Type]
expecting = [[Type]] -> [Type]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Type]]
ne_ts
got :: [Type]
got = Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
take ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
expecting) [Type]
ts
Bool -> TypeM lore () -> TypeM lore ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type]
expecting [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== [Type]
got) (TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
ErrorCase lore -> TypeM lore ()
forall lore a. ErrorCase lore -> TypeM lore a
TC.bad (ErrorCase lore -> TypeM lore ())
-> ErrorCase lore -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ String -> ErrorCase lore
forall lore. String -> ErrorCase lore
TC.TypeError (String -> ErrorCase lore) -> String -> ErrorCase lore
forall a b. (a -> b) -> a -> b
$
String
"Wrong return for body (does not match neutral elements; expected " String -> ShowS
forall a. [a] -> [a] -> [a]
++
[Type] -> String
forall a. Pretty a => a -> String
pretty [Type]
expecting String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"; found " String -> ShowS
forall a. [a] -> [a] -> [a]
++
[Type] -> String
forall a. Pretty a => a -> String
pretty [Type]
got String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
")"
[Type] -> KernelBody (Aliases lore) -> TypeM lore ()
forall lore.
Checkable lore =>
[Type] -> KernelBody (Aliases lore) -> TypeM lore ()
checkKernelBody [Type]
ts KernelBody (Aliases lore)
kbody
data SegOpMapper flore tlore m = SegOpMapper {
SegOpMapper flore tlore m -> SubExp -> m SubExp
mapOnSegOpSubExp :: SubExp -> m SubExp
, SegOpMapper flore tlore m -> Lambda flore -> m (Lambda tlore)
mapOnSegOpLambda :: Lambda flore -> m (Lambda tlore)
, SegOpMapper flore tlore m
-> KernelBody flore -> m (KernelBody tlore)
mapOnSegOpBody :: KernelBody flore -> m (KernelBody tlore)
, SegOpMapper flore tlore m -> VName -> m VName
mapOnSegOpVName :: VName -> m VName
}
identitySegOpMapper :: Monad m => SegOpMapper lore lore m
identitySegOpMapper :: SegOpMapper lore lore m
identitySegOpMapper = SegOpMapper :: forall flore tlore (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda flore -> m (Lambda tlore))
-> (KernelBody flore -> m (KernelBody tlore))
-> (VName -> m VName)
-> SegOpMapper flore tlore m
SegOpMapper { mapOnSegOpSubExp :: SubExp -> m SubExp
mapOnSegOpSubExp = SubExp -> m SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return
, mapOnSegOpLambda :: Lambda lore -> m (Lambda lore)
mapOnSegOpLambda = Lambda lore -> m (Lambda lore)
forall (m :: * -> *) a. Monad m => a -> m a
return
, mapOnSegOpBody :: KernelBody lore -> m (KernelBody lore)
mapOnSegOpBody = KernelBody lore -> m (KernelBody lore)
forall (m :: * -> *) a. Monad m => a -> m a
return
, mapOnSegOpVName :: VName -> m VName
mapOnSegOpVName = VName -> m VName
forall (m :: * -> *) a. Monad m => a -> m a
return
}
mapOnSegSpace :: Monad f =>
SegOpMapper flore tlore f -> SegSpace -> f SegSpace
mapOnSegSpace :: SegOpMapper flore tlore f -> SegSpace -> f SegSpace
mapOnSegSpace SegOpMapper flore tlore f
tv (SegSpace VName
phys [(VName, SubExp)]
dims) =
VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
phys ([(VName, SubExp)] -> SegSpace)
-> f [(VName, SubExp)] -> f SegSpace
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((VName, SubExp) -> f (VName, SubExp))
-> [(VName, SubExp)] -> f [(VName, SubExp)]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse ((SubExp -> f SubExp) -> (VName, SubExp) -> f (VName, SubExp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse ((SubExp -> f SubExp) -> (VName, SubExp) -> f (VName, SubExp))
-> (SubExp -> f SubExp) -> (VName, SubExp) -> f (VName, SubExp)
forall a b. (a -> b) -> a -> b
$ SegOpMapper flore tlore f -> SubExp -> f SubExp
forall flore tlore (m :: * -> *).
SegOpMapper flore tlore m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper flore tlore f
tv) [(VName, SubExp)]
dims
mapSegOpM :: (Applicative m, Monad m) =>
SegOpMapper flore tlore m -> SegOp flore -> m (SegOp tlore)
mapSegOpM :: SegOpMapper flore tlore m -> SegOp flore -> m (SegOp tlore)
mapSegOpM SegOpMapper flore tlore m
tv (SegMap SegLevel
lvl SegSpace
space [Type]
ts KernelBody flore
body) =
SegLevel -> SegSpace -> [Type] -> KernelBody tlore -> SegOp tlore
forall lore.
SegLevel -> SegSpace -> [Type] -> KernelBody lore -> SegOp lore
SegMap
(SegLevel -> SegSpace -> [Type] -> KernelBody tlore -> SegOp tlore)
-> m SegLevel
-> m (SegSpace -> [Type] -> KernelBody tlore -> SegOp tlore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper flore tlore m -> SegLevel -> m SegLevel
forall (m :: * -> *) flore tlore.
Monad m =>
SegOpMapper flore tlore m -> SegLevel -> m SegLevel
mapOnSegLevel SegOpMapper flore tlore m
tv SegLevel
lvl
m (SegSpace -> [Type] -> KernelBody tlore -> SegOp tlore)
-> m SegSpace -> m ([Type] -> KernelBody tlore -> SegOp tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper flore tlore m -> SegSpace -> m SegSpace
forall (f :: * -> *) flore tlore.
Monad f =>
SegOpMapper flore tlore f -> SegSpace -> f SegSpace
mapOnSegSpace SegOpMapper flore tlore m
tv SegSpace
space
m ([Type] -> KernelBody tlore -> SegOp tlore)
-> m [Type] -> m (KernelBody tlore -> SegOp tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (Type -> m Type) -> [Type] -> m [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SegOpMapper flore tlore m -> Type -> m Type
forall (m :: * -> *) flore tlore.
Monad m =>
SegOpMapper flore tlore m -> Type -> m Type
mapOnSegOpType SegOpMapper flore tlore m
tv) [Type]
ts
m (KernelBody tlore -> SegOp tlore)
-> m (KernelBody tlore) -> m (SegOp tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper flore tlore m
-> KernelBody flore -> m (KernelBody tlore)
forall flore tlore (m :: * -> *).
SegOpMapper flore tlore m
-> KernelBody flore -> m (KernelBody tlore)
mapOnSegOpBody SegOpMapper flore tlore m
tv KernelBody flore
body
mapSegOpM SegOpMapper flore tlore m
tv (SegRed SegLevel
lvl SegSpace
space [SegRedOp flore]
reds [Type]
ts KernelBody flore
lam) =
SegLevel
-> SegSpace
-> [SegRedOp tlore]
-> [Type]
-> KernelBody tlore
-> SegOp tlore
forall lore.
SegLevel
-> SegSpace
-> [SegRedOp lore]
-> [Type]
-> KernelBody lore
-> SegOp lore
SegRed
(SegLevel
-> SegSpace
-> [SegRedOp tlore]
-> [Type]
-> KernelBody tlore
-> SegOp tlore)
-> m SegLevel
-> m (SegSpace
-> [SegRedOp tlore] -> [Type] -> KernelBody tlore -> SegOp tlore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper flore tlore m -> SegLevel -> m SegLevel
forall (m :: * -> *) flore tlore.
Monad m =>
SegOpMapper flore tlore m -> SegLevel -> m SegLevel
mapOnSegLevel SegOpMapper flore tlore m
tv SegLevel
lvl
m (SegSpace
-> [SegRedOp tlore] -> [Type] -> KernelBody tlore -> SegOp tlore)
-> m SegSpace
-> m ([SegRedOp tlore]
-> [Type] -> KernelBody tlore -> SegOp tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper flore tlore m -> SegSpace -> m SegSpace
forall (f :: * -> *) flore tlore.
Monad f =>
SegOpMapper flore tlore f -> SegSpace -> f SegSpace
mapOnSegSpace SegOpMapper flore tlore m
tv SegSpace
space
m ([SegRedOp tlore] -> [Type] -> KernelBody tlore -> SegOp tlore)
-> m [SegRedOp tlore]
-> m ([Type] -> KernelBody tlore -> SegOp tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SegRedOp flore -> m (SegRedOp tlore))
-> [SegRedOp flore] -> m [SegRedOp tlore]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SegRedOp flore -> m (SegRedOp tlore)
onSegOp [SegRedOp flore]
reds
m ([Type] -> KernelBody tlore -> SegOp tlore)
-> m [Type] -> m (KernelBody tlore -> SegOp tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (Type -> m Type) -> [Type] -> m [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((SubExp -> m SubExp) -> Type -> m Type
forall (m :: * -> *).
Monad m =>
(SubExp -> m SubExp) -> Type -> m Type
mapOnType ((SubExp -> m SubExp) -> Type -> m Type)
-> (SubExp -> m SubExp) -> Type -> m Type
forall a b. (a -> b) -> a -> b
$ SegOpMapper flore tlore m -> SubExp -> m SubExp
forall flore tlore (m :: * -> *).
SegOpMapper flore tlore m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper flore tlore m
tv) [Type]
ts
m (KernelBody tlore -> SegOp tlore)
-> m (KernelBody tlore) -> m (SegOp tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper flore tlore m
-> KernelBody flore -> m (KernelBody tlore)
forall flore tlore (m :: * -> *).
SegOpMapper flore tlore m
-> KernelBody flore -> m (KernelBody tlore)
mapOnSegOpBody SegOpMapper flore tlore m
tv KernelBody flore
lam
where onSegOp :: SegRedOp flore -> m (SegRedOp tlore)
onSegOp (SegRedOp Commutativity
comm Lambda flore
red_op [SubExp]
nes Shape
shape) =
Commutativity
-> Lambda tlore -> [SubExp] -> Shape -> SegRedOp tlore
forall lore.
Commutativity -> Lambda lore -> [SubExp] -> Shape -> SegRedOp lore
SegRedOp Commutativity
comm
(Lambda tlore -> [SubExp] -> Shape -> SegRedOp tlore)
-> m (Lambda tlore) -> m ([SubExp] -> Shape -> SegRedOp tlore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper flore tlore m -> Lambda flore -> m (Lambda tlore)
forall flore tlore (m :: * -> *).
SegOpMapper flore tlore m -> Lambda flore -> m (Lambda tlore)
mapOnSegOpLambda SegOpMapper flore tlore m
tv Lambda flore
red_op
m ([SubExp] -> Shape -> SegRedOp tlore)
-> m [SubExp] -> m (Shape -> SegRedOp tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SubExp -> m SubExp) -> [SubExp] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SegOpMapper flore tlore m -> SubExp -> m SubExp
forall flore tlore (m :: * -> *).
SegOpMapper flore tlore m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper flore tlore m
tv) [SubExp]
nes
m (Shape -> SegRedOp tlore) -> m Shape -> m (SegRedOp tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> Shape) -> m [SubExp] -> m Shape
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExp -> m SubExp) -> [SubExp] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SegOpMapper flore tlore m -> SubExp -> m SubExp
forall flore tlore (m :: * -> *).
SegOpMapper flore tlore m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper flore tlore m
tv) (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape))
mapSegOpM SegOpMapper flore tlore m
tv (SegScan SegLevel
lvl SegSpace
space Lambda flore
scan_op [SubExp]
nes [Type]
ts KernelBody flore
body) =
SegLevel
-> SegSpace
-> Lambda tlore
-> [SubExp]
-> [Type]
-> KernelBody tlore
-> SegOp tlore
forall lore.
SegLevel
-> SegSpace
-> Lambda lore
-> [SubExp]
-> [Type]
-> KernelBody lore
-> SegOp lore
SegScan
(SegLevel
-> SegSpace
-> Lambda tlore
-> [SubExp]
-> [Type]
-> KernelBody tlore
-> SegOp tlore)
-> m SegLevel
-> m (SegSpace
-> Lambda tlore
-> [SubExp]
-> [Type]
-> KernelBody tlore
-> SegOp tlore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper flore tlore m -> SegLevel -> m SegLevel
forall (m :: * -> *) flore tlore.
Monad m =>
SegOpMapper flore tlore m -> SegLevel -> m SegLevel
mapOnSegLevel SegOpMapper flore tlore m
tv SegLevel
lvl
m (SegSpace
-> Lambda tlore
-> [SubExp]
-> [Type]
-> KernelBody tlore
-> SegOp tlore)
-> m SegSpace
-> m (Lambda tlore
-> [SubExp] -> [Type] -> KernelBody tlore -> SegOp tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper flore tlore m -> SegSpace -> m SegSpace
forall (f :: * -> *) flore tlore.
Monad f =>
SegOpMapper flore tlore f -> SegSpace -> f SegSpace
mapOnSegSpace SegOpMapper flore tlore m
tv SegSpace
space
m (Lambda tlore
-> [SubExp] -> [Type] -> KernelBody tlore -> SegOp tlore)
-> m (Lambda tlore)
-> m ([SubExp] -> [Type] -> KernelBody tlore -> SegOp tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper flore tlore m -> Lambda flore -> m (Lambda tlore)
forall flore tlore (m :: * -> *).
SegOpMapper flore tlore m -> Lambda flore -> m (Lambda tlore)
mapOnSegOpLambda SegOpMapper flore tlore m
tv Lambda flore
scan_op
m ([SubExp] -> [Type] -> KernelBody tlore -> SegOp tlore)
-> m [SubExp] -> m ([Type] -> KernelBody tlore -> SegOp tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SubExp -> m SubExp) -> [SubExp] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SegOpMapper flore tlore m -> SubExp -> m SubExp
forall flore tlore (m :: * -> *).
SegOpMapper flore tlore m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper flore tlore m
tv) [SubExp]
nes
m ([Type] -> KernelBody tlore -> SegOp tlore)
-> m [Type] -> m (KernelBody tlore -> SegOp tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (Type -> m Type) -> [Type] -> m [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((SubExp -> m SubExp) -> Type -> m Type
forall (m :: * -> *).
Monad m =>
(SubExp -> m SubExp) -> Type -> m Type
mapOnType ((SubExp -> m SubExp) -> Type -> m Type)
-> (SubExp -> m SubExp) -> Type -> m Type
forall a b. (a -> b) -> a -> b
$ SegOpMapper flore tlore m -> SubExp -> m SubExp
forall flore tlore (m :: * -> *).
SegOpMapper flore tlore m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper flore tlore m
tv) [Type]
ts
m (KernelBody tlore -> SegOp tlore)
-> m (KernelBody tlore) -> m (SegOp tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper flore tlore m
-> KernelBody flore -> m (KernelBody tlore)
forall flore tlore (m :: * -> *).
SegOpMapper flore tlore m
-> KernelBody flore -> m (KernelBody tlore)
mapOnSegOpBody SegOpMapper flore tlore m
tv KernelBody flore
body
mapSegOpM SegOpMapper flore tlore m
tv (SegHist SegLevel
lvl SegSpace
space [HistOp flore]
ops [Type]
ts KernelBody flore
body) =
SegLevel
-> SegSpace
-> [HistOp tlore]
-> [Type]
-> KernelBody tlore
-> SegOp tlore
forall lore.
SegLevel
-> SegSpace
-> [HistOp lore]
-> [Type]
-> KernelBody lore
-> SegOp lore
SegHist
(SegLevel
-> SegSpace
-> [HistOp tlore]
-> [Type]
-> KernelBody tlore
-> SegOp tlore)
-> m SegLevel
-> m (SegSpace
-> [HistOp tlore] -> [Type] -> KernelBody tlore -> SegOp tlore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper flore tlore m -> SegLevel -> m SegLevel
forall (m :: * -> *) flore tlore.
Monad m =>
SegOpMapper flore tlore m -> SegLevel -> m SegLevel
mapOnSegLevel SegOpMapper flore tlore m
tv SegLevel
lvl
m (SegSpace
-> [HistOp tlore] -> [Type] -> KernelBody tlore -> SegOp tlore)
-> m SegSpace
-> m ([HistOp tlore] -> [Type] -> KernelBody tlore -> SegOp tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper flore tlore m -> SegSpace -> m SegSpace
forall (f :: * -> *) flore tlore.
Monad f =>
SegOpMapper flore tlore f -> SegSpace -> f SegSpace
mapOnSegSpace SegOpMapper flore tlore m
tv SegSpace
space
m ([HistOp tlore] -> [Type] -> KernelBody tlore -> SegOp tlore)
-> m [HistOp tlore]
-> m ([Type] -> KernelBody tlore -> SegOp tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (HistOp flore -> m (HistOp tlore))
-> [HistOp flore] -> m [HistOp tlore]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM HistOp flore -> m (HistOp tlore)
onHistOp [HistOp flore]
ops
m ([Type] -> KernelBody tlore -> SegOp tlore)
-> m [Type] -> m (KernelBody tlore -> SegOp tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (Type -> m Type) -> [Type] -> m [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((SubExp -> m SubExp) -> Type -> m Type
forall (m :: * -> *).
Monad m =>
(SubExp -> m SubExp) -> Type -> m Type
mapOnType ((SubExp -> m SubExp) -> Type -> m Type)
-> (SubExp -> m SubExp) -> Type -> m Type
forall a b. (a -> b) -> a -> b
$ SegOpMapper flore tlore m -> SubExp -> m SubExp
forall flore tlore (m :: * -> *).
SegOpMapper flore tlore m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper flore tlore m
tv) [Type]
ts
m (KernelBody tlore -> SegOp tlore)
-> m (KernelBody tlore) -> m (SegOp tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper flore tlore m
-> KernelBody flore -> m (KernelBody tlore)
forall flore tlore (m :: * -> *).
SegOpMapper flore tlore m
-> KernelBody flore -> m (KernelBody tlore)
mapOnSegOpBody SegOpMapper flore tlore m
tv KernelBody flore
body
where onHistOp :: HistOp flore -> m (HistOp tlore)
onHistOp (HistOp SubExp
w SubExp
rf [VName]
arrs [SubExp]
nes Shape
shape Lambda flore
op) =
SubExp
-> SubExp
-> [VName]
-> [SubExp]
-> Shape
-> Lambda tlore
-> HistOp tlore
forall lore.
SubExp
-> SubExp
-> [VName]
-> [SubExp]
-> Shape
-> Lambda lore
-> HistOp lore
HistOp (SubExp
-> SubExp
-> [VName]
-> [SubExp]
-> Shape
-> Lambda tlore
-> HistOp tlore)
-> m SubExp
-> m (SubExp
-> [VName] -> [SubExp] -> Shape -> Lambda tlore -> HistOp tlore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper flore tlore m -> SubExp -> m SubExp
forall flore tlore (m :: * -> *).
SegOpMapper flore tlore m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper flore tlore m
tv SubExp
w
m (SubExp
-> [VName] -> [SubExp] -> Shape -> Lambda tlore -> HistOp tlore)
-> m SubExp
-> m ([VName] -> [SubExp] -> Shape -> Lambda tlore -> HistOp tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper flore tlore m -> SubExp -> m SubExp
forall flore tlore (m :: * -> *).
SegOpMapper flore tlore m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper flore tlore m
tv SubExp
rf
m ([VName] -> [SubExp] -> Shape -> Lambda tlore -> HistOp tlore)
-> m [VName]
-> m ([SubExp] -> Shape -> Lambda tlore -> HistOp tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (VName -> m VName) -> [VName] -> m [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SegOpMapper flore tlore m -> VName -> m VName
forall flore tlore (m :: * -> *).
SegOpMapper flore tlore m -> VName -> m VName
mapOnSegOpVName SegOpMapper flore tlore m
tv) [VName]
arrs
m ([SubExp] -> Shape -> Lambda tlore -> HistOp tlore)
-> m [SubExp] -> m (Shape -> Lambda tlore -> HistOp tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SubExp -> m SubExp) -> [SubExp] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SegOpMapper flore tlore m -> SubExp -> m SubExp
forall flore tlore (m :: * -> *).
SegOpMapper flore tlore m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper flore tlore m
tv) [SubExp]
nes
m (Shape -> Lambda tlore -> HistOp tlore)
-> m Shape -> m (Lambda tlore -> HistOp tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> Shape) -> m [SubExp] -> m Shape
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExp -> m SubExp) -> [SubExp] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SegOpMapper flore tlore m -> SubExp -> m SubExp
forall flore tlore (m :: * -> *).
SegOpMapper flore tlore m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper flore tlore m
tv) (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape))
m (Lambda tlore -> HistOp tlore)
-> m (Lambda tlore) -> m (HistOp tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper flore tlore m -> Lambda flore -> m (Lambda tlore)
forall flore tlore (m :: * -> *).
SegOpMapper flore tlore m -> Lambda flore -> m (Lambda tlore)
mapOnSegOpLambda SegOpMapper flore tlore m
tv Lambda flore
op
mapOnSegLevel :: Monad m =>
SegOpMapper flore tlore m -> SegLevel -> m SegLevel
mapOnSegLevel :: SegOpMapper flore tlore m -> SegLevel -> m SegLevel
mapOnSegLevel SegOpMapper flore tlore m
tv (SegThread Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegVirt
virt) =
Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegThread
(Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel)
-> m (Count NumGroups SubExp)
-> m (Count GroupSize SubExp -> SegVirt -> SegLevel)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExp -> m SubExp)
-> Count NumGroups SubExp -> m (Count NumGroups SubExp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (SegOpMapper flore tlore m -> SubExp -> m SubExp
forall flore tlore (m :: * -> *).
SegOpMapper flore tlore m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper flore tlore m
tv) Count NumGroups SubExp
num_groups
m (Count GroupSize SubExp -> SegVirt -> SegLevel)
-> m (Count GroupSize SubExp) -> m (SegVirt -> SegLevel)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SubExp -> m SubExp)
-> Count GroupSize SubExp -> m (Count GroupSize SubExp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (SegOpMapper flore tlore m -> SubExp -> m SubExp
forall flore tlore (m :: * -> *).
SegOpMapper flore tlore m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper flore tlore m
tv) Count GroupSize SubExp
group_size
m (SegVirt -> SegLevel) -> m SegVirt -> m SegLevel
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegVirt -> m SegVirt
forall (f :: * -> *) a. Applicative f => a -> f a
pure SegVirt
virt
mapOnSegLevel SegOpMapper flore tlore m
tv (SegGroup Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegVirt
virt) =
Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegGroup
(Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel)
-> m (Count NumGroups SubExp)
-> m (Count GroupSize SubExp -> SegVirt -> SegLevel)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExp -> m SubExp)
-> Count NumGroups SubExp -> m (Count NumGroups SubExp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (SegOpMapper flore tlore m -> SubExp -> m SubExp
forall flore tlore (m :: * -> *).
SegOpMapper flore tlore m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper flore tlore m
tv) Count NumGroups SubExp
num_groups
m (Count GroupSize SubExp -> SegVirt -> SegLevel)
-> m (Count GroupSize SubExp) -> m (SegVirt -> SegLevel)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SubExp -> m SubExp)
-> Count GroupSize SubExp -> m (Count GroupSize SubExp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (SegOpMapper flore tlore m -> SubExp -> m SubExp
forall flore tlore (m :: * -> *).
SegOpMapper flore tlore m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper flore tlore m
tv) Count GroupSize SubExp
group_size
m (SegVirt -> SegLevel) -> m SegVirt -> m SegLevel
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegVirt -> m SegVirt
forall (f :: * -> *) a. Applicative f => a -> f a
pure SegVirt
virt
mapOnSegOpType :: Monad m =>
SegOpMapper flore tlore m -> Type -> m Type
mapOnSegOpType :: SegOpMapper flore tlore m -> Type -> m Type
mapOnSegOpType SegOpMapper flore tlore m
_tv (Prim PrimType
pt) = Type -> m Type
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Type -> m Type) -> Type -> m Type
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
pt
mapOnSegOpType SegOpMapper flore tlore m
tv (Array PrimType
pt Shape
shape NoUniqueness
u) = PrimType -> Shape -> NoUniqueness -> Type
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
pt (Shape -> NoUniqueness -> Type)
-> m Shape -> m (NoUniqueness -> Type)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Shape -> m Shape
f Shape
shape m (NoUniqueness -> Type) -> m NoUniqueness -> m Type
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> NoUniqueness -> m NoUniqueness
forall (f :: * -> *) a. Applicative f => a -> f a
pure NoUniqueness
u
where f :: Shape -> m Shape
f (Shape [SubExp]
dims) = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> Shape) -> m [SubExp] -> m Shape
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExp -> m SubExp) -> [SubExp] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SegOpMapper flore tlore m -> SubExp -> m SubExp
forall flore tlore (m :: * -> *).
SegOpMapper flore tlore m -> SubExp -> m SubExp
mapOnSegOpSubExp SegOpMapper flore tlore m
tv) [SubExp]
dims
mapOnSegOpType SegOpMapper flore tlore m
_tv (Mem Space
s) = Type -> m Type
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Type -> m Type) -> Type -> m Type
forall a b. (a -> b) -> a -> b
$ Space -> Type
forall shape u. Space -> TypeBase shape u
Mem Space
s
instance Attributes lore => Substitute (SegOp lore) where
substituteNames :: Map VName VName -> SegOp lore -> SegOp lore
substituteNames Map VName VName
subst = Identity (SegOp lore) -> SegOp lore
forall a. Identity a -> a
runIdentity (Identity (SegOp lore) -> SegOp lore)
-> (SegOp lore -> Identity (SegOp lore))
-> SegOp lore
-> SegOp lore
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegOpMapper lore lore Identity
-> SegOp lore -> Identity (SegOp lore)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
SegOpMapper flore tlore m -> SegOp flore -> m (SegOp tlore)
mapSegOpM SegOpMapper lore lore Identity
substitute
where substitute :: SegOpMapper lore lore Identity
substitute =
SegOpMapper :: forall flore tlore (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda flore -> m (Lambda tlore))
-> (KernelBody flore -> m (KernelBody tlore))
-> (VName -> m VName)
-> SegOpMapper flore tlore m
SegOpMapper { mapOnSegOpSubExp :: SubExp -> Identity SubExp
mapOnSegOpSubExp = SubExp -> Identity SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp -> Identity SubExp)
-> (SubExp -> SubExp) -> SubExp -> Identity SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst
, mapOnSegOpLambda :: Lambda lore -> Identity (Lambda lore)
mapOnSegOpLambda = Lambda lore -> Identity (Lambda lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda lore -> Identity (Lambda lore))
-> (Lambda lore -> Lambda lore)
-> Lambda lore
-> Identity (Lambda lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map VName VName -> Lambda lore -> Lambda lore
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst
, mapOnSegOpBody :: KernelBody lore -> Identity (KernelBody lore)
mapOnSegOpBody = KernelBody lore -> Identity (KernelBody lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (KernelBody lore -> Identity (KernelBody lore))
-> (KernelBody lore -> KernelBody lore)
-> KernelBody lore
-> Identity (KernelBody lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map VName VName -> KernelBody lore -> KernelBody lore
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst
, mapOnSegOpVName :: VName -> Identity VName
mapOnSegOpVName = VName -> Identity VName
forall (m :: * -> *) a. Monad m => a -> m a
return (VName -> Identity VName)
-> (VName -> VName) -> VName -> Identity VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map VName VName -> VName -> VName
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst
}
instance Attributes lore => Rename (SegOp lore) where
rename :: SegOp lore -> RenameM (SegOp lore)
rename = SegOpMapper lore lore RenameM -> SegOp lore -> RenameM (SegOp lore)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
SegOpMapper flore tlore m -> SegOp flore -> m (SegOp tlore)
mapSegOpM SegOpMapper lore lore RenameM
renamer
where renamer :: SegOpMapper lore lore RenameM
renamer = (SubExp -> RenameM SubExp)
-> (Lambda lore -> RenameM (Lambda lore))
-> (KernelBody lore -> RenameM (KernelBody lore))
-> (VName -> RenameM VName)
-> SegOpMapper lore lore RenameM
forall flore tlore (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda flore -> m (Lambda tlore))
-> (KernelBody flore -> m (KernelBody tlore))
-> (VName -> m VName)
-> SegOpMapper flore tlore m
SegOpMapper SubExp -> RenameM SubExp
forall a. Rename a => a -> RenameM a
rename Lambda lore -> RenameM (Lambda lore)
forall a. Rename a => a -> RenameM a
rename KernelBody lore -> RenameM (KernelBody lore)
forall a. Rename a => a -> RenameM a
rename VName -> RenameM VName
forall a. Rename a => a -> RenameM a
rename
instance (Attributes lore, FreeIn (LParamAttr lore)) =>
FreeIn (SegOp lore) where
freeIn' :: SegOp lore -> FV
freeIn' SegOp lore
e = (State FV (SegOp lore) -> FV -> FV)
-> FV -> State FV (SegOp lore) -> FV
forall a b c. (a -> b -> c) -> b -> a -> c
flip State FV (SegOp lore) -> FV -> FV
forall s a. State s a -> s -> s
execState FV
forall a. Monoid a => a
mempty (State FV (SegOp lore) -> FV) -> State FV (SegOp lore) -> FV
forall a b. (a -> b) -> a -> b
$ SegOpMapper lore lore (StateT FV Identity)
-> SegOp lore -> State FV (SegOp lore)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
SegOpMapper flore tlore m -> SegOp flore -> m (SegOp tlore)
mapSegOpM SegOpMapper lore lore (StateT FV Identity)
free SegOp lore
e
where walk :: (b -> s) -> b -> m b
walk b -> s
f b
x = (s -> s) -> m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (s -> s -> s
forall a. Semigroup a => a -> a -> a
<>b -> s
f b
x) m () -> m b -> m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> b -> m b
forall (m :: * -> *) a. Monad m => a -> m a
return b
x
free :: SegOpMapper lore lore (StateT FV Identity)
free = SegOpMapper :: forall flore tlore (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda flore -> m (Lambda tlore))
-> (KernelBody flore -> m (KernelBody tlore))
-> (VName -> m VName)
-> SegOpMapper flore tlore m
SegOpMapper { mapOnSegOpSubExp :: SubExp -> StateT FV Identity SubExp
mapOnSegOpSubExp = (SubExp -> FV) -> SubExp -> StateT FV Identity SubExp
forall (m :: * -> *) s b.
(MonadState s m, Semigroup s) =>
(b -> s) -> b -> m b
walk SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn'
, mapOnSegOpLambda :: Lambda lore -> StateT FV Identity (Lambda lore)
mapOnSegOpLambda = (Lambda lore -> FV)
-> Lambda lore -> StateT FV Identity (Lambda lore)
forall (m :: * -> *) s b.
(MonadState s m, Semigroup s) =>
(b -> s) -> b -> m b
walk Lambda lore -> FV
forall a. FreeIn a => a -> FV
freeIn'
, mapOnSegOpBody :: KernelBody lore -> StateT FV Identity (KernelBody lore)
mapOnSegOpBody = (KernelBody lore -> FV)
-> KernelBody lore -> StateT FV Identity (KernelBody lore)
forall (m :: * -> *) s b.
(MonadState s m, Semigroup s) =>
(b -> s) -> b -> m b
walk KernelBody lore -> FV
forall a. FreeIn a => a -> FV
freeIn'
, mapOnSegOpVName :: VName -> StateT FV Identity VName
mapOnSegOpVName = (VName -> FV) -> VName -> StateT FV Identity VName
forall (m :: * -> *) s b.
(MonadState s m, Semigroup s) =>
(b -> s) -> b -> m b
walk VName -> FV
forall a. FreeIn a => a -> FV
freeIn'
}
instance OpMetrics (Op lore) => OpMetrics (SegOp lore) where
opMetrics :: SegOp lore -> MetricsM ()
opMetrics (SegMap SegLevel
_ SegSpace
_ [Type]
_ KernelBody lore
body) =
Text -> MetricsM () -> MetricsM ()
inside Text
"SegMap" (MetricsM () -> MetricsM ()) -> MetricsM () -> MetricsM ()
forall a b. (a -> b) -> a -> b
$ KernelBody lore -> MetricsM ()
forall lore. OpMetrics (Op lore) => KernelBody lore -> MetricsM ()
kernelBodyMetrics KernelBody lore
body
opMetrics (SegRed SegLevel
_ SegSpace
_ [SegRedOp lore]
reds [Type]
_ KernelBody lore
body) =
Text -> MetricsM () -> MetricsM ()
inside Text
"SegRed" (MetricsM () -> MetricsM ()) -> MetricsM () -> MetricsM ()
forall a b. (a -> b) -> a -> b
$ do (SegRedOp lore -> MetricsM ()) -> [SegRedOp lore] -> MetricsM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Lambda lore -> MetricsM ()
forall lore. OpMetrics (Op lore) => Lambda lore -> MetricsM ()
lambdaMetrics (Lambda lore -> MetricsM ())
-> (SegRedOp lore -> Lambda lore) -> SegRedOp lore -> MetricsM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegRedOp lore -> Lambda lore
forall lore. SegRedOp lore -> Lambda lore
segRedLambda) [SegRedOp lore]
reds
KernelBody lore -> MetricsM ()
forall lore. OpMetrics (Op lore) => KernelBody lore -> MetricsM ()
kernelBodyMetrics KernelBody lore
body
opMetrics (SegScan SegLevel
_ SegSpace
_ Lambda lore
scan_op [SubExp]
_ [Type]
_ KernelBody lore
body) =
Text -> MetricsM () -> MetricsM ()
inside Text
"SegScan" (MetricsM () -> MetricsM ()) -> MetricsM () -> MetricsM ()
forall a b. (a -> b) -> a -> b
$ Lambda lore -> MetricsM ()
forall lore. OpMetrics (Op lore) => Lambda lore -> MetricsM ()
lambdaMetrics Lambda lore
scan_op MetricsM () -> MetricsM () -> MetricsM ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> KernelBody lore -> MetricsM ()
forall lore. OpMetrics (Op lore) => KernelBody lore -> MetricsM ()
kernelBodyMetrics KernelBody lore
body
opMetrics (SegHist SegLevel
_ SegSpace
_ [HistOp lore]
ops [Type]
_ KernelBody lore
body) =
Text -> MetricsM () -> MetricsM ()
inside Text
"SegHist" (MetricsM () -> MetricsM ()) -> MetricsM () -> MetricsM ()
forall a b. (a -> b) -> a -> b
$ do (HistOp lore -> MetricsM ()) -> [HistOp lore] -> MetricsM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Lambda lore -> MetricsM ()
forall lore. OpMetrics (Op lore) => Lambda lore -> MetricsM ()
lambdaMetrics (Lambda lore -> MetricsM ())
-> (HistOp lore -> Lambda lore) -> HistOp lore -> MetricsM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp lore -> Lambda lore
forall lore. HistOp lore -> Lambda lore
histOp) [HistOp lore]
ops
KernelBody lore -> MetricsM ()
forall lore. OpMetrics (Op lore) => KernelBody lore -> MetricsM ()
kernelBodyMetrics KernelBody lore
body
instance Pretty SegSpace where
ppr :: SegSpace -> Doc
ppr (SegSpace VName
phys [(VName, SubExp)]
dims) = Doc -> Doc
parens ([Doc] -> Doc
commasep ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ do (VName
i,SubExp
d) <- [(VName, SubExp)]
dims
Doc -> [Doc]
forall (m :: * -> *) a. Monad m => a -> m a
return (Doc -> [Doc]) -> Doc -> [Doc]
forall a b. (a -> b) -> a -> b
$ VName -> Doc
forall a. Pretty a => a -> Doc
ppr VName
i Doc -> Doc -> Doc
<+> Doc
"<" Doc -> Doc -> Doc
<+> SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
d) Doc -> Doc -> Doc
<+>
Doc -> Doc
parens (String -> Doc
text String
"~" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> VName -> Doc
forall a. Pretty a => a -> Doc
ppr VName
phys)
instance PP.Pretty SegLevel where
ppr :: SegLevel -> Doc
ppr SegThread{} = Doc
"thread"
ppr SegGroup{} = Doc
"group"
ppSegLevel :: SegLevel -> PP.Doc
ppSegLevel :: SegLevel -> Doc
ppSegLevel SegLevel
lvl =
Doc -> Doc
PP.parens (Doc -> Doc) -> Doc -> Doc
forall a b. (a -> b) -> a -> b
$
String -> Doc
text String
"#groups=" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Count NumGroups SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr (SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.semi Doc -> Doc -> Doc
<+>
String -> Doc
text String
"groupsize=" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Count GroupSize SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr (SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<>
case SegLevel -> SegVirt
segVirt SegLevel
lvl of
SegVirt
SegNoVirt -> Doc
forall a. Monoid a => a
mempty
SegVirt
SegVirt -> Doc
PP.semi Doc -> Doc -> Doc
<+> String -> Doc
text String
"virtualise"
instance PrettyLore lore => PP.Pretty (SegOp lore) where
ppr :: SegOp lore -> Doc
ppr (SegMap SegLevel
lvl SegSpace
space [Type]
ts KernelBody lore
body) =
String -> Doc
text String
"segmap_" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> SegLevel -> Doc
forall a. Pretty a => a -> Doc
ppr SegLevel
lvl Doc -> Doc -> Doc
</>
SegLevel -> Doc
ppSegLevel SegLevel
lvl Doc -> Doc -> Doc
</>
Doc -> Doc
PP.align (SegSpace -> Doc
forall a. Pretty a => a -> Doc
ppr SegSpace
space) Doc -> Doc -> Doc
<+>
Doc
PP.colon Doc -> Doc -> Doc
<+> [Type] -> Doc
forall a. Pretty a => [a] -> Doc
ppTuple' [Type]
ts Doc -> Doc -> Doc
<+> String -> String -> Doc -> Doc
PP.nestedBlock String
"{" String
"}" (KernelBody lore -> Doc
forall a. Pretty a => a -> Doc
ppr KernelBody lore
body)
ppr (SegRed SegLevel
lvl SegSpace
space [SegRedOp lore]
reds [Type]
ts KernelBody lore
body) =
String -> Doc
text String
"segred_" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> SegLevel -> Doc
forall a. Pretty a => a -> Doc
ppr SegLevel
lvl Doc -> Doc -> Doc
</>
SegLevel -> Doc
ppSegLevel SegLevel
lvl Doc -> Doc -> Doc
</>
Doc -> Doc
PP.parens (Doc -> Doc
PP.braces ([Doc] -> Doc
forall a. Monoid a => [a] -> a
mconcat ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ Doc -> [Doc] -> [Doc]
forall a. a -> [a] -> [a]
intersperse (Doc
PP.comma Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.line) ([Doc] -> [Doc]) -> [Doc] -> [Doc]
forall a b. (a -> b) -> a -> b
$ (SegRedOp lore -> Doc) -> [SegRedOp lore] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map SegRedOp lore -> Doc
forall lore. PrettyLore lore => SegRedOp lore -> Doc
ppOp [SegRedOp lore]
reds)) Doc -> Doc -> Doc
</>
Doc -> Doc
PP.align (SegSpace -> Doc
forall a. Pretty a => a -> Doc
ppr SegSpace
space) Doc -> Doc -> Doc
<+> Doc
PP.colon Doc -> Doc -> Doc
<+> [Type] -> Doc
forall a. Pretty a => [a] -> Doc
ppTuple' [Type]
ts Doc -> Doc -> Doc
<+>
String -> String -> Doc -> Doc
PP.nestedBlock String
"{" String
"}" (KernelBody lore -> Doc
forall a. Pretty a => a -> Doc
ppr KernelBody lore
body)
where ppOp :: SegRedOp lore -> Doc
ppOp (SegRedOp Commutativity
comm Lambda lore
lam [SubExp]
nes Shape
shape) =
Doc -> Doc
PP.braces ([Doc] -> Doc
PP.commasep ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ (SubExp -> Doc) -> [SubExp] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr [SubExp]
nes) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.comma Doc -> Doc -> Doc
</>
Shape -> Doc
forall a. Pretty a => a -> Doc
ppr Shape
shape Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.comma Doc -> Doc -> Doc
</>
Doc
comm' Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Lambda lore -> Doc
forall a. Pretty a => a -> Doc
ppr Lambda lore
lam
where comm' :: Doc
comm' = case Commutativity
comm of Commutativity
Commutative -> String -> Doc
text String
"commutative "
Commutativity
Noncommutative -> Doc
forall a. Monoid a => a
mempty
ppr (SegScan SegLevel
lvl SegSpace
space Lambda lore
scan_op [SubExp]
nes [Type]
ts KernelBody lore
body) =
String -> Doc
text String
"segscan_" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> SegLevel -> Doc
forall a. Pretty a => a -> Doc
ppr SegLevel
lvl Doc -> Doc -> Doc
</>
SegLevel -> Doc
ppSegLevel SegLevel
lvl Doc -> Doc -> Doc
</>
Doc -> Doc
PP.parens (Lambda lore -> Doc
forall a. Pretty a => a -> Doc
ppr Lambda lore
scan_op Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.comma Doc -> Doc -> Doc
</>
Doc -> Doc
PP.braces ([Doc] -> Doc
PP.commasep ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ (SubExp -> Doc) -> [SubExp] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr [SubExp]
nes)) Doc -> Doc -> Doc
</>
Doc -> Doc
PP.align (SegSpace -> Doc
forall a. Pretty a => a -> Doc
ppr SegSpace
space) Doc -> Doc -> Doc
<+> Doc
PP.colon Doc -> Doc -> Doc
<+> [Type] -> Doc
forall a. Pretty a => [a] -> Doc
ppTuple' [Type]
ts Doc -> Doc -> Doc
<+>
String -> String -> Doc -> Doc
PP.nestedBlock String
"{" String
"}" (KernelBody lore -> Doc
forall a. Pretty a => a -> Doc
ppr KernelBody lore
body)
ppr (SegHist SegLevel
lvl SegSpace
space [HistOp lore]
ops [Type]
ts KernelBody lore
body) =
String -> Doc
text String
"seghist_" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> SegLevel -> Doc
forall a. Pretty a => a -> Doc
ppr SegLevel
lvl Doc -> Doc -> Doc
</>
SegLevel -> Doc
ppSegLevel SegLevel
lvl Doc -> Doc -> Doc
</>
Doc -> Doc
PP.parens (Doc -> Doc
PP.braces ([Doc] -> Doc
forall a. Monoid a => [a] -> a
mconcat ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ Doc -> [Doc] -> [Doc]
forall a. a -> [a] -> [a]
intersperse (Doc
PP.comma Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.line) ([Doc] -> [Doc]) -> [Doc] -> [Doc]
forall a b. (a -> b) -> a -> b
$ (HistOp lore -> Doc) -> [HistOp lore] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map HistOp lore -> Doc
forall lore. PrettyLore lore => HistOp lore -> Doc
ppOp [HistOp lore]
ops)) Doc -> Doc -> Doc
</>
Doc -> Doc
PP.align (SegSpace -> Doc
forall a. Pretty a => a -> Doc
ppr SegSpace
space) Doc -> Doc -> Doc
<+> Doc
PP.colon Doc -> Doc -> Doc
<+> [Type] -> Doc
forall a. Pretty a => [a] -> Doc
ppTuple' [Type]
ts Doc -> Doc -> Doc
<+>
String -> String -> Doc -> Doc
PP.nestedBlock String
"{" String
"}" (KernelBody lore -> Doc
forall a. Pretty a => a -> Doc
ppr KernelBody lore
body)
where ppOp :: HistOp lore -> Doc
ppOp (HistOp SubExp
w SubExp
rf [VName]
dests [SubExp]
nes Shape
shape Lambda lore
op) =
SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
w Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.comma Doc -> Doc -> Doc
<+> SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
rf Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.comma Doc -> Doc -> Doc
</>
Doc -> Doc
PP.braces ([Doc] -> Doc
PP.commasep ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ (VName -> Doc) -> [VName] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Doc
forall a. Pretty a => a -> Doc
ppr [VName]
dests) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.comma Doc -> Doc -> Doc
</>
Doc -> Doc
PP.braces ([Doc] -> Doc
PP.commasep ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ (SubExp -> Doc) -> [SubExp] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr [SubExp]
nes) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.comma Doc -> Doc -> Doc
</>
Shape -> Doc
forall a. Pretty a => a -> Doc
ppr Shape
shape Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.comma Doc -> Doc -> Doc
</>
Lambda lore -> Doc
forall a. Pretty a => a -> Doc
ppr Lambda lore
op
instance Attributes inner => RangedOp (SegOp inner) where
opRanges :: SegOp inner -> [Range]
opRanges SegOp inner
op = Int -> Range -> [Range]
forall a. Int -> a -> [a]
replicate ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([Type] -> Int) -> [Type] -> Int
forall a b. (a -> b) -> a -> b
$ SegOp inner -> [Type]
forall lore. SegOp lore -> [Type]
segOpType SegOp inner
op) Range
unknownRange
instance (Attributes lore, CanBeRanged (Op lore)) => CanBeRanged (SegOp lore) where
type OpWithRanges (SegOp lore) = SegOp (Ranges lore)
removeOpRanges :: OpWithRanges (SegOp lore) -> SegOp lore
removeOpRanges = Identity (SegOp lore) -> SegOp lore
forall a. Identity a -> a
runIdentity (Identity (SegOp lore) -> SegOp lore)
-> (SegOp (Ranges lore) -> Identity (SegOp lore))
-> SegOp (Ranges lore)
-> SegOp lore
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegOpMapper (Ranges lore) lore Identity
-> SegOp (Ranges lore) -> Identity (SegOp lore)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
SegOpMapper flore tlore m -> SegOp flore -> m (SegOp tlore)
mapSegOpM SegOpMapper (Ranges lore) lore Identity
remove
where remove :: SegOpMapper (Ranges lore) lore Identity
remove = (SubExp -> Identity SubExp)
-> (Lambda (Ranges lore) -> Identity (Lambda lore))
-> (KernelBody (Ranges lore) -> Identity (KernelBody lore))
-> (VName -> Identity VName)
-> SegOpMapper (Ranges lore) lore Identity
forall flore tlore (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda flore -> m (Lambda tlore))
-> (KernelBody flore -> m (KernelBody tlore))
-> (VName -> m VName)
-> SegOpMapper flore tlore m
SegOpMapper SubExp -> Identity SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda lore -> Identity (Lambda lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda lore -> Identity (Lambda lore))
-> (Lambda (Ranges lore) -> Lambda lore)
-> Lambda (Ranges lore)
-> Identity (Lambda lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda (Ranges lore) -> Lambda lore
forall lore.
CanBeRanged (Op lore) =>
Lambda (Ranges lore) -> Lambda lore
removeLambdaRanges)
(KernelBody lore -> Identity (KernelBody lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (KernelBody lore -> Identity (KernelBody lore))
-> (KernelBody (Ranges lore) -> KernelBody lore)
-> KernelBody (Ranges lore)
-> Identity (KernelBody lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelBody (Ranges lore) -> KernelBody lore
forall lore.
CanBeRanged (Op lore) =>
KernelBody (Ranges lore) -> KernelBody lore
removeKernelBodyRanges) VName -> Identity VName
forall (m :: * -> *) a. Monad m => a -> m a
return
addOpRanges :: SegOp lore -> OpWithRanges (SegOp lore)
addOpRanges = RangeM (SegOp (Ranges lore)) -> SegOp (Ranges lore)
forall a. RangeM a -> a
Range.runRangeM (RangeM (SegOp (Ranges lore)) -> SegOp (Ranges lore))
-> (SegOp lore -> RangeM (SegOp (Ranges lore)))
-> SegOp lore
-> SegOp (Ranges lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegOpMapper lore (Ranges lore) (ReaderT RangeEnv Identity)
-> SegOp lore -> RangeM (SegOp (Ranges lore))
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
SegOpMapper flore tlore m -> SegOp flore -> m (SegOp tlore)
mapSegOpM SegOpMapper lore (Ranges lore) (ReaderT RangeEnv Identity)
add
where add :: SegOpMapper lore (Ranges lore) (ReaderT RangeEnv Identity)
add = (SubExp -> ReaderT RangeEnv Identity SubExp)
-> (Lambda lore
-> ReaderT RangeEnv Identity (Lambda (Ranges lore)))
-> (KernelBody lore
-> ReaderT RangeEnv Identity (KernelBody (Ranges lore)))
-> (VName -> ReaderT RangeEnv Identity VName)
-> SegOpMapper lore (Ranges lore) (ReaderT RangeEnv Identity)
forall flore tlore (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda flore -> m (Lambda tlore))
-> (KernelBody flore -> m (KernelBody tlore))
-> (VName -> m VName)
-> SegOpMapper flore tlore m
SegOpMapper SubExp -> ReaderT RangeEnv Identity SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return Lambda lore -> ReaderT RangeEnv Identity (Lambda (Ranges lore))
forall lore.
(Attributes lore, CanBeRanged (Op lore)) =>
Lambda lore -> RangeM (Lambda (Ranges lore))
Range.analyseLambda
KernelBody lore
-> ReaderT RangeEnv Identity (KernelBody (Ranges lore))
forall lore.
(Attributes lore, CanBeRanged (Op lore)) =>
KernelBody lore -> RangeM (KernelBody (Ranges lore))
addKernelBodyRanges VName -> ReaderT RangeEnv Identity VName
forall (m :: * -> *) a. Monad m => a -> m a
return
instance (Attributes lore,
Attributes (Aliases lore),
CanBeAliased (Op lore)) => CanBeAliased (SegOp lore) where
type OpWithAliases (SegOp lore) = SegOp (Aliases lore)
addOpAliases :: SegOp lore -> OpWithAliases (SegOp lore)
addOpAliases = Identity (SegOp (Aliases lore)) -> SegOp (Aliases lore)
forall a. Identity a -> a
runIdentity (Identity (SegOp (Aliases lore)) -> SegOp (Aliases lore))
-> (SegOp lore -> Identity (SegOp (Aliases lore)))
-> SegOp lore
-> SegOp (Aliases lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegOpMapper lore (Aliases lore) Identity
-> SegOp lore -> Identity (SegOp (Aliases lore))
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
SegOpMapper flore tlore m -> SegOp flore -> m (SegOp tlore)
mapSegOpM SegOpMapper lore (Aliases lore) Identity
alias
where alias :: SegOpMapper lore (Aliases lore) Identity
alias = (SubExp -> Identity SubExp)
-> (Lambda lore -> Identity (Lambda (Aliases lore)))
-> (KernelBody lore -> Identity (KernelBody (Aliases lore)))
-> (VName -> Identity VName)
-> SegOpMapper lore (Aliases lore) Identity
forall flore tlore (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda flore -> m (Lambda tlore))
-> (KernelBody flore -> m (KernelBody tlore))
-> (VName -> m VName)
-> SegOpMapper flore tlore m
SegOpMapper SubExp -> Identity SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda (Aliases lore) -> Identity (Lambda (Aliases lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda (Aliases lore) -> Identity (Lambda (Aliases lore)))
-> (Lambda lore -> Lambda (Aliases lore))
-> Lambda lore
-> Identity (Lambda (Aliases lore))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda lore -> Lambda (Aliases lore)
forall lore.
(Attributes lore, CanBeAliased (Op lore)) =>
Lambda lore -> Lambda (Aliases lore)
Alias.analyseLambda)
(KernelBody (Aliases lore) -> Identity (KernelBody (Aliases lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (KernelBody (Aliases lore) -> Identity (KernelBody (Aliases lore)))
-> (KernelBody lore -> KernelBody (Aliases lore))
-> KernelBody lore
-> Identity (KernelBody (Aliases lore))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelBody lore -> KernelBody (Aliases lore)
forall lore.
(Attributes lore, CanBeAliased (Op lore)) =>
KernelBody lore -> KernelBody (Aliases lore)
aliasAnalyseKernelBody) VName -> Identity VName
forall (m :: * -> *) a. Monad m => a -> m a
return
removeOpAliases :: OpWithAliases (SegOp lore) -> SegOp lore
removeOpAliases = Identity (SegOp lore) -> SegOp lore
forall a. Identity a -> a
runIdentity (Identity (SegOp lore) -> SegOp lore)
-> (SegOp (Aliases lore) -> Identity (SegOp lore))
-> SegOp (Aliases lore)
-> SegOp lore
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegOpMapper (Aliases lore) lore Identity
-> SegOp (Aliases lore) -> Identity (SegOp lore)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
SegOpMapper flore tlore m -> SegOp flore -> m (SegOp tlore)
mapSegOpM SegOpMapper (Aliases lore) lore Identity
remove
where remove :: SegOpMapper (Aliases lore) lore Identity
remove = (SubExp -> Identity SubExp)
-> (Lambda (Aliases lore) -> Identity (Lambda lore))
-> (KernelBody (Aliases lore) -> Identity (KernelBody lore))
-> (VName -> Identity VName)
-> SegOpMapper (Aliases lore) lore Identity
forall flore tlore (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda flore -> m (Lambda tlore))
-> (KernelBody flore -> m (KernelBody tlore))
-> (VName -> m VName)
-> SegOpMapper flore tlore m
SegOpMapper SubExp -> Identity SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda lore -> Identity (Lambda lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda lore -> Identity (Lambda lore))
-> (Lambda (Aliases lore) -> Lambda lore)
-> Lambda (Aliases lore)
-> Identity (Lambda lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda (Aliases lore) -> Lambda lore
forall lore.
CanBeAliased (Op lore) =>
Lambda (Aliases lore) -> Lambda lore
removeLambdaAliases)
(KernelBody lore -> Identity (KernelBody lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (KernelBody lore -> Identity (KernelBody lore))
-> (KernelBody (Aliases lore) -> KernelBody lore)
-> KernelBody (Aliases lore)
-> Identity (KernelBody lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelBody (Aliases lore) -> KernelBody lore
forall lore.
CanBeAliased (Op lore) =>
KernelBody (Aliases lore) -> KernelBody lore
removeKernelBodyAliases) VName -> Identity VName
forall (m :: * -> *) a. Monad m => a -> m a
return
instance (CanBeWise (Op lore), Attributes lore) => CanBeWise (SegOp lore) where
type OpWithWisdom (SegOp lore) = SegOp (Wise lore)
removeOpWisdom :: OpWithWisdom (SegOp lore) -> SegOp lore
removeOpWisdom = Identity (SegOp lore) -> SegOp lore
forall a. Identity a -> a
runIdentity (Identity (SegOp lore) -> SegOp lore)
-> (SegOp (Wise lore) -> Identity (SegOp lore))
-> SegOp (Wise lore)
-> SegOp lore
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegOpMapper (Wise lore) lore Identity
-> SegOp (Wise lore) -> Identity (SegOp lore)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
SegOpMapper flore tlore m -> SegOp flore -> m (SegOp tlore)
mapSegOpM SegOpMapper (Wise lore) lore Identity
remove
where remove :: SegOpMapper (Wise lore) lore Identity
remove = (SubExp -> Identity SubExp)
-> (Lambda (Wise lore) -> Identity (Lambda lore))
-> (KernelBody (Wise lore) -> Identity (KernelBody lore))
-> (VName -> Identity VName)
-> SegOpMapper (Wise lore) lore Identity
forall flore tlore (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda flore -> m (Lambda tlore))
-> (KernelBody flore -> m (KernelBody tlore))
-> (VName -> m VName)
-> SegOpMapper flore tlore m
SegOpMapper SubExp -> Identity SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return
(Lambda lore -> Identity (Lambda lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda lore -> Identity (Lambda lore))
-> (Lambda (Wise lore) -> Lambda lore)
-> Lambda (Wise lore)
-> Identity (Lambda lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda (Wise lore) -> Lambda lore
forall lore.
CanBeWise (Op lore) =>
Lambda (Wise lore) -> Lambda lore
removeLambdaWisdom)
(KernelBody lore -> Identity (KernelBody lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (KernelBody lore -> Identity (KernelBody lore))
-> (KernelBody (Wise lore) -> KernelBody lore)
-> KernelBody (Wise lore)
-> Identity (KernelBody lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelBody (Wise lore) -> KernelBody lore
forall lore.
CanBeWise (Op lore) =>
KernelBody (Wise lore) -> KernelBody lore
removeKernelBodyWisdom)
VName -> Identity VName
forall (m :: * -> *) a. Monad m => a -> m a
return
instance Attributes lore => ST.IndexOp (SegOp lore) where
indexOp :: SymbolTable lore
-> Int -> SegOp lore -> [PrimExp VName] -> Maybe Indexed
indexOp SymbolTable lore
vtable Int
k (SegMap SegLevel
_ SegSpace
space [Type]
_ KernelBody lore
kbody) [PrimExp VName]
is = do
Returns ResultManifest
ResultMaySimplify SubExp
se <- Int -> [KernelResult] -> Maybe KernelResult
forall int a. Integral int => int -> [a] -> Maybe a
maybeNth Int
k ([KernelResult] -> Maybe KernelResult)
-> [KernelResult] -> Maybe KernelResult
forall a b. (a -> b) -> a -> b
$ KernelBody lore -> [KernelResult]
forall lore. KernelBody lore -> [KernelResult]
kernelBodyResult KernelBody lore
kbody
Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ [VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
gtids Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= [PrimExp VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PrimExp VName]
is
let idx_table :: Map VName Indexed
idx_table = [(VName, Indexed)] -> Map VName Indexed
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, Indexed)] -> Map VName Indexed)
-> [(VName, Indexed)] -> Map VName Indexed
forall a b. (a -> b) -> a -> b
$ [VName] -> [Indexed] -> [(VName, Indexed)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
gtids ([Indexed] -> [(VName, Indexed)])
-> [Indexed] -> [(VName, Indexed)]
forall a b. (a -> b) -> a -> b
$ (PrimExp VName -> Indexed) -> [PrimExp VName] -> [Indexed]
forall a b. (a -> b) -> [a] -> [b]
map (Certificates -> PrimExp VName -> Indexed
ST.Indexed Certificates
forall a. Monoid a => a
mempty) [PrimExp VName]
is
idx_table' :: Map VName Indexed
idx_table' = (Map VName Indexed -> Stm lore -> Map VName Indexed)
-> Map VName Indexed -> Seq (Stm lore) -> Map VName Indexed
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Map VName Indexed -> Stm lore -> Map VName Indexed
expandIndexedTable Map VName Indexed
idx_table (Seq (Stm lore) -> Map VName Indexed)
-> Seq (Stm lore) -> Map VName Indexed
forall a b. (a -> b) -> a -> b
$ KernelBody lore -> Seq (Stm lore)
forall lore. KernelBody lore -> Stms lore
kernelBodyStms KernelBody lore
kbody
case SubExp
se of
Var VName
v -> VName -> Map VName Indexed -> Maybe Indexed
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName Indexed
idx_table'
SubExp
_ -> Maybe Indexed
forall a. Maybe a
Nothing
where ([VName]
gtids, [SubExp]
_) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
excess_is :: [PrimExp VName]
excess_is = Int -> [PrimExp VName] -> [PrimExp VName]
forall a. Int -> [a] -> [a]
drop ([VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
gtids) [PrimExp VName]
is
expandIndexedTable :: Map VName Indexed -> Stm lore -> Map VName Indexed
expandIndexedTable Map VName Indexed
table Stm lore
stm
| [VName
v] <- PatternT (LetAttr lore) -> [VName]
forall attr. PatternT attr -> [VName]
patternNames (PatternT (LetAttr lore) -> [VName])
-> PatternT (LetAttr lore) -> [VName]
forall a b. (a -> b) -> a -> b
$ Stm lore -> PatternT (LetAttr lore)
forall lore. Stm lore -> Pattern lore
stmPattern Stm lore
stm,
Just (PrimExp VName
pe,Certificates
cs) <-
WriterT Certificates Maybe (PrimExp VName)
-> Maybe (PrimExp VName, Certificates)
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT Certificates Maybe (PrimExp VName)
-> Maybe (PrimExp VName, Certificates))
-> WriterT Certificates Maybe (PrimExp VName)
-> Maybe (PrimExp VName, Certificates)
forall a b. (a -> b) -> a -> b
$ (VName -> WriterT Certificates Maybe (PrimExp VName))
-> Exp lore -> WriterT Certificates Maybe (PrimExp VName)
forall (m :: * -> *) lore v.
(MonadFail m, Annotations lore) =>
(VName -> m (PrimExp v)) -> Exp lore -> m (PrimExp v)
primExpFromExp (Map VName Indexed
-> VName -> WriterT Certificates Maybe (PrimExp VName)
asPrimExp Map VName Indexed
table) (Exp lore -> WriterT Certificates Maybe (PrimExp VName))
-> Exp lore -> WriterT Certificates Maybe (PrimExp VName)
forall a b. (a -> b) -> a -> b
$ Stm lore -> Exp lore
forall lore. Stm lore -> Exp lore
stmExp Stm lore
stm =
VName -> Indexed -> Map VName Indexed -> Map VName Indexed
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v (Certificates -> PrimExp VName -> Indexed
ST.Indexed (Stm lore -> Certificates
forall lore. Stm lore -> Certificates
stmCerts Stm lore
stm Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> Certificates
cs) PrimExp VName
pe) Map VName Indexed
table
| [VName
v] <- PatternT (LetAttr lore) -> [VName]
forall attr. PatternT attr -> [VName]
patternNames (PatternT (LetAttr lore) -> [VName])
-> PatternT (LetAttr lore) -> [VName]
forall a b. (a -> b) -> a -> b
$ Stm lore -> PatternT (LetAttr lore)
forall lore. Stm lore -> Pattern lore
stmPattern Stm lore
stm,
BasicOp (Index VName
arr Slice SubExp
slice) <- Stm lore -> Exp lore
forall lore. Stm lore -> Exp lore
stmExp Stm lore
stm,
[SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (Slice SubExp -> [SubExp]
forall d. Slice d -> [d]
sliceDims Slice SubExp
slice) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [PrimExp VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PrimExp VName]
excess_is,
VName
arr VName -> SymbolTable lore -> Bool
forall lore. VName -> SymbolTable lore -> Bool
`ST.elem` SymbolTable lore
vtable,
Just ([DimIndex (PrimExp VName)]
slice', Certificates
cs) <- Map VName Indexed
-> Slice SubExp -> Maybe ([DimIndex (PrimExp VName)], Certificates)
asPrimExpSlice Map VName Indexed
table Slice SubExp
slice =
let idx :: Indexed
idx = Certificates -> VName -> [PrimExp VName] -> Indexed
ST.IndexedArray (Stm lore -> Certificates
forall lore. Stm lore -> Certificates
stmCerts Stm lore
stm Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> Certificates
cs)
VName
arr ([DimIndex (PrimExp VName)] -> [PrimExp VName] -> [PrimExp VName]
forall d. Num d => Slice d -> [d] -> [d]
fixSlice [DimIndex (PrimExp VName)]
slice' [PrimExp VName]
excess_is)
in VName -> Indexed -> Map VName Indexed -> Map VName Indexed
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v Indexed
idx Map VName Indexed
table
| Bool
otherwise =
Map VName Indexed
table
asPrimExpSlice :: Map VName Indexed
-> Slice SubExp -> Maybe ([DimIndex (PrimExp VName)], Certificates)
asPrimExpSlice Map VName Indexed
table =
WriterT Certificates Maybe [DimIndex (PrimExp VName)]
-> Maybe ([DimIndex (PrimExp VName)], Certificates)
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT Certificates Maybe [DimIndex (PrimExp VName)]
-> Maybe ([DimIndex (PrimExp VName)], Certificates))
-> (Slice SubExp
-> WriterT Certificates Maybe [DimIndex (PrimExp VName)])
-> Slice SubExp
-> Maybe ([DimIndex (PrimExp VName)], Certificates)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (DimIndex SubExp
-> WriterT Certificates Maybe (DimIndex (PrimExp VName)))
-> Slice SubExp
-> WriterT Certificates Maybe [DimIndex (PrimExp VName)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((SubExp -> WriterT Certificates Maybe (PrimExp VName))
-> DimIndex SubExp
-> WriterT Certificates Maybe (DimIndex (PrimExp VName))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse ((VName -> WriterT Certificates Maybe (PrimExp VName))
-> SubExp -> WriterT Certificates Maybe (PrimExp VName)
forall (m :: * -> *) v.
Applicative m =>
(VName -> m (PrimExp v)) -> SubExp -> m (PrimExp v)
primExpFromSubExpM (Map VName Indexed
-> VName -> WriterT Certificates Maybe (PrimExp VName)
asPrimExp Map VName Indexed
table)))
asPrimExp :: Map VName Indexed
-> VName -> WriterT Certificates Maybe (PrimExp VName)
asPrimExp Map VName Indexed
table VName
v
| Just (ST.Indexed Certificates
cs PrimExp VName
e) <- VName -> Map VName Indexed -> Maybe Indexed
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName Indexed
table = Certificates -> WriterT Certificates Maybe ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell Certificates
cs WriterT Certificates Maybe ()
-> WriterT Certificates Maybe (PrimExp VName)
-> WriterT Certificates Maybe (PrimExp VName)
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> PrimExp VName -> WriterT Certificates Maybe (PrimExp VName)
forall (m :: * -> *) a. Monad m => a -> m a
return PrimExp VName
e
| Just (Prim PrimType
pt) <- VName -> SymbolTable lore -> Maybe Type
forall lore.
Attributes lore =>
VName -> SymbolTable lore -> Maybe Type
ST.lookupType VName
v SymbolTable lore
vtable =
PrimExp VName -> WriterT Certificates Maybe (PrimExp VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (PrimExp VName -> WriterT Certificates Maybe (PrimExp VName))
-> PrimExp VName -> WriterT Certificates Maybe (PrimExp VName)
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> PrimExp VName
forall v. v -> PrimType -> PrimExp v
LeafExp VName
v PrimType
pt
| Bool
otherwise = Maybe (PrimExp VName) -> WriterT Certificates Maybe (PrimExp VName)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift Maybe (PrimExp VName)
forall a. Maybe a
Nothing
indexOp SymbolTable lore
_ Int
_ SegOp lore
_ [PrimExp VName]
_ = Maybe Indexed
forall a. Maybe a
Nothing
instance Attributes lore => IsOp (SegOp lore) where
cheapOp :: SegOp lore -> Bool
cheapOp SegOp lore
_ = Bool
False
safeOp :: SegOp lore -> Bool
safeOp SegOp lore
_ = Bool
True
data SizeOp
= SplitSpace SplitOrdering SubExp SubExp SubExp
| GetSize Name SizeClass
| GetSizeMax SizeClass
| CmpSizeLe Name SizeClass SubExp
| CalcNumGroups SubExp Name SubExp
deriving (SizeOp -> SizeOp -> Bool
(SizeOp -> SizeOp -> Bool)
-> (SizeOp -> SizeOp -> Bool) -> Eq SizeOp
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SizeOp -> SizeOp -> Bool
$c/= :: SizeOp -> SizeOp -> Bool
== :: SizeOp -> SizeOp -> Bool
$c== :: SizeOp -> SizeOp -> Bool
Eq, Eq SizeOp
Eq SizeOp
-> (SizeOp -> SizeOp -> Ordering)
-> (SizeOp -> SizeOp -> Bool)
-> (SizeOp -> SizeOp -> Bool)
-> (SizeOp -> SizeOp -> Bool)
-> (SizeOp -> SizeOp -> Bool)
-> (SizeOp -> SizeOp -> SizeOp)
-> (SizeOp -> SizeOp -> SizeOp)
-> Ord SizeOp
SizeOp -> SizeOp -> Bool
SizeOp -> SizeOp -> Ordering
SizeOp -> SizeOp -> SizeOp
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: SizeOp -> SizeOp -> SizeOp
$cmin :: SizeOp -> SizeOp -> SizeOp
max :: SizeOp -> SizeOp -> SizeOp
$cmax :: SizeOp -> SizeOp -> SizeOp
>= :: SizeOp -> SizeOp -> Bool
$c>= :: SizeOp -> SizeOp -> Bool
> :: SizeOp -> SizeOp -> Bool
$c> :: SizeOp -> SizeOp -> Bool
<= :: SizeOp -> SizeOp -> Bool
$c<= :: SizeOp -> SizeOp -> Bool
< :: SizeOp -> SizeOp -> Bool
$c< :: SizeOp -> SizeOp -> Bool
compare :: SizeOp -> SizeOp -> Ordering
$ccompare :: SizeOp -> SizeOp -> Ordering
$cp1Ord :: Eq SizeOp
Ord, Int -> SizeOp -> ShowS
[SizeOp] -> ShowS
SizeOp -> String
(Int -> SizeOp -> ShowS)
-> (SizeOp -> String) -> ([SizeOp] -> ShowS) -> Show SizeOp
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SizeOp] -> ShowS
$cshowList :: [SizeOp] -> ShowS
show :: SizeOp -> String
$cshow :: SizeOp -> String
showsPrec :: Int -> SizeOp -> ShowS
$cshowsPrec :: Int -> SizeOp -> ShowS
Show)
instance Substitute SizeOp where
substituteNames :: Map VName VName -> SizeOp -> SizeOp
substituteNames Map VName VName
subst (SplitSpace SplitOrdering
o SubExp
w SubExp
i SubExp
elems_per_thread) =
SplitOrdering -> SubExp -> SubExp -> SubExp -> SizeOp
SplitSpace
(Map VName VName -> SplitOrdering -> SplitOrdering
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst SplitOrdering
o)
(Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst SubExp
w)
(Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst SubExp
i)
(Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst SubExp
elems_per_thread)
substituteNames Map VName VName
substs (CmpSizeLe Name
name SizeClass
sclass SubExp
x) =
Name -> SizeClass -> SubExp -> SizeOp
CmpSizeLe Name
name SizeClass
sclass (Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs SubExp
x)
substituteNames Map VName VName
substs (CalcNumGroups SubExp
w Name
max_num_groups SubExp
group_size) =
SubExp -> Name -> SubExp -> SizeOp
CalcNumGroups
(Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs SubExp
w)
Name
max_num_groups
(Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs SubExp
group_size)
substituteNames Map VName VName
_ SizeOp
op = SizeOp
op
instance Rename SizeOp where
rename :: SizeOp -> RenameM SizeOp
rename (SplitSpace SplitOrdering
o SubExp
w SubExp
i SubExp
elems_per_thread) =
SplitOrdering -> SubExp -> SubExp -> SubExp -> SizeOp
SplitSpace
(SplitOrdering -> SubExp -> SubExp -> SubExp -> SizeOp)
-> RenameM SplitOrdering
-> RenameM (SubExp -> SubExp -> SubExp -> SizeOp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SplitOrdering -> RenameM SplitOrdering
forall a. Rename a => a -> RenameM a
rename SplitOrdering
o
RenameM (SubExp -> SubExp -> SubExp -> SizeOp)
-> RenameM SubExp -> RenameM (SubExp -> SubExp -> SizeOp)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SubExp -> RenameM SubExp
forall a. Rename a => a -> RenameM a
rename SubExp
w
RenameM (SubExp -> SubExp -> SizeOp)
-> RenameM SubExp -> RenameM (SubExp -> SizeOp)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SubExp -> RenameM SubExp
forall a. Rename a => a -> RenameM a
rename SubExp
i
RenameM (SubExp -> SizeOp) -> RenameM SubExp -> RenameM SizeOp
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SubExp -> RenameM SubExp
forall a. Rename a => a -> RenameM a
rename SubExp
elems_per_thread
rename (CmpSizeLe Name
name SizeClass
sclass SubExp
x) =
Name -> SizeClass -> SubExp -> SizeOp
CmpSizeLe Name
name SizeClass
sclass (SubExp -> SizeOp) -> RenameM SubExp -> RenameM SizeOp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> RenameM SubExp
forall a. Rename a => a -> RenameM a
rename SubExp
x
rename (CalcNumGroups SubExp
w Name
max_num_groups SubExp
group_size) =
SubExp -> Name -> SubExp -> SizeOp
CalcNumGroups (SubExp -> Name -> SubExp -> SizeOp)
-> RenameM SubExp -> RenameM (Name -> SubExp -> SizeOp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> RenameM SubExp
forall a. Rename a => a -> RenameM a
rename SubExp
w RenameM (Name -> SubExp -> SizeOp)
-> RenameM Name -> RenameM (SubExp -> SizeOp)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Name -> RenameM Name
forall (f :: * -> *) a. Applicative f => a -> f a
pure Name
max_num_groups RenameM (SubExp -> SizeOp) -> RenameM SubExp -> RenameM SizeOp
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SubExp -> RenameM SubExp
forall a. Rename a => a -> RenameM a
rename SubExp
group_size
rename SizeOp
x = SizeOp -> RenameM SizeOp
forall (f :: * -> *) a. Applicative f => a -> f a
pure SizeOp
x
instance IsOp SizeOp where
safeOp :: SizeOp -> Bool
safeOp SizeOp
_ = Bool
True
cheapOp :: SizeOp -> Bool
cheapOp SizeOp
_ = Bool
True
instance TypedOp SizeOp where
opType :: SizeOp -> m [ExtType]
opType SplitSpace{} = [ExtType] -> m [ExtType]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [PrimType -> ExtType
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32]
opType (GetSize Name
_ SizeClass
_) = [ExtType] -> m [ExtType]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [PrimType -> ExtType
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32]
opType (GetSizeMax SizeClass
_) = [ExtType] -> m [ExtType]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [PrimType -> ExtType
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32]
opType CmpSizeLe{} = [ExtType] -> m [ExtType]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [PrimType -> ExtType
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
Bool]
opType CalcNumGroups{} = [ExtType] -> m [ExtType]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [PrimType -> ExtType
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32]
instance AliasedOp SizeOp where
opAliases :: SizeOp -> [Names]
opAliases SizeOp
_ = [Names
forall a. Monoid a => a
mempty]
consumedInOp :: SizeOp -> Names
consumedInOp SizeOp
_ = Names
forall a. Monoid a => a
mempty
instance RangedOp SizeOp where
opRanges :: SizeOp -> [Range]
opRanges (SplitSpace SplitOrdering
_ SubExp
_ SubExp
_ SubExp
elems_per_thread) =
[(KnownBound -> Maybe KnownBound
forall a. a -> Maybe a
Just (ScalExp -> KnownBound
ScalarBound ScalExp
0),
KnownBound -> Maybe KnownBound
forall a. a -> Maybe a
Just (ScalExp -> KnownBound
ScalarBound (SubExp -> PrimType -> ScalExp
SE.subExpToScalExp SubExp
elems_per_thread PrimType
int32)))]
opRanges SizeOp
_ = [Range
unknownRange]
instance FreeIn SizeOp where
freeIn' :: SizeOp -> FV
freeIn' (SplitSpace SplitOrdering
o SubExp
w SubExp
i SubExp
elems_per_thread) =
SplitOrdering -> FV
forall a. FreeIn a => a -> FV
freeIn' SplitOrdering
o FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> [SubExp] -> FV
forall a. FreeIn a => a -> FV
freeIn' [SubExp
w, SubExp
i, SubExp
elems_per_thread]
freeIn' (CmpSizeLe Name
_ SizeClass
_ SubExp
x) = SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn' SubExp
x
freeIn' (CalcNumGroups SubExp
w Name
_ SubExp
group_size) = SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn' SubExp
w FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn' SubExp
group_size
freeIn' SizeOp
_ = FV
forall a. Monoid a => a
mempty
instance PP.Pretty SizeOp where
ppr :: SizeOp -> Doc
ppr (SplitSpace SplitOrdering
o SubExp
w SubExp
i SubExp
elems_per_thread) =
String -> Doc
text String
"splitSpace" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
suff Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<>
Doc -> Doc
parens ([Doc] -> Doc
commasep [SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
w, SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
i, SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
elems_per_thread])
where suff :: Doc
suff = case SplitOrdering
o of SplitOrdering
SplitContiguous -> Doc
forall a. Monoid a => a
mempty
SplitStrided SubExp
stride -> String -> Doc
text String
"Strided" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens (SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
stride)
ppr (GetSize Name
name SizeClass
size_class) =
String -> Doc
text String
"get_size" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens ([Doc] -> Doc
commasep [Name -> Doc
forall a. Pretty a => a -> Doc
ppr Name
name, SizeClass -> Doc
forall a. Pretty a => a -> Doc
ppr SizeClass
size_class])
ppr (GetSizeMax SizeClass
size_class) =
String -> Doc
text String
"get_size_max" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens ([Doc] -> Doc
commasep [SizeClass -> Doc
forall a. Pretty a => a -> Doc
ppr SizeClass
size_class])
ppr (CmpSizeLe Name
name SizeClass
size_class SubExp
x) =
String -> Doc
text String
"get_size" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens ([Doc] -> Doc
commasep [Name -> Doc
forall a. Pretty a => a -> Doc
ppr Name
name, SizeClass -> Doc
forall a. Pretty a => a -> Doc
ppr SizeClass
size_class]) Doc -> Doc -> Doc
<+>
String -> Doc
text String
"<=" Doc -> Doc -> Doc
<+> SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
x
ppr (CalcNumGroups SubExp
w Name
max_num_groups SubExp
group_size) =
String -> Doc
text String
"calc_num_groups" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens ([Doc] -> Doc
commasep [SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
w, Name -> Doc
forall a. Pretty a => a -> Doc
ppr Name
max_num_groups, SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
group_size])
instance OpMetrics SizeOp where
opMetrics :: SizeOp -> MetricsM ()
opMetrics SplitSpace{} = Text -> MetricsM ()
seen Text
"SplitSpace"
opMetrics GetSize{} = Text -> MetricsM ()
seen Text
"GetSize"
opMetrics GetSizeMax{} = Text -> MetricsM ()
seen Text
"GetSizeMax"
opMetrics CmpSizeLe{} = Text -> MetricsM ()
seen Text
"CmpSizeLe"
opMetrics CalcNumGroups{} = Text -> MetricsM ()
seen Text
"CalcNumGroups"
typeCheckSizeOp :: TC.Checkable lore => SizeOp -> TC.TypeM lore ()
typeCheckSizeOp :: SizeOp -> TypeM lore ()
typeCheckSizeOp (SplitSpace SplitOrdering
o SubExp
w SubExp
i SubExp
elems_per_thread) = do
case SplitOrdering
o of
SplitOrdering
SplitContiguous -> () -> TypeM lore ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
SplitStrided SubExp
stride -> [Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32] SubExp
stride
(SubExp -> TypeM lore ()) -> [SubExp] -> TypeM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ([Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32]) [SubExp
w, SubExp
i, SubExp
elems_per_thread]
typeCheckSizeOp GetSize{} = () -> TypeM lore ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
typeCheckSizeOp GetSizeMax{} = () -> TypeM lore ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
typeCheckSizeOp (CmpSizeLe Name
_ SizeClass
_ SubExp
x) = [Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32] SubExp
x
typeCheckSizeOp (CalcNumGroups SubExp
w Name
_ SubExp
group_size) = do [Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
w
[Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32] SubExp
group_size
data HostOp lore op
= SegOp (SegOp lore)
| SizeOp SizeOp
| OtherOp op
deriving (HostOp lore op -> HostOp lore op -> Bool
(HostOp lore op -> HostOp lore op -> Bool)
-> (HostOp lore op -> HostOp lore op -> Bool)
-> Eq (HostOp lore op)
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall lore op.
(Annotations lore, Eq op) =>
HostOp lore op -> HostOp lore op -> Bool
/= :: HostOp lore op -> HostOp lore op -> Bool
$c/= :: forall lore op.
(Annotations lore, Eq op) =>
HostOp lore op -> HostOp lore op -> Bool
== :: HostOp lore op -> HostOp lore op -> Bool
$c== :: forall lore op.
(Annotations lore, Eq op) =>
HostOp lore op -> HostOp lore op -> Bool
Eq, Eq (HostOp lore op)
Eq (HostOp lore op)
-> (HostOp lore op -> HostOp lore op -> Ordering)
-> (HostOp lore op -> HostOp lore op -> Bool)
-> (HostOp lore op -> HostOp lore op -> Bool)
-> (HostOp lore op -> HostOp lore op -> Bool)
-> (HostOp lore op -> HostOp lore op -> Bool)
-> (HostOp lore op -> HostOp lore op -> HostOp lore op)
-> (HostOp lore op -> HostOp lore op -> HostOp lore op)
-> Ord (HostOp lore op)
HostOp lore op -> HostOp lore op -> Bool
HostOp lore op -> HostOp lore op -> Ordering
HostOp lore op -> HostOp lore op -> HostOp lore op
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall lore op. (Annotations lore, Ord op) => Eq (HostOp lore op)
forall lore op.
(Annotations lore, Ord op) =>
HostOp lore op -> HostOp lore op -> Bool
forall lore op.
(Annotations lore, Ord op) =>
HostOp lore op -> HostOp lore op -> Ordering
forall lore op.
(Annotations lore, Ord op) =>
HostOp lore op -> HostOp lore op -> HostOp lore op
min :: HostOp lore op -> HostOp lore op -> HostOp lore op
$cmin :: forall lore op.
(Annotations lore, Ord op) =>
HostOp lore op -> HostOp lore op -> HostOp lore op
max :: HostOp lore op -> HostOp lore op -> HostOp lore op
$cmax :: forall lore op.
(Annotations lore, Ord op) =>
HostOp lore op -> HostOp lore op -> HostOp lore op
>= :: HostOp lore op -> HostOp lore op -> Bool
$c>= :: forall lore op.
(Annotations lore, Ord op) =>
HostOp lore op -> HostOp lore op -> Bool
> :: HostOp lore op -> HostOp lore op -> Bool
$c> :: forall lore op.
(Annotations lore, Ord op) =>
HostOp lore op -> HostOp lore op -> Bool
<= :: HostOp lore op -> HostOp lore op -> Bool
$c<= :: forall lore op.
(Annotations lore, Ord op) =>
HostOp lore op -> HostOp lore op -> Bool
< :: HostOp lore op -> HostOp lore op -> Bool
$c< :: forall lore op.
(Annotations lore, Ord op) =>
HostOp lore op -> HostOp lore op -> Bool
compare :: HostOp lore op -> HostOp lore op -> Ordering
$ccompare :: forall lore op.
(Annotations lore, Ord op) =>
HostOp lore op -> HostOp lore op -> Ordering
$cp1Ord :: forall lore op. (Annotations lore, Ord op) => Eq (HostOp lore op)
Ord, Int -> HostOp lore op -> ShowS
[HostOp lore op] -> ShowS
HostOp lore op -> String
(Int -> HostOp lore op -> ShowS)
-> (HostOp lore op -> String)
-> ([HostOp lore op] -> ShowS)
-> Show (HostOp lore op)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall lore op.
(Annotations lore, Show op) =>
Int -> HostOp lore op -> ShowS
forall lore op.
(Annotations lore, Show op) =>
[HostOp lore op] -> ShowS
forall lore op.
(Annotations lore, Show op) =>
HostOp lore op -> String
showList :: [HostOp lore op] -> ShowS
$cshowList :: forall lore op.
(Annotations lore, Show op) =>
[HostOp lore op] -> ShowS
show :: HostOp lore op -> String
$cshow :: forall lore op.
(Annotations lore, Show op) =>
HostOp lore op -> String
showsPrec :: Int -> HostOp lore op -> ShowS
$cshowsPrec :: forall lore op.
(Annotations lore, Show op) =>
Int -> HostOp lore op -> ShowS
Show)
instance (Attributes lore, Substitute op) => Substitute (HostOp lore op) where
substituteNames :: Map VName VName -> HostOp lore op -> HostOp lore op
substituteNames Map VName VName
substs (SegOp SegOp lore
op) =
SegOp lore -> HostOp lore op
forall lore op. SegOp lore -> HostOp lore op
SegOp (SegOp lore -> HostOp lore op) -> SegOp lore -> HostOp lore op
forall a b. (a -> b) -> a -> b
$ Map VName VName -> SegOp lore -> SegOp lore
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs SegOp lore
op
substituteNames Map VName VName
substs (OtherOp op
op) =
op -> HostOp lore op
forall lore op. op -> HostOp lore op
OtherOp (op -> HostOp lore op) -> op -> HostOp lore op
forall a b. (a -> b) -> a -> b
$ Map VName VName -> op -> op
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs op
op
substituteNames Map VName VName
substs (SizeOp SizeOp
op) =
SizeOp -> HostOp lore op
forall lore op. SizeOp -> HostOp lore op
SizeOp (SizeOp -> HostOp lore op) -> SizeOp -> HostOp lore op
forall a b. (a -> b) -> a -> b
$ Map VName VName -> SizeOp -> SizeOp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs SizeOp
op
instance (Attributes lore, Rename op) => Rename (HostOp lore op) where
rename :: HostOp lore op -> RenameM (HostOp lore op)
rename (SegOp SegOp lore
op) = SegOp lore -> HostOp lore op
forall lore op. SegOp lore -> HostOp lore op
SegOp (SegOp lore -> HostOp lore op)
-> RenameM (SegOp lore) -> RenameM (HostOp lore op)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOp lore -> RenameM (SegOp lore)
forall a. Rename a => a -> RenameM a
rename SegOp lore
op
rename (OtherOp op
op) = op -> HostOp lore op
forall lore op. op -> HostOp lore op
OtherOp (op -> HostOp lore op) -> RenameM op -> RenameM (HostOp lore op)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> op -> RenameM op
forall a. Rename a => a -> RenameM a
rename op
op
rename (SizeOp SizeOp
op) = SizeOp -> HostOp lore op
forall lore op. SizeOp -> HostOp lore op
SizeOp (SizeOp -> HostOp lore op)
-> RenameM SizeOp -> RenameM (HostOp lore op)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SizeOp -> RenameM SizeOp
forall a. Rename a => a -> RenameM a
rename SizeOp
op
instance (Attributes lore, IsOp op) => IsOp (HostOp lore op) where
safeOp :: HostOp lore op -> Bool
safeOp (SegOp SegOp lore
op) = SegOp lore -> Bool
forall op. IsOp op => op -> Bool
safeOp SegOp lore
op
safeOp (OtherOp op
op) = op -> Bool
forall op. IsOp op => op -> Bool
safeOp op
op
safeOp (SizeOp SizeOp
op) = SizeOp -> Bool
forall op. IsOp op => op -> Bool
safeOp SizeOp
op
cheapOp :: HostOp lore op -> Bool
cheapOp (SegOp SegOp lore
op) = SegOp lore -> Bool
forall op. IsOp op => op -> Bool
cheapOp SegOp lore
op
cheapOp (OtherOp op
op) = op -> Bool
forall op. IsOp op => op -> Bool
cheapOp op
op
cheapOp (SizeOp SizeOp
op) = SizeOp -> Bool
forall op. IsOp op => op -> Bool
cheapOp SizeOp
op
instance TypedOp op => TypedOp (HostOp lore op) where
opType :: HostOp lore op -> m [ExtType]
opType (SegOp SegOp lore
op) = SegOp lore -> m [ExtType]
forall op t (m :: * -> *).
(TypedOp op, HasScope t m) =>
op -> m [ExtType]
opType SegOp lore
op
opType (OtherOp op
op) = op -> m [ExtType]
forall op t (m :: * -> *).
(TypedOp op, HasScope t m) =>
op -> m [ExtType]
opType op
op
opType (SizeOp SizeOp
op) = SizeOp -> m [ExtType]
forall op t (m :: * -> *).
(TypedOp op, HasScope t m) =>
op -> m [ExtType]
opType SizeOp
op
instance (Aliased lore, AliasedOp op, Attributes lore) => AliasedOp (HostOp lore op) where
opAliases :: HostOp lore op -> [Names]
opAliases (SegOp SegOp lore
op) = SegOp lore -> [Names]
forall op. AliasedOp op => op -> [Names]
opAliases SegOp lore
op
opAliases (OtherOp op
op) = op -> [Names]
forall op. AliasedOp op => op -> [Names]
opAliases op
op
opAliases (SizeOp SizeOp
op) = SizeOp -> [Names]
forall op. AliasedOp op => op -> [Names]
opAliases SizeOp
op
consumedInOp :: HostOp lore op -> Names
consumedInOp (SegOp SegOp lore
op) = SegOp lore -> Names
forall op. AliasedOp op => op -> Names
consumedInOp SegOp lore
op
consumedInOp (OtherOp op
op) = op -> Names
forall op. AliasedOp op => op -> Names
consumedInOp op
op
consumedInOp (SizeOp SizeOp
op) = SizeOp -> Names
forall op. AliasedOp op => op -> Names
consumedInOp SizeOp
op
instance (Attributes lore, RangedOp op) => RangedOp (HostOp lore op) where
opRanges :: HostOp lore op -> [Range]
opRanges (SegOp SegOp lore
op) = SegOp lore -> [Range]
forall op. RangedOp op => op -> [Range]
opRanges SegOp lore
op
opRanges (OtherOp op
op) = op -> [Range]
forall op. RangedOp op => op -> [Range]
opRanges op
op
opRanges (SizeOp SizeOp
op) = SizeOp -> [Range]
forall op. RangedOp op => op -> [Range]
opRanges SizeOp
op
instance (Attributes lore, FreeIn op) => FreeIn (HostOp lore op) where
freeIn' :: HostOp lore op -> FV
freeIn' (SegOp SegOp lore
op) = SegOp lore -> FV
forall a. FreeIn a => a -> FV
freeIn' SegOp lore
op
freeIn' (OtherOp op
op) = op -> FV
forall a. FreeIn a => a -> FV
freeIn' op
op
freeIn' (SizeOp SizeOp
op) = SizeOp -> FV
forall a. FreeIn a => a -> FV
freeIn' SizeOp
op
instance (CanBeAliased (Op lore), CanBeAliased op, Attributes lore) => CanBeAliased (HostOp lore op) where
type OpWithAliases (HostOp lore op) = HostOp (Aliases lore) (OpWithAliases op)
addOpAliases :: HostOp lore op -> OpWithAliases (HostOp lore op)
addOpAliases (SegOp SegOp lore
op) = SegOp (Aliases lore) -> HostOp (Aliases lore) (OpWithAliases op)
forall lore op. SegOp lore -> HostOp lore op
SegOp (SegOp (Aliases lore) -> HostOp (Aliases lore) (OpWithAliases op))
-> SegOp (Aliases lore) -> HostOp (Aliases lore) (OpWithAliases op)
forall a b. (a -> b) -> a -> b
$ SegOp lore -> OpWithAliases (SegOp lore)
forall op. CanBeAliased op => op -> OpWithAliases op
addOpAliases SegOp lore
op
addOpAliases (OtherOp op
op) = OpWithAliases op -> HostOp (Aliases lore) (OpWithAliases op)
forall lore op. op -> HostOp lore op
OtherOp (OpWithAliases op -> HostOp (Aliases lore) (OpWithAliases op))
-> OpWithAliases op -> HostOp (Aliases lore) (OpWithAliases op)
forall a b. (a -> b) -> a -> b
$ op -> OpWithAliases op
forall op. CanBeAliased op => op -> OpWithAliases op
addOpAliases op
op
addOpAliases (SizeOp SizeOp
op) = SizeOp -> HostOp (Aliases lore) (OpWithAliases op)
forall lore op. SizeOp -> HostOp lore op
SizeOp SizeOp
op
removeOpAliases :: OpWithAliases (HostOp lore op) -> HostOp lore op
removeOpAliases (SegOp op) = SegOp lore -> HostOp lore op
forall lore op. SegOp lore -> HostOp lore op
SegOp (SegOp lore -> HostOp lore op) -> SegOp lore -> HostOp lore op
forall a b. (a -> b) -> a -> b
$ OpWithAliases (SegOp lore) -> SegOp lore
forall op. CanBeAliased op => OpWithAliases op -> op
removeOpAliases OpWithAliases (SegOp lore)
SegOp (Aliases lore)
op
removeOpAliases (OtherOp op) = op -> HostOp lore op
forall lore op. op -> HostOp lore op
OtherOp (op -> HostOp lore op) -> op -> HostOp lore op
forall a b. (a -> b) -> a -> b
$ OpWithAliases op -> op
forall op. CanBeAliased op => OpWithAliases op -> op
removeOpAliases OpWithAliases op
op
removeOpAliases (SizeOp op) = SizeOp -> HostOp lore op
forall lore op. SizeOp -> HostOp lore op
SizeOp SizeOp
op
instance (CanBeRanged (Op lore), CanBeRanged op, Attributes lore) => CanBeRanged (HostOp lore op) where
type OpWithRanges (HostOp lore op) = HostOp (Ranges lore) (OpWithRanges op)
addOpRanges :: HostOp lore op -> OpWithRanges (HostOp lore op)
addOpRanges (SegOp SegOp lore
op) = SegOp (Ranges lore) -> HostOp (Ranges lore) (OpWithRanges op)
forall lore op. SegOp lore -> HostOp lore op
SegOp (SegOp (Ranges lore) -> HostOp (Ranges lore) (OpWithRanges op))
-> SegOp (Ranges lore) -> HostOp (Ranges lore) (OpWithRanges op)
forall a b. (a -> b) -> a -> b
$ SegOp lore -> OpWithRanges (SegOp lore)
forall op. CanBeRanged op => op -> OpWithRanges op
addOpRanges SegOp lore
op
addOpRanges (OtherOp op
op) = OpWithRanges op -> HostOp (Ranges lore) (OpWithRanges op)
forall lore op. op -> HostOp lore op
OtherOp (OpWithRanges op -> HostOp (Ranges lore) (OpWithRanges op))
-> OpWithRanges op -> HostOp (Ranges lore) (OpWithRanges op)
forall a b. (a -> b) -> a -> b
$ op -> OpWithRanges op
forall op. CanBeRanged op => op -> OpWithRanges op
addOpRanges op
op
addOpRanges (SizeOp SizeOp
op) = SizeOp -> HostOp (Ranges lore) (OpWithRanges op)
forall lore op. SizeOp -> HostOp lore op
SizeOp SizeOp
op
removeOpRanges :: OpWithRanges (HostOp lore op) -> HostOp lore op
removeOpRanges (SegOp op) = SegOp lore -> HostOp lore op
forall lore op. SegOp lore -> HostOp lore op
SegOp (SegOp lore -> HostOp lore op) -> SegOp lore -> HostOp lore op
forall a b. (a -> b) -> a -> b
$ OpWithRanges (SegOp lore) -> SegOp lore
forall op. CanBeRanged op => OpWithRanges op -> op
removeOpRanges OpWithRanges (SegOp lore)
SegOp (Ranges lore)
op
removeOpRanges (OtherOp op) = op -> HostOp lore op
forall lore op. op -> HostOp lore op
OtherOp (op -> HostOp lore op) -> op -> HostOp lore op
forall a b. (a -> b) -> a -> b
$ OpWithRanges op -> op
forall op. CanBeRanged op => OpWithRanges op -> op
removeOpRanges OpWithRanges op
op
removeOpRanges (SizeOp op) = SizeOp -> HostOp lore op
forall lore op. SizeOp -> HostOp lore op
SizeOp SizeOp
op
instance (CanBeWise (Op lore), CanBeWise op, Attributes lore) => CanBeWise (HostOp lore op) where
type OpWithWisdom (HostOp lore op) = HostOp (Wise lore) (OpWithWisdom op)
removeOpWisdom :: OpWithWisdom (HostOp lore op) -> HostOp lore op
removeOpWisdom (SegOp op) = SegOp lore -> HostOp lore op
forall lore op. SegOp lore -> HostOp lore op
SegOp (SegOp lore -> HostOp lore op) -> SegOp lore -> HostOp lore op
forall a b. (a -> b) -> a -> b
$ OpWithWisdom (SegOp lore) -> SegOp lore
forall op. CanBeWise op => OpWithWisdom op -> op
removeOpWisdom OpWithWisdom (SegOp lore)
SegOp (Wise lore)
op
removeOpWisdom (OtherOp op) = op -> HostOp lore op
forall lore op. op -> HostOp lore op
OtherOp (op -> HostOp lore op) -> op -> HostOp lore op
forall a b. (a -> b) -> a -> b
$ OpWithWisdom op -> op
forall op. CanBeWise op => OpWithWisdom op -> op
removeOpWisdom OpWithWisdom op
op
removeOpWisdom (SizeOp op) = SizeOp -> HostOp lore op
forall lore op. SizeOp -> HostOp lore op
SizeOp SizeOp
op
instance (Attributes lore, ST.IndexOp op) => ST.IndexOp (HostOp lore op) where
indexOp :: SymbolTable lore
-> Int -> HostOp lore op -> [PrimExp VName] -> Maybe Indexed
indexOp SymbolTable lore
vtable Int
k (SegOp SegOp lore
op) [PrimExp VName]
is = SymbolTable lore
-> Int -> SegOp lore -> [PrimExp VName] -> Maybe Indexed
forall op lore.
(IndexOp op, Attributes lore, IndexOp (Op lore)) =>
SymbolTable lore -> Int -> op -> [PrimExp VName] -> Maybe Indexed
ST.indexOp SymbolTable lore
vtable Int
k SegOp lore
op [PrimExp VName]
is
indexOp SymbolTable lore
vtable Int
k (OtherOp op
op) [PrimExp VName]
is = SymbolTable lore -> Int -> op -> [PrimExp VName] -> Maybe Indexed
forall op lore.
(IndexOp op, Attributes lore, IndexOp (Op lore)) =>
SymbolTable lore -> Int -> op -> [PrimExp VName] -> Maybe Indexed
ST.indexOp SymbolTable lore
vtable Int
k op
op [PrimExp VName]
is
indexOp SymbolTable lore
_ Int
_ HostOp lore op
_ [PrimExp VName]
_ = Maybe Indexed
forall a. Maybe a
Nothing
instance (PrettyLore lore, PP.Pretty op) => PP.Pretty (HostOp lore op) where
ppr :: HostOp lore op -> Doc
ppr (SegOp SegOp lore
op) = SegOp lore -> Doc
forall a. Pretty a => a -> Doc
ppr SegOp lore
op
ppr (OtherOp op
op) = op -> Doc
forall a. Pretty a => a -> Doc
ppr op
op
ppr (SizeOp SizeOp
op) = SizeOp -> Doc
forall a. Pretty a => a -> Doc
ppr SizeOp
op
instance (OpMetrics (Op lore), OpMetrics op) => OpMetrics (HostOp lore op) where
opMetrics :: HostOp lore op -> MetricsM ()
opMetrics (SegOp SegOp lore
op) = SegOp lore -> MetricsM ()
forall op. OpMetrics op => op -> MetricsM ()
opMetrics SegOp lore
op
opMetrics (OtherOp op
op) = op -> MetricsM ()
forall op. OpMetrics op => op -> MetricsM ()
opMetrics op
op
opMetrics (SizeOp SizeOp
op) = SizeOp -> MetricsM ()
forall op. OpMetrics op => op -> MetricsM ()
opMetrics SizeOp
op
typeCheckHostOp :: TC.Checkable lore =>
(SegLevel -> OpWithAliases (Op lore) -> TC.TypeM lore ())
-> Maybe SegLevel
-> (op -> TC.TypeM lore ())
-> HostOp (Aliases lore) op
-> TC.TypeM lore ()
typeCheckHostOp :: (SegLevel -> OpWithAliases (Op lore) -> TypeM lore ())
-> Maybe SegLevel
-> (op -> TypeM lore ())
-> HostOp (Aliases lore) op
-> TypeM lore ()
typeCheckHostOp SegLevel -> OpWithAliases (Op lore) -> TypeM lore ()
checker Maybe SegLevel
lvl op -> TypeM lore ()
_ (SegOp SegOp (Aliases lore)
op) =
(OpWithAliases (Op lore) -> TypeM lore ())
-> TypeM lore () -> TypeM lore ()
forall lore a.
(OpWithAliases (Op lore) -> TypeM lore ())
-> TypeM lore a -> TypeM lore a
TC.checkOpWith (SegLevel -> OpWithAliases (Op lore) -> TypeM lore ()
checker (SegLevel -> OpWithAliases (Op lore) -> TypeM lore ())
-> SegLevel -> OpWithAliases (Op lore) -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ SegOp (Aliases lore) -> SegLevel
forall lore. SegOp lore -> SegLevel
segLevel SegOp (Aliases lore)
op) (TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
Maybe SegLevel -> SegOp (Aliases lore) -> TypeM lore ()
forall lore.
Checkable lore =>
Maybe SegLevel -> SegOp (Aliases lore) -> TypeM lore ()
typeCheckSegOp Maybe SegLevel
lvl SegOp (Aliases lore)
op
typeCheckHostOp SegLevel -> OpWithAliases (Op lore) -> TypeM lore ()
_ Maybe SegLevel
_ op -> TypeM lore ()
f (OtherOp op
op) = op -> TypeM lore ()
f op
op
typeCheckHostOp SegLevel -> OpWithAliases (Op lore) -> TypeM lore ()
_ Maybe SegLevel
_ op -> TypeM lore ()
_ (SizeOp SizeOp
op) = SizeOp -> TypeM lore ()
forall lore. Checkable lore => SizeOp -> TypeM lore ()
typeCheckSizeOp SizeOp
op