{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Futhark.IR.Kernels.Kernel
(
SizeOp(..)
, HostOp(..)
, typeCheckHostOp
, SegLevel(..)
, module Futhark.IR.Kernels.Sizes
, module Futhark.IR.SegOp
)
where
import Futhark.IR
import qualified Futhark.Analysis.SymbolTable as ST
import qualified Futhark.Util.Pretty as PP
import Futhark.Util.Pretty
((</>), (<+>), ppr, commasep, parens, text)
import Futhark.Transform.Substitute
import Futhark.Transform.Rename
import Futhark.Optimise.Simplify.Lore
import qualified Futhark.Optimise.Simplify.Engine as Engine
import Futhark.IR.Prop.Aliases
import Futhark.IR.Aliases (Aliases)
import Futhark.IR.SegOp
import Futhark.IR.Kernels.Sizes
import qualified Futhark.TypeCheck as TC
import Futhark.Analysis.Metrics
data SegLevel = SegThread { SegLevel -> Count NumGroups SubExp
segNumGroups :: Count NumGroups SubExp
, SegLevel -> Count GroupSize SubExp
segGroupSize :: Count GroupSize SubExp
, SegLevel -> SegVirt
segVirt :: SegVirt }
| SegGroup { segNumGroups :: Count NumGroups SubExp
, segGroupSize :: Count GroupSize SubExp
, segVirt :: SegVirt }
deriving (SegLevel -> SegLevel -> Bool
(SegLevel -> SegLevel -> Bool)
-> (SegLevel -> SegLevel -> Bool) -> Eq SegLevel
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SegLevel -> SegLevel -> Bool
$c/= :: SegLevel -> SegLevel -> Bool
== :: SegLevel -> SegLevel -> Bool
$c== :: SegLevel -> SegLevel -> Bool
Eq, Eq SegLevel
Eq SegLevel
-> (SegLevel -> SegLevel -> Ordering)
-> (SegLevel -> SegLevel -> Bool)
-> (SegLevel -> SegLevel -> Bool)
-> (SegLevel -> SegLevel -> Bool)
-> (SegLevel -> SegLevel -> Bool)
-> (SegLevel -> SegLevel -> SegLevel)
-> (SegLevel -> SegLevel -> SegLevel)
-> Ord SegLevel
SegLevel -> SegLevel -> Bool
SegLevel -> SegLevel -> Ordering
SegLevel -> SegLevel -> SegLevel
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: SegLevel -> SegLevel -> SegLevel
$cmin :: SegLevel -> SegLevel -> SegLevel
max :: SegLevel -> SegLevel -> SegLevel
$cmax :: SegLevel -> SegLevel -> SegLevel
>= :: SegLevel -> SegLevel -> Bool
$c>= :: SegLevel -> SegLevel -> Bool
> :: SegLevel -> SegLevel -> Bool
$c> :: SegLevel -> SegLevel -> Bool
<= :: SegLevel -> SegLevel -> Bool
$c<= :: SegLevel -> SegLevel -> Bool
< :: SegLevel -> SegLevel -> Bool
$c< :: SegLevel -> SegLevel -> Bool
compare :: SegLevel -> SegLevel -> Ordering
$ccompare :: SegLevel -> SegLevel -> Ordering
$cp1Ord :: Eq SegLevel
Ord, Int -> SegLevel -> ShowS
[SegLevel] -> ShowS
SegLevel -> String
(Int -> SegLevel -> ShowS)
-> (SegLevel -> String) -> ([SegLevel] -> ShowS) -> Show SegLevel
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SegLevel] -> ShowS
$cshowList :: [SegLevel] -> ShowS
show :: SegLevel -> String
$cshow :: SegLevel -> String
showsPrec :: Int -> SegLevel -> ShowS
$cshowsPrec :: Int -> SegLevel -> ShowS
Show)
instance PP.Pretty SegLevel where
ppr :: SegLevel -> Doc
ppr SegLevel
lvl =
Doc
lvl' Doc -> Doc -> Doc
</>
Doc -> Doc
PP.parens (String -> Doc
text String
"#groups=" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Count NumGroups SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr (SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.semi Doc -> Doc -> Doc
<+>
String -> Doc
text String
"groupsize=" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Count GroupSize SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr (SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<>
case SegLevel -> SegVirt
segVirt SegLevel
lvl of
SegVirt
SegNoVirt -> Doc
forall a. Monoid a => a
mempty
SegVirt
SegNoVirtFull -> Doc
PP.semi Doc -> Doc -> Doc
<+> String -> Doc
text String
"full"
SegVirt
SegVirt -> Doc
PP.semi Doc -> Doc -> Doc
<+> String -> Doc
text String
"virtualise")
where lvl' :: Doc
lvl' = case SegLevel
lvl of SegThread{} -> Doc
"_thread"
SegGroup{} -> Doc
"_group"
instance Engine.Simplifiable SegLevel where
simplify :: SegLevel -> SimpleM lore SegLevel
simplify (SegThread Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegVirt
virt) =
Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegThread (Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel)
-> SimpleM lore (Count NumGroups SubExp)
-> SimpleM lore (Count GroupSize SubExp -> SegVirt -> SegLevel)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExp -> SimpleM lore SubExp)
-> Count NumGroups SubExp -> SimpleM lore (Count NumGroups SubExp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify Count NumGroups SubExp
num_groups SimpleM lore (Count GroupSize SubExp -> SegVirt -> SegLevel)
-> SimpleM lore (Count GroupSize SubExp)
-> SimpleM lore (SegVirt -> SegLevel)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*>
(SubExp -> SimpleM lore SubExp)
-> Count GroupSize SubExp -> SimpleM lore (Count GroupSize SubExp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify Count GroupSize SubExp
group_size SimpleM lore (SegVirt -> SegLevel)
-> SimpleM lore SegVirt -> SimpleM lore SegLevel
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegVirt -> SimpleM lore SegVirt
forall (f :: * -> *) a. Applicative f => a -> f a
pure SegVirt
virt
simplify (SegGroup Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegVirt
virt) =
Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegGroup (Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel)
-> SimpleM lore (Count NumGroups SubExp)
-> SimpleM lore (Count GroupSize SubExp -> SegVirt -> SegLevel)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExp -> SimpleM lore SubExp)
-> Count NumGroups SubExp -> SimpleM lore (Count NumGroups SubExp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify Count NumGroups SubExp
num_groups SimpleM lore (Count GroupSize SubExp -> SegVirt -> SegLevel)
-> SimpleM lore (Count GroupSize SubExp)
-> SimpleM lore (SegVirt -> SegLevel)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*>
(SubExp -> SimpleM lore SubExp)
-> Count GroupSize SubExp -> SimpleM lore (Count GroupSize SubExp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify Count GroupSize SubExp
group_size SimpleM lore (SegVirt -> SegLevel)
-> SimpleM lore SegVirt -> SimpleM lore SegLevel
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegVirt -> SimpleM lore SegVirt
forall (f :: * -> *) a. Applicative f => a -> f a
pure SegVirt
virt
instance Substitute SegLevel where
substituteNames :: Map VName VName -> SegLevel -> SegLevel
substituteNames Map VName VName
substs (SegThread Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegVirt
virt) =
Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegThread
(Map VName VName -> Count NumGroups SubExp -> Count NumGroups SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Count NumGroups SubExp
num_groups) (Map VName VName -> Count GroupSize SubExp -> Count GroupSize SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Count GroupSize SubExp
group_size) SegVirt
virt
substituteNames Map VName VName
substs (SegGroup Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegVirt
virt) =
Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegGroup
(Map VName VName -> Count NumGroups SubExp -> Count NumGroups SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Count NumGroups SubExp
num_groups) (Map VName VName -> Count GroupSize SubExp -> Count GroupSize SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Count GroupSize SubExp
group_size) SegVirt
virt
instance Rename SegLevel where
rename :: SegLevel -> RenameM SegLevel
rename = SegLevel -> RenameM SegLevel
forall a. Substitute a => a -> RenameM a
substituteRename
instance FreeIn SegLevel where
freeIn' :: SegLevel -> FV
freeIn' (SegThread Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegVirt
_) =
Count NumGroups SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn' Count NumGroups SubExp
num_groups FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> Count GroupSize SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn' Count GroupSize SubExp
group_size
freeIn' (SegGroup Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegVirt
_) =
Count NumGroups SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn' Count NumGroups SubExp
num_groups FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> Count GroupSize SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn' Count GroupSize SubExp
group_size
data SizeOp
= SplitSpace SplitOrdering SubExp SubExp SubExp
| GetSize Name SizeClass
| GetSizeMax SizeClass
| CmpSizeLe Name SizeClass SubExp
| CalcNumGroups SubExp Name SubExp
deriving (SizeOp -> SizeOp -> Bool
(SizeOp -> SizeOp -> Bool)
-> (SizeOp -> SizeOp -> Bool) -> Eq SizeOp
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SizeOp -> SizeOp -> Bool
$c/= :: SizeOp -> SizeOp -> Bool
== :: SizeOp -> SizeOp -> Bool
$c== :: SizeOp -> SizeOp -> Bool
Eq, Eq SizeOp
Eq SizeOp
-> (SizeOp -> SizeOp -> Ordering)
-> (SizeOp -> SizeOp -> Bool)
-> (SizeOp -> SizeOp -> Bool)
-> (SizeOp -> SizeOp -> Bool)
-> (SizeOp -> SizeOp -> Bool)
-> (SizeOp -> SizeOp -> SizeOp)
-> (SizeOp -> SizeOp -> SizeOp)
-> Ord SizeOp
SizeOp -> SizeOp -> Bool
SizeOp -> SizeOp -> Ordering
SizeOp -> SizeOp -> SizeOp
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: SizeOp -> SizeOp -> SizeOp
$cmin :: SizeOp -> SizeOp -> SizeOp
max :: SizeOp -> SizeOp -> SizeOp
$cmax :: SizeOp -> SizeOp -> SizeOp
>= :: SizeOp -> SizeOp -> Bool
$c>= :: SizeOp -> SizeOp -> Bool
> :: SizeOp -> SizeOp -> Bool
$c> :: SizeOp -> SizeOp -> Bool
<= :: SizeOp -> SizeOp -> Bool
$c<= :: SizeOp -> SizeOp -> Bool
< :: SizeOp -> SizeOp -> Bool
$c< :: SizeOp -> SizeOp -> Bool
compare :: SizeOp -> SizeOp -> Ordering
$ccompare :: SizeOp -> SizeOp -> Ordering
$cp1Ord :: Eq SizeOp
Ord, Int -> SizeOp -> ShowS
[SizeOp] -> ShowS
SizeOp -> String
(Int -> SizeOp -> ShowS)
-> (SizeOp -> String) -> ([SizeOp] -> ShowS) -> Show SizeOp
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SizeOp] -> ShowS
$cshowList :: [SizeOp] -> ShowS
show :: SizeOp -> String
$cshow :: SizeOp -> String
showsPrec :: Int -> SizeOp -> ShowS
$cshowsPrec :: Int -> SizeOp -> ShowS
Show)
instance Substitute SizeOp where
substituteNames :: Map VName VName -> SizeOp -> SizeOp
substituteNames Map VName VName
subst (SplitSpace SplitOrdering
o SubExp
w SubExp
i SubExp
elems_per_thread) =
SplitOrdering -> SubExp -> SubExp -> SubExp -> SizeOp
SplitSpace
(Map VName VName -> SplitOrdering -> SplitOrdering
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst SplitOrdering
o)
(Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst SubExp
w)
(Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst SubExp
i)
(Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst SubExp
elems_per_thread)
substituteNames Map VName VName
substs (CmpSizeLe Name
name SizeClass
sclass SubExp
x) =
Name -> SizeClass -> SubExp -> SizeOp
CmpSizeLe Name
name SizeClass
sclass (Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs SubExp
x)
substituteNames Map VName VName
substs (CalcNumGroups SubExp
w Name
max_num_groups SubExp
group_size) =
SubExp -> Name -> SubExp -> SizeOp
CalcNumGroups
(Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs SubExp
w)
Name
max_num_groups
(Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs SubExp
group_size)
substituteNames Map VName VName
_ SizeOp
op = SizeOp
op
instance Rename SizeOp where
rename :: SizeOp -> RenameM SizeOp
rename (SplitSpace SplitOrdering
o SubExp
w SubExp
i SubExp
elems_per_thread) =
SplitOrdering -> SubExp -> SubExp -> SubExp -> SizeOp
SplitSpace
(SplitOrdering -> SubExp -> SubExp -> SubExp -> SizeOp)
-> RenameM SplitOrdering
-> RenameM (SubExp -> SubExp -> SubExp -> SizeOp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SplitOrdering -> RenameM SplitOrdering
forall a. Rename a => a -> RenameM a
rename SplitOrdering
o
RenameM (SubExp -> SubExp -> SubExp -> SizeOp)
-> RenameM SubExp -> RenameM (SubExp -> SubExp -> SizeOp)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SubExp -> RenameM SubExp
forall a. Rename a => a -> RenameM a
rename SubExp
w
RenameM (SubExp -> SubExp -> SizeOp)
-> RenameM SubExp -> RenameM (SubExp -> SizeOp)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SubExp -> RenameM SubExp
forall a. Rename a => a -> RenameM a
rename SubExp
i
RenameM (SubExp -> SizeOp) -> RenameM SubExp -> RenameM SizeOp
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SubExp -> RenameM SubExp
forall a. Rename a => a -> RenameM a
rename SubExp
elems_per_thread
rename (CmpSizeLe Name
name SizeClass
sclass SubExp
x) =
Name -> SizeClass -> SubExp -> SizeOp
CmpSizeLe Name
name SizeClass
sclass (SubExp -> SizeOp) -> RenameM SubExp -> RenameM SizeOp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> RenameM SubExp
forall a. Rename a => a -> RenameM a
rename SubExp
x
rename (CalcNumGroups SubExp
w Name
max_num_groups SubExp
group_size) =
SubExp -> Name -> SubExp -> SizeOp
CalcNumGroups (SubExp -> Name -> SubExp -> SizeOp)
-> RenameM SubExp -> RenameM (Name -> SubExp -> SizeOp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> RenameM SubExp
forall a. Rename a => a -> RenameM a
rename SubExp
w RenameM (Name -> SubExp -> SizeOp)
-> RenameM Name -> RenameM (SubExp -> SizeOp)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Name -> RenameM Name
forall (f :: * -> *) a. Applicative f => a -> f a
pure Name
max_num_groups RenameM (SubExp -> SizeOp) -> RenameM SubExp -> RenameM SizeOp
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SubExp -> RenameM SubExp
forall a. Rename a => a -> RenameM a
rename SubExp
group_size
rename SizeOp
x = SizeOp -> RenameM SizeOp
forall (f :: * -> *) a. Applicative f => a -> f a
pure SizeOp
x
instance IsOp SizeOp where
safeOp :: SizeOp -> Bool
safeOp SizeOp
_ = Bool
True
cheapOp :: SizeOp -> Bool
cheapOp SizeOp
_ = Bool
True
instance TypedOp SizeOp where
opType :: SizeOp -> m [ExtType]
opType SplitSpace{} = [ExtType] -> m [ExtType]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [PrimType -> ExtType
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32]
opType (GetSize Name
_ SizeClass
_) = [ExtType] -> m [ExtType]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [PrimType -> ExtType
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32]
opType (GetSizeMax SizeClass
_) = [ExtType] -> m [ExtType]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [PrimType -> ExtType
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32]
opType CmpSizeLe{} = [ExtType] -> m [ExtType]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [PrimType -> ExtType
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
Bool]
opType CalcNumGroups{} = [ExtType] -> m [ExtType]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [PrimType -> ExtType
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32]
instance AliasedOp SizeOp where
opAliases :: SizeOp -> [Names]
opAliases SizeOp
_ = [Names
forall a. Monoid a => a
mempty]
consumedInOp :: SizeOp -> Names
consumedInOp SizeOp
_ = Names
forall a. Monoid a => a
mempty
instance FreeIn SizeOp where
freeIn' :: SizeOp -> FV
freeIn' (SplitSpace SplitOrdering
o SubExp
w SubExp
i SubExp
elems_per_thread) =
SplitOrdering -> FV
forall a. FreeIn a => a -> FV
freeIn' SplitOrdering
o FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> [SubExp] -> FV
forall a. FreeIn a => a -> FV
freeIn' [SubExp
w, SubExp
i, SubExp
elems_per_thread]
freeIn' (CmpSizeLe Name
_ SizeClass
_ SubExp
x) = SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn' SubExp
x
freeIn' (CalcNumGroups SubExp
w Name
_ SubExp
group_size) = SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn' SubExp
w FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn' SubExp
group_size
freeIn' SizeOp
_ = FV
forall a. Monoid a => a
mempty
instance PP.Pretty SizeOp where
ppr :: SizeOp -> Doc
ppr (SplitSpace SplitOrdering
o SubExp
w SubExp
i SubExp
elems_per_thread) =
String -> Doc
text String
"splitSpace" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
suff Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<>
Doc -> Doc
parens ([Doc] -> Doc
commasep [SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
w, SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
i, SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
elems_per_thread])
where suff :: Doc
suff = case SplitOrdering
o of SplitOrdering
SplitContiguous -> Doc
forall a. Monoid a => a
mempty
SplitStrided SubExp
stride -> String -> Doc
text String
"Strided" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens (SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
stride)
ppr (GetSize Name
name SizeClass
size_class) =
String -> Doc
text String
"get_size" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens ([Doc] -> Doc
commasep [Name -> Doc
forall a. Pretty a => a -> Doc
ppr Name
name, SizeClass -> Doc
forall a. Pretty a => a -> Doc
ppr SizeClass
size_class])
ppr (GetSizeMax SizeClass
size_class) =
String -> Doc
text String
"get_size_max" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens ([Doc] -> Doc
commasep [SizeClass -> Doc
forall a. Pretty a => a -> Doc
ppr SizeClass
size_class])
ppr (CmpSizeLe Name
name SizeClass
size_class SubExp
x) =
String -> Doc
text String
"get_size" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens ([Doc] -> Doc
commasep [Name -> Doc
forall a. Pretty a => a -> Doc
ppr Name
name, SizeClass -> Doc
forall a. Pretty a => a -> Doc
ppr SizeClass
size_class]) Doc -> Doc -> Doc
<+>
String -> Doc
text String
"<=" Doc -> Doc -> Doc
<+> SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
x
ppr (CalcNumGroups SubExp
w Name
max_num_groups SubExp
group_size) =
String -> Doc
text String
"calc_num_groups" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens ([Doc] -> Doc
commasep [SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
w, Name -> Doc
forall a. Pretty a => a -> Doc
ppr Name
max_num_groups, SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
group_size])
instance OpMetrics SizeOp where
opMetrics :: SizeOp -> MetricsM ()
opMetrics SplitSpace{} = Text -> MetricsM ()
seen Text
"SplitSpace"
opMetrics GetSize{} = Text -> MetricsM ()
seen Text
"GetSize"
opMetrics GetSizeMax{} = Text -> MetricsM ()
seen Text
"GetSizeMax"
opMetrics CmpSizeLe{} = Text -> MetricsM ()
seen Text
"CmpSizeLe"
opMetrics CalcNumGroups{} = Text -> MetricsM ()
seen Text
"CalcNumGroups"
typeCheckSizeOp :: TC.Checkable lore => SizeOp -> TC.TypeM lore ()
typeCheckSizeOp :: SizeOp -> TypeM lore ()
typeCheckSizeOp (SplitSpace SplitOrdering
o SubExp
w SubExp
i SubExp
elems_per_thread) = do
case SplitOrdering
o of
SplitOrdering
SplitContiguous -> () -> TypeM lore ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
SplitStrided SubExp
stride -> [Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32] SubExp
stride
(SubExp -> TypeM lore ()) -> [SubExp] -> TypeM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ([Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32]) [SubExp
w, SubExp
i, SubExp
elems_per_thread]
typeCheckSizeOp GetSize{} = () -> TypeM lore ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
typeCheckSizeOp GetSizeMax{} = () -> TypeM lore ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
typeCheckSizeOp (CmpSizeLe Name
_ SizeClass
_ SubExp
x) = [Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32] SubExp
x
typeCheckSizeOp (CalcNumGroups SubExp
w Name
_ SubExp
group_size) = do [Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
w
[Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32] SubExp
group_size
data HostOp lore op
= SegOp (SegOp SegLevel lore)
| SizeOp SizeOp
| OtherOp op
deriving (HostOp lore op -> HostOp lore op -> Bool
(HostOp lore op -> HostOp lore op -> Bool)
-> (HostOp lore op -> HostOp lore op -> Bool)
-> Eq (HostOp lore op)
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall lore op.
(Decorations lore, Eq op) =>
HostOp lore op -> HostOp lore op -> Bool
/= :: HostOp lore op -> HostOp lore op -> Bool
$c/= :: forall lore op.
(Decorations lore, Eq op) =>
HostOp lore op -> HostOp lore op -> Bool
== :: HostOp lore op -> HostOp lore op -> Bool
$c== :: forall lore op.
(Decorations lore, Eq op) =>
HostOp lore op -> HostOp lore op -> Bool
Eq, Eq (HostOp lore op)
Eq (HostOp lore op)
-> (HostOp lore op -> HostOp lore op -> Ordering)
-> (HostOp lore op -> HostOp lore op -> Bool)
-> (HostOp lore op -> HostOp lore op -> Bool)
-> (HostOp lore op -> HostOp lore op -> Bool)
-> (HostOp lore op -> HostOp lore op -> Bool)
-> (HostOp lore op -> HostOp lore op -> HostOp lore op)
-> (HostOp lore op -> HostOp lore op -> HostOp lore op)
-> Ord (HostOp lore op)
HostOp lore op -> HostOp lore op -> Bool
HostOp lore op -> HostOp lore op -> Ordering
HostOp lore op -> HostOp lore op -> HostOp lore op
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall lore op. (Decorations lore, Ord op) => Eq (HostOp lore op)
forall lore op.
(Decorations lore, Ord op) =>
HostOp lore op -> HostOp lore op -> Bool
forall lore op.
(Decorations lore, Ord op) =>
HostOp lore op -> HostOp lore op -> Ordering
forall lore op.
(Decorations lore, Ord op) =>
HostOp lore op -> HostOp lore op -> HostOp lore op
min :: HostOp lore op -> HostOp lore op -> HostOp lore op
$cmin :: forall lore op.
(Decorations lore, Ord op) =>
HostOp lore op -> HostOp lore op -> HostOp lore op
max :: HostOp lore op -> HostOp lore op -> HostOp lore op
$cmax :: forall lore op.
(Decorations lore, Ord op) =>
HostOp lore op -> HostOp lore op -> HostOp lore op
>= :: HostOp lore op -> HostOp lore op -> Bool
$c>= :: forall lore op.
(Decorations lore, Ord op) =>
HostOp lore op -> HostOp lore op -> Bool
> :: HostOp lore op -> HostOp lore op -> Bool
$c> :: forall lore op.
(Decorations lore, Ord op) =>
HostOp lore op -> HostOp lore op -> Bool
<= :: HostOp lore op -> HostOp lore op -> Bool
$c<= :: forall lore op.
(Decorations lore, Ord op) =>
HostOp lore op -> HostOp lore op -> Bool
< :: HostOp lore op -> HostOp lore op -> Bool
$c< :: forall lore op.
(Decorations lore, Ord op) =>
HostOp lore op -> HostOp lore op -> Bool
compare :: HostOp lore op -> HostOp lore op -> Ordering
$ccompare :: forall lore op.
(Decorations lore, Ord op) =>
HostOp lore op -> HostOp lore op -> Ordering
$cp1Ord :: forall lore op. (Decorations lore, Ord op) => Eq (HostOp lore op)
Ord, Int -> HostOp lore op -> ShowS
[HostOp lore op] -> ShowS
HostOp lore op -> String
(Int -> HostOp lore op -> ShowS)
-> (HostOp lore op -> String)
-> ([HostOp lore op] -> ShowS)
-> Show (HostOp lore op)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall lore op.
(Decorations lore, Show op) =>
Int -> HostOp lore op -> ShowS
forall lore op.
(Decorations lore, Show op) =>
[HostOp lore op] -> ShowS
forall lore op.
(Decorations lore, Show op) =>
HostOp lore op -> String
showList :: [HostOp lore op] -> ShowS
$cshowList :: forall lore op.
(Decorations lore, Show op) =>
[HostOp lore op] -> ShowS
show :: HostOp lore op -> String
$cshow :: forall lore op.
(Decorations lore, Show op) =>
HostOp lore op -> String
showsPrec :: Int -> HostOp lore op -> ShowS
$cshowsPrec :: forall lore op.
(Decorations lore, Show op) =>
Int -> HostOp lore op -> ShowS
Show)
instance (ASTLore lore, Substitute op) => Substitute (HostOp lore op) where
substituteNames :: Map VName VName -> HostOp lore op -> HostOp lore op
substituteNames Map VName VName
substs (SegOp SegOp SegLevel lore
op) =
SegOp SegLevel lore -> HostOp lore op
forall lore op. SegOp SegLevel lore -> HostOp lore op
SegOp (SegOp SegLevel lore -> HostOp lore op)
-> SegOp SegLevel lore -> HostOp lore op
forall a b. (a -> b) -> a -> b
$ Map VName VName -> SegOp SegLevel lore -> SegOp SegLevel lore
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs SegOp SegLevel lore
op
substituteNames Map VName VName
substs (OtherOp op
op) =
op -> HostOp lore op
forall lore op. op -> HostOp lore op
OtherOp (op -> HostOp lore op) -> op -> HostOp lore op
forall a b. (a -> b) -> a -> b
$ Map VName VName -> op -> op
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs op
op
substituteNames Map VName VName
substs (SizeOp SizeOp
op) =
SizeOp -> HostOp lore op
forall lore op. SizeOp -> HostOp lore op
SizeOp (SizeOp -> HostOp lore op) -> SizeOp -> HostOp lore op
forall a b. (a -> b) -> a -> b
$ Map VName VName -> SizeOp -> SizeOp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs SizeOp
op
instance (ASTLore lore, Rename op) => Rename (HostOp lore op) where
rename :: HostOp lore op -> RenameM (HostOp lore op)
rename (SegOp SegOp SegLevel lore
op) = SegOp SegLevel lore -> HostOp lore op
forall lore op. SegOp SegLevel lore -> HostOp lore op
SegOp (SegOp SegLevel lore -> HostOp lore op)
-> RenameM (SegOp SegLevel lore) -> RenameM (HostOp lore op)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOp SegLevel lore -> RenameM (SegOp SegLevel lore)
forall a. Rename a => a -> RenameM a
rename SegOp SegLevel lore
op
rename (OtherOp op
op) = op -> HostOp lore op
forall lore op. op -> HostOp lore op
OtherOp (op -> HostOp lore op) -> RenameM op -> RenameM (HostOp lore op)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> op -> RenameM op
forall a. Rename a => a -> RenameM a
rename op
op
rename (SizeOp SizeOp
op) = SizeOp -> HostOp lore op
forall lore op. SizeOp -> HostOp lore op
SizeOp (SizeOp -> HostOp lore op)
-> RenameM SizeOp -> RenameM (HostOp lore op)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SizeOp -> RenameM SizeOp
forall a. Rename a => a -> RenameM a
rename SizeOp
op
instance (ASTLore lore, IsOp op) => IsOp (HostOp lore op) where
safeOp :: HostOp lore op -> Bool
safeOp (SegOp SegOp SegLevel lore
op) = SegOp SegLevel lore -> Bool
forall op. IsOp op => op -> Bool
safeOp SegOp SegLevel lore
op
safeOp (OtherOp op
op) = op -> Bool
forall op. IsOp op => op -> Bool
safeOp op
op
safeOp (SizeOp SizeOp
op) = SizeOp -> Bool
forall op. IsOp op => op -> Bool
safeOp SizeOp
op
cheapOp :: HostOp lore op -> Bool
cheapOp (SegOp SegOp SegLevel lore
op) = SegOp SegLevel lore -> Bool
forall op. IsOp op => op -> Bool
cheapOp SegOp SegLevel lore
op
cheapOp (OtherOp op
op) = op -> Bool
forall op. IsOp op => op -> Bool
cheapOp op
op
cheapOp (SizeOp SizeOp
op) = SizeOp -> Bool
forall op. IsOp op => op -> Bool
cheapOp SizeOp
op
instance TypedOp op => TypedOp (HostOp lore op) where
opType :: HostOp lore op -> m [ExtType]
opType (SegOp SegOp SegLevel lore
op) = SegOp SegLevel lore -> m [ExtType]
forall op t (m :: * -> *).
(TypedOp op, HasScope t m) =>
op -> m [ExtType]
opType SegOp SegLevel lore
op
opType (OtherOp op
op) = op -> m [ExtType]
forall op t (m :: * -> *).
(TypedOp op, HasScope t m) =>
op -> m [ExtType]
opType op
op
opType (SizeOp SizeOp
op) = SizeOp -> m [ExtType]
forall op t (m :: * -> *).
(TypedOp op, HasScope t m) =>
op -> m [ExtType]
opType SizeOp
op
instance (Aliased lore, AliasedOp op, ASTLore lore) => AliasedOp (HostOp lore op) where
opAliases :: HostOp lore op -> [Names]
opAliases (SegOp SegOp SegLevel lore
op) = SegOp SegLevel lore -> [Names]
forall op. AliasedOp op => op -> [Names]
opAliases SegOp SegLevel lore
op
opAliases (OtherOp op
op) = op -> [Names]
forall op. AliasedOp op => op -> [Names]
opAliases op
op
opAliases (SizeOp SizeOp
op) = SizeOp -> [Names]
forall op. AliasedOp op => op -> [Names]
opAliases SizeOp
op
consumedInOp :: HostOp lore op -> Names
consumedInOp (SegOp SegOp SegLevel lore
op) = SegOp SegLevel lore -> Names
forall op. AliasedOp op => op -> Names
consumedInOp SegOp SegLevel lore
op
consumedInOp (OtherOp op
op) = op -> Names
forall op. AliasedOp op => op -> Names
consumedInOp op
op
consumedInOp (SizeOp SizeOp
op) = SizeOp -> Names
forall op. AliasedOp op => op -> Names
consumedInOp SizeOp
op
instance (ASTLore lore, FreeIn op) => FreeIn (HostOp lore op) where
freeIn' :: HostOp lore op -> FV
freeIn' (SegOp SegOp SegLevel lore
op) = SegOp SegLevel lore -> FV
forall a. FreeIn a => a -> FV
freeIn' SegOp SegLevel lore
op
freeIn' (OtherOp op
op) = op -> FV
forall a. FreeIn a => a -> FV
freeIn' op
op
freeIn' (SizeOp SizeOp
op) = SizeOp -> FV
forall a. FreeIn a => a -> FV
freeIn' SizeOp
op
instance (CanBeAliased (Op lore), CanBeAliased op, ASTLore lore) => CanBeAliased (HostOp lore op) where
type OpWithAliases (HostOp lore op) = HostOp (Aliases lore) (OpWithAliases op)
addOpAliases :: HostOp lore op -> OpWithAliases (HostOp lore op)
addOpAliases (SegOp SegOp SegLevel lore
op) = SegOp SegLevel (Aliases lore)
-> HostOp (Aliases lore) (OpWithAliases op)
forall lore op. SegOp SegLevel lore -> HostOp lore op
SegOp (SegOp SegLevel (Aliases lore)
-> HostOp (Aliases lore) (OpWithAliases op))
-> SegOp SegLevel (Aliases lore)
-> HostOp (Aliases lore) (OpWithAliases op)
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel lore -> OpWithAliases (SegOp SegLevel lore)
forall op. CanBeAliased op => op -> OpWithAliases op
addOpAliases SegOp SegLevel lore
op
addOpAliases (OtherOp op
op) = OpWithAliases op -> HostOp (Aliases lore) (OpWithAliases op)
forall lore op. op -> HostOp lore op
OtherOp (OpWithAliases op -> HostOp (Aliases lore) (OpWithAliases op))
-> OpWithAliases op -> HostOp (Aliases lore) (OpWithAliases op)
forall a b. (a -> b) -> a -> b
$ op -> OpWithAliases op
forall op. CanBeAliased op => op -> OpWithAliases op
addOpAliases op
op
addOpAliases (SizeOp SizeOp
op) = SizeOp -> HostOp (Aliases lore) (OpWithAliases op)
forall lore op. SizeOp -> HostOp lore op
SizeOp SizeOp
op
removeOpAliases :: OpWithAliases (HostOp lore op) -> HostOp lore op
removeOpAliases (SegOp op) = SegOp SegLevel lore -> HostOp lore op
forall lore op. SegOp SegLevel lore -> HostOp lore op
SegOp (SegOp SegLevel lore -> HostOp lore op)
-> SegOp SegLevel lore -> HostOp lore op
forall a b. (a -> b) -> a -> b
$ OpWithAliases (SegOp SegLevel lore) -> SegOp SegLevel lore
forall op. CanBeAliased op => OpWithAliases op -> op
removeOpAliases OpWithAliases (SegOp SegLevel lore)
SegOp SegLevel (Aliases lore)
op
removeOpAliases (OtherOp op) = op -> HostOp lore op
forall lore op. op -> HostOp lore op
OtherOp (op -> HostOp lore op) -> op -> HostOp lore op
forall a b. (a -> b) -> a -> b
$ OpWithAliases op -> op
forall op. CanBeAliased op => OpWithAliases op -> op
removeOpAliases OpWithAliases op
op
removeOpAliases (SizeOp op) = SizeOp -> HostOp lore op
forall lore op. SizeOp -> HostOp lore op
SizeOp SizeOp
op
instance (CanBeWise (Op lore), CanBeWise op, ASTLore lore) => CanBeWise (HostOp lore op) where
type OpWithWisdom (HostOp lore op) = HostOp (Wise lore) (OpWithWisdom op)
removeOpWisdom :: OpWithWisdom (HostOp lore op) -> HostOp lore op
removeOpWisdom (SegOp op) = SegOp SegLevel lore -> HostOp lore op
forall lore op. SegOp SegLevel lore -> HostOp lore op
SegOp (SegOp SegLevel lore -> HostOp lore op)
-> SegOp SegLevel lore -> HostOp lore op
forall a b. (a -> b) -> a -> b
$ OpWithWisdom (SegOp SegLevel lore) -> SegOp SegLevel lore
forall op. CanBeWise op => OpWithWisdom op -> op
removeOpWisdom OpWithWisdom (SegOp SegLevel lore)
SegOp SegLevel (Wise lore)
op
removeOpWisdom (OtherOp op) = op -> HostOp lore op
forall lore op. op -> HostOp lore op
OtherOp (op -> HostOp lore op) -> op -> HostOp lore op
forall a b. (a -> b) -> a -> b
$ OpWithWisdom op -> op
forall op. CanBeWise op => OpWithWisdom op -> op
removeOpWisdom OpWithWisdom op
op
removeOpWisdom (SizeOp op) = SizeOp -> HostOp lore op
forall lore op. SizeOp -> HostOp lore op
SizeOp SizeOp
op
instance (ASTLore lore, ST.IndexOp op) => ST.IndexOp (HostOp lore op) where
indexOp :: SymbolTable lore
-> Int -> HostOp lore op -> [PrimExp VName] -> Maybe Indexed
indexOp SymbolTable lore
vtable Int
k (SegOp SegOp SegLevel lore
op) [PrimExp VName]
is = SymbolTable lore
-> Int -> SegOp SegLevel lore -> [PrimExp VName] -> Maybe Indexed
forall op lore.
(IndexOp op, ASTLore lore, IndexOp (Op lore)) =>
SymbolTable lore -> Int -> op -> [PrimExp VName] -> Maybe Indexed
ST.indexOp SymbolTable lore
vtable Int
k SegOp SegLevel lore
op [PrimExp VName]
is
indexOp SymbolTable lore
vtable Int
k (OtherOp op
op) [PrimExp VName]
is = SymbolTable lore -> Int -> op -> [PrimExp VName] -> Maybe Indexed
forall op lore.
(IndexOp op, ASTLore lore, IndexOp (Op lore)) =>
SymbolTable lore -> Int -> op -> [PrimExp VName] -> Maybe Indexed
ST.indexOp SymbolTable lore
vtable Int
k op
op [PrimExp VName]
is
indexOp SymbolTable lore
_ Int
_ HostOp lore op
_ [PrimExp VName]
_ = Maybe Indexed
forall a. Maybe a
Nothing
instance (PrettyLore lore, PP.Pretty op) => PP.Pretty (HostOp lore op) where
ppr :: HostOp lore op -> Doc
ppr (SegOp SegOp SegLevel lore
op) = SegOp SegLevel lore -> Doc
forall a. Pretty a => a -> Doc
ppr SegOp SegLevel lore
op
ppr (OtherOp op
op) = op -> Doc
forall a. Pretty a => a -> Doc
ppr op
op
ppr (SizeOp SizeOp
op) = SizeOp -> Doc
forall a. Pretty a => a -> Doc
ppr SizeOp
op
instance (OpMetrics (Op lore), OpMetrics op) => OpMetrics (HostOp lore op) where
opMetrics :: HostOp lore op -> MetricsM ()
opMetrics (SegOp SegOp SegLevel lore
op) = SegOp SegLevel lore -> MetricsM ()
forall op. OpMetrics op => op -> MetricsM ()
opMetrics SegOp SegLevel lore
op
opMetrics (OtherOp op
op) = op -> MetricsM ()
forall op. OpMetrics op => op -> MetricsM ()
opMetrics op
op
opMetrics (SizeOp SizeOp
op) = SizeOp -> MetricsM ()
forall op. OpMetrics op => op -> MetricsM ()
opMetrics SizeOp
op
checkSegLevel :: TC.Checkable lore =>
Maybe SegLevel -> SegLevel -> TC.TypeM lore ()
checkSegLevel :: Maybe SegLevel -> SegLevel -> TypeM lore ()
checkSegLevel Maybe SegLevel
Nothing SegLevel
lvl = do
[Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32] (SubExp -> TypeM lore ()) -> SubExp -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ Count NumGroups SubExp -> SubExp
forall u e. Count u e -> e
unCount (Count NumGroups SubExp -> SubExp)
-> Count NumGroups SubExp -> SubExp
forall a b. (a -> b) -> a -> b
$ SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl
[Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32] (SubExp -> TypeM lore ()) -> SubExp -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ Count GroupSize SubExp -> SubExp
forall u e. Count u e -> e
unCount (Count GroupSize SubExp -> SubExp)
-> Count GroupSize SubExp -> SubExp
forall a b. (a -> b) -> a -> b
$ SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl
checkSegLevel (Just SegThread{}) SegLevel
_ =
ErrorCase lore -> TypeM lore ()
forall lore a. ErrorCase lore -> TypeM lore a
TC.bad (ErrorCase lore -> TypeM lore ())
-> ErrorCase lore -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ String -> ErrorCase lore
forall lore. String -> ErrorCase lore
TC.TypeError String
"SegOps cannot occur when already at thread level."
checkSegLevel (Just SegLevel
x) SegLevel
y
| SegLevel
x SegLevel -> SegLevel -> Bool
forall a. Eq a => a -> a -> Bool
== SegLevel
y = ErrorCase lore -> TypeM lore ()
forall lore a. ErrorCase lore -> TypeM lore a
TC.bad (ErrorCase lore -> TypeM lore ())
-> ErrorCase lore -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ String -> ErrorCase lore
forall lore. String -> ErrorCase lore
TC.TypeError (String -> ErrorCase lore) -> String -> ErrorCase lore
forall a b. (a -> b) -> a -> b
$ String
"Already at at level " String -> ShowS
forall a. [a] -> [a] -> [a]
++ SegLevel -> String
forall a. Pretty a => a -> String
pretty SegLevel
x
| SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
x Count NumGroups SubExp -> Count NumGroups SubExp -> Bool
forall a. Eq a => a -> a -> Bool
/= SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
y Bool -> Bool -> Bool
|| SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
x Count GroupSize SubExp -> Count GroupSize SubExp -> Bool
forall a. Eq a => a -> a -> Bool
/= SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
y =
ErrorCase lore -> TypeM lore ()
forall lore a. ErrorCase lore -> TypeM lore a
TC.bad (ErrorCase lore -> TypeM lore ())
-> ErrorCase lore -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ String -> ErrorCase lore
forall lore. String -> ErrorCase lore
TC.TypeError String
"Physical layout for SegLevel does not match parent SegLevel."
| Bool
otherwise =
() -> TypeM lore ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
typeCheckHostOp :: TC.Checkable lore =>
(SegLevel -> OpWithAliases (Op lore) -> TC.TypeM lore ())
-> Maybe SegLevel
-> (op -> TC.TypeM lore ())
-> HostOp (Aliases lore) op
-> TC.TypeM lore ()
typeCheckHostOp :: (SegLevel -> OpWithAliases (Op lore) -> TypeM lore ())
-> Maybe SegLevel
-> (op -> TypeM lore ())
-> HostOp (Aliases lore) op
-> TypeM lore ()
typeCheckHostOp SegLevel -> OpWithAliases (Op lore) -> TypeM lore ()
checker Maybe SegLevel
lvl op -> TypeM lore ()
_ (SegOp SegOp SegLevel (Aliases lore)
op) =
(OpWithAliases (Op lore) -> TypeM lore ())
-> TypeM lore () -> TypeM lore ()
forall lore a.
(OpWithAliases (Op lore) -> TypeM lore ())
-> TypeM lore a -> TypeM lore a
TC.checkOpWith (SegLevel -> OpWithAliases (Op lore) -> TypeM lore ()
checker (SegLevel -> OpWithAliases (Op lore) -> TypeM lore ())
-> SegLevel -> OpWithAliases (Op lore) -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel (Aliases lore) -> SegLevel
forall lvl lore. SegOp lvl lore -> lvl
segLevel SegOp SegLevel (Aliases lore)
op) (TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
(SegLevel -> TypeM lore ())
-> SegOp SegLevel (Aliases lore) -> TypeM lore ()
forall lore lvl.
Checkable lore =>
(lvl -> TypeM lore ()) -> SegOp lvl (Aliases lore) -> TypeM lore ()
typeCheckSegOp (Maybe SegLevel -> SegLevel -> TypeM lore ()
forall lore.
Checkable lore =>
Maybe SegLevel -> SegLevel -> TypeM lore ()
checkSegLevel Maybe SegLevel
lvl) SegOp SegLevel (Aliases lore)
op
typeCheckHostOp SegLevel -> OpWithAliases (Op lore) -> TypeM lore ()
_ Maybe SegLevel
_ op -> TypeM lore ()
f (OtherOp op
op) = op -> TypeM lore ()
f op
op
typeCheckHostOp SegLevel -> OpWithAliases (Op lore) -> TypeM lore ()
_ Maybe SegLevel
_ op -> TypeM lore ()
_ (SizeOp SizeOp
op) = SizeOp -> TypeM lore ()
forall lore. Checkable lore => SizeOp -> TypeM lore ()
typeCheckSizeOp SizeOp
op