-- | Type checking of patterns.
module Language.Futhark.TypeChecker.Terms.Pat
  ( binding,
    bindingParams,
    bindingPat,
    bindingIdent,
    bindingSizes,
    doNotShadow,
    boundAliases,
  )
where

import Control.Monad.Except
import Control.Monad.State
import Data.Bitraversable
import Data.Either
import Data.List (find, isPrefixOf, sort)
import Data.Map.Strict qualified as M
import Data.Maybe
import Data.Set qualified as S
import Futhark.Util.Pretty hiding (group, space)
import Language.Futhark
import Language.Futhark.TypeChecker.Monad hiding (BoundV)
import Language.Futhark.TypeChecker.Terms.Monad
import Language.Futhark.TypeChecker.Types
import Language.Futhark.TypeChecker.Unify hiding (Usage)
import Prelude hiding (mod)

-- | Names that may not be shadowed.
doNotShadow :: [String]
doNotShadow :: [[Char]]
doNotShadow = [[Char]
"&&", [Char]
"||"]

nonrigidFor :: [SizeBinder VName] -> StructType -> TermTypeM StructType
nonrigidFor :: [SizeBinder VName] -> StructType -> TermTypeM StructType
nonrigidFor [] StructType
t = forall (f :: * -> *) a. Applicative f => a -> f a
pure StructType
t -- Minor optimisation.
nonrigidFor [SizeBinder VName]
sizes StructType
t = forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT (forall (t :: * -> * -> *) (f :: * -> *) a c b d.
(Bitraversable t, Applicative f) =>
(a -> f c) -> (b -> f d) -> t a b -> f (t c d)
bitraverse forall {t :: (* -> *) -> * -> *}.
(MonadState [(VName, VName)] (t TermTypeM), MonadTrans t) =>
Size -> t TermTypeM Size
onDim forall (f :: * -> *) a. Applicative f => a -> f a
pure StructType
t) forall a. Monoid a => a
mempty
  where
    onDim :: Size -> t TermTypeM Size
onDim (NamedSize (QualName [VName]
_ VName
v))
      | Just SizeBinder VName
size <- forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((forall a. Eq a => a -> a -> Bool
== VName
v) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall vn. SizeBinder vn -> vn
sizeName) [SizeBinder VName]
sizes = do
          Maybe VName
prev <- forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets forall a b. (a -> b) -> a -> b
$ forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup VName
v
          case Maybe VName
prev of
            Maybe VName
Nothing -> do
              VName
v' <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *). MonadTypeChecker m => Name -> m VName
newID forall a b. (a -> b) -> a -> b
$ VName -> Name
baseName VName
v
              forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ VName -> Constraint -> TermTypeM ()
constrain VName
v' forall a b. (a -> b) -> a -> b
$ Maybe Size -> Usage -> Constraint
Size forall a. Maybe a
Nothing forall a b. (a -> b) -> a -> b
$ SrcLoc -> Usage
mkUsage' forall a b. (a -> b) -> a -> b
$ forall a. Located a => a -> SrcLoc
srclocOf SizeBinder VName
size
              forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((VName
v, VName
v') :)
              forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ QualName VName -> Size
NamedSize forall a b. (a -> b) -> a -> b
$ forall v. v -> QualName v
qualName VName
v'
            Just VName
v' ->
              forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ QualName VName -> Size
NamedSize forall a b. (a -> b) -> a -> b
$ forall v. v -> QualName v
qualName VName
v'
    onDim Size
d = forall (f :: * -> *) a. Applicative f => a -> f a
pure Size
d

-- | The set of in-scope variables that are being aliased.
boundAliases :: Aliasing -> S.Set VName
boundAliases :: Aliasing -> Set VName
boundAliases = forall b a. Ord b => (a -> b) -> Set a -> Set b
S.map Alias -> VName
aliasVar forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. (a -> Bool) -> Set a -> Set a
S.filter Alias -> Bool
bound
  where
    bound :: Alias -> Bool
bound AliasBound {} = Bool
True
    bound AliasFree {} = Bool
False

checkIfUsed :: Bool -> Occurrences -> Ident -> TermTypeM ()
checkIfUsed :: Bool -> Occurrences -> Ident -> TermTypeM ()
checkIfUsed Bool
allow_consume Occurrences
occs Ident
v
  | Bool -> Bool
not Bool
allow_consume,
    Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ forall {dim} {as}. TypeBase dim as -> Bool
consumable forall a b. (a -> b) -> a -> b
$ forall a. Info a -> a
unInfo forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) vn. IdentBase f vn -> f PatType
identType Ident
v,
    Just Occurrence
occ <- forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find Occurrence -> Bool
consumes Occurrences
occs =
      forall (m :: * -> *) loc a.
(MonadTypeChecker m, Located loc) =>
loc -> Notes -> Doc () -> m a
typeError (forall a. Located a => a -> SrcLoc
srclocOf Occurrence
occ) forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$
        Doc ()
"Consuming"
          forall ann. Doc ann -> Doc ann -> Doc ann
<+> forall ann. Doc ann -> Doc ann
dquotes (forall v a. IsName v => v -> Doc a
prettyName forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) vn. IdentBase f vn -> vn
identName Ident
v)
          forall ann. Doc ann -> Doc ann -> Doc ann
<+> forall a. Text -> Doc a
textwrap (Text
"which is a non-consumable parameter bound at " forall a. Semigroup a => a -> a -> a
<> forall a. Located a => a -> Text
locText (forall a. Located a => a -> Loc
locOf Ident
v) forall a. Semigroup a => a -> a -> a
<> Text
".")
  | Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) vn. IdentBase f vn -> vn
identName Ident
v forall a. Ord a => a -> Set a -> Bool
`S.member` Occurrences -> Set VName
allOccurring Occurrences
occs,
    Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ [Char]
"_" forall a. Eq a => [a] -> [a] -> Bool
`isPrefixOf` VName -> [Char]
baseString (forall (f :: * -> *) vn. IdentBase f vn -> vn
identName Ident
v) =
      forall (m :: * -> *) loc.
(MonadTypeChecker m, Located loc) =>
loc -> Doc () -> m ()
warn (forall a. Located a => a -> SrcLoc
srclocOf Ident
v) forall a b. (a -> b) -> a -> b
$
        Doc ()
"Unused variable" forall ann. Doc ann -> Doc ann -> Doc ann
<+> forall ann. Doc ann -> Doc ann
dquotes (forall v a. IsName v => v -> Doc a
prettyName forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) vn. IdentBase f vn -> vn
identName Ident
v) forall ann. Doc ann -> Doc ann -> Doc ann
<+> Doc ()
"."
  | Bool
otherwise =
      forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
  where
    consumes :: Occurrence -> Bool
consumes = forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
False (forall (f :: * -> *) vn. IdentBase f vn -> vn
identName Ident
v `S.member`) forall b c a. (b -> c) -> (a -> b) -> a -> c
. Occurrence -> Maybe (Set VName)
consumed

    consumable :: TypeBase dim as -> Bool
consumable (Scalar (Record Map Name (TypeBase dim as)
fs)) = forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all TypeBase dim as -> Bool
consumable Map Name (TypeBase dim as)
fs
    consumable (Scalar (Sum Map Name [TypeBase dim as]
cs)) = forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all TypeBase dim as -> Bool
consumable) Map Name [TypeBase dim as]
cs
    consumable (Scalar (TypeVar as
_ Uniqueness
u QualName VName
_ [TypeArg dim]
_)) = Uniqueness
u forall a. Eq a => a -> a -> Bool
== Uniqueness
Unique
    consumable (Scalar Arrow {}) = Bool
True
    consumable (Scalar Prim {}) = Bool
True
    consumable (Array as
_ Uniqueness
u Shape dim
_ ScalarTypeBase dim ()
_) = Uniqueness
u forall a. Eq a => a -> a -> Bool
== Uniqueness
Unique

-- | Bind these identifiers locally while running the provided action.
-- Checks that the identifiers are used properly within the scope
-- (e.g. consumption).
binding ::
  -- | Allow consumption, even if the type is not unique.
  Bool ->
  [Ident] ->
  TermTypeM a ->
  TermTypeM a
binding :: forall a. Bool -> [Ident] -> TermTypeM a -> TermTypeM a
binding Bool
allow_consume [Ident]
idents = forall {b}. TermTypeM b -> TermTypeM b
check forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {b}. TermTypeM b -> TermTypeM b
handleVars
  where
    handleVars :: TermTypeM a -> TermTypeM a
handleVars TermTypeM a
m =
      forall a. (TermScope -> TermScope) -> TermTypeM a -> TermTypeM a
localScope (TermScope -> [Ident] -> TermScope
`bindVars` [Ident]
idents) forall a b. (a -> b) -> a -> b
$ do
        -- Those identifiers that can potentially also be sizes are
        -- added as type constraints.  This is necessary so that we
        -- can properly detect scope violations during unification.
        -- We do this for *all* identifiers, not just those that are
        -- integers, because they may become integers later due to
        -- inference...
        forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Ident]
idents forall a b. (a -> b) -> a -> b
$ \Ident
ident ->
          VName -> Constraint -> TermTypeM ()
constrain (forall (f :: * -> *) vn. IdentBase f vn -> vn
identName Ident
ident) forall a b. (a -> b) -> a -> b
$ SrcLoc -> Constraint
ParamSize forall a b. (a -> b) -> a -> b
$ forall a. Located a => a -> SrcLoc
srclocOf Ident
ident
        TermTypeM a
m

    bindVars :: TermScope -> [Ident] -> TermScope
    bindVars :: TermScope -> [Ident] -> TermScope
bindVars = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl TermScope -> Ident -> TermScope
bindVar

    bindVar :: TermScope -> Ident -> TermScope
    bindVar :: TermScope -> Ident -> TermScope
bindVar TermScope
scope (Ident VName
name (Info PatType
tp) SrcLoc
_) =
      let inedges :: Set VName
inedges = Aliasing -> Set VName
boundAliases forall a b. (a -> b) -> a -> b
$ forall as shape. Monoid as => TypeBase shape as -> as
aliases PatType
tp
          update :: ValBinding -> ValBinding
update (BoundV Locality
l [TypeParam]
tparams PatType
in_t)
            | Array {} <- PatType
tp = Locality -> [TypeParam] -> PatType -> ValBinding
BoundV Locality
l [TypeParam]
tparams (PatType
in_t forall dim asf ast.
TypeBase dim asf -> (asf -> ast) -> TypeBase dim ast
`addAliases` forall a. Ord a => a -> Set a -> Set a
S.insert (VName -> Alias
AliasBound VName
name))
            | Bool
otherwise = Locality -> [TypeParam] -> PatType -> ValBinding
BoundV Locality
l [TypeParam]
tparams PatType
in_t
          update ValBinding
b = ValBinding
b

          tp' :: PatType
tp' = PatType
tp forall dim asf ast.
TypeBase dim asf -> (asf -> ast) -> TypeBase dim ast
`addAliases` forall a. Ord a => a -> Set a -> Set a
S.insert (VName -> Alias
AliasBound VName
name)
       in TermScope
scope
            { scopeVtable :: Map VName ValBinding
scopeVtable =
                forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
name (Locality -> [TypeParam] -> PatType -> ValBinding
BoundV Locality
Local [] PatType
tp') forall a b. (a -> b) -> a -> b
$
                  forall {t :: * -> *} {a} {a}.
(Foldable t, Ord a) =>
(a -> a) -> t a -> Map a a -> Map a a
adjustSeveral ValBinding -> ValBinding
update Set VName
inedges forall a b. (a -> b) -> a -> b
$
                    TermScope -> Map VName ValBinding
scopeVtable TermScope
scope
            }

    adjustSeveral :: (a -> a) -> t a -> Map a a -> Map a a
adjustSeveral a -> a
f = forall a b c. (a -> b -> c) -> b -> a -> c
flip forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> b -> a -> c
flip forall a b. (a -> b) -> a -> b
$ forall k a. Ord k => (a -> a) -> k -> Map k a -> Map k a
M.adjust a -> a
f

    -- Check whether the bound variables have been used correctly
    -- within their scope.
    check :: TermTypeM b -> TermTypeM b
check TermTypeM b
m = do
      (b
a, Occurrences
usages) <- forall {a}. TermTypeM a -> TermTypeM (a, Occurrences)
collectBindingsOccurrences TermTypeM b
m
      Occurrences -> TermTypeM ()
checkOccurrences Occurrences
usages

      forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Bool -> Occurrences -> Ident -> TermTypeM ()
checkIfUsed Bool
allow_consume Occurrences
usages) [Ident]
idents

      forall (f :: * -> *) a. Applicative f => a -> f a
pure b
a

    -- Collect and remove all occurences of @idents@.  This relies
    -- on the fact that no variables shadow any other.
    collectBindingsOccurrences :: TermTypeM a -> TermTypeM (a, Occurrences)
collectBindingsOccurrences TermTypeM a
m = do
      (a
x, Occurrences
usage) <- forall {a}. TermTypeM a -> TermTypeM (a, Occurrences)
collectOccurrences TermTypeM a
m
      let (Occurrences
relevant, Occurrences
rest) = Occurrences -> (Occurrences, Occurrences)
split Occurrences
usage
      Occurrences -> TermTypeM ()
occur Occurrences
rest
      forall (f :: * -> *) a. Applicative f => a -> f a
pure (a
x, Occurrences
relevant)
      where
        onOcc :: Occurrence -> (Occurrence, Occurrence)
onOcc Occurrence
occ =
          let (Set VName
obs1, Set VName
obs2) = Set VName -> (Set VName, Set VName)
divide forall a b. (a -> b) -> a -> b
$ Occurrence -> Set VName
observed Occurrence
occ
              occ_cons :: Maybe (Set VName, Set VName)
occ_cons = Set VName -> (Set VName, Set VName)
divide forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Occurrence -> Maybe (Set VName)
consumed Occurrence
occ
              con1 :: Maybe (Set VName)
con1 = forall a b. (a, b) -> a
fst forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe (Set VName, Set VName)
occ_cons
              con2 :: Maybe (Set VName)
con2 = forall a b. (a, b) -> b
snd forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe (Set VName, Set VName)
occ_cons
           in ( Occurrence
occ {observed :: Set VName
observed = Set VName
obs1, consumed :: Maybe (Set VName)
consumed = Maybe (Set VName)
con1},
                Occurrence
occ {observed :: Set VName
observed = Set VName
obs2, consumed :: Maybe (Set VName)
consumed = Maybe (Set VName)
con2}
              )
        split :: Occurrences -> (Occurrences, Occurrences)
split = forall a b. [(a, b)] -> ([a], [b])
unzip forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map Occurrence -> (Occurrence, Occurrence)
onOcc
        names :: Set VName
names = forall a. Ord a => [a] -> Set a
S.fromList forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall (f :: * -> *) vn. IdentBase f vn -> vn
identName [Ident]
idents
        divide :: Set VName -> (Set VName, Set VName)
divide Set VName
s = (Set VName
s forall a. Ord a => Set a -> Set a -> Set a
`S.intersection` Set VName
names, Set VName
s forall a. Ord a => Set a -> Set a -> Set a
`S.difference` Set VName
names)

bindingTypes ::
  [Either (VName, TypeBinding) (VName, Constraint)] ->
  TermTypeM a ->
  TermTypeM a
bindingTypes :: forall a.
[Either (VName, TypeBinding) (VName, Constraint)]
-> TermTypeM a -> TermTypeM a
bindingTypes [Either (VName, TypeBinding) (VName, Constraint)]
types TermTypeM a
m = do
  Int
lvl <- forall (m :: * -> *). MonadUnify m => m Int
curLevel
  forall (m :: * -> *).
MonadUnify m =>
(Constraints -> Constraints) -> m ()
modifyConstraints (forall a. Semigroup a => a -> a -> a
<> forall a b k. (a -> b) -> Map k a -> Map k b
M.map (Int
lvl,) (forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(VName, Constraint)]
constraints))
  forall a. (TermScope -> TermScope) -> TermTypeM a -> TermTypeM a
localScope TermScope -> TermScope
extend TermTypeM a
m
  where
    ([(VName, TypeBinding)]
tbinds, [(VName, Constraint)]
constraints) = forall a b. [Either a b] -> ([a], [b])
partitionEithers [Either (VName, TypeBinding) (VName, Constraint)]
types
    extend :: TermScope -> TermScope
extend TermScope
scope =
      TermScope
scope
        { scopeTypeTable :: Map VName TypeBinding
scopeTypeTable = forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(VName, TypeBinding)]
tbinds forall a. Semigroup a => a -> a -> a
<> TermScope -> Map VName TypeBinding
scopeTypeTable TermScope
scope
        }

bindingTypeParams :: [TypeParam] -> TermTypeM a -> TermTypeM a
bindingTypeParams :: forall a. [TypeParam] -> TermTypeM a -> TermTypeM a
bindingTypeParams [TypeParam]
tparams =
  forall a. Bool -> [Ident] -> TermTypeM a -> TermTypeM a
binding Bool
False (forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe TypeParam -> Maybe Ident
typeParamIdent [TypeParam]
tparams)
    forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a.
[Either (VName, TypeBinding) (VName, Constraint)]
-> TermTypeM a -> TermTypeM a
bindingTypes (forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap TypeParam -> [Either (VName, TypeBinding) (VName, Constraint)]
typeParamType [TypeParam]
tparams)
  where
    typeParamType :: TypeParam -> [Either (VName, TypeBinding) (VName, Constraint)]
typeParamType (TypeParamType Liftedness
l VName
v SrcLoc
loc) =
      [ forall a b. a -> Either a b
Left (VName
v, Liftedness -> [TypeParam] -> StructRetType -> TypeBinding
TypeAbbr Liftedness
l [] forall a b. (a -> b) -> a -> b
$ forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType [] forall a b. (a -> b) -> a -> b
$ forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (forall dim as.
as
-> Uniqueness
-> QualName VName
-> [TypeArg dim]
-> ScalarTypeBase dim as
TypeVar () Uniqueness
Nonunique (forall v. v -> QualName v
qualName VName
v) [])),
        forall a b. b -> Either a b
Right (VName
v, Liftedness -> SrcLoc -> Constraint
ParamType Liftedness
l SrcLoc
loc)
      ]
    typeParamType (TypeParamDim VName
v SrcLoc
loc) =
      [forall a b. b -> Either a b
Right (VName
v, SrcLoc -> Constraint
ParamSize SrcLoc
loc)]

typeParamIdent :: TypeParam -> Maybe Ident
typeParamIdent :: TypeParam -> Maybe Ident
typeParamIdent (TypeParamDim VName
v SrcLoc
loc) =
  forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) vn.
vn -> f PatType -> SrcLoc -> IdentBase f vn
Ident VName
v (forall a. a -> Info a
Info forall a b. (a -> b) -> a -> b
$ forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar forall a b. (a -> b) -> a -> b
$ forall dim as. PrimType -> ScalarTypeBase dim as
Prim forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
Signed IntType
Int64) SrcLoc
loc
typeParamIdent TypeParam
_ = forall a. Maybe a
Nothing

-- | Bind a single term-level identifier.
bindingIdent ::
  IdentBase NoInfo Name ->
  PatType ->
  (Ident -> TermTypeM a) ->
  TermTypeM a
bindingIdent :: forall a.
IdentBase NoInfo Name
-> PatType -> (Ident -> TermTypeM a) -> TermTypeM a
bindingIdent (Ident Name
v NoInfo PatType
NoInfo SrcLoc
vloc) PatType
t Ident -> TermTypeM a
m =
  forall (m :: * -> *) a.
MonadTypeChecker m =>
[(Namespace, Name)] -> m a -> m a
bindSpaced [(Namespace
Term, Name
v)] forall a b. (a -> b) -> a -> b
$ do
    VName
v' <- forall (m :: * -> *).
MonadTypeChecker m =>
Namespace -> Name -> SrcLoc -> m VName
checkName Namespace
Term Name
v SrcLoc
vloc
    let ident :: Ident
ident = forall (f :: * -> *) vn.
vn -> f PatType -> SrcLoc -> IdentBase f vn
Ident VName
v' (forall a. a -> Info a
Info PatType
t) SrcLoc
vloc
    forall a. Bool -> [Ident] -> TermTypeM a -> TermTypeM a
binding Bool
True [Ident
ident] forall a b. (a -> b) -> a -> b
$ Ident -> TermTypeM a
m Ident
ident

-- | Bind @let@-bound sizes.  This is usually followed by 'bindingPat'
-- immediately afterwards.
bindingSizes :: [SizeBinder Name] -> ([SizeBinder VName] -> TermTypeM a) -> TermTypeM a
bindingSizes :: forall a.
[SizeBinder Name]
-> ([SizeBinder VName] -> TermTypeM a) -> TermTypeM a
bindingSizes [] [SizeBinder VName] -> TermTypeM a
m = [SizeBinder VName] -> TermTypeM a
m [] -- Minor optimisation.
bindingSizes [SizeBinder Name]
sizes [SizeBinder VName] -> TermTypeM a
m = do
  forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m ()
foldM_ forall {k} {m :: * -> *}.
(Ord k, MonadTypeChecker m) =>
Map k SrcLoc -> SizeBinder k -> m (Map k SrcLoc)
lookForDuplicates forall a. Monoid a => a
mempty [SizeBinder Name]
sizes
  forall (m :: * -> *) a.
MonadTypeChecker m =>
[(Namespace, Name)] -> m a -> m a
bindSpaced (forall a b. (a -> b) -> [a] -> [b]
map forall {b}. SizeBinder b -> (Namespace, b)
sizeWithSpace [SizeBinder Name]
sizes) forall a b. (a -> b) -> a -> b
$ do
    [SizeBinder VName]
sizes' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {f :: * -> *}.
MonadTypeChecker f =>
SizeBinder Name -> f (SizeBinder VName)
check [SizeBinder Name]
sizes
    forall a. Bool -> [Ident] -> TermTypeM a -> TermTypeM a
binding Bool
False (forall a b. (a -> b) -> [a] -> [b]
map forall {vn}. SizeBinder vn -> IdentBase Info vn
sizeWithType [SizeBinder VName]
sizes') forall a b. (a -> b) -> a -> b
$ [SizeBinder VName] -> TermTypeM a
m [SizeBinder VName]
sizes'
  where
    lookForDuplicates :: Map k SrcLoc -> SizeBinder k -> m (Map k SrcLoc)
lookForDuplicates Map k SrcLoc
prev SizeBinder k
size
      | Just SrcLoc
prevloc <- forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (forall vn. SizeBinder vn -> vn
sizeName SizeBinder k
size) Map k SrcLoc
prev =
          forall (m :: * -> *) loc a.
(MonadTypeChecker m, Located loc) =>
loc -> Notes -> Doc () -> m a
typeError SizeBinder k
size forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$
            Doc ()
"Size name also bound at "
              forall a. Semigroup a => a -> a -> a
<> forall a ann. Pretty a => a -> Doc ann
pretty (forall a b. (Located a, Located b) => a -> b -> [Char]
locStrRel (forall a. Located a => a -> SrcLoc
srclocOf SizeBinder k
size) SrcLoc
prevloc)
              forall a. Semigroup a => a -> a -> a
<> Doc ()
"."
      | Bool
otherwise =
          forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert (forall vn. SizeBinder vn -> vn
sizeName SizeBinder k
size) (forall a. Located a => a -> SrcLoc
srclocOf SizeBinder k
size) Map k SrcLoc
prev

    sizeWithSpace :: SizeBinder b -> (Namespace, b)
sizeWithSpace SizeBinder b
size =
      (Namespace
Term, forall vn. SizeBinder vn -> vn
sizeName SizeBinder b
size)
    sizeWithType :: SizeBinder vn -> IdentBase Info vn
sizeWithType SizeBinder vn
size =
      forall (f :: * -> *) vn.
vn -> f PatType -> SrcLoc -> IdentBase f vn
Ident (forall vn. SizeBinder vn -> vn
sizeName SizeBinder vn
size) (forall a. a -> Info a
Info (forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (forall dim as. PrimType -> ScalarTypeBase dim as
Prim (IntType -> PrimType
Signed IntType
Int64)))) (forall a. Located a => a -> SrcLoc
srclocOf SizeBinder vn
size)

    check :: SizeBinder Name -> f (SizeBinder VName)
check (SizeBinder Name
v SrcLoc
loc) =
      forall vn. vn -> SrcLoc -> SizeBinder vn
SizeBinder forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *).
MonadTypeChecker m =>
Namespace -> Name -> SrcLoc -> m VName
checkName Namespace
Term Name
v SrcLoc
loc forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure SrcLoc
loc

sizeBinderToParam :: SizeBinder VName -> UncheckedTypeParam
sizeBinderToParam :: SizeBinder VName -> UncheckedTypeParam
sizeBinderToParam (SizeBinder VName
v SrcLoc
loc) = forall vn. vn -> SrcLoc -> TypeParamBase vn
TypeParamDim (VName -> Name
baseName VName
v) SrcLoc
loc

-- | Check and bind a @let@-pattern.
bindingPat ::
  [SizeBinder VName] ->
  PatBase NoInfo Name ->
  InferredType ->
  (Pat -> TermTypeM a) ->
  TermTypeM a
bindingPat :: forall a.
[SizeBinder VName]
-> PatBase NoInfo Name
-> InferredType
-> (Pat -> TermTypeM a)
-> TermTypeM a
bindingPat [SizeBinder VName]
sizes PatBase NoInfo Name
p InferredType
t Pat -> TermTypeM a
m = do
  forall (m :: * -> *).
MonadTypeChecker m =>
[UncheckedTypeParam] -> [PatBase NoInfo Name] -> m ()
checkForDuplicateNames (forall a b. (a -> b) -> [a] -> [b]
map SizeBinder VName -> UncheckedTypeParam
sizeBinderToParam [SizeBinder VName]
sizes) [PatBase NoInfo Name
p]
  forall a.
[SizeBinder VName]
-> PatBase NoInfo Name
-> InferredType
-> (Pat -> TermTypeM a)
-> TermTypeM a
checkPat [SizeBinder VName]
sizes PatBase NoInfo Name
p InferredType
t forall a b. (a -> b) -> a -> b
$ \Pat
p' -> forall a. Bool -> [Ident] -> TermTypeM a -> TermTypeM a
binding Bool
True (forall a. Set a -> [a]
S.toList forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) vn.
(Functor f, Ord vn) =>
PatBase f vn -> Set (IdentBase f vn)
patIdents Pat
p') forall a b. (a -> b) -> a -> b
$ do
    -- Perform an observation of every declared dimension.  This
    -- prevents unused-name warnings for otherwise unused dimensions.
    let ident :: SizeBinder vn -> IdentBase Info vn
ident (SizeBinder vn
v SrcLoc
loc) =
          forall (f :: * -> *) vn.
vn -> f PatType -> SrcLoc -> IdentBase f vn
Ident vn
v (forall a. a -> Info a
Info (forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar forall a b. (a -> b) -> a -> b
$ forall dim as. PrimType -> ScalarTypeBase dim as
Prim forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
Signed IntType
Int64)) SrcLoc
loc
    forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Ident -> TermTypeM ()
observe forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {vn}. SizeBinder vn -> IdentBase Info vn
ident) [SizeBinder VName]
sizes

    let used_sizes :: Set VName
used_sizes = forall as. TypeBase Size as -> Set VName
freeInType forall a b. (a -> b) -> a -> b
$ Pat -> StructType
patternStructType Pat
p'
    case forall a. (a -> Bool) -> [a] -> [a]
filter ((forall a. Ord a => a -> Set a -> Bool
`S.notMember` Set VName
used_sizes) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall vn. SizeBinder vn -> vn
sizeName) [SizeBinder VName]
sizes of
      [] -> Pat -> TermTypeM a
m Pat
p'
      SizeBinder VName
size : [SizeBinder VName]
_ -> forall (m :: * -> *) a.
MonadTypeChecker m =>
SizeBinder VName -> m a
unusedSize SizeBinder VName
size

-- All this complexity is just so we can handle un-suffixed numeric
-- literals in patterns.
patLitMkType :: PatLit -> SrcLoc -> TermTypeM StructType
patLitMkType :: PatLit -> SrcLoc -> TermTypeM StructType
patLitMkType (PatLitInt Integer
_) SrcLoc
loc = do
  StructType
t <- forall (m :: * -> *) als dim.
(MonadUnify m, Monoid als) =>
SrcLoc -> Name -> m (TypeBase dim als)
newTypeVar SrcLoc
loc Name
"t"
  forall (m :: * -> *).
MonadUnify m =>
[PrimType] -> Usage -> StructType -> m ()
mustBeOneOf [PrimType]
anyNumberType (SrcLoc -> Text -> Usage
mkUsage SrcLoc
loc Text
"integer literal") StructType
t
  forall (f :: * -> *) a. Applicative f => a -> f a
pure StructType
t
patLitMkType (PatLitFloat Double
_) SrcLoc
loc = do
  StructType
t <- forall (m :: * -> *) als dim.
(MonadUnify m, Monoid als) =>
SrcLoc -> Name -> m (TypeBase dim als)
newTypeVar SrcLoc
loc Name
"t"
  forall (m :: * -> *).
MonadUnify m =>
[PrimType] -> Usage -> StructType -> m ()
mustBeOneOf [PrimType]
anyFloatType (SrcLoc -> Text -> Usage
mkUsage SrcLoc
loc Text
"float literal") StructType
t
  forall (f :: * -> *) a. Applicative f => a -> f a
pure StructType
t
patLitMkType (PatLitPrim PrimValue
v) SrcLoc
_ =
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar forall a b. (a -> b) -> a -> b
$ forall dim as. PrimType -> ScalarTypeBase dim as
Prim forall a b. (a -> b) -> a -> b
$ PrimValue -> PrimType
primValueType PrimValue
v

checkPat' ::
  [SizeBinder VName] ->
  UncheckedPat ->
  InferredType ->
  TermTypeM Pat
checkPat' :: [SizeBinder VName]
-> PatBase NoInfo Name -> InferredType -> TermTypeM Pat
checkPat' [SizeBinder VName]
sizes (PatParens PatBase NoInfo Name
p SrcLoc
loc) InferredType
t =
  forall (f :: * -> *) vn. PatBase f vn -> SrcLoc -> PatBase f vn
PatParens forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [SizeBinder VName]
-> PatBase NoInfo Name -> InferredType -> TermTypeM Pat
checkPat' [SizeBinder VName]
sizes PatBase NoInfo Name
p InferredType
t forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure SrcLoc
loc
checkPat' [SizeBinder VName]
sizes (PatAttr AttrInfo Name
attr PatBase NoInfo Name
p SrcLoc
loc) InferredType
t =
  forall (f :: * -> *) vn.
AttrInfo vn -> PatBase f vn -> SrcLoc -> PatBase f vn
PatAttr forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *).
MonadTypeChecker m =>
AttrInfo Name -> m (AttrInfo VName)
checkAttr AttrInfo Name
attr forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [SizeBinder VName]
-> PatBase NoInfo Name -> InferredType -> TermTypeM Pat
checkPat' [SizeBinder VName]
sizes PatBase NoInfo Name
p InferredType
t forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure SrcLoc
loc
checkPat' [SizeBinder VName]
_ (Id Name
name NoInfo PatType
_ SrcLoc
loc) InferredType
_
  | [Char]
name' forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [[Char]]
doNotShadow =
      forall (m :: * -> *) loc a.
(MonadTypeChecker m, Located loc) =>
loc -> Notes -> Doc () -> m a
typeError SrcLoc
loc forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$ Doc ()
"The" forall ann. Doc ann -> Doc ann -> Doc ann
<+> forall a ann. Pretty a => a -> Doc ann
pretty [Char]
name' forall ann. Doc ann -> Doc ann -> Doc ann
<+> Doc ()
"operator may not be redefined."
  where
    name' :: [Char]
name' = Name -> [Char]
nameToString Name
name
checkPat' [SizeBinder VName]
_ (Id Name
name NoInfo PatType
NoInfo SrcLoc
loc) (Ascribed PatType
t) = do
  VName
name' <- forall (m :: * -> *). MonadTypeChecker m => Name -> m VName
newID Name
name
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) vn. vn -> f PatType -> SrcLoc -> PatBase f vn
Id VName
name' (forall a. a -> Info a
Info PatType
t) SrcLoc
loc
checkPat' [SizeBinder VName]
_ (Id Name
name NoInfo PatType
NoInfo SrcLoc
loc) InferredType
NoneInferred = do
  VName
name' <- forall (m :: * -> *). MonadTypeChecker m => Name -> m VName
newID Name
name
  PatType
t <- forall (m :: * -> *) als dim.
(MonadUnify m, Monoid als) =>
SrcLoc -> Name -> m (TypeBase dim als)
newTypeVar SrcLoc
loc Name
"t"
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) vn. vn -> f PatType -> SrcLoc -> PatBase f vn
Id VName
name' (forall a. a -> Info a
Info PatType
t) SrcLoc
loc
checkPat' [SizeBinder VName]
_ (Wildcard NoInfo PatType
_ SrcLoc
loc) (Ascribed PatType
t) =
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) vn. f PatType -> SrcLoc -> PatBase f vn
Wildcard (forall a. a -> Info a
Info forall a b. (a -> b) -> a -> b
$ PatType
t forall dim as. TypeBase dim as -> Uniqueness -> TypeBase dim as
`setUniqueness` Uniqueness
Nonunique) SrcLoc
loc
checkPat' [SizeBinder VName]
_ (Wildcard NoInfo PatType
NoInfo SrcLoc
loc) InferredType
NoneInferred = do
  PatType
t <- forall (m :: * -> *) als dim.
(MonadUnify m, Monoid als) =>
SrcLoc -> Name -> m (TypeBase dim als)
newTypeVar SrcLoc
loc Name
"t"
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) vn. f PatType -> SrcLoc -> PatBase f vn
Wildcard (forall a. a -> Info a
Info PatType
t) SrcLoc
loc
checkPat' [SizeBinder VName]
sizes (TuplePat [PatBase NoInfo Name]
ps SrcLoc
loc) (Ascribed PatType
t)
  | Just [PatType]
ts <- forall dim as. TypeBase dim as -> Maybe [TypeBase dim as]
isTupleRecord PatType
t,
    forall (t :: * -> *) a. Foldable t => t a -> Int
length [PatType]
ts forall a. Eq a => a -> a -> Bool
== forall (t :: * -> *) a. Foldable t => t a -> Int
length [PatBase NoInfo Name]
ps =
      forall (f :: * -> *) vn. [PatBase f vn] -> SrcLoc -> PatBase f vn
TuplePat
        forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM ([SizeBinder VName]
-> PatBase NoInfo Name -> InferredType -> TermTypeM Pat
checkPat' [SizeBinder VName]
sizes) [PatBase NoInfo Name]
ps (forall a b. (a -> b) -> [a] -> [b]
map PatType -> InferredType
Ascribed [PatType]
ts)
        forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure SrcLoc
loc
checkPat' [SizeBinder VName]
sizes p :: PatBase NoInfo Name
p@(TuplePat [PatBase NoInfo Name]
ps SrcLoc
loc) (Ascribed PatType
t) = do
  [StructType]
ps_t <- forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM (forall (t :: * -> *) a. Foldable t => t a -> Int
length [PatBase NoInfo Name]
ps) (forall (m :: * -> *) als dim.
(MonadUnify m, Monoid als) =>
SrcLoc -> Name -> m (TypeBase dim als)
newTypeVar SrcLoc
loc Name
"t")
  forall (m :: * -> *).
MonadUnify m =>
Usage -> StructType -> StructType -> m ()
unify (SrcLoc -> Text -> Usage
mkUsage SrcLoc
loc Text
"matching a tuple pattern") (forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (forall dim as. [TypeBase dim as] -> ScalarTypeBase dim as
tupleRecord [StructType]
ps_t)) forall a b. (a -> b) -> a -> b
$ forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct PatType
t
  PatType
t' <- forall a (m :: * -> *). (Substitutable a, MonadUnify m) => a -> m a
normTypeFully PatType
t
  [SizeBinder VName]
-> PatBase NoInfo Name -> InferredType -> TermTypeM Pat
checkPat' [SizeBinder VName]
sizes PatBase NoInfo Name
p forall a b. (a -> b) -> a -> b
$ PatType -> InferredType
Ascribed PatType
t'
checkPat' [SizeBinder VName]
sizes (TuplePat [PatBase NoInfo Name]
ps SrcLoc
loc) InferredType
NoneInferred =
  forall (f :: * -> *) vn. [PatBase f vn] -> SrcLoc -> PatBase f vn
TuplePat forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\PatBase NoInfo Name
p -> [SizeBinder VName]
-> PatBase NoInfo Name -> InferredType -> TermTypeM Pat
checkPat' [SizeBinder VName]
sizes PatBase NoInfo Name
p InferredType
NoneInferred) [PatBase NoInfo Name]
ps forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure SrcLoc
loc
checkPat' [SizeBinder VName]
_ (RecordPat [(Name, PatBase NoInfo Name)]
p_fs SrcLoc
_) InferredType
_
  | Just (Name
f, PatBase NoInfo Name
fp) <- forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (([Char]
"_" `isPrefixOf`) forall b c a. (b -> c) -> (a -> b) -> a -> c
. Name -> [Char]
nameToString forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) [(Name, PatBase NoInfo Name)]
p_fs =
      forall (m :: * -> *) loc a.
(MonadTypeChecker m, Located loc) =>
loc -> Notes -> Doc () -> m a
typeError PatBase NoInfo Name
fp forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$
        Doc ()
"Underscore-prefixed fields are not allowed."
          forall ann. Doc ann -> Doc ann -> Doc ann
</> Doc ()
"Did you mean" forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann -> Doc ann
dquotes (forall a ann. Pretty a => a -> Doc ann
pretty (forall a. Int -> [a] -> [a]
drop Int
1 (Name -> [Char]
nameToString Name
f)) forall a. Semigroup a => a -> a -> a
<> Doc ()
"=_") forall a. Semigroup a => a -> a -> a
<> Doc ()
"?"
checkPat' [SizeBinder VName]
sizes (RecordPat [(Name, PatBase NoInfo Name)]
p_fs SrcLoc
loc) (Ascribed (Scalar (Record Map Name PatType
t_fs)))
  | forall a. Ord a => [a] -> [a]
sort (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(Name, PatBase NoInfo Name)]
p_fs) forall a. Eq a => a -> a -> Bool
== forall a. Ord a => [a] -> [a]
sort (forall k a. Map k a -> [k]
M.keys Map Name PatType
t_fs) =
      forall (f :: * -> *) vn.
[(Name, PatBase f vn)] -> SrcLoc -> PatBase f vn
RecordPat forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k a. Map k a -> [(k, a)]
M.toList forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TermTypeM (Map Name Pat)
check forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure SrcLoc
loc
  where
    check :: TermTypeM (Map Name Pat)
check =
      forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ([SizeBinder VName]
-> PatBase NoInfo Name -> InferredType -> TermTypeM Pat
checkPat' [SizeBinder VName]
sizes)) forall a b. (a -> b) -> a -> b
$
        forall k a b c.
Ord k =>
(a -> b -> c) -> Map k a -> Map k b -> Map k c
M.intersectionWith (,) (forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(Name, PatBase NoInfo Name)]
p_fs) (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap PatType -> InferredType
Ascribed Map Name PatType
t_fs)
checkPat' [SizeBinder VName]
sizes p :: PatBase NoInfo Name
p@(RecordPat [(Name, PatBase NoInfo Name)]
fields SrcLoc
loc) (Ascribed PatType
t) = do
  Map Name StructType
fields' <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall a b. a -> b -> a
const forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) als dim.
(MonadUnify m, Monoid als) =>
SrcLoc -> Name -> m (TypeBase dim als)
newTypeVar SrcLoc
loc Name
"t") forall a b. (a -> b) -> a -> b
$ forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(Name, PatBase NoInfo Name)]
fields

  forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (forall a. Ord a => [a] -> [a]
sort (forall k a. Map k a -> [k]
M.keys Map Name StructType
fields') forall a. Eq a => a -> a -> Bool
/= forall a. Ord a => [a] -> [a]
sort (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(Name, PatBase NoInfo Name)]
fields)) forall a b. (a -> b) -> a -> b
$
    forall (m :: * -> *) loc a.
(MonadTypeChecker m, Located loc) =>
loc -> Notes -> Doc () -> m a
typeError SrcLoc
loc forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$
      Doc ()
"Duplicate fields in record pattern" forall ann. Doc ann -> Doc ann -> Doc ann
<+> forall a ann. Pretty a => a -> Doc ann
pretty PatBase NoInfo Name
p forall a. Semigroup a => a -> a -> a
<> Doc ()
"."

  forall (m :: * -> *).
MonadUnify m =>
Usage -> StructType -> StructType -> m ()
unify (SrcLoc -> Text -> Usage
mkUsage SrcLoc
loc Text
"matching a record pattern") (forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (forall dim as. Map Name (TypeBase dim as) -> ScalarTypeBase dim as
Record Map Name StructType
fields')) forall a b. (a -> b) -> a -> b
$ forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct PatType
t
  PatType
t' <- forall a (m :: * -> *). (Substitutable a, MonadUnify m) => a -> m a
normTypeFully PatType
t
  [SizeBinder VName]
-> PatBase NoInfo Name -> InferredType -> TermTypeM Pat
checkPat' [SizeBinder VName]
sizes PatBase NoInfo Name
p forall a b. (a -> b) -> a -> b
$ PatType -> InferredType
Ascribed PatType
t'
checkPat' [SizeBinder VName]
sizes (RecordPat [(Name, PatBase NoInfo Name)]
fs SrcLoc
loc) InferredType
NoneInferred =
  forall (f :: * -> *) vn.
[(Name, PatBase f vn)] -> SrcLoc -> PatBase f vn
RecordPat forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k a. Map k a -> [(k, a)]
M.toList
    forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (\PatBase NoInfo Name
p -> [SizeBinder VName]
-> PatBase NoInfo Name -> InferredType -> TermTypeM Pat
checkPat' [SizeBinder VName]
sizes PatBase NoInfo Name
p InferredType
NoneInferred) (forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(Name, PatBase NoInfo Name)]
fs)
    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure SrcLoc
loc
checkPat' [SizeBinder VName]
sizes (PatAscription PatBase NoInfo Name
p TypeExp Name
t SrcLoc
loc) InferredType
maybe_outer_t = do
  (TypeExp VName
t', StructType
st, [VName]
_) <- TypeExp Name -> TermTypeM (TypeExp VName, StructType, [VName])
checkTypeExpNonrigid TypeExp Name
t

  case InferredType
maybe_outer_t of
    Ascribed PatType
outer_t -> do
      StructType
st_forunify <- [SizeBinder VName] -> StructType -> TermTypeM StructType
nonrigidFor [SizeBinder VName]
sizes StructType
st
      forall (m :: * -> *).
MonadUnify m =>
Usage -> StructType -> StructType -> m ()
unify (SrcLoc -> Text -> Usage
mkUsage SrcLoc
loc Text
"explicit type ascription") StructType
st_forunify (forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct PatType
outer_t)

      PatType
outer_t' <- forall a (m :: * -> *). (Substitutable a, MonadUnify m) => a -> m a
normTypeFully PatType
outer_t
      forall (f :: * -> *) vn.
PatBase f vn -> TypeExp vn -> SrcLoc -> PatBase f vn
PatAscription
        forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [SizeBinder VName]
-> PatBase NoInfo Name -> InferredType -> TermTypeM Pat
checkPat' [SizeBinder VName]
sizes PatBase NoInfo Name
p (PatType -> InferredType
Ascribed (StructType -> PatType -> PatType
addAliasesFromType StructType
st PatType
outer_t'))
        forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure TypeExp VName
t'
        forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure SrcLoc
loc
    InferredType
NoneInferred ->
      forall (f :: * -> *) vn.
PatBase f vn -> TypeExp vn -> SrcLoc -> PatBase f vn
PatAscription
        forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [SizeBinder VName]
-> PatBase NoInfo Name -> InferredType -> TermTypeM Pat
checkPat' [SizeBinder VName]
sizes PatBase NoInfo Name
p (PatType -> InferredType
Ascribed (forall dim as. TypeBase dim as -> TypeBase dim Aliasing
fromStruct StructType
st))
        forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure TypeExp VName
t'
        forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure SrcLoc
loc
checkPat' [SizeBinder VName]
_ (PatLit PatLit
l NoInfo PatType
NoInfo SrcLoc
loc) (Ascribed PatType
t) = do
  StructType
t' <- PatLit -> SrcLoc -> TermTypeM StructType
patLitMkType PatLit
l SrcLoc
loc
  forall (m :: * -> *).
MonadUnify m =>
Usage -> StructType -> StructType -> m ()
unify (SrcLoc -> Text -> Usage
mkUsage SrcLoc
loc Text
"matching against literal") StructType
t' (forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct PatType
t)
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) vn.
PatLit -> f PatType -> SrcLoc -> PatBase f vn
PatLit PatLit
l (forall a. a -> Info a
Info (forall dim as. TypeBase dim as -> TypeBase dim Aliasing
fromStruct StructType
t')) SrcLoc
loc
checkPat' [SizeBinder VName]
_ (PatLit PatLit
l NoInfo PatType
NoInfo SrcLoc
loc) InferredType
NoneInferred = do
  StructType
t' <- PatLit -> SrcLoc -> TermTypeM StructType
patLitMkType PatLit
l SrcLoc
loc
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) vn.
PatLit -> f PatType -> SrcLoc -> PatBase f vn
PatLit PatLit
l (forall a. a -> Info a
Info (forall dim as. TypeBase dim as -> TypeBase dim Aliasing
fromStruct StructType
t')) SrcLoc
loc
checkPat' [SizeBinder VName]
sizes (PatConstr Name
n NoInfo PatType
NoInfo [PatBase NoInfo Name]
ps SrcLoc
loc) (Ascribed (Scalar (Sum Map Name [PatType]
cs)))
  | Just [PatType]
ts <- forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Name
n Map Name [PatType]
cs = do
      [Pat]
ps' <- forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM ([SizeBinder VName]
-> PatBase NoInfo Name -> InferredType -> TermTypeM Pat
checkPat' [SizeBinder VName]
sizes) [PatBase NoInfo Name]
ps forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map PatType -> InferredType
Ascribed [PatType]
ts
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) vn.
Name -> f PatType -> [PatBase f vn] -> SrcLoc -> PatBase f vn
PatConstr Name
n (forall a. a -> Info a
Info (forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (forall dim as. Map Name [TypeBase dim as] -> ScalarTypeBase dim as
Sum Map Name [PatType]
cs))) [Pat]
ps' SrcLoc
loc
checkPat' [SizeBinder VName]
sizes (PatConstr Name
n NoInfo PatType
NoInfo [PatBase NoInfo Name]
ps SrcLoc
loc) (Ascribed PatType
t) = do
  StructType
t' <- forall (m :: * -> *) als dim.
(MonadUnify m, Monoid als) =>
SrcLoc -> Name -> m (TypeBase dim als)
newTypeVar SrcLoc
loc Name
"t"
  [Pat]
ps' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\PatBase NoInfo Name
p -> [SizeBinder VName]
-> PatBase NoInfo Name -> InferredType -> TermTypeM Pat
checkPat' [SizeBinder VName]
sizes PatBase NoInfo Name
p InferredType
NoneInferred) [PatBase NoInfo Name]
ps
  forall (m :: * -> *).
MonadUnify m =>
Usage -> Name -> StructType -> [StructType] -> m ()
mustHaveConstr Usage
usage Name
n StructType
t' (Pat -> StructType
patternStructType forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Pat]
ps')
  forall (m :: * -> *).
MonadUnify m =>
Usage -> StructType -> StructType -> m ()
unify Usage
usage StructType
t' (forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct PatType
t)
  PatType
t'' <- forall a (m :: * -> *). (Substitutable a, MonadUnify m) => a -> m a
normTypeFully PatType
t
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) vn.
Name -> f PatType -> [PatBase f vn] -> SrcLoc -> PatBase f vn
PatConstr Name
n (forall a. a -> Info a
Info PatType
t'') [Pat]
ps' SrcLoc
loc
  where
    usage :: Usage
usage = SrcLoc -> Text -> Usage
mkUsage SrcLoc
loc Text
"matching against constructor"
checkPat' [SizeBinder VName]
sizes (PatConstr Name
n NoInfo PatType
NoInfo [PatBase NoInfo Name]
ps SrcLoc
loc) InferredType
NoneInferred = do
  [Pat]
ps' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\PatBase NoInfo Name
p -> [SizeBinder VName]
-> PatBase NoInfo Name -> InferredType -> TermTypeM Pat
checkPat' [SizeBinder VName]
sizes PatBase NoInfo Name
p InferredType
NoneInferred) [PatBase NoInfo Name]
ps
  StructType
t <- forall (m :: * -> *) als dim.
(MonadUnify m, Monoid als) =>
SrcLoc -> Name -> m (TypeBase dim als)
newTypeVar SrcLoc
loc Name
"t"
  forall (m :: * -> *).
MonadUnify m =>
Usage -> Name -> StructType -> [StructType] -> m ()
mustHaveConstr Usage
usage Name
n StructType
t (Pat -> StructType
patternStructType forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Pat]
ps')
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) vn.
Name -> f PatType -> [PatBase f vn] -> SrcLoc -> PatBase f vn
PatConstr Name
n (forall a. a -> Info a
Info forall a b. (a -> b) -> a -> b
$ forall dim as. TypeBase dim as -> TypeBase dim Aliasing
fromStruct StructType
t) [Pat]
ps' SrcLoc
loc
  where
    usage :: Usage
usage = SrcLoc -> Text -> Usage
mkUsage SrcLoc
loc Text
"matching against constructor"

patNameMap :: Pat -> NameMap
patNameMap :: Pat -> NameMap
patNameMap = forall k a. Ord k => [(k, a)] -> Map k a
M.fromList forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map VName -> ((Namespace, Name), QualName VName)
asTerm forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Set a -> [a]
S.toList forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) vn.
(Functor f, Ord vn) =>
PatBase f vn -> Set vn
patNames
  where
    asTerm :: VName -> ((Namespace, Name), QualName VName)
asTerm VName
v = ((Namespace
Term, VName -> Name
baseName VName
v), forall v. v -> QualName v
qualName VName
v)

checkPat ::
  [SizeBinder VName] ->
  UncheckedPat ->
  InferredType ->
  (Pat -> TermTypeM a) ->
  TermTypeM a
checkPat :: forall a.
[SizeBinder VName]
-> PatBase NoInfo Name
-> InferredType
-> (Pat -> TermTypeM a)
-> TermTypeM a
checkPat [SizeBinder VName]
sizes PatBase NoInfo Name
p InferredType
t Pat -> TermTypeM a
m = do
  forall (m :: * -> *).
MonadTypeChecker m =>
[UncheckedTypeParam] -> [PatBase NoInfo Name] -> m ()
checkForDuplicateNames (forall a b. (a -> b) -> [a] -> [b]
map SizeBinder VName -> UncheckedTypeParam
sizeBinderToParam [SizeBinder VName]
sizes) [PatBase NoInfo Name
p]
  Pat
p' <- forall a. Checking -> TermTypeM a -> TermTypeM a
onFailure (PatBase NoInfo Name -> InferredType -> Checking
CheckingPat PatBase NoInfo Name
p InferredType
t) forall a b. (a -> b) -> a -> b
$ [SizeBinder VName]
-> PatBase NoInfo Name -> InferredType -> TermTypeM Pat
checkPat' [SizeBinder VName]
sizes PatBase NoInfo Name
p InferredType
t

  let explicit :: Set VName
explicit = StructType -> Set VName
mustBeExplicitInType forall a b. (a -> b) -> a -> b
$ Pat -> StructType
patternStructType Pat
p'

  case forall a. (a -> Bool) -> [a] -> [a]
filter ((forall a. Ord a => a -> Set a -> Bool
`S.member` Set VName
explicit) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall vn. SizeBinder vn -> vn
sizeName) [SizeBinder VName]
sizes of
    SizeBinder VName
size : [SizeBinder VName]
_ ->
      forall (m :: * -> *) loc a.
(MonadTypeChecker m, Located loc) =>
loc -> Notes -> Doc () -> m a
typeError SizeBinder VName
size forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$
        Doc ()
"Cannot bind"
          forall ann. Doc ann -> Doc ann -> Doc ann
<+> forall a ann. Pretty a => a -> Doc ann
pretty SizeBinder VName
size
          forall ann. Doc ann -> Doc ann -> Doc ann
<+> Doc ()
"as it is never used as the size of a concrete (non-function) value."
    [] ->
      forall (m :: * -> *) a. MonadTypeChecker m => NameMap -> m a -> m a
bindNameMap (Pat -> NameMap
patNameMap Pat
p') forall a b. (a -> b) -> a -> b
$ Pat -> TermTypeM a
m Pat
p'

-- | Check and bind type and value parameters.
bindingParams ::
  [UncheckedTypeParam] ->
  [UncheckedPat] ->
  ([TypeParam] -> [Pat] -> TermTypeM a) ->
  TermTypeM a
bindingParams :: forall a.
[UncheckedTypeParam]
-> [PatBase NoInfo Name]
-> ([TypeParam] -> [Pat] -> TermTypeM a)
-> TermTypeM a
bindingParams [UncheckedTypeParam]
tps [PatBase NoInfo Name]
orig_ps [TypeParam] -> [Pat] -> TermTypeM a
m = do
  forall (m :: * -> *).
MonadTypeChecker m =>
[UncheckedTypeParam] -> [PatBase NoInfo Name] -> m ()
checkForDuplicateNames [UncheckedTypeParam]
tps [PatBase NoInfo Name]
orig_ps
  forall (m :: * -> *) a.
MonadTypeChecker m =>
[UncheckedTypeParam] -> ([TypeParam] -> m a) -> m a
checkTypeParams [UncheckedTypeParam]
tps forall a b. (a -> b) -> a -> b
$ \[TypeParam]
tps' -> forall a. [TypeParam] -> TermTypeM a -> TermTypeM a
bindingTypeParams [TypeParam]
tps' forall a b. (a -> b) -> a -> b
$ do
    let descend :: [Pat] -> [PatBase NoInfo Name] -> TermTypeM a
descend [Pat]
ps' (PatBase NoInfo Name
p : [PatBase NoInfo Name]
ps) =
          forall a.
[SizeBinder VName]
-> PatBase NoInfo Name
-> InferredType
-> (Pat -> TermTypeM a)
-> TermTypeM a
checkPat [] PatBase NoInfo Name
p InferredType
NoneInferred forall a b. (a -> b) -> a -> b
$ \Pat
p' ->
            forall a. Bool -> [Ident] -> TermTypeM a -> TermTypeM a
binding Bool
False (forall a. Set a -> [a]
S.toList forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) vn.
(Functor f, Ord vn) =>
PatBase f vn -> Set (IdentBase f vn)
patIdents Pat
p') forall a b. (a -> b) -> a -> b
$ [Pat] -> [PatBase NoInfo Name] -> TermTypeM a
descend (Pat
p' forall a. a -> [a] -> [a]
: [Pat]
ps') [PatBase NoInfo Name]
ps
        descend [Pat]
ps' [] = do
          -- Perform an observation of every type parameter.  This
          -- prevents unused-name warnings for otherwise unused
          -- dimensions.
          forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Ident -> TermTypeM ()
observe forall a b. (a -> b) -> a -> b
$ forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe TypeParam -> Maybe Ident
typeParamIdent [TypeParam]
tps'
          [TypeParam] -> [Pat] -> TermTypeM a
m [TypeParam]
tps' forall a b. (a -> b) -> a -> b
$ forall a. [a] -> [a]
reverse [Pat]
ps'

    [Pat] -> [PatBase NoInfo Name] -> TermTypeM a
descend [] [PatBase NoInfo Name]
orig_ps