module Funsat.Monad
( liftST
, runSSTErrMonad
, evalSSTErrMonad
, SSTErrMonad )
where
import Control.Monad.Error
import Control.Monad.ST.Strict
import Control.Monad.State.Class
import Control.Monad.MonadST
instance MonadST s (SSTErrMonad e st s) where
liftST = dpllST
dpllST :: ST s a -> SSTErrMonad e st s a
dpllST st = SSTErrMonad (\k s -> st >>= \x -> k x s)
runSSTErrMonad :: (Error e) => SSTErrMonad e st s a -> (st -> ST s (Either e a, st))
runSSTErrMonad m = unSSTErrMonad m (\x s -> return (return x, s))
evalSSTErrMonad :: (Error e) => SSTErrMonad e st s a -> st -> ST s (Either e a)
evalSSTErrMonad m s = do (result, _) <- runSSTErrMonad m s
return result
newtype SSTErrMonad e st s a =
SSTErrMonad { unSSTErrMonad :: forall r. (a -> (st -> ST s (Either e r, st)))
-> (st -> ST s (Either e r, st)) }
instance Monad (SSTErrMonad e st s) where
return x = SSTErrMonad ($ x)
(>>=) = bindSSTErrMonad
bindSSTErrMonad :: SSTErrMonad e st s a -> (a -> SSTErrMonad e st s b)
-> SSTErrMonad e st s b
bindSSTErrMonad m f =
SSTErrMonad (\k -> unSSTErrMonad m (\a -> unSSTErrMonad (f a) k))
instance MonadState st (SSTErrMonad e st s) where
get = SSTErrMonad (\k s -> k s s)
put s' = SSTErrMonad (\k _ -> k () s')
instance (Error e) => MonadError e (SSTErrMonad e st s) where
throwError err =
SSTErrMonad (\_ s -> return (Left err, s))
catchError action handler = SSTErrMonad
(\k s -> do (x, s') <- runSSTErrMonad action s
case x of
Left error -> unSSTErrMonad (handler error) k s'
Right result -> k result s')
instance (Error e) => MonadPlus (SSTErrMonad e st s) where
mzero = SSTErrMonad (\_ s -> return (Left noMsg, s))
mplus m n = SSTErrMonad (\k s ->
do (r, s') <- runSSTErrMonad m s
case r of
Left _ -> unSSTErrMonad n k s'
Right x -> k x s')