{-# OPTIONS_GHC -Wall #-} {-# LANGUAGE ScopedTypeVariables, BangPatterns, TypeFamilies #-} ----------------------------------------------------------------------------- -- | -- Module : ToySolver.SAT.MessagePassing.SurveyPropagation -- Copyright : (c) Masahiro Sakai 2016 -- License : BSD-style -- -- Maintainer : masahiro.sakai@gmail.com -- Stability : provisional -- Portability : non-portable (ScopedTypeVariables, BangPatterns, TypeFamilies) -- -- References: -- -- * Alfredo Braunstein, Marc Mézard and Riccardo Zecchina. -- Survey Propagation: An Algorithm for Satisfiability, -- -- -- * Corrie Scalisi. Visualizing Survey Propagation in 3-SAT Factor Graphs, -- . -- ----------------------------------------------------------------------------- module ToySolver.SAT.MessagePassing.SurveyPropagation ( -- * The Solver type Solver , newSolver , deleteSolver -- * Problem information , getNVars , getNConstraints -- * Parameters , getTolerance , setTolerance , getIterationLimit , setIterationLimit , getNThreads , setNThreads -- * Computing marginal distributions , initializeRandom , initializeRandomDirichlet , propagate , getVarProb -- * Solving , fixLit , unfixLit -- * Debugging , printInfo ) where import Control.Applicative import Control.Concurrent import Control.Concurrent.STM import Control.Exception import Control.Loop import Control.Monad import qualified Data.Array.IArray as A import qualified Data.IntMap as IntMap import qualified Data.IntSet as IntSet import Data.IORef import Data.Maybe (fromJust) import qualified Data.Vector as V import qualified Data.Vector.Mutable as VM import qualified Data.Vector.Unboxed as VU import qualified Data.Vector.Unboxed.Mutable as VUM import Data.Vector.Generic ((!)) import qualified Data.Vector.Generic as VG import qualified Data.Vector.Generic.Mutable as VGM import qualified Numeric.Log as L import qualified System.Random.MWC as Rand import qualified System.Random.MWC.Distributions as Rand import qualified ToySolver.SAT.Types as SAT infixr 8 ^* (^*) :: Num a => L.Log a -> a -> L.Log a L.Exp a ^* b = L.Exp (a*b) comp :: (RealFloat a, L.Precise a) => L.Log a -> L.Log a comp (L.Exp a) = L.Exp $ L.log1p $ max (-1) $ negate (exp a) type ClauseIndex = Int type EdgeIndex = Int data Solver = Solver { svVarEdges :: !(V.Vector (VU.Vector EdgeIndex)) , svVarProbT :: !(VUM.IOVector (L.Log Double)) , svVarProbF :: !(VUM.IOVector (L.Log Double)) , svVarFixed :: !(VM.IOVector (Maybe Bool)) , svClauseEdges :: !(V.Vector (VU.Vector EdgeIndex)) , svClauseWeight :: !(VU.Vector Double) , svEdgeLit :: !(VU.Vector SAT.Lit) -- i , svEdgeClause :: !(VU.Vector ClauseIndex) -- a , svEdgeSurvey :: !(VUM.IOVector (L.Log Double)) -- η_{a → i} , svEdgeProbU :: !(VUM.IOVector (L.Log Double)) -- Π^u_{i → a} / (Π^u_{i → a} + Π^s_{i → a} + Π^0_{i → a}) , svTolRef :: !(IORef Double) , svIterLimRef :: !(IORef (Maybe Int)) , svNThreadsRef :: !(IORef Int) } newSolver :: Int -> [(Double, SAT.Clause)] -> IO Solver newSolver nv clauses = do let num_clauses = length clauses num_edges = sum [length c | (_,c) <- clauses] varEdgesRef <- newIORef IntMap.empty clauseEdgesM <- VGM.new num_clauses (edgeLitM :: VUM.IOVector SAT.Lit) <- VGM.new num_edges (edgeClauseM :: VUM.IOVector ClauseIndex) <- VGM.new num_edges ref <- newIORef 0 forM_ (zip [0..] clauses) $ \(i,(_,c)) -> do es <- forM c $ \lit -> do e <- readIORef ref modifyIORef' ref (+1) modifyIORef' varEdgesRef (IntMap.insertWith IntSet.union (abs lit) (IntSet.singleton e)) VGM.unsafeWrite edgeLitM e lit VGM.unsafeWrite edgeClauseM e i return e VGM.unsafeWrite clauseEdgesM i (VG.fromList es) varEdges <- readIORef varEdgesRef clauseEdges <- VG.unsafeFreeze clauseEdgesM edgeLit <- VG.unsafeFreeze edgeLitM edgeClause <- VG.unsafeFreeze edgeClauseM -- Initialize all surveys with non-zero values. -- If we initialize to zero, following trivial solution exists: -- -- η_{a→i} = 0 for all i and a. -- -- Π^0_{i→a} = 1, Π^u_{i→a} = Π^s_{i→a} = 0 for all i and a, -- -- \^{Π}^{0}_i = 1, \^{Π}^{+}_i = \^{Π}^{-}_i = 0 -- edgeSurvey <- VGM.replicate num_edges 0.5 edgeProbU <- VGM.new num_edges varFixed <- VGM.replicate nv Nothing varProbT <- VGM.new nv varProbF <- VGM.new nv tolRef <- newIORef 0.01 maxIterRef <- newIORef (Just 1000) nthreadsRef <- newIORef 1 let solver = Solver { svVarEdges = VG.generate nv $ \i -> case IntMap.lookup (i+1) varEdges of Nothing -> VG.empty Just es -> VG.fromListN (IntSet.size es) (IntSet.toList es) , svVarProbT = varProbT , svVarProbF = varProbF , svVarFixed = varFixed , svClauseEdges = clauseEdges , svClauseWeight = VG.fromListN (VG.length clauseEdges) (map fst clauses) , svEdgeLit = edgeLit , svEdgeClause = edgeClause , svEdgeSurvey = edgeSurvey , svEdgeProbU = edgeProbU , svTolRef = tolRef , svIterLimRef = maxIterRef , svNThreadsRef = nthreadsRef } return solver deleteSolver :: Solver -> IO () deleteSolver solver = return () initializeRandom :: Solver -> Rand.GenIO -> IO () initializeRandom solver gen = do VG.forM_ (svClauseEdges solver) $ \es -> do case VG.length es of 0 -> return () 1 -> VGM.unsafeWrite (svEdgeSurvey solver) (es ! 0) 1 n -> do let p :: Double p = 1 / fromIntegral n VG.forM_ es $ \e -> do d <- Rand.uniformR (p*0.5, p*1.5) gen VGM.unsafeWrite (svEdgeSurvey solver) e (realToFrac d) initializeRandomDirichlet :: Solver -> Rand.GenIO -> IO () initializeRandomDirichlet solver gen = do VG.forM_ (svClauseEdges solver) $ \es -> do case VG.length es of 0 -> return () 1 -> VGM.unsafeWrite (svEdgeSurvey solver) (es ! 0) 1 len -> do (ps :: V.Vector Double) <- Rand.dirichlet (VG.replicate len 1) gen numLoop 0 (len-1) $ \i -> do VGM.unsafeWrite (svEdgeSurvey solver) (es ! i) (realToFrac (ps ! i)) -- | number of variables of the problem. getNVars :: Solver -> IO Int getNVars solver = return $ VG.length (svVarEdges solver) -- | number of constraints of the problem. getNConstraints :: Solver -> IO Int getNConstraints solver = return $ VG.length (svClauseEdges solver) -- | number of edges of the factor graph getNEdges :: Solver -> IO Int getNEdges solver = return $ VG.length (svEdgeLit solver) getTolerance :: Solver -> IO Double getTolerance solver = readIORef (svTolRef solver) setTolerance :: Solver -> Double -> IO () setTolerance solver !tol = writeIORef (svTolRef solver) tol getIterationLimit :: Solver -> IO (Maybe Int) getIterationLimit solver = readIORef (svIterLimRef solver) setIterationLimit :: Solver -> Maybe Int -> IO () setIterationLimit solver val = writeIORef (svIterLimRef solver) val getNThreads :: Solver -> IO Int getNThreads solver = readIORef (svNThreadsRef solver) setNThreads :: Solver -> Int -> IO () setNThreads solver val = writeIORef (svNThreadsRef solver) val propagate :: Solver -> IO Bool propagate solver = do nthreads <- getNThreads solver if nthreads > 1 then propagateMT solver nthreads else propagateST solver propagateST :: Solver -> IO Bool propagateST solver = do tol <- getTolerance solver lim <- getIterationLimit solver nv <- getNVars solver nc <- getNConstraints solver let max_v_len = VG.maximum $ VG.map VG.length $ svVarEdges solver max_c_len = VG.maximum $ VG.map VG.length $ svClauseEdges solver tmp <- VGM.new (max (max_v_len * 2) max_c_len) let loop !i | Just l <- lim, i >= l = return False | otherwise = do numLoop 1 nv $ \v -> updateEdgeProb solver v tmp let f maxDelta c = max maxDelta <$> updateEdgeSurvey solver c tmp delta <- foldM f 0 [0 .. nc-1] if delta <= tol then do numLoop 1 nv $ \v -> computeVarProb solver v return True else loop (i+1) loop 0 data WorkerCommand = WCUpdateEdgeProb | WCUpdateSurvey | WCComputeVarProb | WCTerminate propagateMT :: Solver -> Int -> IO Bool propagateMT solver nthreads = do tol <- getTolerance solver lim <- getIterationLimit solver nv <- getNVars solver nc <- getNConstraints solver mask $ \restore -> do ex <- newEmptyTMVarIO let wait :: STM a -> IO a wait m = join $ atomically $ liftM return m `orElse` liftM throwIO (takeTMVar ex) workers <- do let mV = (nv + nthreads - 1) `div` nthreads mC = (nc + nthreads - 1) `div` nthreads forM [0..nthreads-1] $ \i -> do let lbV = mV * i + 1 -- inclusive ubV = min (lbV + mV) (nv + 1) -- exclusive lbC = mC * i -- exclusive ubC = min (lbC + mC) nc -- exclusive let max_v_len = VG.maximum $ VG.map VG.length $ VG.slice (lbV - 1) (ubV - lbV) (svVarEdges solver) max_c_len = VG.maximum $ VG.map VG.length $ VG.slice lbC (ubC - lbC) (svClauseEdges solver) tmp <- VGM.new (max (max_v_len*2) max_c_len) reqVar <- newEmptyMVar respVar <- newEmptyTMVarIO respVar2 <- newEmptyTMVarIO th <- forkIO $ do let loop = do cmd <- takeMVar reqVar case cmd of WCTerminate -> return () WCUpdateEdgeProb -> do numLoop lbV (ubV-1) $ \v -> updateEdgeProb solver v tmp atomically $ putTMVar respVar () loop WCUpdateSurvey -> do let f maxDelta c = max maxDelta <$> updateEdgeSurvey solver c tmp delta <- foldM f 0 [lbC .. ubC-1] atomically $ putTMVar respVar2 delta loop WCComputeVarProb -> do numLoop lbV (ubV-1) $ \v -> computeVarProb solver v atomically $ putTMVar respVar () loop restore loop `catch` \(e :: SomeException) -> atomically (tryPutTMVar ex e >> return ()) return (th, reqVar, respVar, respVar2) let loop !i | Just l <- lim, i >= l = return False | otherwise = do mapM_ (\(_,reqVar,_,_) -> putMVar reqVar WCUpdateEdgeProb) workers mapM_ (\(_,_,respVar,_) -> wait (takeTMVar respVar)) workers mapM_ (\(_,reqVar,_,_) -> putMVar reqVar WCUpdateSurvey) workers delta <- foldM (\delta (_,_,_,respVar2) -> max delta <$> wait (takeTMVar respVar2)) 0 workers if delta <= tol then do mapM_ (\(_,reqVar,_,_) -> putMVar reqVar WCComputeVarProb) workers mapM_ (\(_,_,respVar,_) -> wait (takeTMVar respVar)) workers mapM_ (\(_,reqVar,_,_) -> putMVar reqVar WCTerminate) workers return True else loop (i+1) ret <- try $ restore $ loop 0 case ret of Right b -> return b Left (e :: SomeException) -> do mapM_ (\(th,_,_,_) -> killThread th) workers throwIO e -- tmp1 must have at least @VG.length (svVarEdges solver ! (v - 1)) * 2@ elements updateEdgeProb :: Solver -> SAT.Var -> VUM.IOVector (L.Log Double) -> IO () updateEdgeProb solver v tmp = do let i = v - 1 edges = svVarEdges solver ! i m <- VGM.unsafeRead (svVarFixed solver) i case m of Just val -> do VG.forM_ edges $ \e -> do let lit = svEdgeLit solver ! e flag = (lit > 0) == val VGM.unsafeWrite (svEdgeProbU solver) e (if flag then 0 else 1) Nothing -> do let f !k !val1_pre !val2_pre | k >= VG.length edges = return () | otherwise = do let e = edges ! k a = svEdgeClause solver ! e VGM.unsafeWrite tmp (k*2) val1_pre VGM.unsafeWrite tmp (k*2+1) val2_pre eta_ai <- VGM.unsafeRead (svEdgeSurvey solver) e -- η_{a→i} let w = svClauseWeight solver ! a lit2 = svEdgeLit solver ! e val1_pre' = if lit2 > 0 then val1_pre * comp eta_ai ^* w else val1_pre val2_pre' = if lit2 > 0 then val2_pre else val2_pre * comp eta_ai ^* w f (k+1) val1_pre' val2_pre' f 0 1 1 -- tmp ! (k*2) == Π_{a∈edges[0..k-1], a∈V^{+}(i)} (1 - eta_ai)^{w_i} -- tmp ! (k*2+1) == Π_{a∈edges[0..k-1], a∈V^{-}(i)} (1 - eta_ai)^{w_i} let g !k !val1_post !val2_post | k < 0 = return () | otherwise = do let e = edges ! k a = svEdgeClause solver ! e lit2 = svEdgeLit solver ! e val1_pre <- VGM.unsafeRead tmp (k*2) val2_pre <- VGM.unsafeRead tmp (k*2+1) let val1 = val1_pre * val1_post -- val1 == Π_{b∈edges, b∈V^{+}(i), a≠b} (1 - eta_bi)^{w_i} val2 = val2_pre * val2_post -- val2 == Π_{b∈edges, b∈V^{-}(i), a≠b} (1 - eta_bi)^{w_i} eta_ai <- VGM.unsafeRead (svEdgeSurvey solver) e -- η_{a→i} let w = svClauseWeight solver ! a val1_post' = if lit2 > 0 then val1_post * comp eta_ai ^* w else val1_post val2_post' = if lit2 > 0 then val2_post else val2_post * comp eta_ai ^* w let pi_0 = val1 * val2 -- Π^0_{i→a} pi_u = if lit2 > 0 then comp val2 * val1 else comp val1 * val2 -- Π^u_{i→a} pi_s = if lit2 > 0 then comp val1 * val2 else comp val2 * val1 -- Π^s_{i→a} VGM.unsafeWrite (svEdgeProbU solver) e (pi_u / L.sum [pi_0, pi_u, pi_s]) g (k-1) val1_post' val2_post' g (VG.length edges - 1) 1 1 -- tmp must have at least @VG.length (svClauseEdges solver ! a)@ elements updateEdgeSurvey :: Solver -> ClauseIndex -> VUM.IOVector (L.Log Double) -> IO Double updateEdgeSurvey solver a tmp = do let edges = svClauseEdges solver ! a let f !k !p_pre | k >= VG.length edges = return () | otherwise = do let e = edges ! k VGM.unsafeWrite tmp k p_pre p <- VGM.unsafeRead (svEdgeProbU solver) e -- p is the probability of lit being false, if the edge does not exist. f (k+1) (p_pre * p) let g !k !p_post !maxDelta | k < 0 = return maxDelta | otherwise = do let e = edges ! k -- p_post == Π_{e∈edges[k+1..]} p_e p_pre <- VGM.unsafeRead tmp k -- Π_{e∈edges[0..k-1]} p_e p <- VGM.unsafeRead (svEdgeProbU solver) e eta_ai <- VGM.unsafeRead (svEdgeSurvey solver) e let eta_ai' = p_pre * p_post -- Π_{e∈edges[0,..,k-1,k+1,..]} p_e VGM.unsafeWrite (svEdgeSurvey solver) e eta_ai' let delta = abs (realToFrac eta_ai' - realToFrac eta_ai) g (k-1) (p_post * p) (max delta maxDelta) f 0 1 -- tmp ! k == Π_{e∈edges[0..k-1]} p_e g (VG.length edges - 1) 1 0 computeVarProb :: Solver -> SAT.Var -> IO () computeVarProb solver v = do let i = v - 1 f (val1,val2) e = do let lit = svEdgeLit solver ! e a = svEdgeClause solver ! e w = svClauseWeight solver ! a eta_ai <- VGM.unsafeRead (svEdgeSurvey solver) e let val1' = if lit > 0 then val1 * comp eta_ai ^* w else val1 val2' = if lit < 0 then val2 * comp eta_ai ^* w else val2 return (val1',val2') (val1,val2) <- VG.foldM' f (1,1) (svVarEdges solver ! i) let p0 = val1 * val2 -- \^{Π}^{0}_i pp = comp val1 * val2 -- \^{Π}^{+}_i pn = comp val2 * val1 -- \^{Π}^{-}_i let wp = pp / (pp + pn + p0) wn = pn / (pp + pn + p0) VGM.unsafeWrite (svVarProbT solver) i wp -- W^{(+)}_i VGM.unsafeWrite (svVarProbF solver) i wn -- W^{(-)}_i -- | Get the marginal probability of the variable to be @True@, @False@ and unspecified respectively. getVarProb :: Solver -> SAT.Var -> IO (Double, Double, Double) getVarProb solver v = do pt <- realToFrac <$> VGM.unsafeRead (svVarProbT solver) (v - 1) pf <- realToFrac <$> VGM.unsafeRead (svVarProbF solver) (v - 1) return (pt, pf, 1 - (pt + pf)) fixLit :: Solver -> SAT.Lit -> IO () fixLit solver lit = do VGM.unsafeWrite (svVarFixed solver) (abs lit - 1) (if lit > 0 then Just True else Just False) unfixLit :: Solver -> SAT.Lit -> IO () unfixLit solver lit = do VGM.unsafeWrite (svVarFixed solver) (abs lit - 1) Nothing printInfo :: Solver -> IO () printInfo solver = do (surveys :: VU.Vector (L.Log Double)) <- VG.freeze (svEdgeSurvey solver) (u :: VU.Vector (L.Log Double)) <- VG.freeze (svEdgeProbU solver) let xs = [(clause, lit, eta, u ! e) | (e, eta) <- zip [0..] (VG.toList surveys), let lit = svEdgeLit solver ! e, let clause = svEdgeClause solver ! e] putStrLn $ "edges: " ++ show xs (pt :: VU.Vector (L.Log Double)) <- VG.freeze (svVarProbT solver) (pf :: VU.Vector (L.Log Double)) <- VG.freeze (svVarProbF solver) nv <- getNVars solver let xs2 = [(v, realToFrac (pt ! i) :: Double, realToFrac (pf ! i) :: Double, realToFrac (pt ! i) - realToFrac (pf ! i) :: Double) | v <- [1..nv], let i = v - 1] putStrLn $ "vars: " ++ show xs2