module Language.Lambda.Untyped.State
  ( EvalState(..),
    Eval(),
    Globals(),
    runEval,
    execEval,
    unsafeExecEval,
    unsafeRunEval,
    globals,
    uniques,
    mkEvalState,
    getGlobals,
    getUniques,
    setGlobals,
    setUniques
  ) where

import Language.Lambda.Shared.Errors
import Language.Lambda.Untyped.Expression 

import Control.Monad.Except
import RIO
import RIO.State
import qualified RIO.Map as Map

-- | The evaluation state
data EvalState name = EvalState
  { forall name. EvalState name -> Globals name
esGlobals :: Globals name,
    forall name. EvalState name -> [name]
esUniques :: [name] -- ^ Unused unique names
  }

-- | A stateful computation
type Eval name
  = StateT (EvalState name)
      (Except LambdaException)

-- | A mapping of global variables to expressions
type Globals name = Map name (LambdaExpr name)

-- | Run an evalualation
runEval :: Eval name result -> EvalState name -> Either LambdaException (result, EvalState name)
runEval :: forall name result.
Eval name result
-> EvalState name
-> Either LambdaException (result, EvalState name)
runEval Eval name result
computation = forall e a. Except e a -> Either e a
runExcept forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT Eval name result
computation

-- | Run an evalualation, throwing away the final state
execEval :: Eval name result -> EvalState name -> Either LambdaException result
execEval :: forall name result.
Eval name result -> EvalState name -> Either LambdaException result
execEval Eval name result
computation = forall e a. Except e a -> Either e a
runExcept forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT Eval name result
computation

-- | Run an evaluation. If the result is an error, throws it
unsafeRunEval :: Eval name result -> EvalState name -> (result, EvalState name)
unsafeRunEval :: forall name result.
Eval name result -> EvalState name -> (result, EvalState name)
unsafeRunEval Eval name result
computation EvalState name
state'
  = case forall name result.
Eval name result
-> EvalState name
-> Either LambdaException (result, EvalState name)
runEval Eval name result
computation EvalState name
state' of
      Left LambdaException
err -> forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ forall a. Show a => a -> [Char]
show LambdaException
err
      Right (result, EvalState name)
res -> (result, EvalState name)
res
  
-- | Run an evaluation, throwing away the final state. If the result is an error, throws it
unsafeExecEval:: Eval name result -> EvalState name -> result
unsafeExecEval :: forall name result. Eval name result -> EvalState name -> result
unsafeExecEval Eval name result
computation EvalState name
state'
  = case forall name result.
Eval name result -> EvalState name -> Either LambdaException result
execEval Eval name result
computation EvalState name
state' of
      Left LambdaException
err -> forall e a. Exception e => e -> a
impureThrow LambdaException
err
      Right result
res -> result
res

-- | Create an EvalState
mkEvalState :: [name] -> EvalState name
mkEvalState :: forall name. [name] -> EvalState name
mkEvalState = forall name. Globals name -> [name] -> EvalState name
EvalState forall k a. Map k a
Map.empty

globals :: Lens' (EvalState name) (Globals name)
globals :: forall name. Lens' (EvalState name) (Globals name)
globals Globals name -> f (Globals name)
f EvalState name
state'
  = (\Globals name
globals' -> EvalState name
state' { esGlobals :: Globals name
esGlobals = Globals name
globals' })
  forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Globals name -> f (Globals name)
f (forall name. EvalState name -> Globals name
esGlobals EvalState name
state')

uniques :: Lens' (EvalState name) [name]
uniques :: forall name. Lens' (EvalState name) [name]
uniques [name] -> f [name]
f EvalState name
state'
  = (\[name]
uniques' -> EvalState name
state' { esUniques :: [name]
esUniques = [name]
uniques' })
  forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [name] -> f [name]
f (forall name. EvalState name -> [name]
esUniques EvalState name
state')

-- | Access globals from the state monad
getGlobals :: Eval name (Globals name)
getGlobals :: forall name. Eval name (Globals name)
getGlobals = forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (forall s a. s -> Getting a s a -> a
^. forall name. Lens' (EvalState name) (Globals name)
globals)

-- | Access unique supply from state monad
getUniques :: Eval name [name]
getUniques :: forall name. Eval name [name]
getUniques = forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (forall s a. s -> Getting a s a -> a
^. forall name. Lens' (EvalState name) [name]
uniques)

setGlobals :: Globals name -> Eval name ()
setGlobals :: forall name. Globals name -> Eval name ()
setGlobals Globals name
globals' = forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (forall a b. a -> (a -> b) -> b
& forall name. Lens' (EvalState name) (Globals name)
globals forall s t a b. ASetter s t a b -> b -> s -> t
.~ Globals name
globals')

setUniques :: [name] -> Eval name ()
setUniques :: forall name. [name] -> Eval name ()
setUniques [name]
uniques' = forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (forall a b. a -> (a -> b) -> b
& forall name. Lens' (EvalState name) [name]
uniques forall s t a b. ASetter s t a b -> b -> s -> t
.~ [name]
uniques')