module DDC.Core.Flow.Transform.Forward ( forwardProcesses ) where import DDC.Core.Flow.Profile import DDC.Core.Flow.Prim import DDC.Core.Flow.Prim.KiConFlow import DDC.Core.Flow.Prim.TyConFlow import DDC.Core.Exp.Annot import DDC.Core.Module import qualified DDC.Core.Simplifier as C import qualified DDC.Core.Transform.Forward as Forward import qualified DDC.Core.Transform.TransformModX as T -- | Find all top-level Process bindings, and forward all non-series operators. -- This is a bit of a hack, because lower doesn't accept any non-series bindings. forwardProcesses :: Module () Name -> Module () Name forwardProcesses mm = T.transformModLet forwardBind mm -- | Forward a single process binding forwardBind :: Bind Name -> Exp () Name -> Exp () Name forwardBind b xx -- If the result type of a top-level binding is a Process, -- we must prepare it for the lowering transform. -- Forward everything we can, while leaving series operators at the top. | isProcessType $ snd $ takeTFunAllArgResult $ typeOfBind b = C.result $ Forward.forwardX profile conf_process xx -- Otherwise do minimal forwarding, except for pushing any rate-valued functions -- into their runKernel#. | otherwise = C.result $ Forward.forwardX profile conf_nonproc xx where conf_process = Forward.Config isFloatable_process False conf_nonproc = Forward.Config isFloatable_nonproc False -- Deny forwarding of flow primitives. -- Force anything else that's used only once. -- -- For lower to work, we need to forward everything except primitives, -- but that duplicates work. Lower should probably be changed. isFloatable_process lts = case lts of LLet (BName _ _) x | Just (n,_) <- takeXPrimApps x -> case n of NameOpConcrete _ -> Forward.FloatDeny NameOpControl _ -> Forward.FloatDeny NameOpSeries _ -> Forward.FloatDeny NameOpStore _ -> Forward.FloatDeny NameOpVector _ -> Forward.FloatDeny _ -> Forward.FloatForceUsedOnce _ -> Forward.FloatForceUsedOnce -- Forward any Process functions - they will have Rate foralls inside them. isFloatable_nonproc lts = case lts of LLet _ x | Just (lams,_) <- takeXLamFlags x , any (\(_,bo) -> typeOfBind bo == kRate) lams -> Forward.FloatForce _ -> Forward.FloatAllow