module ToySolver.EUF.EUFSolver
(
Solver
, newSolver
, FSym
, Term (..)
, ConstrID
, VAFun (..)
, newFSym
, newFun
, newConst
, assertEqual
, assertEqual'
, assertNotEqual
, assertNotEqual'
, check
, areEqual
, explain
, Entity
, EntityTuple
, Model (..)
, getModel
, eval
, evalAp
, pushBacktrackPoint
, popBacktrackPoint
, termToFlatTerm
, termToFSym
, fsymToTerm
, fsymToFlatTerm
, flatTermToFSym
) where
import Control.Monad
import Control.Monad.Trans
import Control.Monad.Trans.Except
import Data.Either
import Data.IntSet (IntSet)
import qualified Data.IntSet as IntSet
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Set (Set)
import qualified Data.Set as Set
import Data.IORef
import qualified ToySolver.Internal.Data.Vec as Vec
import ToySolver.EUF.CongruenceClosure (FSym, Term (..), ConstrID, VAFun (..))
import ToySolver.EUF.CongruenceClosure (Model (..), Entity, EntityTuple, eval, evalAp)
import qualified ToySolver.EUF.CongruenceClosure as CC
data Solver
= Solver
{ svCCSolver :: !CC.Solver
, svDisequalities :: IORef (Map (Term, Term) (Maybe ConstrID))
, svExplanation :: IORef IntSet
, svBacktrackPoints :: !(Vec.Vec (Map (Term, Term) ()))
}
newSolver :: IO Solver
newSolver = do
cc <- CC.newSolver
deqs <- newIORef Map.empty
expl <- newIORef undefined
bp <- Vec.new
let solver =
Solver
{ svCCSolver = cc
, svDisequalities = deqs
, svExplanation = expl
, svBacktrackPoints = bp
}
return solver
newFSym :: Solver -> IO FSym
newFSym solver = CC.newFSym (svCCSolver solver)
newConst :: Solver -> IO Term
newConst solver = CC.newConst (svCCSolver solver)
newFun :: CC.VAFun a => Solver -> IO a
newFun solver = CC.newFun (svCCSolver solver)
assertEqual :: Solver -> Term -> Term -> IO ()
assertEqual solver t1 t2 = assertEqual' solver t1 t2 Nothing
assertEqual' :: Solver -> Term -> Term -> Maybe ConstrID -> IO ()
assertEqual' solver t1 t2 cid = CC.merge' (svCCSolver solver) t1 t2 cid
assertNotEqual :: Solver -> Term -> Term -> IO ()
assertNotEqual solver t1 t2 = assertNotEqual' solver t1 t2 Nothing
assertNotEqual' :: Solver -> Term -> Term -> Maybe ConstrID -> IO ()
assertNotEqual' solver t1 t2 cid = if t1 < t2 then f (t1,t2) cid else f (t2,t1) cid
where
f deq cid = do
ds <- readIORef (svDisequalities solver)
unless (deq `Map.member` ds) $ do
_ <- termToFSym solver (fst deq)
_ <- termToFSym solver (snd deq)
writeIORef (svDisequalities solver) $! Map.insert deq cid ds
lv <- getCurrentLevel solver
unless (lv==0) $ do
Vec.unsafeModify' (svBacktrackPoints solver) (lv 1) $ Map.insert deq ()
check :: Solver -> IO Bool
check solver = do
ds <- readIORef (svDisequalities solver)
liftM isRight $ runExceptT $ forM_ (Map.toList ds) $ \((t1,t2), cid) -> do
b <- lift $ CC.areCongruent (svCCSolver solver) t1 t2
if b then do
Just cs <- lift $ CC.explain (svCCSolver solver) t1 t2
lift $ writeIORef (svExplanation solver) $!
case cid of
Nothing -> cs
Just c -> IntSet.insert c cs
throwE ()
else
return ()
areEqual :: Solver -> Term -> Term -> IO Bool
areEqual solver t1 t2 = CC.areCongruent (svCCSolver solver) t1 t2
explain :: Solver -> Maybe (Term,Term) -> IO IntSet
explain solver Nothing = readIORef (svExplanation solver)
explain solver (Just (t1,t2)) = do
ret <- CC.explain (svCCSolver solver) t1 t2
case ret of
Nothing -> error "ToySolver.EUF.EUFSolver.explain: should not happen"
Just cs -> return cs
getModel :: Solver -> IO Model
getModel = CC.getModel . svCCSolver
type Level = Int
getCurrentLevel :: Solver -> IO Level
getCurrentLevel solver = Vec.getSize (svBacktrackPoints solver)
pushBacktrackPoint :: Solver -> IO ()
pushBacktrackPoint solver = do
CC.pushBacktrackPoint (svCCSolver solver)
Vec.push (svBacktrackPoints solver) Map.empty
popBacktrackPoint :: Solver -> IO ()
popBacktrackPoint solver = do
lv <- getCurrentLevel solver
if lv==0 then
error "ToySolver.EUF.EUFSolver.popBacktrackPoint: root level"
else do
CC.popBacktrackPoint (svCCSolver solver)
xs <- Vec.unsafePop (svBacktrackPoints solver)
modifyIORef' (svDisequalities solver) $ (`Map.difference` xs)
termToFlatTerm = CC.termToFlatTerm . svCCSolver
termToFSym = CC.termToFSym . svCCSolver
fsymToTerm = CC.fsymToTerm . svCCSolver
fsymToFlatTerm = CC.fsymToFlatTerm . svCCSolver
flatTermToFSym = CC.flatTermToFSym . svCCSolver