module LambdaCube.SystemFw.TypeChecker
  ( reduceType

  , infer
  , inferKind
  ) where

import           Data.List                        (uncons)
import           LambdaCube.SystemFw.Ast
import           LambdaCube.SystemFw.Substitution

reduceType :: LCType -> LCType
reduceType :: LCType -> LCType
reduceType = LCType -> LCType
go
  where
    go :: LCType -> LCType
go LCType
LCBase = LCType
LCBase
    go e :: LCType
e@(LCTVar Int
_) = LCType
e
    go (LCArr LCType
a LCType
b) = LCType -> LCType
go LCType
a LCType -> LCType -> LCType
`LCArr` LCType -> LCType
go LCType
b
    go (LCUniv LCKind
k LCType
a) = LCKind -> LCType -> LCType
LCUniv LCKind
k (LCType -> LCType) -> LCType -> LCType
forall a b. (a -> b) -> a -> b
$ LCType -> LCType
go LCType
a
    go (LCTTLam LCKind
k LCType
b) = LCKind -> LCType -> LCType
LCTTLam LCKind
k (LCType -> LCType) -> LCType -> LCType
forall a b. (a -> b) -> a -> b
$ LCType -> LCType
go LCType
b
    go (LCTTApp LCType
f LCType
a)
      | LCTTLam LCKind
_ LCType
b <- LCType -> LCType
go LCType
f
      , LCType
v <- LCType -> LCType
go LCType
a
      = LCType -> LCType
go (LCType -> LCType) -> LCType -> LCType
forall a b. (a -> b) -> a -> b
$ LCType -> Int -> LCType -> LCType
substituteTypeInType LCType
v Int
0 LCType
b
      | Bool
otherwise
      = [Char] -> LCType
forall a. HasCallStack => [Char] -> a
error [Char]
"Did you really kind check this?"

infer :: LCTerm -> LCType
infer :: LCTerm -> LCType
infer = [LCKind] -> [LCType] -> LCTerm -> LCType
go [] []
  where
    go :: [LCKind] -> [LCType] -> LCTerm -> LCType
go [LCKind]
_  [LCType]
tl (LCVar Int
n) = LCType
-> ((LCType, [LCType]) -> LCType)
-> Maybe (LCType, [LCType])
-> LCType
forall b a. b -> (a -> b) -> Maybe a -> b
maybe ([Char] -> LCType
forall a. HasCallStack => [Char] -> a
error [Char]
"Out-of-scope variable") (LCType, [LCType]) -> LCType
forall a b. (a, b) -> a
fst (Maybe (LCType, [LCType]) -> LCType)
-> ([LCType] -> Maybe (LCType, [LCType])) -> [LCType] -> LCType
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [LCType] -> Maybe (LCType, [LCType])
forall a. [a] -> Maybe (a, [a])
uncons ([LCType] -> LCType) -> [LCType] -> LCType
forall a b. (a -> b) -> a -> b
$ Int -> [LCType] -> [LCType]
forall a. Int -> [a] -> [a]
drop Int
n [LCType]
tl
    go [LCKind]
kl [LCType]
tl (LCLam LCType
t LCTerm
b)
      | LCKind
LCStar <- [LCKind] -> LCType -> LCKind
inferKind [LCKind]
kl LCType
t
      = LCType
v LCType -> LCType -> LCType
`LCArr` [LCKind] -> [LCType] -> LCTerm -> LCType
go [LCKind]
kl (LCType
v LCType -> [LCType] -> [LCType]
forall a. a -> [a] -> [a]
: [LCType]
tl) LCTerm
b
      | Bool
otherwise
      = [Char] -> LCType
forall a. HasCallStack => [Char] -> a
error [Char]
"Function argument kind mismatch"
      where
        v :: LCType
v = LCType -> LCType
reduceType LCType
t
    go [LCKind]
kl [LCType]
tl (LCApp LCTerm
f LCTerm
a)
      | LCArr LCType
at LCType
rt <- [LCKind] -> [LCType] -> LCTerm -> LCType
go [LCKind]
kl [LCType]
tl LCTerm
f
      , LCType
at LCType -> LCType -> Bool
forall a. Eq a => a -> a -> Bool
== [LCKind] -> [LCType] -> LCTerm -> LCType
go [LCKind]
kl [LCType]
tl LCTerm
a
      = LCType
rt
      | Bool
otherwise
      = [Char] -> LCType
forall a. HasCallStack => [Char] -> a
error [Char]
"Function argument type mismatch"
    go [LCKind]
kl [LCType]
tl (LCTLam LCKind
k LCTerm
b) = LCKind -> LCType -> LCType
LCUniv LCKind
k (LCType -> LCType) -> LCType -> LCType
forall a b. (a -> b) -> a -> b
$ [LCKind] -> [LCType] -> LCTerm -> LCType
go (LCKind
k LCKind -> [LCKind] -> [LCKind]
forall a. a -> [a] -> [a]
: [LCKind]
kl) ((LCType -> LCType) -> [LCType] -> [LCType]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((LCType, Int) -> LCType
shiftType ((LCType, Int) -> LCType)
-> (LCType -> (LCType, Int)) -> LCType -> LCType
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (, Int
1)) [LCType]
tl) LCTerm
b
    go [LCKind]
kl [LCType]
tl (LCTApp LCTerm
f LCType
t)
      | LCUniv LCKind
tk LCType
rt <- [LCKind] -> [LCType] -> LCTerm -> LCType
go [LCKind]
kl [LCType]
tl LCTerm
f
      , LCKind
tk LCKind -> LCKind -> Bool
forall a. Eq a => a -> a -> Bool
== [LCKind] -> LCType -> LCKind
inferKind [LCKind]
kl LCType
t
      = LCType -> Int -> LCType -> LCType
substituteTypeInType LCType
v Int
0 LCType
rt
      | Bool
otherwise
      = [Char] -> LCType
forall a. HasCallStack => [Char] -> a
error [Char]
"Function argument kind mismatch"
      where
        v :: LCType
v = LCType -> LCType
reduceType LCType
t

inferKind :: [LCKind] -> LCType -> LCKind
inferKind :: [LCKind] -> LCType -> LCKind
inferKind = [LCKind] -> LCType -> LCKind
go
  where
    go :: [LCKind] -> LCType -> LCKind
go [LCKind]
_  LCType
LCBase = LCKind
LCStar
    go [LCKind]
kl (LCTVar Int
n) = LCKind
-> ((LCKind, [LCKind]) -> LCKind)
-> Maybe (LCKind, [LCKind])
-> LCKind
forall b a. b -> (a -> b) -> Maybe a -> b
maybe ([Char] -> LCKind
forall a. HasCallStack => [Char] -> a
error [Char]
"Out-of-scope variable") (LCKind, [LCKind]) -> LCKind
forall a b. (a, b) -> a
fst (Maybe (LCKind, [LCKind]) -> LCKind)
-> ([LCKind] -> Maybe (LCKind, [LCKind])) -> [LCKind] -> LCKind
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [LCKind] -> Maybe (LCKind, [LCKind])
forall a. [a] -> Maybe (a, [a])
uncons ([LCKind] -> LCKind) -> [LCKind] -> LCKind
forall a b. (a -> b) -> a -> b
$ Int -> [LCKind] -> [LCKind]
forall a. Int -> [a] -> [a]
drop Int
n [LCKind]
kl
    go [LCKind]
kl (LCArr LCType
a LCType
b)
      | LCKind
LCStar <- [LCKind] -> LCType -> LCKind
go [LCKind]
kl LCType
a
      , LCKind
LCStar <- [LCKind] -> LCType -> LCKind
go [LCKind]
kl LCType
b
      = LCKind
LCStar
      | Bool
otherwise
      = [Char] -> LCKind
forall a. HasCallStack => [Char] -> a
error [Char]
"Arrow kind mismatch"
    go [LCKind]
kl (LCUniv LCKind
k LCType
a) = [LCKind] -> LCType -> LCKind
go (LCKind
k LCKind -> [LCKind] -> [LCKind]
forall a. a -> [a] -> [a]
: [LCKind]
kl) LCType
a
    go [LCKind]
kl (LCTTLam LCKind
k LCType
b) = LCKind -> LCKind -> LCKind
LCKArr LCKind
k (LCKind -> LCKind) -> LCKind -> LCKind
forall a b. (a -> b) -> a -> b
$ [LCKind] -> LCType -> LCKind
go (LCKind
k LCKind -> [LCKind] -> [LCKind]
forall a. a -> [a] -> [a]
: [LCKind]
kl) LCType
b
    go [LCKind]
kl (LCTTApp LCType
f LCType
a)
      | LCKArr LCKind
ak LCKind
rk <- [LCKind] -> LCType -> LCKind
go [LCKind]
kl LCType
f
      , LCKind
ak LCKind -> LCKind -> Bool
forall a. Eq a => a -> a -> Bool
== [LCKind] -> LCType -> LCKind
go [LCKind]
kl LCType
a
      = LCKind
rk
      | Bool
otherwise
      = [Char] -> LCKind
forall a. HasCallStack => [Char] -> a
error [Char]
"Function argument kind mismatch"