{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE TypeFamilies #-}
module Futhark.Pass.ExtractKernels.Intragroup
(intraGroupParallelise)
where
import Control.Monad.Identity
import Control.Monad.RWS
import Control.Monad.Trans.Maybe
import qualified Data.Map.Strict as M
import qualified Data.Set as S
import Prelude hiding (log)
import Futhark.Analysis.PrimExp.Convert
import Futhark.Representation.SOACS
import qualified Futhark.Representation.Kernels as Out
import Futhark.Representation.Kernels.Kernel
import Futhark.MonadFreshNames
import Futhark.Tools
import Futhark.Pass.ExtractKernels.DistributeNests
import Futhark.Pass.ExtractKernels.Distribution
import Futhark.Pass.ExtractKernels.BlockedKernel
import Futhark.Util (chunks)
import Futhark.Util.Log
intraGroupParallelise :: (MonadFreshNames m, LocalScope Out.Kernels m) =>
KernelNest -> Lambda
-> m (Maybe ((SubExp, SubExp), SubExp, Log,
Out.Stms Out.Kernels, Out.Stms Out.Kernels))
intraGroupParallelise knest lam = runMaybeT $ do
(ispace, inps) <- lift $ flatKernel knest
(num_groups, w_stms) <- lift $ runBinder $
letSubExp "intra_num_groups" =<<
foldBinOp (Mul Int32) (intConst Int32 1) (map snd ispace)
let body = lambdaBody lam
group_size <- newVName "computed_group_size"
let intra_lvl = SegThread (Count num_groups) (Count $ Var group_size) SegNoVirt
(wss_min, wss_avail, log, kbody) <-
lift $ localScope (scopeOfLParams $ lambdaParams lam) $
intraGroupParalleliseBody intra_lvl body
known_outside <- lift $ M.keys <$> askScope
unless (all (`elem` known_outside) $ namesToList $ freeIn $
wss_min ++ wss_avail) $
fail "Irregular parallelism"
((intra_avail_par, kspace, read_input_stms), prelude_stms) <- lift $ runBinder $ do
let foldBinOp' _ [] = eSubExp $ intConst Int32 0
foldBinOp' bop (x:xs) = foldBinOp bop x xs
ws_min <- mapM (letSubExp "one_intra_par_min" <=< foldBinOp' (Mul Int32)) $
filter (not . null) wss_min
ws_avail <- mapM (letSubExp "one_intra_par_avail" <=< foldBinOp' (Mul Int32)) $
filter (not . null) wss_avail
intra_avail_par <- letSubExp "intra_avail_par" =<< foldBinOp' (SMin Int32) ws_avail
letBindNames_ [group_size] =<<
if null ws_min
then eBinOp (SMin Int32)
(eSubExp =<< letSubExp "max_group_size" (Op $ Out.GetSizeMax Out.SizeGroup))
(eSubExp intra_avail_par)
else foldBinOp' (SMax Int32) ws_min
let inputIsUsed input = kernelInputName input `nameIn` freeIn body
used_inps = filter inputIsUsed inps
addStms w_stms
read_input_stms <- runBinder_ $ mapM readKernelInput used_inps
space <- mkSegSpace ispace
return (intra_avail_par, space, read_input_stms)
let kbody' = kbody { kernelBodyStms = read_input_stms <> kernelBodyStms kbody }
let nested_pat = loopNestingPattern first_nest
rts = map (length ispace `stripArray`) $ patternTypes nested_pat
lvl = SegGroup (Count num_groups) (Count $ Var group_size) SegNoVirt
kstm = Let nested_pat (StmAux cs ()) $ Op $ SegOp $ SegMap lvl kspace rts kbody'
let intra_min_par = intra_avail_par
return ((intra_min_par, intra_avail_par), Var group_size, log,
prelude_stms, oneStm kstm)
where first_nest = fst knest
cs = loopNestingCertificates first_nest
data Acc = Acc { accMinPar :: S.Set [SubExp]
, accAvailPar :: S.Set [SubExp]
, accLog :: Log
}
instance Semigroup Acc where
Acc min_x avail_x log_x <> Acc min_y avail_y log_y =
Acc (min_x <> min_y) (avail_x <> avail_y) (log_x <> log_y)
instance Monoid Acc where
mempty = Acc mempty mempty mempty
type IntraGroupM =
BinderT Out.Kernels (RWS () Acc VNameSource)
instance MonadLogger IntraGroupM where
addLog log = tell mempty { accLog = log }
runIntraGroupM :: (MonadFreshNames m, HasScope Out.Kernels m) =>
IntraGroupM () -> m (Acc, Out.Stms Out.Kernels)
runIntraGroupM m = do
scope <- castScope <$> askScope
modifyNameSource $ \src ->
let (((), kstms), src', acc) = runRWS (runBinderT m scope) () src
in ((acc, kstms), src')
parallelMin :: [SubExp] -> IntraGroupM ()
parallelMin ws = tell mempty { accMinPar = S.singleton ws
, accAvailPar = S.singleton ws
}
intraGroupBody :: SegLevel -> Body -> IntraGroupM (Out.Body Out.Kernels)
intraGroupBody lvl body = do
stms <- collectStms_ $ mapM_ (intraGroupStm lvl) $ bodyStms body
return $ mkBody stms $ bodyResult body
intraGroupStm :: SegLevel -> Stm -> IntraGroupM ()
intraGroupStm lvl stm@(Let pat aux e) = do
scope <- askScope
let lvl' = SegThread (segNumGroups lvl) (segGroupSize lvl) SegNoVirt
case e of
DoLoop ctx val form loopbody ->
localScope (scopeOf form') $
localScope (scopeOfFParams $ map fst $ ctx ++ val) $ do
loopbody' <- intraGroupBody lvl loopbody
certifying (stmAuxCerts aux) $
letBind_ pat $ DoLoop ctx val form' loopbody'
where form' = case form of
ForLoop i it bound inps -> ForLoop i it bound inps
WhileLoop cond -> WhileLoop cond
If cond tbody fbody ifattr -> do
tbody' <- intraGroupBody lvl tbody
fbody' <- intraGroupBody lvl fbody
certifying (stmAuxCerts aux) $
letBind_ pat $ If cond tbody' fbody' ifattr
Op (Screma w form arrs)
| Just lam <- isMapSOAC form -> do
let loopnest = MapNesting pat (stmAuxCerts aux) w $ zip (lambdaParams lam) arrs
env = DistEnv { distNest =
singleNesting $ Nesting mempty loopnest
, distScope =
scopeOfPattern pat <>
scopeForKernels (scopeOf lam) <> scope
, distOnInnerMap =
distributeMap
, distOnTopLevelStms =
lift . collectStms_ . intraGroupStms lvl
, distSegLevel = \minw _ _ -> do
lift $ parallelMin minw
return lvl
}
acc = DistAcc { distTargets = singleTarget (pat, bodyResult $ lambdaBody lam)
, distStms = mempty
}
addStms =<<
runDistNestT env (distributeMapBodyStms acc (bodyStms $ lambdaBody lam))
Op (Screma w form arrs)
| Just (scanfun, nes, mapfun) <- isScanomapSOAC form -> do
let scanfun' = soacsLambdaToKernels scanfun
mapfun' = soacsLambdaToKernels mapfun
certifying (stmAuxCerts aux) $
addStms =<< segScan lvl' pat w scanfun' mapfun' nes arrs [] []
parallelMin [w]
Op (Screma w form arrs)
| Just (reds, map_lam) <- isRedomapSOAC form,
Reduce comm red_lam nes <- singleReduce reds -> do
let red_lam' = soacsLambdaToKernels red_lam
map_lam' = soacsLambdaToKernels map_lam
certifying (stmAuxCerts aux) $
addStms =<< segRed lvl' pat w [SegRedOp comm red_lam' nes mempty] map_lam' arrs [] []
parallelMin [w]
Op (Stream w (Sequential accs) lam arrs)
| chunk_size_param : _ <- lambdaParams lam -> do
types <- asksScope castScope
((), stream_bnds) <-
runBinderT (sequentialStreamWholeArray pat w accs lam arrs) types
let replace (Var v) | v == paramName chunk_size_param = w
replace se = se
replaceSets (Acc x y log) =
Acc (S.map (map replace) x) (S.map (map replace) y) log
censor replaceSets $ mapM_ (intraGroupStm lvl) stream_bnds
Op (Scatter w lam ivs dests) -> do
write_i <- newVName "write_i"
space <- mkSegSpace [(write_i, w)]
let lam' = soacsLambdaToKernels lam
(dests_ws, dests_ns, dests_vs) = unzip3 dests
(i_res, v_res) = splitAt (sum dests_ns) $ bodyResult $ lambdaBody lam'
krets = do (a_w, a, is_vs) <- zip3 dests_ws dests_vs $ chunks dests_ns $ zip i_res v_res
return $ WriteReturns [a_w] a [ ([i],v) | (i,v) <- is_vs ]
inputs = do (p, p_a) <- zip (lambdaParams lam') ivs
return $ KernelInput (paramName p) (paramType p) p_a [Var write_i]
kstms <- runBinder_ $ localScope (scopeOfSegSpace space) $ do
mapM_ readKernelInput inputs
addStms $ bodyStms $ lambdaBody lam'
certifying (stmAuxCerts aux) $ do
let ts = map rowType $ patternTypes pat
body = KernelBody () kstms krets
letBind_ pat $ Op $ SegOp $ SegMap lvl' space ts body
parallelMin [w]
_ ->
addStm $ soacsStmToKernels stm
intraGroupStms :: SegLevel -> Stms SOACS -> IntraGroupM ()
intraGroupStms lvl = mapM_ (intraGroupStm lvl)
intraGroupParalleliseBody :: (MonadFreshNames m, HasScope Out.Kernels m) =>
SegLevel -> Body
-> m ([[SubExp]], [[SubExp]], Log, Out.KernelBody Out.Kernels)
intraGroupParalleliseBody lvl body = do
(Acc min_ws avail_ws log, kstms) <-
runIntraGroupM $ intraGroupStms lvl $ bodyStms body
return (S.toList min_ws, S.toList avail_ws, log,
KernelBody () kstms $ map Returns $ bodyResult body)