{-# OPTIONS_GHC -Wall #-}
{-# OPTIONS_HADDOCK show-extensions #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
-----------------------------------------------------------------------------
-- |
-- Module      :  ToySolver.SAT.Solver.MessagePassing.SurveyPropagation
-- Copyright   :  (c) Masahiro Sakai 2016
-- License     :  BSD-style
--
-- Maintainer  :  masahiro.sakai@gmail.com
-- Stability   :  provisional
-- Portability :  non-portable
--
-- References:
--
-- * Alfredo Braunstein, Marc Mézard and Riccardo Zecchina.
--   Survey Propagation: An Algorithm for Satisfiability,
--   <http://arxiv.org/abs/cs/0212002>
--
-- * Corrie Scalisi. Visualizing Survey Propagation in 3-SAT Factor Graphs,
--   <http://classes.soe.ucsc.edu/cmps290c/Winter06/proj/corriereport.pdf>.
--
-----------------------------------------------------------------------------
module ToySolver.SAT.Solver.MessagePassing.SurveyPropagation
  (
  -- * The Solver type
    Solver
  , newSolver
  , deleteSolver

  -- * Problem information
  , getNVars
  , getNConstraints

  -- * Parameters
  , getTolerance
  , setTolerance
  , getIterationLimit
  , setIterationLimit
  , getNThreads
  , setNThreads

  -- * Computing marginal distributions
  , initializeRandom
  , initializeRandomDirichlet
  , propagate
  , getVarProb

  -- * Solving
  , fixLit
  , unfixLit

  -- * Debugging
  , printInfo
  ) where

import Control.Concurrent
import Control.Concurrent.STM
import Control.Exception
import Control.Loop
import Control.Monad
import qualified Data.IntMap as IntMap
import qualified Data.IntSet as IntSet
import Data.IORef
import qualified Data.Vector as V
import qualified Data.Vector.Mutable as VM
import qualified Data.Vector.Unboxed as VU
import qualified Data.Vector.Unboxed.Mutable as VUM
import Data.Vector.Generic ((!))
import qualified Data.Vector.Generic as VG
import qualified Data.Vector.Generic.Mutable as VGM
import Numeric
import qualified Numeric.Log as L
import qualified System.Random.MWC as Rand
import qualified System.Random.MWC.Distributions as Rand

import qualified ToySolver.SAT.Types as SAT

infixr 8 ^*

(^*) :: Num a => L.Log a -> a -> L.Log a
L.Exp a
a ^* :: forall a. Num a => Log a -> a -> Log a
^* a
b = forall a. a -> Log a
L.Exp (a
aforall a. Num a => a -> a -> a
*a
b)

comp :: RealFloat a => L.Log a -> L.Log a
comp :: forall a. RealFloat a => Log a -> Log a
comp (L.Exp a
a) = forall a. a -> Log a
L.Exp forall a b. (a -> b) -> a -> b
$ forall a. Floating a => a -> a
log1p forall a b. (a -> b) -> a -> b
$ forall a. Ord a => a -> a -> a
max (-a
1) forall a b. (a -> b) -> a -> b
$ forall a. Num a => a -> a
negate (forall a. Floating a => a -> a
exp a
a)

type ClauseIndex = Int
type EdgeIndex   = Int

data Solver
  = Solver
  { Solver -> Vector (Vector Int)
svVarEdges :: !(V.Vector (VU.Vector EdgeIndex))
  , Solver -> IOVector (Log Double)
svVarProbT :: !(VUM.IOVector (L.Log Double))
  , Solver -> IOVector (Log Double)
svVarProbF :: !(VUM.IOVector (L.Log Double))
  , Solver -> IOVector (Maybe Bool)
svVarFixed :: !(VM.IOVector (Maybe Bool))

  , Solver -> Vector (Vector Int)
svClauseEdges :: !(V.Vector (VU.Vector EdgeIndex))
  , Solver -> Vector Double
svClauseWeight :: !(VU.Vector Double)

  , Solver -> Vector Int
svEdgeLit    :: !(VU.Vector SAT.Lit) -- i
  , Solver -> Vector Int
svEdgeClause :: !(VU.Vector ClauseIndex) -- a
  , Solver -> IOVector (Log Double)
svEdgeSurvey :: !(VUM.IOVector (L.Log Double)) -- η_{a → i}
  , Solver -> IOVector (Log Double)
svEdgeProbU  :: !(VUM.IOVector (L.Log Double)) -- Π^u_{i → a} / (Π^u_{i → a} + Π^s_{i → a} + Π^0_{i → a})

  , Solver -> IORef Double
svTolRef :: !(IORef Double)
  , Solver -> IORef (Maybe Int)
svIterLimRef :: !(IORef (Maybe Int))
  , Solver -> IORef Int
svNThreadsRef :: !(IORef Int)
  }

newSolver :: Int -> [(Double, SAT.PackedClause)] -> IO Solver
newSolver :: Int -> [(Double, PackedClause)] -> IO Solver
newSolver Int
nv [(Double, PackedClause)]
clauses = do
  let num_clauses :: Int
num_clauses = forall (t :: * -> *) a. Foldable t => t a -> Int
length [(Double, PackedClause)]
clauses
      num_edges :: Int
num_edges = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [forall (v :: * -> *) a. Vector v a => v a -> Int
VG.length PackedClause
c | (Double
_,PackedClause
c) <- [(Double, PackedClause)]
clauses]

  IORef (IntMap IntSet)
varEdgesRef <- forall a. a -> IO (IORef a)
newIORef forall a. IntMap a
IntMap.empty
  MVector RealWorld (Vector Int)
clauseEdgesM <- forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
Int -> m (v (PrimState m) a)
VGM.new Int
num_clauses

  (IOVector Int
edgeLitM :: VUM.IOVector SAT.Lit) <- forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
Int -> m (v (PrimState m) a)
VGM.new Int
num_edges
  (IOVector Int
edgeClauseM :: VUM.IOVector ClauseIndex) <- forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
Int -> m (v (PrimState m) a)
VGM.new Int
num_edges

  IORef Int
ref <- forall a. a -> IO (IORef a)
newIORef Int
0
  forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0..] [(Double, PackedClause)]
clauses) forall a b. (a -> b) -> a -> b
$ \(Int
i,(Double
_,PackedClause
c)) -> do
    [Int]
es <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (PackedClause -> [Int]
SAT.unpackClause PackedClause
c) forall a b. (a -> b) -> a -> b
$ \Int
lit -> do
      Int
e <- forall a. IORef a -> IO a
readIORef IORef Int
ref
      forall a. IORef a -> (a -> a) -> IO ()
modifyIORef' IORef Int
ref (forall a. Num a => a -> a -> a
+Int
1)
      forall a. IORef a -> (a -> a) -> IO ()
modifyIORef' IORef (IntMap IntSet)
varEdgesRef (forall a. (a -> a -> a) -> Int -> a -> IntMap a -> IntMap a
IntMap.insertWith IntSet -> IntSet -> IntSet
IntSet.union (forall a. Num a => a -> a
abs Int
lit) (Int -> IntSet
IntSet.singleton Int
e))
      forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.unsafeWrite IOVector Int
edgeLitM Int
e Int
lit
      forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.unsafeWrite IOVector Int
edgeClauseM Int
e Int
i
      forall (m :: * -> *) a. Monad m => a -> m a
return Int
e
    forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.unsafeWrite MVector RealWorld (Vector Int)
clauseEdgesM Int
i (forall (v :: * -> *) a. Vector v a => [a] -> v a
VG.fromList [Int]
es)

  IntMap IntSet
varEdges <- forall a. IORef a -> IO a
readIORef IORef (IntMap IntSet)
varEdgesRef
  Vector (Vector Int)
clauseEdges <- forall (m :: * -> *) (v :: * -> *) a.
(PrimMonad m, Vector v a) =>
Mutable v (PrimState m) a -> m (v a)
VG.unsafeFreeze MVector RealWorld (Vector Int)
clauseEdgesM

  Vector Int
edgeLit     <- forall (m :: * -> *) (v :: * -> *) a.
(PrimMonad m, Vector v a) =>
Mutable v (PrimState m) a -> m (v a)
VG.unsafeFreeze IOVector Int
edgeLitM
  Vector Int
edgeClause  <- forall (m :: * -> *) (v :: * -> *) a.
(PrimMonad m, Vector v a) =>
Mutable v (PrimState m) a -> m (v a)
VG.unsafeFreeze IOVector Int
edgeClauseM

  -- Initialize all surveys with non-zero values.
  -- If we initialize to zero, following trivial solution exists:
  --
  -- η_{a→i} = 0 for all i and a.
  --
  -- Π^0_{i→a} = 1, Π^u_{i→a} = Π^s_{i→a} = 0 for all i and a,
  --
  -- \^{Π}^{0}_i = 1, \^{Π}^{+}_i = \^{Π}^{-}_i = 0
  --
  IOVector (Log Double)
edgeSurvey  <- forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
Int -> a -> m (v (PrimState m) a)
VGM.replicate Int
num_edges Log Double
0.5
  IOVector (Log Double)
edgeProbU   <- forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
Int -> m (v (PrimState m) a)
VGM.new Int
num_edges

  IOVector (Maybe Bool)
varFixed <- forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
Int -> a -> m (v (PrimState m) a)
VGM.replicate Int
nv forall a. Maybe a
Nothing
  IOVector (Log Double)
varProbT <- forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
Int -> m (v (PrimState m) a)
VGM.new Int
nv
  IOVector (Log Double)
varProbF <- forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
Int -> m (v (PrimState m) a)
VGM.new Int
nv

  IORef Double
tolRef <- forall a. a -> IO (IORef a)
newIORef Double
0.01
  IORef (Maybe Int)
maxIterRef <- forall a. a -> IO (IORef a)
newIORef (forall a. a -> Maybe a
Just Int
1000)
  IORef Int
nthreadsRef <- forall a. a -> IO (IORef a)
newIORef Int
1

  let solver :: Solver
solver = Solver
        { svVarEdges :: Vector (Vector Int)
svVarEdges    = forall (v :: * -> *) a. Vector v a => Int -> (Int -> a) -> v a
VG.generate Int
nv forall a b. (a -> b) -> a -> b
$ \Int
i ->
            case forall a. Int -> IntMap a -> Maybe a
IntMap.lookup (Int
iforall a. Num a => a -> a -> a
+Int
1) IntMap IntSet
varEdges of
              Maybe IntSet
Nothing -> forall (v :: * -> *) a. Vector v a => v a
VG.empty
              Just IntSet
es -> forall (v :: * -> *) a. Vector v a => Int -> [a] -> v a
VG.fromListN (IntSet -> Int
IntSet.size IntSet
es) (IntSet -> [Int]
IntSet.toList IntSet
es)
        , svVarProbT :: IOVector (Log Double)
svVarProbT = IOVector (Log Double)
varProbT
        , svVarProbF :: IOVector (Log Double)
svVarProbF = IOVector (Log Double)
varProbF
        , svVarFixed :: IOVector (Maybe Bool)
svVarFixed = IOVector (Maybe Bool)
varFixed

        , svClauseEdges :: Vector (Vector Int)
svClauseEdges = Vector (Vector Int)
clauseEdges
        , svClauseWeight :: Vector Double
svClauseWeight = forall (v :: * -> *) a. Vector v a => Int -> [a] -> v a
VG.fromListN (forall (v :: * -> *) a. Vector v a => v a -> Int
VG.length Vector (Vector Int)
clauseEdges) (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(Double, PackedClause)]
clauses)

        , svEdgeLit :: Vector Int
svEdgeLit     = Vector Int
edgeLit
        , svEdgeClause :: Vector Int
svEdgeClause  = Vector Int
edgeClause
        , svEdgeSurvey :: IOVector (Log Double)
svEdgeSurvey  = IOVector (Log Double)
edgeSurvey
        , svEdgeProbU :: IOVector (Log Double)
svEdgeProbU   = IOVector (Log Double)
edgeProbU

        , svTolRef :: IORef Double
svTolRef = IORef Double
tolRef
        , svIterLimRef :: IORef (Maybe Int)
svIterLimRef = IORef (Maybe Int)
maxIterRef
        , svNThreadsRef :: IORef Int
svNThreadsRef = IORef Int
nthreadsRef
        }

  forall (m :: * -> *) a. Monad m => a -> m a
return Solver
solver

deleteSolver :: Solver -> IO ()
deleteSolver :: Solver -> IO ()
deleteSolver Solver
_solver = forall (m :: * -> *) a. Monad m => a -> m a
return ()

initializeRandom :: Solver -> Rand.GenIO -> IO ()
initializeRandom :: Solver -> GenIO -> IO ()
initializeRandom Solver
solver GenIO
gen = do
  forall (m :: * -> *) (v :: * -> *) a b.
(Monad m, Vector v a) =>
v a -> (a -> m b) -> m ()
VG.forM_ (Solver -> Vector (Vector Int)
svClauseEdges Solver
solver) forall a b. (a -> b) -> a -> b
$ \Vector Int
es -> do
    case forall (v :: * -> *) a. Vector v a => v a -> Int
VG.length Vector Int
es of
      Int
0 -> forall (m :: * -> *) a. Monad m => a -> m a
return ()
      Int
1 -> forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.unsafeWrite (Solver -> IOVector (Log Double)
svEdgeSurvey Solver
solver) (Vector Int
es forall (v :: * -> *) a. Vector v a => v a -> Int -> a
! Int
0) Log Double
1
      Int
n -> do
        let p :: Double
            p :: Double
p = Double
1 forall a. Fractional a => a -> a -> a
/ forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n
        forall (m :: * -> *) (v :: * -> *) a b.
(Monad m, Vector v a) =>
v a -> (a -> m b) -> m ()
VG.forM_ Vector Int
es forall a b. (a -> b) -> a -> b
$ \Int
e -> do
          Double
d <- forall a (m :: * -> *).
(Variate a, PrimMonad m) =>
(a, a) -> Gen (PrimState m) -> m a
Rand.uniformR (Double
pforall a. Num a => a -> a -> a
*Double
0.5, Double
pforall a. Num a => a -> a -> a
*Double
1.5) GenIO
gen
          forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.unsafeWrite (Solver -> IOVector (Log Double)
svEdgeSurvey Solver
solver) Int
e (forall a b. (Real a, Fractional b) => a -> b
realToFrac Double
d)

initializeRandomDirichlet :: Solver -> Rand.GenIO -> IO ()
initializeRandomDirichlet :: Solver -> GenIO -> IO ()
initializeRandomDirichlet Solver
solver GenIO
gen = do
  forall (m :: * -> *) (v :: * -> *) a b.
(Monad m, Vector v a) =>
v a -> (a -> m b) -> m ()
VG.forM_ (Solver -> Vector (Vector Int)
svClauseEdges Solver
solver) forall a b. (a -> b) -> a -> b
$ \Vector Int
es -> do
    case forall (v :: * -> *) a. Vector v a => v a -> Int
VG.length Vector Int
es of
      Int
0 -> forall (m :: * -> *) a. Monad m => a -> m a
return ()
      Int
1 -> forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.unsafeWrite (Solver -> IOVector (Log Double)
svEdgeSurvey Solver
solver) (Vector Int
es forall (v :: * -> *) a. Vector v a => v a -> Int -> a
! Int
0) Log Double
1
      Int
len -> do
        (Vector Double
ps :: V.Vector Double) <- forall g (m :: * -> *) (t :: * -> *).
(StatefulGen g m, Traversable t) =>
t Double -> g -> m (t Double)
Rand.dirichlet (forall (v :: * -> *) a. Vector v a => Int -> a -> v a
VG.replicate Int
len Double
1) GenIO
gen
        forall a (m :: * -> *).
(Num a, Ord a, Monad m) =>
a -> a -> (a -> m ()) -> m ()
numLoop Int
0 (Int
lenforall a. Num a => a -> a -> a
-Int
1) forall a b. (a -> b) -> a -> b
$ \Int
i -> do
          forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.unsafeWrite (Solver -> IOVector (Log Double)
svEdgeSurvey Solver
solver) (Vector Int
es forall (v :: * -> *) a. Vector v a => v a -> Int -> a
! Int
i) (forall a b. (Real a, Fractional b) => a -> b
realToFrac (Vector Double
ps forall (v :: * -> *) a. Vector v a => v a -> Int -> a
! Int
i))

-- | number of variables of the problem.
getNVars :: Solver -> IO Int
getNVars :: Solver -> IO Int
getNVars Solver
solver = forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall (v :: * -> *) a. Vector v a => v a -> Int
VG.length (Solver -> Vector (Vector Int)
svVarEdges Solver
solver)

-- | number of constraints of the problem.
getNConstraints :: Solver -> IO Int
getNConstraints :: Solver -> IO Int
getNConstraints Solver
solver = forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall (v :: * -> *) a. Vector v a => v a -> Int
VG.length (Solver -> Vector (Vector Int)
svClauseEdges Solver
solver)

-- | number of edges of the factor graph
getNEdges :: Solver -> IO Int
getNEdges :: Solver -> IO Int
getNEdges Solver
solver = forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall (v :: * -> *) a. Vector v a => v a -> Int
VG.length (Solver -> Vector Int
svEdgeLit Solver
solver)

getTolerance :: Solver -> IO Double
getTolerance :: Solver -> IO Double
getTolerance Solver
solver = forall a. IORef a -> IO a
readIORef (Solver -> IORef Double
svTolRef Solver
solver)

setTolerance :: Solver -> Double -> IO ()
setTolerance :: Solver -> Double -> IO ()
setTolerance Solver
solver !Double
tol = forall a. IORef a -> a -> IO ()
writeIORef (Solver -> IORef Double
svTolRef Solver
solver) Double
tol

getIterationLimit :: Solver -> IO (Maybe Int)
getIterationLimit :: Solver -> IO (Maybe Int)
getIterationLimit Solver
solver = forall a. IORef a -> IO a
readIORef (Solver -> IORef (Maybe Int)
svIterLimRef Solver
solver)

setIterationLimit :: Solver -> Maybe Int -> IO ()
setIterationLimit :: Solver -> Maybe Int -> IO ()
setIterationLimit Solver
solver Maybe Int
val = forall a. IORef a -> a -> IO ()
writeIORef (Solver -> IORef (Maybe Int)
svIterLimRef Solver
solver) Maybe Int
val

getNThreads :: Solver -> IO Int
getNThreads :: Solver -> IO Int
getNThreads Solver
solver = forall a. IORef a -> IO a
readIORef (Solver -> IORef Int
svNThreadsRef Solver
solver)

setNThreads :: Solver -> Int -> IO ()
setNThreads :: Solver -> Int -> IO ()
setNThreads Solver
solver Int
val = forall a. IORef a -> a -> IO ()
writeIORef (Solver -> IORef Int
svNThreadsRef Solver
solver) Int
val

propagate :: Solver -> IO Bool
propagate :: Solver -> IO Bool
propagate Solver
solver = do
  Int
nthreads <- Solver -> IO Int
getNThreads Solver
solver
  if Int
nthreads forall a. Ord a => a -> a -> Bool
> Int
1 then
    Solver -> Int -> IO Bool
propagateMT Solver
solver Int
nthreads
  else
    Solver -> IO Bool
propagateST Solver
solver

propagateST :: Solver -> IO Bool
propagateST :: Solver -> IO Bool
propagateST Solver
solver = do
  Double
tol <- Solver -> IO Double
getTolerance Solver
solver
  Maybe Int
lim <- Solver -> IO (Maybe Int)
getIterationLimit Solver
solver
  Int
nv <- Solver -> IO Int
getNVars Solver
solver
  Int
nc <- Solver -> IO Int
getNConstraints Solver
solver
  let max_v_len :: Int
max_v_len = forall (v :: * -> *) a. (Vector v a, Ord a) => v a -> a
VG.maximum forall a b. (a -> b) -> a -> b
$ forall (v :: * -> *) a b.
(Vector v a, Vector v b) =>
(a -> b) -> v a -> v b
VG.map forall (v :: * -> *) a. Vector v a => v a -> Int
VG.length forall a b. (a -> b) -> a -> b
$ Solver -> Vector (Vector Int)
svVarEdges Solver
solver
      max_c_len :: Int
max_c_len = forall (v :: * -> *) a. (Vector v a, Ord a) => v a -> a
VG.maximum forall a b. (a -> b) -> a -> b
$ forall (v :: * -> *) a b.
(Vector v a, Vector v b) =>
(a -> b) -> v a -> v b
VG.map forall (v :: * -> *) a. Vector v a => v a -> Int
VG.length forall a b. (a -> b) -> a -> b
$ Solver -> Vector (Vector Int)
svClauseEdges Solver
solver
  IOVector (Log Double)
tmp <- forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
Int -> m (v (PrimState m) a)
VGM.new (forall a. Ord a => a -> a -> a
max (Int
max_v_len forall a. Num a => a -> a -> a
* Int
2) Int
max_c_len)
  let loop :: Int -> IO Bool
loop !Int
i
        | Just Int
l <- Maybe Int
lim, Int
i forall a. Ord a => a -> a -> Bool
>= Int
l = forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
        | Bool
otherwise = do
            forall a (m :: * -> *).
(Num a, Ord a, Monad m) =>
a -> a -> (a -> m ()) -> m ()
numLoop Int
1 Int
nv forall a b. (a -> b) -> a -> b
$ \Int
v -> Solver -> Int -> IOVector (Log Double) -> IO ()
updateEdgeProb Solver
solver Int
v IOVector (Log Double)
tmp
            let f :: Double -> Int -> IO Double
f Double
maxDelta Int
c = forall a. Ord a => a -> a -> a
max Double
maxDelta forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Solver -> Int -> IOVector (Log Double) -> IO Double
updateEdgeSurvey Solver
solver Int
c IOVector (Log Double)
tmp
            Double
delta <- forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM Double -> Int -> IO Double
f Double
0 [Int
0 .. Int
ncforall a. Num a => a -> a -> a
-Int
1]
            if Double
delta forall a. Ord a => a -> a -> Bool
<= Double
tol then do
              forall a (m :: * -> *).
(Num a, Ord a, Monad m) =>
a -> a -> (a -> m ()) -> m ()
numLoop Int
1 Int
nv forall a b. (a -> b) -> a -> b
$ \Int
v -> Solver -> Int -> IO ()
computeVarProb Solver
solver Int
v
              forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
            else
              Int -> IO Bool
loop (Int
iforall a. Num a => a -> a -> a
+Int
1)
  Int -> IO Bool
loop Int
0

data WorkerCommand
  = WCUpdateEdgeProb
  | WCUpdateSurvey
  | WCComputeVarProb
  | WCTerminate

propagateMT :: Solver -> Int -> IO Bool
propagateMT :: Solver -> Int -> IO Bool
propagateMT Solver
solver Int
nthreads = do
  Double
tol <- Solver -> IO Double
getTolerance Solver
solver
  Maybe Int
lim <- Solver -> IO (Maybe Int)
getIterationLimit Solver
solver
  Int
nv <- Solver -> IO Int
getNVars Solver
solver
  Int
nc <- Solver -> IO Int
getNConstraints Solver
solver

  forall b. ((forall a. IO a -> IO a) -> IO b) -> IO b
mask forall a b. (a -> b) -> a -> b
$ \forall a. IO a -> IO a
restore -> do
    TMVar SomeException
ex <- forall a. IO (TMVar a)
newEmptyTMVarIO
    let wait :: STM a -> IO a
        wait :: forall a. STM a -> IO a
wait STM a
m = forall (m :: * -> *) a. Monad m => m (m a) -> m a
join forall a b. (a -> b) -> a -> b
$ forall a. STM a -> IO a
atomically forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM forall (m :: * -> *) a. Monad m => a -> m a
return STM a
m forall a. STM a -> STM a -> STM a
`orElse` forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM forall e a. Exception e => e -> IO a
throwIO (forall a. TMVar a -> STM a
takeTMVar TMVar SomeException
ex)

    [(ThreadId, MVar WorkerCommand, TMVar (), TMVar Double)]
workers <- do
      let mV :: Int
mV = (Int
nv forall a. Num a => a -> a -> a
+ Int
nthreads forall a. Num a => a -> a -> a
- Int
1) forall a. Integral a => a -> a -> a
`div` Int
nthreads
          mC :: Int
mC = (Int
nc forall a. Num a => a -> a -> a
+ Int
nthreads forall a. Num a => a -> a -> a
- Int
1) forall a. Integral a => a -> a -> a
`div` Int
nthreads
      forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Int
0..Int
nthreadsforall a. Num a => a -> a -> a
-Int
1] forall a b. (a -> b) -> a -> b
$ \Int
i -> do
         let lbV :: Int
lbV = Int
mV forall a. Num a => a -> a -> a
* Int
i forall a. Num a => a -> a -> a
+ Int
1 -- inclusive
             ubV :: Int
ubV = forall a. Ord a => a -> a -> a
min (Int
lbV forall a. Num a => a -> a -> a
+ Int
mV) (Int
nv forall a. Num a => a -> a -> a
+ Int
1) -- exclusive
             lbC :: Int
lbC = Int
mC forall a. Num a => a -> a -> a
* Int
i -- exclusive
             ubC :: Int
ubC = forall a. Ord a => a -> a -> a
min (Int
lbC forall a. Num a => a -> a -> a
+ Int
mC) Int
nc -- exclusive
         let max_v_len :: Int
max_v_len = forall (v :: * -> *) a. (Vector v a, Ord a) => v a -> a
VG.maximum forall a b. (a -> b) -> a -> b
$ forall (v :: * -> *) a b.
(Vector v a, Vector v b) =>
(a -> b) -> v a -> v b
VG.map forall (v :: * -> *) a. Vector v a => v a -> Int
VG.length forall a b. (a -> b) -> a -> b
$ forall (v :: * -> *) a. Vector v a => Int -> Int -> v a -> v a
VG.slice (Int
lbV forall a. Num a => a -> a -> a
- Int
1) (Int
ubV forall a. Num a => a -> a -> a
- Int
lbV) (Solver -> Vector (Vector Int)
svVarEdges Solver
solver)
             max_c_len :: Int
max_c_len = forall (v :: * -> *) a. (Vector v a, Ord a) => v a -> a
VG.maximum forall a b. (a -> b) -> a -> b
$ forall (v :: * -> *) a b.
(Vector v a, Vector v b) =>
(a -> b) -> v a -> v b
VG.map forall (v :: * -> *) a. Vector v a => v a -> Int
VG.length forall a b. (a -> b) -> a -> b
$ forall (v :: * -> *) a. Vector v a => Int -> Int -> v a -> v a
VG.slice Int
lbC (Int
ubC forall a. Num a => a -> a -> a
- Int
lbC) (Solver -> Vector (Vector Int)
svClauseEdges Solver
solver)
         IOVector (Log Double)
tmp <- forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
Int -> m (v (PrimState m) a)
VGM.new (forall a. Ord a => a -> a -> a
max (Int
max_v_lenforall a. Num a => a -> a -> a
*Int
2) Int
max_c_len)
         MVar WorkerCommand
reqVar   <- forall a. IO (MVar a)
newEmptyMVar
         TMVar ()
respVar  <- forall a. IO (TMVar a)
newEmptyTMVarIO
         TMVar Double
respVar2 <- forall a. IO (TMVar a)
newEmptyTMVarIO
         ThreadId
th <- IO () -> IO ThreadId
forkIO forall a b. (a -> b) -> a -> b
$ do
           let loop :: IO ()
loop = do
                 WorkerCommand
cmd <- forall a. MVar a -> IO a
takeMVar MVar WorkerCommand
reqVar
                 case WorkerCommand
cmd of
                   WorkerCommand
WCTerminate -> forall (m :: * -> *) a. Monad m => a -> m a
return ()
                   WorkerCommand
WCUpdateEdgeProb -> do
                     forall a (m :: * -> *).
(Num a, Ord a, Monad m) =>
a -> a -> (a -> m ()) -> m ()
numLoop Int
lbV (Int
ubVforall a. Num a => a -> a -> a
-Int
1) forall a b. (a -> b) -> a -> b
$ \Int
v -> Solver -> Int -> IOVector (Log Double) -> IO ()
updateEdgeProb Solver
solver Int
v IOVector (Log Double)
tmp
                     forall a. STM a -> IO a
atomically forall a b. (a -> b) -> a -> b
$ forall a. TMVar a -> a -> STM ()
putTMVar TMVar ()
respVar ()
                     IO ()
loop
                   WorkerCommand
WCUpdateSurvey -> do
                     let f :: Double -> Int -> IO Double
f Double
maxDelta Int
c = forall a. Ord a => a -> a -> a
max Double
maxDelta forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Solver -> Int -> IOVector (Log Double) -> IO Double
updateEdgeSurvey Solver
solver Int
c IOVector (Log Double)
tmp
                     Double
delta <- forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM Double -> Int -> IO Double
f Double
0 [Int
lbC .. Int
ubCforall a. Num a => a -> a -> a
-Int
1]
                     forall a. STM a -> IO a
atomically forall a b. (a -> b) -> a -> b
$ forall a. TMVar a -> a -> STM ()
putTMVar TMVar Double
respVar2 Double
delta
                     IO ()
loop
                   WorkerCommand
WCComputeVarProb -> do
                     forall a (m :: * -> *).
(Num a, Ord a, Monad m) =>
a -> a -> (a -> m ()) -> m ()
numLoop Int
lbV (Int
ubVforall a. Num a => a -> a -> a
-Int
1) forall a b. (a -> b) -> a -> b
$ \Int
v -> Solver -> Int -> IO ()
computeVarProb Solver
solver Int
v
                     forall a. STM a -> IO a
atomically forall a b. (a -> b) -> a -> b
$ forall a. TMVar a -> a -> STM ()
putTMVar TMVar ()
respVar ()
                     IO ()
loop
           forall a. IO a -> IO a
restore IO ()
loop forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`catch` \(SomeException
e :: SomeException) -> forall a. STM a -> IO a
atomically (forall a. TMVar a -> a -> STM Bool
tryPutTMVar TMVar SomeException
ex SomeException
e forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall (m :: * -> *) a. Monad m => a -> m a
return ())
         forall (m :: * -> *) a. Monad m => a -> m a
return (ThreadId
th, MVar WorkerCommand
reqVar, TMVar ()
respVar, TMVar Double
respVar2)

    let loop :: Int -> IO Bool
loop !Int
i
          | Just Int
l <- Maybe Int
lim, Int
i forall a. Ord a => a -> a -> Bool
>= Int
l = forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
          | Bool
otherwise = do
              forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (\(ThreadId
_,MVar WorkerCommand
reqVar,TMVar ()
_,TMVar Double
_) -> forall a. MVar a -> a -> IO ()
putMVar MVar WorkerCommand
reqVar WorkerCommand
WCUpdateEdgeProb) [(ThreadId, MVar WorkerCommand, TMVar (), TMVar Double)]
workers
              forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (\(ThreadId
_,MVar WorkerCommand
_,TMVar ()
respVar,TMVar Double
_) -> forall a. STM a -> IO a
wait (forall a. TMVar a -> STM a
takeTMVar TMVar ()
respVar)) [(ThreadId, MVar WorkerCommand, TMVar (), TMVar Double)]
workers
              forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (\(ThreadId
_,MVar WorkerCommand
reqVar,TMVar ()
_,TMVar Double
_) -> forall a. MVar a -> a -> IO ()
putMVar MVar WorkerCommand
reqVar WorkerCommand
WCUpdateSurvey) [(ThreadId, MVar WorkerCommand, TMVar (), TMVar Double)]
workers
              Double
delta <- forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (\Double
delta (ThreadId
_,MVar WorkerCommand
_,TMVar ()
_,TMVar Double
respVar2) -> forall a. Ord a => a -> a -> a
max Double
delta forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. STM a -> IO a
wait (forall a. TMVar a -> STM a
takeTMVar TMVar Double
respVar2)) Double
0 [(ThreadId, MVar WorkerCommand, TMVar (), TMVar Double)]
workers
              if Double
delta forall a. Ord a => a -> a -> Bool
<= Double
tol then do
                forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (\(ThreadId
_,MVar WorkerCommand
reqVar,TMVar ()
_,TMVar Double
_) -> forall a. MVar a -> a -> IO ()
putMVar MVar WorkerCommand
reqVar WorkerCommand
WCComputeVarProb) [(ThreadId, MVar WorkerCommand, TMVar (), TMVar Double)]
workers
                forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (\(ThreadId
_,MVar WorkerCommand
_,TMVar ()
respVar,TMVar Double
_) -> forall a. STM a -> IO a
wait (forall a. TMVar a -> STM a
takeTMVar TMVar ()
respVar)) [(ThreadId, MVar WorkerCommand, TMVar (), TMVar Double)]
workers
                forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (\(ThreadId
_,MVar WorkerCommand
reqVar,TMVar ()
_,TMVar Double
_) -> forall a. MVar a -> a -> IO ()
putMVar MVar WorkerCommand
reqVar WorkerCommand
WCTerminate) [(ThreadId, MVar WorkerCommand, TMVar (), TMVar Double)]
workers
                forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
              else
                Int -> IO Bool
loop (Int
iforall a. Num a => a -> a -> a
+Int
1)

    Either SomeException Bool
ret <- forall e a. Exception e => IO a -> IO (Either e a)
try forall a b. (a -> b) -> a -> b
$ forall a. IO a -> IO a
restore forall a b. (a -> b) -> a -> b
$ Int -> IO Bool
loop Int
0
    case Either SomeException Bool
ret of
      Right Bool
b -> forall (m :: * -> *) a. Monad m => a -> m a
return Bool
b
      Left (SomeException
e :: SomeException) -> do
        forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (\(ThreadId
th,MVar WorkerCommand
_,TMVar ()
_,TMVar Double
_) -> ThreadId -> IO ()
killThread ThreadId
th) [(ThreadId, MVar WorkerCommand, TMVar (), TMVar Double)]
workers
        forall e a. Exception e => e -> IO a
throwIO SomeException
e

-- tmp1 must have at least @VG.length (svVarEdges solver ! (v - 1)) * 2@ elements
updateEdgeProb :: Solver -> SAT.Var -> VUM.IOVector (L.Log Double) -> IO ()
updateEdgeProb :: Solver -> Int -> IOVector (Log Double) -> IO ()
updateEdgeProb Solver
solver Int
v IOVector (Log Double)
tmp = do
  let i :: Int
i = Int
v forall a. Num a => a -> a -> a
- Int
1
      edges :: Vector Int
edges = Solver -> Vector (Vector Int)
svVarEdges Solver
solver forall (v :: * -> *) a. Vector v a => v a -> Int -> a
! Int
i
  Maybe Bool
m <- forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.unsafeRead (Solver -> IOVector (Maybe Bool)
svVarFixed Solver
solver) Int
i
  case Maybe Bool
m of
    Just Bool
val -> do
      forall (m :: * -> *) (v :: * -> *) a b.
(Monad m, Vector v a) =>
v a -> (a -> m b) -> m ()
VG.forM_ Vector Int
edges forall a b. (a -> b) -> a -> b
$ \Int
e -> do
        let lit :: Int
lit = Solver -> Vector Int
svEdgeLit Solver
solver forall (v :: * -> *) a. Vector v a => v a -> Int -> a
! Int
e
            flag :: Bool
flag = (Int
lit forall a. Ord a => a -> a -> Bool
> Int
0) forall a. Eq a => a -> a -> Bool
== Bool
val
        forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.unsafeWrite (Solver -> IOVector (Log Double)
svEdgeProbU Solver
solver) Int
e (if Bool
flag then Log Double
0 else Log Double
1)
    Maybe Bool
Nothing -> do
      let f :: Int -> Log Double -> Log Double -> IO ()
f !Int
k !Log Double
val1_pre !Log Double
val2_pre
            | Int
k forall a. Ord a => a -> a -> Bool
>= forall (v :: * -> *) a. Vector v a => v a -> Int
VG.length Vector Int
edges = forall (m :: * -> *) a. Monad m => a -> m a
return ()
            | Bool
otherwise = do
                let e :: Int
e = Vector Int
edges forall (v :: * -> *) a. Vector v a => v a -> Int -> a
! Int
k
                    a :: Int
a = Solver -> Vector Int
svEdgeClause Solver
solver forall (v :: * -> *) a. Vector v a => v a -> Int -> a
! Int
e
                forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.unsafeWrite IOVector (Log Double)
tmp (Int
kforall a. Num a => a -> a -> a
*Int
2) Log Double
val1_pre
                forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.unsafeWrite IOVector (Log Double)
tmp (Int
kforall a. Num a => a -> a -> a
*Int
2forall a. Num a => a -> a -> a
+Int
1) Log Double
val2_pre
                Log Double
eta_ai <- forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.unsafeRead (Solver -> IOVector (Log Double)
svEdgeSurvey Solver
solver) Int
e -- η_{a→i}
                let w :: Double
w = Solver -> Vector Double
svClauseWeight Solver
solver forall (v :: * -> *) a. Vector v a => v a -> Int -> a
! Int
a
                    lit2 :: Int
lit2 = Solver -> Vector Int
svEdgeLit Solver
solver forall (v :: * -> *) a. Vector v a => v a -> Int -> a
! Int
e
                    val1_pre' :: Log Double
val1_pre' = if Int
lit2 forall a. Ord a => a -> a -> Bool
> Int
0 then Log Double
val1_pre forall a. Num a => a -> a -> a
* forall a. RealFloat a => Log a -> Log a
comp Log Double
eta_ai forall a. Num a => Log a -> a -> Log a
^* Double
w else Log Double
val1_pre
                    val2_pre' :: Log Double
val2_pre' = if Int
lit2 forall a. Ord a => a -> a -> Bool
> Int
0 then Log Double
val2_pre else Log Double
val2_pre forall a. Num a => a -> a -> a
* forall a. RealFloat a => Log a -> Log a
comp Log Double
eta_ai forall a. Num a => Log a -> a -> Log a
^* Double
w
                Int -> Log Double -> Log Double -> IO ()
f (Int
kforall a. Num a => a -> a -> a
+Int
1) Log Double
val1_pre' Log Double
val2_pre'
      Int -> Log Double -> Log Double -> IO ()
f Int
0 Log Double
1 Log Double
1
      -- tmp ! (k*2)   == Π_{a∈edges[0..k-1], a∈V^{+}(i)} (1 - eta_ai)^{w_i}
      -- tmp ! (k*2+1) == Π_{a∈edges[0..k-1], a∈V^{-}(i)} (1 - eta_ai)^{w_i}

      let g :: Int -> Log Double -> Log Double -> IO ()
g !Int
k !Log Double
val1_post !Log Double
val2_post
            | Int
k forall a. Ord a => a -> a -> Bool
< Int
0 = forall (m :: * -> *) a. Monad m => a -> m a
return ()
            | Bool
otherwise = do
                let e :: Int
e = Vector Int
edges forall (v :: * -> *) a. Vector v a => v a -> Int -> a
! Int
k
                    a :: Int
a = Solver -> Vector Int
svEdgeClause Solver
solver forall (v :: * -> *) a. Vector v a => v a -> Int -> a
! Int
e
                    lit2 :: Int
lit2 = Solver -> Vector Int
svEdgeLit Solver
solver forall (v :: * -> *) a. Vector v a => v a -> Int -> a
! Int
e
                Log Double
val1_pre <- forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.unsafeRead IOVector (Log Double)
tmp (Int
kforall a. Num a => a -> a -> a
*Int
2)
                Log Double
val2_pre <- forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.unsafeRead IOVector (Log Double)
tmp (Int
kforall a. Num a => a -> a -> a
*Int
2forall a. Num a => a -> a -> a
+Int
1)
                let val1 :: Log Double
val1 = Log Double
val1_pre forall a. Num a => a -> a -> a
* Log Double
val1_post -- val1 == Π_{b∈edges, b∈V^{+}(i), a≠b} (1 - eta_bi)^{w_i}
                    val2 :: Log Double
val2 = Log Double
val2_pre forall a. Num a => a -> a -> a
* Log Double
val2_post -- val2 == Π_{b∈edges, b∈V^{-}(i), a≠b} (1 - eta_bi)^{w_i}
                Log Double
eta_ai <- forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.unsafeRead (Solver -> IOVector (Log Double)
svEdgeSurvey Solver
solver) Int
e -- η_{a→i}
                let w :: Double
w = Solver -> Vector Double
svClauseWeight Solver
solver forall (v :: * -> *) a. Vector v a => v a -> Int -> a
! Int
a
                    val1_post' :: Log Double
val1_post' = if Int
lit2 forall a. Ord a => a -> a -> Bool
> Int
0 then Log Double
val1_post forall a. Num a => a -> a -> a
* forall a. RealFloat a => Log a -> Log a
comp Log Double
eta_ai forall a. Num a => Log a -> a -> Log a
^* Double
w else Log Double
val1_post
                    val2_post' :: Log Double
val2_post' = if Int
lit2 forall a. Ord a => a -> a -> Bool
> Int
0 then Log Double
val2_post else Log Double
val2_post forall a. Num a => a -> a -> a
* forall a. RealFloat a => Log a -> Log a
comp Log Double
eta_ai forall a. Num a => Log a -> a -> Log a
^* Double
w
                let pi_0 :: Log Double
pi_0 = Log Double
val1 forall a. Num a => a -> a -> a
* Log Double
val2 -- Π^0_{i→a}
                    pi_u :: Log Double
pi_u = if Int
lit2 forall a. Ord a => a -> a -> Bool
> Int
0 then forall a. RealFloat a => Log a -> Log a
comp Log Double
val2 forall a. Num a => a -> a -> a
* Log Double
val1 else forall a. RealFloat a => Log a -> Log a
comp Log Double
val1 forall a. Num a => a -> a -> a
* Log Double
val2 -- Π^u_{i→a}
                    pi_s :: Log Double
pi_s = if Int
lit2 forall a. Ord a => a -> a -> Bool
> Int
0 then forall a. RealFloat a => Log a -> Log a
comp Log Double
val1 forall a. Num a => a -> a -> a
* Log Double
val2 else forall a. RealFloat a => Log a -> Log a
comp Log Double
val2 forall a. Num a => a -> a -> a
* Log Double
val1 -- Π^s_{i→a}
                forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.unsafeWrite (Solver -> IOVector (Log Double)
svEdgeProbU Solver
solver) Int
e (Log Double
pi_u forall a. Fractional a => a -> a -> a
/ forall a (f :: * -> *).
(RealFloat a, Foldable f) =>
f (Log a) -> Log a
L.sum [Log Double
pi_0, Log Double
pi_u, Log Double
pi_s])
                Int -> Log Double -> Log Double -> IO ()
g (Int
kforall a. Num a => a -> a -> a
-Int
1) Log Double
val1_post' Log Double
val2_post'
      Int -> Log Double -> Log Double -> IO ()
g (forall (v :: * -> *) a. Vector v a => v a -> Int
VG.length Vector Int
edges forall a. Num a => a -> a -> a
- Int
1) Log Double
1 Log Double
1

-- tmp must have at least @VG.length (svClauseEdges solver ! a)@ elements
updateEdgeSurvey :: Solver -> ClauseIndex -> VUM.IOVector (L.Log Double) -> IO Double
updateEdgeSurvey :: Solver -> Int -> IOVector (Log Double) -> IO Double
updateEdgeSurvey Solver
solver Int
a IOVector (Log Double)
tmp = do
  let edges :: Vector Int
edges = Solver -> Vector (Vector Int)
svClauseEdges Solver
solver forall (v :: * -> *) a. Vector v a => v a -> Int -> a
! Int
a
  let f :: Int -> Log Double -> IO ()
f !Int
k !Log Double
p_pre
        | Int
k forall a. Ord a => a -> a -> Bool
>= forall (v :: * -> *) a. Vector v a => v a -> Int
VG.length Vector Int
edges = forall (m :: * -> *) a. Monad m => a -> m a
return ()
        | Bool
otherwise = do
            let e :: Int
e = Vector Int
edges forall (v :: * -> *) a. Vector v a => v a -> Int -> a
! Int
k
            forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.unsafeWrite IOVector (Log Double)
tmp Int
k Log Double
p_pre
            Log Double
p <- forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.unsafeRead (Solver -> IOVector (Log Double)
svEdgeProbU Solver
solver) Int
e
            -- p is the probability of lit being false, if the edge does not exist.
            Int -> Log Double -> IO ()
f (Int
kforall a. Num a => a -> a -> a
+Int
1) (Log Double
p_pre forall a. Num a => a -> a -> a
* Log Double
p)
  let g :: Int -> Log Double -> Double -> IO Double
g !Int
k !Log Double
p_post !Double
maxDelta
        | Int
k forall a. Ord a => a -> a -> Bool
< Int
0 = forall (m :: * -> *) a. Monad m => a -> m a
return Double
maxDelta
        | Bool
otherwise = do
            let e :: Int
e = Vector Int
edges forall (v :: * -> *) a. Vector v a => v a -> Int -> a
! Int
k
            -- p_post == Π_{e∈edges[k+1..]} p_e
            Log Double
p_pre <- forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.unsafeRead IOVector (Log Double)
tmp Int
k -- Π_{e∈edges[0..k-1]} p_e
            Log Double
p <- forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.unsafeRead (Solver -> IOVector (Log Double)
svEdgeProbU Solver
solver) Int
e
            Log Double
eta_ai <- forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.unsafeRead (Solver -> IOVector (Log Double)
svEdgeSurvey Solver
solver) Int
e
            let eta_ai' :: Log Double
eta_ai' = Log Double
p_pre forall a. Num a => a -> a -> a
* Log Double
p_post -- Π_{e∈edges[0,..,k-1,k+1,..]} p_e
            forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.unsafeWrite (Solver -> IOVector (Log Double)
svEdgeSurvey Solver
solver) Int
e Log Double
eta_ai'
            let delta :: Double
delta = forall a. Num a => a -> a
abs (forall a b. (Real a, Fractional b) => a -> b
realToFrac Log Double
eta_ai' forall a. Num a => a -> a -> a
- forall a b. (Real a, Fractional b) => a -> b
realToFrac Log Double
eta_ai)
            Int -> Log Double -> Double -> IO Double
g (Int
kforall a. Num a => a -> a -> a
-Int
1) (Log Double
p_post forall a. Num a => a -> a -> a
* Log Double
p) (forall a. Ord a => a -> a -> a
max Double
delta Double
maxDelta)
  Int -> Log Double -> IO ()
f Int
0 Log Double
1
  -- tmp ! k == Π_{e∈edges[0..k-1]} p_e
  Int -> Log Double -> Double -> IO Double
g (forall (v :: * -> *) a. Vector v a => v a -> Int
VG.length Vector Int
edges forall a. Num a => a -> a -> a
- Int
1) Log Double
1 Double
0

computeVarProb :: Solver -> SAT.Var -> IO ()
computeVarProb :: Solver -> Int -> IO ()
computeVarProb Solver
solver Int
v = do
  let i :: Int
i = Int
v forall a. Num a => a -> a -> a
- Int
1
      f :: (Log Double, Log Double) -> Int -> IO (Log Double, Log Double)
f (Log Double
val1,Log Double
val2) Int
e = do
        let lit :: Int
lit = Solver -> Vector Int
svEdgeLit Solver
solver forall (v :: * -> *) a. Vector v a => v a -> Int -> a
! Int
e
            a :: Int
a = Solver -> Vector Int
svEdgeClause Solver
solver forall (v :: * -> *) a. Vector v a => v a -> Int -> a
! Int
e
            w :: Double
w = Solver -> Vector Double
svClauseWeight Solver
solver forall (v :: * -> *) a. Vector v a => v a -> Int -> a
! Int
a
        Log Double
eta_ai <- forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.unsafeRead (Solver -> IOVector (Log Double)
svEdgeSurvey Solver
solver) Int
e
        let val1' :: Log Double
val1' = if Int
lit forall a. Ord a => a -> a -> Bool
> Int
0 then Log Double
val1 forall a. Num a => a -> a -> a
* forall a. RealFloat a => Log a -> Log a
comp Log Double
eta_ai forall a. Num a => Log a -> a -> Log a
^* Double
w else Log Double
val1
            val2' :: Log Double
val2' = if Int
lit forall a. Ord a => a -> a -> Bool
< Int
0 then Log Double
val2 forall a. Num a => a -> a -> a
* forall a. RealFloat a => Log a -> Log a
comp Log Double
eta_ai forall a. Num a => Log a -> a -> Log a
^* Double
w else Log Double
val2
        forall (m :: * -> *) a. Monad m => a -> m a
return (Log Double
val1',Log Double
val2')
  (Log Double
val1,Log Double
val2) <- forall (m :: * -> *) (v :: * -> *) b a.
(Monad m, Vector v b) =>
(a -> b -> m a) -> a -> v b -> m a
VG.foldM' (Log Double, Log Double) -> Int -> IO (Log Double, Log Double)
f (Log Double
1,Log Double
1) (Solver -> Vector (Vector Int)
svVarEdges Solver
solver forall (v :: * -> *) a. Vector v a => v a -> Int -> a
! Int
i)
  let p0 :: Log Double
p0 = Log Double
val1 forall a. Num a => a -> a -> a
* Log Double
val2       -- \^{Π}^{0}_i
      pp :: Log Double
pp = forall a. RealFloat a => Log a -> Log a
comp Log Double
val1 forall a. Num a => a -> a -> a
* Log Double
val2 -- \^{Π}^{+}_i
      pn :: Log Double
pn = forall a. RealFloat a => Log a -> Log a
comp Log Double
val2 forall a. Num a => a -> a -> a
* Log Double
val1 -- \^{Π}^{-}_i
  let wp :: Log Double
wp = Log Double
pp forall a. Fractional a => a -> a -> a
/ (Log Double
pp forall a. Num a => a -> a -> a
+ Log Double
pn forall a. Num a => a -> a -> a
+ Log Double
p0)
      wn :: Log Double
wn = Log Double
pn forall a. Fractional a => a -> a -> a
/ (Log Double
pp forall a. Num a => a -> a -> a
+ Log Double
pn forall a. Num a => a -> a -> a
+ Log Double
p0)
  forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.unsafeWrite (Solver -> IOVector (Log Double)
svVarProbT Solver
solver) Int
i Log Double
wp -- W^{(+)}_i
  forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.unsafeWrite (Solver -> IOVector (Log Double)
svVarProbF Solver
solver) Int
i Log Double
wn -- W^{(-)}_i

-- | Get the marginal probability of the variable to be @True@, @False@ and unspecified respectively.
getVarProb :: Solver -> SAT.Var -> IO (Double, Double, Double)
getVarProb :: Solver -> Int -> IO (Double, Double, Double)
getVarProb Solver
solver Int
v = do
  Double
pt <- forall a b. (Real a, Fractional b) => a -> b
realToFrac forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.unsafeRead (Solver -> IOVector (Log Double)
svVarProbT Solver
solver) (Int
v forall a. Num a => a -> a -> a
- Int
1)
  Double
pf <- forall a b. (Real a, Fractional b) => a -> b
realToFrac forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.unsafeRead (Solver -> IOVector (Log Double)
svVarProbF Solver
solver) (Int
v forall a. Num a => a -> a -> a
- Int
1)
  forall (m :: * -> *) a. Monad m => a -> m a
return (Double
pt, Double
pf, Double
1 forall a. Num a => a -> a -> a
- (Double
pt forall a. Num a => a -> a -> a
+ Double
pf))

fixLit :: Solver -> SAT.Lit -> IO ()
fixLit :: Solver -> Int -> IO ()
fixLit Solver
solver Int
lit = do
  forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.unsafeWrite (Solver -> IOVector (Maybe Bool)
svVarFixed Solver
solver) (forall a. Num a => a -> a
abs Int
lit forall a. Num a => a -> a -> a
- Int
1) (if Int
lit forall a. Ord a => a -> a -> Bool
> Int
0 then forall a. a -> Maybe a
Just Bool
True else forall a. a -> Maybe a
Just Bool
False)

unfixLit :: Solver -> SAT.Lit -> IO ()
unfixLit :: Solver -> Int -> IO ()
unfixLit Solver
solver Int
lit = do
  forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.unsafeWrite (Solver -> IOVector (Maybe Bool)
svVarFixed Solver
solver) (forall a. Num a => a -> a
abs Int
lit forall a. Num a => a -> a -> a
- Int
1) forall a. Maybe a
Nothing

printInfo :: Solver -> IO ()
printInfo :: Solver -> IO ()
printInfo Solver
solver = do
  (Vector (Log Double)
surveys :: VU.Vector (L.Log Double)) <- forall (m :: * -> *) (v :: * -> *) a.
(PrimMonad m, Vector v a) =>
Mutable v (PrimState m) a -> m (v a)
VG.freeze (Solver -> IOVector (Log Double)
svEdgeSurvey Solver
solver)
  (Vector (Log Double)
u :: VU.Vector (L.Log Double)) <- forall (m :: * -> *) (v :: * -> *) a.
(PrimMonad m, Vector v a) =>
Mutable v (PrimState m) a -> m (v a)
VG.freeze (Solver -> IOVector (Log Double)
svEdgeProbU Solver
solver)
  let xs :: [(Int, Int, Log Double, Log Double)]
xs = [(Int
clause, Int
lit, Log Double
eta, Vector (Log Double)
u forall (v :: * -> *) a. Vector v a => v a -> Int -> a
! Int
e) | (Int
e, Log Double
eta) <- forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0..] (forall (v :: * -> *) a. Vector v a => v a -> [a]
VG.toList Vector (Log Double)
surveys), let lit :: Int
lit = Solver -> Vector Int
svEdgeLit Solver
solver forall (v :: * -> *) a. Vector v a => v a -> Int -> a
! Int
e, let clause :: Int
clause = Solver -> Vector Int
svEdgeClause Solver
solver forall (v :: * -> *) a. Vector v a => v a -> Int -> a
! Int
e]
  String -> IO ()
putStrLn forall a b. (a -> b) -> a -> b
$ String
"edges: " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show [(Int, Int, Log Double, Log Double)]
xs

  (Vector (Log Double)
pt :: VU.Vector (L.Log Double)) <- forall (m :: * -> *) (v :: * -> *) a.
(PrimMonad m, Vector v a) =>
Mutable v (PrimState m) a -> m (v a)
VG.freeze (Solver -> IOVector (Log Double)
svVarProbT Solver
solver)
  (Vector (Log Double)
pf :: VU.Vector (L.Log Double)) <- forall (m :: * -> *) (v :: * -> *) a.
(PrimMonad m, Vector v a) =>
Mutable v (PrimState m) a -> m (v a)
VG.freeze (Solver -> IOVector (Log Double)
svVarProbF Solver
solver)
  Int
nv <- Solver -> IO Int
getNVars Solver
solver
  let xs2 :: [(Int, Double, Double, Double)]
xs2 = [(Int
v, forall a b. (Real a, Fractional b) => a -> b
realToFrac (Vector (Log Double)
pt forall (v :: * -> *) a. Vector v a => v a -> Int -> a
! Int
i) :: Double, forall a b. (Real a, Fractional b) => a -> b
realToFrac (Vector (Log Double)
pf forall (v :: * -> *) a. Vector v a => v a -> Int -> a
! Int
i) :: Double, forall a b. (Real a, Fractional b) => a -> b
realToFrac (Vector (Log Double)
pt forall (v :: * -> *) a. Vector v a => v a -> Int -> a
! Int
i) forall a. Num a => a -> a -> a
- forall a b. (Real a, Fractional b) => a -> b
realToFrac (Vector (Log Double)
pf forall (v :: * -> *) a. Vector v a => v a -> Int -> a
! Int
i) :: Double) | Int
v <- [Int
1..Int
nv], let i :: Int
i = Int
v forall a. Num a => a -> a -> a
- Int
1]
  String -> IO ()
putStrLn forall a b. (a -> b) -> a -> b
$ String
"vars: " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show [(Int, Double, Double, Double)]
xs2