{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE TypeFamilies #-} -- | It is well known that fully parallel loops can always be -- interchanged inwards with a sequential loop. This module -- implements that transformation. -- -- This is also where we implement loop-switching (for branches), -- which is semantically similar to interchange. module Futhark.Pass.ExtractKernels.Interchange ( SeqLoop (..) , interchangeLoops , Branch (..) , interchangeBranch ) where import Control.Monad.RWS.Strict import qualified Data.Set as S import Data.Maybe import Data.List import Futhark.Pass.ExtractKernels.Distribution (LoopNesting(..), KernelNest, kernelNestLoops) import Futhark.Representation.SOACS import Futhark.MonadFreshNames import Futhark.Transform.Rename import Futhark.Tools -- | An encoding of a sequential do-loop with no existential context, -- alongside its result pattern. data SeqLoop = SeqLoop [Int] Pattern [(FParam, SubExp)] (LoopForm SOACS) Body seqLoopStm :: SeqLoop -> Stm seqLoopStm (SeqLoop _ pat merge form body) = Let pat (defAux ()) $ DoLoop [] merge form body interchangeLoop :: (MonadBinder m, LocalScope SOACS m) => SeqLoop -> LoopNesting -> m SeqLoop interchangeLoop (SeqLoop perm loop_pat merge form body) (MapNesting pat cs w params_and_arrs) = do merge_expanded <- localScope (scopeOfLParams $ map fst params_and_arrs) $ mapM expand merge let loop_pat_expanded = Pattern [] $ map expandPatElem $ patternElements loop_pat new_params = [ Param pname $ fromDecl ptype | (Param pname ptype, _) <- merge ] new_arrs = map (paramName . fst) merge_expanded rettype = map rowType $ patternTypes loop_pat_expanded -- If the map consumes something that is bound outside the loop -- (i.e. is not a merge parameter), we have to copy() it. As a -- small simplification, we just remove the parameter outright if -- it is not used anymore. This might happen if the parameter was -- used just as the inital value of a merge parameter. ((params', arrs'), pre_copy_bnds) <- runBinder $ localScope (scopeOfLParams new_params) $ unzip . catMaybes <$> mapM copyOrRemoveParam params_and_arrs body' <- mkDummyStms (params'<>new_params) body let lam = Lambda (params'<>new_params) body' rettype map_bnd = Let loop_pat_expanded (StmAux cs ()) $ Op $ Screma w (mapSOAC lam) $ arrs' <> new_arrs res = map Var $ patternNames loop_pat_expanded pat' = Pattern [] $ rearrangeShape perm $ patternValueElements pat return $ SeqLoop [0..patternSize pat-1] pat' merge_expanded form $ mkBody (pre_copy_bnds<>oneStm map_bnd) res where free_in_body = freeInBody body copyOrRemoveParam (param, arr) | not (paramName param `S.member` free_in_body) = return Nothing | otherwise = return $ Just (param, arr) expandedInit _ (Var v) | Just (_, arr) <- find ((==v).paramName.fst) params_and_arrs = return $ Var arr expandedInit param_name se = letSubExp (param_name <> "_expanded_init") $ BasicOp $ Replicate (Shape [w]) se expand (merge_param, merge_init) = do expanded_param <- newParam (param_name <> "_expanded") $ arrayOf (paramDeclType merge_param) (Shape [w]) $ uniqueness $ declTypeOf merge_param expanded_init <- expandedInit param_name merge_init return (expanded_param, expanded_init) where param_name = baseString $ paramName merge_param expandPatElem (PatElem name t) = PatElem name $ arrayOfRow t w -- | The kernel extractor cannot handle identity mappings, so -- insert dummy statements for body results that are just a -- lambda parameter. mkDummyStms params (Body () stms res) = do (res', extra_stms) <- unzip <$> mapM dummyStm res return $ Body () (stms<>mconcat extra_stms) res' where dummyStm (Var v) | Just p <- find ((==v) . paramName) params = do dummy <- newVName (baseString v ++ "_dummy") return (Var dummy, oneStm $ Let (Pattern [] [PatElem dummy $ paramType p]) (defAux ()) $ BasicOp $ SubExp $ Var $ paramName p) dummyStm se = return (se, mempty) -- | Given a (parallel) map nesting and an inner sequential loop, move -- the maps inside the sequential loop. The result is several -- statements - one of these will be the loop, which will then contain -- statements with 'Map' expressions. interchangeLoops :: (MonadFreshNames m, HasScope SOACS m) => KernelNest -> SeqLoop -> m (Stms SOACS) interchangeLoops nest loop = do (loop', bnds) <- runBinder $ foldM interchangeLoop loop $ reverse $ kernelNestLoops nest return $ bnds <> oneStm (seqLoopStm loop') data Branch = Branch [Int] Pattern SubExp Body Body (IfAttr (BranchType SOACS)) branchStm :: Branch -> Stm branchStm (Branch _ pat cond tbranch fbranch ret) = Let pat (defAux ()) $ If cond tbranch fbranch ret interchangeBranch1 :: (MonadBinder m, LocalScope SOACS m) => Branch -> LoopNesting -> m Branch interchangeBranch1 (Branch perm branch_pat cond tbranch fbranch (IfAttr ret if_sort)) (MapNesting pat cs w params_and_arrs) = do let ret' = map (`arrayOfRow` Free w) ret pat' = Pattern [] $ rearrangeShape perm $ patternValueElements pat (params, arrs) = unzip params_and_arrs lam_ret = map rowType $ patternTypes pat branch_pat' = Pattern [] $ map (fmap (`arrayOfRow` w)) $ patternElements branch_pat mkBranch branch = (renameBody=<<) $ do branch' <- if null $ bodyStms branch then runBodyBinder $ -- XXX: We need a temporary dummy binding to -- prevent an empty map body. The kernel -- extractor does not like empty map bodies. resultBody <$> mapM dummyBind (bodyResult branch) else return branch let lam = Lambda params branch' lam_ret res = map Var $ patternNames branch_pat' map_bnd = Let branch_pat' (StmAux cs ()) $ Op $ Screma w (mapSOAC lam) arrs return $ mkBody (oneStm map_bnd) res tbranch' <- mkBranch tbranch fbranch' <- mkBranch fbranch return $ Branch [0..patternSize pat-1] pat' cond tbranch' fbranch' $ IfAttr ret' if_sort where dummyBind se = do dummy <- newVName "dummy" letBindNames_ [dummy] (BasicOp $ SubExp se) return $ Var dummy interchangeBranch :: (MonadFreshNames m, HasScope SOACS m) => KernelNest -> Branch -> m (Stms SOACS) interchangeBranch nest loop = do (loop', bnds) <- runBinder $ foldM interchangeBranch1 loop $ reverse $ kernelNestLoops nest return $ bnds <> oneStm (branchStm loop')