{-# LANGUAGE BangPatterns #-} {-# LANGUAGE CPP #-} {-# LANGUAGE DeriveDataTypeable #-} {-# LANGUAGE ScopedTypeVariables #-} {-# OPTIONS_GHC -Wall #-} ---------------------------------------------------------------------- -- | -- Module : ToySolver.SAT.SLS.ProbSAT -- Copyright : (c) Masahiro Sakai 2017 -- License : BSD-style -- -- Maintainer : masahiro.sakai@gmail.com -- Stability : provisional -- Portability : non-portable -- -- References: -- ---------------------------------------------------------------------- module ToySolver.SAT.SLS.ProbSAT ( Solver , newSolver , newSolverWeighted , getNumVars , getRandomGen , setRandomGen , getBestSolution , getStatistics , Options (..) , Callbacks (..) , Statistics (..) , generateUniformRandomSolution , probsat , walksat ) where import Prelude hiding (break) import Control.Exception import Control.Loop import Control.Monad import Control.Monad.Primitive import Control.Monad.Trans import Control.Monad.Trans.Except import Data.Array.Base (unsafeRead, unsafeWrite, unsafeAt) import Data.Array.IArray import Data.Array.IO import Data.Array.Unboxed import Data.Array.Unsafe import Data.Bits import Data.Default.Class import qualified Data.Foldable as F import Data.Int import Data.IORef import Data.Maybe import Data.Sequence ((|>)) import qualified Data.Sequence as Seq import Data.Typeable import Data.Word import System.Clock import qualified System.Random.MWC as Rand import qualified System.Random.MWC.Distributions as Rand import qualified ToySolver.FileFormat.CNF as CNF import ToySolver.Internal.Data.IOURef import qualified ToySolver.Internal.Data.Vec as Vec import qualified ToySolver.SAT.Types as SAT -- ------------------------------------------------------------------- data Solver = Solver { svClauses :: !(Array ClauseId PackedClause) , svClauseWeights :: !(Array ClauseId CNF.Weight) , svClauseWeightsF :: !(UArray ClauseId Double) , svClauseNumTrueLits :: !(IOUArray ClauseId Int32) , svClauseUnsatClauseIndex :: !(IOUArray ClauseId Int) , svUnsatClauses :: !(Vec.UVec ClauseId) , svVarOccurs :: !(Array SAT.Var (UArray Int ClauseId)) , svVarOccursState :: !(Array SAT.Var (IOUArray Int Bool)) , svSolution :: !(IOUArray SAT.Var Bool) , svObj :: !(IORef CNF.Weight) , svRandomGen :: !(IORef Rand.GenIO) , svBestSolution :: !(IORef (CNF.Weight, SAT.Model)) , svStatistics :: !(IORef Statistics) } type ClauseId = Int type PackedClause = Array Int SAT.Lit newSolver :: CNF.CNF -> IO Solver newSolver cnf = do let wcnf = CNF.WCNF { CNF.wcnfNumVars = CNF.cnfNumVars cnf , CNF.wcnfNumClauses = CNF.cnfNumClauses cnf , CNF.wcnfTopCost = fromIntegral (CNF.cnfNumClauses cnf) + 1 , CNF.wcnfClauses = [(1,c) | c <- CNF.cnfClauses cnf] } newSolverWeighted wcnf newSolverWeighted :: CNF.WCNF -> IO Solver newSolverWeighted wcnf = do let m :: SAT.Var -> Bool m _ = False nv = CNF.wcnfNumVars wcnf objRef <- newIORef (0::Integer) cs <- liftM catMaybes $ forM (CNF.wcnfClauses wcnf) $ \(w,pc) -> do case SAT.normalizeClause (SAT.unpackClause pc) of Nothing -> return Nothing Just [] -> modifyIORef' objRef (w+) >> return Nothing Just c -> do let c' = listArray (0, length c - 1) c seq c' $ return (Just (w,c')) let len = length cs clauses = listArray (0, len - 1) (map snd cs) weights :: Array ClauseId CNF.Weight weights = listArray (0, len - 1) (map fst cs) weightsF :: UArray ClauseId Double weightsF = listArray (0, len - 1) (map (fromIntegral . fst) cs) (varOccurs' :: IOArray SAT.Var (Seq.Seq (Int, Bool))) <- newArray (1, nv) Seq.empty clauseNumTrueLits <- newArray (bounds clauses) 0 clauseUnsatClauseIndex <- newArray (bounds clauses) (-1) unsatClauses <- Vec.new forAssocsM_ clauses $ \(c,clause) -> do let n = sum [1 | lit <- elems clause, SAT.evalLit m lit] writeArray clauseNumTrueLits c n when (n == 0) $ do i <- Vec.getSize unsatClauses writeArray clauseUnsatClauseIndex c i Vec.push unsatClauses c modifyIORef objRef ((weights ! c) +) forM_ (elems clause) $ \lit -> do let v = SAT.litVar lit let b = SAT.evalLit m lit seq b $ modifyArray varOccurs' v (|> (c,b)) varOccurs <- do (arr::IOArray SAT.Var (UArray Int ClauseId)) <- newArray_ (1, nv) forM_ [1 .. nv] $ \v -> do s <- readArray varOccurs' v writeArray arr v $ listArray (0, Seq.length s - 1) (map fst (F.toList s)) unsafeFreeze arr varOccursState <- do (arr::IOArray SAT.Var (IOUArray Int Bool)) <- newArray_ (1, nv) forM_ [1 .. nv] $ \v -> do s <- readArray varOccurs' v ss <- newArray_ (0, Seq.length s - 1) forM_ (zip [0..] (F.toList s)) $ \(j,a) -> writeArray ss j (snd a) writeArray arr v ss unsafeFreeze arr solution <- newListArray (1, nv) $ [SAT.evalVar m v | v <- [1..nv]] bestObj <- readIORef objRef bestSol <- freeze solution bestSolution <- newIORef (bestObj, bestSol) randGen <- newIORef =<< Rand.create stat <- newIORef def return $ Solver { svClauses = clauses , svClauseWeights = weights , svClauseWeightsF = weightsF , svClauseNumTrueLits = clauseNumTrueLits , svClauseUnsatClauseIndex = clauseUnsatClauseIndex , svUnsatClauses = unsatClauses , svVarOccurs = varOccurs , svVarOccursState = varOccursState , svSolution = solution , svObj = objRef , svRandomGen = randGen , svBestSolution = bestSolution , svStatistics = stat } flipVar :: Solver -> SAT.Var -> IO () flipVar solver v = mask_ $ do let occurs = svVarOccurs solver ! v occursState = svVarOccursState solver ! v seq occurs $ seq occursState $ return () modifyArray (svSolution solver) v not forAssocsM_ occurs $ \(j,!c) -> do b <- unsafeRead occursState j n <- unsafeRead (svClauseNumTrueLits solver) c unsafeWrite occursState j (not b) if b then do unsafeWrite (svClauseNumTrueLits solver) c (n-1) when (n==1) $ do i <- Vec.getSize (svUnsatClauses solver) Vec.push (svUnsatClauses solver) c unsafeWrite (svClauseUnsatClauseIndex solver) c i modifyIORef' (svObj solver) (+ unsafeAt (svClauseWeights solver) c) else do unsafeWrite (svClauseNumTrueLits solver) c (n+1) when (n==0) $ do s <- Vec.getSize (svUnsatClauses solver) i <- unsafeRead (svClauseUnsatClauseIndex solver) c unless (i == s-1) $ do let i2 = s-1 c2 <- Vec.unsafeRead (svUnsatClauses solver) i2 Vec.unsafeWrite (svUnsatClauses solver) i2 c Vec.unsafeWrite (svUnsatClauses solver) i c2 unsafeWrite (svClauseUnsatClauseIndex solver) c2 i _ <- Vec.unsafePop (svUnsatClauses solver) modifyIORef' (svObj solver) (subtract (unsafeAt (svClauseWeights solver) c)) return () setSolution :: SAT.IModel m => Solver -> m -> IO () setSolution solver m = do b <- getBounds (svSolution solver) forM_ (range b) $ \v -> do val <- readArray (svSolution solver) v let val' = SAT.evalVar m v unless (val == val') $ do flipVar solver v getNumVars :: Solver -> IO Int getNumVars solver = return $ rangeSize $ bounds (svVarOccurs solver) getRandomGen :: Solver -> IO Rand.GenIO getRandomGen solver = readIORef (svRandomGen solver) setRandomGen :: Solver -> Rand.GenIO -> IO () setRandomGen solver gen = writeIORef (svRandomGen solver) gen getBestSolution :: Solver -> IO (CNF.Weight, SAT.Model) getBestSolution solver = readIORef (svBestSolution solver) getStatistics :: Solver -> IO Statistics getStatistics solver = readIORef (svStatistics solver) {-# INLINE getMakeValue #-} getMakeValue :: Solver -> SAT.Var -> IO Double getMakeValue solver v = do let occurs = svVarOccurs solver ! v (lb,ub) = bounds occurs seq occurs $ seq lb $ seq ub $ numLoopState lb ub 0 $ \ !r !i -> do let c = unsafeAt occurs i n <- unsafeRead (svClauseNumTrueLits solver) c return $! if n == 0 then (r + unsafeAt (svClauseWeightsF solver) c) else r {-# INLINE getBreakValue #-} getBreakValue :: Solver -> SAT.Var -> IO Double getBreakValue solver v = do let occurs = svVarOccurs solver ! v occursState = svVarOccursState solver ! v (lb,ub) = bounds occurs seq occurs $ seq occursState $ seq lb $ seq ub $ numLoopState lb ub 0 $ \ !r !i -> do b <- unsafeRead occursState i if b then do let c = unsafeAt occurs i n <- unsafeRead (svClauseNumTrueLits solver) c return $! if n==1 then (r + unsafeAt (svClauseWeightsF solver) c) else r else return r -- ------------------------------------------------------------------- data Options = Options { optTarget :: !CNF.Weight , optMaxTries :: !Int , optMaxFlips :: !Int , optPickClauseWeighted :: Bool } deriving (Eq, Show) instance Default Options where def = Options { optTarget = 0 , optMaxTries = 1 , optMaxFlips = 100000 , optPickClauseWeighted = False } data Callbacks = Callbacks { cbGenerateInitialSolution :: Solver -> IO SAT.Model , cbOnUpdateBestSolution :: Solver -> CNF.Weight -> SAT.Model -> IO () } instance Default Callbacks where def = Callbacks { cbGenerateInitialSolution = generateUniformRandomSolution , cbOnUpdateBestSolution = \_ _ _ -> return () } data Statistics = Statistics { statTotalCPUTime :: !TimeSpec , statFlips :: !Int , statFlipsPerSecond :: !Double } deriving (Eq, Show) instance Default Statistics where def = Statistics { statTotalCPUTime = 0 , statFlips = 0 , statFlipsPerSecond = 0 } -- ------------------------------------------------------------------- generateUniformRandomSolution :: Solver -> IO SAT.Model generateUniformRandomSolution solver = do gen <- getRandomGen solver n <- getNumVars solver (a :: IOUArray Int Bool) <- newArray_ (1,n) forM_ [1..n] $ \v -> do b <- Rand.uniform gen writeArray a v b unsafeFreeze a checkCurrentSolution :: Solver -> Callbacks -> IO () checkCurrentSolution solver cb = do best <- readIORef (svBestSolution solver) obj <- readIORef (svObj solver) when (obj < fst best) $ do sol <- freeze (svSolution solver) writeIORef (svBestSolution solver) (obj, sol) cbOnUpdateBestSolution cb solver obj sol pickClause :: Solver -> Options -> IO PackedClause pickClause solver opt = do gen <- getRandomGen solver if optPickClauseWeighted opt then do obj <- readIORef (svObj solver) let f !j !x = do c <- Vec.read (svUnsatClauses solver) j let w = svClauseWeights solver ! c if x < w then return c else f (j + 1) (x - w) x <- rand obj gen c <- f 0 x return $ (svClauses solver ! c) else do s <- Vec.getSize (svUnsatClauses solver) j <- Rand.uniformR (0, s - 1) gen -- For integral types inclusive range is used liftM (svClauses solver !) $ Vec.read (svUnsatClauses solver) j rand :: PrimMonad m => Integer -> Rand.Gen (PrimState m) -> m Integer rand n gen | n <= toInteger (maxBound :: Word32) = liftM toInteger $ Rand.uniformR (0, fromIntegral n - 1 :: Word32) gen | otherwise = do a <- rand (n `shiftR` 32) gen (b::Word32) <- Rand.uniform gen return $ (a `shiftL` 32) .|. toInteger b data Finished = Finished deriving (Show, Typeable) instance Exception Finished -- ------------------------------------------------------------------- probsat :: Solver -> Options -> Callbacks -> (Double -> Double -> Double) -> IO () probsat solver opt cb f = do gen <- getRandomGen solver let maxClauseLen = if rangeSize (bounds (svClauses solver)) == 0 then 0 else maximum $ map (rangeSize . bounds) $ elems (svClauses solver) (wbuf :: IOUArray Int Double) <- newArray_ (0, maxClauseLen-1) wsumRef <- newIOURef (0 :: Double) let pickVar :: PackedClause -> IO SAT.Var pickVar c = do writeIOURef wsumRef 0 forAssocsM_ c $ \(k,lit) -> do let v = SAT.litVar lit m <- getMakeValue solver v b <- getBreakValue solver v let w = f m b writeArray wbuf k w modifyIOURef wsumRef (+w) wsum <- readIOURef wsumRef let go :: Int -> Double -> IO Int go !k !a = do if not (inRange (bounds c) k) then do return $! snd (bounds c) else do w <- readArray wbuf k if a <= w then return k else go (k + 1) (a - w) k <- go 0 =<< Rand.uniformR (0, wsum) gen return $! SAT.litVar (c ! k) startCPUTime <- getTime ProcessCPUTime flipsRef <- newIOURef (0::Int) -- It's faster to use Control.Exception than using Control.Monad.Except let body = do replicateM_ (optMaxTries opt) $ do sol <- cbGenerateInitialSolution cb solver setSolution solver sol checkCurrentSolution solver cb replicateM_ (optMaxFlips opt) $ do s <- Vec.getSize (svUnsatClauses solver) when (s == 0) $ throw Finished obj <- readIORef (svObj solver) when (obj <= optTarget opt) $ throw Finished c <- pickClause solver opt v <- pickVar c flipVar solver v modifyIOURef flipsRef inc checkCurrentSolution solver cb body `catch` (\(_::Finished) -> return ()) endCPUTime <- getTime ProcessCPUTime flips <- readIOURef flipsRef let totalCPUTime = endCPUTime `diffTimeSpec` startCPUTime totalCPUTimeSec = fromIntegral (toNanoSecs totalCPUTime) / 10^(9::Int) writeIORef (svStatistics solver) $ Statistics { statTotalCPUTime = totalCPUTime , statFlips = flips , statFlipsPerSecond = fromIntegral flips / totalCPUTimeSec } return () walksat :: Solver -> Options -> Callbacks -> Double -> IO () walksat solver opt cb p = do gen <- getRandomGen solver (buf :: Vec.UVec SAT.Var) <- Vec.new let pickVar :: PackedClause -> IO SAT.Var pickVar c = do Vec.clear buf let (lb,ub) = bounds c r <- runExceptT $ do _ <- numLoopState lb ub (1.0/0.0) $ \ !b0 !i -> do let v = SAT.litVar (c ! i) b <- lift $ getBreakValue solver v if b <= 0 then throwE v -- freebie move else if b < b0 then do lift $ Vec.clear buf >> Vec.push buf v return b else if b == b0 then do lift $ Vec.push buf v return b0 else do return b0 return () case r of Left v -> return v Right _ -> do flag <- Rand.bernoulli p gen if flag then do -- random walk move i <- Rand.uniformR (lb,ub) gen return $! SAT.litVar (c ! i) else do -- greedy move s <- Vec.getSize buf if s == 1 then Vec.unsafeRead buf 0 else do i <- Rand.uniformR (0, s - 1) gen Vec.unsafeRead buf i startCPUTime <- getTime ProcessCPUTime flipsRef <- newIOURef (0::Int) -- It's faster to use Control.Exception than using Control.Monad.Except let body = do replicateM_ (optMaxTries opt) $ do sol <- cbGenerateInitialSolution cb solver setSolution solver sol checkCurrentSolution solver cb replicateM_ (optMaxFlips opt) $ do s <- Vec.getSize (svUnsatClauses solver) when (s == 0) $ throw Finished obj <- readIORef (svObj solver) when (obj <= optTarget opt) $ throw Finished c <- pickClause solver opt v <- pickVar c flipVar solver v modifyIOURef flipsRef inc checkCurrentSolution solver cb body `catch` (\(_::Finished) -> return ()) endCPUTime <- getTime ProcessCPUTime flips <- readIOURef flipsRef let totalCPUTime = endCPUTime `diffTimeSpec` startCPUTime totalCPUTimeSec = fromIntegral (toNanoSecs totalCPUTime) / 10^(9::Int) writeIORef (svStatistics solver) $ Statistics { statTotalCPUTime = totalCPUTime , statFlips = flips , statFlipsPerSecond = fromIntegral flips / totalCPUTimeSec } return () -- ------------------------------------------------------------------- {-# INLINE modifyArray #-} modifyArray :: (MArray a e m, Ix i) => a i e -> i -> (e -> e) -> m () modifyArray a i f = do e <- readArray a i writeArray a i (f e) {-# INLINE forAssocsM_ #-} forAssocsM_ :: (IArray a e, Monad m) => a Int e -> ((Int,e) -> m ()) -> m () forAssocsM_ a f = do let (lb,ub) = bounds a numLoop lb ub $ \i -> f (i, unsafeAt a i) {-# INLINE inc #-} inc :: Integral a => a -> a inc a = a+1 -- -------------------------------------------------------------------