module Language.Lambda.SystemF.TypeCheck where

import Language.Lambda.Shared.Errors (LambdaException(..))
import Language.Lambda.SystemF.Expression
import Language.Lambda.SystemF.State

import Control.Monad.Except (MonadError(..))
import Prettyprinter
import RIO
import qualified RIO.Map as Map

type UniqueSupply n = [n]
type Context' n t = Map n t

typecheck
  :: (Ord name, Pretty name)
  => SystemFExpr name
  -> Typecheck name (Ty name)
typecheck :: forall name.
(Ord name, Pretty name) =>
SystemFExpr name -> Typecheck name (Ty name)
typecheck SystemFExpr name
expr = do
  Context name
ctx <- forall name. Typecheck name (Context name)
getContext
  forall name.
(Ord name, Pretty name) =>
Context name -> SystemFExpr name -> Typecheck name (Ty name)
typecheckTopLevel Context name
ctx SystemFExpr name
expr

typecheckTopLevel
  :: (Ord name, Pretty name)
  => Context name
  -> SystemFExpr name
  -> Typecheck name (Ty name)
typecheckTopLevel :: forall name.
(Ord name, Pretty name) =>
Context name -> SystemFExpr name -> Typecheck name (Ty name)
typecheckTopLevel Context name
ctx (Let name
n SystemFExpr name
expr) = forall name.
(Pretty name, Ord name) =>
Context name
-> name -> SystemFExpr name -> Typecheck name (Ty name)
typecheckLet Context name
ctx name
n SystemFExpr name
expr
typecheckTopLevel Context name
ctx SystemFExpr name
expr = forall name.
(Ord name, Pretty name) =>
Context name -> SystemFExpr name -> Typecheck name (Ty name)
typecheckExpr Context name
ctx SystemFExpr name
expr

typecheckLet
  :: (Pretty name, Ord name)
  => Context name
  -> name
  -> SystemFExpr name
  -> Typecheck name (Ty name)
typecheckLet :: forall name.
(Pretty name, Ord name) =>
Context name
-> name -> SystemFExpr name -> Typecheck name (Ty name)
typecheckLet Context name
ctx name
_ = forall name.
(Ord name, Pretty name) =>
Context name -> SystemFExpr name -> Typecheck name (Ty name)
typecheckExpr Context name
ctx
  
typecheckExpr
  :: (Ord name, Pretty name)
  => Context name
  -> SystemFExpr name
  -> Typecheck name (Ty name)
typecheckExpr :: forall name.
(Ord name, Pretty name) =>
Context name -> SystemFExpr name -> Typecheck name (Ty name)
typecheckExpr Context name
ctx (Var name
v) = forall name.
Ord name =>
Context name -> name -> Typecheck name (Ty name)
typecheckVar Context name
ctx name
v
typecheckExpr Context name
ctx (VarAnn name
v Ty name
ty) = forall name.
(Ord name, Pretty name) =>
Context name -> name -> Ty name -> Typecheck name (Ty name)
typecheckVarAnn Context name
ctx name
v Ty name
ty
typecheckExpr Context name
ctx (Abs name
n Ty name
t SystemFExpr name
body) = forall name.
(Ord name, Pretty name) =>
Context name
-> name -> Ty name -> SystemFExpr name -> Typecheck name (Ty name)
typecheckAbs Context name
ctx name
n Ty name
t SystemFExpr name
body
typecheckExpr Context name
ctx (App SystemFExpr name
e1 SystemFExpr name
e2) = forall name.
(Ord name, Pretty name) =>
Context name
-> SystemFExpr name -> SystemFExpr name -> Typecheck name (Ty name)
typecheckApp Context name
ctx SystemFExpr name
e1 SystemFExpr name
e2
typecheckExpr Context name
ctx (TyAbs name
t SystemFExpr name
body) = forall name.
(Ord name, Pretty name) =>
Context name
-> name -> SystemFExpr name -> Typecheck name (Ty name)
typecheckTyAbs Context name
ctx name
t SystemFExpr name
body
typecheckExpr Context name
ctx (TyApp SystemFExpr name
e Ty name
ty) = forall name.
(Ord name, Pretty name) =>
Context name
-> SystemFExpr name -> Ty name -> Typecheck name (Ty name)
typecheckTyApp Context name
ctx SystemFExpr name
e Ty name
ty
typecheckExpr Context name
_ (Let name
_ SystemFExpr name
_) = forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError LambdaException
ImpossibleError

typecheckVar :: Ord name => Context name -> name -> Typecheck name (Ty name)
typecheckVar :: forall name.
Ord name =>
Context name -> name -> Typecheck name (Ty name)
typecheckVar Context name
ctx = forall {a}.
Maybe (Ty a)
-> StateT (TypecheckState a) (Except LambdaException) (Ty a)
defaultToUnique forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall name. Ord name => Context name -> name -> Maybe (Ty name)
typecheckVar' Context name
ctx
  where defaultToUnique :: Maybe (Ty a)
-> StateT (TypecheckState a) (Except LambdaException) (Ty a)
defaultToUnique = forall b a. b -> (a -> b) -> Maybe a -> b
maybe (forall name. name -> Ty name
TyVar forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall name. Typecheck name name
tyUnique) forall (f :: * -> *) a. Applicative f => a -> f a
pure
    
typecheckVarAnn
  :: (Ord name, Pretty name)
  => Context name
  -> name
  -> Ty  name
  -> Typecheck name (Ty name)
typecheckVarAnn :: forall name.
(Ord name, Pretty name) =>
Context name -> name -> Ty name -> Typecheck name (Ty name)
typecheckVarAnn Context name
ctx name
var Ty name
ty = forall b a. b -> (a -> b) -> Maybe a -> b
maybe (forall (f :: * -> *) a. Applicative f => a -> f a
pure Ty name
ty) Ty name
-> StateT (TypecheckState name) (Except LambdaException) (Ty name)
checkContextType Maybe (Ty name)
maybeTy
  where checkContextType :: Ty name
-> StateT (TypecheckState name) (Except LambdaException) (Ty name)
checkContextType Ty name
ty'
          | Ty name
ty' forall a. Eq a => a -> a -> Bool
== Ty name
ty = forall (f :: * -> *) a. Applicative f => a -> f a
pure Ty name
ty
          | Bool
otherwise = forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError forall a b. (a -> b) -> a -> b
$ forall ty. Pretty ty => ty -> ty -> LambdaException
tyMismatchError Ty name
ty' Ty name
ty
        maybeTy :: Maybe (Ty name)
maybeTy = forall name. Ord name => Context name -> name -> Maybe (Ty name)
typecheckVar' Context name
ctx name
var

typecheckAbs
  :: (Ord name, Pretty name)
  => Context name
  -> name
  -> Ty name
  -> SystemFExpr name
  -> Typecheck name (Ty name)
typecheckAbs :: forall name.
(Ord name, Pretty name) =>
Context name
-> name -> Ty name -> SystemFExpr name -> Typecheck name (Ty name)
typecheckAbs Context name
ctx name
name Ty name
ty SystemFExpr name
body = Ty name
-> Context name
-> StateT (TypecheckState name) (Except LambdaException) (Ty name)
typecheckAbs' Ty name
ty' (forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert name
name (forall name. Ty name -> Binding name
BindTerm Ty name
ty') Context name
ctx)
  where typecheckAbs' :: Ty name
-> Context name
-> StateT (TypecheckState name) (Except LambdaException) (Ty name)
typecheckAbs' (TyForAll name
tyName Ty name
tyBody) Context name
ctx' = do
          Ty name
inner <- forall name.
(Ord name, Pretty name) =>
Context name -> SystemFExpr name -> Typecheck name (Ty name)
typecheckExpr (forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert name
tyName forall name. Binding name
BindTy Context name
ctx') SystemFExpr name
body
          forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall name. name -> Ty name -> Ty name
TyForAll name
tyName (forall name. Ty name -> Ty name -> Ty name
TyArrow Ty name
tyBody Ty name
inner)
        typecheckAbs' Ty name
t Context name
ctx' = forall name. Ty name -> Ty name -> Ty name
TyArrow Ty name
t forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall name.
(Ord name, Pretty name) =>
Context name -> SystemFExpr name -> Typecheck name (Ty name)
typecheckExpr Context name
ctx' SystemFExpr name
body

        ty' :: Ty name
ty' = forall name. Ty name -> Ty name
liftForAlls Ty name
ty
      
typecheckApp
  :: (Ord name, Pretty name)
  => Context name
  -> SystemFExpr name
  -> SystemFExpr name
  -> Typecheck name (Ty name)
typecheckApp :: forall name.
(Ord name, Pretty name) =>
Context name
-> SystemFExpr name -> SystemFExpr name -> Typecheck name (Ty name)
typecheckApp Context name
ctx SystemFExpr name
e1 SystemFExpr name
e2 = do
  -- Typecheck expressions
  Ty name
t1 <- forall name.
(Ord name, Pretty name) =>
Context name -> SystemFExpr name -> Typecheck name (Ty name)
typecheckExpr Context name
ctx SystemFExpr name
e1
  Ty name
t2 <- forall name.
(Ord name, Pretty name) =>
Context name -> SystemFExpr name -> Typecheck name (Ty name)
typecheckExpr Context name
ctx SystemFExpr name
e2

  (Ty name
t1AppInput, Ty name
t1AppOutput) <- case Ty name
t1 of
    (TyArrow Ty name
appInput Ty name
appOutput) -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (Ty name
appInput, Ty name
appOutput)
    (TyForAll name
n1 (TyArrow Ty name
appInput Ty name
_))
      -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall name. name -> Ty name -> Ty name
TyForAll name
n1 Ty name
appInput, Ty name
t2)
    Ty name
_ -> forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError forall a b. (a -> b) -> a -> b
$ Text -> LambdaException
TyMismatchError Text
"Not Arrow"

  -- Verify the output of e1 matches the type of e2
  if Ty name
t1AppInput forall name. Ord name => Ty name -> Ty name -> Bool
`isTyEquivalent` Ty name
t2
    then forall (m :: * -> *) a. Monad m => a -> m a
return Ty name
t1AppOutput
    else forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError forall a b. (a -> b) -> a -> b
$ forall ty. Pretty ty => ty -> ty -> LambdaException
tyMismatchError (forall name. Ty name -> Ty name -> Ty name
TyArrow Ty name
t2 Ty name
t1AppOutput) (forall name. Ty name -> Ty name -> Ty name
TyArrow Ty name
t1 Ty name
t1AppOutput)

typecheckTyAbs
  :: (Ord name, Pretty name)
  => Context name
  -> name
  -> SystemFExpr name
  -> Typecheck name (Ty name)
typecheckTyAbs :: forall name.
(Ord name, Pretty name) =>
Context name
-> name -> SystemFExpr name -> Typecheck name (Ty name)
typecheckTyAbs Context name
ctx name
ty SystemFExpr name
body = forall name. name -> Ty name -> Ty name
TyForAll name
ty forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall name.
(Ord name, Pretty name) =>
Context name -> SystemFExpr name -> Typecheck name (Ty name)
typecheckExpr Context name
ctx' SystemFExpr name
body
  where ctx' :: Context name
ctx' = forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert name
ty forall name. Binding name
BindTy Context name
ctx

typecheckTyApp
  :: (Ord name, Pretty name)
  => Context name
  -> SystemFExpr name
  -> Ty name
  -> Typecheck name (Ty name)
typecheckTyApp :: forall name.
(Ord name, Pretty name) =>
Context name
-> SystemFExpr name -> Ty name -> Typecheck name (Ty name)
typecheckTyApp Context name
ctx SystemFExpr name
expr Ty name
ty = do
  -- Clear in-scope type variables
  let ctx' :: Context name
ctx' = forall a k. (a -> Bool) -> Map k a -> Map k a
Map.filter forall {name}. Binding name -> Bool
isTyBind Context name
ctx
  
  forall name.
(Ord name, Pretty name) =>
Context name -> SystemFExpr name -> Typecheck name (Ty name)
typecheckExpr Context name
ctx' SystemFExpr name
expr forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    TyForAll name
tyName Ty name
tyBody -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall name. Eq name => Ty name -> name -> Ty name -> Ty name
substituteTy Ty name
ty name
tyName Ty name
tyBody
    Ty name
_ -> do
      LambdaException
err <- forall name.
(Ord name, Pretty name) =>
Context name
-> SystemFExpr name -> Ty name -> Typecheck name LambdaException
tyAppMismatchError Context name
ctx SystemFExpr name
expr Ty name
ty
      forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError LambdaException
err

  where
    isTyBind :: Binding name -> Bool
isTyBind Binding name
BindTy = Bool
False
    isTyBind Binding name
_ = Bool
True

typecheckVar' :: Ord name => Context name -> name -> Maybe (Ty name)
typecheckVar' :: forall name. Ord name => Context name -> name -> Maybe (Ty name)
typecheckVar' Context name
ctx name
var = forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup name
var Context name
ctx forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
  BindTerm ty :: Ty name
ty@(TyForAll name
tyName Ty name
tyBody)
    | forall k a. Ord k => k -> Map k a -> Bool
Map.member name
tyName Context name
ctx -> forall a. a -> Maybe a
Just Ty name
tyBody
    | Bool
otherwise -> forall a. a -> Maybe a
Just Ty name
ty
  BindTerm Ty name
ty -> forall a. a -> Maybe a
Just Ty name
ty
  Binding name
BindTy -> forall a. Maybe a
Nothing

liftForAlls :: Ty name -> Ty name
liftForAlls :: forall name. Ty name -> Ty name
liftForAlls Ty name
ty = forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr forall name. name -> Ty name -> Ty name
TyForAll Ty name
res [name]
tyNames
  where ([name]
tyNames, Ty name
res) = forall name. Ty name -> ([name], Ty name)
liftForAlls' Ty name
ty

liftForAlls' :: Ty name -> ([name], Ty name)
liftForAlls' :: forall name. Ty name -> ([name], Ty name)
liftForAlls' (TyVar name
name) = ([], forall name. name -> Ty name
TyVar name
name)
liftForAlls' (TyForAll name
name Ty name
body) = (name
nameforall a. a -> [a] -> [a]
:[name]
names, Ty name
body')
  where ([name]
names, Ty name
body') = forall name. Ty name -> ([name], Ty name)
liftForAlls' Ty name
body
liftForAlls' (TyArrow Ty name
t1 Ty name
t2) = ([name]
n1 forall a. [a] -> [a] -> [a]
++ [name]
n2, forall name. Ty name -> Ty name -> Ty name
TyArrow Ty name
t1' Ty name
t2')
  where ([name]
n1, Ty name
t1') = forall name. Ty name -> ([name], Ty name)
liftForAlls' Ty name
t1
        ([name]
n2, Ty name
t2') = forall name. Ty name -> ([name], Ty name)
liftForAlls' Ty name
t2

isTyEquivalent :: Ord name => Ty name -> Ty name -> Bool
isTyEquivalent :: forall name. Ord name => Ty name -> Ty name -> Bool
isTyEquivalent Ty name
t1 Ty name
t2
  | Ty name
t1 forall a. Eq a => a -> a -> Bool
== Ty name
t2 = Bool
True
  | Bool
otherwise = case (Ty name
t1, Ty name
t2) of
      (TyForAll name
n1 Ty name
t1', TyForAll name
n2 Ty name
t2') -> (name
n1, Ty name
t1') forall name. Ord name => (name, Ty name) -> (name, Ty name) -> Bool
`areForAllsEquivalent` (name
n2, Ty name
t2')
      (Ty name, Ty name)
_ -> Bool
False

areForAllsEquivalent :: Ord name => (name, Ty name) -> (name, Ty name) -> Bool
areForAllsEquivalent :: forall name. Ord name => (name, Ty name) -> (name, Ty name) -> Bool
areForAllsEquivalent (name
n1, Ty name
t1) (name
n2, Ty name
t2) = Ty name
t1 forall a. Eq a => a -> a -> Bool
== forall name. Eq name => Ty name -> name -> Ty name -> Ty name
substituteTy (forall name. name -> Ty name
TyVar name
n1) name
n2 Ty name
t2

tyUnique :: Typecheck name name
tyUnique :: forall name. Typecheck name name
tyUnique = forall name. Typecheck name [name]
getTyUniques forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall {b}.
[b] -> StateT (TypecheckState b) (Except LambdaException) b
tyUnique'
    where tyUnique' :: [b] -> StateT (TypecheckState b) (Except LambdaException) b
tyUnique' (b
u:[b]
us) = forall name. [name] -> Typecheck name ()
setTyUniques [b]
us forall (f :: * -> *) a b. Functor f => f a -> b -> f b
$> b
u
          tyUnique' [b]
_ = forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError LambdaException
ImpossibleError

tyMismatchError
  :: Pretty ty => ty -> ty -> LambdaException
tyMismatchError :: forall ty. Pretty ty => ty -> ty -> LambdaException
tyMismatchError ty
expected ty
actual
  = Text -> LambdaException
TyMismatchError
  forall a b. (a -> b) -> a -> b
$ Text
"Couldn't match expected type "
  forall a. Semigroup a => a -> a -> a
<> forall pretty. Pretty pretty => pretty -> Text
prettyPrint ty
expected
  forall a. Semigroup a => a -> a -> a
<> Text
" with actual type "
  forall a. Semigroup a => a -> a -> a
<> forall pretty. Pretty pretty => pretty -> Text
prettyPrint ty
actual

tyAppMismatchError
  :: (Ord name, Pretty name)
  => Context name
  -> SystemFExpr name
  -> Ty name
  -> Typecheck name LambdaException
tyAppMismatchError :: forall name.
(Ord name, Pretty name) =>
Context name
-> SystemFExpr name -> Ty name -> Typecheck name LambdaException
tyAppMismatchError Context name
ctx SystemFExpr name
expr Ty name
appTy = Ty name -> LambdaException
tyAppMismatchError' forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall name.
(Ord name, Pretty name) =>
Context name -> SystemFExpr name -> Typecheck name (Ty name)
typecheckExpr Context name
ctx SystemFExpr name
expr
  where tyAppMismatchError' :: Ty name -> LambdaException
tyAppMismatchError' Ty name
actual = Text -> LambdaException
TyMismatchError
          forall a b. (a -> b) -> a -> b
$ Text
"Cannot apply type "
          forall a. Semigroup a => a -> a -> a
<> forall pretty. Pretty pretty => pretty -> Text
prettyPrint Ty name
appTy
          forall a. Semigroup a => a -> a -> a
<> Text
" to non-polymorphic type "
          forall a. Semigroup a => a -> a -> a
<> forall pretty. Pretty pretty => pretty -> Text
prettyPrint Ty name
actual