{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}

-- |
-- Module      : Jikka.Core.Language.RewriteRules
-- Description : checks that there are no name conflicts. / 名前衝突がないか検査します。
-- Copyright   : (c) Kimiyuki Onaka, 2021
-- License     : Apache License 2.0
-- Maintainer  : kimiyuki95@gmail.com
-- Stability   : experimental
-- Portability : portable
module Jikka.Core.Language.NameCheck
  ( namecheckProgram,
    namecheckToplevelExpr,
    namecheckExpr,
  )
where

import Control.Monad.State.Strict
import Jikka.Common.Error
import Jikka.Core.Format (formatType)
import Jikka.Core.Language.Expr
import Jikka.Core.Language.Util

define :: (MonadState [(VarName, Type)] m, MonadError Error m) => VarName -> Type -> m ()
define :: VarName -> Type -> m ()
define VarName
x Type
t = do
  [(VarName, Type)]
env <- m [(VarName, Type)]
forall s (m :: * -> *). MonadState s m => m s
get
  case VarName -> [(VarName, Type)] -> Maybe Type
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup VarName
x [(VarName, Type)]
env of
    Just Type
t' -> String -> m ()
forall (m :: * -> *) a. MonadError Error m => String -> m a
throwInternalError (String -> m ()) -> String -> m ()
forall a b. (a -> b) -> a -> b
$ String
"name conflict: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ VarName -> String
formatVarName VarName
x String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
": " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Type -> String
formatType Type
t String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" and " String -> String -> String
forall a. [a] -> [a] -> [a]
++ VarName -> String
formatVarName VarName
x String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
": " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Type -> String
formatType Type
t'
    Maybe Type
Nothing -> [(VarName, Type)] -> m ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put ([(VarName, Type)] -> m ()) -> [(VarName, Type)] -> m ()
forall a b. (a -> b) -> a -> b
$ (VarName
x, Type
t) (VarName, Type) -> [(VarName, Type)] -> [(VarName, Type)]
forall a. a -> [a] -> [a]
: [(VarName, Type)]
env

namecheckExpr' :: (MonadState [(VarName, Type)] m, MonadError Error m) => Expr -> m ()
namecheckExpr' :: Expr -> m ()
namecheckExpr' = \case
  Var VarName
x -> do
    [(VarName, Type)]
env <- m [(VarName, Type)]
forall s (m :: * -> *). MonadState s m => m s
get
    case VarName -> [(VarName, Type)] -> Maybe Type
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup VarName
x [(VarName, Type)]
env of
      Maybe Type
Nothing -> String -> m ()
forall (m :: * -> *) a. MonadError Error m => String -> m a
throwInternalError (String -> m ()) -> String -> m ()
forall a b. (a -> b) -> a -> b
$ String
"undefined variable: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ VarName -> String
formatVarName VarName
x
      Just Type
_ -> () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
  Lit Literal
_ -> () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
  App Expr
f Expr
e -> do
    Expr -> m ()
forall (m :: * -> *).
(MonadState [(VarName, Type)] m, MonadError Error m) =>
Expr -> m ()
namecheckExpr' Expr
f
    Expr -> m ()
forall (m :: * -> *).
(MonadState [(VarName, Type)] m, MonadError Error m) =>
Expr -> m ()
namecheckExpr' Expr
e
  Lam VarName
x Type
t Expr
e -> do
    VarName -> Type -> m ()
forall (m :: * -> *).
(MonadState [(VarName, Type)] m, MonadError Error m) =>
VarName -> Type -> m ()
define VarName
x Type
t
    Expr -> m ()
forall (m :: * -> *).
(MonadState [(VarName, Type)] m, MonadError Error m) =>
Expr -> m ()
namecheckExpr' Expr
e
  Let VarName
x Type
t Expr
e1 Expr
e2 -> do
    Expr -> m ()
forall (m :: * -> *).
(MonadState [(VarName, Type)] m, MonadError Error m) =>
Expr -> m ()
namecheckExpr' Expr
e1
    VarName -> Type -> m ()
forall (m :: * -> *).
(MonadState [(VarName, Type)] m, MonadError Error m) =>
VarName -> Type -> m ()
define VarName
x Type
t
    Expr -> m ()
forall (m :: * -> *).
(MonadState [(VarName, Type)] m, MonadError Error m) =>
Expr -> m ()
namecheckExpr' Expr
e2
  Assert Expr
e1 Expr
e2 -> do
    Expr -> m ()
forall (m :: * -> *).
(MonadState [(VarName, Type)] m, MonadError Error m) =>
Expr -> m ()
namecheckExpr' Expr
e1
    Expr -> m ()
forall (m :: * -> *).
(MonadState [(VarName, Type)] m, MonadError Error m) =>
Expr -> m ()
namecheckExpr' Expr
e2

namecheckExpr :: MonadError Error m => [(VarName, Type)] -> Expr -> m ()
namecheckExpr :: [(VarName, Type)] -> Expr -> m ()
namecheckExpr [(VarName, Type)]
env Expr
e = String -> m () -> m ()
forall (m :: * -> *) a. MonadError Error m => String -> m a -> m a
wrapError' String
"Jikka.Core.Language.NameCheck.namecheckExpr" (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
  StateT [(VarName, Type)] m () -> [(VarName, Type)] -> m ()
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT (Expr -> StateT [(VarName, Type)] m ()
forall (m :: * -> *).
(MonadState [(VarName, Type)] m, MonadError Error m) =>
Expr -> m ()
namecheckExpr' Expr
e) [(VarName, Type)]
env

namecheckToplevelExpr' :: (MonadState [(VarName, Type)] m, MonadError Error m) => ToplevelExpr -> m ()
namecheckToplevelExpr' :: ToplevelExpr -> m ()
namecheckToplevelExpr' = \case
  ResultExpr Expr
e -> Expr -> m ()
forall (m :: * -> *).
(MonadState [(VarName, Type)] m, MonadError Error m) =>
Expr -> m ()
namecheckExpr' Expr
e
  ToplevelLet VarName
x Type
t Expr
e ToplevelExpr
cont -> do
    Expr -> m ()
forall (m :: * -> *).
(MonadState [(VarName, Type)] m, MonadError Error m) =>
Expr -> m ()
namecheckExpr' Expr
e
    VarName -> Type -> m ()
forall (m :: * -> *).
(MonadState [(VarName, Type)] m, MonadError Error m) =>
VarName -> Type -> m ()
define VarName
x Type
t
    ToplevelExpr -> m ()
forall (m :: * -> *).
(MonadState [(VarName, Type)] m, MonadError Error m) =>
ToplevelExpr -> m ()
namecheckToplevelExpr' ToplevelExpr
cont
  ToplevelLetRec VarName
f [(VarName, Type)]
args Type
ret Expr
body ToplevelExpr
cont -> do
    let t :: Type
t = [Type] -> Type -> Type
curryFunTy (((VarName, Type) -> Type) -> [(VarName, Type)] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (VarName, Type) -> Type
forall a b. (a, b) -> b
snd [(VarName, Type)]
args) Type
ret
    VarName -> Type -> m ()
forall (m :: * -> *).
(MonadState [(VarName, Type)] m, MonadError Error m) =>
VarName -> Type -> m ()
define VarName
f Type
t
    [(VarName, Type)] -> ((VarName, Type) -> m ()) -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(VarName, Type)]
args (((VarName, Type) -> m ()) -> m ())
-> ((VarName, Type) -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \(VarName
x, Type
t) -> do
      VarName -> Type -> m ()
forall (m :: * -> *).
(MonadState [(VarName, Type)] m, MonadError Error m) =>
VarName -> Type -> m ()
define VarName
x Type
t
    Expr -> m ()
forall (m :: * -> *).
(MonadState [(VarName, Type)] m, MonadError Error m) =>
Expr -> m ()
namecheckExpr' Expr
body
    ToplevelExpr -> m ()
forall (m :: * -> *).
(MonadState [(VarName, Type)] m, MonadError Error m) =>
ToplevelExpr -> m ()
namecheckToplevelExpr' ToplevelExpr
cont
  ToplevelAssert Expr
e1 ToplevelExpr
e2 -> do
    Expr -> m ()
forall (m :: * -> *).
(MonadState [(VarName, Type)] m, MonadError Error m) =>
Expr -> m ()
namecheckExpr' Expr
e1
    ToplevelExpr -> m ()
forall (m :: * -> *).
(MonadState [(VarName, Type)] m, MonadError Error m) =>
ToplevelExpr -> m ()
namecheckToplevelExpr' ToplevelExpr
e2

namecheckToplevelExpr :: MonadError Error m => [(VarName, Type)] -> ToplevelExpr -> m ()
namecheckToplevelExpr :: [(VarName, Type)] -> ToplevelExpr -> m ()
namecheckToplevelExpr [(VarName, Type)]
env ToplevelExpr
e = String -> m () -> m ()
forall (m :: * -> *) a. MonadError Error m => String -> m a -> m a
wrapError' String
"Jikka.Core.Language.NameCheck.namecheckToplevelExpr" (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
  StateT [(VarName, Type)] m () -> [(VarName, Type)] -> m ()
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT (ToplevelExpr -> StateT [(VarName, Type)] m ()
forall (m :: * -> *).
(MonadState [(VarName, Type)] m, MonadError Error m) =>
ToplevelExpr -> m ()
namecheckToplevelExpr' ToplevelExpr
e) [(VarName, Type)]
env

namecheckProgram :: MonadError Error m => Program -> m ()
namecheckProgram :: ToplevelExpr -> m ()
namecheckProgram ToplevelExpr
prog = String -> m () -> m ()
forall (m :: * -> *) a. MonadError Error m => String -> m a -> m a
wrapError' String
"Jikka.Core.Language.NameCheck.namecheckProgram" (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
  StateT [(VarName, Type)] m () -> [(VarName, Type)] -> m ()
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT (ToplevelExpr -> StateT [(VarName, Type)] m ()
forall (m :: * -> *).
(MonadState [(VarName, Type)] m, MonadError Error m) =>
ToplevelExpr -> m ()
namecheckToplevelExpr' ToplevelExpr
prog) []