{-# LANGUAGE UnboxedTuples #-}
{-# LANGUAGE PatternSynonyms #-}

-- | A state monad which is strict in its state.
module GHC.Utils.Monad.State.Strict
  ( -- * The State monad
    State(State)
  , state
  , evalState
  , execState
  , runState
    -- * Operations
  , get
  , gets
  , put
  , modify
  ) where

import GHC.Prelude

import GHC.Exts (oneShot)

{- Note [Strict State monad]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
A State monad can be strict in many ways. Which kind of strictness do we mean?

First of, since we represent the result pair as an unboxed pair, this State
monad is strict in the sense of "Control.Monad.Trans.State.Strict": The
computations and the sequencing there-of (through 'Applicative and 'Monad'
instances) are forced strictly.

Beyond the manual unboxing of one level (which CPR could achieve similarly,
yet perhaps a bit less reliably), our 'State' is even stricter than the
transformers version:
It's also strict in the state `s` (but still lazy in the value `a`). What this
means is that whenever callers examine the state component (perhaps through
'runState'), they will find that the `s` has already been evaluated.

This additional strictness maintained in a single place, by the ubiquitous
'State' pattern synonym, by forcing the state component *after* any state action
has been run. The INVARIANT is:

> Any `s` that makes it into the unboxed pair representation is evaluated.

This invariant has another nice effect: Because the evaluatedness is quite
apparent, Nested CPR will try to unbox the state component `s` nestedly if
feasible. Detecting evaluatedness of nested components is a necessary
condition for Nested CPR to trigger; see the user's guide entry on that:
https://ghc.gitlab.haskell.org/ghc/doc/users_guide/using-optimisation.html#ghc-flag--fcpr-anal

Note that this doesn't have any effects on whether Nested CPR will unbox the `a`
component (which is still lazy by default). The user still has to use the
`return $!` idiom from the user's guide to encourage Nested CPR to unbox the `a`
result of a stateful computation.
-}

-- | A state monad which is strict in the state `s`, but lazy in the value `a`.
--
-- See Note [Strict State monad] for the particular notion of strictness and
-- implementation details.
newtype State s a = State' { forall s a. State s a -> s -> (# a, s #)
runState' :: s -> (# a, s #) }

pattern State :: (s -> (# a, s #))
              -> State s a

-- This pattern synonym makes the monad eta-expand,
-- which as a very beneficial effect on compiler performance
-- See #18202.
-- See Note [The one-shot state monad trick] in GHC.Utils.Monad
-- It also implements the particular notion of strictness of this monad;
-- see Note [Strict State monad].
pattern $mState :: forall {r} {s} {a}.
State s a -> ((s -> (# a, s #)) -> r) -> ((# #) -> r) -> r
$bState :: forall s a. (s -> (# a, s #)) -> State s a
State m <- State' m
  where
    State s -> (# a, s #)
m = (s -> (# a, s #)) -> State s a
forall s a. (s -> (# a, s #)) -> State s a
State' ((s -> (# a, s #)) -> s -> (# a, s #)
forall a b. (a -> b) -> a -> b
oneShot ((s -> (# a, s #)) -> s -> (# a, s #))
-> (s -> (# a, s #)) -> s -> (# a, s #)
forall a b. (a -> b) -> a -> b
$ \s
s -> (# a, s #) -> (# a, s #)
forall a s. (# a, s #) -> (# a, s #)
forceState (s -> (# a, s #)
m s
s))

-- | Forces the state component of the unboxed representation pair of 'State'.
-- See Note [Strict State monad]. This is The Place doing the forcing!
forceState :: (# a, s #) -> (# a, s #)
forceState :: forall a s. (# a, s #) -> (# a, s #)
forceState (# a
a, !s
s #) = (# a
a, s
s #)

instance Functor (State s) where
  fmap :: forall a b. (a -> b) -> State s a -> State s b
fmap a -> b
f State s a
m = (s -> (# b, s #)) -> State s b
forall s a. (s -> (# a, s #)) -> State s a
State ((s -> (# b, s #)) -> State s b) -> (s -> (# b, s #)) -> State s b
forall a b. (a -> b) -> a -> b
$ \s
s -> case State s a -> s -> (# a, s #)
forall s a. State s a -> s -> (# a, s #)
runState' State s a
m s
s  of (# a
x, s
s' #) -> (# a -> b
f a
x, s
s' #)

instance Applicative (State s) where
  pure :: forall a. a -> State s a
pure a
x  = (s -> (# a, s #)) -> State s a
forall s a. (s -> (# a, s #)) -> State s a
State ((s -> (# a, s #)) -> State s a) -> (s -> (# a, s #)) -> State s a
forall a b. (a -> b) -> a -> b
$ \s
s -> (# a
x, s
s #)
  State s (a -> b)
m <*> :: forall a b. State s (a -> b) -> State s a -> State s b
<*> State s a
n = (s -> (# b, s #)) -> State s b
forall s a. (s -> (# a, s #)) -> State s a
State ((s -> (# b, s #)) -> State s b) -> (s -> (# b, s #)) -> State s b
forall a b. (a -> b) -> a -> b
$ \s
s ->
    case State s (a -> b) -> s -> (# a -> b, s #)
forall s a. State s a -> s -> (# a, s #)
runState' State s (a -> b)
m s
s  of { (# a -> b
f, s
s' #) ->
    case State s a -> s -> (# a, s #)
forall s a. State s a -> s -> (# a, s #)
runState' State s a
n s
s' of { (# a
x, s
s'' #) ->
                             (# a -> b
f a
x, s
s'' #) }}

instance Monad (State s) where
  State s a
m >>= :: forall a b. State s a -> (a -> State s b) -> State s b
>>= a -> State s b
n = (s -> (# b, s #)) -> State s b
forall s a. (s -> (# a, s #)) -> State s a
State ((s -> (# b, s #)) -> State s b) -> (s -> (# b, s #)) -> State s b
forall a b. (a -> b) -> a -> b
$ \s
s -> case State s a -> s -> (# a, s #)
forall s a. State s a -> s -> (# a, s #)
runState' State s a
m s
s of
    (# a
r, !s
s' #) -> State s b -> s -> (# b, s #)
forall s a. State s a -> s -> (# a, s #)
runState' (a -> State s b
n a
r) s
s'

state :: (s -> (a, s)) -> State s a
state :: forall s a. (s -> (a, s)) -> State s a
state s -> (a, s)
f = (s -> (# a, s #)) -> State s a
forall s a. (s -> (# a, s #)) -> State s a
State ((s -> (# a, s #)) -> State s a) -> (s -> (# a, s #)) -> State s a
forall a b. (a -> b) -> a -> b
$ \s
s -> case s -> (a, s)
f s
s of (a
r, s
s') -> (# a
r, s
s' #)

get :: State s s
get :: forall s. State s s
get = (s -> (# s, s #)) -> State s s
forall s a. (s -> (# a, s #)) -> State s a
State ((s -> (# s, s #)) -> State s s) -> (s -> (# s, s #)) -> State s s
forall a b. (a -> b) -> a -> b
$ \s
s -> (# s
s, s
s #)

gets :: (s -> a) -> State s a
gets :: forall s a. (s -> a) -> State s a
gets s -> a
f = (s -> (# a, s #)) -> State s a
forall s a. (s -> (# a, s #)) -> State s a
State ((s -> (# a, s #)) -> State s a) -> (s -> (# a, s #)) -> State s a
forall a b. (a -> b) -> a -> b
$ \s
s -> (# s -> a
f s
s, s
s #)

put :: s -> State s ()
put :: forall s. s -> State s ()
put s
s' = (s -> (# (), s #)) -> State s ()
forall s a. (s -> (# a, s #)) -> State s a
State ((s -> (# (), s #)) -> State s ())
-> (s -> (# (), s #)) -> State s ()
forall a b. (a -> b) -> a -> b
$ \s
_ -> (# (), s
s' #)

modify :: (s -> s) -> State s ()
modify :: forall s. (s -> s) -> State s ()
modify s -> s
f = (s -> (# (), s #)) -> State s ()
forall s a. (s -> (# a, s #)) -> State s a
State ((s -> (# (), s #)) -> State s ())
-> (s -> (# (), s #)) -> State s ()
forall a b. (a -> b) -> a -> b
$ \s
s -> (# (), s -> s
f s
s #)

evalState :: State s a -> s -> a
evalState :: forall s a. State s a -> s -> a
evalState State s a
s s
i = case State s a -> s -> (# a, s #)
forall s a. State s a -> s -> (# a, s #)
runState' State s a
s s
i of (# a
a, s
_ #) -> a
a

execState :: State s a -> s -> s
execState :: forall s a. State s a -> s -> s
execState State s a
s s
i = case State s a -> s -> (# a, s #)
forall s a. State s a -> s -> (# a, s #)
runState' State s a
s s
i of (# a
_, s
s' #) -> s
s'

runState :: State s a -> s -> (a, s)
runState :: forall s a. State s a -> s -> (a, s)
runState State s a
s s
i = case State s a -> s -> (# a, s #)
forall s a. State s a -> s -> (# a, s #)
runState' State s a
s s
i of (# a
a, !s
s' #) -> (a
a, s
s')