{-# OPTIONS_GHC -Wall #-}
{-# OPTIONS_HADDOCK show-extensions #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}
-----------------------------------------------------------------------------
-- |
-- Module      :  ToySolver.SAT.Encoder.Cardinality.Internal.Totalizer
-- Copyright   :  (c) Masahiro Sakai 2020
-- License     :  BSD-style
--
-- Maintainer  :  masahiro.sakai@gmail.com
-- Stability   :  provisional
-- Portability :  non-portable
--
-----------------------------------------------------------------------------
module ToySolver.SAT.Encoder.Cardinality.Internal.Totalizer
  ( Encoder (..)
  , newEncoder

  , Definitions
  , getDefinitions
  , evalDefinitions

  , addAtLeast
  , encodeAtLeast

  , addCardinality
  , encodeCardinality

  , encodeSum
  ) where

import Control.Monad.Primitive
import Control.Monad.State.Strict
import qualified Data.IntSet as IntSet
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as Map
import Data.Primitive.MutVar
import Data.Vector.Unboxed (Vector)
import qualified Data.Vector.Unboxed as V
import qualified ToySolver.SAT.Types as SAT
import qualified ToySolver.SAT.Encoder.Tseitin as Tseitin


data Encoder m = Encoder (Tseitin.Encoder m) (MutVar (PrimState m) (Map SAT.LitSet (Vector SAT.Var)))

instance Monad m => SAT.NewVar m (Encoder m) where
  newVar :: Encoder m -> m Var
newVar   (Encoder Encoder m
a MutVar (PrimState m) (Map LitSet (Vector Var))
_) = Encoder m -> m Var
forall (m :: * -> *) a. NewVar m a => a -> m Var
SAT.newVar Encoder m
a
  newVars :: Encoder m -> Var -> m [Var]
newVars  (Encoder Encoder m
a MutVar (PrimState m) (Map LitSet (Vector Var))
_) = Encoder m -> Var -> m [Var]
forall (m :: * -> *) a. NewVar m a => a -> Var -> m [Var]
SAT.newVars Encoder m
a
  newVars_ :: Encoder m -> Var -> m ()
newVars_ (Encoder Encoder m
a MutVar (PrimState m) (Map LitSet (Vector Var))
_) = Encoder m -> Var -> m ()
forall (m :: * -> *) a. NewVar m a => a -> Var -> m ()
SAT.newVars_ Encoder m
a

instance Monad m => SAT.AddClause m (Encoder m) where
  addClause :: Encoder m -> [Var] -> m ()
addClause (Encoder Encoder m
a MutVar (PrimState m) (Map LitSet (Vector Var))
_) = Encoder m -> [Var] -> m ()
forall (m :: * -> *) a. AddClause m a => a -> [Var] -> m ()
SAT.addClause Encoder m
a

newEncoder :: PrimMonad m => Tseitin.Encoder m -> m (Encoder m)
newEncoder :: Encoder m -> m (Encoder m)
newEncoder Encoder m
tseitin = do
  MutVar (PrimState m) (Map LitSet (Vector Var))
tableRef <- Map LitSet (Vector Var)
-> m (MutVar (PrimState m) (Map LitSet (Vector Var)))
forall (m :: * -> *) a.
PrimMonad m =>
a -> m (MutVar (PrimState m) a)
newMutVar Map LitSet (Vector Var)
forall k a. Map k a
Map.empty
  Encoder m -> m (Encoder m)
forall (m :: * -> *) a. Monad m => a -> m a
return (Encoder m -> m (Encoder m)) -> Encoder m -> m (Encoder m)
forall a b. (a -> b) -> a -> b
$ Encoder m
-> MutVar (PrimState m) (Map LitSet (Vector Var)) -> Encoder m
forall (m :: * -> *).
Encoder m
-> MutVar (PrimState m) (Map LitSet (Vector Var)) -> Encoder m
Encoder Encoder m
tseitin MutVar (PrimState m) (Map LitSet (Vector Var))
tableRef


type Definitions = [(Vector SAT.Var, SAT.LitSet)]

getDefinitions :: PrimMonad m => Encoder m -> m Definitions
getDefinitions :: Encoder m -> m Definitions
getDefinitions (Encoder Encoder m
_ MutVar (PrimState m) (Map LitSet (Vector Var))
tableRef) = do
  Map LitSet (Vector Var)
m <- MutVar (PrimState m) (Map LitSet (Vector Var))
-> m (Map LitSet (Vector Var))
forall (m :: * -> *) a.
PrimMonad m =>
MutVar (PrimState m) a -> m a
readMutVar MutVar (PrimState m) (Map LitSet (Vector Var))
tableRef
  Definitions -> m Definitions
forall (m :: * -> *) a. Monad m => a -> m a
return [(Vector Var
vars', LitSet
lits) | (LitSet
lits, Vector Var
vars') <- Map LitSet (Vector Var) -> [(LitSet, Vector Var)]
forall k a. Map k a -> [(k, a)]
Map.toList Map LitSet (Vector Var)
m]


evalDefinitions :: SAT.IModel m => m -> Definitions -> [(SAT.Var, Bool)]
evalDefinitions :: m -> Definitions -> [(Var, Bool)]
evalDefinitions m
m Definitions
defs = do
  (Vector Var
vars', LitSet
lits) <- Definitions
defs
  let n :: Var
n = [()] -> Var
forall (t :: * -> *) a. Foldable t => t a -> Var
length [() | Var
l <- LitSet -> [Var]
IntSet.toList LitSet
lits, m -> Var -> Bool
forall m. IModel m => m -> Var -> Bool
SAT.evalLit m
m Var
l]
  (Var
i, Var
v) <- [Var] -> [Var] -> [(Var, Var)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Var
1..] (Vector Var -> [Var]
forall a. Unbox a => Vector a -> [a]
V.toList Vector Var
vars')
  (Var, Bool) -> [(Var, Bool)]
forall (m :: * -> *) a. Monad m => a -> m a
return (Var
v, Var
i Var -> Var -> Bool
forall a. Ord a => a -> a -> Bool
<= Var
n)


addAtLeast :: PrimMonad m => Encoder m -> SAT.AtLeast -> m ()
addAtLeast :: Encoder m -> AtLeast -> m ()
addAtLeast Encoder m
enc ([Var]
lhs, Var
rhs) = do
  Encoder m -> [Var] -> (Var, Var) -> m ()
forall (m :: * -> *).
PrimMonad m =>
Encoder m -> [Var] -> (Var, Var) -> m ()
addCardinality Encoder m
enc [Var]
lhs (Var
rhs, [Var] -> Var
forall (t :: * -> *) a. Foldable t => t a -> Var
length [Var]
lhs)


addCardinality :: PrimMonad m => Encoder m -> [SAT.Lit] -> (Int, Int) -> m ()
addCardinality :: Encoder m -> [Var] -> (Var, Var) -> m ()
addCardinality Encoder m
enc [Var]
lits (Var
lb, Var
ub) = do
  let n :: Var
n = [Var] -> Var
forall (t :: * -> *) a. Foldable t => t a -> Var
length [Var]
lits
  if Var
lb Var -> Var -> Bool
forall a. Ord a => a -> a -> Bool
<= Var
0 Bool -> Bool -> Bool
&& Var
n Var -> Var -> Bool
forall a. Ord a => a -> a -> Bool
<= Var
ub then
    () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
  else if Var
n Var -> Var -> Bool
forall a. Ord a => a -> a -> Bool
< Var
lb Bool -> Bool -> Bool
|| Var
ub Var -> Var -> Bool
forall a. Ord a => a -> a -> Bool
< Var
0 then
    Encoder m -> [Var] -> m ()
forall (m :: * -> *) a. AddClause m a => a -> [Var] -> m ()
SAT.addClause Encoder m
enc []
  else do
    [Var]
lits' <- Encoder m -> [Var] -> m [Var]
forall (m :: * -> *). PrimMonad m => Encoder m -> [Var] -> m [Var]
encodeSum Encoder m
enc [Var]
lits
    [Var] -> (Var -> m ()) -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (Var -> [Var] -> [Var]
forall a. Var -> [a] -> [a]
take Var
lb [Var]
lits') ((Var -> m ()) -> m ()) -> (Var -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Var
l -> Encoder m -> [Var] -> m ()
forall (m :: * -> *) a. AddClause m a => a -> [Var] -> m ()
SAT.addClause Encoder m
enc [Var
l]
    [Var] -> (Var -> m ()) -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (Var -> [Var] -> [Var]
forall a. Var -> [a] -> [a]
drop Var
ub [Var]
lits') ((Var -> m ()) -> m ()) -> (Var -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Var
l -> Encoder m -> [Var] -> m ()
forall (m :: * -> *) a. AddClause m a => a -> [Var] -> m ()
SAT.addClause Encoder m
enc [- Var
l]


-- TODO: consider polarity
encodeAtLeast :: PrimMonad m => Encoder m -> SAT.AtLeast -> m SAT.Lit
encodeAtLeast :: Encoder m -> AtLeast -> m Var
encodeAtLeast Encoder m
enc ([Var]
lhs,Var
rhs) = do
  Encoder m -> [Var] -> (Var, Var) -> m Var
forall (m :: * -> *).
PrimMonad m =>
Encoder m -> [Var] -> (Var, Var) -> m Var
encodeCardinality Encoder m
enc [Var]
lhs (Var
rhs, [Var] -> Var
forall (t :: * -> *) a. Foldable t => t a -> Var
length [Var]
lhs)


-- TODO: consider polarity
encodeCardinality :: PrimMonad m => Encoder m -> [SAT.Lit] -> (Int, Int) -> m SAT.Lit
encodeCardinality :: Encoder m -> [Var] -> (Var, Var) -> m Var
encodeCardinality enc :: Encoder m
enc@(Encoder Encoder m
tseitin MutVar (PrimState m) (Map LitSet (Vector Var))
_) [Var]
lits (Var
lb, Var
ub) = do
  let n :: Var
n = [Var] -> Var
forall (t :: * -> *) a. Foldable t => t a -> Var
length [Var]
lits
  if Var
lb Var -> Var -> Bool
forall a. Ord a => a -> a -> Bool
<= Var
0 Bool -> Bool -> Bool
&& Var
n Var -> Var -> Bool
forall a. Ord a => a -> a -> Bool
<= Var
ub then
    Encoder m -> [Var] -> m Var
forall (m :: * -> *). PrimMonad m => Encoder m -> [Var] -> m Var
Tseitin.encodeConj Encoder m
tseitin []
  else if Var
n Var -> Var -> Bool
forall a. Ord a => a -> a -> Bool
< Var
lb Bool -> Bool -> Bool
|| Var
ub Var -> Var -> Bool
forall a. Ord a => a -> a -> Bool
< Var
0 then
    Encoder m -> [Var] -> m Var
forall (m :: * -> *). PrimMonad m => Encoder m -> [Var] -> m Var
Tseitin.encodeDisj Encoder m
tseitin []
  else do
    [Var]
lits' <- Encoder m -> [Var] -> m [Var]
forall (m :: * -> *). PrimMonad m => Encoder m -> [Var] -> m [Var]
encodeSum Encoder m
enc [Var]
lits
    [(Var, Var)] -> ((Var, Var) -> m ()) -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Var] -> [Var] -> [(Var, Var)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Var]
lits' ([Var] -> [Var]
forall a. [a] -> [a]
tail [Var]
lits')) (((Var, Var) -> m ()) -> m ()) -> ((Var, Var) -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \(Var
l1, Var
l2) -> do
      Encoder m -> [Var] -> m ()
forall (m :: * -> *) a. AddClause m a => a -> [Var] -> m ()
SAT.addClause Encoder m
enc [-Var
l2, Var
l1] -- l2→l1 or equivalently ¬l1→¬l2
    Encoder m -> [Var] -> m Var
forall (m :: * -> *). PrimMonad m => Encoder m -> [Var] -> m Var
Tseitin.encodeConj Encoder m
tseitin ([Var] -> m Var) -> [Var] -> m Var
forall a b. (a -> b) -> a -> b
$
      [[Var]
lits' [Var] -> Var -> Var
forall a. [a] -> Var -> a
!! (Var
lb Var -> Var -> Var
forall a. Num a => a -> a -> a
- Var
1) | Var
lb Var -> Var -> Bool
forall a. Ord a => a -> a -> Bool
> Var
0] [Var] -> [Var] -> [Var]
forall a. [a] -> [a] -> [a]
++ [- ([Var]
lits' [Var] -> Var -> Var
forall a. [a] -> Var -> a
!! (Var
ub Var -> Var -> Var
forall a. Num a => a -> a -> a
+ Var
1 Var -> Var -> Var
forall a. Num a => a -> a -> a
- Var
1)) | Var
ub Var -> Var -> Bool
forall a. Ord a => a -> a -> Bool
< Var
n]


encodeSum :: PrimMonad m => Encoder m -> [SAT.Lit] -> m [SAT.Lit]
encodeSum :: Encoder m -> [Var] -> m [Var]
encodeSum Encoder m
enc = (Vector Var -> [Var]) -> m (Vector Var) -> m [Var]
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM Vector Var -> [Var]
forall a. Unbox a => Vector a -> [a]
V.toList (m (Vector Var) -> m [Var])
-> ([Var] -> m (Vector Var)) -> [Var] -> m [Var]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Encoder m -> Vector Var -> m (Vector Var)
forall (m :: * -> *).
PrimMonad m =>
Encoder m -> Vector Var -> m (Vector Var)
encodeSumV Encoder m
enc (Vector Var -> m (Vector Var))
-> ([Var] -> Vector Var) -> [Var] -> m (Vector Var)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Var] -> Vector Var
forall a. Unbox a => [a] -> Vector a
V.fromList


encodeSumV :: PrimMonad m => Encoder m -> Vector SAT.Lit -> m (Vector SAT.Lit)
encodeSumV :: Encoder m -> Vector Var -> m (Vector Var)
encodeSumV (Encoder Encoder m
enc MutVar (PrimState m) (Map LitSet (Vector Var))
tableRef) = Vector Var -> m (Vector Var)
f
  where
    f :: Vector Var -> m (Vector Var)
f Vector Var
lits
      | Var
n Var -> Var -> Bool
forall a. Ord a => a -> a -> Bool
<= Var
1 = Vector Var -> m (Vector Var)
forall (m :: * -> *) a. Monad m => a -> m a
return Vector Var
lits
      | Bool
otherwise = do
          Map LitSet (Vector Var)
m <- MutVar (PrimState m) (Map LitSet (Vector Var))
-> m (Map LitSet (Vector Var))
forall (m :: * -> *) a.
PrimMonad m =>
MutVar (PrimState m) a -> m a
readMutVar MutVar (PrimState m) (Map LitSet (Vector Var))
tableRef
          let key :: LitSet
key = [Var] -> LitSet
IntSet.fromList (Vector Var -> [Var]
forall a. Unbox a => Vector a -> [a]
V.toList Vector Var
lits)
          case LitSet -> Map LitSet (Vector Var) -> Maybe (Vector Var)
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup LitSet
key Map LitSet (Vector Var)
m of
            Just Vector Var
vars -> Vector Var -> m (Vector Var)
forall (m :: * -> *) a. Monad m => a -> m a
return Vector Var
vars
            Maybe (Vector Var)
Nothing -> do
              Vector Var
rs <- ([Var] -> Vector Var) -> m [Var] -> m (Vector Var)
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM [Var] -> Vector Var
forall a. Unbox a => [a] -> Vector a
V.fromList (m [Var] -> m (Vector Var)) -> m [Var] -> m (Vector Var)
forall a b. (a -> b) -> a -> b
$ Encoder m -> Var -> m [Var]
forall (m :: * -> *) a. NewVar m a => a -> Var -> m [Var]
SAT.newVars Encoder m
enc Var
n
              MutVar (PrimState m) (Map LitSet (Vector Var))
-> Map LitSet (Vector Var) -> m ()
forall (m :: * -> *) a.
PrimMonad m =>
MutVar (PrimState m) a -> a -> m ()
writeMutVar MutVar (PrimState m) (Map LitSet (Vector Var))
tableRef (LitSet
-> Vector Var -> Map LitSet (Vector Var) -> Map LitSet (Vector Var)
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert LitSet
key Vector Var
rs Map LitSet (Vector Var)
m)
              case Var -> Vector Var -> (Vector Var, Vector Var)
forall a. Unbox a => Var -> Vector a -> (Vector a, Vector a)
V.splitAt Var
n1 Vector Var
lits of
                (Vector Var
lits1, Vector Var
lits2) -> do
                  Vector Var
lits1' <- Vector Var -> m (Vector Var)
f Vector Var
lits1
                  Vector Var
lits2' <- Vector Var -> m (Vector Var)
f Vector Var
lits2
                  [Var] -> (Var -> m ()) -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Var
0 .. Var
n] ((Var -> m ()) -> m ()) -> (Var -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Var
sigma ->
                    -- a + b = sigma, 0 <= a <= n1, 0 <= b <= n2
                    [Var] -> (Var -> m ()) -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Var -> Var -> Var
forall a. Ord a => a -> a -> a
max Var
0 (Var
sigma Var -> Var -> Var
forall a. Num a => a -> a -> a
- Var
n2) .. Var -> Var -> Var
forall a. Ord a => a -> a -> a
min Var
n1 Var
sigma] ((Var -> m ()) -> m ()) -> (Var -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Var
a -> do
                      let b :: Var
b = Var
sigma Var -> Var -> Var
forall a. Num a => a -> a -> a
- Var
a
                      -- card(lits1) >= a ∧ card(lits2) >= b → card(lits) >= sigma
                      -- ¬(card(lits1) >= a) ∨ ¬(card(lits2) >= b) ∨ card(lits) >= sigma
                      Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Var
sigma Var -> Var -> Bool
forall a. Eq a => a -> a -> Bool
== Var
0) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
                        Encoder m -> [Var] -> m ()
forall (m :: * -> *) a. AddClause m a => a -> [Var] -> m ()
SAT.addClause Encoder m
enc ([Var] -> m ()) -> [Var] -> m ()
forall a b. (a -> b) -> a -> b
$
                          [- (Vector Var
lits1' Vector Var -> Var -> Var
forall a. Unbox a => Vector a -> Var -> a
V.! (Var
a Var -> Var -> Var
forall a. Num a => a -> a -> a
- Var
1)) | Var
a Var -> Var -> Bool
forall a. Ord a => a -> a -> Bool
> Var
0] [Var] -> [Var] -> [Var]
forall a. [a] -> [a] -> [a]
++
                          [- (Vector Var
lits2' Vector Var -> Var -> Var
forall a. Unbox a => Vector a -> Var -> a
V.! (Var
b Var -> Var -> Var
forall a. Num a => a -> a -> a
- Var
1)) | Var
b Var -> Var -> Bool
forall a. Ord a => a -> a -> Bool
> Var
0] [Var] -> [Var] -> [Var]
forall a. [a] -> [a] -> [a]
++
                          [Vector Var
rs Vector Var -> Var -> Var
forall a. Unbox a => Vector a -> Var -> a
V.! (Var
sigma Var -> Var -> Var
forall a. Num a => a -> a -> a
- Var
1)]
                      -- card(lits) > sigma → (card(lits1) > a ∨ card(lits2) > b)
                      -- card(lits) >= sigma+1 → (card(lits1) >= a+1 ∨ card(lits2) >= b+1)
                      -- card(lits1) >= a+1 ∨ card(lits2) >= b+1 ∨ ¬(card(lits) >= sigma+1)
                      Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Var
sigma Var -> Var -> Var
forall a. Num a => a -> a -> a
+ Var
1 Var -> Var -> Bool
forall a. Eq a => a -> a -> Bool
== Var
n Var -> Var -> Var
forall a. Num a => a -> a -> a
+ Var
1) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
                        Encoder m -> [Var] -> m ()
forall (m :: * -> *) a. AddClause m a => a -> [Var] -> m ()
SAT.addClause Encoder m
enc ([Var] -> m ()) -> [Var] -> m ()
forall a b. (a -> b) -> a -> b
$
                          [Vector Var
lits1' Vector Var -> Var -> Var
forall a. Unbox a => Vector a -> Var -> a
V.! (Var
a Var -> Var -> Var
forall a. Num a => a -> a -> a
+ Var
1 Var -> Var -> Var
forall a. Num a => a -> a -> a
- Var
1) | Var
a Var -> Var -> Var
forall a. Num a => a -> a -> a
+ Var
1 Var -> Var -> Bool
forall a. Ord a => a -> a -> Bool
< Var
n1 Var -> Var -> Var
forall a. Num a => a -> a -> a
+ Var
1] [Var] -> [Var] -> [Var]
forall a. [a] -> [a] -> [a]
++
                          [Vector Var
lits2' Vector Var -> Var -> Var
forall a. Unbox a => Vector a -> Var -> a
V.! (Var
b Var -> Var -> Var
forall a. Num a => a -> a -> a
+ Var
1 Var -> Var -> Var
forall a. Num a => a -> a -> a
- Var
1) | Var
b Var -> Var -> Var
forall a. Num a => a -> a -> a
+ Var
1 Var -> Var -> Bool
forall a. Ord a => a -> a -> Bool
< Var
n2 Var -> Var -> Var
forall a. Num a => a -> a -> a
+ Var
1] [Var] -> [Var] -> [Var]
forall a. [a] -> [a] -> [a]
++
                          [- (Vector Var
rs Vector Var -> Var -> Var
forall a. Unbox a => Vector a -> Var -> a
V.! (Var
sigma Var -> Var -> Var
forall a. Num a => a -> a -> a
+ Var
1 Var -> Var -> Var
forall a. Num a => a -> a -> a
- Var
1))]
                  Vector Var -> m (Vector Var)
forall (m :: * -> *) a. Monad m => a -> m a
return Vector Var
rs
     where
       n :: Var
n = Vector Var -> Var
forall a. Unbox a => Vector a -> Var
V.length Vector Var
lits
       n1 :: Var
n1 = Var
n Var -> Var -> Var
forall a. Integral a => a -> a -> a
`div` Var
2
       n2 :: Var
n2 = Var
n Var -> Var -> Var
forall a. Num a => a -> a -> a
- Var
n1