module Twee.Constraints where
#include "errors.h"
import qualified Twee.Term as Flat
import qualified Data.Map.Strict as Map
import Twee.Pretty hiding (equals)
import Twee.Utils
import Data.Maybe
import Data.List
import Data.Function
import Data.Graph
import Data.Map.Strict(Map)
import Data.Ord
import Twee.Term hiding (lookup)
data Atom f = Constant (Fun f) | Variable Var deriving Show
deriving instance Eq (Fun f) => Eq (Atom f)
deriving instance Ord (Fun f) => Ord (Atom f)
atoms :: Term f -> [Atom f]
atoms t = aux (singleton t)
where
aux Empty = []
aux (Cons (Fun f Empty) t) = Constant f:aux t
aux (Cons (Var x) t) = Variable x:aux t
aux (ConsSym _ t) = aux t
toTerm :: Atom f -> Term f
toTerm (Constant f) = build (con f)
toTerm (Variable x) = build (var x)
fromTerm :: Flat.Term f -> Maybe (Atom f)
fromTerm (Fun f Empty) = Just (Constant f)
fromTerm (Var x) = Just (Variable x)
fromTerm _ = Nothing
instance (Numbered f, PrettyTerm f) => Pretty (Atom f) where
pPrint = pPrint . toTerm
data Formula f =
Less (Atom f) (Atom f)
| LessEq (Atom f) (Atom f)
| And [Formula f]
| Or [Formula f]
deriving Show
deriving instance Eq (Fun f) => Eq (Formula f)
deriving instance Ord (Fun f) => Ord (Formula f)
instance (Numbered f, PrettyTerm f) => Pretty (Formula f) where
pPrintPrec _ _ (Less t u) = hang (pPrint t <+> text "<") 2 (pPrint u)
pPrintPrec _ _ (LessEq t u) = hang (pPrint t <+> text "<=") 2 (pPrint u)
pPrintPrec _ _ (And []) = text "true"
pPrintPrec _ _ (Or []) = text "false"
pPrintPrec l p (And xs) =
pPrintParen (p > 10)
(fsep (punctuate (text " &") (nest_ (map (pPrintPrec l 11) xs))))
where
nest_ (x:xs) = x:map (nest 2) xs
nest_ [] = __
pPrintPrec l p (Or xs) =
pPrintParen (p > 10)
(fsep (punctuate (text " |") (nest_ (map (pPrintPrec l 11) xs))))
where
nest_ (x:xs) = x:map (nest 2) xs
nest_ [] = __
negateFormula :: Formula f -> Formula f
negateFormula (Less t u) = LessEq u t
negateFormula (LessEq t u) = Less u t
negateFormula (And ts) = Or (map negateFormula ts)
negateFormula (Or ts) = And (map negateFormula ts)
conj forms
| false `elem` forms' = false
| otherwise =
case forms' of
[x] -> x
xs -> And xs
where
flatten (And xs) = xs
flatten x = [x]
forms' = filter (/= true) (usort (concatMap flatten forms))
disj forms
| true `elem` forms' = true
| otherwise =
case forms' of
[x] -> x
xs -> Or xs
where
flatten (Or xs) = xs
flatten x = [x]
forms' = filter (/= false) (usort (concatMap flatten forms))
x &&& y = conj [x, y]
x ||| y = disj [x, y]
true = And []
false = Or []
data Branch f =
Branch {
funs :: [Fun f],
less :: [(Atom f, Atom f)],
equals :: [(Atom f, Atom f)] }
deriving instance Eq (Fun f) => Eq (Branch f)
deriving instance Ord (Fun f) => Ord (Branch f)
instance (Numbered f, PrettyTerm f) => Pretty (Branch f) where
pPrint Branch{..} =
braces $ fsep $ punctuate (text ",") $
[pPrint x <+> text "<" <+> pPrint y | (x, y) <- less ] ++
[pPrint x <+> text "=" <+> pPrint y | (x, y) <- equals ]
trueBranch :: Branch f
trueBranch = Branch [] [] []
norm :: Eq f => Branch f -> Atom f -> Atom f
norm Branch{..} x = fromMaybe x (lookup x equals)
contradictory :: (Numbered f, Minimal f, Ord f) => Branch f -> Bool
contradictory Branch{..} =
or [f == minimal | (_, Constant f) <- less] ||
or [f /= g | (Constant f, Constant g) <- equals] ||
any cyclic (stronglyConnComp
[(x, x, [y | (x', y) <- less, x == x']) | x <- usort (map fst less)])
where
cyclic (AcyclicSCC _) = False
cyclic (CyclicSCC _) = True
formAnd :: (Numbered f, Minimal f, Ord f) => Formula f -> [Branch f] -> [Branch f]
formAnd f bs = usort (bs >>= add f)
where
add (Less t u) b = addLess t u b
add (LessEq t u) b = addLess t u b ++ addEquals t u b
add (And []) b = [b]
add (And (f:fs)) b = add f b >>= add (And fs)
add (Or fs) b = usort (concat [ add f b | f <- fs ])
branches :: (Numbered f, Minimal f, Ord f) => Formula f -> [Branch f]
branches x = aux [x]
where
aux [] = [Branch [] [] []]
aux (And xs:ys) = aux (xs ++ ys)
aux (Or xs:ys) = usort $ concat [aux (x:ys) | x <- xs]
aux (Less t u:xs) = usort $ concatMap (addLess t u) (aux xs)
aux (LessEq t u:xs) =
usort $
concatMap (addLess t u) (aux xs) ++
concatMap (addEquals u t) (aux xs)
addLess :: (Numbered f, Minimal f, Ord f) => Atom f -> Atom f -> Branch f -> [Branch f]
addLess _ (Constant min) _ | min == minimal = []
addLess (Constant min) _ b | min == minimal = [b]
addLess t0 u0 b@Branch{..} =
filter (not . contradictory)
[addTerm t (addTerm u b{less = usort ((t, u):less)})]
where
t = norm b t0
u = norm b u0
addEquals :: (Numbered f, Minimal f, Ord f) => Atom f -> Atom f -> Branch f -> [Branch f]
addEquals t0 u0 b@Branch{..}
| t == u || (t, u) `elem` equals = [b]
| otherwise =
filter (not . contradictory)
[addTerm t (addTerm u b {
equals = usort $ (t, u):[(x', y') | (x, y) <- equals, let (y', x') = sort2 (sub x, sub y), x' /= y'],
less = usort $ [(sub x, sub y) | (x, y) <- less] })]
where
sort2 (x, y) = (min x y, max x y)
(u, t) = sort2 (norm b t0, norm b u0)
sub x
| x == t = u
| otherwise = x
addTerm :: (Numbered f, Minimal f, Ord f) => Atom f -> Branch f -> Branch f
addTerm (Constant f) b
| f `notElem` funs b =
b {
funs = f:funs b,
less = [ (Constant f, Constant g) | g <- funs b, f < g ] ++
[ (Constant g, Constant f) | g <- funs b, g < f ] ++ less b }
addTerm _ b = b
newtype Model f = Model (Map (Atom f) (Int, Int))
deriving Show
instance (Numbered f, PrettyTerm f) => Pretty (Model f) where
pPrint (Model m)
| Map.size m <= 1 = text "empty"
| otherwise = fsep (go (sortBy (comparing snd) (Map.toList m)))
where
go [(x, _)] = [pPrint x]
go ((x, (i, _)):xs@((_, (j, _)):_)) =
(pPrint x <+> text rel):go xs
where
rel = if i == j then "<=" else "<"
modelToLiterals :: Model f -> [Formula f]
modelToLiterals (Model m) = go (sortBy (comparing snd) (Map.toList m))
where
go [] = []
go [_] = []
go ((x, (i, _)):xs@((y, (j, _)):_)) =
rel x y:go xs
where
rel = if i == j then LessEq else Less
modelFromOrder :: (Numbered f, Minimal f, Ord f) => [Atom f] -> Model f
modelFromOrder xs =
Model (Map.fromList [(x, (i, i)) | (x, i) <- zip xs [0..]])
weakenModel :: Ord (Fun f) => Model f -> [Model f]
weakenModel (Model m) =
[ Model (Map.delete x m) | x <- Map.keys m ] ++
[ Model (Map.fromList xs)
| xs <- glue (sortBy (comparing snd) (Map.toList m)),
all ok (groupBy ((==) `on` (fst . snd)) xs) ]
where
glue [] = []
glue [_] = []
glue (a@(_x, (i1, j1)):b@(y, (i2, _)):xs) =
[ (a:(y, (i1, j1+1)):xs) | i1 < i2 ] ++
map (a:) (glue (b:xs))
ok xs = length [x | (Constant x, _) <- xs] <= 1
varInModel :: (Numbered f, Minimal f, Ord f) => Model f -> Var -> Bool
varInModel (Model m) x = Variable x `Map.member` m
varGroups :: (Numbered f, Minimal f, Ord f) => Model f -> [(Fun f, [Var], Maybe (Fun f))]
varGroups (Model m) = filter nonempty (go minimal (map fst (sortBy (comparing snd) (Map.toList m))))
where
go f xs =
case span isVariable xs of
(_, []) -> [(f, map unVariable xs, Nothing)]
(ys, Constant g:zs) ->
(f, map unVariable ys, Just g):go g zs
isVariable (Constant _) = False
isVariable (Variable _) = True
unVariable (Variable x) = x
nonempty (_, [], _) = False
nonempty _ = True
class Minimal a where
minimal :: a
instance (Numbered f, Minimal f) => Minimal (Fun f) where
minimal = toFun minimal
lessEqInModel :: (Numbered f, Minimal f, Ord f) => Model f -> Atom f -> Atom f -> Maybe Strictness
lessEqInModel (Model m) x y
| Just (a, _) <- Map.lookup x m,
Just (b, _) <- Map.lookup y m,
a < b = Just Strict
| Just a <- Map.lookup x m,
Just b <- Map.lookup y m,
a < b = Just Nonstrict
| x == y = Just Nonstrict
| Constant a <- x, Constant b <- y, a < b = Just Strict
| Constant a <- x, a == minimal = Just Nonstrict
| otherwise = Nothing
solve :: (Numbered f, Minimal f, Ord f, PrettyTerm f) => [Atom f] -> Branch f -> Either (Model f) (Subst f)
solve xs branch@Branch{..}
| null equals && not (all true less) =
ERROR("Model " ++ prettyShow model ++ " is not a model of " ++ prettyShow branch ++ " (edges = " ++ prettyShow edges ++ ", vs = " ++ prettyShow vs ++ ")")
| null equals = Left model
| otherwise = Right sub
where
sub = fromMaybe __ . flattenSubst $
[(x, toTerm y) | (Variable x, y) <- equals] ++
[(y, toTerm x) | (x@Constant{}, Variable y) <- equals]
vs = Constant minimal:reverse (flattenSCCs (stronglyConnComp edges))
edges = [(x, x, [y | (x', y) <- less', x == x']) | x <- as]
less' = less ++ [(Constant x, Constant y) | Constant x <- as, Constant y <- as, x < y]
as = usort $ xs ++ map fst less ++ map snd less
model = modelFromOrder vs
true (t, u) = lessEqInModel model t u == Just Strict
class Ord f => Ordered f where
orientTerms :: Term f -> Term f -> Maybe Ordering
orientTerms t u
| t == u = Just EQ
| lessEq t u = Just LT
| lessEq u t = Just GT
| otherwise = Nothing
lessEq :: Term f -> Term f -> Bool
lessIn :: Model f -> Term f -> Term f -> Maybe Strictness
data Strictness = Strict | Nonstrict deriving (Eq, Show)