{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE Trustworthy #-}
-- | In the context of this module, a "size" is any kind of tunable
-- (run-time) constant.
module Futhark.IR.Kernels.Sizes
  ( SizeClass (..)
  , sizeDefault
  , KernelPath
  , Count(..)
  , NumGroups, GroupSize, NumThreads
  )
  where

import Data.Int (Int32)
import Data.Traversable

import Futhark.Util.Pretty
import Futhark.Transform.Substitute
import Language.Futhark.Core (Name)
import Futhark.Util.IntegralExp (IntegralExp)
import Futhark.IR.Prop.Names (FreeIn)

-- | An indication of which comparisons have been performed to get to
-- this point, as well as the result of each comparison.
type KernelPath = [(Name, Bool)]

-- | The class of some kind of configurable size.  Each class may
-- impose constraints on the valid values.
data SizeClass = SizeThreshold KernelPath (Maybe Int32)
                 -- ^ A threshold with an optional default.
               | SizeGroup
               | SizeNumGroups
               | SizeTile
               | SizeLocalMemory
               -- ^ Likely not useful on its own, but querying the
               -- maximum can be handy.
               | SizeBespoke Name Int32
               -- ^ A bespoke size with a default.
               deriving (SizeClass -> SizeClass -> Bool
(SizeClass -> SizeClass -> Bool)
-> (SizeClass -> SizeClass -> Bool) -> Eq SizeClass
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SizeClass -> SizeClass -> Bool
$c/= :: SizeClass -> SizeClass -> Bool
== :: SizeClass -> SizeClass -> Bool
$c== :: SizeClass -> SizeClass -> Bool
Eq, Eq SizeClass
Eq SizeClass
-> (SizeClass -> SizeClass -> Ordering)
-> (SizeClass -> SizeClass -> Bool)
-> (SizeClass -> SizeClass -> Bool)
-> (SizeClass -> SizeClass -> Bool)
-> (SizeClass -> SizeClass -> Bool)
-> (SizeClass -> SizeClass -> SizeClass)
-> (SizeClass -> SizeClass -> SizeClass)
-> Ord SizeClass
SizeClass -> SizeClass -> Bool
SizeClass -> SizeClass -> Ordering
SizeClass -> SizeClass -> SizeClass
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: SizeClass -> SizeClass -> SizeClass
$cmin :: SizeClass -> SizeClass -> SizeClass
max :: SizeClass -> SizeClass -> SizeClass
$cmax :: SizeClass -> SizeClass -> SizeClass
>= :: SizeClass -> SizeClass -> Bool
$c>= :: SizeClass -> SizeClass -> Bool
> :: SizeClass -> SizeClass -> Bool
$c> :: SizeClass -> SizeClass -> Bool
<= :: SizeClass -> SizeClass -> Bool
$c<= :: SizeClass -> SizeClass -> Bool
< :: SizeClass -> SizeClass -> Bool
$c< :: SizeClass -> SizeClass -> Bool
compare :: SizeClass -> SizeClass -> Ordering
$ccompare :: SizeClass -> SizeClass -> Ordering
$cp1Ord :: Eq SizeClass
Ord, Int -> SizeClass -> ShowS
[SizeClass] -> ShowS
SizeClass -> String
(Int -> SizeClass -> ShowS)
-> (SizeClass -> String)
-> ([SizeClass] -> ShowS)
-> Show SizeClass
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SizeClass] -> ShowS
$cshowList :: [SizeClass] -> ShowS
show :: SizeClass -> String
$cshow :: SizeClass -> String
showsPrec :: Int -> SizeClass -> ShowS
$cshowsPrec :: Int -> SizeClass -> ShowS
Show)

instance Pretty SizeClass where
  ppr :: SizeClass -> Doc
ppr (SizeThreshold KernelPath
path Maybe Int32
_) = String -> Doc
text (String -> Doc) -> String -> Doc
forall a b. (a -> b) -> a -> b
$ String
"threshold (" String -> ShowS
forall a. [a] -> [a] -> [a]
++ [String] -> String
unwords (((Name, Bool) -> String) -> KernelPath -> [String]
forall a b. (a -> b) -> [a] -> [b]
map (Name, Bool) -> String
forall a. Pretty a => (a, Bool) -> String
pStep KernelPath
path) String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
")"
    where pStep :: (a, Bool) -> String
pStep (a
v, Bool
True) = a -> String
forall a. Pretty a => a -> String
pretty a
v
          pStep (a
v, Bool
False) = Char
'!' Char -> ShowS
forall a. a -> [a] -> [a]
: a -> String
forall a. Pretty a => a -> String
pretty a
v
  ppr SizeClass
SizeGroup = String -> Doc
text String
"group_size"
  ppr SizeClass
SizeNumGroups = String -> Doc
text String
"num_groups"
  ppr SizeClass
SizeTile = String -> Doc
text String
"tile_size"
  ppr SizeClass
SizeLocalMemory = String -> Doc
text String
"local_memory"
  ppr (SizeBespoke Name
k Int32
_) = Name -> Doc
forall a. Pretty a => a -> Doc
ppr Name
k

-- | The default value for the size.  If 'Nothing', that means the backend gets to decide.
sizeDefault :: SizeClass -> Maybe Int32
sizeDefault :: SizeClass -> Maybe Int32
sizeDefault (SizeThreshold KernelPath
_ Maybe Int32
x) = Maybe Int32
x
sizeDefault (SizeBespoke Name
_ Int32
x) = Int32 -> Maybe Int32
forall a. a -> Maybe a
Just Int32
x
sizeDefault SizeClass
_ = Maybe Int32
forall a. Maybe a
Nothing

-- | A wrapper supporting a phantom type for indicating what we are counting.
newtype Count u e = Count { Count u e -> e
unCount :: e }
  deriving (Count u e -> Count u e -> Bool
(Count u e -> Count u e -> Bool)
-> (Count u e -> Count u e -> Bool) -> Eq (Count u e)
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall u e. Eq e => Count u e -> Count u e -> Bool
/= :: Count u e -> Count u e -> Bool
$c/= :: forall u e. Eq e => Count u e -> Count u e -> Bool
== :: Count u e -> Count u e -> Bool
$c== :: forall u e. Eq e => Count u e -> Count u e -> Bool
Eq, Eq (Count u e)
Eq (Count u e)
-> (Count u e -> Count u e -> Ordering)
-> (Count u e -> Count u e -> Bool)
-> (Count u e -> Count u e -> Bool)
-> (Count u e -> Count u e -> Bool)
-> (Count u e -> Count u e -> Bool)
-> (Count u e -> Count u e -> Count u e)
-> (Count u e -> Count u e -> Count u e)
-> Ord (Count u e)
Count u e -> Count u e -> Bool
Count u e -> Count u e -> Ordering
Count u e -> Count u e -> Count u e
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall u e. Ord e => Eq (Count u e)
forall u e. Ord e => Count u e -> Count u e -> Bool
forall u e. Ord e => Count u e -> Count u e -> Ordering
forall u e. Ord e => Count u e -> Count u e -> Count u e
min :: Count u e -> Count u e -> Count u e
$cmin :: forall u e. Ord e => Count u e -> Count u e -> Count u e
max :: Count u e -> Count u e -> Count u e
$cmax :: forall u e. Ord e => Count u e -> Count u e -> Count u e
>= :: Count u e -> Count u e -> Bool
$c>= :: forall u e. Ord e => Count u e -> Count u e -> Bool
> :: Count u e -> Count u e -> Bool
$c> :: forall u e. Ord e => Count u e -> Count u e -> Bool
<= :: Count u e -> Count u e -> Bool
$c<= :: forall u e. Ord e => Count u e -> Count u e -> Bool
< :: Count u e -> Count u e -> Bool
$c< :: forall u e. Ord e => Count u e -> Count u e -> Bool
compare :: Count u e -> Count u e -> Ordering
$ccompare :: forall u e. Ord e => Count u e -> Count u e -> Ordering
$cp1Ord :: forall u e. Ord e => Eq (Count u e)
Ord, Int -> Count u e -> ShowS
[Count u e] -> ShowS
Count u e -> String
(Int -> Count u e -> ShowS)
-> (Count u e -> String)
-> ([Count u e] -> ShowS)
-> Show (Count u e)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall u e. Show e => Int -> Count u e -> ShowS
forall u e. Show e => [Count u e] -> ShowS
forall u e. Show e => Count u e -> String
showList :: [Count u e] -> ShowS
$cshowList :: forall u e. Show e => [Count u e] -> ShowS
show :: Count u e -> String
$cshow :: forall u e. Show e => Count u e -> String
showsPrec :: Int -> Count u e -> ShowS
$cshowsPrec :: forall u e. Show e => Int -> Count u e -> ShowS
Show, Integer -> Count u e
Count u e -> Count u e
Count u e -> Count u e -> Count u e
(Count u e -> Count u e -> Count u e)
-> (Count u e -> Count u e -> Count u e)
-> (Count u e -> Count u e -> Count u e)
-> (Count u e -> Count u e)
-> (Count u e -> Count u e)
-> (Count u e -> Count u e)
-> (Integer -> Count u e)
-> Num (Count u e)
forall a.
(a -> a -> a)
-> (a -> a -> a)
-> (a -> a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (Integer -> a)
-> Num a
forall u e. Num e => Integer -> Count u e
forall u e. Num e => Count u e -> Count u e
forall u e. Num e => Count u e -> Count u e -> Count u e
fromInteger :: Integer -> Count u e
$cfromInteger :: forall u e. Num e => Integer -> Count u e
signum :: Count u e -> Count u e
$csignum :: forall u e. Num e => Count u e -> Count u e
abs :: Count u e -> Count u e
$cabs :: forall u e. Num e => Count u e -> Count u e
negate :: Count u e -> Count u e
$cnegate :: forall u e. Num e => Count u e -> Count u e
* :: Count u e -> Count u e -> Count u e
$c* :: forall u e. Num e => Count u e -> Count u e -> Count u e
- :: Count u e -> Count u e -> Count u e
$c- :: forall u e. Num e => Count u e -> Count u e -> Count u e
+ :: Count u e -> Count u e -> Count u e
$c+ :: forall u e. Num e => Count u e -> Count u e -> Count u e
Num, Num (Count u e)
Num (Count u e)
-> (Count u e -> Count u e -> Count u e)
-> (Count u e -> Count u e -> Count u e)
-> (Count u e -> Count u e -> Count u e)
-> (Count u e -> Count u e -> Count u e)
-> (Count u e -> Maybe Int)
-> (Count u e -> Count u e -> Count u e)
-> (Int8 -> Count u e)
-> (Int16 -> Count u e)
-> (Int32 -> Count u e)
-> (Int64 -> Count u e)
-> IntegralExp (Count u e)
Int8 -> Count u e
Int16 -> Count u e
Int32 -> Count u e
Int64 -> Count u e
Count u e -> Maybe Int
Count u e -> Count u e -> Count u e
forall e.
Num e
-> (e -> e -> e)
-> (e -> e -> e)
-> (e -> e -> e)
-> (e -> e -> e)
-> (e -> Maybe Int)
-> (e -> e -> e)
-> (Int8 -> e)
-> (Int16 -> e)
-> (Int32 -> e)
-> (Int64 -> e)
-> IntegralExp e
forall u e. IntegralExp e => Num (Count u e)
forall u e. IntegralExp e => Int8 -> Count u e
forall u e. IntegralExp e => Int16 -> Count u e
forall u e. IntegralExp e => Int32 -> Count u e
forall u e. IntegralExp e => Int64 -> Count u e
forall u e. IntegralExp e => Count u e -> Maybe Int
forall u e. IntegralExp e => Count u e -> Count u e -> Count u e
fromInt64 :: Int64 -> Count u e
$cfromInt64 :: forall u e. IntegralExp e => Int64 -> Count u e
fromInt32 :: Int32 -> Count u e
$cfromInt32 :: forall u e. IntegralExp e => Int32 -> Count u e
fromInt16 :: Int16 -> Count u e
$cfromInt16 :: forall u e. IntegralExp e => Int16 -> Count u e
fromInt8 :: Int8 -> Count u e
$cfromInt8 :: forall u e. IntegralExp e => Int8 -> Count u e
divUp :: Count u e -> Count u e -> Count u e
$cdivUp :: forall u e. IntegralExp e => Count u e -> Count u e -> Count u e
sgn :: Count u e -> Maybe Int
$csgn :: forall u e. IntegralExp e => Count u e -> Maybe Int
mod :: Count u e -> Count u e -> Count u e
$cmod :: forall u e. IntegralExp e => Count u e -> Count u e -> Count u e
div :: Count u e -> Count u e -> Count u e
$cdiv :: forall u e. IntegralExp e => Count u e -> Count u e -> Count u e
rem :: Count u e -> Count u e -> Count u e
$crem :: forall u e. IntegralExp e => Count u e -> Count u e -> Count u e
quot :: Count u e -> Count u e -> Count u e
$cquot :: forall u e. IntegralExp e => Count u e -> Count u e -> Count u e
$cp1IntegralExp :: forall u e. IntegralExp e => Num (Count u e)
IntegralExp, Count u e -> FV
(Count u e -> FV) -> FreeIn (Count u e)
forall a. (a -> FV) -> FreeIn a
forall u e. FreeIn e => Count u e -> FV
freeIn' :: Count u e -> FV
$cfreeIn' :: forall u e. FreeIn e => Count u e -> FV
FreeIn, Int -> Count u e -> Doc
[Count u e] -> Doc
Count u e -> Doc
(Count u e -> Doc)
-> (Int -> Count u e -> Doc)
-> ([Count u e] -> Doc)
-> Pretty (Count u e)
forall a.
(a -> Doc) -> (Int -> a -> Doc) -> ([a] -> Doc) -> Pretty a
forall u e. Pretty e => Int -> Count u e -> Doc
forall u e. Pretty e => [Count u e] -> Doc
forall u e. Pretty e => Count u e -> Doc
pprList :: [Count u e] -> Doc
$cpprList :: forall u e. Pretty e => [Count u e] -> Doc
pprPrec :: Int -> Count u e -> Doc
$cpprPrec :: forall u e. Pretty e => Int -> Count u e -> Doc
ppr :: Count u e -> Doc
$cppr :: forall u e. Pretty e => Count u e -> Doc
Pretty, Map VName VName -> Count u e -> Count u e
(Map VName VName -> Count u e -> Count u e)
-> Substitute (Count u e)
forall a. (Map VName VName -> a -> a) -> Substitute a
forall u e.
Substitute e =>
Map VName VName -> Count u e -> Count u e
substituteNames :: Map VName VName -> Count u e -> Count u e
$csubstituteNames :: forall u e.
Substitute e =>
Map VName VName -> Count u e -> Count u e
Substitute)

instance Functor (Count u) where
  fmap :: (a -> b) -> Count u a -> Count u b
fmap = (a -> b) -> Count u a -> Count u b
forall (t :: * -> *) a b. Traversable t => (a -> b) -> t a -> t b
fmapDefault

instance Foldable (Count u) where
  foldMap :: (a -> m) -> Count u a -> m
foldMap = (a -> m) -> Count u a -> m
forall (t :: * -> *) m a.
(Traversable t, Monoid m) =>
(a -> m) -> t a -> m
foldMapDefault

instance Traversable (Count u) where
  traverse :: (a -> f b) -> Count u a -> f (Count u b)
traverse a -> f b
f (Count a
x) = b -> Count u b
forall u e. e -> Count u e
Count (b -> Count u b) -> f b -> f (Count u b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> a -> f b
f a
x

-- | Phantom type for the number of groups of some kernel.
data NumGroups

-- | Phantom type for the group size of some kernel.
data GroupSize

-- | Phantom type for number of threads.
data NumThreads