{-# LANGUAGE TypeFamilies, CPP, FlexibleContexts, UndecidableInstances, StandaloneDeriving, RecordWildCards, GADTs, ScopedTypeVariables, PatternGuards, PatternSynonyms #-} module Twee.Constraints where #include "errors.h" --import Twee.Base hiding (equals, Term, pattern Fun, pattern Var, lookup, funs) import qualified Twee.Term as Flat import qualified Data.Map.Strict as Map import Twee.Pretty hiding (equals) import Twee.Utils import Data.Maybe import Data.List import Data.Function import Data.Graph import Data.Map.Strict(Map) import Data.Ord import Twee.Term hiding (lookup) data Atom f = Constant (Fun f) | Variable Var deriving Show deriving instance Eq (Fun f) => Eq (Atom f) deriving instance Ord (Fun f) => Ord (Atom f) {-# INLINE atoms #-} atoms :: Term f -> [Atom f] atoms t = aux (singleton t) where aux Empty = [] aux (Cons (Fun f Empty) t) = Constant f:aux t aux (Cons (Var x) t) = Variable x:aux t aux (ConsSym _ t) = aux t toTerm :: Atom f -> Term f toTerm (Constant f) = build (con f) toTerm (Variable x) = build (var x) fromTerm :: Flat.Term f -> Maybe (Atom f) fromTerm (Fun f Empty) = Just (Constant f) fromTerm (Var x) = Just (Variable x) fromTerm _ = Nothing instance (Numbered f, PrettyTerm f) => Pretty (Atom f) where pPrint = pPrint . toTerm data Formula f = Less (Atom f) (Atom f) | LessEq (Atom f) (Atom f) | And [Formula f] | Or [Formula f] deriving Show deriving instance Eq (Fun f) => Eq (Formula f) deriving instance Ord (Fun f) => Ord (Formula f) instance (Numbered f, PrettyTerm f) => Pretty (Formula f) where pPrintPrec _ _ (Less t u) = hang (pPrint t <+> text "<") 2 (pPrint u) pPrintPrec _ _ (LessEq t u) = hang (pPrint t <+> text "<=") 2 (pPrint u) pPrintPrec _ _ (And []) = text "true" pPrintPrec _ _ (Or []) = text "false" pPrintPrec l p (And xs) = pPrintParen (p > 10) (fsep (punctuate (text " &") (nest_ (map (pPrintPrec l 11) xs)))) where nest_ (x:xs) = x:map (nest 2) xs nest_ [] = __ pPrintPrec l p (Or xs) = pPrintParen (p > 10) (fsep (punctuate (text " |") (nest_ (map (pPrintPrec l 11) xs)))) where nest_ (x:xs) = x:map (nest 2) xs nest_ [] = __ negateFormula :: Formula f -> Formula f negateFormula (Less t u) = LessEq u t negateFormula (LessEq t u) = Less u t negateFormula (And ts) = Or (map negateFormula ts) negateFormula (Or ts) = And (map negateFormula ts) conj forms | false `elem` forms' = false | otherwise = case forms' of [x] -> x xs -> And xs where flatten (And xs) = xs flatten x = [x] forms' = filter (/= true) (usort (concatMap flatten forms)) disj forms | true `elem` forms' = true | otherwise = case forms' of [x] -> x xs -> Or xs where flatten (Or xs) = xs flatten x = [x] forms' = filter (/= false) (usort (concatMap flatten forms)) x &&& y = conj [x, y] x ||| y = disj [x, y] true = And [] false = Or [] data Branch f = -- Branches are kept normalised wrt equals Branch { funs :: [Fun f], less :: [(Atom f, Atom f)], equals :: [(Atom f, Atom f)] } -- greatest atom first deriving instance Eq (Fun f) => Eq (Branch f) deriving instance Ord (Fun f) => Ord (Branch f) instance (Numbered f, PrettyTerm f) => Pretty (Branch f) where pPrint Branch{..} = braces $ fsep $ punctuate (text ",") $ [pPrint x <+> text "<" <+> pPrint y | (x, y) <- less ] ++ [pPrint x <+> text "=" <+> pPrint y | (x, y) <- equals ] trueBranch :: Branch f trueBranch = Branch [] [] [] norm :: Eq f => Branch f -> Atom f -> Atom f norm Branch{..} x = fromMaybe x (lookup x equals) contradictory :: (Numbered f, Minimal f, Ord f) => Branch f -> Bool contradictory Branch{..} = or [f == minimal | (_, Constant f) <- less] || or [f /= g | (Constant f, Constant g) <- equals] || any cyclic (stronglyConnComp [(x, x, [y | (x', y) <- less, x == x']) | x <- usort (map fst less)]) where cyclic (AcyclicSCC _) = False cyclic (CyclicSCC _) = True formAnd :: (Numbered f, Minimal f, Ord f) => Formula f -> [Branch f] -> [Branch f] formAnd f bs = usort (bs >>= add f) where add (Less t u) b = addLess t u b add (LessEq t u) b = addLess t u b ++ addEquals t u b add (And []) b = [b] add (And (f:fs)) b = add f b >>= add (And fs) add (Or fs) b = usort (concat [ add f b | f <- fs ]) branches :: (Numbered f, Minimal f, Ord f) => Formula f -> [Branch f] branches x = aux [x] where aux [] = [Branch [] [] []] aux (And xs:ys) = aux (xs ++ ys) aux (Or xs:ys) = usort $ concat [aux (x:ys) | x <- xs] aux (Less t u:xs) = usort $ concatMap (addLess t u) (aux xs) aux (LessEq t u:xs) = usort $ concatMap (addLess t u) (aux xs) ++ concatMap (addEquals u t) (aux xs) addLess :: (Numbered f, Minimal f, Ord f) => Atom f -> Atom f -> Branch f -> [Branch f] addLess _ (Constant min) _ | min == minimal = [] addLess (Constant min) _ b | min == minimal = [b] addLess t0 u0 b@Branch{..} = filter (not . contradictory) [addTerm t (addTerm u b{less = usort ((t, u):less)})] where t = norm b t0 u = norm b u0 addEquals :: (Numbered f, Minimal f, Ord f) => Atom f -> Atom f -> Branch f -> [Branch f] addEquals t0 u0 b@Branch{..} | t == u || (t, u) `elem` equals = [b] | otherwise = filter (not . contradictory) [addTerm t (addTerm u b { equals = usort $ (t, u):[(x', y') | (x, y) <- equals, let (y', x') = sort2 (sub x, sub y), x' /= y'], less = usort $ [(sub x, sub y) | (x, y) <- less] })] where sort2 (x, y) = (min x y, max x y) (u, t) = sort2 (norm b t0, norm b u0) sub x | x == t = u | otherwise = x addTerm :: (Numbered f, Minimal f, Ord f) => Atom f -> Branch f -> Branch f addTerm (Constant f) b | f `notElem` funs b = b { funs = f:funs b, less = [ (Constant f, Constant g) | g <- funs b, f < g ] ++ [ (Constant g, Constant f) | g <- funs b, g < f ] ++ less b } addTerm _ b = b newtype Model f = Model (Map (Atom f) (Int, Int)) deriving Show -- Representation: map from atom to (major, minor) -- x < y if major x < major y -- x <= y if major x = major y and minor x < minor y instance (Numbered f, PrettyTerm f) => Pretty (Model f) where pPrint (Model m) | Map.size m <= 1 = text "empty" | otherwise = fsep (go (sortBy (comparing snd) (Map.toList m))) where go [(x, _)] = [pPrint x] go ((x, (i, _)):xs@((_, (j, _)):_)) = (pPrint x <+> text rel):go xs where rel = if i == j then "<=" else "<" modelToLiterals :: Model f -> [Formula f] modelToLiterals (Model m) = go (sortBy (comparing snd) (Map.toList m)) where go [] = [] go [_] = [] go ((x, (i, _)):xs@((y, (j, _)):_)) = rel x y:go xs where rel = if i == j then LessEq else Less modelFromOrder :: (Numbered f, Minimal f, Ord f) => [Atom f] -> Model f modelFromOrder xs = Model (Map.fromList [(x, (i, i)) | (x, i) <- zip xs [0..]]) weakenModel :: Ord (Fun f) => Model f -> [Model f] weakenModel (Model m) = [ Model (Map.delete x m) | x <- Map.keys m ] ++ [ Model (Map.fromList xs) | xs <- glue (sortBy (comparing snd) (Map.toList m)), all ok (groupBy ((==) `on` (fst . snd)) xs) ] where glue [] = [] glue [_] = [] glue (a@(_x, (i1, j1)):b@(y, (i2, _)):xs) = [ (a:(y, (i1, j1+1)):xs) | i1 < i2 ] ++ map (a:) (glue (b:xs)) -- We must never make two constants equal ok xs = length [x | (Constant x, _) <- xs] <= 1 varInModel :: (Numbered f, Minimal f, Ord f) => Model f -> Var -> Bool varInModel (Model m) x = Variable x `Map.member` m varGroups :: (Numbered f, Minimal f, Ord f) => Model f -> [(Fun f, [Var], Maybe (Fun f))] varGroups (Model m) = filter nonempty (go minimal (map fst (sortBy (comparing snd) (Map.toList m)))) where go f xs = case span isVariable xs of (_, []) -> [(f, map unVariable xs, Nothing)] (ys, Constant g:zs) -> (f, map unVariable ys, Just g):go g zs isVariable (Constant _) = False isVariable (Variable _) = True unVariable (Variable x) = x nonempty (_, [], _) = False nonempty _ = True class Minimal a where minimal :: a instance (Numbered f, Minimal f) => Minimal (Fun f) where minimal = toFun minimal {-# INLINE lessEqInModel #-} lessEqInModel :: (Numbered f, Minimal f, Ord f) => Model f -> Atom f -> Atom f -> Maybe Strictness lessEqInModel (Model m) x y | Just (a, _) <- Map.lookup x m, Just (b, _) <- Map.lookup y m, a < b = Just Strict | Just a <- Map.lookup x m, Just b <- Map.lookup y m, a < b = Just Nonstrict | x == y = Just Nonstrict | Constant a <- x, Constant b <- y, a < b = Just Strict | Constant a <- x, a == minimal = Just Nonstrict | otherwise = Nothing solve :: (Numbered f, Minimal f, Ord f, PrettyTerm f) => [Atom f] -> Branch f -> Either (Model f) (Subst f) solve xs branch@Branch{..} | null equals && not (all true less) = ERROR("Model " ++ prettyShow model ++ " is not a model of " ++ prettyShow branch ++ " (edges = " ++ prettyShow edges ++ ", vs = " ++ prettyShow vs ++ ")") | null equals = Left model | otherwise = Right sub where sub = fromMaybe __ . flattenSubst $ [(x, toTerm y) | (Variable x, y) <- equals] ++ [(y, toTerm x) | (x@Constant{}, Variable y) <- equals] vs = Constant minimal:reverse (flattenSCCs (stronglyConnComp edges)) edges = [(x, x, [y | (x', y) <- less', x == x']) | x <- as] less' = less ++ [(Constant x, Constant y) | Constant x <- as, Constant y <- as, x < y] as = usort $ xs ++ map fst less ++ map snd less model = modelFromOrder vs true (t, u) = lessEqInModel model t u == Just Strict class Ord f => Ordered f where orientTerms :: Term f -> Term f -> Maybe Ordering orientTerms t u | t == u = Just EQ | lessEq t u = Just LT | lessEq u t = Just GT | otherwise = Nothing lessEq :: Term f -> Term f -> Bool lessIn :: Model f -> Term f -> Term f -> Maybe Strictness data Strictness = Strict | Nonstrict deriving (Eq, Show)