{-# LANGUAGE RecordWildCards #-} -- | Simple code motion transformation performing common sub-expression -- elimination and variable hoisting. Note that the implementation is very -- inefficient. -- -- The code is based on an implementation by Gergely Dévai. module Language.Syntactic.Functional.Sharing ( -- * Interface InjDict (..) , CodeMotionInterface (..) , defaultInterface , defaultInterfaceDecor -- * Code motion , codeMotion ) where import Control.Monad.State import Data.Maybe (isNothing) import Data.Set (Set) import qualified Data.Set as Set import Data.Typeable import Data.Constraint (Dict (..)) import Language.Syntactic import Language.Syntactic.Functional -------------------------------------------------------------------------------- -- * Interface -------------------------------------------------------------------------------- -- | Interface for injecting binding constructs data InjDict sym a b = InjDict { injVariable :: Name -> sym (Full a) -- ^ Inject a variable , injLambda :: Name -> sym (b :-> Full (a -> b)) -- ^ Inject a lambda , injLet :: sym (a :-> (a -> b) :-> Full b) -- ^ Inject a "let" symbol } -- | Code motion interface data CodeMotionInterface sym = Interface { mkInjDict :: forall a b . ASTF sym a -> ASTF sym b -> Maybe (InjDict sym a b) -- ^ Try to construct an 'InjDict'. The first argument is the expression -- to be shared, and the second argument the expression in which it will -- be shared. This function can be used to transfer information (e.g. -- from static analysis) from the shared expression to the introduced -- variable. , castExprCM :: forall a b . ASTF sym a -> ASTF sym b -> Maybe (ASTF sym b) -- ^ Try to type cast an expression. The first argument is the -- expression to cast. The second argument can be used to construct a -- witness to support the casting. The resulting expression (if any) -- should be equal to the first argument. , hoistOver :: forall c. ASTF sym c -> Bool -- ^ Whether a sub-expression can be hoisted over the given expression } -- | Default 'CodeMotionInterface' for domains of the form -- @`Typed` (... `:+:` `Binding` `:+:` ...)@. defaultInterface :: forall binding sym symT . ( binding :<: sym , Let :<: sym , symT ~ Typed sym ) => (forall a . Typeable a => Name -> binding (Full a)) -- ^ Variable constructor (e.g. 'Var' or 'VarT') -> (forall a b . Typeable a => Name -> binding (b :-> Full (a -> b))) -- ^ Lambda constructor (e.g. 'Lam' or 'LamT') -> (forall a b . ASTF symT a -> ASTF symT b -> Bool) -- ^ Can the expression represented by the first argument be shared in -- the second argument? -> (forall a . ASTF symT a -> Bool) -- ^ Can we hoist over this expression? -> CodeMotionInterface symT defaultInterface var lam sharable hoistOver = Interface {..} where mkInjDict :: ASTF symT a -> ASTF symT b -> Maybe (InjDict symT a b) mkInjDict a b | not (sharable a b) = Nothing mkInjDict a b = simpleMatch (\(Typed _) _ -> simpleMatch (\(Typed _) _ -> let injVariable = Typed . inj . var injLambda = Typed . inj . lam injLet = Typed $ inj Let in Just InjDict {..} ) b ) a castExprCM = castExpr -- | Default 'CodeMotionInterface' for domains of the form -- @(... `:&:` info)@, where @info@ can be used to witness type casting defaultInterfaceDecor :: forall binding sym symI info . ( binding :<: sym , Let :<: sym , symI ~ (sym :&: info) ) => (forall a b . info a -> info b -> Maybe (Dict (a ~ b))) -- ^ Construct a type equality witness -> (forall a b . info a -> info b -> info (a -> b)) -- ^ Construct info for a function, given info for the argument and the -- result -> (forall a . info a -> Name -> binding (Full a)) -- ^ Variable constructor -> (forall a b . info a -> info b -> Name -> binding (b :-> Full (a -> b))) -- ^ Lambda constructor -> (forall a b . ASTF symI a -> ASTF symI b -> Bool) -- ^ Can the expression represented by the first argument be shared in -- the second argument? -> (forall a . ASTF symI a -> Bool) -- ^ Can we hoist over this expression? -> CodeMotionInterface symI defaultInterfaceDecor kaka mkFunInfo var lam sharable hoistOver = Interface {..} where mkInjDict :: ASTF symI a -> ASTF symI b -> Maybe (InjDict symI a b) mkInjDict a b | not (sharable a b) = Nothing mkInjDict a b = simpleMatch (\(_ :&: aInfo) _ -> simpleMatch (\(_ :&: bInfo) _ -> let injVariable v = inj (var aInfo v) :&: aInfo injLambda v = inj (lam aInfo bInfo v) :&: mkFunInfo aInfo bInfo injLet = inj Let :&: bInfo in Just InjDict {..} ) b ) a castExprCM :: ASTF symI a -> ASTF symI b -> Maybe (ASTF symI b) castExprCM a b = simpleMatch (\(_ :&: aInfo) _ -> simpleMatch (\(_ :&: bInfo) _ -> case kaka aInfo bInfo of Just Dict -> Just a _ -> Nothing ) b ) a -------------------------------------------------------------------------------- -- * Code motion -------------------------------------------------------------------------------- -- | Substituting a sub-expression. Assumes no variable capturing in the -- expressions involved. substitute :: forall sym a b . (Equality sym, BindingDomain sym) => CodeMotionInterface sym -> ASTF sym a -- ^ Sub-expression to be replaced -> ASTF sym a -- ^ Replacing sub-expression -> ASTF sym b -- ^ Whole expression -> ASTF sym b substitute iface x y a | Just y' <- castExprCM iface y a, alphaEq x a = y' | otherwise = subst a where subst :: AST sym c -> AST sym c subst (f :$ a) = subst f :$ substitute iface x y a subst a = a -- Note: Since `codeMotion` only uses `substitute` to replace sub-expressions -- with fresh variables, there's no risk of capturing. -- | Count the number of occurrences of a sub-expression count :: forall sym a b . (Equality sym, BindingDomain sym) => ASTF sym a -- ^ Expression to count -> ASTF sym b -- ^ Expression to count in -> Int count a b | alphaEq a b = 1 | otherwise = cnt b where cnt :: AST sym c -> Int cnt (f :$ b) = cnt f + count a b cnt _ = 0 -- | Environment for the expression in the 'choose' function data Env sym = Env { inLambda :: Bool -- ^ Whether the current expression is inside a lambda , counter :: EF (AST sym) -> Int -- ^ Counting the number of occurrences of an expression in the -- environment , dependencies :: Set Name -- ^ The set of variables that are not allowed to occur in the chosen -- expression } -- | Checks whether a sub-expression in a given environment can be lifted out liftable :: BindingDomain sym => Env sym -> ASTF sym a -> Bool liftable env a = independent && isNothing (prVar a) && heuristic -- Lifting dependent expressions is semantically incorrect. Lifting -- variables would cause `codeMotion` to loop. where independent = Set.null $ Set.intersection (freeVars a) (dependencies env) heuristic = inLambda env || (counter env (EF a) > 1) -- | A sub-expression chosen to be shared together with an evidence that it can -- actually be shared in the whole expression under consideration data Chosen sym a where Chosen :: InjDict sym b a -> ASTF sym b -> Chosen sym a -- | Choose a sub-expression to share choose :: forall sym a . (Equality sym, BindingDomain sym) => CodeMotionInterface sym -> ASTF sym a -> Maybe (Chosen sym a) choose iface a = chooseEnvSub initEnv a where initEnv = Env { inLambda = False , counter = \(EF b) -> count b a , dependencies = Set.empty } chooseEnv :: Env sym -> ASTF sym b -> Maybe (Chosen sym a) chooseEnv env b | liftable env b , Just id <- mkInjDict iface b a = Just $ Chosen id b chooseEnv env b | hoistOver iface b = chooseEnvSub env b | otherwise = Nothing -- | Like 'chooseEnv', but does not consider the top expression for sharing chooseEnvSub :: Env sym -> AST sym b -> Maybe (Chosen sym a) chooseEnvSub env (Sym lam :$ b) | Just v <- prLam lam = chooseEnv (env' v) b where env' v = env { inLambda = True , dependencies = Set.insert v (dependencies env) } chooseEnvSub env (s :$ b) = chooseEnvSub env s `mplus` chooseEnv env b chooseEnvSub _ _ = Nothing codeMotionM :: forall sym m a . ( Equality sym , BindingDomain sym , MonadState Name m ) => CodeMotionInterface sym -> ASTF sym a -> m (ASTF sym a) codeMotionM iface a | Just (Chosen id b) <- choose iface a = share id b | otherwise = descend a where share :: InjDict sym b a -> ASTF sym b -> m (ASTF sym a) share id b = do b' <- codeMotionM iface b v <- get; put (v+1) let x = Sym (injVariable id v) body <- codeMotionM iface $ substitute iface b x a return $ Sym (injLet id) :$ b' :$ (Sym (injLambda id v) :$ body) descend :: AST sym b -> m (AST sym b) descend (f :$ a) = liftM2 (:$) (descend f) (codeMotionM iface a) descend a = return a -- | Perform common sub-expression elimination and variable hoisting codeMotion :: forall sym m a . ( Equality sym , BindingDomain sym ) => CodeMotionInterface sym -> ASTF sym a -> ASTF sym a codeMotion iface a = flip evalState maxVar $ codeMotionM iface a where maxVar = succ $ Set.findMax $ Set.insert 0 $ allVars a