{-# LANGUAGE RecursiveDo, UndecidableInstances, FlexibleInstances, FunctionalDependencies, MultiParamTypeClasses, RankNTypes, MagicHash #-} module Control.Monad.ST.Trans (STT, runSTT, module Control.Monad.ST.Class) where import GHC.ST hiding (liftST) import qualified GHC.ST as ST import Control.Monad.ST.Class import GHC.Prim import Control.Monad import Control.Monad.Trans import Control.Monad.Fix import Control.Monad.State.Class import Control.Monad.Reader.Class -- | 'ST' monad transformer. newtype STT s m a = STT {execSTT :: State# s -> m (STret s a)} instance Monad m => Monad (STT s m) where return x = STT (\ s -> return (STret s x)) m >>= k = STT (\ s -> do STret s' x <- execSTT m s execSTT (k x) s') fail err = STT (\ s -> fail err) instance MonadTrans (STT s) where lift m = STT (\ s -> liftM (STret s) m) instance Monad m => MonadST s (STT s m) where liftST m = STT (\ s -> return (ST.liftST m s)) instance MonadState s m => MonadState s (STT s' m) where get = lift get put = lift . put instance MonadReader r m => MonadReader r (STT s m) where ask = lift ask local f = (lift . local f . return =<<) instance MonadPlus m => MonadPlus (STT s m) where mzero = lift mzero m `mplus` k = STT $ \ s -> execSTT m s `mplus` execSTT k s instance MonadFix m => MonadFix (STT s m) where mfix f = STT $ \ s -> mdo STret s' a <- execSTT (f a) s return (STret s' a) instance MonadIO m => MonadIO (STT s m) where liftIO = lift . liftIO {-# NOINLINE runSTTRep #-} runSTTRep :: Monad m => (forall s . State# s -> m (STret s a)) -> m a runSTTRep f = do STret s' x <- f realWorld# return x -- | Safely executes the state-transformer part of a monadic computation in the 'STT' monad transformer. runSTT :: Monad m => (forall s . STT s m a) -> m a runSTT m = runSTTRep (execSTT m)