{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE QuasiQuotes #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}

-- For 'Ord IntVar' instance

-- |
-- Module      :  Swarm.Language.Typecheck
-- Copyright   :  Brent Yorgey
-- Maintainer  :  byorgey@gmail.com
--
-- SPDX-License-Identifier: BSD-3-Clause
--
-- Type inference for the Swarm language.  For the approach used here,
-- see
-- https://byorgey.wordpress.com/2021/09/08/implementing-hindley-milner-with-the-unification-fd-library/ .
module Swarm.Language.Typecheck (
  -- * Type errors
  TypeErr (..),
  InvalidAtomicReason (..),
  getTypeErrLocation,

  -- * Inference monad
  Infer,
  runInfer,
  lookup,
  fresh,

  -- * Unification
  substU,
  (=:=),
  HasBindings (..),
  instantiate,
  skolemize,
  generalize,

  -- * Type inference
  inferTop,
  inferModule,
  infer,
  inferConst,
  check,
  decomposeCmdTy,
  decomposeFunTy,
  isSimpleUType,
) where

import Control.Category ((>>>))
import Control.Monad.Except
import Control.Monad.Reader
import Control.Unification hiding (applyBindings, (=:=))
import Control.Unification qualified as U
import Control.Unification.IntVar
import Data.Foldable (fold)
import Data.Functor.Identity
import Data.Map (Map)
import Data.Map qualified as M
import Data.Maybe
import Data.Set (Set, (\\))
import Data.Set qualified as S
import Swarm.Language.Context hiding (lookup)
import Swarm.Language.Context qualified as Ctx
import Swarm.Language.Parse.QQ (tyQ)
import Swarm.Language.Syntax
import Swarm.Language.Types
import Prelude hiding (lookup)

------------------------------------------------------------
-- Inference monad

-- | The concrete monad used for type inference.  'IntBindingT' is a
--   monad transformer provided by the @unification-fd@ library which
--   supports various operations such as generating fresh variables
--   and unifying things.
type Infer = ReaderT UCtx (ExceptT TypeErr (IntBindingT TypeF Identity))

-- | Run a top-level inference computation, returning either a
--   'TypeErr' or a fully resolved 'TModule'.
runInfer :: TCtx -> Infer UModule -> Either TypeErr TModule
runInfer :: Ctx Polytype -> Infer UModule -> Either TypeErr TModule
runInfer Ctx Polytype
ctx =
  (forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall u. HasBindings u => u -> Infer u
applyBindings)
    forall {k} (cat :: k -> k -> *) (a :: k) (b :: k) (c :: k).
Category cat =>
cat a b -> cat b c -> cat a c
>>> (forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \(Module UType
uty UCtx
uctx) -> forall s t. s -> Ctx t -> Module s t
Module forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall t. WithU t => U t -> t
fromU forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> UType -> Infer UPolytype
generalize UType
uty) forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall t. WithU t => U t -> t
fromU UCtx
uctx))
    forall {k} (cat :: k -> k -> *) (a :: k) (b :: k) (c :: k).
Category cat =>
cat a b -> cat b c -> cat a c
>>> forall a b c. (a -> b -> c) -> b -> a -> c
flip forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (forall t. WithU t => t -> U t
toU Ctx Polytype
ctx)
    forall {k} (cat :: k -> k -> *) (a :: k) (b :: k) (c :: k).
Category cat =>
cat a b -> cat b c -> cat a c
>>> forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT
    forall {k} (cat :: k -> k -> *) (a :: k) (b :: k) (c :: k).
Category cat =>
cat a b -> cat b c -> cat a c
>>> forall (m :: * -> *) (t :: * -> *) a.
Monad m =>
IntBindingT t m a -> m a
evalIntBindingT
    forall {k} (cat :: k -> k -> *) (a :: k) (b :: k) (c :: k).
Category cat =>
cat a b -> cat b c -> cat a c
>>> forall a. Identity a -> a
runIdentity

-- | Look up a variable in the ambient type context, either throwing
--   an 'UnboundVar' error if it is not found, or opening its
--   associated 'UPolytype' with fresh unification variables via
--   'instantiate'.
lookup :: Location -> Var -> Infer UType
lookup :: Location -> Var -> Infer UType
lookup Location
loc Var
x = do
  UCtx
ctx <- forall r (m :: * -> *). MonadReader r m => m r
ask
  forall b a. b -> (a -> b) -> Maybe a -> b
maybe (forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError forall a b. (a -> b) -> a -> b
$ Location -> Var -> TypeErr
UnboundVar Location
loc Var
x) UPolytype -> Infer UType
instantiate (forall t. Var -> Ctx t -> Maybe t
Ctx.lookup Var
x UCtx
ctx)

------------------------------------------------------------
-- Dealing with variables: free variables, fresh variables,
-- substitution

-- | @unification-fd@ does not provide an 'Ord' instance for 'IntVar',
--   so we must provide our own, in order to be able to store
--   'IntVar's in a 'Set'.
deriving instance Ord IntVar

-- | A class for getting the free unification variables of a thing.
class FreeVars a where
  freeVars :: a -> Infer (Set IntVar)

-- | We can get the free unification variables of a 'UType'.
instance FreeVars UType where
  freeVars :: UType -> Infer (Set IntVar)
freeVars UType
ut = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. Ord a => [a] -> Set a
S.fromList forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) v (m :: * -> *).
BindingMonad t v m =>
UTerm t v -> m [v]
getFreeVars UType
ut

-- | We can also get the free variables of a polytype.
instance FreeVars t => FreeVars (Poly t) where
  freeVars :: Poly t -> Infer (Set IntVar)
freeVars (Forall [Var]
_ t
t) = forall a. FreeVars a => a -> Infer (Set IntVar)
freeVars t
t

-- | We can get the free variables in any polytype in a context.
instance FreeVars UCtx where
  freeVars :: UCtx -> Infer (Set IntVar)
freeVars = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall (f :: * -> *) a. (Foldable f, Ord a) => f (Set a) -> Set a
S.unions forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall a. FreeVars a => a -> Infer (Set IntVar)
freeVars forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k a. Map k a -> [a]
M.elems forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall t. Ctx t -> Map Var t
unCtx

-- | Generate a fresh unification variable.
fresh :: Infer UType
fresh :: Infer UType
fresh = forall (t :: * -> *) v. v -> UTerm t v
UVar forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall (t :: * -> *) v (m :: * -> *). BindingMonad t v m => m v
freeVar)

-- | Perform a substitution over a 'UType', substituting for both type
--   and unification variables.  Note that since 'UType's do not have
--   any binding constructs, we don't have to worry about ignoring
--   bound variables; all variables in a 'UType' are free.
substU :: Map (Either Var IntVar) UType -> UType -> UType
substU :: Map (Either Var IntVar) UType -> UType -> UType
substU Map (Either Var IntVar) UType
m =
  forall (t :: * -> *) v a.
Functor t =>
(v -> a) -> (t a -> a) -> UTerm t v -> a
ucata
    (\IntVar
v -> forall a. a -> Maybe a -> a
fromMaybe (forall (t :: * -> *) v. v -> UTerm t v
UVar IntVar
v) (forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (forall a b. b -> Either a b
Right IntVar
v) Map (Either Var IntVar) UType
m))
    ( \case
        TyVarF Var
v -> forall a. a -> Maybe a -> a
fromMaybe (Var -> UType
UTyVar Var
v) (forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (forall a b. a -> Either a b
Left Var
v) Map (Either Var IntVar) UType
m)
        TypeF UType
f -> forall (t :: * -> *) v. t (UTerm t v) -> UTerm t v
UTerm TypeF UType
f
    )

------------------------------------------------------------
-- Lifted stuff from unification-fd

infix 4 =:=

-- | Constrain two types to be equal.
(=:=) :: UType -> UType -> Infer ()
UType
s =:= :: UType -> UType -> Infer ()
=:= UType
t = forall (f :: * -> *) a. Functor f => f a -> f ()
void (forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ UType
s forall (t :: * -> *) v (m :: * -> *) e (em :: (* -> *) -> * -> *).
(BindingMonad t v m, Fallible t v e, MonadTrans em, Functor (em m),
 MonadError e (em m)) =>
UTerm t v -> UTerm t v -> em m (UTerm t v)
U.=:= UType
t)

-- | @unification-fd@ provides a function 'U.applyBindings' which
--   fully substitutes for any bound unification variables (for
--   efficiency, it does not perform such substitution as it goes
--   along).  The 'HasBindings' class is for anything which has
--   unification variables in it and to which we can usefully apply
--   'U.applyBindings'.
class HasBindings u where
  applyBindings :: u -> Infer u

instance HasBindings UType where
  applyBindings :: UType -> Infer UType
applyBindings = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) v (m :: * -> *) e (em :: (* -> *) -> * -> *).
(BindingMonad t v m, Fallible t v e, MonadTrans em, Functor (em m),
 MonadError e (em m)) =>
UTerm t v -> em m (UTerm t v)
U.applyBindings

instance HasBindings UPolytype where
  applyBindings :: UPolytype -> Infer UPolytype
applyBindings (Forall [Var]
xs UType
u) = forall t. [Var] -> t -> Poly t
Forall [Var]
xs forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall u. HasBindings u => u -> Infer u
applyBindings UType
u

instance HasBindings UCtx where
  applyBindings :: UCtx -> Infer UCtx
applyBindings = forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall u. HasBindings u => u -> Infer u
applyBindings

instance HasBindings UModule where
  applyBindings :: UModule -> Infer UModule
applyBindings (Module UType
uty UCtx
uctx) = forall s t. s -> Ctx t -> Module s t
Module forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall u. HasBindings u => u -> Infer u
applyBindings UType
uty forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall u. HasBindings u => u -> Infer u
applyBindings UCtx
uctx

------------------------------------------------------------
-- Converting between mono- and polytypes

-- | To 'instantiate' a 'UPolytype', we generate a fresh unification
--   variable for each variable bound by the `Forall`, and then
--   substitute them throughout the type.
instantiate :: UPolytype -> Infer UType
instantiate :: UPolytype -> Infer UType
instantiate (Forall [Var]
xs UType
uty) = do
  [UType]
xs' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall a b. a -> b -> a
const Infer UType
fresh) [Var]
xs
  forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ Map (Either Var IntVar) UType -> UType -> UType
substU (forall k a. Ord k => [(k, a)] -> Map k a
M.fromList (forall a b. [a] -> [b] -> [(a, b)]
zip (forall a b. (a -> b) -> [a] -> [b]
map forall a b. a -> Either a b
Left [Var]
xs) [UType]
xs')) UType
uty

-- | 'skolemize' is like 'instantiate', except we substitute fresh
--   /type/ variables instead of unification variables.  Such
--   variables cannot unify with anything other than themselves.  This
--   is used when checking something with a polytype explicitly
--   specified by the user.
skolemize :: UPolytype -> Infer UType
skolemize :: UPolytype -> Infer UType
skolemize (Forall [Var]
xs UType
uty) = do
  [UType]
xs' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall a b. a -> b -> a
const Infer UType
fresh) [Var]
xs
  forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ Map (Either Var IntVar) UType -> UType -> UType
substU (forall k a. Ord k => [(k, a)] -> Map k a
M.fromList (forall a b. [a] -> [b] -> [(a, b)]
zip (forall a b. (a -> b) -> [a] -> [b]
map forall a b. a -> Either a b
Left [Var]
xs) (forall a b. (a -> b) -> [a] -> [b]
map forall {t :: * -> *}.
Show (t (UTerm t IntVar)) =>
UTerm t IntVar -> UType
toSkolem [UType]
xs'))) UType
uty
 where
  toSkolem :: UTerm t IntVar -> UType
toSkolem (UVar IntVar
v) = Var -> UType
UTyVar (Var -> IntVar -> Var
mkVarName Var
"s" IntVar
v)
  toSkolem UTerm t IntVar
x = forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"Impossible! Non-UVar in skolemize.toSkolem: " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> [Char]
show UTerm t IntVar
x

-- | 'generalize' is the opposite of 'instantiate': add a 'Forall'
--   which closes over all free type and unification variables.
generalize :: UType -> Infer UPolytype
generalize :: UType -> Infer UPolytype
generalize UType
uty = do
  UType
uty' <- forall u. HasBindings u => u -> Infer u
applyBindings UType
uty
  UCtx
ctx <- forall r (m :: * -> *). MonadReader r m => m r
ask
  Set IntVar
tmfvs <- forall a. FreeVars a => a -> Infer (Set IntVar)
freeVars UType
uty'
  Set IntVar
ctxfvs <- forall a. FreeVars a => a -> Infer (Set IntVar)
freeVars UCtx
ctx
  let fvs :: [IntVar]
fvs = forall a. Set a -> [a]
S.toList forall a b. (a -> b) -> a -> b
$ Set IntVar
tmfvs forall a. Ord a => Set a -> Set a -> Set a
\\ Set IntVar
ctxfvs
      xs :: [Var]
xs = forall a b. (a -> b) -> [a] -> [b]
map (Var -> IntVar -> Var
mkVarName Var
"a") [IntVar]
fvs
  forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall t. [Var] -> t -> Poly t
Forall [Var]
xs (Map (Either Var IntVar) UType -> UType -> UType
substU (forall k a. Ord k => [(k, a)] -> Map k a
M.fromList (forall a b. [a] -> [b] -> [(a, b)]
zip (forall a b. (a -> b) -> [a] -> [b]
map forall a b. b -> Either a b
Right [IntVar]
fvs) (forall a b. (a -> b) -> [a] -> [b]
map Var -> UType
UTyVar [Var]
xs))) UType
uty')

------------------------------------------------------------
-- Type errors

-- | Errors that can occur during type checking.  The idea is that
--   each error carries information that can be used to help explain
--   what went wrong (though the amount of information carried can and
--   should be very much improved in the future); errors can then
--   separately be pretty-printed to display them to the user.
data TypeErr
  = -- | An undefined variable was encountered.
    UnboundVar Location Var
  | -- | A Skolem variable escaped its local context.
    EscapedSkolem Location Var
  | Infinite IntVar UType
  | -- | The given term was expected to have a certain type, but has a
    -- different type instead.
    Mismatch Location (TypeF UType) (TypeF UType)
  | -- | A definition was encountered not at the top level.
    DefNotTopLevel Location Term
  | -- | A term was encountered which we cannot infer the type of.
    --   This should never happen.
    CantInfer Location Term
  | -- | An invalid argument was provided to @atomic@.
    InvalidAtomic Location InvalidAtomicReason Term
  deriving (Int -> TypeErr -> ShowS
[TypeErr] -> ShowS
TypeErr -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [TypeErr] -> ShowS
$cshowList :: [TypeErr] -> ShowS
show :: TypeErr -> [Char]
$cshow :: TypeErr -> [Char]
showsPrec :: Int -> TypeErr -> ShowS
$cshowsPrec :: Int -> TypeErr -> ShowS
Show)

-- | Various reasons the body of an @atomic@ might be invalid.
data InvalidAtomicReason
  = -- | The arugment has too many tangible commands.
    TooManyTicks Int
  | -- | The argument uses some way to duplicate code: @def@, @let@, or lambda.
    AtomicDupingThing
  | -- | The argument referred to a variable with a non-simple type.
    NonSimpleVarType Var UPolytype
  | -- | The argument had a nested @atomic@
    NestedAtomic
  | -- | The argument contained a long command
    LongConst
  deriving (Int -> InvalidAtomicReason -> ShowS
[InvalidAtomicReason] -> ShowS
InvalidAtomicReason -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [InvalidAtomicReason] -> ShowS
$cshowList :: [InvalidAtomicReason] -> ShowS
show :: InvalidAtomicReason -> [Char]
$cshow :: InvalidAtomicReason -> [Char]
showsPrec :: Int -> InvalidAtomicReason -> ShowS
$cshowsPrec :: Int -> InvalidAtomicReason -> ShowS
Show)

instance Fallible TypeF IntVar TypeErr where
  occursFailure :: IntVar -> UType -> TypeErr
occursFailure = IntVar -> UType -> TypeErr
Infinite
  mismatchFailure :: TypeF UType -> TypeF UType -> TypeErr
mismatchFailure = Location -> TypeF UType -> TypeF UType -> TypeErr
Mismatch Location
NoLoc

getTypeErrLocation :: TypeErr -> Maybe Location
getTypeErrLocation :: TypeErr -> Maybe Location
getTypeErrLocation TypeErr
te = case TypeErr
te of
  UnboundVar Location
l Var
_ -> forall a. a -> Maybe a
Just Location
l
  EscapedSkolem Location
l Var
_ -> forall a. a -> Maybe a
Just Location
l
  Infinite IntVar
_ UType
_ -> forall a. Maybe a
Nothing
  Mismatch Location
l TypeF UType
_ TypeF UType
_ -> forall a. a -> Maybe a
Just Location
l
  DefNotTopLevel Location
l Term
_ -> forall a. a -> Maybe a
Just Location
l
  CantInfer Location
l Term
_ -> forall a. a -> Maybe a
Just Location
l
  InvalidAtomic Location
l InvalidAtomicReason
_ Term
_ -> forall a. a -> Maybe a
Just Location
l

------------------------------------------------------------
-- Type inference / checking

-- | Top-level type inference function: given a context of definition
--   types and a top-level term, either return a type error or its
--   type as a 'TModule'.
inferTop :: TCtx -> Syntax -> Either TypeErr TModule
inferTop :: Ctx Polytype -> Syntax -> Either TypeErr TModule
inferTop Ctx Polytype
ctx = Ctx Polytype -> Infer UModule -> Either TypeErr TModule
runInfer Ctx Polytype
ctx forall b c a. (b -> c) -> (a -> b) -> a -> c
. Syntax -> Infer UModule
inferModule

-- | Infer the signature of a top-level expression which might
--   contain definitions.
inferModule :: Syntax -> Infer UModule
inferModule :: Syntax -> Infer UModule
inferModule s :: Syntax
s@(Syntax Location
_ Term
t) = (forall e (m :: * -> *) a.
MonadError e m =>
m a -> (e -> m a) -> m a
`catchError` forall a. Syntax -> TypeErr -> Infer a
addLocToTypeErr Syntax
s) forall a b. (a -> b) -> a -> b
$ case Term
t of
  -- For definitions with no type signature, make up a fresh type
  -- variable for the body, infer the body under an extended context,
  -- and unify the two.  Then generalize the type and return an
  -- appropriate context.
  SDef Bool
_ Var
x Maybe Polytype
Nothing Syntax
t1 -> do
    UType
xTy <- Infer UType
fresh
    UType
ty <- forall t (m :: * -> *) a.
MonadReader (Ctx t) m =>
Var -> t -> m a -> m a
withBinding Var
x (forall t. [Var] -> t -> Poly t
Forall [] UType
xTy) forall a b. (a -> b) -> a -> b
$ Syntax -> Infer UType
infer Syntax
t1
    UType
xTy UType -> UType -> Infer ()
=:= UType
ty
    UPolytype
pty <- UType -> Infer UPolytype
generalize UType
ty
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall s t. s -> Ctx t -> Module s t
Module (UType -> UType
UTyCmd UType
UTyUnit) (forall t. Var -> t -> Ctx t
singleton Var
x UPolytype
pty)

  -- If a (poly)type signature has been provided, skolemize it and
  -- check the definition.
  SDef Bool
_ Var
x (Just Polytype
pty) Syntax
t1 -> do
    let upty :: U Polytype
upty = forall t. WithU t => t -> U t
toU Polytype
pty
    UType
uty <- UPolytype -> Infer UType
skolemize UPolytype
upty
    forall t (m :: * -> *) a.
MonadReader (Ctx t) m =>
Var -> t -> m a -> m a
withBinding Var
x UPolytype
upty forall a b. (a -> b) -> a -> b
$ Syntax -> UType -> Infer ()
check Syntax
t1 UType
uty
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall s t. s -> Ctx t -> Module s t
Module (UType -> UType
UTyCmd UType
UTyUnit) (forall t. Var -> t -> Ctx t
singleton Var
x UPolytype
upty)

  -- To handle a 'TBind', infer the types of both sides, combining the
  -- returned modules appropriately.  Have to be careful to use the
  -- correct context when checking the right-hand side in particular.
  SBind Maybe Var
mx Syntax
c1 Syntax
c2 -> do
    -- First, infer the left side.
    Module UType
cmda UCtx
ctx1 <- Syntax -> Infer UModule
inferModule Syntax
c1
    UType
a <- UType -> Infer UType
decomposeCmdTy UType
cmda

    -- Now infer the right side under an extended context: things in
    -- scope on the right-hand side include both any definitions
    -- created by the left-hand side, as well as a variable as in @x
    -- <- c1; c2@.  The order of extensions here matters: in theory,
    -- c1 could define something with the same name as x, in which
    -- case the bound x should shadow the defined one; hence, we apply
    -- that binding /after/ (i.e. /within/) the application of @ctx1@.
    forall t (m :: * -> *) a.
MonadReader (Ctx t) m =>
Ctx t -> m a -> m a
withBindings UCtx
ctx1 forall a b. (a -> b) -> a -> b
$
      forall b a. b -> (a -> b) -> Maybe a -> b
maybe forall a. a -> a
id (forall t (m :: * -> *) a.
MonadReader (Ctx t) m =>
Var -> t -> m a -> m a
`withBinding` forall t. [Var] -> t -> Poly t
Forall [] UType
a) Maybe Var
mx forall a b. (a -> b) -> a -> b
$ do
        Module UType
cmdb UCtx
ctx2 <- Syntax -> Infer UModule
inferModule Syntax
c2

        -- We don't actually need the result type since we're just going
        -- to return cmdb, but it's important to ensure it's a command
        -- type anyway.  Otherwise something like 'move; 3' would be
        -- accepted with type int.
        UType
_ <- UType -> Infer UType
decomposeCmdTy UType
cmdb

        -- Ctx.union is right-biased, so ctx1 `union` ctx2 means later
        -- definitions will shadow previous ones.  Include the binder
        -- (if any) as well, since binders are made available at the top
        -- level, just like definitions. e.g. if the user writes `r <- build {move}`,
        -- then they will be able to refer to r again later.
        let ctxX :: UCtx
ctxX = forall b a. b -> (a -> b) -> Maybe a -> b
maybe forall t. Ctx t
Ctx.empty (forall t. Var -> t -> Ctx t
`Ctx.singleton` forall t. [Var] -> t -> Poly t
Forall [] UType
a) Maybe Var
mx
        forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall s t. s -> Ctx t -> Module s t
Module UType
cmdb (UCtx
ctx1 forall t. Ctx t -> Ctx t -> Ctx t
`Ctx.union` UCtx
ctxX forall t. Ctx t -> Ctx t -> Ctx t
`Ctx.union` UCtx
ctx2)

  -- In all other cases, there can no longer be any definitions in the
  -- term, so delegate to 'infer'.
  Term
_anyOtherTerm -> forall s t. s -> Module s t
trivMod forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Syntax -> Infer UType
infer Syntax
s

-- | Infer the type of a term which does not contain definitions.
infer :: Syntax -> Infer UType
infer :: Syntax -> Infer UType
infer s :: Syntax
s@(Syntax Location
l Term
t) = (forall e (m :: * -> *) a.
MonadError e m =>
m a -> (e -> m a) -> m a
`catchError` forall a. Syntax -> TypeErr -> Infer a
addLocToTypeErr Syntax
s) forall a b. (a -> b) -> a -> b
$ case Term
t of
  Term
TUnit -> forall (m :: * -> *) a. Monad m => a -> m a
return UType
UTyUnit
  TConst Const
c -> UPolytype -> Infer UType
instantiate forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall t. WithU t => t -> U t
toU forall a b. (a -> b) -> a -> b
$ Const -> Polytype
inferConst Const
c
  TDir Direction
_ -> forall (m :: * -> *) a. Monad m => a -> m a
return UType
UTyDir
  TInt Integer
_ -> forall (m :: * -> *) a. Monad m => a -> m a
return UType
UTyInt
  TAntiInt Var
_ -> forall (m :: * -> *) a. Monad m => a -> m a
return UType
UTyInt
  TText Var
_ -> forall (m :: * -> *) a. Monad m => a -> m a
return UType
UTyText
  TAntiText Var
_ -> forall (m :: * -> *) a. Monad m => a -> m a
return UType
UTyText
  TBool Bool
_ -> forall (m :: * -> *) a. Monad m => a -> m a
return UType
UTyBool
  TRobot Int
_ -> forall (m :: * -> *) a. Monad m => a -> m a
return UType
UTyRobot
  -- We should never encounter a TRef since they do not show up in
  -- surface syntax, only as values while evaluating (*after*
  -- typechecking).
  TRef Int
_ -> forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError forall a b. (a -> b) -> a -> b
$ Location -> Term -> TypeErr
CantInfer Location
l Term
t
  TRequireDevice Var
_ -> forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ UType -> UType
UTyCmd UType
UTyUnit
  TRequire Int
_ Var
_ -> forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ UType -> UType
UTyCmd UType
UTyUnit
  -- To infer the type of a pair, just infer both components.
  SPair Syntax
t1 Syntax
t2 -> UType -> UType -> UType
UTyProd forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Syntax -> Infer UType
infer Syntax
t1 forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Syntax -> Infer UType
infer Syntax
t2
  -- if t : ty, then  {t} : {ty}.
  -- Note that in theory, if the @Maybe Var@ component of the @SDelay@
  -- is @Just@, we should typecheck the body under a context extended
  -- with a type binding for the variable, and ensure that the type of
  -- the variable is the same as the type inferred for the overall
  -- @SDelay@.  However, we rely on the invariant that such recursive
  -- @SDelay@ nodes are never generated from the surface syntax, only
  -- dynamically at runtime when evaluating recursive let or def expressions,
  -- so we don't have to worry about typechecking them here.
  SDelay DelayType
_ Syntax
dt -> UType -> UType
UTyDelay forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Syntax -> Infer UType
infer Syntax
dt
  -- We need a special case for checking the argument to 'atomic'.
  -- 'atomic t' has the same type as 't', which must have a type of
  -- the form 'cmd a'.  't' must also be syntactically free of
  -- variables.
  TConst Const
Atomic :$: Syntax
at -> do
    UType
argTy <- Infer UType
fresh
    Syntax -> UType -> Infer ()
check Syntax
at (UType -> UType
UTyCmd UType
argTy)
    -- It's important that we typecheck the subterm @at@ *before* we
    -- check that it is a valid argument to @atomic@: this way we can
    -- ensure that we have already inferred the types of any variables
    -- referenced.
    Syntax -> Infer ()
validAtomic Syntax
at
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ UType -> UType
UTyCmd UType
argTy

  -- Just look up variables in the context.
  TVar Var
x -> Location -> Var -> Infer UType
lookup Location
l Var
x
  -- To infer the type of a lambda if the type of the argument is
  -- provided, just infer the body under an extended context and return
  -- the appropriate function type.
  SLam Var
x (Just Fix TypeF
argTy) Syntax
lt -> do
    let uargTy :: U (Fix TypeF)
uargTy = forall t. WithU t => t -> U t
toU Fix TypeF
argTy
    UType
resTy <- forall t (m :: * -> *) a.
MonadReader (Ctx t) m =>
Var -> t -> m a -> m a
withBinding Var
x (forall t. [Var] -> t -> Poly t
Forall [] UType
uargTy) forall a b. (a -> b) -> a -> b
$ Syntax -> Infer UType
infer Syntax
lt
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ UType -> UType -> UType
UTyFun UType
uargTy UType
resTy

  -- If the type of the argument is not provided, create a fresh
  -- unification variable for it and proceed.
  SLam Var
x Maybe (Fix TypeF)
Nothing Syntax
lt -> do
    UType
argTy <- Infer UType
fresh
    UType
resTy <- forall t (m :: * -> *) a.
MonadReader (Ctx t) m =>
Var -> t -> m a -> m a
withBinding Var
x (forall t. [Var] -> t -> Poly t
Forall [] UType
argTy) forall a b. (a -> b) -> a -> b
$ Syntax -> Infer UType
infer Syntax
lt
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ UType -> UType -> UType
UTyFun UType
argTy UType
resTy

  -- To infer the type of an application:
  SApp Syntax
f Syntax
x -> do
    -- Infer the type of the left-hand side and make sure it has a function type.
    UType
fTy <- Syntax -> Infer UType
infer Syntax
f
    (UType
ty1, UType
ty2) <- UType -> Infer (UType, UType)
decomposeFunTy UType
fTy

    -- Then check that the argument has the right type.
    Syntax -> UType -> Infer ()
check Syntax
x UType
ty1 forall e (m :: * -> *) a.
MonadError e m =>
m a -> (e -> m a) -> m a
`catchError` forall a. Syntax -> TypeErr -> Infer a
addLocToTypeErr Syntax
x
    forall (m :: * -> *) a. Monad m => a -> m a
return UType
ty2

  -- We can infer the type of a let whether a type has been provided for
  -- the variable or not.
  SLet Bool
_ Var
x Maybe Polytype
Nothing Syntax
t1 Syntax
t2 -> do
    UType
xTy <- Infer UType
fresh
    UType
uty <- forall t (m :: * -> *) a.
MonadReader (Ctx t) m =>
Var -> t -> m a -> m a
withBinding Var
x (forall t. [Var] -> t -> Poly t
Forall [] UType
xTy) forall a b. (a -> b) -> a -> b
$ Syntax -> Infer UType
infer Syntax
t1
    UType
xTy UType -> UType -> Infer ()
=:= UType
uty
    UPolytype
upty <- UType -> Infer UPolytype
generalize UType
uty
    forall t (m :: * -> *) a.
MonadReader (Ctx t) m =>
Var -> t -> m a -> m a
withBinding Var
x UPolytype
upty forall a b. (a -> b) -> a -> b
$ Syntax -> Infer UType
infer Syntax
t2
  SLet Bool
_ Var
x (Just Polytype
pty) Syntax
t1 Syntax
t2 -> do
    let upty :: U Polytype
upty = forall t. WithU t => t -> U t
toU Polytype
pty
    -- If an explicit polytype has been provided, skolemize it and check
    -- definition and body under an extended context.
    UType
uty <- UPolytype -> Infer UType
skolemize UPolytype
upty
    UType
resTy <- forall t (m :: * -> *) a.
MonadReader (Ctx t) m =>
Var -> t -> m a -> m a
withBinding Var
x UPolytype
upty forall a b. (a -> b) -> a -> b
$ do
      Syntax -> UType -> Infer ()
check Syntax
t1 UType
uty forall e (m :: * -> *) a.
MonadError e m =>
m a -> (e -> m a) -> m a
`catchError` forall a. Syntax -> TypeErr -> Infer a
addLocToTypeErr Syntax
t1
      Syntax -> Infer UType
infer Syntax
t2
    -- Make sure no skolem variables have escaped.
    forall r (m :: * -> *). MonadReader r m => m r
ask forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ UPolytype -> Infer ()
noSkolems
    forall (m :: * -> *) a. Monad m => a -> m a
return UType
resTy
  SDef {} -> forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError forall a b. (a -> b) -> a -> b
$ Location -> Term -> TypeErr
DefNotTopLevel Location
l Term
t
  SBind Maybe Var
mx Syntax
c1 Syntax
c2 -> do
    UType
ty1 <- Syntax -> Infer UType
infer Syntax
c1
    UType
a <- UType -> Infer UType
decomposeCmdTy UType
ty1
    UType
ty2 <- forall b a. b -> (a -> b) -> Maybe a -> b
maybe forall a. a -> a
id (forall t (m :: * -> *) a.
MonadReader (Ctx t) m =>
Var -> t -> m a -> m a
`withBinding` forall t. [Var] -> t -> Poly t
Forall [] UType
a) Maybe Var
mx forall a b. (a -> b) -> a -> b
$ Syntax -> Infer UType
infer Syntax
c2
    UType
_ <- UType -> Infer UType
decomposeCmdTy UType
ty2
    forall (m :: * -> *) a. Monad m => a -> m a
return UType
ty2
 where
  noSkolems :: UPolytype -> Infer ()
  noSkolems :: UPolytype -> Infer ()
noSkolems (Forall [Var]
xs UType
upty) = do
    UType
upty' <- forall u. HasBindings u => u -> Infer u
applyBindings UType
upty
    let tyvs :: Set Var
tyvs =
          forall (t :: * -> *) v a.
Functor t =>
(v -> a) -> (t a -> a) -> UTerm t v -> a
ucata
            (forall a b. a -> b -> a
const forall a. Set a
S.empty)
            (\case TyVarF Var
v -> forall a. a -> Set a
S.singleton Var
v; TypeF (Set Var)
f -> forall (t :: * -> *) m. (Foldable t, Monoid m) => t m -> m
fold TypeF (Set Var)
f)
            UType
upty'
        ftyvs :: Set Var
ftyvs = Set Var
tyvs forall a. Ord a => Set a -> Set a -> Set a
`S.difference` forall a. Ord a => [a] -> Set a
S.fromList [Var]
xs
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (forall a. Set a -> Bool
S.null Set Var
ftyvs) forall a b. (a -> b) -> a -> b
$
      forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError forall a b. (a -> b) -> a -> b
$ Location -> Var -> TypeErr
EscapedSkolem Location
l (forall a. [a] -> a
head (forall a. Set a -> [a]
S.toList Set Var
ftyvs))

addLocToTypeErr :: Syntax -> TypeErr -> Infer a
addLocToTypeErr :: forall a. Syntax -> TypeErr -> Infer a
addLocToTypeErr Syntax
s TypeErr
te = case TypeErr
te of
  Mismatch Location
NoLoc TypeF UType
a TypeF UType
b -> forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError forall a b. (a -> b) -> a -> b
$ Location -> TypeF UType -> TypeF UType -> TypeErr
Mismatch (Syntax -> Location
sLoc Syntax
s) TypeF UType
a TypeF UType
b
  TypeErr
_ -> forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError TypeErr
te

-- | Decompose a type that is supposed to be a command type.
decomposeCmdTy :: UType -> Infer UType
decomposeCmdTy :: UType -> Infer UType
decomposeCmdTy (UTyCmd UType
a) = forall (m :: * -> *) a. Monad m => a -> m a
return UType
a
decomposeCmdTy UType
ty = do
  UType
a <- Infer UType
fresh
  UType
ty UType -> UType -> Infer ()
=:= UType -> UType
UTyCmd UType
a
  forall (m :: * -> *) a. Monad m => a -> m a
return UType
a

-- | Decompose a type that is supposed to be a function type.
decomposeFunTy :: UType -> Infer (UType, UType)
decomposeFunTy :: UType -> Infer (UType, UType)
decomposeFunTy (UTyFun UType
ty1 UType
ty2) = forall (m :: * -> *) a. Monad m => a -> m a
return (UType
ty1, UType
ty2)
decomposeFunTy UType
ty = do
  UType
ty1 <- Infer UType
fresh
  UType
ty2 <- Infer UType
fresh
  UType
ty UType -> UType -> Infer ()
=:= UType -> UType -> UType
UTyFun UType
ty1 UType
ty2
  forall (m :: * -> *) a. Monad m => a -> m a
return (UType
ty1, UType
ty2)

-- | Infer the type of a constant.
inferConst :: Const -> Polytype
inferConst :: Const -> Polytype
inferConst Const
c = case Const
c of
  Const
Wait -> [tyQ| int -> cmd unit |]
  Const
Noop -> [tyQ| cmd unit |]
  Const
Selfdestruct -> [tyQ| cmd unit |]
  Const
Move -> [tyQ| cmd unit |]
  Const
Turn -> [tyQ| dir -> cmd unit |]
  Const
Grab -> [tyQ| cmd text |]
  Const
Harvest -> [tyQ| cmd text |]
  Const
Place -> [tyQ| text -> cmd unit |]
  Const
Give -> [tyQ| robot -> text -> cmd unit |]
  Const
Install -> [tyQ| robot -> text -> cmd unit |]
  Const
Make -> [tyQ| text -> cmd unit |]
  Const
Has -> [tyQ| text -> cmd bool |]
  Const
Installed -> [tyQ| text -> cmd bool |]
  Const
Count -> [tyQ| text -> cmd int |]
  Const
Reprogram -> [tyQ| robot -> {cmd a} -> cmd unit |]
  Const
Build -> [tyQ| {cmd a} -> cmd robot |]
  Const
Drill -> [tyQ| dir -> cmd unit |]
  Const
Salvage -> [tyQ| cmd unit |]
  Const
Say -> [tyQ| text -> cmd unit |]
  Const
Listen -> [tyQ| cmd text |]
  Const
Log -> [tyQ| text -> cmd unit |]
  Const
View -> [tyQ| robot -> cmd unit |]
  Const
Appear -> [tyQ| text -> cmd unit |]
  Const
Create -> [tyQ| text -> cmd unit |]
  Const
Time -> [tyQ| cmd int |]
  Const
Whereami -> [tyQ| cmd (int * int) |]
  Const
Blocked -> [tyQ| cmd bool |]
  Const
Scan -> [tyQ| dir -> cmd (unit + text) |]
  Const
Upload -> [tyQ| robot -> cmd unit |]
  Const
Ishere -> [tyQ| text -> cmd bool |]
  Const
Self -> [tyQ| robot |]
  Const
Parent -> [tyQ| robot |]
  Const
Base -> [tyQ| robot |]
  Const
Whoami -> [tyQ| cmd text |]
  Const
Setname -> [tyQ| text -> cmd unit |]
  Const
Random -> [tyQ| int -> cmd int |]
  Const
Run -> [tyQ| text -> cmd unit |]
  Const
If -> [tyQ| bool -> {a} -> {a} -> a |]
  Const
Inl -> [tyQ| a -> a + b |]
  Const
Inr -> [tyQ| b -> a + b |]
  Const
Case -> [tyQ|a + b -> (a -> c) -> (b -> c) -> c |]
  Const
Fst -> [tyQ| a * b -> a |]
  Const
Snd -> [tyQ| a * b -> b |]
  Const
Force -> [tyQ| {a} -> a |]
  Const
Return -> [tyQ| a -> cmd a |]
  Const
Try -> [tyQ| {cmd a} -> {cmd a} -> cmd a |]
  Const
Undefined -> [tyQ| a |]
  Const
Fail -> [tyQ| text -> a |]
  Const
Not -> [tyQ| bool -> bool |]
  Const
Neg -> [tyQ| int -> int |]
  Const
Eq -> Polytype
cmpBinT
  Const
Neq -> Polytype
cmpBinT
  Const
Lt -> Polytype
cmpBinT
  Const
Gt -> Polytype
cmpBinT
  Const
Leq -> Polytype
cmpBinT
  Const
Geq -> Polytype
cmpBinT
  Const
And -> [tyQ| bool -> bool -> bool|]
  Const
Or -> [tyQ| bool -> bool -> bool|]
  Const
Add -> Polytype
arithBinT
  Const
Sub -> Polytype
arithBinT
  Const
Mul -> Polytype
arithBinT
  Const
Div -> Polytype
arithBinT
  Const
Exp -> Polytype
arithBinT
  Const
Format -> [tyQ| a -> text |]
  Const
Concat -> [tyQ| text -> text -> text |]
  Const
Chars -> [tyQ| text -> int |]
  Const
Split -> [tyQ| int -> text -> (text * text) |]
  Const
AppF -> [tyQ| (a -> b) -> a -> b |]
  Const
Swap -> [tyQ| text -> cmd text |]
  Const
Atomic -> [tyQ| cmd a -> cmd a |]
  Const
Teleport -> [tyQ| robot -> (int * int) -> cmd unit |]
  Const
As -> [tyQ| robot -> {cmd a} -> cmd a |]
  Const
RobotNamed -> [tyQ| text -> cmd robot |]
  Const
RobotNumbered -> [tyQ| int -> cmd robot |]
  Const
Knows -> [tyQ| text -> cmd bool |]
 where
  cmpBinT :: Polytype
cmpBinT = [tyQ| a -> a -> bool |]
  arithBinT :: Polytype
arithBinT = [tyQ| int -> int -> int |]

-- | @check t ty@ checks that @t@ has type @ty@.
check :: Syntax -> UType -> Infer ()
check :: Syntax -> UType -> Infer ()
check Syntax
t UType
ty = do
  UType
ty' <- Syntax -> Infer UType
infer Syntax
t
  ()
_ <- UType
ty UType -> UType -> Infer ()
=:= UType
ty'
  forall (m :: * -> *) a. Monad m => a -> m a
return ()

-- | Ensure a term is a valid argument to @atomic@.  Valid arguments
--   may not contain @def@, @let@, or lambda. Any variables which are
--   referenced must have a primitive, first-order type such as
--   @text@ or @int@ (in particular, no functions, @cmd@, or
--   @delay@).  We simply assume that any locally bound variables are
--   OK without checking their type: the only way to bind a variable
--   locally is with a binder of the form @x <- c1; c2@, where @c1@ is
--   some primitive command (since we can't refer to external
--   variables of type @cmd a@).  If we wanted to do something more
--   sophisticated with locally bound variables we would have to
--   inline this analysis into typechecking proper, instead of having
--   it be a separate, out-of-band check.
--
--   The goal is to ensure that any argument to @atomic@ is guaranteed
--   to evaluate and execute in some small, finite amount of time, so
--   that it's impossible to write a term which runs atomically for an
--   indefinite amount of time and freezes the rest of the game.  Of
--   course, nothing prevents one from writing a large amount of code
--   inside an @atomic@ block; but we want the execution time to be
--   linear in the size of the code.
--
--   We also ensure that the atomic block takes at most one tick,
--   i.e. contains at most one tangible command. For example, @atomic
--   (move; move)@ is invalid, since that would allow robots to move
--   twice as fast as usual by doing both actions in one tick.
validAtomic :: Syntax -> Infer ()
validAtomic :: Syntax -> Infer ()
validAtomic s :: Syntax
s@(Syntax Location
l Term
t) = do
  Int
n <- Set Var -> Syntax -> Infer Int
analyzeAtomic forall a. Set a
S.empty Syntax
s
  forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
n forall a. Ord a => a -> a -> Bool
> Int
1) forall a b. (a -> b) -> a -> b
$ forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (Location -> InvalidAtomicReason -> Term -> TypeErr
InvalidAtomic Location
l (Int -> InvalidAtomicReason
TooManyTicks Int
n) Term
t)

-- | Analyze an argument to @atomic@: ensure it contains no nested
--   atomic blocks and no references to external variables, and count
--   how many tangible commands it will execute.
analyzeAtomic :: Set Var -> Syntax -> Infer Int
analyzeAtomic :: Set Var -> Syntax -> Infer Int
analyzeAtomic Set Var
locals (Syntax Location
l Term
t) = case Term
t of
  -- Literals, primitives, etc. that are fine and don't require a tick
  -- to evaluate
  TUnit {} -> forall (m :: * -> *) a. Monad m => a -> m a
return Int
0
  TDir {} -> forall (m :: * -> *) a. Monad m => a -> m a
return Int
0
  TInt {} -> forall (m :: * -> *) a. Monad m => a -> m a
return Int
0
  TAntiInt {} -> forall (m :: * -> *) a. Monad m => a -> m a
return Int
0
  TText {} -> forall (m :: * -> *) a. Monad m => a -> m a
return Int
0
  TAntiText {} -> forall (m :: * -> *) a. Monad m => a -> m a
return Int
0
  TBool {} -> forall (m :: * -> *) a. Monad m => a -> m a
return Int
0
  TRobot {} -> forall (m :: * -> *) a. Monad m => a -> m a
return Int
0
  TRequireDevice {} -> forall (m :: * -> *) a. Monad m => a -> m a
return Int
0
  TRequire {} -> forall (m :: * -> *) a. Monad m => a -> m a
return Int
0
  -- Constants.
  TConst Const
c
    -- Nested 'atomic' is not allowed.
    | Const
c forall a. Eq a => a -> a -> Bool
== Const
Atomic -> forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError forall a b. (a -> b) -> a -> b
$ Location -> InvalidAtomicReason -> Term -> TypeErr
InvalidAtomic Location
l InvalidAtomicReason
NestedAtomic Term
t
    -- We cannot allow long commands (commands that may require more
    -- than one tick to execute) since that could freeze the game.
    | Const -> Bool
isLong Const
c -> forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError forall a b. (a -> b) -> a -> b
$ Location -> InvalidAtomicReason -> Term -> TypeErr
InvalidAtomic Location
l InvalidAtomicReason
LongConst Term
t
    -- Otherwise, return 1 or 0 depending on whether the command is
    -- tangible.
    | Bool
otherwise -> forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ if Const -> Bool
isTangible Const
c then Int
1 else Int
0
  -- Special case for if: number of tangible commands is the *max* of
  -- the branches instead of the sum, since exactly one of them will be
  -- executed.
  TConst Const
If :$: Syntax
tst :$: Syntax
thn :$: Syntax
els ->
    forall a. Num a => a -> a -> a
(+) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Set Var -> Syntax -> Infer Int
analyzeAtomic Set Var
locals Syntax
tst forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (forall a. Ord a => a -> a -> a
max forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Set Var -> Syntax -> Infer Int
analyzeAtomic Set Var
locals Syntax
thn forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Set Var -> Syntax -> Infer Int
analyzeAtomic Set Var
locals Syntax
els)
  -- Pairs, application, and delay are simple: just recurse and sum the results.
  SPair Syntax
s1 Syntax
s2 -> forall a. Num a => a -> a -> a
(+) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Set Var -> Syntax -> Infer Int
analyzeAtomic Set Var
locals Syntax
s1 forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Set Var -> Syntax -> Infer Int
analyzeAtomic Set Var
locals Syntax
s2
  SApp Syntax
s1 Syntax
s2 -> forall a. Num a => a -> a -> a
(+) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Set Var -> Syntax -> Infer Int
analyzeAtomic Set Var
locals Syntax
s1 forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Set Var -> Syntax -> Infer Int
analyzeAtomic Set Var
locals Syntax
s2
  SDelay DelayType
_ Syntax
s1 -> Set Var -> Syntax -> Infer Int
analyzeAtomic Set Var
locals Syntax
s1
  -- Bind is similarly simple except that we have to keep track of a local variable
  -- bound in the RHS.
  SBind Maybe Var
mx Syntax
s1 Syntax
s2 -> forall a. Num a => a -> a -> a
(+) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Set Var -> Syntax -> Infer Int
analyzeAtomic Set Var
locals Syntax
s1 forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Set Var -> Syntax -> Infer Int
analyzeAtomic (forall b a. b -> (a -> b) -> Maybe a -> b
maybe forall a. a -> a
id forall a. Ord a => a -> Set a -> Set a
S.insert Maybe Var
mx Set Var
locals) Syntax
s2
  -- Variables are allowed if bound locally, or if they have a simple type.
  TVar Var
x
    | Var
x forall a. Ord a => a -> Set a -> Bool
`S.member` Set Var
locals -> forall (m :: * -> *) a. Monad m => a -> m a
return Int
0
    | Bool
otherwise -> do
      Maybe UPolytype
mxTy <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall a b. (a -> b) -> a -> b
$ forall t. Var -> Ctx t -> Maybe t
Ctx.lookup Var
x
      case Maybe UPolytype
mxTy of
        -- If the variable is undefined, return 0 to indicate the
        -- atomic block is valid, because we'd rather have the error
        -- caught by the real name+type checking.
        Maybe UPolytype
Nothing -> forall (m :: * -> *) a. Monad m => a -> m a
return Int
0
        Just UPolytype
xTy -> do
          -- Use applyBindings to make sure that we apply as much
          -- information as unification has learned at this point.  In
          -- theory, continuing to typecheck other terms elsewhere in
          -- the program could give us further information about xTy,
          -- so we might have incomplete information at this point.
          -- However, since variables referenced in an atomic block
          -- must necessarily have simple types, it's unlikely this
          -- will really make a difference.  The alternative, more
          -- "correct" way to do this would be to simply emit some
          -- constraints at this point saying that xTy must be a
          -- simple type, and check later that the constraint holds,
          -- after performing complete type inference.  However, since
          -- the current approach is much simpler, we'll stick with
          -- this until such time as we have concrete examples showing
          -- that the more correct, complex way is necessary.
          UPolytype
xTy' <- forall u. HasBindings u => u -> Infer u
applyBindings UPolytype
xTy
          if UPolytype -> Bool
isSimpleUPolytype UPolytype
xTy'
            then forall (m :: * -> *) a. Monad m => a -> m a
return Int
0
            else forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (Location -> InvalidAtomicReason -> Term -> TypeErr
InvalidAtomic Location
l (Var -> UPolytype -> InvalidAtomicReason
NonSimpleVarType Var
x UPolytype
xTy') Term
t)
  -- No lambda, `let` or `def` allowed!
  SLam {} -> forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (Location -> InvalidAtomicReason -> Term -> TypeErr
InvalidAtomic Location
l InvalidAtomicReason
AtomicDupingThing Term
t)
  SLet {} -> forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (Location -> InvalidAtomicReason -> Term -> TypeErr
InvalidAtomic Location
l InvalidAtomicReason
AtomicDupingThing Term
t)
  SDef {} -> forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (Location -> InvalidAtomicReason -> Term -> TypeErr
InvalidAtomic Location
l InvalidAtomicReason
AtomicDupingThing Term
t)
  -- We should never encounter a TRef since they do not show up in
  -- surface syntax, only as values while evaluating (*after*
  -- typechecking).
  TRef {} -> forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (Location -> Term -> TypeErr
CantInfer Location
l Term
t)

-- | A simple polytype is a simple type with no quantifiers.
isSimpleUPolytype :: UPolytype -> Bool
isSimpleUPolytype :: UPolytype -> Bool
isSimpleUPolytype (Forall [] UType
ty) = UType -> Bool
isSimpleUType UType
ty
isSimpleUPolytype UPolytype
_ = Bool
False

-- | A simple type is a sum or product of base types.
isSimpleUType :: UType -> Bool
isSimpleUType :: UType -> Bool
isSimpleUType = \case
  UTyBase {} -> Bool
True
  UTyVar {} -> Bool
False
  UTySum UType
ty1 UType
ty2 -> UType -> Bool
isSimpleUType UType
ty1 Bool -> Bool -> Bool
&& UType -> Bool
isSimpleUType UType
ty2
  UTyProd UType
ty1 UType
ty2 -> UType -> Bool
isSimpleUType UType
ty1 Bool -> Bool -> Bool
&& UType -> Bool
isSimpleUType UType
ty2
  UTyFun {} -> Bool
False
  UTyCmd {} -> Bool
False
  UTyDelay {} -> Bool
False
  -- Make the pattern-match coverage checker happy
  UVar {} -> Bool
False
  UTerm {} -> Bool
False