{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE TypeFamilies #-} -- | Multiversion segmented reduction. module Futhark.Pass.ExtractKernels.Segmented ( regularSegmentedRedomap , regularSegmentedScan ) where import Control.Monad import qualified Data.Map.Strict as M import Data.Semigroup ((<>)) import Futhark.Transform.Rename import Futhark.Representation.Kernels import Futhark.Representation.SOACS.SOAC (nilFn) import Futhark.MonadFreshNames import Futhark.Tools import Futhark.Pass.ExtractKernels.BlockedKernel data SegmentedVersion = OneGroupOneSegment | ManyGroupsOneSegment deriving (Eq, Ord, Show) -- | @regularSegmentedRedomap@ will generate code for a segmented redomap using -- two different strategies, and dynamically deciding which one to use based on -- the number of segments and segment size. We use the (static) @group_size@ to -- decide which of the following two strategies to choose: -- -- * Large: uses one or more groups to process a single segment. If multiple -- groups are used per segment, the intermediate reduction results must be -- recursively reduced, until there is only a single value per segment. -- -- Each thread /can/ read multiple elements, which will greatly increase -- performance; however, if the reduction is non-commutative the input array -- will be transposed (by the KernelBabysitter) to enable memory coalesced -- accesses. Currently we will always make each thread read as many elements -- as it can, but this /could/ be unfavorable because of the transpose: in -- the case where each thread can only read 2 elements, the cost of the -- transpose might not be worth the performance gained by letting each thread -- read multiple elements. This could be investigated more in depth in the -- future (TODO) -- -- * Small: is used to let each group process *multiple* segments within a -- group. We will only use this approach when we can process at least two -- segments within a single group. In those cases, we would normally allocate -- a /whole/ group per segment with the large strategy, but at most 50% of the -- threads in the group would have any element to read, which becomes highly -- inefficient. regularSegmentedRedomap :: (HasScope Kernels m, MonadBinder m, Lore m ~ Kernels) => SubExp -- segment_size -> SubExp -- num_segments -> [SubExp] -- nest_sizes = the sizes of the maps on "top" of this redomap -> Pattern Kernels -- flat_pat ... pat where each type is array with dim [w] -> Pattern Kernels -- pat -> SubExp -- w = total_num_elements -> Commutativity -- comm -> Lambda InKernel -- reduce_lam -> Lambda InKernel -- fold_lam = this lambda performs both the map-part and -- reduce-part of a redomap (described in redomap paper) -> [(VName, SubExp)] -- ispace = pair of (gtid, size) for the maps on "top" of this redomap -> [KernelInput] -- inps = inputs that can be looked up by using the gtids from ispace -> [SubExp] -- nes -> [VName] -- arrs_flat -> m () regularSegmentedRedomap segment_size num_segments nest_sizes flat_pat pat w comm reduce_lam fold_lam ispace inps nes arrs_flat = do unless (null $ patternContextElements pat) $ fail "regularSegmentedRedomap result pattern contains context elements, and Rasmus did not think this would ever happen." -- the result of the "map" part of a redomap has to be stored somewhere within -- the chunking loop of a kernel. The current way to do this is to make some -- scratch space initially, and each thread will get a part of this by -- splitting it. Finally it is returned as a result of the kernel (to not -- break functional semantics). map_out_arrs <- forM (drop num_redres $ patternIdents pat) $ \(Ident name t) -> do tmp <- letExp (baseString name <> "_out_in") $ BasicOp $ Scratch (elemType t) (arrayDims t) -- This reshape will not always work. letExp (baseString name ++ "_out_in") $ BasicOp $ Reshape (reshapeOuter [DimNew w] (length nest_sizes+1) $ arrayShape t) tmp -- Check that we're only dealing with arrays with dimension [w] forM_ arrs_flat $ \arr -> do tp <- lookupType arr case tp of -- TODO: this won't work if the reduction operator works on lists... but -- they seem to be handled in some other way (which makes sense) Array _primtp (Shape (flatsize:_)) _uniqness -> when (flatsize /= w) $ fail$ "regularSegmentedRedomap: first dimension of array has incorrect size " ++ pretty arr ++ ":" ++ pretty tp _ -> fail $ "regularSegmentedRedomap: non array encountered " ++ pretty arr ++ ":" ++ pretty tp -- The pattern passed to chunkLambda must have exactly *one* array dimension, -- to get the correct size of [chunk_size]type. -- -- TODO: not sure if this will work when result of map is multidimensional, -- or if reduction operator uses lists... must check chunk_pat <- fmap (Pattern []) $ forM (patternValueElements pat) $ \pat_e -> case patElemType pat_e of Array ty (Shape (dim0:_)) u -> do vn' <- newName $ patElemName pat_e return $ PatElem vn' $ Array ty (Shape [dim0]) u _ -> fail $ "segmentedRedomap: result pattern is not array " ++ pretty pat_e chunk_fold_lam <- chunkLambda chunk_pat nes fold_lam kern_chunk_fold_lam <- kerneliseLambda nes chunk_fold_lam let chunk_red_pat = Pattern [] $ take num_redres $ patternValueElements chunk_pat kern_chunk_reduce_lam <- kerneliseLambda nes =<< chunkLambda chunk_red_pat nes reduce_lam -- the lambda for a GroupReduce needs these two extra parameters my_index <- newVName "my_index" other_offset <- newVName "other_offset" let my_index_param = Param my_index (Prim int32) let other_offset_param = Param other_offset (Prim int32) let reduce_lam' = reduce_lam { lambdaParams = my_index_param : other_offset_param : lambdaParams reduce_lam } flag_reduce_lam <- addFlagToLambda nes reduce_lam let flag_reduce_lam' = flag_reduce_lam { lambdaParams = my_index_param : other_offset_param : lambdaParams flag_reduce_lam } -- TODO: 'blockedReductionStream' in BlockedKernel.hs which is very similar -- performs a copy here... however, I have not seen a need for it yet. group_size <- getSize "group_size" SizeGroup num_groups_hint <- getSize "num_groups_hint" SizeNumGroups -- Here we make a small optimization: if we will use the large kernel, and -- only one group per segment, we can simplify the calcualtions within the -- kernel for the indexes of which segment is it working on; therefore we -- create two different kernels (this will increase the final code size a bit -- though). TODO: test how much we win by doing this. (num_groups_per_segment, _) <- calcGroupsPerSegmentAndElementsPerThread segment_size num_segments num_groups_hint group_size ManyGroupsOneSegment let all_arrs = arrs_flat ++ map_out_arrs (large_1_ses, large_1_stms) <- runBinder $ useLargeOnePerSeg group_size all_arrs reduce_lam' kern_chunk_fold_lam (large_m_ses, large_m_stms) <- runBinder $ useLargeMultiRecursiveReduce group_size all_arrs reduce_lam' kern_chunk_fold_lam kern_chunk_reduce_lam flag_reduce_lam' let e_large_seg = eIf (eCmpOp (CmpEq $ IntType Int32) (eSubExp num_groups_per_segment) (eSubExp one)) (mkBodyM large_1_stms large_1_ses) (mkBodyM large_m_stms large_m_ses) (small_ses, small_stms) <- runBinder $ useSmallKernel group_size map_out_arrs flag_reduce_lam' -- if (group_size/2) < segment_size, means that we will not be able to fit two -- segments into one group, and therefore we should not use the kernel that -- relies on this. e <- eIf (eCmpOp (CmpSlt Int32) (eBinOp (SQuot Int32) (eSubExp group_size) (eSubExp two)) (eSubExp segment_size)) (eBody [e_large_seg]) (mkBodyM small_stms small_ses) redres_pes <- forM (take num_redres (patternValueElements pat)) $ \pe -> do vn' <- newName $ patElemName pe return $ PatElem vn' $ replaceSegmentDims num_segments $ patElemType pe let mapres_pes = drop num_redres $ patternValueElements flat_pat let unreshaped_pat = Pattern [] $ redres_pes ++ mapres_pes letBind_ unreshaped_pat e forM_ (zip (patternValueElements unreshaped_pat) (patternValueElements pat)) $ \(kpe, pe) -> letBind_ (Pattern [] [pe]) $ BasicOp $ Reshape [DimNew se | se <- arrayDims $ patElemAttr pe] (patElemName kpe) where replaceSegmentDims d t = t `setArrayDims` (d : drop (length nest_sizes) (arrayDims t)) one = constant (1 :: Int32) two = constant (2 :: Int32) -- number of reduction results (tuple size for reduction operator) num_redres = length nes ---------------------------------------------------------------------------- -- The functions below generate all the needed code for the two different -- version of segmented-redomap (one group per segment, and many groups per -- segment). -- -- We rename statements before adding them because the same lambdas -- (reduce/fold) are used multiple times, and we do not want to bind the -- same VName twice (as this is a type error) ---------------------------------------------------------------------------- useLargeOnePerSeg group_size all_arrs reduce_lam' kern_chunk_fold_lam = do mapres_pes <- forM (drop num_redres $ patternValueElements flat_pat) $ \pe -> do vn' <- newName $ patElemName pe return $ PatElem vn' $ patElemType pe (kernel, _, _) <- largeKernel group_size segment_size num_segments nest_sizes all_arrs comm reduce_lam' kern_chunk_fold_lam nes w OneGroupOneSegment ispace inps kernel_redres_pes <- forM (take num_redres (patternValueElements pat)) $ \pe -> do vn' <- newName $ patElemName pe return $ PatElem vn' $ replaceSegmentDims num_segments $ patElemType pe let kernel_pat = Pattern [] $ kernel_redres_pes ++ mapres_pes addStm =<< renameStm (Let kernel_pat (defAux ()) $ Op kernel) return $ map (Var . patElemName) $ patternValueElements kernel_pat ---------------------------------------------------------------------------- useLargeMultiRecursiveReduce group_size all_arrs reduce_lam' kern_chunk_fold_lam kern_chunk_reduce_lam flag_reduce_lam' = do mapres_pes <- forM (drop num_redres $ patternValueElements flat_pat) $ \pe -> do vn' <- newName $ patElemName pe return $ PatElem vn' $ patElemType pe (firstkernel, num_groups_used, num_groups_per_segment) <- largeKernel group_size segment_size num_segments nest_sizes all_arrs comm reduce_lam' kern_chunk_fold_lam nes w ManyGroupsOneSegment ispace inps firstkernel_redres_pes <- forM (take num_redres (patternValueElements pat)) $ \pe -> do vn' <- newName $ patElemName pe return $ PatElem vn' $ replaceSegmentDims num_groups_used $ patElemType pe let first_pat = Pattern [] $ firstkernel_redres_pes ++ mapres_pes addStm =<< renameStm (Let first_pat (defAux ()) $ Op firstkernel) let new_segment_size = num_groups_per_segment let new_total_elems = num_groups_used let tmp_redres = map patElemName firstkernel_redres_pes (finalredres, part_two_stms) <- runBinder $ performFinalReduction new_segment_size new_total_elems tmp_redres reduce_lam' kern_chunk_reduce_lam flag_reduce_lam' mapM_ (addStm <=< renameStm) part_two_stms return $ finalredres ++ map (Var . patElemName) mapres_pes ---------------------------------------------------------------------------- -- The "recursive" reduction step. However, will always do this using -- exactly one extra step. Either by using the small kernel, or by using the -- large kernel with one group per segment. performFinalReduction new_segment_size new_total_elems tmp_redres reduce_lam' kern_chunk_reduce_lam flag_reduce_lam' = do group_size <- getSize "group_size" SizeGroup -- Large kernel, using one group per segment (ogps) (large_ses, large_stms) <- runBinder $ do (large_kernel, _, _) <- largeKernel group_size new_segment_size num_segments nest_sizes tmp_redres comm reduce_lam' kern_chunk_reduce_lam nes new_total_elems OneGroupOneSegment ispace inps letTupExp' "kernel_result" $ Op large_kernel -- Small kernel, using one group many segments (ogms) (small_ses, small_stms) <- runBinder $ do red_scratch_arrs <- forM (take num_redres $ patternIdents pat) $ \(Ident name t) -> do -- We construct a scratch array for writing the result, but -- we have to flatten the dimensions corresponding to the -- map nest, because multi-dimensional WriteReturns are/were -- not supported. tmp <- letExp (baseString name <> "_redres_scratch") $ BasicOp $ Scratch (elemType t) (arrayDims t) let reshape = reshapeOuter [DimNew num_segments] (length nest_sizes) $ arrayShape t letExp (baseString name ++ "_redres_scratch") $ BasicOp $ Reshape reshape tmp kernel <- smallKernel group_size new_segment_size num_segments tmp_redres red_scratch_arrs comm flag_reduce_lam' reduce_lam nes new_total_elems ispace inps letTupExp' "kernel_result" $ Op kernel e <- eIf (eCmpOp (CmpSlt Int32) (eBinOp (SQuot Int32) (eSubExp group_size) (eSubExp two)) (eSubExp new_segment_size)) (mkBodyM large_stms large_ses) (mkBodyM small_stms small_ses) letTupExp' "step_two_kernel_result" e ---------------------------------------------------------------------------- useSmallKernel group_size map_out_arrs flag_reduce_lam' = do red_scratch_arrs <- forM (take num_redres $ patternIdents pat) $ \(Ident name t) -> do tmp <- letExp (baseString name <> "_redres_scratch") $ BasicOp $ Scratch (elemType t) (arrayDims t) let shape_change = reshapeOuter [DimNew num_segments] (length nest_sizes) (arrayShape t) letExp (baseString name ++ "_redres_scratch") $ BasicOp $ Reshape shape_change tmp let scratch_arrays = red_scratch_arrs ++ map_out_arrs kernel <- smallKernel group_size segment_size num_segments arrs_flat scratch_arrays comm flag_reduce_lam' fold_lam nes w ispace inps letTupExp' "kernel_result" $ Op kernel largeKernel :: (MonadBinder m, Lore m ~ Kernels) => SubExp -- group_size -> SubExp -- segment_size -> SubExp -- num_segments -> [SubExp] -- nest sizes -> [VName] -- all_arrs: flat arrays (also the "map_out" ones) -> Commutativity -- comm -> Lambda InKernel -- reduce_lam -> Lambda InKernel -- kern_chunk_fold_lam -> [SubExp] -- nes -> SubExp -- w = total_num_elements -> SegmentedVersion -- segver -> [(VName, SubExp)] -- ispace = pair of (gtid, size) for the maps on "top" of this redomap -> [KernelInput] -- inps = inputs that can be looked up by using the gtids from ispace -> m (Kernel InKernel, SubExp, SubExp) largeKernel group_size segment_size num_segments nest_sizes all_arrs comm reduce_lam' kern_chunk_fold_lam nes w segver ispace inps = do let num_redres = length nes -- number of reduction results (tuple size for -- reduction operator) num_groups_hint <- getSize "num_groups_hint" SizeNumGroups (num_groups_per_segment, elements_per_thread) <- calcGroupsPerSegmentAndElementsPerThread segment_size num_segments num_groups_hint group_size segver num_groups <- letSubExp "num_groups" $ case segver of OneGroupOneSegment -> BasicOp $ SubExp num_segments ManyGroupsOneSegment -> BasicOp $ BinOp (Mul Int32) num_segments num_groups_per_segment num_threads <- letSubExp "num_threads" $ BasicOp $ BinOp (Mul Int32) num_groups group_size threads_within_segment <- letSubExp "threads_within_segment" $ BasicOp $ BinOp (Mul Int32) group_size num_groups_per_segment gtid_vn <- newVName "gtid" gtid_ln <- newVName "gtid" -- the array passed here is the structure for how to layout the kernel space space <- newKernelSpace (num_groups, group_size, num_threads) $ FlatThreadSpace $ ispace ++ [(gtid_vn, num_groups_per_segment),(gtid_ln,group_size)] let red_ts = take num_redres $ lambdaReturnType kern_chunk_fold_lam let map_ts = map rowType $ drop num_redres $ lambdaReturnType kern_chunk_fold_lam let kernel_return_types = red_ts ++ map_ts let ordering = case comm of Commutative -> SplitStrided threads_within_segment Noncommutative -> SplitContiguous let stride = case ordering of SplitStrided s -> s SplitContiguous -> one let each_thread = do segment_index <- letSubExp "segment_index" $ BasicOp $ BinOp (SQuot Int32) (Var $ spaceGroupId space) num_groups_per_segment -- localId + (group_size * (groupId % num_groups_per_segment)) index_within_segment <- letSubExp "index_within_segment" =<< eBinOp (Add Int32) (eSubExp $ Var gtid_ln) (eBinOp (Mul Int32) (eSubExp group_size) (eBinOp (SRem Int32) (eSubExp $ Var $ spaceGroupId space) (eSubExp num_groups_per_segment)) ) (in_segment_offset,offset) <- makeOffsetExp ordering index_within_segment elements_per_thread segment_index let (_, chunksize, [], arr_params) = partitionChunkedKernelFoldParameters 0 $ lambdaParams kern_chunk_fold_lam let chunksize_se = Var $ paramName chunksize patelems_res_of_split <- forM arr_params $ \arr_param -> do let chunk_t = paramType arr_param `setOuterSize` Var (paramName chunksize) return $ PatElem (paramName arr_param) chunk_t letBind_ (Pattern [] [PatElem (paramName chunksize) $ paramType chunksize]) $ Op $ SplitSpace ordering segment_size index_within_segment elements_per_thread addKernelInputStms inps forM_ (zip all_arrs patelems_res_of_split) $ \(arr, pe) -> do let pe_t = patElemType pe segment_dims = nest_sizes ++ arrayDims (pe_t `setOuterSize` segment_size) arr_nested <- letExp (baseString arr ++ "_nested") $ BasicOp $ Reshape (map DimNew segment_dims) arr arr_nested_t <- lookupType arr_nested let slice = fullSlice arr_nested_t $ map (DimFix . Var . fst) ispace ++ [DimSlice in_segment_offset chunksize_se stride] letBind_ (Pattern [] [pe]) $ BasicOp $ Index arr_nested slice red_pes <- forM red_ts $ \red_t -> do pe_name <- newVName "chunk_fold_red" return $ PatElem pe_name red_t map_pes <- forM map_ts $ \map_t -> do pe_name <- newVName "chunk_fold_map" return $ PatElem pe_name $ map_t `arrayOfRow` chunksize_se -- we add the lets here, as we practially don't know if the resulting subexp -- is a Constant or a Var, so better be safe (?) addStms $ bodyStms (lambdaBody kern_chunk_fold_lam) addStms $ stmsFromList [ Let (Pattern [] [pe]) (defAux ()) $ BasicOp $ SubExp se | (pe,se) <- zip (red_pes ++ map_pes) (bodyResult $ lambdaBody kern_chunk_fold_lam) ] -- Combine the reduction results from each thread. This will put results in -- local memory, so a GroupReduce can be performed on them combine_red_pes <- forM red_ts $ \red_t -> do pe_name <- newVName "chunk_fold_red" return $ PatElem pe_name $ red_t `arrayOfRow` group_size cids <- replicateM (length red_pes) $ newVName "cid" addStms $ stmsFromList [ Let (Pattern [] [pe']) (defAux ()) $ Op $ Combine (combineSpace [(cid, group_size)]) [patElemType pe] [] $ Body () mempty [Var $ patElemName pe] | (cid, pe', pe) <- zip3 cids combine_red_pes red_pes ] final_red_pes <- forM (lambdaReturnType reduce_lam') $ \t -> do pe_name <- newVName "final_result" return $ PatElem pe_name t letBind_ (Pattern [] final_red_pes) $ Op $ GroupReduce group_size reduce_lam' $ zip nes (map patElemName combine_red_pes) return (final_red_pes, map_pes, offset) ((final_red_pes, map_pes, offset), stms) <- runBinder each_thread red_returns <- forM final_red_pes $ \pe -> return $ ThreadsReturn OneResultPerGroup $ Var $ patElemName pe map_returns <- forM map_pes $ \pe -> return $ ConcatReturns ordering w elements_per_thread (Just offset) $ patElemName pe let kernel_returns = red_returns ++ map_returns let kerneldebughints = KernelDebugHints kernelname [ ("num_segment", num_segments) , ("segment_size", segment_size) , ("num_groups", num_groups) , ("group_size", group_size) , ("elements_per_thread", elements_per_thread) , ("num_groups_per_segment", num_groups_per_segment) ] let kernel = Kernel kerneldebughints space kernel_return_types $ KernelBody () stms kernel_returns return (kernel, num_groups, num_groups_per_segment) where one = constant (1 :: Int32) commname = case comm of Commutative -> "comm" Noncommutative -> "nocomm" kernelname = case segver of OneGroupOneSegment -> "segmented_redomap__large_" ++ commname ++ "_one" ManyGroupsOneSegment -> "segmented_redomap__large_" ++ commname ++ "_many" makeOffsetExp SplitContiguous index_within_segment elements_per_thread segment_index = do in_segment_offset <- letSubExp "in_segment_offset" $ BasicOp $ BinOp (Mul Int32) elements_per_thread index_within_segment offset <- letSubExp "offset" =<< eBinOp (Add Int32) (eSubExp in_segment_offset) (eBinOp (Mul Int32) (eSubExp segment_size) (eSubExp segment_index)) return (in_segment_offset, offset) makeOffsetExp (SplitStrided _) index_within_segment _elements_per_thread segment_index = do offset <- letSubExp "offset" =<< eBinOp (Add Int32) (eSubExp index_within_segment) (eBinOp (Mul Int32) (eSubExp segment_size) (eSubExp segment_index)) return (index_within_segment, offset) calcGroupsPerSegmentAndElementsPerThread :: (MonadBinder m, Lore m ~ Kernels) => SubExp -> SubExp -> SubExp -> SubExp -> SegmentedVersion -> m (SubExp, SubExp) calcGroupsPerSegmentAndElementsPerThread segment_size num_segments num_groups_hint group_size segver = do num_groups_per_segment_hint <- letSubExp "num_groups_per_segment_hint" =<< case segver of OneGroupOneSegment -> eSubExp one ManyGroupsOneSegment -> eDivRoundingUp Int32 (eSubExp num_groups_hint) (eSubExp num_segments) elements_per_thread <- letSubExp "elements_per_thread" =<< eDivRoundingUp Int32 (eSubExp segment_size) (eBinOp (Mul Int32) (eSubExp group_size) (eSubExp num_groups_per_segment_hint)) -- if we are using 1 element per thread, we might be launching too many -- groups. This expression will remedy this. -- -- For example, if there are 3 segments of size 512, we are using group size -- 128, and @num_groups_hint@ is 256; then we would use 1 element per thread, -- and launch 256 groups. However, we only need 4 groups per segment to -- process all elements. num_groups_per_segment <- letSubExp "num_groups_per_segment" =<< case segver of OneGroupOneSegment -> eSubExp one ManyGroupsOneSegment -> eIf (eCmpOp (CmpEq $ IntType Int32) (eSubExp elements_per_thread) (eSubExp one)) (eBody [eDivRoundingUp Int32 (eSubExp segment_size) (eSubExp group_size)]) (mkBodyM mempty [num_groups_per_segment_hint]) return (num_groups_per_segment, elements_per_thread) where one = constant (1 :: Int32) smallKernel :: (MonadBinder m, Lore m ~ Kernels) => SubExp -- group_size -> SubExp -- segment_size -> SubExp -- num_segments -> [VName] -- in_arrs: flat arrays (containing input to fold_lam) -> [VName] -- scratch_arrs: Preallocated space that we can write into -> Commutativity -- comm -> Lambda InKernel -- flag_reduce_lam' -> Lambda InKernel -- fold_lam -> [SubExp] -- nes -> SubExp -- w = total_num_elements -> [(VName, SubExp)] -- ispace = pair of (gtid, size) for the maps on "top" of this redomap -> [KernelInput] -- inps = inputs that can be looked up by using the gtids from ispace -> m (Kernel InKernel) smallKernel group_size segment_size num_segments in_arrs scratch_arrs comm flag_reduce_lam' fold_lam_unrenamed nes w ispace inps = do let num_redres = length nes -- number of reduction results (tuple size for -- reduction operator) fold_lam <- renameLambda fold_lam_unrenamed num_segments_per_group <- letSubExp "num_segments_per_group" $ BasicOp $ BinOp (SQuot Int32) group_size segment_size num_groups <- letSubExp "num_groups" =<< eDivRoundingUp Int32 (eSubExp num_segments) (eSubExp num_segments_per_group) num_threads <- letSubExp "num_threads" $ BasicOp $ BinOp (Mul Int32) num_groups group_size active_threads_per_group <- letSubExp "active_threads_per_group" $ BasicOp $ BinOp (Mul Int32) segment_size num_segments_per_group let remainder_last_group = eBinOp (SRem Int32) (eSubExp num_segments) (eSubExp num_segments_per_group) segments_in_last_group <- letSubExp "seg_in_last_group" =<< eIf (eCmpOp (CmpEq $ IntType Int32) remainder_last_group (eSubExp zero)) (eBody [eSubExp num_segments_per_group]) (eBody [remainder_last_group]) active_threads_in_last_group <- letSubExp "active_threads_last_group" $ BasicOp $ BinOp (Mul Int32) segment_size segments_in_last_group -- the array passed here is the structure for how to layout the kernel space space <- newKernelSpace (num_groups, group_size, num_threads) $ FlatThreadSpace [] ------------------------------------------------------------------------------ -- What follows is the statements used in the kernel ------------------------------------------------------------------------------ let lid = Var $ spaceLocalId space let (red_ts, map_ts) = splitAt num_redres $ lambdaReturnType fold_lam let kernel_return_types = red_ts ++ map_ts let wasted_thread_part1 = do let create_dummy_val (Prim ty) = return $ Constant $ blankPrimValue ty create_dummy_val (Array ty sh _) = letSubExp "dummy" $ BasicOp $ Scratch ty (shapeDims sh) create_dummy_val Mem{} = fail "segredomap, 'Mem' used as result type" dummy_vals <- mapM create_dummy_val kernel_return_types return (negone : dummy_vals) let normal_thread_part1 = do segment_index <- letSubExp "segment_index" =<< eBinOp (Add Int32) (eBinOp (SQuot Int32) (eSubExp $ Var $ spaceLocalId space) (eSubExp segment_size)) (eBinOp (Mul Int32) (eSubExp $ Var $ spaceGroupId space) (eSubExp num_segments_per_group)) index_within_segment <- letSubExp "index_within_segment" =<< eBinOp (SRem Int32) (eSubExp $ Var $ spaceLocalId space) (eSubExp segment_size) offset <- makeOffsetExp index_within_segment segment_index red_pes <- forM red_ts $ \red_t -> do pe_name <- newVName "fold_red" return $ PatElem pe_name red_t map_pes <- forM map_ts $ \map_t -> do pe_name <- newVName "fold_map" return $ PatElem pe_name map_t addManualIspaceCalcStms segment_index ispace addKernelInputStms inps -- Index input array to get arguments to fold_lam let arr_params = drop num_redres $ lambdaParams fold_lam let nonred_lamparam_pes = map (\p -> PatElem (paramName p) (paramType p)) arr_params forM_ (zip in_arrs nonred_lamparam_pes) $ \(arr, pe) -> do tp <- lookupType arr let slice = fullSlice tp [DimFix offset] letBind_ (Pattern [] [pe]) $ BasicOp $ Index arr slice -- Bind neutral element (serves as the reduction arguments to fold_lam) forM_ (zip nes (take num_redres $ lambdaParams fold_lam)) $ \(ne,param) -> do let pe = PatElem (paramName param) (paramType param) letBind_ (Pattern [] [pe]) $ BasicOp $ SubExp ne addStms $ bodyStms $ lambdaBody fold_lam -- we add the lets here, as we practially don't know if the resulting subexp -- is a Constant or a Var, so better be safe (?) addStms $ stmsFromList [ Let (Pattern [] [pe]) (defAux ()) $ BasicOp $ SubExp se | (pe,se) <- zip (red_pes ++ map_pes) (bodyResult $ lambdaBody fold_lam) ] let mapoffset = offset let mapret_elems = map (Var . patElemName) map_pes let redres_elems = map (Var . patElemName) red_pes return (mapoffset : redres_elems ++ mapret_elems) let all_threads red_pes = do isfirstinsegment <- letExp "isfirstinsegment" =<< eCmpOp (CmpEq $ IntType Int32) (eBinOp (SRem Int32) (eSubExp lid) (eSubExp segment_size)) (eSubExp zero) -- We will perform a segmented-scan, so all the prime variables here -- include the flag, which is the first argument to flag_reduce_lam let red_pes_wflag = PatElem isfirstinsegment (Prim Bool) : red_pes let red_ts_wflag = Prim Bool : red_ts -- Combine the reduction results from each thread. This will put results in -- local memory, so a GroupReduce/GroupScan can be performed on them combine_red_pes' <- forM red_ts_wflag $ \red_t -> do pe_name <- newVName "chunk_fold_red" return $ PatElem pe_name $ red_t `arrayOfRow` group_size cids <- replicateM (length red_pes_wflag) $ newVName "cid" addStms $ stmsFromList [ Let (Pattern [] [pe']) (defAux ()) $ Op $ Combine (combineSpace [(cid, group_size)]) [patElemType pe] [] $ Body () mempty [Var $ patElemName pe] | (cid, pe', pe) <- zip3 cids combine_red_pes' red_pes_wflag ] scan_red_pes_wflag <- forM red_ts_wflag $ \red_t -> do pe_name <- newVName "scanned" return $ PatElem pe_name $ red_t `arrayOfRow` group_size let scan_red_pes = drop 1 scan_red_pes_wflag letBind_ (Pattern [] scan_red_pes_wflag) $ Op $ GroupScan group_size flag_reduce_lam' $ zip (false:nes) (map patElemName combine_red_pes') return scan_red_pes let normal_thread_part2 scan_red_pes = do segment_index <- letSubExp "segment_index" =<< eBinOp (Add Int32) (eBinOp (SQuot Int32) (eSubExp $ Var $ spaceLocalId space) (eSubExp segment_size)) (eBinOp (Mul Int32) (eSubExp $ Var $ spaceGroupId space) (eSubExp num_segments_per_group)) islastinsegment <- letExp "islastinseg" =<< eCmpOp (CmpEq $ IntType Int32) (eBinOp (SRem Int32) (eSubExp lid) (eSubExp segment_size)) (eBinOp (Sub Int32) (eSubExp segment_size) (eSubExp one)) redoffset <- letSubExp "redoffset" =<< eIf (eSubExp $ Var islastinsegment) (eBody [eSubExp segment_index]) (mkBodyM mempty [negone]) redret_elems <- fmap (map Var) $ letTupExp "red_return_elem" =<< eIf (eSubExp $ Var islastinsegment) (eBody [return $ BasicOp $ Index (patElemName pe) (fullSlice (patElemType pe) [DimFix lid]) | pe <- scan_red_pes]) (mkBodyM mempty nes) return (redoffset : redret_elems) let picknchoose = do is_last_group <- letSubExp "islastgroup" =<< eCmpOp (CmpEq $ IntType Int32) (eSubExp $ Var $ spaceGroupId space) (eBinOp (Sub Int32) (eSubExp num_groups) (eSubExp one)) active_threads_this_group <- letSubExp "active_thread_this_group" =<< eIf (eSubExp is_last_group) (eBody [eSubExp active_threads_in_last_group]) (eBody [eSubExp active_threads_per_group]) isactive <- letSubExp "isactive" =<< eCmpOp (CmpSlt Int32) (eSubExp lid) (eSubExp active_threads_this_group) -- Part 1: All active threads reads element from input array and applies -- folding function. "wasted" threads will just create dummy values (normal_res1, normal_stms1) <- runBinder normal_thread_part1 (wasted_res1, wasted_stms1) <- runBinder wasted_thread_part1 -- we could just have used letTupExp, but this would not give as nice -- names in the generated code mapoffset_pe <- (`PatElem` i32) <$> newVName "mapoffset" redtmp_pes <- forM red_ts $ \red_t -> do pe_name <- newVName "redtmp_res" return $ PatElem pe_name red_t map_pes <- forM map_ts $ \map_t -> do pe_name <- newVName "map_res" return $ PatElem pe_name map_t e1 <- eIf (eSubExp isactive) (mkBodyM normal_stms1 normal_res1) (mkBodyM wasted_stms1 wasted_res1) letBind_ (Pattern [] (mapoffset_pe:redtmp_pes++map_pes)) e1 -- Part 2: All threads participate in Comine & GroupScan scan_red_pes <- all_threads redtmp_pes -- Part 3: Active thread that are the last element in segment, should -- write the element from local memory to the output array (normal_res2, normal_stms2) <- runBinder $ normal_thread_part2 scan_red_pes redoffset_pe <- (`PatElem` i32) <$> newVName "redoffset" red_pes <- forM red_ts $ \red_t -> do pe_name <- newVName "red_res" return $ PatElem pe_name red_t e2 <- eIf (eSubExp isactive) (mkBodyM normal_stms2 normal_res2) (mkBodyM mempty (negone : nes)) letBind_ (Pattern [] (redoffset_pe:red_pes)) e2 return $ map (Var . patElemName) $ redoffset_pe:mapoffset_pe:red_pes++map_pes (redoffset:mapoffset:redmapres, stms) <- runBinder picknchoose let (finalredvals, finalmapvals) = splitAt num_redres redmapres -- To be able to only return elements from some threads, we exploit the fact -- that WriteReturn with offset=-1, won't do anything. red_returns <- forM (zip finalredvals $ take num_redres scratch_arrs) $ \(se, scarr) -> return $ WriteReturn [num_segments] scarr [([redoffset], se)] map_returns <- forM (zip finalmapvals $ drop num_redres scratch_arrs) $ \(se, scarr) -> return $ WriteReturn [w] scarr [([mapoffset], se)] let kernel_returns = red_returns ++ map_returns let kerneldebughints = KernelDebugHints kernelname [ ("num_segment", num_segments) , ("segment_size", segment_size) , ("num_groups", num_groups) , ("group_size", group_size) , ("num_segments_per_group", num_segments_per_group) , ("active_threads_per_group", active_threads_per_group) ] let kernel = Kernel kerneldebughints space kernel_return_types $ KernelBody () stms kernel_returns return kernel where i32 = Prim $ IntType Int32 zero = constant (0 :: Int32) one = constant (1 :: Int32) negone = constant (-1 :: Int32) false = constant False commname = case comm of Commutative -> "comm" Noncommutative -> "nocomm" kernelname = "segmented_redomap__small_" ++ commname makeOffsetExp index_within_segment segment_index = do e <- eBinOp (Add Int32) (eSubExp index_within_segment) (eBinOp (Mul Int32) (eSubExp segment_size) (eSubExp segment_index)) letSubExp "offset" e addKernelInputStms :: (MonadBinder m, Lore m ~ InKernel) => [KernelInput] -> m () addKernelInputStms = mapM_ $ \kin -> do let pe = PatElem (kernelInputName kin) (kernelInputType kin) let arr = kernelInputArray kin arrtp <- lookupType arr let slice = fullSlice arrtp [DimFix se | se <- kernelInputIndices kin] letBind (Pattern [] [pe]) $ BasicOp $ Index arr slice -- | Manually calculate the values for the ispace identifiers, when the -- 'SpaceStructure' won't do. ispace is the dimensions of the overlaying maps. -- -- If the input is @i [(a_vn, a), (b_vn, b), (c_vn, c)]@ then @i@ should hit all -- the values [0,a*b*c). We can calculate the indexes for the other dimensions: -- -- > c_vn = i % c -- > b_vn = (i/c) % b -- > a_vn = ((i/c)/b) % a addManualIspaceCalcStms :: (MonadBinder m, Lore m ~ InKernel) => SubExp -> [(VName, SubExp)] -> m () addManualIspaceCalcStms outer_index ispace = do -- TODO: The ispace index is calculated in a bit different way than it -- would have been done if the ThreadSpace was used. However, this -- works. Maybe ask Troels if doing it the other way has some benefit? let calc_ispace_index prev_val (vn,size) = do let pe = PatElem vn (Prim $ IntType Int32) letBind_ (Pattern [] [pe]) $ BasicOp $ BinOp (SRem Int32) prev_val size letSubExp "tmp_val" $ BasicOp $ BinOp (SQuot Int32) prev_val size foldM_ calc_ispace_index outer_index (reverse ispace) addFlagToLambda :: (MonadBinder m, Lore m ~ Kernels) => [SubExp] -> Lambda InKernel -> m (Lambda InKernel) addFlagToLambda nes lam = do let num_accs = length nes x_flag <- newVName "x_flag" y_flag <- newVName "y_flag" let x_flag_param = Param x_flag $ Prim Bool y_flag_param = Param y_flag $ Prim Bool (x_params, y_params) = splitAt num_accs $ lambdaParams lam params = [x_flag_param] ++ x_params ++ [y_flag_param] ++ y_params body <- runBodyBinder $ localScope (scopeOfLParams params) $ do new_flag <- letSubExp "new_flag" $ BasicOp $ BinOp LogOr (Var x_flag) (Var y_flag) lhs <- fmap (map Var) $ letTupExp "seg_lhs" $ If (Var y_flag) (resultBody nes) (resultBody $ map (Var . paramName) x_params) $ ifCommon $ map paramType x_params let rhs = map (Var . paramName) y_params lam' <- renameLambda lam -- avoid shadowing res <- eLambda lam' $ map eSubExp $ lhs ++ rhs return $ resultBody $ new_flag : res return Lambda { lambdaParams = params , lambdaBody = body , lambdaReturnType = Prim Bool : lambdaReturnType lam } regularSegmentedScan :: (MonadBinder m, Lore m ~ Kernels) => SubExp -> Pattern Kernels -> SubExp -> Lambda InKernel -> Lambda InKernel -> [(VName, SubExp)] -> [KernelInput] -> [SubExp] -> [VName] -> m () regularSegmentedScan segment_size pat w lam map_lam ispace inps nes arrs = do flags_i <- newVName "flags_i" unused_flag_array <- newVName "unused_flag_array" flags_body <- runBodyBinder $ localScope (M.singleton flags_i $ IndexInfo Int32) $ do segment_index <- letSubExp "segment_index" $ BasicOp $ BinOp (SRem Int32) (Var flags_i) segment_size start_of_segment <- letSubExp "start_of_segment" $ BasicOp $ CmpOp (CmpEq int32) segment_index zero let flag = start_of_segment return $ resultBody [flag] (mapk_bnds, mapk) <- mapKernelFromBody w (FlatThreadSpace [(flags_i, w)]) [] [Prim Bool] flags_body addStms mapk_bnds flags <- letExp "flags" $ Op mapk lam' <- addFlagToLambda nes lam flag_p <- newParam "flag" $ Prim Bool let map_lam' = map_lam { lambdaParams = flag_p : lambdaParams map_lam , lambdaBody = (lambdaBody map_lam) { bodyResult = Var (paramName flag_p) : bodyResult (lambdaBody map_lam) } , lambdaReturnType = Prim Bool : lambdaReturnType map_lam } let pat' = pat { patternValueElements = PatElem unused_flag_array (arrayOf (Prim Bool) (Shape [w]) NoUniqueness) : patternValueElements pat } void $ blockedScan pat' w (lam', false:nes) (Commutative, nilFn, mempty) map_lam' segment_size ispace inps (flags:arrs) where zero = constant (0 :: Int32) false = constant False