module Theory.Tools.AbstractInterpretation (
interpretAbstractly
, EvaluationStyle(..)
, partialEvaluation
) where
import Debug.Trace
import Control.Basics
import Control.Monad.Bind
import Control.Monad.Reader
import Data.Label
import Data.List
import qualified Data.Set as S
import Data.Traversable (traverse)
import Term.Substitution
import Theory.Model
import Theory.Text.Pretty
interpretAbstractly
:: (Eq s, HasFrees i, Apply i, Show i)
=> ([Equal LNFact] -> [LNSubstVFresh])
-> s
-> (LNFact -> s -> s)
-> (s -> [LNFact])
-> [Rule i]
-> [(s, [Rule i])]
interpretAbstractly unifyFactEqs initState addFact stateFacts rus =
go st0
where
st0 = addFact (freshFact (varTerm (LVar "z" LSortFresh 0))) $
addFact (inFact (varTerm (LVar "z" LSortMsg 0))) $
initState
go st =
(st, rus') : if st == st' then [] else go st'
where
rus' = concatMap refineRule rus
st' = foldl' (flip addFact) st $ concatMap (get rConcs) rus'
refineRule ru = (`evalFreshT` avoid ru) $ do
eqs <- forM (get rPrems ru) $ \prem -> msum $ do
fa <- stateFacts st
guard (factTag prem == factTag fa)
return (Equal prem <$> rename fa)
subst <- msum $ freshToFree <$> unifyFactEqs eqs
return $ apply subst ru
data EvaluationStyle = Silent | Summary | Tracing
partialEvaluation :: EvaluationStyle
-> [ProtoRuleE] -> WithMaude (S.Set LNFact, [ProtoRuleE])
partialEvaluation evalStyle ruEs = reader $ \hnd ->
consumeEvaluation $ interpretAbstractly
((`runReader` hnd) . unifyLNFactEqs)
S.empty
(S.insert . absFact)
S.toList
ruEs
where
consumeEvaluation [] = error "partialEvaluation: impossible"
consumeEvaluation ((st0, rus0) : rest0) =
go (0 :: Int) st0 rus0 rest0
where
go _ st rus [] =
( st
, nubBy eqModuloFreshnessNoAC $
map ((`evalFresh` nothingUsed) . rename) rus
)
go i st _ ((st', rus') : rest) =
withTrace (go (i + 1) st' rus' rest)
where
incDesc = " partial evaluation: step " ++ show i ++ " added " ++
show (S.size st' S.size st) ++ " facts"
withTrace = case evalStyle of
Silent -> id
Summary -> trace incDesc
Tracing -> trace $ incDesc ++ "\n\n" ++
( render $ nest 2 $ numbered' $ map prettyLNFact $
S.toList $ st' `S.difference` st ) ++ "\n"
absFact :: LNFact -> LNFact
absFact fa = case fa of
Fact OutFact _ -> outFact (varTerm (LVar "z" LSortMsg 0))
Fact tag ts -> Fact tag $ evalAbstraction $ traverse absTerm ts
where
evalAbstraction = (`evalBind` noBindings) . (`evalFreshT` nothingUsed)
absTerm t = case viewTerm t of
Lit (Con _) -> pure t
FApp (sym@(NoEq _)) ts
-> fApp sym <$> traverse absTerm ts
_ -> importBinding mkVar t (varName t)
where
mkVar name idx = varTerm (LVar name (sortOfLNTerm t) idx)
varName (viewTerm -> Lit (Var v)) = lvarName v
varName _ = "z"