module LambdaCube.STLC.TypeChecker
  ( infer
  ) where

import           Data.List           (uncons)
import           LambdaCube.STLC.Ast

infer :: LCTerm -> LCType
infer :: LCTerm -> LCType
infer = [LCType] -> LCTerm -> LCType
go []
  where
    go :: [LCType] -> LCTerm -> LCType
go [LCType]
l (LCVar Int
x) = LCType
-> ((LCType, [LCType]) -> LCType)
-> Maybe (LCType, [LCType])
-> LCType
forall b a. b -> (a -> b) -> Maybe a -> b
maybe ([Char] -> LCType
forall a. HasCallStack => [Char] -> a
error [Char]
"Out-of-scope variable") (LCType, [LCType]) -> LCType
forall a b. (a, b) -> a
fst (Maybe (LCType, [LCType]) -> LCType)
-> ([LCType] -> Maybe (LCType, [LCType])) -> [LCType] -> LCType
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [LCType] -> Maybe (LCType, [LCType])
forall a. [a] -> Maybe (a, [a])
uncons ([LCType] -> LCType) -> [LCType] -> LCType
forall a b. (a -> b) -> a -> b
$ Int -> [LCType] -> [LCType]
forall a. Int -> [a] -> [a]
drop Int
x [LCType]
l
    go [LCType]
l (LCLam LCType
t LCTerm
b) = LCType
t LCType -> LCType -> LCType
`LCArr` [LCType] -> LCTerm -> LCType
go (LCType
t LCType -> [LCType] -> [LCType]
forall a. a -> [a] -> [a]
: [LCType]
l) LCTerm
b
    go [LCType]
l (LCApp LCTerm
f LCTerm
a)
      | LCArr LCType
at LCType
rt <- [LCType] -> LCTerm -> LCType
go [LCType]
l LCTerm
f
      , LCType
at LCType -> LCType -> Bool
forall a. Eq a => a -> a -> Bool
== [LCType] -> LCTerm -> LCType
go [LCType]
l LCTerm
a
      = LCType
rt
      | Bool
otherwise
      = [Char] -> LCType
forall a. HasCallStack => [Char] -> a
error [Char]
"Function argument type mismatch"