-- | An implementation of Knuth-Bendix ordering. {-# LANGUAGE PatternGuards #-} module Twee.KBO(lessEq, lessIn) where import Twee.Base hiding (lessEq, lessIn) import Data.List import Twee.Constraints hiding (lessEq, lessIn) import qualified Data.Map.Strict as Map import Data.Map.Strict(Map) import Data.Maybe import Control.Monad -- | Check if one term is less than another in KBO. lessEq :: Function f => Term f -> Term f -> Bool lessEq (App f Empty) _ | f == minimal = True lessEq (Var x) (Var y) | x == y = True lessEq _ (Var _) = False lessEq (Var x) t = x `elem` vars t lessEq t@(App f ts) u@(App g us) = (st < su || (st == su && f << g) || (st == su && f == g && lexLess ts us)) && xs `isSubsequenceOf` ys where lexLess Empty Empty = True lexLess (Cons t ts) (Cons u us) | t == u = lexLess ts us | otherwise = lessEq t u && case unify t u of Nothing -> True Just sub | not (allSubst (\_ (Cons t Empty) -> isMinimal t) sub) -> error "weird term inequality" | otherwise -> lexLess (subst sub ts) (subst sub us) lexLess _ _ = error "incorrect function arity" xs = sort (vars t) ys = sort (vars u) st = size t su = size u -- | Check if one term is less than another in a given model. -- See "notes/kbo under assumptions" for how this works. lessIn :: Function f => Model f -> Term f -> Term f -> Maybe Strictness lessIn model t u = case sizeLessIn model t u of Nothing -> Nothing Just Strict -> Just Strict Just Nonstrict -> lexLessIn model t u sizeLessIn :: Function f => Model f -> Term f -> Term f -> Maybe Strictness sizeLessIn model t u = case minimumIn model m of Just l | l > -k -> Just Strict | l == -k -> Just Nonstrict _ -> Nothing where (k, m) = foldr (addSize id) (foldr (addSize negate) (0, Map.empty) (subterms t)) (subterms u) addSize op (App f _) (k, m) = (k + op (size f), m) addSize op (Var x) (k, m) = (k, Map.insertWith (+) x (op 1) m) minimumIn :: Function f => Model f -> Map Var Int -> Maybe Int minimumIn model t = liftM2 (+) (fmap sum (mapM minGroup (varGroups model))) (fmap sum (mapM minOrphan (Map.toList t))) where minGroup (lo, xs, mhi) | all (>= 0) sums = Just (sum coeffs * size lo) | otherwise = case mhi of Nothing -> Nothing Just hi -> let coeff = negate (minimum coeffs) in Just $ sum coeffs * size lo + coeff * (size lo - size hi) where coeffs = map (\x -> Map.findWithDefault 0 x t) xs sums = scanr1 (+) coeffs minOrphan (x, k) | varInModel model x = Just 0 | k < 0 = Nothing | otherwise = Just k lexLessIn :: Function f => Model f -> Term f -> Term f -> Maybe Strictness lexLessIn _ t u | t == u = Just Nonstrict lexLessIn cond t u | Just a <- fromTerm t, Just b <- fromTerm u, Just x <- lessEqInModel cond a b = Just x | Just a <- fromTerm t, any isJust [ lessEqInModel cond a b | v <- properSubterms u, Just b <- [fromTerm v]] = Just Strict lexLessIn cond (App f ts) (App g us) | f == g = loop ts us | f << g = Just Strict | otherwise = Nothing where loop Empty Empty = Just Nonstrict loop (Cons t ts) (Cons u us) | t == u = loop ts us | otherwise = case lessIn cond t u of Nothing -> Nothing Just Strict -> Just Strict Just Nonstrict -> let Just sub = unify t u in loop (subst sub ts) (subst sub us) loop _ _ = error "incorrect function arity" lexLessIn _ t _ | isMinimal t = Just Nonstrict lexLessIn _ _ _ = Nothing