{-# LANGUAGE CPP #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE PatternGuards #-} module Agda.Compiler.Treeless.Simplify (simplifyTTerm) where import Control.Arrow (first, second, (***)) import Control.Applicative import Control.Monad.Reader import Control.Monad.Writer import Data.Traversable (traverse) import Data.List import Agda.Syntax.Treeless import Agda.Syntax.Internal (Substitution'(..)) import Agda.Syntax.Literal import Agda.TypeChecking.Monad import Agda.TypeChecking.Monad.Builtin import Agda.TypeChecking.Primitive import Agda.TypeChecking.Substitute import Agda.Utils.Maybe import Agda.Compiler.Treeless.Subst import Agda.Compiler.Treeless.Pretty import Agda.Compiler.Treeless.Compare import Agda.Utils.Pretty import Agda.Utils.Impossible #include "undefined.h" data SEnv = SEnv { envSubst :: Substitution' TTerm , envRewrite :: [(TTerm, TTerm)] } type S = Reader SEnv runS :: S a -> a runS m = runReader m $ SEnv IdS [] lookupVar :: Int -> S TTerm lookupVar i = asks $ (`lookupS` i) . envSubst onSubst :: (Substitution' TTerm -> Substitution' TTerm) -> S a -> S a onSubst f = local $ \ env -> env { envSubst = f (envSubst env) } onRewrite :: Substitution' TTerm -> S a -> S a onRewrite rho = local $ \ env -> env { envRewrite = map (applySubst rho *** applySubst rho) (envRewrite env) } addRewrite :: TTerm -> TTerm -> S a -> S a addRewrite lhs rhs = local $ \ env -> env { envRewrite = (lhs, rhs) : envRewrite env } underLams :: Int -> S a -> S a underLams i = onRewrite (raiseS i) . onSubst (liftS i) underLam :: S a -> S a underLam = underLams 1 underLet :: TTerm -> S a -> S a underLet u = onRewrite (raiseS 1) . onSubst (\rho -> wkS 1 $ u :# rho) rewrite :: TTerm -> S TTerm rewrite t = do rules <- asks envRewrite case [ rhs | (lhs, rhs) <- rules, equalTerms t lhs ] of rhs : _ -> pure rhs [] -> pure t data FunctionKit = FunctionKit { modAux, divAux, natMinus, true, false :: Maybe QName } simplifyTTerm :: TTerm -> TCM TTerm simplifyTTerm t = do kit <- FunctionKit <$> getBuiltinName builtinNatModSucAux <*> getBuiltinName builtinNatDivSucAux <*> getBuiltinName builtinNatMinus <*> getBuiltinName builtinTrue <*> getBuiltinName builtinFalse return $ runS $ simplify kit t simplify :: FunctionKit -> TTerm -> S TTerm simplify FunctionKit{..} = simpl where simpl t = rewrite' t >>= \ t -> case t of TDef{} -> pure t TPrim{} -> pure t TVar x -> do v <- lookupVar x pure $ if isAtomic v then v else t TApp (TDef f) [TLit (LitNat _ 0), m, n, m'] -- div/mod are equivalent to quot/rem on natural numbers. | m == m', Just f == divAux -> simpl $ tOp PQuot n (tPlusK 1 m) | m == m', Just f == modAux -> simpl $ tOp PRem n (tPlusK 1 m) TApp (TPrim _) _ -> pure t -- taken care of by rewrite' TApp f es -> do f <- simpl f es <- traverse simpl es maybeMinusToPrim f es TLam b -> TLam <$> underLam (simpl b) TLit{} -> pure t TCon{} -> pure t TLet e b -> do e <- simpl e tLet e <$> underLet e (simpl b) TCase x t d bs -> do v <- lookupVar x let (lets, u) = letView v case u of -- TODO: also for literals _ | Just (c, as) <- conView u -> simpl $ matchCon lets c as d bs TCase y t1 d1 bs1 -> simpl $ mkLets lets $ TCase y t1 (distrDef case1 d1) $ map (distrCase case1) bs1 where -- Γ x Δ -> Γ _ Δ Θ y, where x maps to y and Θ are the lets n = length lets rho = liftS (x + n + 1) (raiseS 1) `composeS` singletonS (x + n + 1) (TVar 0) `composeS` raiseS (n + 1) case1 = applySubst rho (TCase x t d bs) distrDef v d | isUnreachable d = tUnreachable | otherwise = tLet d v distrCase v (TACon c a b) = TACon c a $ TLet b $ raiseFrom 1 a v distrCase v (TALit l b) = TALit l $ TLet b v distrCase v (TAGuard g b) = TAGuard g $ TLet b v _ -> do d <- simpl d bs <- traverse (simplAlt x) bs tCase x t d bs TUnit -> pure t TSort -> pure t TErased -> pure t TError{} -> pure t conView (TCon c) = Just (c, []) conView (TApp (TCon c) as) = Just (c, as) conView e = Nothing letView (TLet e b) = first (e :) $ letView b letView e = ([], e) mkLets es b = foldr TLet b es matchCon _ _ _ d [] = d matchCon lets c as d (TALit{} : bs) = matchCon lets c as d bs matchCon lets c as d (TAGuard{} : bs) = matchCon lets c as d bs matchCon lets c as d (TACon c' a b : bs) | c == c' = flip (foldr TLet) lets $ mkLet 0 as (raiseFrom a (length lets) b) | otherwise = matchCon lets c as d bs where mkLet _ [] b = b mkLet i (a : as) b = TLet (raise i a) $ mkLet (i + 1) as b simplPrim (TApp f@TPrim{} args) = do args <- mapM simpl args inlined <- mapM inline args let u = TApp f args v = simplPrim' (TApp f inlined) pure $ if v `betterThan` u then v else u where inline (TVar x) = do v <- lookupVar x if v == TVar x then pure v else inline v inline (TApp f@TPrim{} args) = TApp f <$> mapM inline args inline u@(TLet _ (TCase 0 _ _ _)) = pure u inline (TLet e b) = inline (subst 0 e b) inline u = pure u simplPrim t = pure t simplPrim' :: TTerm -> TTerm simplPrim' (TApp (TPrim PLt) [u, v]) | Just (PAdd, k, u) <- constArithView u, Just (PAdd, j, v) <- constArithView v, k == j = tOp PLt u v simplPrim' (TApp (TPrim PEq) [u, v]) | Just (op1, k, u) <- constArithView u, Just (op2, j, v) <- constArithView v, op1 == op2, k == j, elem op1 [PAdd, PSub] = tOp PEq u v simplPrim' (TApp (TPrim PMul) [u, v]) | Just 0 <- intView u = tInt 0 | Just 0 <- intView v = tInt 0 simplPrim' (TApp (TPrim op) [u, v]) | Just u <- negView u, Just v <- negView v, elem op [PMul, PQuot] = tOp op u v | Just u <- negView u, elem op [PMul, PQuot] = simplArith $ tOp PSub (tInt 0) (tOp op u v) | Just v <- negView v, elem op [PMul, PQuot] = simplArith $ tOp PSub (tInt 0) (tOp op u v) simplPrim' (TApp (TPrim PRem) [u, v]) | Just u <- negView u = simplArith $ tOp PSub (tInt 0) (tOp PRem u (unNeg v)) | Just v <- negView v = tOp PRem u v simplPrim' (TApp f@(TPrim op) [u, v]) = simplArith $ TApp f [simplPrim' u, simplPrim' v] simplPrim' u = u unNeg u | Just v <- negView u = v | otherwise = u negView (TApp (TPrim PSub) [a, b]) | Just 0 <- intView a = Just b negView _ = Nothing -- Count arithmetic operations betterThan u v = operations u <= operations v where operations (TApp (TPrim _) [a, b]) = 1 + operations a + operations b operations TVar{} = 0 operations TLit{} = 0 operations _ = 1000 rewrite' t = rewrite =<< simplPrim t constArithView :: TTerm -> Maybe (TPrim, Integer, TTerm) constArithView (TApp (TPrim op) [TLit (LitNat _ k), u]) | elem op [PAdd, PSub] = Just (op, k, u) constArithView (TApp (TPrim op) [u, TLit (LitNat _ k)]) | op == PAdd = Just (op, k, u) | op == PSub = Just (PAdd, -k, u) constArithView _ = Nothing simplAlt x (TACon c a b) = TACon c a <$> underLams a (maybeAddRewrite (x + a) conTerm $ simpl b) where conTerm = mkTApp (TCon c) [TVar i | i <- reverse $ take a [0..]] simplAlt x (TALit l b) = TALit l <$> maybeAddRewrite x (TLit l) (simpl b) simplAlt x (TAGuard g b) = TAGuard <$> simpl g <*> simpl b maybeAddRewrite x rhs cont = do v <- lookupVar x case v of TVar y | x == y -> cont _ -> addRewrite v rhs cont isTrue (TCon c) = Just c == true isTrue _ = False isFalse (TCon c) = Just c == false isFalse _ = False maybeMinusToPrim f@(TDef minus) es@[a, b] | Just minus == natMinus = do b_a <- rewrite' (tOp PLt b a) b_sa <- rewrite' (tOp PLt b (tOp PAdd (tInt 1) a)) a_b <- rewrite' (tOp PLt a b) if isTrue b_a || isTrue b_sa || isFalse b_a && isFalse a_b then pure $ tOp PSub a b else tApp f es maybeMinusToPrim f es = tApp f es tLet (TVar x) b = subst 0 (TVar x) b tLet e b = TLet e b tCase :: Int -> CaseType -> TTerm -> [TAlt] -> S TTerm tCase x t d bs | isUnreachable d = case reverse bs' of [] -> pure d TALit _ b : as -> pure $ tCase' x t b (reverse as) TAGuard _ b : as -> pure $ tCase' x t b (reverse as) TACon c a b : _ -> pure $ tCase' x t d bs' | otherwise = pure $ TCase x t d bs' where bs' = filter (not . isUnreachable) bs tCase' x t d [] = d tCase' x t d bs = TCase x t d bs tApp :: TTerm -> [TTerm] -> S TTerm tApp (TLet e b) es = TLet e <$> underLet e (tApp b (raise 1 es)) tApp (TCase x t d bs) es = do d <- tApp d es bs <- mapM (`tAppAlt` es) bs simpl $ TCase x t d bs -- will resimplify branches tApp (TVar x) es = do v <- lookupVar x case v of _ | v /= TVar x && isAtomic v -> tApp v es TLam{} -> tApp v es -- could blow up the code _ -> pure $ mkTApp (TVar x) es tApp f [] = pure f tApp (TLam b) (TVar i : es) = tApp (subst 0 (TVar i) b) es tApp (TLam b) (e : es) = tApp (TLet e b) es tApp f es = pure $ TApp f es tAppAlt (TACon c a b) es = TACon c a <$> underLams a (tApp b (raise a es)) tAppAlt (TALit l b) es = TALit l <$> tApp b es tAppAlt (TAGuard g b) es = TAGuard g <$> tApp b es isAtomic v = case v of TVar{} -> True TCon{} -> True TPrim{} -> True TDef{} -> True TLit{} -> True TSort{} -> True TErased{} -> True TError{} -> True _ -> False type Arith = (Integer, [Atom]) data Atom = Pos TTerm | Neg TTerm deriving (Show, Eq, Ord) aNeg :: Atom -> Atom aNeg (Pos a) = Neg a aNeg (Neg a) = Pos a aCancel :: [Atom] -> [Atom] aCancel (a : as) | elem (aNeg a) as = aCancel (delete (aNeg a) as) | otherwise = a : aCancel as aCancel [] = [] sortR :: Ord a => [a] -> [a] sortR = sortBy (flip compare) aAdd :: Arith -> Arith -> Arith aAdd (a, xs) (b, ys) = (a + b, aCancel $ sortR $ xs ++ ys) aSub :: Arith -> Arith -> Arith aSub (a, xs) (b, ys) = (a - b, aCancel $ sortR $ xs ++ map aNeg ys) fromArith :: Arith -> TTerm fromArith (n, []) = tInt n fromArith (0, xs) | (ys, Pos a : zs) <- break isPos xs = foldl addAtom a (ys ++ zs) fromArith (n, xs) | n < 0, (ys, Pos a : zs) <- break isPos xs = tOp PSub (foldl addAtom a (ys ++ zs)) (tInt (-n)) fromArith (n, xs) = foldl addAtom (tInt n) xs isPos :: Atom -> Bool isPos Pos{} = True isPos Neg{} = False addAtom :: TTerm -> Atom -> TTerm addAtom t (Pos a) = tOp PAdd t a addAtom t (Neg a) = tOp PSub t a toArith :: TTerm -> Arith toArith t | Just n <- intView t = (n, []) toArith (TApp (TPrim PAdd) [a, b]) = aAdd (toArith a) (toArith b) toArith (TApp (TPrim PSub) [a, b]) = aSub (toArith a) (toArith b) toArith t = (0, [Pos t]) simplArith :: TTerm -> TTerm simplArith = fromArith . toArith