{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeOperators #-}

{- | The state effect
-}

module Effects.State (
    State(..)
  , get
  , put
  , modify
  , handleState) where

import Prog ( discharge, Member(inj), Prog(..) )

-- | The state effect
data State s a where
  -- | Get the current state
  Get :: State s s
  -- | Set the current state
  Put :: s -> State s ()

-- | Wrapper function for @Get@
get :: Member (State s) es => Prog es s
get :: forall s (es :: [* -> *]). Member (State s) es => Prog es s
get = EffectSum es s -> (s -> Prog es s) -> Prog es s
forall (es :: [* -> *]) x a.
EffectSum es x -> (x -> Prog es a) -> Prog es a
Op (State s s -> EffectSum es s
forall (e :: * -> *) (es :: [* -> *]) x.
Member e es =>
e x -> EffectSum es x
inj State s s
forall s. State s s
Get) s -> Prog es s
forall a (es :: [* -> *]). a -> Prog es a
Val

-- | Wrapper function for @Set@
put :: (Member (State s) es) => s -> Prog es ()
put :: forall s (es :: [* -> *]). Member (State s) es => s -> Prog es ()
put s
s = EffectSum es () -> (() -> Prog es ()) -> Prog es ()
forall (es :: [* -> *]) x a.
EffectSum es x -> (x -> Prog es a) -> Prog es a
Op (State s () -> EffectSum es ()
forall (e :: * -> *) (es :: [* -> *]) x.
Member e es =>
e x -> EffectSum es x
inj (State s () -> EffectSum es ()) -> State s () -> EffectSum es ()
forall a b. (a -> b) -> a -> b
$ s -> State s ()
forall s. s -> State s ()
Put s
s) () -> Prog es ()
forall a (es :: [* -> *]). a -> Prog es a
Val

-- | Wrapper function for apply a function to the state
modify :: Member (State s) es => (s -> s) -> Prog es ()
modify :: forall s (es :: [* -> *]).
Member (State s) es =>
(s -> s) -> Prog es ()
modify s -> s
f = Prog es s
forall s (es :: [* -> *]). Member (State s) es => Prog es s
get Prog es s -> (s -> Prog es ()) -> Prog es ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= s -> Prog es ()
forall s (es :: [* -> *]). Member (State s) es => s -> Prog es ()
put (s -> Prog es ()) -> (s -> s) -> s -> Prog es ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. s -> s
f

-- | Handle the @State s@ effect
handleState
  -- | Initial state
  :: s
  -> Prog (State s ': es) a
  -- | (Output, final state)
  -> Prog es (a, s)
handleState :: forall s (es :: [* -> *]) a.
s -> Prog (State s : es) a -> Prog es (a, s)
handleState s
s Prog (State s : es) a
m = s -> Prog (State s : es) a -> Prog es (a, s)
forall s (es :: [* -> *]) a.
s -> Prog (State s : es) a -> Prog es (a, s)
loop s
s Prog (State s : es) a
m where
  loop :: s -> Prog (State s ': es) a -> Prog es (a, s)
  loop :: forall s (es :: [* -> *]) a.
s -> Prog (State s : es) a -> Prog es (a, s)
loop s
s (Val a
x) = (a, s) -> Prog es (a, s)
forall (m :: * -> *) a. Monad m => a -> m a
return (a
x, s
s)
  loop s
s (Op EffectSum (State s : es) x
u x -> Prog (State s : es) a
k) = case EffectSum (State s : es) x -> Either (EffectSum es x) (State s x)
forall (e :: * -> *) (es :: [* -> *]) x.
EffectSum (e : es) x -> Either (EffectSum es x) (e x)
discharge EffectSum (State s : es) x
u of
    Right State s x
Get      -> s -> Prog (State s : es) a -> Prog es (a, s)
forall s (es :: [* -> *]) a.
s -> Prog (State s : es) a -> Prog es (a, s)
loop s
s (x -> Prog (State s : es) a
k s
x
s)
    Right (Put s
s') -> s -> Prog (State s : es) a -> Prog es (a, s)
forall s (es :: [* -> *]) a.
s -> Prog (State s : es) a -> Prog es (a, s)
loop s
s' (x -> Prog (State s : es) a
k ())
    Left  EffectSum es x
u'         -> EffectSum es x -> (x -> Prog es (a, s)) -> Prog es (a, s)
forall (es :: [* -> *]) x a.
EffectSum es x -> (x -> Prog es a) -> Prog es a
Op EffectSum es x
u' (s -> Prog (State s : es) a -> Prog es (a, s)
forall s (es :: [* -> *]) a.
s -> Prog (State s : es) a -> Prog es (a, s)
loop s
s (Prog (State s : es) a -> Prog es (a, s))
-> (x -> Prog (State s : es) a) -> x -> Prog es (a, s)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. x -> Prog (State s : es) a
k)