{-# LANGUAGE PolymorphicComponents ,MultiParamTypeClasses ,FunctionalDependencies ,FlexibleInstances #-} {- This file is part of funsat. funsat is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. funsat is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details. You should have received a copy of the GNU Lesser General Public License along with funsat. If not, see . Copyright 2008 Denis Bueno -} {-| The main SAT solver monad. Embeds `ST'. See type `SSTErrMonad', which stands for ''State ST Error Monad''. -} module Funsat.Monad ( liftST , runSSTErrMonad , evalSSTErrMonad , SSTErrMonad ) where import Control.Monad.Error import Control.Monad.ST.Strict import Control.Monad.State.Class import Control.Monad.MonadST instance MonadST s (SSTErrMonad e st s) where liftST = dpllST -- | Perform an @ST@ action in the DPLL monad. dpllST :: ST s a -> SSTErrMonad e st s a {-# INLINE dpllST #-} dpllST st = SSTErrMonad (\k s -> st >>= \x -> k x s) -- | @runSSTErrMonad m s@ executes a `SSTErrMonad' action with initial state @s@ -- until an error occurs or a result is returned. runSSTErrMonad :: (Error e) => SSTErrMonad e st s a -> (st -> ST s (Either e a, st)) runSSTErrMonad m = unSSTErrMonad m (\x s -> return (return x, s)) evalSSTErrMonad :: (Error e) => SSTErrMonad e st s a -> st -> ST s (Either e a) evalSSTErrMonad m s = do (result, _) <- runSSTErrMonad m s return result -- | @SSTErrMonad e st s a@: the error type @e@, state type @st@, @ST@ thread -- @s@ and result type @a@. -- -- This is a monad embedding @ST@ and supporting error handling and state -- threading. It uses CPS to avoid checking `Left' and `Right' for every -- `>>='; instead only checks on `catchError'. Idea adapted from -- . newtype SSTErrMonad e st s a = SSTErrMonad { unSSTErrMonad :: forall r. (a -> (st -> ST s (Either e r, st))) -> (st -> ST s (Either e r, st)) } instance Monad (SSTErrMonad e st s) where return x = SSTErrMonad ($ x) (>>=) = bindSSTErrMonad bindSSTErrMonad :: SSTErrMonad e st s a -> (a -> SSTErrMonad e st s b) -> SSTErrMonad e st s b {-# INLINE bindSSTErrMonad #-} bindSSTErrMonad m f = {-# SCC "bindSSTErrMonad" #-} SSTErrMonad (\k -> unSSTErrMonad m (\a -> unSSTErrMonad (f a) k)) instance MonadState st (SSTErrMonad e st s) where get = SSTErrMonad (\k s -> k s s) put s' = SSTErrMonad (\k _ -> k () s') instance (Error e) => MonadError e (SSTErrMonad e st s) where throwError err = -- throw away continuation SSTErrMonad (\_ s -> return (Left err, s)) catchError action handler = {-# SCC "catchErrorSSTErrMonad" #-} SSTErrMonad (\k s -> do (x, s') <- runSSTErrMonad action s case x of Left error -> unSSTErrMonad (handler error) k s' Right result -> k result s') instance (Error e) => MonadPlus (SSTErrMonad e st s) where mzero = SSTErrMonad (\_ s -> return (Left noMsg, s)) mplus m n = SSTErrMonad (\k s -> do (r, s') <- runSSTErrMonad m s case r of Left _ -> unSSTErrMonad n k s' Right x -> k x s')