module DDC.Core.Flow.Transform.Rates.SeriesOfVector (seriesOfVectorModule ,seriesOfVectorFunction) where import DDC.Core.Collect import DDC.Core.Flow.Compounds import DDC.Core.Flow.Prim import DDC.Core.Flow.Exp import DDC.Core.Flow.Transform.Rates.Constraints import DDC.Core.Flow.Transform.Rates.Fail import DDC.Core.Flow.Transform.Rates.Graph import DDC.Core.Module import DDC.Core.Transform.Annotate import DDC.Core.Transform.Deannotate import qualified DDC.Type.Env as Env import Control.Applicative import Control.Monad import Data.List (intersect, nub) import qualified Data.Map as Map import Data.Maybe (catMaybes) import qualified Data.Set as Set seriesOfVectorModule :: ModuleF -> (ModuleF, [(Name,Fail)]) seriesOfVectorModule mm = let body = deannotate (const Nothing) $ moduleBody mm (lets, xx) = splitXLets body letsErrs = map seriesOfVectorLets lets lets' = map fst letsErrs errs = concatMap snd letsErrs body' = annotate () $ xLets lets' xx in -- trace ("ORIGINAL:"++ show (ppr $ moduleBody mm)) -- trace ("MODULE:" ++ show (ppr body')) (mm { moduleBody = body' }, errs) seriesOfVectorLets :: LetsF -> (LetsF, [(Name,Fail)]) seriesOfVectorLets ll | LLet b@(BName n _) x <- ll , (x',errs) <- seriesOfVectorFunction x = (LLet b x', map (\f -> (n,f)) errs) | LRec bxs <- ll , (bs,xs) <- unzip bxs , (xs',_errs) <- unzip $ map seriesOfVectorFunction xs = (LRec (bs `zip` xs'), []) -- We still need to produce errors if this doesn't work. | otherwise = (ll, []) -- | Takes a single function body. Function body must be in a-normal form. seriesOfVectorFunction :: ExpF -> (ExpF, [Fail]) seriesOfVectorFunction fun = run $ do -- Peel off the lambdas let (lams, body) = takeXLamFlags_safe fun -- This assumes the body is already in a-normal form. (lets, xx) = splitXLets body -- Split into name and values and warn for recursive bindings binds <- takeLets lets let tymap = takeTypes (concatMap valwitBindsOfLets lets ++ map snd lams) -- Assumes the binds only use vector primitives, -- OR if not vector primitives, do not refer to bound vectors let names = map fst binds -- Make sure names are unique when (length names /= length (nub names)) $ warn FailNamesNotUnique (constrs, equivs) <- checkBindConstraints binds let extras = catMaybes $ map (takeNameOfBind . snd) lams let graph = graphOfBinds binds extras let rets = catMaybes $ map takeNameOfBound $ Set.toList $ freeX Env.empty xx loops <- schedule graph equivs rets binds' <- orderBinds binds loops -- True <- trace ("TYMAP:" ++ show tymap) return True -- True <- trace ("NAMES,LOOPS,NAMES':" ++ show (names, loops, map (map fst) binds')) -- return True let outputs = map lOutputs loops let inputs = map lInputs loops let getMax = getMaxSize constrs equivs extras return $ construct getMax lams (zip3 binds' outputs inputs) equivs tymap xx -- | Peel the lambdas off, or const if there are none takeXLamFlags_safe x | Just (binds, body) <- takeXLamFlags x = (binds, body) | otherwise = ([], x) -- | Split into name and values and warn for recursive bindings takeLets :: [LetsF] -> LogFailures [(Name, ExpF)] takeLets lets = concat <$> mapM get lets where get (LLet (BName n _) x) = return [(n,x)] get (LLet (BNone _) _) = return [] get (LLet (BAnon _) _) = w FailNoDeBruijnAllowed get (LRec _ ) = w FailRecursiveBindings get (LPrivate _ _ _) = w FailLetRegionNotHandled get (LWithRegion _ ) = w FailLetRegionNotHandled w err = warn err >> return [] -- | Split into name and values and warn for recursive bindings takeTypes :: [Bind Name] -> Map.Map Name TypeF takeTypes binds = Map.fromList $ concatMap get binds where get (BName n t) = [(n,t)] get _ = [] data Loop = Loop { lBindings :: [Name] , lOutputs :: [Name] , lInputs :: [Name] } deriving (Eq,Show) schedule :: Graph -> EquivClass -> [Name] -> LogFailures [Loop] schedule graph equivs rets = let type_order = map (canonName equivs . Set.findMin) equivs -- minimumBy length $ map scheduleTypes $ permutations type_order (wts, graph') = scheduleTypes graph equivs type_order loops = scheduleAll (map snd wts) graph graph' -- Use the original graph to find vars that cross loop boundaries outputs = scheduleOutputs loops graph rets inputs = scheduleInputs loops graph in -- trace ("GRAPH,GRAPH',WTS,EQUIVS:" ++ show (graph, graph', wts, equivs)) return $ zipWith3 Loop loops outputs inputs scheduleTypes :: Graph -> EquivClass -> [Name] -> ([(Name, Map.Map Name Int)], Graph) scheduleTypes graph types type_order = foldl go ([],graph) type_order where go (w,g) ty = let w' = typedTraversal g types ty g' = mergeWeights g w' in ((ty,w') : w, g') scheduleAll :: [Map.Map Name Int] -> Graph -> Graph -> [[Name]] scheduleAll weights graph graph' = loops where weights' = map invertMap weights topo = graphTopoOrder graph' loops = map getNames topo getNames n = sort $ find n (weights `zip` weights') original_order = graphTopoOrder graph -- Cheesy hack to get ns in same order as the original graph's topo: -- filter topo to only those elements in ns sort ns = filter (flip elem ns) original_order find _ [] = [] find n ((w,w') : rest) | Just i <- n `Map.lookup` w , Just ns <- i `Map.lookup` w' = ns | otherwise = find n rest -- Find any variables that cross loop boundaries - they must be reified scheduleOutputs :: [[Name]] -> Graph -> [Name] -> [[Name]] scheduleOutputs loops graph rets = map output loops where output ns = graphOuts ns ++ filter (`elem` ns) rets graphOuts ns = concatMap (\(k,es) -> if k `elem` ns then [] else ns `intersect` map fst es) $ Map.toList graph -- Find any variables that cross loop boundaries - they must be reified scheduleInputs :: [[Name]] -> Graph -> [[Name]] scheduleInputs loops graph = map input loops where input ns = filter (\n -> not (n `elem` ns)) $ graphIns ns graphIns ns = nub $ concatMap (map fst . mlookup "graphIns" graph) ns typedTraversal :: Graph -> EquivClass -> Name -> Map.Map Name Int typedTraversal graph types current_type = restrictTypes types current_type $ traversal graph w where w u v = if w' u v then 1 else 0 w' (u, fusible) v | canonName types u == current_type = canonName types v /= current_type || not fusible | otherwise = False restrictTypes :: EquivClass -> Name -> Map.Map Name Int -> Map.Map Name Int restrictTypes types current_type weights = Map.filterWithKey restrict weights where restrict n _ = canonName types n == current_type orderBinds :: [(Name,ExpF)] -> [Loop] -> LogFailures [[(Name,ExpF)]] orderBinds binds loops = let bindsM = Map.fromList binds order = map lBindings loops get k | Just v <- Map.lookup k bindsM = [(k,v)] | otherwise = [] in return $ map (\o -> concatMap get o) order construct :: (Name -> Name) -> [(Bool, BindF)] -> [([(Name, ExpF)], [Name], [Name])] -> EquivClass -> Map.Map Name TypeF -> ExpF -> ExpF construct getMax lams loops equivs tys xx = let lets = concatMap convert loops in makeXLamFlags lams $ xLets lets $ xx where convert (binds, outputs, inputs) = convertToSeries getMax binds outputs inputs equivs tys -- We still need to join procs, -- split output procs into separate functions convertToSeries :: (Name -> Name) -> [(Name,ExpF)] -> [Name] -> [Name] -> EquivClass -> Map.Map Name TypeF -> [LetsF] convertToSeries getMax binds outputs inputs equivs tys = concat setups ++ [LLet (BNone tBool) (runprocs inputs' processes)] ++ concat readrefs where runprocs :: [(Name,TypeF)] -> ExpF -> ExpF runprocs vecs@((cn,_):_) body = let cnn = canonName equivs cn kN = NameVarMod cnn "k" kFlags = [ (True, BName kN kRate) , (False, BNone $ tRateNat $ TVar $ UName kN)] vFlags = map (\(n,t) -> (False, BName (NameVarMod n "s") (tSeries (TVar (UName kN)) t))) vecs in xApps (xVarOpSeries (OpSeriesRunProcess $ length vecs)) ( map (XType . snd) vecs ++ map (XVar . UName . fst) vecs ++ [(makeXLamFlags (kFlags ++ vFlags) body)]) -- Should we introduce a rate parameter for generates? runprocs [] body = body inputs' :: [(Name,TypeF)] inputs' = concatMap filterInputs inputs filterInputs inp | tyI <- mlookup "collectKloks" tys inp , Just (_tcVec, [tyA]) <- takeTyConApps tyI , tyI == tVector tyA = [(inp, tyA)] | otherwise = [] processes = foldr wrap joins binds wrap (n,x) body = wrapSeriesX equivs outputs n (mlookup "wrap" tys n) x body joins | not $ null outputs = foldl1 mkJoin $ map (\n -> XVar $ UName $ NameVarMod n "proc") outputs | otherwise = xUnit -- ??? mkJoin p q = xApps (xVarOpSeries OpSeriesJoin) [p, q] -- fill vectors and read references (setups, readrefs) = unzip $ map setread $ filter (flip elem outputs . fst) binds setread (n,x) = setreadSeriesX getMax tys n (mlookup "setread" tys n) x setreadSeriesX :: (Name -> Name) -> Map.Map Name TypeF -> Name -> TypeF -> ExpF -> ([LetsF], [LetsF]) setreadSeriesX getMax tys name ty xx | Just (f, args) <- takeXApps xx , XVar (UPrim (NameOpVector ov) _) <- f = case ov of -- any folds MUST be known as outputs, so this is safe OpVectorReduce | [_tA, _f, z, _vA] <- args -> ([ LLet (BName (nm "ref") (tRef ty)) (xNew ty z) ] ,[ LLet (BName name ty) (xRead ty (vr $ nm "ref"))]) _ | [_vec, tyR] <- takeTApps ty , v <- getMax name -- canonName equivs name , [_vec, tyCR] <- takeTApps $ mlookup "setreadSeriesX" tys v -> let vl = xApps (xVarOpVector OpVectorLength) [XType tyCR, XVar $ UName v] in ([ LLet (BName name $ tBot kData) $ xNewVector tyR vl ] , []) _ -> ([], []) | otherwise = ([],[]) where nm s = NameVarMod name s vr n = XVar $ UName n wrapSeriesX :: EquivClass -> [Name] -> Name -> TypeF -> ExpF -> ExpF -> ExpF wrapSeriesX equivs outputs name ty xx wrap | Just (op, args) <- takeXApps xx , XVar (UPrim (NameOpVector ov) _) <- op = case ov of OpVectorReduce | [_tA, f, z, vA] <- args , XVar (UName nvA) <- vA , kA <- klok nvA -> XLet (LLet (BName name'proc tProcess) $ xApps (xVarOpSeries OpSeriesReduce) [kA, XType ty, XVar (UName name'ref), f, z, modNameX "s" vA]) wrap OpVectorMap n | (tys, f : rest) <- splitAt (n+1) args , length rest == n , kT <- klok name , rest' <- map (modNameX "s") rest -> XLet (LLet (BName name's $ tBot kData) $ xApps (xVarOpSeries (OpSeriesMap n)) ([kT] ++ tys ++ [f] ++ rest')) wrap'fill OpVectorFilter | [tA, p, vA] <- args , XVar (UName nvA) <- vA , tkA <- klokT nvA , kA <- klok nvA , TVar (UName nkT) <- klokT name , tkT <- klokT name -> XLet (LLet (BName name'flags (tBot kData)) $ xApps (xVarOpSeries (OpSeriesMap 1)) ([kA, tA, XType tBool, p, modNameX "s" vA])) $ xApps (xVarOpSeries (OpSeriesMkSel 1)) ([kA, XVar (UName name'flags) , XLAM (BName nkT kRate) $ XLam (BName name'sel (tSel1 tkA tkT)) $ XLet (LLet (BName name's (tBot kData)) $ xApps (xVarOpSeries OpSeriesPack) ([kA, XType tkT, tA, XVar (UName name'sel), modNameX "s" vA])) wrap'fill ]) _ -> xx | otherwise = xx where name'flags= NameVarMod name "flags" name'proc = NameVarMod name "proc" name'ref = NameVarMod name "ref" name's = NameVarMod name "s" name'sel = NameVarMod name "sel" klokT n = let n' = canonName equivs n kN = NameVarMod n' "k" in TVar $ UName kN klok n = XType $ klokT n tyR | [_vec, tyR'] <- takeTApps ty = Just tyR' | otherwise = Nothing wrap'fill | name `elem` outputs , Just tyR' <- tyR = XLet (LLet (BName name'proc tProcess) $ xApps fillV [klok name, XType tyR', vr name, vr name's]) wrap | otherwise = wrap fillV = xVarOpSeries OpSeriesFill vr n = XVar $ UName n -- tySeries -- | Vector n xVarOpSeries n = XVar (UPrim (NameOpSeries n) (typeOpSeries n)) xVarOpVector n = XVar (UPrim (NameOpVector n) (typeOpVector n)) modNameX :: String -> ExpF -> ExpF modNameX s xx = case xx of XVar (UName n) -> XVar (UName (NameVarMod n s)) _ -> xx {- \as,bs... cs = map as ds = filter as n = fold ds es = map3 bs cs return es ==> schedule graph equivs [es] ==> [ [ds, n] , [cs, es] ] -}