{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE ViewPatterns #-}
{-# OPTIONS_GHC -Wall -fwarn-tabs #-}
module Language.Hakaru.Evaluation.ConstantPropagation
( constantPropagation
) where
#if __GLASGOW_HASKELL__ < 710
import Control.Applicative (Applicative (..))
import Data.Functor ((<$>))
#endif
import Control.Monad.Reader
import Data.Monoid (All (..))
import Language.Hakaru.Evaluation.EvalMonad (runPureEvaluate)
import Language.Hakaru.Syntax.ABT (ABT (..), View (..))
import Language.Hakaru.Syntax.AST
import Language.Hakaru.Syntax.IClasses (Foldable21 (..),
Traversable21 (..))
import Language.Hakaru.Syntax.Variable
type Env = Assocs Literal
newtype PropM a = PropM { runPropM :: Reader Env a }
deriving (Functor, Applicative, Monad, MonadReader Env)
constantPropagation
:: forall abt a . (ABT Term abt)
=> abt '[] a
-> abt '[] a
constantPropagation abt = runReader (runPropM $ constantProp' abt) emptyAssocs
constantProp'
:: forall abt a xs . (ABT Term abt)
=> abt xs a
-> PropM (abt xs a)
constantProp' = start
where
start :: forall b ys . abt ys b -> PropM (abt ys b)
start = loop . viewABT
loop :: forall b ys . View (Term abt) ys b -> PropM (abt ys b)
loop (Var v) = maybe (var v) (syn . Literal_) . lookupAssoc v <$> ask
loop (Syn s) = constantPropTerm s
loop (Bind v b) = bind v <$> loop b
isLiteral :: forall abt b ys . (ABT Term abt) => abt ys b -> Bool
isLiteral abt = case viewABT abt of
Syn (Literal_ _) -> True
_ -> False
isFoldable :: forall abt b . (ABT Term abt) => Term abt b -> Bool
isFoldable = getAll . foldMap21 (All . isLiteral)
getLiteral :: forall abt ys b. (ABT Term abt) => abt ys b -> Maybe (Literal b)
getLiteral e =
case viewABT e of
Syn (Literal_ l) -> Just l
_ -> Nothing
tryEval :: forall abt b . (ABT Term abt) => Term abt b -> abt '[] b
tryEval term
| isFoldable term = runPureEvaluate (syn term)
| otherwise = syn term
constantPropTerm
:: (ABT Term abt)
=> Term abt a
-> PropM (abt '[] a)
constantPropTerm (Let_ :$ rhs :* body :* End) =
caseBind body $ \v body' -> do
rhs' <- constantProp' rhs
case getLiteral rhs' of
Just l -> local (insertAssoc (Assoc v l)) (constantProp' body')
Nothing -> do
body'' <- constantProp' body'
return $ syn (Let_ :$ rhs' :* bind v body'' :* End)
constantPropTerm term = tryEval <$> traverse21 constantProp' term