{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE CPP #-} module Tip.Pass.Lift (lambdaLift, letLift, axiomatizeLambdas, boolOpLift) where #include "errors.h" import Tip.Core import Tip.Fresh import Tip.Utils import Data.Char (toLower) import Data.Either import Data.List import Data.Generics.Geniplate import Control.Applicative import Control.Monad import Control.Monad.Writer import qualified Data.Map as Map type LiftM a = WriterT [Function a] Fresh type TopLift a = Expr a -> LiftM a (Expr a) liftAnywhere :: (Name a,TransformBiM (LiftM a) (Expr a) (t a)) => TopLift a -> t a -> Fresh (t a,[Function a]) liftAnywhere top = runWriterT . transformExprInM top liftTheory :: Name a => TopLift a -> Theory a -> Fresh (Theory a) liftTheory top thy0 = do (Theory{..},new_func_decls) <- liftAnywhere top thy0 return Theory{thy_funcs = new_func_decls ++ thy_funcs,..} lambdaLiftTop :: Name a => TopLift a lambdaLiftTop e0 = case e0 of Lam lam_args lam_body -> do g_name <- lift (freshNamed "lam") let g_args = free e0 let g_tvs = freeTyVars e0 let g_type = map lcl_type lam_args :=>: exprType lam_body let g = Function g_name g_tvs g_args g_type (Lam lam_args lam_body) tell [g] return (applyFunction g (map TyVar g_tvs) (map Lcl g_args)) _ -> return e0 -- | Defunctionalization. -- -- > f x = ... \ y -> e [ x ] ... -- -- becomes -- -- > f x = ... g x ... -- > g x = \ y -> e [ x ] -- -- where @g@ is a fresh function. -- -- After this pass, lambdas only exist at the top level of functions. lambdaLift :: Name a => Theory a -> Fresh (Theory a) lambdaLift = liftTheory lambdaLiftTop letLiftTop :: Name a => TopLift a letLiftTop e0 = case e0 of Let xl@(Local x xt) b e -> do let fvs = free b let tvs = freeTyVars b let xfn = Function x tvs fvs (exprType b) b tell [xfn] lift ((applyFunction xfn (map TyVar tvs) (map Lcl fvs) // xl) e) _ -> return e0 -- | Lift lets to the top level. -- -- > let x = b[fvs] in e[x] -- -- becomes -- -- > e[f fvs] -- > f fvs = b[fvs] letLift :: Name a => Theory a -> Fresh (Theory a) letLift = liftTheory letLiftTop axLamFunc :: Function a -> Maybe (Signature a,Formula a) axLamFunc Function{..} = case func_body of Lam lam_args e -> let abs = Signature func_name (PolyType func_tvs (map lcl_type func_args) func_res) fm = Formula Assert (Defunction func_name) func_tvs (mkQuant Forall (func_args ++ lam_args) (apply (applySignature abs (map TyVar func_tvs) (map Lcl func_args)) (map Lcl lam_args) === e)) in Just (abs,fm) _ -> Nothing -- | Axiomatize lambdas. -- -- > f x = \ y -> E[x,y] -- -- becomes -- -- > declare-fun f ... -- > assert (forall x y . @ (f x) y = E[x,y]) axiomatizeLambdas :: forall a. Name a => Theory a -> Fresh (Theory a) axiomatizeLambdas thy0 = do arrows <- fmap Map.fromList (mapM makeArrow arities) ats <- fmap Map.fromList (mapM (makeAt arrows) arities) return $ transformBi (eliminateArrows arrows) $ transformBi (eliminateAts ats) thy { thy_sigs = Map.elems ats ++ thy_sigs thy, thy_sorts = Map.elems arrows ++ thy_sorts thy } where thy = thy0 { thy_sigs = new_abs ++ thy_sigs thy0, thy_funcs = survivors, thy_asserts = new_form ++ thy_asserts thy0 } (survivors,new) = partitionEithers [ maybe (Left fn) Right (axLamFunc fn) | fn <- thy_funcs thy0 ] (new_abs,new_form) = unzip new arities = usort [ length args | args :=>: _ <- universeBi thy :: [Type a] ] makeArrow n = do ty <- freshNamed ("fun" ++ show n) tvs <- replicateM (n+1) fresh return (n, Sort ty tvs) makeAt arrows n = do name <- freshNamed ("apply" ++ show n) tvs <- mapM (freshNamed . mkTyVarName) [0..(n-1)] tv <- freshNamed (mkTyVarName n) let Sort{..} = Map.findWithDefault __ n arrows ty = TyCon sort_name (map TyVar (tvs ++ [tv])) return $ (n, Signature name (PolyType (tvs ++ [tv]) (ty:map TyVar tvs) (TyVar tv))) eliminateArrows arrows (args :=>: res) = TyCon sort_name (map (eliminateArrows arrows) (args ++ [res])) where Sort{..} = Map.findWithDefault __ (length args) arrows eliminateArrows _ ty = ty eliminateAts ats (Builtin At :@: (e:es)) = Gbl (Global sig_name sig_type (args ++ [res])) :@: map (eliminateAts ats) (e:es) where args :=>: res = exprType e Signature{..} = Map.findWithDefault __ (length args) ats eliminateAts _ e = e mkTyVarName :: Int -> String mkTyVarName x = vars !! x where vars = ["a","b","c","d"] ++ ["t" ++ show i | i <- [0..]] boolOpTop :: Name a => TopLift a boolOpTop e0 = case e0 of Builtin x :@: es | x `elem` [And,Or,Implies] -> do f <- lift (freshNamed (map toLower (show x))) as <- lift (sequence [ (`Local` boolType) <$> fresh | _ <- es ]) let fn = Function f [] as boolType (Builtin x :@: map Lcl as) tell [fn] return (applyFunction fn [] es) _ -> return e0 -- | Lifts boolean operators to the top level -- -- replaces -- (and r s t) -- with -- f r s t -- and -- f x y z = and x y z -- -- Run CollapseEqual and BoolOpToIf afterwards boolOpLift :: Name a => Theory a -> Fresh (Theory a) boolOpLift = liftTheory boolOpTop