module EqualitySolver.Solver( EqFormula, EqLiteral, EqTerm, Name, Arity, eqF, eq, neq, var, fun, satisfiableInEq) where import Control.Monad.Identity import Control.Monad.State import Control.Monad.Union as U import Data.List as L import Data.Map as M import Data.Maybe import Data.Set as S -- |Returns true if the conjunction of literals given as an -- argument is satisfiable in the first order theory of -- uninterpreted functions with equality satisfiableInEq :: EqFormula -> Bool satisfiableInEq formula = fst $ runDecideEq $ decideEq formula -- A conjunction of literals data EqFormula = EqFormula (Set EqLiteral) deriving (Eq, Ord, Show) -- |Build a conjunction of literals out of a list of literals eqF = EqFormula . S.fromList allTerms :: EqFormula -> [EqTerm] allTerms (EqFormula lits) = L.concatMap extractTerms $ S.toList lits extractTerms (EqLiteral _ l r) = [l, r] ++ subTerms l ++ subTerms r contains :: EqTerm -> EqTerm -> Bool contains t (Function _ _ args) = tIsArg || tInArg where tIsArg = t `L.elem` args tInArg = L.any (contains t) args contains _ _ = False data EqLiteral = EqLiteral Predicate EqTerm EqTerm deriving (Eq, Ord, Show) -- |'eq a b' builds the literal a = b eq = EqLiteral Eq -- |'neq a b' builds the literal 'not (a = b)' neq = EqLiteral Neq isEq (EqLiteral Eq _ _) = True isEq _ = False data Predicate = Eq | Neq deriving (Eq, Ord, Show) type Arity = Int -- |Make a new arity arity :: Int -> Arity arity i = i type Name = String -- |Make a new name name :: String -> Name name n = n data EqTerm = Function Name Arity [EqTerm] | Variable Name deriving (Eq, Ord) instance Show EqTerm where show = showEqTerm showEqTerm :: EqTerm -> String showEqTerm (Function name arity args) = name ++ "(" ++ intercalate "," (L.map show args) ++ ")" showEqTerm (Variable name) = name -- |Returns a new variable var = Variable -- |Returns a new function fun = Function subTerms :: EqTerm -> [EqTerm] subTerms (Variable _) = [] subTerms (Function _ _ args) = args ++ L.concatMap subTerms args type DecideEq a = StateT EqState (UnionM EqTerm) a runDecideEq :: DecideEq a -> (a, EqState) runDecideEq decide = run $ runStateT decide newEqState decideEq :: EqFormula -> DecideEq Bool decideEq f@(EqFormula lits) = do addTerms $ allTerms f buildContainsMap (allTerms f) (allTerms f) processEqualities eqs processDisequalities diseqs where litList = S.toList lits eqs = L.filter isEq litList diseqs = L.filter (not . isEq) litList data EqState = EqState { pointMap :: Map EqTerm Node, superTerms :: Map Node [EqTerm] } termsContaining :: EqTerm -> DecideEq [EqTerm] termsContaining t = do pt <- getRep t sts <- gets superTerms case M.lookup pt sts of Just ts -> return ts Nothing -> error $ "Term " ++ show t ++ " not in superTerms" getRep :: EqTerm -> DecideEq Node getRep t = do dt <- getNode t (repr, lab) <- U.lookup dt return repr sameClass :: EqTerm -> EqTerm -> DecideEq Bool sameClass l r = do repL <- getRep l repR <- getRep r return $ repL == repR defaultMerge :: EqTerm -> EqTerm -> (EqTerm, [a]) defaultMerge l r = (l, []) findCongruences :: [EqTerm] -> [EqTerm] -> DecideEq [EqLiteral] findCongruences [] rs = return [] findCongruences (l:ls) rs = do congWithL <- congruentWith l rs rest <- findCongruences ls rs return $ congWithL ++ rest congruentWith :: EqTerm -> [EqTerm] -> DecideEq [EqLiteral] congruentWith l [] = return [] congruentWith l (r:rs) = do areCong <- congruent l r rest <- congruentWith l rs return $ case areCong of True -> EqLiteral Eq l r : rest False -> rest congruent :: EqTerm -> EqTerm -> DecideEq Bool congruent (Function n1 a1 args1) (Function n2 a2 args2) = case n1 /= n2 || a1 /= a2 of True -> return False False -> equivalentArgs args1 args2 equivalentArgs :: [EqTerm] -> [EqTerm] -> DecideEq Bool equivalentArgs [] [] = return True equivalentArgs (l:ls) (r:rs) = do same <- sameClass l r case same of True -> equivalentArgs ls rs False -> return False classConflict :: [(EqTerm, EqTerm)] -> DecideEq Bool classConflict [] = return False classConflict (nextDis:rest) = do s <- uncurry sameClass nextDis case s of True -> return False _ -> classConflict rest addEq :: EqTerm -> EqTerm -> DecideEq [EqLiteral] addEq l r = do repL <- getRep l repR <- getRep r termsWithL <- termsContaining l termsWithR <- termsContaining r res <- U.merge defaultMerge repL repR case res of Nothing -> return [] _ -> do newCong <- findCongruences termsWithL termsWithR oldSt <- gets superTerms rep <- getRep l modify $ \s -> s { superTerms = M.insert rep (termsWithL ++ termsWithR) oldSt } return newCong instance Show EqState where show = showEqState showEqState :: EqState -> String showEqState (EqState pMap _) = L.concatMap (show . fst) $ M.toList pMap newEqState = EqState M.empty M.empty nodeForTerm :: EqTerm -> DecideEq (Maybe Node) nodeForTerm t = do pMap <- gets pointMap return $ M.lookup t pMap getNode :: EqTerm -> DecideEq Node getNode t = do p <- nodeForTerm t return $ fromJust p addTerm :: EqTerm -> DecideEq () addTerm t = do point <- nodeForTerm t case point of Just p -> return () Nothing -> do pts <- gets pointMap pt <- new t modify $ \eqSt -> eqSt { pointMap = M.insert t pt pts } addTerms :: [EqTerm] -> DecideEq () addTerms [] = return () addTerms (t:ts) = do addTerm t addTerms ts return () buildContainsMap :: [EqTerm] -> [EqTerm] -> DecideEq () buildContainsMap [] _ = return () buildContainsMap (l:ls) r = do tc <- allTermsContaining l r oldSup <- gets superTerms n <- getNode l modify $ \eqSt -> eqSt { superTerms = M.insert n tc oldSup } buildContainsMap ls r allTermsContaining :: EqTerm -> [EqTerm] -> DecideEq [EqTerm] allTermsContaining l [] = return [] allTermsContaining l (r:rs) = case contains l r of True -> do tc <- allTermsContaining l rs return $ r:tc False -> allTermsContaining l rs processEqualities :: [EqLiteral] -> DecideEq () processEqualities [] = return () processEqualities (EqLiteral Eq l r:ts) = do newEqs <- addEq l r processEqualities (newEqs ++ ts) processDisequalities :: [EqLiteral] -> DecideEq Bool processDisequalities [] = return True processDisequalities (EqLiteral Neq l r : ts) = do same <- sameClass l r case same of True -> return False False -> processDisequalities ts