{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
module Futhark.IR.GPU.Op
(
SizeOp (..),
HostOp (..),
traverseHostOpStms,
typeCheckHostOp,
SegLevel (..),
segVirt,
SegVirt (..),
SegSeqDims (..),
KernelGrid (..),
module Futhark.IR.GPU.Sizes,
module Futhark.IR.SegOp,
)
where
import Control.Monad
import Data.Sequence qualified as SQ
import Data.Text qualified as T
import Futhark.Analysis.Alias qualified as Alias
import Futhark.Analysis.Metrics
import Futhark.Analysis.SymbolTable qualified as ST
import Futhark.IR
import Futhark.IR.Aliases (Aliases, CanBeAliased (..))
import Futhark.IR.GPU.Sizes
import Futhark.IR.Mem (OpReturns (..), extReturns)
import Futhark.IR.Prop.Aliases
import Futhark.IR.SegOp
import Futhark.IR.TypeCheck qualified as TC
import Futhark.Optimise.Simplify.Engine qualified as Engine
import Futhark.Optimise.Simplify.Rep
import Futhark.Transform.Rename
import Futhark.Transform.Substitute
import Futhark.Util.Pretty
( commasep,
parens,
ppTuple',
pretty,
(<+>),
)
import Futhark.Util.Pretty qualified as PP
newtype SegSeqDims = SegSeqDims {SegSeqDims -> [Int]
segSeqDims :: [Int]}
deriving (SegSeqDims -> SegSeqDims -> Bool
(SegSeqDims -> SegSeqDims -> Bool)
-> (SegSeqDims -> SegSeqDims -> Bool) -> Eq SegSeqDims
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: SegSeqDims -> SegSeqDims -> Bool
== :: SegSeqDims -> SegSeqDims -> Bool
$c/= :: SegSeqDims -> SegSeqDims -> Bool
/= :: SegSeqDims -> SegSeqDims -> Bool
Eq, Eq SegSeqDims
Eq SegSeqDims =>
(SegSeqDims -> SegSeqDims -> Ordering)
-> (SegSeqDims -> SegSeqDims -> Bool)
-> (SegSeqDims -> SegSeqDims -> Bool)
-> (SegSeqDims -> SegSeqDims -> Bool)
-> (SegSeqDims -> SegSeqDims -> Bool)
-> (SegSeqDims -> SegSeqDims -> SegSeqDims)
-> (SegSeqDims -> SegSeqDims -> SegSeqDims)
-> Ord SegSeqDims
SegSeqDims -> SegSeqDims -> Bool
SegSeqDims -> SegSeqDims -> Ordering
SegSeqDims -> SegSeqDims -> SegSeqDims
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: SegSeqDims -> SegSeqDims -> Ordering
compare :: SegSeqDims -> SegSeqDims -> Ordering
$c< :: SegSeqDims -> SegSeqDims -> Bool
< :: SegSeqDims -> SegSeqDims -> Bool
$c<= :: SegSeqDims -> SegSeqDims -> Bool
<= :: SegSeqDims -> SegSeqDims -> Bool
$c> :: SegSeqDims -> SegSeqDims -> Bool
> :: SegSeqDims -> SegSeqDims -> Bool
$c>= :: SegSeqDims -> SegSeqDims -> Bool
>= :: SegSeqDims -> SegSeqDims -> Bool
$cmax :: SegSeqDims -> SegSeqDims -> SegSeqDims
max :: SegSeqDims -> SegSeqDims -> SegSeqDims
$cmin :: SegSeqDims -> SegSeqDims -> SegSeqDims
min :: SegSeqDims -> SegSeqDims -> SegSeqDims
Ord, Int -> SegSeqDims -> ShowS
[SegSeqDims] -> ShowS
SegSeqDims -> String
(Int -> SegSeqDims -> ShowS)
-> (SegSeqDims -> String)
-> ([SegSeqDims] -> ShowS)
-> Show SegSeqDims
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> SegSeqDims -> ShowS
showsPrec :: Int -> SegSeqDims -> ShowS
$cshow :: SegSeqDims -> String
show :: SegSeqDims -> String
$cshowList :: [SegSeqDims] -> ShowS
showList :: [SegSeqDims] -> ShowS
Show)
data SegVirt
= SegVirt
| SegNoVirt
|
SegNoVirtFull SegSeqDims
deriving (SegVirt -> SegVirt -> Bool
(SegVirt -> SegVirt -> Bool)
-> (SegVirt -> SegVirt -> Bool) -> Eq SegVirt
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: SegVirt -> SegVirt -> Bool
== :: SegVirt -> SegVirt -> Bool
$c/= :: SegVirt -> SegVirt -> Bool
/= :: SegVirt -> SegVirt -> Bool
Eq, Eq SegVirt
Eq SegVirt =>
(SegVirt -> SegVirt -> Ordering)
-> (SegVirt -> SegVirt -> Bool)
-> (SegVirt -> SegVirt -> Bool)
-> (SegVirt -> SegVirt -> Bool)
-> (SegVirt -> SegVirt -> Bool)
-> (SegVirt -> SegVirt -> SegVirt)
-> (SegVirt -> SegVirt -> SegVirt)
-> Ord SegVirt
SegVirt -> SegVirt -> Bool
SegVirt -> SegVirt -> Ordering
SegVirt -> SegVirt -> SegVirt
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: SegVirt -> SegVirt -> Ordering
compare :: SegVirt -> SegVirt -> Ordering
$c< :: SegVirt -> SegVirt -> Bool
< :: SegVirt -> SegVirt -> Bool
$c<= :: SegVirt -> SegVirt -> Bool
<= :: SegVirt -> SegVirt -> Bool
$c> :: SegVirt -> SegVirt -> Bool
> :: SegVirt -> SegVirt -> Bool
$c>= :: SegVirt -> SegVirt -> Bool
>= :: SegVirt -> SegVirt -> Bool
$cmax :: SegVirt -> SegVirt -> SegVirt
max :: SegVirt -> SegVirt -> SegVirt
$cmin :: SegVirt -> SegVirt -> SegVirt
min :: SegVirt -> SegVirt -> SegVirt
Ord, Int -> SegVirt -> ShowS
[SegVirt] -> ShowS
SegVirt -> String
(Int -> SegVirt -> ShowS)
-> (SegVirt -> String) -> ([SegVirt] -> ShowS) -> Show SegVirt
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> SegVirt -> ShowS
showsPrec :: Int -> SegVirt -> ShowS
$cshow :: SegVirt -> String
show :: SegVirt -> String
$cshowList :: [SegVirt] -> ShowS
showList :: [SegVirt] -> ShowS
Show)
data KernelGrid = KernelGrid
{ KernelGrid -> Count NumBlocks SubExp
gridNumBlocks :: Count NumBlocks SubExp,
KernelGrid -> Count BlockSize SubExp
gridBlockSize :: Count BlockSize SubExp
}
deriving (KernelGrid -> KernelGrid -> Bool
(KernelGrid -> KernelGrid -> Bool)
-> (KernelGrid -> KernelGrid -> Bool) -> Eq KernelGrid
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: KernelGrid -> KernelGrid -> Bool
== :: KernelGrid -> KernelGrid -> Bool
$c/= :: KernelGrid -> KernelGrid -> Bool
/= :: KernelGrid -> KernelGrid -> Bool
Eq, Eq KernelGrid
Eq KernelGrid =>
(KernelGrid -> KernelGrid -> Ordering)
-> (KernelGrid -> KernelGrid -> Bool)
-> (KernelGrid -> KernelGrid -> Bool)
-> (KernelGrid -> KernelGrid -> Bool)
-> (KernelGrid -> KernelGrid -> Bool)
-> (KernelGrid -> KernelGrid -> KernelGrid)
-> (KernelGrid -> KernelGrid -> KernelGrid)
-> Ord KernelGrid
KernelGrid -> KernelGrid -> Bool
KernelGrid -> KernelGrid -> Ordering
KernelGrid -> KernelGrid -> KernelGrid
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: KernelGrid -> KernelGrid -> Ordering
compare :: KernelGrid -> KernelGrid -> Ordering
$c< :: KernelGrid -> KernelGrid -> Bool
< :: KernelGrid -> KernelGrid -> Bool
$c<= :: KernelGrid -> KernelGrid -> Bool
<= :: KernelGrid -> KernelGrid -> Bool
$c> :: KernelGrid -> KernelGrid -> Bool
> :: KernelGrid -> KernelGrid -> Bool
$c>= :: KernelGrid -> KernelGrid -> Bool
>= :: KernelGrid -> KernelGrid -> Bool
$cmax :: KernelGrid -> KernelGrid -> KernelGrid
max :: KernelGrid -> KernelGrid -> KernelGrid
$cmin :: KernelGrid -> KernelGrid -> KernelGrid
min :: KernelGrid -> KernelGrid -> KernelGrid
Ord, Int -> KernelGrid -> ShowS
[KernelGrid] -> ShowS
KernelGrid -> String
(Int -> KernelGrid -> ShowS)
-> (KernelGrid -> String)
-> ([KernelGrid] -> ShowS)
-> Show KernelGrid
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> KernelGrid -> ShowS
showsPrec :: Int -> KernelGrid -> ShowS
$cshow :: KernelGrid -> String
show :: KernelGrid -> String
$cshowList :: [KernelGrid] -> ShowS
showList :: [KernelGrid] -> ShowS
Show)
data SegLevel
= SegThread SegVirt (Maybe KernelGrid)
| SegBlock SegVirt (Maybe KernelGrid)
| SegThreadInBlock SegVirt
deriving (SegLevel -> SegLevel -> Bool
(SegLevel -> SegLevel -> Bool)
-> (SegLevel -> SegLevel -> Bool) -> Eq SegLevel
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: SegLevel -> SegLevel -> Bool
== :: SegLevel -> SegLevel -> Bool
$c/= :: SegLevel -> SegLevel -> Bool
/= :: SegLevel -> SegLevel -> Bool
Eq, Eq SegLevel
Eq SegLevel =>
(SegLevel -> SegLevel -> Ordering)
-> (SegLevel -> SegLevel -> Bool)
-> (SegLevel -> SegLevel -> Bool)
-> (SegLevel -> SegLevel -> Bool)
-> (SegLevel -> SegLevel -> Bool)
-> (SegLevel -> SegLevel -> SegLevel)
-> (SegLevel -> SegLevel -> SegLevel)
-> Ord SegLevel
SegLevel -> SegLevel -> Bool
SegLevel -> SegLevel -> Ordering
SegLevel -> SegLevel -> SegLevel
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: SegLevel -> SegLevel -> Ordering
compare :: SegLevel -> SegLevel -> Ordering
$c< :: SegLevel -> SegLevel -> Bool
< :: SegLevel -> SegLevel -> Bool
$c<= :: SegLevel -> SegLevel -> Bool
<= :: SegLevel -> SegLevel -> Bool
$c> :: SegLevel -> SegLevel -> Bool
> :: SegLevel -> SegLevel -> Bool
$c>= :: SegLevel -> SegLevel -> Bool
>= :: SegLevel -> SegLevel -> Bool
$cmax :: SegLevel -> SegLevel -> SegLevel
max :: SegLevel -> SegLevel -> SegLevel
$cmin :: SegLevel -> SegLevel -> SegLevel
min :: SegLevel -> SegLevel -> SegLevel
Ord, Int -> SegLevel -> ShowS
[SegLevel] -> ShowS
SegLevel -> String
(Int -> SegLevel -> ShowS)
-> (SegLevel -> String) -> ([SegLevel] -> ShowS) -> Show SegLevel
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> SegLevel -> ShowS
showsPrec :: Int -> SegLevel -> ShowS
$cshow :: SegLevel -> String
show :: SegLevel -> String
$cshowList :: [SegLevel] -> ShowS
showList :: [SegLevel] -> ShowS
Show)
segVirt :: SegLevel -> SegVirt
segVirt :: SegLevel -> SegVirt
segVirt (SegThread SegVirt
v Maybe KernelGrid
_) = SegVirt
v
segVirt (SegBlock SegVirt
v Maybe KernelGrid
_) = SegVirt
v
segVirt (SegThreadInBlock SegVirt
v) = SegVirt
v
instance PP.Pretty SegVirt where
pretty :: forall ann. SegVirt -> Doc ann
pretty SegVirt
SegNoVirt = Doc ann
forall a. Monoid a => a
mempty
pretty (SegNoVirtFull SegSeqDims
dims) = Doc ann
"full" Doc ann -> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann -> Doc ann
<+> [Int] -> Doc ann
forall ann. [Int] -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty (SegSeqDims -> [Int]
segSeqDims SegSeqDims
dims)
pretty SegVirt
SegVirt = Doc ann
"virtualise"
instance PP.Pretty KernelGrid where
pretty :: forall ann. KernelGrid -> Doc ann
pretty (KernelGrid Count NumBlocks SubExp
num_tblocks Count BlockSize SubExp
tblock_size) =
Doc ann
"grid="
Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Count NumBlocks SubExp -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. Count NumBlocks SubExp -> Doc ann
pretty Count NumBlocks SubExp
num_tblocks
Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
PP.semi
Doc ann -> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann -> Doc ann
<+> Doc ann
"blocksize="
Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Count BlockSize SubExp -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. Count BlockSize SubExp -> Doc ann
pretty Count BlockSize SubExp
tblock_size
instance PP.Pretty SegLevel where
pretty :: forall ann. SegLevel -> Doc ann
pretty (SegThread SegVirt
virt Maybe KernelGrid
grid) =
Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
PP.parens (Doc ann
"thread" Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
PP.semi Doc ann -> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann -> Doc ann
<+> SegVirt -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. SegVirt -> Doc ann
pretty SegVirt
virt Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
PP.semi Doc ann -> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann -> Doc ann
<+> Maybe KernelGrid -> Doc ann
forall ann. Maybe KernelGrid -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty Maybe KernelGrid
grid)
pretty (SegBlock SegVirt
virt Maybe KernelGrid
grid) =
Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
PP.parens (Doc ann
"block" Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
PP.semi Doc ann -> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann -> Doc ann
<+> SegVirt -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. SegVirt -> Doc ann
pretty SegVirt
virt Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
PP.semi Doc ann -> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann -> Doc ann
<+> Maybe KernelGrid -> Doc ann
forall ann. Maybe KernelGrid -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty Maybe KernelGrid
grid)
pretty (SegThreadInBlock SegVirt
virt) =
Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
PP.parens (Doc ann
"inblock" Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
PP.semi Doc ann -> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann -> Doc ann
<+> SegVirt -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. SegVirt -> Doc ann
pretty SegVirt
virt)
instance Engine.Simplifiable KernelGrid where
simplify :: forall rep.
SimplifiableRep rep =>
KernelGrid -> SimpleM rep KernelGrid
simplify (KernelGrid Count NumBlocks SubExp
num_tblocks Count BlockSize SubExp
tblock_size) =
Count NumBlocks SubExp -> Count BlockSize SubExp -> KernelGrid
KernelGrid
(Count NumBlocks SubExp -> Count BlockSize SubExp -> KernelGrid)
-> SimpleM rep (Count NumBlocks SubExp)
-> SimpleM rep (Count BlockSize SubExp -> KernelGrid)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExp -> SimpleM rep SubExp)
-> Count NumBlocks SubExp -> SimpleM rep (Count NumBlocks SubExp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> Count NumBlocks a -> f (Count NumBlocks b)
traverse SubExp -> SimpleM rep SubExp
forall rep. SimplifiableRep rep => SubExp -> SimpleM rep SubExp
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify Count NumBlocks SubExp
num_tblocks
SimpleM rep (Count BlockSize SubExp -> KernelGrid)
-> SimpleM rep (Count BlockSize SubExp) -> SimpleM rep KernelGrid
forall a b. SimpleM rep (a -> b) -> SimpleM rep a -> SimpleM rep b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SubExp -> SimpleM rep SubExp)
-> Count BlockSize SubExp -> SimpleM rep (Count BlockSize SubExp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> Count BlockSize a -> f (Count BlockSize b)
traverse SubExp -> SimpleM rep SubExp
forall rep. SimplifiableRep rep => SubExp -> SimpleM rep SubExp
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify Count BlockSize SubExp
tblock_size
instance Engine.Simplifiable SegLevel where
simplify :: forall rep. SimplifiableRep rep => SegLevel -> SimpleM rep SegLevel
simplify (SegThread SegVirt
virt Maybe KernelGrid
grid) =
SegVirt -> Maybe KernelGrid -> SegLevel
SegThread SegVirt
virt (Maybe KernelGrid -> SegLevel)
-> SimpleM rep (Maybe KernelGrid) -> SimpleM rep SegLevel
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe KernelGrid -> SimpleM rep (Maybe KernelGrid)
forall rep.
SimplifiableRep rep =>
Maybe KernelGrid -> SimpleM rep (Maybe KernelGrid)
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify Maybe KernelGrid
grid
simplify (SegBlock SegVirt
virt Maybe KernelGrid
grid) =
SegVirt -> Maybe KernelGrid -> SegLevel
SegBlock SegVirt
virt (Maybe KernelGrid -> SegLevel)
-> SimpleM rep (Maybe KernelGrid) -> SimpleM rep SegLevel
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe KernelGrid -> SimpleM rep (Maybe KernelGrid)
forall rep.
SimplifiableRep rep =>
Maybe KernelGrid -> SimpleM rep (Maybe KernelGrid)
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify Maybe KernelGrid
grid
simplify (SegThreadInBlock SegVirt
virt) =
SegLevel -> SimpleM rep SegLevel
forall a. a -> SimpleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SegLevel -> SimpleM rep SegLevel)
-> SegLevel -> SimpleM rep SegLevel
forall a b. (a -> b) -> a -> b
$ SegVirt -> SegLevel
SegThreadInBlock SegVirt
virt
instance Substitute KernelGrid where
substituteNames :: Map VName VName -> KernelGrid -> KernelGrid
substituteNames Map VName VName
substs (KernelGrid Count NumBlocks SubExp
num_tblocks Count BlockSize SubExp
tblock_size) =
Count NumBlocks SubExp -> Count BlockSize SubExp -> KernelGrid
KernelGrid
(Map VName VName -> Count NumBlocks SubExp -> Count NumBlocks SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Count NumBlocks SubExp
num_tblocks)
(Map VName VName -> Count BlockSize SubExp -> Count BlockSize SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Count BlockSize SubExp
tblock_size)
instance Substitute SegLevel where
substituteNames :: Map VName VName -> SegLevel -> SegLevel
substituteNames Map VName VName
substs (SegThread SegVirt
virt Maybe KernelGrid
grid) =
SegVirt -> Maybe KernelGrid -> SegLevel
SegThread SegVirt
virt (Map VName VName -> Maybe KernelGrid -> Maybe KernelGrid
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Maybe KernelGrid
grid)
substituteNames Map VName VName
substs (SegBlock SegVirt
virt Maybe KernelGrid
grid) =
SegVirt -> Maybe KernelGrid -> SegLevel
SegBlock SegVirt
virt (Map VName VName -> Maybe KernelGrid -> Maybe KernelGrid
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Maybe KernelGrid
grid)
substituteNames Map VName VName
_ (SegThreadInBlock SegVirt
virt) =
SegVirt -> SegLevel
SegThreadInBlock SegVirt
virt
instance Rename SegLevel where
rename :: SegLevel -> RenameM SegLevel
rename = SegLevel -> RenameM SegLevel
forall a. Substitute a => a -> RenameM a
substituteRename
instance FreeIn KernelGrid where
freeIn' :: KernelGrid -> FV
freeIn' (KernelGrid Count NumBlocks SubExp
num_tblocks Count BlockSize SubExp
tblock_size) =
(Count NumBlocks SubExp, Count BlockSize SubExp) -> FV
forall a. FreeIn a => a -> FV
freeIn' (Count NumBlocks SubExp
num_tblocks, Count BlockSize SubExp
tblock_size)
instance FreeIn SegLevel where
freeIn' :: SegLevel -> FV
freeIn' (SegThread SegVirt
_virt Maybe KernelGrid
grid) = Maybe KernelGrid -> FV
forall a. FreeIn a => a -> FV
freeIn' Maybe KernelGrid
grid
freeIn' (SegBlock SegVirt
_virt Maybe KernelGrid
grid) = Maybe KernelGrid -> FV
forall a. FreeIn a => a -> FV
freeIn' Maybe KernelGrid
grid
freeIn' (SegThreadInBlock SegVirt
_virt) = FV
forall a. Monoid a => a
mempty
data SizeOp
=
GetSize Name SizeClass
|
GetSizeMax SizeClass
|
CmpSizeLe Name SizeClass SubExp
|
CalcNumBlocks SubExp Name SubExp
deriving (SizeOp -> SizeOp -> Bool
(SizeOp -> SizeOp -> Bool)
-> (SizeOp -> SizeOp -> Bool) -> Eq SizeOp
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: SizeOp -> SizeOp -> Bool
== :: SizeOp -> SizeOp -> Bool
$c/= :: SizeOp -> SizeOp -> Bool
/= :: SizeOp -> SizeOp -> Bool
Eq, Eq SizeOp
Eq SizeOp =>
(SizeOp -> SizeOp -> Ordering)
-> (SizeOp -> SizeOp -> Bool)
-> (SizeOp -> SizeOp -> Bool)
-> (SizeOp -> SizeOp -> Bool)
-> (SizeOp -> SizeOp -> Bool)
-> (SizeOp -> SizeOp -> SizeOp)
-> (SizeOp -> SizeOp -> SizeOp)
-> Ord SizeOp
SizeOp -> SizeOp -> Bool
SizeOp -> SizeOp -> Ordering
SizeOp -> SizeOp -> SizeOp
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: SizeOp -> SizeOp -> Ordering
compare :: SizeOp -> SizeOp -> Ordering
$c< :: SizeOp -> SizeOp -> Bool
< :: SizeOp -> SizeOp -> Bool
$c<= :: SizeOp -> SizeOp -> Bool
<= :: SizeOp -> SizeOp -> Bool
$c> :: SizeOp -> SizeOp -> Bool
> :: SizeOp -> SizeOp -> Bool
$c>= :: SizeOp -> SizeOp -> Bool
>= :: SizeOp -> SizeOp -> Bool
$cmax :: SizeOp -> SizeOp -> SizeOp
max :: SizeOp -> SizeOp -> SizeOp
$cmin :: SizeOp -> SizeOp -> SizeOp
min :: SizeOp -> SizeOp -> SizeOp
Ord, Int -> SizeOp -> ShowS
[SizeOp] -> ShowS
SizeOp -> String
(Int -> SizeOp -> ShowS)
-> (SizeOp -> String) -> ([SizeOp] -> ShowS) -> Show SizeOp
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> SizeOp -> ShowS
showsPrec :: Int -> SizeOp -> ShowS
$cshow :: SizeOp -> String
show :: SizeOp -> String
$cshowList :: [SizeOp] -> ShowS
showList :: [SizeOp] -> ShowS
Show)
instance Substitute SizeOp where
substituteNames :: Map VName VName -> SizeOp -> SizeOp
substituteNames Map VName VName
substs (CmpSizeLe Name
name SizeClass
sclass SubExp
x) =
Name -> SizeClass -> SubExp -> SizeOp
CmpSizeLe Name
name SizeClass
sclass (Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs SubExp
x)
substituteNames Map VName VName
substs (CalcNumBlocks SubExp
w Name
max_num_tblocks SubExp
tblock_size) =
SubExp -> Name -> SubExp -> SizeOp
CalcNumBlocks
(Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs SubExp
w)
Name
max_num_tblocks
(Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs SubExp
tblock_size)
substituteNames Map VName VName
_ SizeOp
op = SizeOp
op
instance Rename SizeOp where
rename :: SizeOp -> RenameM SizeOp
rename (CmpSizeLe Name
name SizeClass
sclass SubExp
x) =
Name -> SizeClass -> SubExp -> SizeOp
CmpSizeLe Name
name SizeClass
sclass (SubExp -> SizeOp) -> RenameM SubExp -> RenameM SizeOp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> RenameM SubExp
forall a. Rename a => a -> RenameM a
rename SubExp
x
rename (CalcNumBlocks SubExp
w Name
max_num_tblocks SubExp
tblock_size) =
SubExp -> Name -> SubExp -> SizeOp
CalcNumBlocks (SubExp -> Name -> SubExp -> SizeOp)
-> RenameM SubExp -> RenameM (Name -> SubExp -> SizeOp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> RenameM SubExp
forall a. Rename a => a -> RenameM a
rename SubExp
w RenameM (Name -> SubExp -> SizeOp)
-> RenameM Name -> RenameM (SubExp -> SizeOp)
forall a b. RenameM (a -> b) -> RenameM a -> RenameM b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Name -> RenameM Name
forall a. a -> RenameM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Name
max_num_tblocks RenameM (SubExp -> SizeOp) -> RenameM SubExp -> RenameM SizeOp
forall a b. RenameM (a -> b) -> RenameM a -> RenameM b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SubExp -> RenameM SubExp
forall a. Rename a => a -> RenameM a
rename SubExp
tblock_size
rename SizeOp
x = SizeOp -> RenameM SizeOp
forall a. a -> RenameM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure SizeOp
x
instance FreeIn SizeOp where
freeIn' :: SizeOp -> FV
freeIn' (CmpSizeLe Name
_ SizeClass
_ SubExp
x) = SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn' SubExp
x
freeIn' (CalcNumBlocks SubExp
w Name
_ SubExp
tblock_size) = SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn' SubExp
w FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn' SubExp
tblock_size
freeIn' SizeOp
_ = FV
forall a. Monoid a => a
mempty
instance PP.Pretty SizeOp where
pretty :: forall ann. SizeOp -> Doc ann
pretty (GetSize Name
name SizeClass
size_class) =
Doc ann
"get_size" Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
parens ([Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
commasep [Name -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. Name -> Doc ann
pretty Name
name, SizeClass -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. SizeClass -> Doc ann
pretty SizeClass
size_class])
pretty (GetSizeMax SizeClass
size_class) =
Doc ann
"get_size_max" Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
parens ([Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
commasep [SizeClass -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. SizeClass -> Doc ann
pretty SizeClass
size_class])
pretty (CmpSizeLe Name
name SizeClass
size_class SubExp
x) =
Doc ann
"cmp_size"
Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
parens ([Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
commasep [Name -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. Name -> Doc ann
pretty Name
name, SizeClass -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. SizeClass -> Doc ann
pretty SizeClass
size_class])
Doc ann -> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann -> Doc ann
<+> Doc ann
"<="
Doc ann -> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann -> Doc ann
<+> SubExp -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. SubExp -> Doc ann
pretty SubExp
x
pretty (CalcNumBlocks SubExp
w Name
max_num_tblocks SubExp
tblock_size) =
Doc ann
"calc_num_tblocks" Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
parens ([Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
commasep [SubExp -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. SubExp -> Doc ann
pretty SubExp
w, Name -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. Name -> Doc ann
pretty Name
max_num_tblocks, SubExp -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. SubExp -> Doc ann
pretty SubExp
tblock_size])
instance OpMetrics SizeOp where
opMetrics :: SizeOp -> MetricsM ()
opMetrics GetSize {} = Text -> MetricsM ()
seen Text
"GetSize"
opMetrics GetSizeMax {} = Text -> MetricsM ()
seen Text
"GetSizeMax"
opMetrics CmpSizeLe {} = Text -> MetricsM ()
seen Text
"CmpSizeLe"
opMetrics CalcNumBlocks {} = Text -> MetricsM ()
seen Text
"CalcNumBlocks"
typeCheckSizeOp :: (TC.Checkable rep) => SizeOp -> TC.TypeM rep ()
typeCheckSizeOp :: forall rep. Checkable rep => SizeOp -> TypeM rep ()
typeCheckSizeOp GetSize {} = () -> TypeM rep ()
forall a. a -> TypeM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
typeCheckSizeOp GetSizeMax {} = () -> TypeM rep ()
forall a. a -> TypeM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
typeCheckSizeOp (CmpSizeLe Name
_ SizeClass
_ SubExp
x) = [Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
x
typeCheckSizeOp (CalcNumBlocks SubExp
w Name
_ SubExp
tblock_size) = do
[Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
w
[Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
tblock_size
data HostOp op rep
=
SegOp (SegOp SegLevel rep)
| SizeOp SizeOp
| OtherOp (op rep)
|
GPUBody [Type] (Body rep)
deriving (HostOp op rep -> HostOp op rep -> Bool
(HostOp op rep -> HostOp op rep -> Bool)
-> (HostOp op rep -> HostOp op rep -> Bool) -> Eq (HostOp op rep)
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall (op :: * -> *) rep.
(RepTypes rep, Eq (op rep)) =>
HostOp op rep -> HostOp op rep -> Bool
$c== :: forall (op :: * -> *) rep.
(RepTypes rep, Eq (op rep)) =>
HostOp op rep -> HostOp op rep -> Bool
== :: HostOp op rep -> HostOp op rep -> Bool
$c/= :: forall (op :: * -> *) rep.
(RepTypes rep, Eq (op rep)) =>
HostOp op rep -> HostOp op rep -> Bool
/= :: HostOp op rep -> HostOp op rep -> Bool
Eq, Eq (HostOp op rep)
Eq (HostOp op rep) =>
(HostOp op rep -> HostOp op rep -> Ordering)
-> (HostOp op rep -> HostOp op rep -> Bool)
-> (HostOp op rep -> HostOp op rep -> Bool)
-> (HostOp op rep -> HostOp op rep -> Bool)
-> (HostOp op rep -> HostOp op rep -> Bool)
-> (HostOp op rep -> HostOp op rep -> HostOp op rep)
-> (HostOp op rep -> HostOp op rep -> HostOp op rep)
-> Ord (HostOp op rep)
HostOp op rep -> HostOp op rep -> Bool
HostOp op rep -> HostOp op rep -> Ordering
HostOp op rep -> HostOp op rep -> HostOp op rep
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall (op :: * -> *) rep.
(RepTypes rep, Ord (op rep)) =>
Eq (HostOp op rep)
forall (op :: * -> *) rep.
(RepTypes rep, Ord (op rep)) =>
HostOp op rep -> HostOp op rep -> Bool
forall (op :: * -> *) rep.
(RepTypes rep, Ord (op rep)) =>
HostOp op rep -> HostOp op rep -> Ordering
forall (op :: * -> *) rep.
(RepTypes rep, Ord (op rep)) =>
HostOp op rep -> HostOp op rep -> HostOp op rep
$ccompare :: forall (op :: * -> *) rep.
(RepTypes rep, Ord (op rep)) =>
HostOp op rep -> HostOp op rep -> Ordering
compare :: HostOp op rep -> HostOp op rep -> Ordering
$c< :: forall (op :: * -> *) rep.
(RepTypes rep, Ord (op rep)) =>
HostOp op rep -> HostOp op rep -> Bool
< :: HostOp op rep -> HostOp op rep -> Bool
$c<= :: forall (op :: * -> *) rep.
(RepTypes rep, Ord (op rep)) =>
HostOp op rep -> HostOp op rep -> Bool
<= :: HostOp op rep -> HostOp op rep -> Bool
$c> :: forall (op :: * -> *) rep.
(RepTypes rep, Ord (op rep)) =>
HostOp op rep -> HostOp op rep -> Bool
> :: HostOp op rep -> HostOp op rep -> Bool
$c>= :: forall (op :: * -> *) rep.
(RepTypes rep, Ord (op rep)) =>
HostOp op rep -> HostOp op rep -> Bool
>= :: HostOp op rep -> HostOp op rep -> Bool
$cmax :: forall (op :: * -> *) rep.
(RepTypes rep, Ord (op rep)) =>
HostOp op rep -> HostOp op rep -> HostOp op rep
max :: HostOp op rep -> HostOp op rep -> HostOp op rep
$cmin :: forall (op :: * -> *) rep.
(RepTypes rep, Ord (op rep)) =>
HostOp op rep -> HostOp op rep -> HostOp op rep
min :: HostOp op rep -> HostOp op rep -> HostOp op rep
Ord, Int -> HostOp op rep -> ShowS
[HostOp op rep] -> ShowS
HostOp op rep -> String
(Int -> HostOp op rep -> ShowS)
-> (HostOp op rep -> String)
-> ([HostOp op rep] -> ShowS)
-> Show (HostOp op rep)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall (op :: * -> *) rep.
(RepTypes rep, Show (op rep)) =>
Int -> HostOp op rep -> ShowS
forall (op :: * -> *) rep.
(RepTypes rep, Show (op rep)) =>
[HostOp op rep] -> ShowS
forall (op :: * -> *) rep.
(RepTypes rep, Show (op rep)) =>
HostOp op rep -> String
$cshowsPrec :: forall (op :: * -> *) rep.
(RepTypes rep, Show (op rep)) =>
Int -> HostOp op rep -> ShowS
showsPrec :: Int -> HostOp op rep -> ShowS
$cshow :: forall (op :: * -> *) rep.
(RepTypes rep, Show (op rep)) =>
HostOp op rep -> String
show :: HostOp op rep -> String
$cshowList :: forall (op :: * -> *) rep.
(RepTypes rep, Show (op rep)) =>
[HostOp op rep] -> ShowS
showList :: [HostOp op rep] -> ShowS
Show)
traverseHostOpStms ::
(Monad m) =>
OpStmsTraverser m (op rep) rep ->
OpStmsTraverser m (HostOp op rep) rep
traverseHostOpStms :: forall (m :: * -> *) (op :: * -> *) rep.
Monad m =>
OpStmsTraverser m (op rep) rep
-> OpStmsTraverser m (HostOp op rep) rep
traverseHostOpStms OpStmsTraverser m (op rep) rep
_ Scope rep -> Stms rep -> m (Stms rep)
f (SegOp SegOp SegLevel rep
segop) = SegOp SegLevel rep -> HostOp op rep
forall (op :: * -> *) rep. SegOp SegLevel rep -> HostOp op rep
SegOp (SegOp SegLevel rep -> HostOp op rep)
-> m (SegOp SegLevel rep) -> m (HostOp op rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> OpStmsTraverser m (SegOp SegLevel rep) rep
forall (m :: * -> *) lvl rep.
Monad m =>
OpStmsTraverser m (SegOp lvl rep) rep
traverseSegOpStms Scope rep -> Stms rep -> m (Stms rep)
f SegOp SegLevel rep
segop
traverseHostOpStms OpStmsTraverser m (op rep) rep
_ Scope rep -> Stms rep -> m (Stms rep)
_ (SizeOp SizeOp
sizeop) = HostOp op rep -> m (HostOp op rep)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (HostOp op rep -> m (HostOp op rep))
-> HostOp op rep -> m (HostOp op rep)
forall a b. (a -> b) -> a -> b
$ SizeOp -> HostOp op rep
forall (op :: * -> *) rep. SizeOp -> HostOp op rep
SizeOp SizeOp
sizeop
traverseHostOpStms OpStmsTraverser m (op rep) rep
onOtherOp Scope rep -> Stms rep -> m (Stms rep)
f (OtherOp op rep
other) = op rep -> HostOp op rep
forall (op :: * -> *) rep. op rep -> HostOp op rep
OtherOp (op rep -> HostOp op rep) -> m (op rep) -> m (HostOp op rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> OpStmsTraverser m (op rep) rep
onOtherOp Scope rep -> Stms rep -> m (Stms rep)
f op rep
other
traverseHostOpStms OpStmsTraverser m (op rep) rep
_ Scope rep -> Stms rep -> m (Stms rep)
f (GPUBody [Type]
ts Body rep
body) = do
Stms rep
stms <- Scope rep -> Stms rep -> m (Stms rep)
f Scope rep
forall a. Monoid a => a
mempty (Stms rep -> m (Stms rep)) -> Stms rep -> m (Stms rep)
forall a b. (a -> b) -> a -> b
$ Body rep -> Stms rep
forall rep. Body rep -> Stms rep
bodyStms Body rep
body
HostOp op rep -> m (HostOp op rep)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (HostOp op rep -> m (HostOp op rep))
-> HostOp op rep -> m (HostOp op rep)
forall a b. (a -> b) -> a -> b
$ [Type] -> Body rep -> HostOp op rep
forall (op :: * -> *) rep. [Type] -> Body rep -> HostOp op rep
GPUBody [Type]
ts (Body rep -> HostOp op rep) -> Body rep -> HostOp op rep
forall a b. (a -> b) -> a -> b
$ Body rep
body {bodyStms = stms}
instance (ASTRep rep, Substitute (op rep)) => Substitute (HostOp op rep) where
substituteNames :: Map VName VName -> HostOp op rep -> HostOp op rep
substituteNames Map VName VName
substs (SegOp SegOp SegLevel rep
op) =
SegOp SegLevel rep -> HostOp op rep
forall (op :: * -> *) rep. SegOp SegLevel rep -> HostOp op rep
SegOp (SegOp SegLevel rep -> HostOp op rep)
-> SegOp SegLevel rep -> HostOp op rep
forall a b. (a -> b) -> a -> b
$ Map VName VName -> SegOp SegLevel rep -> SegOp SegLevel rep
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs SegOp SegLevel rep
op
substituteNames Map VName VName
substs (OtherOp op rep
op) =
op rep -> HostOp op rep
forall (op :: * -> *) rep. op rep -> HostOp op rep
OtherOp (op rep -> HostOp op rep) -> op rep -> HostOp op rep
forall a b. (a -> b) -> a -> b
$ Map VName VName -> op rep -> op rep
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs op rep
op
substituteNames Map VName VName
substs (SizeOp SizeOp
op) =
SizeOp -> HostOp op rep
forall (op :: * -> *) rep. SizeOp -> HostOp op rep
SizeOp (SizeOp -> HostOp op rep) -> SizeOp -> HostOp op rep
forall a b. (a -> b) -> a -> b
$ Map VName VName -> SizeOp -> SizeOp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs SizeOp
op
substituteNames Map VName VName
substs (GPUBody [Type]
ts Body rep
body) =
[Type] -> Body rep -> HostOp op rep
forall (op :: * -> *) rep. [Type] -> Body rep -> HostOp op rep
GPUBody (Map VName VName -> [Type] -> [Type]
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs [Type]
ts) (Map VName VName -> Body rep -> Body rep
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Body rep
body)
instance (ASTRep rep, Rename (op rep)) => Rename (HostOp op rep) where
rename :: HostOp op rep -> RenameM (HostOp op rep)
rename (SegOp SegOp SegLevel rep
op) = SegOp SegLevel rep -> HostOp op rep
forall (op :: * -> *) rep. SegOp SegLevel rep -> HostOp op rep
SegOp (SegOp SegLevel rep -> HostOp op rep)
-> RenameM (SegOp SegLevel rep) -> RenameM (HostOp op rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOp SegLevel rep -> RenameM (SegOp SegLevel rep)
forall a. Rename a => a -> RenameM a
rename SegOp SegLevel rep
op
rename (OtherOp op rep
op) = op rep -> HostOp op rep
forall (op :: * -> *) rep. op rep -> HostOp op rep
OtherOp (op rep -> HostOp op rep)
-> RenameM (op rep) -> RenameM (HostOp op rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> op rep -> RenameM (op rep)
forall a. Rename a => a -> RenameM a
rename op rep
op
rename (SizeOp SizeOp
op) = SizeOp -> HostOp op rep
forall (op :: * -> *) rep. SizeOp -> HostOp op rep
SizeOp (SizeOp -> HostOp op rep)
-> RenameM SizeOp -> RenameM (HostOp op rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SizeOp -> RenameM SizeOp
forall a. Rename a => a -> RenameM a
rename SizeOp
op
rename (GPUBody [Type]
ts Body rep
body) = [Type] -> Body rep -> HostOp op rep
forall (op :: * -> *) rep. [Type] -> Body rep -> HostOp op rep
GPUBody ([Type] -> Body rep -> HostOp op rep)
-> RenameM [Type] -> RenameM (Body rep -> HostOp op rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Type] -> RenameM [Type]
forall a. Rename a => a -> RenameM a
rename [Type]
ts RenameM (Body rep -> HostOp op rep)
-> RenameM (Body rep) -> RenameM (HostOp op rep)
forall a b. RenameM (a -> b) -> RenameM a -> RenameM b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Body rep -> RenameM (Body rep)
forall a. Rename a => a -> RenameM a
rename Body rep
body
instance (IsOp op) => IsOp (HostOp op) where
safeOp :: forall rep. ASTRep rep => HostOp op rep -> Bool
safeOp (SegOp SegOp SegLevel rep
op) = SegOp SegLevel rep -> Bool
forall rep. ASTRep rep => SegOp SegLevel rep -> Bool
forall (op :: * -> *) rep. (IsOp op, ASTRep rep) => op rep -> Bool
safeOp SegOp SegLevel rep
op
safeOp (OtherOp op rep
op) = op rep -> Bool
forall rep. ASTRep rep => op rep -> Bool
forall (op :: * -> *) rep. (IsOp op, ASTRep rep) => op rep -> Bool
safeOp op rep
op
safeOp (SizeOp SizeOp
_) = Bool
True
safeOp (GPUBody [Type]
_ Body rep
body) = (Stm rep -> Bool) -> Seq (Stm rep) -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Exp rep -> Bool
forall rep. ASTRep rep => Exp rep -> Bool
safeExp (Exp rep -> Bool) -> (Stm rep -> Exp rep) -> Stm rep -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm rep -> Exp rep
forall rep. Stm rep -> Exp rep
stmExp) (Seq (Stm rep) -> Bool) -> Seq (Stm rep) -> Bool
forall a b. (a -> b) -> a -> b
$ Body rep -> Seq (Stm rep)
forall rep. Body rep -> Stms rep
bodyStms Body rep
body
cheapOp :: forall rep. ASTRep rep => HostOp op rep -> Bool
cheapOp (SegOp SegOp SegLevel rep
op) = SegOp SegLevel rep -> Bool
forall rep. ASTRep rep => SegOp SegLevel rep -> Bool
forall (op :: * -> *) rep. (IsOp op, ASTRep rep) => op rep -> Bool
cheapOp SegOp SegLevel rep
op
cheapOp (OtherOp op rep
op) = op rep -> Bool
forall rep. ASTRep rep => op rep -> Bool
forall (op :: * -> *) rep. (IsOp op, ASTRep rep) => op rep -> Bool
cheapOp op rep
op
cheapOp (SizeOp SizeOp
_) = Bool
True
cheapOp (GPUBody [Type]
types Body rep
body) =
Seq (Stm rep) -> Bool
forall a. Seq a -> Bool
SQ.null (Body rep -> Seq (Stm rep)
forall rep. Body rep -> Stms rep
bodyStms Body rep
body) Bool -> Bool -> Bool
&& (Type -> Bool) -> [Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all ((Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0) (Int -> Bool) -> (Type -> Int) -> Type -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank) [Type]
types
opDependencies :: forall rep. ASTRep rep => HostOp op rep -> [Names]
opDependencies (SegOp SegOp SegLevel rep
op) = SegOp SegLevel rep -> [Names]
forall rep. ASTRep rep => SegOp SegLevel rep -> [Names]
forall (op :: * -> *) rep.
(IsOp op, ASTRep rep) =>
op rep -> [Names]
opDependencies SegOp SegLevel rep
op
opDependencies (OtherOp op rep
op) = op rep -> [Names]
forall rep. ASTRep rep => op rep -> [Names]
forall (op :: * -> *) rep.
(IsOp op, ASTRep rep) =>
op rep -> [Names]
opDependencies op rep
op
opDependencies (SizeOp SizeOp
op) = [SizeOp -> Names
forall a. FreeIn a => a -> Names
freeIn SizeOp
op]
opDependencies (GPUBody [Type]
_ Body rep
body) =
Int -> Names -> [Names]
forall a. Int -> a -> [a]
replicate ([SubExpRes] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExpRes] -> Int)
-> (Body rep -> [SubExpRes]) -> Body rep -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body rep -> [SubExpRes]
forall rep. Body rep -> [SubExpRes]
bodyResult (Body rep -> Int) -> Body rep -> Int
forall a b. (a -> b) -> a -> b
$ Body rep
body) (Body rep -> Names
forall a. FreeIn a => a -> Names
freeIn Body rep
body)
instance (TypedOp op) => TypedOp (HostOp op) where
opType :: forall rep (m :: * -> *).
HasScope rep m =>
HostOp op rep -> m [ExtType]
opType (SegOp SegOp SegLevel rep
op) = SegOp SegLevel rep -> m [ExtType]
forall rep (m :: * -> *).
HasScope rep m =>
SegOp SegLevel rep -> m [ExtType]
forall (op :: * -> *) rep (m :: * -> *).
(TypedOp op, HasScope rep m) =>
op rep -> m [ExtType]
opType SegOp SegLevel rep
op
opType (OtherOp op rep
op) = op rep -> m [ExtType]
forall rep (m :: * -> *). HasScope rep m => op rep -> m [ExtType]
forall (op :: * -> *) rep (m :: * -> *).
(TypedOp op, HasScope rep m) =>
op rep -> m [ExtType]
opType op rep
op
opType (SizeOp (GetSize Name
_ SizeClass
_)) = [ExtType] -> m [ExtType]
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [PrimType -> ExtType
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]
opType (SizeOp (GetSizeMax SizeClass
_)) = [ExtType] -> m [ExtType]
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [PrimType -> ExtType
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]
opType (SizeOp CmpSizeLe {}) = [ExtType] -> m [ExtType]
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [PrimType -> ExtType
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
Bool]
opType (SizeOp (CalcNumBlocks {})) = [ExtType] -> m [ExtType]
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [PrimType -> ExtType
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]
opType (GPUBody [Type]
ts Body rep
_) =
[ExtType] -> m [ExtType]
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([ExtType] -> m [ExtType]) -> [ExtType] -> m [ExtType]
forall a b. (a -> b) -> a -> b
$ [Type] -> [ExtType]
forall u. [TypeBase (ShapeBase SubExp) u] -> [TypeBase ExtShape u]
staticShapes ([Type] -> [ExtType]) -> [Type] -> [ExtType]
forall a b. (a -> b) -> a -> b
$ (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) [Type]
ts
instance (AliasedOp op) => AliasedOp (HostOp op) where
opAliases :: forall rep. Aliased rep => HostOp op rep -> [Names]
opAliases (SegOp SegOp SegLevel rep
op) = SegOp SegLevel rep -> [Names]
forall rep. Aliased rep => SegOp SegLevel rep -> [Names]
forall (op :: * -> *) rep.
(AliasedOp op, Aliased rep) =>
op rep -> [Names]
opAliases SegOp SegLevel rep
op
opAliases (OtherOp op rep
op) = op rep -> [Names]
forall rep. Aliased rep => op rep -> [Names]
forall (op :: * -> *) rep.
(AliasedOp op, Aliased rep) =>
op rep -> [Names]
opAliases op rep
op
opAliases (SizeOp SizeOp
_) = [Names
forall a. Monoid a => a
mempty]
opAliases (GPUBody [Type]
ts Body rep
_) = (Type -> Names) -> [Type] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map (Names -> Type -> Names
forall a b. a -> b -> a
const Names
forall a. Monoid a => a
mempty) [Type]
ts
consumedInOp :: forall rep. Aliased rep => HostOp op rep -> Names
consumedInOp (SegOp SegOp SegLevel rep
op) = SegOp SegLevel rep -> Names
forall rep. Aliased rep => SegOp SegLevel rep -> Names
forall (op :: * -> *) rep.
(AliasedOp op, Aliased rep) =>
op rep -> Names
consumedInOp SegOp SegLevel rep
op
consumedInOp (OtherOp op rep
op) = op rep -> Names
forall rep. Aliased rep => op rep -> Names
forall (op :: * -> *) rep.
(AliasedOp op, Aliased rep) =>
op rep -> Names
consumedInOp op rep
op
consumedInOp (SizeOp SizeOp
_) = Names
forall a. Monoid a => a
mempty
consumedInOp (GPUBody [Type]
_ Body rep
body) = Body rep -> Names
forall rep. Aliased rep => Body rep -> Names
consumedInBody Body rep
body
instance (ASTRep rep, FreeIn (op rep)) => FreeIn (HostOp op rep) where
freeIn' :: HostOp op rep -> FV
freeIn' (SegOp SegOp SegLevel rep
op) = SegOp SegLevel rep -> FV
forall a. FreeIn a => a -> FV
freeIn' SegOp SegLevel rep
op
freeIn' (OtherOp op rep
op) = op rep -> FV
forall a. FreeIn a => a -> FV
freeIn' op rep
op
freeIn' (SizeOp SizeOp
op) = SizeOp -> FV
forall a. FreeIn a => a -> FV
freeIn' SizeOp
op
freeIn' (GPUBody [Type]
ts Body rep
body) = [Type] -> FV
forall a. FreeIn a => a -> FV
freeIn' [Type]
ts FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> Body rep -> FV
forall a. FreeIn a => a -> FV
freeIn' Body rep
body
instance (CanBeAliased op) => CanBeAliased (HostOp op) where
addOpAliases :: forall rep.
AliasableRep rep =>
AliasTable -> HostOp op rep -> HostOp op (Aliases rep)
addOpAliases AliasTable
aliases (SegOp SegOp SegLevel rep
op) = SegOp SegLevel (Aliases rep) -> HostOp op (Aliases rep)
forall (op :: * -> *) rep. SegOp SegLevel rep -> HostOp op rep
SegOp (SegOp SegLevel (Aliases rep) -> HostOp op (Aliases rep))
-> SegOp SegLevel (Aliases rep) -> HostOp op (Aliases rep)
forall a b. (a -> b) -> a -> b
$ AliasTable -> SegOp SegLevel rep -> SegOp SegLevel (Aliases rep)
forall rep.
AliasableRep rep =>
AliasTable -> SegOp SegLevel rep -> SegOp SegLevel (Aliases rep)
forall (op :: * -> *) rep.
(CanBeAliased op, AliasableRep rep) =>
AliasTable -> op rep -> op (Aliases rep)
addOpAliases AliasTable
aliases SegOp SegLevel rep
op
addOpAliases AliasTable
aliases (GPUBody [Type]
ts Body rep
body) = [Type] -> Body (Aliases rep) -> HostOp op (Aliases rep)
forall (op :: * -> *) rep. [Type] -> Body rep -> HostOp op rep
GPUBody [Type]
ts (Body (Aliases rep) -> HostOp op (Aliases rep))
-> Body (Aliases rep) -> HostOp op (Aliases rep)
forall a b. (a -> b) -> a -> b
$ AliasTable -> Body rep -> Body (Aliases rep)
forall rep.
AliasableRep rep =>
AliasTable -> Body rep -> Body (Aliases rep)
Alias.analyseBody AliasTable
aliases Body rep
body
addOpAliases AliasTable
aliases (OtherOp op rep
op) = op (Aliases rep) -> HostOp op (Aliases rep)
forall (op :: * -> *) rep. op rep -> HostOp op rep
OtherOp (op (Aliases rep) -> HostOp op (Aliases rep))
-> op (Aliases rep) -> HostOp op (Aliases rep)
forall a b. (a -> b) -> a -> b
$ AliasTable -> op rep -> op (Aliases rep)
forall rep.
AliasableRep rep =>
AliasTable -> op rep -> op (Aliases rep)
forall (op :: * -> *) rep.
(CanBeAliased op, AliasableRep rep) =>
AliasTable -> op rep -> op (Aliases rep)
addOpAliases AliasTable
aliases op rep
op
addOpAliases AliasTable
_ (SizeOp SizeOp
op) = SizeOp -> HostOp op (Aliases rep)
forall (op :: * -> *) rep. SizeOp -> HostOp op rep
SizeOp SizeOp
op
instance (CanBeWise op) => CanBeWise (HostOp op) where
addOpWisdom :: forall rep. Informing rep => HostOp op rep -> HostOp op (Wise rep)
addOpWisdom (SegOp SegOp SegLevel rep
op) = SegOp SegLevel (Wise rep) -> HostOp op (Wise rep)
forall (op :: * -> *) rep. SegOp SegLevel rep -> HostOp op rep
SegOp (SegOp SegLevel (Wise rep) -> HostOp op (Wise rep))
-> SegOp SegLevel (Wise rep) -> HostOp op (Wise rep)
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel rep -> SegOp SegLevel (Wise rep)
forall rep.
Informing rep =>
SegOp SegLevel rep -> SegOp SegLevel (Wise rep)
forall (op :: * -> *) rep.
(CanBeWise op, Informing rep) =>
op rep -> op (Wise rep)
addOpWisdom SegOp SegLevel rep
op
addOpWisdom (OtherOp op rep
op) = op (Wise rep) -> HostOp op (Wise rep)
forall (op :: * -> *) rep. op rep -> HostOp op rep
OtherOp (op (Wise rep) -> HostOp op (Wise rep))
-> op (Wise rep) -> HostOp op (Wise rep)
forall a b. (a -> b) -> a -> b
$ op rep -> op (Wise rep)
forall rep. Informing rep => op rep -> op (Wise rep)
forall (op :: * -> *) rep.
(CanBeWise op, Informing rep) =>
op rep -> op (Wise rep)
addOpWisdom op rep
op
addOpWisdom (SizeOp SizeOp
op) = SizeOp -> HostOp op (Wise rep)
forall (op :: * -> *) rep. SizeOp -> HostOp op rep
SizeOp SizeOp
op
addOpWisdom (GPUBody [Type]
ts Body rep
body) = [Type] -> Body (Wise rep) -> HostOp op (Wise rep)
forall (op :: * -> *) rep. [Type] -> Body rep -> HostOp op rep
GPUBody [Type]
ts (Body (Wise rep) -> HostOp op (Wise rep))
-> Body (Wise rep) -> HostOp op (Wise rep)
forall a b. (a -> b) -> a -> b
$ Body rep -> Body (Wise rep)
forall rep. Informing rep => Body rep -> Body (Wise rep)
informBody Body rep
body
instance OpReturns (HostOp NoOp) where
opReturns :: forall rep (inner :: * -> *) (m :: * -> *).
(Mem rep inner, Monad m, HasScope rep m) =>
HostOp NoOp rep -> m [ExpReturns]
opReturns (SegOp SegOp SegLevel rep
op) = SegOp SegLevel rep -> m [ExpReturns]
forall rep (inner :: * -> *) (m :: * -> *) lvl.
(Mem rep inner, Monad m, HasScope rep m) =>
SegOp lvl rep -> m [ExpReturns]
segOpReturns SegOp SegLevel rep
op
opReturns HostOp NoOp rep
k = [ExtType] -> [ExpReturns]
extReturns ([ExtType] -> [ExpReturns]) -> m [ExtType] -> m [ExpReturns]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> HostOp NoOp rep -> m [ExtType]
forall rep (m :: * -> *).
HasScope rep m =>
HostOp NoOp rep -> m [ExtType]
forall (op :: * -> *) rep (m :: * -> *).
(TypedOp op, HasScope rep m) =>
op rep -> m [ExtType]
opType HostOp NoOp rep
k
instance (ASTRep rep, ST.IndexOp (op rep)) => ST.IndexOp (HostOp op rep) where
indexOp :: forall rep.
(ASTRep rep, IndexOp (Op rep)) =>
SymbolTable rep
-> Int -> HostOp op rep -> [TPrimExp Int64 VName] -> Maybe Indexed
indexOp SymbolTable rep
vtable Int
k (SegOp SegOp SegLevel rep
op) [TPrimExp Int64 VName]
is = SymbolTable rep
-> Int
-> SegOp SegLevel rep
-> [TPrimExp Int64 VName]
-> Maybe Indexed
forall rep.
(ASTRep rep, IndexOp (Op rep)) =>
SymbolTable rep
-> Int
-> SegOp SegLevel rep
-> [TPrimExp Int64 VName]
-> Maybe Indexed
forall op rep.
(IndexOp op, ASTRep rep, IndexOp (Op rep)) =>
SymbolTable rep
-> Int -> op -> [TPrimExp Int64 VName] -> Maybe Indexed
ST.indexOp SymbolTable rep
vtable Int
k SegOp SegLevel rep
op [TPrimExp Int64 VName]
is
indexOp SymbolTable rep
vtable Int
k (OtherOp op rep
op) [TPrimExp Int64 VName]
is = SymbolTable rep
-> Int -> op rep -> [TPrimExp Int64 VName] -> Maybe Indexed
forall rep.
(ASTRep rep, IndexOp (Op rep)) =>
SymbolTable rep
-> Int -> op rep -> [TPrimExp Int64 VName] -> Maybe Indexed
forall op rep.
(IndexOp op, ASTRep rep, IndexOp (Op rep)) =>
SymbolTable rep
-> Int -> op -> [TPrimExp Int64 VName] -> Maybe Indexed
ST.indexOp SymbolTable rep
vtable Int
k op rep
op [TPrimExp Int64 VName]
is
indexOp SymbolTable rep
_ Int
_ HostOp op rep
_ [TPrimExp Int64 VName]
_ = Maybe Indexed
forall a. Maybe a
Nothing
instance (PrettyRep rep, PP.Pretty (op rep)) => PP.Pretty (HostOp op rep) where
pretty :: forall ann. HostOp op rep -> Doc ann
pretty (SegOp SegOp SegLevel rep
op) = SegOp SegLevel rep -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. SegOp SegLevel rep -> Doc ann
pretty SegOp SegLevel rep
op
pretty (OtherOp op rep
op) = op rep -> Doc ann
forall ann. op rep -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty op rep
op
pretty (SizeOp SizeOp
op) = SizeOp -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. SizeOp -> Doc ann
pretty SizeOp
op
pretty (GPUBody [Type]
ts Body rep
body) =
Doc ann
"gpu" Doc ann -> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann -> Doc ann
<+> Doc ann
forall ann. Doc ann
PP.colon Doc ann -> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann -> Doc ann
<+> [Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
ppTuple' ((Type -> Doc ann) -> [Type] -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. Type -> Doc ann
pretty [Type]
ts) Doc ann -> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann -> Doc ann
<+> Doc ann -> Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a -> Doc a
PP.nestedBlock Doc ann
"{" Doc ann
"}" (Body rep -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. Body rep -> Doc ann
pretty Body rep
body)
instance (OpMetrics (Op rep), OpMetrics (op rep)) => OpMetrics (HostOp op rep) where
opMetrics :: HostOp op rep -> MetricsM ()
opMetrics (SegOp SegOp SegLevel rep
op) = SegOp SegLevel rep -> MetricsM ()
forall op. OpMetrics op => op -> MetricsM ()
opMetrics SegOp SegLevel rep
op
opMetrics (OtherOp op rep
op) = op rep -> MetricsM ()
forall op. OpMetrics op => op -> MetricsM ()
opMetrics op rep
op
opMetrics (SizeOp SizeOp
op) = SizeOp -> MetricsM ()
forall op. OpMetrics op => op -> MetricsM ()
opMetrics SizeOp
op
opMetrics (GPUBody [Type]
_ Body rep
body) = Text -> MetricsM () -> MetricsM ()
inside Text
"GPUBody" (MetricsM () -> MetricsM ()) -> MetricsM () -> MetricsM ()
forall a b. (a -> b) -> a -> b
$ Body rep -> MetricsM ()
forall rep. OpMetrics (Op rep) => Body rep -> MetricsM ()
bodyMetrics Body rep
body
instance (RephraseOp op) => RephraseOp (HostOp op) where
rephraseInOp :: forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> HostOp op from -> m (HostOp op to)
rephraseInOp Rephraser m from to
r (SegOp SegOp SegLevel from
op) = SegOp SegLevel to -> HostOp op to
forall (op :: * -> *) rep. SegOp SegLevel rep -> HostOp op rep
SegOp (SegOp SegLevel to -> HostOp op to)
-> m (SegOp SegLevel to) -> m (HostOp op to)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Rephraser m from to -> SegOp SegLevel from -> m (SegOp SegLevel to)
forall (op :: * -> *) (m :: * -> *) from to.
(RephraseOp op, Monad m) =>
Rephraser m from to -> op from -> m (op to)
forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> SegOp SegLevel from -> m (SegOp SegLevel to)
rephraseInOp Rephraser m from to
r SegOp SegLevel from
op
rephraseInOp Rephraser m from to
r (OtherOp op from
op) = op to -> HostOp op to
forall (op :: * -> *) rep. op rep -> HostOp op rep
OtherOp (op to -> HostOp op to) -> m (op to) -> m (HostOp op to)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Rephraser m from to -> op from -> m (op to)
forall (op :: * -> *) (m :: * -> *) from to.
(RephraseOp op, Monad m) =>
Rephraser m from to -> op from -> m (op to)
forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> op from -> m (op to)
rephraseInOp Rephraser m from to
r op from
op
rephraseInOp Rephraser m from to
_ (SizeOp SizeOp
op) = HostOp op to -> m (HostOp op to)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (HostOp op to -> m (HostOp op to))
-> HostOp op to -> m (HostOp op to)
forall a b. (a -> b) -> a -> b
$ SizeOp -> HostOp op to
forall (op :: * -> *) rep. SizeOp -> HostOp op rep
SizeOp SizeOp
op
rephraseInOp Rephraser m from to
r (GPUBody [Type]
ts Body from
body) = [Type] -> Body to -> HostOp op to
forall (op :: * -> *) rep. [Type] -> Body rep -> HostOp op rep
GPUBody [Type]
ts (Body to -> HostOp op to) -> m (Body to) -> m (HostOp op to)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Rephraser m from to -> Body from -> m (Body to)
forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Body from -> m (Body to)
rephraseBody Rephraser m from to
r Body from
body
checkGrid :: (TC.Checkable rep) => KernelGrid -> TC.TypeM rep ()
checkGrid :: forall rep. Checkable rep => KernelGrid -> TypeM rep ()
checkGrid KernelGrid
grid = do
[Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] (SubExp -> TypeM rep ()) -> SubExp -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ Count NumBlocks SubExp -> SubExp
forall {k} (u :: k) e. Count u e -> e
unCount (Count NumBlocks SubExp -> SubExp)
-> Count NumBlocks SubExp -> SubExp
forall a b. (a -> b) -> a -> b
$ KernelGrid -> Count NumBlocks SubExp
gridNumBlocks KernelGrid
grid
[Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] (SubExp -> TypeM rep ()) -> SubExp -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ Count BlockSize SubExp -> SubExp
forall {k} (u :: k) e. Count u e -> e
unCount (Count BlockSize SubExp -> SubExp)
-> Count BlockSize SubExp -> SubExp
forall a b. (a -> b) -> a -> b
$ KernelGrid -> Count BlockSize SubExp
gridBlockSize KernelGrid
grid
checkSegLevel ::
(TC.Checkable rep) =>
Maybe SegLevel ->
SegLevel ->
TC.TypeM rep ()
checkSegLevel :: forall rep.
Checkable rep =>
Maybe SegLevel -> SegLevel -> TypeM rep ()
checkSegLevel (Just SegBlock {}) (SegThreadInBlock SegVirt
_virt) =
() -> TypeM rep ()
forall a. a -> TypeM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
checkSegLevel Maybe SegLevel
_ (SegThreadInBlock SegVirt
_virt) =
ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ()) -> ErrorCase rep -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ Text -> ErrorCase rep
forall rep. Text -> ErrorCase rep
TC.TypeError Text
"inblock SegOp not in block SegOp."
checkSegLevel (Just SegThread {}) SegLevel
_ =
ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ()) -> ErrorCase rep -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ Text -> ErrorCase rep
forall rep. Text -> ErrorCase rep
TC.TypeError Text
"SegOps cannot occur when already at thread level."
checkSegLevel (Just SegThreadInBlock {}) SegLevel
_ =
ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ()) -> ErrorCase rep -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ Text -> ErrorCase rep
forall rep. Text -> ErrorCase rep
TC.TypeError Text
"SegOps cannot occur when already at inblock level."
checkSegLevel Maybe SegLevel
_ (SegThread SegVirt
_virt Maybe KernelGrid
Nothing) =
() -> TypeM rep ()
forall a. a -> TypeM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
checkSegLevel (Just SegLevel
_) SegThread {} =
ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ()) -> ErrorCase rep -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ Text -> ErrorCase rep
forall rep. Text -> ErrorCase rep
TC.TypeError Text
"thread-level SegOp cannot be nested"
checkSegLevel Maybe SegLevel
Nothing (SegThread SegVirt
_virt Maybe KernelGrid
grid) =
(KernelGrid -> TypeM rep ()) -> Maybe KernelGrid -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ KernelGrid -> TypeM rep ()
forall rep. Checkable rep => KernelGrid -> TypeM rep ()
checkGrid Maybe KernelGrid
grid
checkSegLevel (Just SegLevel
_) SegBlock {} =
ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ()) -> ErrorCase rep -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ Text -> ErrorCase rep
forall rep. Text -> ErrorCase rep
TC.TypeError Text
"block-level SegOp cannot be nested"
checkSegLevel Maybe SegLevel
Nothing (SegBlock SegVirt
_virt Maybe KernelGrid
grid) =
(KernelGrid -> TypeM rep ()) -> Maybe KernelGrid -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ KernelGrid -> TypeM rep ()
forall rep. Checkable rep => KernelGrid -> TypeM rep ()
checkGrid Maybe KernelGrid
grid
typeCheckHostOp ::
(TC.Checkable rep) =>
(SegLevel -> Op (Aliases rep) -> TC.TypeM rep ()) ->
Maybe SegLevel ->
(op (Aliases rep) -> TC.TypeM rep ()) ->
HostOp op (Aliases rep) ->
TC.TypeM rep ()
typeCheckHostOp :: forall rep (op :: * -> *).
Checkable rep =>
(SegLevel -> Op (Aliases rep) -> TypeM rep ())
-> Maybe SegLevel
-> (op (Aliases rep) -> TypeM rep ())
-> HostOp op (Aliases rep)
-> TypeM rep ()
typeCheckHostOp SegLevel -> Op (Aliases rep) -> TypeM rep ()
checker Maybe SegLevel
lvl op (Aliases rep) -> TypeM rep ()
_ (SegOp SegOp SegLevel (Aliases rep)
op) =
(Op (Aliases rep) -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall rep a.
(Op (Aliases rep) -> TypeM rep ()) -> TypeM rep a -> TypeM rep a
TC.checkOpWith (SegLevel -> Op (Aliases rep) -> TypeM rep ()
checker (SegLevel -> Op (Aliases rep) -> TypeM rep ())
-> SegLevel -> Op (Aliases rep) -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel (Aliases rep) -> SegLevel
forall lvl rep. SegOp lvl rep -> lvl
segLevel SegOp SegLevel (Aliases rep)
op) (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
(SegLevel -> TypeM rep ())
-> SegOp SegLevel (Aliases rep) -> TypeM rep ()
forall rep lvl.
Checkable rep =>
(lvl -> TypeM rep ()) -> SegOp lvl (Aliases rep) -> TypeM rep ()
typeCheckSegOp (Maybe SegLevel -> SegLevel -> TypeM rep ()
forall rep.
Checkable rep =>
Maybe SegLevel -> SegLevel -> TypeM rep ()
checkSegLevel Maybe SegLevel
lvl) SegOp SegLevel (Aliases rep)
op
typeCheckHostOp SegLevel -> Op (Aliases rep) -> TypeM rep ()
_ Just {} op (Aliases rep) -> TypeM rep ()
_ GPUBody {} =
ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ()) -> ErrorCase rep -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ Text -> ErrorCase rep
forall rep. Text -> ErrorCase rep
TC.TypeError Text
"GPUBody may not be nested in SegOps."
typeCheckHostOp SegLevel -> Op (Aliases rep) -> TypeM rep ()
_ Maybe SegLevel
_ op (Aliases rep) -> TypeM rep ()
f (OtherOp op (Aliases rep)
op) = op (Aliases rep) -> TypeM rep ()
f op (Aliases rep)
op
typeCheckHostOp SegLevel -> Op (Aliases rep) -> TypeM rep ()
_ Maybe SegLevel
_ op (Aliases rep) -> TypeM rep ()
_ (SizeOp SizeOp
op) = SizeOp -> TypeM rep ()
forall rep. Checkable rep => SizeOp -> TypeM rep ()
typeCheckSizeOp SizeOp
op
typeCheckHostOp SegLevel -> Op (Aliases rep) -> TypeM rep ()
_ Maybe SegLevel
Nothing op (Aliases rep) -> TypeM rep ()
_ (GPUBody [Type]
ts Body (Aliases rep)
body) = do
(Type -> TypeM rep ()) -> [Type] -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Type -> TypeM rep ()
forall rep u.
Checkable rep =>
TypeBase (ShapeBase SubExp) u -> TypeM rep ()
TC.checkType [Type]
ts
TypeM rep [Names] -> TypeM rep ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (TypeM rep [Names] -> TypeM rep ())
-> TypeM rep [Names] -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ Body (Aliases rep) -> TypeM rep [Names]
forall rep.
Checkable rep =>
Body (Aliases rep) -> TypeM rep [Names]
TC.checkBody Body (Aliases rep)
body
[Type]
body_ts <-
ExtendedScope (Aliases rep) (TypeM rep) [Type]
-> Scope (Aliases rep) -> TypeM rep [Type]
forall rep (m :: * -> *) a.
ExtendedScope rep m a -> Scope rep -> m a
extendedScope
((SubExpRes -> ExtendedScope (Aliases rep) (TypeM rep) Type)
-> [SubExpRes] -> ExtendedScope (Aliases rep) (TypeM rep) [Type]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse SubExpRes -> ExtendedScope (Aliases rep) (TypeM rep) Type
forall t (m :: * -> *). HasScope t m => SubExpRes -> m Type
subExpResType (Body (Aliases rep) -> [SubExpRes]
forall rep. Body rep -> [SubExpRes]
bodyResult Body (Aliases rep)
body))
(Stms (Aliases rep) -> Scope (Aliases rep)
forall rep a. Scoped rep a => a -> Scope rep
scopeOf (Body (Aliases rep) -> Stms (Aliases rep)
forall rep. Body rep -> Stms rep
bodyStms Body (Aliases rep)
body))
Bool -> TypeM rep () -> TypeM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type]
body_ts [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== [Type]
ts) (TypeM rep () -> TypeM rep ())
-> ([Text] -> TypeM rep ()) -> [Text] -> TypeM rep ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ())
-> ([Text] -> ErrorCase rep) -> [Text] -> TypeM rep ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> ErrorCase rep
forall rep. Text -> ErrorCase rep
TC.TypeError (Text -> ErrorCase rep)
-> ([Text] -> Text) -> [Text] -> ErrorCase rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Text] -> Text
T.unlines ([Text] -> TypeM rep ()) -> [Text] -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
[ Text
"Expected type: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> [Type] -> Text
forall a. Pretty a => [a] -> Text
prettyTuple [Type]
ts,
Text
"Got body type: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> [Type] -> Text
forall a. Pretty a => [a] -> Text
prettyTuple [Type]
body_ts
]