{-# LANGUAGE CPP #-}
-----------------------------------------------------------------------------
-- |
-- Module      :  ToySolver.EUF.EUFSolver
-- Copyright   :  (c) Masahiro Sakai 2015
-- License     :  BSD-style
-- 
-- Maintainer  :  masahiro.sakai@gmail.com
-- Stability   :  unstable
-- Portability :  non-portable (CPP)
--
-----------------------------------------------------------------------------
module ToySolver.EUF.EUFSolver
  ( -- * The @Solver@ type
    Solver
  , newSolver

  -- * Problem description
  , FSym
  , Term (..)
  , ConstrID
  , VAFun (..)
  , newFSym
  , newFun
  , newConst
  , assertEqual
  , assertEqual'
  , assertNotEqual
  , assertNotEqual'

  -- * Query
  , check
  , areEqual

  -- * Explanation
  , explain

  -- * Model Construction
  , Entity
  , EntityTuple
  , Model (..)
  , getModel
  , eval
  , evalAp

  -- * Backtracking
  , pushBacktrackPoint
  , popBacktrackPoint

  -- * Low-level operations
  , termToFlatTerm
  , termToFSym
  , fsymToTerm
  , fsymToFlatTerm
  , flatTermToFSym
  ) where

import Control.Monad
import Control.Monad.Trans
import Control.Monad.Trans.Except
import Data.Either
import Data.IntSet (IntSet)
import qualified Data.IntSet as IntSet
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Set (Set)
import qualified Data.Set as Set
import Data.IORef

import qualified ToySolver.Internal.Data.Vec as Vec
import ToySolver.EUF.CongruenceClosure (FSym, Term (..), ConstrID, VAFun (..))
import ToySolver.EUF.CongruenceClosure (Model (..), Entity, EntityTuple, eval, evalAp)
import qualified ToySolver.EUF.CongruenceClosure as CC

data Solver
  = Solver
  { svCCSolver :: !CC.Solver
  , svDisequalities :: IORef (Map (Term, Term) (Maybe ConstrID))
  , svExplanation :: IORef IntSet
  , svBacktrackPoints :: !(Vec.Vec (Map (Term, Term) ()))
  }

newSolver :: IO Solver
newSolver = do
  cc <- CC.newSolver
  deqs <- newIORef Map.empty
  expl <- newIORef undefined
  bp <- Vec.new

  let solver = 
        Solver
        { svCCSolver = cc
        , svDisequalities = deqs
        , svExplanation = expl
        , svBacktrackPoints = bp
        }
  return solver

newFSym :: Solver -> IO FSym
newFSym solver = CC.newFSym (svCCSolver solver)

newConst :: Solver -> IO Term
newConst solver = CC.newConst (svCCSolver solver)

newFun :: CC.VAFun a => Solver -> IO a
newFun solver = CC.newFun (svCCSolver solver)

assertEqual :: Solver -> Term -> Term -> IO ()
assertEqual solver t1 t2 = assertEqual' solver t1 t2 Nothing

assertEqual' :: Solver -> Term -> Term -> Maybe ConstrID -> IO ()
assertEqual' solver t1 t2 cid = CC.merge' (svCCSolver solver) t1 t2 cid

assertNotEqual :: Solver -> Term -> Term -> IO ()
assertNotEqual solver t1 t2 = assertNotEqual' solver t1 t2 Nothing

assertNotEqual' :: Solver -> Term -> Term -> Maybe ConstrID -> IO ()
assertNotEqual' solver t1 t2 cid = if t1 < t2 then f (t1,t2) cid else f (t2,t1) cid
  where
    f deq cid = do
      ds <- readIORef (svDisequalities solver)
      unless (deq `Map.member` ds) $ do
        _ <- termToFSym solver (fst deq) -- It is important to name the term for model generation
        _ <- termToFSym solver (snd deq) -- It is important to name the term for model generation
        writeIORef (svDisequalities solver) $! Map.insert deq cid ds
        lv <- getCurrentLevel solver
        unless (lv==0) $ do
          Vec.unsafeModify' (svBacktrackPoints solver) (lv - 1) $ Map.insert deq ()

check :: Solver -> IO Bool
check solver = do
  ds <- readIORef (svDisequalities solver)
  liftM isRight $ runExceptT $ forM_ (Map.toList ds) $ \((t1,t2), cid) -> do
    b <- lift $ CC.areCongruent (svCCSolver solver) t1 t2
    if b then do
      Just cs <- lift $ CC.explain (svCCSolver solver) t1 t2
      lift $ writeIORef (svExplanation solver) $!
        case cid of
          Nothing -> cs
          Just c -> IntSet.insert c cs
      throwE ()
    else
      return ()

areEqual :: Solver -> Term -> Term -> IO Bool
areEqual solver t1 t2 = CC.areCongruent (svCCSolver solver) t1 t2

explain :: Solver -> Maybe (Term,Term) -> IO IntSet
explain solver Nothing = readIORef (svExplanation solver)
explain solver (Just (t1,t2)) = do
  ret <- CC.explain (svCCSolver solver) t1 t2
  case ret of
    Nothing -> error "ToySolver.EUF.EUFSolver.explain: should not happen"
    Just cs -> return cs

-- -------------------------------------------------------------------
-- Model construction
-- -------------------------------------------------------------------

getModel :: Solver -> IO Model
getModel = CC.getModel . svCCSolver

-- -------------------------------------------------------------------
-- Backtracking
-- -------------------------------------------------------------------

type Level = Int

getCurrentLevel :: Solver -> IO Level
getCurrentLevel solver = Vec.getSize (svBacktrackPoints solver)

pushBacktrackPoint :: Solver -> IO ()
pushBacktrackPoint solver = do
  CC.pushBacktrackPoint (svCCSolver solver)
  Vec.push (svBacktrackPoints solver) Map.empty

popBacktrackPoint :: Solver -> IO ()
popBacktrackPoint solver = do
  lv <- getCurrentLevel solver
  if lv==0 then
    error "ToySolver.EUF.EUFSolver.popBacktrackPoint: root level"
  else do
    CC.popBacktrackPoint (svCCSolver solver)
    xs <- Vec.unsafePop (svBacktrackPoints solver)
    modifyIORef' (svDisequalities solver) $ (`Map.difference` xs)

termToFlatTerm = CC.termToFlatTerm . svCCSolver
termToFSym     = CC.termToFSym     . svCCSolver
fsymToTerm     = CC.fsymToTerm     . svCCSolver
fsymToFlatTerm = CC.fsymToFlatTerm . svCCSolver
flatTermToFSym = CC.flatTermToFSym . svCCSolver