{-# LANGUAGE CPP #-} module Agda.Compiler.Treeless.Simplify (simplifyTTerm) where import Control.Arrow (first, second, (***)) import Control.Monad.Reader import Control.Monad.Writer import Data.Traversable (traverse) import qualified Data.List as List import Agda.Syntax.Treeless import Agda.Syntax.Internal (Substitution'(..)) import Agda.Syntax.Literal import Agda.TypeChecking.Monad import Agda.TypeChecking.Monad.Builtin import Agda.TypeChecking.Primitive import Agda.TypeChecking.Substitute import Agda.Utils.Maybe import Agda.Compiler.Treeless.Subst import Agda.Compiler.Treeless.Pretty import Agda.Compiler.Treeless.Compare import Agda.Utils.Pretty import Agda.Utils.Impossible #include "undefined.h" data SEnv = SEnv { envSubst :: Substitution' TTerm , envRewrite :: [(TTerm, TTerm)] } type S = Reader SEnv runS :: S a -> a runS m = runReader m $ SEnv IdS [] lookupVar :: Int -> S TTerm lookupVar i = asks $ (`lookupS` i) . envSubst onSubst :: (Substitution' TTerm -> Substitution' TTerm) -> S a -> S a onSubst f = local $ \ env -> env { envSubst = f (envSubst env) } onRewrite :: Substitution' TTerm -> S a -> S a onRewrite rho = local $ \ env -> env { envRewrite = map (applySubst rho *** applySubst rho) (envRewrite env) } addRewrite :: TTerm -> TTerm -> S a -> S a addRewrite lhs rhs = local $ \ env -> env { envRewrite = (lhs, rhs) : envRewrite env } underLams :: Int -> S a -> S a underLams i = onRewrite (raiseS i) . onSubst (liftS i) underLam :: S a -> S a underLam = underLams 1 underLet :: TTerm -> S a -> S a underLet u = onRewrite (raiseS 1) . onSubst (\rho -> wkS 1 $ u :# rho) bindVar :: Int -> TTerm -> S a -> S a bindVar x u = onSubst (inplaceS x u `composeS`) rewrite :: TTerm -> S TTerm rewrite t = do rules <- asks envRewrite case [ rhs | (lhs, rhs) <- rules, equalTerms t lhs ] of rhs : _ -> pure rhs [] -> pure t data FunctionKit = FunctionKit { modAux, divAux, natMinus, true, false :: Maybe QName } simplifyTTerm :: TTerm -> TCM TTerm simplifyTTerm t = do kit <- FunctionKit <$> getBuiltinName builtinNatModSucAux <*> getBuiltinName builtinNatDivSucAux <*> getBuiltinName builtinNatMinus <*> getBuiltinName builtinTrue <*> getBuiltinName builtinFalse return $ runS $ simplify kit t simplify :: FunctionKit -> TTerm -> S TTerm simplify FunctionKit{..} = simpl where simpl = rewrite' >=> unchainCase >=> \ t -> case t of TDef{} -> pure t TPrim{} -> pure t TVar{} -> pure t TApp (TDef f) [TLit (LitNat _ 0), m, n, m'] -- div/mod are equivalent to quot/rem on natural numbers. | m == m', Just f == divAux -> simpl $ tOp PQuot n (tPlusK 1 m) | m == m', Just f == modAux -> simpl $ tOp PRem n (tPlusK 1 m) -- Word64 primitives -- -- toWord (a ∙ b) == toWord a ∙64 toWord b TPFn PITo64 (TPOp op a b) | Just op64 <- opTo64 op -> simpl $ tOp op64 (TPFn PITo64 a) (TPFn PITo64 b) where opTo64 op = lookup op [(PAdd, PAdd64), (PSub, PSub64), (PMul, PMul64), (PQuot, PQuot64), (PRem, PRem64)] TApp (TPrim _) _ -> pure t -- taken care of by rewrite' TCoerce t -> TCoerce <$> simpl t TApp f es -> do f <- simpl f es <- traverse simpl es maybeMinusToPrim f es TLam b -> TLam <$> underLam (simpl b) TLit{} -> pure t TCon{} -> pure t TLet e b -> do e <- simpl e case e of TPFn P64ToI a -> do -- Inline calls to P64ToI since these trigger optimisations. -- Ideally, the optimisations would trigger anyway, but at the -- moment they only do if inlining the entire let looks like a -- good idea. let rho = inplaceS 0 (TPFn P64ToI (TVar 0)) tLet a <$> underLet a (simpl (applySubst rho b)) _ -> tLet e <$> underLet e (simpl b) TCase x t d bs -> do v <- lookupVar x let (lets, u) = tLetView v case u of -- TODO: also for literals _ | Just (c, as) <- conView u -> simpl $ matchCon lets c as d bs | Just (k, TVar y) <- plusKView u -> simpl . mkLets lets . TCase y t d =<< mapM (matchPlusK y x k) bs TCase y t1 d1 bs1 -> simpl $ mkLets lets $ TCase y t1 (distrDef case1 d1) $ map (distrCase case1) bs1 where -- Γ x Δ -> Γ _ Δ Θ y, where x maps to y and Θ are the lets n = length lets rho = liftS (x + n + 1) (raiseS 1) `composeS` singletonS (x + n + 1) (TVar 0) `composeS` raiseS (n + 1) case1 = applySubst rho (TCase x t d bs) distrDef v d | isUnreachable d = tUnreachable | otherwise = tLet d v distrCase v (TACon c a b) = TACon c a $ TLet b $ raiseFrom 1 a v distrCase v (TALit l b) = TALit l $ TLet b v distrCase v (TAGuard g b) = TAGuard g $ TLet b v _ -> do d <- simpl d bs <- traverse (simplAlt x) bs tCase x t d bs TUnit -> pure t TSort -> pure t TErased -> pure t TError{} -> pure t conView (TCon c) = Just (c, []) conView (TApp (TCon c) as) = Just (c, as) conView e = Nothing -- Collapse chained cases (case x of bs -> vs; _ -> case x of bs' -> vs' ==> -- case x of bs -> vs; bs' -> vs') unchainCase :: TTerm -> S TTerm unchainCase e@(TCase x t d bs) = do let (lets, u) = tLetView d k = length lets return $ case u of TCase y _ d' bs' | x + k == y -> mkLets lets $ TCase y t d' $ raise k bs ++ bs' _ -> e unchainCase e = return e mkLets es b = foldr TLet b es matchCon _ _ _ d [] = d matchCon lets c as d (TALit{} : bs) = matchCon lets c as d bs matchCon lets c as d (TAGuard{} : bs) = matchCon lets c as d bs matchCon lets c as d (TACon c' a b : bs) | c == c' = flip (foldr TLet) lets $ mkLet 0 as (raiseFrom a (length lets) b) | otherwise = matchCon lets c as d bs where mkLet _ [] b = b mkLet i (a : as) b = TLet (raise i a) $ mkLet (i + 1) as b -- Simplify let y = x + k in case y of j -> u; _ | g[y] -> v -- to let y = x + k in case x of j - k -> u; _ | g[x + k] -> v matchPlusK :: Int -> Int -> Integer -> TAlt -> S TAlt matchPlusK x y k (TALit (LitNat r j) b) = return $ TALit (LitNat r (j - k)) b matchPlusK x y k (TAGuard g b) = flip TAGuard b <$> simpl (applySubst (inplaceS y (tPlusK k (TVar x))) g) matchPlusK x y k TACon{} = __IMPOSSIBLE__ matchPlusK x y k TALit{} = __IMPOSSIBLE__ simplPrim (TApp f@TPrim{} args) = do args <- mapM simpl args inlined <- mapM inline args let u = TApp f args v = simplPrim' (TApp f inlined) pure $ if v `betterThan` u then v else u where inline (TVar x) = do v <- lookupVar x if v == TVar x then pure v else inline v inline (TApp f@TPrim{} args) = TApp f <$> mapM inline args inline u@(TLet _ (TCase 0 _ _ _)) = pure u inline (TLet e b) = inline (subst 0 e b) inline u = pure u simplPrim t = pure t simplPrim' :: TTerm -> TTerm simplPrim' (TApp (TPrim PSeq) (u : v : vs)) | u == v = mkTApp v vs | TApp TCon{} _ <- u = mkTApp v vs | TApp TLit{} _ <- u = mkTApp v vs simplPrim' (TApp (TPrim PLt) [u, v]) | Just (PAdd, k, u) <- constArithView u, Just (PAdd, j, v) <- constArithView v, k == j = tOp PLt u v | Just (PAdd, k, v) <- constArithView v, TApp (TPrim P64ToI) [u] <- u, k >= 2^64, Just trueCon <- true = TCon trueCon | Just k <- intView u , Just j <- intView v , Just trueCon <- true , Just falseCon <- false = if k < j then TCon trueCon else TCon falseCon simplPrim' (TApp (TPrim op) [u, v]) | elem op [PGeq, PLt, PEqI] , Just (PAdd, k, u) <- constArithView u , Just j <- intView v = TApp (TPrim op) [u, tInt (j - k)] simplPrim' (TApp (TPrim PEqI) [u, v]) | Just (op1, k, u) <- constArithView u, Just (op2, j, v) <- constArithView v, op1 == op2, k == j, elem op1 [PAdd, PSub] = tOp PEqI u v simplPrim' (TPOp op u v) | zeroL, isMul || isDiv = tInt 0 | zeroL, isAdd = v | zeroR, isMul = tInt 0 | zeroR, isAdd || isSub = u where zeroL = Just 0 == intView u || Just 0 == word64View u zeroR = Just 0 == intView v || Just 0 == word64View v isAdd = elem op [PAdd, PAdd64] isSub = elem op [PSub, PSub64] isMul = elem op [PMul, PMul64] isDiv = elem op [PQuot, PQuot64, PRem, PRem64] simplPrim' (TApp (TPrim op) [u, v]) | Just u <- negView u, Just v <- negView v, elem op [PMul, PQuot] = tOp op u v | Just u <- negView u, elem op [PMul, PQuot] = simplArith $ tOp PSub (tInt 0) (tOp op u v) | Just v <- negView v, elem op [PMul, PQuot] = simplArith $ tOp PSub (tInt 0) (tOp op u v) simplPrim' (TApp (TPrim PRem) [u, v]) | Just u <- negView u = simplArith $ tOp PSub (tInt 0) (tOp PRem u (unNeg v)) | Just v <- negView v = tOp PRem u v -- (fromWord a == fromWord b) = (a ==64 b) simplPrim' (TPOp op (TPFn P64ToI a) (TPFn P64ToI b)) | Just op64 <- opTo64 op = tOp op64 a b where opTo64 op = lookup op [(PEqI, PEq64), (PLt, PLt64)] -- toWord/fromWord k == fromIntegral k simplPrim' (TPFn PITo64 (TLit (LitNat r n))) = TLit (LitWord64 r (fromIntegral n)) simplPrim' (TPFn P64ToI (TLit (LitWord64 r n))) = TLit (LitNat r (fromIntegral n)) -- toWord (fromWord a) == a simplPrim' (TPFn PITo64 (TPFn P64ToI a)) = a simplPrim' (TApp f@(TPrim op) [u, v]) = simplArith $ TApp f [simplPrim' u, simplPrim' v] simplPrim' u = u unNeg u | Just v <- negView u = v | otherwise = u negView (TApp (TPrim PSub) [a, b]) | Just 0 <- intView a = Just b negView _ = Nothing -- Count arithmetic operations betterThan u v = operations u <= operations v where operations (TApp (TPrim _) [a, b]) = 1 + operations a + operations b operations (TApp (TPrim PSeq) (a : _)) | notVar a = 1000000 -- only seq on variables! operations (TApp (TPrim _) [a]) = 1 + operations a operations TVar{} = 0 operations TLit{} = 0 operations TCon{} = 0 operations TDef{} = 0 operations _ = 1000 notVar TVar{} = False notVar _ = True rewrite' t = rewrite =<< simplPrim t constArithView :: TTerm -> Maybe (TPrim, Integer, TTerm) constArithView (TApp (TPrim op) [TLit (LitNat _ k), u]) | elem op [PAdd, PSub] = Just (op, k, u) constArithView (TApp (TPrim op) [u, TLit (LitNat _ k)]) | op == PAdd = Just (op, k, u) | op == PSub = Just (PAdd, -k, u) constArithView _ = Nothing simplAlt x (TACon c a b) = TACon c a <$> underLams a (maybeAddRewrite (x + a) conTerm $ simpl b) where conTerm = mkTApp (TCon c) [TVar i | i <- reverse $ take a [0..]] simplAlt x (TALit l b) = TALit l <$> maybeAddRewrite x (TLit l) (simpl b) simplAlt x (TAGuard g b) = TAGuard <$> simpl g <*> simpl b -- If x is already bound we add a rewrite, otherwise we bind x to rhs. maybeAddRewrite x rhs cont = do v <- lookupVar x case v of TVar y | x == y -> bindVar x rhs $ cont _ -> addRewrite v rhs cont isTrue (TCon c) = Just c == true isTrue _ = False isFalse (TCon c) = Just c == false isFalse _ = False maybeMinusToPrim f@(TDef minus) es@[a, b] | Just minus == natMinus = do leq <- checkLeq b a if leq then pure $ tOp PSub a b else tApp f es maybeMinusToPrim f es = tApp f es tLet (TVar x) b = subst 0 (TVar x) b tLet e (TVar 0) = e tLet e b = TLet e b tCase :: Int -> CaseInfo -> TTerm -> [TAlt] -> S TTerm tCase x t d [] = pure d tCase x t d bs | isUnreachable d = case reverse bs' of [] -> pure d TALit _ b : as -> tCase x t b (reverse as) TAGuard _ b : as -> tCase x t b (reverse as) TACon c a b : _ -> tCase' x t d bs' | otherwise = do d' <- lookupIfVar d case d' of TCase y _ d bs'' | x == y -> tCase x t d (bs' ++ filter noOverlap bs'') _ -> tCase' x t d bs' where bs' = filter (not . isUnreachable) bs lookupIfVar (TVar i) = lookupVar i lookupIfVar t = pure t noOverlap b = not $ any (overlapped b) bs' overlapped (TACon c _ _) (TACon c' _ _) = c == c' overlapped (TALit l _) (TALit l' _) = l == l' overlapped _ _ = False -- | Drop unreachable cases for Nat and Int cases. pruneLitCases :: Int -> CaseInfo -> TTerm -> [TAlt] -> S TTerm pruneLitCases x t d bs | CTNat == caseType t = case complete bs [] Nothing of Just bs' -> tCase x t tUnreachable bs' Nothing -> return $ TCase x t d bs where complete bs small (Just upper) | null $ [0..upper - 1] List.\\ small = Just [] complete (b@(TALit (LitNat _ n) _) : bs) small upper = (b :) <$> complete bs (n : small) upper complete (b@(TAGuard (TApp (TPrim PGeq) [TVar y, TLit (LitNat _ j)]) _) : bs) small upper | x == y = (b :) <$> complete bs small (Just $ maybe j (min j) upper) complete _ _ _ = Nothing pruneLitCases x t d bs | CTInt == caseType t = return $ TCase x t d bs -- TODO | otherwise = return $ TCase x t d bs tCase' x t d [] = return d tCase' x t d bs = pruneLitCases x t d bs tApp :: TTerm -> [TTerm] -> S TTerm tApp (TLet e b) es = TLet e <$> underLet e (tApp b (raise 1 es)) tApp (TCase x t d bs) es = do d <- tApp d es bs <- mapM (`tAppAlt` es) bs simpl $ TCase x t d bs -- will resimplify branches tApp (TVar x) es = do v <- lookupVar x case v of _ | v /= TVar x && isAtomic v -> tApp v es TLam{} -> tApp v es -- could blow up the code _ -> pure $ mkTApp (TVar x) es tApp f [] = pure f tApp (TLam b) (TVar i : es) = tApp (subst 0 (TVar i) b) es tApp (TLam b) (e : es) = tApp (TLet e b) es tApp f es = pure $ TApp f es tAppAlt (TACon c a b) es = TACon c a <$> underLams a (tApp b (raise a es)) tAppAlt (TALit l b) es = TALit l <$> tApp b es tAppAlt (TAGuard g b) es = TAGuard g <$> tApp b es isAtomic v = case v of TVar{} -> True TCon{} -> True TPrim{} -> True TDef{} -> True TLit{} -> True TSort{} -> True TErased{} -> True TError{} -> True _ -> False checkLeq a b = do rho <- asks envSubst rwr <- asks envRewrite let nf = toArith . applySubst rho less = [ (nf a, nf b) | (TPOp PLt a b, rhs) <- rwr, isTrue rhs ] leq = [ (nf b, nf a) | (TPOp PLt a b, rhs) <- rwr, isFalse rhs ] match (j, as) (k, bs) | as == bs = Just (j - k) | otherwise = Nothing -- Do we have x ≤ y given x' < y' + d ? matchEqn d x y (x', y') = isJust $ do k <- match x x' -- x = x' + k j <- match y y' -- y = y' + j guard (k <= j + d) -- x ≤ y if k ≤ j + d matchLess = matchEqn 1 matchLeq = matchEqn 0 literal (j, []) (k, []) = j <= k literal _ _ = False -- k + fromWord x ≤ y if k + 2^64 - 1 ≤ y wordUpperBound (k, [Pos (TApp (TPrim P64ToI) _)]) y = go (k + 2^64 - 1, []) y wordUpperBound _ _ = False -- x ≤ k + fromWord y if x ≤ k wordLowerBound a (k, [Pos (TApp (TPrim P64ToI) _)]) = go a (k, []) wordLowerBound _ _ = False go x y = or [ literal x y , wordUpperBound x y , wordLowerBound x y , any (matchLess x y) less , any (matchLeq x y) leq ] return $ go (nf a) (nf b) type Arith = (Integer, [Atom]) data Atom = Pos TTerm | Neg TTerm deriving (Show, Eq, Ord) aNeg :: Atom -> Atom aNeg (Pos a) = Neg a aNeg (Neg a) = Pos a aCancel :: [Atom] -> [Atom] aCancel (a : as) | elem (aNeg a) as = aCancel (List.delete (aNeg a) as) | otherwise = a : aCancel as aCancel [] = [] sortR :: Ord a => [a] -> [a] sortR = List.sortBy (flip compare) aAdd :: Arith -> Arith -> Arith aAdd (a, xs) (b, ys) = (a + b, aCancel $ sortR $ xs ++ ys) aSub :: Arith -> Arith -> Arith aSub (a, xs) (b, ys) = (a - b, aCancel $ sortR $ xs ++ map aNeg ys) fromArith :: Arith -> TTerm fromArith (n, []) = tInt n fromArith (0, xs) | (ys, Pos a : zs) <- break isPos xs = foldl addAtom a (ys ++ zs) fromArith (n, xs) | n < 0, (ys, Pos a : zs) <- break isPos xs = tOp PSub (foldl addAtom a (ys ++ zs)) (tInt (-n)) fromArith (n, xs) = foldl addAtom (tInt n) xs isPos :: Atom -> Bool isPos Pos{} = True isPos Neg{} = False addAtom :: TTerm -> Atom -> TTerm addAtom t (Pos a) = tOp PAdd t a addAtom t (Neg a) = tOp PSub t a toArith :: TTerm -> Arith toArith t | Just n <- intView t = (n, []) toArith (TApp (TPrim PAdd) [a, b]) = aAdd (toArith a) (toArith b) toArith (TApp (TPrim PSub) [a, b]) = aSub (toArith a) (toArith b) toArith t = (0, [Pos t]) simplArith :: TTerm -> TTerm simplArith = fromArith . toArith