{-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE Rank2Types #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE ViewPatterns #-} {-# OPTIONS_GHC -Wall -fwarn-tabs #-} ---------------------------------------------------------------- -- 2016.04.02 -- | -- Module : Language.Hakaru.Evaluation.ConstantPropagation -- Copyright : Copyright (c) 2016 the Hakaru team -- License : BSD3 -- Maintainer : wren@community.haskell.org -- Stability : experimental -- Portability : GHC-only -- -- ---------------------------------------------------------------- module Language.Hakaru.Evaluation.ConstantPropagation ( constantPropagation ) where #if __GLASGOW_HASKELL__ < 710 import Control.Applicative (Applicative (..)) import Data.Functor ((<$>)) #endif import Control.Monad.Reader import Data.Monoid (All (..)) import Language.Hakaru.Evaluation.EvalMonad (runPureEvaluate) import Language.Hakaru.Syntax.ABT (ABT (..), View (..)) import Language.Hakaru.Syntax.AST import Language.Hakaru.Syntax.IClasses (Foldable21 (..), Traversable21 (..)) import Language.Hakaru.Syntax.Variable type Env = Assocs Literal -- The constant propagation monad. Simply threads through an environment mapping -- variables to known constant values. newtype PropM a = PropM { runPropM :: Reader Env a } deriving (Functor, Applicative, Monad, MonadReader Env) ---------------------------------------------------------------- ---------------------------------------------------------------- -- TODO: try evaluating certain things even if not all their immediate -- subterms are literals. For example: -- (1) evaluate beta-redexes where the argument is a literal -- (2) evaluate case-of-constructor if we can -- (3) handle identity elements for NaryOps -- (4) Recognize trivial cases for looping constructs: -- summate a b (const 0) == 0 -- summate a b id == b - a -- summate a b (const x) == x * (b - a) -- -- | Perform basic constant propagation. constantPropagation :: forall abt a . (ABT Term abt) => abt '[] a -> abt '[] a constantPropagation abt = runReader (runPropM $ constantProp' abt) emptyAssocs constantProp' :: forall abt a xs . (ABT Term abt) => abt xs a -> PropM (abt xs a) constantProp' = start where start :: forall b ys . abt ys b -> PropM (abt ys b) start = loop . viewABT loop :: forall b ys . View (Term abt) ys b -> PropM (abt ys b) loop (Var v) = maybe (var v) (syn . Literal_) . lookupAssoc v <$> ask loop (Syn s) = constantPropTerm s loop (Bind v b) = bind v <$> loop b isLiteral :: forall abt b ys . (ABT Term abt) => abt ys b -> Bool isLiteral abt = case viewABT abt of Syn (Literal_ _) -> True _ -> False isFoldable :: forall abt b . (ABT Term abt) => Term abt b -> Bool isFoldable = getAll . foldMap21 (All . isLiteral) getLiteral :: forall abt ys b. (ABT Term abt) => abt ys b -> Maybe (Literal b) getLiteral e = case viewABT e of Syn (Literal_ l) -> Just l _ -> Nothing tryEval :: forall abt b . (ABT Term abt) => Term abt b -> abt '[] b tryEval term | isFoldable term = runPureEvaluate (syn term) | otherwise = syn term constantPropTerm :: (ABT Term abt) => Term abt a -> PropM (abt '[] a) constantPropTerm (Let_ :$ rhs :* body :* End) = caseBind body $ \v body' -> do rhs' <- constantProp' rhs case getLiteral rhs' of Just l -> local (insertAssoc (Assoc v l)) (constantProp' body') Nothing -> do body'' <- constantProp' body' return $ syn (Let_ :$ rhs' :* bind v body'' :* End) constantPropTerm term = tryEval <$> traverse21 constantProp' term