{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Futhark.Representation.Kernels.Kernel
( Kernel(..)
, kernelType
, kernelSpace
, KernelDebugHints(..)
, GenReduceOp(..)
, SegRedOp(..)
, segRedResults
, KernelBody(..)
, KernelSpace(..)
, spaceDimensions
, SpaceStructure(..)
, scopeOfKernelSpace
, KernelResult(..)
, kernelResultSubExp
, KernelPath
, chunkedKernelNonconcatOutputs
, typeCheckKernel
, KernelMapper(..)
, identityKernelMapper
, mapKernelM
, KernelWalker(..)
, identityKernelWalker
, walkKernelM
, HostOp(..)
, typeCheckHostOp
)
where
import Control.Arrow (first)
import Control.Monad.Writer hiding (mapM_)
import Control.Monad.Identity hiding (mapM_)
import qualified Data.Set as S
import qualified Data.Map.Strict as M
import Data.Foldable
import Data.List
import Futhark.Representation.AST
import qualified Futhark.Analysis.Alias as Alias
import qualified Futhark.Analysis.SymbolTable as ST
import Futhark.Analysis.PrimExp.Convert
import qualified Futhark.Util.Pretty as PP
import Futhark.Util.Pretty
((</>), (<+>), ppr, commasep, Pretty, parens, text)
import Futhark.Transform.Substitute
import Futhark.Transform.Rename
import Futhark.Optimise.Simplify.Lore
import Futhark.Representation.Ranges
(Ranges, removeLambdaRanges, removeBodyRanges, mkBodyRanges)
import Futhark.Representation.AST.Attributes.Ranges
import Futhark.Representation.AST.Attributes.Aliases
import Futhark.Representation.Aliases
(Aliases, removeLambdaAliases, removeBodyAliases, removeStmAliases)
import Futhark.Representation.Kernels.KernelExp (SplitOrdering(..))
import Futhark.Representation.Kernels.Sizes
import qualified Futhark.TypeCheck as TC
import Futhark.Analysis.Metrics
import Futhark.Tools (partitionChunkedKernelLambdaParameters)
import qualified Futhark.Analysis.Range as Range
import Futhark.Util (maybeNth)
data KernelDebugHints =
KernelDebugHints { kernelName :: String
, kernelHints :: [(String, SubExp)]
}
deriving (Eq, Show, Ord)
data GenReduceOp lore =
GenReduceOp { genReduceWidth :: SubExp
, genReduceDest :: [VName]
, genReduceNeutral :: [SubExp]
, genReduceShape :: Shape
, genReduceOp :: LambdaT lore
}
deriving (Eq, Ord, Show)
data SegRedOp lore =
SegRedOp { segRedComm :: Commutativity
, segRedLambda :: Lambda lore
, segRedNeutral :: [SubExp]
, segRedShape :: Shape
}
deriving (Eq, Ord, Show)
segRedResults :: [SegRedOp lore] -> Int
segRedResults = sum . map (length . segRedNeutral)
data Kernel lore
= Kernel KernelDebugHints KernelSpace [Type] (KernelBody lore)
| SegMap KernelSpace [Type] (KernelBody lore)
| SegRed KernelSpace [SegRedOp lore] [Type] (KernelBody lore)
| SegScan KernelSpace (Lambda lore) [SubExp] [Type] (KernelBody lore)
| SegGenRed KernelSpace [GenReduceOp lore] [Type] (KernelBody lore)
deriving (Eq, Show, Ord)
kernelSpace :: Kernel lore -> KernelSpace
kernelSpace (Kernel _ kspace _ _) = kspace
kernelSpace (SegMap kspace _ _) = kspace
kernelSpace (SegRed kspace _ _ _) = kspace
kernelSpace (SegScan kspace _ _ _ _) = kspace
kernelSpace (SegGenRed kspace _ _ _) = kspace
data KernelSpace = KernelSpace { spaceGlobalId :: VName
, spaceLocalId :: VName
, spaceGroupId :: VName
, spaceNumThreads :: SubExp
, spaceNumGroups :: SubExp
, spaceGroupSize :: SubExp
, spaceNumVirtGroups :: SubExp
, spaceStructure :: SpaceStructure
}
deriving (Eq, Show, Ord)
data SpaceStructure = FlatThreadSpace
[(VName, SubExp)]
| NestedThreadSpace
[(VName,
SubExp,
VName,
SubExp
)]
deriving (Eq, Show, Ord)
spaceDimensions :: KernelSpace -> [(VName, SubExp)]
spaceDimensions = structureDimensions . spaceStructure
where structureDimensions (FlatThreadSpace dims) = dims
structureDimensions (NestedThreadSpace dims) =
let (gtids, gdim_sizes, _, _) = unzip4 dims
in zip gtids gdim_sizes
data KernelBody lore = KernelBody { kernelBodyLore :: BodyAttr lore
, kernelBodyStms :: Stms lore
, kernelBodyResult :: [KernelResult]
}
deriving instance Annotations lore => Ord (KernelBody lore)
deriving instance Annotations lore => Show (KernelBody lore)
deriving instance Annotations lore => Eq (KernelBody lore)
data KernelResult = ThreadsReturn SubExp
| GroupsReturn SubExp
| WriteReturn
[SubExp]
VName
[([SubExp], SubExp)]
| ConcatReturns
SplitOrdering
SubExp
SubExp
(Maybe SubExp)
VName
deriving (Eq, Show, Ord)
kernelResultSubExp :: KernelResult -> SubExp
kernelResultSubExp (ThreadsReturn se) = se
kernelResultSubExp (GroupsReturn se) = se
kernelResultSubExp (WriteReturn _ arr _) = Var arr
kernelResultSubExp (ConcatReturns _ _ _ _ v) = Var v
data KernelMapper flore tlore m = KernelMapper {
mapOnKernelSubExp :: SubExp -> m SubExp
, mapOnKernelLambda :: Lambda flore -> m (Lambda tlore)
, mapOnKernelBody :: Body flore -> m (Body tlore)
, mapOnKernelVName :: VName -> m VName
, mapOnKernelLParam :: LParam flore -> m (LParam tlore)
, mapOnKernelKernelBody :: KernelBody flore -> m (KernelBody tlore)
}
identityKernelMapper :: Monad m => KernelMapper lore lore m
identityKernelMapper = KernelMapper { mapOnKernelSubExp = return
, mapOnKernelLambda = return
, mapOnKernelBody = return
, mapOnKernelVName = return
, mapOnKernelLParam = return
, mapOnKernelKernelBody = return
}
mapKernelM :: (Applicative m, Monad m) =>
KernelMapper flore tlore m -> Kernel flore -> m (Kernel tlore)
mapKernelM tv (SegMap space ts body) =
SegMap
<$> mapOnKernelSpace tv space
<*> mapM (mapOnType $ mapOnKernelSubExp tv) ts
<*> mapOnKernelKernelBody tv body
mapKernelM tv (SegRed space reds ts body) =
SegRed
<$> mapOnKernelSpace tv space
<*> mapM onSegOp reds
<*> mapM (mapOnType $ mapOnKernelSubExp tv) ts
<*> mapOnKernelKernelBody tv body
where onSegOp (SegRedOp comm red_op nes shape) =
SegRedOp comm
<$> mapOnKernelLambda tv red_op
<*> mapM (mapOnKernelSubExp tv) nes
<*> (Shape <$> mapM (mapOnKernelSubExp tv) (shapeDims shape))
mapKernelM tv (SegScan space scan_op nes ts body) =
SegScan
<$> mapOnKernelSpace tv space
<*> mapOnKernelLambda tv scan_op
<*> mapM (mapOnKernelSubExp tv) nes
<*> mapM (mapOnType $ mapOnKernelSubExp tv) ts
<*> mapOnKernelKernelBody tv body
mapKernelM tv (SegGenRed space ops ts body) =
SegGenRed
<$> mapOnKernelSpace tv space
<*> mapM onGenRedOp ops
<*> mapM (mapOnType $ mapOnKernelSubExp tv) ts
<*> mapOnKernelKernelBody tv body
where onGenRedOp (GenReduceOp w arrs nes shape op) =
GenReduceOp <$> mapOnKernelSubExp tv w
<*> mapM (mapOnKernelVName tv) arrs
<*> mapM (mapOnKernelSubExp tv) nes
<*> (Shape <$> mapM (mapOnKernelSubExp tv) (shapeDims shape))
<*> mapOnKernelLambda tv op
mapKernelM tv (Kernel desc space ts kernel_body) =
Kernel <$> mapOnKernelDebugHints desc <*>
mapOnKernelSpace tv space <*>
mapM (mapOnKernelType tv) ts <*>
mapOnKernelKernelBody tv kernel_body
where mapOnKernelDebugHints (KernelDebugHints name kvs) =
KernelDebugHints name <$>
(zip (map fst kvs) <$> mapM (mapOnKernelSubExp tv . snd) kvs)
mapOnKernelSpace :: Monad f =>
KernelMapper flore tlore f -> KernelSpace -> f KernelSpace
mapOnKernelSpace tv (KernelSpace gtid ltid gid num_threads num_groups group_size virt_groups structure) =
KernelSpace gtid ltid gid
<$> mapOnKernelSubExp tv num_threads
<*> mapOnKernelSubExp tv num_groups
<*> mapOnKernelSubExp tv group_size
<*> mapOnKernelSubExp tv virt_groups
<*> mapOnKernelStructure structure
where mapOnKernelStructure (FlatThreadSpace dims) =
FlatThreadSpace <$> (zip gtids <$> mapM (mapOnKernelSubExp tv) gdim_sizes)
where (gtids, gdim_sizes) = unzip dims
mapOnKernelStructure (NestedThreadSpace dims) =
NestedThreadSpace <$> (zip4 gtids
<$> mapM (mapOnKernelSubExp tv) gdim_sizes
<*> pure ltids
<*> mapM (mapOnKernelSubExp tv) ldim_sizes)
where (gtids, gdim_sizes, ltids, ldim_sizes) = unzip4 dims
mapOnKernelType :: Monad m =>
KernelMapper flore tlore m -> Type -> m Type
mapOnKernelType _tv (Prim pt) = pure $ Prim pt
mapOnKernelType tv (Array pt shape u) = Array pt <$> f shape <*> pure u
where f (Shape dims) = Shape <$> mapM (mapOnKernelSubExp tv) dims
mapOnKernelType _tv (Mem s) = pure $ Mem s
instance (Attributes lore, FreeIn (LParamAttr lore)) =>
FreeIn (Kernel lore) where
freeIn e = execWriter $ mapKernelM free e
where walk f x = tell (f x) >> return x
free = KernelMapper { mapOnKernelSubExp = walk freeIn
, mapOnKernelLambda = walk freeIn
, mapOnKernelBody = walk freeIn
, mapOnKernelVName = walk freeIn
, mapOnKernelLParam = walk freeIn
, mapOnKernelKernelBody = walk freeIn
}
data KernelWalker lore m = KernelWalker {
walkOnKernelSubExp :: SubExp -> m ()
, walkOnKernelLambda :: Lambda lore -> m ()
, walkOnKernelBody :: Body lore -> m ()
, walkOnKernelVName :: VName -> m ()
, walkOnKernelLParam :: LParam lore -> m ()
, walkOnKernelKernelBody :: KernelBody lore -> m ()
}
identityKernelWalker :: Monad m => KernelWalker lore m
identityKernelWalker = KernelWalker {
walkOnKernelSubExp = const $ return ()
, walkOnKernelLambda = const $ return ()
, walkOnKernelBody = const $ return ()
, walkOnKernelVName = const $ return ()
, walkOnKernelLParam = const $ return ()
, walkOnKernelKernelBody = const $ return ()
}
walkKernelMapper :: forall lore m. Monad m =>
KernelWalker lore m -> KernelMapper lore lore m
walkKernelMapper f = KernelMapper {
mapOnKernelSubExp = wrap walkOnKernelSubExp
, mapOnKernelLambda = wrap walkOnKernelLambda
, mapOnKernelBody = wrap walkOnKernelBody
, mapOnKernelVName = wrap walkOnKernelVName
, mapOnKernelLParam = wrap walkOnKernelLParam
, mapOnKernelKernelBody = wrap walkOnKernelKernelBody
}
where wrap :: (KernelWalker lore m -> a -> m ()) -> a -> m a
wrap op k = op f k >> return k
walkKernelM :: Monad m => KernelWalker lore m -> Kernel lore -> m ()
walkKernelM f = void . mapKernelM m
where m = walkKernelMapper f
instance FreeIn KernelResult where
freeIn (GroupsReturn what) = freeIn what
freeIn (ThreadsReturn what) = freeIn what
freeIn (WriteReturn rws arr res) = freeIn rws <> freeIn arr <> freeIn res
freeIn (ConcatReturns o w per_thread_elems moffset v) =
freeIn o <> freeIn w <> freeIn per_thread_elems <> freeIn moffset <> freeIn v
instance Attributes lore => FreeIn (KernelBody lore) where
freeIn (KernelBody attr stms res) =
(freeIn attr <> free_in_stms <> free_in_res) `S.difference` bound_in_stms
where free_in_stms = fold $ fmap freeIn stms
free_in_res = freeIn res
bound_in_stms = fold $ fmap boundByStm stms
instance Attributes lore => Substitute (KernelBody lore) where
substituteNames subst (KernelBody attr stms res) =
KernelBody
(substituteNames subst attr)
(substituteNames subst stms)
(substituteNames subst res)
instance Substitute KernelResult where
substituteNames subst (GroupsReturn se) =
GroupsReturn $ substituteNames subst se
substituteNames subst (ThreadsReturn se) =
ThreadsReturn $ substituteNames subst se
substituteNames subst (WriteReturn rws arr res) =
WriteReturn
(substituteNames subst rws) (substituteNames subst arr)
(substituteNames subst res)
substituteNames subst (ConcatReturns o w per_thread_elems moffset v) =
ConcatReturns
(substituteNames subst o)
(substituteNames subst w)
(substituteNames subst per_thread_elems)
(substituteNames subst moffset)
(substituteNames subst v)
instance Substitute KernelSpace where
substituteNames subst (KernelSpace gtid ltid gid num_threads num_groups group_size virt_groups structure) =
KernelSpace (substituteNames subst gtid)
(substituteNames subst ltid)
(substituteNames subst gid)
(substituteNames subst num_threads)
(substituteNames subst num_groups)
(substituteNames subst group_size)
(substituteNames subst virt_groups)
(substituteNames subst structure)
instance Substitute SpaceStructure where
substituteNames subst (FlatThreadSpace dims) =
FlatThreadSpace (map (substituteNames subst) dims)
substituteNames subst (NestedThreadSpace dims) =
NestedThreadSpace (map (substituteNames subst) dims)
instance Attributes lore => Substitute (Kernel lore) where
substituteNames subst (Kernel desc space ts kbody) =
Kernel desc
(substituteNames subst space)
(substituteNames subst ts)
(substituteNames subst kbody)
substituteNames subst k = runIdentity $ mapKernelM substitute k
where substitute =
KernelMapper { mapOnKernelSubExp = return . substituteNames subst
, mapOnKernelLambda = return . substituteNames subst
, mapOnKernelBody = return . substituteNames subst
, mapOnKernelVName = return . substituteNames subst
, mapOnKernelLParam = return . substituteNames subst
, mapOnKernelKernelBody = return . substituteNames subst
}
instance Attributes lore => Rename (KernelBody lore) where
rename (KernelBody attr stms res) = do
attr' <- rename attr
renamingStms stms $ \stms' ->
KernelBody attr' stms' <$> rename res
instance Rename KernelResult where
rename = substituteRename
scopeOfKernelSpace :: KernelSpace -> Scope lore
scopeOfKernelSpace (KernelSpace gtid ltid gid _ _ _ _ structure) =
M.fromList $ zip ([gtid, ltid, gid] ++ structure') $ repeat $ IndexInfo Int32
where structure' = case structure of
FlatThreadSpace dims -> map fst dims
NestedThreadSpace dims ->
let (gtids, _, ltids, _) = unzip4 dims
in gtids ++ ltids
instance Attributes lore => Rename (Kernel lore) where
rename = mapKernelM renamer
where renamer = KernelMapper rename rename rename rename rename rename
kernelResultShape :: KernelSpace -> Type -> KernelResult -> Type
kernelResultShape _ t (WriteReturn rws _ _) =
t `arrayOfShape` Shape rws
kernelResultShape space t (GroupsReturn _) =
t `arrayOfRow` spaceNumGroups space
kernelResultShape space t (ThreadsReturn _) =
foldr (flip arrayOfRow . snd) t $ spaceDimensions space
kernelResultShape _ t (ConcatReturns _ w _ _ _) =
t `arrayOfRow` w
kernelType :: Kernel lore -> [Type]
kernelType (Kernel _ space ts body) =
zipWith (kernelResultShape space) ts $ kernelBodyResult body
kernelType (SegMap space ts body) =
zipWith (kernelResultShape space) ts $ kernelBodyResult body
kernelType (SegRed space reds ts body) =
red_ts ++
zipWith (kernelResultShape space) map_ts
(drop (length red_ts) $ kernelBodyResult body)
where map_ts = drop (length red_ts) ts
segment_dims = init $ map snd $ spaceDimensions space
red_ts = do
op <- reds
let shape = Shape segment_dims <> segRedShape op
map (`arrayOfShape` shape) (lambdaReturnType $ segRedLambda op)
kernelType (SegScan space _ _ ts _) =
map (`arrayOfShape` Shape dims) ts
where dims = map snd $ spaceDimensions space
kernelType (SegGenRed space ops _ _) = do
op <- ops
let shape = Shape (segment_dims <> [genReduceWidth op]) <> genReduceShape op
map (`arrayOfShape` shape) (lambdaReturnType $ genReduceOp op)
where dims = map snd $ spaceDimensions space
segment_dims = init dims
chunkedKernelNonconcatOutputs :: Lambda lore -> Int
chunkedKernelNonconcatOutputs fun =
length $ takeWhile (not . outerSizeIsChunk) $ lambdaReturnType fun
where outerSizeIsChunk = (==Var (paramName chunk)) . arraySize 0
(_, chunk, _) = partitionChunkedKernelLambdaParameters $ lambdaParams fun
instance TypedOp (Kernel lore) where
opType = pure . staticShapes . kernelType
instance (Attributes lore, Aliased lore) => AliasedOp (Kernel lore) where
opAliases = map (const mempty) . kernelType
consumedInOp (Kernel _ _ _ kbody) =
consumedInKernelBody kbody <>
mconcat (map consumedByReturn (kernelBodyResult kbody))
where consumedByReturn (WriteReturn _ a _) = S.singleton a
consumedByReturn _ = mempty
consumedInOp (SegGenRed _ ops _ kbody) =
S.fromList (concatMap genReduceDest ops) <>
consumedInKernelBody kbody
consumedInOp (SegMap _ _ kbody) =
consumedInKernelBody kbody
consumedInOp (SegRed _ _ _ kbody) =
consumedInKernelBody kbody
consumedInOp (SegScan _ _ _ _ kbody) =
consumedInKernelBody kbody
aliasAnalyseKernelBody :: (Attributes lore,
CanBeAliased (Op lore)) =>
KernelBody lore
-> KernelBody (Aliases lore)
aliasAnalyseKernelBody (KernelBody attr stms res) =
let Body attr' stms' _ = Alias.analyseBody $ Body attr stms []
in KernelBody attr' stms' res
instance (Attributes lore,
Attributes (Aliases lore),
CanBeAliased (Op lore)) => CanBeAliased (Kernel lore) where
type OpWithAliases (Kernel lore) = Kernel (Aliases lore)
addOpAliases = runIdentity . mapKernelM alias
where alias = KernelMapper return (return . Alias.analyseLambda)
(return . Alias.analyseBody) return return
(return . aliasAnalyseKernelBody)
removeOpAliases = runIdentity . mapKernelM remove
where remove = KernelMapper return (return . removeLambdaAliases)
(return . removeBodyAliases) return return
(return . removeKernelBodyAliases)
removeKernelBodyAliases :: KernelBody (Aliases lore)
-> KernelBody lore
removeKernelBodyAliases (KernelBody (_, attr) stms res) =
KernelBody attr (fmap removeStmAliases stms) res
instance Attributes lore => IsOp (Kernel lore) where
safeOp _ = True
cheapOp Kernel{} = False
cheapOp _ = True
instance Ranged inner => RangedOp (Kernel inner) where
opRanges op = replicate (length $ kernelType op) unknownRange
instance (Attributes lore, CanBeRanged (Op lore)) => CanBeRanged (Kernel lore) where
type OpWithRanges (Kernel lore) = Kernel (Ranges lore)
removeOpRanges = runIdentity . mapKernelM remove
where remove = KernelMapper return (return . removeLambdaRanges)
(return . removeBodyRanges) return return
(return . removeKernelBodyRanges)
removeKernelBodyRanges = error "removeKernelBodyRanges"
addOpRanges = Range.runRangeM . mapKernelM add
where add = KernelMapper return Range.analyseLambda
Range.analyseBody return return addKernelBodyRanges
addKernelBodyRanges (KernelBody attr stms res) =
Range.analyseStms stms $ \stms' -> do
let attr' = (mkBodyRanges stms $ map kernelResultSubExp res, attr)
return $ KernelBody attr' stms' res
instance (Attributes lore, CanBeWise (Op lore)) => CanBeWise (Kernel lore) where
type OpWithWisdom (Kernel lore) = Kernel (Wise lore)
removeOpWisdom = runIdentity . mapKernelM remove
where remove = KernelMapper return
(return . removeLambdaWisdom)
(return . removeBodyWisdom)
return return
(return . removeKernelBodyWisdom)
removeKernelBodyWisdom :: KernelBody (Wise lore)
-> KernelBody lore
removeKernelBodyWisdom (KernelBody attr stms res) =
let Body attr' stms' _ = removeBodyWisdom $ Body attr stms []
in KernelBody attr' stms' res
instance Attributes lore => ST.IndexOp (Kernel lore) where
indexOp vtable k (Kernel _ space _ kbody) is = do
ThreadsReturn se <- maybeNth k $ kernelBodyResult kbody
let (gtids, _) = unzip $ spaceDimensions space
guard $ length gtids == length is
let prim_table = M.fromList $ zip gtids $ zip is $ repeat mempty
prim_table' = foldl expandPrimExpTable prim_table $ kernelBodyStms kbody
case se of
Var v -> M.lookup v prim_table'
_ -> Nothing
where expandPrimExpTable table stm
| [v] <- patternNames $ stmPattern stm,
Just (pe,cs) <-
runWriterT $ primExpFromExp (asPrimExp table) $ stmExp stm =
M.insert v (pe, stmCerts stm <> cs) table
| otherwise =
table
asPrimExp table v
| Just (e,cs) <- M.lookup v table = tell cs >> return e
| Just (Prim pt) <- ST.lookupType v vtable =
return $ LeafExp v pt
| otherwise = lift Nothing
indexOp _ _ _ _ = Nothing
consumedInKernelBody :: Aliased lore =>
KernelBody lore -> Names
consumedInKernelBody (KernelBody attr stms _) =
consumedInBody $ Body attr stms []
typeCheckKernel :: TC.Checkable lore => Kernel (Aliases lore) -> TC.TypeM lore ()
typeCheckKernel (SegMap space ts kbody) = do
checkSpace space
mapM_ TC.checkType ts
TC.binding (scopeOfKernelSpace space) $ checkKernelBody ts kbody
typeCheckKernel (SegRed space reds ts body) =
checkScanRed space reds' ts body
where reds' = zip3
(map segRedLambda reds)
(map segRedNeutral reds)
(map segRedShape reds)
typeCheckKernel (SegScan space scan_op nes ts body) =
checkScanRed space [(scan_op, nes, mempty)] ts body
typeCheckKernel (SegGenRed space ops ts body) = do
checkSpace space
mapM_ TC.checkType ts
TC.binding (scopeOfKernelSpace space) $ do
nes_ts <- forM ops $ \(GenReduceOp dest_w dests nes shape op) -> do
TC.require [Prim int32] dest_w
nes' <- mapM TC.checkArg nes
mapM_ (TC.require [Prim int32]) $ shapeDims shape
let stripVecDims = stripArray $ shapeRank shape
TC.checkLambda op $ map (TC.noArgAliases . first stripVecDims) $ nes' ++ nes'
let nes_t = map TC.argType nes'
unless (nes_t == lambdaReturnType op) $
TC.bad $ TC.TypeError $ "SegGenRed operator has return type " ++
prettyTuple (lambdaReturnType op) ++ " but neutral element has type " ++
prettyTuple nes_t
let dest_shape = Shape (segment_dims <> [dest_w]) <> shape
forM_ (zip nes_t dests) $ \(t, dest) -> do
TC.requireI [t `arrayOfShape` dest_shape] dest
TC.consume =<< TC.lookupAliases dest
return $ map (`arrayOfShape` shape) nes_t
checkKernelBody ts body
let bucket_ret_t = replicate (length ops) (Prim int32) ++ concat nes_ts
unless (bucket_ret_t == ts) $
TC.bad $ TC.TypeError $ "SegGenRed body has return type " ++
prettyTuple ts ++ " but should have type " ++
prettyTuple bucket_ret_t
where segment_dims = init $ map snd $ spaceDimensions space
typeCheckKernel (Kernel _ space kts kbody) = do
checkSpace space
mapM_ TC.checkType kts
mapM_ (TC.require [Prim int32] . snd) $ spaceDimensions space
TC.binding (scopeOfKernelSpace space) $
checkKernelBody kts kbody
checkKernelBody :: TC.Checkable lore =>
[Type] -> KernelBody (Aliases lore) -> TC.TypeM lore ()
checkKernelBody ts (KernelBody (_, attr) stms kres) = do
TC.checkBodyLore attr
TC.checkStms stms $ do
unless (length ts == length kres) $
TC.bad $ TC.TypeError $ "Kernel return type is " ++ prettyTuple ts ++
", but body returns " ++ show (length kres) ++ " values."
zipWithM_ checkKernelResult kres ts
where checkKernelResult (GroupsReturn what) t =
TC.require [t] what
checkKernelResult (ThreadsReturn what) t =
TC.require [t] what
checkKernelResult (WriteReturn rws arr res) t = do
mapM_ (TC.require [Prim int32]) rws
arr_t <- lookupType arr
forM_ res $ \(is, e) -> do
mapM_ (TC.require [Prim int32]) is
TC.require [t] e
unless (arr_t == t `arrayOfShape` Shape rws) $
TC.bad $ TC.TypeError $ "WriteReturn returning " ++
pretty e ++ " of type " ++ pretty t ++ ", shape=" ++ pretty rws ++
", but destination array has type " ++ pretty arr_t
TC.consume =<< TC.lookupAliases arr
checkKernelResult (ConcatReturns o w per_thread_elems moffset v) t = do
case o of
SplitContiguous -> return ()
SplitStrided stride -> TC.require [Prim int32] stride
TC.require [Prim int32] w
TC.require [Prim int32] per_thread_elems
mapM_ (TC.require [Prim int32]) moffset
vt <- lookupType v
unless (vt == t `arrayOfRow` arraySize 0 vt) $
TC.bad $ TC.TypeError $ "Invalid type for ConcatReturns " ++ pretty v
checkScanRed :: TC.Checkable lore =>
KernelSpace
-> [(Lambda (Aliases lore), [SubExp], Shape)]
-> [Type]
-> KernelBody (Aliases lore)
-> TC.TypeM lore ()
checkScanRed space ops ts kbody = do
checkSpace space
mapM_ TC.checkType ts
TC.binding (scopeOfKernelSpace space) $ do
ne_ts <- forM ops $ \(lam, nes, shape) -> do
mapM_ (TC.require [Prim int32]) $ shapeDims shape
nes' <- mapM TC.checkArg nes
let stripVecDims = stripArray $ shapeRank shape
TC.checkLambda lam $ map (TC.noArgAliases . first stripVecDims) $ nes' ++ nes'
let nes_t = map TC.argType nes'
unless (lambdaReturnType lam == nes_t) $
TC.bad $ TC.TypeError "wrong type for operator or neutral elements."
return $ map (`arrayOfShape` shape) nes_t
let expecting = concat ne_ts
got = take (length expecting) ts
unless (expecting == got) $
TC.bad $ TC.TypeError $
"Wrong return for body (does not match neutral elements; expected " ++
pretty expecting ++ "; found " ++
pretty got ++ ")"
checkKernelBody ts kbody
checkSpace :: TC.Checkable lore => KernelSpace -> TC.TypeM lore ()
checkSpace (KernelSpace _ _ _ num_threads num_groups group_size virt_groups structure) = do
mapM_ (TC.require [Prim int32]) [num_threads,num_groups,group_size,virt_groups]
case structure of
FlatThreadSpace dims ->
mapM_ (TC.require [Prim int32] . snd) dims
NestedThreadSpace dims ->
let (_, gdim_sizes, _, ldim_sizes) = unzip4 dims
in mapM_ (TC.require [Prim int32]) $ gdim_sizes ++ ldim_sizes
instance OpMetrics (Op lore) => OpMetrics (Kernel lore) where
opMetrics (Kernel _ _ _ kbody) =
inside "Kernel" $ kernelBodyMetrics kbody
opMetrics (SegMap _ _ body) =
inside "SegMap" $ kernelBodyMetrics body
opMetrics (SegRed _ reds _ body) =
inside "SegRed" $ do mapM_ (lambdaMetrics . segRedLambda) reds
kernelBodyMetrics body
opMetrics (SegScan _ scan_op _ _ body) =
inside "SegScan" $ lambdaMetrics scan_op >> kernelBodyMetrics body
opMetrics (SegGenRed _ ops _ body) =
inside "SegGenRed" $ do mapM_ (lambdaMetrics . genReduceOp) ops
kernelBodyMetrics body
kernelBodyMetrics :: OpMetrics (Op lore) => KernelBody lore -> MetricsM ()
kernelBodyMetrics = mapM_ bindingMetrics . kernelBodyStms
instance PrettyLore lore => PP.Pretty (Kernel lore) where
ppr (Kernel desc space ts body) =
text "kernel" <+> text (kernelName desc) <>
PP.align (ppr space) <+>
PP.colon <+> ppTuple' ts <+> PP.nestedBlock "{" "}" (ppr body)
ppr (SegMap space ts body) =
text "segmap" <>
PP.align (ppr space) <+> PP.colon <+> ppTuple' ts <+>
PP.nestedBlock "{" "}" (ppr body)
ppr (SegRed space reds ts body) =
text "segred" <>
PP.parens (PP.braces (mconcat $ intersperse (PP.comma <> PP.line) $ map ppOp reds)) </>
PP.align (ppr space) <+> PP.colon <+> ppTuple' ts <+>
PP.nestedBlock "{" "}" (ppr body)
where ppOp (SegRedOp comm lam nes shape) =
PP.braces (PP.commasep $ map ppr nes) <> PP.comma </>
ppr shape <> PP.comma </>
comm' <> ppr lam
where comm' = case comm of Commutative -> text "commutative "
Noncommutative -> mempty
ppr (SegScan space scan_op nes ts body) =
text "segscan" <> PP.parens (ppr scan_op <> PP.comma </>
PP.braces (PP.commasep $ map ppr nes)) </>
PP.align (ppr space) <+> PP.colon <+> ppTuple' ts <+>
PP.nestedBlock "{" "}" (ppr body)
ppr (SegGenRed space ops ts body) =
text "seggenred" <>
PP.parens (PP.braces (mconcat $ intersperse (PP.comma <> PP.line) $ map ppOp ops)) </>
PP.align (ppr space) <+> PP.colon <+> ppTuple' ts <+>
PP.nestedBlock "{" "}" (ppr body)
where ppOp (GenReduceOp w dests nes shape op) =
ppr w <> PP.comma </>
PP.braces (PP.commasep $ map ppr dests) <> PP.comma </>
PP.braces (PP.commasep $ map ppr nes) <> PP.comma </>
ppr shape <> PP.comma </>
ppr op
instance Pretty KernelSpace where
ppr (KernelSpace f_gtid f_ltid gid num_threads num_groups group_size virt_groups structure) =
parens (commasep [text "num groups:" <+> ppr num_groups,
text "group size:" <+> ppr group_size,
text "virt_num_groups:" <+> ppr virt_groups,
text "num threads:" <+> ppr num_threads,
text "global TID ->" <+> ppr f_gtid,
text "local TID ->" <+> ppr f_ltid,
text "group ID ->" <+> ppr gid]) </> structure'
where structure' =
case structure of
FlatThreadSpace dims -> flat dims
NestedThreadSpace space ->
parens (commasep $ do
(gtid,gd,ltid,ld) <- space
return $ ppr (gtid,ltid) <+> "<" <+> ppr (gd,ld))
flat dims = parens $ commasep $ do
(i,d) <- dims
return $ ppr i <+> "<" <+> ppr d
instance PrettyLore lore => Pretty (KernelBody lore) where
ppr (KernelBody _ stms res) =
PP.stack (map ppr (stmsToList stms)) </>
text "return" <+> PP.braces (PP.commasep $ map ppr res)
instance Pretty KernelResult where
ppr (GroupsReturn what) =
text "group returns" <+> ppr what
ppr (ThreadsReturn what) =
text "thread returns" <+> ppr what
ppr (WriteReturn rws arr res) =
ppr arr <+> text "with" <+> PP.apply (map ppRes res)
where ppRes (is, e) =
PP.brackets (PP.commasep $ zipWith f is rws) <+> text "<-" <+> ppr e
f i rw = ppr i <+> text "<" <+> ppr rw
ppr (ConcatReturns o w per_thread_elems offset v) =
text "concat" <> suff <>
parens (commasep [ppr w, ppr per_thread_elems] <> offset_text) <+>
ppr v
where suff = case o of SplitContiguous -> mempty
SplitStrided stride -> text "Strided" <> parens (ppr stride)
offset_text = case offset of Nothing -> ""
Just se -> "," <+> "offset=" <> ppr se
data HostOp lore inner
= GetSize Name SizeClass
| GetSizeMax SizeClass
| CmpSizeLe Name SizeClass SubExp
| HostOp inner
deriving (Eq, Ord, Show)
instance Substitute inner => Substitute (HostOp lore inner) where
substituteNames substs (HostOp op) =
HostOp $ substituteNames substs op
substituteNames substs (CmpSizeLe name sclass x) =
CmpSizeLe name sclass $ substituteNames substs x
substituteNames _ x = x
instance Rename inner => Rename (HostOp lore inner) where
rename (HostOp op) = HostOp <$> rename op
rename (CmpSizeLe name sclass x) = CmpSizeLe name sclass <$> rename x
rename x = pure x
instance IsOp inner => IsOp (HostOp lore inner) where
safeOp (HostOp op) = safeOp op
safeOp _ = True
cheapOp (HostOp op) = cheapOp op
cheapOp _ = True
instance TypedOp inner => TypedOp (HostOp lore inner) where
opType GetSize{} = pure [Prim int32]
opType GetSizeMax{} = pure [Prim int32]
opType CmpSizeLe{} = pure [Prim Bool]
opType (HostOp op) = opType op
instance AliasedOp inner => AliasedOp (HostOp lore inner) where
opAliases (HostOp op) = opAliases op
opAliases _ = [mempty]
consumedInOp (HostOp op) = consumedInOp op
consumedInOp _ = mempty
instance RangedOp inner => RangedOp (HostOp lore inner) where
opRanges (HostOp op) = opRanges op
opRanges _ = [unknownRange]
instance FreeIn inner => FreeIn (HostOp lore inner) where
freeIn (HostOp op) = freeIn op
freeIn (CmpSizeLe _ _ x) = freeIn x
freeIn _ = mempty
instance CanBeAliased inner => CanBeAliased (HostOp lore inner) where
type OpWithAliases (HostOp lore inner) = HostOp (Aliases lore) (OpWithAliases inner)
addOpAliases (HostOp op) = HostOp $ addOpAliases op
addOpAliases (GetSize name sclass) = GetSize name sclass
addOpAliases (GetSizeMax sclass) = GetSizeMax sclass
addOpAliases (CmpSizeLe name sclass x) = CmpSizeLe name sclass x
removeOpAliases (HostOp op) = HostOp $ removeOpAliases op
removeOpAliases (GetSize name sclass) = GetSize name sclass
removeOpAliases (GetSizeMax sclass) = GetSizeMax sclass
removeOpAliases (CmpSizeLe name sclass x) = CmpSizeLe name sclass x
instance CanBeRanged inner => CanBeRanged (HostOp lore inner) where
type OpWithRanges (HostOp lore inner) = HostOp (Ranges lore) (OpWithRanges inner)
addOpRanges (HostOp op) = HostOp $ addOpRanges op
addOpRanges (GetSize name sclass) = GetSize name sclass
addOpRanges (GetSizeMax sclass) = GetSizeMax sclass
addOpRanges (CmpSizeLe name sclass x) = CmpSizeLe name sclass x
removeOpRanges (HostOp op) = HostOp $ removeOpRanges op
removeOpRanges (GetSize name sclass) = GetSize name sclass
removeOpRanges (GetSizeMax sclass) = GetSizeMax sclass
removeOpRanges (CmpSizeLe name sclass x) = CmpSizeLe name sclass x
instance CanBeWise inner => CanBeWise (HostOp lore inner) where
type OpWithWisdom (HostOp lore inner) = HostOp (Wise lore) (OpWithWisdom inner)
removeOpWisdom (HostOp op) = HostOp $ removeOpWisdom op
removeOpWisdom (GetSize name sclass) = GetSize name sclass
removeOpWisdom (GetSizeMax sclass) = GetSizeMax sclass
removeOpWisdom (CmpSizeLe name sclass x) = CmpSizeLe name sclass x
instance ST.IndexOp op => ST.IndexOp (HostOp lore op) where
indexOp vtable k (HostOp op) is = ST.indexOp vtable k op is
indexOp _ _ _ _ = Nothing
instance PP.Pretty inner => PP.Pretty (HostOp lore inner) where
ppr (GetSize name size_class) =
text "get_size" <> parens (commasep [ppr name, ppr size_class])
ppr (GetSizeMax size_class) =
text "get_size_max" <> parens (ppr size_class)
ppr (CmpSizeLe name size_class x) =
text "get_size" <> parens (commasep [ppr name, ppr size_class]) <+>
text "<=" <+> ppr x
ppr (HostOp op) = ppr op
instance OpMetrics inner => OpMetrics (HostOp lore inner) where
opMetrics GetSize{} = seen "GetSize"
opMetrics GetSizeMax{} = seen "GetSizeMax"
opMetrics CmpSizeLe{} = seen "CmpSizeLe"
opMetrics (HostOp op) = opMetrics op
typeCheckHostOp :: TC.Checkable lore =>
(inner -> TC.TypeM lore ())
-> HostOp (Aliases lore) inner
-> TC.TypeM lore ()
typeCheckHostOp _ GetSize{} = return ()
typeCheckHostOp _ GetSizeMax{} = return ()
typeCheckHostOp _ (CmpSizeLe _ _ x) = TC.require [Prim int32] x
typeCheckHostOp f (HostOp op) = f op