{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE ScopedTypeVariables #-} module Futhark.IR.Kernels.Kernel ( -- * Size operations SizeOp(..) -- * Host operations , HostOp(..) , typeCheckHostOp -- * SegOp refinements , SegLevel(..) -- * Reexports , module Futhark.IR.Kernels.Sizes , module Futhark.IR.SegOp ) where import Futhark.IR import qualified Futhark.Analysis.SymbolTable as ST import qualified Futhark.Util.Pretty as PP import Futhark.Util.Pretty ((), (<+>), ppr, commasep, parens, text) import Futhark.Transform.Substitute import Futhark.Transform.Rename import Futhark.Optimise.Simplify.Lore import qualified Futhark.Optimise.Simplify.Engine as Engine import Futhark.IR.Prop.Aliases import Futhark.IR.Aliases (Aliases) import Futhark.IR.SegOp import Futhark.IR.Kernels.Sizes import qualified Futhark.TypeCheck as TC import Futhark.Analysis.Metrics -- | At which level the *body* of a t'SegOp' executes. data SegLevel = SegThread { segNumGroups :: Count NumGroups SubExp , segGroupSize :: Count GroupSize SubExp , segVirt :: SegVirt } | SegGroup { segNumGroups :: Count NumGroups SubExp , segGroupSize :: Count GroupSize SubExp , segVirt :: SegVirt } deriving (Eq, Ord, Show) instance PP.Pretty SegLevel where ppr lvl = lvl' PP.parens (text "#groups=" <> ppr (segNumGroups lvl) <> PP.semi <+> text "groupsize=" <> ppr (segGroupSize lvl) <> case segVirt lvl of SegNoVirt -> mempty SegNoVirtFull -> PP.semi <+> text "full" SegVirt -> PP.semi <+> text "virtualise") where lvl' = case lvl of SegThread{} -> "_thread" SegGroup{} -> "_group" instance Engine.Simplifiable SegLevel where simplify (SegThread num_groups group_size virt) = SegThread <$> traverse Engine.simplify num_groups <*> traverse Engine.simplify group_size <*> pure virt simplify (SegGroup num_groups group_size virt) = SegGroup <$> traverse Engine.simplify num_groups <*> traverse Engine.simplify group_size <*> pure virt instance Substitute SegLevel where substituteNames substs (SegThread num_groups group_size virt) = SegThread (substituteNames substs num_groups) (substituteNames substs group_size) virt substituteNames substs (SegGroup num_groups group_size virt) = SegGroup (substituteNames substs num_groups) (substituteNames substs group_size) virt instance Rename SegLevel where rename = substituteRename instance FreeIn SegLevel where freeIn' (SegThread num_groups group_size _) = freeIn' num_groups <> freeIn' group_size freeIn' (SegGroup num_groups group_size _) = freeIn' num_groups <> freeIn' group_size -- | A simple size-level query or computation. data SizeOp = SplitSpace SplitOrdering SubExp SubExp SubExp -- ^ @SplitSpace o w i elems_per_thread@. -- -- Computes how to divide array elements to -- threads in a kernel. Returns the number of -- elements in the chunk that the current thread -- should take. -- -- @w@ is the length of the outer dimension in -- the array. @i@ is the current thread -- index. Each thread takes at most -- @elems_per_thread@ elements. -- -- If the order @o@ is 'SplitContiguous', thread with index @i@ -- should receive elements -- @i*elems_per_tread, i*elems_per_thread + 1, -- ..., i*elems_per_thread + (elems_per_thread-1)@. -- -- If the order @o@ is @'SplitStrided' stride@, -- the thread will receive elements @i, -- i+stride, i+2*stride, ..., -- i+(elems_per_thread-1)*stride@. | GetSize Name SizeClass -- ^ Produce some runtime-configurable size. | GetSizeMax SizeClass -- ^ The maximum size of some class. | CmpSizeLe Name SizeClass SubExp -- ^ Compare size (likely a threshold) with some integer value. | CalcNumGroups SubExp Name SubExp -- ^ @CalcNumGroups w max_num_groups group_size@ calculates the -- number of GPU workgroups to use for an input of the given size. -- The @Name@ is a size name. Note that @w@ is an i64 to avoid -- overflow issues. deriving (Eq, Ord, Show) instance Substitute SizeOp where substituteNames subst (SplitSpace o w i elems_per_thread) = SplitSpace (substituteNames subst o) (substituteNames subst w) (substituteNames subst i) (substituteNames subst elems_per_thread) substituteNames substs (CmpSizeLe name sclass x) = CmpSizeLe name sclass (substituteNames substs x) substituteNames substs (CalcNumGroups w max_num_groups group_size) = CalcNumGroups (substituteNames substs w) max_num_groups (substituteNames substs group_size) substituteNames _ op = op instance Rename SizeOp where rename (SplitSpace o w i elems_per_thread) = SplitSpace <$> rename o <*> rename w <*> rename i <*> rename elems_per_thread rename (CmpSizeLe name sclass x) = CmpSizeLe name sclass <$> rename x rename (CalcNumGroups w max_num_groups group_size) = CalcNumGroups <$> rename w <*> pure max_num_groups <*> rename group_size rename x = pure x instance IsOp SizeOp where safeOp _ = True cheapOp _ = True instance TypedOp SizeOp where opType SplitSpace{} = pure [Prim int32] opType (GetSize _ _) = pure [Prim int32] opType (GetSizeMax _) = pure [Prim int32] opType CmpSizeLe{} = pure [Prim Bool] opType CalcNumGroups{} = pure [Prim int32] instance AliasedOp SizeOp where opAliases _ = [mempty] consumedInOp _ = mempty instance FreeIn SizeOp where freeIn' (SplitSpace o w i elems_per_thread) = freeIn' o <> freeIn' [w, i, elems_per_thread] freeIn' (CmpSizeLe _ _ x) = freeIn' x freeIn' (CalcNumGroups w _ group_size) = freeIn' w <> freeIn' group_size freeIn' _ = mempty instance PP.Pretty SizeOp where ppr (SplitSpace o w i elems_per_thread) = text "splitSpace" <> suff <> parens (commasep [ppr w, ppr i, ppr elems_per_thread]) where suff = case o of SplitContiguous -> mempty SplitStrided stride -> text "Strided" <> parens (ppr stride) ppr (GetSize name size_class) = text "get_size" <> parens (commasep [ppr name, ppr size_class]) ppr (GetSizeMax size_class) = text "get_size_max" <> parens (commasep [ppr size_class]) ppr (CmpSizeLe name size_class x) = text "get_size" <> parens (commasep [ppr name, ppr size_class]) <+> text "<=" <+> ppr x ppr (CalcNumGroups w max_num_groups group_size) = text "calc_num_groups" <> parens (commasep [ppr w, ppr max_num_groups, ppr group_size]) instance OpMetrics SizeOp where opMetrics SplitSpace{} = seen "SplitSpace" opMetrics GetSize{} = seen "GetSize" opMetrics GetSizeMax{} = seen "GetSizeMax" opMetrics CmpSizeLe{} = seen "CmpSizeLe" opMetrics CalcNumGroups{} = seen "CalcNumGroups" typeCheckSizeOp :: TC.Checkable lore => SizeOp -> TC.TypeM lore () typeCheckSizeOp (SplitSpace o w i elems_per_thread) = do case o of SplitContiguous -> return () SplitStrided stride -> TC.require [Prim int32] stride mapM_ (TC.require [Prim int32]) [w, i, elems_per_thread] typeCheckSizeOp GetSize{} = return () typeCheckSizeOp GetSizeMax{} = return () typeCheckSizeOp (CmpSizeLe _ _ x) = TC.require [Prim int32] x typeCheckSizeOp (CalcNumGroups w _ group_size) = do TC.require [Prim int64] w TC.require [Prim int32] group_size -- | A host-level operation; parameterised by what else it can do. data HostOp lore op = SegOp (SegOp SegLevel lore) -- ^ A segmented operation. | SizeOp SizeOp | OtherOp op deriving (Eq, Ord, Show) instance (ASTLore lore, Substitute op) => Substitute (HostOp lore op) where substituteNames substs (SegOp op) = SegOp $ substituteNames substs op substituteNames substs (OtherOp op) = OtherOp $ substituteNames substs op substituteNames substs (SizeOp op) = SizeOp $ substituteNames substs op instance (ASTLore lore, Rename op) => Rename (HostOp lore op) where rename (SegOp op) = SegOp <$> rename op rename (OtherOp op) = OtherOp <$> rename op rename (SizeOp op) = SizeOp <$> rename op instance (ASTLore lore, IsOp op) => IsOp (HostOp lore op) where safeOp (SegOp op) = safeOp op safeOp (OtherOp op) = safeOp op safeOp (SizeOp op) = safeOp op cheapOp (SegOp op) = cheapOp op cheapOp (OtherOp op) = cheapOp op cheapOp (SizeOp op) = cheapOp op instance TypedOp op => TypedOp (HostOp lore op) where opType (SegOp op) = opType op opType (OtherOp op) = opType op opType (SizeOp op) = opType op instance (Aliased lore, AliasedOp op, ASTLore lore) => AliasedOp (HostOp lore op) where opAliases (SegOp op) = opAliases op opAliases (OtherOp op) = opAliases op opAliases (SizeOp op) = opAliases op consumedInOp (SegOp op) = consumedInOp op consumedInOp (OtherOp op) = consumedInOp op consumedInOp (SizeOp op) = consumedInOp op instance (ASTLore lore, FreeIn op) => FreeIn (HostOp lore op) where freeIn' (SegOp op) = freeIn' op freeIn' (OtherOp op) = freeIn' op freeIn' (SizeOp op) = freeIn' op instance (CanBeAliased (Op lore), CanBeAliased op, ASTLore lore) => CanBeAliased (HostOp lore op) where type OpWithAliases (HostOp lore op) = HostOp (Aliases lore) (OpWithAliases op) addOpAliases (SegOp op) = SegOp $ addOpAliases op addOpAliases (OtherOp op) = OtherOp $ addOpAliases op addOpAliases (SizeOp op) = SizeOp op removeOpAliases (SegOp op) = SegOp $ removeOpAliases op removeOpAliases (OtherOp op) = OtherOp $ removeOpAliases op removeOpAliases (SizeOp op) = SizeOp op instance (CanBeWise (Op lore), CanBeWise op, ASTLore lore) => CanBeWise (HostOp lore op) where type OpWithWisdom (HostOp lore op) = HostOp (Wise lore) (OpWithWisdom op) removeOpWisdom (SegOp op) = SegOp $ removeOpWisdom op removeOpWisdom (OtherOp op) = OtherOp $ removeOpWisdom op removeOpWisdom (SizeOp op) = SizeOp op instance (ASTLore lore, ST.IndexOp op) => ST.IndexOp (HostOp lore op) where indexOp vtable k (SegOp op) is = ST.indexOp vtable k op is indexOp vtable k (OtherOp op) is = ST.indexOp vtable k op is indexOp _ _ _ _ = Nothing instance (PrettyLore lore, PP.Pretty op) => PP.Pretty (HostOp lore op) where ppr (SegOp op) = ppr op ppr (OtherOp op) = ppr op ppr (SizeOp op) = ppr op instance (OpMetrics (Op lore), OpMetrics op) => OpMetrics (HostOp lore op) where opMetrics (SegOp op) = opMetrics op opMetrics (OtherOp op) = opMetrics op opMetrics (SizeOp op) = opMetrics op checkSegLevel :: TC.Checkable lore => Maybe SegLevel -> SegLevel -> TC.TypeM lore () checkSegLevel Nothing lvl = do TC.require [Prim int32] $ unCount $ segNumGroups lvl TC.require [Prim int32] $ unCount $ segGroupSize lvl checkSegLevel (Just SegThread{}) _ = TC.bad $ TC.TypeError "SegOps cannot occur when already at thread level." checkSegLevel (Just x) y | x == y = TC.bad $ TC.TypeError $ "Already at at level " ++ pretty x | segNumGroups x /= segNumGroups y || segGroupSize x /= segGroupSize y = TC.bad $ TC.TypeError "Physical layout for SegLevel does not match parent SegLevel." | otherwise = return () typeCheckHostOp :: TC.Checkable lore => (SegLevel -> OpWithAliases (Op lore) -> TC.TypeM lore ()) -> Maybe SegLevel -> (op -> TC.TypeM lore ()) -> HostOp (Aliases lore) op -> TC.TypeM lore () typeCheckHostOp checker lvl _ (SegOp op) = TC.checkOpWith (checker $ segLevel op) $ typeCheckSegOp (checkSegLevel lvl) op typeCheckHostOp _ _ f (OtherOp op) = f op typeCheckHostOp _ _ _ (SizeOp op) = typeCheckSizeOp op