-- |
-- SPDX-License-Identifier: BSD-3-Clause
--
-- Kind checking for the Swarm language.
module Swarm.Language.Kindcheck (
  KindError (..),
  checkPolytypeKind,
  checkKind,
) where

import Control.Algebra (Has)
import Control.Effect.Reader (Reader, ask)
import Control.Effect.Throw (Throw, throwError)
import Control.Monad.Extra (unlessM)
import Data.Fix (Fix (..))
import Swarm.Language.Types

-- | Kind checking errors that can occur.
data KindError
  = -- | A type constructor expects n arguments, but was given these
    --   arguments instead.
    ArityMismatch TyCon Int [Type]
  | -- | An undefined type constructor was encountered in the given type.
    UndefinedTyCon TyCon Type
  | -- | A trivial recursive type (one that does not use its bound
    --   variable) was encountered.
    TrivialRecTy Var Type
  | -- | A vacuous recursive type (one that expands immediately to
    --   itself) was encountered.
    VacuousRecTy Var Type
  deriving (KindError -> KindError -> Bool
(KindError -> KindError -> Bool)
-> (KindError -> KindError -> Bool) -> Eq KindError
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: KindError -> KindError -> Bool
== :: KindError -> KindError -> Bool
$c/= :: KindError -> KindError -> Bool
/= :: KindError -> KindError -> Bool
Eq, Int -> KindError -> ShowS
[KindError] -> ShowS
KindError -> String
(Int -> KindError -> ShowS)
-> (KindError -> String)
-> ([KindError] -> ShowS)
-> Show KindError
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> KindError -> ShowS
showsPrec :: Int -> KindError -> ShowS
$cshow :: KindError -> String
show :: KindError -> String
$cshowList :: [KindError] -> ShowS
showList :: [KindError] -> ShowS
Show)

-- | Check that a polytype is well-kinded.
checkPolytypeKind :: (Has (Reader TDCtx) sig m, Has (Throw KindError) sig m) => Polytype -> m TydefInfo
checkPolytypeKind :: forall (sig :: (* -> *) -> * -> *) (m :: * -> *).
(Has (Reader TDCtx) sig m, Has (Throw KindError) sig m) =>
Polytype -> m TydefInfo
checkPolytypeKind pty :: Polytype
pty@(Forall [Var]
xs Type
t) = Polytype -> Arity -> TydefInfo
TydefInfo Polytype
pty (Int -> Arity
Arity (Int -> Arity) -> Int -> Arity
forall a b. (a -> b) -> a -> b
$ [Var] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Var]
xs) TydefInfo -> m () -> m TydefInfo
forall a b. a -> m b -> m a
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ Type -> m ()
forall (sig :: (* -> *) -> * -> *) (m :: * -> *).
(Has (Reader TDCtx) sig m, Has (Throw KindError) sig m) =>
Type -> m ()
checkKind Type
t

-- | Check that a type is well-kinded. For now, we don't allow
--   higher-kinded types, *i.e.* all kinds will be of the form @Type
--   -> Type -> ... -> Type@ which can be represented by a number (the
--   arity); every type constructor must also be fully applied. So, we
--   only have to check that each type constructor is applied to the
--   correct number of type arguments.  In the future, we might very
--   well want to generalize to arbitrary higher kinds (e.g. @(Type ->
--   Type) -> Type@ etc.) which would require generalizing this
--   checking code a bit.
--
--   Here we also check that any recursive types are non-vacuous,
--   /i.e./ not of the form @rec t. t@, and non-trivial, /i.e./ the
--   variable bound by the @rec@ actually occurs somewhere in the
--   body.
checkKind :: (Has (Reader TDCtx) sig m, Has (Throw KindError) sig m) => Type -> m ()
checkKind :: forall (sig :: (* -> *) -> * -> *) (m :: * -> *).
(Has (Reader TDCtx) sig m, Has (Throw KindError) sig m) =>
Type -> m ()
checkKind ty :: Type
ty@(Fix TypeF Type
tyF) = case TypeF Type
tyF of
  TyConF TyCon
c [Type]
tys -> do
    TDCtx
tdCtx <- m TDCtx
forall r (sig :: (* -> *) -> * -> *) (m :: * -> *).
Has (Reader r) sig m =>
m r
ask
    case Arity -> Int
getArity (Arity -> Int) -> Maybe Arity -> Maybe Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TDCtx -> TyCon -> Maybe Arity
tcArity TDCtx
tdCtx TyCon
c of
      Maybe Int
Nothing -> KindError -> m ()
forall e (sig :: (* -> *) -> * -> *) (m :: * -> *) a.
Has (Throw e) sig m =>
e -> m a
throwError (KindError -> m ()) -> KindError -> m ()
forall a b. (a -> b) -> a -> b
$ TyCon -> Type -> KindError
UndefinedTyCon TyCon
c Type
ty
      Just Int
a -> case Int -> Int -> Ordering
forall a. Ord a => a -> a -> Ordering
compare ([Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
tys) Int
a of
        Ordering
EQ -> (Type -> m ()) -> [Type] -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Type -> m ()
forall (sig :: (* -> *) -> * -> *) (m :: * -> *).
(Has (Reader TDCtx) sig m, Has (Throw KindError) sig m) =>
Type -> m ()
checkKind [Type]
tys
        Ordering
_ -> KindError -> m ()
forall e (sig :: (* -> *) -> * -> *) (m :: * -> *) a.
Has (Throw e) sig m =>
e -> m a
throwError (KindError -> m ()) -> KindError -> m ()
forall a b. (a -> b) -> a -> b
$ TyCon -> Int -> [Type] -> KindError
ArityMismatch TyCon
c Int
a [Type]
tys
  TyVarF Var
_ -> () -> m ()
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
  TyRcdF Map Var Type
m -> (Type -> m ()) -> Map Var Type -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Type -> m ()
forall (sig :: (* -> *) -> * -> *) (m :: * -> *).
(Has (Reader TDCtx) sig m, Has (Throw KindError) sig m) =>
Type -> m ()
checkKind Map Var Type
m
  TyRecF Var
x Type
t -> do
    -- It's important to call checkKind first, to rule out undefined
    -- type constructors. Within the recursive kind check, we
    -- substitute the given variable name for the bound de Bruijn
    -- index 0 in the body.  This doesn't affect the checking but it
    -- does ensure that error messages will use the variable name and
    -- not de Bruijn indices.
    Type -> m ()
forall (sig :: (* -> *) -> * -> *) (m :: * -> *).
(Has (Reader TDCtx) sig m, Has (Throw KindError) sig m) =>
Type -> m ()
checkKind (TypeF Type -> Type -> Nat -> Type
forall t. SubstRec t => TypeF t -> t -> Nat -> t
substRec (Var -> TypeF Type
forall t. Var -> TypeF t
TyVarF Var
x) Type
t Nat
NZ)
    -- Now check that the recursive type is well-formed.  We call this
    -- with the *unsubstituted* t because the check will be looking
    -- for de Bruijn variables specifically.
    Var -> Type -> m ()
forall (sig :: (* -> *) -> * -> *) (m :: * -> *).
(Has (Reader TDCtx) sig m, Has (Throw KindError) sig m) =>
Var -> Type -> m ()
checkRecTy Var
x Type
t
  TyRecVarF Nat
_ -> () -> m ()
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return ()

-- | Check that the body of a recursive type actually contains the
--   bound variable at least once (otherwise there's no point in using
--   @rec@) and does not consist solely of that variable.
checkRecTy :: (Has (Reader TDCtx) sig m, Has (Throw KindError) sig m) => Var -> Type -> m ()
checkRecTy :: forall (sig :: (* -> *) -> * -> *) (m :: * -> *).
(Has (Reader TDCtx) sig m, Has (Throw KindError) sig m) =>
Var -> Type -> m ()
checkRecTy Var
x Type
ty = do
  m Bool -> m () -> m ()
forall (m :: * -> *). Monad m => m Bool -> m () -> m ()
unlessM (Nat -> Type -> m Bool
forall (sig :: (* -> *) -> * -> *) (m :: * -> *).
Has (Reader TDCtx) sig m =>
Nat -> Type -> m Bool
containsVar Nat
NZ Type
ty) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ KindError -> m ()
forall e (sig :: (* -> *) -> * -> *) (m :: * -> *) a.
Has (Throw e) sig m =>
e -> m a
throwError (Var -> Type -> KindError
TrivialRecTy Var
x Type
ty)
  m Bool -> m () -> m ()
forall (m :: * -> *). Monad m => m Bool -> m () -> m ()
unlessM (Nat -> Type -> m Bool
forall (sig :: (* -> *) -> * -> *) (m :: * -> *).
Has (Reader TDCtx) sig m =>
Nat -> Type -> m Bool
nonVacuous Nat
NZ Type
ty) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ KindError -> m ()
forall e (sig :: (* -> *) -> * -> *) (m :: * -> *) a.
Has (Throw e) sig m =>
e -> m a
throwError (Var -> Type -> KindError
VacuousRecTy Var
x Type
ty)

-- Note, in theory it would be more efficient to combine containsVar
-- and nonVacuous into a single check that walks over the type only
-- once, but we keep them separate just to simplify things.  This
-- won't make much difference in the grand scheme of things since
-- types are small.

-- | Check whether a type contains a specific bound recursive type
--   variable.
containsVar :: Has (Reader TDCtx) sig m => Nat -> Type -> m Bool
containsVar :: forall (sig :: (* -> *) -> * -> *) (m :: * -> *).
Has (Reader TDCtx) sig m =>
Nat -> Type -> m Bool
containsVar Nat
i (Fix TypeF Type
tyF) = case TypeF Type
tyF of
  TyRecVarF Nat
j -> Bool -> m Bool
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Nat
i Nat -> Nat -> Bool
forall a. Eq a => a -> a -> Bool
== Nat
j)
  TyVarF {} -> Bool -> m Bool
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False
  TyConF (TCUser Var
u) [Type]
tys -> do
    Type
ty' <- Var -> [Type] -> m Type
forall (sig :: (* -> *) -> * -> *) (m :: * -> *) t.
(Has (Reader TDCtx) sig m, Typical t) =>
Var -> [t] -> m t
expandTydef Var
u [Type]
tys
    Nat -> Type -> m Bool
forall (sig :: (* -> *) -> * -> *) (m :: * -> *).
Has (Reader TDCtx) sig m =>
Nat -> Type -> m Bool
containsVar Nat
i Type
ty'
  TyConF TyCon
_ [Type]
tys -> [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
or ([Bool] -> Bool) -> m [Bool] -> m Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Type -> m Bool) -> [Type] -> m [Bool]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Nat -> Type -> m Bool
forall (sig :: (* -> *) -> * -> *) (m :: * -> *).
Has (Reader TDCtx) sig m =>
Nat -> Type -> m Bool
containsVar Nat
i) [Type]
tys
  TyRcdF Map Var Type
m -> Map Var Bool -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
or (Map Var Bool -> Bool) -> m (Map Var Bool) -> m Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Type -> m Bool) -> Map Var Type -> m (Map Var Bool)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> Map Var a -> m (Map Var b)
mapM (Nat -> Type -> m Bool
forall (sig :: (* -> *) -> * -> *) (m :: * -> *).
Has (Reader TDCtx) sig m =>
Nat -> Type -> m Bool
containsVar Nat
i) Map Var Type
m
  TyRecF Var
_ Type
ty -> Nat -> Type -> m Bool
forall (sig :: (* -> *) -> * -> *) (m :: * -> *).
Has (Reader TDCtx) sig m =>
Nat -> Type -> m Bool
containsVar (Nat -> Nat
NS Nat
i) Type
ty

-- | @nonVacuous ty@ checks that the recursive type @rec x. ty@ is
--   non-vacuous, /i.e./ that it doesn't look like @rec x. x@.  Put
--   another way, we make sure the recursive type is "productive" in
--   the sense that unfolding it will result in a well-defined
--   infinite type (as opposed to @rec x. x@ which just unfolds to
--   itself).  However, we can't just check whether it literally looks
--   like @rec x. x@ since we must also (1) expand type aliases and
--   (2) ignore additional intervening @rec@s.  For example, given
--   @tydef Id a = a@, the type @rec x. rec y. Id x@ is also vacuous.
nonVacuous :: (Has (Reader TDCtx) sig m) => Nat -> Type -> m Bool
nonVacuous :: forall (sig :: (* -> *) -> * -> *) (m :: * -> *).
Has (Reader TDCtx) sig m =>
Nat -> Type -> m Bool
nonVacuous Nat
i (Fix TypeF Type
tyF) = case TypeF Type
tyF of
  -- The type simply consists of a variable bound by some @rec@.
  -- Check if it's the variable we're currently looking for.
  TyRecVarF Nat
j -> Bool -> m Bool
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Nat
i Nat -> Nat -> Bool
forall a. Eq a => a -> a -> Bool
/= Nat
j)
  -- Expand a user-defined type and keep looking.
  TyConF (TCUser Var
u) [Type]
tys -> do
    Type
ty' <- Var -> [Type] -> m Type
forall (sig :: (* -> *) -> * -> *) (m :: * -> *) t.
(Has (Reader TDCtx) sig m, Typical t) =>
Var -> [t] -> m t
expandTydef Var
u [Type]
tys
    Nat -> Type -> m Bool
forall (sig :: (* -> *) -> * -> *) (m :: * -> *).
Has (Reader TDCtx) sig m =>
Nat -> Type -> m Bool
nonVacuous Nat
i Type
ty'
  -- Increment the variable we're looking for when going under a @rec@
  -- binder.
  TyRecF Var
_ Type
ty -> Nat -> Type -> m Bool
forall (sig :: (* -> *) -> * -> *) (m :: * -> *).
Has (Reader TDCtx) sig m =>
Nat -> Type -> m Bool
nonVacuous (Nat -> Nat
NS Nat
i) Type
ty
  -- If we encounter any other kind of type constructor or record
  -- type, rejoice!
  TyConF {} -> Bool -> m Bool
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
True
  TyRcdF {} -> Bool -> m Bool
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
True
  -- This last case can't actully happen if we already checked that
  -- the recursive type actually contains its bound variable (with
  -- 'containsVar'), since it would correspond to something like @rec
  -- x. y@.  However, it's still correct to return True.
  TyVarF {} -> Bool -> m Bool
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
True