{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE Trustworthy #-}

-- | In the context of this module, a "size" is any kind of tunable
-- (run-time) constant.
module Futhark.IR.Kernels.Sizes
  ( SizeClass (..),
    sizeDefault,
    KernelPath,
    Count (..),
    NumGroups,
    GroupSize,
    NumThreads,
  )
where

import Data.Int (Int64)
import Data.Traversable
import Futhark.IR.Prop.Names (FreeIn)
import Futhark.Transform.Substitute
import Futhark.Util.IntegralExp (IntegralExp)
import Futhark.Util.Pretty
import Language.Futhark.Core (Name)
import Prelude hiding (id, (.))

-- | An indication of which comparisons have been performed to get to
-- this point, as well as the result of each comparison.
type KernelPath = [(Name, Bool)]

-- | The class of some kind of configurable size.  Each class may
-- impose constraints on the valid values.
data SizeClass
  = -- | A threshold with an optional default.
    SizeThreshold KernelPath (Maybe Int64)
  | SizeGroup
  | SizeNumGroups
  | SizeTile
  | SizeRegTile
  | -- | Likely not useful on its own, but querying the
    -- maximum can be handy.
    SizeLocalMemory
  | -- | A bespoke size with a default.
    SizeBespoke Name Int64
  deriving (SizeClass -> SizeClass -> Bool
(SizeClass -> SizeClass -> Bool)
-> (SizeClass -> SizeClass -> Bool) -> Eq SizeClass
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SizeClass -> SizeClass -> Bool
$c/= :: SizeClass -> SizeClass -> Bool
== :: SizeClass -> SizeClass -> Bool
$c== :: SizeClass -> SizeClass -> Bool
Eq, Eq SizeClass
Eq SizeClass
-> (SizeClass -> SizeClass -> Ordering)
-> (SizeClass -> SizeClass -> Bool)
-> (SizeClass -> SizeClass -> Bool)
-> (SizeClass -> SizeClass -> Bool)
-> (SizeClass -> SizeClass -> Bool)
-> (SizeClass -> SizeClass -> SizeClass)
-> (SizeClass -> SizeClass -> SizeClass)
-> Ord SizeClass
SizeClass -> SizeClass -> Bool
SizeClass -> SizeClass -> Ordering
SizeClass -> SizeClass -> SizeClass
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: SizeClass -> SizeClass -> SizeClass
$cmin :: SizeClass -> SizeClass -> SizeClass
max :: SizeClass -> SizeClass -> SizeClass
$cmax :: SizeClass -> SizeClass -> SizeClass
>= :: SizeClass -> SizeClass -> Bool
$c>= :: SizeClass -> SizeClass -> Bool
> :: SizeClass -> SizeClass -> Bool
$c> :: SizeClass -> SizeClass -> Bool
<= :: SizeClass -> SizeClass -> Bool
$c<= :: SizeClass -> SizeClass -> Bool
< :: SizeClass -> SizeClass -> Bool
$c< :: SizeClass -> SizeClass -> Bool
compare :: SizeClass -> SizeClass -> Ordering
$ccompare :: SizeClass -> SizeClass -> Ordering
$cp1Ord :: Eq SizeClass
Ord, Int -> SizeClass -> ShowS
[SizeClass] -> ShowS
SizeClass -> String
(Int -> SizeClass -> ShowS)
-> (SizeClass -> String)
-> ([SizeClass] -> ShowS)
-> Show SizeClass
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SizeClass] -> ShowS
$cshowList :: [SizeClass] -> ShowS
show :: SizeClass -> String
$cshow :: SizeClass -> String
showsPrec :: Int -> SizeClass -> ShowS
$cshowsPrec :: Int -> SizeClass -> ShowS
Show)

instance Pretty SizeClass where
  ppr :: SizeClass -> Doc
ppr (SizeThreshold KernelPath
path Maybe Int64
def) =
    Doc
"threshold" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens (Doc
def' Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma Doc -> Doc -> Doc
<+> [Doc] -> Doc
spread (((Name, Bool) -> Doc) -> KernelPath -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map (Name, Bool) -> Doc
forall a. Pretty a => (a, Bool) -> Doc
pStep KernelPath
path))
    where
      pStep :: (a, Bool) -> Doc
pStep (a
v, Bool
True) = a -> Doc
forall a. Pretty a => a -> Doc
ppr a
v
      pStep (a
v, Bool
False) = Doc
"!" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> a -> Doc
forall a. Pretty a => a -> Doc
ppr a
v
      def' :: Doc
def' = Doc -> (Int64 -> Doc) -> Maybe Int64 -> Doc
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Doc
"def" Int64 -> Doc
forall a. Pretty a => a -> Doc
ppr Maybe Int64
def
  ppr SizeClass
SizeGroup = String -> Doc
text String
"group_size"
  ppr SizeClass
SizeNumGroups = String -> Doc
text String
"num_groups"
  ppr SizeClass
SizeTile = String -> Doc
text String
"tile_size"
  ppr SizeClass
SizeRegTile = String -> Doc
text String
"reg_tile_size"
  ppr SizeClass
SizeLocalMemory = String -> Doc
text String
"local_memory"
  ppr (SizeBespoke Name
k Int64
def) =
    String -> Doc
text String
"bespoke" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens (Name -> Doc
forall a. Pretty a => a -> Doc
ppr Name
k Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma Doc -> Doc -> Doc
<+> Int64 -> Doc
forall a. Pretty a => a -> Doc
ppr Int64
def)

-- | The default value for the size.  If 'Nothing', that means the backend gets to decide.
sizeDefault :: SizeClass -> Maybe Int64
sizeDefault :: SizeClass -> Maybe Int64
sizeDefault (SizeThreshold KernelPath
_ Maybe Int64
x) = Maybe Int64
x
sizeDefault (SizeBespoke Name
_ Int64
x) = Int64 -> Maybe Int64
forall a. a -> Maybe a
Just Int64
x
sizeDefault SizeClass
_ = Maybe Int64
forall a. Maybe a
Nothing

-- | A wrapper supporting a phantom type for indicating what we are counting.
newtype Count u e = Count {Count u e -> e
unCount :: e}
  deriving (Count u e -> Count u e -> Bool
(Count u e -> Count u e -> Bool)
-> (Count u e -> Count u e -> Bool) -> Eq (Count u e)
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall u e. Eq e => Count u e -> Count u e -> Bool
/= :: Count u e -> Count u e -> Bool
$c/= :: forall u e. Eq e => Count u e -> Count u e -> Bool
== :: Count u e -> Count u e -> Bool
$c== :: forall u e. Eq e => Count u e -> Count u e -> Bool
Eq, Eq (Count u e)
Eq (Count u e)
-> (Count u e -> Count u e -> Ordering)
-> (Count u e -> Count u e -> Bool)
-> (Count u e -> Count u e -> Bool)
-> (Count u e -> Count u e -> Bool)
-> (Count u e -> Count u e -> Bool)
-> (Count u e -> Count u e -> Count u e)
-> (Count u e -> Count u e -> Count u e)
-> Ord (Count u e)
Count u e -> Count u e -> Bool
Count u e -> Count u e -> Ordering
Count u e -> Count u e -> Count u e
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall u e. Ord e => Eq (Count u e)
forall u e. Ord e => Count u e -> Count u e -> Bool
forall u e. Ord e => Count u e -> Count u e -> Ordering
forall u e. Ord e => Count u e -> Count u e -> Count u e
min :: Count u e -> Count u e -> Count u e
$cmin :: forall u e. Ord e => Count u e -> Count u e -> Count u e
max :: Count u e -> Count u e -> Count u e
$cmax :: forall u e. Ord e => Count u e -> Count u e -> Count u e
>= :: Count u e -> Count u e -> Bool
$c>= :: forall u e. Ord e => Count u e -> Count u e -> Bool
> :: Count u e -> Count u e -> Bool
$c> :: forall u e. Ord e => Count u e -> Count u e -> Bool
<= :: Count u e -> Count u e -> Bool
$c<= :: forall u e. Ord e => Count u e -> Count u e -> Bool
< :: Count u e -> Count u e -> Bool
$c< :: forall u e. Ord e => Count u e -> Count u e -> Bool
compare :: Count u e -> Count u e -> Ordering
$ccompare :: forall u e. Ord e => Count u e -> Count u e -> Ordering
$cp1Ord :: forall u e. Ord e => Eq (Count u e)
Ord, Int -> Count u e -> ShowS
[Count u e] -> ShowS
Count u e -> String
(Int -> Count u e -> ShowS)
-> (Count u e -> String)
-> ([Count u e] -> ShowS)
-> Show (Count u e)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall u e. Show e => Int -> Count u e -> ShowS
forall u e. Show e => [Count u e] -> ShowS
forall u e. Show e => Count u e -> String
showList :: [Count u e] -> ShowS
$cshowList :: forall u e. Show e => [Count u e] -> ShowS
show :: Count u e -> String
$cshow :: forall u e. Show e => Count u e -> String
showsPrec :: Int -> Count u e -> ShowS
$cshowsPrec :: forall u e. Show e => Int -> Count u e -> ShowS
Show, Integer -> Count u e
Count u e -> Count u e
Count u e -> Count u e -> Count u e
(Count u e -> Count u e -> Count u e)
-> (Count u e -> Count u e -> Count u e)
-> (Count u e -> Count u e -> Count u e)
-> (Count u e -> Count u e)
-> (Count u e -> Count u e)
-> (Count u e -> Count u e)
-> (Integer -> Count u e)
-> Num (Count u e)
forall a.
(a -> a -> a)
-> (a -> a -> a)
-> (a -> a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (Integer -> a)
-> Num a
forall u e. Num e => Integer -> Count u e
forall u e. Num e => Count u e -> Count u e
forall u e. Num e => Count u e -> Count u e -> Count u e
fromInteger :: Integer -> Count u e
$cfromInteger :: forall u e. Num e => Integer -> Count u e
signum :: Count u e -> Count u e
$csignum :: forall u e. Num e => Count u e -> Count u e
abs :: Count u e -> Count u e
$cabs :: forall u e. Num e => Count u e -> Count u e
negate :: Count u e -> Count u e
$cnegate :: forall u e. Num e => Count u e -> Count u e
* :: Count u e -> Count u e -> Count u e
$c* :: forall u e. Num e => Count u e -> Count u e -> Count u e
- :: Count u e -> Count u e -> Count u e
$c- :: forall u e. Num e => Count u e -> Count u e -> Count u e
+ :: Count u e -> Count u e -> Count u e
$c+ :: forall u e. Num e => Count u e -> Count u e -> Count u e
Num, Num (Count u e)
Num (Count u e)
-> (Count u e -> Count u e -> Count u e)
-> (Count u e -> Count u e -> Count u e)
-> (Count u e -> Count u e -> Count u e)
-> (Count u e -> Count u e -> Count u e)
-> (Count u e -> Maybe Int)
-> (Count u e -> Count u e -> Count u e)
-> IntegralExp (Count u e)
Count u e -> Maybe Int
Count u e -> Count u e -> Count u e
forall e.
Num e
-> (e -> e -> e)
-> (e -> e -> e)
-> (e -> e -> e)
-> (e -> e -> e)
-> (e -> Maybe Int)
-> (e -> e -> e)
-> IntegralExp e
forall u e. IntegralExp e => Num (Count u e)
forall u e. IntegralExp e => Count u e -> Maybe Int
forall u e. IntegralExp e => Count u e -> Count u e -> Count u e
divUp :: Count u e -> Count u e -> Count u e
$cdivUp :: forall u e. IntegralExp e => Count u e -> Count u e -> Count u e
sgn :: Count u e -> Maybe Int
$csgn :: forall u e. IntegralExp e => Count u e -> Maybe Int
mod :: Count u e -> Count u e -> Count u e
$cmod :: forall u e. IntegralExp e => Count u e -> Count u e -> Count u e
div :: Count u e -> Count u e -> Count u e
$cdiv :: forall u e. IntegralExp e => Count u e -> Count u e -> Count u e
rem :: Count u e -> Count u e -> Count u e
$crem :: forall u e. IntegralExp e => Count u e -> Count u e -> Count u e
quot :: Count u e -> Count u e -> Count u e
$cquot :: forall u e. IntegralExp e => Count u e -> Count u e -> Count u e
$cp1IntegralExp :: forall u e. IntegralExp e => Num (Count u e)
IntegralExp, Count u e -> FV
(Count u e -> FV) -> FreeIn (Count u e)
forall a. (a -> FV) -> FreeIn a
forall u e. FreeIn e => Count u e -> FV
freeIn' :: Count u e -> FV
$cfreeIn' :: forall u e. FreeIn e => Count u e -> FV
FreeIn, Int -> Count u e -> Doc
[Count u e] -> Doc
Count u e -> Doc
(Count u e -> Doc)
-> (Int -> Count u e -> Doc)
-> ([Count u e] -> Doc)
-> Pretty (Count u e)
forall a.
(a -> Doc) -> (Int -> a -> Doc) -> ([a] -> Doc) -> Pretty a
forall u e. Pretty e => Int -> Count u e -> Doc
forall u e. Pretty e => [Count u e] -> Doc
forall u e. Pretty e => Count u e -> Doc
pprList :: [Count u e] -> Doc
$cpprList :: forall u e. Pretty e => [Count u e] -> Doc
pprPrec :: Int -> Count u e -> Doc
$cpprPrec :: forall u e. Pretty e => Int -> Count u e -> Doc
ppr :: Count u e -> Doc
$cppr :: forall u e. Pretty e => Count u e -> Doc
Pretty, Map VName VName -> Count u e -> Count u e
(Map VName VName -> Count u e -> Count u e)
-> Substitute (Count u e)
forall a. (Map VName VName -> a -> a) -> Substitute a
forall u e.
Substitute e =>
Map VName VName -> Count u e -> Count u e
substituteNames :: Map VName VName -> Count u e -> Count u e
$csubstituteNames :: forall u e.
Substitute e =>
Map VName VName -> Count u e -> Count u e
Substitute)

instance Functor (Count u) where
  fmap :: (a -> b) -> Count u a -> Count u b
fmap = (a -> b) -> Count u a -> Count u b
forall (t :: * -> *) a b. Traversable t => (a -> b) -> t a -> t b
fmapDefault

instance Foldable (Count u) where
  foldMap :: (a -> m) -> Count u a -> m
foldMap = (a -> m) -> Count u a -> m
forall (t :: * -> *) m a.
(Traversable t, Monoid m) =>
(a -> m) -> t a -> m
foldMapDefault

instance Traversable (Count u) where
  traverse :: (a -> f b) -> Count u a -> f (Count u b)
traverse a -> f b
f (Count a
x) = b -> Count u b
forall u e. e -> Count u e
Count (b -> Count u b) -> f b -> f (Count u b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> a -> f b
f a
x

-- | Phantom type for the number of groups of some kernel.
data NumGroups

-- | Phantom type for the group size of some kernel.
data GroupSize

-- | Phantom type for number of threads.
data NumThreads