{-# OPTIONS_GHC -fno-warn-orphans #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE ConstraintKinds #-} module Futhark.Representation.Kernels.Simplify ( simplifyKernels , simplifyLambda -- * Building blocks , simplifyKernelOp , simplifyKernelExp ) where import Control.Monad import Data.Either import Data.Foldable import Data.List import Data.Maybe import qualified Data.Map.Strict as M import qualified Data.Set as S import Futhark.Representation.Kernels import qualified Futhark.Optimise.Simplify.Engine as Engine import Futhark.Optimise.Simplify.Rules import Futhark.Optimise.Simplify.Lore import Futhark.MonadFreshNames import Futhark.Tools import Futhark.Pass import qualified Futhark.Optimise.Simplify as Simplify import Futhark.Optimise.Simplify.Rule import qualified Futhark.Analysis.SymbolTable as ST import qualified Futhark.Analysis.UsageTable as UT import Futhark.Analysis.Rephrase (castStm) import Futhark.Util (chunks) simpleKernels :: Simplify.SimpleOps Kernels simpleKernels = Simplify.bindableSimpleOps (simplifyKernelOp simpleInKernel inKernelEnv) simpleInKernel :: KernelSpace -> Simplify.SimpleOps InKernel simpleInKernel = Simplify.bindableSimpleOps . simplifyKernelExp simplifyKernels :: Prog Kernels -> PassM (Prog Kernels) simplifyKernels = Simplify.simplifyProg simpleKernels kernelRules Simplify.noExtraHoistBlockers simplifyLambda :: (HasScope InKernel m, MonadFreshNames m) => KernelSpace -> Lambda InKernel -> [Maybe VName] -> m (Lambda InKernel) simplifyLambda kspace = Simplify.simplifyLambda (simpleInKernel kspace) inKernelRules Engine.noExtraHoistBlockers simplifyKernelOp :: (Engine.SimplifiableLore lore, Engine.SimplifiableLore outerlore, BodyAttr outerlore ~ (), BodyAttr lore ~ (), ExpAttr lore ~ ExpAttr outerlore, SameScope lore outerlore, RetType lore ~ RetType outerlore, BranchType lore ~ BranchType outerlore) => (KernelSpace -> Engine.SimpleOps lore) -> Engine.Env lore -> HostOp outerlore (Kernel lore) -> Engine.SimpleM outerlore (HostOp (Wise outerlore) (Kernel (Wise lore)), Stms (Wise outerlore)) simplifyKernelOp mk_ops env (HostOp (Kernel desc space ts kbody)) = do space' <- Engine.simplify space ts' <- mapM Engine.simplify ts outer_vtable <- Engine.askVtable ((kbody_stms, kbody_res), kbody_hoisted) <- Engine.subSimpleM (mk_ops space) env outer_vtable $ do par_blocker <- Engine.asksEngineEnv $ Engine.blockHoistPar . Engine.envHoistBlockers Engine.localVtable (<>scope_vtable) $ Engine.blockIf (Engine.hasFree bound_here `Engine.orIf` Engine.isOp `Engine.orIf` par_blocker `Engine.orIf` Engine.isConsumed) $ simplifyKernelBodyM kbody kbody_hoisted' <- mapM processHoistedStm kbody_hoisted return (HostOp $ Kernel desc space' ts' $ mkWiseKernelBody () kbody_stms kbody_res, kbody_hoisted') where scope = scopeOfKernelSpace space scope_vtable = ST.fromScope scope bound_here = S.fromList $ M.keys scope simplifyKernelOp mk_ops env (HostOp (SegMap space ts body)) = do space' <- Engine.simplify space ts' <- mapM Engine.simplify ts (body', body_hoisted) <- hoistFromBody space' (mk_ops space') env body return (HostOp $ SegMap space' ts' body', body_hoisted) simplifyKernelOp mk_ops env (HostOp (SegRed space reds ts body)) = do ts' <- mapM Engine.simplify ts space' <- Engine.simplify space outer_vtable <- Engine.askVtable (reds', reds_hoisted) <- fmap unzip $ forM reds $ \(SegRedOp comm lam nes shape) -> do (lam', hoisted) <- Engine.subSimpleM (mk_ops space) env outer_vtable $ Engine.localVtable (<>scope_vtable) $ Engine.simplifyLambda lam $ replicate (length nes * 2) Nothing shape' <- Engine.simplify shape nes' <- mapM Engine.simplify nes return (SegRedOp comm lam' nes' shape', hoisted) red_op_hoisted' <- mapM processHoistedStm $ mconcat reds_hoisted (body', body_hoisted) <- hoistFromBody space' (mk_ops space') env body return (HostOp $ SegRed space' reds' ts' body', red_op_hoisted' <> body_hoisted) where scope_vtable = ST.fromScope scope scope = scopeOfKernelSpace space simplifyKernelOp mk_ops env (HostOp (SegScan space scan_op nes ts body)) = do outer_vtable <- Engine.askVtable space' <- Engine.simplify space (scan_op', scan_op_hoisted) <- Engine.subSimpleM (mk_ops space) env outer_vtable $ Engine.localVtable (<>scope_vtable) $ Engine.simplifyLambda scan_op $ replicate (length nes * 2) Nothing scan_op_hoisted' <- mapM processHoistedStm scan_op_hoisted nes' <- mapM Engine.simplify nes ts' <- mapM Engine.simplify ts (body', body_hoisted) <- hoistFromBody space' (mk_ops space') env body return (HostOp $ SegScan space' scan_op' nes' ts' body', scan_op_hoisted' <> body_hoisted) where scope_vtable = ST.fromScope scope scope = scopeOfKernelSpace space simplifyKernelOp mk_ops env (HostOp (SegGenRed space ops ts body)) = do outer_vtable <- Engine.askVtable space' <- Engine.simplify space ts' <- mapM Engine.simplify ts (ops', ops_hoisted) <- fmap unzip $ forM ops $ \(GenReduceOp w arrs nes dims lam) -> do w' <- Engine.simplify w arrs' <- Engine.simplify arrs nes' <- Engine.simplify nes dims' <- Engine.simplify dims (lam', op_hoisted) <- Engine.subSimpleM (mk_ops space) env outer_vtable $ Engine.localVtable (<>scope_vtable) $ Engine.simplifyLambda lam $ replicate (length nes * 2) Nothing return (GenReduceOp w' arrs' nes' dims' lam', op_hoisted) red_op_hoisted' <- mapM processHoistedStm $ mconcat ops_hoisted (body', body_hoisted) <- hoistFromBody space' (mk_ops space') env body return (HostOp $ SegGenRed space' ops' ts' body', red_op_hoisted' <> body_hoisted) where scope_vtable = ST.fromScope scope scope = scopeOfKernelSpace space simplifyKernelOp _ _ (GetSize key size_class) = return (GetSize key size_class, mempty) simplifyKernelOp _ _ (GetSizeMax size_class) = return (GetSizeMax size_class, mempty) simplifyKernelOp _ _ (CmpSizeLe key size_class x) = do x' <- Engine.simplify x return (CmpSizeLe key size_class x', mempty) hoistFromBody :: (Engine.SimplifiableLore lore, SameScope lore outerlore, BodyAttr outerlore ~ (), BodyAttr lore ~ (), ExpAttr lore ~ ExpAttr outerlore, RetType lore ~ RetType outerlore, BranchType lore ~ BranchType outerlore) => KernelSpace -> Simplify.SimpleOps lore -> Engine.Env lore -> KernelBody lore -> Engine.SimpleM outerlore (KernelBody (Wise lore), Stms (Wise outerlore)) hoistFromBody kspace ops env kbody = do outer_vtable <- Engine.askVtable ((body_stms, body_res), body_hoisted) <- Engine.subSimpleM ops env outer_vtable $ do par_blocker <- Engine.asksEngineEnv $ Engine.blockHoistPar . Engine.envHoistBlockers Engine.localVtable (<>scope_vtable) $ Engine.blockIf (Engine.hasFree bound_here `Engine.orIf` Engine.isOp `Engine.orIf` par_blocker `Engine.orIf` Engine.isConsumed) $ simplifyKernelBodyM kbody body_hoisted' <- mapM processHoistedStm body_hoisted return (mkWiseKernelBody () body_stms body_res, body_hoisted') where scope_vtable = ST.fromScope scope scope = scopeOfKernelSpace kspace bound_here = S.fromList $ M.keys scope processHoistedStm :: (Monad m, PrettyLore from, ExpAttr from ~ ExpAttr to, BodyAttr from ~ BodyAttr to, RetType from ~ RetType to, BranchType from ~ BranchType to, LetAttr from ~ LetAttr to, FParamAttr from ~ FParamAttr to, LParamAttr from ~ LParamAttr to) => Stm from -> m (Stm to) processHoistedStm bnd | Just bnd' <- castStm bnd = return bnd' | otherwise = fail $ "Cannot hoist binding: " ++ pretty bnd mkWiseKernelBody :: (Attributes lore, CanBeWise (Op lore)) => BodyAttr lore -> Stms (Wise lore) -> [KernelResult] -> KernelBody (Wise lore) mkWiseKernelBody attr bnds res = let Body attr' _ _ = mkWiseBody attr bnds res_vs in KernelBody attr' bnds res where res_vs = map kernelResultSubExp res inKernelEnv :: Engine.Env InKernel inKernelEnv = Engine.emptyEnv inKernelRules Simplify.noExtraHoistBlockers instance Engine.Simplifiable SplitOrdering where simplify SplitContiguous = return SplitContiguous simplify (SplitStrided stride) = SplitStrided <$> Engine.simplify stride instance Engine.Simplifiable CombineSpace where simplify (CombineSpace scatter cspace) = CombineSpace <$> mapM Engine.simplify scatter <*> mapM (traverse Engine.simplify) cspace simplifyKernelExp :: Engine.SimplifiableLore lore => KernelSpace -> KernelExp lore -> Engine.SimpleM lore (KernelExp (Wise lore), Stms (Wise lore)) simplifyKernelExp _ (Barrier se) = (,) <$> (Barrier <$> Engine.simplify se) <*> pure mempty simplifyKernelExp _ (SplitSpace o w i elems_per_thread) = (,) <$> (SplitSpace <$> Engine.simplify o <*> Engine.simplify w <*> Engine.simplify i <*> Engine.simplify elems_per_thread) <*> pure mempty simplifyKernelExp kspace (Combine cspace ts active body) = do ((body_stms', body_res'), hoisted) <- wrapbody $ Engine.blockIf (Engine.hasFree bound_here `Engine.orIf` maybeBlockUnsafe) $ localScope (scopeOfCombineSpace cspace) $ Engine.simplifyBody (map (const Observe) ts) body body' <- Engine.constructBody body_stms' body_res' (,) <$> (Combine <$> Engine.simplify cspace <*> mapM Engine.simplify ts <*> mapM Engine.simplify active <*> pure body') <*> pure hoisted where bound_here = S.fromList $ M.keys $ scopeOfCombineSpace cspace protectCombineHoisted checkIfActive m = do (x, stms) <- m runBinder $ do if any (not . safeExp . stmExp) stms then do is_active <- checkIfActive mapM_ (Engine.protectIf (not . safeExp) is_active) stms else addStms stms return x (maybeBlockUnsafe, wrapbody) | [d] <- map snd $ cspaceDims cspace, d == spaceGroupSize kspace = (Engine.isFalse True, protectCombineHoisted $ letSubExp "active" =<< foldBinOp LogAnd (constant True) =<< mapM (uncurry check) active) | otherwise = (Engine.isNotSafe, id) check v se = letSubExp "is_active" $ BasicOp $ CmpOp (CmpSlt Int32) (Var v) se simplifyKernelExp _ (GroupReduce w lam input) = do arrs' <- mapM Engine.simplify arrs nes' <- mapM Engine.simplify nes w' <- Engine.simplify w (lam', hoisted) <- Engine.simplifyLambdaSeq lam (map (const Nothing) arrs') return (GroupReduce w' lam' $ zip nes' arrs', hoisted) where (nes,arrs) = unzip input simplifyKernelExp _ (GroupScan w lam input) = do w' <- Engine.simplify w nes' <- mapM Engine.simplify nes arrs' <- mapM Engine.simplify arrs (lam', hoisted) <- Engine.simplifyLambdaSeq lam (map (const Nothing) arrs') return (GroupScan w' lam' $ zip nes' arrs', hoisted) where (nes,arrs) = unzip input simplifyKernelExp _ (GroupGenReduce w dests op bucket vs locks) = do w' <- Engine.simplify w dests' <- mapM Engine.simplify dests (op', hoisted) <- Engine.simplifyLambdaSeq op (map (const Nothing) vs) bucket' <- Engine.simplify bucket vs' <- mapM Engine.simplify vs locks' <- Engine.simplify locks return (GroupGenReduce w' dests' op' bucket' vs' locks', hoisted) simplifyKernelExp _ (GroupStream w maxchunk lam accs arrs) = do w' <- Engine.simplify w maxchunk' <- Engine.simplify maxchunk accs' <- mapM Engine.simplify accs arrs' <- mapM Engine.simplify arrs (lam', hoisted) <- simplifyGroupStreamLambda lam w' maxchunk' arrs' return (GroupStream w' maxchunk' lam' accs' arrs', hoisted) simplifyKernelBodyM :: Engine.SimplifiableLore lore => KernelBody lore -> Engine.SimpleM lore (Engine.SimplifiedBody lore [KernelResult]) simplifyKernelBodyM (KernelBody _ stms res) = Engine.simplifyStms stms $ do res' <- mapM Engine.simplify res return ((res', UT.usages $ freeIn res'), mempty) simplifyGroupStreamLambda :: Engine.SimplifiableLore lore => GroupStreamLambda lore -> SubExp -> SubExp -> [VName] -> Engine.SimpleM lore (GroupStreamLambda (Wise lore), Stms (Wise lore)) simplifyGroupStreamLambda lam w max_chunk arrs = do let GroupStreamLambda block_size block_offset acc_params arr_params body = lam bound_here = S.fromList $ block_size : block_offset : map paramName (acc_params ++ arr_params) ((body_stms', body_res'), hoisted) <- Engine.enterLoop $ Engine.bindLoopVar block_size Int32 max_chunk $ Engine.bindLoopVar block_offset Int32 w $ Engine.bindLParams acc_params $ Engine.bindChunkLParams block_offset (zip arr_params arrs) $ Engine.blockIf (Engine.hasFree bound_here `Engine.orIf` Engine.isConsumed) $ Engine.simplifyBody (replicate (length (bodyResult body)) Observe) body acc_params' <- mapM (Engine.simplifyParam Engine.simplify) acc_params arr_params' <- mapM (Engine.simplifyParam Engine.simplify) arr_params body' <- Engine.constructBody body_stms' body_res' return (GroupStreamLambda block_size block_offset acc_params' arr_params' body', hoisted) instance Engine.Simplifiable KernelSpace where simplify (KernelSpace gtid ltid gid num_threads num_groups group_size virt_groups structure) = KernelSpace gtid ltid gid <$> Engine.simplify num_threads <*> Engine.simplify num_groups <*> Engine.simplify group_size <*> Engine.simplify virt_groups <*> Engine.simplify structure instance Engine.Simplifiable SpaceStructure where simplify (FlatThreadSpace dims) = FlatThreadSpace <$> (zip gtids <$> mapM Engine.simplify gdims) where (gtids, gdims) = unzip dims simplify (NestedThreadSpace dims) = NestedThreadSpace <$> (zip4 gtids <$> mapM Engine.simplify gdims <*> pure ltids <*> mapM Engine.simplify ldims) where (gtids, gdims, ltids, ldims) = unzip4 dims instance Engine.Simplifiable KernelResult where simplify (GroupsReturn what) = GroupsReturn <$> Engine.simplify what simplify (ThreadsReturn what) = ThreadsReturn <$> Engine.simplify what simplify (WriteReturn ws a res) = WriteReturn <$> Engine.simplify ws <*> Engine.simplify a <*> Engine.simplify res simplify (ConcatReturns o w pte moffset what) = ConcatReturns <$> Engine.simplify o <*> Engine.simplify w <*> Engine.simplify pte <*> Engine.simplify moffset <*> Engine.simplify what instance BinderOps (Wise Kernels) where mkExpAttrB = bindableMkExpAttrB mkBodyB = bindableMkBodyB mkLetNamesB = bindableMkLetNamesB instance BinderOps (Wise InKernel) where mkExpAttrB = bindableMkExpAttrB mkBodyB = bindableMkBodyB mkLetNamesB = bindableMkLetNamesB kernelRules :: RuleBook (Wise Kernels) kernelRules = standardRules <> ruleBook [ RuleOp removeInvariantKernelResults , RuleOp mergeSegRedOps] [ RuleOp distributeKernelResults , RuleBasicOp removeUnnecessaryCopy] fuseStreamIota :: TopDownRuleOp (Wise InKernel) fuseStreamIota vtable pat _ (GroupStream w max_chunk lam accs arrs) | ([(iota_cs, iota_param, iota_start, iota_stride, iota_t)], params_and_arrs) <- partitionEithers $ zipWith (isIota vtable) (groupStreamArrParams lam) arrs = do let (arr_params', arrs') = unzip params_and_arrs chunk_size = groupStreamChunkSize lam offset = groupStreamChunkOffset lam body' <- insertStmsM $ inScopeOf lam $ certifying iota_cs $ do -- Convert index to appropriate type. offset' <- asIntS iota_t $ Var offset offset'' <- letSubExp "offset_by_stride" $ BasicOp $ BinOp (Mul iota_t) offset' iota_stride start <- letSubExp "iota_start" $ BasicOp $ BinOp (Add iota_t) offset'' iota_start letBindNames_ [paramName iota_param] $ BasicOp $ Iota (Var chunk_size) start iota_stride iota_t return $ groupStreamLambdaBody lam let lam' = lam { groupStreamArrParams = arr_params', groupStreamLambdaBody = body' } letBind_ pat $ Op $ GroupStream w max_chunk lam' accs arrs' fuseStreamIota _ _ _ _ = cannotSimplify isIota :: ST.SymbolTable lore -> a -> VName -> Either (Certificates, a, SubExp, SubExp, IntType) (a, VName) isIota vtable chunk arr | Just (BasicOp (Iota _ x s it), cs) <- ST.lookupExp arr vtable = Left (cs, chunk, x, s, it) | otherwise = Right (chunk, arr) -- If a kernel produces something invariant to the kernel, turn it -- into a replicate. removeInvariantKernelResults :: TopDownRuleOp (Wise Kernels) removeInvariantKernelResults vtable (Pattern [] kpes) attr (HostOp (Kernel desc space ts (KernelBody _ kstms kres))) = do (ts', kpes', kres') <- unzip3 <$> filterM checkForInvarianceResult (zip3 ts kpes kres) -- Check if we did anything at all. when (kres == kres') cannotSimplify addStm $ Let (Pattern [] kpes') attr $ Op $ HostOp $ Kernel desc space ts' $ mkWiseKernelBody () kstms kres' where isInvariant Constant{} = True isInvariant (Var v) = isJust $ ST.lookup v vtable space_dims = map snd $ spaceDimensions space checkForInvarianceResult (_, pe, ThreadsReturn se) | isInvariant se = do let rep a d = BasicOp . Replicate (Shape [d]) <$> letSubExp "rep" a letBindNames_ [patElemName pe] =<< foldM rep (BasicOp (SubExp se)) (reverse space_dims) return False checkForInvarianceResult _ = return True removeInvariantKernelResults _ _ _ _ = cannotSimplify -- Some kernel results can be moved outside the kernel, which can -- simplify further analysis. distributeKernelResults :: BottomUpRuleOp (Wise Kernels) distributeKernelResults (vtable, used) (Pattern [] kpes) attr (HostOp (Kernel desc kspace kts (KernelBody _ kstms kres))) = do -- Iterate through the bindings. For each, we check whether it is -- in kres and can be moved outside. If so, we remove it from kres -- and kpes and make it a binding outside. (kpes', kts', kres', kstms_rev) <- localScope (scopeOfKernelSpace kspace) $ foldM distribute (kpes, kts, kres, []) kstms when (kpes' == kpes) cannotSimplify addStm $ Let (Pattern [] kpes') attr $ Op $ HostOp $ Kernel desc kspace kts' $ mkWiseKernelBody () (stmsFromList $ reverse kstms_rev) kres' where free_in_kstms = fold $ fmap freeIn kstms distribute (kpes', kts', kres', kstms_rev) bnd | Let (Pattern [] [pe]) _ (BasicOp (Index arr slice)) <- bnd, kspace_slice <- map (DimFix . Var . fst) $ spaceDimensions kspace, kspace_slice `isPrefixOf` slice, remaining_slice <- drop (length kspace_slice) slice, all (isJust . flip ST.lookup vtable) $ S.toList $ freeIn arr <> freeIn remaining_slice, Just (kpe, kpes'', kts'', kres'') <- isResult kpes' kts' kres' pe = do let outer_slice = map (\(_, d) -> DimSlice (constant (0::Int32)) d (constant (1::Int32))) $ spaceDimensions kspace index kpe' = letBind_ (Pattern [] [kpe']) $ BasicOp $ Index arr $ outer_slice <> remaining_slice if patElemName kpe `UT.isConsumed` used then do precopy <- newVName $ baseString (patElemName kpe) <> "_precopy" index kpe { patElemName = precopy } letBind_ (Pattern [] [kpe]) $ BasicOp $ Copy precopy else index kpe return (kpes'', kts'', kres'', if patElemName pe `S.member` free_in_kstms then bnd : kstms_rev else kstms_rev) distribute (kpes', kts', kres', kstms_rev) bnd = return (kpes', kts', kres', bnd : kstms_rev) isResult kpes' kts' kres' pe = case partition matches $ zip3 kpes' kts' kres' of ([(kpe,_,_)], kpes_and_kres) | (kpes'', kts'', kres'') <- unzip3 kpes_and_kres -> Just (kpe, kpes'', kts'', kres'') _ -> Nothing where matches (_, _, kre) = kre == ThreadsReturn (Var $ patElemName pe) distributeKernelResults _ _ _ _ = cannotSimplify -- If a SegRed contains two reduction operations that have the same -- vector shape, merge them together. This saves on communication -- overhead, but can in principle lead to more local memory usage. mergeSegRedOps :: TopDownRuleOp (Wise Kernels) mergeSegRedOps _ (Pattern [] pes) _ (HostOp (SegRed space ops ts kbody)) | length ops > 1, op_groupings <- groupBy sameShape $ zip ops $ chunks (map (length . segRedNeutral) ops) $ zip3 red_pes red_ts red_res, any ((>1) . length) op_groupings = do let (ops', aux) = unzip $ mapMaybe combineOps op_groupings (red_pes', red_ts', red_res') = unzip3 $ concat aux pes' = red_pes' ++ map_pes ts' = red_ts' ++ map_ts kbody' = kbody { kernelBodyResult = red_res' ++ map_res } letBind_ (Pattern [] pes') $ Op $ HostOp $ SegRed space ops' ts' kbody' where (red_pes, map_pes) = splitAt (segRedResults ops) pes (red_ts, map_ts) = splitAt (segRedResults ops) ts (red_res, map_res) = splitAt (segRedResults ops) $ kernelBodyResult kbody sameShape (op1, _) (op2, _) = segRedShape op1 == segRedShape op2 combineOps :: [(SegRedOp (Wise InKernel), [a])] -> Maybe (SegRedOp (Wise InKernel), [a]) combineOps [] = Nothing combineOps (x:xs) = Just $ foldl' combine x xs combine (op1, op1_aux) (op2, op2_aux) = let lam1 = segRedLambda op1 lam2 = segRedLambda op2 (op1_xparams, op1_yparams) = splitAt (length (segRedNeutral op1)) $ lambdaParams lam1 (op2_xparams, op2_yparams) = splitAt (length (segRedNeutral op2)) $ lambdaParams lam2 lam = Lambda { lambdaParams = op1_xparams ++ op2_xparams ++ op1_yparams ++ op2_yparams , lambdaReturnType = lambdaReturnType lam1 ++ lambdaReturnType lam2 , lambdaBody = mkBody (bodyStms (lambdaBody lam1) <> bodyStms (lambdaBody lam2)) $ bodyResult (lambdaBody lam1) <> bodyResult (lambdaBody lam2) } in (SegRedOp { segRedComm = segRedComm op1 <> segRedComm op2 , segRedLambda = lam , segRedNeutral = segRedNeutral op1 ++ segRedNeutral op2 , segRedShape = segRedShape op1 -- Same as shape of op2 due to the grouping. }, op1_aux ++ op2_aux) mergeSegRedOps _ _ _ _ = cannotSimplify simplifyKnownIterationStream :: TopDownRuleOp (Wise InKernel) -- Remove GroupStreams over single-element arrays. Not much to stream -- here, and no information to exploit. simplifyKnownIterationStream _ pat _ (GroupStream (Constant v) _ lam accs arrs) | oneIsh v = do let GroupStreamLambda chunk_size chunk_offset acc_params arr_params body = lam letBindNames_ [chunk_size] $ BasicOp $ SubExp $ constant (1::Int32) letBindNames_ [chunk_offset] $ BasicOp $ SubExp $ constant (0::Int32) forM_ (zip acc_params accs) $ \(p,a) -> letBindNames_ [paramName p] $ BasicOp $ SubExp a forM_ (zip arr_params arrs) $ \(p,a) -> letBindNames_ [paramName p] $ BasicOp $ Index a $ fullSlice (paramType p) [DimSlice (Var chunk_offset) (Var chunk_size) (constant (1::Int32))] res <- bodyBind body forM_ (zip (patternElements pat) res) $ \(pe,r) -> letBindNames_ [patElemName pe] $ BasicOp $ SubExp r simplifyKnownIterationStream _ _ _ _ = cannotSimplify removeUnusedStreamInputs :: TopDownRuleOp (Wise InKernel) removeUnusedStreamInputs _ pat _ (GroupStream w maxchunk lam accs arrs) | (used,unused) <- partition (isUsed . paramName . fst) $ zip arr_params arrs, not $ null unused = do let (arr_params', arrs') = unzip used lam' = GroupStreamLambda chunk_size chunk_offset acc_params arr_params' body letBind_ pat $ Op $ GroupStream w maxchunk lam' accs arrs' where GroupStreamLambda chunk_size chunk_offset acc_params arr_params body = lam isUsed = (`S.member` freeIn body) removeUnusedStreamInputs _ _ _ _ = cannotSimplify inKernelRules :: RuleBook (Wise InKernel) inKernelRules = standardRules <> ruleBook [RuleOp fuseStreamIota, RuleOp simplifyKnownIterationStream, RuleOp removeUnusedStreamInputs] []