{-# LANGUAGE ScopedTypeVariables, FlexibleContexts #-} {-# OPTIONS_GHC -Wall #-} ----------------------------------------------------------------------------- -- | -- Module : ToySolver.SAT.Encoder.PB.Internal.Sorter -- Copyright : (c) Masahiro Sakai 2016 -- License : BSD-style -- -- Maintainer : masahiro.sakai@gmail.com -- Stability : provisional -- Portability : non-portable -- -- References: -- -- * [ES06] N. Eén and N. Sörensson. Translating Pseudo-Boolean -- Constraints into SAT. JSAT 2:1–26, 2006. -- ----------------------------------------------------------------------------- module ToySolver.SAT.Encoder.PB.Internal.Sorter ( Base , UDigit , UNumber , isRepresentable , encode , decode , Cost , optimizeBase , genSorterCircuit , sortVector , addPBLinAtLeastSorter , encodePBLinAtLeastSorter ) where import Control.Monad.Primitive import Control.Monad.State import Control.Monad.Writer import Data.List import Data.Maybe import Data.Ord import Data.Vector (Vector, (!)) import qualified Data.Vector as V import qualified Data.Vector.Mutable as MV import ToySolver.Data.BoolExpr import ToySolver.Data.Boolean import qualified ToySolver.SAT.Types as SAT import qualified ToySolver.SAT.Encoder.Tseitin as Tseitin -- ------------------------------------------------------------------------ -- Circuit-like implementation of Batcher's odd-even mergesort genSorterCircuit :: Int -> [(Int,Int)] genSorterCircuit len = execWriter (mergeSort (V.iterateN len (+1) 0)) [] where genCompareAndSwap i j = tell ((i,j) :) mergeSort is | V.length is <= 1 = return () | V.length is == 2 = genCompareAndSwap (is!0) (is!1) | otherwise = case halve is of (is1,is2) -> do mergeSort is1 mergeSort is2 oddEvenMerge is oddEvenMerge is | V.length is <= 1 = return () | V.length is == 2 = genCompareAndSwap (is!0) (is!1) | otherwise = case splitOddEven is of (os,es) -> do oddEvenMerge os oddEvenMerge es forM_ [2,3 .. V.length is-1] $ \i -> do genCompareAndSwap (is!(i-1)) (is!i) halve :: Vector a -> (Vector a, Vector a) halve v | V.length v <= 1 = (v, V.empty) | otherwise = (V.slice 0 len1 v, V.slice len1 len2 v) where n = head $ dropWhile (< V.length v) $ iterate (*2) 1 len1 = n `div` 2 len2 = V.length v - len1 splitOddEven :: Vector a -> (Vector a, Vector a) splitOddEven v = (V.generate len1 (\i -> v V.! (i*2+1)), V.generate len2 (\i -> v V.! (i*2))) where len1 = V.length v `div` 2 len2 = (V.length v + 1) `div` 2 sortVector :: (Ord a) => Vector a -> Vector a sortVector v = V.create $ do m <- V.thaw v forM_ (genSorterCircuit (V.length v)) $ \(i,j) -> do vi <- MV.read m i vj <- MV.read m j when (vi > vj) $ do MV.write m i vj MV.write m j vi return m -- ------------------------------------------------------------------------ type Base = [Int] type UDigit = Int type UNumber = [UDigit] isRepresentable :: Base -> Integer -> Bool isRepresentable _ 0 = True isRepresentable [] x = x <= fromIntegral (maxBound :: UDigit) isRepresentable (b:bs) x = isRepresentable bs (x `div` fromIntegral b) encode :: Base -> Integer -> UNumber encode _ 0 = [] encode [] x | x <= fromIntegral (maxBound :: UDigit) = [fromIntegral x] | otherwise = undefined encode (b:bs) x = fromIntegral (x `mod` fromIntegral b) : encode bs (x `div` fromIntegral b) decode :: Base -> UNumber -> Integer decode _ [] = 0 decode [] [x] = fromIntegral x decode (b:bs) (x:xs) = fromIntegral x + fromIntegral b * decode bs xs {- test1 = encode [3,5] 164 -- [2,4,10] test2 = decode [3,5] [2,4,10] -- 164 test3 = optimizeBase [1,1,2,2,3,3,3,3,7] -} -- ------------------------------------------------------------------------ type Cost = Integer primes :: [Int] primes = [2, 3, 5, 7, 11, 13, 17] optimizeBase :: [Integer] -> Base optimizeBase xs = reverse $ fst $ fromJust $ execState (m xs [] 0) Nothing where m :: [Integer] -> Base -> Integer -> State (Maybe (Base, Cost)) () m xs base cost = do let lb = cost + sum [1 | x <- xs, x > 0] best <- get case best of Just (_bestBase, bestCost) | bestCost <= lb -> return () _ -> do when (sum xs <= 1024) $ do let cost' = cost + sum xs case best of Just (_bestBase, bestCost) | bestCost < cost' -> return () _ -> put $ Just (base, cost') unless (null xs) $ do let delta = sortBy (comparing snd) [(p, sum [x `mod` fromIntegral p | x <- xs]) | p <- primes] case delta of (p,0) : _ -> do m [d | x <- xs, let d = x `div` fromIntegral p, d > 0] (p : base) cost _ -> do forM_ delta $ \(p,s) -> do m [d | x <- xs, let d = x `div` fromIntegral p, d > 0] (p : base) (cost + s) -- ------------------------------------------------------------------------ addPBLinAtLeastSorter :: PrimMonad m => Tseitin.Encoder m -> SAT.PBLinAtLeast -> m () addPBLinAtLeastSorter enc constr = do formula <- encodePBLinAtLeastSorter' enc constr Tseitin.addFormula enc formula encodePBLinAtLeastSorter :: PrimMonad m => Tseitin.Encoder m -> SAT.PBLinAtLeast -> m SAT.Lit encodePBLinAtLeastSorter enc constr = do formula <- encodePBLinAtLeastSorter' enc constr Tseitin.encodeFormula enc formula encodePBLinAtLeastSorter' :: PrimMonad m => Tseitin.Encoder m -> SAT.PBLinAtLeast -> m Tseitin.Formula encodePBLinAtLeastSorter' enc (lhs,rhs) = do let base = optimizeBase [c | (c,_) <- lhs] if isRepresentable base rhs then do sorters <- genSorters enc base [(encode base c, l) | (c,l) <- lhs] [] return $ lexComp base sorters (encode base rhs) else do return false genSorters :: PrimMonad m => Tseitin.Encoder m -> Base -> [(UNumber, SAT.Lit)] -> [SAT.Lit] -> m [Vector SAT.Lit] genSorters enc base lhs carry = do let is = V.fromList carry <> V.concat [V.replicate (fromIntegral d) l | (d:_,l) <- lhs, d /= 0] buf <- V.thaw is forM_ (genSorterCircuit (V.length is)) $ \(i,j) -> do vi <- MV.read buf i vj <- MV.read buf j MV.write buf i =<< Tseitin.encodeDisj enc [vi,vj] MV.write buf j =<< Tseitin.encodeConj enc [vi,vj] os <- V.freeze buf case base of [] -> return [os] b:bs -> do oss <- genSorters enc bs [(ds,l) | (_:ds,l) <- lhs] [os!(i-1) | i <- takeWhile (<= V.length os) (iterate (+b) b)] return $ os : oss isGE :: Vector SAT.Lit -> Int -> Tseitin.Formula isGE out lim | lim <= 0 = true | lim - 1 < V.length out = Atom $ out ! (lim - 1) | otherwise = false isGEMod :: Int -> Vector SAT.Lit -> Int -> Tseitin.Formula isGEMod _n _out lim | lim <= 0 = true isGEMod n out lim = orB [isGE out (j + lim) .&&. notB isGE out (j + n) | j <- [0, n .. V.length out - 1]] lexComp :: Base -> [Vector SAT.Lit] -> UNumber -> Tseitin.Formula lexComp base lhs rhs = f true base lhs rhs where f ret (b:bs) (out:os) ds = f (gt .||. (ge .&&. ret)) bs os (drop 1 ds) where d = if null ds then 0 else head ds gt = isGEMod b out (d+1) ge = isGEMod b out d f ret [] [out] ds = gt .||. (ge .&&. ret) where d = if null ds then 0 else head ds gt = isGE out (d+1) ge = isGE out d -- ------------------------------------------------------------------------