{-# LANGUAGE BangPatterns, GeneralizedNewtypeDeriving, CPP #-} module Jukebox.Sat.Easy where import MiniSat hiding (neg) import qualified MiniSat import Jukebox.Form(Signed(..), neg) import qualified Data.Map.Strict as Map import Data.Map(Map) import Control.Monad import Control.Monad.Trans.Class import Control.Monad.IO.Class import Control.Monad.Trans.State.Strict import Control.Monad.Trans.Reader import Data.Maybe #if __GLASGOW_HASKELL__ < 710 import Control.Applicative import Data.Traversable(traverse) #endif newtype Sat1 a b = Sat1 { runSat1_ :: ReaderT Solver (ReaderT (Watch a) (StateT (Map a Lit) IO)) b } deriving (Functor, Applicative, Monad, MonadIO) newtype Sat a b c = Sat { runSat_ :: ReaderT (Watch a) (StateT (Map b (SatState a)) IO) c } deriving (Functor, Applicative, Monad, MonadIO) data SatState a = SatState Solver (Map a Lit) type Watch a = a -> Sat1 a () data Form a = Lit (Signed a) | And [Form a] | Or [Form a] nt :: Form a -> Form a nt (Lit x) = Lit (neg x) nt (And xs) = Or (fmap nt xs) nt (Or xs) = And (fmap nt xs) true, false :: Form a true = And [] false = Or [] unique :: [Form a] -> Form a unique = u where u [] = true u [_] = true u (x:xs) = And [Or [nt x, And (map nt xs)], u xs] runSat :: Ord b => Watch a -> [b] -> Sat a b c -> IO c runSat w idxs x = go idxs Map.empty where go [] m = evalStateT (runReaderT (runSat_ x) w) m go (idx:idxs) m = withNewSolver $ \s -> go idxs (Map.insert idx (SatState s Map.empty) m) runSat1 :: Ord a => Watch a -> Sat1 a b -> IO b runSat1 w x = runSat w [()] (atIndex () x) atIndex :: (Ord a, Ord b) => b -> Sat1 a c -> Sat a b c atIndex !idx m = do watch <- Sat ask SatState s ls <- Sat (lift (gets (Map.findWithDefault (error "withSolver: index not found") idx))) (x, ls') <- liftIO (runStateT (runReaderT (runReaderT (runSat1_ m) s) watch) ls) Sat (lift (modify (Map.insert idx (SatState s ls')))) return x solve :: Ord a => [Signed a] -> Sat1 a Bool solve xs = do s <- Sat1 ask ls <- mapM lit xs liftIO (MiniSat.solve s ls) model :: Ord a => Sat1 a (a -> Bool) model = do s <- Sat1 ask m <- Sat1 (lift (lift get)) vals <- liftIO (traverse (MiniSat.modelValue s) m) return (\v -> fromMaybe False (Map.findWithDefault Nothing v vals)) modelValue :: Ord a => a -> Sat1 a Bool modelValue x = do s <- Sat1 ask l <- var x Just b <- liftIO (MiniSat.modelValue s l) return b addForm :: Ord a => Form a -> Sat1 a () addForm f = do s <- Sat1 ask cs <- flatten f liftIO (mapM (MiniSat.addClause s) cs) return () flatten :: Ord a => Form a -> Sat1 a [[Lit]] flatten (Lit l) = fmap (return . return) (lit l) flatten (And fs) = fmap concat (mapM flatten fs) flatten (Or fs) = fmap (fmap concat . sequence) (mapM flatten fs) lit :: Ord a => Signed a -> Sat1 a Lit lit (Pos x) = var x lit (Neg x) = liftM MiniSat.neg (var x) var :: Ord a => a -> Sat1 a Lit var x = do s <- Sat1 ask m <- Sat1 (lift (lift get)) case Map.lookup x m of Nothing -> do l <- liftIO (MiniSat.newLit s) Sat1 (lift (lift (put (Map.insert x l m)))) w <- Sat1 (lift ask) w x return l Just l -> return l