{-# OPTIONS_GHC -Wall #-}
{-# OPTIONS_HADDOCK show-extensions #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE ScopedTypeVariables #-}
-----------------------------------------------------------------------------
-- |
-- Module      :  ToySolver.SAT.Encoder.Cardinality.Internal.ParallelCounter
-- Copyright   :  (c) Masahiro Sakai 2019
-- License     :  BSD-style
--
-- Maintainer  :  masahiro.sakai@gmail.com
-- Stability   :  provisional
-- Portability :  non-portable
--
-----------------------------------------------------------------------------
module ToySolver.SAT.Encoder.Cardinality.Internal.ParallelCounter
  ( addAtLeastParallelCounter
  , encodeAtLeastParallelCounter
  ) where

import Control.Monad.Primitive
import Control.Monad.State.Strict
import Data.Bits
import Data.Vector (Vector)
import qualified Data.Vector as V
import qualified ToySolver.SAT.Types as SAT
import qualified ToySolver.SAT.Encoder.Tseitin as Tseitin

addAtLeastParallelCounter :: PrimMonad m => Tseitin.Encoder m -> SAT.AtLeast -> m ()
addAtLeastParallelCounter :: forall (m :: * -> *). PrimMonad m => Encoder m -> AtLeast -> m ()
addAtLeastParallelCounter Encoder m
enc AtLeast
constr = do
  Lit
l <- forall (m :: * -> *). PrimMonad m => Encoder m -> AtLeast -> m Lit
encodeAtLeastParallelCounter Encoder m
enc AtLeast
constr
  forall (m :: * -> *) a. AddClause m a => a -> Clause -> m ()
SAT.addClause Encoder m
enc [Lit
l]

-- TODO: consider polarity
encodeAtLeastParallelCounter :: forall m. PrimMonad m => Tseitin.Encoder m -> SAT.AtLeast -> m SAT.Lit
encodeAtLeastParallelCounter :: forall (m :: * -> *). PrimMonad m => Encoder m -> AtLeast -> m Lit
encodeAtLeastParallelCounter Encoder m
enc (Clause
lhs,Lit
rhs) = do
  if Lit
rhs forall a. Ord a => a -> a -> Bool
<= Lit
0 then
    forall (m :: * -> *). PrimMonad m => Encoder m -> Clause -> m Lit
Tseitin.encodeConj Encoder m
enc []
  else if forall (t :: * -> *) a. Foldable t => t a -> Lit
length Clause
lhs forall a. Ord a => a -> a -> Bool
< Lit
rhs then
    forall (m :: * -> *). PrimMonad m => Encoder m -> Clause -> m Lit
Tseitin.encodeDisj Encoder m
enc []
  else do
    let rhs_bits :: [Bool]
rhs_bits = Integer -> [Bool]
bits (forall a b. (Integral a, Num b) => a -> b
fromIntegral Lit
rhs)
    (Clause
cnt, Clause
overflowBits) <- forall (m :: * -> *).
PrimMonad m =>
Encoder m -> Lit -> Clause -> m (Clause, Clause)
encodeSumParallelCounter Encoder m
enc (forall (t :: * -> *) a. Foldable t => t a -> Lit
length [Bool]
rhs_bits) Clause
lhs
    Lit
isGE <- forall (m :: * -> *).
PrimMonad m =>
Encoder m -> Clause -> [Bool] -> m Lit
encodeGE Encoder m
enc Clause
cnt [Bool]
rhs_bits
    forall (m :: * -> *). PrimMonad m => Encoder m -> Clause -> m Lit
Tseitin.encodeDisj Encoder m
enc forall a b. (a -> b) -> a -> b
$ Lit
isGE forall a. a -> [a] -> [a]
: Clause
overflowBits
  where
    bits :: Integer -> [Bool]
    bits :: Integer -> [Bool]
bits Integer
n = forall {t}. (Num t, Bits t) => t -> Lit -> [Bool]
f Integer
n Lit
0
      where
        f :: t -> Lit -> [Bool]
f t
0 !Lit
_ = []
        f t
n Lit
i = forall a. Bits a => a -> Lit -> Bool
testBit t
n Lit
i forall a. a -> [a] -> [a]
: t -> Lit -> [Bool]
f (forall a. Bits a => a -> Lit -> a
clearBit t
n Lit
i) (Lit
iforall a. Num a => a -> a -> a
+Lit
1)

encodeSumParallelCounter :: forall m. PrimMonad m => Tseitin.Encoder m -> Int -> [SAT.Lit] -> m ([SAT.Lit], [SAT.Lit])
encodeSumParallelCounter :: forall (m :: * -> *).
PrimMonad m =>
Encoder m -> Lit -> Clause -> m (Clause, Clause)
encodeSumParallelCounter Encoder m
enc Lit
w Clause
lits = do
  let add :: [SAT.Lit] -> [SAT.Lit] -> SAT.Lit -> StateT [SAT.Lit] m [SAT.Lit]
      add :: Clause -> Clause -> Lit -> StateT Clause m Clause
add = Lit -> Clause -> Clause -> Clause -> Lit -> StateT Clause m Clause
go Lit
0 []
        where
          go :: Int -> [SAT.Lit] -> [SAT.Lit] -> [SAT.Lit] -> SAT.Lit -> StateT [SAT.Lit] m [SAT.Lit]
          go :: Lit -> Clause -> Clause -> Clause -> Lit -> StateT Clause m Clause
go Lit
i Clause
ret Clause
_xs Clause
_ys Lit
c | Lit
i forall a. Eq a => a -> a -> Bool
== Lit
w = do
            forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (Lit
cforall a. a -> [a] -> [a]
:)
            forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a. [a] -> [a]
reverse Clause
ret
          go Lit
_i Clause
ret [] [] Lit
c = forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a. [a] -> [a]
reverse (Lit
c forall a. a -> [a] -> [a]
: Clause
ret)
          go Lit
i Clause
ret (Lit
x : Clause
xs) (Lit
y : Clause
ys) Lit
c = do
            Lit
z <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
PrimMonad m =>
Encoder m -> Lit -> Lit -> Lit -> m Lit
Tseitin.encodeFASum Encoder m
enc Lit
x Lit
y Lit
c
            Lit
c' <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
PrimMonad m =>
Encoder m -> Lit -> Lit -> Lit -> m Lit
Tseitin.encodeFACarry Encoder m
enc Lit
x Lit
y Lit
c
            Lit -> Clause -> Clause -> Clause -> Lit -> StateT Clause m Clause
go (Lit
iforall a. Num a => a -> a -> a
+Lit
1) (Lit
z forall a. a -> [a] -> [a]
: Clause
ret) Clause
xs Clause
ys Lit
c'
          go Lit
_ Clause
_ Clause
_ Clause
_ Lit
_ = forall a. HasCallStack => [Char] -> a
error [Char]
"encodeSumParallelCounter: should not happen"

      f :: Vector SAT.Lit -> StateT [SAT.Lit] m [SAT.Lit]
      f :: Vector Lit -> StateT Clause m Clause
f Vector Lit
xs
        | forall a. Vector a -> Bool
V.null Vector Lit
xs = forall (m :: * -> *) a. Monad m => a -> m a
return []
        | Bool
otherwise = do
            let len2 :: Lit
len2 = forall a. Vector a -> Lit
V.length Vector Lit
xs forall a. Integral a => a -> a -> a
`div` Lit
2
            Clause
cnt1 <- Vector Lit -> StateT Clause m Clause
f (forall a. Lit -> Lit -> Vector a -> Vector a
V.slice Lit
0 Lit
len2 Vector Lit
xs)
            Clause
cnt2 <- Vector Lit -> StateT Clause m Clause
f (forall a. Lit -> Lit -> Vector a -> Vector a
V.slice Lit
len2 Lit
len2 Vector Lit
xs)
            Lit
c <- if forall a. Vector a -> Lit
V.length Vector Lit
xs forall a. Integral a => a -> a -> a
`mod` Lit
2 forall a. Eq a => a -> a -> Bool
== Lit
0 then
                   forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *). PrimMonad m => Encoder m -> Clause -> m Lit
Tseitin.encodeDisj Encoder m
enc []
                 else
                   forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ Vector Lit
xs forall a. Vector a -> Lit -> a
V.! (forall a. Vector a -> Lit
V.length Vector Lit
xs forall a. Num a => a -> a -> a
- Lit
1)
            Clause -> Clause -> Lit -> StateT Clause m Clause
add Clause
cnt1 Clause
cnt2 Lit
c

  forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT (Vector Lit -> StateT Clause m Clause
f (forall a. [a] -> Vector a
V.fromList Clause
lits)) []

encodeGE :: forall m. PrimMonad m => Tseitin.Encoder m -> [SAT.Lit] -> [Bool] -> m SAT.Lit
encodeGE :: forall (m :: * -> *).
PrimMonad m =>
Encoder m -> Clause -> [Bool] -> m Lit
encodeGE Encoder m
enc Clause
lhs [Bool]
rhs = do
  let f :: [SAT.Lit] -> [Bool] -> SAT.Lit -> m SAT.Lit
      f :: Clause -> [Bool] -> Lit -> m Lit
f [] [] Lit
r = forall (m :: * -> *) a. Monad m => a -> m a
return Lit
r
      f [] (Bool
True  : [Bool]
_) Lit
_ = forall (m :: * -> *). PrimMonad m => Encoder m -> Clause -> m Lit
Tseitin.encodeDisj Encoder m
enc [] -- false
      f [] (Bool
False : [Bool]
bs) Lit
r = Clause -> [Bool] -> Lit -> m Lit
f [] [Bool]
bs Lit
r
      f (Lit
l : Clause
ls) (Bool
True  : [Bool]
bs) Lit
r = do
        Clause -> [Bool] -> Lit -> m Lit
f Clause
ls [Bool]
bs forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *). PrimMonad m => Encoder m -> Clause -> m Lit
Tseitin.encodeConj Encoder m
enc [Lit
l, Lit
r]
      f (Lit
l : Clause
ls) (Bool
False : [Bool]
bs) Lit
r = do
        Clause -> [Bool] -> Lit -> m Lit
f Clause
ls [Bool]
bs forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *). PrimMonad m => Encoder m -> Clause -> m Lit
Tseitin.encodeDisj Encoder m
enc [Lit
l, Lit
r]
      f (Lit
l : Clause
ls) [] Lit
r = do
        Clause -> [Bool] -> Lit -> m Lit
f Clause
ls [] forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *). PrimMonad m => Encoder m -> Clause -> m Lit
Tseitin.encodeDisj Encoder m
enc [Lit
l, Lit
r]
  Lit
t <- forall (m :: * -> *). PrimMonad m => Encoder m -> Clause -> m Lit
Tseitin.encodeConj Encoder m
enc [] -- true
  Clause -> [Bool] -> Lit -> m Lit
f Clause
lhs [Bool]
rhs Lit
t