{-# LANGUAGE AllowAmbiguousTypes #-}

-- SPDX-License-Identifier: MPL-2.0

{- |
Copyright   :  (c) 2023 Sayo Koyoneda
License     :  MPL-2.0 (see the LICENSE file)
Maintainer  :  ymdfield@outlook.jp

Interpreter for the t'Data.Effect.State.State' effect.
-}
module Control.Monad.Hefty.State (
    module Control.Monad.Hefty.State,
    module Data.Effect.State,
)
where

import Control.Arrow ((>>>))
import Control.Monad.Hefty (
    Eff,
    StateInterpreter,
    interpose,
    interposeStateBy,
    interpret,
    interpretBy,
    interpretRecWith,
    interpretStateBy,
    interpretStateRecWith,
    raiseUnder,
    (&),
    type (<|),
    type (~>),
 )
import Control.Monad.Hefty.Reader (runAsk)
import Data.Effect.Reader (Ask (Ask), ask)
import Data.Effect.State
import Data.Functor ((<&>))
import UnliftIO (newIORef, readIORef, writeIORef)

-- | Interpret the 'State' effect.
runState :: forall s ef a. s -> Eff '[] (State s ': ef) a -> Eff '[] ef (s, a)
runState :: forall s (ef :: [* -> *]) a.
s -> Eff '[] (State s : ef) a -> Eff '[] ef (s, a)
runState s
s0 = s
-> (s -> a -> Eff '[] ef (s, a))
-> StateInterpreter s (State s) (Eff '[] ef) (s, a)
-> Eff '[] (State s : ef) a
-> Eff '[] ef (s, a)
forall s (e :: * -> *) (ef :: [* -> *]) ans a.
s
-> (s -> a -> Eff '[] ef ans)
-> StateInterpreter s e (Eff '[] ef) ans
-> Eff '[] (e : ef) a
-> Eff '[] ef ans
interpretStateBy s
s0 (((s, a) -> Eff '[] ef (s, a)) -> s -> a -> Eff '[] ef (s, a)
forall a b c. ((a, b) -> c) -> a -> b -> c
curry (s, a) -> Eff '[] ef (s, a)
forall a. a -> Eff '[] ef a
forall (f :: * -> *) a. Applicative f => a -> f a
pure) State s x
-> s -> (s -> x -> Eff '[] ef (s, a)) -> Eff '[] ef (s, a)
StateInterpreter s (State s) (Eff '[] ef) (s, a)
forall s (eh :: [EffectH]) (r :: [* -> *]) ans x.
State s x -> s -> (s -> x -> Eff eh r ans) -> Eff eh r ans
handleState

-- | Interpret the 'State' effect. Do not include the final state in the return value.
evalState :: forall s ef a. s -> Eff '[] (State s ': ef) a -> Eff '[] ef a
evalState :: forall s (ef :: [* -> *]) a.
s -> Eff '[] (State s : ef) a -> Eff '[] ef a
evalState s
s0 = s
-> (s -> a -> Eff '[] ef a)
-> StateInterpreter s (State s) (Eff '[] ef) a
-> Eff '[] (State s : ef) a
-> Eff '[] ef a
forall s (e :: * -> *) (ef :: [* -> *]) ans a.
s
-> (s -> a -> Eff '[] ef ans)
-> StateInterpreter s e (Eff '[] ef) ans
-> Eff '[] (e : ef) a
-> Eff '[] ef ans
interpretStateBy s
s0 ((a -> Eff '[] ef a) -> s -> a -> Eff '[] ef a
forall a b. a -> b -> a
const a -> Eff '[] ef a
forall a. a -> Eff '[] ef a
forall (f :: * -> *) a. Applicative f => a -> f a
pure) State s x -> s -> (s -> x -> Eff '[] ef a) -> Eff '[] ef a
StateInterpreter s (State s) (Eff '[] ef) a
forall s (eh :: [EffectH]) (r :: [* -> *]) ans x.
State s x -> s -> (s -> x -> Eff eh r ans) -> Eff eh r ans
handleState

-- | Interpret the 'State' effect. Do not include the final result in the return value.
execState :: forall s ef a. s -> Eff '[] (State s ': ef) a -> Eff '[] ef s
execState :: forall s (ef :: [* -> *]) a.
s -> Eff '[] (State s : ef) a -> Eff '[] ef s
execState s
s0 = s
-> (s -> a -> Eff '[] ef s)
-> StateInterpreter s (State s) (Eff '[] ef) s
-> Eff '[] (State s : ef) a
-> Eff '[] ef s
forall s (e :: * -> *) (ef :: [* -> *]) ans a.
s
-> (s -> a -> Eff '[] ef ans)
-> StateInterpreter s e (Eff '[] ef) ans
-> Eff '[] (e : ef) a
-> Eff '[] ef ans
interpretStateBy s
s0 (\s
s a
_ -> s -> Eff '[] ef s
forall a. a -> Eff '[] ef a
forall (f :: * -> *) a. Applicative f => a -> f a
pure s
s) State s x -> s -> (s -> x -> Eff '[] ef s) -> Eff '[] ef s
StateInterpreter s (State s) (Eff '[] ef) s
forall s (eh :: [EffectH]) (r :: [* -> *]) ans x.
State s x -> s -> (s -> x -> Eff eh r ans) -> Eff eh r ans
handleState

{- |
Interpret the 'State' effect.

Interpretation is performed recursively with respect to the scopes of unelaborated higher-order effects @eh@.
Note that the state is reset and does not persist beyond the scopes.
-}
evalStateRec :: forall s ef eh. s -> Eff eh (State s ': ef) ~> Eff eh ef
evalStateRec :: forall s (ef :: [* -> *]) (eh :: [EffectH]).
s -> Eff eh (State s : ef) ~> Eff eh ef
evalStateRec s
s0 = s
-> (forall ans x.
    State s x -> s -> (s -> x -> Eff eh ef ans) -> Eff eh ef ans)
-> Eff eh (State s : ef) x
-> Eff eh ef x
forall s (e :: * -> *) (ef :: [* -> *]) (eh :: [EffectH]) a.
s
-> (forall ans x.
    e x -> s -> (s -> x -> Eff eh ef ans) -> Eff eh ef ans)
-> Eff eh (e : ef) a
-> Eff eh ef a
interpretStateRecWith s
s0 State s x -> s -> (s -> x -> Eff eh ef ans) -> Eff eh ef ans
forall s (eh :: [EffectH]) (r :: [* -> *]) ans x.
State s x -> s -> (s -> x -> Eff eh r ans) -> Eff eh r ans
forall ans x.
State s x -> s -> (s -> x -> Eff eh ef ans) -> Eff eh ef ans
handleState

-- | A handler function for the 'State' effect.
handleState :: StateInterpreter s (State s) (Eff eh r) ans
handleState :: forall s (eh :: [EffectH]) (r :: [* -> *]) ans x.
State s x -> s -> (s -> x -> Eff eh r ans) -> Eff eh r ans
handleState = \case
    Put s
s -> \s
_ s -> x -> Eff eh r ans
k -> s -> x -> Eff eh r ans
k s
s ()
    State s x
Get -> \s
s s -> x -> Eff eh r ans
k -> s -> x -> Eff eh r ans
k s
s s
x
s
{-# INLINE handleState #-}

-- | Interpret the 'State' effect based on an IO-fused semantics using t'Data.IORef.IORef'.
runStateIORef
    :: forall s ef eh a
     . (IO <| ef)
    => s
    -> Eff eh (State s ': ef) a
    -> Eff eh ef (s, a)
runStateIORef :: forall s (ef :: [* -> *]) (eh :: [EffectH]) a.
(IO <| ef) =>
s -> Eff eh (State s : ef) a -> Eff eh ef (s, a)
runStateIORef s
s0 Eff eh (State s : ef) a
m = do
    IORef s
ref <- s -> Eff eh ef (IORef s)
forall (m :: * -> *) a. MonadIO m => a -> m (IORef a)
newIORef s
s0
    a
a <-
        Eff eh (State s : ef) a
m Eff eh (State s : ef) a
-> (Eff eh (State s : ef) a -> Eff eh ef a) -> Eff eh ef a
forall a b. a -> (a -> b) -> b
& (State s ~> Eff eh ef) -> Eff eh (State s : ef) ~> Eff eh ef
forall (e :: * -> *) (ef :: [* -> *]) (eh :: [EffectH]).
(e ~> Eff eh ef) -> Eff eh (e : ef) ~> Eff eh ef
interpret \case
            State s x
Get -> IORef x -> Eff eh ef x
forall (m :: * -> *) a. MonadIO m => IORef a -> m a
readIORef IORef s
IORef x
ref
            Put s
s -> IORef s -> s -> Eff eh ef ()
forall (m :: * -> *) a. MonadIO m => IORef a -> a -> m ()
writeIORef IORef s
ref s
s
    IORef s -> Eff eh ef s
forall (m :: * -> *) a. MonadIO m => IORef a -> m a
readIORef IORef s
ref Eff eh ef s -> (s -> (s, a)) -> Eff eh ef (s, a)
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> (,a
a)

{- |
Interpret the 'State' effect based on an IO-fused semantics using t'Data.IORef.IORef'.
Do not include the final state in the return value.
-}
evalStateIORef
    :: forall s ef eh a
     . (IO <| ef)
    => s
    -> Eff eh (State s ': ef) a
    -> Eff eh ef a
evalStateIORef :: forall s (ef :: [* -> *]) (eh :: [EffectH]) a.
(IO <| ef) =>
s -> Eff eh (State s : ef) a -> Eff eh ef a
evalStateIORef s
s0 Eff eh (State s : ef) a
m = do
    IORef s
ref <- s -> Eff eh ef (IORef s)
forall (m :: * -> *) a. MonadIO m => a -> m (IORef a)
newIORef s
s0
    Eff eh (State s : ef) a
m Eff eh (State s : ef) a
-> (Eff eh (State s : ef) a -> Eff eh ef a) -> Eff eh ef a
forall a b. a -> (a -> b) -> b
& (State s ~> Eff eh ef) -> Eff eh (State s : ef) ~> Eff eh ef
forall (e :: * -> *) (ef :: [* -> *]) (eh :: [EffectH]).
(e ~> Eff eh ef) -> Eff eh (e : ef) ~> Eff eh ef
interpret \case
        State s x
Get -> IORef x -> Eff eh ef x
forall (m :: * -> *) a. MonadIO m => IORef a -> m a
readIORef IORef s
IORef x
ref
        Put s
s -> IORef s -> s -> Eff eh ef ()
forall (m :: * -> *) a. MonadIO m => IORef a -> a -> m ()
writeIORef IORef s
ref s
s

-- | Within the given scope, make the state roll back to the beginning of the scope in case of exceptions, etc.
transactState :: forall s ef. (State s <| ef) => Eff '[] ef ~> Eff '[] ef
transactState :: forall s (ef :: [* -> *]).
(State s <| ef) =>
Eff '[] ef ~> Eff '[] ef
transactState Eff '[] ef x
m = do
    s
pre <- forall s (f :: * -> *). SendFOE (State s) f => f s
get @s
    (s
post, x
a) <- s
-> (s -> x -> Eff '[] ef (s, x))
-> StateInterpreter s (State s) (Eff '[] ef) (s, x)
-> Eff '[] ef x
-> Eff '[] ef (s, x)
forall s (e :: * -> *) (ef :: [* -> *]) ans a.
(e <| ef) =>
s
-> (s -> a -> Eff '[] ef ans)
-> StateInterpreter s e (Eff '[] ef) ans
-> Eff '[] ef a
-> Eff '[] ef ans
interposeStateBy s
pre (((s, x) -> Eff '[] ef (s, x)) -> s -> x -> Eff '[] ef (s, x)
forall a b c. ((a, b) -> c) -> a -> b -> c
curry (s, x) -> Eff '[] ef (s, x)
forall a. a -> Eff '[] ef a
forall (f :: * -> *) a. Applicative f => a -> f a
pure) State s x
-> s -> (s -> x -> Eff '[] ef (s, x)) -> Eff '[] ef (s, x)
StateInterpreter s (State s) (Eff '[] ef) (s, x)
forall s (eh :: [EffectH]) (r :: [* -> *]) ans x.
State s x -> s -> (s -> x -> Eff eh r ans) -> Eff eh r ans
handleState Eff '[] ef x
m
    s -> Eff '[] ef ()
forall s (f :: * -> *). SendFOE (State s) f => s -> f ()
put s
post
    x -> Eff '[] ef x
forall a. a -> Eff '[] ef a
forall (f :: * -> *) a. Applicative f => a -> f a
pure x
a

-- | A naive but somewhat slower version of 'runState' that does not use ad-hoc optimizations.
runStateNaive :: forall s ef a. s -> Eff '[] (State s ': ef) a -> Eff '[] ef (s, a)
runStateNaive :: forall s (ef :: [* -> *]) a.
s -> Eff '[] (State s : ef) a -> Eff '[] ef (s, a)
runStateNaive s
s0 Eff '[] (State s : ef) a
m = do
    s -> Eff '[] ef (s, a)
f <-
        Eff '[] (State s : ef) a
m Eff '[] (State s : ef) a
-> (Eff '[] (State s : ef) a
    -> Eff '[] ef (s -> Eff '[] ef (s, a)))
-> Eff '[] ef (s -> Eff '[] ef (s, a))
forall a b. a -> (a -> b) -> b
& (a -> Eff '[] ef (s -> Eff '[] ef (s, a)))
-> Interpreter (State s) (Eff '[] ef) (s -> Eff '[] ef (s, a))
-> Eff '[] (State s : ef) a
-> Eff '[] ef (s -> Eff '[] ef (s, a))
forall (e :: * -> *) (ef :: [* -> *]) ans a.
(a -> Eff '[] ef ans)
-> Interpreter e (Eff '[] ef) ans
-> Eff '[] (e : ef) a
-> Eff '[] ef ans
interpretBy (\a
a -> (s -> Eff '[] ef (s, a)) -> Eff '[] ef (s -> Eff '[] ef (s, a))
forall a. a -> Eff '[] ef a
forall (f :: * -> *) a. Applicative f => a -> f a
pure \s
s -> (s, a) -> Eff '[] ef (s, a)
forall a. a -> Eff '[] ef a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (s
s, a
a)) \case
            State s x
Get -> \x -> Eff '[] ef (s -> Eff '[] ef (s, a))
k -> (s -> Eff '[] ef (s, a)) -> Eff '[] ef (s -> Eff '[] ef (s, a))
forall a. a -> Eff '[] ef a
forall (f :: * -> *) a. Applicative f => a -> f a
pure \s
s -> x -> Eff '[] ef (s -> Eff '[] ef (s, a))
k s
x
s Eff '[] ef (s -> Eff '[] ef (s, a))
-> ((s -> Eff '[] ef (s, a)) -> Eff '[] ef (s, a))
-> Eff '[] ef (s, a)
forall a b. Eff '[] ef a -> (a -> Eff '[] ef b) -> Eff '[] ef b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ((s -> Eff '[] ef (s, a)) -> s -> Eff '[] ef (s, a)
forall a b. (a -> b) -> a -> b
$ s
s)
            Put s
s -> \x -> Eff '[] ef (s -> Eff '[] ef (s, a))
k -> (s -> Eff '[] ef (s, a)) -> Eff '[] ef (s -> Eff '[] ef (s, a))
forall a. a -> Eff '[] ef a
forall (f :: * -> *) a. Applicative f => a -> f a
pure \s
_ -> x -> Eff '[] ef (s -> Eff '[] ef (s, a))
k () Eff '[] ef (s -> Eff '[] ef (s, a))
-> ((s -> Eff '[] ef (s, a)) -> Eff '[] ef (s, a))
-> Eff '[] ef (s, a)
forall a b. Eff '[] ef a -> (a -> Eff '[] ef b) -> Eff '[] ef b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ((s -> Eff '[] ef (s, a)) -> s -> Eff '[] ef (s, a)
forall a b. (a -> b) -> a -> b
$ s
s)
    s -> Eff '[] ef (s, a)
f s
s0

-- | A naive but somewhat slower version of 'evalStateRec' that does not use ad-hoc optimizations.
evalStateNaiveRec :: forall s ef eh. s -> Eff eh (State s ': ef) ~> Eff eh ef
evalStateNaiveRec :: forall s (ef :: [* -> *]) (eh :: [EffectH]).
s -> Eff eh (State s : ef) ~> Eff eh ef
evalStateNaiveRec s
s0 =
    Eff eh (State s : ef) x -> Eff eh (State s : Ask s : ef) x
forall (e1 :: * -> *) (e2 :: * -> *) (ef :: [* -> *])
       (eh :: [EffectH]) x.
Eff eh (e1 : ef) x -> Eff eh (e1 : e2 : ef) x
raiseUnder
        (Eff eh (State s : ef) x -> Eff eh (State s : Ask s : ef) x)
-> (Eff eh (State s : Ask s : ef) x -> Eff eh ef x)
-> Eff eh (State s : ef) x
-> Eff eh ef x
forall {k} (cat :: k -> k -> *) (a :: k) (b :: k) (c :: k).
Category cat =>
cat a b -> cat b c -> cat a c
>>> (forall ans x.
 State s x
 -> (x -> Eff eh (Ask s : ef) ans) -> Eff eh (Ask s : ef) ans)
-> Eff eh (State s : Ask s : ef) x -> Eff eh (Ask s : ef) x
forall (e :: * -> *) (ef :: [* -> *]) (eh :: [EffectH]) a.
(forall ans x. e x -> (x -> Eff eh ef ans) -> Eff eh ef ans)
-> Eff eh (e : ef) a -> Eff eh ef a
interpretRecWith \case
            State s x
Get -> (forall r (f :: * -> *). SendFOE (Ask r) f => f r
ask @s >>=)
            Put s
s -> \x -> Eff eh (Ask s : ef) ans
k -> x -> Eff eh (Ask s : ef) ans
k () Eff eh (Ask s : ef) ans
-> (Eff eh (Ask s : ef) ans -> Eff eh (Ask s : ef) ans)
-> Eff eh (Ask s : ef) ans
forall a b. a -> (a -> b) -> b
& forall (e :: * -> *) (ef :: [* -> *]) (eh :: [EffectH]).
(e <| ef) =>
(e ~> Eff eh ef) -> Eff eh ef ~> Eff eh ef
interpose @(Ask s) \Ask s x
Ask -> x -> Eff eh (Ask s : ef) x
forall a. a -> Eff eh (Ask s : ef) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure s
x
s
        (Eff eh (State s : Ask s : ef) x -> Eff eh (Ask s : ef) x)
-> (Eff eh (Ask s : ef) x -> Eff eh ef x)
-> Eff eh (State s : Ask s : ef) x
-> Eff eh ef x
forall {k} (cat :: k -> k -> *) (a :: k) (b :: k) (c :: k).
Category cat =>
cat a b -> cat b c -> cat a c
>>> forall r (ef :: [* -> *]) (eh :: [EffectH]).
r -> Eff eh (Ask r : ef) ~> Eff eh ef
runAsk @s s
s0