module Ersatz.Monad
(
SAT(..)
, MonadSAT(..)
) where
import Control.Applicative
import Control.Monad
import Control.Monad.Reader
import Control.Monad.RWS.Strict as Strict
import Control.Monad.RWS.Lazy as Lazy
import Control.Monad.State.Lazy as Lazy
import Control.Monad.State.Strict as Strict
import Control.Monad.Writer.Lazy as Lazy
import Control.Monad.Writer.Strict as Strict
import Data.IntSet as IntSet
import Data.HashMap.Strict as HashMap
import Ersatz.Internal.Formula
import Ersatz.Internal.Literal
import Ersatz.Internal.StableName
import Ersatz.Problem
import System.IO.Unsafe
newtype SAT m a = SAT { runSAT :: forall r. (a -> Problem -> m r) -> Problem -> m r }
instance Functor (SAT m) where
fmap f (SAT m) = SAT $ \k -> m (k . f)
instance Applicative (SAT m) where
pure a = SAT $ \k -> k a
(<*>) = ap
instance Monad (SAT m) where
return a = SAT $ \k -> k a
SAT m >>= f = SAT $ \k -> m (\a -> runSAT (f a) k)
instance MonadTrans SAT where
lift m = SAT $ \k p -> do
a <- m
k a p
instance MonadIO m => MonadIO (SAT m) where
liftIO m = SAT $ \k p -> do
a <- liftIO m
k a p
class (Applicative m, Monad m) => MonadSAT m where
sat :: (Problem -> (a, Problem)) -> m a
default sat :: (MonadTrans t, MonadSAT n, m ~ t n) => (Problem -> (a, Problem)) -> m a
sat = lift . sat
literalExists :: m Literal
literalExists = sat $ \qbf -> let !qbfLastAtom' = qbfLastAtom qbf + 1 in
(Literal qbfLastAtom', qbf { qbfLastAtom = qbfLastAtom' })
literalForall :: m Literal
literalForall = sat $ \qbf -> let !qbfLastAtom' = qbfLastAtom qbf + 1 in
( Literal qbfLastAtom', qbf { qbfLastAtom = qbfLastAtom', qbfUniversals = IntSet.insert qbfLastAtom' (qbfUniversals qbf) })
assertFormula :: Formula -> m ()
assertFormula formula = sat $ \qbf -> ((), qbf { qbfFormula = qbfFormula qbf <> formula })
generateLiteral :: a -> (forall n. Literal -> SAT n ()) -> m Literal
generateLiteral a f = sat $ \qbf -> case HashMap.lookup sn (qbfSNMap qbf) of
Just l -> (l, qbf)
Nothing | !qbfLastAtom' <- qbfLastAtom qbf + 1, !l <- Literal qbfLastAtom' ->
runSAT (f l) (\_ s -> (l, s)) qbf { qbfSNMap = HashMap.insert sn l (qbfSNMap qbf), qbfLastAtom = qbfLastAtom' }
where sn = unsafePerformIO (makeStableName' a)
instance MonadSAT (SAT m) where
sat f = SAT $ \k s -> case f s of
(a, t) -> k a t
instance MonadSAT m => MonadSAT (ReaderT r m)
instance MonadSAT m => MonadSAT (Lazy.StateT s m)
instance MonadSAT m => MonadSAT (Strict.StateT s m)
instance (MonadSAT m, Monoid w) => MonadSAT (Lazy.WriterT w m)
instance (MonadSAT m, Monoid w) => MonadSAT (Strict.WriterT w m)
instance (MonadSAT m, Monoid w) => MonadSAT (Lazy.RWST r w s m)
instance (MonadSAT m, Monoid w) => MonadSAT (Strict.RWST r w s m)