{-# OPTIONS_GHC -Wall #-}
{-# OPTIONS_HADDOCK show-extensions #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
-----------------------------------------------------------------------------
-- |
-- Module      :  ToySolver.SAT.Encoder.PBNLC
-- Copyright   :  (c) Masahiro Sakai 2015
-- License     :  BSD-style
--
-- Maintainer  :  masahiro.sakai@gmail.com
-- Stability   :  provisional
-- Portability :  non-portable
--
-----------------------------------------------------------------------------
module ToySolver.SAT.Encoder.PBNLC
  (
  -- * The encoder type
    Encoder
  , newEncoder
  , getTseitinEncoder

  -- * Adding constraints
  , addPBNLAtLeast
  , addPBNLAtMost
  , addPBNLExactly
  , addPBNLAtLeastSoft
  , addPBNLAtMostSoft
  , addPBNLExactlySoft

  -- * Linearization
  , linearizePBSum
  , linearizePBSumWithPolarity
  ) where

import Control.Monad.Primitive
import ToySolver.SAT.Types as SAT
import qualified ToySolver.SAT.Encoder.Tseitin as Tseitin
import ToySolver.Internal.Util (revForM)

data Encoder m
  = forall a. SAT.AddPBLin m a => Encoder
  { ()
encBase    :: a
  , forall (m :: * -> *). Encoder m -> Encoder m
encTseitin :: Tseitin.Encoder m
  }

instance Monad m => SAT.NewVar m (Encoder m) where
  newVar :: Encoder m -> m Var
newVar Encoder{ encBase :: ()
encBase = a
a }   = forall (m :: * -> *) a. NewVar m a => a -> m Var
SAT.newVar a
a
  newVars :: Encoder m -> Var -> m [Var]
newVars Encoder{ encBase :: ()
encBase = a
a }  = forall (m :: * -> *) a. NewVar m a => a -> Var -> m [Var]
SAT.newVars a
a
  newVars_ :: Encoder m -> Var -> m ()
newVars_ Encoder{ encBase :: ()
encBase = a
a } = forall (m :: * -> *) a. NewVar m a => a -> Var -> m ()
SAT.newVars_ a
a

instance Monad m => SAT.AddClause m (Encoder m) where
  addClause :: Encoder m -> [Var] -> m ()
addClause Encoder{ encBase :: ()
encBase = a
a } = forall (m :: * -> *) a. AddClause m a => a -> [Var] -> m ()
SAT.addClause a
a

instance Monad m => SAT.AddCardinality m (Encoder m) where
  addAtLeast :: Encoder m -> [Var] -> Var -> m ()
addAtLeast Encoder{ encBase :: ()
encBase = a
a } = forall (m :: * -> *) a.
AddCardinality m a =>
a -> [Var] -> Var -> m ()
SAT.addAtLeast a
a
  addAtMost :: Encoder m -> [Var] -> Var -> m ()
addAtMost  Encoder{ encBase :: ()
encBase = a
a } = forall (m :: * -> *) a.
AddCardinality m a =>
a -> [Var] -> Var -> m ()
SAT.addAtMost a
a
  addExactly :: Encoder m -> [Var] -> Var -> m ()
addExactly Encoder{ encBase :: ()
encBase = a
a } = forall (m :: * -> *) a.
AddCardinality m a =>
a -> [Var] -> Var -> m ()
SAT.addExactly a
a

instance Monad m => SAT.AddPBLin m (Encoder m) where
  addPBAtLeast :: Encoder m -> PBLinSum -> Integer -> m ()
addPBAtLeast Encoder{ encBase :: ()
encBase = a
a } = forall (m :: * -> *) a.
AddPBLin m a =>
a -> PBLinSum -> Integer -> m ()
SAT.addPBAtLeast a
a
  addPBAtMost :: Encoder m -> PBLinSum -> Integer -> m ()
addPBAtMost  Encoder{ encBase :: ()
encBase = a
a } = forall (m :: * -> *) a.
AddPBLin m a =>
a -> PBLinSum -> Integer -> m ()
SAT.addPBAtMost a
a
  addPBExactly :: Encoder m -> PBLinSum -> Integer -> m ()
addPBExactly Encoder{ encBase :: ()
encBase = a
a } = forall (m :: * -> *) a.
AddPBLin m a =>
a -> PBLinSum -> Integer -> m ()
SAT.addPBExactly a
a
  addPBAtLeastSoft :: Encoder m -> Var -> PBLinSum -> Integer -> m ()
addPBAtLeastSoft Encoder{ encBase :: ()
encBase = a
a } = forall (m :: * -> *) a.
AddPBLin m a =>
a -> Var -> PBLinSum -> Integer -> m ()
SAT.addPBAtLeastSoft a
a
  addPBAtMostSoft :: Encoder m -> Var -> PBLinSum -> Integer -> m ()
addPBAtMostSoft  Encoder{ encBase :: ()
encBase = a
a } = forall (m :: * -> *) a.
AddPBLin m a =>
a -> Var -> PBLinSum -> Integer -> m ()
SAT.addPBAtMostSoft a
a
  addPBExactlySoft :: Encoder m -> Var -> PBLinSum -> Integer -> m ()
addPBExactlySoft Encoder{ encBase :: ()
encBase = a
a } = forall (m :: * -> *) a.
AddPBLin m a =>
a -> Var -> PBLinSum -> Integer -> m ()
SAT.addPBExactlySoft a
a

newEncoder :: (SAT.AddPBLin m a) => a -> Tseitin.Encoder m -> m (Encoder m)
newEncoder :: forall (m :: * -> *) a.
AddPBLin m a =>
a -> Encoder m -> m (Encoder m)
newEncoder a
a Encoder m
tseitin = forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a. AddPBLin m a => a -> Encoder m -> Encoder m
Encoder a
a Encoder m
tseitin

getTseitinEncoder :: Encoder m -> Tseitin.Encoder m
getTseitinEncoder :: forall (m :: * -> *). Encoder m -> Encoder m
getTseitinEncoder Encoder{ encTseitin :: forall (m :: * -> *). Encoder m -> Encoder m
encTseitin = Encoder m
tseitin } = Encoder m
tseitin

instance PrimMonad m => SAT.AddPBNL m (Encoder m) where
  addPBNLAtLeast :: Encoder m -> PBSum -> Integer -> m ()
addPBNLAtLeast Encoder m
enc PBSum
lhs Integer
rhs = do
    let c :: Integer
c = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [Integer
c | (Integer
c,[]) <- PBSum
lhs]
    PBLinSum
lhs' <- forall (m :: * -> *).
PrimMonad m =>
Encoder m -> Polarity -> PBSum -> m PBLinSum
linearizePBSumWithPolarity Encoder m
enc Polarity
Tseitin.polarityPos [(Integer
c,[Var]
ls) | (Integer
c,[Var]
ls) <- PBSum
lhs, Bool -> Bool
not (forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Var]
ls)]
    forall (m :: * -> *) a.
AddPBLin m a =>
a -> PBLinSum -> Integer -> m ()
SAT.addPBAtLeast Encoder m
enc PBLinSum
lhs' (Integer
rhs forall a. Num a => a -> a -> a
- Integer
c)

  addPBNLAtMost :: Encoder m -> PBSum -> Integer -> m ()
addPBNLAtMost Encoder m
enc PBSum
lhs Integer
rhs =
    forall (m :: * -> *) a.
AddPBNL m a =>
a -> PBSum -> Integer -> m ()
addPBNLAtLeast Encoder m
enc [(-Integer
c,[Var]
ls) | (Integer
c,[Var]
ls) <- PBSum
lhs] (forall a. Num a => a -> a
negate Integer
rhs)

  addPBNLExactly :: Encoder m -> PBSum -> Integer -> m ()
addPBNLExactly Encoder m
enc PBSum
lhs Integer
rhs = do
    let c :: Integer
c = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [Integer
c | (Integer
c,[]) <- PBSum
lhs]
    PBLinSum
lhs' <- forall (m :: * -> *).
PrimMonad m =>
Encoder m -> PBSum -> m PBLinSum
linearizePBSum Encoder m
enc [(Integer
c,[Var]
ls) | (Integer
c,[Var]
ls) <- PBSum
lhs, Bool -> Bool
not (forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Var]
ls)]
    forall (m :: * -> *) a.
AddPBLin m a =>
a -> PBLinSum -> Integer -> m ()
SAT.addPBExactly Encoder m
enc PBLinSum
lhs' (Integer
rhs forall a. Num a => a -> a -> a
- Integer
c)

  addPBNLAtLeastSoft :: Encoder m -> Var -> PBSum -> Integer -> m ()
addPBNLAtLeastSoft Encoder m
enc Var
sel PBSum
lhs Integer
rhs = do
    let c :: Integer
c = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [Integer
c | (Integer
c,[]) <- PBSum
lhs]
    PBLinSum
lhs' <- forall (m :: * -> *).
PrimMonad m =>
Encoder m -> Polarity -> PBSum -> m PBLinSum
linearizePBSumWithPolarity Encoder m
enc Polarity
Tseitin.polarityPos [(Integer
c,[Var]
ls) | (Integer
c,[Var]
ls) <- PBSum
lhs, Bool -> Bool
not (forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Var]
ls)]
    forall (m :: * -> *) a.
AddPBLin m a =>
a -> Var -> PBLinSum -> Integer -> m ()
SAT.addPBAtLeastSoft Encoder m
enc Var
sel PBLinSum
lhs' (Integer
rhs forall a. Num a => a -> a -> a
- Integer
c)

  addPBNLAtMostSoft :: Encoder m -> Var -> PBSum -> Integer -> m ()
addPBNLAtMostSoft Encoder m
enc Var
sel PBSum
lhs Integer
rhs =
    forall (m :: * -> *) a.
AddPBNL m a =>
a -> Var -> PBSum -> Integer -> m ()
addPBNLAtLeastSoft Encoder m
enc Var
sel [(forall a. Num a => a -> a
negate Integer
c, [Var]
lit) | (Integer
c,[Var]
lit) <- PBSum
lhs] (forall a. Num a => a -> a
negate Integer
rhs)

  addPBNLExactlySoft :: Encoder m -> Var -> PBSum -> Integer -> m ()
addPBNLExactlySoft Encoder m
enc Var
sel PBSum
lhs Integer
rhs = do
    let c :: Integer
c = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [Integer
c | (Integer
c,[]) <- PBSum
lhs]
    PBLinSum
lhs' <- forall (m :: * -> *).
PrimMonad m =>
Encoder m -> PBSum -> m PBLinSum
linearizePBSum Encoder m
enc [(Integer
c,[Var]
ls) | (Integer
c,[Var]
ls) <- PBSum
lhs, Bool -> Bool
not (forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Var]
ls)]
    forall (m :: * -> *) a.
AddPBLin m a =>
a -> Var -> PBLinSum -> Integer -> m ()
SAT.addPBExactlySoft Encoder m
enc Var
sel PBLinSum
lhs' (Integer
rhs forall a. Num a => a -> a -> a
- Integer
c)

-- | Encode a non-linear 'PBSum' into a lienar 'PBLinSum'.
--
-- @linearizePBSum enc s@ is equivalent to @linearizePBSumWithPolarity enc polarityBoth@.
linearizePBSum
  :: PrimMonad m
  => Encoder m
  -> PBSum
  -> m PBLinSum
linearizePBSum :: forall (m :: * -> *).
PrimMonad m =>
Encoder m -> PBSum -> m PBLinSum
linearizePBSum Encoder m
enc = forall (m :: * -> *).
PrimMonad m =>
Encoder m -> Polarity -> PBSum -> m PBLinSum
linearizePBSumWithPolarity Encoder m
enc Polarity
Tseitin.polarityBoth

-- | Linearize a non-linear 'PBSum' into a lienar 'PBLinSum'.
--
-- The input 'PBSum' is assumed to occur only in specified polarity.
--
-- * If @'polarityPosOccurs' p@, the value of resulting 'PBLinSum' is /greater than/ or /equal to/ the value of original 'PBSum'.
--
-- * If @'polarityNegOccurs' p@, the value of resulting 'PBLinSum' is /lesser than/ or /equal to/ the value of original 'PBSum'.
--
linearizePBSumWithPolarity
  :: PrimMonad m
  => Encoder m
  -> Tseitin.Polarity -- polarity /p/
  -> PBSum
  -> m PBLinSum
linearizePBSumWithPolarity :: forall (m :: * -> *).
PrimMonad m =>
Encoder m -> Polarity -> PBSum -> m PBLinSum
linearizePBSumWithPolarity Encoder{ encTseitin :: forall (m :: * -> *). Encoder m -> Encoder m
encTseitin = Encoder m
tseitin } Polarity
p PBSum
xs =
  forall (m :: * -> *) a b. Monad m => [a] -> (a -> m b) -> m [b]
revForM PBSum
xs forall a b. (a -> b) -> a -> b
$ \(Integer
c,[Var]
ls) -> do
    Var
l <- if Integer
c forall a. Ord a => a -> a -> Bool
> Integer
0 then
           forall (m :: * -> *).
PrimMonad m =>
Encoder m -> Polarity -> [Var] -> m Var
Tseitin.encodeConjWithPolarity Encoder m
tseitin Polarity
p [Var]
ls
         else
           forall (m :: * -> *).
PrimMonad m =>
Encoder m -> Polarity -> [Var] -> m Var
Tseitin.encodeConjWithPolarity Encoder m
tseitin (Polarity -> Polarity
Tseitin.negatePolarity Polarity
p) [Var]
ls
    forall (m :: * -> *) a. Monad m => a -> m a
return (Integer
c,Var
l)