{-# LANGUAGE RecursiveDo, UndecidableInstances, FlexibleInstances, FunctionalDependencies, MultiParamTypeClasses, RankNTypes, MagicHash #-}
module Control.Monad.ST.Trans (STT, runSTT, module Control.Monad.ST.Class) where

import GHC.ST hiding (liftST)
import qualified GHC.ST as ST
import Control.Monad.ST.Class
import GHC.Prim
import Control.Monad
import Control.Monad.Trans
import Control.Monad.Fix
import Control.Monad.State.Class
import Control.Monad.Reader.Class

-- | 'ST' monad transformer.
newtype STT s m a = STT {execSTT :: State# s -> m (STret s a)}

instance Monad m => Monad (STT s m) where
	return x = STT (\ s -> return (STret s x))
	m >>= k = STT (\ s -> do	STret s' x <- execSTT m s
					execSTT (k x) s')
	fail err = STT (\ s -> fail err)

instance MonadTrans (STT s) where
	lift m = STT (\ s -> liftM (STret s) m)

instance Monad m => MonadST s (STT s m) where
	liftST m = STT (\ s -> return (ST.liftST m s))

instance MonadState s m => MonadState s (STT s' m) where
	get = lift get
	put = lift . put

instance MonadReader r m => MonadReader r (STT s m) where
	ask = lift ask
	local f = (lift . local f . return =<<)

instance MonadPlus m => MonadPlus (STT s m) where
	mzero = lift mzero
	m `mplus` k = STT $ \ s -> execSTT m s `mplus` execSTT k s

instance MonadFix m => MonadFix (STT s m) where
	mfix f = STT $ \ s -> mdo	STret s' a <- execSTT (f a) s
					return (STret s' a)

instance MonadIO m => MonadIO (STT s m) where
	liftIO = lift . liftIO

{-# NOINLINE runSTTRep #-}
runSTTRep :: Monad m => (forall s . State# s -> m (STret s a)) -> m a
runSTTRep f = do	STret s' x <- f realWorld#
			return x

-- | Safely executes the state-transformer part of a monadic computation in the 'STT' monad transformer.
runSTT :: Monad m => (forall s . STT s m a) -> m a
runSTT m = runSTTRep (execSTT m)