{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}

module Futhark.IR.GPU.Op
  ( -- * Size operations
    SizeOp (..),

    -- * Host operations
    HostOp (..),
    traverseHostOpStms,
    typeCheckHostOp,

    -- * SegOp refinements
    SegLevel (..),

    -- * Reexports
    module Futhark.IR.GPU.Sizes,
    module Futhark.IR.SegOp,
  )
where

import Control.Monad
import Data.Sequence qualified as SQ
import Data.Text qualified as T
import Futhark.Analysis.Alias qualified as Alias
import Futhark.Analysis.Metrics
import Futhark.Analysis.SymbolTable qualified as ST
import Futhark.IR
import Futhark.IR.Aliases (Aliases, removeBodyAliases)
import Futhark.IR.GPU.Sizes
import Futhark.IR.Prop.Aliases
import Futhark.IR.SegOp
import Futhark.IR.TypeCheck qualified as TC
import Futhark.Optimise.Simplify.Engine qualified as Engine
import Futhark.Optimise.Simplify.Rep
import Futhark.Transform.Rename
import Futhark.Transform.Substitute
import Futhark.Util.Pretty
  ( commasep,
    parens,
    ppTuple',
    pretty,
    (<+>),
  )
import Futhark.Util.Pretty qualified as PP

-- | At which level the *body* of a t'SegOp' executes.
data SegLevel
  = SegThread
      { SegLevel -> Count NumGroups SubExp
segNumGroups :: Count NumGroups SubExp,
        SegLevel -> Count GroupSize SubExp
segGroupSize :: Count GroupSize SubExp,
        SegLevel -> SegVirt
segVirt :: SegVirt
      }
  | SegGroup
      { segNumGroups :: Count NumGroups SubExp,
        segGroupSize :: Count GroupSize SubExp,
        segVirt :: SegVirt
      }
  deriving (SegLevel -> SegLevel -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SegLevel -> SegLevel -> Bool
$c/= :: SegLevel -> SegLevel -> Bool
== :: SegLevel -> SegLevel -> Bool
$c== :: SegLevel -> SegLevel -> Bool
Eq, Eq SegLevel
SegLevel -> SegLevel -> Bool
SegLevel -> SegLevel -> Ordering
SegLevel -> SegLevel -> SegLevel
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: SegLevel -> SegLevel -> SegLevel
$cmin :: SegLevel -> SegLevel -> SegLevel
max :: SegLevel -> SegLevel -> SegLevel
$cmax :: SegLevel -> SegLevel -> SegLevel
>= :: SegLevel -> SegLevel -> Bool
$c>= :: SegLevel -> SegLevel -> Bool
> :: SegLevel -> SegLevel -> Bool
$c> :: SegLevel -> SegLevel -> Bool
<= :: SegLevel -> SegLevel -> Bool
$c<= :: SegLevel -> SegLevel -> Bool
< :: SegLevel -> SegLevel -> Bool
$c< :: SegLevel -> SegLevel -> Bool
compare :: SegLevel -> SegLevel -> Ordering
$ccompare :: SegLevel -> SegLevel -> Ordering
Ord, Int -> SegLevel -> ShowS
[SegLevel] -> ShowS
SegLevel -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SegLevel] -> ShowS
$cshowList :: [SegLevel] -> ShowS
show :: SegLevel -> String
$cshow :: SegLevel -> String
showsPrec :: Int -> SegLevel -> ShowS
$cshowsPrec :: Int -> SegLevel -> ShowS
Show)

instance PP.Pretty SegLevel where
  pretty :: forall ann. SegLevel -> Doc ann
pretty SegLevel
lvl =
    forall ann. Doc ann -> Doc ann
PP.parens
      ( Doc ann
lvl' forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
PP.semi
          forall ann. Doc ann -> Doc ann -> Doc ann
<+> Doc ann
"#groups=" forall a. Semigroup a => a -> a -> a
<> forall a ann. Pretty a => a -> Doc ann
pretty (SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl) forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
PP.semi
          forall ann. Doc ann -> Doc ann -> Doc ann
<+> Doc ann
"groupsize=" forall a. Semigroup a => a -> a -> a
<> forall a ann. Pretty a => a -> Doc ann
pretty (SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl) forall a. Semigroup a => a -> a -> a
<> Doc ann
virt
      )
    where
      lvl' :: Doc ann
lvl' = case SegLevel
lvl of
        SegThread {} -> Doc ann
"thread"
        SegGroup {} -> Doc ann
"group"
      virt :: Doc ann
virt = case SegLevel -> SegVirt
segVirt SegLevel
lvl of
        SegVirt
SegNoVirt -> forall a. Monoid a => a
mempty
        SegNoVirtFull SegSeqDims
dims -> forall ann. Doc ann
PP.semi forall ann. Doc ann -> Doc ann -> Doc ann
<+> Doc ann
"full" forall ann. Doc ann -> Doc ann -> Doc ann
<+> forall a ann. Pretty a => a -> Doc ann
pretty (SegSeqDims -> [Int]
segSeqDims SegSeqDims
dims)
        SegVirt
SegVirt -> forall ann. Doc ann
PP.semi forall ann. Doc ann -> Doc ann -> Doc ann
<+> Doc ann
"virtualise"

instance Engine.Simplifiable SegLevel where
  simplify :: forall {k} (rep :: k).
SimplifiableRep rep =>
SegLevel -> SimpleM rep SegLevel
simplify (SegThread Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegVirt
virt) =
    Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegThread
      forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall e {k} (rep :: k).
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify Count NumGroups SubExp
num_groups
      forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall e {k} (rep :: k).
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify Count GroupSize SubExp
group_size
      forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure SegVirt
virt
  simplify (SegGroup Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegVirt
virt) =
    Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegGroup
      forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall e {k} (rep :: k).
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify Count NumGroups SubExp
num_groups
      forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall e {k} (rep :: k).
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify Count GroupSize SubExp
group_size
      forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure SegVirt
virt

instance Substitute SegLevel where
  substituteNames :: Map VName VName -> SegLevel -> SegLevel
substituteNames Map VName VName
substs (SegThread Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegVirt
virt) =
    Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegThread
      (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Count NumGroups SubExp
num_groups)
      (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Count GroupSize SubExp
group_size)
      SegVirt
virt
  substituteNames Map VName VName
substs (SegGroup Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegVirt
virt) =
    Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegGroup
      (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Count NumGroups SubExp
num_groups)
      (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Count GroupSize SubExp
group_size)
      SegVirt
virt

instance Rename SegLevel where
  rename :: SegLevel -> RenameM SegLevel
rename = forall a. Substitute a => a -> RenameM a
substituteRename

instance FreeIn SegLevel where
  freeIn' :: SegLevel -> FV
freeIn' (SegThread Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegVirt
_) =
    forall a. FreeIn a => a -> FV
freeIn' Count NumGroups SubExp
num_groups forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> FV
freeIn' Count GroupSize SubExp
group_size
  freeIn' (SegGroup Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegVirt
_) =
    forall a. FreeIn a => a -> FV
freeIn' Count NumGroups SubExp
num_groups forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> FV
freeIn' Count GroupSize SubExp
group_size

-- | A simple size-level query or computation.
data SizeOp
  = -- | Produce some runtime-configurable size.
    GetSize Name SizeClass
  | -- | The maximum size of some class.
    GetSizeMax SizeClass
  | -- | Compare size (likely a threshold) with some integer value.
    CmpSizeLe Name SizeClass SubExp
  | -- | @CalcNumGroups w max_num_groups group_size@ calculates the
    -- number of GPU workgroups to use for an input of the given size.
    -- The @Name@ is a size name.  Note that @w@ is an i64 to avoid
    -- overflow issues.
    CalcNumGroups SubExp Name SubExp
  deriving (SizeOp -> SizeOp -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SizeOp -> SizeOp -> Bool
$c/= :: SizeOp -> SizeOp -> Bool
== :: SizeOp -> SizeOp -> Bool
$c== :: SizeOp -> SizeOp -> Bool
Eq, Eq SizeOp
SizeOp -> SizeOp -> Bool
SizeOp -> SizeOp -> Ordering
SizeOp -> SizeOp -> SizeOp
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: SizeOp -> SizeOp -> SizeOp
$cmin :: SizeOp -> SizeOp -> SizeOp
max :: SizeOp -> SizeOp -> SizeOp
$cmax :: SizeOp -> SizeOp -> SizeOp
>= :: SizeOp -> SizeOp -> Bool
$c>= :: SizeOp -> SizeOp -> Bool
> :: SizeOp -> SizeOp -> Bool
$c> :: SizeOp -> SizeOp -> Bool
<= :: SizeOp -> SizeOp -> Bool
$c<= :: SizeOp -> SizeOp -> Bool
< :: SizeOp -> SizeOp -> Bool
$c< :: SizeOp -> SizeOp -> Bool
compare :: SizeOp -> SizeOp -> Ordering
$ccompare :: SizeOp -> SizeOp -> Ordering
Ord, Int -> SizeOp -> ShowS
[SizeOp] -> ShowS
SizeOp -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SizeOp] -> ShowS
$cshowList :: [SizeOp] -> ShowS
show :: SizeOp -> String
$cshow :: SizeOp -> String
showsPrec :: Int -> SizeOp -> ShowS
$cshowsPrec :: Int -> SizeOp -> ShowS
Show)

instance Substitute SizeOp where
  substituteNames :: Map VName VName -> SizeOp -> SizeOp
substituteNames Map VName VName
substs (CmpSizeLe Name
name SizeClass
sclass SubExp
x) =
    Name -> SizeClass -> SubExp -> SizeOp
CmpSizeLe Name
name SizeClass
sclass (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs SubExp
x)
  substituteNames Map VName VName
substs (CalcNumGroups SubExp
w Name
max_num_groups SubExp
group_size) =
    SubExp -> Name -> SubExp -> SizeOp
CalcNumGroups
      (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs SubExp
w)
      Name
max_num_groups
      (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs SubExp
group_size)
  substituteNames Map VName VName
_ SizeOp
op = SizeOp
op

instance Rename SizeOp where
  rename :: SizeOp -> RenameM SizeOp
rename (CmpSizeLe Name
name SizeClass
sclass SubExp
x) =
    Name -> SizeClass -> SubExp -> SizeOp
CmpSizeLe Name
name SizeClass
sclass forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. Rename a => a -> RenameM a
rename SubExp
x
  rename (CalcNumGroups SubExp
w Name
max_num_groups SubExp
group_size) =
    SubExp -> Name -> SubExp -> SizeOp
CalcNumGroups forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. Rename a => a -> RenameM a
rename SubExp
w forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure Name
max_num_groups forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a. Rename a => a -> RenameM a
rename SubExp
group_size
  rename SizeOp
x = forall (f :: * -> *) a. Applicative f => a -> f a
pure SizeOp
x

instance IsOp SizeOp where
  safeOp :: SizeOp -> Bool
safeOp SizeOp
_ = Bool
True
  cheapOp :: SizeOp -> Bool
cheapOp SizeOp
_ = Bool
True

instance TypedOp SizeOp where
  opType :: forall {k} (t :: k) (m :: * -> *).
HasScope t m =>
SizeOp -> m [ExtType]
opType (GetSize Name
_ SizeClass
_) = forall (f :: * -> *) a. Applicative f => a -> f a
pure [forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]
  opType (GetSizeMax SizeClass
_) = forall (f :: * -> *) a. Applicative f => a -> f a
pure [forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]
  opType CmpSizeLe {} = forall (f :: * -> *) a. Applicative f => a -> f a
pure [forall shape u. PrimType -> TypeBase shape u
Prim PrimType
Bool]
  opType CalcNumGroups {} = forall (f :: * -> *) a. Applicative f => a -> f a
pure [forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]

instance AliasedOp SizeOp where
  opAliases :: SizeOp -> [Names]
opAliases SizeOp
_ = [forall a. Monoid a => a
mempty]
  consumedInOp :: SizeOp -> Names
consumedInOp SizeOp
_ = forall a. Monoid a => a
mempty

instance FreeIn SizeOp where
  freeIn' :: SizeOp -> FV
freeIn' (CmpSizeLe Name
_ SizeClass
_ SubExp
x) = forall a. FreeIn a => a -> FV
freeIn' SubExp
x
  freeIn' (CalcNumGroups SubExp
w Name
_ SubExp
group_size) = forall a. FreeIn a => a -> FV
freeIn' SubExp
w forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> FV
freeIn' SubExp
group_size
  freeIn' SizeOp
_ = forall a. Monoid a => a
mempty

instance PP.Pretty SizeOp where
  pretty :: forall ann. SizeOp -> Doc ann
pretty (GetSize Name
name SizeClass
size_class) =
    Doc ann
"get_size" forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann -> Doc ann
parens (forall a. [Doc a] -> Doc a
commasep [forall a ann. Pretty a => a -> Doc ann
pretty Name
name, forall a ann. Pretty a => a -> Doc ann
pretty SizeClass
size_class])
  pretty (GetSizeMax SizeClass
size_class) =
    Doc ann
"get_size_max" forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann -> Doc ann
parens (forall a. [Doc a] -> Doc a
commasep [forall a ann. Pretty a => a -> Doc ann
pretty SizeClass
size_class])
  pretty (CmpSizeLe Name
name SizeClass
size_class SubExp
x) =
    Doc ann
"cmp_size" forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann -> Doc ann
parens (forall a. [Doc a] -> Doc a
commasep [forall a ann. Pretty a => a -> Doc ann
pretty Name
name, forall a ann. Pretty a => a -> Doc ann
pretty SizeClass
size_class])
      forall ann. Doc ann -> Doc ann -> Doc ann
<+> Doc ann
"<="
      forall ann. Doc ann -> Doc ann -> Doc ann
<+> forall a ann. Pretty a => a -> Doc ann
pretty SubExp
x
  pretty (CalcNumGroups SubExp
w Name
max_num_groups SubExp
group_size) =
    Doc ann
"calc_num_groups" forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann -> Doc ann
parens (forall a. [Doc a] -> Doc a
commasep [forall a ann. Pretty a => a -> Doc ann
pretty SubExp
w, forall a ann. Pretty a => a -> Doc ann
pretty Name
max_num_groups, forall a ann. Pretty a => a -> Doc ann
pretty SubExp
group_size])

instance OpMetrics SizeOp where
  opMetrics :: SizeOp -> MetricsM ()
opMetrics GetSize {} = Text -> MetricsM ()
seen Text
"GetSize"
  opMetrics GetSizeMax {} = Text -> MetricsM ()
seen Text
"GetSizeMax"
  opMetrics CmpSizeLe {} = Text -> MetricsM ()
seen Text
"CmpSizeLe"
  opMetrics CalcNumGroups {} = Text -> MetricsM ()
seen Text
"CalcNumGroups"

typeCheckSizeOp :: TC.Checkable rep => SizeOp -> TC.TypeM rep ()
typeCheckSizeOp :: forall {k} (rep :: k). Checkable rep => SizeOp -> TypeM rep ()
typeCheckSizeOp GetSize {} = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
typeCheckSizeOp GetSizeMax {} = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
typeCheckSizeOp (CmpSizeLe Name
_ SizeClass
_ SubExp
x) = forall {k} (rep :: k).
Checkable rep =>
[Type] -> SubExp -> TypeM rep ()
TC.require [forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
x
typeCheckSizeOp (CalcNumGroups SubExp
w Name
_ SubExp
group_size) = do
  forall {k} (rep :: k).
Checkable rep =>
[Type] -> SubExp -> TypeM rep ()
TC.require [forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
w
  forall {k} (rep :: k).
Checkable rep =>
[Type] -> SubExp -> TypeM rep ()
TC.require [forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
group_size

-- | A host-level operation; parameterised by what else it can do.
data HostOp rep op
  = -- | A segmented operation.
    SegOp (SegOp SegLevel rep)
  | SizeOp SizeOp
  | OtherOp op
  | -- | Code to run sequentially on the GPU,
    -- in a single thread.
    GPUBody [Type] (Body rep)
  deriving (HostOp rep op -> HostOp rep op -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall k (rep :: k) op.
(RepTypes rep, Eq op) =>
HostOp rep op -> HostOp rep op -> Bool
/= :: HostOp rep op -> HostOp rep op -> Bool
$c/= :: forall k (rep :: k) op.
(RepTypes rep, Eq op) =>
HostOp rep op -> HostOp rep op -> Bool
== :: HostOp rep op -> HostOp rep op -> Bool
$c== :: forall k (rep :: k) op.
(RepTypes rep, Eq op) =>
HostOp rep op -> HostOp rep op -> Bool
Eq, HostOp rep op -> HostOp rep op -> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {k} {rep :: k} {op}.
(RepTypes rep, Ord op) =>
Eq (HostOp rep op)
forall k (rep :: k) op.
(RepTypes rep, Ord op) =>
HostOp rep op -> HostOp rep op -> Bool
forall k (rep :: k) op.
(RepTypes rep, Ord op) =>
HostOp rep op -> HostOp rep op -> Ordering
forall k (rep :: k) op.
(RepTypes rep, Ord op) =>
HostOp rep op -> HostOp rep op -> HostOp rep op
min :: HostOp rep op -> HostOp rep op -> HostOp rep op
$cmin :: forall k (rep :: k) op.
(RepTypes rep, Ord op) =>
HostOp rep op -> HostOp rep op -> HostOp rep op
max :: HostOp rep op -> HostOp rep op -> HostOp rep op
$cmax :: forall k (rep :: k) op.
(RepTypes rep, Ord op) =>
HostOp rep op -> HostOp rep op -> HostOp rep op
>= :: HostOp rep op -> HostOp rep op -> Bool
$c>= :: forall k (rep :: k) op.
(RepTypes rep, Ord op) =>
HostOp rep op -> HostOp rep op -> Bool
> :: HostOp rep op -> HostOp rep op -> Bool
$c> :: forall k (rep :: k) op.
(RepTypes rep, Ord op) =>
HostOp rep op -> HostOp rep op -> Bool
<= :: HostOp rep op -> HostOp rep op -> Bool
$c<= :: forall k (rep :: k) op.
(RepTypes rep, Ord op) =>
HostOp rep op -> HostOp rep op -> Bool
< :: HostOp rep op -> HostOp rep op -> Bool
$c< :: forall k (rep :: k) op.
(RepTypes rep, Ord op) =>
HostOp rep op -> HostOp rep op -> Bool
compare :: HostOp rep op -> HostOp rep op -> Ordering
$ccompare :: forall k (rep :: k) op.
(RepTypes rep, Ord op) =>
HostOp rep op -> HostOp rep op -> Ordering
Ord, Int -> HostOp rep op -> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall k (rep :: k) op.
(RepTypes rep, Show op) =>
Int -> HostOp rep op -> ShowS
forall k (rep :: k) op.
(RepTypes rep, Show op) =>
[HostOp rep op] -> ShowS
forall k (rep :: k) op.
(RepTypes rep, Show op) =>
HostOp rep op -> String
showList :: [HostOp rep op] -> ShowS
$cshowList :: forall k (rep :: k) op.
(RepTypes rep, Show op) =>
[HostOp rep op] -> ShowS
show :: HostOp rep op -> String
$cshow :: forall k (rep :: k) op.
(RepTypes rep, Show op) =>
HostOp rep op -> String
showsPrec :: Int -> HostOp rep op -> ShowS
$cshowsPrec :: forall k (rep :: k) op.
(RepTypes rep, Show op) =>
Int -> HostOp rep op -> ShowS
Show)

-- | A helper for defining 'TraverseOpStms'.
traverseHostOpStms ::
  Monad m =>
  OpStmsTraverser m op rep ->
  OpStmsTraverser m (HostOp rep op) rep
traverseHostOpStms :: forall {k} (m :: * -> *) op (rep :: k).
Monad m =>
OpStmsTraverser m op rep -> OpStmsTraverser m (HostOp rep op) rep
traverseHostOpStms OpStmsTraverser m op rep
_ Scope rep -> Stms rep -> m (Stms rep)
f (SegOp SegOp SegLevel rep
segop) = forall {k} (rep :: k) op. SegOp SegLevel rep -> HostOp rep op
SegOp forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (m :: * -> *) lvl (rep :: k).
Monad m =>
OpStmsTraverser m (SegOp lvl rep) rep
traverseSegOpStms Scope rep -> Stms rep -> m (Stms rep)
f SegOp SegLevel rep
segop
traverseHostOpStms OpStmsTraverser m op rep
_ Scope rep -> Stms rep -> m (Stms rep)
_ (SizeOp SizeOp
sizeop) = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) op. SizeOp -> HostOp rep op
SizeOp SizeOp
sizeop
traverseHostOpStms OpStmsTraverser m op rep
onOtherOp Scope rep -> Stms rep -> m (Stms rep)
f (OtherOp op
other) = forall {k} (rep :: k) op. op -> HostOp rep op
OtherOp forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> OpStmsTraverser m op rep
onOtherOp Scope rep -> Stms rep -> m (Stms rep)
f op
other
traverseHostOpStms OpStmsTraverser m op rep
_ Scope rep -> Stms rep -> m (Stms rep)
f (GPUBody [Type]
ts Body rep
body) = do
  Stms rep
stms <- Scope rep -> Stms rep -> m (Stms rep)
f forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Body rep -> Stms rep
bodyStms Body rep
body
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) op. [Type] -> Body rep -> HostOp rep op
GPUBody [Type]
ts forall a b. (a -> b) -> a -> b
$ Body rep
body {bodyStms :: Stms rep
bodyStms = Stms rep
stms}

instance (ASTRep rep, Substitute op) => Substitute (HostOp rep op) where
  substituteNames :: Map VName VName -> HostOp rep op -> HostOp rep op
substituteNames Map VName VName
substs (SegOp SegOp SegLevel rep
op) =
    forall {k} (rep :: k) op. SegOp SegLevel rep -> HostOp rep op
SegOp forall a b. (a -> b) -> a -> b
$ forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs SegOp SegLevel rep
op
  substituteNames Map VName VName
substs (OtherOp op
op) =
    forall {k} (rep :: k) op. op -> HostOp rep op
OtherOp forall a b. (a -> b) -> a -> b
$ forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs op
op
  substituteNames Map VName VName
substs (SizeOp SizeOp
op) =
    forall {k} (rep :: k) op. SizeOp -> HostOp rep op
SizeOp forall a b. (a -> b) -> a -> b
$ forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs SizeOp
op
  substituteNames Map VName VName
substs (GPUBody [Type]
ts Body rep
body) =
    forall {k} (rep :: k) op. [Type] -> Body rep -> HostOp rep op
GPUBody (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs [Type]
ts) (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Body rep
body)

instance (ASTRep rep, Rename op) => Rename (HostOp rep op) where
  rename :: HostOp rep op -> RenameM (HostOp rep op)
rename (SegOp SegOp SegLevel rep
op) = forall {k} (rep :: k) op. SegOp SegLevel rep -> HostOp rep op
SegOp forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. Rename a => a -> RenameM a
rename SegOp SegLevel rep
op
  rename (OtherOp op
op) = forall {k} (rep :: k) op. op -> HostOp rep op
OtherOp forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. Rename a => a -> RenameM a
rename op
op
  rename (SizeOp SizeOp
op) = forall {k} (rep :: k) op. SizeOp -> HostOp rep op
SizeOp forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. Rename a => a -> RenameM a
rename SizeOp
op
  rename (GPUBody [Type]
ts Body rep
body) = forall {k} (rep :: k) op. [Type] -> Body rep -> HostOp rep op
GPUBody forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. Rename a => a -> RenameM a
rename [Type]
ts forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a. Rename a => a -> RenameM a
rename Body rep
body

instance (ASTRep rep, IsOp op) => IsOp (HostOp rep op) where
  safeOp :: HostOp rep op -> Bool
safeOp (SegOp SegOp SegLevel rep
op) = forall op. IsOp op => op -> Bool
safeOp SegOp SegLevel rep
op
  safeOp (OtherOp op
op) = forall op. IsOp op => op -> Bool
safeOp op
op
  safeOp (SizeOp SizeOp
op) = forall op. IsOp op => op -> Bool
safeOp SizeOp
op
  safeOp GPUBody {} = Bool
True

  cheapOp :: HostOp rep op -> Bool
cheapOp (SegOp SegOp SegLevel rep
op) = forall op. IsOp op => op -> Bool
cheapOp SegOp SegLevel rep
op
  cheapOp (OtherOp op
op) = forall op. IsOp op => op -> Bool
cheapOp op
op
  cheapOp (SizeOp SizeOp
op) = forall op. IsOp op => op -> Bool
cheapOp SizeOp
op
  cheapOp (GPUBody [Type]
types Body rep
body) =
    -- Current GPUBody usage only benefits from hoisting kernels that
    -- transfer scalars to device.
    forall a. Seq a -> Bool
SQ.null (forall {k} (rep :: k). Body rep -> Stms rep
bodyStms Body rep
body) Bool -> Bool -> Bool
&& forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all ((forall a. Eq a => a -> a -> Bool
== Int
0) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank) [Type]
types

instance TypedOp op => TypedOp (HostOp rep op) where
  opType :: forall {k} (t :: k) (m :: * -> *).
HasScope t m =>
HostOp rep op -> m [ExtType]
opType (SegOp SegOp SegLevel rep
op) = forall op {k} (t :: k) (m :: * -> *).
(TypedOp op, HasScope t m) =>
op -> m [ExtType]
opType SegOp SegLevel rep
op
  opType (OtherOp op
op) = forall op {k} (t :: k) (m :: * -> *).
(TypedOp op, HasScope t m) =>
op -> m [ExtType]
opType op
op
  opType (SizeOp SizeOp
op) = forall op {k} (t :: k) (m :: * -> *).
(TypedOp op, HasScope t m) =>
op -> m [ExtType]
opType SizeOp
op
  opType (GPUBody [Type]
ts Body rep
_) =
    forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall u. [TypeBase (ShapeBase SubExp) u] -> [TypeBase ExtShape u]
staticShapes forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) [Type]
ts

instance (Aliased rep, AliasedOp op, ASTRep rep) => AliasedOp (HostOp rep op) where
  opAliases :: HostOp rep op -> [Names]
opAliases (SegOp SegOp SegLevel rep
op) = forall op. AliasedOp op => op -> [Names]
opAliases SegOp SegLevel rep
op
  opAliases (OtherOp op
op) = forall op. AliasedOp op => op -> [Names]
opAliases op
op
  opAliases (SizeOp SizeOp
op) = forall op. AliasedOp op => op -> [Names]
opAliases SizeOp
op
  opAliases (GPUBody [Type]
ts Body rep
_) = forall a b. (a -> b) -> [a] -> [b]
map (forall a b. a -> b -> a
const forall a. Monoid a => a
mempty) [Type]
ts

  consumedInOp :: HostOp rep op -> Names
consumedInOp (SegOp SegOp SegLevel rep
op) = forall op. AliasedOp op => op -> Names
consumedInOp SegOp SegLevel rep
op
  consumedInOp (OtherOp op
op) = forall op. AliasedOp op => op -> Names
consumedInOp op
op
  consumedInOp (SizeOp SizeOp
op) = forall op. AliasedOp op => op -> Names
consumedInOp SizeOp
op
  consumedInOp (GPUBody [Type]
_ Body rep
body) = forall {k} (rep :: k). Aliased rep => Body rep -> Names
consumedInBody Body rep
body

instance (ASTRep rep, FreeIn op) => FreeIn (HostOp rep op) where
  freeIn' :: HostOp rep op -> FV
freeIn' (SegOp SegOp SegLevel rep
op) = forall a. FreeIn a => a -> FV
freeIn' SegOp SegLevel rep
op
  freeIn' (OtherOp op
op) = forall a. FreeIn a => a -> FV
freeIn' op
op
  freeIn' (SizeOp SizeOp
op) = forall a. FreeIn a => a -> FV
freeIn' SizeOp
op
  freeIn' (GPUBody [Type]
ts Body rep
body) = forall a. FreeIn a => a -> FV
freeIn' [Type]
ts forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> FV
freeIn' Body rep
body

instance (CanBeAliased (Op rep), CanBeAliased op, ASTRep rep) => CanBeAliased (HostOp rep op) where
  type OpWithAliases (HostOp rep op) = HostOp (Aliases rep) (OpWithAliases op)

  addOpAliases :: AliasTable -> HostOp rep op -> OpWithAliases (HostOp rep op)
addOpAliases AliasTable
aliases (SegOp SegOp SegLevel rep
op) = forall {k} (rep :: k) op. SegOp SegLevel rep -> HostOp rep op
SegOp forall a b. (a -> b) -> a -> b
$ forall op. CanBeAliased op => AliasTable -> op -> OpWithAliases op
addOpAliases AliasTable
aliases SegOp SegLevel rep
op
  addOpAliases AliasTable
aliases (GPUBody [Type]
ts Body rep
body) = forall {k} (rep :: k) op. [Type] -> Body rep -> HostOp rep op
GPUBody [Type]
ts forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
(ASTRep rep, CanBeAliased (Op rep)) =>
AliasTable -> Body rep -> Body (Aliases rep)
Alias.analyseBody AliasTable
aliases Body rep
body
  addOpAliases AliasTable
aliases (OtherOp op
op) = forall {k} (rep :: k) op. op -> HostOp rep op
OtherOp forall a b. (a -> b) -> a -> b
$ forall op. CanBeAliased op => AliasTable -> op -> OpWithAliases op
addOpAliases AliasTable
aliases op
op
  addOpAliases AliasTable
_ (SizeOp SizeOp
op) = forall {k} (rep :: k) op. SizeOp -> HostOp rep op
SizeOp SizeOp
op

  removeOpAliases :: OpWithAliases (HostOp rep op) -> HostOp rep op
removeOpAliases (SegOp SegOp SegLevel (Aliases rep)
op) = forall {k} (rep :: k) op. SegOp SegLevel rep -> HostOp rep op
SegOp forall a b. (a -> b) -> a -> b
$ forall op. CanBeAliased op => OpWithAliases op -> op
removeOpAliases SegOp SegLevel (Aliases rep)
op
  removeOpAliases (OtherOp OpWithAliases op
op) = forall {k} (rep :: k) op. op -> HostOp rep op
OtherOp forall a b. (a -> b) -> a -> b
$ forall op. CanBeAliased op => OpWithAliases op -> op
removeOpAliases OpWithAliases op
op
  removeOpAliases (SizeOp SizeOp
op) = forall {k} (rep :: k) op. SizeOp -> HostOp rep op
SizeOp SizeOp
op
  removeOpAliases (GPUBody [Type]
ts Body (Aliases rep)
body) = forall {k} (rep :: k) op. [Type] -> Body rep -> HostOp rep op
GPUBody [Type]
ts forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
CanBeAliased (Op rep) =>
Body (Aliases rep) -> Body rep
removeBodyAliases Body (Aliases rep)
body

instance (CanBeWise (Op rep), CanBeWise op, ASTRep rep) => CanBeWise (HostOp rep op) where
  type OpWithWisdom (HostOp rep op) = HostOp (Wise rep) (OpWithWisdom op)

  removeOpWisdom :: OpWithWisdom (HostOp rep op) -> HostOp rep op
removeOpWisdom (SegOp SegOp SegLevel (Wise rep)
op) = forall {k} (rep :: k) op. SegOp SegLevel rep -> HostOp rep op
SegOp forall a b. (a -> b) -> a -> b
$ forall op. CanBeWise op => OpWithWisdom op -> op
removeOpWisdom SegOp SegLevel (Wise rep)
op
  removeOpWisdom (OtherOp OpWithWisdom op
op) = forall {k} (rep :: k) op. op -> HostOp rep op
OtherOp forall a b. (a -> b) -> a -> b
$ forall op. CanBeWise op => OpWithWisdom op -> op
removeOpWisdom OpWithWisdom op
op
  removeOpWisdom (SizeOp SizeOp
op) = forall {k} (rep :: k) op. SizeOp -> HostOp rep op
SizeOp SizeOp
op
  removeOpWisdom (GPUBody [Type]
ts Body (Wise rep)
body) = forall {k} (rep :: k) op. [Type] -> Body rep -> HostOp rep op
GPUBody [Type]
ts forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
CanBeWise (Op rep) =>
Body (Wise rep) -> Body rep
removeBodyWisdom Body (Wise rep)
body

  addOpWisdom :: HostOp rep op -> OpWithWisdom (HostOp rep op)
addOpWisdom (SegOp SegOp SegLevel rep
op) = forall {k} (rep :: k) op. SegOp SegLevel rep -> HostOp rep op
SegOp forall a b. (a -> b) -> a -> b
$ forall op. CanBeWise op => op -> OpWithWisdom op
addOpWisdom SegOp SegLevel rep
op
  addOpWisdom (OtherOp op
op) = forall {k} (rep :: k) op. op -> HostOp rep op
OtherOp forall a b. (a -> b) -> a -> b
$ forall op. CanBeWise op => op -> OpWithWisdom op
addOpWisdom op
op
  addOpWisdom (SizeOp SizeOp
op) = forall {k} (rep :: k) op. SizeOp -> HostOp rep op
SizeOp SizeOp
op
  addOpWisdom (GPUBody [Type]
ts Body rep
body) = forall {k} (rep :: k) op. [Type] -> Body rep -> HostOp rep op
GPUBody [Type]
ts forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Informing rep => Body rep -> Body (Wise rep)
informBody Body rep
body

instance (ASTRep rep, ST.IndexOp op) => ST.IndexOp (HostOp rep op) where
  indexOp :: forall {k} (rep :: k).
(ASTRep rep, IndexOp (Op rep)) =>
SymbolTable rep
-> Int -> HostOp rep op -> [TPrimExp Int64 VName] -> Maybe Indexed
indexOp SymbolTable rep
vtable Int
k (SegOp SegOp SegLevel rep
op) [TPrimExp Int64 VName]
is = forall op {k} (rep :: k).
(IndexOp op, ASTRep rep, IndexOp (Op rep)) =>
SymbolTable rep
-> Int -> op -> [TPrimExp Int64 VName] -> Maybe Indexed
ST.indexOp SymbolTable rep
vtable Int
k SegOp SegLevel rep
op [TPrimExp Int64 VName]
is
  indexOp SymbolTable rep
vtable Int
k (OtherOp op
op) [TPrimExp Int64 VName]
is = forall op {k} (rep :: k).
(IndexOp op, ASTRep rep, IndexOp (Op rep)) =>
SymbolTable rep
-> Int -> op -> [TPrimExp Int64 VName] -> Maybe Indexed
ST.indexOp SymbolTable rep
vtable Int
k op
op [TPrimExp Int64 VName]
is
  indexOp SymbolTable rep
_ Int
_ HostOp rep op
_ [TPrimExp Int64 VName]
_ = forall a. Maybe a
Nothing

instance (PrettyRep rep, PP.Pretty op) => PP.Pretty (HostOp rep op) where
  pretty :: forall ann. HostOp rep op -> Doc ann
pretty (SegOp SegOp SegLevel rep
op) = forall a ann. Pretty a => a -> Doc ann
pretty SegOp SegLevel rep
op
  pretty (OtherOp op
op) = forall a ann. Pretty a => a -> Doc ann
pretty op
op
  pretty (SizeOp SizeOp
op) = forall a ann. Pretty a => a -> Doc ann
pretty SizeOp
op
  pretty (GPUBody [Type]
ts Body rep
body) =
    Doc ann
"gpu" forall ann. Doc ann -> Doc ann -> Doc ann
<+> forall ann. Doc ann
PP.colon forall ann. Doc ann -> Doc ann -> Doc ann
<+> forall a. [Doc a] -> Doc a
ppTuple' (forall a b. (a -> b) -> [a] -> [b]
map forall a ann. Pretty a => a -> Doc ann
pretty [Type]
ts) forall ann. Doc ann -> Doc ann -> Doc ann
<+> forall a. Doc a -> Doc a -> Doc a -> Doc a
PP.nestedBlock Doc ann
"{" Doc ann
"}" (forall a ann. Pretty a => a -> Doc ann
pretty Body rep
body)

instance (OpMetrics (Op rep), OpMetrics op) => OpMetrics (HostOp rep op) where
  opMetrics :: HostOp rep op -> MetricsM ()
opMetrics (SegOp SegOp SegLevel rep
op) = forall op. OpMetrics op => op -> MetricsM ()
opMetrics SegOp SegLevel rep
op
  opMetrics (OtherOp op
op) = forall op. OpMetrics op => op -> MetricsM ()
opMetrics op
op
  opMetrics (SizeOp SizeOp
op) = forall op. OpMetrics op => op -> MetricsM ()
opMetrics SizeOp
op
  opMetrics (GPUBody [Type]
_ Body rep
body) = Text -> MetricsM () -> MetricsM ()
inside Text
"GPUBody" forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
OpMetrics (Op rep) =>
Body rep -> MetricsM ()
bodyMetrics Body rep
body

checkSegLevel ::
  TC.Checkable rep =>
  Maybe SegLevel ->
  SegLevel ->
  TC.TypeM rep ()
checkSegLevel :: forall {k} (rep :: k).
Checkable rep =>
Maybe SegLevel -> SegLevel -> TypeM rep ()
checkSegLevel Maybe SegLevel
Nothing SegLevel
lvl = do
  forall {k} (rep :: k).
Checkable rep =>
[Type] -> SubExp -> TypeM rep ()
TC.require [forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] forall a b. (a -> b) -> a -> b
$ forall {k} (u :: k) e. Count u e -> e
unCount forall a b. (a -> b) -> a -> b
$ SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl
  forall {k} (rep :: k).
Checkable rep =>
[Type] -> SubExp -> TypeM rep ()
TC.require [forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] forall a b. (a -> b) -> a -> b
$ forall {k} (u :: k) e. Count u e -> e
unCount forall a b. (a -> b) -> a -> b
$ SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl
checkSegLevel (Just SegThread {}) SegLevel
_ =
  forall {k} (rep :: k) a. ErrorCase rep -> TypeM rep a
TC.bad forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Text -> ErrorCase rep
TC.TypeError Text
"SegOps cannot occur when already at thread level."
checkSegLevel (Just SegLevel
x) SegLevel
y
  | SegLevel
x forall a. Eq a => a -> a -> Bool
== SegLevel
y = forall {k} (rep :: k) a. ErrorCase rep -> TypeM rep a
TC.bad forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Text -> ErrorCase rep
TC.TypeError forall a b. (a -> b) -> a -> b
$ Text
"Already at at level " forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => a -> Text
prettyText SegLevel
x
  | SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
x forall a. Eq a => a -> a -> Bool
/= SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
y Bool -> Bool -> Bool
|| SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
x forall a. Eq a => a -> a -> Bool
/= SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
y =
      forall {k} (rep :: k) a. ErrorCase rep -> TypeM rep a
TC.bad forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Text -> ErrorCase rep
TC.TypeError Text
"Physical layout for SegLevel does not match parent SegLevel."
  | Bool
otherwise =
      forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

typeCheckHostOp ::
  TC.Checkable rep =>
  (SegLevel -> OpWithAliases (Op rep) -> TC.TypeM rep ()) ->
  Maybe SegLevel ->
  (op -> TC.TypeM rep ()) ->
  HostOp (Aliases rep) op ->
  TC.TypeM rep ()
typeCheckHostOp :: forall {k} (rep :: k) op.
Checkable rep =>
(SegLevel -> OpWithAliases (Op rep) -> TypeM rep ())
-> Maybe SegLevel
-> (op -> TypeM rep ())
-> HostOp (Aliases rep) op
-> TypeM rep ()
typeCheckHostOp SegLevel -> OpWithAliases (Op rep) -> TypeM rep ()
checker Maybe SegLevel
lvl op -> TypeM rep ()
_ (SegOp SegOp SegLevel (Aliases rep)
op) =
  forall {k} (rep :: k) a.
(OpWithAliases (Op rep) -> TypeM rep ())
-> TypeM rep a -> TypeM rep a
TC.checkOpWith (SegLevel -> OpWithAliases (Op rep) -> TypeM rep ()
checker forall a b. (a -> b) -> a -> b
$ forall {k} lvl (rep :: k). SegOp lvl rep -> lvl
segLevel SegOp SegLevel (Aliases rep)
op) forall a b. (a -> b) -> a -> b
$
    forall {k} (rep :: k) lvl.
Checkable rep =>
(lvl -> TypeM rep ()) -> SegOp lvl (Aliases rep) -> TypeM rep ()
typeCheckSegOp (forall {k} (rep :: k).
Checkable rep =>
Maybe SegLevel -> SegLevel -> TypeM rep ()
checkSegLevel Maybe SegLevel
lvl) SegOp SegLevel (Aliases rep)
op
typeCheckHostOp SegLevel -> OpWithAliases (Op rep) -> TypeM rep ()
_ Just {} op -> TypeM rep ()
_ GPUBody {} =
  forall {k} (rep :: k) a. ErrorCase rep -> TypeM rep a
TC.bad forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Text -> ErrorCase rep
TC.TypeError Text
"GPUBody may not be nested in SegOps."
typeCheckHostOp SegLevel -> OpWithAliases (Op rep) -> TypeM rep ()
_ Maybe SegLevel
_ op -> TypeM rep ()
f (OtherOp op
op) = op -> TypeM rep ()
f op
op
typeCheckHostOp SegLevel -> OpWithAliases (Op rep) -> TypeM rep ()
_ Maybe SegLevel
_ op -> TypeM rep ()
_ (SizeOp SizeOp
op) = forall {k} (rep :: k). Checkable rep => SizeOp -> TypeM rep ()
typeCheckSizeOp SizeOp
op
typeCheckHostOp SegLevel -> OpWithAliases (Op rep) -> TypeM rep ()
_ Maybe SegLevel
Nothing op -> TypeM rep ()
_ (GPUBody [Type]
ts Body (Aliases rep)
body) = do
  forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ forall {k} (rep :: k) u.
Checkable rep =>
TypeBase (ShapeBase SubExp) u -> TypeM rep ()
TC.checkType [Type]
ts
  forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
Checkable rep =>
Body (Aliases rep) -> TypeM rep [Names]
TC.checkBody Body (Aliases rep)
body
  [Type]
body_ts <-
    forall {k} (rep :: k) (m :: * -> *) a.
ExtendedScope rep m a -> Scope rep -> m a
extendedScope
      (forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall {k} (t :: k) (m :: * -> *).
HasScope t m =>
SubExpRes -> m Type
subExpResType (forall {k} (rep :: k). Body rep -> Result
bodyResult Body (Aliases rep)
body))
      (forall {k} (rep :: k) a. Scoped rep a => a -> Scope rep
scopeOf (forall {k} (rep :: k). Body rep -> Stms rep
bodyStms Body (Aliases rep)
body))
  forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type]
body_ts forall a. Eq a => a -> a -> Bool
== [Type]
ts) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k) a. ErrorCase rep -> TypeM rep a
TC.bad forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Text -> ErrorCase rep
TC.TypeError forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Text] -> Text
T.unlines forall a b. (a -> b) -> a -> b
$
    [ Text
"Expected type: " forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => [a] -> Text
prettyTuple [Type]
ts,
      Text
"Got body type: " forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => [a] -> Text
prettyTuple [Type]
body_ts
    ]