{-# OPTIONS_GHC -Wall #-}
{-# LANGUAGE ScopedTypeVariables, FlexibleInstances, MultiParamTypeClasses, ExistentialQuantification #-}
module ToySolver.SAT.Encoder.Tseitin
(
Encoder
, newEncoder
, newEncoderWithPBLin
, setUsePB
, Polarity (..)
, negatePolarity
, polarityPos
, polarityNeg
, polarityBoth
, polarityNone
, Formula
, evalFormula
, addFormula
, encodeFormula
, encodeFormulaWithPolarity
, encodeConj
, encodeConjWithPolarity
, encodeDisj
, encodeDisjWithPolarity
, encodeITE
, encodeITEWithPolarity
, encodeXOR
, encodeXORWithPolarity
, encodeFASum
, encodeFASumWithPolarity
, encodeFACarry
, encodeFACarryWithPolarity
, getDefinitions
) where
import Control.Monad
import Control.Monad.Primitive
import Data.Primitive.MutVar
import Data.Map (Map)
import qualified Data.Map as Map
import qualified Data.IntSet as IntSet
import ToySolver.Data.Boolean
import ToySolver.Data.BoolExpr
import qualified ToySolver.SAT.Types as SAT
type Formula = BoolExpr SAT.Lit
evalFormula :: SAT.IModel m => m -> Formula -> Bool
evalFormula m = fold (SAT.evalLit m)
data Encoder m =
forall a. SAT.AddClause m a =>
Encoder
{ encBase :: a
, encAddPBAtLeast :: Maybe (SAT.PBLinSum -> Integer -> m ())
, encUsePB :: !(MutVar (PrimState m) Bool)
, encConjTable :: !(MutVar (PrimState m) (Map SAT.LitSet (SAT.Var, Bool, Bool)))
, encITETable :: !(MutVar (PrimState m) (Map (SAT.Lit, SAT.Lit, SAT.Lit) (SAT.Var, Bool, Bool)))
, encXORTable :: !(MutVar (PrimState m) (Map (SAT.Lit, SAT.Lit) (SAT.Var, Bool, Bool)))
, encFASumTable :: !(MutVar (PrimState m) (Map (SAT.Lit, SAT.Lit, SAT.Lit) (SAT.Var, Bool, Bool)))
, encFACarryTable :: !(MutVar (PrimState m) (Map (SAT.Lit, SAT.Lit, SAT.Lit) (SAT.Var, Bool, Bool)))
}
instance Monad m => SAT.NewVar m (Encoder m) where
newVar Encoder{ encBase = a } = SAT.newVar a
newVars Encoder{ encBase = a } = SAT.newVars a
newVars_ Encoder{ encBase = a } = SAT.newVars_ a
instance Monad m => SAT.AddClause m (Encoder m) where
addClause Encoder{ encBase = a } = SAT.addClause a
newEncoder :: PrimMonad m => SAT.AddClause m a => a -> m (Encoder m)
newEncoder solver = do
usePBRef <- newMutVar False
tableConj <- newMutVar Map.empty
tableITE <- newMutVar Map.empty
tableXOR <- newMutVar Map.empty
tableFASum <- newMutVar Map.empty
tableFACarry <- newMutVar Map.empty
return $
Encoder
{ encBase = solver
, encAddPBAtLeast = Nothing
, encUsePB = usePBRef
, encConjTable = tableConj
, encITETable = tableITE
, encXORTable = tableXOR
, encFASumTable = tableFASum
, encFACarryTable = tableFACarry
}
newEncoderWithPBLin :: PrimMonad m => SAT.AddPBLin m a => a -> m (Encoder m)
newEncoderWithPBLin solver = do
usePBRef <- newMutVar False
tableConj <- newMutVar Map.empty
tableITE <- newMutVar Map.empty
tableXOR <- newMutVar Map.empty
tableFASum <- newMutVar Map.empty
tableFACarry <- newMutVar Map.empty
return $
Encoder
{ encBase = solver
, encAddPBAtLeast = Just (SAT.addPBAtLeast solver)
, encUsePB = usePBRef
, encConjTable = tableConj
, encITETable = tableITE
, encXORTable = tableXOR
, encFASumTable = tableFASum
, encFACarryTable = tableFACarry
}
setUsePB :: PrimMonad m => Encoder m -> Bool -> m ()
setUsePB encoder usePB = writeMutVar (encUsePB encoder) usePB
addFormula :: PrimMonad m => Encoder m -> Formula -> m ()
addFormula encoder formula = do
case formula of
And xs -> mapM_ (addFormula encoder) xs
Equiv a b -> do
lit1 <- encodeFormula encoder a
lit2 <- encodeFormula encoder b
SAT.addClause encoder [SAT.litNot lit1, lit2]
SAT.addClause encoder [SAT.litNot lit2, lit1]
Not (Not a) -> addFormula encoder a
Not (Or xs) -> addFormula encoder (andB (map notB xs))
Not (Imply a b) -> addFormula encoder (a .&&. notB b)
Not (Equiv a b) -> do
lit1 <- encodeFormula encoder a
lit2 <- encodeFormula encoder b
SAT.addClause encoder [lit1, lit2]
SAT.addClause encoder [SAT.litNot lit1, SAT.litNot lit2]
ITE c t e -> do
c' <- encodeFormula encoder c
t' <- encodeFormulaWithPolarity encoder polarityPos t
e' <- encodeFormulaWithPolarity encoder polarityPos e
SAT.addClause encoder [-c', t']
SAT.addClause encoder [ c', e']
_ -> do
c <- encodeToClause encoder formula
SAT.addClause encoder c
encodeToClause :: PrimMonad m => Encoder m -> Formula -> m SAT.Clause
encodeToClause encoder formula =
case formula of
And [x] -> encodeToClause encoder x
Or xs -> do
cs <- mapM (encodeToClause encoder) xs
return $ concat cs
Not (Not x) -> encodeToClause encoder x
Not (And xs) -> do
encodeToClause encoder (orB (map notB xs))
Imply a b -> do
encodeToClause encoder (notB a .||. b)
_ -> do
l <- encodeFormulaWithPolarity encoder polarityPos formula
return [l]
encodeFormula :: PrimMonad m => Encoder m -> Formula -> m SAT.Lit
encodeFormula encoder = encodeFormulaWithPolarity encoder polarityBoth
encodeWithPolarityHelper
:: (PrimMonad m, Ord k)
=> Encoder m
-> MutVar (PrimState m) (Map k (SAT.Var, Bool, Bool))
-> (SAT.Lit -> m ()) -> (SAT.Lit -> m ())
-> Polarity
-> k
-> m SAT.Var
encodeWithPolarityHelper encoder tableRef definePos defineNeg (Polarity pos neg) k = do
table <- readMutVar tableRef
case Map.lookup k table of
Just (v, posDefined, negDefined) -> do
when (pos && not posDefined) $ definePos v
when (neg && not negDefined) $ defineNeg v
when (posDefined < pos || negDefined < neg) $
modifyMutVar' tableRef (Map.insert k (v, (max posDefined pos), (max negDefined neg)))
return v
Nothing -> do
v <- SAT.newVar encoder
when pos $ definePos v
when neg $ defineNeg v
modifyMutVar' tableRef (Map.insert k (v, pos, neg))
return v
encodeFormulaWithPolarity :: PrimMonad m => Encoder m -> Polarity -> Formula -> m SAT.Lit
encodeFormulaWithPolarity encoder p formula = do
case formula of
Atom l -> return l
And xs -> encodeConjWithPolarity encoder p =<< mapM (encodeFormulaWithPolarity encoder p) xs
Or xs -> encodeDisjWithPolarity encoder p =<< mapM (encodeFormulaWithPolarity encoder p) xs
Not x -> liftM SAT.litNot $ encodeFormulaWithPolarity encoder (negatePolarity p) x
Imply x y -> do
encodeFormulaWithPolarity encoder p (notB x .||. y)
Equiv x y -> do
lit1 <- encodeFormulaWithPolarity encoder polarityBoth x
lit2 <- encodeFormulaWithPolarity encoder polarityBoth y
encodeFormulaWithPolarity encoder p $
(Atom lit1 .=>. Atom lit2) .&&. (Atom lit2 .=>. Atom lit1)
ITE c t e -> do
c' <- encodeFormulaWithPolarity encoder polarityBoth c
t' <- encodeFormulaWithPolarity encoder p t
e' <- encodeFormulaWithPolarity encoder p e
encodeITEWithPolarity encoder p c' t' e'
encodeConj :: PrimMonad m => Encoder m -> [SAT.Lit] -> m SAT.Lit
encodeConj encoder = encodeConjWithPolarity encoder polarityBoth
encodeConjWithPolarity :: forall m. PrimMonad m => Encoder m -> Polarity -> [SAT.Lit] -> m SAT.Lit
encodeConjWithPolarity _ _ [l] = return l
encodeConjWithPolarity encoder polarity ls = do
usePB <- readMutVar (encUsePB encoder)
table <- readMutVar (encConjTable encoder)
let ls3 = IntSet.fromList ls
ls2 = case Map.lookup IntSet.empty table of
Nothing -> ls3
Just (litTrue, _, _)
| litFalse `IntSet.member` ls3 -> IntSet.singleton litFalse
| otherwise -> IntSet.delete litTrue ls3
where litFalse = SAT.litNot litTrue
if IntSet.size ls2 == 1 then do
return $ head $ IntSet.toList ls2
else do
let
definePos :: SAT.Lit -> m ()
definePos l = do
case encAddPBAtLeast encoder of
Just addPBAtLeast | usePB -> do
let n = IntSet.size ls2
addPBAtLeast ((- fromIntegral n, l) : [(1,li) | li <- IntSet.toList ls2]) 0
_ -> do
forM_ (IntSet.toList ls2) $ \li -> do
SAT.addClause encoder [SAT.litNot l, li]
defineNeg :: SAT.Lit -> m ()
defineNeg l = do
SAT.addClause encoder (l : map SAT.litNot (IntSet.toList ls2))
encodeWithPolarityHelper encoder (encConjTable encoder) definePos defineNeg polarity ls2
encodeDisj :: PrimMonad m => Encoder m -> [SAT.Lit] -> m SAT.Lit
encodeDisj encoder = encodeDisjWithPolarity encoder polarityBoth
encodeDisjWithPolarity :: PrimMonad m => Encoder m -> Polarity -> [SAT.Lit] -> m SAT.Lit
encodeDisjWithPolarity _ _ [l] = return l
encodeDisjWithPolarity encoder p ls = do
l <- encodeConjWithPolarity encoder (negatePolarity p) [SAT.litNot li | li <- ls]
return $ SAT.litNot l
encodeITE :: PrimMonad m => Encoder m -> SAT.Lit -> SAT.Lit -> SAT.Lit -> m SAT.Lit
encodeITE encoder = encodeITEWithPolarity encoder polarityBoth
encodeITEWithPolarity :: forall m. PrimMonad m => Encoder m -> Polarity -> SAT.Lit -> SAT.Lit -> SAT.Lit -> m SAT.Lit
encodeITEWithPolarity encoder p c t e | c < 0 =
encodeITEWithPolarity encoder p (- c) e t
encodeITEWithPolarity encoder polarity c t e = do
let definePos :: SAT.Lit -> m ()
definePos x = do
SAT.addClause encoder [-x, -c, t]
SAT.addClause encoder [-x, c, e]
SAT.addClause encoder [t, e, -x]
defineNeg :: SAT.Lit -> m ()
defineNeg x = do
SAT.addClause encoder [-c, -t, x]
SAT.addClause encoder [c, -e, x]
SAT.addClause encoder [-t, -e, x]
encodeWithPolarityHelper encoder (encITETable encoder) definePos defineNeg polarity (c,t,e)
encodeXOR :: PrimMonad m => Encoder m -> SAT.Lit -> SAT.Lit -> m SAT.Lit
encodeXOR encoder = encodeXORWithPolarity encoder polarityBoth
encodeXORWithPolarity :: forall m. PrimMonad m => Encoder m -> Polarity -> SAT.Lit -> SAT.Lit -> m SAT.Lit
encodeXORWithPolarity encoder polarity a b = do
let defineNeg x = do
SAT.addClause encoder [a, -b, x]
SAT.addClause encoder [-a, b, x]
definePos x = do
SAT.addClause encoder [a, b, -x]
SAT.addClause encoder [-a, -b, -x]
encodeWithPolarityHelper encoder (encXORTable encoder) definePos defineNeg polarity (a,b)
encodeFASum :: forall m. PrimMonad m => Encoder m -> SAT.Lit -> SAT.Lit -> SAT.Lit -> m SAT.Lit
encodeFASum encoder = encodeFASumWithPolarity encoder polarityBoth
encodeFASumWithPolarity :: forall m. PrimMonad m => Encoder m -> Polarity -> SAT.Lit -> SAT.Lit -> SAT.Lit -> m SAT.Lit
encodeFASumWithPolarity encoder polarity a b c = do
let defineNeg x = do
SAT.addClause encoder [-a,-b,-c,x]
SAT.addClause encoder [-a,b,c,x]
SAT.addClause encoder [a,-b,c,x]
SAT.addClause encoder [a,b,-c,x]
definePos x = do
SAT.addClause encoder [a,b,c,-x]
SAT.addClause encoder [a,-b,-c,-x]
SAT.addClause encoder [-a,b,-c,-x]
SAT.addClause encoder [-a,-b,c,-x]
encodeWithPolarityHelper encoder (encFASumTable encoder) definePos defineNeg polarity (a,b,c)
encodeFACarry :: forall m. PrimMonad m => Encoder m -> SAT.Lit -> SAT.Lit -> SAT.Lit -> m SAT.Lit
encodeFACarry encoder = encodeFACarryWithPolarity encoder polarityBoth
encodeFACarryWithPolarity :: forall m. PrimMonad m => Encoder m -> Polarity -> SAT.Lit -> SAT.Lit -> SAT.Lit -> m SAT.Lit
encodeFACarryWithPolarity encoder polarity a b c = do
let defineNeg x = do
SAT.addClause encoder [-a,-b,x]
SAT.addClause encoder [-a,-c,x]
SAT.addClause encoder [-b,-c,x]
definePos x = do
SAT.addClause encoder [a,b,-x]
SAT.addClause encoder [a,c,-x]
SAT.addClause encoder [b,c,-x]
encodeWithPolarityHelper encoder (encFACarryTable encoder) definePos defineNeg polarity (a,b,c)
getDefinitions :: PrimMonad m => Encoder m -> m [(SAT.Var, Formula)]
getDefinitions encoder = do
tableConj <- readMutVar (encConjTable encoder)
tableITE <- readMutVar (encITETable encoder)
tableXOR <- readMutVar (encXORTable encoder)
tableFASum <- readMutVar (encFASumTable encoder)
tableFACarry <- readMutVar (encFACarryTable encoder)
let m1 = [(v, andB [Atom l1 | l1 <- IntSet.toList ls]) | (ls, (v, _, _)) <- Map.toList tableConj]
m2 = [(v, ite (Atom c) (Atom t) (Atom e)) | ((c,t,e), (v, _, _)) <- Map.toList tableITE]
m3 = [(v, (Atom a .||. Atom b) .&&. (Atom (-a) .||. Atom (-b))) | ((a,b), (v, _, _)) <- Map.toList tableXOR]
m4 = [(v, orB [andB [Atom l | l <- ls] | ls <- [[a,b,c],[a,-b,-c],[-a,b,-c],[-a,-b,c]]])
| ((a,b,c), (v, _, _)) <- Map.toList tableFASum]
m5 = [(v, orB [andB [Atom l | l <- ls] | ls <- [[a,b],[a,c],[b,c]]])
| ((a,b,c), (v, _, _)) <- Map.toList tableFACarry]
return $ concat [m1, m2, m3, m4, m5]
data Polarity
= Polarity
{ polarityPosOccurs :: Bool
, polarityNegOccurs :: Bool
}
deriving (Eq, Show)
negatePolarity :: Polarity -> Polarity
negatePolarity (Polarity pos neg) = (Polarity neg pos)
polarityPos :: Polarity
polarityPos = Polarity True False
polarityNeg :: Polarity
polarityNeg = Polarity False True
polarityBoth :: Polarity
polarityBoth = Polarity True True
polarityNone :: Polarity
polarityNone = Polarity False False