module SAT.Types
(
Var
, VarSet
, VarMap
, validVar
, Model
, Lit
, LitSet
, LitMap
, litUndef
, validLit
, literal
, litNot
, litVar
, litPolarity
, evalLit
, Clause
, normalizeClause
, normalizeAtLeast
, normalizePBSum
, normalizePBAtLeast
, normalizePBExactly
, cutResolve
, cardinalityReduction
, negatePBAtLeast
, pbEval
, pbLowerBound
, pbUpperBound
) where
import Control.Monad
import Control.Exception
import Data.Array.Unboxed
import Data.Ord
import Data.List
import qualified Data.IntMap as IM
import qualified Data.IntSet as IS
import qualified Data.Set as Set
type Var = Int
type VarSet = IS.IntSet
type VarMap = IM.IntMap
validVar :: Var -> Bool
validVar v = v > 0
type Model = UArray Var Bool
type Lit = Int
litUndef :: Lit
litUndef = 0
type LitSet = IS.IntSet
type LitMap = IM.IntMap
validLit :: Lit -> Bool
validLit l = l /= 0
literal :: Var
-> Bool
-> Lit
literal v polarity =
assert (validVar v) $ if polarity then v else litNot v
litNot :: Lit -> Lit
litNot l = assert (validLit l) $ negate l
litVar :: Lit -> Var
litVar l = assert (validLit l) $ abs l
litPolarity :: Lit -> Bool
litPolarity l = assert (validLit l) $ l > 0
evalLit :: Model -> Lit -> Bool
evalLit m l = if l > 0 then m ! l else not (m ! abs l)
type Clause = [Lit]
normalizeClause :: Clause -> Maybe Clause
normalizeClause lits = assert (IS.size ys `mod` 2 == 0) $
if IS.null ys
then Just (IS.toList xs)
else Nothing
where
xs = IS.fromList lits
ys = xs `IS.intersection` (IS.map litNot xs)
normalizeAtLeast :: ([Lit],Int) -> ([Lit],Int)
normalizeAtLeast (lits,n) = assert (IS.size ys `mod` 2 == 0) $
(IS.toList lits', n')
where
xs = IS.fromList lits
ys = xs `IS.intersection` (IS.map litNot xs)
lits' = xs `IS.difference` ys
n' = n (IS.size ys `div` 2)
normalizePBSum :: ([(Integer,Lit)], Integer) -> ([(Integer,Lit)], Integer)
normalizePBSum = step2 . step1
where
step1 :: ([(Integer,Lit)], Integer) -> ([(Integer,Lit)], Integer)
step1 (xs,n) =
case loop (IM.empty,n) xs of
(ys,n') -> ([(c,v) | (v,c) <- IM.toList ys], n')
where
loop :: (VarMap Integer, Integer) -> [(Integer,Lit)] -> (VarMap Integer, Integer)
loop (ys,m) [] = (ys,m)
loop (ys,m) ((c,l):zs) =
if litPolarity l
then loop (IM.insertWith (+) l c ys, m) zs
else loop (IM.insertWith (+) (litNot l) (negate c) ys, m+c) zs
step2 :: ([(Integer,Lit)], Integer) -> ([(Integer,Lit)], Integer)
step2 (xs,n) = loop ([],n) xs
where
loop (ys,m) [] = (ys,m)
loop (ys,m) (t@(c,l):zs)
| c == 0 = loop (ys,m) zs
| c < 0 = loop ((negate c,litNot l):ys, m+c) zs
| otherwise = loop (t:ys,m) zs
normalizePBAtLeast :: ([(Integer,Lit)], Integer) -> ([(Integer,Lit)], Integer)
normalizePBAtLeast a =
case step1 a of
(xs,n)
| n > 0 -> step3 (saturate n xs, n)
| otherwise -> ([], 0)
where
step1 :: ([(Integer,Lit)], Integer) -> ([(Integer,Lit)], Integer)
step1 (xs,n) =
case normalizePBSum (xs,n) of
(ys,m) -> (ys, m)
saturate :: Integer -> [(Integer,Lit)] -> [(Integer,Lit)]
saturate n xs = [assert (c>0) (min n c, l) | (c,l) <- xs]
step3 :: ([(Integer,Lit)], Integer) -> ([(Integer,Lit)], Integer)
step3 ([],n) = ([],n)
step3 (xs,n) = ([(c `div` d, l) | (c,l) <- xs], (n+d1) `div` d)
where
d = foldl1' gcd [c | (c,_) <- xs]
normalizePBExactly :: ([(Integer,Lit)], Integer) -> ([(Integer,Lit)], Integer)
normalizePBExactly a =
case step1 $ a of
(xs,n)
| n >= 0 -> step2 (xs, n)
| otherwise -> ([], 1)
where
step1 :: ([(Integer,Lit)], Integer) -> ([(Integer,Lit)], Integer)
step1 (xs,n) =
case normalizePBSum (xs,n) of
(ys,m) -> (ys, m)
step2 :: ([(Integer,Lit)], Integer) -> ([(Integer,Lit)], Integer)
step2 ([],n) = ([],n)
step2 (xs,n)
| n `mod` d == 0 = ([(c `div` d, l) | (c,l) <- xs], n `div` d)
| otherwise = ([], 1)
where
d = foldl1' gcd [c | (c,_) <- xs]
cutResolve :: ([(Integer,Lit)],Integer) -> ([(Integer,Lit)],Integer) -> Var -> ([(Integer,Lit)],Integer)
cutResolve (lhs1,rhs1) (lhs2,rhs2) v = assert (l1 == litNot l2) $ normalizePBAtLeast pb
where
(c1,l1) = head [(c,l) | (c,l) <- lhs1, litVar l == v]
(c2,l2) = head [(c,l) | (c,l) <- lhs2, litVar l == v]
g = gcd c1 c2
s1 = c2 `div` g
s2 = c1 `div` g
pb = ([(s1*c,l) | (c,l) <- lhs1] ++ [(s2*c,l) | (c,l) <- lhs2], s1*rhs1 + s2 * rhs2)
cardinalityReduction :: ([(Integer,Lit)],Integer) -> ([Lit],Int)
cardinalityReduction (lhs,rhs) = (ls, rhs')
where
rhs' = go1 0 0 (sortBy (flip (comparing fst)) lhs)
go1 !s !k ((a,_):ts)
| s < rhs = go1 (s+a) (k+1) ts
| otherwise = k
go1 _ _ [] = error "cardinalityReduction: should not happen"
ls = go2 (minimum (rhs : map (subtract 1 . fst) lhs)) (sortBy (comparing fst) lhs)
go2 !guard' ((a,_) : ts)
| a 1 < guard' = go2 (guard' a) ts
| otherwise = map snd ts
go2 _ [] = error "cardinalityReduction: should not happen"
negatePBAtLeast :: ([(Integer, Lit)], Integer) -> ([(Integer, Lit)], Integer)
negatePBAtLeast (xs, rhs) = ([(c,lit) | (c,lit)<-xs] , rhs + 1)
pbEval :: Model -> [(Integer, Lit)] -> Integer
pbEval m xs = sum [c | (c,lit) <- xs, evalLit m lit]
pbLowerBound :: [(Integer, Lit)] -> Integer
pbLowerBound xs = sum [if c < 0 then c else 0 | (c,_) <- xs]
pbUpperBound :: [(Integer, Lit)] -> Integer
pbUpperBound xs = sum [if c > 0 then c else 0 | (c,_) <- xs]