-- | In the context of this module, a "size" is any kind of tunable
-- (run-time) constant.
module Futhark.IR.GPU.Sizes
  ( SizeClass (..),
    sizeDefault,
    KernelPath,
    Count (..),
    NumGroups,
    GroupSize,
    NumThreads,
  )
where

import Data.Int (Int64)
import Data.Traversable
import Futhark.IR.Prop.Names (FreeIn)
import Futhark.Transform.Substitute
import Futhark.Util.IntegralExp (IntegralExp)
import Futhark.Util.Pretty
import Language.Futhark.Core (Name)
import Prelude hiding (id, (.))

-- | An indication of which comparisons have been performed to get to
-- this point, as well as the result of each comparison.
type KernelPath = [(Name, Bool)]

-- | The class of some kind of configurable size.  Each class may
-- impose constraints on the valid values.
data SizeClass
  = -- | A threshold with an optional default.
    SizeThreshold KernelPath (Maybe Int64)
  | SizeGroup
  | SizeNumGroups
  | SizeTile
  | SizeRegTile
  | -- | Likely not useful on its own, but querying the
    -- maximum can be handy.
    SizeLocalMemory
  | -- | A bespoke size with a default.
    SizeBespoke Name Int64
  deriving (SizeClass -> SizeClass -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SizeClass -> SizeClass -> Bool
$c/= :: SizeClass -> SizeClass -> Bool
== :: SizeClass -> SizeClass -> Bool
$c== :: SizeClass -> SizeClass -> Bool
Eq, Eq SizeClass
SizeClass -> SizeClass -> Bool
SizeClass -> SizeClass -> Ordering
SizeClass -> SizeClass -> SizeClass
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: SizeClass -> SizeClass -> SizeClass
$cmin :: SizeClass -> SizeClass -> SizeClass
max :: SizeClass -> SizeClass -> SizeClass
$cmax :: SizeClass -> SizeClass -> SizeClass
>= :: SizeClass -> SizeClass -> Bool
$c>= :: SizeClass -> SizeClass -> Bool
> :: SizeClass -> SizeClass -> Bool
$c> :: SizeClass -> SizeClass -> Bool
<= :: SizeClass -> SizeClass -> Bool
$c<= :: SizeClass -> SizeClass -> Bool
< :: SizeClass -> SizeClass -> Bool
$c< :: SizeClass -> SizeClass -> Bool
compare :: SizeClass -> SizeClass -> Ordering
$ccompare :: SizeClass -> SizeClass -> Ordering
Ord, Int -> SizeClass -> ShowS
[SizeClass] -> ShowS
SizeClass -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SizeClass] -> ShowS
$cshowList :: [SizeClass] -> ShowS
show :: SizeClass -> String
$cshow :: SizeClass -> String
showsPrec :: Int -> SizeClass -> ShowS
$cshowsPrec :: Int -> SizeClass -> ShowS
Show)

instance Pretty SizeClass where
  pretty :: forall ann. SizeClass -> Doc ann
pretty (SizeThreshold KernelPath
path Maybe Int64
def) =
    Doc ann
"threshold" forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann -> Doc ann
parens (forall {ann}. Doc ann
def' forall a. Semigroup a => a -> a -> a
<> forall {ann}. Doc ann
comma forall ann. Doc ann -> Doc ann -> Doc ann
<+> forall ann. [Doc ann] -> Doc ann
hsep (forall a b. (a -> b) -> [a] -> [b]
map forall {a} {ann}. Pretty a => (a, Bool) -> Doc ann
pStep KernelPath
path))
    where
      pStep :: (a, Bool) -> Doc ann
pStep (a
v, Bool
True) = forall a ann. Pretty a => a -> Doc ann
pretty a
v
      pStep (a
v, Bool
False) = Doc ann
"!" forall a. Semigroup a => a -> a -> a
<> forall a ann. Pretty a => a -> Doc ann
pretty a
v
      def' :: Doc ann
def' = forall b a. b -> (a -> b) -> Maybe a -> b
maybe Doc ann
"def" forall a ann. Pretty a => a -> Doc ann
pretty Maybe Int64
def
  pretty SizeClass
SizeGroup = Doc ann
"group_size"
  pretty SizeClass
SizeNumGroups = Doc ann
"num_groups"
  pretty SizeClass
SizeTile = Doc ann
"tile_size"
  pretty SizeClass
SizeRegTile = Doc ann
"reg_tile_size"
  pretty SizeClass
SizeLocalMemory = Doc ann
"local_memory"
  pretty (SizeBespoke Name
k Int64
def) =
    Doc ann
"bespoke" forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann -> Doc ann
parens (forall a ann. Pretty a => a -> Doc ann
pretty Name
k forall a. Semigroup a => a -> a -> a
<> forall {ann}. Doc ann
comma forall ann. Doc ann -> Doc ann -> Doc ann
<+> forall a ann. Pretty a => a -> Doc ann
pretty Int64
def)

-- | The default value for the size.  If 'Nothing', that means the backend gets to decide.
sizeDefault :: SizeClass -> Maybe Int64
sizeDefault :: SizeClass -> Maybe Int64
sizeDefault (SizeThreshold KernelPath
_ Maybe Int64
x) = Maybe Int64
x
sizeDefault (SizeBespoke Name
_ Int64
x) = forall a. a -> Maybe a
Just Int64
x
sizeDefault SizeClass
_ = forall a. Maybe a
Nothing

-- | A wrapper supporting a phantom type for indicating what we are counting.
newtype Count u e = Count {forall {k} (u :: k) e. Count u e -> e
unCount :: e}
  deriving (Count u e -> Count u e -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall k (u :: k) e. Eq e => Count u e -> Count u e -> Bool
/= :: Count u e -> Count u e -> Bool
$c/= :: forall k (u :: k) e. Eq e => Count u e -> Count u e -> Bool
== :: Count u e -> Count u e -> Bool
$c== :: forall k (u :: k) e. Eq e => Count u e -> Count u e -> Bool
Eq, Count u e -> Count u e -> Bool
Count u e -> Count u e -> Ordering
Count u e -> Count u e -> Count u e
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {k} {u :: k} {e}. Ord e => Eq (Count u e)
forall k (u :: k) e. Ord e => Count u e -> Count u e -> Bool
forall k (u :: k) e. Ord e => Count u e -> Count u e -> Ordering
forall k (u :: k) e. Ord e => Count u e -> Count u e -> Count u e
min :: Count u e -> Count u e -> Count u e
$cmin :: forall k (u :: k) e. Ord e => Count u e -> Count u e -> Count u e
max :: Count u e -> Count u e -> Count u e
$cmax :: forall k (u :: k) e. Ord e => Count u e -> Count u e -> Count u e
>= :: Count u e -> Count u e -> Bool
$c>= :: forall k (u :: k) e. Ord e => Count u e -> Count u e -> Bool
> :: Count u e -> Count u e -> Bool
$c> :: forall k (u :: k) e. Ord e => Count u e -> Count u e -> Bool
<= :: Count u e -> Count u e -> Bool
$c<= :: forall k (u :: k) e. Ord e => Count u e -> Count u e -> Bool
< :: Count u e -> Count u e -> Bool
$c< :: forall k (u :: k) e. Ord e => Count u e -> Count u e -> Bool
compare :: Count u e -> Count u e -> Ordering
$ccompare :: forall k (u :: k) e. Ord e => Count u e -> Count u e -> Ordering
Ord, Int -> Count u e -> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall k (u :: k) e. Show e => Int -> Count u e -> ShowS
forall k (u :: k) e. Show e => [Count u e] -> ShowS
forall k (u :: k) e. Show e => Count u e -> String
showList :: [Count u e] -> ShowS
$cshowList :: forall k (u :: k) e. Show e => [Count u e] -> ShowS
show :: Count u e -> String
$cshow :: forall k (u :: k) e. Show e => Count u e -> String
showsPrec :: Int -> Count u e -> ShowS
$cshowsPrec :: forall k (u :: k) e. Show e => Int -> Count u e -> ShowS
Show, Integer -> Count u e
Count u e -> Count u e
Count u e -> Count u e -> Count u e
forall a.
(a -> a -> a)
-> (a -> a -> a)
-> (a -> a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (Integer -> a)
-> Num a
forall k (u :: k) e. Num e => Integer -> Count u e
forall k (u :: k) e. Num e => Count u e -> Count u e
forall k (u :: k) e. Num e => Count u e -> Count u e -> Count u e
fromInteger :: Integer -> Count u e
$cfromInteger :: forall k (u :: k) e. Num e => Integer -> Count u e
signum :: Count u e -> Count u e
$csignum :: forall k (u :: k) e. Num e => Count u e -> Count u e
abs :: Count u e -> Count u e
$cabs :: forall k (u :: k) e. Num e => Count u e -> Count u e
negate :: Count u e -> Count u e
$cnegate :: forall k (u :: k) e. Num e => Count u e -> Count u e
* :: Count u e -> Count u e -> Count u e
$c* :: forall k (u :: k) e. Num e => Count u e -> Count u e -> Count u e
- :: Count u e -> Count u e -> Count u e
$c- :: forall k (u :: k) e. Num e => Count u e -> Count u e -> Count u e
+ :: Count u e -> Count u e -> Count u e
$c+ :: forall k (u :: k) e. Num e => Count u e -> Count u e -> Count u e
Num, Count u e -> Maybe Int
Count u e -> Count u e -> Count u e
forall e.
Num e
-> (e -> e -> e)
-> (e -> e -> e)
-> (e -> e -> e)
-> (e -> e -> e)
-> (e -> Maybe Int)
-> (e -> e -> e)
-> (e -> e -> e)
-> IntegralExp e
forall {k} {u :: k} {e}. IntegralExp e => Num (Count u e)
forall k (u :: k) e. IntegralExp e => Count u e -> Maybe Int
forall k (u :: k) e.
IntegralExp e =>
Count u e -> Count u e -> Count u e
divUp :: Count u e -> Count u e -> Count u e
$cdivUp :: forall k (u :: k) e.
IntegralExp e =>
Count u e -> Count u e -> Count u e
pow :: Count u e -> Count u e -> Count u e
$cpow :: forall k (u :: k) e.
IntegralExp e =>
Count u e -> Count u e -> Count u e
sgn :: Count u e -> Maybe Int
$csgn :: forall k (u :: k) e. IntegralExp e => Count u e -> Maybe Int
mod :: Count u e -> Count u e -> Count u e
$cmod :: forall k (u :: k) e.
IntegralExp e =>
Count u e -> Count u e -> Count u e
div :: Count u e -> Count u e -> Count u e
$cdiv :: forall k (u :: k) e.
IntegralExp e =>
Count u e -> Count u e -> Count u e
rem :: Count u e -> Count u e -> Count u e
$crem :: forall k (u :: k) e.
IntegralExp e =>
Count u e -> Count u e -> Count u e
quot :: Count u e -> Count u e -> Count u e
$cquot :: forall k (u :: k) e.
IntegralExp e =>
Count u e -> Count u e -> Count u e
IntegralExp, Count u e -> FV
forall a. (a -> FV) -> FreeIn a
forall k (u :: k) e. FreeIn e => Count u e -> FV
freeIn' :: Count u e -> FV
$cfreeIn' :: forall k (u :: k) e. FreeIn e => Count u e -> FV
FreeIn, forall ann. [Count u e] -> Doc ann
forall ann. Count u e -> Doc ann
forall a.
(forall ann. a -> Doc ann)
-> (forall ann. [a] -> Doc ann) -> Pretty a
forall k (u :: k) e ann. Pretty e => [Count u e] -> Doc ann
forall k (u :: k) e ann. Pretty e => Count u e -> Doc ann
prettyList :: forall ann. [Count u e] -> Doc ann
$cprettyList :: forall k (u :: k) e ann. Pretty e => [Count u e] -> Doc ann
pretty :: forall ann. Count u e -> Doc ann
$cpretty :: forall k (u :: k) e ann. Pretty e => Count u e -> Doc ann
Pretty, Map VName VName -> Count u e -> Count u e
forall a. (Map VName VName -> a -> a) -> Substitute a
forall k (u :: k) e.
Substitute e =>
Map VName VName -> Count u e -> Count u e
substituteNames :: Map VName VName -> Count u e -> Count u e
$csubstituteNames :: forall k (u :: k) e.
Substitute e =>
Map VName VName -> Count u e -> Count u e
Substitute)

instance Functor (Count u) where
  fmap :: forall a b. (a -> b) -> Count u a -> Count u b
fmap = forall (t :: * -> *) a b. Traversable t => (a -> b) -> t a -> t b
fmapDefault

instance Foldable (Count u) where
  foldMap :: forall m a. Monoid m => (a -> m) -> Count u a -> m
foldMap = forall (t :: * -> *) m a.
(Traversable t, Monoid m) =>
(a -> m) -> t a -> m
foldMapDefault

instance Traversable (Count u) where
  traverse :: forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> Count u a -> f (Count u b)
traverse a -> f b
f (Count a
x) = forall {k} (u :: k) e. e -> Count u e
Count forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> a -> f b
f a
x

-- | Phantom type for the number of groups of some kernel.
data NumGroups

-- | Phantom type for the group size of some kernel.
data GroupSize

-- | Phantom type for number of threads.
data NumThreads