{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Futhark.IR.Kernels.Kernel
  ( -- * Size operations
    SizeOp(..)

    -- * Host operations
  , HostOp(..)
  , typeCheckHostOp

    -- * SegOp refinements
  , SegLevel(..)

    -- * Reexports
  , module Futhark.IR.Kernels.Sizes
  , module Futhark.IR.SegOp
  )
where

import Futhark.IR
import qualified Futhark.Analysis.SymbolTable as ST
import qualified Futhark.Util.Pretty as PP
import Futhark.Util.Pretty
  ((</>), (<+>), ppr, commasep, parens, text)
import Futhark.Transform.Substitute
import Futhark.Transform.Rename
import Futhark.Optimise.Simplify.Lore
import qualified Futhark.Optimise.Simplify.Engine as Engine
import Futhark.IR.Prop.Aliases
import Futhark.IR.Aliases (Aliases)
import Futhark.IR.SegOp
import Futhark.IR.Kernels.Sizes
import qualified Futhark.TypeCheck as TC
import Futhark.Analysis.Metrics

-- | At which level the *body* of a t'SegOp' executes.
data SegLevel = SegThread { SegLevel -> Count NumGroups SubExp
segNumGroups :: Count NumGroups SubExp
                          , SegLevel -> Count GroupSize SubExp
segGroupSize :: Count GroupSize SubExp
                          , SegLevel -> SegVirt
segVirt :: SegVirt }
              | SegGroup { segNumGroups :: Count NumGroups SubExp
                         , segGroupSize :: Count GroupSize SubExp
                         , segVirt :: SegVirt }
              deriving (SegLevel -> SegLevel -> Bool
(SegLevel -> SegLevel -> Bool)
-> (SegLevel -> SegLevel -> Bool) -> Eq SegLevel
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SegLevel -> SegLevel -> Bool
$c/= :: SegLevel -> SegLevel -> Bool
== :: SegLevel -> SegLevel -> Bool
$c== :: SegLevel -> SegLevel -> Bool
Eq, Eq SegLevel
Eq SegLevel
-> (SegLevel -> SegLevel -> Ordering)
-> (SegLevel -> SegLevel -> Bool)
-> (SegLevel -> SegLevel -> Bool)
-> (SegLevel -> SegLevel -> Bool)
-> (SegLevel -> SegLevel -> Bool)
-> (SegLevel -> SegLevel -> SegLevel)
-> (SegLevel -> SegLevel -> SegLevel)
-> Ord SegLevel
SegLevel -> SegLevel -> Bool
SegLevel -> SegLevel -> Ordering
SegLevel -> SegLevel -> SegLevel
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: SegLevel -> SegLevel -> SegLevel
$cmin :: SegLevel -> SegLevel -> SegLevel
max :: SegLevel -> SegLevel -> SegLevel
$cmax :: SegLevel -> SegLevel -> SegLevel
>= :: SegLevel -> SegLevel -> Bool
$c>= :: SegLevel -> SegLevel -> Bool
> :: SegLevel -> SegLevel -> Bool
$c> :: SegLevel -> SegLevel -> Bool
<= :: SegLevel -> SegLevel -> Bool
$c<= :: SegLevel -> SegLevel -> Bool
< :: SegLevel -> SegLevel -> Bool
$c< :: SegLevel -> SegLevel -> Bool
compare :: SegLevel -> SegLevel -> Ordering
$ccompare :: SegLevel -> SegLevel -> Ordering
$cp1Ord :: Eq SegLevel
Ord, Int -> SegLevel -> ShowS
[SegLevel] -> ShowS
SegLevel -> String
(Int -> SegLevel -> ShowS)
-> (SegLevel -> String) -> ([SegLevel] -> ShowS) -> Show SegLevel
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SegLevel] -> ShowS
$cshowList :: [SegLevel] -> ShowS
show :: SegLevel -> String
$cshow :: SegLevel -> String
showsPrec :: Int -> SegLevel -> ShowS
$cshowsPrec :: Int -> SegLevel -> ShowS
Show)

instance PP.Pretty SegLevel where
  ppr :: SegLevel -> Doc
ppr SegLevel
lvl =
    Doc
lvl' Doc -> Doc -> Doc
</>
    Doc -> Doc
PP.parens (String -> Doc
text String
"#groups=" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Count NumGroups SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr (SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.semi Doc -> Doc -> Doc
<+>
               String -> Doc
text String
"groupsize=" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Count GroupSize SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr (SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<>
               case SegLevel -> SegVirt
segVirt SegLevel
lvl of
                 SegVirt
SegNoVirt -> Doc
forall a. Monoid a => a
mempty
                 SegVirt
SegNoVirtFull -> Doc
PP.semi Doc -> Doc -> Doc
<+> String -> Doc
text String
"full"
                 SegVirt
SegVirt -> Doc
PP.semi Doc -> Doc -> Doc
<+> String -> Doc
text String
"virtualise")

    where lvl' :: Doc
lvl' = case SegLevel
lvl of SegThread{} -> Doc
"_thread"
                             SegGroup{} -> Doc
"_group"

instance Engine.Simplifiable SegLevel where
  simplify :: SegLevel -> SimpleM lore SegLevel
simplify (SegThread Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegVirt
virt) =
    Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegThread (Count NumGroups SubExp
 -> Count GroupSize SubExp -> SegVirt -> SegLevel)
-> SimpleM lore (Count NumGroups SubExp)
-> SimpleM lore (Count GroupSize SubExp -> SegVirt -> SegLevel)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExp -> SimpleM lore SubExp)
-> Count NumGroups SubExp -> SimpleM lore (Count NumGroups SubExp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify Count NumGroups SubExp
num_groups SimpleM lore (Count GroupSize SubExp -> SegVirt -> SegLevel)
-> SimpleM lore (Count GroupSize SubExp)
-> SimpleM lore (SegVirt -> SegLevel)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*>
    (SubExp -> SimpleM lore SubExp)
-> Count GroupSize SubExp -> SimpleM lore (Count GroupSize SubExp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify Count GroupSize SubExp
group_size SimpleM lore (SegVirt -> SegLevel)
-> SimpleM lore SegVirt -> SimpleM lore SegLevel
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegVirt -> SimpleM lore SegVirt
forall (f :: * -> *) a. Applicative f => a -> f a
pure SegVirt
virt
  simplify (SegGroup Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegVirt
virt) =
    Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegGroup (Count NumGroups SubExp
 -> Count GroupSize SubExp -> SegVirt -> SegLevel)
-> SimpleM lore (Count NumGroups SubExp)
-> SimpleM lore (Count GroupSize SubExp -> SegVirt -> SegLevel)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExp -> SimpleM lore SubExp)
-> Count NumGroups SubExp -> SimpleM lore (Count NumGroups SubExp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify Count NumGroups SubExp
num_groups SimpleM lore (Count GroupSize SubExp -> SegVirt -> SegLevel)
-> SimpleM lore (Count GroupSize SubExp)
-> SimpleM lore (SegVirt -> SegLevel)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*>
    (SubExp -> SimpleM lore SubExp)
-> Count GroupSize SubExp -> SimpleM lore (Count GroupSize SubExp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify Count GroupSize SubExp
group_size SimpleM lore (SegVirt -> SegLevel)
-> SimpleM lore SegVirt -> SimpleM lore SegLevel
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegVirt -> SimpleM lore SegVirt
forall (f :: * -> *) a. Applicative f => a -> f a
pure SegVirt
virt

instance Substitute SegLevel where
  substituteNames :: Map VName VName -> SegLevel -> SegLevel
substituteNames Map VName VName
substs (SegThread Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegVirt
virt) =
    Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegThread
    (Map VName VName -> Count NumGroups SubExp -> Count NumGroups SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Count NumGroups SubExp
num_groups) (Map VName VName -> Count GroupSize SubExp -> Count GroupSize SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Count GroupSize SubExp
group_size) SegVirt
virt
  substituteNames Map VName VName
substs (SegGroup Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegVirt
virt) =
    Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegGroup
    (Map VName VName -> Count NumGroups SubExp -> Count NumGroups SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Count NumGroups SubExp
num_groups) (Map VName VName -> Count GroupSize SubExp -> Count GroupSize SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Count GroupSize SubExp
group_size) SegVirt
virt

instance Rename SegLevel where
  rename :: SegLevel -> RenameM SegLevel
rename = SegLevel -> RenameM SegLevel
forall a. Substitute a => a -> RenameM a
substituteRename

instance FreeIn SegLevel where
  freeIn' :: SegLevel -> FV
freeIn' (SegThread Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegVirt
_) =
    Count NumGroups SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn' Count NumGroups SubExp
num_groups FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> Count GroupSize SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn' Count GroupSize SubExp
group_size
  freeIn' (SegGroup Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegVirt
_) =
    Count NumGroups SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn' Count NumGroups SubExp
num_groups FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> Count GroupSize SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn' Count GroupSize SubExp
group_size

-- | A simple size-level query or computation.
data SizeOp
  = SplitSpace SplitOrdering SubExp SubExp SubExp
    -- ^ @SplitSpace o w i elems_per_thread@.
    --
    -- Computes how to divide array elements to
    -- threads in a kernel.  Returns the number of
    -- elements in the chunk that the current thread
    -- should take.
    --
    -- @w@ is the length of the outer dimension in
    -- the array. @i@ is the current thread
    -- index. Each thread takes at most
    -- @elems_per_thread@ elements.
    --
    -- If the order @o@ is 'SplitContiguous', thread with index @i@
    -- should receive elements
    -- @i*elems_per_tread, i*elems_per_thread + 1,
    -- ..., i*elems_per_thread + (elems_per_thread-1)@.
    --
    -- If the order @o@ is @'SplitStrided' stride@,
    -- the thread will receive elements @i,
    -- i+stride, i+2*stride, ...,
    -- i+(elems_per_thread-1)*stride@.
  | GetSize Name SizeClass
    -- ^ Produce some runtime-configurable size.
  | GetSizeMax SizeClass
    -- ^ The maximum size of some class.
  | CmpSizeLe Name SizeClass SubExp
    -- ^ Compare size (likely a threshold) with some integer value.
  | CalcNumGroups SubExp Name SubExp
    -- ^ @CalcNumGroups w max_num_groups group_size@ calculates the
    -- number of GPU workgroups to use for an input of the given size.
    -- The @Name@ is a size name.  Note that @w@ is an i64 to avoid
    -- overflow issues.
  deriving (SizeOp -> SizeOp -> Bool
(SizeOp -> SizeOp -> Bool)
-> (SizeOp -> SizeOp -> Bool) -> Eq SizeOp
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SizeOp -> SizeOp -> Bool
$c/= :: SizeOp -> SizeOp -> Bool
== :: SizeOp -> SizeOp -> Bool
$c== :: SizeOp -> SizeOp -> Bool
Eq, Eq SizeOp
Eq SizeOp
-> (SizeOp -> SizeOp -> Ordering)
-> (SizeOp -> SizeOp -> Bool)
-> (SizeOp -> SizeOp -> Bool)
-> (SizeOp -> SizeOp -> Bool)
-> (SizeOp -> SizeOp -> Bool)
-> (SizeOp -> SizeOp -> SizeOp)
-> (SizeOp -> SizeOp -> SizeOp)
-> Ord SizeOp
SizeOp -> SizeOp -> Bool
SizeOp -> SizeOp -> Ordering
SizeOp -> SizeOp -> SizeOp
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: SizeOp -> SizeOp -> SizeOp
$cmin :: SizeOp -> SizeOp -> SizeOp
max :: SizeOp -> SizeOp -> SizeOp
$cmax :: SizeOp -> SizeOp -> SizeOp
>= :: SizeOp -> SizeOp -> Bool
$c>= :: SizeOp -> SizeOp -> Bool
> :: SizeOp -> SizeOp -> Bool
$c> :: SizeOp -> SizeOp -> Bool
<= :: SizeOp -> SizeOp -> Bool
$c<= :: SizeOp -> SizeOp -> Bool
< :: SizeOp -> SizeOp -> Bool
$c< :: SizeOp -> SizeOp -> Bool
compare :: SizeOp -> SizeOp -> Ordering
$ccompare :: SizeOp -> SizeOp -> Ordering
$cp1Ord :: Eq SizeOp
Ord, Int -> SizeOp -> ShowS
[SizeOp] -> ShowS
SizeOp -> String
(Int -> SizeOp -> ShowS)
-> (SizeOp -> String) -> ([SizeOp] -> ShowS) -> Show SizeOp
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SizeOp] -> ShowS
$cshowList :: [SizeOp] -> ShowS
show :: SizeOp -> String
$cshow :: SizeOp -> String
showsPrec :: Int -> SizeOp -> ShowS
$cshowsPrec :: Int -> SizeOp -> ShowS
Show)

instance Substitute SizeOp where
  substituteNames :: Map VName VName -> SizeOp -> SizeOp
substituteNames Map VName VName
subst (SplitSpace SplitOrdering
o SubExp
w SubExp
i SubExp
elems_per_thread) =
    SplitOrdering -> SubExp -> SubExp -> SubExp -> SizeOp
SplitSpace
    (Map VName VName -> SplitOrdering -> SplitOrdering
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst SplitOrdering
o)
    (Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst SubExp
w)
    (Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst SubExp
i)
    (Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst SubExp
elems_per_thread)
  substituteNames Map VName VName
substs (CmpSizeLe Name
name SizeClass
sclass SubExp
x) =
    Name -> SizeClass -> SubExp -> SizeOp
CmpSizeLe Name
name SizeClass
sclass (Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs SubExp
x)
  substituteNames Map VName VName
substs (CalcNumGroups SubExp
w Name
max_num_groups SubExp
group_size) =
    SubExp -> Name -> SubExp -> SizeOp
CalcNumGroups
    (Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs SubExp
w)
    Name
max_num_groups
    (Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs SubExp
group_size)
  substituteNames Map VName VName
_ SizeOp
op = SizeOp
op

instance Rename SizeOp where
  rename :: SizeOp -> RenameM SizeOp
rename (SplitSpace SplitOrdering
o SubExp
w SubExp
i SubExp
elems_per_thread) =
    SplitOrdering -> SubExp -> SubExp -> SubExp -> SizeOp
SplitSpace
    (SplitOrdering -> SubExp -> SubExp -> SubExp -> SizeOp)
-> RenameM SplitOrdering
-> RenameM (SubExp -> SubExp -> SubExp -> SizeOp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SplitOrdering -> RenameM SplitOrdering
forall a. Rename a => a -> RenameM a
rename SplitOrdering
o
    RenameM (SubExp -> SubExp -> SubExp -> SizeOp)
-> RenameM SubExp -> RenameM (SubExp -> SubExp -> SizeOp)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SubExp -> RenameM SubExp
forall a. Rename a => a -> RenameM a
rename SubExp
w
    RenameM (SubExp -> SubExp -> SizeOp)
-> RenameM SubExp -> RenameM (SubExp -> SizeOp)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SubExp -> RenameM SubExp
forall a. Rename a => a -> RenameM a
rename SubExp
i
    RenameM (SubExp -> SizeOp) -> RenameM SubExp -> RenameM SizeOp
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SubExp -> RenameM SubExp
forall a. Rename a => a -> RenameM a
rename SubExp
elems_per_thread
  rename (CmpSizeLe Name
name SizeClass
sclass SubExp
x) =
    Name -> SizeClass -> SubExp -> SizeOp
CmpSizeLe Name
name SizeClass
sclass (SubExp -> SizeOp) -> RenameM SubExp -> RenameM SizeOp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> RenameM SubExp
forall a. Rename a => a -> RenameM a
rename SubExp
x
  rename (CalcNumGroups SubExp
w Name
max_num_groups SubExp
group_size) =
    SubExp -> Name -> SubExp -> SizeOp
CalcNumGroups (SubExp -> Name -> SubExp -> SizeOp)
-> RenameM SubExp -> RenameM (Name -> SubExp -> SizeOp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> RenameM SubExp
forall a. Rename a => a -> RenameM a
rename SubExp
w RenameM (Name -> SubExp -> SizeOp)
-> RenameM Name -> RenameM (SubExp -> SizeOp)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Name -> RenameM Name
forall (f :: * -> *) a. Applicative f => a -> f a
pure Name
max_num_groups RenameM (SubExp -> SizeOp) -> RenameM SubExp -> RenameM SizeOp
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SubExp -> RenameM SubExp
forall a. Rename a => a -> RenameM a
rename SubExp
group_size
  rename SizeOp
x = SizeOp -> RenameM SizeOp
forall (f :: * -> *) a. Applicative f => a -> f a
pure SizeOp
x

instance IsOp SizeOp where
  safeOp :: SizeOp -> Bool
safeOp SizeOp
_ = Bool
True
  cheapOp :: SizeOp -> Bool
cheapOp SizeOp
_ = Bool
True

instance TypedOp SizeOp where
  opType :: SizeOp -> m [ExtType]
opType SplitSpace{} = [ExtType] -> m [ExtType]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [PrimType -> ExtType
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32]
  opType (GetSize Name
_ SizeClass
_) = [ExtType] -> m [ExtType]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [PrimType -> ExtType
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32]
  opType (GetSizeMax SizeClass
_) = [ExtType] -> m [ExtType]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [PrimType -> ExtType
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32]
  opType CmpSizeLe{} = [ExtType] -> m [ExtType]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [PrimType -> ExtType
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
Bool]
  opType CalcNumGroups{} = [ExtType] -> m [ExtType]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [PrimType -> ExtType
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32]

instance AliasedOp SizeOp where
  opAliases :: SizeOp -> [Names]
opAliases SizeOp
_ = [Names
forall a. Monoid a => a
mempty]
  consumedInOp :: SizeOp -> Names
consumedInOp SizeOp
_ = Names
forall a. Monoid a => a
mempty

instance FreeIn SizeOp where
  freeIn' :: SizeOp -> FV
freeIn' (SplitSpace SplitOrdering
o SubExp
w SubExp
i SubExp
elems_per_thread) =
    SplitOrdering -> FV
forall a. FreeIn a => a -> FV
freeIn' SplitOrdering
o FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> [SubExp] -> FV
forall a. FreeIn a => a -> FV
freeIn' [SubExp
w, SubExp
i, SubExp
elems_per_thread]
  freeIn' (CmpSizeLe Name
_ SizeClass
_ SubExp
x) = SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn' SubExp
x
  freeIn' (CalcNumGroups SubExp
w Name
_ SubExp
group_size) = SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn' SubExp
w FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn' SubExp
group_size
  freeIn' SizeOp
_ = FV
forall a. Monoid a => a
mempty

instance PP.Pretty SizeOp where
  ppr :: SizeOp -> Doc
ppr (SplitSpace SplitOrdering
o SubExp
w SubExp
i SubExp
elems_per_thread) =
    String -> Doc
text String
"splitSpace" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
suff Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<>
    Doc -> Doc
parens ([Doc] -> Doc
commasep [SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
w, SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
i, SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
elems_per_thread])
    where suff :: Doc
suff = case SplitOrdering
o of SplitOrdering
SplitContiguous     -> Doc
forall a. Monoid a => a
mempty
                           SplitStrided SubExp
stride -> String -> Doc
text String
"Strided" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens (SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
stride)

  ppr (GetSize Name
name SizeClass
size_class) =
    String -> Doc
text String
"get_size" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens ([Doc] -> Doc
commasep [Name -> Doc
forall a. Pretty a => a -> Doc
ppr Name
name, SizeClass -> Doc
forall a. Pretty a => a -> Doc
ppr SizeClass
size_class])

  ppr (GetSizeMax SizeClass
size_class) =
    String -> Doc
text String
"get_size_max" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens ([Doc] -> Doc
commasep [SizeClass -> Doc
forall a. Pretty a => a -> Doc
ppr SizeClass
size_class])

  ppr (CmpSizeLe Name
name SizeClass
size_class SubExp
x) =
    String -> Doc
text String
"get_size" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens ([Doc] -> Doc
commasep [Name -> Doc
forall a. Pretty a => a -> Doc
ppr Name
name, SizeClass -> Doc
forall a. Pretty a => a -> Doc
ppr SizeClass
size_class]) Doc -> Doc -> Doc
<+>
    String -> Doc
text String
"<=" Doc -> Doc -> Doc
<+> SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
x

  ppr (CalcNumGroups SubExp
w Name
max_num_groups SubExp
group_size) =
    String -> Doc
text String
"calc_num_groups" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens ([Doc] -> Doc
commasep [SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
w, Name -> Doc
forall a. Pretty a => a -> Doc
ppr Name
max_num_groups, SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
group_size])

instance OpMetrics SizeOp where
  opMetrics :: SizeOp -> MetricsM ()
opMetrics SplitSpace{} = Text -> MetricsM ()
seen Text
"SplitSpace"
  opMetrics GetSize{} = Text -> MetricsM ()
seen Text
"GetSize"
  opMetrics GetSizeMax{} = Text -> MetricsM ()
seen Text
"GetSizeMax"
  opMetrics CmpSizeLe{} = Text -> MetricsM ()
seen Text
"CmpSizeLe"
  opMetrics CalcNumGroups{} = Text -> MetricsM ()
seen Text
"CalcNumGroups"

typeCheckSizeOp :: TC.Checkable lore => SizeOp -> TC.TypeM lore ()
typeCheckSizeOp :: SizeOp -> TypeM lore ()
typeCheckSizeOp (SplitSpace SplitOrdering
o SubExp
w SubExp
i SubExp
elems_per_thread) = do
  case SplitOrdering
o of
    SplitOrdering
SplitContiguous     -> () -> TypeM lore ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    SplitStrided SubExp
stride -> [Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32] SubExp
stride
  (SubExp -> TypeM lore ()) -> [SubExp] -> TypeM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ([Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32]) [SubExp
w, SubExp
i, SubExp
elems_per_thread]
typeCheckSizeOp GetSize{} = () -> TypeM lore ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
typeCheckSizeOp GetSizeMax{} = () -> TypeM lore ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
typeCheckSizeOp (CmpSizeLe Name
_ SizeClass
_ SubExp
x) = [Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32] SubExp
x
typeCheckSizeOp (CalcNumGroups SubExp
w Name
_ SubExp
group_size) = do [Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
w
                                                    [Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32] SubExp
group_size

-- | A host-level operation; parameterised by what else it can do.
data HostOp lore op
  = SegOp (SegOp SegLevel lore)
    -- ^ A segmented operation.
  | SizeOp SizeOp
  | OtherOp op
  deriving (HostOp lore op -> HostOp lore op -> Bool
(HostOp lore op -> HostOp lore op -> Bool)
-> (HostOp lore op -> HostOp lore op -> Bool)
-> Eq (HostOp lore op)
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall lore op.
(Decorations lore, Eq op) =>
HostOp lore op -> HostOp lore op -> Bool
/= :: HostOp lore op -> HostOp lore op -> Bool
$c/= :: forall lore op.
(Decorations lore, Eq op) =>
HostOp lore op -> HostOp lore op -> Bool
== :: HostOp lore op -> HostOp lore op -> Bool
$c== :: forall lore op.
(Decorations lore, Eq op) =>
HostOp lore op -> HostOp lore op -> Bool
Eq, Eq (HostOp lore op)
Eq (HostOp lore op)
-> (HostOp lore op -> HostOp lore op -> Ordering)
-> (HostOp lore op -> HostOp lore op -> Bool)
-> (HostOp lore op -> HostOp lore op -> Bool)
-> (HostOp lore op -> HostOp lore op -> Bool)
-> (HostOp lore op -> HostOp lore op -> Bool)
-> (HostOp lore op -> HostOp lore op -> HostOp lore op)
-> (HostOp lore op -> HostOp lore op -> HostOp lore op)
-> Ord (HostOp lore op)
HostOp lore op -> HostOp lore op -> Bool
HostOp lore op -> HostOp lore op -> Ordering
HostOp lore op -> HostOp lore op -> HostOp lore op
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall lore op. (Decorations lore, Ord op) => Eq (HostOp lore op)
forall lore op.
(Decorations lore, Ord op) =>
HostOp lore op -> HostOp lore op -> Bool
forall lore op.
(Decorations lore, Ord op) =>
HostOp lore op -> HostOp lore op -> Ordering
forall lore op.
(Decorations lore, Ord op) =>
HostOp lore op -> HostOp lore op -> HostOp lore op
min :: HostOp lore op -> HostOp lore op -> HostOp lore op
$cmin :: forall lore op.
(Decorations lore, Ord op) =>
HostOp lore op -> HostOp lore op -> HostOp lore op
max :: HostOp lore op -> HostOp lore op -> HostOp lore op
$cmax :: forall lore op.
(Decorations lore, Ord op) =>
HostOp lore op -> HostOp lore op -> HostOp lore op
>= :: HostOp lore op -> HostOp lore op -> Bool
$c>= :: forall lore op.
(Decorations lore, Ord op) =>
HostOp lore op -> HostOp lore op -> Bool
> :: HostOp lore op -> HostOp lore op -> Bool
$c> :: forall lore op.
(Decorations lore, Ord op) =>
HostOp lore op -> HostOp lore op -> Bool
<= :: HostOp lore op -> HostOp lore op -> Bool
$c<= :: forall lore op.
(Decorations lore, Ord op) =>
HostOp lore op -> HostOp lore op -> Bool
< :: HostOp lore op -> HostOp lore op -> Bool
$c< :: forall lore op.
(Decorations lore, Ord op) =>
HostOp lore op -> HostOp lore op -> Bool
compare :: HostOp lore op -> HostOp lore op -> Ordering
$ccompare :: forall lore op.
(Decorations lore, Ord op) =>
HostOp lore op -> HostOp lore op -> Ordering
$cp1Ord :: forall lore op. (Decorations lore, Ord op) => Eq (HostOp lore op)
Ord, Int -> HostOp lore op -> ShowS
[HostOp lore op] -> ShowS
HostOp lore op -> String
(Int -> HostOp lore op -> ShowS)
-> (HostOp lore op -> String)
-> ([HostOp lore op] -> ShowS)
-> Show (HostOp lore op)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall lore op.
(Decorations lore, Show op) =>
Int -> HostOp lore op -> ShowS
forall lore op.
(Decorations lore, Show op) =>
[HostOp lore op] -> ShowS
forall lore op.
(Decorations lore, Show op) =>
HostOp lore op -> String
showList :: [HostOp lore op] -> ShowS
$cshowList :: forall lore op.
(Decorations lore, Show op) =>
[HostOp lore op] -> ShowS
show :: HostOp lore op -> String
$cshow :: forall lore op.
(Decorations lore, Show op) =>
HostOp lore op -> String
showsPrec :: Int -> HostOp lore op -> ShowS
$cshowsPrec :: forall lore op.
(Decorations lore, Show op) =>
Int -> HostOp lore op -> ShowS
Show)

instance (ASTLore lore, Substitute op) => Substitute (HostOp lore op) where
  substituteNames :: Map VName VName -> HostOp lore op -> HostOp lore op
substituteNames Map VName VName
substs (SegOp SegOp SegLevel lore
op) =
    SegOp SegLevel lore -> HostOp lore op
forall lore op. SegOp SegLevel lore -> HostOp lore op
SegOp (SegOp SegLevel lore -> HostOp lore op)
-> SegOp SegLevel lore -> HostOp lore op
forall a b. (a -> b) -> a -> b
$ Map VName VName -> SegOp SegLevel lore -> SegOp SegLevel lore
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs SegOp SegLevel lore
op
  substituteNames Map VName VName
substs (OtherOp op
op) =
    op -> HostOp lore op
forall lore op. op -> HostOp lore op
OtherOp (op -> HostOp lore op) -> op -> HostOp lore op
forall a b. (a -> b) -> a -> b
$ Map VName VName -> op -> op
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs op
op
  substituteNames Map VName VName
substs (SizeOp SizeOp
op) =
    SizeOp -> HostOp lore op
forall lore op. SizeOp -> HostOp lore op
SizeOp (SizeOp -> HostOp lore op) -> SizeOp -> HostOp lore op
forall a b. (a -> b) -> a -> b
$ Map VName VName -> SizeOp -> SizeOp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs SizeOp
op

instance (ASTLore lore, Rename op) => Rename (HostOp lore op) where
  rename :: HostOp lore op -> RenameM (HostOp lore op)
rename (SegOp SegOp SegLevel lore
op) = SegOp SegLevel lore -> HostOp lore op
forall lore op. SegOp SegLevel lore -> HostOp lore op
SegOp (SegOp SegLevel lore -> HostOp lore op)
-> RenameM (SegOp SegLevel lore) -> RenameM (HostOp lore op)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOp SegLevel lore -> RenameM (SegOp SegLevel lore)
forall a. Rename a => a -> RenameM a
rename SegOp SegLevel lore
op
  rename (OtherOp op
op) = op -> HostOp lore op
forall lore op. op -> HostOp lore op
OtherOp (op -> HostOp lore op) -> RenameM op -> RenameM (HostOp lore op)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> op -> RenameM op
forall a. Rename a => a -> RenameM a
rename op
op
  rename (SizeOp SizeOp
op) = SizeOp -> HostOp lore op
forall lore op. SizeOp -> HostOp lore op
SizeOp (SizeOp -> HostOp lore op)
-> RenameM SizeOp -> RenameM (HostOp lore op)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SizeOp -> RenameM SizeOp
forall a. Rename a => a -> RenameM a
rename SizeOp
op

instance (ASTLore lore, IsOp op) => IsOp (HostOp lore op) where
  safeOp :: HostOp lore op -> Bool
safeOp (SegOp SegOp SegLevel lore
op) = SegOp SegLevel lore -> Bool
forall op. IsOp op => op -> Bool
safeOp SegOp SegLevel lore
op
  safeOp (OtherOp op
op) = op -> Bool
forall op. IsOp op => op -> Bool
safeOp op
op
  safeOp (SizeOp SizeOp
op) = SizeOp -> Bool
forall op. IsOp op => op -> Bool
safeOp SizeOp
op

  cheapOp :: HostOp lore op -> Bool
cheapOp (SegOp SegOp SegLevel lore
op) = SegOp SegLevel lore -> Bool
forall op. IsOp op => op -> Bool
cheapOp SegOp SegLevel lore
op
  cheapOp (OtherOp op
op) = op -> Bool
forall op. IsOp op => op -> Bool
cheapOp op
op
  cheapOp (SizeOp SizeOp
op) = SizeOp -> Bool
forall op. IsOp op => op -> Bool
cheapOp SizeOp
op

instance TypedOp op => TypedOp (HostOp lore op) where
  opType :: HostOp lore op -> m [ExtType]
opType (SegOp SegOp SegLevel lore
op) = SegOp SegLevel lore -> m [ExtType]
forall op t (m :: * -> *).
(TypedOp op, HasScope t m) =>
op -> m [ExtType]
opType SegOp SegLevel lore
op
  opType (OtherOp op
op) = op -> m [ExtType]
forall op t (m :: * -> *).
(TypedOp op, HasScope t m) =>
op -> m [ExtType]
opType op
op
  opType (SizeOp SizeOp
op) = SizeOp -> m [ExtType]
forall op t (m :: * -> *).
(TypedOp op, HasScope t m) =>
op -> m [ExtType]
opType SizeOp
op

instance (Aliased lore, AliasedOp op, ASTLore lore) => AliasedOp (HostOp lore op) where
  opAliases :: HostOp lore op -> [Names]
opAliases (SegOp SegOp SegLevel lore
op) = SegOp SegLevel lore -> [Names]
forall op. AliasedOp op => op -> [Names]
opAliases SegOp SegLevel lore
op
  opAliases (OtherOp op
op) = op -> [Names]
forall op. AliasedOp op => op -> [Names]
opAliases op
op
  opAliases (SizeOp SizeOp
op) = SizeOp -> [Names]
forall op. AliasedOp op => op -> [Names]
opAliases SizeOp
op

  consumedInOp :: HostOp lore op -> Names
consumedInOp (SegOp SegOp SegLevel lore
op) = SegOp SegLevel lore -> Names
forall op. AliasedOp op => op -> Names
consumedInOp SegOp SegLevel lore
op
  consumedInOp (OtherOp op
op) = op -> Names
forall op. AliasedOp op => op -> Names
consumedInOp op
op
  consumedInOp (SizeOp SizeOp
op) = SizeOp -> Names
forall op. AliasedOp op => op -> Names
consumedInOp SizeOp
op

instance (ASTLore lore, FreeIn op) => FreeIn (HostOp lore op) where
  freeIn' :: HostOp lore op -> FV
freeIn' (SegOp SegOp SegLevel lore
op) = SegOp SegLevel lore -> FV
forall a. FreeIn a => a -> FV
freeIn' SegOp SegLevel lore
op
  freeIn' (OtherOp op
op) = op -> FV
forall a. FreeIn a => a -> FV
freeIn' op
op
  freeIn' (SizeOp SizeOp
op) = SizeOp -> FV
forall a. FreeIn a => a -> FV
freeIn' SizeOp
op

instance (CanBeAliased (Op lore), CanBeAliased op, ASTLore lore) => CanBeAliased (HostOp lore op) where
  type OpWithAliases (HostOp lore op) = HostOp (Aliases lore) (OpWithAliases op)

  addOpAliases :: HostOp lore op -> OpWithAliases (HostOp lore op)
addOpAliases (SegOp SegOp SegLevel lore
op) = SegOp SegLevel (Aliases lore)
-> HostOp (Aliases lore) (OpWithAliases op)
forall lore op. SegOp SegLevel lore -> HostOp lore op
SegOp (SegOp SegLevel (Aliases lore)
 -> HostOp (Aliases lore) (OpWithAliases op))
-> SegOp SegLevel (Aliases lore)
-> HostOp (Aliases lore) (OpWithAliases op)
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel lore -> OpWithAliases (SegOp SegLevel lore)
forall op. CanBeAliased op => op -> OpWithAliases op
addOpAliases SegOp SegLevel lore
op
  addOpAliases (OtherOp op
op) = OpWithAliases op -> HostOp (Aliases lore) (OpWithAliases op)
forall lore op. op -> HostOp lore op
OtherOp (OpWithAliases op -> HostOp (Aliases lore) (OpWithAliases op))
-> OpWithAliases op -> HostOp (Aliases lore) (OpWithAliases op)
forall a b. (a -> b) -> a -> b
$ op -> OpWithAliases op
forall op. CanBeAliased op => op -> OpWithAliases op
addOpAliases op
op
  addOpAliases (SizeOp SizeOp
op) = SizeOp -> HostOp (Aliases lore) (OpWithAliases op)
forall lore op. SizeOp -> HostOp lore op
SizeOp SizeOp
op

  removeOpAliases :: OpWithAliases (HostOp lore op) -> HostOp lore op
removeOpAliases (SegOp op) = SegOp SegLevel lore -> HostOp lore op
forall lore op. SegOp SegLevel lore -> HostOp lore op
SegOp (SegOp SegLevel lore -> HostOp lore op)
-> SegOp SegLevel lore -> HostOp lore op
forall a b. (a -> b) -> a -> b
$ OpWithAliases (SegOp SegLevel lore) -> SegOp SegLevel lore
forall op. CanBeAliased op => OpWithAliases op -> op
removeOpAliases OpWithAliases (SegOp SegLevel lore)
SegOp SegLevel (Aliases lore)
op
  removeOpAliases (OtherOp op) = op -> HostOp lore op
forall lore op. op -> HostOp lore op
OtherOp (op -> HostOp lore op) -> op -> HostOp lore op
forall a b. (a -> b) -> a -> b
$ OpWithAliases op -> op
forall op. CanBeAliased op => OpWithAliases op -> op
removeOpAliases OpWithAliases op
op
  removeOpAliases (SizeOp op) = SizeOp -> HostOp lore op
forall lore op. SizeOp -> HostOp lore op
SizeOp SizeOp
op

instance (CanBeWise (Op lore), CanBeWise op, ASTLore lore) => CanBeWise (HostOp lore op) where
  type OpWithWisdom (HostOp lore op) = HostOp (Wise lore) (OpWithWisdom op)

  removeOpWisdom :: OpWithWisdom (HostOp lore op) -> HostOp lore op
removeOpWisdom (SegOp op) = SegOp SegLevel lore -> HostOp lore op
forall lore op. SegOp SegLevel lore -> HostOp lore op
SegOp (SegOp SegLevel lore -> HostOp lore op)
-> SegOp SegLevel lore -> HostOp lore op
forall a b. (a -> b) -> a -> b
$ OpWithWisdom (SegOp SegLevel lore) -> SegOp SegLevel lore
forall op. CanBeWise op => OpWithWisdom op -> op
removeOpWisdom OpWithWisdom (SegOp SegLevel lore)
SegOp SegLevel (Wise lore)
op
  removeOpWisdom (OtherOp op) = op -> HostOp lore op
forall lore op. op -> HostOp lore op
OtherOp (op -> HostOp lore op) -> op -> HostOp lore op
forall a b. (a -> b) -> a -> b
$ OpWithWisdom op -> op
forall op. CanBeWise op => OpWithWisdom op -> op
removeOpWisdom OpWithWisdom op
op
  removeOpWisdom (SizeOp op) = SizeOp -> HostOp lore op
forall lore op. SizeOp -> HostOp lore op
SizeOp SizeOp
op

instance (ASTLore lore, ST.IndexOp op) => ST.IndexOp (HostOp lore op) where
  indexOp :: SymbolTable lore
-> Int -> HostOp lore op -> [PrimExp VName] -> Maybe Indexed
indexOp SymbolTable lore
vtable Int
k (SegOp SegOp SegLevel lore
op) [PrimExp VName]
is = SymbolTable lore
-> Int -> SegOp SegLevel lore -> [PrimExp VName] -> Maybe Indexed
forall op lore.
(IndexOp op, ASTLore lore, IndexOp (Op lore)) =>
SymbolTable lore -> Int -> op -> [PrimExp VName] -> Maybe Indexed
ST.indexOp SymbolTable lore
vtable Int
k SegOp SegLevel lore
op [PrimExp VName]
is
  indexOp SymbolTable lore
vtable Int
k (OtherOp op
op) [PrimExp VName]
is = SymbolTable lore -> Int -> op -> [PrimExp VName] -> Maybe Indexed
forall op lore.
(IndexOp op, ASTLore lore, IndexOp (Op lore)) =>
SymbolTable lore -> Int -> op -> [PrimExp VName] -> Maybe Indexed
ST.indexOp SymbolTable lore
vtable Int
k op
op [PrimExp VName]
is
  indexOp SymbolTable lore
_ Int
_ HostOp lore op
_ [PrimExp VName]
_ = Maybe Indexed
forall a. Maybe a
Nothing

instance (PrettyLore lore, PP.Pretty op) => PP.Pretty (HostOp lore op) where
  ppr :: HostOp lore op -> Doc
ppr (SegOp SegOp SegLevel lore
op) = SegOp SegLevel lore -> Doc
forall a. Pretty a => a -> Doc
ppr SegOp SegLevel lore
op
  ppr (OtherOp op
op) = op -> Doc
forall a. Pretty a => a -> Doc
ppr op
op
  ppr (SizeOp SizeOp
op) = SizeOp -> Doc
forall a. Pretty a => a -> Doc
ppr SizeOp
op

instance (OpMetrics (Op lore), OpMetrics op) => OpMetrics (HostOp lore op) where
  opMetrics :: HostOp lore op -> MetricsM ()
opMetrics (SegOp SegOp SegLevel lore
op) = SegOp SegLevel lore -> MetricsM ()
forall op. OpMetrics op => op -> MetricsM ()
opMetrics SegOp SegLevel lore
op
  opMetrics (OtherOp op
op) = op -> MetricsM ()
forall op. OpMetrics op => op -> MetricsM ()
opMetrics op
op
  opMetrics (SizeOp SizeOp
op) = SizeOp -> MetricsM ()
forall op. OpMetrics op => op -> MetricsM ()
opMetrics SizeOp
op

checkSegLevel :: TC.Checkable lore =>
                 Maybe SegLevel -> SegLevel -> TC.TypeM lore ()
checkSegLevel :: Maybe SegLevel -> SegLevel -> TypeM lore ()
checkSegLevel Maybe SegLevel
Nothing SegLevel
lvl = do
  [Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32] (SubExp -> TypeM lore ()) -> SubExp -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ Count NumGroups SubExp -> SubExp
forall u e. Count u e -> e
unCount (Count NumGroups SubExp -> SubExp)
-> Count NumGroups SubExp -> SubExp
forall a b. (a -> b) -> a -> b
$ SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl
  [Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32] (SubExp -> TypeM lore ()) -> SubExp -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ Count GroupSize SubExp -> SubExp
forall u e. Count u e -> e
unCount (Count GroupSize SubExp -> SubExp)
-> Count GroupSize SubExp -> SubExp
forall a b. (a -> b) -> a -> b
$ SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl
checkSegLevel (Just SegThread{}) SegLevel
_ =
  ErrorCase lore -> TypeM lore ()
forall lore a. ErrorCase lore -> TypeM lore a
TC.bad (ErrorCase lore -> TypeM lore ())
-> ErrorCase lore -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ String -> ErrorCase lore
forall lore. String -> ErrorCase lore
TC.TypeError String
"SegOps cannot occur when already at thread level."
checkSegLevel (Just SegLevel
x) SegLevel
y
  | SegLevel
x SegLevel -> SegLevel -> Bool
forall a. Eq a => a -> a -> Bool
== SegLevel
y = ErrorCase lore -> TypeM lore ()
forall lore a. ErrorCase lore -> TypeM lore a
TC.bad (ErrorCase lore -> TypeM lore ())
-> ErrorCase lore -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ String -> ErrorCase lore
forall lore. String -> ErrorCase lore
TC.TypeError (String -> ErrorCase lore) -> String -> ErrorCase lore
forall a b. (a -> b) -> a -> b
$ String
"Already at at level " String -> ShowS
forall a. [a] -> [a] -> [a]
++ SegLevel -> String
forall a. Pretty a => a -> String
pretty SegLevel
x
  | SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
x Count NumGroups SubExp -> Count NumGroups SubExp -> Bool
forall a. Eq a => a -> a -> Bool
/= SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
y Bool -> Bool -> Bool
|| SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
x Count GroupSize SubExp -> Count GroupSize SubExp -> Bool
forall a. Eq a => a -> a -> Bool
/= SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
y =
      ErrorCase lore -> TypeM lore ()
forall lore a. ErrorCase lore -> TypeM lore a
TC.bad (ErrorCase lore -> TypeM lore ())
-> ErrorCase lore -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ String -> ErrorCase lore
forall lore. String -> ErrorCase lore
TC.TypeError String
"Physical layout for SegLevel does not match parent SegLevel."
  | Bool
otherwise =
      () -> TypeM lore ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

typeCheckHostOp :: TC.Checkable lore =>
                   (SegLevel -> OpWithAliases (Op lore) -> TC.TypeM lore ())
                -> Maybe SegLevel
                -> (op -> TC.TypeM lore ())
                -> HostOp (Aliases lore) op
                -> TC.TypeM lore ()
typeCheckHostOp :: (SegLevel -> OpWithAliases (Op lore) -> TypeM lore ())
-> Maybe SegLevel
-> (op -> TypeM lore ())
-> HostOp (Aliases lore) op
-> TypeM lore ()
typeCheckHostOp SegLevel -> OpWithAliases (Op lore) -> TypeM lore ()
checker Maybe SegLevel
lvl op -> TypeM lore ()
_ (SegOp SegOp SegLevel (Aliases lore)
op) =
  (OpWithAliases (Op lore) -> TypeM lore ())
-> TypeM lore () -> TypeM lore ()
forall lore a.
(OpWithAliases (Op lore) -> TypeM lore ())
-> TypeM lore a -> TypeM lore a
TC.checkOpWith (SegLevel -> OpWithAliases (Op lore) -> TypeM lore ()
checker (SegLevel -> OpWithAliases (Op lore) -> TypeM lore ())
-> SegLevel -> OpWithAliases (Op lore) -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel (Aliases lore) -> SegLevel
forall lvl lore. SegOp lvl lore -> lvl
segLevel SegOp SegLevel (Aliases lore)
op) (TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
  (SegLevel -> TypeM lore ())
-> SegOp SegLevel (Aliases lore) -> TypeM lore ()
forall lore lvl.
Checkable lore =>
(lvl -> TypeM lore ()) -> SegOp lvl (Aliases lore) -> TypeM lore ()
typeCheckSegOp (Maybe SegLevel -> SegLevel -> TypeM lore ()
forall lore.
Checkable lore =>
Maybe SegLevel -> SegLevel -> TypeM lore ()
checkSegLevel Maybe SegLevel
lvl) SegOp SegLevel (Aliases lore)
op
typeCheckHostOp SegLevel -> OpWithAliases (Op lore) -> TypeM lore ()
_ Maybe SegLevel
_ op -> TypeM lore ()
f (OtherOp op
op) = op -> TypeM lore ()
f op
op
typeCheckHostOp SegLevel -> OpWithAliases (Op lore) -> TypeM lore ()
_ Maybe SegLevel
_ op -> TypeM lore ()
_ (SizeOp SizeOp
op) = SizeOp -> TypeM lore ()
forall lore. Checkable lore => SizeOp -> TypeM lore ()
typeCheckSizeOp SizeOp
op