-- | In the context of this module, a "size" is any kind of tunable
-- (run-time) constant.
module Futhark.IR.GPU.Sizes
  ( SizeClass (..),
    sizeDefault,
    KernelPath,
    Count (..),
    NumGroups,
    GroupSize,
    NumThreads,
  )
where

import Data.Int (Int64)
import Data.Traversable
import Futhark.IR.Prop.Names (FreeIn)
import Futhark.Transform.Substitute
import Futhark.Util.IntegralExp (IntegralExp)
import Futhark.Util.Pretty
import Language.Futhark.Core (Name)
import Prelude hiding (id, (.))

-- | An indication of which comparisons have been performed to get to
-- this point, as well as the result of each comparison.
type KernelPath = [(Name, Bool)]

-- | The class of some kind of configurable size.  Each class may
-- impose constraints on the valid values.
data SizeClass
  = -- | A threshold with an optional default.
    SizeThreshold KernelPath (Maybe Int64)
  | SizeGroup
  | SizeNumGroups
  | SizeTile
  | SizeRegTile
  | -- | Likely not useful on its own, but querying the
    -- maximum can be handy.
    SizeLocalMemory
  | -- | A bespoke size with a default.
    SizeBespoke Name Int64
  deriving (SizeClass -> SizeClass -> Bool
(SizeClass -> SizeClass -> Bool)
-> (SizeClass -> SizeClass -> Bool) -> Eq SizeClass
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: SizeClass -> SizeClass -> Bool
== :: SizeClass -> SizeClass -> Bool
$c/= :: SizeClass -> SizeClass -> Bool
/= :: SizeClass -> SizeClass -> Bool
Eq, Eq SizeClass
Eq SizeClass
-> (SizeClass -> SizeClass -> Ordering)
-> (SizeClass -> SizeClass -> Bool)
-> (SizeClass -> SizeClass -> Bool)
-> (SizeClass -> SizeClass -> Bool)
-> (SizeClass -> SizeClass -> Bool)
-> (SizeClass -> SizeClass -> SizeClass)
-> (SizeClass -> SizeClass -> SizeClass)
-> Ord SizeClass
SizeClass -> SizeClass -> Bool
SizeClass -> SizeClass -> Ordering
SizeClass -> SizeClass -> SizeClass
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: SizeClass -> SizeClass -> Ordering
compare :: SizeClass -> SizeClass -> Ordering
$c< :: SizeClass -> SizeClass -> Bool
< :: SizeClass -> SizeClass -> Bool
$c<= :: SizeClass -> SizeClass -> Bool
<= :: SizeClass -> SizeClass -> Bool
$c> :: SizeClass -> SizeClass -> Bool
> :: SizeClass -> SizeClass -> Bool
$c>= :: SizeClass -> SizeClass -> Bool
>= :: SizeClass -> SizeClass -> Bool
$cmax :: SizeClass -> SizeClass -> SizeClass
max :: SizeClass -> SizeClass -> SizeClass
$cmin :: SizeClass -> SizeClass -> SizeClass
min :: SizeClass -> SizeClass -> SizeClass
Ord, Int -> SizeClass -> ShowS
[SizeClass] -> ShowS
SizeClass -> String
(Int -> SizeClass -> ShowS)
-> (SizeClass -> String)
-> ([SizeClass] -> ShowS)
-> Show SizeClass
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> SizeClass -> ShowS
showsPrec :: Int -> SizeClass -> ShowS
$cshow :: SizeClass -> String
show :: SizeClass -> String
$cshowList :: [SizeClass] -> ShowS
showList :: [SizeClass] -> ShowS
Show)

instance Pretty SizeClass where
  pretty :: forall ann. SizeClass -> Doc ann
pretty (SizeThreshold KernelPath
path Maybe Int64
def) =
    Doc ann
"threshold" Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
parens (Doc ann
forall {ann}. Doc ann
def' Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall {ann}. Doc ann
comma Doc ann -> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann -> Doc ann
<+> [Doc ann] -> Doc ann
forall ann. [Doc ann] -> Doc ann
hsep (((Name, Bool) -> Doc ann) -> KernelPath -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map (Name, Bool) -> Doc ann
forall {a} {ann}. Pretty a => (a, Bool) -> Doc ann
pStep KernelPath
path))
    where
      pStep :: (a, Bool) -> Doc ann
pStep (a
v, Bool
True) = a -> Doc ann
forall ann. a -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty a
v
      pStep (a
v, Bool
False) = Doc ann
"!" Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> a -> Doc ann
forall ann. a -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty a
v
      def' :: Doc ann
def' = Doc ann -> (Int64 -> Doc ann) -> Maybe Int64 -> Doc ann
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Doc ann
"def" Int64 -> Doc ann
forall ann. Int64 -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty Maybe Int64
def
  pretty SizeClass
SizeGroup = Doc ann
"group_size"
  pretty SizeClass
SizeNumGroups = Doc ann
"num_groups"
  pretty SizeClass
SizeTile = Doc ann
"tile_size"
  pretty SizeClass
SizeRegTile = Doc ann
"reg_tile_size"
  pretty SizeClass
SizeLocalMemory = Doc ann
"local_memory"
  pretty (SizeBespoke Name
k Int64
def) =
    Doc ann
"bespoke" Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
parens (Name -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. Name -> Doc ann
pretty Name
k Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall {ann}. Doc ann
comma Doc ann -> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann -> Doc ann
<+> Int64 -> Doc ann
forall ann. Int64 -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty Int64
def)

-- | The default value for the size.  If 'Nothing', that means the backend gets to decide.
sizeDefault :: SizeClass -> Maybe Int64
sizeDefault :: SizeClass -> Maybe Int64
sizeDefault (SizeThreshold KernelPath
_ Maybe Int64
x) = Maybe Int64
x
sizeDefault (SizeBespoke Name
_ Int64
x) = Int64 -> Maybe Int64
forall a. a -> Maybe a
Just Int64
x
sizeDefault SizeClass
_ = Maybe Int64
forall a. Maybe a
Nothing

-- | A wrapper supporting a phantom type for indicating what we are counting.
newtype Count u e = Count {forall {k} (u :: k) e. Count u e -> e
unCount :: e}
  deriving (Count u e -> Count u e -> Bool
(Count u e -> Count u e -> Bool)
-> (Count u e -> Count u e -> Bool) -> Eq (Count u e)
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall k (u :: k) e. Eq e => Count u e -> Count u e -> Bool
$c== :: forall k (u :: k) e. Eq e => Count u e -> Count u e -> Bool
== :: Count u e -> Count u e -> Bool
$c/= :: forall k (u :: k) e. Eq e => Count u e -> Count u e -> Bool
/= :: Count u e -> Count u e -> Bool
Eq, Eq (Count u e)
Eq (Count u e)
-> (Count u e -> Count u e -> Ordering)
-> (Count u e -> Count u e -> Bool)
-> (Count u e -> Count u e -> Bool)
-> (Count u e -> Count u e -> Bool)
-> (Count u e -> Count u e -> Bool)
-> (Count u e -> Count u e -> Count u e)
-> (Count u e -> Count u e -> Count u e)
-> Ord (Count u e)
Count u e -> Count u e -> Bool
Count u e -> Count u e -> Ordering
Count u e -> Count u e -> Count u e
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {k} {u :: k} {e}. Ord e => Eq (Count u e)
forall k (u :: k) e. Ord e => Count u e -> Count u e -> Bool
forall k (u :: k) e. Ord e => Count u e -> Count u e -> Ordering
forall k (u :: k) e. Ord e => Count u e -> Count u e -> Count u e
$ccompare :: forall k (u :: k) e. Ord e => Count u e -> Count u e -> Ordering
compare :: Count u e -> Count u e -> Ordering
$c< :: forall k (u :: k) e. Ord e => Count u e -> Count u e -> Bool
< :: Count u e -> Count u e -> Bool
$c<= :: forall k (u :: k) e. Ord e => Count u e -> Count u e -> Bool
<= :: Count u e -> Count u e -> Bool
$c> :: forall k (u :: k) e. Ord e => Count u e -> Count u e -> Bool
> :: Count u e -> Count u e -> Bool
$c>= :: forall k (u :: k) e. Ord e => Count u e -> Count u e -> Bool
>= :: Count u e -> Count u e -> Bool
$cmax :: forall k (u :: k) e. Ord e => Count u e -> Count u e -> Count u e
max :: Count u e -> Count u e -> Count u e
$cmin :: forall k (u :: k) e. Ord e => Count u e -> Count u e -> Count u e
min :: Count u e -> Count u e -> Count u e
Ord, Int -> Count u e -> ShowS
[Count u e] -> ShowS
Count u e -> String
(Int -> Count u e -> ShowS)
-> (Count u e -> String)
-> ([Count u e] -> ShowS)
-> Show (Count u e)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall k (u :: k) e. Show e => Int -> Count u e -> ShowS
forall k (u :: k) e. Show e => [Count u e] -> ShowS
forall k (u :: k) e. Show e => Count u e -> String
$cshowsPrec :: forall k (u :: k) e. Show e => Int -> Count u e -> ShowS
showsPrec :: Int -> Count u e -> ShowS
$cshow :: forall k (u :: k) e. Show e => Count u e -> String
show :: Count u e -> String
$cshowList :: forall k (u :: k) e. Show e => [Count u e] -> ShowS
showList :: [Count u e] -> ShowS
Show, Integer -> Count u e
Count u e -> Count u e
Count u e -> Count u e -> Count u e
(Count u e -> Count u e -> Count u e)
-> (Count u e -> Count u e -> Count u e)
-> (Count u e -> Count u e -> Count u e)
-> (Count u e -> Count u e)
-> (Count u e -> Count u e)
-> (Count u e -> Count u e)
-> (Integer -> Count u e)
-> Num (Count u e)
forall a.
(a -> a -> a)
-> (a -> a -> a)
-> (a -> a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (Integer -> a)
-> Num a
forall k (u :: k) e. Num e => Integer -> Count u e
forall k (u :: k) e. Num e => Count u e -> Count u e
forall k (u :: k) e. Num e => Count u e -> Count u e -> Count u e
$c+ :: forall k (u :: k) e. Num e => Count u e -> Count u e -> Count u e
+ :: Count u e -> Count u e -> Count u e
$c- :: forall k (u :: k) e. Num e => Count u e -> Count u e -> Count u e
- :: Count u e -> Count u e -> Count u e
$c* :: forall k (u :: k) e. Num e => Count u e -> Count u e -> Count u e
* :: Count u e -> Count u e -> Count u e
$cnegate :: forall k (u :: k) e. Num e => Count u e -> Count u e
negate :: Count u e -> Count u e
$cabs :: forall k (u :: k) e. Num e => Count u e -> Count u e
abs :: Count u e -> Count u e
$csignum :: forall k (u :: k) e. Num e => Count u e -> Count u e
signum :: Count u e -> Count u e
$cfromInteger :: forall k (u :: k) e. Num e => Integer -> Count u e
fromInteger :: Integer -> Count u e
Num, Num (Count u e)
Num (Count u e)
-> (Count u e -> Count u e -> Count u e)
-> (Count u e -> Count u e -> Count u e)
-> (Count u e -> Count u e -> Count u e)
-> (Count u e -> Count u e -> Count u e)
-> (Count u e -> Maybe Int)
-> (Count u e -> Count u e -> Count u e)
-> (Count u e -> Count u e -> Count u e)
-> IntegralExp (Count u e)
Count u e -> Maybe Int
Count u e -> Count u e -> Count u e
forall e.
Num e
-> (e -> e -> e)
-> (e -> e -> e)
-> (e -> e -> e)
-> (e -> e -> e)
-> (e -> Maybe Int)
-> (e -> e -> e)
-> (e -> e -> e)
-> IntegralExp e
forall {k} {u :: k} {e}. IntegralExp e => Num (Count u e)
forall k (u :: k) e. IntegralExp e => Count u e -> Maybe Int
forall k (u :: k) e.
IntegralExp e =>
Count u e -> Count u e -> Count u e
$cquot :: forall k (u :: k) e.
IntegralExp e =>
Count u e -> Count u e -> Count u e
quot :: Count u e -> Count u e -> Count u e
$crem :: forall k (u :: k) e.
IntegralExp e =>
Count u e -> Count u e -> Count u e
rem :: Count u e -> Count u e -> Count u e
$cdiv :: forall k (u :: k) e.
IntegralExp e =>
Count u e -> Count u e -> Count u e
div :: Count u e -> Count u e -> Count u e
$cmod :: forall k (u :: k) e.
IntegralExp e =>
Count u e -> Count u e -> Count u e
mod :: Count u e -> Count u e -> Count u e
$csgn :: forall k (u :: k) e. IntegralExp e => Count u e -> Maybe Int
sgn :: Count u e -> Maybe Int
$cpow :: forall k (u :: k) e.
IntegralExp e =>
Count u e -> Count u e -> Count u e
pow :: Count u e -> Count u e -> Count u e
$cdivUp :: forall k (u :: k) e.
IntegralExp e =>
Count u e -> Count u e -> Count u e
divUp :: Count u e -> Count u e -> Count u e
IntegralExp, Count u e -> FV
(Count u e -> FV) -> FreeIn (Count u e)
forall a. (a -> FV) -> FreeIn a
forall k (u :: k) e. FreeIn e => Count u e -> FV
$cfreeIn' :: forall k (u :: k) e. FreeIn e => Count u e -> FV
freeIn' :: Count u e -> FV
FreeIn, (forall ann. Count u e -> Doc ann)
-> (forall ann. [Count u e] -> Doc ann) -> Pretty (Count u e)
forall ann. [Count u e] -> Doc ann
forall ann. Count u e -> Doc ann
forall a.
(forall ann. a -> Doc ann)
-> (forall ann. [a] -> Doc ann) -> Pretty a
forall k (u :: k) e ann. Pretty e => [Count u e] -> Doc ann
forall k (u :: k) e ann. Pretty e => Count u e -> Doc ann
$cpretty :: forall k (u :: k) e ann. Pretty e => Count u e -> Doc ann
pretty :: forall ann. Count u e -> Doc ann
$cprettyList :: forall k (u :: k) e ann. Pretty e => [Count u e] -> Doc ann
prettyList :: forall ann. [Count u e] -> Doc ann
Pretty, Map VName VName -> Count u e -> Count u e
(Map VName VName -> Count u e -> Count u e)
-> Substitute (Count u e)
forall a. (Map VName VName -> a -> a) -> Substitute a
forall k (u :: k) e.
Substitute e =>
Map VName VName -> Count u e -> Count u e
$csubstituteNames :: forall k (u :: k) e.
Substitute e =>
Map VName VName -> Count u e -> Count u e
substituteNames :: Map VName VName -> Count u e -> Count u e
Substitute)

instance Functor (Count u) where
  fmap :: forall a b. (a -> b) -> Count u a -> Count u b
fmap = (a -> b) -> Count u a -> Count u b
forall (t :: * -> *) a b. Traversable t => (a -> b) -> t a -> t b
fmapDefault

instance Foldable (Count u) where
  foldMap :: forall m a. Monoid m => (a -> m) -> Count u a -> m
foldMap = (a -> m) -> Count u a -> m
forall (t :: * -> *) m a.
(Traversable t, Monoid m) =>
(a -> m) -> t a -> m
foldMapDefault

instance Traversable (Count u) where
  traverse :: forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> Count u a -> f (Count u b)
traverse a -> f b
f (Count a
x) = b -> Count u b
forall {k} (u :: k) e. e -> Count u e
Count (b -> Count u b) -> f b -> f (Count u b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> a -> f b
f a
x

-- | Phantom type for the number of groups of some kernel.
data NumGroups

-- | Phantom type for the group size of some kernel.
data GroupSize

-- | Phantom type for number of threads.
data NumThreads