module Term.Unification (
unifyLTerm
, unifyLNTerm
, matchLTerm
, matchLNTerm
, unifyLTermFactored
, unifyLNTermFactored
, MaudeHandle
, WithMaude
, startMaude
, getMaudeStats
, mhMaudeSig
, mhFilePath
, MaudeSig
, enableDH
, enableXor
, enableMSet
, minimalMaudeSig
, dhMaudeSig
, xorMaudeSig
, msetMaudeSig
, pairMaudeSig
, symEncMaudeSig
, asymEncMaudeSig
, signatureMaudeSig
, hashMaudeSig
, rrulesForMaudeSig
, allFunctionSymbols
, stRules
, irreducibleFunctionSymbols
, addFunctionSymbol
, addStRule
, module Term.Substitution
, module Term.Rewriting.Definitions
) where
import Control.Applicative
import Control.Monad.RWS
import Control.Monad.Reader
import Control.Monad.Error
import Control.Monad.State
import qualified Data.Map as M
import Data.Map (Map)
import System.IO.Unsafe (unsafePerformIO)
import Term.Rewriting.Definitions
import Term.Substitution
import qualified Term.Maude.Process as UM
import Term.Maude.Process
(MaudeHandle, WithMaude, startMaude, getMaudeStats, mhMaudeSig, mhFilePath)
import Term.Maude.Signature
import Debug.Trace.Ignore
unifyLTermFactored :: (IsConst c , Show (Lit c LVar), Ord c)
=> (c -> LSort)
-> [Equal (LTerm c)]
-> WithMaude (LSubst c, [SubstVFresh c LVar])
unifyLTermFactored sortOf eqs = reader $ \h -> (\res -> trace (unlines $ ["unifyLTerm: "++ show eqs, "result = "++ show res]) res) $ do
solve h $ execRWST unif sortOf M.empty
where
unif = sequence [ unifyRaw t p | Equal t p <- eqs ]
solve _ Nothing = (emptySubst, [])
solve _ (Just (m, [])) = (substFromMap m, [emptySubstVFresh])
solve h (Just (m, leqs)) =
(subst, unsafePerformIO (UM.unifyViaMaude h sortOf $
map (applyVTerm subst <$>) leqs))
where subst = substFromMap m
unifyLNTermFactored :: [Equal LNTerm]
-> WithMaude (LNSubst, [SubstVFresh Name LVar])
unifyLNTermFactored = unifyLTermFactored sortOfName
unifyLTerm :: (IsConst c , Show (Lit c LVar), Ord c)
=> (c -> LSort)
-> [Equal (LTerm c)]
-> WithMaude [SubstVFresh c LVar]
unifyLTerm sortOf eqs = flattenUnif <$> unifyLTermFactored sortOf eqs
unifyLNTerm :: [Equal LNTerm] -> WithMaude [SubstVFresh Name LVar]
unifyLNTerm = unifyLTerm sortOfName
flattenUnif :: IsConst c => (LSubst c, [LSubstVFresh c]) -> [LSubstVFresh c]
flattenUnif (subst, substs) = (\res -> trace (show ("flattenUnif",subst, substs,res )) res) $ map (`composeVFresh` subst) substs
matchLTerm :: (IsConst c , Show (Lit c LVar), Ord c)
=> (c -> LSort)
-> [Match (LTerm c)]
-> WithMaude [Subst c LVar]
matchLTerm sortOf eqs =
reader $ \h -> (\res -> trace (unlines $ ["matchLTerm: "++ show eqs, "result = "++ show res]) res) $
case runState (runErrorT match) M.empty of
(Left NoMatch,_) -> []
(Left ACProblem, _) -> unsafePerformIO (UM.matchViaMaude h sortOf eqs)
(Right _, mappings) -> [substFromMap mappings]
where
match = sequence [ matchRaw sortOf t p | MatchWith t p <- eqs ]
matchLNTerm :: [Match LNTerm] -> WithMaude [Subst Name LVar]
matchLNTerm = matchLTerm sortOfName
type UnifyRaw c = RWST (c -> LSort) [Equal (LTerm c)] (Map LVar (VTerm c LVar)) Maybe
unifyRaw :: IsConst c => LTerm c -> LTerm c -> UnifyRaw c ()
unifyRaw l0 r0 = do
mappings <- get
sortOf <- ask
l <- gets ((`applyVTerm` l0) . substFromMap)
r <- gets ((`applyVTerm` r0) . substFromMap)
guard (trace (show ("unifyRaw", mappings, l ,r)) True)
case (viewTerm l, viewTerm r) of
(Lit (Var vl), Lit (Var vr))
| vl == vr -> return ()
| otherwise -> case (lvarSort vl, lvarSort vr) of
(sl, sr) | sl == sr -> if vl < vr then elim vr l
else elim vl r
_ | sortGeqLTerm sortOf vl r -> elim vl r
_ -> elim vr l
(Lit (Var vl), _ ) -> elim vl r
(_, Lit (Var vr) ) -> elim vr l
(Lit (Con cl), Lit (Con cr) ) -> guard (cl == cr)
(FApp (NonAC lfsym) largs, FApp (NonAC rfsym) rargs) ->
guard (lfsym == rfsym && length largs == length rargs)
>> sequence_ (zipWith unifyRaw largs rargs)
(FApp List largs, FApp List rargs) ->
guard (length largs == length rargs)
>> sequence_ (zipWith unifyRaw largs rargs)
(FApp (AC lacsym) _, FApp (AC racsym) _) ->
guard (lacsym == racsym) >> tell [Equal l r]
_ -> mzero
where
elim v t
| v `occurs` t = mzero
| otherwise = do
sortOf <- ask
guard (sortGeqLTerm sortOf v t)
modify (M.insert v t . M.map (applyVTerm (substFromList [(v,t)])))
data MatchFailure = NoMatch | ACProblem
instance Error MatchFailure where
strMsg _ = NoMatch
matchRaw :: IsConst c
=> (c -> LSort)
-> LTerm c
-> LTerm c
-> ErrorT MatchFailure (State (Map LVar (VTerm c LVar))) ()
matchRaw sortOf t p = do
mappings <- get
guard (trace (show (mappings,t,p)) True)
case (t, p) of
(_, viewTerm -> Lit (Var vp)) ->
case M.lookup vp mappings of
Nothing -> do
unless (sortGeqLTerm sortOf vp t) $
throwError NoMatch
modify (M.insert vp t)
Just tp | t == tp -> return ()
| otherwise -> throwError NoMatch
(viewTerm -> Lit (Con ct), viewTerm -> Lit (Con cp)) -> guard (ct == cp)
(viewTerm -> FApp (NonAC tfsym) targs, viewTerm -> FApp (NonAC pfsym) pargs) ->
guard (tfsym == pfsym && length targs == length pargs)
>> sequence_ (zipWith (matchRaw sortOf) targs pargs)
(viewTerm -> FApp List targs, viewTerm -> FApp List pargs) ->
guard (length targs == length pargs)
>> sequence_ (zipWith (matchRaw sortOf) targs pargs)
(viewTerm -> FApp (AC _) _, viewTerm -> FApp (AC _) _) -> throwError ACProblem
_ -> throwError NoMatch
sortGeqLTerm :: IsConst c => (c -> LSort) -> LVar -> LTerm c -> Bool
sortGeqLTerm st v t = do
case (lvarSort v, sortOfLTerm st t) of
(s1, s2) | s1 == s2 -> True
(LSortNode, _ ) -> errNodeSort
(_, LSortNode) -> errNodeSort
(s1, s2) -> sortCompare s1 s2 `elem` [Just EQ, Just GT]
where
errNodeSort = error $
"sortGeqLTerm: node sort misuse " ++ show v ++ " -> " ++ show t