{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE RankNTypes #-}
module Futhark.Pass.ExtractKernels.DistributeNests
( KernelsStms
, MapLoop(..)
, mapLoopStm
, bodyContainsParallelism
, lambdaContainsParallelism
, determineReduceOp
, incrementalFlattening
, genReduceKernel
, DistEnv (..)
, DistAcc (..)
, runDistNestT
, DistNestT
, distributeMap
, distribute
, distributeSingleStm
, distributeMapBodyStms
, postKernelsStms
, addStmsToKernel
, addStmToKernel
, permutationAndMissing
, addKernels
, addKernel
, inNesting
)
where
import Control.Arrow (first)
import Control.Monad.Identity
import Control.Monad.RWS.Strict
import Control.Monad.Reader
import Control.Monad.Writer.Strict
import Control.Monad.Trans.Maybe
import Data.Maybe
import Data.List
import Futhark.Representation.SOACS
import qualified Futhark.Representation.SOACS.SOAC as SOAC
import Futhark.Representation.SOACS.Simplify (simpleSOACS)
import qualified Futhark.Representation.Kernels as Out
import Futhark.Representation.Kernels.Kernel
import Futhark.MonadFreshNames
import Futhark.Tools
import Futhark.Transform.Rename
import Futhark.Transform.CopyPropagate
import Futhark.Pass.ExtractKernels.Distribution
import Futhark.Pass.ExtractKernels.ISRWIM
import Futhark.Pass.ExtractKernels.BlockedKernel hiding (segThread, segThreadCapped)
import Futhark.Pass.ExtractKernels.Interchange
import Futhark.Util
import Futhark.Util.Log
data MapLoop = MapLoop Pattern Certificates SubExp Lambda [VName]
mapLoopStm :: MapLoop -> Stm
mapLoopStm (MapLoop pat cs w lam arrs) = Let pat (StmAux cs ()) $ Op $ Screma w (mapSOAC lam) arrs
type KernelsStms = Out.Stms Out.Kernels
data DistEnv m =
DistEnv { distNest :: Nestings
, distScope :: Scope Out.Kernels
, distOnTopLevelStms :: Stms SOACS -> DistNestT m (Stms Out.Kernels)
, distOnInnerMap :: MapLoop -> DistAcc -> DistNestT m DistAcc
, distSegLevel :: MkSegLevel m
}
data DistAcc =
DistAcc { distTargets :: Targets
, distStms :: KernelsStms
}
data DistRes =
DistRes { accPostKernels :: PostKernels
, accLog :: Log
}
instance Semigroup DistRes where
DistRes ks1 log1 <> DistRes ks2 log2 =
DistRes (ks1 <> ks2) (log1 <> log2)
instance Monoid DistRes where
mempty = DistRes mempty mempty
newtype PostKernel = PostKernel { unPostKernel :: KernelsStms }
newtype PostKernels = PostKernels [PostKernel]
instance Semigroup PostKernels where
PostKernels xs <> PostKernels ys = PostKernels $ ys ++ xs
instance Monoid PostKernels where
mempty = PostKernels mempty
postKernelsStms :: PostKernels -> KernelsStms
postKernelsStms (PostKernels kernels) = mconcat $ map unPostKernel kernels
typeEnvFromDistAcc :: DistAcc -> Scope Out.Kernels
typeEnvFromDistAcc = scopeOfPattern . fst . outerTarget . distTargets
addStmsToKernel :: KernelsStms -> DistAcc -> DistAcc
addStmsToKernel stms acc =
acc { distStms = stms <> distStms acc }
addStmToKernel :: Monad m => Stm -> DistAcc -> m DistAcc
addStmToKernel stm acc = do
let stm' = soacsStmToKernels stm
return acc { distStms = oneStm stm' <> distStms acc }
newtype DistNestT m a = DistNestT (ReaderT (DistEnv m) (WriterT DistRes m) a)
deriving (Functor, Applicative, Monad,
MonadReader (DistEnv m),
MonadWriter DistRes)
instance MonadTrans DistNestT where
lift = DistNestT . lift . lift
instance MonadFreshNames m => MonadFreshNames (DistNestT m) where
getNameSource = DistNestT $ lift getNameSource
putNameSource = DistNestT . lift . putNameSource
instance Monad m => HasScope Out.Kernels (DistNestT m) where
askScope = asks distScope
instance Monad m => LocalScope Out.Kernels (DistNestT m) where
localScope types = local $ \env ->
env { distScope = types <> distScope env }
instance Monad m => MonadLogger (DistNestT m) where
addLog msgs = tell mempty { accLog = msgs }
runDistNestT :: MonadLogger m =>
DistEnv m -> DistNestT m DistAcc -> m (Out.Stms Out.Kernels)
runDistNestT env (DistNestT m) = do
(acc, res) <- runWriterT $ runReaderT m env
addLog $ accLog res
return $
postKernelsStms (accPostKernels res) <>
identityStms (outerTarget $ distTargets acc)
where outermost = nestingLoop $
case distNest env of (nest, []) -> nest
(_, nest : _) -> nest
params_to_arrs = map (first paramName) $
loopNestingParamsAndArrs outermost
identityStms (rem_pat, res) =
stmsFromList $ zipWith identityStm (patternValueElements rem_pat) res
identityStm pe (Var v)
| Just arr <- lookup v params_to_arrs =
Let (Pattern [] [pe]) (defAux ()) $ BasicOp $ Copy arr
identityStm pe se =
Let (Pattern [] [pe]) (defAux ()) $ BasicOp $
Replicate (Shape [loopNestingWidth outermost]) se
addKernels :: Monad m => PostKernels -> DistNestT m ()
addKernels ks = tell $ mempty { accPostKernels = ks }
addKernel :: Monad m => KernelsStms -> DistNestT m ()
addKernel bnds = addKernels $ PostKernels [PostKernel bnds]
withStm :: Monad m => Stm -> DistNestT m a -> DistNestT m a
withStm stm = local $ \env ->
env { distScope =
scopeForKernels (scopeOf stm) <> distScope env
, distNest =
letBindInInnerNesting provided $
distNest env
}
where provided = namesFromList $ patternNames $ stmPattern stm
mapNesting :: Monad m =>
Pattern -> Certificates -> SubExp -> Lambda -> [VName]
-> DistNestT m a
-> DistNestT m a
mapNesting pat cs w lam arrs = local $ \env ->
env { distNest = pushInnerNesting nest $ distNest env
, distScope = scopeForKernels (scopeOf lam) <> distScope env
}
where nest = Nesting mempty $
MapNesting pat cs w $
zip (lambdaParams lam) arrs
inNesting :: Monad m =>
KernelNest -> DistNestT m a -> DistNestT m a
inNesting (outer, nests) = local $ \env ->
env { distNest = (inner, nests')
, distScope = mconcat (map scopeOf $ outer : nests) <> distScope env
}
where (inner, nests') =
case reverse nests of
[] -> (asNesting outer, [])
(inner' : ns) -> (asNesting inner', map asNesting $ outer : reverse ns)
asNesting = Nesting mempty
bodyContainsParallelism :: Body -> Bool
bodyContainsParallelism = any (isMap . stmExp) . bodyStms
where isMap Op{} = True
isMap _ = False
lambdaContainsParallelism :: Lambda -> Bool
lambdaContainsParallelism = bodyContainsParallelism . lambdaBody
incrementalFlattening :: Bool
incrementalFlattening = isJust $ lookup "FUTHARK_INCREMENTAL_FLATTENING" unixEnvironment
leavingNesting :: Monad m => MapLoop -> DistAcc -> DistNestT m DistAcc
leavingNesting (MapLoop _ cs w lam arrs) acc =
case popInnerTarget $ distTargets acc of
Nothing ->
fail "The kernel targets list is unexpectedly small"
Just ((pat,res), newtargets) -> do
let acc' = acc { distTargets = newtargets }
if null $ distStms acc'
then return acc'
else do let body = Body () (distStms acc') res
used_in_body = freeIn body
(used_params, used_arrs) =
unzip $
filter ((`nameIn` used_in_body) . paramName . fst) $
zip (lambdaParams lam) arrs
lam' = Lambda { lambdaParams = used_params
, lambdaBody = body
, lambdaReturnType = map rowType $ patternTypes pat
}
let stms = oneStm $ Let pat (StmAux cs ()) $ Op $
OtherOp $ Screma w (mapSOAC lam') used_arrs
return $ addStmsToKernel stms acc' { distStms = mempty }
distributeMapBodyStms :: MonadFreshNames m => DistAcc -> Stms SOACS -> DistNestT m DistAcc
distributeMapBodyStms orig_acc = distribute <=< onStms orig_acc . stmsToList
where
onStms acc [] = return acc
onStms acc (Let pat (StmAux cs _) (Op (Stream w (Sequential accs) lam arrs)):stms) = do
types <- asksScope scopeForSOACs
stream_stms <-
snd <$> runBinderT (sequentialStreamWholeArray pat w accs lam arrs) types
stream_stms' <-
runReaderT (copyPropagateInStms simpleSOACS stream_stms) types
onStms acc $ stmsToList (fmap (certify cs) stream_stms') ++ stms
onStms acc (stm:stms) =
withStm stm $ maybeDistributeStm stm =<< onStms acc stms
onInnerMap :: Monad m => MapLoop -> DistAcc -> DistNestT m DistAcc
onInnerMap loop acc = do
f <- asks distOnInnerMap
f loop acc
onTopLevelStms :: Monad m => Stms SOACS -> DistNestT m ()
onTopLevelStms stms = do
f <- asks distOnTopLevelStms
addKernel =<< f stms
maybeDistributeStm :: MonadFreshNames m => Stm -> DistAcc -> DistNestT m DistAcc
maybeDistributeStm bnd@(Let pat _ (Op (Screma w form arrs))) acc
| Just lam <- isMapSOAC form =
distributeIfPossible acc >>= \case
Nothing -> addStmToKernel bnd acc
Just acc' -> distribute =<< onInnerMap (MapLoop pat (stmCerts bnd) w lam arrs) acc'
maybeDistributeStm bnd@(Let pat _ (DoLoop [] val form@ForLoop{} body)) acc
| null (patternContextElements pat), bodyContainsParallelism body =
distributeSingleStm acc bnd >>= \case
Just (kernels, res, nest, acc')
| not $ freeIn form `namesIntersect` boundInKernelNest nest,
Just (perm, pat_unused) <- permutationAndMissing pat res ->
localScope (typeEnvFromDistAcc acc') $ do
addKernels kernels
nest' <- expandKernelNest pat_unused nest
types <- asksScope scopeForSOACs
bnds <- runReaderT
(interchangeLoops nest' (SeqLoop perm pat val form body)) types
onTopLevelStms bnds
return acc'
_ ->
addStmToKernel bnd acc
maybeDistributeStm stm@(Let pat _ (If cond tbranch fbranch ret)) acc
| null (patternContextElements pat),
bodyContainsParallelism tbranch || bodyContainsParallelism fbranch ||
any (not . primType) (ifReturns ret) =
distributeSingleStm acc stm >>= \case
Just (kernels, res, nest, acc')
| not $
(freeIn cond <> freeIn ret) `namesIntersect` boundInKernelNest nest,
Just (perm, pat_unused) <- permutationAndMissing pat res ->
localScope (typeEnvFromDistAcc acc') $ do
nest' <- expandKernelNest pat_unused nest
addKernels kernels
types <- asksScope scopeForSOACs
let branch = Branch perm pat cond tbranch fbranch ret
stms <- runReaderT (interchangeBranch nest' branch) types
onTopLevelStms stms
return acc'
_ ->
addStmToKernel stm acc
maybeDistributeStm (Let pat (StmAux cs _) (Op (Screma w form arrs))) acc
| Just [Reduce comm lam nes] <- isReduceSOAC form,
Just m <- irwim pat w comm lam $ zip nes arrs = do
types <- asksScope scopeForSOACs
(_, bnds) <- runBinderT (certifying cs m) types
distributeMapBodyStms acc bnds
maybeDistributeStm bnd@(Let pat (StmAux cs _) (Op (Scatter w lam ivs as))) acc =
distributeSingleStm acc bnd >>= \case
Just (kernels, res, nest, acc')
| Just (perm, pat_unused) <- permutationAndMissing pat res ->
localScope (typeEnvFromDistAcc acc') $ do
nest' <- expandKernelNest pat_unused nest
let lam' = soacsLambdaToKernels lam
addKernels kernels
addKernel =<< segmentedScatterKernel nest' perm pat cs w lam' ivs as
return acc'
_ ->
addStmToKernel bnd acc
maybeDistributeStm bnd@(Let pat (StmAux cs _) (Op (GenReduce w ops lam as))) acc =
distributeSingleStm acc bnd >>= \case
Just (kernels, res, nest, acc')
| Just (perm, pat_unused) <- permutationAndMissing pat res ->
localScope (typeEnvFromDistAcc acc') $ do
let lam' = soacsLambdaToKernels lam
nest' <- expandKernelNest pat_unused nest
addKernels kernels
addKernel =<< segmentedGenReduceKernel nest' perm cs w ops lam' as
return acc'
_ ->
addStmToKernel bnd acc
maybeDistributeStm bnd@(Let pat (StmAux cs _) (Op (Screma w form arrs))) acc
| Just (lam, nes, map_lam) <- isScanomapSOAC form =
distributeSingleStm acc bnd >>= \case
Just (kernels, res, nest, acc')
| Just (perm, pat_unused) <- permutationAndMissing pat res ->
localScope (typeEnvFromDistAcc acc') $ do
nest' <- expandKernelNest pat_unused nest
let map_lam' = soacsLambdaToKernels map_lam
lam' = soacsLambdaToKernels lam
localScope (typeEnvFromDistAcc acc') $
segmentedScanomapKernel nest' perm w lam' map_lam' nes arrs >>=
kernelOrNot cs bnd acc kernels acc'
_ ->
addStmToKernel bnd acc
maybeDistributeStm bnd@(Let pat (StmAux cs _) (Op (Screma w form arrs))) acc
| Just (reds, map_lam) <- isRedomapSOAC form,
Reduce comm lam nes <- singleReduce reds,
isIdentityLambda map_lam || incrementalFlattening =
distributeSingleStm acc bnd >>= \case
Just (kernels, res, nest, acc')
| Just (perm, pat_unused) <- permutationAndMissing pat res ->
localScope (typeEnvFromDistAcc acc') $ do
nest' <- expandKernelNest pat_unused nest
let lam' = soacsLambdaToKernels lam
map_lam' = soacsLambdaToKernels map_lam
let comm' | commutativeLambda lam = Commutative
| otherwise = comm
regularSegmentedRedomapKernel nest' perm w comm' lam' map_lam' nes arrs >>=
kernelOrNot cs bnd acc kernels acc'
_ ->
addStmToKernel bnd acc
maybeDistributeStm (Let pat (StmAux cs _) (Op (Screma w form arrs))) acc
| incrementalFlattening || isNothing (isRedomapSOAC form) = do
scope <- asksScope scopeForSOACs
distributeMapBodyStms acc . fmap (certify cs) . snd =<<
runBinderT (dissectScrema pat w form arrs) scope
maybeDistributeStm (Let pat aux (BasicOp (Replicate (Shape (d:ds)) v))) acc
| [t] <- patternTypes pat = do
tmp <- newVName "tmp"
let rowt = rowType t
newbnd = Let pat aux $ Op $ Screma d (mapSOAC lam) []
tmpbnd = Let (Pattern [] [PatElem tmp rowt]) aux $
BasicOp $ Replicate (Shape ds) v
lam = Lambda { lambdaReturnType = [rowt]
, lambdaParams = []
, lambdaBody = mkBody (oneStm tmpbnd) [Var tmp]
}
maybeDistributeStm newbnd acc
maybeDistributeStm bnd@(Let _ aux (BasicOp Copy{})) acc =
distributeSingleUnaryStm acc bnd $ \_ outerpat arr ->
return $ oneStm $ Let outerpat aux $ BasicOp $ Copy arr
maybeDistributeStm bnd@(Let (Pattern [] [pe]) aux (BasicOp Opaque{})) acc
| not $ primType $ typeOf pe =
distributeSingleUnaryStm acc bnd $ \_ outerpat arr ->
return $ oneStm $ Let outerpat aux $ BasicOp $ Copy arr
maybeDistributeStm bnd@(Let _ aux (BasicOp (Rearrange perm _))) acc =
distributeSingleUnaryStm acc bnd $ \nest outerpat arr -> do
let r = length (snd nest) + 1
perm' = [0..r-1] ++ map (+r) perm
arr' <- newVName $ baseString arr
arr_t <- lookupType arr
return $ stmsFromList
[Let (Pattern [] [PatElem arr' arr_t]) aux $ BasicOp $ Copy arr,
Let outerpat aux $ BasicOp $ Rearrange perm' arr']
maybeDistributeStm bnd@(Let _ aux (BasicOp (Reshape reshape _))) acc =
distributeSingleUnaryStm acc bnd $ \nest outerpat arr -> do
let reshape' = map DimNew (kernelNestWidths nest) ++
map DimNew (newDims reshape)
return $ oneStm $ Let outerpat aux $ BasicOp $ Reshape reshape' arr
maybeDistributeStm stm@(Let _ aux (BasicOp (Rotate rots _))) acc =
distributeSingleUnaryStm acc stm $ \nest outerpat arr -> do
let rots' = map (const $ intConst Int32 0) (kernelNestWidths nest) ++ rots
return $ oneStm $ Let outerpat aux $ BasicOp $ Rotate rots' arr
maybeDistributeStm (Let pat aux (BasicOp (Update arr [DimFix i] v))) acc
| [t] <- patternTypes pat,
arrayRank t == 1,
not $ any (amortises . stmExp) $ distStms acc = do
let w = arraySize 0 t
et = stripArray 1 t
lam = Lambda { lambdaParams = []
, lambdaReturnType = [Prim int32, et]
, lambdaBody = mkBody mempty [i, v] }
maybeDistributeStm (Let pat aux $ Op $ Scatter (intConst Int32 1) lam [] [(w, 1, arr)]) acc
where amortises DoLoop{} = True
amortises Op{} = True
amortises _ = False
maybeDistributeStm stm@(Let _ aux (BasicOp (Concat d x xs w))) acc =
distributeSingleStm acc stm >>= \case
Just (kernels, _, nest, acc') ->
localScope (typeEnvFromDistAcc acc') $
segmentedConcat nest >>=
kernelOrNot (stmAuxCerts aux) stm acc kernels acc'
_ ->
addStmToKernel stm acc
where segmentedConcat nest =
isSegmentedOp nest [0] w mempty mempty [] (x:xs) $
\pat _ _ _ (x':xs') _ ->
let d' = d + length (snd nest) + 1
in addStm $ Let pat aux $ BasicOp $ Concat d' x' xs' w
maybeDistributeStm bnd acc =
addStmToKernel bnd acc
distributeSingleUnaryStm :: MonadFreshNames m =>
DistAcc -> Stm
-> (KernelNest -> Pattern -> VName -> DistNestT m (Stms Out.Kernels))
-> DistNestT m DistAcc
distributeSingleUnaryStm acc bnd f =
distributeSingleStm acc bnd >>= \case
Just (kernels, res, nest, acc')
| res == map Var (patternNames $ stmPattern bnd),
(outer, inners) <- nest,
[(arr_p, arr)] <- loopNestingParamsAndArrs outer,
boundInKernelNest nest `namesIntersection` freeIn bnd
== oneName (paramName arr_p) -> do
addKernels kernels
let outerpat = loopNestingPattern $ fst nest
localScope (typeEnvFromDistAcc acc') $ do
(arr', pre_stms) <- repeatMissing arr (outer:inners)
f_stms <- inScopeOf pre_stms $ f nest outerpat arr'
addKernel $ pre_stms <> f_stms
return acc'
_ -> addStmToKernel bnd acc
where
repeatMissing arr inners = do
arr_t <- lookupType arr
let shapes = determineRepeats arr arr_t inners
if all (==Shape []) shapes then return (arr, mempty)
else do
let (outer_shapes, inner_shape) = repeatShapes shapes arr_t
arr_t' = repeatDims outer_shapes inner_shape arr_t
arr' <- newVName $ baseString arr
return (arr', oneStm $ Let (Pattern [] [PatElem arr' arr_t']) (defAux ()) $
BasicOp $ Repeat outer_shapes inner_shape arr)
determineRepeats arr arr_t nests
| (skipped, arr_nest:nests') <- break (hasInput arr) nests,
[(arr_p, _)] <- loopNestingParamsAndArrs arr_nest =
Shape (map loopNestingWidth skipped) :
determineRepeats (paramName arr_p) (rowType arr_t) nests'
| otherwise =
Shape (map loopNestingWidth nests) : replicate (arrayRank arr_t) (Shape [])
hasInput arr nest
| [(_, arr')] <- loopNestingParamsAndArrs nest, arr' == arr = True
| otherwise = False
distribute :: MonadFreshNames m => DistAcc -> DistNestT m DistAcc
distribute acc =
fromMaybe acc <$> distributeIfPossible acc
mkSegLevel :: MonadFreshNames m => DistNestT m (MkSegLevel (DistNestT m))
mkSegLevel = do
mk_lvl <- asks distSegLevel
return $ \w desc r -> do
scope <- askScope
(lvl, stms) <- lift $ lift $ runBinderT (mk_lvl w desc r) scope
addStms stms
return lvl
distributeIfPossible :: MonadFreshNames m => DistAcc -> DistNestT m (Maybe DistAcc)
distributeIfPossible acc = do
nest <- asks distNest
mk_lvl <- mkSegLevel
tryDistribute mk_lvl nest (distTargets acc) (distStms acc) >>= \case
Nothing -> return Nothing
Just (targets, kernel) -> do
addKernel kernel
return $ Just DistAcc { distTargets = targets
, distStms = mempty
}
distributeSingleStm :: MonadFreshNames m =>
DistAcc -> Stm
-> DistNestT m (Maybe (PostKernels, Result, KernelNest, DistAcc))
distributeSingleStm acc bnd = do
nest <- asks distNest
mk_lvl <- mkSegLevel
tryDistribute mk_lvl nest (distTargets acc) (distStms acc) >>= \case
Nothing -> return Nothing
Just (targets, distributed_bnds) ->
tryDistributeStm nest targets bnd >>= \case
Nothing -> return Nothing
Just (res, targets', new_kernel_nest) ->
return $ Just (PostKernels [PostKernel distributed_bnds],
res,
new_kernel_nest,
DistAcc { distTargets = targets'
, distStms = mempty
})
segmentedScatterKernel :: MonadFreshNames m =>
KernelNest
-> [Int]
-> Pattern
-> Certificates
-> SubExp
-> Out.Lambda Out.Kernels
-> [VName] -> [(SubExp,Int,VName)]
-> DistNestT m KernelsStms
segmentedScatterKernel nest perm scatter_pat cs scatter_w lam ivs dests = do
let nest' = pushInnerKernelNesting (scatter_pat, bodyResult $ lambdaBody lam)
(MapNesting scatter_pat cs scatter_w $ zip (lambdaParams lam) ivs) nest
(ispace, kernel_inps) <- flatKernel nest'
let (as_ws, as_ns, as) = unzip3 dests
as_inps <- mapM (findInput kernel_inps) as
mk_lvl <- mkSegLevel
let rts = concatMap (take 1) $ chunks as_ns $
drop (sum as_ns) $ lambdaReturnType lam
(is,vs) = splitAt (sum as_ns) $ bodyResult $ lambdaBody lam
k_body = KernelBody () (bodyStms $ lambdaBody lam) $
map (inPlaceReturn ispace) $
zip3 as_ws as_inps $ chunks as_ns $ zip is vs
(k, k_bnds) <- mapKernel mk_lvl ispace kernel_inps rts k_body
runBinder_ $ do
addStms k_bnds
let pat = Pattern [] $ rearrangeShape perm $
patternValueElements $ loopNestingPattern $ fst nest
certifying cs $ letBind_ pat $ Op $ SegOp k
where findInput kernel_inps a =
maybe bad return $ find ((==a) . kernelInputName) kernel_inps
bad = fail "Ill-typed nested scatter encountered."
inPlaceReturn ispace (aw, inp, is_vs) =
WriteReturns (init ws++[aw]) (kernelInputArray inp)
[ (map Var (init gtids)++[i], v) | (i,v) <- is_vs ]
where (gtids,ws) = unzip ispace
segmentedGenReduceKernel :: MonadFreshNames m =>
KernelNest
-> [Int]
-> Certificates
-> SubExp
-> [SOAC.GenReduceOp SOACS]
-> Out.Lambda Out.Kernels
-> [VName]
-> DistNestT m KernelsStms
segmentedGenReduceKernel nest perm cs genred_w ops lam arrs = do
(ispace, inputs) <- flatKernel nest
let orig_pat = Pattern [] $ rearrangeShape perm $
patternValueElements $ loopNestingPattern $ fst nest
ops' <- forM ops $ \(SOAC.GenReduceOp num_bins dests nes op) ->
SOAC.GenReduceOp num_bins
<$> mapM (fmap kernelInputArray . findInput inputs) dests
<*> pure nes
<*> pure op
runBinder_ $ addStms =<<
genReduceKernel orig_pat ispace inputs cs genred_w ops' lam arrs
where findInput kernel_inps a =
maybe bad return $ find ((==a) . kernelInputName) kernel_inps
bad = fail "Ill-typed nested GenReduce encountered."
genReduceKernel :: (MonadFreshNames m, HasScope Out.Kernels m) =>
Pattern -> [(VName, SubExp)] -> [KernelInput]
-> Certificates -> SubExp -> [SOAC.GenReduceOp SOACS]
-> Out.Lambda Out.Kernels -> [VName]
-> m KernelsStms
genReduceKernel orig_pat ispace inputs cs genred_w ops lam arrs = runBinder_ $ do
ops' <- forM ops $ \(SOAC.GenReduceOp num_bins dests nes op) -> do
(op', nes', shape) <- determineReduceOp op nes
return $ Out.GenReduceOp num_bins dests nes' shape op'
let isDest = flip elem $ concatMap Out.genReduceDest ops'
inputs' = filter (not . isDest . kernelInputArray) inputs
certifying cs $
addStms =<< segGenRed orig_pat genred_w ispace inputs' ops' lam arrs
determineReduceOp :: (MonadBinder m, Lore m ~ Out.Kernels) =>
Lambda -> [SubExp] -> m (Out.Lambda Out.Kernels, [SubExp], Shape)
determineReduceOp lam nes =
case mapM subExpVar nes of
Just ne_vs' -> do
let (shape, lam') = isVectorMap lam
nes' <- forM ne_vs' $ \ne_v -> do
ne_v_t <- lookupType ne_v
letSubExp "genred_ne" $
BasicOp $ Index ne_v $ fullSlice ne_v_t $
replicate (shapeRank shape) $ DimFix $ intConst Int32 0
let lam'' = soacsLambdaToKernels lam'
return (lam'', nes', shape)
Nothing -> do
let lam' = soacsLambdaToKernels lam
return (lam', nes, mempty)
isVectorMap :: Lambda -> (Shape, Lambda)
isVectorMap lam
| [Let (Pattern [] pes) _ (Op (Screma w form arrs))] <-
stmsToList $ bodyStms $ lambdaBody lam,
bodyResult (lambdaBody lam) == map (Var . patElemName) pes,
Just map_lam <- isMapSOAC form,
arrs == map paramName (lambdaParams lam) =
let (shape, lam') = isVectorMap map_lam
in (Shape [w] <> shape, lam')
| otherwise = (mempty, lam)
segmentedScanomapKernel :: MonadFreshNames m =>
KernelNest
-> [Int]
-> SubExp
-> Out.Lambda Out.Kernels -> Out.Lambda Out.Kernels
-> [SubExp] -> [VName]
-> DistNestT m (Maybe KernelsStms)
segmentedScanomapKernel nest perm segment_size lam map_lam nes arrs = do
mk_lvl <- asks distSegLevel
isSegmentedOp nest perm segment_size (freeIn lam) (freeIn map_lam) nes arrs $
\pat ispace inps nes' _ _ -> do
lvl <- mk_lvl (segment_size : map snd ispace) "segscan" $ NoRecommendation SegNoVirt
addStms =<< traverse renameStm =<<
segScan lvl pat segment_size lam map_lam nes' arrs ispace inps
regularSegmentedRedomapKernel :: MonadFreshNames m =>
KernelNest
-> [Int]
-> SubExp -> Commutativity
-> Out.Lambda Out.Kernels -> Out.Lambda Out.Kernels
-> [SubExp] -> [VName]
-> DistNestT m (Maybe KernelsStms)
regularSegmentedRedomapKernel nest perm segment_size comm lam map_lam nes arrs = do
mk_lvl <- asks distSegLevel
isSegmentedOp nest perm segment_size (freeIn lam) (freeIn map_lam) nes arrs $
\pat ispace inps nes' _ _ -> do
let red_op = SegRedOp comm lam nes' mempty
lvl <- mk_lvl (segment_size : map snd ispace) "segred" $ NoRecommendation SegNoVirt
addStms =<< traverse renameStm =<<
segRed lvl pat segment_size [red_op] map_lam arrs ispace inps
isSegmentedOp :: MonadFreshNames m =>
KernelNest
-> [Int]
-> SubExp
-> Names -> Names
-> [SubExp] -> [VName]
-> (Pattern
-> [(VName, SubExp)]
-> [KernelInput]
-> [SubExp] -> [VName] -> [VName]
-> BinderT Out.Kernels m ())
-> DistNestT m (Maybe KernelsStms)
isSegmentedOp nest perm segment_size free_in_op _free_in_fold_op nes arrs m = runMaybeT $ do
let bound_by_nest = boundInKernelNest nest
(ispace, kernel_inps) <- flatKernel nest
unless (not $ free_in_op `namesIntersect` bound_by_nest) $
fail "Non-fold lambda uses nest-bound parameters."
let indices = map fst ispace
prepareNe (Var v) | v `nameIn` bound_by_nest =
fail "Neutral element bound in nest"
prepareNe ne = return ne
prepareArr arr =
case find ((==arr) . kernelInputName) kernel_inps of
Just inp
| kernelInputIndices inp == map Var indices ->
return $ return $ kernelInputArray inp
| not (kernelInputArray inp `nameIn` bound_by_nest) ->
return $ replicateMissing ispace inp
Nothing | not (arr `nameIn` bound_by_nest) ->
return $
letExp (baseString arr ++ "_repd")
(BasicOp $ Replicate (Shape $ map snd ispace) $ Var arr)
_ ->
fail "Input not free or outermost."
nes' <- mapM prepareNe nes
mk_arrs <- mapM prepareArr arrs
scope <- lift askScope
lift $ lift $ flip runBinderT_ scope $ do
total_num_elements <-
letSubExp "total_num_elements" =<<
foldBinOp (Mul Int32) segment_size (map snd ispace)
let flatten arr = do
arr_shape <- arrayShape <$> lookupType arr
let reshape = reshapeOuter [DimNew total_num_elements]
(2+length (snd nest)) arr_shape
letExp (baseString arr ++ "_flat") $
BasicOp $ Reshape reshape arr
nested_arrs <- sequence mk_arrs
arrs' <- mapM flatten nested_arrs
let pat = Pattern [] $ rearrangeShape perm $
patternValueElements $ loopNestingPattern $ fst nest
m pat ispace kernel_inps nes' nested_arrs arrs'
where replicateMissing ispace inp = do
t <- lookupType $ kernelInputArray inp
let inp_is = kernelInputIndices inp
shapes = determineRepeats ispace inp_is
(outer_shapes, inner_shape) = repeatShapes shapes t
letExp "repeated" $ BasicOp $
Repeat outer_shapes inner_shape $ kernelInputArray inp
determineRepeats ispace (i:is)
| (skipped_ispace, ispace') <- span ((/=i) . Var . fst) ispace =
Shape (map snd skipped_ispace) : determineRepeats (drop 1 ispace') is
determineRepeats ispace _ =
[Shape $ map snd ispace]
permutationAndMissing :: Pattern -> [SubExp] -> Maybe ([Int], [PatElem])
permutationAndMissing pat res = do
let pes = patternValueElements pat
(_used,unused) =
partition ((`nameIn` freeIn res) . patElemName) pes
res_expanded = res ++ map (Var . patElemName) unused
perm <- map (Var . patElemName) pes `isPermutationOf` res_expanded
return (perm, unused)
expandKernelNest :: MonadFreshNames m =>
[PatElem] -> KernelNest -> m KernelNest
expandKernelNest pes (outer_nest, inner_nests) = do
let outer_size = loopNestingWidth outer_nest :
map loopNestingWidth inner_nests
inner_sizes = tails $ map loopNestingWidth inner_nests
outer_nest' <- expandWith outer_nest outer_size
inner_nests' <- zipWithM expandWith inner_nests inner_sizes
return (outer_nest', inner_nests')
where expandWith nest dims = do
pes' <- mapM (expandPatElemWith dims) pes
return nest { loopNestingPattern =
Pattern [] $
patternElements (loopNestingPattern nest) <> pes'
}
expandPatElemWith dims pe = do
name <- newVName $ baseString $ patElemName pe
return pe { patElemName = name
, patElemAttr = patElemType pe `arrayOfShape` Shape dims
}
kernelOrNot :: MonadFreshNames m =>
Certificates -> Stm -> DistAcc
-> PostKernels -> DistAcc -> Maybe KernelsStms
-> DistNestT m DistAcc
kernelOrNot cs bnd acc _ _ Nothing =
addStmToKernel (certify cs bnd) acc
kernelOrNot cs _ _ kernels acc' (Just bnds) = do
addKernels kernels
addKernel $ fmap (certify cs) bnds
return acc'
distributeMap :: MonadFreshNames m => MapLoop -> DistAcc -> DistNestT m DistAcc
distributeMap maploop@(MapLoop pat cs w lam arrs) acc =
distribute =<<
leavingNesting maploop =<<
mapNesting pat cs w lam arrs
(distribute =<< distributeMapBodyStms acc' lam_bnds)
where acc' = DistAcc { distTargets = pushInnerTarget
(pat, bodyResult $ lambdaBody lam) $
distTargets acc
, distStms = mempty
}
lam_bnds = bodyStms $ lambdaBody lam