module ToySolver.SAT.PBO.BCD2
( Options (..)
, solve
) where
import Control.Concurrent.STM
import Control.Monad
import Data.IORef
import Data.Default.Class
import qualified Data.IntSet as IntSet
import qualified Data.IntMap as IntMap
import Data.Map (Map)
import qualified Data.Map as Map
import qualified Data.Vector as V
import qualified ToySolver.SAT as SAT
import qualified ToySolver.SAT.Types as SAT
import qualified ToySolver.SAT.PBO.Context as C
import qualified ToySolver.Combinatorial.SubsetSum as SubsetSum
import Text.Printf
data Options
= Options
{ optEnableHardening :: Bool
, optEnableBiasedSearch :: Bool
, optSolvingNormalFirst :: Bool
}
instance Default Options where
def =
Options
{ optEnableHardening = True
, optEnableBiasedSearch = True
, optSolvingNormalFirst = True
}
data CoreInfo
= CoreInfo
{ coreLits :: SAT.LitSet
, coreLBRef :: !(IORef Integer)
, coreUBSelectors :: !(IORef (Map Integer SAT.Lit))
}
newCoreInfo :: SAT.LitSet -> Integer -> IO CoreInfo
newCoreInfo lits lb = do
lbRef <- newIORef lb
selsRef <- newIORef Map.empty
return
CoreInfo
{ coreLits = lits
, coreLBRef = lbRef
, coreUBSelectors = selsRef
}
deleteCoreInfo :: SAT.Solver -> CoreInfo -> IO ()
deleteCoreInfo solver core = do
m <- readIORef (coreUBSelectors core)
forM_ (Map.elems m) $ \sel ->
SAT.addClause solver [sel]
getCoreLB :: CoreInfo -> IO Integer
getCoreLB = readIORef . coreLBRef
solve :: C.Context cxt => cxt -> SAT.Solver -> Options -> IO ()
solve cxt solver opt = solveWBO (C.normalize cxt) solver opt
solveWBO :: C.Context cxt => cxt -> SAT.Solver -> Options -> IO ()
solveWBO cxt solver opt = do
C.logMessage cxt $ printf "BCD2: Hardening is %s." (if optEnableHardening opt then "enabled" else "disabled")
C.logMessage cxt $ printf "BCD2: BiasedSearch is %s." (if optEnableBiasedSearch opt then "enabled" else "disabled")
SAT.modifyConfig solver $ \config -> config{ SAT.configEnableBackwardSubsumptionRemoval = True }
unrelaxedRef <- newIORef $ IntSet.fromList [lit | (lit,_) <- sels]
relaxedRef <- newIORef IntSet.empty
hardenedRef <- newIORef IntSet.empty
nsatRef <- newIORef 1
nunsatRef <- newIORef 1
lastUBRef <- newIORef $ SAT.pbUpperBound obj
coresRef <- newIORef []
let getLB = do
xs <- readIORef coresRef
foldM (\s core -> do{ v <- getCoreLB core; return $! s + v }) 0 xs
deductedWeightRef <- newIORef weights
let deductWeight d core =
modifyIORef' deductedWeightRef $ IntMap.unionWith (+) $
IntMap.fromList [(lit, d) | lit <- IntSet.toList (coreLits core)]
updateLB oldLB core = do
newLB <- getLB
C.addLowerBound cxt newLB
deductWeight (newLB oldLB) core
SAT.addPBAtLeast solver (coreCostFun core) =<< getCoreLB core
let loop = do
lb <- getLB
ub <- do
ub1 <- atomically $ C.getSearchUpperBound cxt
let ub2 = refineUB (map fst obj) ub1
when (ub1 /= ub2) $ C.logMessage cxt $ printf "BCD2: refineUB: %d -> %d" ub1 ub2
return ub2
when (ub < lb) $ C.setFinished cxt
fin <- atomically $ C.isFinished cxt
unless fin $ do
when (optEnableHardening opt) $ do
deductedWeight <- readIORef deductedWeightRef
hardened <- readIORef hardenedRef
let lits = IntMap.keysSet (IntMap.filter (\w -> lb + w > ub) deductedWeight) `IntSet.difference` hardened
unless (IntSet.null lits) $ do
C.logMessage cxt $ printf "BCD2: hardening fixes %d literals" (IntSet.size lits)
forM_ (IntSet.toList lits) $ \lit -> SAT.addClause solver [lit]
modifyIORef unrelaxedRef (`IntSet.difference` lits)
modifyIORef relaxedRef (`IntSet.difference` lits)
modifyIORef hardenedRef (`IntSet.union` lits)
ub0 <- readIORef lastUBRef
when (ub < ub0) $ do
C.logMessage cxt $ printf "BCD2: updating upper bound: %d -> %d" ub0 ub
SAT.addPBAtMost solver obj ub
writeIORef lastUBRef ub
cores <- readIORef coresRef
unrelaxed <- readIORef unrelaxedRef
relaxed <- readIORef relaxedRef
hardened <- readIORef hardenedRef
nsat <- readIORef nsatRef
nunsat <- readIORef nunsatRef
C.logMessage cxt $ printf "BCD2: %d <= obj <= %d" lb ub
C.logMessage cxt $ printf "BCD2: #cores=%d, #unrelaxed=%d, #relaxed=%d, #hardened=%d"
(length cores) (IntSet.size unrelaxed) (IntSet.size relaxed) (IntSet.size hardened)
C.logMessage cxt $ printf "BCD2: #sat=%d #unsat=%d bias=%d/%d" nsat nunsat nunsat (nunsat + nsat)
lastModel <- atomically $ C.getBestModel cxt
sels <- liftM IntMap.fromList $ forM cores $ \core -> do
coreLB <- getCoreLB core
let coreUB = SAT.pbUpperBound (coreCostFun core)
if coreUB < coreLB then do
C.logMessage cxt $ printf "BCD2: coreLB (%d) exceeds coreUB (%d)" coreLB coreUB
SAT.addClause solver []
sel <- getCoreUBAssumption core coreUB
return (sel, (core, coreUB))
else do
let estimated =
case lastModel of
Nothing -> coreUB
Just m ->
coreLB `max` SAT.evalPBLinSum m (coreCostFun core)
mid =
if optEnableBiasedSearch opt
then coreLB + (estimated coreLB) * nunsat `div` (nunsat + nsat)
else (coreLB + estimated) `div` 2
sel <- getCoreUBAssumption core mid
return (sel, (core,mid))
ret <- SAT.solveWith solver (IntMap.keys sels ++ IntSet.toList unrelaxed)
if ret then do
modifyIORef' nsatRef (+1)
C.addSolution cxt =<< SAT.getModel solver
loop
else do
modifyIORef' nunsatRef (+1)
failed <- SAT.getFailedAssumptions solver
case failed of
[] -> C.setFinished cxt
[sel] | Just (core,mid) <- IntMap.lookup sel sels -> do
C.logMessage cxt $ printf "BCD2: updating lower bound of a core"
let newCoreLB = mid + 1
newCoreLB' = refineLB [weights IntMap.! lit | lit <- IntSet.toList (coreLits core)] newCoreLB
when (newCoreLB /= newCoreLB') $ C.logMessage cxt $
printf "BCD2: refineLB: %d -> %d" newCoreLB newCoreLB'
writeIORef (coreLBRef core) newCoreLB'
SAT.addClause solver [sel]
updateLB lb core
loop
_ -> do
let failed' = IntSet.fromList failed
torelax = unrelaxed `IntSet.intersection` failed'
intersected = [(core,mid) | (sel,(core,mid)) <- IntMap.toList sels, sel `IntSet.member` failed']
disjoint = [core | (sel,(core,_)) <- IntMap.toList sels, sel `IntSet.notMember` failed']
modifyIORef unrelaxedRef (`IntSet.difference` torelax)
modifyIORef relaxedRef (`IntSet.union` torelax)
delta <- do
xs1 <- forM intersected $ \(core,mid) -> do
coreLB <- getCoreLB core
return $ mid coreLB + 1
let xs2 = [weights IntMap.! lit | lit <- IntSet.toList torelax]
return $! minimum (xs1 ++ xs2)
let mergedCoreLits = IntSet.unions $ torelax : [coreLits core | (core,_) <- intersected]
mergedCoreLB <- liftM ((delta +) . sum) $ mapM (getCoreLB . fst) intersected
let mergedCoreLB' = refineLB [weights IntMap.! lit | lit <- IntSet.toList mergedCoreLits] mergedCoreLB
mergedCore <- newCoreInfo mergedCoreLits mergedCoreLB'
writeIORef coresRef (mergedCore : disjoint)
forM_ intersected $ \(core, _) -> deleteCoreInfo solver core
if null intersected then
C.logMessage cxt $ printf "BCD2: found a new core of size %d: cost of the new core is >=%d"
(IntSet.size torelax) mergedCoreLB'
else
C.logMessage cxt $ printf "BCD2: relaxing %d and merging with %d cores into a new core of size %d: cost of the new core is >=%d"
(IntSet.size torelax) (length intersected) (IntSet.size mergedCoreLits) mergedCoreLB'
when (mergedCoreLB /= mergedCoreLB') $
C.logMessage cxt $ printf "BCD2: refineLB: %d -> %d" mergedCoreLB mergedCoreLB'
updateLB lb mergedCore
loop
best <- atomically $ C.getBestModel cxt
case best of
Nothing | optSolvingNormalFirst opt -> do
ret <- SAT.solve solver
if ret then
C.addSolution cxt =<< SAT.getModel solver
else
C.setFinished cxt
_ -> return ()
loop
where
obj :: SAT.PBLinSum
obj = C.getObjectiveFunction cxt
sels :: [(SAT.Lit, Integer)]
sels = [(lit, w) | (w,lit) <- obj]
weights :: SAT.LitMap Integer
weights = IntMap.fromList sels
coreCostFun :: CoreInfo -> SAT.PBLinSum
coreCostFun c = [(weights IntMap.! lit, lit) | lit <- IntSet.toList (coreLits c)]
getCoreUBAssumption :: CoreInfo -> Integer -> IO SAT.Lit
getCoreUBAssumption core ub = do
m <- readIORef (coreUBSelectors core)
case Map.splitLookup ub m of
(_, Just sel, _) -> return sel
(lo, Nothing, hi) -> do
sel <- SAT.newVar solver
SAT.addPBAtMostSoft solver sel (coreCostFun core) ub
writeIORef (coreUBSelectors core) (Map.insert ub sel m)
unless (Map.null lo) $
SAT.addClause solver [ snd (Map.findMax lo), sel]
unless (Map.null hi) $
SAT.addClause solver [ sel, snd (Map.findMin hi)]
return sel
refineLB
:: [Integer]
-> Integer
-> Integer
refineLB ws n =
case SubsetSum.minSubsetSum (V.fromList ws) n of
Nothing -> sum [w | w <- ws, w > 0] + 1
Just (obj, _) -> obj
refineUB
:: [Integer]
-> Integer
-> Integer
refineUB ws n
| n < 0 = n
| otherwise =
case SubsetSum.maxSubsetSum (V.fromList ws) n of
Nothing -> sum [w | w <- ws, w < 0] 1
Just (obj, _) -> obj