-- | Common sub-expression elimination and variable hoisting for 'Lam' -- expressions. module Feldspar.DSL.Sharing where import Control.Monad.State import Data.Typeable import Feldspar.DSL.Expression import Feldspar.DSL.Lambda -- | Substituting a sub-expression substitute :: (ExprEq expr, Typeable ra, Typeable a, Typeable rb, Typeable b) => Lam expr ra a -- ^ Sub-expression to be replaced -> Lam expr ra a -- ^ Replacing sub-expression -> Lam expr rb b -- ^ Whole expression -> Lam expr rb b substitute x y a | Just y' <- exprCast y, exprEq x a = y' substitute x y (Lambda f) = Lambda $ \v -> substitute x y (f v) substitute x y (f :$: a) = substitute x y f :$: substitute x y a substitute x y a = a -- | Count the number of occurrences of a sub-expression count :: (ExprEq expr, Typeable ra, Typeable a, Typeable rb, Typeable b) => Lam expr ra a -- ^ Sub-expression -> Lam expr rb b -- ^ Whole expression -> Integer count a b = evalState (countM a b) 0 countM :: (ExprEq expr, Typeable ra, Typeable a, Typeable rb, Typeable b) => Lam expr ra a -> Lam expr rb b -> State Integer Integer countM a b = case exprCast b of Just b' -> do eq <- exprEqLam a (b' `asTypeOf` a) if eq then return 1 else countNonEq a b _ -> countNonEq a b countNonEq :: (ExprEq expr, Typeable ra, Typeable a, Typeable rb, Typeable b) => Lam expr ra a -> Lam expr rb b -> State Integer Integer countNonEq a (Lambda f) = do v <- freshVar "" countM a (f v) countNonEq a (f :$: b) = liftM2 (+) (countM a f) (countM a b) countNonEq _ _ = return 0 data SomeLam expr = forall ra a . (Typeable ra, Typeable a) => SomeLam (Lam expr ra a) -- | Custom parameters to sharing transformation. The 'necessary' predicate -- gives a necessary condition for lifting an expression, and 'sufficient' gives -- a sufficient condition. Note that 'necessary' takes precedence over -- 'sufficient'. -- -- The 'sharingPoint' field determines whether the expression is a valid point -- for introducing a 'Let'. data Params expr = Params { necessary :: SomeLam expr -> Bool , sufficient :: SomeLam expr -> Bool , sharingPoint :: SomeLam expr -> Bool } data Env expr = Env { inLambda :: Bool -- ^ Whether the current expression is inside a lambda , subExpr :: Bool -- ^ Whether the current expression is a sub-expression , counter :: SomeLam expr -> Integer -- ^ Counting the number of occurrences of an expression in the -- environment } simpleParams :: Params expr simpleParams = Params { necessary = const True , sufficient = const False , sharingPoint = const True } initEnv :: (ExprEq expr, Typeable ra, Typeable a) => Lam expr ra a -> Env expr initEnv a = Env { inLambda = False , subExpr = False , counter = \(SomeLam b) -> count b a } dummy = Variable "" ph = Variable "PLACEHOLDER" -- | Checks whether an expression is compound compound :: Lam expr ra a -> Bool compound (Lambda _) = True compound (_ :$: _) = True compound _ = False -- | Checks if the expression does not contain any @"PLACEHOLDER"@ variables independent :: Lam expr ra a -> Bool independent (Variable ident) = ident /= "PLACEHOLDER" independent (Lambda f) = independent (f dummy) independent (f :$: a) = independent f && independent a independent _ = True -- | Checks whether a sub-expression in a given environment can be lifted out liftable :: (Typeable ra, Typeable a) => Params expr -> Env expr -> Lam expr ra a -> Bool liftable params env a = independent a -- Lifting dependent expressions is semantically incorrect && subExpr env -- Otherwise infinite loop && necessary params (SomeLam a) && (heuristic || sufficient params (SomeLam a)) where heuristic = compound a && (inLambda env || (counter env (SomeLam a) > 1)) -- | Chooses a sub-expression to lift out choose :: (Typeable ra, Typeable a) => Params expr -> Env expr -> Lam expr ra a -> Maybe (SomeLam expr) choose par env a | liftable par env a = Just (SomeLam a) choose par env (Lambda f) = choose par env' (f ph) where env' = env {inLambda = True, subExpr = True} choose par env (f :$: a) = choose par env' f `mplus` choose par env' a where env' = env {subExpr = True} choose _ _ _ = Nothing -- | Perform common sub-expression elimination and variable hoisting sharing :: forall expr ra a . (ExprEq expr, Typeable ra, Typeable a) => Params expr -> Lam expr ra a -> Lam expr ra a sharing par a = case choose par (initEnv a) a of Just b | sharingPoint par (SomeLam a) -> share b _ -> descend par a where share :: SomeLam expr -> Lam expr ra a share (SomeLam b) = let_ "v" (sharing par b) f where f x = sharing par (substitute b x a) descend :: (ExprEq expr, Typeable ra, Typeable a) => Params expr -> Lam expr ra a -> Lam expr ra a descend params (Lambda f) = Lambda $ \v -> sharing params (f v) descend params (f :$: a) = sharing params f :$: sharing params a descend _ a = a simpleSharing :: (ExprEq expr, Typeable ra, Typeable a) => Lam expr ra a -> Lam expr ra a simpleSharing = sharing simpleParams -- | Checks if the expression computes a function. Can be used in the parameters -- passed to 'sharing'. isFunction :: forall expr ra a . Typeable a => Lam expr ra a -> Bool isFunction _ = show (typeRepTyCon $ typeOf (undefined :: a)) == "->"