module Language.Lambda.SystemF.State
  ( TypecheckState(..),
    Typecheck(),
    Context(),
    Binding(..),
    Globals(),
    runTypecheck,
    execTypecheck,
    unsafeRunTypecheck,
    unsafeExecTypecheck,
    mkTypecheckState,
    _context,
    _globals,
    _varUniques,
    _tyUniques,
    getContext,
    getGlobals,
    getVarUniques,
    getTyUniques,
    modifyGlobals,
    modifyVarUniques,
    modifyTyUniques,
    setGlobals,
    setVarUniques,
    setTyUniques
  ) where

import Language.Lambda.Shared.Errors (LambdaException(..))
import Language.Lambda.SystemF.Expression

import Control.Monad.Except (Except(), runExcept)
import RIO
import RIO.State
import qualified RIO.Map as Map

data TypecheckState name = TypecheckState
  { forall name. TypecheckState name -> Globals name
tsGlobals :: Globals name,
    forall name. TypecheckState name -> [name]
tsVarUniques :: [name],  -- ^ A unique supply of term-level variables
    forall name. TypecheckState name -> [name]
tsTyUniques :: [name]    -- ^ A unique supply of type-level variables
  } deriving (TypecheckState name -> TypecheckState name -> Bool
forall name.
Eq name =>
TypecheckState name -> TypecheckState name -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: TypecheckState name -> TypecheckState name -> Bool
$c/= :: forall name.
Eq name =>
TypecheckState name -> TypecheckState name -> Bool
== :: TypecheckState name -> TypecheckState name -> Bool
$c== :: forall name.
Eq name =>
TypecheckState name -> TypecheckState name -> Bool
Eq, Int -> TypecheckState name -> ShowS
forall name. Show name => Int -> TypecheckState name -> ShowS
forall name. Show name => [TypecheckState name] -> ShowS
forall name. Show name => TypecheckState name -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [TypecheckState name] -> ShowS
$cshowList :: forall name. Show name => [TypecheckState name] -> ShowS
show :: TypecheckState name -> String
$cshow :: forall name. Show name => TypecheckState name -> String
showsPrec :: Int -> TypecheckState name -> ShowS
$cshowsPrec :: forall name. Show name => Int -> TypecheckState name -> ShowS
Show)

type Typecheck name
  = StateT (TypecheckState name)
      (Except LambdaException)

type Globals name = Map name (TypedExpr name)

type Context name = Map name (Binding name)

data Binding name
  = BindTerm (Ty name)
  | BindTy
  deriving (Binding name -> Binding name -> Bool
forall name. Eq name => Binding name -> Binding name -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Binding name -> Binding name -> Bool
$c/= :: forall name. Eq name => Binding name -> Binding name -> Bool
== :: Binding name -> Binding name -> Bool
$c== :: forall name. Eq name => Binding name -> Binding name -> Bool
Eq, Int -> Binding name -> ShowS
forall name. Show name => Int -> Binding name -> ShowS
forall name. Show name => [Binding name] -> ShowS
forall name. Show name => Binding name -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Binding name] -> ShowS
$cshowList :: forall name. Show name => [Binding name] -> ShowS
show :: Binding name -> String
$cshow :: forall name. Show name => Binding name -> String
showsPrec :: Int -> Binding name -> ShowS
$cshowsPrec :: forall name. Show name => Int -> Binding name -> ShowS
Show)

runTypecheck
  :: Typecheck name result
  -> TypecheckState name
  -> Either LambdaException (result, TypecheckState name)
runTypecheck :: forall name result.
Typecheck name result
-> TypecheckState name
-> Either LambdaException (result, TypecheckState name)
runTypecheck Typecheck name result
computation = forall e a. Except e a -> Either e a
runExcept forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT Typecheck name result
computation

execTypecheck
  :: Typecheck name result
  -> TypecheckState name
  -> Either LambdaException result
execTypecheck :: forall name result.
Typecheck name result
-> TypecheckState name -> Either LambdaException result
execTypecheck Typecheck name result
computation = forall e a. Except e a -> Either e a
runExcept forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT Typecheck name result
computation

unsafeRunTypecheck
  :: Typecheck name result
  -> TypecheckState name
  -> (result, TypecheckState name)
unsafeRunTypecheck :: forall name result.
Typecheck name result
-> TypecheckState name -> (result, TypecheckState name)
unsafeRunTypecheck Typecheck name result
computation TypecheckState name
state' = forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either forall e a. Exception e => e -> a
impureThrow forall a. a -> a
id Either LambdaException (result, TypecheckState name)
tcResult
  where tcResult :: Either LambdaException (result, TypecheckState name)
tcResult = forall name result.
Typecheck name result
-> TypecheckState name
-> Either LambdaException (result, TypecheckState name)
runTypecheck Typecheck name result
computation TypecheckState name
state'

unsafeExecTypecheck :: Typecheck name result -> TypecheckState name -> result
unsafeExecTypecheck :: forall name result.
Typecheck name result -> TypecheckState name -> result
unsafeExecTypecheck Typecheck name result
computation TypecheckState name
state' = forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either forall e a. Exception e => e -> a
impureThrow forall a. a -> a
id Either LambdaException result
tcResult
  where tcResult :: Either LambdaException result
tcResult = forall name result.
Typecheck name result
-> TypecheckState name -> Either LambdaException result
execTypecheck Typecheck name result
computation TypecheckState name
state'

mkTypecheckState :: [name] -> [name] -> TypecheckState name
mkTypecheckState :: forall name. [name] -> [name] -> TypecheckState name
mkTypecheckState = forall name.
Globals name -> [name] -> [name] -> TypecheckState name
TypecheckState forall k a. Map k a
Map.empty

_context :: SimpleGetter (TypecheckState name) (Context name)
_context :: forall name. SimpleGetter (TypecheckState name) (Context name)
_context = forall s a. (s -> a) -> SimpleGetter s a
to (forall name. Globals name -> Context name
getContext' forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall name. TypecheckState name -> Globals name
tsGlobals)
  where getContext' :: Globals name -> Context name
        getContext' :: forall name. Globals name -> Context name
getContext' = forall a b k. (a -> b) -> Map k a -> Map k b
Map.map (\TypedExpr name
expr -> forall name. Ty name -> Binding name
BindTerm (TypedExpr name
expr forall s a. s -> Getting a s a -> a
^. forall name. Lens' (TypedExpr name) (Ty name)
_ty))
        
_globals :: Lens' (TypecheckState name) (Globals name)
_globals :: forall name. Lens' (TypecheckState name) (Globals name)
_globals Globals name -> f (Globals name)
f TypecheckState name
state' = (\Globals name
globals' -> TypecheckState name
state' { tsGlobals :: Globals name
tsGlobals = Globals name
globals' })
  forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Globals name -> f (Globals name)
f (forall name. TypecheckState name -> Globals name
tsGlobals TypecheckState name
state')

_varUniques :: Lens' (TypecheckState name) [name]
_varUniques :: forall name. Lens' (TypecheckState name) [name]
_varUniques [name] -> f [name]
f TypecheckState name
state' = (\[name]
uniques' -> TypecheckState name
state' { tsVarUniques :: [name]
tsVarUniques = [name]
uniques' })
  forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [name] -> f [name]
f (forall name. TypecheckState name -> [name]
tsVarUniques TypecheckState name
state')

_tyUniques :: Lens' (TypecheckState name) [name]
_tyUniques :: forall name. Lens' (TypecheckState name) [name]
_tyUniques [name] -> f [name]
f TypecheckState name
state' = (\[name]
uniques' -> TypecheckState name
state' { tsTyUniques :: [name]
tsTyUniques = [name]
uniques' })
  forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [name] -> f [name]
f (forall name. TypecheckState name -> [name]
tsTyUniques TypecheckState name
state')

getVarUniques :: Typecheck name [name]
getVarUniques :: forall name. Typecheck name [name]
getVarUniques = forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (forall s a. s -> Getting a s a -> a
^. forall name. Lens' (TypecheckState name) [name]
_varUniques)

getTyUniques :: Typecheck name [name]
getTyUniques :: forall name. Typecheck name [name]
getTyUniques = forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (forall s a. s -> Getting a s a -> a
^. forall name. Lens' (TypecheckState name) [name]
_tyUniques)

getContext :: Typecheck name (Context name)
getContext :: forall name. Typecheck name (Context name)
getContext = forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (forall s a. s -> Getting a s a -> a
^. forall name. SimpleGetter (TypecheckState name) (Context name)
_context)

getGlobals :: Typecheck name (Globals name)
getGlobals :: forall name. Typecheck name (Globals name)
getGlobals = forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (forall s a. s -> Getting a s a -> a
^. forall name. Lens' (TypecheckState name) (Globals name)
_globals)

modifyGlobals :: (Globals name -> Globals name) -> Typecheck name ()
modifyGlobals :: forall name. (Globals name -> Globals name) -> Typecheck name ()
modifyGlobals Globals name -> Globals name
f = forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ forall name. Lens' (TypecheckState name) (Globals name)
_globals forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
%~ Globals name -> Globals name
f

modifyVarUniques :: ([name] -> [name]) -> Typecheck name ()
modifyVarUniques :: forall name. ([name] -> [name]) -> Typecheck name ()
modifyVarUniques [name] -> [name]
f = forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ forall name. Lens' (TypecheckState name) [name]
_varUniques forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
%~ [name] -> [name]
f

modifyTyUniques :: ([name] -> [name]) -> Typecheck name ()
modifyTyUniques :: forall name. ([name] -> [name]) -> Typecheck name ()
modifyTyUniques [name] -> [name]
f = forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ forall name. Lens' (TypecheckState name) [name]
_tyUniques forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
%~ [name] -> [name]
f

setVarUniques :: [name] -> Typecheck name ()
setVarUniques :: forall name. [name] -> Typecheck name ()
setVarUniques [name]
uniques' = forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ forall name. Lens' (TypecheckState name) [name]
_varUniques forall s t a b. ASetter s t a b -> b -> s -> t
.~ [name]
uniques'

setTyUniques :: [name] -> Typecheck name ()
setTyUniques :: forall name. [name] -> Typecheck name ()
setTyUniques [name]
uniques' = forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ forall name. Lens' (TypecheckState name) [name]
_tyUniques forall s t a b. ASetter s t a b -> b -> s -> t
.~ [name]
uniques'

setGlobals :: Globals name -> Typecheck name ()
setGlobals :: forall name. Globals name -> Typecheck name ()
setGlobals Globals name
globals' = forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ forall name. Lens' (TypecheckState name) (Globals name)
_globals forall s t a b. ASetter s t a b -> b -> s -> t
.~ Globals name
globals'