{-# LANGUAGE TemplateHaskell #-}
-----------------------------------------------------------------------------
-- |
-- Module      :  Control.Effect.State
-- Copyright   :  (c) Michael Szvetits, 2020
-- License     :  BSD3 (see the file LICENSE)
-- Maintainer  :  typedbyte@qualified.name
-- Stability   :  stable
-- Portability :  portable
--
-- The state effect, similar to the @MonadState@ type class from the @mtl@
-- library.
--
-- Lazy and strict interpretations of the effect are available here:
-- "Control.Effect.State.Lazy" and "Control.Effect.State.Strict".
-----------------------------------------------------------------------------
module Control.Effect.State
  ( -- * Tagged State Effect
    State'(..)
    -- * Convenience Functions
  , gets'
  , modify'
  , modifyStrict'
    -- * Untagged State Effect
    -- | If you don't require disambiguation of multiple state effects
    -- (i.e., you only have one state effect in your monadic context),
    -- it is recommended to always use the untagged state effect.
  , State
  , get
  , put
  , state
  , gets
  , modify
  , modifyStrict
    -- * Tagging and Untagging
    -- | Conversion functions between the tagged and untagged state effect,
    -- usually used in combination with type applications, like:
    --
    -- @
    --     'tagState'' \@\"newTag\" program
    --     'retagState'' \@\"oldTag\" \@\"newTag\" program
    --     'untagState'' \@\"erasedTag\" program
    -- @
    -- 
  , tagState'
  , retagState'
  , untagState'
  ) where

-- base
import Data.Tuple (swap)

-- transformers
import qualified Control.Monad.Trans.RWS.CPS      as Strict
import qualified Control.Monad.Trans.RWS.Lazy     as Lazy
import qualified Control.Monad.Trans.State.Lazy   as L
import qualified Control.Monad.Trans.State.Strict as S

import Control.Effect.Machinery

-- | An effect that adds a mutable state to a given computation.
class Monad m => State' tag s m | tag m -> s where
  {-# MINIMAL get', put' | state' #-}
  
  -- | Gets the current state.
  get' :: m s
  get' = (s -> (s, s)) -> m s
forall k (tag :: k) s (m :: * -> *) a.
State' tag s m =>
(s -> (s, a)) -> m a
state' @tag (\s
s -> (s
s, s
s))
  {-# INLINE get' #-}
  
  -- | Replaces the state with a new value.
  put' :: s -> m ()
  put' s
s = (s -> (s, ())) -> m ()
forall k (tag :: k) s (m :: * -> *) a.
State' tag s m =>
(s -> (s, a)) -> m a
state' @tag (\s
_ -> (s
s, ()))
  {-# INLINE put' #-}
  
  -- | Updates the state and produces a value based on the current state.
  state' :: (s -> (s, a)) -> m a
  state' s -> (s, a)
f = do
    s
s <- forall k (tag :: k) s (m :: * -> *). State' tag s m => m s
forall s (m :: * -> *). State' tag s m => m s
get' @tag
    let ~(s
s', a
a) = s -> (s, a)
f s
s
    s -> m ()
forall k (tag :: k) s (m :: * -> *). State' tag s m => s -> m ()
put' @tag s
s'
    a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
a
  {-# INLINE state' #-}

makeTaggedEffect ''State'

instance Monad m => State' tag s (L.StateT s m) where
  get' :: StateT s m s
get' = StateT s m s
forall (m :: * -> *) s. Monad m => StateT s m s
L.get
  {-# INLINE get' #-}
  put' :: s -> StateT s m ()
put' = s -> StateT s m ()
forall (m :: * -> *) s. Monad m => s -> StateT s m ()
L.put
  {-# INLINE put' #-}
  state' :: (s -> (s, a)) -> StateT s m a
state' = (s -> (a, s)) -> StateT s m a
forall (m :: * -> *) s a. Monad m => (s -> (a, s)) -> StateT s m a
L.state ((s -> (a, s)) -> StateT s m a)
-> ((s -> (s, a)) -> s -> (a, s)) -> (s -> (s, a)) -> StateT s m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((s, a) -> (a, s)) -> (s -> (s, a)) -> s -> (a, s)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (s, a) -> (a, s)
forall a b. (a, b) -> (b, a)
swap
  {-# INLINE state' #-}

instance Monad m => State' tag s (S.StateT s m) where
  get' :: StateT s m s
get' = StateT s m s
forall (m :: * -> *) s. Monad m => StateT s m s
S.get
  {-# INLINE get' #-}
  put' :: s -> StateT s m ()
put' = s -> StateT s m ()
forall (m :: * -> *) s. Monad m => s -> StateT s m ()
S.put
  {-# INLINE put' #-}
  state' :: (s -> (s, a)) -> StateT s m a
state' = (s -> (a, s)) -> StateT s m a
forall (m :: * -> *) s a. Monad m => (s -> (a, s)) -> StateT s m a
S.state ((s -> (a, s)) -> StateT s m a)
-> ((s -> (s, a)) -> s -> (a, s)) -> (s -> (s, a)) -> StateT s m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((s, a) -> (a, s)) -> (s -> (s, a)) -> s -> (a, s)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (s, a) -> (a, s)
forall a b. (a, b) -> (b, a)
swap
  {-# INLINE state' #-}

instance (Monad m, Monoid w) => State' tag s (Lazy.RWST r w s m) where
  get' :: RWST r w s m s
get' = RWST r w s m s
forall w (m :: * -> *) r s. (Monoid w, Monad m) => RWST r w s m s
Lazy.get
  {-# INLINE get' #-}
  put' :: s -> RWST r w s m ()
put' = s -> RWST r w s m ()
forall w (m :: * -> *) s r.
(Monoid w, Monad m) =>
s -> RWST r w s m ()
Lazy.put
  {-# INLINE put' #-}
  state' :: (s -> (s, a)) -> RWST r w s m a
state' = (s -> (a, s)) -> RWST r w s m a
forall w (m :: * -> *) s a r.
(Monoid w, Monad m) =>
(s -> (a, s)) -> RWST r w s m a
Lazy.state ((s -> (a, s)) -> RWST r w s m a)
-> ((s -> (s, a)) -> s -> (a, s))
-> (s -> (s, a))
-> RWST r w s m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((s, a) -> (a, s)) -> (s -> (s, a)) -> s -> (a, s)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (s, a) -> (a, s)
forall a b. (a, b) -> (b, a)
swap
  {-# INLINE state' #-}

instance Monad m => State' tag s (Strict.RWST r w s m) where
  get' :: RWST r w s m s
get' = RWST r w s m s
forall (m :: * -> *) r w s. Monad m => RWST r w s m s
Strict.get
  {-# INLINE get' #-}
  put' :: s -> RWST r w s m ()
put' = s -> RWST r w s m ()
forall (m :: * -> *) s r w. Monad m => s -> RWST r w s m ()
Strict.put
  {-# INLINE put' #-}
  state' :: (s -> (s, a)) -> RWST r w s m a
state' = (s -> (a, s)) -> RWST r w s m a
forall (m :: * -> *) s a r w.
Monad m =>
(s -> (a, s)) -> RWST r w s m a
Strict.state ((s -> (a, s)) -> RWST r w s m a)
-> ((s -> (s, a)) -> s -> (a, s))
-> (s -> (s, a))
-> RWST r w s m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((s, a) -> (a, s)) -> (s -> (s, a)) -> s -> (a, s)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (s, a) -> (a, s)
forall a b. (a, b) -> (b, a)
swap
  {-# INLINE state' #-}

-- | Gets a specific component of the state, using the provided projection function.
gets' :: forall tag s m a. State' tag s m => (s -> a) -> m a
gets' :: (s -> a) -> m a
gets' s -> a
f = (s -> a) -> m s -> m a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap s -> a
f (forall k (tag :: k) s (m :: * -> *). State' tag s m => m s
forall s (m :: * -> *). State' tag s m => m s
get' @tag)
{-# INLINE gets' #-}

-- | Modifies the state, using the provided function.
modify' :: forall tag s m. State' tag s m => (s -> s) -> m ()
modify' :: (s -> s) -> m ()
modify' s -> s
f = do
  s
s <- forall k (tag :: k) s (m :: * -> *). State' tag s m => m s
forall s (m :: * -> *). State' tag s m => m s
get' @tag
  s -> m ()
forall k (tag :: k) s (m :: * -> *). State' tag s m => s -> m ()
put' @tag (s -> s
f s
s)
{-# INLINE modify' #-}

-- | Modifies the state, using the provided function.
-- The computation is strict in the new state.
modifyStrict' :: forall tag s m. State' tag s m => (s -> s) -> m ()
modifyStrict' :: (s -> s) -> m ()
modifyStrict' s -> s
f = do
  s
s <- forall k (tag :: k) s (m :: * -> *). State' tag s m => m s
forall s (m :: * -> *). State' tag s m => m s
get' @tag
  forall k (tag :: k) s (m :: * -> *). State' tag s m => s -> m ()
forall s (m :: * -> *). State' tag s m => s -> m ()
put' @tag (s -> m ()) -> s -> m ()
forall a b. (a -> b) -> a -> b
$! s -> s
f s
s
{-# INLINE modifyStrict' #-}

makeUntagged ['gets', 'modify', 'modifyStrict']