{-# OPTIONS_HADDOCK show-extensions #-}
{-# LANGUAGE CPP #-}
-----------------------------------------------------------------------------
-- |
-- Module      :  ToySolver.EUF.EUFSolver
-- Copyright   :  (c) Masahiro Sakai 2015
-- License     :  BSD-style
--
-- Maintainer  :  masahiro.sakai@gmail.com
-- Stability   :  unstable
-- Portability :  non-portable
--
-----------------------------------------------------------------------------
module ToySolver.EUF.EUFSolver
  ( -- * The @Solver@ type
    Solver
  , newSolver

  -- * Problem description
  , FSym
  , Term (..)
  , ConstrID
  , VAFun (..)
  , newFSym
  , newFun
  , newConst
  , assertEqual
  , assertEqual'
  , assertNotEqual
  , assertNotEqual'

  -- * Query
  , check
  , areEqual

  -- * Explanation
  , explain

  -- * Model Construction
  , Entity
  , EntityTuple
  , Model (..)
  , getModel
  , eval
  , evalAp

  -- * Backtracking
  , pushBacktrackPoint
  , popBacktrackPoint

  -- * Low-level operations
  , termToFlatTerm
  , termToFSym
  , fsymToTerm
  , fsymToFlatTerm
  , flatTermToFSym
  ) where

import Control.Monad
import Control.Monad.Trans
import Control.Monad.Trans.Except
import Data.Either
import Data.IntSet (IntSet)
import qualified Data.IntSet as IntSet
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Set (Set)
import qualified Data.Set as Set
import Data.IORef

import qualified ToySolver.Internal.Data.Vec as Vec
import ToySolver.EUF.CongruenceClosure (FSym, Term (..), ConstrID, VAFun (..))
import ToySolver.EUF.CongruenceClosure (Model (..), Entity, EntityTuple, eval, evalAp)
import qualified ToySolver.EUF.CongruenceClosure as CC

data Solver
  = Solver
  { Solver -> Solver
svCCSolver :: !CC.Solver
  , Solver -> IORef (Map (Term, Term) (Maybe ConstrID))
svDisequalities :: IORef (Map (Term, Term) (Maybe ConstrID))
  , Solver -> IORef IntSet
svExplanation :: IORef IntSet
  , Solver -> Vec (Map (Term, Term) ())
svBacktrackPoints :: !(Vec.Vec (Map (Term, Term) ()))
  }

newSolver :: IO Solver
newSolver :: IO Solver
newSolver = do
  Solver
cc <- IO Solver
CC.newSolver
  IORef (Map (Term, Term) (Maybe ConstrID))
deqs <- Map (Term, Term) (Maybe ConstrID)
-> IO (IORef (Map (Term, Term) (Maybe ConstrID)))
forall a. a -> IO (IORef a)
newIORef Map (Term, Term) (Maybe ConstrID)
forall k a. Map k a
Map.empty
  IORef IntSet
expl <- IntSet -> IO (IORef IntSet)
forall a. a -> IO (IORef a)
newIORef IntSet
forall a. HasCallStack => a
undefined
  Vec (Map (Term, Term) ())
bp <- IO (Vec (Map (Term, Term) ()))
forall (a :: * -> * -> *) e. MArray a e IO => IO (GenericVec a e)
Vec.new

  let solver :: Solver
solver =
        Solver :: Solver
-> IORef (Map (Term, Term) (Maybe ConstrID))
-> IORef IntSet
-> Vec (Map (Term, Term) ())
-> Solver
Solver
        { svCCSolver :: Solver
svCCSolver = Solver
cc
        , svDisequalities :: IORef (Map (Term, Term) (Maybe ConstrID))
svDisequalities = IORef (Map (Term, Term) (Maybe ConstrID))
deqs
        , svExplanation :: IORef IntSet
svExplanation = IORef IntSet
expl
        , svBacktrackPoints :: Vec (Map (Term, Term) ())
svBacktrackPoints = Vec (Map (Term, Term) ())
bp
        }
  Solver -> IO Solver
forall (m :: * -> *) a. Monad m => a -> m a
return Solver
solver

newFSym :: Solver -> IO FSym
newFSym :: Solver -> IO ConstrID
newFSym Solver
solver = Solver -> IO ConstrID
CC.newFSym (Solver -> Solver
svCCSolver Solver
solver)

newConst :: Solver -> IO Term
newConst :: Solver -> IO Term
newConst Solver
solver = Solver -> IO Term
CC.newConst (Solver -> Solver
svCCSolver Solver
solver)

newFun :: CC.VAFun a => Solver -> IO a
newFun :: Solver -> IO a
newFun Solver
solver = Solver -> IO a
forall a. VAFun a => Solver -> IO a
CC.newFun (Solver -> Solver
svCCSolver Solver
solver)

assertEqual :: Solver -> Term -> Term -> IO ()
assertEqual :: Solver -> Term -> Term -> IO ()
assertEqual Solver
solver Term
t1 Term
t2 = Solver -> Term -> Term -> Maybe ConstrID -> IO ()
assertEqual' Solver
solver Term
t1 Term
t2 Maybe ConstrID
forall a. Maybe a
Nothing

assertEqual' :: Solver -> Term -> Term -> Maybe ConstrID -> IO ()
assertEqual' :: Solver -> Term -> Term -> Maybe ConstrID -> IO ()
assertEqual' Solver
solver Term
t1 Term
t2 Maybe ConstrID
cid = Solver -> Term -> Term -> Maybe ConstrID -> IO ()
CC.merge' (Solver -> Solver
svCCSolver Solver
solver) Term
t1 Term
t2 Maybe ConstrID
cid

assertNotEqual :: Solver -> Term -> Term -> IO ()
assertNotEqual :: Solver -> Term -> Term -> IO ()
assertNotEqual Solver
solver Term
t1 Term
t2 = Solver -> Term -> Term -> Maybe ConstrID -> IO ()
assertNotEqual' Solver
solver Term
t1 Term
t2 Maybe ConstrID
forall a. Maybe a
Nothing

assertNotEqual' :: Solver -> Term -> Term -> Maybe ConstrID -> IO ()
assertNotEqual' :: Solver -> Term -> Term -> Maybe ConstrID -> IO ()
assertNotEqual' Solver
solver Term
t1 Term
t2 Maybe ConstrID
cid = if Term
t1 Term -> Term -> Bool
forall a. Ord a => a -> a -> Bool
< Term
t2 then (Term, Term) -> Maybe ConstrID -> IO ()
f (Term
t1,Term
t2) Maybe ConstrID
cid else (Term, Term) -> Maybe ConstrID -> IO ()
f (Term
t2,Term
t1) Maybe ConstrID
cid
  where
    f :: (Term, Term) -> Maybe ConstrID -> IO ()
f (Term, Term)
deq Maybe ConstrID
cid = do
      Map (Term, Term) (Maybe ConstrID)
ds <- IORef (Map (Term, Term) (Maybe ConstrID))
-> IO (Map (Term, Term) (Maybe ConstrID))
forall a. IORef a -> IO a
readIORef (Solver -> IORef (Map (Term, Term) (Maybe ConstrID))
svDisequalities Solver
solver)
      Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ((Term, Term)
deq (Term, Term) -> Map (Term, Term) (Maybe ConstrID) -> Bool
forall k a. Ord k => k -> Map k a -> Bool
`Map.member` Map (Term, Term) (Maybe ConstrID)
ds) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
        ConstrID
_ <- Solver -> Term -> IO ConstrID
termToFSym Solver
solver ((Term, Term) -> Term
forall a b. (a, b) -> a
fst (Term, Term)
deq) -- It is important to name the term for model generation
        ConstrID
_ <- Solver -> Term -> IO ConstrID
termToFSym Solver
solver ((Term, Term) -> Term
forall a b. (a, b) -> b
snd (Term, Term)
deq) -- It is important to name the term for model generation
        IORef (Map (Term, Term) (Maybe ConstrID))
-> Map (Term, Term) (Maybe ConstrID) -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef (Solver -> IORef (Map (Term, Term) (Maybe ConstrID))
svDisequalities Solver
solver) (Map (Term, Term) (Maybe ConstrID) -> IO ())
-> Map (Term, Term) (Maybe ConstrID) -> IO ()
forall a b. (a -> b) -> a -> b
$! (Term, Term)
-> Maybe ConstrID
-> Map (Term, Term) (Maybe ConstrID)
-> Map (Term, Term) (Maybe ConstrID)
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert (Term, Term)
deq Maybe ConstrID
cid Map (Term, Term) (Maybe ConstrID)
ds
        ConstrID
lv <- Solver -> IO ConstrID
getCurrentLevel Solver
solver
        Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (ConstrID
lvConstrID -> ConstrID -> Bool
forall a. Eq a => a -> a -> Bool
==ConstrID
0) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
          Vec (Map (Term, Term) ())
-> ConstrID
-> (Map (Term, Term) () -> Map (Term, Term) ())
-> IO ()
forall (a :: * -> * -> *) e.
MArray a e IO =>
GenericVec a e -> ConstrID -> (e -> e) -> IO ()
Vec.unsafeModify' (Solver -> Vec (Map (Term, Term) ())
svBacktrackPoints Solver
solver) (ConstrID
lv ConstrID -> ConstrID -> ConstrID
forall a. Num a => a -> a -> a
- ConstrID
1) ((Map (Term, Term) () -> Map (Term, Term) ()) -> IO ())
-> (Map (Term, Term) () -> Map (Term, Term) ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ (Term, Term) -> () -> Map (Term, Term) () -> Map (Term, Term) ()
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert (Term, Term)
deq ()

check :: Solver -> IO Bool
check :: Solver -> IO Bool
check Solver
solver = do
  Map (Term, Term) (Maybe ConstrID)
ds <- IORef (Map (Term, Term) (Maybe ConstrID))
-> IO (Map (Term, Term) (Maybe ConstrID))
forall a. IORef a -> IO a
readIORef (Solver -> IORef (Map (Term, Term) (Maybe ConstrID))
svDisequalities Solver
solver)
  (Either () () -> Bool) -> IO (Either () ()) -> IO Bool
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM Either () () -> Bool
forall a b. Either a b -> Bool
isRight (IO (Either () ()) -> IO Bool) -> IO (Either () ()) -> IO Bool
forall a b. (a -> b) -> a -> b
$ ExceptT () IO () -> IO (Either () ())
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT (ExceptT () IO () -> IO (Either () ()))
-> ExceptT () IO () -> IO (Either () ())
forall a b. (a -> b) -> a -> b
$ [((Term, Term), Maybe ConstrID)]
-> (((Term, Term), Maybe ConstrID) -> ExceptT () IO ())
-> ExceptT () IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (Map (Term, Term) (Maybe ConstrID)
-> [((Term, Term), Maybe ConstrID)]
forall k a. Map k a -> [(k, a)]
Map.toList Map (Term, Term) (Maybe ConstrID)
ds) ((((Term, Term), Maybe ConstrID) -> ExceptT () IO ())
 -> ExceptT () IO ())
-> (((Term, Term), Maybe ConstrID) -> ExceptT () IO ())
-> ExceptT () IO ()
forall a b. (a -> b) -> a -> b
$ \((Term
t1,Term
t2), Maybe ConstrID
cid) -> do
    Bool
b <- IO Bool -> ExceptT () IO Bool
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (IO Bool -> ExceptT () IO Bool) -> IO Bool -> ExceptT () IO Bool
forall a b. (a -> b) -> a -> b
$ Solver -> Term -> Term -> IO Bool
CC.areCongruent (Solver -> Solver
svCCSolver Solver
solver) Term
t1 Term
t2
    if Bool
b then do
      Just IntSet
cs <- IO (Maybe IntSet) -> ExceptT () IO (Maybe IntSet)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (IO (Maybe IntSet) -> ExceptT () IO (Maybe IntSet))
-> IO (Maybe IntSet) -> ExceptT () IO (Maybe IntSet)
forall a b. (a -> b) -> a -> b
$ Solver -> Term -> Term -> IO (Maybe IntSet)
CC.explain (Solver -> Solver
svCCSolver Solver
solver) Term
t1 Term
t2
      IO () -> ExceptT () IO ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (IO () -> ExceptT () IO ()) -> IO () -> ExceptT () IO ()
forall a b. (a -> b) -> a -> b
$ IORef IntSet -> IntSet -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef (Solver -> IORef IntSet
svExplanation Solver
solver) (IntSet -> IO ()) -> IntSet -> IO ()
forall a b. (a -> b) -> a -> b
$!
        case Maybe ConstrID
cid of
          Maybe ConstrID
Nothing -> IntSet
cs
          Just ConstrID
c -> ConstrID -> IntSet -> IntSet
IntSet.insert ConstrID
c IntSet
cs
      () -> ExceptT () IO ()
forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE ()
    else
      () -> ExceptT () IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

areEqual :: Solver -> Term -> Term -> IO Bool
areEqual :: Solver -> Term -> Term -> IO Bool
areEqual Solver
solver Term
t1 Term
t2 = Solver -> Term -> Term -> IO Bool
CC.areCongruent (Solver -> Solver
svCCSolver Solver
solver) Term
t1 Term
t2

explain :: Solver -> Maybe (Term,Term) -> IO IntSet
explain :: Solver -> Maybe (Term, Term) -> IO IntSet
explain Solver
solver Maybe (Term, Term)
Nothing = IORef IntSet -> IO IntSet
forall a. IORef a -> IO a
readIORef (Solver -> IORef IntSet
svExplanation Solver
solver)
explain Solver
solver (Just (Term
t1,Term
t2)) = do
  Maybe IntSet
ret <- Solver -> Term -> Term -> IO (Maybe IntSet)
CC.explain (Solver -> Solver
svCCSolver Solver
solver) Term
t1 Term
t2
  case Maybe IntSet
ret of
    Maybe IntSet
Nothing -> [Char] -> IO IntSet
forall a. HasCallStack => [Char] -> a
error [Char]
"ToySolver.EUF.EUFSolver.explain: should not happen"
    Just IntSet
cs -> IntSet -> IO IntSet
forall (m :: * -> *) a. Monad m => a -> m a
return IntSet
cs

-- -------------------------------------------------------------------
-- Model construction
-- -------------------------------------------------------------------

getModel :: Solver -> IO Model
getModel :: Solver -> IO Model
getModel = Solver -> IO Model
CC.getModel (Solver -> IO Model) -> (Solver -> Solver) -> Solver -> IO Model
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Solver -> Solver
svCCSolver

-- -------------------------------------------------------------------
-- Backtracking
-- -------------------------------------------------------------------

type Level = Int

getCurrentLevel :: Solver -> IO Level
getCurrentLevel :: Solver -> IO ConstrID
getCurrentLevel Solver
solver = Vec (Map (Term, Term) ()) -> IO ConstrID
forall (a :: * -> * -> *) e. GenericVec a e -> IO ConstrID
Vec.getSize (Solver -> Vec (Map (Term, Term) ())
svBacktrackPoints Solver
solver)

pushBacktrackPoint :: Solver -> IO ()
pushBacktrackPoint :: Solver -> IO ()
pushBacktrackPoint Solver
solver = do
  Solver -> IO ()
CC.pushBacktrackPoint (Solver -> Solver
svCCSolver Solver
solver)
  Vec (Map (Term, Term) ()) -> Map (Term, Term) () -> IO ()
forall (a :: * -> * -> *) e.
MArray a e IO =>
GenericVec a e -> e -> IO ()
Vec.push (Solver -> Vec (Map (Term, Term) ())
svBacktrackPoints Solver
solver) Map (Term, Term) ()
forall k a. Map k a
Map.empty

popBacktrackPoint :: Solver -> IO ()
popBacktrackPoint :: Solver -> IO ()
popBacktrackPoint Solver
solver = do
  ConstrID
lv <- Solver -> IO ConstrID
getCurrentLevel Solver
solver
  if ConstrID
lvConstrID -> ConstrID -> Bool
forall a. Eq a => a -> a -> Bool
==ConstrID
0 then
    [Char] -> IO ()
forall a. HasCallStack => [Char] -> a
error [Char]
"ToySolver.EUF.EUFSolver.popBacktrackPoint: root level"
  else do
    Solver -> IO ()
CC.popBacktrackPoint (Solver -> Solver
svCCSolver Solver
solver)
    Map (Term, Term) ()
xs <- Vec (Map (Term, Term) ()) -> IO (Map (Term, Term) ())
forall (a :: * -> * -> *) e.
MArray a e IO =>
GenericVec a e -> IO e
Vec.unsafePop (Solver -> Vec (Map (Term, Term) ())
svBacktrackPoints Solver
solver)
    IORef (Map (Term, Term) (Maybe ConstrID))
-> (Map (Term, Term) (Maybe ConstrID)
    -> Map (Term, Term) (Maybe ConstrID))
-> IO ()
forall a. IORef a -> (a -> a) -> IO ()
modifyIORef' (Solver -> IORef (Map (Term, Term) (Maybe ConstrID))
svDisequalities Solver
solver) ((Map (Term, Term) (Maybe ConstrID)
  -> Map (Term, Term) (Maybe ConstrID))
 -> IO ())
-> (Map (Term, Term) (Maybe ConstrID)
    -> Map (Term, Term) (Maybe ConstrID))
-> IO ()
forall a b. (a -> b) -> a -> b
$ (Map (Term, Term) (Maybe ConstrID)
-> Map (Term, Term) () -> Map (Term, Term) (Maybe ConstrID)
forall k a b. Ord k => Map k a -> Map k b -> Map k a
`Map.difference` Map (Term, Term) ()
xs)

termToFlatTerm :: Solver -> Term -> IO FlatTerm
termToFlatTerm = Solver -> Term -> IO FlatTerm
CC.termToFlatTerm (Solver -> Term -> IO FlatTerm)
-> (Solver -> Solver) -> Solver -> Term -> IO FlatTerm
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Solver -> Solver
svCCSolver
termToFSym :: Solver -> Term -> IO ConstrID
termToFSym     = Solver -> Term -> IO ConstrID
CC.termToFSym     (Solver -> Term -> IO ConstrID)
-> (Solver -> Solver) -> Solver -> Term -> IO ConstrID
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Solver -> Solver
svCCSolver
fsymToTerm :: Solver -> ConstrID -> IO Term
fsymToTerm     = Solver -> ConstrID -> IO Term
CC.fsymToTerm     (Solver -> ConstrID -> IO Term)
-> (Solver -> Solver) -> Solver -> ConstrID -> IO Term
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Solver -> Solver
svCCSolver
fsymToFlatTerm :: Solver -> ConstrID -> IO FlatTerm
fsymToFlatTerm = Solver -> ConstrID -> IO FlatTerm
CC.fsymToFlatTerm (Solver -> ConstrID -> IO FlatTerm)
-> (Solver -> Solver) -> Solver -> ConstrID -> IO FlatTerm
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Solver -> Solver
svCCSolver
flatTermToFSym :: Solver -> FlatTerm -> IO ConstrID
flatTermToFSym = Solver -> FlatTerm -> IO ConstrID
CC.flatTermToFSym (Solver -> FlatTerm -> IO ConstrID)
-> (Solver -> Solver) -> Solver -> FlatTerm -> IO ConstrID
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Solver -> Solver
svCCSolver