{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE ScopedTypeVariables #-} module Futhark.Representation.Kernels.Kernel ( GenReduceOp(..) , SegRedOp(..) , segRedResults , KernelBody(..) , KernelResult(..) , kernelResultSubExp , SplitOrdering(..) -- * Segmented operations , SegOp(..) , SegLevel(..) , SegVirt(..) , segLevel , segSpace , typeCheckSegOp , SegSpace(..) , scopeOfSegSpace , segSpaceDims -- ** Generic traversal , SegOpMapper(..) , identitySegOpMapper , mapSegOpM , SegOpWalker(..) , identitySegOpWalker , walkSegOpM -- * Host operations , HostOp(..) , typeCheckHostOp -- * Reexports , module Futhark.Representation.Kernels.Sizes ) where import Control.Arrow (first) import Control.Monad.State.Strict import Control.Monad.Writer hiding (mapM_) import Control.Monad.Identity hiding (mapM_) import qualified Data.Map.Strict as M import Data.Foldable import Data.List import Futhark.Representation.AST import qualified Futhark.Analysis.Alias as Alias import qualified Futhark.Analysis.ScalExp as SE import qualified Futhark.Analysis.SymbolTable as ST import Futhark.Analysis.PrimExp.Convert import qualified Futhark.Util.Pretty as PP import Futhark.Util.Pretty ((), (<+>), ppr, commasep, Pretty, parens, text) import Futhark.Transform.Substitute import Futhark.Transform.Rename import Futhark.Optimise.Simplify.Lore import Futhark.Representation.Ranges (Ranges, removeLambdaRanges, removeStmRanges, mkBodyRanges) import Futhark.Representation.AST.Attributes.Ranges import Futhark.Representation.AST.Attributes.Aliases import Futhark.Representation.Aliases (Aliases, removeLambdaAliases, removeStmAliases) import Futhark.Representation.Kernels.Sizes import qualified Futhark.TypeCheck as TC import Futhark.Analysis.Metrics import qualified Futhark.Analysis.Range as Range import Futhark.Util (maybeNth) -- | How an array is split into chunks. data SplitOrdering = SplitContiguous | SplitStrided SubExp deriving (Eq, Ord, Show) instance FreeIn SplitOrdering where freeIn' SplitContiguous = mempty freeIn' (SplitStrided stride) = freeIn' stride instance Substitute SplitOrdering where substituteNames _ SplitContiguous = SplitContiguous substituteNames subst (SplitStrided stride) = SplitStrided $ substituteNames subst stride instance Rename SplitOrdering where rename SplitContiguous = pure SplitContiguous rename (SplitStrided stride) = SplitStrided <$> rename stride data GenReduceOp lore = GenReduceOp { genReduceWidth :: SubExp , genReduceDest :: [VName] , genReduceNeutral :: [SubExp] , genReduceShape :: Shape -- ^ In case this operator is semantically a -- vectorised operator (corresponding to a perfect map -- nest in the SOACS representation), these are the -- logical "dimensions". This is used to generate -- more efficient code. , genReduceOp :: LambdaT lore } deriving (Eq, Ord, Show) data SegRedOp lore = SegRedOp { segRedComm :: Commutativity , segRedLambda :: Lambda lore , segRedNeutral :: [SubExp] , segRedShape :: Shape -- ^ In case this operator is semantically a vectorised -- operator (corresponding to a perfect map nest in the -- SOACS representation), these are the logical -- "dimensions". This is used to generate more efficient -- code. } deriving (Eq, Ord, Show) -- | How many reduction results are produced by these 'SegRedOp's? segRedResults :: [SegRedOp lore] -> Int segRedResults = sum . map (length . segRedNeutral) -- | The body of a 'Kernel'. data KernelBody lore = KernelBody { kernelBodyLore :: BodyAttr lore , kernelBodyStms :: Stms lore , kernelBodyResult :: [KernelResult] } deriving instance Annotations lore => Ord (KernelBody lore) deriving instance Annotations lore => Show (KernelBody lore) deriving instance Annotations lore => Eq (KernelBody lore) data KernelResult = Returns SubExp -- ^ Each "worker" in the kernel returns this. -- Whether this is a result-per-thread or a -- result-per-group depends on the 'SegLevel'. | WriteReturns [SubExp] -- Size of array. Must match number of dims. VName -- Which array [([SubExp], SubExp)] -- Arbitrary number of index/value pairs. | ConcatReturns SplitOrdering -- Permuted? SubExp -- The final size. SubExp -- Per-thread/group (max) chunk size. VName -- Chunk by this worker. | TileReturns [(SubExp, SubExp)] -- Total/tile for each dimension VName -- Tile written by this worker. -- The TileReturns must not expect more than one -- result to be written per physical thread. deriving (Eq, Show, Ord) kernelResultSubExp :: KernelResult -> SubExp kernelResultSubExp (Returns se) = se kernelResultSubExp (WriteReturns _ arr _) = Var arr kernelResultSubExp (ConcatReturns _ _ _ v) = Var v kernelResultSubExp (TileReturns _ v) = Var v instance FreeIn KernelResult where freeIn' (Returns what) = freeIn' what freeIn' (WriteReturns rws arr res) = freeIn' rws <> freeIn' arr <> freeIn' res freeIn' (ConcatReturns o w per_thread_elems v) = freeIn' o <> freeIn' w <> freeIn' per_thread_elems <> freeIn' v freeIn' (TileReturns dims v) = freeIn' dims <> freeIn' v instance Attributes lore => FreeIn (KernelBody lore) where freeIn' (KernelBody attr stms res) = fvBind bound_in_stms $ freeIn' attr <> freeIn' stms <> freeIn' res where bound_in_stms = fold $ fmap boundByStm stms instance Attributes lore => Substitute (KernelBody lore) where substituteNames subst (KernelBody attr stms res) = KernelBody (substituteNames subst attr) (substituteNames subst stms) (substituteNames subst res) instance Substitute KernelResult where substituteNames subst (Returns se) = Returns $ substituteNames subst se substituteNames subst (WriteReturns rws arr res) = WriteReturns (substituteNames subst rws) (substituteNames subst arr) (substituteNames subst res) substituteNames subst (ConcatReturns o w per_thread_elems v) = ConcatReturns (substituteNames subst o) (substituteNames subst w) (substituteNames subst per_thread_elems) (substituteNames subst v) substituteNames subst (TileReturns dims v) = TileReturns (substituteNames subst dims) (substituteNames subst v) instance Attributes lore => Rename (KernelBody lore) where rename (KernelBody attr stms res) = do attr' <- rename attr renamingStms stms $ \stms' -> KernelBody attr' stms' <$> rename res instance Rename KernelResult where rename = substituteRename aliasAnalyseKernelBody :: (Attributes lore, CanBeAliased (Op lore)) => KernelBody lore -> KernelBody (Aliases lore) aliasAnalyseKernelBody (KernelBody attr stms res) = let Body attr' stms' _ = Alias.analyseBody $ Body attr stms [] in KernelBody attr' stms' res removeKernelBodyAliases :: CanBeAliased (Op lore) => KernelBody (Aliases lore) -> KernelBody lore removeKernelBodyAliases (KernelBody (_, attr) stms res) = KernelBody attr (fmap removeStmAliases stms) res addKernelBodyRanges :: (Attributes lore, CanBeRanged (Op lore)) => KernelBody lore -> Range.RangeM (KernelBody (Ranges lore)) addKernelBodyRanges (KernelBody attr stms res) = Range.analyseStms stms $ \stms' -> do let attr' = (mkBodyRanges stms $ map kernelResultSubExp res, attr) return $ KernelBody attr' stms' res removeKernelBodyRanges :: CanBeRanged (Op lore) => KernelBody (Ranges lore) -> KernelBody lore removeKernelBodyRanges (KernelBody (_, attr) stms res) = KernelBody attr (fmap removeStmRanges stms) res removeKernelBodyWisdom :: CanBeWise (Op lore) => KernelBody (Wise lore) -> KernelBody lore removeKernelBodyWisdom (KernelBody attr stms res) = let Body attr' stms' _ = removeBodyWisdom $ Body attr stms [] in KernelBody attr' stms' res consumedInKernelBody :: Aliased lore => KernelBody lore -> Names consumedInKernelBody (KernelBody attr stms res) = consumedInBody (Body attr stms []) <> mconcat (map consumedByReturn res) where consumedByReturn (WriteReturns _ a _) = oneName a consumedByReturn _ = mempty checkKernelBody :: TC.Checkable lore => [Type] -> KernelBody (Aliases lore) -> TC.TypeM lore () checkKernelBody ts (KernelBody (_, attr) stms kres) = do TC.checkBodyLore attr TC.checkStms stms $ do unless (length ts == length kres) $ TC.bad $ TC.TypeError $ "Kernel return type is " ++ prettyTuple ts ++ ", but body returns " ++ show (length kres) ++ " values." zipWithM_ checkKernelResult kres ts where checkKernelResult (Returns what) t = TC.require [t] what checkKernelResult (WriteReturns rws arr res) t = do mapM_ (TC.require [Prim int32]) rws arr_t <- lookupType arr forM_ res $ \(is, e) -> do mapM_ (TC.require [Prim int32]) is TC.require [t] e unless (arr_t == t `arrayOfShape` Shape rws) $ TC.bad $ TC.TypeError $ "WriteReturns returning " ++ pretty e ++ " of type " ++ pretty t ++ ", shape=" ++ pretty rws ++ ", but destination array has type " ++ pretty arr_t TC.consume =<< TC.lookupAliases arr checkKernelResult (ConcatReturns o w per_thread_elems v) t = do case o of SplitContiguous -> return () SplitStrided stride -> TC.require [Prim int32] stride TC.require [Prim int32] w TC.require [Prim int32] per_thread_elems vt <- lookupType v unless (vt == t `arrayOfRow` arraySize 0 vt) $ TC.bad $ TC.TypeError $ "Invalid type for ConcatReturns " ++ pretty v checkKernelResult (TileReturns dims v) t = do forM_ dims $ \(dim, tile) -> do TC.require [Prim int32] dim TC.require [Prim int32] tile vt <- lookupType v unless (vt == t `arrayOfShape` Shape (map snd dims)) $ TC.bad $ TC.TypeError $ "Invalid type for TileReturns " ++ pretty v kernelBodyMetrics :: OpMetrics (Op lore) => KernelBody lore -> MetricsM () kernelBodyMetrics = mapM_ bindingMetrics . kernelBodyStms instance PrettyLore lore => Pretty (KernelBody lore) where ppr (KernelBody _ stms res) = PP.stack (map ppr (stmsToList stms)) text "return" <+> PP.braces (PP.commasep $ map ppr res) instance Pretty KernelResult where ppr (Returns what) = text "thread returns" <+> ppr what ppr (WriteReturns rws arr res) = ppr arr <+> text "with" <+> PP.apply (map ppRes res) where ppRes (is, e) = PP.brackets (PP.commasep $ zipWith f is rws) <+> text "<-" <+> ppr e f i rw = ppr i <+> text "<" <+> ppr rw ppr (ConcatReturns o w per_thread_elems v) = text "concat" <> suff <> parens (commasep [ppr w, ppr per_thread_elems]) <+> ppr v where suff = case o of SplitContiguous -> mempty SplitStrided stride -> text "Strided" <> parens (ppr stride) ppr (TileReturns dims v) = text "tile" <> parens (commasep $ map onDim dims) <+> ppr v where onDim (dim, tile) = ppr dim <+> text "/" <+> ppr tile --- Segmented operations -- | Do we need group-virtualisation when generating code for the -- segmented operation? In most cases, we do, but for some simple -- kernels, we compute the full number of groups in advance, and then -- virtualisation is an unnecessary (but generally very small) -- overhead. This only really matters for fairly trivial but very -- wide @map@ kernels where each thread performs constant-time work on -- scalars. data SegVirt = SegVirt | SegNoVirt deriving (Eq, Ord, Show) -- | At which level the *body* of a 'SegOp' executes. data SegLevel = SegThread { segNumGroups :: Count NumGroups SubExp , segGroupSize :: Count GroupSize SubExp , segVirt :: SegVirt } | SegGroup { segNumGroups :: Count NumGroups SubExp , segGroupSize :: Count GroupSize SubExp , segVirt :: SegVirt } | SegThreadScalar { segNumGroups :: Count NumGroups SubExp , segGroupSize :: Count GroupSize SubExp , segVirt :: SegVirt } -- ^ Like 'SegThread', but with the invariant that the -- results produced are only used within the same -- physical thread later on, and can thus be kept in -- registers. May only occur within an enclosing -- 'SegGroup' construct. deriving (Eq, Ord, Show) -- | Index space of a 'SegOp'. data SegSpace = SegSpace { segFlat :: VName -- ^ Flat physical index corresponding to the -- dimensions (at code generation used for a -- thread ID or similar). , unSegSpace :: [(VName, SubExp)] } deriving (Eq, Ord, Show) segSpaceDims :: SegSpace -> [SubExp] segSpaceDims (SegSpace _ space) = map snd space scopeOfSegSpace :: SegSpace -> Scope lore scopeOfSegSpace (SegSpace phys space) = M.fromList $ zip (phys : map fst space) $ repeat $ IndexInfo Int32 checkSegSpace :: TC.Checkable lore => SegSpace -> TC.TypeM lore () checkSegSpace (SegSpace _ dims) = mapM_ (TC.require [Prim int32] . snd) dims data SegOp lore = SegMap SegLevel SegSpace [Type] (KernelBody lore) | SegRed SegLevel SegSpace [SegRedOp lore] [Type] (KernelBody lore) -- ^ The KernelSpace must always have at least two dimensions, -- implying that the result of a SegRed is always an array. | SegScan SegLevel SegSpace (Lambda lore) [SubExp] [Type] (KernelBody lore) | SegGenRed SegLevel SegSpace [GenReduceOp lore] [Type] (KernelBody lore) deriving (Eq, Ord, Show) segLevel :: SegOp lore -> SegLevel segLevel (SegMap lvl _ _ _) = lvl segLevel (SegRed lvl _ _ _ _) = lvl segLevel (SegScan lvl _ _ _ _ _) = lvl segLevel (SegGenRed lvl _ _ _ _) = lvl segSpace :: SegOp lore -> SegSpace segSpace (SegMap _ lvl _ _) = lvl segSpace (SegRed _ lvl _ _ _) = lvl segSpace (SegScan _ lvl _ _ _ _) = lvl segSpace (SegGenRed _ lvl _ _ _) = lvl segResultShape :: SegSpace -> Type -> KernelResult -> Type segResultShape _ t (WriteReturns rws _ _) = t `arrayOfShape` Shape rws segResultShape space t (Returns _) = foldr (flip arrayOfRow) t $ segSpaceDims space segResultShape _ t (ConcatReturns _ w _ _) = t `arrayOfRow` w segResultShape _ t (TileReturns dims _) = t `arrayOfShape` Shape (map fst dims) segOpType :: SegOp lore -> [Type] segOpType (SegMap _ space ts kbody) = zipWith (segResultShape space) ts $ kernelBodyResult kbody segOpType (SegRed _ space reds ts kbody) = red_ts ++ zipWith (segResultShape space) map_ts (drop (length red_ts) $ kernelBodyResult kbody) where map_ts = drop (length red_ts) ts segment_dims = init $ segSpaceDims space red_ts = do op <- reds let shape = Shape segment_dims <> segRedShape op map (`arrayOfShape` shape) (lambdaReturnType $ segRedLambda op) segOpType (SegScan _ space _ nes ts kbody) = map (`arrayOfShape` Shape dims) scan_ts ++ zipWith (segResultShape space) map_ts (drop (length scan_ts) $ kernelBodyResult kbody) where dims = segSpaceDims space (scan_ts, map_ts) = splitAt (length nes) ts segOpType (SegGenRed _ space ops _ _) = do op <- ops let shape = Shape (segment_dims <> [genReduceWidth op]) <> genReduceShape op map (`arrayOfShape` shape) (lambdaReturnType $ genReduceOp op) where dims = segSpaceDims space segment_dims = init dims instance TypedOp (SegOp lore) where opType = pure . staticShapes . segOpType instance (Attributes lore, Aliased lore) => AliasedOp (SegOp lore) where opAliases = map (const mempty) . segOpType consumedInOp (SegMap _ _ _ kbody) = consumedInKernelBody kbody consumedInOp (SegRed _ _ _ _ kbody) = consumedInKernelBody kbody consumedInOp (SegScan _ _ _ _ _ kbody) = consumedInKernelBody kbody consumedInOp (SegGenRed _ _ ops _ kbody) = namesFromList (concatMap genReduceDest ops) <> consumedInKernelBody kbody checkSegLevel :: Maybe SegLevel -> SegLevel -> TC.TypeM lore () checkSegLevel Nothing SegThreadScalar{} = TC.bad $ TC.TypeError "SegThreadScalar at top level." checkSegLevel Nothing _ = return () checkSegLevel (Just SegThread{}) _ = TC.bad $ TC.TypeError "SegOps cannot occur when already at thread level." checkSegLevel (Just x) y | x == y = TC.bad $ TC.TypeError $ "Already at at level " ++ pretty x | segNumGroups x /= segNumGroups y || segGroupSize x /= segGroupSize y = TC.bad $ TC.TypeError "Physical layout for SegLevel does not match parent SegLevel." | otherwise = return () checkSegBasics :: TC.Checkable lore => Maybe SegLevel -> SegLevel -> SegSpace -> [Type] -> TC.TypeM lore () checkSegBasics cur_lvl lvl space ts = do checkSegLevel cur_lvl lvl checkSegSpace space mapM_ TC.checkType ts typeCheckSegOp :: TC.Checkable lore => Maybe SegLevel -> SegOp (Aliases lore) -> TC.TypeM lore () typeCheckSegOp cur_lvl (SegMap lvl space ts kbody) = checkScanRed cur_lvl lvl space [] ts kbody typeCheckSegOp cur_lvl (SegRed lvl space reds ts body) = checkScanRed cur_lvl lvl space reds' ts body where reds' = zip3 (map segRedLambda reds) (map segRedNeutral reds) (map segRedShape reds) typeCheckSegOp cur_lvl (SegScan lvl space scan_op nes ts body) = checkScanRed cur_lvl lvl space [(scan_op, nes, mempty)] ts body typeCheckSegOp cur_lvl (SegGenRed lvl space ops ts kbody) = do checkSegBasics cur_lvl lvl space ts TC.binding (scopeOfSegSpace space) $ do nes_ts <- forM ops $ \(GenReduceOp dest_w dests nes shape op) -> do TC.require [Prim int32] dest_w nes' <- mapM TC.checkArg nes mapM_ (TC.require [Prim int32]) $ shapeDims shape -- Operator type must match the type of neutral elements. let stripVecDims = stripArray $ shapeRank shape TC.checkLambda op $ map (TC.noArgAliases . first stripVecDims) $ nes' ++ nes' let nes_t = map TC.argType nes' unless (nes_t == lambdaReturnType op) $ TC.bad $ TC.TypeError $ "SegGenRed operator has return type " ++ prettyTuple (lambdaReturnType op) ++ " but neutral element has type " ++ prettyTuple nes_t -- Arrays must have proper type. let dest_shape = Shape (segment_dims <> [dest_w]) <> shape forM_ (zip nes_t dests) $ \(t, dest) -> do TC.requireI [t `arrayOfShape` dest_shape] dest TC.consume =<< TC.lookupAliases dest return $ map (`arrayOfShape` shape) nes_t checkKernelBody ts kbody -- Return type of bucket function must be an index for each -- operation followed by the values to write. let bucket_ret_t = replicate (length ops) (Prim int32) ++ concat nes_ts unless (bucket_ret_t == ts) $ TC.bad $ TC.TypeError $ "SegGenRed body has return type " ++ prettyTuple ts ++ " but should have type " ++ prettyTuple bucket_ret_t where segment_dims = init $ segSpaceDims space checkScanRed :: TC.Checkable lore => Maybe SegLevel -> SegLevel -> SegSpace -> [(Lambda (Aliases lore), [SubExp], Shape)] -> [Type] -> KernelBody (Aliases lore) -> TC.TypeM lore () checkScanRed cur_lvl lvl space ops ts kbody = do checkSegBasics cur_lvl lvl space ts TC.binding (scopeOfSegSpace space) $ do ne_ts <- forM ops $ \(lam, nes, shape) -> do mapM_ (TC.require [Prim int32]) $ shapeDims shape nes' <- mapM TC.checkArg nes -- Operator type must match the type of neutral elements. let stripVecDims = stripArray $ shapeRank shape TC.checkLambda lam $ map (TC.noArgAliases . first stripVecDims) $ nes' ++ nes' let nes_t = map TC.argType nes' unless (lambdaReturnType lam == nes_t) $ TC.bad $ TC.TypeError "wrong type for operator or neutral elements." return $ map (`arrayOfShape` shape) nes_t let expecting = concat ne_ts got = take (length expecting) ts unless (expecting == got) $ TC.bad $ TC.TypeError $ "Wrong return for body (does not match neutral elements; expected " ++ pretty expecting ++ "; found " ++ pretty got ++ ")" checkKernelBody ts kbody -- | Like 'Mapper', but just for 'SegOp's. data SegOpMapper flore tlore m = SegOpMapper { mapOnSegOpSubExp :: SubExp -> m SubExp , mapOnSegOpLambda :: Lambda flore -> m (Lambda tlore) , mapOnSegOpBody :: KernelBody flore -> m (KernelBody tlore) , mapOnSegOpVName :: VName -> m VName } -- | A mapper that simply returns the 'SegOp' verbatim. identitySegOpMapper :: Monad m => SegOpMapper lore lore m identitySegOpMapper = SegOpMapper { mapOnSegOpSubExp = return , mapOnSegOpLambda = return , mapOnSegOpBody = return , mapOnSegOpVName = return } mapOnSegSpace :: Monad f => SegOpMapper flore tlore f -> SegSpace -> f SegSpace mapOnSegSpace tv (SegSpace phys dims) = SegSpace phys <$> traverse (traverse $ mapOnSegOpSubExp tv) dims mapSegOpM :: (Applicative m, Monad m) => SegOpMapper flore tlore m -> SegOp flore -> m (SegOp tlore) mapSegOpM tv (SegMap lvl space ts body) = SegMap <$> mapOnSegLevel tv lvl <*> mapOnSegSpace tv space <*> mapM (mapOnSegOpType tv) ts <*> mapOnSegOpBody tv body mapSegOpM tv (SegRed lvl space reds ts lam) = SegRed <$> mapOnSegLevel tv lvl <*> mapOnSegSpace tv space <*> mapM onSegOp reds <*> mapM (mapOnType $ mapOnSegOpSubExp tv) ts <*> mapOnSegOpBody tv lam where onSegOp (SegRedOp comm red_op nes shape) = SegRedOp comm <$> mapOnSegOpLambda tv red_op <*> mapM (mapOnSegOpSubExp tv) nes <*> (Shape <$> mapM (mapOnSegOpSubExp tv) (shapeDims shape)) mapSegOpM tv (SegScan lvl space scan_op nes ts body) = SegScan <$> mapOnSegLevel tv lvl <*> mapOnSegSpace tv space <*> mapOnSegOpLambda tv scan_op <*> mapM (mapOnSegOpSubExp tv) nes <*> mapM (mapOnType $ mapOnSegOpSubExp tv) ts <*> mapOnSegOpBody tv body mapSegOpM tv (SegGenRed lvl space ops ts body) = SegGenRed <$> mapOnSegLevel tv lvl <*> mapOnSegSpace tv space <*> mapM onGenRedOp ops <*> mapM (mapOnType $ mapOnSegOpSubExp tv) ts <*> mapOnSegOpBody tv body where onGenRedOp (GenReduceOp w arrs nes shape op) = GenReduceOp <$> mapOnSegOpSubExp tv w <*> mapM (mapOnSegOpVName tv) arrs <*> mapM (mapOnSegOpSubExp tv) nes <*> (Shape <$> mapM (mapOnSegOpSubExp tv) (shapeDims shape)) <*> mapOnSegOpLambda tv op mapOnSegLevel :: Monad m => SegOpMapper flore tlore m -> SegLevel -> m SegLevel mapOnSegLevel tv (SegThread num_groups group_size virt) = SegThread <$> traverse (mapOnSegOpSubExp tv) num_groups <*> traverse (mapOnSegOpSubExp tv) group_size <*> pure virt mapOnSegLevel tv (SegGroup num_groups group_size virt) = SegGroup <$> traverse (mapOnSegOpSubExp tv) num_groups <*> traverse (mapOnSegOpSubExp tv) group_size <*> pure virt mapOnSegLevel tv (SegThreadScalar num_groups group_size virt) = SegThreadScalar <$> traverse (mapOnSegOpSubExp tv) num_groups <*> traverse (mapOnSegOpSubExp tv) group_size <*> pure virt mapOnSegOpType :: Monad m => SegOpMapper flore tlore m -> Type -> m Type mapOnSegOpType _tv (Prim pt) = pure $ Prim pt mapOnSegOpType tv (Array pt shape u) = Array pt <$> f shape <*> pure u where f (Shape dims) = Shape <$> mapM (mapOnSegOpSubExp tv) dims mapOnSegOpType _tv (Mem s) = pure $ Mem s -- | Like 'Walker', but just for 'SegOp's. data SegOpWalker lore m = SegOpWalker { walkOnSegOpSubExp :: SubExp -> m () , walkOnSegOpLambda :: Lambda lore -> m () , walkOnSegOpBody :: KernelBody lore -> m () , walkOnSegOpVName :: VName -> m () } -- | A no-op traversal. identitySegOpWalker :: Monad m => SegOpWalker lore m identitySegOpWalker = SegOpWalker { walkOnSegOpSubExp = const $ return () , walkOnSegOpLambda = const $ return () , walkOnSegOpBody = const $ return () , walkOnSegOpVName = const $ return () } walkSegOpMapper :: forall lore m. Monad m => SegOpWalker lore m -> SegOpMapper lore lore m walkSegOpMapper f = SegOpMapper { mapOnSegOpSubExp = wrap walkOnSegOpSubExp , mapOnSegOpLambda = wrap walkOnSegOpLambda , mapOnSegOpBody = wrap walkOnSegOpBody , mapOnSegOpVName = wrap walkOnSegOpVName } where wrap :: (SegOpWalker lore m -> a -> m ()) -> a -> m a wrap op k = op f k >> return k -- | As 'mapSegOpM', but ignoring the results. walkSegOpM :: Monad m => SegOpWalker lore m -> SegOp lore -> m () walkSegOpM f = void . mapSegOpM m where m = walkSegOpMapper f instance Attributes lore => Substitute (SegOp lore) where substituteNames subst = runIdentity . mapSegOpM substitute where substitute = SegOpMapper { mapOnSegOpSubExp = return . substituteNames subst , mapOnSegOpLambda = return . substituteNames subst , mapOnSegOpBody = return . substituteNames subst , mapOnSegOpVName = return . substituteNames subst } instance Attributes lore => Rename (SegOp lore) where rename = mapSegOpM renamer where renamer = SegOpMapper rename rename rename rename instance (Attributes lore, FreeIn (LParamAttr lore)) => FreeIn (SegOp lore) where freeIn' e = flip execState mempty $ mapSegOpM free e where walk f x = modify (<>f x) >> return x free = SegOpMapper { mapOnSegOpSubExp = walk freeIn' , mapOnSegOpLambda = walk freeIn' , mapOnSegOpBody = walk freeIn' , mapOnSegOpVName = walk freeIn' } instance OpMetrics (Op lore) => OpMetrics (SegOp lore) where opMetrics (SegMap _ _ _ body) = inside "SegMap" $ kernelBodyMetrics body opMetrics (SegRed _ _ reds _ body) = inside "SegRed" $ do mapM_ (lambdaMetrics . segRedLambda) reds kernelBodyMetrics body opMetrics (SegScan _ _ scan_op _ _ body) = inside "SegScan" $ lambdaMetrics scan_op >> kernelBodyMetrics body opMetrics (SegGenRed _ _ ops _ body) = inside "SegGenRed" $ do mapM_ (lambdaMetrics . genReduceOp) ops kernelBodyMetrics body instance Pretty SegSpace where ppr (SegSpace phys dims) = parens (commasep $ do (i,d) <- dims return $ ppr i <+> "<" <+> ppr d) <+> parens (text "~" <> ppr phys) instance PP.Pretty SegLevel where ppr SegThread{} = "thread" ppr SegThreadScalar{} = "scalar" ppr SegGroup{} = "group" ppSegLevel :: SegLevel -> PP.Doc ppSegLevel lvl = PP.parens $ text "#groups=" <> ppr (segNumGroups lvl) <> PP.semi <+> text "groupsize=" <> ppr (segGroupSize lvl) <> case segVirt lvl of SegNoVirt -> mempty SegVirt -> PP.semi <+> text "virtualise" instance PrettyLore lore => PP.Pretty (SegOp lore) where ppr (SegMap lvl space ts body) = text "segmap_" <> ppr lvl ppSegLevel lvl PP.align (ppr space) <+> PP.colon <+> ppTuple' ts <+> PP.nestedBlock "{" "}" (ppr body) ppr (SegRed lvl space reds ts body) = text "segred_" <> ppr lvl ppSegLevel lvl PP.parens (PP.braces (mconcat $ intersperse (PP.comma <> PP.line) $ map ppOp reds)) PP.align (ppr space) <+> PP.colon <+> ppTuple' ts <+> PP.nestedBlock "{" "}" (ppr body) where ppOp (SegRedOp comm lam nes shape) = PP.braces (PP.commasep $ map ppr nes) <> PP.comma ppr shape <> PP.comma comm' <> ppr lam where comm' = case comm of Commutative -> text "commutative " Noncommutative -> mempty ppr (SegScan lvl space scan_op nes ts body) = text "segscan_" <> ppr lvl ppSegLevel lvl PP.parens (ppr scan_op <> PP.comma PP.braces (PP.commasep $ map ppr nes)) PP.align (ppr space) <+> PP.colon <+> ppTuple' ts <+> PP.nestedBlock "{" "}" (ppr body) ppr (SegGenRed lvl space ops ts body) = text "seggenred_" <> ppr lvl ppSegLevel lvl PP.parens (PP.braces (mconcat $ intersperse (PP.comma <> PP.line) $ map ppOp ops)) PP.align (ppr space) <+> PP.colon <+> ppTuple' ts <+> PP.nestedBlock "{" "}" (ppr body) where ppOp (GenReduceOp w dests nes shape op) = ppr w <> PP.comma PP.braces (PP.commasep $ map ppr dests) <> PP.comma PP.braces (PP.commasep $ map ppr nes) <> PP.comma ppr shape <> PP.comma ppr op instance Attributes inner => RangedOp (SegOp inner) where opRanges op = replicate (length $ segOpType op) unknownRange instance (Attributes lore, CanBeRanged (Op lore)) => CanBeRanged (SegOp lore) where type OpWithRanges (SegOp lore) = SegOp (Ranges lore) removeOpRanges = runIdentity . mapSegOpM remove where remove = SegOpMapper return (return . removeLambdaRanges) (return . removeKernelBodyRanges) return addOpRanges = Range.runRangeM . mapSegOpM add where add = SegOpMapper return Range.analyseLambda addKernelBodyRanges return instance (Attributes lore, Attributes (Aliases lore), CanBeAliased (Op lore)) => CanBeAliased (SegOp lore) where type OpWithAliases (SegOp lore) = SegOp (Aliases lore) addOpAliases = runIdentity . mapSegOpM alias where alias = SegOpMapper return (return . Alias.analyseLambda) (return . aliasAnalyseKernelBody) return removeOpAliases = runIdentity . mapSegOpM remove where remove = SegOpMapper return (return . removeLambdaAliases) (return . removeKernelBodyAliases) return instance (CanBeWise (Op lore), Attributes lore) => CanBeWise (SegOp lore) where type OpWithWisdom (SegOp lore) = SegOp (Wise lore) removeOpWisdom = runIdentity . mapSegOpM remove where remove = SegOpMapper return (return . removeLambdaWisdom) (return . removeKernelBodyWisdom) return instance Attributes lore => ST.IndexOp (SegOp lore) where indexOp vtable k (SegMap _ space _ kbody) is = do Returns se <- maybeNth k $ kernelBodyResult kbody let (gtids, _) = unzip $ unSegSpace space guard $ length gtids == length is let prim_table = M.fromList $ zip gtids $ zip is $ repeat mempty prim_table' = foldl expandPrimExpTable prim_table $ kernelBodyStms kbody case se of Var v -> M.lookup v prim_table' _ -> Nothing where expandPrimExpTable table stm | [v] <- patternNames $ stmPattern stm, Just (pe,cs) <- runWriterT $ primExpFromExp (asPrimExp table) $ stmExp stm = M.insert v (pe, stmCerts stm <> cs) table | otherwise = table asPrimExp table v | Just (e,cs) <- M.lookup v table = tell cs >> return e | Just (Prim pt) <- ST.lookupType v vtable = return $ LeafExp v pt | otherwise = lift Nothing indexOp _ _ _ _ = Nothing instance Attributes lore => IsOp (SegOp lore) where cheapOp _ = False safeOp _ = True --- Host operations -- | A host-level operation; parameterised by what else it can do. data HostOp lore op = SplitSpace SplitOrdering SubExp SubExp SubExp -- ^ @SplitSpace o w i elems_per_thread@. -- -- Computes how to divide array elements to -- threads in a kernel. Returns the number of -- elements in the chunk that the current thread -- should take. -- -- @w@ is the length of the outer dimension in -- the array. @i@ is the current thread -- index. Each thread takes at most -- @elems_per_thread@ elements. -- -- If the order @o@ is 'SplitContiguous', thread with index @i@ -- should receive elements -- @i*elems_per_tread, i*elems_per_thread + 1, -- ..., i*elems_per_thread + (elems_per_thread-1)@. -- -- If the order @o@ is @'SplitStrided' stride@, -- the thread will receive elements @i, -- i+stride, i+2*stride, ..., -- i+(elems_per_thread-1)*stride@. | GetSize Name SizeClass -- ^ Produce some runtime-configurable size. | GetSizeMax SizeClass -- ^ The maximum size of some class. | CmpSizeLe Name SizeClass SubExp -- ^ Compare size (likely a threshold) with some Int32 value. | SegOp (SegOp lore) -- ^ A segmented operation. | OtherOp op deriving (Eq, Ord, Show) instance (Attributes lore, Substitute op) => Substitute (HostOp lore op) where substituteNames substs (SegOp op) = SegOp $ substituteNames substs op substituteNames substs (OtherOp op) = OtherOp $ substituteNames substs op substituteNames subst (SplitSpace o w i elems_per_thread) = SplitSpace (substituteNames subst o) (substituteNames subst w) (substituteNames subst i) (substituteNames subst elems_per_thread) substituteNames substs (CmpSizeLe name sclass x) = CmpSizeLe name sclass $ substituteNames substs x substituteNames _ x = x instance (Attributes lore, Rename op) => Rename (HostOp lore op) where rename (SplitSpace o w i elems_per_thread) = SplitSpace <$> rename o <*> rename w <*> rename i <*> rename elems_per_thread rename (SegOp op) = SegOp <$> rename op rename (OtherOp op) = OtherOp <$> rename op rename (CmpSizeLe name sclass x) = CmpSizeLe name sclass <$> rename x rename x = pure x instance (Attributes lore, IsOp op) => IsOp (HostOp lore op) where safeOp (SegOp op) = safeOp op safeOp (OtherOp op) = safeOp op safeOp _ = True cheapOp (SegOp op) = cheapOp op cheapOp (OtherOp op) = cheapOp op cheapOp _ = True instance TypedOp op => TypedOp (HostOp lore op) where opType SplitSpace{} = pure [Prim int32] opType GetSize{} = pure [Prim int32] opType GetSizeMax{} = pure [Prim int32] opType CmpSizeLe{} = pure [Prim Bool] opType (SegOp op) = opType op opType (OtherOp op) = opType op instance (Aliased lore, AliasedOp op, Attributes lore) => AliasedOp (HostOp lore op) where opAliases (SegOp op) = opAliases op opAliases (OtherOp op) = opAliases op opAliases _ = [mempty] consumedInOp (SegOp op) = consumedInOp op consumedInOp (OtherOp op) = consumedInOp op consumedInOp _ = mempty instance (Attributes lore, RangedOp op) => RangedOp (HostOp lore op) where opRanges (SplitSpace _ _ _ elems_per_thread) = [(Just (ScalarBound 0), Just (ScalarBound (SE.subExpToScalExp elems_per_thread int32)))] opRanges (SegOp op) = opRanges op opRanges (OtherOp op) = opRanges op opRanges _ = [unknownRange] instance (Attributes lore, FreeIn op) => FreeIn (HostOp lore op) where freeIn' (SplitSpace o w i elems_per_thread) = freeIn' o <> freeIn' [w, i, elems_per_thread] freeIn' (SegOp op) = freeIn' op freeIn' (OtherOp op) = freeIn' op freeIn' (CmpSizeLe _ _ x) = freeIn' x freeIn' _ = mempty instance (CanBeAliased (Op lore), CanBeAliased op, Attributes lore) => CanBeAliased (HostOp lore op) where type OpWithAliases (HostOp lore op) = HostOp (Aliases lore) (OpWithAliases op) addOpAliases (SplitSpace o w i elems_per_thread) = SplitSpace o w i elems_per_thread addOpAliases (SegOp op) = SegOp $ addOpAliases op addOpAliases (OtherOp op) = OtherOp $ addOpAliases op addOpAliases (GetSize name sclass) = GetSize name sclass addOpAliases (GetSizeMax sclass) = GetSizeMax sclass addOpAliases (CmpSizeLe name sclass x) = CmpSizeLe name sclass x removeOpAliases (SplitSpace o w i elems_per_thread) = SplitSpace o w i elems_per_thread removeOpAliases (SegOp op) = SegOp $ removeOpAliases op removeOpAliases (OtherOp op) = OtherOp $ removeOpAliases op removeOpAliases (GetSize name sclass) = GetSize name sclass removeOpAliases (GetSizeMax sclass) = GetSizeMax sclass removeOpAliases (CmpSizeLe name sclass x) = CmpSizeLe name sclass x instance (CanBeRanged (Op lore), CanBeRanged op, Attributes lore) => CanBeRanged (HostOp lore op) where type OpWithRanges (HostOp lore op) = HostOp (Ranges lore) (OpWithRanges op) addOpRanges (SplitSpace o w i elems_per_thread) = SplitSpace o w i elems_per_thread addOpRanges (SegOp op) = SegOp $ addOpRanges op addOpRanges (OtherOp op) = OtherOp $ addOpRanges op addOpRanges (GetSize name sclass) = GetSize name sclass addOpRanges (GetSizeMax sclass) = GetSizeMax sclass addOpRanges (CmpSizeLe name sclass x) = CmpSizeLe name sclass x removeOpRanges (SplitSpace o w i elems_per_thread) = SplitSpace o w i elems_per_thread removeOpRanges (SegOp op) = SegOp $ removeOpRanges op removeOpRanges (OtherOp op) = OtherOp $ removeOpRanges op removeOpRanges (GetSize name sclass) = GetSize name sclass removeOpRanges (GetSizeMax sclass) = GetSizeMax sclass removeOpRanges (CmpSizeLe name sclass x) = CmpSizeLe name sclass x instance (CanBeWise (Op lore), CanBeWise op, Attributes lore) => CanBeWise (HostOp lore op) where type OpWithWisdom (HostOp lore op) = HostOp (Wise lore) (OpWithWisdom op) removeOpWisdom (SplitSpace o w i elems_per_thread) = SplitSpace o w i elems_per_thread removeOpWisdom (SegOp op) = SegOp $ removeOpWisdom op removeOpWisdom (GetSize name sclass) = GetSize name sclass removeOpWisdom (GetSizeMax sclass) = GetSizeMax sclass removeOpWisdom (CmpSizeLe name sclass x) = CmpSizeLe name sclass x removeOpWisdom (OtherOp op) = OtherOp $ removeOpWisdom op instance (Attributes lore, ST.IndexOp op) => ST.IndexOp (HostOp lore op) where indexOp vtable k (SegOp op) is = ST.indexOp vtable k op is indexOp vtable k (OtherOp op) is = ST.indexOp vtable k op is indexOp _ _ _ _ = Nothing instance (PrettyLore lore, PP.Pretty op) => PP.Pretty (HostOp lore op) where ppr (SplitSpace o w i elems_per_thread) = text "splitSpace" <> suff <> parens (commasep [ppr w, ppr i, ppr elems_per_thread]) where suff = case o of SplitContiguous -> mempty SplitStrided stride -> text "Strided" <> parens (ppr stride) ppr (GetSize name size_class) = text "get_size" <> parens (commasep [ppr name, ppr size_class]) ppr (GetSizeMax size_class) = text "get_size_max" <> parens (ppr size_class) ppr (CmpSizeLe name size_class x) = text "get_size" <> parens (commasep [ppr name, ppr size_class]) <+> text "<=" <+> ppr x ppr (SegOp op) = ppr op ppr (OtherOp op) = ppr op instance (OpMetrics (Op lore), OpMetrics op) => OpMetrics (HostOp lore op) where opMetrics SplitSpace{} = seen "SplitSpace" opMetrics GetSize{} = seen "GetSize" opMetrics GetSizeMax{} = seen "GetSizeMax" opMetrics CmpSizeLe{} = seen "CmpSizeLe" opMetrics (SegOp op) = opMetrics op opMetrics (OtherOp op) = opMetrics op typeCheckHostOp :: TC.Checkable lore => (SegLevel -> OpWithAliases (Op lore) -> TC.TypeM lore ()) -> Maybe SegLevel -> (op -> TC.TypeM lore ()) -> HostOp (Aliases lore) op -> TC.TypeM lore () typeCheckHostOp _ _ _ (SplitSpace o w i elems_per_thread) = do case o of SplitContiguous -> return () SplitStrided stride -> TC.require [Prim int32] stride mapM_ (TC.require [Prim int32]) [w, i, elems_per_thread] typeCheckHostOp _ _ _ GetSize{} = return () typeCheckHostOp _ _ _ GetSizeMax{} = return () typeCheckHostOp _ _ _ (CmpSizeLe _ _ x) = TC.require [Prim int32] x typeCheckHostOp checker lvl _ (SegOp op) = TC.checkOpWith (checker $ segLevel op) $ typeCheckSegOp lvl op typeCheckHostOp _ _ f (OtherOp op) = f op