-- | Simple code motion transformation performing common sub-expression -- elimination and variable hoisting. Note that the implementation is very -- inefficient. -- -- The code is based on an implementation by Gergely Dévai. module Language.Syntactic.Sharing.SimpleCodeMotion ( BindDict (..) , codeMotion , defaultBindDict , reifySmart ) where import Control.Monad.State import Data.Set as Set import Data.Typeable import Data.Proxy import Language.Syntactic import Language.Syntactic.Constructs.Binding import Language.Syntactic.Constructs.Binding.HigherOrder -- | Interface for binding constructs data BindDict ctx dom = BindDict { prjVariable :: forall a . dom a -> Maybe VarId , prjLambda :: forall a . dom a -> Maybe VarId , injVariable :: forall a . (Sat ctx a, Typeable a) => ASTF dom a -> VarId -> dom (Full a) , injLambda :: forall a b . (Sat ctx a, Typeable a, Sat ctx b) => ASTF dom b -> VarId -> dom (b :-> Full (a -> b)) , injLet :: forall a b . (Sat ctx a, Sat ctx b) => ASTF dom b -> dom (a :-> (a -> b) :-> Full b) } -- TODO `injLambda` has more constraints than the `Lambda` constructor. This -- is demanded by the Feldspar implementation. One way to make things -- more consistent would be to add an extra `ctx` parameter to `Lambda` -- (like `Let`). -- | Substituting a sub-expression. Assumes no variable capturing in the -- expressions involved. substitute :: forall dom a b . (Typeable a, Typeable b, AlphaEq dom dom dom [(VarId,VarId)]) => ASTF dom a -- ^ Sub-expression to be replaced -> ASTF dom a -- ^ Replacing sub-expression -> ASTF dom b -- ^ Whole expression -> ASTF dom b substitute x y a | Just y' <- gcast y, alphaEq x a = y' | otherwise = subst a where subst :: Typeable c => AST dom c -> AST dom c subst (f :$ a) = subst f :$ substitute x y a subst a = a -- | Count the number of occurrences of a sub-expression count :: forall dom a b . AlphaEq dom dom dom [(VarId,VarId)] => ASTF dom a -- ^ Expression to count -> ASTF dom b -- ^ Expression to count in -> Int count a b | alphaEq a b = 1 | otherwise = cnt b where cnt :: AST dom c -> Int cnt (f :$ b) = cnt f + count a b cnt _ = 0 nonTerminal :: AST dom a -> Bool nonTerminal (_ :$ _) = True nonTerminal _ = False data SomeAST ctx dom where SomeAST :: (Sat ctx a, Typeable a) => ASTF dom a -> SomeAST ctx dom -- | Environment for the expression in the 'choose' function data Env ctx dom = Env { inLambda :: Bool -- ^ Whether the current expression is inside a lambda , canShare :: forall a . dom a -> Bool -- ^ Whether a given symbol can be shared , counter :: SomeAST ctx dom -> Int -- ^ Counting the number of occurrences of an expression in the -- environment , dependencies :: Set VarId -- ^ The set of variables that are not allowed to occur in the chosen -- expression } independent :: BindDict ctx dom -> Env ctx dom -> AST dom a -> Bool independent bindDict env (Sym (prjVariable bindDict -> Just v)) = not (v `member` dependencies env) independent bindDict env (f :$ a) = independent bindDict env f && independent bindDict env a independent _ _ _ = True -- | Checks whether a sub-expression in a given environment can be lifted out liftable :: (Sat ctx a, Typeable a) => BindDict ctx dom -> Env ctx dom -> ASTF dom a -> Bool liftable bindDict env a = independent bindDict env a && heuristic -- Lifting dependent expressions is semantically incorrect where heuristic = queryNodeSimple (const . canShare env) a && nonTerminal a && (inLambda env || (counter env (SomeAST a) > 1)) -- | Choose a sub-expression to share choose :: ( AlphaEq dom dom dom [(VarId,VarId)] , MaybeWitnessSat ctx dom , Typeable a ) => BindDict ctx dom -> (forall a . dom a -> Bool) -> ASTF dom a -> Maybe (SomeAST ctx dom) choose bindDict canShr a = chooseEnv bindDict env a where env = Env { inLambda = False , canShare = canShr , counter = \(SomeAST b) -> count b a , dependencies = empty } -- | Choose a sub-expression to share in an 'Env' environment chooseEnv :: forall ctx dom a . (MaybeWitnessSat ctx dom, Typeable a) => BindDict ctx dom -> Env ctx dom -> ASTF dom a -> Maybe (SomeAST ctx dom) chooseEnv bindDict env a | Just SatWit <- maybeWitnessSat (Proxy :: Proxy ctx) a , liftable bindDict env a = Just (SomeAST a) | otherwise = chooseEnvSub bindDict env a -- | Like 'chooseEnv', but does not consider the top expression for sharing chooseEnvSub :: MaybeWitnessSat ctx dom => BindDict ctx dom -> Env ctx dom -> AST dom a -> Maybe (SomeAST ctx dom) chooseEnvSub bindDict env (Sym (prjLambda bindDict -> Just v) :$ a) = chooseEnv bindDict env' a where env' = env { inLambda = True , dependencies = insert v (dependencies env) } chooseEnvSub bindDict env (f :$ a) = chooseEnvSub bindDict env f `mplus` chooseEnv bindDict env a chooseEnvSub _ _ _ = Nothing -- | Perform common sub-expression elimination and variable hoisting codeMotion :: forall ctx dom a . ( AlphaEq dom dom dom [(VarId,VarId)] , MaybeWitnessSat ctx dom , Typeable a ) => BindDict ctx dom -> (forall a . dom a -> Bool) -> ASTF dom a -> State VarId (ASTF dom a) codeMotion bindDict canShr a | Just SatWit <- maybeWitnessSat ctx a , Just b <- choose bindDict canShr a = share b | otherwise = descend a where ctx = Proxy :: Proxy ctx share :: Sat ctx a => SomeAST ctx dom -> State VarId (ASTF dom a) share (SomeAST b) = do b' <- codeMotion bindDict canShr b v <- get; put (v+1) let x = Sym (injVariable bindDict b v) body <- codeMotion bindDict canShr $ substitute b x a return $ Sym (injLet bindDict body) :$ b' :$ (Sym (injLambda bindDict body v) :$ body) descend :: AST dom b -> State VarId (AST dom b) descend (f :$ a) = liftM2 (:$) (descend f) (codeMotion bindDict canShr a) descend a = return a defaultBindDict :: forall ctx dom . ( Variable ctx :<: dom , Lambda ctx :<: dom , Let ctx ctx :<: dom ) => BindDict ctx dom defaultBindDict = BindDict { prjVariable = \a -> do Variable v <- prjCtx ctx a return v , prjLambda = \a -> do Lambda v <- prjCtx ctx a return v , injVariable = \_ v -> inj (Variable v `withContext` ctx) , injLambda = \_ v -> inj (Lambda v `withContext` ctx) , injLet = \_ -> inj (letBind ctx) } where ctx = Proxy :: Proxy ctx -- | Like 'reify' but with common sub-expression elimination and variable -- hoisting reifySmart :: forall ctx dom a . ( Let ctx ctx :<: dom , AlphaEq dom dom (Lambda ctx :+: Variable ctx :+: dom) [(VarId,VarId)] , MaybeWitnessSat ctx dom , Syntactic a (HODomain ctx dom) ) => (forall a . (Lambda ctx :+: Variable ctx :+: dom) a -> Bool) -> a -> ASTF (Lambda ctx :+: Variable ctx :+: dom) (Internal a) reifySmart canShr = flip evalState 0 . (codeMotion dict canShr <=< reifyM . desugar) where dict = defaultBindDict :: BindDict ctx (Lambda ctx :+: Variable ctx :+: dom)