module Folly.Formula(
  Term, Formula,
  fvt, subTerm, isVar, isConst, isFunc,
  funcName, funcArgs, varName,
  appendVarName, collectVars,
  var, func, constant,
  te, fa, pr, con, dis, neg, imp, bic, t, f,
  (/\), (\/),
  vars, freeVars, isAtom, stripNegations,
  generalize, subFormula,
  applyToTerms,
  literalArgs,
  toPNF, toSkolemForm, skf,
  toClausalForm,
  matchingLiterals) where

import Control.Monad
import Data.Set as S
import Data.List as L
import Data.Map as M

data Term =
  Constant String    |
  Var String         |
  Func String [Term]
  deriving (Eq, Ord)
           
instance Show Term where
  show = showTerm
   
showTerm :: Term -> String
showTerm (Constant name) = name
showTerm (Var name) = name
showTerm (Func name args) = name ++ "(" ++ (concat $ intersperse ", " $ L.map showTerm args) ++ ")"

isVar (Var _) = True
isVar _ = False

isFunc (Func _ _) = True
isFunc _ = False

isConst (Constant _) = True
isConst _ = False

funcName (Func n _) = n

funcArgs (Func _ a) = a

var n = Var n
func n args = case (L.take 3 n) == "skl" of
  True -> error $ "Function names beginning with skl are reserved for skolemization"
  False -> Func n args
constant n = Constant n

varName (Var n) = n

appendVarName :: String -> Term -> Term
appendVarName suffix (Var n) = Var (n ++ suffix)
appendVarName suffix (Func name args) = Func name $ L.map (appendVarName suffix) args
appendVarName _ t = t

fvt :: Term -> Set Term
fvt (Constant _) = S.empty
fvt (Var n) = S.fromList [(Var n)]
fvt (Func name args) = S.foldl S.union S.empty (S.fromList (L.map fvt args))

subTerm :: Map Term Term -> Term -> Term
subTerm _ (Constant name) = Constant name
subTerm sub (Func name args) = (Func name (L.map (subTerm sub) args))
subTerm sub (Var x) = case M.lookup (Var x) sub of
  Just s -> s
  Nothing -> (Var x)

data Formula =
  T                            | 
  F                            |
  P String [Term]              |
  B String Formula Formula     |
  N Formula                    |
  Q String Term Formula
  deriving (Eq, Ord)
           
instance Show Formula where
  show = showFormula
  
showFormula :: Formula -> String
showFormula T = "True"
showFormula F = "False"
showFormula (P predName args) = predName ++ "[" ++ (concat $ intersperse ", " $ L.map showTerm args)  ++ "]"
showFormula (N (P name args)) = "~" ++ show (P name args)
showFormula (N f) = "~(" ++ show f ++ ")"
showFormula (B op f1 f2) = "(" ++ show f1 ++ " " ++ op ++ " "  ++ show f2 ++ ")"
showFormula (Q q t f) = "(" ++ q ++ " "  ++ show t ++ " . " ++ show f ++ ")"

applyToTerms :: Formula -> (Term -> Term) -> Formula
applyToTerms (P n args) f = P n $ L.map f args
applyToTerms (B n l r) f = B n (applyToTerms l f) (applyToTerms r f)
applyToTerms (Q n v l) f = Q n (f v) (applyToTerms l f)
applyToTerms (N l) f = N (applyToTerms l f)

collectVars :: Formula -> [Term]
collectVars (P _ args) = L.concatMap (\t -> if isVar t then [t] else []) args
collectVars (N f) = collectVars f
collectVars (B _ a b) = collectVars a ++ collectVars b
collectVars (Q _ v f) = v:(collectVars f)
collectVars _ = []

te :: Term -> Formula -> Formula
te v@(Var _) f = Q "E" v f
te t _ = error $ "Cannot quantify over non-variable term " ++ show t

fa :: Term -> Formula -> Formula
fa v@(Var _) f = Q "V" v f
fa t _ = error $ "Cannot quantify over non-variable term " ++ show t

pr name args = P name args
con f1 f2 = B "&" f1 f2
dis f1 f2 = B "|" f1 f2
imp f1 f2 = B "->" f1 f2
bic f1 f2 = B "<->" f1 f2
neg f = N f
t = T
f = F

(/\) = con
(\/) = dis

vars :: Formula -> Set Term
vars T = S.empty
vars F = S.empty
vars (P name terms) = S.fold S.union S.empty $ S.fromList (L.map fvt terms)
vars (B _ f1 f2) = S.union (vars f1) (vars f2)
vars (N f) = vars f
vars (Q _ v f) = S.insert v (vars f)

freeVars :: Formula -> Set Term
freeVars T = S.empty
freeVars F = S.empty
freeVars (P name terms) = S.fold S.union S.empty $ S.fromList (L.map fvt terms)
freeVars (B _ f1 f2) = S.union (freeVars f1) (freeVars f2)
freeVars (N f) = freeVars f
freeVars (Q _ v f) = S.delete v (freeVars f)

isAtom :: Formula -> Bool
isAtom (P _ _) = True
isAtom _ = False

stripNegations :: Formula -> Formula
stripNegations (N t) = t
stripNegations f = f

literalArgs :: Formula -> [Term]
literalArgs (P _ a) = a
literalArgs (N (P _ a)) = a
literalArgs l = error $ show l ++ " is not a literal"

matchingLiterals :: Formula -> Formula -> Bool
matchingLiterals (P n1 _) (N (P n2 _)) = n1 == n2
matchingLiterals (N (P n1 _)) (P n2 _) = n1 == n2
matchingLiterals (P _ _) (P _ _) = False
matchingLiterals (N (P _ _)) (N (P _ _)) = False
matchingLiterals l1 l2 = error $ show l1 ++ " or " ++ show l2 ++ " is not a literal"

generalize :: Formula -> Formula
generalize f = applyList genFreeVar f
  where
    genFreeVar = L.map fa (S.toList (freeVars f))

applyList :: [a -> a] -> a -> a
applyList [] a = a
applyList (f:fs) a = applyList fs (f a)

variant :: Set Term -> Term -> Term
variant vars x@(Var n) = case S.member x vars of
  True -> variant vars (Var (n ++ "'"))
  False -> x
  
subFormula :: Map Term Term -> Formula -> Formula
subFormula subst (P name args) = P name $ L.map (subTerm subst) args
subFormula subst (B op f1 f2) = B op (subFormula subst f1) (subFormula subst f2)
subFormula subst (N f) = N (subFormula subst f)
subFormula subst q@(Q _ _ _) = subQuant subst q
subFormula subst f = f

subQuant :: Map Term Term -> Formula -> Formula
subQuant subst (Q n v f) = case (M.filter (== v) subst) == M.empty of
  True -> Q n v (subFormula subst f)
  False -> Q n vNew $ subFormula (M.insert v vNew subst) f
  where
    vNew = variant (freeVars (subFormula (M.delete v subst) f)) v
    
    
toPNF :: Formula -> Formula
toPNF = (transformFormula pullQuantifiers) .
        (transformFormula simplifyFormula) .
        (transformFormula pushNegation) .
        (transformFormula elimVacuousQuantifiers) .
        (transformFormula replaceImp) .
        (transformFormula replaceBic)

pullQuantifiers f@(B "&" (Q "V" x p) (Q "V" y q)) = pullQ True True f fa con x y p q
pullQuantifiers f@(B "|" (Q "E" x p) (Q "E" y q)) = pullQ True True f te dis x y p q
pullQuantifiers f@(B "|" (Q "V" x p) q) = pullQ True False f fa dis x x p q
pullQuantifiers f@(B "|" p (Q "V" y q)) = pullQ False True f fa dis y y p q
pullQuantifiers f@(B "|" (Q "E" x p) q) = pullQ True False f te dis x x p q
pullQuantifiers f@(B "|" p (Q "E" y q)) = pullQ False True f te dis y y p q
pullQuantifiers f@(B "&" (Q "V" x p) q) = pullQ True False f fa con x x p q
pullQuantifiers f@(B "&" p (Q "V" y q)) = pullQ False True f fa con y y p q
pullQuantifiers f@(B "&" (Q "E" x p) q) = pullQ True False f te con x x p q
pullQuantifiers f@(B "&" p (Q "E" y q)) = pullQ False True f te con y y p q
pullQuantifiers f = f

pullQ :: Bool ->
         Bool ->
         Formula ->
         (Term -> Formula -> Formula) ->
         (Formula -> Formula -> Formula) ->
         Term ->
         Term ->
         Formula ->
         Formula ->
         Formula
pullQ l r f quant op x y p q =
  let z = variant (freeVars f) x in
  let ps = if l then subFormula (M.singleton x z) p else p in
  let qs = if r then subFormula (M.singleton y z) q else q in
  quant z (pullQuantifiers $ op ps qs)

simplifyFormula (N (N f)) = f
simplifyFormula (N T) = F
simplifyFormula (N F) = T
simplifyFormula (B "|" T f) = T
simplifyFormula (B "|" f T) = T
simplifyFormula (B "|" F F) = F
simplifyFormula (B "&" F f) = F
simplifyFormula (B "&" f F) = F
simplifyFormula (B "&" T T) = T
simplifyFormula f = f

pushNegation (N (B "|" f1 f2)) = B "&" (pushNegation (N f1)) (pushNegation (N f2))
pushNegation (N (B "&" f1 f2)) = B "|" (pushNegation (N f1)) (pushNegation (N f2))
pushNegation (N (Q "V" x f)) = Q "E" x (pushNegation (N f))
pushNegation (N (Q "E" x f)) = Q "V" x (pushNegation (N f))
pushNegation f = f

elimVacuousQuantifiers (Q n x f) = case S.member x (freeVars f) of
  True -> Q n x f
  False -> f
elimVacuousQuantifiers f = f

replaceImp (B "->" f1 f2) = dis (neg f1) f2
replaceImp f = f

replaceBic (B "<->" f1 f2) = con (imp f1 f2) (imp f2 f1)
replaceBic f = f

transformFormula :: (Formula -> Formula) -> Formula -> Formula
transformFormula tran (B op f1 f2) = tran (B op (transformFormula tran f1) (transformFormula tran f2))
transformFormula tran (Q q x f) = tran (Q q x (transformFormula tran f))
transformFormula tran (N f) = tran (N (transformFormula tran f))
transformFormula tran f = tran f

-- Conversion to Skolem form
toSkolemForm :: Formula -> Formula
toSkolemForm = skolemize . toPNF

skolemize :: Formula -> Formula
skolemize f = (transformFormula removeExistential) $ replaceVarsWithSkolemFuncs f

removeExistential :: Formula -> Formula
removeExistential (Q "E" v f) = f
removeExistential f = f

replaceVarsWithSkolemFuncs :: Formula -> Formula
replaceVarsWithSkolemFuncs f = subFormula varsToSkolemFuncs f
  where
    varsToSkolemFuncs = collectSkolemFuncs f 0 []
    
collectSkolemFuncs :: Formula -> Int -> [Term] -> Map Term Term
collectSkolemFuncs (Q "E" v f) n vars = M.insert v (skf n vars) (collectSkolemFuncs f (n+1) vars)
collectSkolemFuncs (Q "V" v f) n vars = collectSkolemFuncs f n (v:vars)
collectSkolemFuncs _ _ _ = M.empty

skf :: Int -> [Term] -> Term
skf n vars = Func ("skl" ++ show n) vars

-- Conversion to clausal form
toClausalForm :: Formula -> [[Formula]]
toClausalForm = splitClauses . removeUniversals . distributeDisjunction . toSkolemForm

distributeDisjunction :: (Formula) -> (Formula)
distributeDisjunction f = transformFormula distrDis f

distrDis :: (Formula) -> (Formula)
distrDis (B "|" (B "&" l r) f) = (B "&" (B "|" l f) (B "|" r f))
distrDis (B "|" f (B "&" l r)) = (B "&" (B "|" f l) (B "|" f r))
distrDis f = f

removeUniversals :: Formula -> Formula
removeUniversals (Q "V" v f) = removeUniversals f
removeUniversals f = f

splitClauses :: Formula -> [[Formula]]
splitClauses (B "&" l r) = (splitClauses l) ++ (splitClauses r)
splitClauses f = [splitDis f]

splitDis :: Formula -> [Formula]
splitDis (B "|" l r) = (splitDis l) ++ (splitDis r)
splitDis f = [f]