module Theory.Constraint.Solver.Reduction (
Reduction
, execReduction
, runReduction
, ChangeIndicator(..)
, whenChanged
, applyChangeList
, whileChanging
, getProofContext
, getMaudeHandle
, labelNodeId
, insertFreshNode
, insertFreshNodeConc
, insertGoal
, insertAtom
, insertEdges
, insertChain
, insertAction
, insertLess
, insertFormula
, reducibleFormula
, markGoalAsSolved
, removeSolvedSplitGoals
, substSystem
, substNodes
, substEdges
, substLastAtom
, substLessAtoms
, substFormulas
, substSolvedFormulas
, SplitStrategy(..)
, solveNodeIdEqs
, solveTermEqs
, solveFactEqs
, solveRuleEqs
, solveSubstEqs
, conjoinSystem
, module Logic.Connectives
) where
import Debug.Trace
import Prelude hiding (id, (.))
import qualified Data.Foldable as F
import qualified Data.Map as M
import qualified Data.Set as S
import Data.List (mapAccumL)
import Safe
import Control.Basics
import Control.Category
import Control.Monad.Bind
import Control.Monad.Disj
import Control.Monad.Reader
import Control.Monad.State (StateT, execStateT, gets, runStateT)
import Text.PrettyPrint.Class
import Extension.Data.Label
import Extension.Data.Monoid (Monoid(..))
import Extension.Prelude
import Logic.Connectives
import Theory.Constraint.Solver.Contradictions
import Theory.Constraint.Solver.Types
import Theory.Constraint.System
import Theory.Model
type Reduction = StateT System (FreshT (DisjT (Reader ProofContext)))
runReduction :: Reduction a -> ProofContext -> System -> FreshState
-> Disj ((a, System), FreshState)
runReduction m ctxt se fs =
Disj $ (`runReader` ctxt) $ runDisjT $ (`runFreshT` fs) $ runStateT m se
execReduction :: Reduction a -> ProofContext -> System -> FreshState
-> Disj (System, FreshState)
execReduction m ctxt se fs =
Disj $ (`runReader` ctxt) . runDisjT . (`runFreshT` fs) $ execStateT m se
data ChangeIndicator = Unchanged | Changed
deriving( Eq, Ord, Show )
instance Monoid ChangeIndicator where
mempty = Unchanged
Changed `mappend` _ = Changed
_ `mappend` Changed = Changed
Unchanged `mappend` Unchanged = Unchanged
wasChanged :: ChangeIndicator -> Bool
wasChanged Changed = True
wasChanged Unchanged = False
whenChanged :: Monad m => ChangeIndicator -> m () -> m ()
whenChanged = when . wasChanged
applyChangeList :: [Reduction ()] -> Reduction ChangeIndicator
applyChangeList [] = return Unchanged
applyChangeList changes = sequence_ changes >> return Changed
whileChanging :: Reduction ChangeIndicator -> Reduction ChangeIndicator
whileChanging reduction =
go Unchanged
where
go indicator = do indicator' <- reduction
case indicator' of
Unchanged -> return indicator
Changed -> go indicator'
getProofContext :: Reduction ProofContext
getProofContext = ask
getMaudeHandle :: Reduction MaudeHandle
getMaudeHandle = askM pcMaudeHandle
insertFreshNodeConc :: [RuleAC] -> Reduction (RuleACInst, NodeConc, LNFact)
insertFreshNodeConc rules = do
(i, ru) <- insertFreshNode rules
(v, fa) <- disjunctionOfList $ enumConcs ru
return (ru, (i, v), fa)
insertFreshNode :: [RuleAC] -> Reduction (NodeId, RuleACInst)
insertFreshNode rules = do
i <- freshLVar "vr" LSortNode
(,) i <$> labelNodeId i rules
labelNodeId :: NodeId -> [RuleAC] -> Reduction RuleACInst
labelNodeId = \i rules -> do
(ru, mrconstrs) <- importRule =<< disjunctionOfList rules
solveRuleConstraints mrconstrs
modM sNodes (M.insert i ru)
exploitPrems i ru
return ru
where
importRule ru = someRuleACInst ru `evalBindT` noBindings
mkISendRuleAC m = return $ Rule (IntrInfo (ISendRule))
[kuFact m] [inFact m] [kLogFact m]
mkFreshRuleAC m = Rule (ProtoInfo (ProtoRuleACInstInfo FreshRule []))
[] [freshFact m] []
exploitPrems i ru = mapM_ (exploitPrem i ru) (enumPrems ru)
exploitPrem i ru (v, fa) = case fa of
Fact InFact [m] -> do
j <- freshLVar "vf" LSortNode
ruKnows <- mkISendRuleAC m
modM sNodes (M.insert j ruKnows)
modM sEdges (S.insert $ Edge (j, ConcIdx 0) (i, v))
exploitPrems j ruKnows
Fact FreshFact [m] -> do
j <- freshLVar "vf" LSortNode
modM sNodes (M.insert j (mkFreshRuleAC m))
unless (isFreshVar m) $ do
n <- varTerm <$> freshLVar "n" LSortFresh
void (solveTermEqs SplitNow [Equal m n])
modM sEdges (S.insert $ Edge (j, ConcIdx 0) (i,v))
_ | isKUFact fa -> do
j <- freshLVar "vk" LSortNode
insertLess j i
void (insertAction j fa)
| otherwise -> insertGoal (PremiseG (i,v) fa) (v `elem` breakers)
where
breakers = ruleInfo (get praciLoopBreakers) (const []) $ get rInfo ru
insertChain :: NodeConc -> NodePrem -> Reduction ()
insertChain c p = insertGoal (ChainG c p) False
insertEdges :: [(NodeConc, LNFact, LNFact, NodePrem)] -> Reduction ()
insertEdges edges = do
void (solveFactEqs SplitNow [ Equal fa1 fa2 | (_, fa1, fa2, _) <- edges ])
modM sEdges (\es -> foldr S.insert es [ Edge c p | (c,_,_,p) <- edges])
insertAction :: NodeId -> LNFact -> Reduction ChangeIndicator
insertAction i fa = do
present <- (goal `M.member`) <$> getM sGoals
if present
then do return Unchanged
else do insertGoal goal False
case kFactView fa of
Just (UpK, viewTerm2 -> FPair m1 m2) ->
requiresKU m1 *> requiresKU m2 *> return Changed
Just (UpK, viewTerm2 -> FInv m) ->
requiresKU m *> return Changed
Just (UpK, viewTerm2 -> FMult ms) ->
mapM_ requiresKU ms *> return Changed
Just (UpK, viewTerm2 -> FUnion ms) ->
mapM_ requiresKU ms *> return Changed
_ -> return Unchanged
where
goal = ActionG i fa
requiresKU t = do
j <- freshLVar "vk" LSortNode
let faKU = kuFact t
insertLess j i
void (insertAction j faKU)
insertLess :: NodeId -> NodeId -> Reduction ()
insertLess i j = modM sLessAtoms (S.insert (i, j))
insertLast :: NodeId -> Reduction ChangeIndicator
insertLast i = do
lst <- getM sLastAtom
case lst of
Nothing -> setM sLastAtom (Just i) >> return Unchanged
Just j -> solveNodeIdEqs [Equal i j]
insertAtom :: LNAtom -> Reduction ChangeIndicator
insertAtom ato = case ato of
EqE x y -> solveTermEqs SplitNow [Equal x y]
Action i fa -> insertAction (ltermNodeId' i) fa
Less i j -> do insertLess (ltermNodeId' i) (ltermNodeId' j)
return Unchanged
Last i -> insertLast (ltermNodeId' i)
insertFormula :: LNGuarded -> Reduction ()
insertFormula = do
insert True
where
insert mark fm = do
formulas <- getM sFormulas
solvedFormulas <- getM sSolvedFormulas
insert' mark formulas solvedFormulas fm
insert' mark formulas solvedFormulas fm
| fm `S.member` formulas = return ()
| fm `S.member` solvedFormulas = return ()
| otherwise = case fm of
GAto ato -> do
markAsSolved
void (insertAtom (bvarToLVar ato))
GConj fms -> do
markAsSolved
mapM_ (insert False) (getConj fms)
GDisj disj -> do
modM sFormulas (S.insert fm)
insertGoal (DisjG disj) False
GGuarded Ex ss as gf -> do
modM sSolvedFormulas $ S.insert fm
xs <- mapM (uncurry freshLVar) ss
let body = gconj (map GAto as ++ [gf])
insert False (substBound (zip [0..] (reverse xs)) body)
GGuarded All [] [Less i j] gf | gf == gfalse -> do
markAsSolved
insert False (gdisj [GAto (EqE i j), GAto (Less j i)])
GGuarded All [] [EqE i@(bltermNodeId -> Just _)
j@(bltermNodeId -> Just _) ] gf
| gf == gfalse -> do
markAsSolved
insert False (gdisj [GAto (Less i j), GAto (Less j i)])
GGuarded All [] [Last i] gf | gf == gfalse -> do
markAsSolved
lst <- getM sLastAtom
j <- case lst of
Nothing -> do j <- freshLVar "last" LSortNode
void (insertLast j)
return (varTerm (Free j))
Just j -> return (varTerm (Free j))
insert False $ gdisj [ GAto (Less j i), GAto (Less i j) ]
GGuarded All _ _ _ -> modM sFormulas (S.insert fm)
where
markAsSolved = when mark $ modM sSolvedFormulas $ S.insert fm
reducibleFormula :: LNGuarded -> Bool
reducibleFormula fm = case fm of
GAto _ -> True
GConj _ -> True
GGuarded Ex _ _ _ -> True
GGuarded All [] [Less _ _] gf -> gf == gfalse
GGuarded All [] [Last _] gf -> gf == gfalse
_ -> False
combineGoalStatus :: GoalStatus -> GoalStatus -> GoalStatus
combineGoalStatus (GoalStatus solved1 age1 loops1)
(GoalStatus solved2 age2 loops2) =
GoalStatus (solved1 || solved2) (min age1 age2) (loops1 || loops2)
insertGoalStatus :: Goal -> GoalStatus -> Reduction ()
insertGoalStatus goal status = do
age <- getM sNextGoalNr
modM sGoals $ M.insertWith' combineGoalStatus goal (set gsNr age status)
sNextGoalNr =: succ age
insertGoal :: Goal -> Bool -> Reduction ()
insertGoal goal looping = insertGoalStatus goal (GoalStatus False 0 looping)
markGoalAsSolved :: String -> Goal -> Reduction ()
markGoalAsSolved how goal =
case goal of
ActionG _ _ -> updateStatus
PremiseG _ fa
| isKDFact fa -> modM sGoals $ M.delete goal
| otherwise -> updateStatus
ChainG _ _ -> modM sGoals $ M.delete goal
SplitG _ -> updateStatus
DisjG disj -> modM sFormulas (S.delete $ GDisj disj) >>
modM sSolvedFormulas (S.insert $ GDisj disj) >>
updateStatus
where
updateStatus = do
mayStatus <- M.lookup goal <$> getM sGoals
case mayStatus of
Just status -> trace (msg status) $
modM sGoals $ M.insert goal $ set gsSolved True status
Nothing -> trace ("markGoalAsSolved: inexistent goal " ++ show goal) $ return ()
msg status = render $ nest 2 $ fsep $
[ text ("solved goal nr. "++ show (get gsNr status))
<-> parens (text how) <> colon
, nest 2 (prettyGoal goal) ]
removeSolvedSplitGoals :: Reduction ()
removeSolvedSplitGoals = do
goals <- getM sGoals
existent <- splitExists <$> getM sEqStore
sequence_ [ modM sGoals $ M.delete goal
| goal@(SplitG i) <- M.keys goals, not (existent i) ]
substSystem :: Reduction ChangeIndicator
substSystem = do
c1 <- substNodes
substEdges
substLastAtom
substLessAtoms
substFormulas
substSolvedFormulas
substLemmas
c2 <- substGoals
substNextGoalNr
return (c1 <> c2)
substEdges, substLessAtoms, substLastAtom, substFormulas,
substSolvedFormulas, substLemmas, substNextGoalNr :: Reduction ()
substEdges = substPart sEdges
substLessAtoms = substPart sLessAtoms
substLastAtom = substPart sLastAtom
substFormulas = substPart sFormulas
substSolvedFormulas = substPart sSolvedFormulas
substLemmas = substPart sLemmas
substNextGoalNr = return ()
substPart :: Apply a => (System :-> a) -> Reduction ()
substPart l = do subst <- getM sSubst
modM l (apply subst)
substNodes :: Reduction ChangeIndicator
substNodes =
substNodeIds <* ((modM sNodes . M.map . apply) =<< getM sSubst)
setNodes :: [(NodeId, RuleACInst)] -> Reduction ChangeIndicator
setNodes nodes0 = do
sNodes =: M.fromList nodes
if null ruleEqs then return Unchanged
else solveRuleEqs SplitLater ruleEqs >> return Changed
where
(ruleEqs, nodes) = first concat $ unzip $ map merge $ groupSortOn fst nodes0
merge [] = unreachable "setNodes"
merge (keep:remove) = (map (Equal (snd keep) . snd) remove, keep)
substNodeIds :: Reduction ChangeIndicator
substNodeIds =
whileChanging $ do
subst <- getM sSubst
nodes <- gets (map (first (apply subst)) . M.toList . get sNodes)
setNodes nodes
substGoals :: Reduction ChangeIndicator
substGoals = do
subst <- getM sSubst
goals <- M.toList <$> getM sGoals
sGoals =: M.empty
changes <- forM goals $ \(goal, status) -> case goal of
ActionG i fa@(kFactView -> Just (UpK, m))
| (isMsgVar m || isProduct m || isUnion m) && (apply subst m /= m) ->
insertAction i (apply subst fa)
_ -> do modM sGoals $
M.insertWith' combineGoalStatus (apply subst goal) status
return Unchanged
return (mconcat changes)
conjoinSystem :: System -> Reduction ()
conjoinSystem sys = do
kind <- getM sCaseDistKind
unless (kind == get sCaseDistKind sys) $
error "conjoinSystem: typing-kind mismatch"
joinSets sSolvedFormulas
joinSets sLemmas
joinSets sEdges
F.mapM_ insertLast $ get sLastAtom sys
F.mapM_ (uncurry insertLess) $ get sLessAtoms sys
mapM_ (uncurry insertGoalStatus) $ filter (not . isSplitGoal . fst) $ M.toList $ get sGoals sys
F.mapM_ insertFormula $ get sFormulas sys
_ <- (setNodes . (M.toList (get sNodes sys) ++) . M.toList) =<< getM sNodes
eqs <- getM sEqStore
let (eqs',splitIds) = (mapAccumL addDisj eqs (map snd . getConj $ get sConjDisjEqs sys))
setM sEqStore eqs'
mapM_ (`insertGoal` False) $ SplitG <$> splitIds
void (solveSubstEqs SplitNow $ get sSubst sys)
void substSystem
where
joinSets :: Ord a => (System :-> S.Set a) -> Reduction ()
joinSets proj = modM proj (`S.union` get proj sys)
data SplitStrategy = SplitNow | SplitLater
noContradictoryEqStore :: Reduction ()
noContradictoryEqStore = (contradictoryIf . eqsIsFalse) =<< getM sEqStore
solveTermEqs :: SplitStrategy -> [Equal LNTerm] -> Reduction ChangeIndicator
solveTermEqs splitStrat eqs0 =
case filter (not . evalEqual) eqs0 of
[] -> do return Unchanged
eqs1 -> do
hnd <- getMaudeHandle
se <- gets id
(eqs2, maySplitId) <- addEqs hnd eqs1 =<< getM sEqStore
setM sEqStore
=<< simp hnd (substCreatesNonNormalTerms hnd se)
=<< case (maySplitId, splitStrat) of
(Just splitId, SplitNow) -> disjunctionOfList
$ fromJustNote "solveTermEqs"
$ performSplit eqs2 splitId
(Just splitId, SplitLater) -> do
insertGoal (SplitG splitId) False
return eqs2
_ -> return eqs2
noContradictoryEqStore
return Changed
solveSubstEqs :: SplitStrategy -> LNSubst -> Reduction ChangeIndicator
solveSubstEqs split subst =
solveTermEqs split [Equal (varTerm v) t | (v, t) <- substToList subst]
solveNodeIdEqs :: [Equal NodeId] -> Reduction ChangeIndicator
solveNodeIdEqs = solveTermEqs SplitNow . map (fmap varTerm)
solveFactEqs :: SplitStrategy -> [Equal LNFact] -> Reduction ChangeIndicator
solveFactEqs split eqs = do
contradictoryIf (not $ all evalEqual $ map (fmap factTag) eqs)
solveListEqs (solveTermEqs split) $ map (fmap factTerms) eqs
solveRuleEqs :: SplitStrategy -> [Equal RuleACInst] -> Reduction ChangeIndicator
solveRuleEqs split eqs = do
contradictoryIf (not $ all evalEqual $ map (fmap (get rInfo)) eqs)
solveListEqs (solveFactEqs split) $
map (fmap (get rConcs)) eqs ++ map (fmap (get rPrems)) eqs
++ map (fmap (get rActs)) eqs
solveListEqs :: ([Equal a] -> Reduction b) -> [(Equal [a])] -> Reduction b
solveListEqs solver eqs = do
contradictoryIf (not $ all evalEqual $ map (fmap length) eqs)
solver $ concatMap flatten eqs
where
flatten (Equal l r) = zipWith Equal l r
solveRuleConstraints :: Maybe RuleACConstrs -> Reduction ()
solveRuleConstraints (Just eqConstr) = do
hnd <- getMaudeHandle
(eqs, splitId) <- addRuleVariants eqConstr <$> getM sEqStore
insertGoal (SplitG splitId) False
setM sEqStore =<< simp hnd (const (const False)) eqs
noContradictoryEqStore
solveRuleConstraints Nothing = return ()