module Control.Monad.Trans.State.Ref
( StateRefT
, runStateRefT
, runStateIORefT
, runStateSTRefT
, module Control.Monad.State.Class
) where
import Control.Applicative (Applicative (..))
import Control.Monad.Catch (MonadCatch (..), MonadMask (..),
MonadThrow (..))
import Control.Monad.IO.Class (MonadIO (..))
import Control.Monad.Reader.Class (MonadReader (..))
import Control.Monad.State.Class
import Control.Monad.Trans.Control (defaultLiftBaseWith,
defaultRestoreM)
import Control.Monad.Trans.Unlift
import Control.Monad.Trans.Resource (MonadResource (..))
import Data.Mutable (IORef, MCState, MutableRef,
PrimMonad, PrimState, RealWorld,
RefElement, STRef, newRef,
readRef, writeRef)
newtype StateRefT ref s m a = StateRefT
{ unStateRefT :: ref s -> m a
}
deriving Functor
runStateRefT
:: ( Monad m
, s ~ RefElement (ref s)
, MCState (ref s) ~ PrimState b
, MonadBase b m
, MutableRef (ref s)
, PrimMonad b
)
=> StateRefT ref s m a
-> s
-> m (a, s)
runStateRefT (StateRefT f) v0 = do
ref <- liftBase $ newRef v0
a <- f ref
v <- liftBase $ readRef ref
return (a, v)
runStateIORefT
:: ( Monad m
, RealWorld ~ PrimState b
, MonadBase b m
, PrimMonad b
)
=> StateRefT IORef s m a
-> s
-> m (a, s)
runStateIORefT = runStateRefT
runStateSTRefT
:: ( Monad m
, ps ~ PrimState b
, MonadBase b m
, PrimMonad b
)
=> StateRefT (STRef ps) s m a
-> s
-> m (a, s)
runStateSTRefT = runStateRefT
instance Applicative m => Applicative (StateRefT ref s m) where
pure = StateRefT . const . pure
StateRefT f <*> StateRefT g = StateRefT $ \x -> f x <*> g x
instance Monad m => Monad (StateRefT ref s m) where
return = StateRefT . const . return
StateRefT f >>= g = StateRefT $ \x -> do
a <- f x
unStateRefT (g a) x
instance ( MCState (ref s) ~ PrimState b
, Monad m
, s ~ RefElement (ref s)
, MutableRef (ref s)
, PrimMonad b
, MonadBase b m
)
=> MonadState s (StateRefT ref s m) where
get = StateRefT $ liftBase . readRef
put x = seq x $ StateRefT $ liftBase . (`writeRef` x)
instance MonadReader r m => MonadReader r (StateRefT ref s m) where
ask = StateRefT $ const ask
local f m = StateRefT $ local f . unStateRefT m
reader = StateRefT . const . reader
instance MonadTrans (StateRefT ref s) where
lift = StateRefT . const
instance MonadIO m => MonadIO (StateRefT ref s m) where
liftIO = lift . liftIO
instance MonadBase b m => MonadBase b (StateRefT ref s m) where
liftBase = lift . liftBase
instance MonadTransControl (StateRefT ref s) where
type StT (StateRefT ref s) a = a
liftWith f = StateRefT $ \r -> f $ \t -> unStateRefT t r
restoreT = StateRefT . const
instance MonadBaseControl b m => MonadBaseControl b (StateRefT ref s m) where
type StM (StateRefT ref s m) a = StM m a
liftBaseWith = defaultLiftBaseWith
restoreM = defaultRestoreM
instance MonadThrow m => MonadThrow (StateRefT ref s m) where
throwM = lift . throwM
instance MonadCatch m => MonadCatch (StateRefT ref s m) where
catch (StateRefT f) g = StateRefT $ \e -> catch (f e) ((`unStateRefT` e) . g)
instance MonadMask m => MonadMask (StateRefT ref s m) where
mask a = StateRefT $ \e -> mask $ \u -> unStateRefT (a $ q u) e
where q :: (m a -> m a) -> StateRefT ref s m a -> StateRefT ref s m a
q u (StateRefT b) = StateRefT (u . b)
uninterruptibleMask a =
StateRefT $ \e -> uninterruptibleMask $ \u -> unStateRefT (a $ q u) e
where q :: (m a -> m a) -> StateRefT ref s m a -> StateRefT ref s m a
q u (StateRefT b) = StateRefT (u . b)
instance MonadResource m => MonadResource (StateRefT ref s m) where
liftResourceT = lift . liftResourceT