{-# LANGUAGE AllowAmbiguousTypes #-}

-- This Source Code Form is subject to the terms of the Mozilla Public
-- License, v. 2.0. If a copy of the MPL was not distributed with this
-- file, You can obtain one at https://mozilla.org/MPL/2.0/.

{- |
Copyright   :  (c) 2023 Yamada Ryo
License     :  MPL-2.0 (see the file LICENSE)
Maintainer  :  ymdfield@outlook.jp
Stability   :  experimental
Portability :  portable

Interpreter for the t'Control.Effect.Class.State.State' effect class.
-}
module Control.Effect.Interpreter.Heftia.State where

import Control.Arrow ((>>>))
import Control.Effect (type (~>))
import Control.Effect.Hefty (Eff, injectF, interpose, interposeT, interpret, interpretFin, interpretK, raiseUnder)
import Control.Effect.Interpreter.Heftia.Reader (runAsk)
import Control.Freer (Freer)
import Control.Monad.Freer (MonadFreer)
import Control.Monad.State (StateT)
import Control.Monad.Trans.State qualified as T
import Data.Effect.HFunctor (HFunctor)
import Data.Effect.Reader (Ask (Ask), LAsk, ask)
import Data.Effect.State (LState, State (Get, Put), get, put)
import Data.Function ((&))
import Data.Functor ((<&>))
import Data.Hefty.Union (Member, Union)
import Data.Tuple (swap)
import UnliftIO (MonadIO, newIORef, readIORef, writeIORef)

-- | Interpret the 'Get'/'Put' effects using the 'StateT' monad transformer.
runState ::
    forall s r a fr u c.
    (Freer c fr, Union u, c (Eff u fr '[] r), c (StateT s (Eff u fr '[] r)), Applicative (Eff u fr '[] r)) =>
    s ->
    Eff u fr '[] (LState s ': r) a ->
    Eff u fr '[] r (s, a)
runState :: forall s (r :: [SigClass]) a (fr :: SigClass)
       (u :: [SigClass] -> SigClass) (c :: (* -> *) -> Constraint).
(Freer c fr, Union u, c (Eff u fr '[] r),
 c (StateT s (Eff u fr '[] r)), Applicative (Eff u fr '[] r)) =>
s -> Eff u fr '[] (LState s : r) a -> Eff u fr '[] r (s, a)
runState s
s Eff u fr '[] (LState s : r) a
a = forall a b. (a, b) -> (b, a)
swap forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
T.runStateT (forall s (r :: [SigClass]) (fr :: SigClass)
       (u :: [SigClass] -> SigClass) (c :: (* -> *) -> Constraint).
(Freer c fr, Union u, c (StateT s (Eff u fr '[] r)),
 c (Eff u fr '[] r), Applicative (Eff u fr '[] r)) =>
Eff u fr '[] (LState s : r) ~> StateT s (Eff u fr '[] r)
runStateT Eff u fr '[] (LState s : r) a
a) s
s
{-# INLINE runState #-}

evalState ::
    forall s r fr u c.
    (Freer c fr, Union u, c (Eff u fr '[] r), c (StateT s (Eff u fr '[] r)), Applicative (Eff u fr '[] r)) =>
    s ->
    Eff u fr '[] (LState s ': r) ~> Eff u fr '[] r
evalState :: forall s (r :: [SigClass]) (fr :: SigClass)
       (u :: [SigClass] -> SigClass) (c :: (* -> *) -> Constraint).
(Freer c fr, Union u, c (Eff u fr '[] r),
 c (StateT s (Eff u fr '[] r)), Applicative (Eff u fr '[] r)) =>
s -> Eff u fr '[] (LState s : r) ~> Eff u fr '[] r
evalState s
s Eff u fr '[] (LState s : r) x
a = forall a b. (a, b) -> b
snd forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall s (r :: [SigClass]) a (fr :: SigClass)
       (u :: [SigClass] -> SigClass) (c :: (* -> *) -> Constraint).
(Freer c fr, Union u, c (Eff u fr '[] r),
 c (StateT s (Eff u fr '[] r)), Applicative (Eff u fr '[] r)) =>
s -> Eff u fr '[] (LState s : r) a -> Eff u fr '[] r (s, a)
runState s
s Eff u fr '[] (LState s : r) x
a
{-# INLINE evalState #-}

execState ::
    forall s r a fr u c.
    (Freer c fr, Union u, c (Eff u fr '[] r), c (StateT s (Eff u fr '[] r)), Applicative (Eff u fr '[] r)) =>
    s ->
    Eff u fr '[] (LState s ': r) a ->
    Eff u fr '[] r s
execState :: forall s (r :: [SigClass]) a (fr :: SigClass)
       (u :: [SigClass] -> SigClass) (c :: (* -> *) -> Constraint).
(Freer c fr, Union u, c (Eff u fr '[] r),
 c (StateT s (Eff u fr '[] r)), Applicative (Eff u fr '[] r)) =>
s -> Eff u fr '[] (LState s : r) a -> Eff u fr '[] r s
execState s
s Eff u fr '[] (LState s : r) a
a = forall a b. (a, b) -> a
fst forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall s (r :: [SigClass]) a (fr :: SigClass)
       (u :: [SigClass] -> SigClass) (c :: (* -> *) -> Constraint).
(Freer c fr, Union u, c (Eff u fr '[] r),
 c (StateT s (Eff u fr '[] r)), Applicative (Eff u fr '[] r)) =>
s -> Eff u fr '[] (LState s : r) a -> Eff u fr '[] r (s, a)
runState s
s Eff u fr '[] (LState s : r) a
a
{-# INLINE execState #-}

-- | Interpret the 'Get'/'Put' effects using the 'StateT' monad transformer.
runStateT ::
    forall s r fr u c.
    (Freer c fr, Union u, c (StateT s (Eff u fr '[] r)), c (Eff u fr '[] r), Applicative (Eff u fr '[] r)) =>
    Eff u fr '[] (LState s ': r) ~> StateT s (Eff u fr '[] r)
runStateT :: forall s (r :: [SigClass]) (fr :: SigClass)
       (u :: [SigClass] -> SigClass) (c :: (* -> *) -> Constraint).
(Freer c fr, Union u, c (StateT s (Eff u fr '[] r)),
 c (Eff u fr '[] r), Applicative (Eff u fr '[] r)) =>
Eff u fr '[] (LState s : r) ~> StateT s (Eff u fr '[] r)
runStateT =
    forall (e :: SigClass) (r :: [SigClass]) (f :: * -> *)
       (fr :: SigClass) (u :: [SigClass] -> SigClass)
       (c :: (* -> *) -> Constraint).
(Freer c fr, Union u, HeadIns e, c f) =>
(u r Nop ~> f)
-> (UnliftIfSingle e ~> f) -> Eff u fr '[] (e : r) ~> f
interpretFin (\u r Nop x
u -> forall s (m :: * -> *) a. (s -> m (a, s)) -> StateT s m a
T.StateT \s
s -> (,s
s) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (c :: (* -> *) -> Constraint) (f :: SigClass)
       (u :: [SigClass] -> SigClass) (efs :: [SigClass])
       (ehs :: [SigClass]).
Freer c f =>
u efs Nop ~> Eff u f ehs efs
injectF u r Nop x
u) forall (f :: * -> *) s. Applicative f => State s ~> StateT s f
fuseStateEffect

-- | Interpret the 'Get'/'Put' effects using delimited continuations.
runStateK ::
    forall s r a fr u c.
    ( MonadFreer c fr
    , Union u
    , HFunctor (u '[])
    , Member u (Ask s) (LAsk s ': r)
    , c (Eff u fr '[] (LAsk s ': r))
    , Applicative (Eff u fr '[] r)
    ) =>
    s ->
    Eff u fr '[] (LState s ': r) a ->
    Eff u fr '[] r (s, a)
runStateK :: forall s (r :: [SigClass]) a (fr :: SigClass)
       (u :: [SigClass] -> SigClass) (c :: (* -> *) -> Constraint).
(MonadFreer c fr, Union u, HFunctor (u '[]),
 Member u (Ask s) (LAsk s : r), c (Eff u fr '[] (LAsk s : r)),
 Applicative (Eff u fr '[] r)) =>
s -> Eff u fr '[] (LState s : r) a -> Eff u fr '[] r (s, a)
runStateK s
initialState =
    forall (e1 :: SigClass) (e2 :: SigClass) (r :: [SigClass])
       (ehs :: [SigClass]) (fr :: SigClass) (u :: [SigClass] -> SigClass)
       (c :: (* -> *) -> Constraint).
(Freer c fr, Union u, HFunctor (u ehs)) =>
Eff u fr ehs (e2 : r) ~> Eff u fr ehs (e2 : e1 : r)
raiseUnder
        forall {k} (cat :: k -> k -> *) (a :: k) (b :: k) (c :: k).
Category cat =>
cat a b -> cat b c -> cat a c
>>> forall (e :: SigClass) (rs :: [SigClass]) r a (ehs :: [SigClass])
       (fr :: SigClass) (u :: [SigClass] -> SigClass)
       (c :: (* -> *) -> Constraint).
(MonadFreer c fr, Union u, HeadIns e, c (Eff u fr ehs rs)) =>
(a -> Eff u fr ehs rs r)
-> (forall x.
    (x -> Eff u fr ehs rs r)
    -> UnliftIfSingle e x -> Eff u fr ehs rs r)
-> Eff u fr '[] (e : rs) a
-> Eff u fr ehs rs r
interpretK
            (\a
a -> forall r (f :: * -> *). SendIns (Ask r) f => f r
ask forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> (,a
a))
            ( \x -> Eff u fr '[] (LAsk s : r) (s, a)
k -> \case
                State s x
UnliftIfSingle (LState s) x
Get -> x -> Eff u fr '[] (LAsk s : r) (s, a)
k forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall r (f :: * -> *). SendIns (Ask r) f => f r
ask
                Put s
s -> x -> Eff u fr '[] (LAsk s : r) (s, a)
k () forall a b. a -> (a -> b) -> b
& forall (e :: * -> *) (efs :: [SigClass]) (fr :: SigClass)
       (u :: [SigClass] -> SigClass) (c :: (* -> *) -> Constraint).
(Freer c fr, Union u, Member u e efs) =>
(e ~> Eff u fr '[] efs) -> Eff u fr '[] efs ~> Eff u fr '[] efs
interpose @(Ask s) \Ask s x
Ask -> forall (f :: * -> *) a. Applicative f => a -> f a
pure s
s
            )
        forall {k} (cat :: k -> k -> *) (a :: k) (b :: k) (c :: k).
Category cat =>
cat a b -> cat b c -> cat a c
>>> forall r (rs :: [SigClass]) (eh :: [SigClass]) (fr :: SigClass)
       (u :: [SigClass] -> SigClass) (c :: (* -> *) -> Constraint).
(Freer c fr, Union u, Applicative (Eff u fr eh rs),
 HFunctor (u eh)) =>
r -> Eff u fr eh (LAsk r : rs) ~> Eff u fr eh rs
runAsk s
initialState

runStateIORef ::
    forall s r a fr u c.
    (Freer c fr, Union u, MonadIO (Eff u fr '[] r)) =>
    s ->
    Eff u fr '[] (LState s ': r) a ->
    Eff u fr '[] r (s, a)
runStateIORef :: forall s (r :: [SigClass]) a (fr :: SigClass)
       (u :: [SigClass] -> SigClass) (c :: (* -> *) -> Constraint).
(Freer c fr, Union u, MonadIO (Eff u fr '[] r)) =>
s -> Eff u fr '[] (LState s : r) a -> Eff u fr '[] r (s, a)
runStateIORef s
s Eff u fr '[] (LState s : r) a
m = do
    IORef s
ref <- forall (m :: * -> *) a. MonadIO m => a -> m (IORef a)
newIORef s
s
    a
a <-
        Eff u fr '[] (LState s : r) a
m forall a b. a -> (a -> b) -> b
& forall (e :: SigClass) (r :: [SigClass]) (ehs :: [SigClass])
       (fr :: SigClass) (u :: [SigClass] -> SigClass)
       (c :: (* -> *) -> Constraint).
(Freer c fr, Union u, HeadIns e) =>
(UnliftIfSingle e ~> Eff u fr ehs r)
-> Eff u fr '[] (e : r) ~> Eff u fr ehs r
interpret \case
            State s x
UnliftIfSingle (LState s) x
Get -> forall (m :: * -> *) a. MonadIO m => IORef a -> m a
readIORef IORef s
ref
            Put s
s' -> forall (m :: * -> *) a. MonadIO m => IORef a -> a -> m ()
writeIORef IORef s
ref s
s'
    forall (m :: * -> *) a. MonadIO m => IORef a -> m a
readIORef IORef s
ref forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> (,a
a)

transactState ::
    forall s r fr u c.
    (Freer c fr, Union u, Member u (State s) r, Monad (Eff u fr '[] r), c (StateT s (Eff u fr '[] r))) =>
    Eff u fr '[] r ~> Eff u fr '[] r
transactState :: forall s (r :: [SigClass]) (fr :: SigClass)
       (u :: [SigClass] -> SigClass) (c :: (* -> *) -> Constraint).
(Freer c fr, Union u, Member u (State s) r, Monad (Eff u fr '[] r),
 c (StateT s (Eff u fr '[] r))) =>
Eff u fr '[] r ~> Eff u fr '[] r
transactState Eff u fr '[] r x
m = do
    s
pre <- forall s (f :: * -> *). SendIns (State s) f => f s
get @s
    (x
a, s
post) <- forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
T.runStateT (forall (e :: * -> *) (t :: SigClass) (efs :: [SigClass])
       (fr :: SigClass) (u :: [SigClass] -> SigClass)
       (c :: (* -> *) -> Constraint).
(Freer c fr, Union u, MonadTrans t, Member u e efs,
 Monad (Eff u fr '[] efs), c (t (Eff u fr '[] efs))) =>
(e ~> t (Eff u fr '[] efs))
-> Eff u fr '[] efs ~> t (Eff u fr '[] efs)
interposeT forall (f :: * -> *) s. Applicative f => State s ~> StateT s f
fuseStateEffect Eff u fr '[] r x
m) s
pre
    forall s (f :: * -> *). SendIns (State s) f => s -> f ()
put s
post
    forall (f :: * -> *) a. Applicative f => a -> f a
pure x
a

fuseStateEffect :: Applicative f => State s ~> StateT s f
fuseStateEffect :: forall (f :: * -> *) s. Applicative f => State s ~> StateT s f
fuseStateEffect = \case
    State s x
Get -> forall s (m :: * -> *) a. (s -> m (a, s)) -> StateT s m a
T.StateT \s
s -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (s
s, s
s)
    Put s
s -> forall s (m :: * -> *) a. (s -> m (a, s)) -> StateT s m a
T.StateT \s
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ((), s
s)