-- | A decision procedure for ground equality, -- based on the paper "Proof-producing Congruence Closure". module Test.QuickSpec.Reasoning.CongruenceClosure(CC, newSym, (=:=), (=?=), rep, evalCC, execCC, runCC, ($$), S, funUse, argUse, lookup, initial, frozen) where import Prelude hiding (lookup) import Control.Monad import Control.Monad.Trans.State.Strict import Data.IntMap(IntMap) import qualified Data.IntMap as IntMap import Test.QuickSpec.Reasoning.UnionFind(UF, Replacement((:>))) import qualified Test.QuickSpec.Reasoning.UnionFind as UF import Data.Maybe import Data.List(foldl') -- import Test.QuickCheck -- import Test.QuickCheck.Arbitrary -- import Test.QuickCheck.Monadic import Text.Printf lookup2 :: Int -> Int -> IntMap (IntMap a) -> Maybe a lookup2 k1 k2 m = IntMap.lookup k2 (IntMap.findWithDefault IntMap.empty k1 m) insert2 :: Int -> Int -> a -> IntMap (IntMap a) -> IntMap (IntMap a) insert2 k1 k2 v m = IntMap.insertWith IntMap.union k1 (IntMap.singleton k2 v) m delete2 :: Int -> Int -> IntMap (IntMap a) -> IntMap (IntMap a) delete2 k1 k2 m = IntMap.adjust (IntMap.delete k2) k1 m data FlatEqn = (Int, Int) := Int deriving (Eq, Ord) data S = S { -- in all these maps, the keys are representatives, the values may not be funUse :: !(IntMap [(Int, Int)]), argUse :: !(IntMap [(Int, Int)]), lookup :: IntMap (IntMap Int), uf :: UF.S } type CC = State S liftUF :: UF a -> CC a liftUF m = do s <- get let (x, uf') = UF.runUF (uf s) m put s { uf = uf' } return x invariant :: String -> CC () invariant _ = return () -- invariant str = do -- S funUse argUse lookup <- get -- -- keys of all maps are representatives -- let check phase x = do -- b <- liftUF (UF.isRep x) -- if b then return () else error (printf "%s, %s appears as a key in %s but is not a rep in:\nfunUse=%s\nargUse=%s\nlookup=%s" str (show x) phase (show funUse) (show argUse) (show lookup)) -- mapM_ (check "funUse") (IntMap.keys funUse) -- mapM_ (check "argUse") (IntMap.keys argUse) -- mapM_ (check "lookup") (IntMap.keys lookup) -- mapM_ (mapM_ (check "inner lookup") . IntMap.keys) (IntMap.elems lookup) modifyFunUse f = modify (\s -> s { funUse = f (funUse s) }) modifyArgUse f = modify (\s -> s { argUse = f (argUse s) }) addFunUses xs s = modifyFunUse (IntMap.insertWith (++) s xs) addArgUses xs s = modifyArgUse (IntMap.insertWith (++) s xs) modifyLookup f = modify (\s -> s { lookup = f (lookup s) }) putLookup l = modifyLookup (const l) newSym :: CC Int newSym = liftUF UF.newSym ($$) :: Int -> Int -> CC Int f $$ x = do invariant (printf "before %s$$%s" (show f) (show x)) m <- gets lookup f' <- rep f x' <- rep x invariant (printf "at %s$$%s:1" (show f) (show x)) case lookup2 x' f' m of Nothing -> do c <- newSym invariant (printf "at %s$$%s:2" (show f) (show x)) putLookup (insert2 x' f' c m) addFunUses [(x', c)] f' addArgUses [(f', c)] x' invariant (printf "after %s$$%s" (show f) (show x)) return c Just k -> return k (=:=) :: Int -> Int -> CC Bool a =:= b = propagate (a, b) (=?=) :: Int -> Int -> CC Bool t =?= u = liftM2 (==) (rep t) (rep u) propagate (a, b) = do (unified, pending) <- propagate1 (a, b) mapM_ propagate pending return unified propagate1 (a, b) = do invariant (printf "before propagate (%s, %s)" (show a) (show b)) res <- liftUF (a UF.=:= b) case res of Nothing -> return (False, []) Just (r :> r') -> do funUses <- gets (IntMap.lookup r . funUse) argUses <- gets (IntMap.lookup r . argUse) case (funUses, argUses) of (Nothing, Nothing) -> return (True, []) _ -> fmap (\x -> (True, x)) (updateUses r r' (fromMaybe [] funUses) (fromMaybe [] argUses)) updateUses r r' funUses argUses = do modifyFunUse (IntMap.delete r) modifyArgUse (IntMap.delete r) modifyLookup (IntMap.delete r) forM_ funUses $ \(x, _) -> do x' <- rep x modifyLookup (delete2 x' r) invariant (printf "after deleting %s" (show r)) let repPair (x, c) = do x' <- rep x return (x', c) funUses' <- mapM repPair funUses argUses' <- mapM repPair argUses m <- gets lookup let foldUses insert lookup pending m uses = foldl' op e uses where op (pending, newUses, m) (x', c) = case lookup x' m of Just k -> ((c, k):pending, newUses, m) Nothing -> (pending, (x', c):newUses, insert x' c m) e = (pending, [], m) (funPending, funNewUses, m') = foldUses (\x' c m -> insert2 x' r' c m) (\x' m -> lookup2 x' r' m) [] m funUses' (pending, argNewUses, argM) = foldUses IntMap.insert IntMap.lookup funPending (IntMap.findWithDefault IntMap.empty r' m') argUses' addFunUses funNewUses r' addArgUses argNewUses r' putLookup (if IntMap.null argM then m' else IntMap.insert r' argM m') invariant (printf "after updateUses (%s, %s)" (show r) (show r')) return pending rep :: Int -> CC Int rep s = liftUF (UF.rep s) runCC :: S -> CC a -> (a, S) runCC s m = runState m s evalCC :: S -> CC a -> a evalCC s m = fst (runCC s m) execCC :: S -> CC a -> S execCC s m = snd (runCC s m) initial :: Int -> S initial n = S IntMap.empty IntMap.empty IntMap.empty (UF.initial n) frozen :: CC a -> CC a frozen x = fmap (evalState x) get