{-# LANGUAGE CPP
, DataKinds
, EmptyCase
, ExistentialQuantification
, FlexibleContexts
, GADTs
, GeneralizedNewtypeDeriving
, KindSignatures
, MultiParamTypeClasses
, OverloadedStrings
, PolyKinds
, ScopedTypeVariables
, TypeFamilies
, TypeOperators
#-}
{-# OPTIONS_GHC -Wall -fwarn-tabs #-}
module Language.Hakaru.Syntax.CSE (cse) where
import Control.Monad.Reader
import Language.Hakaru.Syntax.ABT
import Language.Hakaru.Syntax.AST
import Language.Hakaru.Syntax.AST.Eq
import Language.Hakaru.Syntax.IClasses
import Language.Hakaru.Syntax.TypeOf
import Language.Hakaru.Types.DataKind
#if __GLASGOW_HASKELL__ < 710
import Control.Applicative
#endif
data EAssoc (abt :: [Hakaru] -> Hakaru -> *)
= forall a . EAssoc !(abt '[] a) !(abt '[] a)
newtype Env (abt :: [Hakaru] -> Hakaru -> *) = Env [EAssoc abt]
emptyEnv :: Env a
emptyEnv = Env []
trivial :: (ABT Term abt) => abt '[] a -> Bool
trivial abt = case viewABT abt of
Var _ -> True
Syn (Literal_ _) -> True
_ -> False
lookupEnv
:: forall abt a . (ABT Term abt)
=> abt '[] a
-> Env abt
-> abt '[] a
lookupEnv start (Env env) = go start env
where
go :: abt '[] a -> [EAssoc abt] -> abt '[] a
go ast [] = ast
go ast (EAssoc a b : xs) =
case jmEq1 (typeOf ast) (typeOf a) of
Just Refl | alphaEq ast a -> go b env
_ -> go ast xs
insertEnv
:: forall abt a . (ABT Term abt)
=> abt '[] a
-> abt '[] a
-> Env abt
-> Env abt
insertEnv ast1 ast2 (Env env)
| trivial ast1 = Env (EAssoc ast2 ast1 : env)
| otherwise = Env (EAssoc ast1 ast2 : env)
newtype CSE (abt :: [Hakaru] -> Hakaru -> *) a = CSE { runCSE :: Reader (Env abt) a }
deriving (Functor, Applicative, Monad, MonadReader (Env abt))
replaceCSE
:: (ABT Term abt)
=> abt '[] a
-> CSE abt (abt '[] a)
replaceCSE abt = lookupEnv abt `fmap` ask
cse :: forall abt a . (ABT Term abt) => abt '[] a -> abt '[] a
cse abt = runReader (runCSE (cse' abt)) emptyEnv
cse' :: forall abt xs a . (ABT Term abt) => abt xs a -> CSE abt (abt xs a)
cse' = loop . viewABT
where
loop :: View (Term abt) ys a -> CSE abt (abt ys a)
loop (Var v) = cseVar v
loop (Syn s) = cseTerm s
loop (Bind v b) = fmap (bind v) (loop b)
cseVar
:: (ABT Term abt)
=> Variable a
-> CSE abt (abt '[] a)
cseVar = replaceCSE . var
mklet :: ABT Term abt => Variable b -> abt '[] b -> abt '[] a -> abt '[] a
mklet v rhs body = syn (Let_ :$ rhs :* bind v body :* End)
cseTerm
:: (ABT Term abt)
=> Term abt a
-> CSE abt (abt '[] a)
cseTerm (Let_ :$ rhs :* body :* End) = do
rhs' <- cse' rhs
caseBind body $ \v body' ->
local (insertEnv rhs' (var v)) $
if trivial rhs'
then cse' body'
else fmap (mklet v rhs') (cse' body')
cseTerm term = traverse21 cse' term >>= replaceCSE . syn