{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}

module Futhark.IR.GPU.Op
  ( -- * Size operations
    SizeOp (..),

    -- * Host operations
    HostOp (..),
    typeCheckHostOp,

    -- * SegOp refinements
    SegLevel (..),

    -- * Reexports
    module Futhark.IR.GPU.Sizes,
    module Futhark.IR.SegOp,
  )
where

import Futhark.Analysis.Metrics
import qualified Futhark.Analysis.SymbolTable as ST
import Futhark.IR
import Futhark.IR.Aliases (Aliases)
import Futhark.IR.GPU.Sizes
import Futhark.IR.Prop.Aliases
import Futhark.IR.SegOp
import qualified Futhark.Optimise.Simplify.Engine as Engine
import Futhark.Optimise.Simplify.Rep
import Futhark.Transform.Rename
import Futhark.Transform.Substitute
import qualified Futhark.TypeCheck as TC
import Futhark.Util.Pretty
  ( commasep,
    parens,
    ppr,
    text,
    (<+>),
  )
import qualified Futhark.Util.Pretty as PP
import Prelude hiding (id, (.))

-- | At which level the *body* of a t'SegOp' executes.
data SegLevel
  = SegThread
      { SegLevel -> Count NumGroups SubExp
segNumGroups :: Count NumGroups SubExp,
        SegLevel -> Count GroupSize SubExp
segGroupSize :: Count GroupSize SubExp,
        SegLevel -> SegVirt
segVirt :: SegVirt
      }
  | SegGroup
      { segNumGroups :: Count NumGroups SubExp,
        segGroupSize :: Count GroupSize SubExp,
        segVirt :: SegVirt
      }
  deriving (SegLevel -> SegLevel -> Bool
(SegLevel -> SegLevel -> Bool)
-> (SegLevel -> SegLevel -> Bool) -> Eq SegLevel
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SegLevel -> SegLevel -> Bool
$c/= :: SegLevel -> SegLevel -> Bool
== :: SegLevel -> SegLevel -> Bool
$c== :: SegLevel -> SegLevel -> Bool
Eq, Eq SegLevel
Eq SegLevel
-> (SegLevel -> SegLevel -> Ordering)
-> (SegLevel -> SegLevel -> Bool)
-> (SegLevel -> SegLevel -> Bool)
-> (SegLevel -> SegLevel -> Bool)
-> (SegLevel -> SegLevel -> Bool)
-> (SegLevel -> SegLevel -> SegLevel)
-> (SegLevel -> SegLevel -> SegLevel)
-> Ord SegLevel
SegLevel -> SegLevel -> Bool
SegLevel -> SegLevel -> Ordering
SegLevel -> SegLevel -> SegLevel
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: SegLevel -> SegLevel -> SegLevel
$cmin :: SegLevel -> SegLevel -> SegLevel
max :: SegLevel -> SegLevel -> SegLevel
$cmax :: SegLevel -> SegLevel -> SegLevel
>= :: SegLevel -> SegLevel -> Bool
$c>= :: SegLevel -> SegLevel -> Bool
> :: SegLevel -> SegLevel -> Bool
$c> :: SegLevel -> SegLevel -> Bool
<= :: SegLevel -> SegLevel -> Bool
$c<= :: SegLevel -> SegLevel -> Bool
< :: SegLevel -> SegLevel -> Bool
$c< :: SegLevel -> SegLevel -> Bool
compare :: SegLevel -> SegLevel -> Ordering
$ccompare :: SegLevel -> SegLevel -> Ordering
$cp1Ord :: Eq SegLevel
Ord, Int -> SegLevel -> ShowS
[SegLevel] -> ShowS
SegLevel -> String
(Int -> SegLevel -> ShowS)
-> (SegLevel -> String) -> ([SegLevel] -> ShowS) -> Show SegLevel
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SegLevel] -> ShowS
$cshowList :: [SegLevel] -> ShowS
show :: SegLevel -> String
$cshow :: SegLevel -> String
showsPrec :: Int -> SegLevel -> ShowS
$cshowsPrec :: Int -> SegLevel -> ShowS
Show)

instance PP.Pretty SegLevel where
  ppr :: SegLevel -> Doc
ppr SegLevel
lvl =
    Doc -> Doc
PP.parens
      ( Doc
lvl' Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.semi
          Doc -> Doc -> Doc
<+> String -> Doc
text String
"#groups=" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Count NumGroups SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr (SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.semi
          Doc -> Doc -> Doc
<+> String -> Doc
text String
"groupsize=" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Count GroupSize SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr (SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
virt
      )
    where
      lvl' :: Doc
lvl' = case SegLevel
lvl of
        SegThread {} -> Doc
"thread"
        SegGroup {} -> Doc
"group"
      virt :: Doc
virt = case SegLevel -> SegVirt
segVirt SegLevel
lvl of
        SegVirt
SegNoVirt -> Doc
forall a. Monoid a => a
mempty
        SegVirt
SegNoVirtFull -> Doc
PP.semi Doc -> Doc -> Doc
<+> String -> Doc
text String
"full"
        SegVirt
SegVirt -> Doc
PP.semi Doc -> Doc -> Doc
<+> String -> Doc
text String
"virtualise"

instance Engine.Simplifiable SegLevel where
  simplify :: SegLevel -> SimpleM rep SegLevel
simplify (SegThread Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegVirt
virt) =
    Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegThread (Count NumGroups SubExp
 -> Count GroupSize SubExp -> SegVirt -> SegLevel)
-> SimpleM rep (Count NumGroups SubExp)
-> SimpleM rep (Count GroupSize SubExp -> SegVirt -> SegLevel)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExp -> SimpleM rep SubExp)
-> Count NumGroups SubExp -> SimpleM rep (Count NumGroups SubExp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse SubExp -> SimpleM rep SubExp
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify Count NumGroups SubExp
num_groups
      SimpleM rep (Count GroupSize SubExp -> SegVirt -> SegLevel)
-> SimpleM rep (Count GroupSize SubExp)
-> SimpleM rep (SegVirt -> SegLevel)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SubExp -> SimpleM rep SubExp)
-> Count GroupSize SubExp -> SimpleM rep (Count GroupSize SubExp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse SubExp -> SimpleM rep SubExp
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify Count GroupSize SubExp
group_size
      SimpleM rep (SegVirt -> SegLevel)
-> SimpleM rep SegVirt -> SimpleM rep SegLevel
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegVirt -> SimpleM rep SegVirt
forall (f :: * -> *) a. Applicative f => a -> f a
pure SegVirt
virt
  simplify (SegGroup Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegVirt
virt) =
    Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegGroup (Count NumGroups SubExp
 -> Count GroupSize SubExp -> SegVirt -> SegLevel)
-> SimpleM rep (Count NumGroups SubExp)
-> SimpleM rep (Count GroupSize SubExp -> SegVirt -> SegLevel)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExp -> SimpleM rep SubExp)
-> Count NumGroups SubExp -> SimpleM rep (Count NumGroups SubExp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse SubExp -> SimpleM rep SubExp
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify Count NumGroups SubExp
num_groups
      SimpleM rep (Count GroupSize SubExp -> SegVirt -> SegLevel)
-> SimpleM rep (Count GroupSize SubExp)
-> SimpleM rep (SegVirt -> SegLevel)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SubExp -> SimpleM rep SubExp)
-> Count GroupSize SubExp -> SimpleM rep (Count GroupSize SubExp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse SubExp -> SimpleM rep SubExp
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify Count GroupSize SubExp
group_size
      SimpleM rep (SegVirt -> SegLevel)
-> SimpleM rep SegVirt -> SimpleM rep SegLevel
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegVirt -> SimpleM rep SegVirt
forall (f :: * -> *) a. Applicative f => a -> f a
pure SegVirt
virt

instance Substitute SegLevel where
  substituteNames :: Map VName VName -> SegLevel -> SegLevel
substituteNames Map VName VName
substs (SegThread Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegVirt
virt) =
    Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegThread
      (Map VName VName -> Count NumGroups SubExp -> Count NumGroups SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Count NumGroups SubExp
num_groups)
      (Map VName VName -> Count GroupSize SubExp -> Count GroupSize SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Count GroupSize SubExp
group_size)
      SegVirt
virt
  substituteNames Map VName VName
substs (SegGroup Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegVirt
virt) =
    Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegGroup
      (Map VName VName -> Count NumGroups SubExp -> Count NumGroups SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Count NumGroups SubExp
num_groups)
      (Map VName VName -> Count GroupSize SubExp -> Count GroupSize SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Count GroupSize SubExp
group_size)
      SegVirt
virt

instance Rename SegLevel where
  rename :: SegLevel -> RenameM SegLevel
rename = SegLevel -> RenameM SegLevel
forall a. Substitute a => a -> RenameM a
substituteRename

instance FreeIn SegLevel where
  freeIn' :: SegLevel -> FV
freeIn' (SegThread Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegVirt
_) =
    Count NumGroups SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn' Count NumGroups SubExp
num_groups FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> Count GroupSize SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn' Count GroupSize SubExp
group_size
  freeIn' (SegGroup Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size SegVirt
_) =
    Count NumGroups SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn' Count NumGroups SubExp
num_groups FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> Count GroupSize SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn' Count GroupSize SubExp
group_size

-- | A simple size-level query or computation.
data SizeOp
  = -- | @SplitSpace o w i elems_per_thread@.
    --
    -- Computes how to divide array elements to
    -- threads in a kernel.  Returns the number of
    -- elements in the chunk that the current thread
    -- should take.
    --
    -- @w@ is the length of the outer dimension in
    -- the array. @i@ is the current thread
    -- index. Each thread takes at most
    -- @elems_per_thread@ elements.
    --
    -- If the order @o@ is 'SplitContiguous', thread with index @i@
    -- should receive elements
    -- @i*elems_per_tread, i*elems_per_thread + 1,
    -- ..., i*elems_per_thread + (elems_per_thread-1)@.
    --
    -- If the order @o@ is @'SplitStrided' stride@,
    -- the thread will receive elements @i,
    -- i+stride, i+2*stride, ...,
    -- i+(elems_per_thread-1)*stride@.
    SplitSpace SplitOrdering SubExp SubExp SubExp
  | -- | Produce some runtime-configurable size.
    GetSize Name SizeClass
  | -- | The maximum size of some class.
    GetSizeMax SizeClass
  | -- | Compare size (likely a threshold) with some integer value.
    CmpSizeLe Name SizeClass SubExp
  | -- | @CalcNumGroups w max_num_groups group_size@ calculates the
    -- number of GPU workgroups to use for an input of the given size.
    -- The @Name@ is a size name.  Note that @w@ is an i64 to avoid
    -- overflow issues.
    CalcNumGroups SubExp Name SubExp
  deriving (SizeOp -> SizeOp -> Bool
(SizeOp -> SizeOp -> Bool)
-> (SizeOp -> SizeOp -> Bool) -> Eq SizeOp
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SizeOp -> SizeOp -> Bool
$c/= :: SizeOp -> SizeOp -> Bool
== :: SizeOp -> SizeOp -> Bool
$c== :: SizeOp -> SizeOp -> Bool
Eq, Eq SizeOp
Eq SizeOp
-> (SizeOp -> SizeOp -> Ordering)
-> (SizeOp -> SizeOp -> Bool)
-> (SizeOp -> SizeOp -> Bool)
-> (SizeOp -> SizeOp -> Bool)
-> (SizeOp -> SizeOp -> Bool)
-> (SizeOp -> SizeOp -> SizeOp)
-> (SizeOp -> SizeOp -> SizeOp)
-> Ord SizeOp
SizeOp -> SizeOp -> Bool
SizeOp -> SizeOp -> Ordering
SizeOp -> SizeOp -> SizeOp
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: SizeOp -> SizeOp -> SizeOp
$cmin :: SizeOp -> SizeOp -> SizeOp
max :: SizeOp -> SizeOp -> SizeOp
$cmax :: SizeOp -> SizeOp -> SizeOp
>= :: SizeOp -> SizeOp -> Bool
$c>= :: SizeOp -> SizeOp -> Bool
> :: SizeOp -> SizeOp -> Bool
$c> :: SizeOp -> SizeOp -> Bool
<= :: SizeOp -> SizeOp -> Bool
$c<= :: SizeOp -> SizeOp -> Bool
< :: SizeOp -> SizeOp -> Bool
$c< :: SizeOp -> SizeOp -> Bool
compare :: SizeOp -> SizeOp -> Ordering
$ccompare :: SizeOp -> SizeOp -> Ordering
$cp1Ord :: Eq SizeOp
Ord, Int -> SizeOp -> ShowS
[SizeOp] -> ShowS
SizeOp -> String
(Int -> SizeOp -> ShowS)
-> (SizeOp -> String) -> ([SizeOp] -> ShowS) -> Show SizeOp
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SizeOp] -> ShowS
$cshowList :: [SizeOp] -> ShowS
show :: SizeOp -> String
$cshow :: SizeOp -> String
showsPrec :: Int -> SizeOp -> ShowS
$cshowsPrec :: Int -> SizeOp -> ShowS
Show)

instance Substitute SizeOp where
  substituteNames :: Map VName VName -> SizeOp -> SizeOp
substituteNames Map VName VName
subst (SplitSpace SplitOrdering
o SubExp
w SubExp
i SubExp
elems_per_thread) =
    SplitOrdering -> SubExp -> SubExp -> SubExp -> SizeOp
SplitSpace
      (Map VName VName -> SplitOrdering -> SplitOrdering
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst SplitOrdering
o)
      (Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst SubExp
w)
      (Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst SubExp
i)
      (Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst SubExp
elems_per_thread)
  substituteNames Map VName VName
substs (CmpSizeLe Name
name SizeClass
sclass SubExp
x) =
    Name -> SizeClass -> SubExp -> SizeOp
CmpSizeLe Name
name SizeClass
sclass (Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs SubExp
x)
  substituteNames Map VName VName
substs (CalcNumGroups SubExp
w Name
max_num_groups SubExp
group_size) =
    SubExp -> Name -> SubExp -> SizeOp
CalcNumGroups
      (Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs SubExp
w)
      Name
max_num_groups
      (Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs SubExp
group_size)
  substituteNames Map VName VName
_ SizeOp
op = SizeOp
op

instance Rename SizeOp where
  rename :: SizeOp -> RenameM SizeOp
rename (SplitSpace SplitOrdering
o SubExp
w SubExp
i SubExp
elems_per_thread) =
    SplitOrdering -> SubExp -> SubExp -> SubExp -> SizeOp
SplitSpace
      (SplitOrdering -> SubExp -> SubExp -> SubExp -> SizeOp)
-> RenameM SplitOrdering
-> RenameM (SubExp -> SubExp -> SubExp -> SizeOp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SplitOrdering -> RenameM SplitOrdering
forall a. Rename a => a -> RenameM a
rename SplitOrdering
o
      RenameM (SubExp -> SubExp -> SubExp -> SizeOp)
-> RenameM SubExp -> RenameM (SubExp -> SubExp -> SizeOp)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SubExp -> RenameM SubExp
forall a. Rename a => a -> RenameM a
rename SubExp
w
      RenameM (SubExp -> SubExp -> SizeOp)
-> RenameM SubExp -> RenameM (SubExp -> SizeOp)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SubExp -> RenameM SubExp
forall a. Rename a => a -> RenameM a
rename SubExp
i
      RenameM (SubExp -> SizeOp) -> RenameM SubExp -> RenameM SizeOp
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SubExp -> RenameM SubExp
forall a. Rename a => a -> RenameM a
rename SubExp
elems_per_thread
  rename (CmpSizeLe Name
name SizeClass
sclass SubExp
x) =
    Name -> SizeClass -> SubExp -> SizeOp
CmpSizeLe Name
name SizeClass
sclass (SubExp -> SizeOp) -> RenameM SubExp -> RenameM SizeOp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> RenameM SubExp
forall a. Rename a => a -> RenameM a
rename SubExp
x
  rename (CalcNumGroups SubExp
w Name
max_num_groups SubExp
group_size) =
    SubExp -> Name -> SubExp -> SizeOp
CalcNumGroups (SubExp -> Name -> SubExp -> SizeOp)
-> RenameM SubExp -> RenameM (Name -> SubExp -> SizeOp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> RenameM SubExp
forall a. Rename a => a -> RenameM a
rename SubExp
w RenameM (Name -> SubExp -> SizeOp)
-> RenameM Name -> RenameM (SubExp -> SizeOp)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Name -> RenameM Name
forall (f :: * -> *) a. Applicative f => a -> f a
pure Name
max_num_groups RenameM (SubExp -> SizeOp) -> RenameM SubExp -> RenameM SizeOp
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SubExp -> RenameM SubExp
forall a. Rename a => a -> RenameM a
rename SubExp
group_size
  rename SizeOp
x = SizeOp -> RenameM SizeOp
forall (f :: * -> *) a. Applicative f => a -> f a
pure SizeOp
x

instance IsOp SizeOp where
  safeOp :: SizeOp -> Bool
safeOp SizeOp
_ = Bool
True
  cheapOp :: SizeOp -> Bool
cheapOp SizeOp
_ = Bool
True

instance TypedOp SizeOp where
  opType :: SizeOp -> m [ExtType]
opType SplitSpace {} = [ExtType] -> m [ExtType]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [PrimType -> ExtType
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]
  opType (GetSize Name
_ SizeClass
_) = [ExtType] -> m [ExtType]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [PrimType -> ExtType
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]
  opType (GetSizeMax SizeClass
_) = [ExtType] -> m [ExtType]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [PrimType -> ExtType
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]
  opType CmpSizeLe {} = [ExtType] -> m [ExtType]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [PrimType -> ExtType
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
Bool]
  opType CalcNumGroups {} = [ExtType] -> m [ExtType]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [PrimType -> ExtType
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]

instance AliasedOp SizeOp where
  opAliases :: SizeOp -> [Names]
opAliases SizeOp
_ = [Names
forall a. Monoid a => a
mempty]
  consumedInOp :: SizeOp -> Names
consumedInOp SizeOp
_ = Names
forall a. Monoid a => a
mempty

instance FreeIn SizeOp where
  freeIn' :: SizeOp -> FV
freeIn' (SplitSpace SplitOrdering
o SubExp
w SubExp
i SubExp
elems_per_thread) =
    SplitOrdering -> FV
forall a. FreeIn a => a -> FV
freeIn' SplitOrdering
o FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> [SubExp] -> FV
forall a. FreeIn a => a -> FV
freeIn' [SubExp
w, SubExp
i, SubExp
elems_per_thread]
  freeIn' (CmpSizeLe Name
_ SizeClass
_ SubExp
x) = SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn' SubExp
x
  freeIn' (CalcNumGroups SubExp
w Name
_ SubExp
group_size) = SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn' SubExp
w FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn' SubExp
group_size
  freeIn' SizeOp
_ = FV
forall a. Monoid a => a
mempty

instance PP.Pretty SizeOp where
  ppr :: SizeOp -> Doc
ppr (SplitSpace SplitOrdering
SplitContiguous SubExp
w SubExp
i SubExp
elems_per_thread) =
    String -> Doc
text String
"split_space"
      Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens ([Doc] -> Doc
commasep [SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
w, SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
i, SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
elems_per_thread])
  ppr (SplitSpace (SplitStrided SubExp
stride) SubExp
w SubExp
i SubExp
elems_per_thread) =
    String -> Doc
text String
"split_space_strided"
      Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens ([Doc] -> Doc
commasep [SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
stride, SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
w, SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
i, SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
elems_per_thread])
  ppr (GetSize Name
name SizeClass
size_class) =
    String -> Doc
text String
"get_size" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens ([Doc] -> Doc
commasep [Name -> Doc
forall a. Pretty a => a -> Doc
ppr Name
name, SizeClass -> Doc
forall a. Pretty a => a -> Doc
ppr SizeClass
size_class])
  ppr (GetSizeMax SizeClass
size_class) =
    String -> Doc
text String
"get_size_max" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens ([Doc] -> Doc
commasep [SizeClass -> Doc
forall a. Pretty a => a -> Doc
ppr SizeClass
size_class])
  ppr (CmpSizeLe Name
name SizeClass
size_class SubExp
x) =
    String -> Doc
text String
"cmp_size" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens ([Doc] -> Doc
commasep [Name -> Doc
forall a. Pretty a => a -> Doc
ppr Name
name, SizeClass -> Doc
forall a. Pretty a => a -> Doc
ppr SizeClass
size_class])
      Doc -> Doc -> Doc
<+> String -> Doc
text String
"<="
      Doc -> Doc -> Doc
<+> SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
x
  ppr (CalcNumGroups SubExp
w Name
max_num_groups SubExp
group_size) =
    String -> Doc
text String
"calc_num_groups" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens ([Doc] -> Doc
commasep [SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
w, Name -> Doc
forall a. Pretty a => a -> Doc
ppr Name
max_num_groups, SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
group_size])

instance OpMetrics SizeOp where
  opMetrics :: SizeOp -> MetricsM ()
opMetrics SplitSpace {} = Text -> MetricsM ()
seen Text
"SplitSpace"
  opMetrics GetSize {} = Text -> MetricsM ()
seen Text
"GetSize"
  opMetrics GetSizeMax {} = Text -> MetricsM ()
seen Text
"GetSizeMax"
  opMetrics CmpSizeLe {} = Text -> MetricsM ()
seen Text
"CmpSizeLe"
  opMetrics CalcNumGroups {} = Text -> MetricsM ()
seen Text
"CalcNumGroups"

typeCheckSizeOp :: TC.Checkable rep => SizeOp -> TC.TypeM rep ()
typeCheckSizeOp :: SizeOp -> TypeM rep ()
typeCheckSizeOp (SplitSpace SplitOrdering
o SubExp
w SubExp
i SubExp
elems_per_thread) = do
  case SplitOrdering
o of
    SplitOrdering
SplitContiguous -> () -> TypeM rep ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    SplitStrided SubExp
stride -> [Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
stride
  (SubExp -> TypeM rep ()) -> [SubExp] -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ([Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]) [SubExp
w, SubExp
i, SubExp
elems_per_thread]
typeCheckSizeOp GetSize {} = () -> TypeM rep ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
typeCheckSizeOp GetSizeMax {} = () -> TypeM rep ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
typeCheckSizeOp (CmpSizeLe Name
_ SizeClass
_ SubExp
x) = [Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
x
typeCheckSizeOp (CalcNumGroups SubExp
w Name
_ SubExp
group_size) = do
  [Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
w
  [Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
group_size

-- | A host-level operation; parameterised by what else it can do.
data HostOp rep op
  = -- | A segmented operation.
    SegOp (SegOp SegLevel rep)
  | SizeOp SizeOp
  | OtherOp op
  deriving (HostOp rep op -> HostOp rep op -> Bool
(HostOp rep op -> HostOp rep op -> Bool)
-> (HostOp rep op -> HostOp rep op -> Bool) -> Eq (HostOp rep op)
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall rep op.
(RepTypes rep, Eq op) =>
HostOp rep op -> HostOp rep op -> Bool
/= :: HostOp rep op -> HostOp rep op -> Bool
$c/= :: forall rep op.
(RepTypes rep, Eq op) =>
HostOp rep op -> HostOp rep op -> Bool
== :: HostOp rep op -> HostOp rep op -> Bool
$c== :: forall rep op.
(RepTypes rep, Eq op) =>
HostOp rep op -> HostOp rep op -> Bool
Eq, Eq (HostOp rep op)
Eq (HostOp rep op)
-> (HostOp rep op -> HostOp rep op -> Ordering)
-> (HostOp rep op -> HostOp rep op -> Bool)
-> (HostOp rep op -> HostOp rep op -> Bool)
-> (HostOp rep op -> HostOp rep op -> Bool)
-> (HostOp rep op -> HostOp rep op -> Bool)
-> (HostOp rep op -> HostOp rep op -> HostOp rep op)
-> (HostOp rep op -> HostOp rep op -> HostOp rep op)
-> Ord (HostOp rep op)
HostOp rep op -> HostOp rep op -> Bool
HostOp rep op -> HostOp rep op -> Ordering
HostOp rep op -> HostOp rep op -> HostOp rep op
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall rep op. (RepTypes rep, Ord op) => Eq (HostOp rep op)
forall rep op.
(RepTypes rep, Ord op) =>
HostOp rep op -> HostOp rep op -> Bool
forall rep op.
(RepTypes rep, Ord op) =>
HostOp rep op -> HostOp rep op -> Ordering
forall rep op.
(RepTypes rep, Ord op) =>
HostOp rep op -> HostOp rep op -> HostOp rep op
min :: HostOp rep op -> HostOp rep op -> HostOp rep op
$cmin :: forall rep op.
(RepTypes rep, Ord op) =>
HostOp rep op -> HostOp rep op -> HostOp rep op
max :: HostOp rep op -> HostOp rep op -> HostOp rep op
$cmax :: forall rep op.
(RepTypes rep, Ord op) =>
HostOp rep op -> HostOp rep op -> HostOp rep op
>= :: HostOp rep op -> HostOp rep op -> Bool
$c>= :: forall rep op.
(RepTypes rep, Ord op) =>
HostOp rep op -> HostOp rep op -> Bool
> :: HostOp rep op -> HostOp rep op -> Bool
$c> :: forall rep op.
(RepTypes rep, Ord op) =>
HostOp rep op -> HostOp rep op -> Bool
<= :: HostOp rep op -> HostOp rep op -> Bool
$c<= :: forall rep op.
(RepTypes rep, Ord op) =>
HostOp rep op -> HostOp rep op -> Bool
< :: HostOp rep op -> HostOp rep op -> Bool
$c< :: forall rep op.
(RepTypes rep, Ord op) =>
HostOp rep op -> HostOp rep op -> Bool
compare :: HostOp rep op -> HostOp rep op -> Ordering
$ccompare :: forall rep op.
(RepTypes rep, Ord op) =>
HostOp rep op -> HostOp rep op -> Ordering
$cp1Ord :: forall rep op. (RepTypes rep, Ord op) => Eq (HostOp rep op)
Ord, Int -> HostOp rep op -> ShowS
[HostOp rep op] -> ShowS
HostOp rep op -> String
(Int -> HostOp rep op -> ShowS)
-> (HostOp rep op -> String)
-> ([HostOp rep op] -> ShowS)
-> Show (HostOp rep op)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall rep op.
(RepTypes rep, Show op) =>
Int -> HostOp rep op -> ShowS
forall rep op. (RepTypes rep, Show op) => [HostOp rep op] -> ShowS
forall rep op. (RepTypes rep, Show op) => HostOp rep op -> String
showList :: [HostOp rep op] -> ShowS
$cshowList :: forall rep op. (RepTypes rep, Show op) => [HostOp rep op] -> ShowS
show :: HostOp rep op -> String
$cshow :: forall rep op. (RepTypes rep, Show op) => HostOp rep op -> String
showsPrec :: Int -> HostOp rep op -> ShowS
$cshowsPrec :: forall rep op.
(RepTypes rep, Show op) =>
Int -> HostOp rep op -> ShowS
Show)

instance (ASTRep rep, Substitute op) => Substitute (HostOp rep op) where
  substituteNames :: Map VName VName -> HostOp rep op -> HostOp rep op
substituteNames Map VName VName
substs (SegOp SegOp SegLevel rep
op) =
    SegOp SegLevel rep -> HostOp rep op
forall rep op. SegOp SegLevel rep -> HostOp rep op
SegOp (SegOp SegLevel rep -> HostOp rep op)
-> SegOp SegLevel rep -> HostOp rep op
forall a b. (a -> b) -> a -> b
$ Map VName VName -> SegOp SegLevel rep -> SegOp SegLevel rep
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs SegOp SegLevel rep
op
  substituteNames Map VName VName
substs (OtherOp op
op) =
    op -> HostOp rep op
forall rep op. op -> HostOp rep op
OtherOp (op -> HostOp rep op) -> op -> HostOp rep op
forall a b. (a -> b) -> a -> b
$ Map VName VName -> op -> op
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs op
op
  substituteNames Map VName VName
substs (SizeOp SizeOp
op) =
    SizeOp -> HostOp rep op
forall rep op. SizeOp -> HostOp rep op
SizeOp (SizeOp -> HostOp rep op) -> SizeOp -> HostOp rep op
forall a b. (a -> b) -> a -> b
$ Map VName VName -> SizeOp -> SizeOp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs SizeOp
op

instance (ASTRep rep, Rename op) => Rename (HostOp rep op) where
  rename :: HostOp rep op -> RenameM (HostOp rep op)
rename (SegOp SegOp SegLevel rep
op) = SegOp SegLevel rep -> HostOp rep op
forall rep op. SegOp SegLevel rep -> HostOp rep op
SegOp (SegOp SegLevel rep -> HostOp rep op)
-> RenameM (SegOp SegLevel rep) -> RenameM (HostOp rep op)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOp SegLevel rep -> RenameM (SegOp SegLevel rep)
forall a. Rename a => a -> RenameM a
rename SegOp SegLevel rep
op
  rename (OtherOp op
op) = op -> HostOp rep op
forall rep op. op -> HostOp rep op
OtherOp (op -> HostOp rep op) -> RenameM op -> RenameM (HostOp rep op)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> op -> RenameM op
forall a. Rename a => a -> RenameM a
rename op
op
  rename (SizeOp SizeOp
op) = SizeOp -> HostOp rep op
forall rep op. SizeOp -> HostOp rep op
SizeOp (SizeOp -> HostOp rep op)
-> RenameM SizeOp -> RenameM (HostOp rep op)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SizeOp -> RenameM SizeOp
forall a. Rename a => a -> RenameM a
rename SizeOp
op

instance (ASTRep rep, IsOp op) => IsOp (HostOp rep op) where
  safeOp :: HostOp rep op -> Bool
safeOp (SegOp SegOp SegLevel rep
op) = SegOp SegLevel rep -> Bool
forall op. IsOp op => op -> Bool
safeOp SegOp SegLevel rep
op
  safeOp (OtherOp op
op) = op -> Bool
forall op. IsOp op => op -> Bool
safeOp op
op
  safeOp (SizeOp SizeOp
op) = SizeOp -> Bool
forall op. IsOp op => op -> Bool
safeOp SizeOp
op

  cheapOp :: HostOp rep op -> Bool
cheapOp (SegOp SegOp SegLevel rep
op) = SegOp SegLevel rep -> Bool
forall op. IsOp op => op -> Bool
cheapOp SegOp SegLevel rep
op
  cheapOp (OtherOp op
op) = op -> Bool
forall op. IsOp op => op -> Bool
cheapOp op
op
  cheapOp (SizeOp SizeOp
op) = SizeOp -> Bool
forall op. IsOp op => op -> Bool
cheapOp SizeOp
op

instance TypedOp op => TypedOp (HostOp rep op) where
  opType :: HostOp rep op -> m [ExtType]
opType (SegOp SegOp SegLevel rep
op) = SegOp SegLevel rep -> m [ExtType]
forall op t (m :: * -> *).
(TypedOp op, HasScope t m) =>
op -> m [ExtType]
opType SegOp SegLevel rep
op
  opType (OtherOp op
op) = op -> m [ExtType]
forall op t (m :: * -> *).
(TypedOp op, HasScope t m) =>
op -> m [ExtType]
opType op
op
  opType (SizeOp SizeOp
op) = SizeOp -> m [ExtType]
forall op t (m :: * -> *).
(TypedOp op, HasScope t m) =>
op -> m [ExtType]
opType SizeOp
op

instance (Aliased rep, AliasedOp op, ASTRep rep) => AliasedOp (HostOp rep op) where
  opAliases :: HostOp rep op -> [Names]
opAliases (SegOp SegOp SegLevel rep
op) = SegOp SegLevel rep -> [Names]
forall op. AliasedOp op => op -> [Names]
opAliases SegOp SegLevel rep
op
  opAliases (OtherOp op
op) = op -> [Names]
forall op. AliasedOp op => op -> [Names]
opAliases op
op
  opAliases (SizeOp SizeOp
op) = SizeOp -> [Names]
forall op. AliasedOp op => op -> [Names]
opAliases SizeOp
op

  consumedInOp :: HostOp rep op -> Names
consumedInOp (SegOp SegOp SegLevel rep
op) = SegOp SegLevel rep -> Names
forall op. AliasedOp op => op -> Names
consumedInOp SegOp SegLevel rep
op
  consumedInOp (OtherOp op
op) = op -> Names
forall op. AliasedOp op => op -> Names
consumedInOp op
op
  consumedInOp (SizeOp SizeOp
op) = SizeOp -> Names
forall op. AliasedOp op => op -> Names
consumedInOp SizeOp
op

instance (ASTRep rep, FreeIn op) => FreeIn (HostOp rep op) where
  freeIn' :: HostOp rep op -> FV
freeIn' (SegOp SegOp SegLevel rep
op) = SegOp SegLevel rep -> FV
forall a. FreeIn a => a -> FV
freeIn' SegOp SegLevel rep
op
  freeIn' (OtherOp op
op) = op -> FV
forall a. FreeIn a => a -> FV
freeIn' op
op
  freeIn' (SizeOp SizeOp
op) = SizeOp -> FV
forall a. FreeIn a => a -> FV
freeIn' SizeOp
op

instance (CanBeAliased (Op rep), CanBeAliased op, ASTRep rep) => CanBeAliased (HostOp rep op) where
  type OpWithAliases (HostOp rep op) = HostOp (Aliases rep) (OpWithAliases op)

  addOpAliases :: AliasTable -> HostOp rep op -> OpWithAliases (HostOp rep op)
addOpAliases AliasTable
aliases (SegOp SegOp SegLevel rep
op) = SegOp SegLevel (Aliases rep)
-> HostOp (Aliases rep) (OpWithAliases op)
forall rep op. SegOp SegLevel rep -> HostOp rep op
SegOp (SegOp SegLevel (Aliases rep)
 -> HostOp (Aliases rep) (OpWithAliases op))
-> SegOp SegLevel (Aliases rep)
-> HostOp (Aliases rep) (OpWithAliases op)
forall a b. (a -> b) -> a -> b
$ AliasTable
-> SegOp SegLevel rep -> OpWithAliases (SegOp SegLevel rep)
forall op. CanBeAliased op => AliasTable -> op -> OpWithAliases op
addOpAliases AliasTable
aliases SegOp SegLevel rep
op
  addOpAliases AliasTable
aliases (OtherOp op
op) = OpWithAliases op -> HostOp (Aliases rep) (OpWithAliases op)
forall rep op. op -> HostOp rep op
OtherOp (OpWithAliases op -> HostOp (Aliases rep) (OpWithAliases op))
-> OpWithAliases op -> HostOp (Aliases rep) (OpWithAliases op)
forall a b. (a -> b) -> a -> b
$ AliasTable -> op -> OpWithAliases op
forall op. CanBeAliased op => AliasTable -> op -> OpWithAliases op
addOpAliases AliasTable
aliases op
op
  addOpAliases AliasTable
_ (SizeOp SizeOp
op) = SizeOp -> HostOp (Aliases rep) (OpWithAliases op)
forall rep op. SizeOp -> HostOp rep op
SizeOp SizeOp
op

  removeOpAliases :: OpWithAliases (HostOp rep op) -> HostOp rep op
removeOpAliases (SegOp op) = SegOp SegLevel rep -> HostOp rep op
forall rep op. SegOp SegLevel rep -> HostOp rep op
SegOp (SegOp SegLevel rep -> HostOp rep op)
-> SegOp SegLevel rep -> HostOp rep op
forall a b. (a -> b) -> a -> b
$ OpWithAliases (SegOp SegLevel rep) -> SegOp SegLevel rep
forall op. CanBeAliased op => OpWithAliases op -> op
removeOpAliases OpWithAliases (SegOp SegLevel rep)
SegOp SegLevel (Aliases rep)
op
  removeOpAliases (OtherOp op) = op -> HostOp rep op
forall rep op. op -> HostOp rep op
OtherOp (op -> HostOp rep op) -> op -> HostOp rep op
forall a b. (a -> b) -> a -> b
$ OpWithAliases op -> op
forall op. CanBeAliased op => OpWithAliases op -> op
removeOpAliases OpWithAliases op
op
  removeOpAliases (SizeOp op) = SizeOp -> HostOp rep op
forall rep op. SizeOp -> HostOp rep op
SizeOp SizeOp
op

instance (CanBeWise (Op rep), CanBeWise op, ASTRep rep) => CanBeWise (HostOp rep op) where
  type OpWithWisdom (HostOp rep op) = HostOp (Wise rep) (OpWithWisdom op)

  removeOpWisdom :: OpWithWisdom (HostOp rep op) -> HostOp rep op
removeOpWisdom (SegOp op) = SegOp SegLevel rep -> HostOp rep op
forall rep op. SegOp SegLevel rep -> HostOp rep op
SegOp (SegOp SegLevel rep -> HostOp rep op)
-> SegOp SegLevel rep -> HostOp rep op
forall a b. (a -> b) -> a -> b
$ OpWithWisdom (SegOp SegLevel rep) -> SegOp SegLevel rep
forall op. CanBeWise op => OpWithWisdom op -> op
removeOpWisdom OpWithWisdom (SegOp SegLevel rep)
SegOp SegLevel (Wise rep)
op
  removeOpWisdom (OtherOp op) = op -> HostOp rep op
forall rep op. op -> HostOp rep op
OtherOp (op -> HostOp rep op) -> op -> HostOp rep op
forall a b. (a -> b) -> a -> b
$ OpWithWisdom op -> op
forall op. CanBeWise op => OpWithWisdom op -> op
removeOpWisdom OpWithWisdom op
op
  removeOpWisdom (SizeOp op) = SizeOp -> HostOp rep op
forall rep op. SizeOp -> HostOp rep op
SizeOp SizeOp
op

instance (ASTRep rep, ST.IndexOp op) => ST.IndexOp (HostOp rep op) where
  indexOp :: SymbolTable rep
-> Int -> HostOp rep op -> [TPrimExp Int64 VName] -> Maybe Indexed
indexOp SymbolTable rep
vtable Int
k (SegOp SegOp SegLevel rep
op) [TPrimExp Int64 VName]
is = SymbolTable rep
-> Int
-> SegOp SegLevel rep
-> [TPrimExp Int64 VName]
-> Maybe Indexed
forall op rep.
(IndexOp op, ASTRep rep, IndexOp (Op rep)) =>
SymbolTable rep
-> Int -> op -> [TPrimExp Int64 VName] -> Maybe Indexed
ST.indexOp SymbolTable rep
vtable Int
k SegOp SegLevel rep
op [TPrimExp Int64 VName]
is
  indexOp SymbolTable rep
vtable Int
k (OtherOp op
op) [TPrimExp Int64 VName]
is = SymbolTable rep
-> Int -> op -> [TPrimExp Int64 VName] -> Maybe Indexed
forall op rep.
(IndexOp op, ASTRep rep, IndexOp (Op rep)) =>
SymbolTable rep
-> Int -> op -> [TPrimExp Int64 VName] -> Maybe Indexed
ST.indexOp SymbolTable rep
vtable Int
k op
op [TPrimExp Int64 VName]
is
  indexOp SymbolTable rep
_ Int
_ HostOp rep op
_ [TPrimExp Int64 VName]
_ = Maybe Indexed
forall a. Maybe a
Nothing

instance (PrettyRep rep, PP.Pretty op) => PP.Pretty (HostOp rep op) where
  ppr :: HostOp rep op -> Doc
ppr (SegOp SegOp SegLevel rep
op) = SegOp SegLevel rep -> Doc
forall a. Pretty a => a -> Doc
ppr SegOp SegLevel rep
op
  ppr (OtherOp op
op) = op -> Doc
forall a. Pretty a => a -> Doc
ppr op
op
  ppr (SizeOp SizeOp
op) = SizeOp -> Doc
forall a. Pretty a => a -> Doc
ppr SizeOp
op

instance (OpMetrics (Op rep), OpMetrics op) => OpMetrics (HostOp rep op) where
  opMetrics :: HostOp rep op -> MetricsM ()
opMetrics (SegOp SegOp SegLevel rep
op) = SegOp SegLevel rep -> MetricsM ()
forall op. OpMetrics op => op -> MetricsM ()
opMetrics SegOp SegLevel rep
op
  opMetrics (OtherOp op
op) = op -> MetricsM ()
forall op. OpMetrics op => op -> MetricsM ()
opMetrics op
op
  opMetrics (SizeOp SizeOp
op) = SizeOp -> MetricsM ()
forall op. OpMetrics op => op -> MetricsM ()
opMetrics SizeOp
op

checkSegLevel ::
  TC.Checkable rep =>
  Maybe SegLevel ->
  SegLevel ->
  TC.TypeM rep ()
checkSegLevel :: Maybe SegLevel -> SegLevel -> TypeM rep ()
checkSegLevel Maybe SegLevel
Nothing SegLevel
lvl = do
  [Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] (SubExp -> TypeM rep ()) -> SubExp -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ Count NumGroups SubExp -> SubExp
forall u e. Count u e -> e
unCount (Count NumGroups SubExp -> SubExp)
-> Count NumGroups SubExp -> SubExp
forall a b. (a -> b) -> a -> b
$ SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl
  [Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] (SubExp -> TypeM rep ()) -> SubExp -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ Count GroupSize SubExp -> SubExp
forall u e. Count u e -> e
unCount (Count GroupSize SubExp -> SubExp)
-> Count GroupSize SubExp -> SubExp
forall a b. (a -> b) -> a -> b
$ SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl
checkSegLevel (Just SegThread {}) SegLevel
_ =
  ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ()) -> ErrorCase rep -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ String -> ErrorCase rep
forall rep. String -> ErrorCase rep
TC.TypeError String
"SegOps cannot occur when already at thread level."
checkSegLevel (Just SegLevel
x) SegLevel
y
  | SegLevel
x SegLevel -> SegLevel -> Bool
forall a. Eq a => a -> a -> Bool
== SegLevel
y = ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ()) -> ErrorCase rep -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ String -> ErrorCase rep
forall rep. String -> ErrorCase rep
TC.TypeError (String -> ErrorCase rep) -> String -> ErrorCase rep
forall a b. (a -> b) -> a -> b
$ String
"Already at at level " String -> ShowS
forall a. [a] -> [a] -> [a]
++ SegLevel -> String
forall a. Pretty a => a -> String
pretty SegLevel
x
  | SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
x Count NumGroups SubExp -> Count NumGroups SubExp -> Bool
forall a. Eq a => a -> a -> Bool
/= SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
y Bool -> Bool -> Bool
|| SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
x Count GroupSize SubExp -> Count GroupSize SubExp -> Bool
forall a. Eq a => a -> a -> Bool
/= SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
y =
    ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ()) -> ErrorCase rep -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ String -> ErrorCase rep
forall rep. String -> ErrorCase rep
TC.TypeError String
"Physical layout for SegLevel does not match parent SegLevel."
  | Bool
otherwise =
    () -> TypeM rep ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

typeCheckHostOp ::
  TC.Checkable rep =>
  (SegLevel -> OpWithAliases (Op rep) -> TC.TypeM rep ()) ->
  Maybe SegLevel ->
  (op -> TC.TypeM rep ()) ->
  HostOp (Aliases rep) op ->
  TC.TypeM rep ()
typeCheckHostOp :: (SegLevel -> OpWithAliases (Op rep) -> TypeM rep ())
-> Maybe SegLevel
-> (op -> TypeM rep ())
-> HostOp (Aliases rep) op
-> TypeM rep ()
typeCheckHostOp SegLevel -> OpWithAliases (Op rep) -> TypeM rep ()
checker Maybe SegLevel
lvl op -> TypeM rep ()
_ (SegOp SegOp SegLevel (Aliases rep)
op) =
  (OpWithAliases (Op rep) -> TypeM rep ())
-> TypeM rep () -> TypeM rep ()
forall rep a.
(OpWithAliases (Op rep) -> TypeM rep ())
-> TypeM rep a -> TypeM rep a
TC.checkOpWith (SegLevel -> OpWithAliases (Op rep) -> TypeM rep ()
checker (SegLevel -> OpWithAliases (Op rep) -> TypeM rep ())
-> SegLevel -> OpWithAliases (Op rep) -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel (Aliases rep) -> SegLevel
forall lvl rep. SegOp lvl rep -> lvl
segLevel SegOp SegLevel (Aliases rep)
op) (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
    (SegLevel -> TypeM rep ())
-> SegOp SegLevel (Aliases rep) -> TypeM rep ()
forall rep lvl.
Checkable rep =>
(lvl -> TypeM rep ()) -> SegOp lvl (Aliases rep) -> TypeM rep ()
typeCheckSegOp (Maybe SegLevel -> SegLevel -> TypeM rep ()
forall rep.
Checkable rep =>
Maybe SegLevel -> SegLevel -> TypeM rep ()
checkSegLevel Maybe SegLevel
lvl) SegOp SegLevel (Aliases rep)
op
typeCheckHostOp SegLevel -> OpWithAliases (Op rep) -> TypeM rep ()
_ Maybe SegLevel
_ op -> TypeM rep ()
f (OtherOp op
op) = op -> TypeM rep ()
f op
op
typeCheckHostOp SegLevel -> OpWithAliases (Op rep) -> TypeM rep ()
_ Maybe SegLevel
_ op -> TypeM rep ()
_ (SizeOp SizeOp
op) = SizeOp -> TypeM rep ()
forall rep. Checkable rep => SizeOp -> TypeM rep ()
typeCheckSizeOp SizeOp
op