{-# LANGUAGE FlexibleContexts , GADTs , Rank2Types , ScopedTypeVariables , DataKinds , TypeOperators #-} {-# OPTIONS_GHC -Wall -fwarn-tabs #-} --------------------------------------------------------------- module Language.Hakaru.Syntax.AST.Transforms where import qualified Data.Sequence as S import Language.Hakaru.Syntax.ANF (normalize) import Language.Hakaru.Syntax.CSE (cse) import Language.Hakaru.Syntax.Prune (prune) import Language.Hakaru.Syntax.Uniquify (uniquify) import Language.Hakaru.Syntax.Hoist (hoist) import Language.Hakaru.Syntax.ABT import Language.Hakaru.Syntax.AST import Language.Hakaru.Syntax.TypeOf import Language.Hakaru.Syntax.IClasses import Language.Hakaru.Types.DataKind import Language.Hakaru.Expect (expect) import Language.Hakaru.Disintegrate (determine, observe) optimizations :: (ABT Term abt) => abt '[] a -> abt '[] a optimizations = uniquify . prune . cse . hoist -- The hoist pass needs globally uniqiue identifiers . uniquify . normalize underLam :: (ABT Term abt, Monad m) => (abt '[] b -> m (abt '[] b)) -> abt '[] (a ':-> b) -> m (abt '[] (a ':-> b)) underLam f e = caseVarSyn e (return . var) $ \t -> case t of Lam_ :$ e1 :* End -> caseBind e1 $ \x e1' -> do e1'' <- f e1' return . syn $ Lam_ :$ (bind x e1'' :* End) Let_ :$ e1 :* e2 :* End -> case jmEq1 (typeOf e1) (typeOf e) of Just Refl -> do e1' <- underLam f e1 return . syn $ Let_ :$ e1' :* e2 :* End Nothing -> caseBind e2 $ \x e2' -> do e2'' <- underLam f e2' return . syn $ Let_ :$ e1 :* (bind x e2'') :* End _ -> error "TODO: underLam" expandTransformations :: forall abt a . (ABT Term abt) => abt '[] a -> abt '[] a expandTransformations = cataABT var bind alg where alg :: forall b. Term abt b -> abt '[] b alg t = case t of Expect :$ e1 :* e2 :* End -> expect e1 e2 Observe :$ e1 :* e2 :* End -> case determine (observe e1 e2) of Just t' -> t' Nothing -> syn t _ -> syn t coalesce :: forall abt a . (ABT Term abt) => abt '[] a -> abt '[] a coalesce abt = caseVarSyn abt var onNaryOps where onNaryOps (NaryOp_ t es) = syn $ NaryOp_ t (coalesceNaryOp t es) onNaryOps term = syn term coalesceNaryOp :: ABT Term abt => NaryOp a -> S.Seq (abt '[] a) -> S.Seq (abt '[] a) coalesceNaryOp typ args = do abt <- args case viewABT abt of Syn (NaryOp_ typ' args') -> if typ == typ' then coalesceNaryOp typ args' else return (coalesce abt) _ -> return abt