{-# LANGUAGE FlexibleContexts, UndecidableInstances, RecordWildCards #-} -- | Solving constraints on variable ordering. module Twee.Constraints where --import Twee.Base hiding (equals, Term, pattern Fun, pattern Var, lookup, funs) import qualified Twee.Term as Flat import qualified Data.Map.Strict as Map import Twee.Pretty hiding (equals) import Twee.Utils import Data.Maybe import Data.List import Data.Function import Data.Graph import Data.Map.Strict(Map) import Data.Ord import Twee.Term hiding (lookup) data Atom f = Constant (Fun f) | Variable Var deriving (Show, Eq, Ord) {-# INLINE atoms #-} atoms :: Term f -> [Atom f] atoms t = aux (singleton t) where aux Empty = [] aux (Cons (App f Empty) t) = Constant f:aux t aux (Cons (Var x) t) = Variable x:aux t aux (ConsSym _ t) = aux t toTerm :: Atom f -> Term f toTerm (Constant f) = build (con f) toTerm (Variable x) = build (var x) fromTerm :: Flat.Term f -> Maybe (Atom f) fromTerm (App f Empty) = Just (Constant f) fromTerm (Var x) = Just (Variable x) fromTerm _ = Nothing instance PrettyTerm f => Pretty (Atom f) where pPrint = pPrint . toTerm data Formula f = Less (Atom f) (Atom f) | LessEq (Atom f) (Atom f) | And [Formula f] | Or [Formula f] deriving (Eq, Ord, Show) instance PrettyTerm f => Pretty (Formula f) where pPrintPrec _ _ (Less t u) = hang (pPrint t <+> text "<") 2 (pPrint u) pPrintPrec _ _ (LessEq t u) = hang (pPrint t <+> text "<=") 2 (pPrint u) pPrintPrec _ _ (And []) = text "true" pPrintPrec _ _ (Or []) = text "false" pPrintPrec l p (And xs) = maybeParens (p > 10) (fsep (punctuate (text " &") (nest_ (map (pPrintPrec l 11) xs)))) where nest_ (x:xs) = x:map (nest 2) xs nest_ [] = undefined pPrintPrec l p (Or xs) = maybeParens (p > 10) (fsep (punctuate (text " |") (nest_ (map (pPrintPrec l 11) xs)))) where nest_ (x:xs) = x:map (nest 2) xs nest_ [] = undefined negateFormula :: Formula f -> Formula f negateFormula (Less t u) = LessEq u t negateFormula (LessEq t u) = Less u t negateFormula (And ts) = Or (map negateFormula ts) negateFormula (Or ts) = And (map negateFormula ts) conj forms | false `elem` forms' = false | otherwise = case forms' of [x] -> x xs -> And xs where flatten (And xs) = xs flatten x = [x] forms' = filter (/= true) (usort (concatMap flatten forms)) disj forms | true `elem` forms' = true | otherwise = case forms' of [x] -> x xs -> Or xs where flatten (Or xs) = xs flatten x = [x] forms' = filter (/= false) (usort (concatMap flatten forms)) x &&& y = conj [x, y] x ||| y = disj [x, y] true = And [] false = Or [] data Branch f = -- Branches are kept normalised wrt equals Branch { funs :: [Fun f], less :: [(Atom f, Atom f)], -- sorted equals :: [(Atom f, Atom f)] } -- sorted, greatest atom first in each pair deriving (Eq, Ord) instance PrettyTerm f => Pretty (Branch f) where pPrint Branch{..} = braces $ fsep $ punctuate (text ",") $ [pPrint x <+> text "<" <+> pPrint y | (x, y) <- less ] ++ [pPrint x <+> text "=" <+> pPrint y | (x, y) <- equals ] trueBranch :: Branch f trueBranch = Branch [] [] [] norm :: Eq f => Branch f -> Atom f -> Atom f norm Branch{..} x = fromMaybe x (lookup x equals) contradictory :: (Minimal f, Ord f) => Branch f -> Bool contradictory Branch{..} = or [f == minimal | (_, Constant f) <- less] || or [f /= g | (Constant f, Constant g) <- equals] || any cyclic (stronglyConnComp [(x, x, [y | (x', y) <- less, x == x']) | x <- usort (map fst less)]) where cyclic (AcyclicSCC _) = False cyclic (CyclicSCC _) = True formAnd :: (Minimal f, Ordered f) => Formula f -> [Branch f] -> [Branch f] formAnd f bs = usort (bs >>= add f) where add (Less t u) b = addLess t u b add (LessEq t u) b = addLess t u b ++ addEquals t u b add (And []) b = [b] add (And (f:fs)) b = add f b >>= add (And fs) add (Or fs) b = usort (concat [ add f b | f <- fs ]) branches :: (Minimal f, Ordered f) => Formula f -> [Branch f] branches x = aux [x] where aux [] = [Branch [] [] []] aux (And xs:ys) = aux (xs ++ ys) aux (Or xs:ys) = usort $ concat [aux (x:ys) | x <- xs] aux (Less t u:xs) = usort $ concatMap (addLess t u) (aux xs) aux (LessEq t u:xs) = usort $ concatMap (addLess t u) (aux xs) ++ concatMap (addEquals u t) (aux xs) addLess :: (Minimal f, Ordered f) => Atom f -> Atom f -> Branch f -> [Branch f] addLess _ (Constant min) _ | min == minimal = [] addLess (Constant min) _ b | min == minimal = [b] addLess t0 u0 b@Branch{..} = filter (not . contradictory) [addTerm t (addTerm u b{less = usort ((t, u):less)})] where t = norm b t0 u = norm b u0 addEquals :: (Minimal f, Ordered f) => Atom f -> Atom f -> Branch f -> [Branch f] addEquals t0 u0 b@Branch{..} | t == u || (t, u) `elem` equals = [b] | otherwise = filter (not . contradictory) [addTerm t (addTerm u b { equals = usort $ (t, u):[(x', y') | (x, y) <- equals, let (y', x') = sort2 (sub x, sub y), x' /= y'], less = usort $ [(sub x, sub y) | (x, y) <- less] })] where sort2 (x, y) = (min x y, max x y) (u, t) = sort2 (norm b t0, norm b u0) sub x | x == t = u | otherwise = x addTerm :: (Minimal f, Ordered f) => Atom f -> Branch f -> Branch f addTerm (Constant f) b | f `notElem` funs b = b { funs = f:funs b, less = usort $ [ (Constant f, Constant g) | g <- funs b, f << g ] ++ [ (Constant g, Constant f) | g <- funs b, g << f ] ++ less b } addTerm _ b = b newtype Model f = Model (Map (Atom f) (Int, Int)) deriving (Eq, Show) -- Representation: map from atom to (major, minor) -- x < y if major x < major y -- x <= y if major x = major y and minor x < minor y instance PrettyTerm f => Pretty (Model f) where pPrint (Model m) | Map.size m <= 1 = text "empty" | otherwise = fsep (go (sortBy (comparing snd) (Map.toList m))) where go [(x, _)] = [pPrint x] go ((x, (i, _)):xs@((_, (j, _)):_)) = (pPrint x <+> text rel):go xs where rel = if i == j then "<=" else "<" modelToLiterals :: Model f -> [Formula f] modelToLiterals (Model m) = go (sortBy (comparing snd) (Map.toList m)) where go [] = [] go [_] = [] go ((x, (i, _)):xs@((y, (j, _)):_)) = rel x y:go xs where rel = if i == j then LessEq else Less modelFromOrder :: (Minimal f, Ord f) => [Atom f] -> Model f modelFromOrder xs = Model (Map.fromList [(x, (i, i)) | (x, i) <- zip xs [0..]]) weakenModel :: Model f -> [Model f] weakenModel (Model m) = [ Model (Map.delete x m) | x <- Map.keys m ] ++ [ Model (Map.fromList xs) | xs <- glue (sortBy (comparing snd) (Map.toList m)), all ok (groupBy ((==) `on` (fst . snd)) xs) ] where glue [] = [] glue [_] = [] glue (a@(_x, (i1, j1)):b@(y, (i2, _)):xs) = [ (a:(y, (i1, j1+1)):xs) | i1 < i2 ] ++ map (a:) (glue (b:xs)) -- We must never make two constants equal ok xs = length [x | (Constant x, _) <- xs] <= 1 varInModel :: (Minimal f, Ord f) => Model f -> Var -> Bool varInModel (Model m) x = Variable x `Map.member` m varGroups :: (Minimal f, Ord f) => Model f -> [(Fun f, [Var], Maybe (Fun f))] varGroups (Model m) = filter nonempty (go minimal (map fst (sortBy (comparing snd) (Map.toList m)))) where go f xs = case span isVariable xs of (_, []) -> [(f, map unVariable xs, Nothing)] (ys, Constant g:zs) -> (f, map unVariable ys, Just g):go g zs isVariable (Constant _) = False isVariable (Variable _) = True unVariable (Variable x) = x nonempty (_, [], _) = False nonempty _ = True class Minimal f where minimal :: Fun f {-# INLINE lessEqInModel #-} lessEqInModel :: (Minimal f, Ordered f) => Model f -> Atom f -> Atom f -> Maybe Strictness lessEqInModel (Model m) x y | Just (a, _) <- Map.lookup x m, Just (b, _) <- Map.lookup y m, a < b = Just Strict | Just a <- Map.lookup x m, Just b <- Map.lookup y m, a < b = Just Nonstrict | x == y = Just Nonstrict | Constant a <- x, Constant b <- y, a << b = Just Strict | Constant a <- x, a == minimal = Just Nonstrict | otherwise = Nothing solve :: (Minimal f, Ordered f, PrettyTerm f) => [Atom f] -> Branch f -> Either (Model f) (Subst f) solve xs branch@Branch{..} | null equals && not (all true less) = error $ "Model " ++ prettyShow model ++ " is not a model of " ++ prettyShow branch ++ " (edges = " ++ prettyShow edges ++ ", vs = " ++ prettyShow vs ++ ")" | null equals = Left model | otherwise = Right sub where sub = fromMaybe undefined . listToSubst $ [(x, toTerm y) | (Variable x, y) <- equals] ++ [(y, toTerm x) | (x@Constant{}, Variable y) <- equals] vs = Constant minimal:reverse (flattenSCCs (stronglyConnComp edges)) edges = [(x, x, [y | (x', y) <- less', x == x']) | x <- as, x /= Constant minimal] less' = less ++ [(Constant x, Constant y) | Constant x <- as, Constant y <- as, x << y] as = usort $ xs ++ map fst less ++ map snd less model = modelFromOrder vs true (t, u) = lessEqInModel model t u == Just Strict class Ord f => Ordered f where -- | Return 'True' if the first term is less than or equal to the second, -- in the term ordering. lessEq :: Term f -> Term f -> Bool -- | Check if the first term is less than or equal to the second in the given model, -- and decide whether the inequality is strict or nonstrict. lessIn :: Model f -> Term f -> Term f -> Maybe Strictness -- | Describes whether an inequality is strict or nonstrict. data Strictness = -- | The first term is strictly less than the second. Strict -- | The first term is less than or equal to the second. | Nonstrict deriving (Eq, Show) -- | Return 'True' if the first argument is strictly less than the second, -- in the term ordering. lessThan :: Ordered f => Term f -> Term f -> Bool lessThan t u = lessEq t u && isNothing (unify t u) -- | Return the direction in which the terms are oriented according to the term -- ordering, or 'Nothing' if they cannot be oriented. A result of @'Just' 'LT'@ -- means that the first term is less than /or equal to/ the second. orientTerms :: Ordered f => Term f -> Term f -> Maybe Ordering orientTerms t u | t == u = Just EQ | lessEq t u = Just LT | lessEq u t = Just GT | otherwise = Nothing