module Language.Syntactic.Constructs.Binding.Optimize where
import Control.Monad.Writer
import Data.Set as Set
import Data.Proxy
import Language.Syntactic
import Language.Syntactic.Constructs.Binding
import Language.Syntactic.Constructs.Condition
import Language.Syntactic.Constructs.Construct
import Language.Syntactic.Constructs.Identity
import Language.Syntactic.Constructs.Literal
import Language.Syntactic.Constructs.Tuple
type ConstFolder dom = forall a . ASTF dom a -> a -> ASTF dom a
class EvalBind dom => Optimize sub ctx dom
where
optimizeSym
:: Proxy ctx
-> ConstFolder dom
-> sub a
-> Args (AST dom) a
-> Writer (Set VarId) (ASTF dom (DenResult a))
instance (Optimize sub1 ctx dom, Optimize sub2 ctx dom) =>
Optimize (sub1 :+: sub2) ctx dom
where
optimizeSym ctx constFold (InjL a) = optimizeSym ctx constFold a
optimizeSym ctx constFold (InjR a) = optimizeSym ctx constFold a
optimizeM :: Optimize dom ctx dom
=> Proxy ctx
-> ConstFolder dom
-> ASTF dom a
-> Writer (Set VarId) (ASTF dom a)
optimizeM ctx constFold = transformNode (optimizeSym ctx constFold)
optimize :: Optimize dom ctx dom =>
Proxy ctx -> ConstFolder dom -> ASTF dom a -> ASTF dom a
optimize ctx constFold = fst . runWriter . optimizeM ctx constFold
optimizeSymDefault
:: ( sub :<: dom
, WitnessCons sub
, Optimize dom ctx dom
)
=> Proxy ctx
-> ConstFolder dom
-> sub a
-> Args (AST dom) a
-> Writer (Set VarId) (ASTF dom (DenResult a))
optimizeSymDefault ctx constFold sym@(witnessCons -> ConsWit) args = do
(args',vars) <- listen $ mapArgsM (optimizeM ctx constFold) args
let result = appArgs (Sym $ inj sym) args'
value = evalBind result
if Set.null vars
then return $ constFold result value
else return result
instance (Identity ctx' :<: dom, Optimize dom ctx dom) => Optimize (Identity ctx') ctx dom where optimizeSym = optimizeSymDefault
instance (Construct ctx' :<: dom, Optimize dom ctx dom) => Optimize (Construct ctx') ctx dom where optimizeSym = optimizeSymDefault
instance (Literal ctx' :<: dom, Optimize dom ctx dom) => Optimize (Literal ctx') ctx dom where optimizeSym = optimizeSymDefault
instance (Tuple ctx' :<: dom, Optimize dom ctx dom) => Optimize (Tuple ctx') ctx dom where optimizeSym = optimizeSymDefault
instance (Select ctx' :<: dom, Optimize dom ctx dom) => Optimize (Select ctx') ctx dom where optimizeSym = optimizeSymDefault
instance (Let ctxa ctxb :<: dom, Optimize dom ctx dom) => Optimize (Let ctxa ctxb) ctx dom where optimizeSym = optimizeSymDefault
instance
( Condition ctx' :<: dom
, Lambda ctx :<: dom
, Variable ctx :<: dom
, AlphaEq dom dom dom [(VarId,VarId)]
, Optimize dom ctx dom
) =>
Optimize (Condition ctx') ctx dom
where
optimizeSym ctx constFold cond@Condition args@(c :* t :* e :* Nil)
| Set.null cVars = optimizeM ctx constFold t_or_e
| alphaEq t e = optimizeM ctx constFold t
| otherwise = optimizeSymDefault ctx constFold cond args
where
(c',cVars) = runWriter $ optimizeM ctx constFold c
t_or_e = if evalBind c' then t else e
instance (Variable ctx :<: dom, Optimize dom ctx dom) =>
Optimize (Variable ctx) ctx dom
where
optimizeSym _ _ var@(Variable v) Nil = do
tell (singleton v)
return (inj var)
instance (Lambda ctx :<: dom, Optimize dom ctx dom) =>
Optimize (Lambda ctx) ctx dom
where
optimizeSym ctx constFold lam@(Lambda v) (body :* Nil) = do
body' <- censor (delete v) $ optimizeM ctx constFold body
return $ inj lam :$ body'