{-# OPTIONS_GHC -Wall #-}
{-# OPTIONS_HADDOCK show-extensions #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}
-----------------------------------------------------------------------------
-- |
-- Module      :  ToySolver.SAT.Encoder.Cardinality.Internal.Totalizer
-- Copyright   :  (c) Masahiro Sakai 2020
-- License     :  BSD-style
--
-- Maintainer  :  masahiro.sakai@gmail.com
-- Stability   :  provisional
-- Portability :  non-portable
--
-----------------------------------------------------------------------------
module ToySolver.SAT.Encoder.Cardinality.Internal.Totalizer
  ( Encoder (..)
  , newEncoder

  , Definitions
  , getDefinitions
  , evalDefinitions

  , addAtLeast
  , encodeAtLeast

  , addCardinality
  , encodeCardinality

  , encodeSum
  ) where

import Control.Monad.Primitive
import Control.Monad.State.Strict
import qualified Data.IntSet as IntSet
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as Map
import Data.Primitive.MutVar
import Data.Vector.Unboxed (Vector)
import qualified Data.Vector.Unboxed as V
import qualified ToySolver.SAT.Types as SAT
import qualified ToySolver.SAT.Encoder.Tseitin as Tseitin


data Encoder m = Encoder (Tseitin.Encoder m) (MutVar (PrimState m) (Map SAT.LitSet (Vector SAT.Var)))

instance Monad m => SAT.NewVar m (Encoder m) where
  newVar :: Encoder m -> m Var
newVar   (Encoder Encoder m
a MutVar (PrimState m) (Map LitSet (Vector Var))
_) = forall (m :: * -> *) a. NewVar m a => a -> m Var
SAT.newVar Encoder m
a
  newVars :: Encoder m -> Var -> m [Var]
newVars  (Encoder Encoder m
a MutVar (PrimState m) (Map LitSet (Vector Var))
_) = forall (m :: * -> *) a. NewVar m a => a -> Var -> m [Var]
SAT.newVars Encoder m
a
  newVars_ :: Encoder m -> Var -> m ()
newVars_ (Encoder Encoder m
a MutVar (PrimState m) (Map LitSet (Vector Var))
_) = forall (m :: * -> *) a. NewVar m a => a -> Var -> m ()
SAT.newVars_ Encoder m
a

instance Monad m => SAT.AddClause m (Encoder m) where
  addClause :: Encoder m -> [Var] -> m ()
addClause (Encoder Encoder m
a MutVar (PrimState m) (Map LitSet (Vector Var))
_) = forall (m :: * -> *) a. AddClause m a => a -> [Var] -> m ()
SAT.addClause Encoder m
a

newEncoder :: PrimMonad m => Tseitin.Encoder m -> m (Encoder m)
newEncoder :: forall (m :: * -> *). PrimMonad m => Encoder m -> m (Encoder m)
newEncoder Encoder m
tseitin = do
  MutVar (PrimState m) (Map LitSet (Vector Var))
tableRef <- forall (m :: * -> *) a.
PrimMonad m =>
a -> m (MutVar (PrimState m) a)
newMutVar forall k a. Map k a
Map.empty
  forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
Encoder m
-> MutVar (PrimState m) (Map LitSet (Vector Var)) -> Encoder m
Encoder Encoder m
tseitin MutVar (PrimState m) (Map LitSet (Vector Var))
tableRef


type Definitions = [(Vector SAT.Var, SAT.LitSet)]

getDefinitions :: PrimMonad m => Encoder m -> m Definitions
getDefinitions :: forall (m :: * -> *). PrimMonad m => Encoder m -> m Definitions
getDefinitions (Encoder Encoder m
_ MutVar (PrimState m) (Map LitSet (Vector Var))
tableRef) = do
  Map LitSet (Vector Var)
m <- forall (m :: * -> *) a.
PrimMonad m =>
MutVar (PrimState m) a -> m a
readMutVar MutVar (PrimState m) (Map LitSet (Vector Var))
tableRef
  forall (m :: * -> *) a. Monad m => a -> m a
return [(Vector Var
vars', LitSet
lits) | (LitSet
lits, Vector Var
vars') <- forall k a. Map k a -> [(k, a)]
Map.toList Map LitSet (Vector Var)
m]


evalDefinitions :: SAT.IModel m => m -> Definitions -> [(SAT.Var, Bool)]
evalDefinitions :: forall m. IModel m => m -> Definitions -> [(Var, Bool)]
evalDefinitions m
m Definitions
defs = do
  (Vector Var
vars', LitSet
lits) <- Definitions
defs
  let n :: Var
n = forall (t :: * -> *) a. Foldable t => t a -> Var
length [() | Var
l <- LitSet -> [Var]
IntSet.toList LitSet
lits, forall m. IModel m => m -> Var -> Bool
SAT.evalLit m
m Var
l]
  (Var
i, Var
v) <- forall a b. [a] -> [b] -> [(a, b)]
zip [Var
1..] (forall a. Unbox a => Vector a -> [a]
V.toList Vector Var
vars')
  forall (m :: * -> *) a. Monad m => a -> m a
return (Var
v, Var
i forall a. Ord a => a -> a -> Bool
<= Var
n)


addAtLeast :: PrimMonad m => Encoder m -> SAT.AtLeast -> m ()
addAtLeast :: forall (m :: * -> *). PrimMonad m => Encoder m -> AtLeast -> m ()
addAtLeast Encoder m
enc ([Var]
lhs, Var
rhs) = do
  forall (m :: * -> *).
PrimMonad m =>
Encoder m -> [Var] -> (Var, Var) -> m ()
addCardinality Encoder m
enc [Var]
lhs (Var
rhs, forall (t :: * -> *) a. Foldable t => t a -> Var
length [Var]
lhs)


addCardinality :: PrimMonad m => Encoder m -> [SAT.Lit] -> (Int, Int) -> m ()
addCardinality :: forall (m :: * -> *).
PrimMonad m =>
Encoder m -> [Var] -> (Var, Var) -> m ()
addCardinality Encoder m
enc [Var]
lits (Var
lb, Var
ub) = do
  let n :: Var
n = forall (t :: * -> *) a. Foldable t => t a -> Var
length [Var]
lits
  if Var
lb forall a. Ord a => a -> a -> Bool
<= Var
0 Bool -> Bool -> Bool
&& Var
n forall a. Ord a => a -> a -> Bool
<= Var
ub then
    forall (m :: * -> *) a. Monad m => a -> m a
return ()
  else if Var
n forall a. Ord a => a -> a -> Bool
< Var
lb Bool -> Bool -> Bool
|| Var
ub forall a. Ord a => a -> a -> Bool
< Var
0 then
    forall (m :: * -> *) a. AddClause m a => a -> [Var] -> m ()
SAT.addClause Encoder m
enc []
  else do
    [Var]
lits' <- forall (m :: * -> *). PrimMonad m => Encoder m -> [Var] -> m [Var]
encodeSum Encoder m
enc [Var]
lits
    forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a. Var -> [a] -> [a]
take Var
lb [Var]
lits') forall a b. (a -> b) -> a -> b
$ \Var
l -> forall (m :: * -> *) a. AddClause m a => a -> [Var] -> m ()
SAT.addClause Encoder m
enc [Var
l]
    forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a. Var -> [a] -> [a]
drop Var
ub [Var]
lits') forall a b. (a -> b) -> a -> b
$ \Var
l -> forall (m :: * -> *) a. AddClause m a => a -> [Var] -> m ()
SAT.addClause Encoder m
enc [- Var
l]


-- TODO: consider polarity
encodeAtLeast :: PrimMonad m => Encoder m -> SAT.AtLeast -> m SAT.Lit
encodeAtLeast :: forall (m :: * -> *). PrimMonad m => Encoder m -> AtLeast -> m Var
encodeAtLeast Encoder m
enc ([Var]
lhs,Var
rhs) = do
  forall (m :: * -> *).
PrimMonad m =>
Encoder m -> [Var] -> (Var, Var) -> m Var
encodeCardinality Encoder m
enc [Var]
lhs (Var
rhs, forall (t :: * -> *) a. Foldable t => t a -> Var
length [Var]
lhs)


-- TODO: consider polarity
encodeCardinality :: PrimMonad m => Encoder m -> [SAT.Lit] -> (Int, Int) -> m SAT.Lit
encodeCardinality :: forall (m :: * -> *).
PrimMonad m =>
Encoder m -> [Var] -> (Var, Var) -> m Var
encodeCardinality enc :: Encoder m
enc@(Encoder Encoder m
tseitin MutVar (PrimState m) (Map LitSet (Vector Var))
_) [Var]
lits (Var
lb, Var
ub) = do
  let n :: Var
n = forall (t :: * -> *) a. Foldable t => t a -> Var
length [Var]
lits
  if Var
lb forall a. Ord a => a -> a -> Bool
<= Var
0 Bool -> Bool -> Bool
&& Var
n forall a. Ord a => a -> a -> Bool
<= Var
ub then
    forall (m :: * -> *). PrimMonad m => Encoder m -> [Var] -> m Var
Tseitin.encodeConj Encoder m
tseitin []
  else if Var
n forall a. Ord a => a -> a -> Bool
< Var
lb Bool -> Bool -> Bool
|| Var
ub forall a. Ord a => a -> a -> Bool
< Var
0 then
    forall (m :: * -> *). PrimMonad m => Encoder m -> [Var] -> m Var
Tseitin.encodeDisj Encoder m
tseitin []
  else do
    [Var]
lits' <- forall (m :: * -> *). PrimMonad m => Encoder m -> [Var] -> m [Var]
encodeSum Encoder m
enc [Var]
lits
    forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [Var]
lits' (forall a. [a] -> [a]
tail [Var]
lits')) forall a b. (a -> b) -> a -> b
$ \(Var
l1, Var
l2) -> do
      forall (m :: * -> *) a. AddClause m a => a -> [Var] -> m ()
SAT.addClause Encoder m
enc [-Var
l2, Var
l1] -- l2→l1 or equivalently ¬l1→¬l2
    forall (m :: * -> *). PrimMonad m => Encoder m -> [Var] -> m Var
Tseitin.encodeConj Encoder m
tseitin forall a b. (a -> b) -> a -> b
$
      [[Var]
lits' forall a. [a] -> Var -> a
!! (Var
lb forall a. Num a => a -> a -> a
- Var
1) | Var
lb forall a. Ord a => a -> a -> Bool
> Var
0] forall a. [a] -> [a] -> [a]
++ [- ([Var]
lits' forall a. [a] -> Var -> a
!! (Var
ub forall a. Num a => a -> a -> a
+ Var
1 forall a. Num a => a -> a -> a
- Var
1)) | Var
ub forall a. Ord a => a -> a -> Bool
< Var
n]


encodeSum :: PrimMonad m => Encoder m -> [SAT.Lit] -> m [SAT.Lit]
encodeSum :: forall (m :: * -> *). PrimMonad m => Encoder m -> [Var] -> m [Var]
encodeSum Encoder m
enc = forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM forall a. Unbox a => Vector a -> [a]
V.toList forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
PrimMonad m =>
Encoder m -> Vector Var -> m (Vector Var)
encodeSumV Encoder m
enc forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Unbox a => [a] -> Vector a
V.fromList


encodeSumV :: PrimMonad m => Encoder m -> Vector SAT.Lit -> m (Vector SAT.Lit)
encodeSumV :: forall (m :: * -> *).
PrimMonad m =>
Encoder m -> Vector Var -> m (Vector Var)
encodeSumV (Encoder Encoder m
enc MutVar (PrimState m) (Map LitSet (Vector Var))
tableRef) = Vector Var -> m (Vector Var)
f
  where
    f :: Vector Var -> m (Vector Var)
f Vector Var
lits
      | Var
n forall a. Ord a => a -> a -> Bool
<= Var
1 = forall (m :: * -> *) a. Monad m => a -> m a
return Vector Var
lits
      | Bool
otherwise = do
          Map LitSet (Vector Var)
m <- forall (m :: * -> *) a.
PrimMonad m =>
MutVar (PrimState m) a -> m a
readMutVar MutVar (PrimState m) (Map LitSet (Vector Var))
tableRef
          let key :: LitSet
key = [Var] -> LitSet
IntSet.fromList (forall a. Unbox a => Vector a -> [a]
V.toList Vector Var
lits)
          case forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup LitSet
key Map LitSet (Vector Var)
m of
            Just Vector Var
vars -> forall (m :: * -> *) a. Monad m => a -> m a
return Vector Var
vars
            Maybe (Vector Var)
Nothing -> do
              Vector Var
rs <- forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM forall a. Unbox a => [a] -> Vector a
V.fromList forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a. NewVar m a => a -> Var -> m [Var]
SAT.newVars Encoder m
enc Var
n
              forall (m :: * -> *) a.
PrimMonad m =>
MutVar (PrimState m) a -> a -> m ()
writeMutVar MutVar (PrimState m) (Map LitSet (Vector Var))
tableRef (forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert LitSet
key Vector Var
rs Map LitSet (Vector Var)
m)
              case forall a. Unbox a => Var -> Vector a -> (Vector a, Vector a)
V.splitAt Var
n1 Vector Var
lits of
                (Vector Var
lits1, Vector Var
lits2) -> do
                  Vector Var
lits1' <- Vector Var -> m (Vector Var)
f Vector Var
lits1
                  Vector Var
lits2' <- Vector Var -> m (Vector Var)
f Vector Var
lits2
                  forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Var
0 .. Var
n] forall a b. (a -> b) -> a -> b
$ \Var
sigma ->
                    -- a + b = sigma, 0 <= a <= n1, 0 <= b <= n2
                    forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [forall a. Ord a => a -> a -> a
max Var
0 (Var
sigma forall a. Num a => a -> a -> a
- Var
n2) .. forall a. Ord a => a -> a -> a
min Var
n1 Var
sigma] forall a b. (a -> b) -> a -> b
$ \Var
a -> do
                      let b :: Var
b = Var
sigma forall a. Num a => a -> a -> a
- Var
a
                      -- card(lits1) >= a ∧ card(lits2) >= b → card(lits) >= sigma
                      -- ¬(card(lits1) >= a) ∨ ¬(card(lits2) >= b) ∨ card(lits) >= sigma
                      forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Var
sigma forall a. Eq a => a -> a -> Bool
== Var
0) forall a b. (a -> b) -> a -> b
$ do
                        forall (m :: * -> *) a. AddClause m a => a -> [Var] -> m ()
SAT.addClause Encoder m
enc forall a b. (a -> b) -> a -> b
$
                          [- (Vector Var
lits1' forall a. Unbox a => Vector a -> Var -> a
V.! (Var
a forall a. Num a => a -> a -> a
- Var
1)) | Var
a forall a. Ord a => a -> a -> Bool
> Var
0] forall a. [a] -> [a] -> [a]
++
                          [- (Vector Var
lits2' forall a. Unbox a => Vector a -> Var -> a
V.! (Var
b forall a. Num a => a -> a -> a
- Var
1)) | Var
b forall a. Ord a => a -> a -> Bool
> Var
0] forall a. [a] -> [a] -> [a]
++
                          [Vector Var
rs forall a. Unbox a => Vector a -> Var -> a
V.! (Var
sigma forall a. Num a => a -> a -> a
- Var
1)]
                      -- card(lits) > sigma → (card(lits1) > a ∨ card(lits2) > b)
                      -- card(lits) >= sigma+1 → (card(lits1) >= a+1 ∨ card(lits2) >= b+1)
                      -- card(lits1) >= a+1 ∨ card(lits2) >= b+1 ∨ ¬(card(lits) >= sigma+1)
                      forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Var
sigma forall a. Num a => a -> a -> a
+ Var
1 forall a. Eq a => a -> a -> Bool
== Var
n forall a. Num a => a -> a -> a
+ Var
1) forall a b. (a -> b) -> a -> b
$ do
                        forall (m :: * -> *) a. AddClause m a => a -> [Var] -> m ()
SAT.addClause Encoder m
enc forall a b. (a -> b) -> a -> b
$
                          [Vector Var
lits1' forall a. Unbox a => Vector a -> Var -> a
V.! (Var
a forall a. Num a => a -> a -> a
+ Var
1 forall a. Num a => a -> a -> a
- Var
1) | Var
a forall a. Num a => a -> a -> a
+ Var
1 forall a. Ord a => a -> a -> Bool
< Var
n1 forall a. Num a => a -> a -> a
+ Var
1] forall a. [a] -> [a] -> [a]
++
                          [Vector Var
lits2' forall a. Unbox a => Vector a -> Var -> a
V.! (Var
b forall a. Num a => a -> a -> a
+ Var
1 forall a. Num a => a -> a -> a
- Var
1) | Var
b forall a. Num a => a -> a -> a
+ Var
1 forall a. Ord a => a -> a -> Bool
< Var
n2 forall a. Num a => a -> a -> a
+ Var
1] forall a. [a] -> [a] -> [a]
++
                          [- (Vector Var
rs forall a. Unbox a => Vector a -> Var -> a
V.! (Var
sigma forall a. Num a => a -> a -> a
+ Var
1 forall a. Num a => a -> a -> a
- Var
1))]
                  forall (m :: * -> *) a. Monad m => a -> m a
return Vector Var
rs
     where
       n :: Var
n = forall a. Unbox a => Vector a -> Var
V.length Vector Var
lits
       n1 :: Var
n1 = Var
n forall a. Integral a => a -> a -> a
`div` Var
2
       n2 :: Var
n2 = Var
n forall a. Num a => a -> a -> a
- Var
n1