{-# LANGUAGE Strict #-}

-- | Facilities for type-checking terms.  Factored out of
-- "Language.Futhark.TypeChecker.Terms" to prevent the module from
-- being gigantic.
--
-- Incidentally also a nice place to put Haddock comments to make the
-- internal API of the type checker easier to browse.
module Language.Futhark.TypeChecker.Terms.Monad
  ( TermTypeM,
    runTermTypeM,
    ValBinding (..),
    SizeSource (SourceSlice),
    Inferred (..),
    Checking (..),
    withEnv,
    localScope,
    TermEnv (..),
    TermScope (..),
    TermTypeState (..),
    onFailure,
    extSize,
    expType,
    expTypeFully,
    constrain,
    newArrayType,
    allDimsFreshInType,
    updateTypes,
    Names,

    -- * Primitive checking
    unifies,
    require,
    checkTypeExpNonrigid,
    lookupVar,
    lookupMod,

    -- * Sizes
    isInt64,

    -- * Control flow
    incLevel,

    -- * Errors
    unusedSize,
  )
where

import Control.Monad
import Control.Monad.Except
import Control.Monad.Reader
import Control.Monad.State.Strict
import Data.Bitraversable
import Data.Char (isAscii)
import Data.Map.Strict qualified as M
import Data.Maybe
import Data.Set qualified as S
import Data.Text qualified as T
import Futhark.FreshNames hiding (newName)
import Futhark.FreshNames qualified
import Futhark.Util.Pretty hiding (space)
import Language.Futhark
import Language.Futhark.Traversals
import Language.Futhark.TypeChecker.Monad hiding (BoundV, lookupMod, stateNameSource)
import Language.Futhark.TypeChecker.Monad qualified as TypeM
import Language.Futhark.TypeChecker.Types
import Language.Futhark.TypeChecker.Unify
import Prelude hiding (mod)

type Names = S.Set VName

data ValBinding
  = BoundV [TypeParam] StructType
  | OverloadedF [PrimType] [Maybe PrimType] (Maybe PrimType)
  | EqualityF
  deriving (Int -> ValBinding -> ShowS
[ValBinding] -> ShowS
ValBinding -> [Char]
(Int -> ValBinding -> ShowS)
-> (ValBinding -> [Char])
-> ([ValBinding] -> ShowS)
-> Show ValBinding
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ValBinding -> ShowS
showsPrec :: Int -> ValBinding -> ShowS
$cshow :: ValBinding -> [Char]
show :: ValBinding -> [Char]
$cshowList :: [ValBinding] -> ShowS
showList :: [ValBinding] -> ShowS
Show)

unusedSize :: (MonadTypeChecker m) => SizeBinder VName -> m a
unusedSize :: forall (m :: * -> *) a.
MonadTypeChecker m =>
SizeBinder VName -> m a
unusedSize SizeBinder VName
p =
  SizeBinder VName -> Notes -> Doc () -> m a
forall loc a. Located loc => loc -> Notes -> Doc () -> m a
forall (m :: * -> *) loc a.
(MonadTypeChecker m, Located loc) =>
loc -> Notes -> Doc () -> m a
typeError SizeBinder VName
p Notes
forall a. Monoid a => a
mempty (Doc () -> m a) -> (Doc () -> Doc ()) -> Doc () -> m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Doc () -> Doc () -> Doc ()
forall a. Doc a -> Doc a -> Doc a
withIndexLink Doc ()
"unused-size" (Doc () -> m a) -> Doc () -> m a
forall a b. (a -> b) -> a -> b
$
    Doc ()
"Size" Doc () -> Doc () -> Doc ()
forall a. Doc a -> Doc a -> Doc a
<+> SizeBinder VName -> Doc ()
forall a ann. Pretty a => a -> Doc ann
forall ann. SizeBinder VName -> Doc ann
pretty SizeBinder VName
p Doc () -> Doc () -> Doc ()
forall a. Doc a -> Doc a -> Doc a
<+> Doc ()
"unused in pattern."

data Inferred t
  = NoneInferred
  | Ascribed t

instance Functor Inferred where
  fmap :: forall a b. (a -> b) -> Inferred a -> Inferred b
fmap a -> b
_ Inferred a
NoneInferred = Inferred b
forall t. Inferred t
NoneInferred
  fmap a -> b
f (Ascribed a
t) = b -> Inferred b
forall t. t -> Inferred t
Ascribed (a -> b
f a
t)

data Checking
  = CheckingApply (Maybe (QualName VName)) Exp StructType StructType
  | CheckingReturn ResType StructType
  | CheckingAscription StructType StructType
  | CheckingLetGeneralise Name
  | CheckingParams (Maybe Name)
  | CheckingPat (PatBase NoInfo VName StructType) (Inferred StructType)
  | CheckingLoopBody StructType StructType
  | CheckingLoopInitial StructType StructType
  | CheckingRecordUpdate [Name] StructType StructType
  | CheckingRequired [StructType] StructType
  | CheckingBranches StructType StructType

instance Pretty Checking where
  pretty :: forall ann. Checking -> Doc ann
pretty (CheckingApply Maybe (QualName VName)
f Exp
e StructType
expected StructType
actual) =
    Doc ann
forall {ann}. Doc ann
header
      Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Doc ann
"Expected:"
      Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
align (StructType -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. StructType -> Doc ann
pretty StructType
expected)
      Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Doc ann
"Actual:  "
      Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
align (StructType -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. StructType -> Doc ann
pretty StructType
actual)
    where
      header :: Doc ann
header =
        case Maybe (QualName VName)
f of
          Maybe (QualName VName)
Nothing ->
            Doc ann
"Cannot apply function to"
              Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
dquotes (Doc Any -> Doc ann
forall a b. Doc a -> Doc b
shorten (Doc Any -> Doc ann) -> Doc Any -> Doc ann
forall a b. (a -> b) -> a -> b
$ Doc Any -> Doc Any
forall ann. Doc ann -> Doc ann
group (Doc Any -> Doc Any) -> Doc Any -> Doc Any
forall a b. (a -> b) -> a -> b
$ Exp -> Doc Any
forall a ann. Pretty a => a -> Doc ann
forall ann. Exp -> Doc ann
pretty Exp
e)
              Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
" (invalid type)."
          Just QualName VName
fname ->
            Doc ann
"Cannot apply"
              Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
dquotes (QualName VName -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. QualName VName -> Doc ann
pretty QualName VName
fname)
              Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> Doc ann
"to"
              Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
dquotes (Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
align (Doc ann -> Doc ann) -> Doc ann -> Doc ann
forall a b. (a -> b) -> a -> b
$ Doc Any -> Doc ann
forall a b. Doc a -> Doc b
shorten (Doc Any -> Doc ann) -> Doc Any -> Doc ann
forall a b. (a -> b) -> a -> b
$ Doc Any -> Doc Any
forall ann. Doc ann -> Doc ann
group (Doc Any -> Doc Any) -> Doc Any -> Doc Any
forall a b. (a -> b) -> a -> b
$ Exp -> Doc Any
forall a ann. Pretty a => a -> Doc ann
forall ann. Exp -> Doc ann
pretty Exp
e)
              Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
" (invalid type)."
  pretty (CheckingReturn ResType
expected StructType
actual) =
    Doc ann
"Function body does not have expected type."
      Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Doc ann
"Expected:"
      Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
align (ResType -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. ResType -> Doc ann
pretty ResType
expected)
      Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Doc ann
"Actual:  "
      Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
align (StructType -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. StructType -> Doc ann
pretty StructType
actual)
  pretty (CheckingAscription StructType
expected StructType
actual) =
    Doc ann
"Expression does not have expected type from explicit ascription."
      Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Doc ann
"Expected:"
      Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
align (StructType -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. StructType -> Doc ann
pretty StructType
expected)
      Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Doc ann
"Actual:  "
      Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
align (StructType -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. StructType -> Doc ann
pretty StructType
actual)
  pretty (CheckingLetGeneralise Name
fname) =
    Doc ann
"Cannot generalise type of" Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
dquotes (Name -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. Name -> Doc ann
pretty Name
fname) Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
"."
  pretty (CheckingParams Maybe Name
fname) =
    Doc ann
"Invalid use of parameters in" Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
dquotes Doc ann
forall {ann}. Doc ann
fname' Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
"."
    where
      fname' :: Doc ann
fname' = Doc ann -> (Name -> Doc ann) -> Maybe Name -> Doc ann
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Doc ann
"anonymous function" Name -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. Name -> Doc ann
pretty Maybe Name
fname
  pretty (CheckingPat PatBase NoInfo VName StructType
pat Inferred StructType
NoneInferred) =
    Doc ann
"Invalid pattern" Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
dquotes (PatBase NoInfo VName StructType -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. PatBase NoInfo VName StructType -> Doc ann
pretty PatBase NoInfo VName StructType
pat) Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
"."
  pretty (CheckingPat PatBase NoInfo VName StructType
pat (Ascribed StructType
t)) =
    Doc ann
"Pattern"
      Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Int -> Doc ann -> Doc ann
forall ann. Int -> Doc ann -> Doc ann
indent Int
2 (PatBase NoInfo VName StructType -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. PatBase NoInfo VName StructType -> Doc ann
pretty PatBase NoInfo VName StructType
pat)
      Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Doc ann
"cannot match value of type"
      Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Int -> Doc ann -> Doc ann
forall ann. Int -> Doc ann -> Doc ann
indent Int
2 (StructType -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. StructType -> Doc ann
pretty StructType
t)
  pretty (CheckingLoopBody StructType
expected StructType
actual) =
    Doc ann
"Loop body does not have expected type."
      Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Doc ann
"Expected:"
      Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
align (StructType -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. StructType -> Doc ann
pretty StructType
expected)
      Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Doc ann
"Actual:  "
      Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
align (StructType -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. StructType -> Doc ann
pretty StructType
actual)
  pretty (CheckingLoopInitial StructType
expected StructType
actual) =
    Doc ann
"Initial loop values do not have expected type."
      Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Doc ann
"Expected:"
      Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
align (StructType -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. StructType -> Doc ann
pretty StructType
expected)
      Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Doc ann
"Actual:  "
      Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
align (StructType -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. StructType -> Doc ann
pretty StructType
actual)
  pretty (CheckingRecordUpdate [Name]
fs StructType
expected StructType
actual) =
    Doc ann
"Type mismatch when updating record field"
      Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
dquotes Doc ann
forall {ann}. Doc ann
fs'
      Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
"."
        Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Doc ann
"Existing:"
        Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
align (StructType -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. StructType -> Doc ann
pretty StructType
expected)
        Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Doc ann
"New:     "
        Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
align (StructType -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. StructType -> Doc ann
pretty StructType
actual)
    where
      fs' :: Doc ann
fs' = [Doc ann] -> Doc ann
forall a. Monoid a => [a] -> a
mconcat ([Doc ann] -> Doc ann) -> [Doc ann] -> Doc ann
forall a b. (a -> b) -> a -> b
$ Doc ann -> [Doc ann] -> [Doc ann]
forall ann. Doc ann -> [Doc ann] -> [Doc ann]
punctuate Doc ann
"." ([Doc ann] -> [Doc ann]) -> [Doc ann] -> [Doc ann]
forall a b. (a -> b) -> a -> b
$ (Name -> Doc ann) -> [Name] -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map Name -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. Name -> Doc ann
pretty [Name]
fs
  pretty (CheckingRequired [StructType
expected] StructType
actual) =
    Doc ann
"Expression must have type"
      Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> StructType -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. StructType -> Doc ann
pretty StructType
expected
      Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
"."
        Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Doc ann
"Actual type:"
        Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
align (StructType -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. StructType -> Doc ann
pretty StructType
actual)
  pretty (CheckingRequired [StructType]
expected StructType
actual) =
    Doc ann
"Type of expression must be one of "
      Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> Doc ann
forall {ann}. Doc ann
expected'
      Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
"."
        Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Doc ann
"Actual type:"
        Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
align (StructType -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. StructType -> Doc ann
pretty StructType
actual)
    where
      expected' :: Doc a
expected' = [Doc a] -> Doc a
forall a. [Doc a] -> Doc a
commasep ((StructType -> Doc a) -> [StructType] -> [Doc a]
forall a b. (a -> b) -> [a] -> [b]
map StructType -> Doc a
forall a ann. Pretty a => a -> Doc ann
forall ann. StructType -> Doc ann
pretty [StructType]
expected)
  pretty (CheckingBranches StructType
t1 StructType
t2) =
    Doc ann
"Branches differ in type."
      Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Doc ann
"Former:"
      Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> StructType -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. StructType -> Doc ann
pretty StructType
t1
      Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
</> Doc ann
"Latter:"
      Doc ann -> Doc ann -> Doc ann
forall a. Doc a -> Doc a -> Doc a
<+> StructType -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. StructType -> Doc ann
pretty StructType
t2

-- | Type checking happens with access to this environment.  The
-- 'TermScope' will be extended during type-checking as bindings come into
-- scope.
data TermEnv = TermEnv
  { TermEnv -> TermScope
termScope :: TermScope,
    TermEnv -> Maybe Checking
termChecking :: Maybe Checking,
    TermEnv -> Int
termLevel :: Level,
    TermEnv -> ExpBase NoInfo VName -> TermTypeM Exp
termChecker :: ExpBase NoInfo VName -> TermTypeM Exp,
    TermEnv -> Env
termOuterEnv :: Env,
    TermEnv -> ImportName
termImportName :: ImportName
  }

data TermScope = TermScope
  { TermScope -> Map VName ValBinding
scopeVtable :: M.Map VName ValBinding,
    TermScope -> Map VName TypeBinding
scopeTypeTable :: M.Map VName TypeBinding,
    TermScope -> Map VName Mod
scopeModTable :: M.Map VName Mod
  }
  deriving (Int -> TermScope -> ShowS
[TermScope] -> ShowS
TermScope -> [Char]
(Int -> TermScope -> ShowS)
-> (TermScope -> [Char])
-> ([TermScope] -> ShowS)
-> Show TermScope
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> TermScope -> ShowS
showsPrec :: Int -> TermScope -> ShowS
$cshow :: TermScope -> [Char]
show :: TermScope -> [Char]
$cshowList :: [TermScope] -> ShowS
showList :: [TermScope] -> ShowS
Show)

instance Semigroup TermScope where
  TermScope Map VName ValBinding
vt1 Map VName TypeBinding
tt1 Map VName Mod
mt1 <> :: TermScope -> TermScope -> TermScope
<> TermScope Map VName ValBinding
vt2 Map VName TypeBinding
tt2 Map VName Mod
mt2 =
    Map VName ValBinding
-> Map VName TypeBinding -> Map VName Mod -> TermScope
TermScope (Map VName ValBinding
vt2 Map VName ValBinding
-> Map VName ValBinding -> Map VName ValBinding
forall k a. Ord k => Map k a -> Map k a -> Map k a
`M.union` Map VName ValBinding
vt1) (Map VName TypeBinding
tt2 Map VName TypeBinding
-> Map VName TypeBinding -> Map VName TypeBinding
forall k a. Ord k => Map k a -> Map k a -> Map k a
`M.union` Map VName TypeBinding
tt1) (Map VName Mod
mt1 Map VName Mod -> Map VName Mod -> Map VName Mod
forall k a. Ord k => Map k a -> Map k a -> Map k a
`M.union` Map VName Mod
mt2)

envToTermScope :: Env -> TermScope
envToTermScope :: Env -> TermScope
envToTermScope Env
env =
  TermScope
    { scopeVtable :: Map VName ValBinding
scopeVtable = Map VName ValBinding
vtable,
      scopeTypeTable :: Map VName TypeBinding
scopeTypeTable = Env -> Map VName TypeBinding
envTypeTable Env
env,
      scopeModTable :: Map VName Mod
scopeModTable = Env -> Map VName Mod
envModTable Env
env
    }
  where
    vtable :: Map VName ValBinding
vtable = (BoundV -> ValBinding) -> Map VName BoundV -> Map VName ValBinding
forall a b k. (a -> b) -> Map k a -> Map k b
M.map BoundV -> ValBinding
valBinding (Map VName BoundV -> Map VName ValBinding)
-> Map VName BoundV -> Map VName ValBinding
forall a b. (a -> b) -> a -> b
$ Env -> Map VName BoundV
envVtable Env
env
    valBinding :: BoundV -> ValBinding
valBinding (TypeM.BoundV [TypeParam]
tps StructType
v) = [TypeParam] -> StructType -> ValBinding
BoundV [TypeParam]
tps StructType
v

withEnv :: TermEnv -> Env -> TermEnv
withEnv :: TermEnv -> Env -> TermEnv
withEnv TermEnv
tenv Env
env = TermEnv
tenv {termScope :: TermScope
termScope = TermEnv -> TermScope
termScope TermEnv
tenv TermScope -> TermScope -> TermScope
forall a. Semigroup a => a -> a -> a
<> Env -> TermScope
envToTermScope Env
env}

-- | Wrap a function name to give it a vacuous Eq instance for SizeSource.
newtype FName = FName (Maybe (QualName VName))
  deriving (Int -> FName -> ShowS
[FName] -> ShowS
FName -> [Char]
(Int -> FName -> ShowS)
-> (FName -> [Char]) -> ([FName] -> ShowS) -> Show FName
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> FName -> ShowS
showsPrec :: Int -> FName -> ShowS
$cshow :: FName -> [Char]
show :: FName -> [Char]
$cshowList :: [FName] -> ShowS
showList :: [FName] -> ShowS
Show)

instance Eq FName where
  FName
_ == :: FName -> FName -> Bool
== FName
_ = Bool
True

instance Ord FName where
  compare :: FName -> FName -> Ordering
compare FName
_ FName
_ = Ordering
EQ

-- | What was the source of some existential size?  This is used for
-- using the same existential variable if the same source is
-- encountered in multiple locations.
data SizeSource
  = SourceArg FName (ExpBase NoInfo VName)
  | SourceSlice
      (Maybe Size)
      (Maybe (ExpBase NoInfo VName))
      (Maybe (ExpBase NoInfo VName))
      (Maybe (ExpBase NoInfo VName))
  deriving (SizeSource -> SizeSource -> Bool
(SizeSource -> SizeSource -> Bool)
-> (SizeSource -> SizeSource -> Bool) -> Eq SizeSource
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: SizeSource -> SizeSource -> Bool
== :: SizeSource -> SizeSource -> Bool
$c/= :: SizeSource -> SizeSource -> Bool
/= :: SizeSource -> SizeSource -> Bool
Eq, Eq SizeSource
Eq SizeSource
-> (SizeSource -> SizeSource -> Ordering)
-> (SizeSource -> SizeSource -> Bool)
-> (SizeSource -> SizeSource -> Bool)
-> (SizeSource -> SizeSource -> Bool)
-> (SizeSource -> SizeSource -> Bool)
-> (SizeSource -> SizeSource -> SizeSource)
-> (SizeSource -> SizeSource -> SizeSource)
-> Ord SizeSource
SizeSource -> SizeSource -> Bool
SizeSource -> SizeSource -> Ordering
SizeSource -> SizeSource -> SizeSource
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: SizeSource -> SizeSource -> Ordering
compare :: SizeSource -> SizeSource -> Ordering
$c< :: SizeSource -> SizeSource -> Bool
< :: SizeSource -> SizeSource -> Bool
$c<= :: SizeSource -> SizeSource -> Bool
<= :: SizeSource -> SizeSource -> Bool
$c> :: SizeSource -> SizeSource -> Bool
> :: SizeSource -> SizeSource -> Bool
$c>= :: SizeSource -> SizeSource -> Bool
>= :: SizeSource -> SizeSource -> Bool
$cmax :: SizeSource -> SizeSource -> SizeSource
max :: SizeSource -> SizeSource -> SizeSource
$cmin :: SizeSource -> SizeSource -> SizeSource
min :: SizeSource -> SizeSource -> SizeSource
Ord, Int -> SizeSource -> ShowS
[SizeSource] -> ShowS
SizeSource -> [Char]
(Int -> SizeSource -> ShowS)
-> (SizeSource -> [Char])
-> ([SizeSource] -> ShowS)
-> Show SizeSource
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> SizeSource -> ShowS
showsPrec :: Int -> SizeSource -> ShowS
$cshow :: SizeSource -> [Char]
show :: SizeSource -> [Char]
$cshowList :: [SizeSource] -> ShowS
showList :: [SizeSource] -> ShowS
Show)

-- | The state is a set of constraints and a counter for generating
-- type names.  This is distinct from the usual counter we use for
-- generating unique names, as these will be user-visible.
data TermTypeState = TermTypeState
  { TermTypeState -> Constraints
stateConstraints :: Constraints,
    TermTypeState -> Int
stateCounter :: !Int,
    TermTypeState -> Warnings
stateWarnings :: Warnings,
    TermTypeState -> VNameSource
stateNameSource :: VNameSource
  }

newtype TermTypeM a
  = TermTypeM
      ( ReaderT
          TermEnv
          (StateT TermTypeState (Except (Warnings, TypeError)))
          a
      )
  deriving
    ( Applicative TermTypeM
Applicative TermTypeM
-> (forall a b. TermTypeM a -> (a -> TermTypeM b) -> TermTypeM b)
-> (forall a b. TermTypeM a -> TermTypeM b -> TermTypeM b)
-> (forall a. a -> TermTypeM a)
-> Monad TermTypeM
forall a. a -> TermTypeM a
forall a b. TermTypeM a -> TermTypeM b -> TermTypeM b
forall a b. TermTypeM a -> (a -> TermTypeM b) -> TermTypeM b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
$c>>= :: forall a b. TermTypeM a -> (a -> TermTypeM b) -> TermTypeM b
>>= :: forall a b. TermTypeM a -> (a -> TermTypeM b) -> TermTypeM b
$c>> :: forall a b. TermTypeM a -> TermTypeM b -> TermTypeM b
>> :: forall a b. TermTypeM a -> TermTypeM b -> TermTypeM b
$creturn :: forall a. a -> TermTypeM a
return :: forall a. a -> TermTypeM a
Monad,
      (forall a b. (a -> b) -> TermTypeM a -> TermTypeM b)
-> (forall a b. a -> TermTypeM b -> TermTypeM a)
-> Functor TermTypeM
forall a b. a -> TermTypeM b -> TermTypeM a
forall a b. (a -> b) -> TermTypeM a -> TermTypeM b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall a b. (a -> b) -> TermTypeM a -> TermTypeM b
fmap :: forall a b. (a -> b) -> TermTypeM a -> TermTypeM b
$c<$ :: forall a b. a -> TermTypeM b -> TermTypeM a
<$ :: forall a b. a -> TermTypeM b -> TermTypeM a
Functor,
      Functor TermTypeM
Functor TermTypeM
-> (forall a. a -> TermTypeM a)
-> (forall a b. TermTypeM (a -> b) -> TermTypeM a -> TermTypeM b)
-> (forall a b c.
    (a -> b -> c) -> TermTypeM a -> TermTypeM b -> TermTypeM c)
-> (forall a b. TermTypeM a -> TermTypeM b -> TermTypeM b)
-> (forall a b. TermTypeM a -> TermTypeM b -> TermTypeM a)
-> Applicative TermTypeM
forall a. a -> TermTypeM a
forall a b. TermTypeM a -> TermTypeM b -> TermTypeM a
forall a b. TermTypeM a -> TermTypeM b -> TermTypeM b
forall a b. TermTypeM (a -> b) -> TermTypeM a -> TermTypeM b
forall a b c.
(a -> b -> c) -> TermTypeM a -> TermTypeM b -> TermTypeM c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
$cpure :: forall a. a -> TermTypeM a
pure :: forall a. a -> TermTypeM a
$c<*> :: forall a b. TermTypeM (a -> b) -> TermTypeM a -> TermTypeM b
<*> :: forall a b. TermTypeM (a -> b) -> TermTypeM a -> TermTypeM b
$cliftA2 :: forall a b c.
(a -> b -> c) -> TermTypeM a -> TermTypeM b -> TermTypeM c
liftA2 :: forall a b c.
(a -> b -> c) -> TermTypeM a -> TermTypeM b -> TermTypeM c
$c*> :: forall a b. TermTypeM a -> TermTypeM b -> TermTypeM b
*> :: forall a b. TermTypeM a -> TermTypeM b -> TermTypeM b
$c<* :: forall a b. TermTypeM a -> TermTypeM b -> TermTypeM a
<* :: forall a b. TermTypeM a -> TermTypeM b -> TermTypeM a
Applicative,
      MonadReader TermEnv,
      MonadState TermTypeState
    )

instance MonadError TypeError TermTypeM where
  throwError :: forall a. TypeError -> TermTypeM a
throwError TypeError
e = ReaderT
  TermEnv (StateT TermTypeState (Except (Warnings, TypeError))) a
-> TermTypeM a
forall a.
ReaderT
  TermEnv (StateT TermTypeState (Except (Warnings, TypeError))) a
-> TermTypeM a
TermTypeM (ReaderT
   TermEnv (StateT TermTypeState (Except (Warnings, TypeError))) a
 -> TermTypeM a)
-> ReaderT
     TermEnv (StateT TermTypeState (Except (Warnings, TypeError))) a
-> TermTypeM a
forall a b. (a -> b) -> a -> b
$ do
    Warnings
ws <- (TermTypeState -> Warnings)
-> ReaderT
     TermEnv
     (StateT TermTypeState (Except (Warnings, TypeError)))
     Warnings
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets TermTypeState -> Warnings
stateWarnings
    (Warnings, TypeError)
-> ReaderT
     TermEnv (StateT TermTypeState (Except (Warnings, TypeError))) a
forall a.
(Warnings, TypeError)
-> ReaderT
     TermEnv (StateT TermTypeState (Except (Warnings, TypeError))) a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (Warnings
ws, TypeError
e)

  catchError :: forall a. TermTypeM a -> (TypeError -> TermTypeM a) -> TermTypeM a
catchError (TermTypeM ReaderT
  TermEnv (StateT TermTypeState (Except (Warnings, TypeError))) a
m) TypeError -> TermTypeM a
f =
    ReaderT
  TermEnv (StateT TermTypeState (Except (Warnings, TypeError))) a
-> TermTypeM a
forall a.
ReaderT
  TermEnv (StateT TermTypeState (Except (Warnings, TypeError))) a
-> TermTypeM a
TermTypeM (ReaderT
   TermEnv (StateT TermTypeState (Except (Warnings, TypeError))) a
 -> TermTypeM a)
-> ReaderT
     TermEnv (StateT TermTypeState (Except (Warnings, TypeError))) a
-> TermTypeM a
forall a b. (a -> b) -> a -> b
$ ReaderT
  TermEnv (StateT TermTypeState (Except (Warnings, TypeError))) a
m ReaderT
  TermEnv (StateT TermTypeState (Except (Warnings, TypeError))) a
-> ((Warnings, TypeError)
    -> ReaderT
         TermEnv (StateT TermTypeState (Except (Warnings, TypeError))) a)
-> ReaderT
     TermEnv (StateT TermTypeState (Except (Warnings, TypeError))) a
forall a.
ReaderT
  TermEnv (StateT TermTypeState (Except (Warnings, TypeError))) a
-> ((Warnings, TypeError)
    -> ReaderT
         TermEnv (StateT TermTypeState (Except (Warnings, TypeError))) a)
-> ReaderT
     TermEnv (StateT TermTypeState (Except (Warnings, TypeError))) a
forall e (m :: * -> *) a.
MonadError e m =>
m a -> (e -> m a) -> m a
`catchError` (Warnings, TypeError)
-> ReaderT
     TermEnv (StateT TermTypeState (Except (Warnings, TypeError))) a
forall {a}.
(a, TypeError)
-> ReaderT
     TermEnv (StateT TermTypeState (Except (Warnings, TypeError))) a
f'
    where
      f' :: (a, TypeError)
-> ReaderT
     TermEnv (StateT TermTypeState (Except (Warnings, TypeError))) a
f' (a
_, TypeError
e) = let TermTypeM ReaderT
  TermEnv (StateT TermTypeState (Except (Warnings, TypeError))) a
m' = TypeError -> TermTypeM a
f TypeError
e in ReaderT
  TermEnv (StateT TermTypeState (Except (Warnings, TypeError))) a
m'

incCounter :: TermTypeM Int
incCounter :: TermTypeM Int
incCounter = do
  TermTypeState
s <- TermTypeM TermTypeState
forall s (m :: * -> *). MonadState s m => m s
get
  TermTypeState -> TermTypeM ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put TermTypeState
s {stateCounter :: Int
stateCounter = TermTypeState -> Int
stateCounter TermTypeState
s Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1}
  Int -> TermTypeM Int
forall a. a -> TermTypeM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int -> TermTypeM Int) -> Int -> TermTypeM Int
forall a b. (a -> b) -> a -> b
$ TermTypeState -> Int
stateCounter TermTypeState
s

constrain :: VName -> Constraint -> TermTypeM ()
constrain :: VName -> Constraint -> TermTypeM ()
constrain VName
v Constraint
c = do
  Int
lvl <- TermTypeM Int
forall (m :: * -> *). MonadUnify m => m Int
curLevel
  (Constraints -> Constraints) -> TermTypeM ()
forall (m :: * -> *).
MonadUnify m =>
(Constraints -> Constraints) -> m ()
modifyConstraints ((Constraints -> Constraints) -> TermTypeM ())
-> (Constraints -> Constraints) -> TermTypeM ()
forall a b. (a -> b) -> a -> b
$ VName -> (Int, Constraint) -> Constraints -> Constraints
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v (Int
lvl, Constraint
c)

instance MonadUnify TermTypeM where
  getConstraints :: TermTypeM Constraints
getConstraints = (TermTypeState -> Constraints) -> TermTypeM Constraints
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets TermTypeState -> Constraints
stateConstraints
  putConstraints :: Constraints -> TermTypeM ()
putConstraints Constraints
x = (TermTypeState -> TermTypeState) -> TermTypeM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((TermTypeState -> TermTypeState) -> TermTypeM ())
-> (TermTypeState -> TermTypeState) -> TermTypeM ()
forall a b. (a -> b) -> a -> b
$ \TermTypeState
s -> TermTypeState
s {stateConstraints :: Constraints
stateConstraints = Constraints
x}

  newTypeVar :: forall als a dim.
(Monoid als, Located a) =>
a -> Name -> TermTypeM (TypeBase dim als)
newTypeVar a
loc Name
desc = do
    Int
i <- TermTypeM Int
incCounter
    VName
v <- Name -> TermTypeM VName
forall (m :: * -> *). MonadTypeChecker m => Name -> m VName
newID (Name -> TermTypeM VName) -> Name -> TermTypeM VName
forall a b. (a -> b) -> a -> b
$ Name -> Int -> Name
mkTypeVarName Name
desc Int
i
    VName -> Constraint -> TermTypeM ()
constrain VName
v (Constraint -> TermTypeM ()) -> Constraint -> TermTypeM ()
forall a b. (a -> b) -> a -> b
$ Liftedness -> Usage -> Constraint
NoConstraint Liftedness
Lifted (Usage -> Constraint) -> Usage -> Constraint
forall a b. (a -> b) -> a -> b
$ a -> Usage
forall a. Located a => a -> Usage
mkUsage' a
loc
    TypeBase dim als -> TermTypeM (TypeBase dim als)
forall a. a -> TermTypeM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TypeBase dim als -> TermTypeM (TypeBase dim als))
-> TypeBase dim als -> TermTypeM (TypeBase dim als)
forall a b. (a -> b) -> a -> b
$ ScalarTypeBase dim als -> TypeBase dim als
forall dim u. ScalarTypeBase dim u -> TypeBase dim u
Scalar (ScalarTypeBase dim als -> TypeBase dim als)
-> ScalarTypeBase dim als -> TypeBase dim als
forall a b. (a -> b) -> a -> b
$ als -> QualName VName -> [TypeArg dim] -> ScalarTypeBase dim als
forall dim u.
u -> QualName VName -> [TypeArg dim] -> ScalarTypeBase dim u
TypeVar als
forall a. Monoid a => a
mempty (VName -> QualName VName
forall v. v -> QualName v
qualName VName
v) []

  curLevel :: TermTypeM Int
curLevel = (TermEnv -> Int) -> TermTypeM Int
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks TermEnv -> Int
termLevel

  newDimVar :: Usage -> Rigidity -> Name -> TermTypeM VName
newDimVar Usage
usage Rigidity
rigidity Name
name = do
    VName
dim <- Name -> TermTypeM VName
forall (m :: * -> *). MonadTypeChecker m => Name -> m VName
newTypeName Name
name
    case Rigidity
rigidity of
      Rigid RigidSource
rsrc -> VName -> Constraint -> TermTypeM ()
constrain VName
dim (Constraint -> TermTypeM ()) -> Constraint -> TermTypeM ()
forall a b. (a -> b) -> a -> b
$ Loc -> RigidSource -> Constraint
UnknownSize (Usage -> Loc
forall a. Located a => a -> Loc
locOf Usage
usage) RigidSource
rsrc
      Rigidity
Nonrigid -> VName -> Constraint -> TermTypeM ()
constrain VName
dim (Constraint -> TermTypeM ()) -> Constraint -> TermTypeM ()
forall a b. (a -> b) -> a -> b
$ Maybe Exp -> Usage -> Constraint
Size Maybe Exp
forall a. Maybe a
Nothing Usage
usage
    VName -> TermTypeM VName
forall a. a -> TermTypeM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
dim

  unifyError :: forall loc a.
Located loc =>
loc -> Notes -> BreadCrumbs -> Doc () -> TermTypeM a
unifyError loc
loc Notes
notes BreadCrumbs
bcs Doc ()
doc = do
    Maybe Checking
checking <- (TermEnv -> Maybe Checking) -> TermTypeM (Maybe Checking)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks TermEnv -> Maybe Checking
termChecking
    case Maybe Checking
checking of
      Just Checking
checking' ->
        TypeError -> TermTypeM a
forall a. TypeError -> TermTypeM a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (TypeError -> TermTypeM a) -> TypeError -> TermTypeM a
forall a b. (a -> b) -> a -> b
$
          Loc -> Notes -> Doc () -> TypeError
TypeError (loc -> Loc
forall a. Located a => a -> Loc
locOf loc
loc) Notes
notes (Doc () -> TypeError) -> Doc () -> TypeError
forall a b. (a -> b) -> a -> b
$
            Checking -> Doc ()
forall a ann. Pretty a => a -> Doc ann
forall ann. Checking -> Doc ann
pretty Checking
checking' Doc () -> Doc () -> Doc ()
forall a. Semigroup a => a -> a -> a
<> Doc ()
forall {ann}. Doc ann
line Doc () -> Doc () -> Doc ()
forall a. Doc a -> Doc a -> Doc a
</> Doc ()
doc Doc () -> Doc () -> Doc ()
forall a. Semigroup a => a -> a -> a
<> BreadCrumbs -> Doc ()
forall a ann. Pretty a => a -> Doc ann
forall ann. BreadCrumbs -> Doc ann
pretty BreadCrumbs
bcs
      Maybe Checking
Nothing ->
        TypeError -> TermTypeM a
forall a. TypeError -> TermTypeM a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (TypeError -> TermTypeM a) -> TypeError -> TermTypeM a
forall a b. (a -> b) -> a -> b
$ Loc -> Notes -> Doc () -> TypeError
TypeError (loc -> Loc
forall a. Located a => a -> Loc
locOf loc
loc) Notes
notes (Doc () -> TypeError) -> Doc () -> TypeError
forall a b. (a -> b) -> a -> b
$ Doc ()
doc Doc () -> Doc () -> Doc ()
forall a. Semigroup a => a -> a -> a
<> BreadCrumbs -> Doc ()
forall a ann. Pretty a => a -> Doc ann
forall ann. BreadCrumbs -> Doc ann
pretty BreadCrumbs
bcs

  matchError :: forall loc a.
Located loc =>
loc
-> Notes -> BreadCrumbs -> StructType -> StructType -> TermTypeM a
matchError loc
loc Notes
notes BreadCrumbs
bcs StructType
t1 StructType
t2 = do
    Maybe Checking
checking <- (TermEnv -> Maybe Checking) -> TermTypeM (Maybe Checking)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks TermEnv -> Maybe Checking
termChecking
    case Maybe Checking
checking of
      Just Checking
checking'
        | BreadCrumbs -> Bool
hasNoBreadCrumbs BreadCrumbs
bcs ->
            TypeError -> TermTypeM a
forall a. TypeError -> TermTypeM a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (TypeError -> TermTypeM a) -> TypeError -> TermTypeM a
forall a b. (a -> b) -> a -> b
$
              Loc -> Notes -> Doc () -> TypeError
TypeError (loc -> Loc
forall a. Located a => a -> Loc
locOf loc
loc) Notes
notes (Doc () -> TypeError) -> Doc () -> TypeError
forall a b. (a -> b) -> a -> b
$
                Checking -> Doc ()
forall a ann. Pretty a => a -> Doc ann
forall ann. Checking -> Doc ann
pretty Checking
checking'
        | Bool
otherwise ->
            TypeError -> TermTypeM a
forall a. TypeError -> TermTypeM a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (TypeError -> TermTypeM a) -> TypeError -> TermTypeM a
forall a b. (a -> b) -> a -> b
$
              Loc -> Notes -> Doc () -> TypeError
TypeError (loc -> Loc
forall a. Located a => a -> Loc
locOf loc
loc) Notes
notes (Doc () -> TypeError) -> Doc () -> TypeError
forall a b. (a -> b) -> a -> b
$
                Checking -> Doc ()
forall a ann. Pretty a => a -> Doc ann
forall ann. Checking -> Doc ann
pretty Checking
checking' Doc () -> Doc () -> Doc ()
forall a. Semigroup a => a -> a -> a
<> Doc ()
forall {ann}. Doc ann
line Doc () -> Doc () -> Doc ()
forall a. Doc a -> Doc a -> Doc a
</> Doc ()
forall {ann}. Doc ann
doc Doc () -> Doc () -> Doc ()
forall a. Semigroup a => a -> a -> a
<> BreadCrumbs -> Doc ()
forall a ann. Pretty a => a -> Doc ann
forall ann. BreadCrumbs -> Doc ann
pretty BreadCrumbs
bcs
      Maybe Checking
Nothing ->
        TypeError -> TermTypeM a
forall a. TypeError -> TermTypeM a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (TypeError -> TermTypeM a) -> TypeError -> TermTypeM a
forall a b. (a -> b) -> a -> b
$ Loc -> Notes -> Doc () -> TypeError
TypeError (loc -> Loc
forall a. Located a => a -> Loc
locOf loc
loc) Notes
notes (Doc () -> TypeError) -> Doc () -> TypeError
forall a b. (a -> b) -> a -> b
$ Doc ()
forall {ann}. Doc ann
doc Doc () -> Doc () -> Doc ()
forall a. Semigroup a => a -> a -> a
<> BreadCrumbs -> Doc ()
forall a ann. Pretty a => a -> Doc ann
forall ann. BreadCrumbs -> Doc ann
pretty BreadCrumbs
bcs
    where
      doc :: Doc a
doc =
        Doc a
"Types"
          Doc a -> Doc a -> Doc a
forall a. Doc a -> Doc a -> Doc a
</> Int -> Doc a -> Doc a
forall ann. Int -> Doc ann -> Doc ann
indent Int
2 (StructType -> Doc a
forall a ann. Pretty a => a -> Doc ann
forall ann. StructType -> Doc ann
pretty StructType
t1)
          Doc a -> Doc a -> Doc a
forall a. Doc a -> Doc a -> Doc a
</> Doc a
"and"
          Doc a -> Doc a -> Doc a
forall a. Doc a -> Doc a -> Doc a
</> Int -> Doc a -> Doc a
forall ann. Int -> Doc ann -> Doc ann
indent Int
2 (StructType -> Doc a
forall a ann. Pretty a => a -> Doc ann
forall ann. StructType -> Doc ann
pretty StructType
t2)
          Doc a -> Doc a -> Doc a
forall a. Doc a -> Doc a -> Doc a
</> Doc a
"do not match."

-- | Instantiate a type scheme with fresh type variables for its type
-- parameters. Returns the names of the fresh type variables, the
-- instance list, and the instantiated type.
instantiateTypeScheme ::
  QualName VName ->
  SrcLoc ->
  [TypeParam] ->
  StructType ->
  TermTypeM ([VName], StructType)
instantiateTypeScheme :: QualName VName
-> SrcLoc
-> [TypeParam]
-> StructType
-> TermTypeM ([VName], StructType)
instantiateTypeScheme QualName VName
qn SrcLoc
loc [TypeParam]
tparams StructType
t = do
  let tnames :: [VName]
tnames = (TypeParam -> VName) -> [TypeParam] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map TypeParam -> VName
forall vn. TypeParamBase vn -> vn
typeParamName [TypeParam]
tparams
  ([VName]
tparam_names, [Subst (RetTypeBase Exp NoUniqueness)]
tparam_substs) <- (TypeParam
 -> TermTypeM (VName, Subst (RetTypeBase Exp NoUniqueness)))
-> [TypeParam]
-> TermTypeM ([VName], [Subst (RetTypeBase Exp NoUniqueness)])
forall (m :: * -> *) a b c.
Applicative m =>
(a -> m (b, c)) -> [a] -> m ([b], [c])
mapAndUnzipM (QualName VName
-> SrcLoc
-> TypeParam
-> TermTypeM (VName, Subst (RetTypeBase Exp NoUniqueness))
forall as dim.
Monoid as =>
QualName VName
-> SrcLoc
-> TypeParam
-> TermTypeM (VName, Subst (RetTypeBase dim as))
instantiateTypeParam QualName VName
qn SrcLoc
loc) [TypeParam]
tparams
  let substs :: Map VName (Subst (RetTypeBase Exp NoUniqueness))
substs = [(VName, Subst (RetTypeBase Exp NoUniqueness))]
-> Map VName (Subst (RetTypeBase Exp NoUniqueness))
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, Subst (RetTypeBase Exp NoUniqueness))]
 -> Map VName (Subst (RetTypeBase Exp NoUniqueness)))
-> [(VName, Subst (RetTypeBase Exp NoUniqueness))]
-> Map VName (Subst (RetTypeBase Exp NoUniqueness))
forall a b. (a -> b) -> a -> b
$ [VName]
-> [Subst (RetTypeBase Exp NoUniqueness)]
-> [(VName, Subst (RetTypeBase Exp NoUniqueness))]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
tnames [Subst (RetTypeBase Exp NoUniqueness)]
tparam_substs
      t' :: StructType
t' = TypeSubs -> StructType -> StructType
forall a. Substitutable a => TypeSubs -> a -> a
applySubst (VName
-> Map VName (Subst (RetTypeBase Exp NoUniqueness))
-> Maybe (Subst (RetTypeBase Exp NoUniqueness))
forall k a. Ord k => k -> Map k a -> Maybe a
`M.lookup` Map VName (Subst (RetTypeBase Exp NoUniqueness))
substs) StructType
t
  ([VName], StructType) -> TermTypeM ([VName], StructType)
forall a. a -> TermTypeM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([VName]
tparam_names, StructType
t')

-- | Create a new type name and insert it (unconstrained) in the
-- substitution map.
instantiateTypeParam ::
  (Monoid as) =>
  QualName VName ->
  SrcLoc ->
  TypeParam ->
  TermTypeM (VName, Subst (RetTypeBase dim as))
instantiateTypeParam :: forall as dim.
Monoid as =>
QualName VName
-> SrcLoc
-> TypeParam
-> TermTypeM (VName, Subst (RetTypeBase dim as))
instantiateTypeParam QualName VName
qn SrcLoc
loc TypeParam
tparam = do
  Int
i <- TermTypeM Int
incCounter
  let name :: Name
name = [Char] -> Name
nameFromString ((Char -> Bool) -> ShowS
forall a. (a -> Bool) -> [a] -> [a]
takeWhile Char -> Bool
isAscii (VName -> [Char]
baseString (TypeParam -> VName
forall vn. TypeParamBase vn -> vn
typeParamName TypeParam
tparam)))
  VName
v <- Name -> TermTypeM VName
forall (m :: * -> *). MonadTypeChecker m => Name -> m VName
newID (Name -> TermTypeM VName) -> Name -> TermTypeM VName
forall a b. (a -> b) -> a -> b
$ Name -> Int -> Name
mkTypeVarName Name
name Int
i
  case TypeParam
tparam of
    TypeParamType Liftedness
x VName
_ SrcLoc
_ -> do
      VName -> Constraint -> TermTypeM ()
constrain VName
v (Constraint -> TermTypeM ())
-> (Doc Any -> Constraint) -> Doc Any -> TermTypeM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Liftedness -> Usage -> Constraint
NoConstraint Liftedness
x (Usage -> Constraint)
-> (Doc Any -> Usage) -> Doc Any -> Constraint
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SrcLoc -> Text -> Usage
forall a. Located a => a -> Text -> Usage
mkUsage SrcLoc
loc (Text -> Usage) -> (Doc Any -> Text) -> Doc Any -> Usage
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Doc Any -> Text
forall a. Doc a -> Text
docText (Doc Any -> TermTypeM ()) -> Doc Any -> TermTypeM ()
forall a b. (a -> b) -> a -> b
$
        Doc Any
"instantiated type parameter of " Doc Any -> Doc Any -> Doc Any
forall a. Semigroup a => a -> a -> a
<> Doc Any -> Doc Any
forall ann. Doc ann -> Doc ann
dquotes (QualName VName -> Doc Any
forall a ann. Pretty a => a -> Doc ann
forall ann. QualName VName -> Doc ann
pretty QualName VName
qn)
      (VName, Subst (RetTypeBase dim as))
-> TermTypeM (VName, Subst (RetTypeBase dim as))
forall a. a -> TermTypeM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
v, [TypeParam] -> RetTypeBase dim as -> Subst (RetTypeBase dim as)
forall t. [TypeParam] -> t -> Subst t
Subst [] (RetTypeBase dim as -> Subst (RetTypeBase dim as))
-> RetTypeBase dim as -> Subst (RetTypeBase dim as)
forall a b. (a -> b) -> a -> b
$ [VName] -> TypeBase dim as -> RetTypeBase dim as
forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType [] (TypeBase dim as -> RetTypeBase dim as)
-> TypeBase dim as -> RetTypeBase dim as
forall a b. (a -> b) -> a -> b
$ ScalarTypeBase dim as -> TypeBase dim as
forall dim u. ScalarTypeBase dim u -> TypeBase dim u
Scalar (ScalarTypeBase dim as -> TypeBase dim as)
-> ScalarTypeBase dim as -> TypeBase dim as
forall a b. (a -> b) -> a -> b
$ as -> QualName VName -> [TypeArg dim] -> ScalarTypeBase dim as
forall dim u.
u -> QualName VName -> [TypeArg dim] -> ScalarTypeBase dim u
TypeVar as
forall a. Monoid a => a
mempty (VName -> QualName VName
forall v. v -> QualName v
qualName VName
v) [])
    TypeParamDim {} -> do
      VName -> Constraint -> TermTypeM ()
constrain VName
v (Constraint -> TermTypeM ())
-> (Doc Any -> Constraint) -> Doc Any -> TermTypeM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Maybe Exp -> Usage -> Constraint
Size Maybe Exp
forall a. Maybe a
Nothing (Usage -> Constraint)
-> (Doc Any -> Usage) -> Doc Any -> Constraint
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SrcLoc -> Text -> Usage
forall a. Located a => a -> Text -> Usage
mkUsage SrcLoc
loc (Text -> Usage) -> (Doc Any -> Text) -> Doc Any -> Usage
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Doc Any -> Text
forall a. Doc a -> Text
docText (Doc Any -> TermTypeM ()) -> Doc Any -> TermTypeM ()
forall a b. (a -> b) -> a -> b
$
        Doc Any
"instantiated size parameter of " Doc Any -> Doc Any -> Doc Any
forall a. Semigroup a => a -> a -> a
<> Doc Any -> Doc Any
forall ann. Doc ann -> Doc ann
dquotes (QualName VName -> Doc Any
forall a ann. Pretty a => a -> Doc ann
forall ann. QualName VName -> Doc ann
pretty QualName VName
qn)
      (VName, Subst (RetTypeBase dim as))
-> TermTypeM (VName, Subst (RetTypeBase dim as))
forall a. a -> TermTypeM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
v, Exp -> Subst (RetTypeBase dim as)
forall t. Exp -> Subst t
ExpSubst (Exp -> Subst (RetTypeBase dim as))
-> Exp -> Subst (RetTypeBase dim as)
forall a b. (a -> b) -> a -> b
$ QualName VName -> SrcLoc -> Exp
sizeFromName (VName -> QualName VName
forall v. v -> QualName v
qualName VName
v) SrcLoc
loc)

lookupQualNameEnv :: QualName VName -> TermTypeM TermScope
lookupQualNameEnv :: QualName VName -> TermTypeM TermScope
lookupQualNameEnv (QualName [VName
q] VName
_)
  | VName -> Int
baseTag VName
q Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
maxIntrinsicTag = (TermEnv -> TermScope) -> TermTypeM TermScope
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks TermEnv -> TermScope
termScope -- Magical intrinsic module.
lookupQualNameEnv qn :: QualName VName
qn@(QualName [VName]
quals VName
_) = do
  TermScope
scope <- (TermEnv -> TermScope) -> TermTypeM TermScope
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks TermEnv -> TermScope
termScope
  TermScope -> [VName] -> TermTypeM TermScope
forall {f :: * -> *}.
Applicative f =>
TermScope -> [VName] -> f TermScope
descend TermScope
scope [VName]
quals
  where
    descend :: TermScope -> [VName] -> f TermScope
descend TermScope
scope [] = TermScope -> f TermScope
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure TermScope
scope
    descend TermScope
scope (VName
q : [VName]
qs)
      | Just (ModEnv Env
q_scope) <- VName -> Map VName Mod -> Maybe Mod
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
q (Map VName Mod -> Maybe Mod) -> Map VName Mod -> Maybe Mod
forall a b. (a -> b) -> a -> b
$ TermScope -> Map VName Mod
scopeModTable TermScope
scope =
          TermScope -> [VName] -> f TermScope
descend (Env -> TermScope
envToTermScope Env
q_scope) [VName]
qs
      | Bool
otherwise =
          [Char] -> f TermScope
forall a. HasCallStack => [Char] -> a
error ([Char] -> f TermScope) -> [Char] -> f TermScope
forall a b. (a -> b) -> a -> b
$ [Char]
"lookupQualNameEnv " [Char] -> ShowS
forall a. Semigroup a => a -> a -> a
<> QualName VName -> [Char]
forall a. Show a => a -> [Char]
show QualName VName
qn

lookupMod :: QualName VName -> TermTypeM Mod
lookupMod :: QualName VName -> TermTypeM Mod
lookupMod qn :: QualName VName
qn@(QualName [VName]
_ VName
name) = do
  TermScope
scope <- QualName VName -> TermTypeM TermScope
lookupQualNameEnv QualName VName
qn
  case VName -> Map VName Mod -> Maybe Mod
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
name (Map VName Mod -> Maybe Mod) -> Map VName Mod -> Maybe Mod
forall a b. (a -> b) -> a -> b
$ TermScope -> Map VName Mod
scopeModTable TermScope
scope of
    Maybe Mod
Nothing -> [Char] -> TermTypeM Mod
forall a. HasCallStack => [Char] -> a
error ([Char] -> TermTypeM Mod) -> [Char] -> TermTypeM Mod
forall a b. (a -> b) -> a -> b
$ [Char]
"lookupMod: " [Char] -> ShowS
forall a. Semigroup a => a -> a -> a
<> QualName VName -> [Char]
forall a. Show a => a -> [Char]
show QualName VName
qn
    Just Mod
m -> Mod -> TermTypeM Mod
forall a. a -> TermTypeM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Mod
m

localScope :: (TermScope -> TermScope) -> TermTypeM a -> TermTypeM a
localScope :: forall a. (TermScope -> TermScope) -> TermTypeM a -> TermTypeM a
localScope TermScope -> TermScope
f = (TermEnv -> TermEnv) -> TermTypeM a -> TermTypeM a
forall a. (TermEnv -> TermEnv) -> TermTypeM a -> TermTypeM a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((TermEnv -> TermEnv) -> TermTypeM a -> TermTypeM a)
-> (TermEnv -> TermEnv) -> TermTypeM a -> TermTypeM a
forall a b. (a -> b) -> a -> b
$ \TermEnv
tenv -> TermEnv
tenv {termScope :: TermScope
termScope = TermScope -> TermScope
f (TermScope -> TermScope) -> TermScope -> TermScope
forall a b. (a -> b) -> a -> b
$ TermEnv -> TermScope
termScope TermEnv
tenv}

instance MonadTypeChecker TermTypeM where
  warnings :: Warnings -> TermTypeM ()
warnings Warnings
ws =
    (TermTypeState -> TermTypeState) -> TermTypeM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((TermTypeState -> TermTypeState) -> TermTypeM ())
-> (TermTypeState -> TermTypeState) -> TermTypeM ()
forall a b. (a -> b) -> a -> b
$ \TermTypeState
s -> TermTypeState
s {stateWarnings :: Warnings
stateWarnings = TermTypeState -> Warnings
stateWarnings TermTypeState
s Warnings -> Warnings -> Warnings
forall a. Semigroup a => a -> a -> a
<> Warnings
ws}

  warn :: forall loc. Located loc => loc -> Doc () -> TermTypeM ()
warn loc
loc Doc ()
problem = Warnings -> TermTypeM ()
forall (m :: * -> *). MonadTypeChecker m => Warnings -> m ()
warnings (Warnings -> TermTypeM ()) -> Warnings -> TermTypeM ()
forall a b. (a -> b) -> a -> b
$ Loc -> Doc () -> Warnings
singleWarning (loc -> Loc
forall a. Located a => a -> Loc
locOf loc
loc) Doc ()
problem

  newName :: VName -> TermTypeM VName
newName VName
v = do
    TermTypeState
s <- TermTypeM TermTypeState
forall s (m :: * -> *). MonadState s m => m s
get
    let (VName
v', VNameSource
src') = VNameSource -> VName -> (VName, VNameSource)
Futhark.FreshNames.newName (TermTypeState -> VNameSource
stateNameSource TermTypeState
s) VName
v
    TermTypeState -> TermTypeM ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put (TermTypeState -> TermTypeM ()) -> TermTypeState -> TermTypeM ()
forall a b. (a -> b) -> a -> b
$ TermTypeState
s {stateNameSource :: VNameSource
stateNameSource = VNameSource
src'}
    VName -> TermTypeM VName
forall a. a -> TermTypeM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
v'

  newTypeName :: Name -> TermTypeM VName
newTypeName Name
name = do
    Int
i <- TermTypeM Int
incCounter
    Name -> TermTypeM VName
forall (m :: * -> *). MonadTypeChecker m => Name -> m VName
newID (Name -> TermTypeM VName) -> Name -> TermTypeM VName
forall a b. (a -> b) -> a -> b
$ Name -> Int -> Name
mkTypeVarName Name
name Int
i

  bindVal :: forall a. VName -> BoundV -> TermTypeM a -> TermTypeM a
bindVal VName
v (TypeM.BoundV [TypeParam]
tps StructType
t) = (TermScope -> TermScope) -> TermTypeM a -> TermTypeM a
forall a. (TermScope -> TermScope) -> TermTypeM a -> TermTypeM a
localScope ((TermScope -> TermScope) -> TermTypeM a -> TermTypeM a)
-> (TermScope -> TermScope) -> TermTypeM a -> TermTypeM a
forall a b. (a -> b) -> a -> b
$ \TermScope
scope ->
    TermScope
scope {scopeVtable :: Map VName ValBinding
scopeVtable = VName -> ValBinding -> Map VName ValBinding -> Map VName ValBinding
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v ([TypeParam] -> StructType -> ValBinding
BoundV [TypeParam]
tps StructType
t) (Map VName ValBinding -> Map VName ValBinding)
-> Map VName ValBinding -> Map VName ValBinding
forall a b. (a -> b) -> a -> b
$ TermScope -> Map VName ValBinding
scopeVtable TermScope
scope}

  lookupType :: QualName VName
-> TermTypeM
     ([TypeParam], RetTypeBase Exp NoUniqueness, Liftedness)
lookupType QualName VName
qn = do
    Env
outer_env <- (TermEnv -> Env) -> TermTypeM Env
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks TermEnv -> Env
termOuterEnv
    TermScope
scope <- QualName VName -> TermTypeM TermScope
lookupQualNameEnv QualName VName
qn
    case VName -> Map VName TypeBinding -> Maybe TypeBinding
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (QualName VName -> VName
forall vn. QualName vn -> vn
qualLeaf QualName VName
qn) (Map VName TypeBinding -> Maybe TypeBinding)
-> Map VName TypeBinding -> Maybe TypeBinding
forall a b. (a -> b) -> a -> b
$ TermScope -> Map VName TypeBinding
scopeTypeTable TermScope
scope of
      Maybe TypeBinding
Nothing -> [Char]
-> TermTypeM
     ([TypeParam], RetTypeBase Exp NoUniqueness, Liftedness)
forall a. HasCallStack => [Char] -> a
error ([Char]
 -> TermTypeM
      ([TypeParam], RetTypeBase Exp NoUniqueness, Liftedness))
-> [Char]
-> TermTypeM
     ([TypeParam], RetTypeBase Exp NoUniqueness, Liftedness)
forall a b. (a -> b) -> a -> b
$ [Char]
"lookupType: " [Char] -> ShowS
forall a. Semigroup a => a -> a -> a
<> QualName VName -> [Char]
forall a. Show a => a -> [Char]
show QualName VName
qn
      Just (TypeAbbr Liftedness
l [TypeParam]
ps (RetType [VName]
dims StructType
def)) ->
        ([TypeParam], RetTypeBase Exp NoUniqueness, Liftedness)
-> TermTypeM
     ([TypeParam], RetTypeBase Exp NoUniqueness, Liftedness)
forall a. a -> TermTypeM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
          ( [TypeParam]
ps,
            [VName] -> StructType -> RetTypeBase Exp NoUniqueness
forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType [VName]
dims (StructType -> RetTypeBase Exp NoUniqueness)
-> StructType -> RetTypeBase Exp NoUniqueness
forall a b. (a -> b) -> a -> b
$ Env -> [VName] -> [VName] -> StructType -> StructType
forall as.
Env -> [VName] -> [VName] -> TypeBase Exp as -> TypeBase Exp as
qualifyTypeVars Env
outer_env ((TypeParam -> VName) -> [TypeParam] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map TypeParam -> VName
forall vn. TypeParamBase vn -> vn
typeParamName [TypeParam]
ps) (QualName VName -> [VName]
forall vn. QualName vn -> [vn]
qualQuals QualName VName
qn) StructType
def,
            Liftedness
l
          )

  typeError :: forall loc a. Located loc => loc -> Notes -> Doc () -> TermTypeM a
typeError loc
loc Notes
notes Doc ()
s = do
    Maybe Checking
checking <- (TermEnv -> Maybe Checking) -> TermTypeM (Maybe Checking)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks TermEnv -> Maybe Checking
termChecking
    case Maybe Checking
checking of
      Just Checking
checking' ->
        TypeError -> TermTypeM a
forall a. TypeError -> TermTypeM a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (TypeError -> TermTypeM a) -> TypeError -> TermTypeM a
forall a b. (a -> b) -> a -> b
$ Loc -> Notes -> Doc () -> TypeError
TypeError (loc -> Loc
forall a. Located a => a -> Loc
locOf loc
loc) Notes
notes (Checking -> Doc ()
forall a ann. Pretty a => a -> Doc ann
forall ann. Checking -> Doc ann
pretty Checking
checking' Doc () -> Doc () -> Doc ()
forall a. Semigroup a => a -> a -> a
<> Doc ()
forall {ann}. Doc ann
line Doc () -> Doc () -> Doc ()
forall a. Doc a -> Doc a -> Doc a
</> Doc ()
s)
      Maybe Checking
Nothing ->
        TypeError -> TermTypeM a
forall a. TypeError -> TermTypeM a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (TypeError -> TermTypeM a) -> TypeError -> TermTypeM a
forall a b. (a -> b) -> a -> b
$ Loc -> Notes -> Doc () -> TypeError
TypeError (loc -> Loc
forall a. Located a => a -> Loc
locOf loc
loc) Notes
notes Doc ()
s

lookupVar :: SrcLoc -> QualName VName -> TermTypeM StructType
lookupVar :: SrcLoc -> QualName VName -> TermTypeM StructType
lookupVar SrcLoc
loc qn :: QualName VName
qn@(QualName [VName]
qs VName
name) = do
  TermScope
scope <- QualName VName -> TermTypeM TermScope
lookupQualNameEnv QualName VName
qn
  let usage :: Usage
usage = SrcLoc -> Text -> Usage
forall a. Located a => a -> Text -> Usage
mkUsage SrcLoc
loc (Text -> Usage) -> Text -> Usage
forall a b. (a -> b) -> a -> b
$ Doc Any -> Text
forall a. Doc a -> Text
docText (Doc Any -> Text) -> Doc Any -> Text
forall a b. (a -> b) -> a -> b
$ Doc Any
"use of " Doc Any -> Doc Any -> Doc Any
forall a. Semigroup a => a -> a -> a
<> Doc Any -> Doc Any
forall ann. Doc ann -> Doc ann
dquotes (QualName VName -> Doc Any
forall a ann. Pretty a => a -> Doc ann
forall ann. QualName VName -> Doc ann
pretty QualName VName
qn)

  case VName -> Map VName ValBinding -> Maybe ValBinding
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
name (Map VName ValBinding -> Maybe ValBinding)
-> Map VName ValBinding -> Maybe ValBinding
forall a b. (a -> b) -> a -> b
$ TermScope -> Map VName ValBinding
scopeVtable TermScope
scope of
    Maybe ValBinding
Nothing ->
      [Char] -> TermTypeM StructType
forall a. HasCallStack => [Char] -> a
error ([Char] -> TermTypeM StructType) -> [Char] -> TermTypeM StructType
forall a b. (a -> b) -> a -> b
$ [Char]
"lookupVar: " [Char] -> ShowS
forall a. Semigroup a => a -> a -> a
<> QualName VName -> [Char]
forall a. Show a => a -> [Char]
show QualName VName
qn
    Just (BoundV [TypeParam]
tparams StructType
t) -> do
      if [TypeParam] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [TypeParam]
tparams Bool -> Bool -> Bool
&& [VName] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [VName]
qs
        then StructType -> TermTypeM StructType
forall a. a -> TermTypeM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure StructType
t
        else do
          ([VName]
tnames, StructType
t') <- QualName VName
-> SrcLoc
-> [TypeParam]
-> StructType
-> TermTypeM ([VName], StructType)
instantiateTypeScheme QualName VName
qn SrcLoc
loc [TypeParam]
tparams StructType
t
          Env
outer_env <- (TermEnv -> Env) -> TermTypeM Env
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks TermEnv -> Env
termOuterEnv
          StructType -> TermTypeM StructType
forall a. a -> TermTypeM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (StructType -> TermTypeM StructType)
-> StructType -> TermTypeM StructType
forall a b. (a -> b) -> a -> b
$ Env -> [VName] -> [VName] -> StructType -> StructType
forall as.
Env -> [VName] -> [VName] -> TypeBase Exp as -> TypeBase Exp as
qualifyTypeVars Env
outer_env [VName]
tnames [VName]
qs StructType
t'
    Just ValBinding
EqualityF -> do
      StructType
argtype <- SrcLoc -> Name -> TermTypeM StructType
forall als a dim.
(Monoid als, Located a) =>
a -> Name -> TermTypeM (TypeBase dim als)
forall (m :: * -> *) als a dim.
(MonadUnify m, Monoid als, Located a) =>
a -> Name -> m (TypeBase dim als)
newTypeVar SrcLoc
loc Name
"t"
      Usage -> StructType -> TermTypeM ()
forall (m :: * -> *) dim u.
(MonadUnify m, Pretty (Shape dim), Pretty u) =>
Usage -> TypeBase dim u -> m ()
equalityType Usage
usage StructType
argtype
      StructType -> TermTypeM StructType
forall a. a -> TermTypeM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (StructType -> TermTypeM StructType)
-> StructType -> TermTypeM StructType
forall a b. (a -> b) -> a -> b
$
        ScalarTypeBase Exp NoUniqueness -> StructType
forall dim u. ScalarTypeBase dim u -> TypeBase dim u
Scalar (ScalarTypeBase Exp NoUniqueness -> StructType)
-> (ResType -> ScalarTypeBase Exp NoUniqueness)
-> ResType
-> StructType
forall b c a. (b -> c) -> (a -> b) -> a -> c
. NoUniqueness
-> PName
-> Diet
-> StructType
-> RetTypeBase Exp Uniqueness
-> ScalarTypeBase Exp NoUniqueness
forall dim u.
u
-> PName
-> Diet
-> TypeBase dim NoUniqueness
-> RetTypeBase dim Uniqueness
-> ScalarTypeBase dim u
Arrow NoUniqueness
forall a. Monoid a => a
mempty PName
Unnamed Diet
Observe StructType
argtype (RetTypeBase Exp Uniqueness -> ScalarTypeBase Exp NoUniqueness)
-> (ResType -> RetTypeBase Exp Uniqueness)
-> ResType
-> ScalarTypeBase Exp NoUniqueness
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [VName] -> ResType -> RetTypeBase Exp Uniqueness
forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType [] (ResType -> StructType) -> ResType -> StructType
forall a b. (a -> b) -> a -> b
$
          ScalarTypeBase Exp Uniqueness -> ResType
forall dim u. ScalarTypeBase dim u -> TypeBase dim u
Scalar (ScalarTypeBase Exp Uniqueness -> ResType)
-> ScalarTypeBase Exp Uniqueness -> ResType
forall a b. (a -> b) -> a -> b
$
            Uniqueness
-> PName
-> Diet
-> StructType
-> RetTypeBase Exp Uniqueness
-> ScalarTypeBase Exp Uniqueness
forall dim u.
u
-> PName
-> Diet
-> TypeBase dim NoUniqueness
-> RetTypeBase dim Uniqueness
-> ScalarTypeBase dim u
Arrow Uniqueness
forall a. Monoid a => a
mempty PName
Unnamed Diet
Observe StructType
argtype (RetTypeBase Exp Uniqueness -> ScalarTypeBase Exp Uniqueness)
-> RetTypeBase Exp Uniqueness -> ScalarTypeBase Exp Uniqueness
forall a b. (a -> b) -> a -> b
$
              [VName] -> ResType -> RetTypeBase Exp Uniqueness
forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType [] (ResType -> RetTypeBase Exp Uniqueness)
-> ResType -> RetTypeBase Exp Uniqueness
forall a b. (a -> b) -> a -> b
$
                ScalarTypeBase Exp Uniqueness -> ResType
forall dim u. ScalarTypeBase dim u -> TypeBase dim u
Scalar (ScalarTypeBase Exp Uniqueness -> ResType)
-> ScalarTypeBase Exp Uniqueness -> ResType
forall a b. (a -> b) -> a -> b
$
                  PrimType -> ScalarTypeBase Exp Uniqueness
forall dim u. PrimType -> ScalarTypeBase dim u
Prim PrimType
Bool
    Just (OverloadedF [PrimType]
ts [Maybe PrimType]
pts Maybe PrimType
rt) -> do
      StructType
argtype <- SrcLoc -> Name -> TermTypeM StructType
forall als a dim.
(Monoid als, Located a) =>
a -> Name -> TermTypeM (TypeBase dim als)
forall (m :: * -> *) als a dim.
(MonadUnify m, Monoid als, Located a) =>
a -> Name -> m (TypeBase dim als)
newTypeVar SrcLoc
loc Name
"t"
      [PrimType] -> Usage -> StructType -> TermTypeM ()
forall (m :: * -> *).
MonadUnify m =>
[PrimType] -> Usage -> StructType -> m ()
mustBeOneOf [PrimType]
ts Usage
usage StructType
argtype
      let ([StructType]
pts', StructType
rt') = StructType
-> [Maybe PrimType] -> Maybe PrimType -> ([StructType], StructType)
forall {dim} {u}.
TypeBase dim u
-> [Maybe PrimType]
-> Maybe PrimType
-> ([TypeBase dim NoUniqueness], TypeBase dim NoUniqueness)
instOverloaded StructType
argtype [Maybe PrimType]
pts Maybe PrimType
rt
      StructType -> TermTypeM StructType
forall a. a -> TermTypeM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (StructType -> TermTypeM StructType)
-> StructType -> TermTypeM StructType
forall a b. (a -> b) -> a -> b
$ [ParamType] -> RetTypeBase Exp Uniqueness -> StructType
foldFunType ((StructType -> ParamType) -> [StructType] -> [ParamType]
forall a b. (a -> b) -> [a] -> [b]
map (Diet -> StructType -> ParamType
forall u. Diet -> TypeBase Exp u -> ParamType
toParam Diet
Observe) [StructType]
pts') (RetTypeBase Exp Uniqueness -> StructType)
-> RetTypeBase Exp Uniqueness -> StructType
forall a b. (a -> b) -> a -> b
$ [VName] -> ResType -> RetTypeBase Exp Uniqueness
forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType [] (ResType -> RetTypeBase Exp Uniqueness)
-> ResType -> RetTypeBase Exp Uniqueness
forall a b. (a -> b) -> a -> b
$ Uniqueness -> StructType -> ResType
forall u. Uniqueness -> TypeBase Exp u -> ResType
toRes Uniqueness
Nonunique StructType
rt'
  where
    instOverloaded :: TypeBase dim u
-> [Maybe PrimType]
-> Maybe PrimType
-> ([TypeBase dim NoUniqueness], TypeBase dim NoUniqueness)
instOverloaded TypeBase dim u
argtype [Maybe PrimType]
pts Maybe PrimType
rt =
      ( (Maybe PrimType -> TypeBase dim NoUniqueness)
-> [Maybe PrimType] -> [TypeBase dim NoUniqueness]
forall a b. (a -> b) -> [a] -> [b]
map (TypeBase dim NoUniqueness
-> (PrimType -> TypeBase dim NoUniqueness)
-> Maybe PrimType
-> TypeBase dim NoUniqueness
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (TypeBase dim u -> TypeBase dim NoUniqueness
forall dim u. TypeBase dim u -> TypeBase dim NoUniqueness
toStruct TypeBase dim u
argtype) (ScalarTypeBase dim NoUniqueness -> TypeBase dim NoUniqueness
forall dim u. ScalarTypeBase dim u -> TypeBase dim u
Scalar (ScalarTypeBase dim NoUniqueness -> TypeBase dim NoUniqueness)
-> (PrimType -> ScalarTypeBase dim NoUniqueness)
-> PrimType
-> TypeBase dim NoUniqueness
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> ScalarTypeBase dim NoUniqueness
forall dim u. PrimType -> ScalarTypeBase dim u
Prim)) [Maybe PrimType]
pts,
        TypeBase dim NoUniqueness
-> (PrimType -> TypeBase dim NoUniqueness)
-> Maybe PrimType
-> TypeBase dim NoUniqueness
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (TypeBase dim u -> TypeBase dim NoUniqueness
forall dim u. TypeBase dim u -> TypeBase dim NoUniqueness
toStruct TypeBase dim u
argtype) (ScalarTypeBase dim NoUniqueness -> TypeBase dim NoUniqueness
forall dim u. ScalarTypeBase dim u -> TypeBase dim u
Scalar (ScalarTypeBase dim NoUniqueness -> TypeBase dim NoUniqueness)
-> (PrimType -> ScalarTypeBase dim NoUniqueness)
-> PrimType
-> TypeBase dim NoUniqueness
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> ScalarTypeBase dim NoUniqueness
forall dim u. PrimType -> ScalarTypeBase dim u
Prim) Maybe PrimType
rt
      )

onFailure :: Checking -> TermTypeM a -> TermTypeM a
onFailure :: forall a. Checking -> TermTypeM a -> TermTypeM a
onFailure Checking
c = (TermEnv -> TermEnv) -> TermTypeM a -> TermTypeM a
forall a. (TermEnv -> TermEnv) -> TermTypeM a -> TermTypeM a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((TermEnv -> TermEnv) -> TermTypeM a -> TermTypeM a)
-> (TermEnv -> TermEnv) -> TermTypeM a -> TermTypeM a
forall a b. (a -> b) -> a -> b
$ \TermEnv
env -> TermEnv
env {termChecking :: Maybe Checking
termChecking = Checking -> Maybe Checking
forall a. a -> Maybe a
Just Checking
c}

extSize :: SrcLoc -> SizeSource -> TermTypeM (Size, Maybe VName)
extSize :: SrcLoc -> SizeSource -> TermTypeM (Exp, Maybe VName)
extSize SrcLoc
loc SizeSource
e = do
  let rsrc :: RigidSource
rsrc = case SizeSource
e of
        SourceArg (FName Maybe (QualName VName)
fname) ExpBase NoInfo VName
e' ->
          Maybe (QualName VName) -> Text -> RigidSource
RigidArg Maybe (QualName VName)
fname (Text -> RigidSource) -> Text -> RigidSource
forall a b. (a -> b) -> a -> b
$ ExpBase NoInfo VName -> Text
forall a. Pretty a => a -> Text
prettyTextOneLine ExpBase NoInfo VName
e'
        SourceSlice Maybe Exp
d Maybe (ExpBase NoInfo VName)
i Maybe (ExpBase NoInfo VName)
j Maybe (ExpBase NoInfo VName)
s ->
          Maybe Exp -> Text -> RigidSource
RigidSlice Maybe Exp
d (Text -> RigidSource) -> Text -> RigidSource
forall a b. (a -> b) -> a -> b
$ DimIndexBase NoInfo VName -> Text
forall a. Pretty a => a -> Text
prettyTextOneLine (DimIndexBase NoInfo VName -> Text)
-> DimIndexBase NoInfo VName -> Text
forall a b. (a -> b) -> a -> b
$ Maybe (ExpBase NoInfo VName)
-> Maybe (ExpBase NoInfo VName)
-> Maybe (ExpBase NoInfo VName)
-> DimIndexBase NoInfo VName
forall (f :: * -> *) vn.
Maybe (ExpBase f vn)
-> Maybe (ExpBase f vn)
-> Maybe (ExpBase f vn)
-> DimIndexBase f vn
DimSlice Maybe (ExpBase NoInfo VName)
i Maybe (ExpBase NoInfo VName)
j Maybe (ExpBase NoInfo VName)
s
  VName
d <- SrcLoc -> RigidSource -> Name -> TermTypeM VName
forall a. Located a => a -> RigidSource -> Name -> TermTypeM VName
forall (m :: * -> *) a.
(MonadUnify m, Located a) =>
a -> RigidSource -> Name -> m VName
newRigidDim SrcLoc
loc RigidSource
rsrc Name
"n"
  (Exp, Maybe VName) -> TermTypeM (Exp, Maybe VName)
forall a. a -> TermTypeM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
    ( QualName VName -> SrcLoc -> Exp
sizeFromName (VName -> QualName VName
forall v. v -> QualName v
qualName VName
d) SrcLoc
loc,
      VName -> Maybe VName
forall a. a -> Maybe a
Just VName
d
    )

incLevel :: TermTypeM a -> TermTypeM a
incLevel :: forall a. TermTypeM a -> TermTypeM a
incLevel = (TermEnv -> TermEnv) -> TermTypeM a -> TermTypeM a
forall a. (TermEnv -> TermEnv) -> TermTypeM a -> TermTypeM a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((TermEnv -> TermEnv) -> TermTypeM a -> TermTypeM a)
-> (TermEnv -> TermEnv) -> TermTypeM a -> TermTypeM a
forall a b. (a -> b) -> a -> b
$ \TermEnv
env -> TermEnv
env {termLevel :: Int
termLevel = TermEnv -> Int
termLevel TermEnv
env Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1}

-- | Get the type of an expression, with top level type variables
-- substituted.  Never call 'typeOf' directly (except in a few
-- carefully inspected locations)!
expType :: Exp -> TermTypeM StructType
expType :: Exp -> TermTypeM StructType
expType = StructType -> TermTypeM StructType
forall (m :: * -> *). MonadUnify m => StructType -> m StructType
normType (StructType -> TermTypeM StructType)
-> (Exp -> StructType) -> Exp -> TermTypeM StructType
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Exp -> StructType
typeOf

-- | Get the type of an expression, with all type variables
-- substituted.  Slower than 'expType', but sometimes necessary.
-- Never call 'typeOf' directly (except in a few carefully inspected
-- locations)!
expTypeFully :: Exp -> TermTypeM StructType
expTypeFully :: Exp -> TermTypeM StructType
expTypeFully = StructType -> TermTypeM StructType
forall a (m :: * -> *). (Substitutable a, MonadUnify m) => a -> m a
normTypeFully (StructType -> TermTypeM StructType)
-> (Exp -> StructType) -> Exp -> TermTypeM StructType
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Exp -> StructType
typeOf

newArrayType :: Usage -> Name -> Int -> TermTypeM (StructType, StructType)
newArrayType :: Usage -> Name -> Int -> TermTypeM (StructType, StructType)
newArrayType Usage
usage Name
desc Int
r = do
  VName
v <- Name -> TermTypeM VName
forall (m :: * -> *). MonadTypeChecker m => Name -> m VName
newTypeName Name
desc
  VName -> Constraint -> TermTypeM ()
constrain VName
v (Constraint -> TermTypeM ()) -> Constraint -> TermTypeM ()
forall a b. (a -> b) -> a -> b
$ Liftedness -> Usage -> Constraint
NoConstraint Liftedness
Unlifted Usage
usage
  [VName]
dims <- Int -> TermTypeM VName -> TermTypeM [VName]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
r (TermTypeM VName -> TermTypeM [VName])
-> TermTypeM VName -> TermTypeM [VName]
forall a b. (a -> b) -> a -> b
$ Usage -> Rigidity -> Name -> TermTypeM VName
forall (m :: * -> *).
MonadUnify m =>
Usage -> Rigidity -> Name -> m VName
newDimVar Usage
usage Rigidity
Nonrigid Name
"dim"
  let rowt :: ScalarTypeBase dim NoUniqueness
rowt = NoUniqueness
-> QualName VName
-> [TypeArg dim]
-> ScalarTypeBase dim NoUniqueness
forall dim u.
u -> QualName VName -> [TypeArg dim] -> ScalarTypeBase dim u
TypeVar NoUniqueness
forall a. Monoid a => a
mempty (VName -> QualName VName
forall v. v -> QualName v
qualName VName
v) []
      mkSize :: VName -> Exp
mkSize = (QualName VName -> SrcLoc -> Exp)
-> SrcLoc -> QualName VName -> Exp
forall a b c. (a -> b -> c) -> b -> a -> c
flip QualName VName -> SrcLoc -> Exp
sizeFromName (Usage -> SrcLoc
forall a. Located a => a -> SrcLoc
srclocOf Usage
usage) (QualName VName -> Exp)
-> (VName -> QualName VName) -> VName -> Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> QualName VName
forall v. v -> QualName v
qualName
  (StructType, StructType) -> TermTypeM (StructType, StructType)
forall a. a -> TermTypeM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
    ( NoUniqueness
-> Shape Exp -> ScalarTypeBase Exp NoUniqueness -> StructType
forall dim u.
u -> Shape dim -> ScalarTypeBase dim NoUniqueness -> TypeBase dim u
Array NoUniqueness
forall a. Monoid a => a
mempty ([Exp] -> Shape Exp
forall dim. [dim] -> Shape dim
Shape ([Exp] -> Shape Exp) -> [Exp] -> Shape Exp
forall a b. (a -> b) -> a -> b
$ (VName -> Exp) -> [VName] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Exp
mkSize [VName]
dims) ScalarTypeBase Exp NoUniqueness
forall {dim}. ScalarTypeBase dim NoUniqueness
rowt,
      ScalarTypeBase Exp NoUniqueness -> StructType
forall dim u. ScalarTypeBase dim u -> TypeBase dim u
Scalar ScalarTypeBase Exp NoUniqueness
forall {dim}. ScalarTypeBase dim NoUniqueness
rowt
    )

-- | Replace *all* dimensions with distinct fresh size variables.
allDimsFreshInType ::
  Usage ->
  Rigidity ->
  Name ->
  TypeBase Size als ->
  TermTypeM (TypeBase Size als, M.Map VName Size)
allDimsFreshInType :: forall als.
Usage
-> Rigidity
-> Name
-> TypeBase Exp als
-> TermTypeM (TypeBase Exp als, Map VName Exp)
allDimsFreshInType Usage
usage Rigidity
r Name
desc TypeBase Exp als
t =
  StateT (Map VName Exp) TermTypeM (TypeBase Exp als)
-> Map VName Exp -> TermTypeM (TypeBase Exp als, Map VName Exp)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT ((Exp -> StateT (Map VName Exp) TermTypeM Exp)
-> (als -> StateT (Map VName Exp) TermTypeM als)
-> TypeBase Exp als
-> StateT (Map VName Exp) TermTypeM (TypeBase Exp als)
forall (f :: * -> *) a c b d.
Applicative f =>
(a -> f c) -> (b -> f d) -> TypeBase a b -> f (TypeBase c d)
forall (t :: * -> * -> *) (f :: * -> *) a c b d.
(Bitraversable t, Applicative f) =>
(a -> f c) -> (b -> f d) -> t a b -> f (t c d)
bitraverse Exp -> StateT (Map VName Exp) TermTypeM Exp
forall {t :: (* -> *) -> * -> *} {m :: * -> *} {a}.
(MonadTrans t, MonadUnify m, MonadState (Map VName a) (t m)) =>
a -> t m Exp
onDim als -> StateT (Map VName Exp) TermTypeM als
forall a. a -> StateT (Map VName Exp) TermTypeM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure TypeBase Exp als
t) Map VName Exp
forall a. Monoid a => a
mempty
  where
    onDim :: a -> t m Exp
onDim a
d = do
      VName
v <- m VName -> t m VName
forall (m :: * -> *) a. Monad m => m a -> t m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m VName -> t m VName) -> m VName -> t m VName
forall a b. (a -> b) -> a -> b
$ Usage -> Rigidity -> Name -> m VName
forall (m :: * -> *).
MonadUnify m =>
Usage -> Rigidity -> Name -> m VName
newDimVar Usage
usage Rigidity
r Name
desc
      (Map VName a -> Map VName a) -> t m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((Map VName a -> Map VName a) -> t m ())
-> (Map VName a -> Map VName a) -> t m ()
forall a b. (a -> b) -> a -> b
$ VName -> a -> Map VName a -> Map VName a
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v a
d
      Exp -> t m Exp
forall a. a -> t m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp -> t m Exp) -> Exp -> t m Exp
forall a b. (a -> b) -> a -> b
$ QualName VName -> SrcLoc -> Exp
sizeFromName (VName -> QualName VName
forall v. v -> QualName v
qualName VName
v) (SrcLoc -> Exp) -> SrcLoc -> Exp
forall a b. (a -> b) -> a -> b
$ Usage -> SrcLoc
forall a. Located a => a -> SrcLoc
srclocOf Usage
usage

-- | Replace all type variables with their concrete types.
updateTypes :: (ASTMappable e) => e -> TermTypeM e
updateTypes :: forall e. ASTMappable e => e -> TermTypeM e
updateTypes = ASTMapper TermTypeM -> e -> TermTypeM e
forall x (m :: * -> *).
(ASTMappable x, Monad m) =>
ASTMapper m -> x -> m x
forall (m :: * -> *). Monad m => ASTMapper m -> e -> m e
astMap ASTMapper TermTypeM
tv
  where
    tv :: ASTMapper TermTypeM
tv =
      ASTMapper
        { mapOnExp :: Exp -> TermTypeM Exp
mapOnExp = ASTMapper TermTypeM -> Exp -> TermTypeM Exp
forall x (m :: * -> *).
(ASTMappable x, Monad m) =>
ASTMapper m -> x -> m x
forall (m :: * -> *). Monad m => ASTMapper m -> Exp -> m Exp
astMap ASTMapper TermTypeM
tv,
          mapOnName :: QualName VName -> TermTypeM (QualName VName)
mapOnName = QualName VName -> TermTypeM (QualName VName)
forall a. a -> TermTypeM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure,
          mapOnStructType :: StructType -> TermTypeM StructType
mapOnStructType = StructType -> TermTypeM StructType
forall a (m :: * -> *). (Substitutable a, MonadUnify m) => a -> m a
normTypeFully,
          mapOnParamType :: ParamType -> TermTypeM ParamType
mapOnParamType = ParamType -> TermTypeM ParamType
forall a (m :: * -> *). (Substitutable a, MonadUnify m) => a -> m a
normTypeFully,
          mapOnResRetType :: RetTypeBase Exp Uniqueness
-> TermTypeM (RetTypeBase Exp Uniqueness)
mapOnResRetType = RetTypeBase Exp Uniqueness
-> TermTypeM (RetTypeBase Exp Uniqueness)
forall a (m :: * -> *). (Substitutable a, MonadUnify m) => a -> m a
normTypeFully
        }

--- Basic checking

unifies :: T.Text -> StructType -> Exp -> TermTypeM Exp
unifies :: Text -> StructType -> Exp -> TermTypeM Exp
unifies Text
why StructType
t Exp
e = do
  Usage -> StructType -> StructType -> TermTypeM ()
forall (m :: * -> *).
MonadUnify m =>
Usage -> StructType -> StructType -> m ()
unify (SrcLoc -> Text -> Usage
forall a. Located a => a -> Text -> Usage
mkUsage (Exp -> SrcLoc
forall a. Located a => a -> SrcLoc
srclocOf Exp
e) Text
why) StructType
t (StructType -> TermTypeM ())
-> (StructType -> StructType) -> StructType -> TermTypeM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StructType -> StructType
forall dim u. TypeBase dim u -> TypeBase dim NoUniqueness
toStruct (StructType -> TermTypeM ())
-> TermTypeM StructType -> TermTypeM ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Exp -> TermTypeM StructType
expType Exp
e
  Exp -> TermTypeM Exp
forall a. a -> TermTypeM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Exp
e

-- | @require ts e@ causes a 'TypeError' if @expType e@ is not one of
-- the types in @ts@.  Otherwise, simply returns @e@.
require :: T.Text -> [PrimType] -> Exp -> TermTypeM Exp
require :: Text -> [PrimType] -> Exp -> TermTypeM Exp
require Text
why [PrimType]
ts Exp
e = do
  [PrimType] -> Usage -> StructType -> TermTypeM ()
forall (m :: * -> *).
MonadUnify m =>
[PrimType] -> Usage -> StructType -> m ()
mustBeOneOf [PrimType]
ts (SrcLoc -> Text -> Usage
forall a. Located a => a -> Text -> Usage
mkUsage (Exp -> SrcLoc
forall a. Located a => a -> SrcLoc
srclocOf Exp
e) Text
why) (StructType -> TermTypeM ())
-> (StructType -> StructType) -> StructType -> TermTypeM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StructType -> StructType
forall dim u. TypeBase dim u -> TypeBase dim NoUniqueness
toStruct (StructType -> TermTypeM ())
-> TermTypeM StructType -> TermTypeM ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Exp -> TermTypeM StructType
expType Exp
e
  Exp -> TermTypeM Exp
forall a. a -> TermTypeM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Exp
e

checkExpForSize :: ExpBase NoInfo VName -> TermTypeM Exp
checkExpForSize :: ExpBase NoInfo VName -> TermTypeM Exp
checkExpForSize ExpBase NoInfo VName
e = do
  ExpBase NoInfo VName -> TermTypeM Exp
checker <- (TermEnv -> ExpBase NoInfo VName -> TermTypeM Exp)
-> TermTypeM (ExpBase NoInfo VName -> TermTypeM Exp)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks TermEnv -> ExpBase NoInfo VName -> TermTypeM Exp
termChecker
  Exp
e' <- ExpBase NoInfo VName -> TermTypeM Exp
checker ExpBase NoInfo VName
e
  let t :: StructType
t = StructType -> StructType
forall dim u. TypeBase dim u -> TypeBase dim NoUniqueness
toStruct (StructType -> StructType) -> StructType -> StructType
forall a b. (a -> b) -> a -> b
$ Exp -> StructType
typeOf Exp
e'
  Usage -> StructType -> StructType -> TermTypeM ()
forall (m :: * -> *).
MonadUnify m =>
Usage -> StructType -> StructType -> m ()
unify (Loc -> Text -> Usage
forall a. Located a => a -> Text -> Usage
mkUsage (Exp -> Loc
forall a. Located a => a -> Loc
locOf Exp
e') Text
"Size expression") StructType
t (ScalarTypeBase Exp NoUniqueness -> StructType
forall dim u. ScalarTypeBase dim u -> TypeBase dim u
Scalar (PrimType -> ScalarTypeBase Exp NoUniqueness
forall dim u. PrimType -> ScalarTypeBase dim u
Prim (IntType -> PrimType
Signed IntType
Int64)))
  Exp -> TermTypeM Exp
forall e. ASTMappable e => e -> TermTypeM e
updateTypes Exp
e'

checkTypeExpNonrigid ::
  TypeExp (ExpBase NoInfo VName) VName ->
  TermTypeM (TypeExp Exp VName, ResType, [VName])
checkTypeExpNonrigid :: TypeExp (ExpBase NoInfo VName) VName
-> TermTypeM (TypeExp Exp VName, ResType, [VName])
checkTypeExpNonrigid TypeExp (ExpBase NoInfo VName) VName
te = do
  (TypeExp Exp VName
te', [VName]
svars, RetTypeBase Exp Uniqueness
rettype, Liftedness
_l) <- (ExpBase NoInfo VName -> TermTypeM Exp)
-> TypeExp (ExpBase NoInfo VName) VName
-> TermTypeM
     (TypeExp Exp VName, [VName], RetTypeBase Exp Uniqueness,
      Liftedness)
forall (m :: * -> *) df.
(MonadTypeChecker m, Pretty df) =>
(df -> m Exp)
-> TypeExp df VName
-> m (TypeExp Exp VName, [VName], RetTypeBase Exp Uniqueness,
      Liftedness)
checkTypeExp ExpBase NoInfo VName -> TermTypeM Exp
checkExpForSize TypeExp (ExpBase NoInfo VName) VName
te

  -- No guarantee that the locally bound sizes in rettype are globally
  -- unique, but we want to turn them into size variables, so let's
  -- give them some unique names.
  RetType [VName]
dims ResType
st <- RetTypeBase Exp Uniqueness
-> TermTypeM (RetTypeBase Exp Uniqueness)
forall (m :: * -> *).
MonadTypeChecker m =>
RetTypeBase Exp Uniqueness -> m (RetTypeBase Exp Uniqueness)
renameRetType RetTypeBase Exp Uniqueness
rettype

  [VName] -> (VName -> TermTypeM ()) -> TermTypeM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName]
svars [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
dims) ((VName -> TermTypeM ()) -> TermTypeM ())
-> (VName -> TermTypeM ()) -> TermTypeM ()
forall a b. (a -> b) -> a -> b
$ \VName
v ->
    VName -> Constraint -> TermTypeM ()
constrain VName
v (Constraint -> TermTypeM ()) -> Constraint -> TermTypeM ()
forall a b. (a -> b) -> a -> b
$ Maybe Exp -> Usage -> Constraint
Size Maybe Exp
forall a. Maybe a
Nothing (Usage -> Constraint) -> Usage -> Constraint
forall a b. (a -> b) -> a -> b
$ SrcLoc -> Text -> Usage
forall a. Located a => a -> Text -> Usage
mkUsage (TypeExp (ExpBase NoInfo VName) VName -> SrcLoc
forall a. Located a => a -> SrcLoc
srclocOf TypeExp (ExpBase NoInfo VName) VName
te) Text
"anonymous size in type expression"
  (TypeExp Exp VName, ResType, [VName])
-> TermTypeM (TypeExp Exp VName, ResType, [VName])
forall a. a -> TermTypeM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TypeExp Exp VName
te', ResType
st, [VName]
svars [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
dims)

--- Sizes

isInt64 :: Exp -> Maybe Int64
isInt64 :: Exp -> Maybe Int64
isInt64 (Literal (SignedValue (Int64Value Int64
k')) SrcLoc
_) = Int64 -> Maybe Int64
forall a. a -> Maybe a
Just (Int64 -> Maybe Int64) -> Int64 -> Maybe Int64
forall a b. (a -> b) -> a -> b
$ Int64 -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int64
k'
isInt64 (IntLit Integer
k' Info StructType
_ SrcLoc
_) = Int64 -> Maybe Int64
forall a. a -> Maybe a
Just (Int64 -> Maybe Int64) -> Int64 -> Maybe Int64
forall a b. (a -> b) -> a -> b
$ Integer -> Int64
forall a. Num a => Integer -> a
fromInteger Integer
k'
isInt64 (Negate Exp
x SrcLoc
_) = Int64 -> Int64
forall a. Num a => a -> a
negate (Int64 -> Int64) -> Maybe Int64 -> Maybe Int64
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp -> Maybe Int64
isInt64 Exp
x
isInt64 (Parens Exp
x SrcLoc
_) = Exp -> Maybe Int64
isInt64 Exp
x
isInt64 Exp
_ = Maybe Int64
forall a. Maybe a
Nothing

-- Running

initialTermScope :: TermScope
initialTermScope :: TermScope
initialTermScope =
  TermScope
    { scopeVtable :: Map VName ValBinding
scopeVtable = Map VName ValBinding
initialVtable,
      scopeTypeTable :: Map VName TypeBinding
scopeTypeTable = Map VName TypeBinding
forall a. Monoid a => a
mempty,
      scopeModTable :: Map VName Mod
scopeModTable = Map VName Mod
forall a. Monoid a => a
mempty
    }
  where
    initialVtable :: Map VName ValBinding
initialVtable = [(VName, ValBinding)] -> Map VName ValBinding
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, ValBinding)] -> Map VName ValBinding)
-> [(VName, ValBinding)] -> Map VName ValBinding
forall a b. (a -> b) -> a -> b
$ ((VName, Intrinsic) -> Maybe (VName, ValBinding))
-> [(VName, Intrinsic)] -> [(VName, ValBinding)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (VName, Intrinsic) -> Maybe (VName, ValBinding)
forall {a}. (a, Intrinsic) -> Maybe (a, ValBinding)
addIntrinsicF ([(VName, Intrinsic)] -> [(VName, ValBinding)])
-> [(VName, Intrinsic)] -> [(VName, ValBinding)]
forall a b. (a -> b) -> a -> b
$ Map VName Intrinsic -> [(VName, Intrinsic)]
forall k a. Map k a -> [(k, a)]
M.toList Map VName Intrinsic
intrinsics

    prim :: PrimType -> TypeBase dim u
prim = ScalarTypeBase dim u -> TypeBase dim u
forall dim u. ScalarTypeBase dim u -> TypeBase dim u
Scalar (ScalarTypeBase dim u -> TypeBase dim u)
-> (PrimType -> ScalarTypeBase dim u) -> PrimType -> TypeBase dim u
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> ScalarTypeBase dim u
forall dim u. PrimType -> ScalarTypeBase dim u
Prim
    arrow :: TypeBase dim NoUniqueness
-> RetTypeBase dim Uniqueness -> TypeBase dim u
arrow TypeBase dim NoUniqueness
x RetTypeBase dim Uniqueness
y = ScalarTypeBase dim u -> TypeBase dim u
forall dim u. ScalarTypeBase dim u -> TypeBase dim u
Scalar (ScalarTypeBase dim u -> TypeBase dim u)
-> ScalarTypeBase dim u -> TypeBase dim u
forall a b. (a -> b) -> a -> b
$ u
-> PName
-> Diet
-> TypeBase dim NoUniqueness
-> RetTypeBase dim Uniqueness
-> ScalarTypeBase dim u
forall dim u.
u
-> PName
-> Diet
-> TypeBase dim NoUniqueness
-> RetTypeBase dim Uniqueness
-> ScalarTypeBase dim u
Arrow u
forall a. Monoid a => a
mempty PName
Unnamed Diet
Observe TypeBase dim NoUniqueness
x RetTypeBase dim Uniqueness
y

    addIntrinsicF :: (a, Intrinsic) -> Maybe (a, ValBinding)
addIntrinsicF (a
name, IntrinsicMonoFun [PrimType]
pts PrimType
t) =
      (a, ValBinding) -> Maybe (a, ValBinding)
forall a. a -> Maybe a
Just (a
name, [TypeParam] -> StructType -> ValBinding
BoundV [] (StructType -> ValBinding) -> StructType -> ValBinding
forall a b. (a -> b) -> a -> b
$ StructType -> RetTypeBase Exp Uniqueness -> StructType
forall {u} {dim}.
Monoid u =>
TypeBase dim NoUniqueness
-> RetTypeBase dim Uniqueness -> TypeBase dim u
arrow StructType
forall {dim} {u}. TypeBase dim u
pts' (RetTypeBase Exp Uniqueness -> StructType)
-> RetTypeBase Exp Uniqueness -> StructType
forall a b. (a -> b) -> a -> b
$ [VName] -> ResType -> RetTypeBase Exp Uniqueness
forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType [] (ResType -> RetTypeBase Exp Uniqueness)
-> ResType -> RetTypeBase Exp Uniqueness
forall a b. (a -> b) -> a -> b
$ PrimType -> ResType
forall {dim} {u}. PrimType -> TypeBase dim u
prim PrimType
t)
      where
        pts' :: TypeBase dim u
pts' = case [PrimType]
pts of
          [PrimType
pt] -> PrimType -> TypeBase dim u
forall {dim} {u}. PrimType -> TypeBase dim u
prim PrimType
pt
          [PrimType]
_ -> ScalarTypeBase dim u -> TypeBase dim u
forall dim u. ScalarTypeBase dim u -> TypeBase dim u
Scalar (ScalarTypeBase dim u -> TypeBase dim u)
-> ScalarTypeBase dim u -> TypeBase dim u
forall a b. (a -> b) -> a -> b
$ [TypeBase dim u] -> ScalarTypeBase dim u
forall dim as. [TypeBase dim as] -> ScalarTypeBase dim as
tupleRecord ([TypeBase dim u] -> ScalarTypeBase dim u)
-> [TypeBase dim u] -> ScalarTypeBase dim u
forall a b. (a -> b) -> a -> b
$ (PrimType -> TypeBase dim u) -> [PrimType] -> [TypeBase dim u]
forall a b. (a -> b) -> [a] -> [b]
map PrimType -> TypeBase dim u
forall {dim} {u}. PrimType -> TypeBase dim u
prim [PrimType]
pts
    addIntrinsicF (a
name, IntrinsicOverloadedFun [PrimType]
ts [Maybe PrimType]
pts Maybe PrimType
rts) =
      (a, ValBinding) -> Maybe (a, ValBinding)
forall a. a -> Maybe a
Just (a
name, [PrimType] -> [Maybe PrimType] -> Maybe PrimType -> ValBinding
OverloadedF [PrimType]
ts [Maybe PrimType]
pts Maybe PrimType
rts)
    addIntrinsicF (a
name, IntrinsicPolyFun [TypeParam]
tvs [ParamType]
pts RetTypeBase Exp Uniqueness
rt) =
      (a, ValBinding) -> Maybe (a, ValBinding)
forall a. a -> Maybe a
Just
        ( a
name,
          [TypeParam] -> StructType -> ValBinding
BoundV [TypeParam]
tvs (StructType -> ValBinding) -> StructType -> ValBinding
forall a b. (a -> b) -> a -> b
$ [ParamType] -> RetTypeBase Exp Uniqueness -> StructType
foldFunType [ParamType]
pts RetTypeBase Exp Uniqueness
rt
        )
    addIntrinsicF (a
name, Intrinsic
IntrinsicEquality) =
      (a, ValBinding) -> Maybe (a, ValBinding)
forall a. a -> Maybe a
Just (a
name, ValBinding
EqualityF)
    addIntrinsicF (a, Intrinsic)
_ = Maybe (a, ValBinding)
forall a. Maybe a
Nothing

runTermTypeM :: (ExpBase NoInfo VName -> TermTypeM Exp) -> TermTypeM a -> TypeM a
runTermTypeM :: forall a.
(ExpBase NoInfo VName -> TermTypeM Exp) -> TermTypeM a -> TypeM a
runTermTypeM ExpBase NoInfo VName -> TermTypeM Exp
checker (TermTypeM ReaderT
  TermEnv (StateT TermTypeState (Except (Warnings, TypeError))) a
m) = do
  TermScope
initial_scope <- (TermScope
initialTermScope <>) (TermScope -> TermScope) -> (Env -> TermScope) -> Env -> TermScope
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env -> TermScope
envToTermScope (Env -> TermScope) -> TypeM Env -> TypeM TermScope
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TypeM Env
askEnv
  ImportName
name <- TypeM ImportName
askImportName
  Env
outer_env <- TypeM Env
askEnv
  VNameSource
src <- (TypeState -> VNameSource) -> TypeM VNameSource
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets TypeState -> VNameSource
TypeM.stateNameSource
  let initial_tenv :: TermEnv
initial_tenv =
        TermEnv
          { termScope :: TermScope
termScope = TermScope
initial_scope,
            termChecking :: Maybe Checking
termChecking = Maybe Checking
forall a. Maybe a
Nothing,
            termLevel :: Int
termLevel = Int
0,
            termChecker :: ExpBase NoInfo VName -> TermTypeM Exp
termChecker = ExpBase NoInfo VName -> TermTypeM Exp
checker,
            termImportName :: ImportName
termImportName = ImportName
name,
            termOuterEnv :: Env
termOuterEnv = Env
outer_env
          }
      initial_state :: TermTypeState
initial_state =
        TermTypeState
          { stateConstraints :: Constraints
stateConstraints = Constraints
forall a. Monoid a => a
mempty,
            stateCounter :: Int
stateCounter = Int
0,
            stateWarnings :: Warnings
stateWarnings = Warnings
forall a. Monoid a => a
mempty,
            stateNameSource :: VNameSource
stateNameSource = VNameSource
src
          }
  case Except (Warnings, TypeError) (a, TermTypeState)
-> Either (Warnings, TypeError) (a, TermTypeState)
forall e a. Except e a -> Either e a
runExcept (StateT TermTypeState (Except (Warnings, TypeError)) a
-> TermTypeState -> Except (Warnings, TypeError) (a, TermTypeState)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT (ReaderT
  TermEnv (StateT TermTypeState (Except (Warnings, TypeError))) a
-> TermEnv -> StateT TermTypeState (Except (Warnings, TypeError)) a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT
  TermEnv (StateT TermTypeState (Except (Warnings, TypeError))) a
m TermEnv
initial_tenv) TermTypeState
initial_state) of
    Left (Warnings
ws, TypeError
e) -> do
      Warnings -> TypeM ()
forall (m :: * -> *). MonadTypeChecker m => Warnings -> m ()
warnings Warnings
ws
      TypeError -> TypeM a
forall a. TypeError -> TypeM a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError TypeError
e
    Right (a
a, TermTypeState {VNameSource
stateNameSource :: TermTypeState -> VNameSource
stateNameSource :: VNameSource
stateNameSource, Warnings
stateWarnings :: TermTypeState -> Warnings
stateWarnings :: Warnings
stateWarnings}) -> do
      Warnings -> TypeM ()
forall (m :: * -> *). MonadTypeChecker m => Warnings -> m ()
warnings Warnings
stateWarnings
      (TypeState -> TypeState) -> TypeM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((TypeState -> TypeState) -> TypeM ())
-> (TypeState -> TypeState) -> TypeM ()
forall a b. (a -> b) -> a -> b
$ \TypeState
s -> TypeState
s {stateNameSource :: VNameSource
TypeM.stateNameSource = VNameSource
stateNameSource}
      a -> TypeM a
forall a. a -> TypeM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
a