{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE TypeFamilies #-} -- | Extract limited nested parallelism for execution inside -- individual kernel workgroups. module Futhark.Pass.ExtractKernels.Intragroup (intraGroupParallelise) where import Control.Monad.RWS import Control.Monad.Trans.Maybe import qualified Data.Map.Strict as M import qualified Data.Set as S import Futhark.Analysis.PrimExp.Convert import Futhark.Representation.SOACS import qualified Futhark.Representation.Kernels as Out import Futhark.Representation.Kernels.Kernel import Futhark.MonadFreshNames import Futhark.Tools import Futhark.Analysis.DataDependencies import qualified Futhark.Pass.ExtractKernels.Kernelise as Kernelise import Futhark.Pass.ExtractKernels.Distribution import Futhark.Pass.ExtractKernels.BlockedKernel -- | Convert the statements inside a map nest to kernel statements, -- attempting to parallelise any remaining (top-level) parallel -- statements. Anything that is not a map, scan or reduction will -- simply be sequentialised. This includes sequential loops that -- contain maps, scans or reduction. In the future, we could probably -- do something more clever. Make sure that the amount of parallelism -- to be exploited does not exceed the group size. Further, as a hack -- we also consider the size of all intermediate arrays as -- "parallelism to be exploited" to avoid exploding local memory. -- -- We distinguish between "minimum group size" and "maximum -- exploitable parallelism". intraGroupParallelise :: (MonadFreshNames m, LocalScope Out.Kernels m) => KernelNest -> Lambda -> m (Maybe ((SubExp, SubExp), SubExp, Out.Stms Out.Kernels, Out.Stms Out.Kernels)) intraGroupParallelise knest lam = runMaybeT $ do (w_stms, w, ispace, inps, rts) <- lift $ flatKernel knest let num_groups = w body = lambdaBody lam ltid <- newVName "ltid" let group_variant = S.fromList [ltid] (wss_min, wss_avail, kbody) <- lift $ localScope (scopeOfLParams $ lambdaParams lam) $ intraGroupParalleliseBody (dataDependencies body) group_variant ltid body known_outside <- lift $ M.keys <$> askScope unless (all (`elem` known_outside) $ freeIn $ wss_min ++ wss_avail) $ fail "Irregular parallelism" ((intra_avail_par, kspace, read_input_stms), prelude_stms) <- lift $ runBinder $ do let foldBinOp' _ [] = eSubExp $ intConst Int32 0 foldBinOp' bop (x:xs) = foldBinOp bop x xs ws_min <- mapM (letSubExp "one_intra_par_min" <=< foldBinOp' (Mul Int32)) $ filter (not . null) wss_min ws_avail <- mapM (letSubExp "one_intra_par_avail" <=< foldBinOp' (Mul Int32)) $ filter (not . null) wss_avail -- The amount of parallelism available *in the worst case* is -- equal to the smallest parallel loop. intra_avail_par <- letSubExp "intra_avail_par" =<< foldBinOp' (SMin Int32) ws_avail -- The group size is either the maximum of the minimum parallelism -- exploited, or the desired parallelism (bounded by the max group -- size) in case there is no minimum. group_size <- letSubExp "computed_group_size" =<< if null ws_min then eBinOp (SMin Int32) (eSubExp =<< letSubExp "max_group_size" (Op $ Out.GetSizeMax Out.SizeGroup)) (eSubExp intra_avail_par) else foldBinOp' (SMax Int32) ws_min let inputIsUsed input = kernelInputName input `S.member` freeInBody body used_inps = filter inputIsUsed inps addStms w_stms num_threads <- letSubExp "num_threads" $ BasicOp $ BinOp (Mul Int32) num_groups group_size let ksize = (num_groups, group_size, num_threads) kspace <- newKernelSpace ksize $ FlatThreadSpace $ ispace ++ [(ltid,group_size)] read_input_stms <- mapM readKernelInput used_inps return (intra_avail_par, kspace, read_input_stms) let kbody' = kbody { kernelBodyStms = stmsFromList read_input_stms <> kernelBodyStms kbody } -- The kernel itself is producing a "flat" result of shape -- [num_groups]. We must explicitly reshape it to match the shapes -- of our enclosing map-nests. let nested_pat = loopNestingPattern first_nest flatPatElem pat_elem = do let t' = arrayOfRow (length ispace `stripArray` patElemType pat_elem) num_groups name <- newVName $ baseString (patElemName pat_elem) ++ "_flat" return $ PatElem name t' flat_pat <- lift $ Pattern [] <$> mapM flatPatElem (patternValueElements nested_pat) let kstm = Let flat_pat (StmAux cs ()) $ Op $ Kernel (KernelDebugHints "map_intra_group" []) kspace rts kbody' reshapeStm nested_pe flat_pe = Let (Pattern [] [nested_pe]) (StmAux cs ()) $ BasicOp $ Reshape (map DimNew $ arrayDims $ patElemType nested_pe) $ patElemName flat_pe reshape_stms = zipWith reshapeStm (patternElements nested_pat) (patternElements flat_pat) let intra_min_par = intra_avail_par return ((intra_min_par, intra_avail_par), spaceGroupSize kspace, prelude_stms, oneStm kstm <> stmsFromList reshape_stms) where first_nest = fst knest cs = loopNestingCertificates first_nest data Env = Env { _localTID :: VName , _dataDeps :: Dependencies , _groupVariant :: Names } type IntraGroupM = BinderT Out.InKernel (RWS Env (S.Set [SubExp], S.Set [SubExp]) VNameSource) runIntraGroupM :: (MonadFreshNames m, HasScope Out.Kernels m) => Env -> IntraGroupM () -> m ([[SubExp]], [[SubExp]], Out.Stms Out.InKernel) runIntraGroupM env m = do scope <- castScope <$> askScope modifyNameSource $ \src -> let (((), kstms), src', (ws_min, ws_avail)) = runRWS (runBinderT m scope) env src in ((S.toList ws_min, S.toList ws_avail, kstms), src') parallelMin :: [SubExp] -> IntraGroupM () parallelMin ws = tell (S.singleton ws, S.singleton ws) parallelAvail :: [SubExp] -> IntraGroupM () parallelAvail ws = tell (mempty, S.singleton ws) intraGroupBody :: Body -> IntraGroupM (Out.Body Out.InKernel) intraGroupBody body = do stms <- collectStms_ $ mapM_ intraGroupStm $ bodyStms body return $ mkBody stms $ bodyResult body intraGroupStm :: Stm -> IntraGroupM () intraGroupStm stm@(Let pat _ e) = do Env ltid deps group_variant <- ask let groupInvariant (Var v) = S.null . S.intersection group_variant . flip (M.findWithDefault mempty) deps $ v groupInvariant Constant{} = True case e of DoLoop ctx val (ForLoop i it bound inps) loopbody | groupInvariant bound -> localScope (scopeOf form) $ localScope (scopeOfFParams $ map fst $ ctx ++ val) $ do loopbody' <- intraGroupBody loopbody letBind_ pat $ DoLoop ctx val form loopbody' where form = ForLoop i it bound inps If cond tbody fbody ifattr | groupInvariant cond -> do tbody' <- intraGroupBody tbody fbody' <- intraGroupBody fbody letBind_ pat $ If cond tbody' fbody' ifattr Op (Screma w form arrs) | Just fun <- isMapSOAC form -> do body_stms <- collectStms_ $ do forM_ (zip (lambdaParams fun) arrs) $ \(p, arr) -> do arr_t <- lookupType arr letBindNames [paramName p] $ BasicOp $ Index arr $ fullSlice arr_t [DimFix $ Var ltid] Kernelise.transformStms $ bodyStms $ lambdaBody fun let comb_body = mkBody body_stms $ bodyResult $ lambdaBody fun ctid <- newVName "ctid" letBind_ pat $ Op $ Out.Combine (Out.combineSpace [(ctid, w)]) (lambdaReturnType fun) [] comb_body mapM_ (parallelMin . arrayDims) $ patternTypes pat parallelMin [w] Op (Screma w form arrs) | Just (scanfun, nes, foldfun) <- isScanomapSOAC form -> do let (scan_pes, map_pes) = splitAt (length nes) $ patternElements pat scan_input <- procInput ltid (Pattern [] map_pes) w foldfun nes arrs scanfun' <- Kernelise.transformLambda scanfun -- A GroupScan lambda needs two more parameters. my_index <- newVName "my_index" other_index <- newVName "other_index" let my_index_param = Param my_index (Prim int32) other_index_param = Param other_index (Prim int32) scanfun'' = scanfun' { lambdaParams = my_index_param : other_index_param : lambdaParams scanfun' } letBind_ (Pattern [] scan_pes) $ Op $ Out.GroupScan w scanfun'' $ zip nes scan_input parallelMin [w] Op (Screma w form arrs) | Just (_, redfun, nes, foldfun) <- isRedomapSOAC form -> do let (red_pes, map_pes) = splitAt (length nes) $ patternElements pat red_input <- procInput ltid (Pattern [] map_pes) w foldfun nes arrs redfun' <- Kernelise.transformLambda redfun -- A GroupReduce lambda needs two more parameters. my_index <- newVName "my_index" other_index <- newVName "other_index" let my_index_param = Param my_index (Prim int32) other_index_param = Param other_index (Prim int32) redfun'' = redfun' { lambdaParams = my_index_param : other_index_param : lambdaParams redfun' } letBind_ (Pattern [] red_pes) $ Op $ Out.GroupReduce w redfun'' $ zip nes red_input parallelMin [w] Op (Stream w (Sequential accs) lam arrs) | chunk_size_param : _ <- lambdaParams lam -> do types <- asksScope castScope ((), stream_bnds) <- runBinderT (sequentialStreamWholeArray pat w accs lam arrs) types let replace (Var v) | v == paramName chunk_size_param = w replace se = se replaceSets (x, y) = (S.map (map replace) x, S.map (map replace) y) censor replaceSets $ mapM_ intraGroupStm stream_bnds Op (Scatter w lam ivs dests) -> do parallelMin [w] ctid <- newVName "ctid" let cspace = Out.CombineSpace dests [(ctid, w)] body_stms <- collectStms_ $ do forM_ (zip (lambdaParams lam) ivs) $ \(p, arr) -> do arr_t <- lookupType arr letBindNames [paramName p] $ BasicOp $ Index arr $ fullSlice arr_t [DimFix $ Var ltid] -- ltid on purpose to enable hoisting. Kernelise.transformStms $ bodyStms $ lambdaBody lam let body = mkBody body_stms $ bodyResult $ lambdaBody lam letBind_ pat $ Op $ Out.Combine cspace (lambdaReturnType lam) mempty body BasicOp (Update dest slice (Var v)) -> do let ws = sliceDims slice activeForDim w i = BasicOp $ CmpOp (CmpSlt Int32) i w parallelMin ws dest' <- letExp "update_inp" $ Op $ Out.Barrier [Var dest] let new_inds = unflattenIndex (map (primExpFromSubExp int32) ws) (primExpFromSubExp int32 $ Var ltid) new_inds' <- mapM (letSubExp "i" <=< toExp) new_inds active <- letSubExp "active" =<< foldBinOp LogAnd (constant True) =<< mapM (letSubExp "active") (zipWith activeForDim ws new_inds') (active_res, active_stms) <- collectStms $ do slice' <- mapM (letSubExp "j" <=< toExp) $ fixSlice (map (fmap $ primExpFromSubExp int32) slice) new_inds letInPlace "update_res" dest' (map DimFix slice') $ BasicOp $ Index v $ map DimFix new_inds' sync <- letSubExp "update_res" =<< eIf (eSubExp active) (pure $ mkBody active_stms [Var active_res]) (pure $ mkBody mempty [Var dest']) letBind_ pat $ Op $ Out.Barrier [sync] BasicOp (Copy arr) -> do arr_t <- lookupType arr let w = arraySize 0 arr_t ctid <- newVName "copy_ctid" letBind_ pat . Op . Out.Combine (Out.combineSpace [(ctid, w)]) [rowType arr_t] [] <=< localScope (M.singleton ctid $ IndexInfo Int32) $ insertStmsM $ resultBodyM . pure <=< letSubExp "v" $ BasicOp $ Index arr $ fullSlice arr_t [DimFix $ Var ctid] BasicOp (Replicate (Shape outer_ws) se) | [inner_ws] <- map (drop (length outer_ws) . arrayDims) $ patternTypes pat -> do let ws = outer_ws ++ inner_ws new_inds' <- replicateM (length ws) $ newVName "new_local_index" let inner_inds' = drop (length outer_ws) new_inds' space = Out.combineSpace $ zip new_inds' ws index = case se of Var v -> BasicOp $ Index v $ map (DimFix . Var) inner_inds' Constant{} -> BasicOp $ SubExp se body <- runBodyBinder $ eBody [pure index] letBind_ pat $ Op $ Out.Combine space (map (Prim . elemType) $ patternTypes pat) [] body mapM_ (parallelAvail . arrayDims) $ patternTypes pat _ -> Kernelise.transformStm stm where procInput :: VName -> Out.Pattern Out.InKernel -> SubExp -> Lambda -> [SubExp] -> [VName] -> IntraGroupM [VName] procInput ltid map_pat w map_fun nes arrs = do fold_stms <- collectStms_ $ do forM_ (zip (lambdaParams map_fun) arrs) $ \(p, arr) -> do arr_t <- lookupType arr letBindNames_ [paramName p] $ BasicOp $ Index arr $ fullSlice arr_t [DimFix $ Var ltid] Kernelise.transformStms $ bodyStms $ lambdaBody map_fun let fold_body = mkBody fold_stms $ bodyResult $ lambdaBody map_fun op_inps <- replicateM (length nes) (newVName "op_input") ctid <- newVName "ctid" letBindNames_ (op_inps ++ patternNames map_pat) $ Op $ Out.Combine (Out.combineSpace [(ctid, w)]) (lambdaReturnType map_fun) [] fold_body return op_inps intraGroupParalleliseBody :: (MonadFreshNames m, HasScope Out.Kernels m) => Dependencies -> Names -> VName -> Body -> m ([[SubExp]], [[SubExp]], Out.KernelBody Out.InKernel) intraGroupParalleliseBody deps group_variant ltid body = do (min_ws, avail_ws, kstms) <- runIntraGroupM (Env ltid deps group_variant) $ mapM_ intraGroupStm $ bodyStms body return (min_ws, avail_ws, KernelBody () kstms $ map (ThreadsReturn OneResultPerGroup) $ bodyResult body)