{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE ExplicitForAll #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}

{- | A carrier for the 'State' effect. It evaluates its inner state strictly, which is the correct choice for the majority of use cases.

Note that the parameter order in 'runState', 'evalState', and 'execState' is reversed compared the equivalent functions provided by @transformers@. This is an intentional decision made to enable the composition of effect handlers with '.' without invoking 'flip'.

@since 1.0.0.0
-}
module Control.Carrier.State.Strict
( -- * Strict state carrier
  runState
, evalState
, execState
, StateC(..)
  -- * State effect
, module Control.Effect.State
) where

import Control.Algebra
import Control.Applicative (Alternative(..))
import Control.Effect.State
import Control.Monad (MonadPlus)
import Control.Monad.Fail as Fail
import Control.Monad.Fix
import Control.Monad.IO.Class
import Control.Monad.Trans.Class

-- | Run a 'State' effect starting from the passed value.
--
-- @
-- 'runState' s ('pure' a) = 'pure' (s, a)
-- @
-- @
-- 'runState' s 'get' = 'pure' (s, s)
-- @
-- @
-- 'runState' s ('put' t) = 'pure' (t, ())
-- @
--
-- @since 1.0.0.0
runState :: s -> StateC s m a -> m (s, a)
runState :: forall s (m :: * -> *) a. s -> StateC s m a -> m (s, a)
runState s
s (StateC s -> m (s, a)
runStateC) = s -> m (s, a)
runStateC s
s
{-# INLINE[3] runState #-}

-- | Run a 'State' effect, yielding the result value and discarding the final state.
--
-- @
-- 'evalState' s m = 'fmap' 'snd' ('runState' s m)
-- @
--
-- @since 1.0.0.0
evalState :: forall s m a . Functor m => s -> StateC s m a -> m a
evalState :: forall s (m :: * -> *) a. Functor m => s -> StateC s m a -> m a
evalState s
s = ((s, a) -> a) -> m (s, a) -> m a
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (s, a) -> a
forall a b. (a, b) -> b
snd (m (s, a) -> m a)
-> (StateC s m a -> m (s, a)) -> StateC s m a -> m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. s -> StateC s m a -> m (s, a)
forall s (m :: * -> *) a. s -> StateC s m a -> m (s, a)
runState s
s
{-# INLINE[3] evalState #-}

-- | Run a 'State' effect, yielding the final state and discarding the return value.
--
-- @
-- 'execState' s m = 'fmap' 'fst' ('runState' s m)
-- @
--
-- @since 1.0.0.0
execState :: forall s m a . Functor m => s -> StateC s m a -> m s
execState :: forall s (m :: * -> *) a. Functor m => s -> StateC s m a -> m s
execState s
s = ((s, a) -> s) -> m (s, a) -> m s
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (s, a) -> s
forall a b. (a, b) -> a
fst (m (s, a) -> m s)
-> (StateC s m a -> m (s, a)) -> StateC s m a -> m s
forall b c a. (b -> c) -> (a -> b) -> a -> c
. s -> StateC s m a -> m (s, a)
forall s (m :: * -> *) a. s -> StateC s m a -> m (s, a)
runState s
s
{-# INLINE[3] execState #-}


-- | @since 1.0.0.0
newtype StateC s m a = StateC (s -> m (s, a))
  deriving ((forall a b. (a -> b) -> StateC s m a -> StateC s m b)
-> (forall a b. a -> StateC s m b -> StateC s m a)
-> Functor (StateC s m)
forall a b. a -> StateC s m b -> StateC s m a
forall a b. (a -> b) -> StateC s m a -> StateC s m b
forall s (m :: * -> *) a b.
Functor m =>
a -> StateC s m b -> StateC s m a
forall s (m :: * -> *) a b.
Functor m =>
(a -> b) -> StateC s m a -> StateC s m b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall s (m :: * -> *) a b.
Functor m =>
(a -> b) -> StateC s m a -> StateC s m b
fmap :: forall a b. (a -> b) -> StateC s m a -> StateC s m b
$c<$ :: forall s (m :: * -> *) a b.
Functor m =>
a -> StateC s m b -> StateC s m a
<$ :: forall a b. a -> StateC s m b -> StateC s m a
Functor)

instance Monad m => Applicative (StateC s m) where
  pure :: forall a. a -> StateC s m a
pure a
a = (s -> m (s, a)) -> StateC s m a
forall s (m :: * -> *) a. (s -> m (s, a)) -> StateC s m a
StateC (\ s
s -> (s, a) -> m (s, a)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (s
s, a
a))
  {-# INLINE pure #-}

  StateC s -> m (s, a -> b)
f <*> :: forall a b. StateC s m (a -> b) -> StateC s m a -> StateC s m b
<*> StateC s -> m (s, a)
a = (s -> m (s, b)) -> StateC s m b
forall s (m :: * -> *) a. (s -> m (s, a)) -> StateC s m a
StateC ((s -> m (s, b)) -> StateC s m b)
-> (s -> m (s, b)) -> StateC s m b
forall a b. (a -> b) -> a -> b
$ \ s
s -> do
    (s
s', a -> b
f') <- s -> m (s, a -> b)
f s
s
    (s
s'', a
a') <- s -> m (s, a)
a s
s'
    (s, b) -> m (s, b)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (s
s'', a -> b
f' a
a')
  {-# INLINE (<*>) #-}

  StateC s m a
m *> :: forall a b. StateC s m a -> StateC s m b -> StateC s m b
*> StateC s m b
k = StateC s m a
m StateC s m a -> (a -> StateC s m b) -> StateC s m b
forall a b. StateC s m a -> (a -> StateC s m b) -> StateC s m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= StateC s m b -> a -> StateC s m b
forall a b. a -> b -> a
const StateC s m b
k
  {-# INLINE (*>) #-}

instance (Alternative m, Monad m) => Alternative (StateC s m) where
  empty :: forall a. StateC s m a
empty = (s -> m (s, a)) -> StateC s m a
forall s (m :: * -> *) a. (s -> m (s, a)) -> StateC s m a
StateC (m (s, a) -> s -> m (s, a)
forall a b. a -> b -> a
const m (s, a)
forall a. m a
forall (f :: * -> *) a. Alternative f => f a
empty)
  {-# INLINE empty #-}

  StateC s -> m (s, a)
l <|> :: forall a. StateC s m a -> StateC s m a -> StateC s m a
<|> StateC s -> m (s, a)
r = (s -> m (s, a)) -> StateC s m a
forall s (m :: * -> *) a. (s -> m (s, a)) -> StateC s m a
StateC (\ s
s -> s -> m (s, a)
l s
s m (s, a) -> m (s, a) -> m (s, a)
forall a. m a -> m a -> m a
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> s -> m (s, a)
r s
s)
  {-# INLINE (<|>) #-}

instance Monad m => Monad (StateC s m) where
  StateC s -> m (s, a)
m >>= :: forall a b. StateC s m a -> (a -> StateC s m b) -> StateC s m b
>>= a -> StateC s m b
f = (s -> m (s, b)) -> StateC s m b
forall s (m :: * -> *) a. (s -> m (s, a)) -> StateC s m a
StateC ((s -> m (s, b)) -> StateC s m b)
-> (s -> m (s, b)) -> StateC s m b
forall a b. (a -> b) -> a -> b
$ \ s
s -> do
    (s
s', a
a) <- s -> m (s, a)
m s
s
    s -> StateC s m b -> m (s, b)
forall s (m :: * -> *) a. s -> StateC s m a -> m (s, a)
runState s
s' (a -> StateC s m b
f a
a)
  {-# INLINE (>>=) #-}

instance Fail.MonadFail m => Fail.MonadFail (StateC s m) where
  fail :: forall a. String -> StateC s m a
fail String
s = (s -> m (s, a)) -> StateC s m a
forall s (m :: * -> *) a. (s -> m (s, a)) -> StateC s m a
StateC (m (s, a) -> s -> m (s, a)
forall a b. a -> b -> a
const (String -> m (s, a)
forall a. String -> m a
forall (m :: * -> *) a. MonadFail m => String -> m a
Fail.fail String
s))
  {-# INLINE fail #-}

instance MonadFix m => MonadFix (StateC s m) where
  mfix :: forall a. (a -> StateC s m a) -> StateC s m a
mfix a -> StateC s m a
f = (s -> m (s, a)) -> StateC s m a
forall s (m :: * -> *) a. (s -> m (s, a)) -> StateC s m a
StateC (\ s
s -> ((s, a) -> m (s, a)) -> m (s, a)
forall a. (a -> m a) -> m a
forall (m :: * -> *) a. MonadFix m => (a -> m a) -> m a
mfix (s -> StateC s m a -> m (s, a)
forall s (m :: * -> *) a. s -> StateC s m a -> m (s, a)
runState s
s (StateC s m a -> m (s, a))
-> ((s, a) -> StateC s m a) -> (s, a) -> m (s, a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> StateC s m a
f (a -> StateC s m a) -> ((s, a) -> a) -> (s, a) -> StateC s m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (s, a) -> a
forall a b. (a, b) -> b
snd))
  {-# INLINE mfix #-}

instance MonadIO m => MonadIO (StateC s m) where
  liftIO :: forall a. IO a -> StateC s m a
liftIO IO a
io = (s -> m (s, a)) -> StateC s m a
forall s (m :: * -> *) a. (s -> m (s, a)) -> StateC s m a
StateC (\ s
s -> (,) s
s (a -> (s, a)) -> m a -> m (s, a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO a -> m a
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO a
io)
  {-# INLINE liftIO #-}

instance (Alternative m, Monad m) => MonadPlus (StateC s m)

instance MonadTrans (StateC s) where
  lift :: forall (m :: * -> *) a. Monad m => m a -> StateC s m a
lift m a
m = (s -> m (s, a)) -> StateC s m a
forall s (m :: * -> *) a. (s -> m (s, a)) -> StateC s m a
StateC (\ s
s -> (,) s
s (a -> (s, a)) -> m a -> m (s, a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m a
m)
  {-# INLINE lift #-}

instance Algebra sig m => Algebra (State s :+: sig) (StateC s m) where
  alg :: forall (ctx :: * -> *) (n :: * -> *) a.
Functor ctx =>
Handler ctx n (StateC s m)
-> (:+:) (State s) sig n a -> ctx () -> StateC s m (ctx a)
alg Handler ctx n (StateC s m)
hdl (:+:) (State s) sig n a
sig ctx ()
ctx = (s -> m (s, ctx a)) -> StateC s m (ctx a)
forall s (m :: * -> *) a. (s -> m (s, a)) -> StateC s m a
StateC ((s -> m (s, ctx a)) -> StateC s m (ctx a))
-> (s -> m (s, ctx a)) -> StateC s m (ctx a)
forall a b. (a -> b) -> a -> b
$ \ s
s -> case (:+:) (State s) sig n a
sig of
    L State s n a
Get     -> (s, ctx a) -> m (s, ctx a)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (s
s, s
a
s a -> ctx () -> ctx a
forall a b. a -> ctx b -> ctx a
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ ctx ()
ctx)
    L (Put s
s) -> (s, ctx a) -> m (s, ctx a)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (s
s, ctx a
ctx ()
ctx)
    R sig n a
other   -> Handler (Compose ((,) s) ctx) n m
-> sig n a -> (s, ctx ()) -> m (s, ctx a)
forall (ctx1 :: * -> *) (ctx2 :: * -> *)
       (sig :: (* -> *) -> * -> *) (m :: * -> *) (n :: * -> *) a.
(Functor ctx1, Functor ctx2, Algebra sig m) =>
Handler (Compose ctx1 ctx2) n m
-> sig n a -> ctx1 (ctx2 ()) -> m (ctx1 (ctx2 a))
thread ((s -> StateC s m x -> m (s, x)) -> (s, StateC s m x) -> m (s, x)
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry s -> StateC s m x -> m (s, x)
forall s (m :: * -> *) a. s -> StateC s m a -> m (s, a)
runState (forall {x}. (s, StateC s m x) -> m (s, x))
-> Handler ctx n (StateC s m) -> Handler (Compose ((,) s) ctx) n m
forall (n :: * -> *) (ctx1 :: * -> *) (m :: * -> *)
       (ctx2 :: * -> *) (l :: * -> *).
(Functor n, Functor ctx1) =>
Handler ctx1 m n
-> Handler ctx2 l m -> Handler (Compose ctx1 ctx2) l n
~<~ ctx (n x) -> StateC s m (ctx x)
Handler ctx n (StateC s m)
hdl) sig n a
other (s
s, ctx ()
ctx)
  {-# INLINE alg #-}