module EqualitySolver.Solver(
  EqFormula, EqLiteral, EqTerm,
  Name, Arity,
  eqF, eq, neq,
  var, fun,
  satisfiableInEq) where

import Control.Monad.Identity
import Control.Monad.State
import Control.Monad.Union as U
import Data.List as L
import Data.Map as M
import Data.Maybe
import Data.Set as S

-- |Returns true if the conjunction of literals given as an
-- argument is satisfiable in the first order theory of
-- uninterpreted functions with equality
satisfiableInEq :: EqFormula -> Bool
satisfiableInEq formula = fst $ runDecideEq $ decideEq formula

-- A conjunction of literals
data EqFormula = EqFormula (Set EqLiteral)
                 deriving (Eq, Ord, Show)

-- |Build a conjunction of literals out of a list of literals
eqF = EqFormula . S.fromList

allTerms :: EqFormula -> [EqTerm]
allTerms (EqFormula lits) = L.concatMap extractTerms $ S.toList lits

extractTerms (EqLiteral _ l r) = [l, r] ++ subTerms l ++ subTerms r

contains :: EqTerm -> EqTerm -> Bool
contains t (Function _ _ args) = tIsArg || tInArg
  where
    tIsArg = t `L.elem` args
    tInArg = L.any (contains t) args
contains _ _ = False

data EqLiteral
  = EqLiteral Predicate EqTerm EqTerm
    deriving (Eq, Ord, Show)

-- |'eq a b' builds the literal a = b
eq = EqLiteral Eq

-- |'neq a b' builds the literal 'not (a = b)'
neq = EqLiteral Neq

isEq (EqLiteral Eq _ _) = True
isEq _ = False

data Predicate
  = Eq
  | Neq
    deriving (Eq, Ord, Show)

type Arity = Int

-- |Make a new arity
arity :: Int -> Arity
arity i = i

type Name = String

-- |Make a new name
name :: String -> Name
name n = n


data EqTerm
  = Function Name Arity [EqTerm]
  | Variable Name
    deriving (Eq, Ord)

instance Show EqTerm where
  show = showEqTerm

showEqTerm :: EqTerm -> String
showEqTerm (Function name arity args) = name ++ "(" ++ intercalate "," (L.map show args) ++ ")"
showEqTerm (Variable name) = name

-- |Returns a new variable
var = Variable
-- |Returns a new function
fun = Function

subTerms :: EqTerm -> [EqTerm]
subTerms (Variable _) = []
subTerms (Function _ _ args) = args ++ L.concatMap subTerms args


type DecideEq a = StateT EqState (UnionM EqTerm) a

runDecideEq :: DecideEq a -> (a, EqState)
runDecideEq decide =  run $ runStateT decide newEqState

decideEq :: EqFormula -> DecideEq Bool
decideEq f@(EqFormula lits) = do
  addTerms $ allTerms f
  buildContainsMap (allTerms f) (allTerms f)
  processEqualities eqs
  processDisequalities diseqs
  where
    litList = S.toList lits
    eqs = L.filter isEq litList
    diseqs = L.filter  (not . isEq) litList

data EqState
  = EqState {
    pointMap :: Map EqTerm Node,
    superTerms :: Map Node [EqTerm]
    }

termsContaining :: EqTerm -> DecideEq [EqTerm]
termsContaining t = do
  pt <- getRep t
  sts <- gets superTerms
  case M.lookup pt sts of
    Just ts -> return ts
    Nothing -> error $ "Term " ++ show t ++ " not in superTerms"

getRep :: EqTerm -> DecideEq Node
getRep t = do
  dt <- getNode t
  (repr, lab) <- U.lookup dt
  return repr

sameClass :: EqTerm -> EqTerm -> DecideEq Bool
sameClass l r = do
  repL <- getRep l
  repR <- getRep r
  return $ repL == repR

defaultMerge :: EqTerm -> EqTerm -> (EqTerm, [a])
defaultMerge l r = (l, [])

findCongruences :: [EqTerm] -> [EqTerm] -> DecideEq [EqLiteral]
findCongruences [] rs = return []
findCongruences (l:ls) rs = do
  congWithL <- congruentWith l rs
  rest <- findCongruences ls rs
  return $ congWithL ++ rest

congruentWith :: EqTerm -> [EqTerm] -> DecideEq [EqLiteral]
congruentWith l [] = return []
congruentWith l (r:rs) = do
  areCong <- congruent l r
  rest <- congruentWith l rs
  return $ case areCong of
    True -> EqLiteral Eq l r : rest
    False -> rest

congruent :: EqTerm -> EqTerm -> DecideEq Bool
congruent (Function n1 a1 args1) (Function n2 a2 args2) =
  case n1 /= n2 || a1 /= a2 of
    True -> return False
    False -> equivalentArgs args1 args2

equivalentArgs :: [EqTerm] -> [EqTerm] -> DecideEq Bool
equivalentArgs [] [] = return True
equivalentArgs (l:ls) (r:rs) = do
  same <- sameClass l r
  case same of
    True -> equivalentArgs ls rs
    False -> return False

classConflict :: [(EqTerm, EqTerm)] -> DecideEq Bool
classConflict [] = return False
classConflict (nextDis:rest) = do
  s <- uncurry sameClass nextDis
  case s of
    True -> return False
    _ -> classConflict rest

addEq :: EqTerm -> EqTerm -> DecideEq [EqLiteral]
addEq l r = do
  repL <- getRep l
  repR <- getRep r
  termsWithL <- termsContaining l
  termsWithR <- termsContaining r
  res <- U.merge defaultMerge repL repR
  case res of
    Nothing -> return []
    _ -> do
      newCong <- findCongruences termsWithL termsWithR
      oldSt <- gets superTerms
      rep <- getRep l
      modify $ \s -> s { superTerms = M.insert rep (termsWithL ++ termsWithR) oldSt }
      return newCong

instance Show EqState where
  show = showEqState

showEqState :: EqState -> String
showEqState (EqState pMap _) = L.concatMap (show . fst) $ M.toList pMap

newEqState = EqState M.empty M.empty

nodeForTerm :: EqTerm -> DecideEq (Maybe Node)
nodeForTerm t = do
  pMap <- gets pointMap
  return $ M.lookup t pMap

getNode :: EqTerm -> DecideEq Node
getNode t = do
  p <- nodeForTerm t
  return $ fromJust p

addTerm :: EqTerm -> DecideEq ()
addTerm t = do
  point <- nodeForTerm t
  case point of
    Just p -> return ()
    Nothing -> do
      pts <- gets pointMap
      pt <- new t
      modify $ \eqSt -> eqSt { pointMap = M.insert t pt pts }

addTerms :: [EqTerm] -> DecideEq ()
addTerms [] = return ()
addTerms (t:ts) = do
  addTerm t
  addTerms ts
  return ()

buildContainsMap :: [EqTerm] -> [EqTerm] -> DecideEq ()
buildContainsMap [] _ = return ()
buildContainsMap (l:ls) r = do
  tc <- allTermsContaining l r
  oldSup <- gets superTerms
  n <- getNode l
  modify $ \eqSt -> eqSt { superTerms = M.insert n tc oldSup }
  buildContainsMap ls r

allTermsContaining :: EqTerm -> [EqTerm] -> DecideEq [EqTerm]
allTermsContaining l [] = return []
allTermsContaining l (r:rs) =
  case contains l r of
    True -> do
      tc <- allTermsContaining l rs
      return $ r:tc
    False -> allTermsContaining l rs

processEqualities :: [EqLiteral] -> DecideEq ()
processEqualities [] = return ()
processEqualities (EqLiteral Eq l r:ts) = do
  newEqs <- addEq l r
  processEqualities (newEqs ++ ts)
  
processDisequalities :: [EqLiteral] -> DecideEq Bool
processDisequalities [] = return True
processDisequalities (EqLiteral Neq l r : ts) = do
  same <- sameClass l r
  case same of
    True -> return False
    False -> processDisequalities ts