{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE ConstraintKinds #-} module Futhark.Representation.SOACS.SOAC ( SOAC(..) , StreamForm(..) , ScremaForm(..) , GenReduceOp(..) , Scan , Reduce -- * Utility , getStreamOrder , getStreamAccums , scremaType , soacType , typeCheckSOAC , mkIdentityLambda , isIdentityLambda , composeLambda , nilFn , scanomapSOAC , redomapSOAC , scanSOAC , reduceSOAC , mapSOAC , isScanomapSOAC , isRedomapSOAC , isScanSOAC , isReduceSOAC , isMapSOAC , ppScrema , ppGenReduce -- * Generic traversal , SOACMapper(..) , identitySOACMapper , mapSOACM ) where import Control.Monad.Writer import Control.Monad.Identity import qualified Data.Map.Strict as M import qualified Data.Set as S import Data.Maybe import Data.List import Futhark.Representation.AST import qualified Futhark.Analysis.Alias as Alias import qualified Futhark.Util.Pretty as PP import Futhark.Util.Pretty (ppr, Doc, Pretty, parens, comma, (), commasep, text) import Futhark.Representation.AST.Attributes.Aliases import Futhark.Transform.Substitute import Futhark.Transform.Rename import Futhark.Optimise.Simplify.Lore import Futhark.Representation.Ranges (Ranges, removeLambdaRanges) import Futhark.Representation.AST.Attributes.Ranges import Futhark.Representation.Aliases (Aliases, removeLambdaAliases) import Futhark.Analysis.Usage import qualified Futhark.Analysis.SymbolTable as ST import Futhark.Analysis.PrimExp.Convert import qualified Futhark.TypeCheck as TC import Futhark.Analysis.Metrics import qualified Futhark.Analysis.Range as Range import Futhark.Construct import Futhark.Util (maybeNth, chunks, splitAt3) data SOAC lore = Stream SubExp (StreamForm lore) (LambdaT lore) [VName] | Scatter SubExp (LambdaT lore) [VName] [(SubExp, Int, VName)] -- Scatter -- -- -- -- is the length of each index array and value array, since they -- all must be the same length for any fusion to make sense. If you have a -- list of index-value array pairs of different sizes, you need to use -- multiple writes instead. -- -- The lambda body returns the output in this manner: -- -- [index_0, index_1, ..., index_n, value_0, value_1, ..., value_n] -- -- This must be consistent along all Scatter-related optimisations. | GenReduce SubExp [GenReduceOp lore] (LambdaT lore) [VName] -- GenReduce -- -- The first SubExp is the length of the input arrays. The first -- list describes the operations to perform. The 'LambdaT' is the -- bucket function. Finally comes the input images. | Screma SubExp (ScremaForm lore) [VName] -- ^ A combination of scan, reduction, and map. The first -- 'SubExp' is the size of the input arrays. The first -- 'Lambda'/'SubExp' pair is for scan and its neutral elements. -- The second is for the reduction. The final lambda is for the -- map part, and finally comes the input arrays. | CmpThreshold SubExp String deriving (Eq, Ord, Show) data GenReduceOp lore = GenReduceOp { genReduceWidth :: SubExp , genReduceDest :: [VName] , genReduceNeutral :: [SubExp] , genReduceOp :: LambdaT lore } deriving (Eq, Ord, Show) data StreamForm lore = Parallel StreamOrd Commutativity (LambdaT lore) [SubExp] | Sequential [SubExp] deriving (Eq, Ord, Show) -- | The essential parts of a 'Screma' factored out (everything -- except the input arrays). data ScremaForm lore = ScremaForm (Scan lore) (Reduce lore) (LambdaT lore) deriving (Eq, Ord, Show) type Scan lore = (LambdaT lore, [SubExp]) type Reduce lore = (Commutativity, LambdaT lore, [SubExp]) scremaType :: SubExp -> ScremaForm lore -> [Type] scremaType w (ScremaForm (scan_lam, _scan_nes) (_, red_lam, _red_nes) map_lam) = map (`arrayOfRow` w) scan_tps ++ red_tps ++ map (`arrayOfRow` w) map_tps where scan_tps = lambdaReturnType scan_lam red_tps = lambdaReturnType red_lam map_tps = drop (length scan_tps + length red_tps) $ lambdaReturnType map_lam -- | Construct a lambda that takes parameters of the given types and -- simply returns them unchanged. mkIdentityLambda :: (Bindable lore, MonadFreshNames m) => [Type] -> m (Lambda lore) mkIdentityLambda ts = do params <- mapM (newParam "x") ts return Lambda { lambdaParams = params , lambdaBody = mkBody mempty $ map (Var . paramName) params , lambdaReturnType = ts } -- | Is the given lambda an identity lambda? isIdentityLambda :: Lambda lore -> Bool isIdentityLambda lam = bodyResult (lambdaBody lam) == map (Var . paramName) (lambdaParams lam) composeLambda :: (Bindable lore, BinderOps lore, MonadFreshNames m, HasScope somelore m, SameScope somelore lore) => Lambda lore -> Lambda lore -> Lambda lore -> m (Lambda lore) composeLambda scan_fun red_fun map_fun = do body <- runBodyBinder $ inScopeOf scan_fun $ inScopeOf red_fun $ inScopeOf map_fun $ do mapM_ addStm $ bodyStms $ lambdaBody map_fun let (scan_res, red_res, map_res) = splitAt3 n m $ bodyResult $ lambdaBody map_fun forM_ (zip scan_y_params scan_res) $ \(p,se) -> letBindNames_ [paramName p] $ BasicOp $ SubExp se forM_ (zip red_y_params red_res) $ \(p,se) -> letBindNames_ [paramName p] $ BasicOp $ SubExp se mapM_ addStm $ bodyStms $ lambdaBody scan_fun mapM_ addStm $ bodyStms $ lambdaBody red_fun resultBodyM $ bodyResult (lambdaBody scan_fun) ++ bodyResult (lambdaBody red_fun) ++ map_res return Lambda { lambdaParams = scan_x_params ++ red_x_params ++ lambdaParams map_fun , lambdaBody = body , lambdaReturnType = lambdaReturnType map_fun } where n = length $ lambdaReturnType scan_fun m = length $ lambdaReturnType red_fun (scan_x_params, scan_y_params) = splitAt n $ lambdaParams scan_fun (red_x_params, red_y_params) = splitAt m $ lambdaParams red_fun -- | A lambda with no parameters that returns no values. nilFn :: Bindable lore => LambdaT lore nilFn = Lambda mempty (mkBody mempty mempty) mempty isNilFn :: LambdaT lore -> Bool isNilFn (Lambda ps body ts) = null ps && null ts && null (bodyStms body) && null (bodyResult body) scanomapSOAC :: Bindable lore => Lambda lore -> [SubExp] -> Lambda lore -> ScremaForm lore scanomapSOAC lam nes = ScremaForm (lam, nes) (mempty, nilFn, mempty) redomapSOAC :: Bindable lore => Commutativity -> Lambda lore -> [SubExp] -> Lambda lore -> ScremaForm lore redomapSOAC comm lam nes = ScremaForm (nilFn, mempty) (comm, lam, nes) scanSOAC :: (Bindable lore, MonadFreshNames m) => Lambda lore -> [SubExp] -> m (ScremaForm lore) scanSOAC lam nes = scanomapSOAC lam nes <$> mkIdentityLambda (lambdaReturnType lam) reduceSOAC :: (Bindable lore, MonadFreshNames m) => Commutativity -> Lambda lore -> [SubExp] -> m (ScremaForm lore) reduceSOAC comm lam nes = redomapSOAC comm lam nes <$> mkIdentityLambda (lambdaReturnType lam) mapSOAC :: Bindable lore => Lambda lore -> ScremaForm lore mapSOAC = ScremaForm (nilFn, mempty) (mempty, nilFn, mempty) isScanomapSOAC :: ScremaForm lore -> Maybe (Lambda lore, [SubExp], Lambda lore) isScanomapSOAC (ScremaForm (scan_lam, scan_nes) (_, _, red_nes) map_lam) = do guard $ null red_nes guard $ not $ null scan_nes return (scan_lam, scan_nes, map_lam) isScanSOAC :: ScremaForm lore -> Maybe (Lambda lore, [SubExp]) isScanSOAC form = do (scan_lam, scan_nes, map_lam) <- isScanomapSOAC form guard $ isIdentityLambda map_lam return (scan_lam, scan_nes) isRedomapSOAC :: ScremaForm lore -> Maybe (Commutativity, Lambda lore, [SubExp], Lambda lore) isRedomapSOAC (ScremaForm (_, scan_nes) (comm, red_lam, red_nes) map_lam) = do guard $ null scan_nes guard $ not $ null red_nes return (comm, red_lam, red_nes, map_lam) isReduceSOAC :: ScremaForm lore -> Maybe (Commutativity, Lambda lore, [SubExp]) isReduceSOAC form = do (comm, red_lam, red_nes, map_lam) <- isRedomapSOAC form guard $ isIdentityLambda map_lam return (comm, red_lam, red_nes) isMapSOAC :: ScremaForm lore -> Maybe (Lambda lore) isMapSOAC (ScremaForm (_, scan_nes) (_, _, red_nes) map_lam) = do guard $ null scan_nes guard $ null red_nes return map_lam -- | Like 'Mapper', but just for 'SOAC's. data SOACMapper flore tlore m = SOACMapper { mapOnSOACSubExp :: SubExp -> m SubExp , mapOnSOACLambda :: Lambda flore -> m (Lambda tlore) , mapOnSOACVName :: VName -> m VName } -- | A mapper that simply returns the SOAC verbatim. identitySOACMapper :: Monad m => SOACMapper lore lore m identitySOACMapper = SOACMapper { mapOnSOACSubExp = return , mapOnSOACLambda = return , mapOnSOACVName = return } -- | Map a monadic action across the immediate children of a -- SOAC. The mapping does not descend recursively into subexpressions -- and is done left-to-right. mapSOACM :: (Applicative m, Monad m) => SOACMapper flore tlore m -> SOAC flore -> m (SOAC tlore) mapSOACM tv (Stream size form lam arrs) = Stream <$> mapOnSOACSubExp tv size <*> mapOnStreamForm form <*> mapOnSOACLambda tv lam <*> mapM (mapOnSOACVName tv) arrs where mapOnStreamForm (Parallel o comm lam0 acc) = Parallel <$> pure o <*> pure comm <*> mapOnSOACLambda tv lam0 <*> mapM (mapOnSOACSubExp tv) acc mapOnStreamForm (Sequential acc) = Sequential <$> mapM (mapOnSOACSubExp tv) acc mapSOACM tv (Scatter len lam ivs as) = Scatter <$> mapOnSOACSubExp tv len <*> mapOnSOACLambda tv lam <*> mapM (mapOnSOACVName tv) ivs <*> mapM (\(aw,an,a) -> (,,) <$> mapOnSOACSubExp tv aw <*> pure an <*> mapOnSOACVName tv a) as mapSOACM tv (GenReduce len ops bucket_fun imgs) = GenReduce <$> mapOnSOACSubExp tv len <*> mapM (\(GenReduceOp e arrs nes op) -> GenReduceOp <$> mapOnSOACSubExp tv e <*> mapM (mapOnSOACVName tv) arrs <*> mapM (mapOnSOACSubExp tv) nes <*> mapOnSOACLambda tv op) ops <*> mapOnSOACLambda tv bucket_fun <*> mapM (mapOnSOACVName tv) imgs mapSOACM tv (Screma w (ScremaForm (scan_lam, scan_nes) (comm, red_lam, red_nes) map_lam) arrs) = Screma <$> mapOnSOACSubExp tv w <*> (ScremaForm <$> ((,) <$> mapOnSOACLambda tv scan_lam <*> mapM (mapOnSOACSubExp tv) scan_nes) <*> ((,,) comm <$> mapOnSOACLambda tv red_lam <*> mapM (mapOnSOACSubExp tv) red_nes) <*> mapOnSOACLambda tv map_lam) <*> mapM (mapOnSOACVName tv) arrs mapSOACM tv (CmpThreshold what s) = CmpThreshold <$> mapOnSOACSubExp tv what <*> pure s instance Attributes lore => FreeIn (SOAC lore) where freeIn = execWriter . mapSOACM free where walk f x = tell (f x) >> return x free = SOACMapper { mapOnSOACSubExp = walk freeIn , mapOnSOACLambda = walk freeInLambda , mapOnSOACVName = walk freeIn } instance Attributes lore => Substitute (SOAC lore) where substituteNames subst = runIdentity . mapSOACM substitute where substitute = SOACMapper { mapOnSOACSubExp = return . substituteNames subst , mapOnSOACLambda = return . substituteNames subst , mapOnSOACVName = return . substituteNames subst } instance Attributes lore => Rename (SOAC lore) where rename = mapSOACM renamer where renamer = SOACMapper rename rename rename soacType :: SOAC lore -> [Type] soacType (Stream outersize form lam _) = map (substNamesInType substs) rtp where nms = map paramName $ take (1 + length accs) params substs = M.fromList $ zip nms (outersize:accs) Lambda params _ rtp = lam accs = case form of Parallel _ _ _ acc -> acc Sequential acc -> acc soacType (Scatter _w lam _ivs as) = zipWith arrayOfRow val_ts ws where val_ts = concatMap (take 1) $ chunks ns $ drop (sum ns) $ lambdaReturnType lam (ws, ns, _) = unzip3 as soacType (GenReduce _len ops _bucket_fun _imgs) = do op <- ops map (`arrayOfRow` genReduceWidth op) (lambdaReturnType $ genReduceOp op) soacType (Screma w form _arrs) = scremaType w form soacType CmpThreshold{} = [Prim Bool] instance TypedOp (SOAC lore) where opType = pure . staticShapes . soacType instance (Attributes lore, Aliased lore) => AliasedOp (SOAC lore) where opAliases = map (const mempty) . soacType -- Only map functions can consume anything. The operands to scan -- and reduce functions are always considered "fresh". consumedInOp (Screma _ (ScremaForm _ _ map_lam) arrs) = S.map consumedArray $ consumedByLambda map_lam where consumedArray v = fromMaybe v $ lookup v params_to_arrs params_to_arrs = zip (map paramName $ lambdaParams map_lam) arrs consumedInOp (Stream _ form lam arrs) = S.fromList $ subExpVars $ case form of Sequential accs -> map (consumedArray accs) $ S.toList $ consumedByLambda lam Parallel _ _ _ accs -> map (consumedArray accs) $ S.toList $ consumedByLambda lam where consumedArray accs v = fromMaybe (Var v) $ lookup v $ paramsToInput accs -- Drop the chunk parameter, which cannot alias anything. paramsToInput accs = zip (map paramName $ drop 1 $ lambdaParams lam) (accs++map Var arrs) consumedInOp (Scatter _ _ _ as) = S.fromList $ map (\(_, _, a) -> a) as consumedInOp (GenReduce _ ops _ _) = S.fromList $ concatMap genReduceDest ops consumedInOp CmpThreshold{} = mempty mapGenReduceOp :: (LambdaT flore -> LambdaT tlore) -> GenReduceOp flore -> GenReduceOp tlore mapGenReduceOp f (GenReduceOp w dests nes lam) = GenReduceOp w dests nes $ f lam instance (Attributes lore, Attributes (Aliases lore), CanBeAliased (Op lore)) => CanBeAliased (SOAC lore) where type OpWithAliases (SOAC lore) = SOAC (Aliases lore) addOpAliases (Stream size form lam arr) = Stream size (analyseStreamForm form) (Alias.analyseLambda lam) arr where analyseStreamForm (Parallel o comm lam0 acc) = Parallel o comm (Alias.analyseLambda lam0) acc analyseStreamForm (Sequential acc) = Sequential acc addOpAliases (Scatter len lam ivs as) = Scatter len (Alias.analyseLambda lam) ivs as addOpAliases (GenReduce len ops bucket_fun imgs) = GenReduce len (map (mapGenReduceOp Alias.analyseLambda) ops) (Alias.analyseLambda bucket_fun) imgs addOpAliases (Screma w (ScremaForm (scan_lam, scan_nes) (comm, red_lam, red_nes) map_lam) arrs) = Screma w (ScremaForm (Alias.analyseLambda scan_lam, scan_nes) (comm, Alias.analyseLambda red_lam, red_nes) (Alias.analyseLambda map_lam)) arrs addOpAliases (CmpThreshold what s) = CmpThreshold what s removeOpAliases = runIdentity . mapSOACM remove where remove = SOACMapper return (return . removeLambdaAliases) return instance Attributes lore => IsOp (SOAC lore) where safeOp CmpThreshold{} = True safeOp _ = False cheapOp _ = True substNamesInType :: M.Map VName SubExp -> Type -> Type substNamesInType _ tp@(Prim _) = tp substNamesInType subs (Mem se space) = Mem (substNamesInSubExp subs se) space substNamesInType subs (Array btp shp u) = let shp' = Shape $ map (substNamesInSubExp subs) (shapeDims shp) in Array btp shp' u substNamesInSubExp :: M.Map VName SubExp -> SubExp -> SubExp substNamesInSubExp _ e@(Constant _) = e substNamesInSubExp subs (Var idd) = M.findWithDefault (Var idd) idd subs instance (Ranged inner) => RangedOp (SOAC inner) where opRanges op = replicate (length $ soacType op) unknownRange instance (Attributes lore, CanBeRanged (Op lore)) => CanBeRanged (SOAC lore) where type OpWithRanges (SOAC lore) = SOAC (Ranges lore) removeOpRanges = runIdentity . mapSOACM remove where remove = SOACMapper return (return . removeLambdaRanges) return addOpRanges (Stream w form lam arr) = Stream w (Range.runRangeM $ analyseStreamForm form) (Range.runRangeM $ Range.analyseLambda lam) arr where analyseStreamForm (Sequential acc) = return $ Sequential acc analyseStreamForm (Parallel o comm lam0 acc) = do lam0' <- Range.analyseLambda lam0 return $ Parallel o comm lam0' acc addOpRanges (Scatter len lam ivs as) = Scatter len (Range.runRangeM $ Range.analyseLambda lam) ivs as addOpRanges (GenReduce len ops bucket_fun imgs) = GenReduce len (map (mapGenReduceOp $ Range.runRangeM . Range.analyseLambda) ops) (Range.runRangeM $ Range.analyseLambda bucket_fun) imgs addOpRanges (Screma w (ScremaForm (scan_lam, scan_nes) (comm, red_lam, red_nes) map_lam) arrs) = Screma w (ScremaForm (Range.runRangeM $ Range.analyseLambda scan_lam, scan_nes) (comm, Range.runRangeM $ Range.analyseLambda red_lam, red_nes) (Range.runRangeM $ Range.analyseLambda map_lam)) arrs addOpRanges (CmpThreshold what s) = CmpThreshold what s instance (Attributes lore, CanBeWise (Op lore)) => CanBeWise (SOAC lore) where type OpWithWisdom (SOAC lore) = SOAC (Wise lore) removeOpWisdom = runIdentity . mapSOACM remove where remove = SOACMapper return (return . removeLambdaWisdom) return instance Annotations lore => ST.IndexOp (SOAC lore) where indexOp vtable k soac [i] = do (lam,se,arr_params,arrs) <- lambdaAndSubExp soac let arr_indexes = M.fromList $ catMaybes $ zipWith arrIndex arr_params arrs arr_indexes' = foldl expandPrimExpTable arr_indexes $ bodyStms $ lambdaBody lam case se of Var v -> M.lookup v arr_indexes' _ -> Nothing where lambdaAndSubExp (Screma _ (ScremaForm (_, scan_nes) (_, _, red_nes) map_lam) arrs) = nthMapOut (length scan_nes + length red_nes) map_lam arrs lambdaAndSubExp _ = Nothing nthMapOut num_accs lam arrs = do se <- maybeNth (num_accs+k) $ bodyResult $ lambdaBody lam return (lam, se, drop num_accs $ lambdaParams lam, arrs) arrIndex p arr = do (pe,cs) <- ST.index' arr [i] vtable return (paramName p, (pe,cs)) expandPrimExpTable table stm | [v] <- patternNames $ stmPattern stm, Just (pe,cs) <- runWriterT $ primExpFromExp (asPrimExp table) $ stmExp stm, all (`ST.elem` vtable) (unCertificates $ stmCerts stm) = M.insert v (pe, stmCerts stm <> cs) table | otherwise = table asPrimExp table v | Just (e,cs) <- M.lookup v table = tell cs >> return e | Just (Prim pt) <- ST.lookupType v vtable = return $ LeafExp v pt | otherwise = lift Nothing indexOp _ _ _ _ = Nothing instance Aliased lore => UsageInOp (SOAC lore) where usageInOp (Screma _ (ScremaForm _ _ f) arrs) = usageInLambda f arrs usageInOp _ = mempty typeCheckSOAC :: TC.Checkable lore => SOAC (Aliases lore) -> TC.TypeM lore () typeCheckSOAC (CmpThreshold what _) = TC.require [Prim int32] what typeCheckSOAC (Stream size form lam arrexps) = do let accexps = getStreamAccums form TC.require [Prim int32] size accargs <- mapM TC.checkArg accexps arrargs <- mapM lookupType arrexps _ <- TC.checkSOACArrayArgs size arrexps let chunk = head $ lambdaParams lam let asArg t = (t, mempty) inttp = Prim int32 lamarrs'= map (`setOuterSize` Var (paramName chunk)) arrargs let acc_len= length accexps let lamrtp = take acc_len $ lambdaReturnType lam unless (map TC.argType accargs == lamrtp) $ TC.bad $ TC.TypeError "Stream with inconsistent accumulator type in lambda." -- check reduce's lambda, if any _ <- case form of Parallel _ _ lam0 _ -> do let acct = map TC.argType accargs outerRetType = lambdaReturnType lam0 TC.checkLambda lam0 $ map TC.noArgAliases $ accargs ++ accargs unless (acct == outerRetType) $ TC.bad $ TC.TypeError $ "Initial value is of type " ++ prettyTuple acct ++ ", but stream's reduce lambda returns type " ++ prettyTuple outerRetType ++ "." _ -> return () -- just get the dflow of lambda on the fakearg, which does not alias -- arr, so we can later check that aliases of arr are not used inside lam. let fake_lamarrs' = map asArg lamarrs' TC.checkLambda lam $ asArg inttp : accargs ++ fake_lamarrs' typeCheckSOAC (Scatter w lam ivs as) = do -- Requirements: -- -- 0. @lambdaReturnType@ of @lam@ must be a list -- [index types..., value types]. -- -- 1. The number of index types must be equal to the number of value types -- and the number of writes to arrays in @as@. -- -- 2. Each index type must have the type i32. -- -- 3. Each array in @as@ and the value types must have the same type -- -- 4. Each array in @as@ is consumed. This is not really a check, but more -- of a requirement, so that e.g. the source is not hoisted out of a -- loop, which will mean it cannot be consumed. -- -- 5. Each of ivs must be an array matching a corresponding lambda -- parameters. -- -- Code: -- First check the input size. TC.require [Prim int32] w -- 0. let (_as_ws, as_ns, _as_vs) = unzip3 as rts = lambdaReturnType lam rtsLen = length rts `div` 2 rtsI = take rtsLen rts rtsV = drop rtsLen rts -- 1. unless (rtsLen == sum as_ns) $ TC.bad $ TC.TypeError "Scatter: Uneven number of index types, value types, and arrays outputs." -- 2. forM_ rtsI $ \rtI -> unless (Prim int32 == rtI) $ TC.bad $ TC.TypeError "Scatter: Index return type must be i32." forM_ (zip (chunks as_ns rtsV) as) $ \(rtVs, (aw, _, a)) -> do -- All lengths must have type i32. TC.require [Prim int32] aw -- 3. forM_ rtVs $ \rtV -> TC.requireI [rtV `arrayOfRow` aw] a -- 4. TC.consume =<< TC.lookupAliases a -- 5. arrargs <- TC.checkSOACArrayArgs w ivs TC.checkLambda lam arrargs typeCheckSOAC (GenReduce len ops bucket_fun imgs) = do TC.require [Prim int32] len -- Check the operators. forM_ ops $ \(GenReduceOp dest_w dests nes op) -> do nes' <- mapM TC.checkArg nes TC.require [Prim int32] dest_w -- Operator type must match the type of neutral elements. TC.checkLambda op $ map TC.noArgAliases $ nes' ++ nes' let nes_t = map TC.argType nes' unless (nes_t == lambdaReturnType op) $ TC.bad $ TC.TypeError $ "Operator has return type " ++ prettyTuple (lambdaReturnType op) ++ " but neutral element has type " ++ prettyTuple nes_t -- Arrays must have proper type. forM_ (zip nes_t dests) $ \(t, dest) -> do TC.requireI [t `arrayOfRow` dest_w] dest TC.consume =<< TC.lookupAliases dest -- Types of input arrays must equal parameter types for bucket function. img' <- TC.checkSOACArrayArgs len imgs TC.checkLambda bucket_fun img' -- Return type of bucket function must be an index for each -- operation followed by the values to write. nes_ts <- concat <$> mapM (mapM subExpType . genReduceNeutral) ops let bucket_ret_t = replicate (length ops) (Prim int32) ++ nes_ts unless (bucket_ret_t == lambdaReturnType bucket_fun) $ TC.bad $ TC.TypeError $ "Bucket function has return type " ++ prettyTuple (lambdaReturnType bucket_fun) ++ " but should have type " ++ prettyTuple bucket_ret_t typeCheckSOAC (Screma w (ScremaForm (scan_lam, scan_nes) (_, red_lam, red_nes) map_lam) arrs) = do TC.require [Prim int32] w arrs' <- TC.checkSOACArrayArgs w arrs scan_nes' <- mapM TC.checkArg scan_nes red_nes' <- mapM TC.checkArg red_nes TC.checkLambda map_lam $ map TC.noArgAliases arrs' TC.checkLambda scan_lam $ map TC.noArgAliases $ scan_nes' ++ scan_nes' TC.checkLambda red_lam $ map TC.noArgAliases $ red_nes' ++ red_nes' let scan_t = map TC.argType scan_nes' red_t = map TC.argType red_nes' map_lam_ts = lambdaReturnType map_lam unless (scan_t == lambdaReturnType scan_lam) $ TC.bad $ TC.TypeError $ "Scan function returns type " ++ prettyTuple (lambdaReturnType scan_lam) ++ " but neutral element has type " ++ prettyTuple scan_t unless (red_t == lambdaReturnType red_lam) $ TC.bad $ TC.TypeError $ "Reduce function returns type " ++ prettyTuple (lambdaReturnType red_lam) ++ " but neutral element has type " ++ prettyTuple red_t unless (take (length scan_nes + length red_nes) map_lam_ts == map TC.argType (scan_nes'++ red_nes')) $ TC.bad $ TC.TypeError $ "Map function return type " ++ prettyTuple map_lam_ts ++ " wrong for given scan and reduction functions." -- | Get Stream's accumulators as a sub-expression list getStreamAccums :: StreamForm lore -> [SubExp] getStreamAccums (Parallel _ _ _ accs) = accs getStreamAccums (Sequential accs) = accs getStreamOrder :: StreamForm lore -> StreamOrd getStreamOrder (Parallel o _ _ _) = o getStreamOrder (Sequential _) = InOrder instance OpMetrics (Op lore) => OpMetrics (SOAC lore) where opMetrics (Stream _ _ lam _) = inside "Stream" $ lambdaMetrics lam opMetrics (Scatter _len lam _ivs _as) = inside "Scatter" $ lambdaMetrics lam opMetrics (GenReduce _len ops bucket_fun _imgs) = inside "GenReduce" $ mapM_ (lambdaMetrics . genReduceOp) ops >> lambdaMetrics bucket_fun opMetrics (Screma _ (ScremaForm (scan_lam, _) (_, red_lam, _) map_lam) _) = inside "Screma" $ lambdaMetrics scan_lam >> lambdaMetrics red_lam >> lambdaMetrics map_lam opMetrics CmpThreshold{} = seen "CmpThreshold" instance PrettyLore lore => PP.Pretty (SOAC lore) where ppr (Stream size form lam arrs) = case form of Parallel o comm lam0 acc -> let ord_str = if o == Disorder then "Per" else "" comm_str = case comm of Commutative -> "Comm" Noncommutative -> "" in text ("streamPar"++ord_str++comm_str) <> parens (ppr size <> comma ppr lam0 comma ppr lam commasep ( PP.braces (commasep $ map ppr acc) : map ppr arrs )) Sequential acc -> text "streamSeq" <> parens (ppr size <> comma ppr lam <> comma commasep ( PP.braces (commasep $ map ppr acc) : map ppr arrs )) ppr (Scatter len lam ivs as) = ppSOAC "scatter" len [lam] (Just (map Var ivs)) (map (\(_,n,a) -> (n,a)) as) ppr (GenReduce len ops bucket_fun imgs) = ppGenReduce len ops bucket_fun imgs ppr (Screma w (ScremaForm (scan_lam, scan_nes) (_, red_lam, red_nes) map_lam) arrs) | isNilFn scan_lam, null scan_nes, isNilFn red_lam, null red_nes = text "map" <> parens (ppr w <> comma ppr map_lam <> comma commasep (map ppr arrs)) | isNilFn scan_lam, null scan_nes = text "redomap" <> parens (ppr w <> comma ppr red_lam <> comma commasep (map ppr red_nes) <> comma ppr map_lam <> comma commasep (map ppr arrs)) | isNilFn red_lam, null red_nes = text "scanomap" <> parens (ppr w <> comma ppr scan_lam <> comma commasep (map ppr scan_nes) <> comma ppr map_lam <> comma commasep (map ppr arrs)) ppr (Screma w form arrs) = ppScrema w form arrs ppr (CmpThreshold what s) = text "cmpThreshold(" <> ppr what <> comma PP.<+> text (show s) <> text ")" ppScrema :: (PrettyLore lore, Pretty inp) => SubExp -> ScremaForm lore -> [inp] -> Doc ppScrema w (ScremaForm (scan_lam, scan_nes) (comm, red_lam, red_nes) map_lam) arrs = text s <> parens (ppr w <> comma ppr scan_lam <> comma PP.braces (commasep $ map ppr scan_nes) ppr red_lam <> comma PP.braces (commasep $ map ppr red_nes) ppr map_lam <> comma commasep (map ppr arrs)) where s = case comm of Noncommutative -> "screma" Commutative -> "scremaComm" ppGenReduce :: (PrettyLore lore, Pretty inp) => SubExp -> [GenReduceOp lore] -> Lambda lore -> [inp] -> Doc ppGenReduce len ops bucket_fun imgs = text "gen_reduce" <> parens (ppr len <> comma PP.braces (mconcat $ intersperse (comma <> PP.line) $ map ppOp ops) <> comma ppr bucket_fun <> comma commasep (map ppr imgs)) where ppOp (GenReduceOp w dests nes op) = ppr w <> comma <> PP.braces (commasep $ map ppr dests) <> comma PP.braces (commasep $ map ppr nes) <> comma ppr op ppSOAC :: (Pretty fn, Pretty v) => String -> SubExp -> [fn] -> Maybe [SubExp] -> [v] -> Doc ppSOAC name size funs es as = text name <> parens (ppr size <> comma ppList funs commasep (es' ++ map ppr as)) where es' = maybe [] ((:[]) . ppTuple') es ppList :: Pretty a => [a] -> Doc ppList as = case map ppr as of [] -> mempty a':as' -> foldl () (a' <> comma) $ map (<> comma) as'