module Control.Monad.Trans.State.Plus (
StatePlusT, runStatePlusT, execStatePlusT, evalStatePlusT) where
import Control.Monad.State
import Control.Applicative
newtype StatePlusT s m a = MkSPT { unSPT :: StateT (Bool, s) m a }
deriving (Functor, MonadTrans, MonadIO)
mzeroError :: a
mzeroError = error "StatePlusT mzero value"
instance Monad m => Monad (StatePlusT s m) where
return = lift . return
x >>= f = (MkSPT . StateT) go
where go s = do
(a, s') <- runStateT (unSPT x) s
if fst s'
then runStateT (unSPT $ f a) s'
else return $ (mzeroError, s')
instance (Monad m, Functor m) => Applicative (StatePlusT s m) where
pure = return
(<*>) = ap
instance Monad m => MonadState s (StatePlusT s m) where
get = (MkSPT . StateT) $ \s -> return (snd s, s)
put v = (MkSPT . StateT) $ \s -> return ((), (fst s, v))
plusStates :: (a, (Bool, s)) -> (a, (Bool, s)) -> (a, (Bool, s))
plusStates (_, (False, _)) b = b
plusStates a _ = a
instance Monad m => MonadPlus (StatePlusT s m) where
mzero = (MkSPT . StateT) $ \s -> return (mzeroError, (False, snd s))
mplus a b = (MkSPT . StateT) go where
go s = do
as <- runStateT (unSPT a) s
bs <- runStateT (unSPT b) s
let (rr, rs) = plusStates as bs
runStateT (return rr) rs
instance (Monad m, Functor m) => Alternative (StatePlusT s m) where
empty = mzero
(<|>) = mplus
runStatePlusT :: Monad m => StatePlusT s m a -> s -> m (Maybe a, s)
runStatePlusT sm s = do
(v, (isOK, ss)) <- runStateT (unSPT sm) (True, s)
return (if isOK then Just v else Nothing, ss)
execStatePlusT :: Monad m => StatePlusT s m a -> s -> m s
execStatePlusT sm s = do
(_, ss) <- runStatePlusT sm s
return ss
evalStatePlusT :: Monad m => StatePlusT s m a -> s -> m (Maybe a)
evalStatePlusT sm s = do
(mb, _) <- runStatePlusT sm s
return mb