{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
module Futhark.IR.GPU.Op
(
SizeOp (..),
HostOp (..),
traverseHostOpStms,
typeCheckHostOp,
SegLevel (..),
segVirt,
SegVirt (..),
SegSeqDims (..),
KernelGrid (..),
module Futhark.IR.GPU.Sizes,
module Futhark.IR.SegOp,
)
where
import Control.Monad
import Data.Sequence qualified as SQ
import Data.Text qualified as T
import Futhark.Analysis.Alias qualified as Alias
import Futhark.Analysis.Metrics
import Futhark.Analysis.SymbolTable qualified as ST
import Futhark.IR
import Futhark.IR.Aliases (Aliases, CanBeAliased (..))
import Futhark.IR.GPU.Sizes
import Futhark.IR.Prop.Aliases
import Futhark.IR.SegOp
import Futhark.IR.TypeCheck qualified as TC
import Futhark.Optimise.Simplify.Engine qualified as Engine
import Futhark.Optimise.Simplify.Rep
import Futhark.Transform.Rename
import Futhark.Transform.Substitute
import Futhark.Util.Pretty
( commasep,
parens,
ppTuple',
pretty,
(<+>),
)
import Futhark.Util.Pretty qualified as PP
newtype SegSeqDims = SegSeqDims {SegSeqDims -> [Int]
segSeqDims :: [Int]}
deriving (SegSeqDims -> SegSeqDims -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SegSeqDims -> SegSeqDims -> Bool
$c/= :: SegSeqDims -> SegSeqDims -> Bool
== :: SegSeqDims -> SegSeqDims -> Bool
$c== :: SegSeqDims -> SegSeqDims -> Bool
Eq, Eq SegSeqDims
SegSeqDims -> SegSeqDims -> Bool
SegSeqDims -> SegSeqDims -> Ordering
SegSeqDims -> SegSeqDims -> SegSeqDims
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: SegSeqDims -> SegSeqDims -> SegSeqDims
$cmin :: SegSeqDims -> SegSeqDims -> SegSeqDims
max :: SegSeqDims -> SegSeqDims -> SegSeqDims
$cmax :: SegSeqDims -> SegSeqDims -> SegSeqDims
>= :: SegSeqDims -> SegSeqDims -> Bool
$c>= :: SegSeqDims -> SegSeqDims -> Bool
> :: SegSeqDims -> SegSeqDims -> Bool
$c> :: SegSeqDims -> SegSeqDims -> Bool
<= :: SegSeqDims -> SegSeqDims -> Bool
$c<= :: SegSeqDims -> SegSeqDims -> Bool
< :: SegSeqDims -> SegSeqDims -> Bool
$c< :: SegSeqDims -> SegSeqDims -> Bool
compare :: SegSeqDims -> SegSeqDims -> Ordering
$ccompare :: SegSeqDims -> SegSeqDims -> Ordering
Ord, Int -> SegSeqDims -> ShowS
[SegSeqDims] -> ShowS
SegSeqDims -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SegSeqDims] -> ShowS
$cshowList :: [SegSeqDims] -> ShowS
show :: SegSeqDims -> String
$cshow :: SegSeqDims -> String
showsPrec :: Int -> SegSeqDims -> ShowS
$cshowsPrec :: Int -> SegSeqDims -> ShowS
Show)
data SegVirt
= SegVirt
| SegNoVirt
|
SegNoVirtFull SegSeqDims
deriving (SegVirt -> SegVirt -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SegVirt -> SegVirt -> Bool
$c/= :: SegVirt -> SegVirt -> Bool
== :: SegVirt -> SegVirt -> Bool
$c== :: SegVirt -> SegVirt -> Bool
Eq, Eq SegVirt
SegVirt -> SegVirt -> Bool
SegVirt -> SegVirt -> Ordering
SegVirt -> SegVirt -> SegVirt
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: SegVirt -> SegVirt -> SegVirt
$cmin :: SegVirt -> SegVirt -> SegVirt
max :: SegVirt -> SegVirt -> SegVirt
$cmax :: SegVirt -> SegVirt -> SegVirt
>= :: SegVirt -> SegVirt -> Bool
$c>= :: SegVirt -> SegVirt -> Bool
> :: SegVirt -> SegVirt -> Bool
$c> :: SegVirt -> SegVirt -> Bool
<= :: SegVirt -> SegVirt -> Bool
$c<= :: SegVirt -> SegVirt -> Bool
< :: SegVirt -> SegVirt -> Bool
$c< :: SegVirt -> SegVirt -> Bool
compare :: SegVirt -> SegVirt -> Ordering
$ccompare :: SegVirt -> SegVirt -> Ordering
Ord, Int -> SegVirt -> ShowS
[SegVirt] -> ShowS
SegVirt -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SegVirt] -> ShowS
$cshowList :: [SegVirt] -> ShowS
show :: SegVirt -> String
$cshow :: SegVirt -> String
showsPrec :: Int -> SegVirt -> ShowS
$cshowsPrec :: Int -> SegVirt -> ShowS
Show)
data KernelGrid = KernelGrid
{ KernelGrid -> Count NumGroups SubExp
gridNumGroups :: Count NumGroups SubExp,
KernelGrid -> Count GroupSize SubExp
gridGroupSize :: Count GroupSize SubExp
}
deriving (KernelGrid -> KernelGrid -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: KernelGrid -> KernelGrid -> Bool
$c/= :: KernelGrid -> KernelGrid -> Bool
== :: KernelGrid -> KernelGrid -> Bool
$c== :: KernelGrid -> KernelGrid -> Bool
Eq, Eq KernelGrid
KernelGrid -> KernelGrid -> Bool
KernelGrid -> KernelGrid -> Ordering
KernelGrid -> KernelGrid -> KernelGrid
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: KernelGrid -> KernelGrid -> KernelGrid
$cmin :: KernelGrid -> KernelGrid -> KernelGrid
max :: KernelGrid -> KernelGrid -> KernelGrid
$cmax :: KernelGrid -> KernelGrid -> KernelGrid
>= :: KernelGrid -> KernelGrid -> Bool
$c>= :: KernelGrid -> KernelGrid -> Bool
> :: KernelGrid -> KernelGrid -> Bool
$c> :: KernelGrid -> KernelGrid -> Bool
<= :: KernelGrid -> KernelGrid -> Bool
$c<= :: KernelGrid -> KernelGrid -> Bool
< :: KernelGrid -> KernelGrid -> Bool
$c< :: KernelGrid -> KernelGrid -> Bool
compare :: KernelGrid -> KernelGrid -> Ordering
$ccompare :: KernelGrid -> KernelGrid -> Ordering
Ord, Int -> KernelGrid -> ShowS
[KernelGrid] -> ShowS
KernelGrid -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [KernelGrid] -> ShowS
$cshowList :: [KernelGrid] -> ShowS
show :: KernelGrid -> String
$cshow :: KernelGrid -> String
showsPrec :: Int -> KernelGrid -> ShowS
$cshowsPrec :: Int -> KernelGrid -> ShowS
Show)
data SegLevel
= SegThread SegVirt (Maybe KernelGrid)
| SegGroup SegVirt (Maybe KernelGrid)
| SegThreadInGroup SegVirt
deriving (SegLevel -> SegLevel -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SegLevel -> SegLevel -> Bool
$c/= :: SegLevel -> SegLevel -> Bool
== :: SegLevel -> SegLevel -> Bool
$c== :: SegLevel -> SegLevel -> Bool
Eq, Eq SegLevel
SegLevel -> SegLevel -> Bool
SegLevel -> SegLevel -> Ordering
SegLevel -> SegLevel -> SegLevel
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: SegLevel -> SegLevel -> SegLevel
$cmin :: SegLevel -> SegLevel -> SegLevel
max :: SegLevel -> SegLevel -> SegLevel
$cmax :: SegLevel -> SegLevel -> SegLevel
>= :: SegLevel -> SegLevel -> Bool
$c>= :: SegLevel -> SegLevel -> Bool
> :: SegLevel -> SegLevel -> Bool
$c> :: SegLevel -> SegLevel -> Bool
<= :: SegLevel -> SegLevel -> Bool
$c<= :: SegLevel -> SegLevel -> Bool
< :: SegLevel -> SegLevel -> Bool
$c< :: SegLevel -> SegLevel -> Bool
compare :: SegLevel -> SegLevel -> Ordering
$ccompare :: SegLevel -> SegLevel -> Ordering
Ord, Int -> SegLevel -> ShowS
[SegLevel] -> ShowS
SegLevel -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SegLevel] -> ShowS
$cshowList :: [SegLevel] -> ShowS
show :: SegLevel -> String
$cshow :: SegLevel -> String
showsPrec :: Int -> SegLevel -> ShowS
$cshowsPrec :: Int -> SegLevel -> ShowS
Show)
segVirt :: SegLevel -> SegVirt
segVirt :: SegLevel -> SegVirt
segVirt (SegThread SegVirt
v Maybe KernelGrid
_) = SegVirt
v
segVirt (SegGroup SegVirt
v Maybe KernelGrid
_) = SegVirt
v
segVirt (SegThreadInGroup SegVirt
v) = SegVirt
v
instance PP.Pretty SegVirt where
pretty :: forall ann. SegVirt -> Doc ann
pretty SegVirt
SegNoVirt = forall a. Monoid a => a
mempty
pretty (SegNoVirtFull SegSeqDims
dims) = Doc ann
"full" forall ann. Doc ann -> Doc ann -> Doc ann
<+> forall a ann. Pretty a => a -> Doc ann
pretty (SegSeqDims -> [Int]
segSeqDims SegSeqDims
dims)
pretty SegVirt
SegVirt = Doc ann
"virtualise"
instance PP.Pretty KernelGrid where
pretty :: forall ann. KernelGrid -> Doc ann
pretty (KernelGrid Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size) =
Doc ann
"groups=" forall a. Semigroup a => a -> a -> a
<> forall a ann. Pretty a => a -> Doc ann
pretty Count NumGroups SubExp
num_groups forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
PP.semi
forall ann. Doc ann -> Doc ann -> Doc ann
<+> Doc ann
"groupsize=" forall a. Semigroup a => a -> a -> a
<> forall a ann. Pretty a => a -> Doc ann
pretty Count GroupSize SubExp
group_size
instance PP.Pretty SegLevel where
pretty :: forall ann. SegLevel -> Doc ann
pretty (SegThread SegVirt
virt Maybe KernelGrid
grid) =
forall ann. Doc ann -> Doc ann
PP.parens (Doc ann
"thread" forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
PP.semi forall ann. Doc ann -> Doc ann -> Doc ann
<+> forall a ann. Pretty a => a -> Doc ann
pretty SegVirt
virt forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
PP.semi forall ann. Doc ann -> Doc ann -> Doc ann
<+> forall a ann. Pretty a => a -> Doc ann
pretty Maybe KernelGrid
grid)
pretty (SegGroup SegVirt
virt Maybe KernelGrid
grid) =
forall ann. Doc ann -> Doc ann
PP.parens (Doc ann
"group" forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
PP.semi forall ann. Doc ann -> Doc ann -> Doc ann
<+> forall a ann. Pretty a => a -> Doc ann
pretty SegVirt
virt forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
PP.semi forall ann. Doc ann -> Doc ann -> Doc ann
<+> forall a ann. Pretty a => a -> Doc ann
pretty Maybe KernelGrid
grid)
pretty (SegThreadInGroup SegVirt
virt) =
forall ann. Doc ann -> Doc ann
PP.parens (Doc ann
"ingroup" forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
PP.semi forall ann. Doc ann -> Doc ann -> Doc ann
<+> forall a ann. Pretty a => a -> Doc ann
pretty SegVirt
virt)
instance Engine.Simplifiable KernelGrid where
simplify :: forall rep.
SimplifiableRep rep =>
KernelGrid -> SimpleM rep KernelGrid
simplify (KernelGrid Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size) =
Count NumGroups SubExp -> Count GroupSize SubExp -> KernelGrid
KernelGrid
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify Count NumGroups SubExp
num_groups
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify Count GroupSize SubExp
group_size
instance Engine.Simplifiable SegLevel where
simplify :: forall rep. SimplifiableRep rep => SegLevel -> SimpleM rep SegLevel
simplify (SegThread SegVirt
virt Maybe KernelGrid
grid) =
SegVirt -> Maybe KernelGrid -> SegLevel
SegThread SegVirt
virt forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify Maybe KernelGrid
grid
simplify (SegGroup SegVirt
virt Maybe KernelGrid
grid) =
SegVirt -> Maybe KernelGrid -> SegLevel
SegGroup SegVirt
virt forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify Maybe KernelGrid
grid
simplify (SegThreadInGroup SegVirt
virt) =
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ SegVirt -> SegLevel
SegThreadInGroup SegVirt
virt
instance Substitute KernelGrid where
substituteNames :: Map VName VName -> KernelGrid -> KernelGrid
substituteNames Map VName VName
substs (KernelGrid Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size) =
Count NumGroups SubExp -> Count GroupSize SubExp -> KernelGrid
KernelGrid
(forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Count NumGroups SubExp
num_groups)
(forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Count GroupSize SubExp
group_size)
instance Substitute SegLevel where
substituteNames :: Map VName VName -> SegLevel -> SegLevel
substituteNames Map VName VName
substs (SegThread SegVirt
virt Maybe KernelGrid
grid) =
SegVirt -> Maybe KernelGrid -> SegLevel
SegThread SegVirt
virt (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Maybe KernelGrid
grid)
substituteNames Map VName VName
substs (SegGroup SegVirt
virt Maybe KernelGrid
grid) =
SegVirt -> Maybe KernelGrid -> SegLevel
SegGroup SegVirt
virt (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Maybe KernelGrid
grid)
substituteNames Map VName VName
_ (SegThreadInGroup SegVirt
virt) =
SegVirt -> SegLevel
SegThreadInGroup SegVirt
virt
instance Rename SegLevel where
rename :: SegLevel -> RenameM SegLevel
rename = forall a. Substitute a => a -> RenameM a
substituteRename
instance FreeIn KernelGrid where
freeIn' :: KernelGrid -> FV
freeIn' (KernelGrid Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size) =
forall a. FreeIn a => a -> FV
freeIn' (Count NumGroups SubExp
num_groups, Count GroupSize SubExp
group_size)
instance FreeIn SegLevel where
freeIn' :: SegLevel -> FV
freeIn' (SegThread SegVirt
_virt Maybe KernelGrid
grid) = forall a. FreeIn a => a -> FV
freeIn' Maybe KernelGrid
grid
freeIn' (SegGroup SegVirt
_virt Maybe KernelGrid
grid) = forall a. FreeIn a => a -> FV
freeIn' Maybe KernelGrid
grid
freeIn' (SegThreadInGroup SegVirt
_virt) = forall a. Monoid a => a
mempty
data SizeOp
=
GetSize Name SizeClass
|
GetSizeMax SizeClass
|
CmpSizeLe Name SizeClass SubExp
|
CalcNumGroups SubExp Name SubExp
deriving (SizeOp -> SizeOp -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SizeOp -> SizeOp -> Bool
$c/= :: SizeOp -> SizeOp -> Bool
== :: SizeOp -> SizeOp -> Bool
$c== :: SizeOp -> SizeOp -> Bool
Eq, Eq SizeOp
SizeOp -> SizeOp -> Bool
SizeOp -> SizeOp -> Ordering
SizeOp -> SizeOp -> SizeOp
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: SizeOp -> SizeOp -> SizeOp
$cmin :: SizeOp -> SizeOp -> SizeOp
max :: SizeOp -> SizeOp -> SizeOp
$cmax :: SizeOp -> SizeOp -> SizeOp
>= :: SizeOp -> SizeOp -> Bool
$c>= :: SizeOp -> SizeOp -> Bool
> :: SizeOp -> SizeOp -> Bool
$c> :: SizeOp -> SizeOp -> Bool
<= :: SizeOp -> SizeOp -> Bool
$c<= :: SizeOp -> SizeOp -> Bool
< :: SizeOp -> SizeOp -> Bool
$c< :: SizeOp -> SizeOp -> Bool
compare :: SizeOp -> SizeOp -> Ordering
$ccompare :: SizeOp -> SizeOp -> Ordering
Ord, Int -> SizeOp -> ShowS
[SizeOp] -> ShowS
SizeOp -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SizeOp] -> ShowS
$cshowList :: [SizeOp] -> ShowS
show :: SizeOp -> String
$cshow :: SizeOp -> String
showsPrec :: Int -> SizeOp -> ShowS
$cshowsPrec :: Int -> SizeOp -> ShowS
Show)
instance Substitute SizeOp where
substituteNames :: Map VName VName -> SizeOp -> SizeOp
substituteNames Map VName VName
substs (CmpSizeLe Name
name SizeClass
sclass SubExp
x) =
Name -> SizeClass -> SubExp -> SizeOp
CmpSizeLe Name
name SizeClass
sclass (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs SubExp
x)
substituteNames Map VName VName
substs (CalcNumGroups SubExp
w Name
max_num_groups SubExp
group_size) =
SubExp -> Name -> SubExp -> SizeOp
CalcNumGroups
(forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs SubExp
w)
Name
max_num_groups
(forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs SubExp
group_size)
substituteNames Map VName VName
_ SizeOp
op = SizeOp
op
instance Rename SizeOp where
rename :: SizeOp -> RenameM SizeOp
rename (CmpSizeLe Name
name SizeClass
sclass SubExp
x) =
Name -> SizeClass -> SubExp -> SizeOp
CmpSizeLe Name
name SizeClass
sclass forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. Rename a => a -> RenameM a
rename SubExp
x
rename (CalcNumGroups SubExp
w Name
max_num_groups SubExp
group_size) =
SubExp -> Name -> SubExp -> SizeOp
CalcNumGroups forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. Rename a => a -> RenameM a
rename SubExp
w forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure Name
max_num_groups forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a. Rename a => a -> RenameM a
rename SubExp
group_size
rename SizeOp
x = forall (f :: * -> *) a. Applicative f => a -> f a
pure SizeOp
x
instance IsOp SizeOp where
safeOp :: SizeOp -> Bool
safeOp SizeOp
_ = Bool
True
cheapOp :: SizeOp -> Bool
cheapOp SizeOp
_ = Bool
True
instance TypedOp SizeOp where
opType :: forall t (m :: * -> *). HasScope t m => SizeOp -> m [ExtType]
opType (GetSize Name
_ SizeClass
_) = forall (f :: * -> *) a. Applicative f => a -> f a
pure [forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]
opType (GetSizeMax SizeClass
_) = forall (f :: * -> *) a. Applicative f => a -> f a
pure [forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]
opType CmpSizeLe {} = forall (f :: * -> *) a. Applicative f => a -> f a
pure [forall shape u. PrimType -> TypeBase shape u
Prim PrimType
Bool]
opType CalcNumGroups {} = forall (f :: * -> *) a. Applicative f => a -> f a
pure [forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]
instance AliasedOp SizeOp where
opAliases :: SizeOp -> [Names]
opAliases SizeOp
_ = [forall a. Monoid a => a
mempty]
consumedInOp :: SizeOp -> Names
consumedInOp SizeOp
_ = forall a. Monoid a => a
mempty
instance FreeIn SizeOp where
freeIn' :: SizeOp -> FV
freeIn' (CmpSizeLe Name
_ SizeClass
_ SubExp
x) = forall a. FreeIn a => a -> FV
freeIn' SubExp
x
freeIn' (CalcNumGroups SubExp
w Name
_ SubExp
group_size) = forall a. FreeIn a => a -> FV
freeIn' SubExp
w forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> FV
freeIn' SubExp
group_size
freeIn' SizeOp
_ = forall a. Monoid a => a
mempty
instance PP.Pretty SizeOp where
pretty :: forall ann. SizeOp -> Doc ann
pretty (GetSize Name
name SizeClass
size_class) =
Doc ann
"get_size" forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann -> Doc ann
parens (forall a. [Doc a] -> Doc a
commasep [forall a ann. Pretty a => a -> Doc ann
pretty Name
name, forall a ann. Pretty a => a -> Doc ann
pretty SizeClass
size_class])
pretty (GetSizeMax SizeClass
size_class) =
Doc ann
"get_size_max" forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann -> Doc ann
parens (forall a. [Doc a] -> Doc a
commasep [forall a ann. Pretty a => a -> Doc ann
pretty SizeClass
size_class])
pretty (CmpSizeLe Name
name SizeClass
size_class SubExp
x) =
Doc ann
"cmp_size" forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann -> Doc ann
parens (forall a. [Doc a] -> Doc a
commasep [forall a ann. Pretty a => a -> Doc ann
pretty Name
name, forall a ann. Pretty a => a -> Doc ann
pretty SizeClass
size_class])
forall ann. Doc ann -> Doc ann -> Doc ann
<+> Doc ann
"<="
forall ann. Doc ann -> Doc ann -> Doc ann
<+> forall a ann. Pretty a => a -> Doc ann
pretty SubExp
x
pretty (CalcNumGroups SubExp
w Name
max_num_groups SubExp
group_size) =
Doc ann
"calc_num_groups" forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann -> Doc ann
parens (forall a. [Doc a] -> Doc a
commasep [forall a ann. Pretty a => a -> Doc ann
pretty SubExp
w, forall a ann. Pretty a => a -> Doc ann
pretty Name
max_num_groups, forall a ann. Pretty a => a -> Doc ann
pretty SubExp
group_size])
instance OpMetrics SizeOp where
opMetrics :: SizeOp -> MetricsM ()
opMetrics GetSize {} = Text -> MetricsM ()
seen Text
"GetSize"
opMetrics GetSizeMax {} = Text -> MetricsM ()
seen Text
"GetSizeMax"
opMetrics CmpSizeLe {} = Text -> MetricsM ()
seen Text
"CmpSizeLe"
opMetrics CalcNumGroups {} = Text -> MetricsM ()
seen Text
"CalcNumGroups"
typeCheckSizeOp :: TC.Checkable rep => SizeOp -> TC.TypeM rep ()
typeCheckSizeOp :: forall rep. Checkable rep => SizeOp -> TypeM rep ()
typeCheckSizeOp GetSize {} = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
typeCheckSizeOp GetSizeMax {} = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
typeCheckSizeOp (CmpSizeLe Name
_ SizeClass
_ SubExp
x) = forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
x
typeCheckSizeOp (CalcNumGroups SubExp
w Name
_ SubExp
group_size) = do
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
w
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
group_size
data HostOp op rep
=
SegOp (SegOp SegLevel rep)
| SizeOp SizeOp
| OtherOp (op rep)
|
GPUBody [Type] (Body rep)
deriving (HostOp op rep -> HostOp op rep -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall (op :: * -> *) rep.
(RepTypes rep, Eq (op rep)) =>
HostOp op rep -> HostOp op rep -> Bool
/= :: HostOp op rep -> HostOp op rep -> Bool
$c/= :: forall (op :: * -> *) rep.
(RepTypes rep, Eq (op rep)) =>
HostOp op rep -> HostOp op rep -> Bool
== :: HostOp op rep -> HostOp op rep -> Bool
$c== :: forall (op :: * -> *) rep.
(RepTypes rep, Eq (op rep)) =>
HostOp op rep -> HostOp op rep -> Bool
Eq, HostOp op rep -> HostOp op rep -> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {op :: * -> *} {rep}.
(RepTypes rep, Ord (op rep)) =>
Eq (HostOp op rep)
forall (op :: * -> *) rep.
(RepTypes rep, Ord (op rep)) =>
HostOp op rep -> HostOp op rep -> Bool
forall (op :: * -> *) rep.
(RepTypes rep, Ord (op rep)) =>
HostOp op rep -> HostOp op rep -> Ordering
forall (op :: * -> *) rep.
(RepTypes rep, Ord (op rep)) =>
HostOp op rep -> HostOp op rep -> HostOp op rep
min :: HostOp op rep -> HostOp op rep -> HostOp op rep
$cmin :: forall (op :: * -> *) rep.
(RepTypes rep, Ord (op rep)) =>
HostOp op rep -> HostOp op rep -> HostOp op rep
max :: HostOp op rep -> HostOp op rep -> HostOp op rep
$cmax :: forall (op :: * -> *) rep.
(RepTypes rep, Ord (op rep)) =>
HostOp op rep -> HostOp op rep -> HostOp op rep
>= :: HostOp op rep -> HostOp op rep -> Bool
$c>= :: forall (op :: * -> *) rep.
(RepTypes rep, Ord (op rep)) =>
HostOp op rep -> HostOp op rep -> Bool
> :: HostOp op rep -> HostOp op rep -> Bool
$c> :: forall (op :: * -> *) rep.
(RepTypes rep, Ord (op rep)) =>
HostOp op rep -> HostOp op rep -> Bool
<= :: HostOp op rep -> HostOp op rep -> Bool
$c<= :: forall (op :: * -> *) rep.
(RepTypes rep, Ord (op rep)) =>
HostOp op rep -> HostOp op rep -> Bool
< :: HostOp op rep -> HostOp op rep -> Bool
$c< :: forall (op :: * -> *) rep.
(RepTypes rep, Ord (op rep)) =>
HostOp op rep -> HostOp op rep -> Bool
compare :: HostOp op rep -> HostOp op rep -> Ordering
$ccompare :: forall (op :: * -> *) rep.
(RepTypes rep, Ord (op rep)) =>
HostOp op rep -> HostOp op rep -> Ordering
Ord, Int -> HostOp op rep -> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall (op :: * -> *) rep.
(RepTypes rep, Show (op rep)) =>
Int -> HostOp op rep -> ShowS
forall (op :: * -> *) rep.
(RepTypes rep, Show (op rep)) =>
[HostOp op rep] -> ShowS
forall (op :: * -> *) rep.
(RepTypes rep, Show (op rep)) =>
HostOp op rep -> String
showList :: [HostOp op rep] -> ShowS
$cshowList :: forall (op :: * -> *) rep.
(RepTypes rep, Show (op rep)) =>
[HostOp op rep] -> ShowS
show :: HostOp op rep -> String
$cshow :: forall (op :: * -> *) rep.
(RepTypes rep, Show (op rep)) =>
HostOp op rep -> String
showsPrec :: Int -> HostOp op rep -> ShowS
$cshowsPrec :: forall (op :: * -> *) rep.
(RepTypes rep, Show (op rep)) =>
Int -> HostOp op rep -> ShowS
Show)
traverseHostOpStms ::
Monad m =>
OpStmsTraverser m (op rep) rep ->
OpStmsTraverser m (HostOp op rep) rep
traverseHostOpStms :: forall (m :: * -> *) (op :: * -> *) rep.
Monad m =>
OpStmsTraverser m (op rep) rep
-> OpStmsTraverser m (HostOp op rep) rep
traverseHostOpStms OpStmsTraverser m (op rep) rep
_ Scope rep -> Stms rep -> m (Stms rep)
f (SegOp SegOp SegLevel rep
segop) = forall (op :: * -> *) rep. SegOp SegLevel rep -> HostOp op rep
SegOp forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) lvl rep.
Monad m =>
OpStmsTraverser m (SegOp lvl rep) rep
traverseSegOpStms Scope rep -> Stms rep -> m (Stms rep)
f SegOp SegLevel rep
segop
traverseHostOpStms OpStmsTraverser m (op rep) rep
_ Scope rep -> Stms rep -> m (Stms rep)
_ (SizeOp SizeOp
sizeop) = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (op :: * -> *) rep. SizeOp -> HostOp op rep
SizeOp SizeOp
sizeop
traverseHostOpStms OpStmsTraverser m (op rep) rep
onOtherOp Scope rep -> Stms rep -> m (Stms rep)
f (OtherOp op rep
other) = forall (op :: * -> *) rep. op rep -> HostOp op rep
OtherOp forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> OpStmsTraverser m (op rep) rep
onOtherOp Scope rep -> Stms rep -> m (Stms rep)
f op rep
other
traverseHostOpStms OpStmsTraverser m (op rep) rep
_ Scope rep -> Stms rep -> m (Stms rep)
f (GPUBody [Type]
ts Body rep
body) = do
Stms rep
stms <- Scope rep -> Stms rep -> m (Stms rep)
f forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$ forall rep. Body rep -> Stms rep
bodyStms Body rep
body
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (op :: * -> *) rep. [Type] -> Body rep -> HostOp op rep
GPUBody [Type]
ts forall a b. (a -> b) -> a -> b
$ Body rep
body {bodyStms :: Stms rep
bodyStms = Stms rep
stms}
instance (ASTRep rep, Substitute (op rep)) => Substitute (HostOp op rep) where
substituteNames :: Map VName VName -> HostOp op rep -> HostOp op rep
substituteNames Map VName VName
substs (SegOp SegOp SegLevel rep
op) =
forall (op :: * -> *) rep. SegOp SegLevel rep -> HostOp op rep
SegOp forall a b. (a -> b) -> a -> b
$ forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs SegOp SegLevel rep
op
substituteNames Map VName VName
substs (OtherOp op rep
op) =
forall (op :: * -> *) rep. op rep -> HostOp op rep
OtherOp forall a b. (a -> b) -> a -> b
$ forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs op rep
op
substituteNames Map VName VName
substs (SizeOp SizeOp
op) =
forall (op :: * -> *) rep. SizeOp -> HostOp op rep
SizeOp forall a b. (a -> b) -> a -> b
$ forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs SizeOp
op
substituteNames Map VName VName
substs (GPUBody [Type]
ts Body rep
body) =
forall (op :: * -> *) rep. [Type] -> Body rep -> HostOp op rep
GPUBody (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs [Type]
ts) (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Body rep
body)
instance (ASTRep rep, Rename (op rep)) => Rename (HostOp op rep) where
rename :: HostOp op rep -> RenameM (HostOp op rep)
rename (SegOp SegOp SegLevel rep
op) = forall (op :: * -> *) rep. SegOp SegLevel rep -> HostOp op rep
SegOp forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. Rename a => a -> RenameM a
rename SegOp SegLevel rep
op
rename (OtherOp op rep
op) = forall (op :: * -> *) rep. op rep -> HostOp op rep
OtherOp forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. Rename a => a -> RenameM a
rename op rep
op
rename (SizeOp SizeOp
op) = forall (op :: * -> *) rep. SizeOp -> HostOp op rep
SizeOp forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. Rename a => a -> RenameM a
rename SizeOp
op
rename (GPUBody [Type]
ts Body rep
body) = forall (op :: * -> *) rep. [Type] -> Body rep -> HostOp op rep
GPUBody forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. Rename a => a -> RenameM a
rename [Type]
ts forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a. Rename a => a -> RenameM a
rename Body rep
body
instance (ASTRep rep, IsOp (op rep)) => IsOp (HostOp op rep) where
safeOp :: HostOp op rep -> Bool
safeOp (SegOp SegOp SegLevel rep
op) = forall op. IsOp op => op -> Bool
safeOp SegOp SegLevel rep
op
safeOp (OtherOp op rep
op) = forall op. IsOp op => op -> Bool
safeOp op rep
op
safeOp (SizeOp SizeOp
op) = forall op. IsOp op => op -> Bool
safeOp SizeOp
op
safeOp (GPUBody [Type]
_ Body rep
body) = forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (forall rep. IsOp (Op rep) => Exp rep -> Bool
safeExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Stm rep -> Exp rep
stmExp) forall a b. (a -> b) -> a -> b
$ forall rep. Body rep -> Stms rep
bodyStms Body rep
body
cheapOp :: HostOp op rep -> Bool
cheapOp (SegOp SegOp SegLevel rep
op) = forall op. IsOp op => op -> Bool
cheapOp SegOp SegLevel rep
op
cheapOp (OtherOp op rep
op) = forall op. IsOp op => op -> Bool
cheapOp op rep
op
cheapOp (SizeOp SizeOp
op) = forall op. IsOp op => op -> Bool
cheapOp SizeOp
op
cheapOp (GPUBody [Type]
types Body rep
body) =
forall a. Seq a -> Bool
SQ.null (forall rep. Body rep -> Stms rep
bodyStms Body rep
body) Bool -> Bool -> Bool
&& forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all ((forall a. Eq a => a -> a -> Bool
== Int
0) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank) [Type]
types
instance TypedOp (op rep) => TypedOp (HostOp op rep) where
opType :: forall t (m :: * -> *).
HasScope t m =>
HostOp op rep -> m [ExtType]
opType (SegOp SegOp SegLevel rep
op) = forall op t (m :: * -> *).
(TypedOp op, HasScope t m) =>
op -> m [ExtType]
opType SegOp SegLevel rep
op
opType (OtherOp op rep
op) = forall op t (m :: * -> *).
(TypedOp op, HasScope t m) =>
op -> m [ExtType]
opType op rep
op
opType (SizeOp SizeOp
op) = forall op t (m :: * -> *).
(TypedOp op, HasScope t m) =>
op -> m [ExtType]
opType SizeOp
op
opType (GPUBody [Type]
ts Body rep
_) =
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall u. [TypeBase (ShapeBase SubExp) u] -> [TypeBase ExtShape u]
staticShapes forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) [Type]
ts
instance (Aliased rep, AliasedOp (op rep)) => AliasedOp (HostOp op rep) where
opAliases :: HostOp op rep -> [Names]
opAliases (SegOp SegOp SegLevel rep
op) = forall op. AliasedOp op => op -> [Names]
opAliases SegOp SegLevel rep
op
opAliases (OtherOp op rep
op) = forall op. AliasedOp op => op -> [Names]
opAliases op rep
op
opAliases (SizeOp SizeOp
op) = forall op. AliasedOp op => op -> [Names]
opAliases SizeOp
op
opAliases (GPUBody [Type]
ts Body rep
_) = forall a b. (a -> b) -> [a] -> [b]
map (forall a b. a -> b -> a
const forall a. Monoid a => a
mempty) [Type]
ts
consumedInOp :: HostOp op rep -> Names
consumedInOp (SegOp SegOp SegLevel rep
op) = forall op. AliasedOp op => op -> Names
consumedInOp SegOp SegLevel rep
op
consumedInOp (OtherOp op rep
op) = forall op. AliasedOp op => op -> Names
consumedInOp op rep
op
consumedInOp (SizeOp SizeOp
op) = forall op. AliasedOp op => op -> Names
consumedInOp SizeOp
op
consumedInOp (GPUBody [Type]
_ Body rep
body) = forall rep. Aliased rep => Body rep -> Names
consumedInBody Body rep
body
instance (ASTRep rep, FreeIn (op rep)) => FreeIn (HostOp op rep) where
freeIn' :: HostOp op rep -> FV
freeIn' (SegOp SegOp SegLevel rep
op) = forall a. FreeIn a => a -> FV
freeIn' SegOp SegLevel rep
op
freeIn' (OtherOp op rep
op) = forall a. FreeIn a => a -> FV
freeIn' op rep
op
freeIn' (SizeOp SizeOp
op) = forall a. FreeIn a => a -> FV
freeIn' SizeOp
op
freeIn' (GPUBody [Type]
ts Body rep
body) = forall a. FreeIn a => a -> FV
freeIn' [Type]
ts forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> FV
freeIn' Body rep
body
instance CanBeAliased op => CanBeAliased (HostOp op) where
addOpAliases :: forall rep.
AliasableRep rep =>
AliasTable -> HostOp op rep -> HostOp op (Aliases rep)
addOpAliases AliasTable
aliases (SegOp SegOp SegLevel rep
op) = forall (op :: * -> *) rep. SegOp SegLevel rep -> HostOp op rep
SegOp forall a b. (a -> b) -> a -> b
$ forall (op :: * -> *) rep.
(CanBeAliased op, AliasableRep rep) =>
AliasTable -> op rep -> op (Aliases rep)
addOpAliases AliasTable
aliases SegOp SegLevel rep
op
addOpAliases AliasTable
aliases (GPUBody [Type]
ts Body rep
body) = forall (op :: * -> *) rep. [Type] -> Body rep -> HostOp op rep
GPUBody [Type]
ts forall a b. (a -> b) -> a -> b
$ forall rep.
AliasableRep rep =>
AliasTable -> Body rep -> Body (Aliases rep)
Alias.analyseBody AliasTable
aliases Body rep
body
addOpAliases AliasTable
aliases (OtherOp op rep
op) = forall (op :: * -> *) rep. op rep -> HostOp op rep
OtherOp forall a b. (a -> b) -> a -> b
$ forall (op :: * -> *) rep.
(CanBeAliased op, AliasableRep rep) =>
AliasTable -> op rep -> op (Aliases rep)
addOpAliases AliasTable
aliases op rep
op
addOpAliases AliasTable
_ (SizeOp SizeOp
op) = forall (op :: * -> *) rep. SizeOp -> HostOp op rep
SizeOp SizeOp
op
instance CanBeWise op => CanBeWise (HostOp op) where
addOpWisdom :: forall rep. Informing rep => HostOp op rep -> HostOp op (Wise rep)
addOpWisdom (SegOp SegOp SegLevel rep
op) = forall (op :: * -> *) rep. SegOp SegLevel rep -> HostOp op rep
SegOp forall a b. (a -> b) -> a -> b
$ forall (op :: * -> *) rep.
(CanBeWise op, Informing rep) =>
op rep -> op (Wise rep)
addOpWisdom SegOp SegLevel rep
op
addOpWisdom (OtherOp op rep
op) = forall (op :: * -> *) rep. op rep -> HostOp op rep
OtherOp forall a b. (a -> b) -> a -> b
$ forall (op :: * -> *) rep.
(CanBeWise op, Informing rep) =>
op rep -> op (Wise rep)
addOpWisdom op rep
op
addOpWisdom (SizeOp SizeOp
op) = forall (op :: * -> *) rep. SizeOp -> HostOp op rep
SizeOp SizeOp
op
addOpWisdom (GPUBody [Type]
ts Body rep
body) = forall (op :: * -> *) rep. [Type] -> Body rep -> HostOp op rep
GPUBody [Type]
ts forall a b. (a -> b) -> a -> b
$ forall rep. Informing rep => Body rep -> Body (Wise rep)
informBody Body rep
body
instance (ASTRep rep, ST.IndexOp (op rep)) => ST.IndexOp (HostOp op rep) where
indexOp :: forall rep.
(ASTRep rep, IndexOp (Op rep)) =>
SymbolTable rep
-> Int -> HostOp op rep -> [TPrimExp Int64 VName] -> Maybe Indexed
indexOp SymbolTable rep
vtable Int
k (SegOp SegOp SegLevel rep
op) [TPrimExp Int64 VName]
is = forall op rep.
(IndexOp op, ASTRep rep, IndexOp (Op rep)) =>
SymbolTable rep
-> Int -> op -> [TPrimExp Int64 VName] -> Maybe Indexed
ST.indexOp SymbolTable rep
vtable Int
k SegOp SegLevel rep
op [TPrimExp Int64 VName]
is
indexOp SymbolTable rep
vtable Int
k (OtherOp op rep
op) [TPrimExp Int64 VName]
is = forall op rep.
(IndexOp op, ASTRep rep, IndexOp (Op rep)) =>
SymbolTable rep
-> Int -> op -> [TPrimExp Int64 VName] -> Maybe Indexed
ST.indexOp SymbolTable rep
vtable Int
k op rep
op [TPrimExp Int64 VName]
is
indexOp SymbolTable rep
_ Int
_ HostOp op rep
_ [TPrimExp Int64 VName]
_ = forall a. Maybe a
Nothing
instance (PrettyRep rep, PP.Pretty (op rep)) => PP.Pretty (HostOp op rep) where
pretty :: forall ann. HostOp op rep -> Doc ann
pretty (SegOp SegOp SegLevel rep
op) = forall a ann. Pretty a => a -> Doc ann
pretty SegOp SegLevel rep
op
pretty (OtherOp op rep
op) = forall a ann. Pretty a => a -> Doc ann
pretty op rep
op
pretty (SizeOp SizeOp
op) = forall a ann. Pretty a => a -> Doc ann
pretty SizeOp
op
pretty (GPUBody [Type]
ts Body rep
body) =
Doc ann
"gpu" forall ann. Doc ann -> Doc ann -> Doc ann
<+> forall ann. Doc ann
PP.colon forall ann. Doc ann -> Doc ann -> Doc ann
<+> forall a. [Doc a] -> Doc a
ppTuple' (forall a b. (a -> b) -> [a] -> [b]
map forall a ann. Pretty a => a -> Doc ann
pretty [Type]
ts) forall ann. Doc ann -> Doc ann -> Doc ann
<+> forall a. Doc a -> Doc a -> Doc a -> Doc a
PP.nestedBlock Doc ann
"{" Doc ann
"}" (forall a ann. Pretty a => a -> Doc ann
pretty Body rep
body)
instance (OpMetrics (Op rep), OpMetrics (op rep)) => OpMetrics (HostOp op rep) where
opMetrics :: HostOp op rep -> MetricsM ()
opMetrics (SegOp SegOp SegLevel rep
op) = forall op. OpMetrics op => op -> MetricsM ()
opMetrics SegOp SegLevel rep
op
opMetrics (OtherOp op rep
op) = forall op. OpMetrics op => op -> MetricsM ()
opMetrics op rep
op
opMetrics (SizeOp SizeOp
op) = forall op. OpMetrics op => op -> MetricsM ()
opMetrics SizeOp
op
opMetrics (GPUBody [Type]
_ Body rep
body) = Text -> MetricsM () -> MetricsM ()
inside Text
"GPUBody" forall a b. (a -> b) -> a -> b
$ forall rep. OpMetrics (Op rep) => Body rep -> MetricsM ()
bodyMetrics Body rep
body
instance RephraseOp op => RephraseOp (HostOp op) where
rephraseInOp :: forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> HostOp op from -> m (HostOp op to)
rephraseInOp Rephraser m from to
r (SegOp SegOp SegLevel from
op) = forall (op :: * -> *) rep. SegOp SegLevel rep -> HostOp op rep
SegOp forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (op :: * -> *) (m :: * -> *) from to.
(RephraseOp op, Monad m) =>
Rephraser m from to -> op from -> m (op to)
rephraseInOp Rephraser m from to
r SegOp SegLevel from
op
rephraseInOp Rephraser m from to
r (OtherOp op from
op) = forall (op :: * -> *) rep. op rep -> HostOp op rep
OtherOp forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (op :: * -> *) (m :: * -> *) from to.
(RephraseOp op, Monad m) =>
Rephraser m from to -> op from -> m (op to)
rephraseInOp Rephraser m from to
r op from
op
rephraseInOp Rephraser m from to
_ (SizeOp SizeOp
op) = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (op :: * -> *) rep. SizeOp -> HostOp op rep
SizeOp SizeOp
op
rephraseInOp Rephraser m from to
r (GPUBody [Type]
ts Body from
body) = forall (op :: * -> *) rep. [Type] -> Body rep -> HostOp op rep
GPUBody [Type]
ts forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Body from -> m (Body to)
rephraseBody Rephraser m from to
r Body from
body
checkGrid :: TC.Checkable rep => KernelGrid -> TC.TypeM rep ()
checkGrid :: forall rep. Checkable rep => KernelGrid -> TypeM rep ()
checkGrid KernelGrid
grid = do
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] forall a b. (a -> b) -> a -> b
$ forall {k} (u :: k) e. Count u e -> e
unCount forall a b. (a -> b) -> a -> b
$ KernelGrid -> Count NumGroups SubExp
gridNumGroups KernelGrid
grid
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] forall a b. (a -> b) -> a -> b
$ forall {k} (u :: k) e. Count u e -> e
unCount forall a b. (a -> b) -> a -> b
$ KernelGrid -> Count GroupSize SubExp
gridGroupSize KernelGrid
grid
checkSegLevel ::
TC.Checkable rep =>
Maybe SegLevel ->
SegLevel ->
TC.TypeM rep ()
checkSegLevel :: forall rep.
Checkable rep =>
Maybe SegLevel -> SegLevel -> TypeM rep ()
checkSegLevel (Just SegGroup {}) (SegThreadInGroup SegVirt
_virt) =
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
checkSegLevel Maybe SegLevel
_ (SegThreadInGroup SegVirt
_virt) =
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad forall a b. (a -> b) -> a -> b
$ forall rep. Text -> ErrorCase rep
TC.TypeError Text
"ingroup SegOp not in group SegOp."
checkSegLevel (Just SegThread {}) SegLevel
_ =
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad forall a b. (a -> b) -> a -> b
$ forall rep. Text -> ErrorCase rep
TC.TypeError Text
"SegOps cannot occur when already at thread level."
checkSegLevel (Just SegThreadInGroup {}) SegLevel
_ =
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad forall a b. (a -> b) -> a -> b
$ forall rep. Text -> ErrorCase rep
TC.TypeError Text
"SegOps cannot occur when already at ingroup level."
checkSegLevel Maybe SegLevel
_ (SegThread SegVirt
_virt Maybe KernelGrid
Nothing) =
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
checkSegLevel (Just SegLevel
_) SegThread {} =
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad forall a b. (a -> b) -> a -> b
$ forall rep. Text -> ErrorCase rep
TC.TypeError Text
"thread-level SegOp cannot be nested"
checkSegLevel Maybe SegLevel
Nothing (SegThread SegVirt
_virt Maybe KernelGrid
grid) =
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ forall rep. Checkable rep => KernelGrid -> TypeM rep ()
checkGrid Maybe KernelGrid
grid
checkSegLevel (Just SegLevel
_) SegGroup {} =
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad forall a b. (a -> b) -> a -> b
$ forall rep. Text -> ErrorCase rep
TC.TypeError Text
"group-level SegOp cannot be nested"
checkSegLevel Maybe SegLevel
Nothing (SegGroup SegVirt
_virt Maybe KernelGrid
grid) =
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ forall rep. Checkable rep => KernelGrid -> TypeM rep ()
checkGrid Maybe KernelGrid
grid
typeCheckHostOp ::
TC.Checkable rep =>
(SegLevel -> Op (Aliases rep) -> TC.TypeM rep ()) ->
Maybe SegLevel ->
(op (Aliases rep) -> TC.TypeM rep ()) ->
HostOp op (Aliases rep) ->
TC.TypeM rep ()
typeCheckHostOp :: forall rep (op :: * -> *).
Checkable rep =>
(SegLevel -> Op (Aliases rep) -> TypeM rep ())
-> Maybe SegLevel
-> (op (Aliases rep) -> TypeM rep ())
-> HostOp op (Aliases rep)
-> TypeM rep ()
typeCheckHostOp SegLevel -> Op (Aliases rep) -> TypeM rep ()
checker Maybe SegLevel
lvl op (Aliases rep) -> TypeM rep ()
_ (SegOp SegOp SegLevel (Aliases rep)
op) =
forall rep a.
(Op (Aliases rep) -> TypeM rep ()) -> TypeM rep a -> TypeM rep a
TC.checkOpWith (SegLevel -> Op (Aliases rep) -> TypeM rep ()
checker forall a b. (a -> b) -> a -> b
$ forall lvl rep. SegOp lvl rep -> lvl
segLevel SegOp SegLevel (Aliases rep)
op) forall a b. (a -> b) -> a -> b
$
forall rep lvl.
Checkable rep =>
(lvl -> TypeM rep ()) -> SegOp lvl (Aliases rep) -> TypeM rep ()
typeCheckSegOp (forall rep.
Checkable rep =>
Maybe SegLevel -> SegLevel -> TypeM rep ()
checkSegLevel Maybe SegLevel
lvl) SegOp SegLevel (Aliases rep)
op
typeCheckHostOp SegLevel -> Op (Aliases rep) -> TypeM rep ()
_ Just {} op (Aliases rep) -> TypeM rep ()
_ GPUBody {} =
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad forall a b. (a -> b) -> a -> b
$ forall rep. Text -> ErrorCase rep
TC.TypeError Text
"GPUBody may not be nested in SegOps."
typeCheckHostOp SegLevel -> Op (Aliases rep) -> TypeM rep ()
_ Maybe SegLevel
_ op (Aliases rep) -> TypeM rep ()
f (OtherOp op (Aliases rep)
op) = op (Aliases rep) -> TypeM rep ()
f op (Aliases rep)
op
typeCheckHostOp SegLevel -> Op (Aliases rep) -> TypeM rep ()
_ Maybe SegLevel
_ op (Aliases rep) -> TypeM rep ()
_ (SizeOp SizeOp
op) = forall rep. Checkable rep => SizeOp -> TypeM rep ()
typeCheckSizeOp SizeOp
op
typeCheckHostOp SegLevel -> Op (Aliases rep) -> TypeM rep ()
_ Maybe SegLevel
Nothing op (Aliases rep) -> TypeM rep ()
_ (GPUBody [Type]
ts Body (Aliases rep)
body) = do
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ forall rep u.
Checkable rep =>
TypeBase (ShapeBase SubExp) u -> TypeM rep ()
TC.checkType [Type]
ts
forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ forall rep.
Checkable rep =>
Body (Aliases rep) -> TypeM rep [Names]
TC.checkBody Body (Aliases rep)
body
[Type]
body_ts <-
forall rep (m :: * -> *) a.
ExtendedScope rep m a -> Scope rep -> m a
extendedScope
(forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall t (m :: * -> *). HasScope t m => SubExpRes -> m Type
subExpResType (forall rep. Body rep -> Result
bodyResult Body (Aliases rep)
body))
(forall rep a. Scoped rep a => a -> Scope rep
scopeOf (forall rep. Body rep -> Stms rep
bodyStms Body (Aliases rep)
body))
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type]
body_ts forall a. Eq a => a -> a -> Bool
== [Type]
ts) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep a. ErrorCase rep -> TypeM rep a
TC.bad forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Text -> ErrorCase rep
TC.TypeError forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Text] -> Text
T.unlines forall a b. (a -> b) -> a -> b
$
[ Text
"Expected type: " forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => [a] -> Text
prettyTuple [Type]
ts,
Text
"Got body type: " forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => [a] -> Text
prettyTuple [Type]
body_ts
]