{-# OPTIONS_GHC -Wall #-}
{-# LANGUAGE RecursiveDo #-}
{-# LANGUAGE TupleSections #-}

module Control.Monad.Trans.RevState
  ( -- * Monad Transformer
    StateT (StateT)
  , runStateT
  , evalStateT
  , execStateT
  , mapStateT
  , withStateT
  , liftStateT

    -- * Monad
  , State
  , runState
  , evalState
  , execState
  , mapState
  , withState

    -- * Primitives and basic combinators
  , get
  , put
  , state
  , gets  
  , modify
  ) where

import Control.Arrow (first)
import Control.Monad
import Control.Monad.Fix
import Control.Monad.Identity


newtype StateT s m a = StateT
  { forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT :: s -> m (a, s) }

evalStateT :: Functor m => StateT s m a -> s -> m a
evalStateT :: forall (m :: * -> *) s a. Functor m => StateT s m a -> s -> m a
evalStateT StateT s m a
m s
s = (a, s) -> a
forall a b. (a, b) -> a
fst ((a, s) -> a) -> m (a, s) -> m a
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
`fmap` StateT s m a -> s -> m (a, s)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT StateT s m a
m s
s

execStateT :: Functor m => StateT s m a -> s -> m s
execStateT :: forall (m :: * -> *) s a. Functor m => StateT s m a -> s -> m s
execStateT StateT s m a
m s
s = (a, s) -> s
forall a b. (a, b) -> b
snd ((a, s) -> s) -> m (a, s) -> m s
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
`fmap` StateT s m a -> s -> m (a, s)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT StateT s m a
m s
s

type State s = StateT s Identity

runState :: State s a -> s -> (a, s)
runState :: forall s a. State s a -> s -> (a, s)
runState State s a
m s
s = Identity (a, s) -> (a, s)
forall a. Identity a -> a
runIdentity (Identity (a, s) -> (a, s)) -> Identity (a, s) -> (a, s)
forall a b. (a -> b) -> a -> b
$ State s a -> s -> Identity (a, s)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT State s a
m s
s

evalState :: State s a -> s -> a
evalState :: forall s a. State s a -> s -> a
evalState State s a
m s
s = (a, s) -> a
forall a b. (a, b) -> a
fst ((a, s) -> a) -> (a, s) -> a
forall a b. (a -> b) -> a -> b
$ State s a -> s -> (a, s)
forall s a. State s a -> s -> (a, s)
runState State s a
m s
s

execState :: State s a -> s -> s
execState :: forall s a. State s a -> s -> s
execState State s a
m s
s = (a, s) -> s
forall a b. (a, b) -> b
snd ((a, s) -> s) -> (a, s) -> s
forall a b. (a -> b) -> a -> b
$ State s a -> s -> (a, s)
forall s a. State s a -> s -> (a, s)
runState State s a
m s
s

-- we can't declare StateT as an instance of MonadTrans
-- because ghc >= 9.6 requires (forall m. Monad m => Monad (StateT s m))
-- but we need (MonadFix m) to guarantee that (StateT s m) is a Monad
liftStateT :: Functor m => m a -> StateT s m a
liftStateT :: forall (m :: * -> *) a s. Functor m => m a -> StateT s m a
liftStateT m a
m = (s -> m (a, s)) -> StateT s m a
forall s (m :: * -> *) a. (s -> m (a, s)) -> StateT s m a
StateT ((s -> m (a, s)) -> StateT s m a)
-> (s -> m (a, s)) -> StateT s m a
forall a b. (a -> b) -> a -> b
$ \s
s -> (a -> (a, s)) -> m a -> m (a, s)
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (,s
s) m a
m

instance MonadFix m => Monad (StateT s m) where
  StateT s m a
m >>= :: forall a b. StateT s m a -> (a -> StateT s m b) -> StateT s m b
>>= a -> StateT s m b
f = (s -> m (b, s)) -> StateT s m b
forall s (m :: * -> *) a. (s -> m (a, s)) -> StateT s m a
StateT ((s -> m (b, s)) -> StateT s m b)
-> (s -> m (b, s)) -> StateT s m b
forall a b. (a -> b) -> a -> b
$ \s
s -> do
    rec
      (a
x, s
s'') <- StateT s m a -> s -> m (a, s)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT StateT s m a
m s
s'
      (b
x', s
s') <- StateT s m b -> s -> m (b, s)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT (a -> StateT s m b
f a
x) s
s
    (b, s) -> m (b, s)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (b
x', s
s'')

instance MonadFix m => Applicative (StateT s m) where
  pure :: forall a. a -> StateT s m a
pure a
x = (s -> (a, s)) -> StateT s m a
forall (m :: * -> *) s a.
Applicative m =>
(s -> (a, s)) -> StateT s m a
state ((s -> (a, s)) -> StateT s m a) -> (s -> (a, s)) -> StateT s m a
forall a b. (a -> b) -> a -> b
$ \s
s -> (a
x, s
s)
  <*> :: forall a b. StateT s m (a -> b) -> StateT s m a -> StateT s m b
(<*>) = StateT s m (a -> b) -> StateT s m a -> StateT s m b
forall (m :: * -> *) a b. Monad m => m (a -> b) -> m a -> m b
ap

instance Functor m => Functor (StateT s m) where
  -- this instance is hand-written
  -- so we don't have to rely on m being MonadFix
  fmap :: forall a b. (a -> b) -> StateT s m a -> StateT s m b
fmap a -> b
f StateT s m a
m = (s -> m (b, s)) -> StateT s m b
forall s (m :: * -> *) a. (s -> m (a, s)) -> StateT s m a
StateT ((s -> m (b, s)) -> StateT s m b)
-> (s -> m (b, s)) -> StateT s m b
forall a b. (a -> b) -> a -> b
$ \s
s -> (a -> b) -> (a, s) -> (b, s)
forall b c d. (b -> c) -> (b, d) -> (c, d)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (b, d) (c, d)
first a -> b
f ((a, s) -> (b, s)) -> m (a, s) -> m (b, s)
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
`fmap` StateT s m a -> s -> m (a, s)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT StateT s m a
m s
s


instance MonadFix m => MonadFix (StateT s m) where
  mfix :: forall a. (a -> StateT s m a) -> StateT s m a
mfix a -> StateT s m a
f = (s -> m (a, s)) -> StateT s m a
forall s (m :: * -> *) a. (s -> m (a, s)) -> StateT s m a
StateT ((s -> m (a, s)) -> StateT s m a)
-> (s -> m (a, s)) -> StateT s m a
forall a b. (a -> b) -> a -> b
$ \s
s ->
    ((a, s) -> m (a, s)) -> m (a, s)
forall a. (a -> m a) -> m a
forall (m :: * -> *) a. MonadFix m => (a -> m a) -> m a
mfix (\ ~(a
x, s
_) -> StateT s m a -> s -> m (a, s)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT (a -> StateT s m a
f a
x) s
s)


get :: Applicative m => StateT s m s
get :: forall (m :: * -> *) s. Applicative m => StateT s m s
get = (s -> (s, s)) -> StateT s m s
forall (m :: * -> *) s a.
Applicative m =>
(s -> (a, s)) -> StateT s m a
state ((s -> (s, s)) -> StateT s m s) -> (s -> (s, s)) -> StateT s m s
forall a b. (a -> b) -> a -> b
$ \s
s -> (s
s, s
s)

put :: Applicative m => s -> StateT s m ()
put :: forall (m :: * -> *) s. Applicative m => s -> StateT s m ()
put s
s' = (s -> ((), s)) -> StateT s m ()
forall (m :: * -> *) s a.
Applicative m =>
(s -> (a, s)) -> StateT s m a
state ((s -> ((), s)) -> StateT s m ())
-> (s -> ((), s)) -> StateT s m ()
forall a b. (a -> b) -> a -> b
$ \s
_s -> ((), s
s')

modify :: Applicative m => (s -> s) -> StateT s m ()
modify :: forall (m :: * -> *) s. Applicative m => (s -> s) -> StateT s m ()
modify s -> s
f = (s -> ((), s)) -> StateT s m ()
forall (m :: * -> *) s a.
Applicative m =>
(s -> (a, s)) -> StateT s m a
state ((s -> ((), s)) -> StateT s m ())
-> (s -> ((), s)) -> StateT s m ()
forall a b. (a -> b) -> a -> b
$ \s
s -> ((), s -> s
f s
s)

state :: Applicative m => (s -> (a, s)) -> StateT s m a
state :: forall (m :: * -> *) s a.
Applicative m =>
(s -> (a, s)) -> StateT s m a
state s -> (a, s)
f = (s -> m (a, s)) -> StateT s m a
forall s (m :: * -> *) a. (s -> m (a, s)) -> StateT s m a
StateT ((s -> m (a, s)) -> StateT s m a)
-> (s -> m (a, s)) -> StateT s m a
forall a b. (a -> b) -> a -> b
$ \s
s -> (a, s) -> m (a, s)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (s -> (a, s)
f s
s)


mapStateT :: (m (a, s) -> n (b, s)) -> StateT s m a -> StateT s n b
mapStateT :: forall (m :: * -> *) a s (n :: * -> *) b.
(m (a, s) -> n (b, s)) -> StateT s m a -> StateT s n b
mapStateT m (a, s) -> n (b, s)
f StateT s m a
m = (s -> n (b, s)) -> StateT s n b
forall s (m :: * -> *) a. (s -> m (a, s)) -> StateT s m a
StateT ((s -> n (b, s)) -> StateT s n b)
-> (s -> n (b, s)) -> StateT s n b
forall a b. (a -> b) -> a -> b
$ m (a, s) -> n (b, s)
f (m (a, s) -> n (b, s)) -> (s -> m (a, s)) -> s -> n (b, s)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StateT s m a -> s -> m (a, s)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT StateT s m a
m

withStateT :: (s -> s) -> StateT s m a -> StateT s m a
withStateT :: forall s (m :: * -> *) a. (s -> s) -> StateT s m a -> StateT s m a
withStateT s -> s
f StateT s m a
m = (s -> m (a, s)) -> StateT s m a
forall s (m :: * -> *) a. (s -> m (a, s)) -> StateT s m a
StateT ((s -> m (a, s)) -> StateT s m a)
-> (s -> m (a, s)) -> StateT s m a
forall a b. (a -> b) -> a -> b
$ StateT s m a -> s -> m (a, s)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT StateT s m a
m (s -> m (a, s)) -> (s -> s) -> s -> m (a, s)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. s -> s
f

mapState :: ((a, s) -> (b, s)) -> State s a -> State s b
mapState :: forall a s b. ((a, s) -> (b, s)) -> State s a -> State s b
mapState (a, s) -> (b, s)
f = (Identity (a, s) -> Identity (b, s))
-> StateT s Identity a -> StateT s Identity b
forall (m :: * -> *) a s (n :: * -> *) b.
(m (a, s) -> n (b, s)) -> StateT s m a -> StateT s n b
mapStateT ((b, s) -> Identity (b, s)
forall a. a -> Identity a
Identity ((b, s) -> Identity (b, s))
-> (Identity (a, s) -> (b, s))
-> Identity (a, s)
-> Identity (b, s)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a, s) -> (b, s)
f ((a, s) -> (b, s))
-> (Identity (a, s) -> (a, s)) -> Identity (a, s) -> (b, s)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Identity (a, s) -> (a, s)
forall a. Identity a -> a
runIdentity)

withState :: (s -> s) -> State s a -> State s a
withState :: forall s a. (s -> s) -> State s a -> State s a
withState = (s -> s) -> StateT s Identity a -> StateT s Identity a
forall s (m :: * -> *) a. (s -> s) -> StateT s m a -> StateT s m a
withStateT

gets :: Applicative m => (s -> a) -> StateT s m a
gets :: forall (m :: * -> *) s a. Applicative m => (s -> a) -> StateT s m a
gets s -> a
f = (s -> a) -> StateT s m s -> StateT s m a
forall a b. (a -> b) -> StateT s m a -> StateT s m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap s -> a
f StateT s m s
forall (m :: * -> *) s. Applicative m => StateT s m s
get