{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE Trustworthy #-}
{-# LANGUAGE TupleSections #-}

-- | Facilities for type-checking Futhark terms.  Checking a term
-- requires a little more context to track uniqueness and such.
--
-- Type inference is implemented through a variation of
-- Hindley-Milner.  The main complication is supporting the rich
-- number of built-in language constructs, as well as uniqueness
-- types.  This is mostly done in an ad hoc way, and many programs
-- will require the programmer to fall back on type annotations.
module Language.Futhark.TypeChecker.Terms
  ( checkOneExp,
    checkFunDef,
  )
where

import Control.Monad.Except
import Control.Monad.RWS hiding (Sum)
import Control.Monad.State
import Control.Monad.Writer hiding (Sum)
import Data.Bifunctor
import Data.Char (isAscii)
import Data.Either
import Data.List (find, foldl', group, isPrefixOf, nub, sort, transpose, (\\))
import qualified Data.List.NonEmpty as NE
import qualified Data.Map.Strict as M
import Data.Maybe
import qualified Data.Set as S
import Futhark.IR.Primitive (intByteSize)
import Futhark.Util.Pretty hiding (bool, group, space)
import Language.Futhark hiding (unscopeType)
import Language.Futhark.Semantic (includeToString)
import Language.Futhark.Traversals
import Language.Futhark.TypeChecker.Monad hiding (BoundV)
import qualified Language.Futhark.TypeChecker.Monad as TypeM
import Language.Futhark.TypeChecker.Types hiding (checkTypeDecl)
import qualified Language.Futhark.TypeChecker.Types as Types
import Language.Futhark.TypeChecker.Unify hiding (Usage)
import Prelude hiding (mod)

--- Uniqueness

data Usage
  = Consumed SrcLoc
  | Observed SrcLoc
  deriving (Usage -> Usage -> Bool
(Usage -> Usage -> Bool) -> (Usage -> Usage -> Bool) -> Eq Usage
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Usage -> Usage -> Bool
$c/= :: Usage -> Usage -> Bool
== :: Usage -> Usage -> Bool
$c== :: Usage -> Usage -> Bool
Eq, Eq Usage
Eq Usage
-> (Usage -> Usage -> Ordering)
-> (Usage -> Usage -> Bool)
-> (Usage -> Usage -> Bool)
-> (Usage -> Usage -> Bool)
-> (Usage -> Usage -> Bool)
-> (Usage -> Usage -> Usage)
-> (Usage -> Usage -> Usage)
-> Ord Usage
Usage -> Usage -> Bool
Usage -> Usage -> Ordering
Usage -> Usage -> Usage
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: Usage -> Usage -> Usage
$cmin :: Usage -> Usage -> Usage
max :: Usage -> Usage -> Usage
$cmax :: Usage -> Usage -> Usage
>= :: Usage -> Usage -> Bool
$c>= :: Usage -> Usage -> Bool
> :: Usage -> Usage -> Bool
$c> :: Usage -> Usage -> Bool
<= :: Usage -> Usage -> Bool
$c<= :: Usage -> Usage -> Bool
< :: Usage -> Usage -> Bool
$c< :: Usage -> Usage -> Bool
compare :: Usage -> Usage -> Ordering
$ccompare :: Usage -> Usage -> Ordering
$cp1Ord :: Eq Usage
Ord, Int -> Usage -> ShowS
[Usage] -> ShowS
Usage -> String
(Int -> Usage -> ShowS)
-> (Usage -> String) -> ([Usage] -> ShowS) -> Show Usage
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Usage] -> ShowS
$cshowList :: [Usage] -> ShowS
show :: Usage -> String
$cshow :: Usage -> String
showsPrec :: Int -> Usage -> ShowS
$cshowsPrec :: Int -> Usage -> ShowS
Show)

type Names = S.Set VName

-- | The consumption set is a Maybe so we can distinguish whether a
-- consumption took place, but the variable went out of scope since,
-- or no consumption at all took place.
data Occurence = Occurence
  { Occurence -> Names
observed :: Names,
    Occurence -> Maybe Names
consumed :: Maybe Names,
    Occurence -> SrcLoc
location :: SrcLoc
  }
  deriving (Occurence -> Occurence -> Bool
(Occurence -> Occurence -> Bool)
-> (Occurence -> Occurence -> Bool) -> Eq Occurence
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Occurence -> Occurence -> Bool
$c/= :: Occurence -> Occurence -> Bool
== :: Occurence -> Occurence -> Bool
$c== :: Occurence -> Occurence -> Bool
Eq, Int -> Occurence -> ShowS
[Occurence] -> ShowS
Occurence -> String
(Int -> Occurence -> ShowS)
-> (Occurence -> String)
-> ([Occurence] -> ShowS)
-> Show Occurence
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Occurence] -> ShowS
$cshowList :: [Occurence] -> ShowS
show :: Occurence -> String
$cshow :: Occurence -> String
showsPrec :: Int -> Occurence -> ShowS
$cshowsPrec :: Int -> Occurence -> ShowS
Show)

instance Located Occurence where
  locOf :: Occurence -> Loc
locOf = SrcLoc -> Loc
forall a. Located a => a -> Loc
locOf (SrcLoc -> Loc) -> (Occurence -> SrcLoc) -> Occurence -> Loc
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Occurence -> SrcLoc
location

observation :: Aliasing -> SrcLoc -> Occurence
observation :: Aliasing -> SrcLoc -> Occurence
observation = (Names -> Maybe Names -> SrcLoc -> Occurence)
-> Maybe Names -> Names -> SrcLoc -> Occurence
forall a b c. (a -> b -> c) -> b -> a -> c
flip Names -> Maybe Names -> SrcLoc -> Occurence
Occurence Maybe Names
forall a. Maybe a
Nothing (Names -> SrcLoc -> Occurence)
-> (Aliasing -> Names) -> Aliasing -> SrcLoc -> Occurence
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Alias -> VName) -> Aliasing -> Names
forall b a. Ord b => (a -> b) -> Set a -> Set b
S.map Alias -> VName
aliasVar

consumption :: Aliasing -> SrcLoc -> Occurence
consumption :: Aliasing -> SrcLoc -> Occurence
consumption = Names -> Maybe Names -> SrcLoc -> Occurence
Occurence Names
forall a. Set a
S.empty (Maybe Names -> SrcLoc -> Occurence)
-> (Aliasing -> Maybe Names) -> Aliasing -> SrcLoc -> Occurence
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Names -> Maybe Names
forall a. a -> Maybe a
Just (Names -> Maybe Names)
-> (Aliasing -> Names) -> Aliasing -> Maybe Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Alias -> VName) -> Aliasing -> Names
forall b a. Ord b => (a -> b) -> Set a -> Set b
S.map Alias -> VName
aliasVar

-- | A null occurence is one that we can remove without affecting
-- anything.
nullOccurence :: Occurence -> Bool
nullOccurence :: Occurence -> Bool
nullOccurence Occurence
occ = Names -> Bool
forall a. Set a -> Bool
S.null (Occurence -> Names
observed Occurence
occ) Bool -> Bool -> Bool
&& Maybe Names -> Bool
forall a. Maybe a -> Bool
isNothing (Occurence -> Maybe Names
consumed Occurence
occ)

-- | A seminull occurence is one that does not contain references to
-- any variables in scope.  The big difference is that a seminull
-- occurence may denote a consumption, as long as the array that was
-- consumed is now out of scope.
seminullOccurence :: Occurence -> Bool
seminullOccurence :: Occurence -> Bool
seminullOccurence Occurence
occ = Names -> Bool
forall a. Set a -> Bool
S.null (Occurence -> Names
observed Occurence
occ) Bool -> Bool -> Bool
&& Bool -> (Names -> Bool) -> Maybe Names -> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
True Names -> Bool
forall a. Set a -> Bool
S.null (Occurence -> Maybe Names
consumed Occurence
occ)

type Occurences = [Occurence]

type UsageMap = M.Map VName [Usage]

usageMap :: Occurences -> UsageMap
usageMap :: [Occurence] -> UsageMap
usageMap = (UsageMap -> Occurence -> UsageMap)
-> UsageMap -> [Occurence] -> UsageMap
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl UsageMap -> Occurence -> UsageMap
comb UsageMap
forall k a. Map k a
M.empty
  where
    comb :: UsageMap -> Occurence -> UsageMap
comb UsageMap
m (Occurence Names
obs Maybe Names
cons SrcLoc
loc) =
      let m' :: UsageMap
m' = (UsageMap -> VName -> UsageMap) -> UsageMap -> Names -> UsageMap
forall a b. (a -> b -> a) -> a -> Set b -> a
S.foldl' (Usage -> UsageMap -> VName -> UsageMap
forall k a. Ord k => a -> Map k [a] -> k -> Map k [a]
ins (Usage -> UsageMap -> VName -> UsageMap)
-> Usage -> UsageMap -> VName -> UsageMap
forall a b. (a -> b) -> a -> b
$ SrcLoc -> Usage
Observed SrcLoc
loc) UsageMap
m Names
obs
       in (UsageMap -> VName -> UsageMap) -> UsageMap -> Names -> UsageMap
forall a b. (a -> b -> a) -> a -> Set b -> a
S.foldl' (Usage -> UsageMap -> VName -> UsageMap
forall k a. Ord k => a -> Map k [a] -> k -> Map k [a]
ins (Usage -> UsageMap -> VName -> UsageMap)
-> Usage -> UsageMap -> VName -> UsageMap
forall a b. (a -> b) -> a -> b
$ SrcLoc -> Usage
Consumed SrcLoc
loc) UsageMap
m' (Names -> UsageMap) -> Names -> UsageMap
forall a b. (a -> b) -> a -> b
$ Names -> Maybe Names -> Names
forall a. a -> Maybe a -> a
fromMaybe Names
forall a. Monoid a => a
mempty Maybe Names
cons
    ins :: a -> Map k [a] -> k -> Map k [a]
ins a
v Map k [a]
m k
k = ([a] -> [a] -> [a]) -> k -> [a] -> Map k [a] -> Map k [a]
forall k a. Ord k => (a -> a -> a) -> k -> a -> Map k a -> Map k a
M.insertWith [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
(++) k
k [a
v] Map k [a]
m

combineOccurences :: VName -> Usage -> Usage -> TermTypeM Usage
combineOccurences :: VName -> Usage -> Usage -> TermTypeM Usage
combineOccurences VName
_ (Observed SrcLoc
loc) (Observed SrcLoc
_) = Usage -> TermTypeM Usage
forall (m :: * -> *) a. Monad m => a -> m a
return (Usage -> TermTypeM Usage) -> Usage -> TermTypeM Usage
forall a b. (a -> b) -> a -> b
$ SrcLoc -> Usage
Observed SrcLoc
loc
combineOccurences VName
name (Consumed SrcLoc
wloc) (Observed SrcLoc
rloc) =
  Name -> SrcLoc -> SrcLoc -> TermTypeM Usage
forall a. Name -> SrcLoc -> SrcLoc -> TermTypeM a
useAfterConsume (VName -> Name
baseName VName
name) SrcLoc
rloc SrcLoc
wloc
combineOccurences VName
name (Observed SrcLoc
rloc) (Consumed SrcLoc
wloc) =
  Name -> SrcLoc -> SrcLoc -> TermTypeM Usage
forall a. Name -> SrcLoc -> SrcLoc -> TermTypeM a
useAfterConsume (VName -> Name
baseName VName
name) SrcLoc
rloc SrcLoc
wloc
combineOccurences VName
name (Consumed SrcLoc
loc1) (Consumed SrcLoc
loc2) =
  Name -> SrcLoc -> SrcLoc -> TermTypeM Usage
forall a. Name -> SrcLoc -> SrcLoc -> TermTypeM a
consumeAfterConsume (VName -> Name
baseName VName
name) (SrcLoc -> SrcLoc -> SrcLoc
forall a. Ord a => a -> a -> a
max SrcLoc
loc1 SrcLoc
loc2) (SrcLoc -> SrcLoc -> SrcLoc
forall a. Ord a => a -> a -> a
min SrcLoc
loc1 SrcLoc
loc2)

checkOccurences :: Occurences -> TermTypeM ()
checkOccurences :: [Occurence] -> TermTypeM ()
checkOccurences = TermTypeM (Map VName ()) -> TermTypeM ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (TermTypeM (Map VName ()) -> TermTypeM ())
-> ([Occurence] -> TermTypeM (Map VName ()))
-> [Occurence]
-> TermTypeM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> [Usage] -> TermTypeM ())
-> UsageMap -> TermTypeM (Map VName ())
forall (t :: * -> *) k a b.
Applicative t =>
(k -> a -> t b) -> Map k a -> t (Map k b)
M.traverseWithKey VName -> [Usage] -> TermTypeM ()
comb (UsageMap -> TermTypeM (Map VName ()))
-> ([Occurence] -> UsageMap)
-> [Occurence]
-> TermTypeM (Map VName ())
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Occurence] -> UsageMap
usageMap
  where
    comb :: VName -> [Usage] -> TermTypeM ()
comb VName
_ [] = () -> TermTypeM ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    comb VName
name (Usage
u : [Usage]
us) = (Usage -> Usage -> TermTypeM Usage)
-> Usage -> [Usage] -> TermTypeM ()
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m ()
foldM_ (VName -> Usage -> Usage -> TermTypeM Usage
combineOccurences VName
name) Usage
u [Usage]
us

allObserved :: Occurences -> Names
allObserved :: [Occurence] -> Names
allObserved = [Names] -> Names
forall (f :: * -> *) a. (Foldable f, Ord a) => f (Set a) -> Set a
S.unions ([Names] -> Names)
-> ([Occurence] -> [Names]) -> [Occurence] -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Occurence -> Names) -> [Occurence] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map Occurence -> Names
observed

allConsumed :: Occurences -> Names
allConsumed :: [Occurence] -> Names
allConsumed = [Names] -> Names
forall (f :: * -> *) a. (Foldable f, Ord a) => f (Set a) -> Set a
S.unions ([Names] -> Names)
-> ([Occurence] -> [Names]) -> [Occurence] -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Occurence -> Names) -> [Occurence] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map (Names -> Maybe Names -> Names
forall a. a -> Maybe a -> a
fromMaybe Names
forall a. Monoid a => a
mempty (Maybe Names -> Names)
-> (Occurence -> Maybe Names) -> Occurence -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Occurence -> Maybe Names
consumed)

allOccuring :: Occurences -> Names
allOccuring :: [Occurence] -> Names
allOccuring [Occurence]
occs = [Occurence] -> Names
allConsumed [Occurence]
occs Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> [Occurence] -> Names
allObserved [Occurence]
occs

anyConsumption :: Occurences -> Maybe Occurence
anyConsumption :: [Occurence] -> Maybe Occurence
anyConsumption = (Occurence -> Bool) -> [Occurence] -> Maybe Occurence
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (Maybe Names -> Bool
forall a. Maybe a -> Bool
isJust (Maybe Names -> Bool)
-> (Occurence -> Maybe Names) -> Occurence -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Occurence -> Maybe Names
consumed)

seqOccurences :: Occurences -> Occurences -> Occurences
seqOccurences :: [Occurence] -> [Occurence] -> [Occurence]
seqOccurences [Occurence]
occurs1 [Occurence]
occurs2 =
  (Occurence -> Bool) -> [Occurence] -> [Occurence]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not (Bool -> Bool) -> (Occurence -> Bool) -> Occurence -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Occurence -> Bool
nullOccurence) ([Occurence] -> [Occurence]) -> [Occurence] -> [Occurence]
forall a b. (a -> b) -> a -> b
$ (Occurence -> Occurence) -> [Occurence] -> [Occurence]
forall a b. (a -> b) -> [a] -> [b]
map Occurence -> Occurence
filt [Occurence]
occurs1 [Occurence] -> [Occurence] -> [Occurence]
forall a. [a] -> [a] -> [a]
++ [Occurence]
occurs2
  where
    filt :: Occurence -> Occurence
filt Occurence
occ =
      Occurence
occ {observed :: Names
observed = Occurence -> Names
observed Occurence
occ Names -> Names -> Names
forall a. Ord a => Set a -> Set a -> Set a
`S.difference` Names
postcons}
    postcons :: Names
postcons = [Occurence] -> Names
allConsumed [Occurence]
occurs2

altOccurences :: Occurences -> Occurences -> Occurences
altOccurences :: [Occurence] -> [Occurence] -> [Occurence]
altOccurences [Occurence]
occurs1 [Occurence]
occurs2 =
  (Occurence -> Bool) -> [Occurence] -> [Occurence]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not (Bool -> Bool) -> (Occurence -> Bool) -> Occurence -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Occurence -> Bool
nullOccurence) ([Occurence] -> [Occurence]) -> [Occurence] -> [Occurence]
forall a b. (a -> b) -> a -> b
$ (Occurence -> Occurence) -> [Occurence] -> [Occurence]
forall a b. (a -> b) -> [a] -> [b]
map Occurence -> Occurence
filt1 [Occurence]
occurs1 [Occurence] -> [Occurence] -> [Occurence]
forall a. [a] -> [a] -> [a]
++ (Occurence -> Occurence) -> [Occurence] -> [Occurence]
forall a b. (a -> b) -> [a] -> [b]
map Occurence -> Occurence
filt2 [Occurence]
occurs2
  where
    filt1 :: Occurence -> Occurence
filt1 Occurence
occ =
      Occurence
occ
        { consumed :: Maybe Names
consumed = Names -> Names -> Names
forall a. Ord a => Set a -> Set a -> Set a
S.difference (Names -> Names -> Names) -> Maybe Names -> Maybe (Names -> Names)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Occurence -> Maybe Names
consumed Occurence
occ Maybe (Names -> Names) -> Maybe Names -> Maybe Names
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Names -> Maybe Names
forall (f :: * -> *) a. Applicative f => a -> f a
pure Names
cons2,
          observed :: Names
observed = Occurence -> Names
observed Occurence
occ Names -> Names -> Names
forall a. Ord a => Set a -> Set a -> Set a
`S.difference` Names
cons2
        }
    filt2 :: Occurence -> Occurence
filt2 Occurence
occ =
      Occurence
occ
        { consumed :: Maybe Names
consumed = Occurence -> Maybe Names
consumed Occurence
occ,
          observed :: Names
observed = Occurence -> Names
observed Occurence
occ Names -> Names -> Names
forall a. Ord a => Set a -> Set a -> Set a
`S.difference` Names
cons1
        }
    cons1 :: Names
cons1 = [Occurence] -> Names
allConsumed [Occurence]
occurs1
    cons2 :: Names
cons2 = [Occurence] -> Names
allConsumed [Occurence]
occurs2

--- Scope management

data Checking
  = CheckingApply (Maybe (QualName VName)) Exp StructType StructType
  | CheckingReturn StructType StructType
  | CheckingAscription StructType StructType
  | CheckingLetGeneralise Name
  | CheckingParams (Maybe Name)
  | CheckingPattern UncheckedPattern InferredType
  | CheckingLoopBody StructType StructType
  | CheckingLoopInitial StructType StructType
  | CheckingRecordUpdate [Name] StructType StructType
  | CheckingRequired [StructType] StructType
  | CheckingBranches StructType StructType

instance Pretty Checking where
  ppr :: Checking -> Doc
ppr (CheckingApply Maybe (QualName VName)
f Exp
e StructType
expected StructType
actual) =
    Doc
header
      Doc -> Doc -> Doc
</> Doc
"Expected:" Doc -> Doc -> Doc
<+> Doc -> Doc
align (StructType -> Doc
forall a. Pretty a => a -> Doc
ppr StructType
expected)
      Doc -> Doc -> Doc
</> Doc
"Actual:  " Doc -> Doc -> Doc
<+> Doc -> Doc
align (StructType -> Doc
forall a. Pretty a => a -> Doc
ppr StructType
actual)
    where
      header :: Doc
header =
        case Maybe (QualName VName)
f of
          Maybe (QualName VName)
Nothing ->
            Doc
"Cannot apply function to"
              Doc -> Doc -> Doc
<+> Doc -> Doc
pquote (String -> Doc
forall a. Pretty a => a -> Doc
shorten (String -> Doc) -> String -> Doc
forall a b. (a -> b) -> a -> b
$ Doc -> String
forall a. Pretty a => a -> String
pretty (Doc -> String) -> Doc -> String
forall a b. (a -> b) -> a -> b
$ Doc -> Doc
flatten (Doc -> Doc) -> Doc -> Doc
forall a b. (a -> b) -> a -> b
$ Exp -> Doc
forall a. Pretty a => a -> Doc
ppr Exp
e) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
" (invalid type)."
          Just QualName VName
fname ->
            Doc
"Cannot apply" Doc -> Doc -> Doc
<+> Doc -> Doc
pquote (QualName VName -> Doc
forall a. Pretty a => a -> Doc
ppr QualName VName
fname) Doc -> Doc -> Doc
<+> Doc
"to"
              Doc -> Doc -> Doc
<+> Doc -> Doc
pquote (String -> Doc
forall a. Pretty a => a -> Doc
shorten (String -> Doc) -> String -> Doc
forall a b. (a -> b) -> a -> b
$ Doc -> String
forall a. Pretty a => a -> String
pretty (Doc -> String) -> Doc -> String
forall a b. (a -> b) -> a -> b
$ Doc -> Doc
flatten (Doc -> Doc) -> Doc -> Doc
forall a b. (a -> b) -> a -> b
$ Exp -> Doc
forall a. Pretty a => a -> Doc
ppr Exp
e) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
" (invalid type)."
  ppr (CheckingReturn StructType
expected StructType
actual) =
    Doc
"Function body does not have expected type."
      Doc -> Doc -> Doc
</> Doc
"Expected:" Doc -> Doc -> Doc
<+> Doc -> Doc
align (StructType -> Doc
forall a. Pretty a => a -> Doc
ppr StructType
expected)
      Doc -> Doc -> Doc
</> Doc
"Actual:  " Doc -> Doc -> Doc
<+> Doc -> Doc
align (StructType -> Doc
forall a. Pretty a => a -> Doc
ppr StructType
actual)
  ppr (CheckingAscription StructType
expected StructType
actual) =
    Doc
"Expression does not have expected type from explicit ascription."
      Doc -> Doc -> Doc
</> Doc
"Expected:" Doc -> Doc -> Doc
<+> Doc -> Doc
align (StructType -> Doc
forall a. Pretty a => a -> Doc
ppr StructType
expected)
      Doc -> Doc -> Doc
</> Doc
"Actual:  " Doc -> Doc -> Doc
<+> Doc -> Doc
align (StructType -> Doc
forall a. Pretty a => a -> Doc
ppr StructType
actual)
  ppr (CheckingLetGeneralise Name
fname) =
    Doc
"Cannot generalise type of" Doc -> Doc -> Doc
<+> Doc -> Doc
pquote (Name -> Doc
forall a. Pretty a => a -> Doc
ppr Name
fname) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
"."
  ppr (CheckingParams Maybe Name
fname) =
    Doc
"Invalid use of parameters in" Doc -> Doc -> Doc
<+> Doc -> Doc
pquote Doc
fname' Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
"."
    where
      fname' :: Doc
fname' = Doc -> (Name -> Doc) -> Maybe Name -> Doc
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Doc
"anonymous function" Name -> Doc
forall a. Pretty a => a -> Doc
ppr Maybe Name
fname
  ppr (CheckingPattern UncheckedPattern
pat InferredType
NoneInferred) =
    Doc
"Invalid pattern" Doc -> Doc -> Doc
<+> Doc -> Doc
pquote (UncheckedPattern -> Doc
forall a. Pretty a => a -> Doc
ppr UncheckedPattern
pat) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
"."
  ppr (CheckingPattern UncheckedPattern
pat (Ascribed PatternType
t)) =
    Doc
"Pattern" Doc -> Doc -> Doc
<+> Doc -> Doc
pquote (UncheckedPattern -> Doc
forall a. Pretty a => a -> Doc
ppr UncheckedPattern
pat)
      Doc -> Doc -> Doc
<+> Doc
"cannot match value of type"
      Doc -> Doc -> Doc
</> Int -> Doc -> Doc
indent Int
2 (PatternType -> Doc
forall a. Pretty a => a -> Doc
ppr PatternType
t)
  ppr (CheckingLoopBody StructType
expected StructType
actual) =
    Doc
"Loop body does not have expected type."
      Doc -> Doc -> Doc
</> Doc
"Expected:" Doc -> Doc -> Doc
<+> Doc -> Doc
align (StructType -> Doc
forall a. Pretty a => a -> Doc
ppr StructType
expected)
      Doc -> Doc -> Doc
</> Doc
"Actual:  " Doc -> Doc -> Doc
<+> Doc -> Doc
align (StructType -> Doc
forall a. Pretty a => a -> Doc
ppr StructType
actual)
  ppr (CheckingLoopInitial StructType
expected StructType
actual) =
    Doc
"Initial loop values do not have expected type."
      Doc -> Doc -> Doc
</> Doc
"Expected:" Doc -> Doc -> Doc
<+> Doc -> Doc
align (StructType -> Doc
forall a. Pretty a => a -> Doc
ppr StructType
expected)
      Doc -> Doc -> Doc
</> Doc
"Actual:  " Doc -> Doc -> Doc
<+> Doc -> Doc
align (StructType -> Doc
forall a. Pretty a => a -> Doc
ppr StructType
actual)
  ppr (CheckingRecordUpdate [Name]
fs StructType
expected StructType
actual) =
    Doc
"Type mismatch when updating record field" Doc -> Doc -> Doc
<+> Doc -> Doc
pquote Doc
fs' Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
"."
      Doc -> Doc -> Doc
</> Doc
"Existing:" Doc -> Doc -> Doc
<+> Doc -> Doc
align (StructType -> Doc
forall a. Pretty a => a -> Doc
ppr StructType
expected)
      Doc -> Doc -> Doc
</> Doc
"New:     " Doc -> Doc -> Doc
<+> Doc -> Doc
align (StructType -> Doc
forall a. Pretty a => a -> Doc
ppr StructType
actual)
    where
      fs' :: Doc
fs' = [Doc] -> Doc
forall a. Monoid a => [a] -> a
mconcat ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ Doc -> [Doc] -> [Doc]
punctuate Doc
"." ([Doc] -> [Doc]) -> [Doc] -> [Doc]
forall a b. (a -> b) -> a -> b
$ (Name -> Doc) -> [Name] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map Name -> Doc
forall a. Pretty a => a -> Doc
ppr [Name]
fs
  ppr (CheckingRequired [StructType
expected] StructType
actual) =
    Doc
"Expression must must have type" Doc -> Doc -> Doc
<+> StructType -> Doc
forall a. Pretty a => a -> Doc
ppr StructType
expected Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
"."
      Doc -> Doc -> Doc
</> Doc
"Actual type:" Doc -> Doc -> Doc
<+> Doc -> Doc
align (StructType -> Doc
forall a. Pretty a => a -> Doc
ppr StructType
actual)
  ppr (CheckingRequired [StructType]
expected StructType
actual) =
    Doc
"Type of expression must must be one of " Doc -> Doc -> Doc
<+> Doc
expected' Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
"."
      Doc -> Doc -> Doc
</> Doc
"Actual type:" Doc -> Doc -> Doc
<+> Doc -> Doc
align (StructType -> Doc
forall a. Pretty a => a -> Doc
ppr StructType
actual)
    where
      expected' :: Doc
expected' = [Doc] -> Doc
commasep ((StructType -> Doc) -> [StructType] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map StructType -> Doc
forall a. Pretty a => a -> Doc
ppr [StructType]
expected)
  ppr (CheckingBranches StructType
t1 StructType
t2) =
    Doc
"Conditional branches differ in type."
      Doc -> Doc -> Doc
</> Doc
"Former:" Doc -> Doc -> Doc
<+> StructType -> Doc
forall a. Pretty a => a -> Doc
ppr StructType
t1
      Doc -> Doc -> Doc
</> Doc
"Latter:" Doc -> Doc -> Doc
<+> StructType -> Doc
forall a. Pretty a => a -> Doc
ppr StructType
t2

-- | Whether something is a global or a local variable.
data Locality = Local | Global
  deriving (Int -> Locality -> ShowS
[Locality] -> ShowS
Locality -> String
(Int -> Locality -> ShowS)
-> (Locality -> String) -> ([Locality] -> ShowS) -> Show Locality
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Locality] -> ShowS
$cshowList :: [Locality] -> ShowS
show :: Locality -> String
$cshow :: Locality -> String
showsPrec :: Int -> Locality -> ShowS
$cshowsPrec :: Int -> Locality -> ShowS
Show)

data ValBinding
  = -- | Aliases in parameters indicate the lexical
    -- closure.
    BoundV Locality [TypeParam] PatternType
  | OverloadedF [PrimType] [Maybe PrimType] (Maybe PrimType)
  | EqualityF
  | WasConsumed SrcLoc
  deriving (Int -> ValBinding -> ShowS
[ValBinding] -> ShowS
ValBinding -> String
(Int -> ValBinding -> ShowS)
-> (ValBinding -> String)
-> ([ValBinding] -> ShowS)
-> Show ValBinding
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ValBinding] -> ShowS
$cshowList :: [ValBinding] -> ShowS
show :: ValBinding -> String
$cshow :: ValBinding -> String
showsPrec :: Int -> ValBinding -> ShowS
$cshowsPrec :: Int -> ValBinding -> ShowS
Show)

-- | Type checking happens with access to this environment.  The
-- 'TermScope' will be extended during type-checking as bindings come into
-- scope.
data TermEnv = TermEnv
  { TermEnv -> TermScope
termScope :: TermScope,
    TermEnv -> Maybe Checking
termChecking :: Maybe Checking,
    TermEnv -> Int
termLevel :: Level
  }

data TermScope = TermScope
  { TermScope -> Map VName ValBinding
scopeVtable :: M.Map VName ValBinding,
    TermScope -> Map VName TypeBinding
scopeTypeTable :: M.Map VName TypeBinding,
    TermScope -> Map VName Mod
scopeModTable :: M.Map VName Mod,
    TermScope -> NameMap
scopeNameMap :: NameMap
  }
  deriving (Int -> TermScope -> ShowS
[TermScope] -> ShowS
TermScope -> String
(Int -> TermScope -> ShowS)
-> (TermScope -> String)
-> ([TermScope] -> ShowS)
-> Show TermScope
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [TermScope] -> ShowS
$cshowList :: [TermScope] -> ShowS
show :: TermScope -> String
$cshow :: TermScope -> String
showsPrec :: Int -> TermScope -> ShowS
$cshowsPrec :: Int -> TermScope -> ShowS
Show)

instance Semigroup TermScope where
  TermScope Map VName ValBinding
vt1 Map VName TypeBinding
tt1 Map VName Mod
mt1 NameMap
nt1 <> :: TermScope -> TermScope -> TermScope
<> TermScope Map VName ValBinding
vt2 Map VName TypeBinding
tt2 Map VName Mod
mt2 NameMap
nt2 =
    Map VName ValBinding
-> Map VName TypeBinding -> Map VName Mod -> NameMap -> TermScope
TermScope (Map VName ValBinding
vt2 Map VName ValBinding
-> Map VName ValBinding -> Map VName ValBinding
forall k a. Ord k => Map k a -> Map k a -> Map k a
`M.union` Map VName ValBinding
vt1) (Map VName TypeBinding
tt2 Map VName TypeBinding
-> Map VName TypeBinding -> Map VName TypeBinding
forall k a. Ord k => Map k a -> Map k a -> Map k a
`M.union` Map VName TypeBinding
tt1) (Map VName Mod
mt1 Map VName Mod -> Map VName Mod -> Map VName Mod
forall k a. Ord k => Map k a -> Map k a -> Map k a
`M.union` Map VName Mod
mt2) (NameMap
nt2 NameMap -> NameMap -> NameMap
forall k a. Ord k => Map k a -> Map k a -> Map k a
`M.union` NameMap
nt1)

envToTermScope :: Env -> TermScope
envToTermScope :: Env -> TermScope
envToTermScope Env
env =
  TermScope :: Map VName ValBinding
-> Map VName TypeBinding -> Map VName Mod -> NameMap -> TermScope
TermScope
    { scopeVtable :: Map VName ValBinding
scopeVtable = Map VName ValBinding
vtable,
      scopeTypeTable :: Map VName TypeBinding
scopeTypeTable = Env -> Map VName TypeBinding
envTypeTable Env
env,
      scopeNameMap :: NameMap
scopeNameMap = Env -> NameMap
envNameMap Env
env,
      scopeModTable :: Map VName Mod
scopeModTable = Env -> Map VName Mod
envModTable Env
env
    }
  where
    vtable :: Map VName ValBinding
vtable = (VName -> BoundV -> ValBinding)
-> Map VName BoundV -> Map VName ValBinding
forall k a b. (k -> a -> b) -> Map k a -> Map k b
M.mapWithKey VName -> BoundV -> ValBinding
valBinding (Map VName BoundV -> Map VName ValBinding)
-> Map VName BoundV -> Map VName ValBinding
forall a b. (a -> b) -> a -> b
$ Env -> Map VName BoundV
envVtable Env
env
    valBinding :: VName -> BoundV -> ValBinding
valBinding VName
k (TypeM.BoundV [TypeParam]
tps StructType
v) =
      Locality -> [TypeParam] -> PatternType -> ValBinding
BoundV Locality
Global [TypeParam]
tps (PatternType -> ValBinding) -> PatternType -> ValBinding
forall a b. (a -> b) -> a -> b
$
        StructType
v
          StructType -> Aliasing -> PatternType
forall dim asf ast. TypeBase dim asf -> ast -> TypeBase dim ast
`setAliases` (if StructType -> Int
forall dim as. TypeBase dim as -> Int
arrayRank StructType
v Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0 then Alias -> Aliasing
forall a. a -> Set a
S.singleton (VName -> Alias
AliasBound VName
k) else Aliasing
forall a. Monoid a => a
mempty)

withEnv :: TermEnv -> Env -> TermEnv
withEnv :: TermEnv -> Env -> TermEnv
withEnv TermEnv
tenv Env
env = TermEnv
tenv {termScope :: TermScope
termScope = TermEnv -> TermScope
termScope TermEnv
tenv TermScope -> TermScope -> TermScope
forall a. Semigroup a => a -> a -> a
<> Env -> TermScope
envToTermScope Env
env}

overloadedTypeVars :: Constraints -> Names
overloadedTypeVars :: Constraints -> Names
overloadedTypeVars = [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ([Names] -> Names)
-> (Constraints -> [Names]) -> Constraints -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Int, Constraint) -> Names) -> [(Int, Constraint)] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map (Int, Constraint) -> Names
forall a. (a, Constraint) -> Names
f ([(Int, Constraint)] -> [Names])
-> (Constraints -> [(Int, Constraint)]) -> Constraints -> [Names]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Constraints -> [(Int, Constraint)]
forall k a. Map k a -> [a]
M.elems
  where
    f :: (a, Constraint) -> Names
f (a
_, HasFields Map Name StructType
fs Usage
_) = [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ([Names] -> Names) -> [Names] -> Names
forall a b. (a -> b) -> a -> b
$ (StructType -> Names) -> [StructType] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map StructType -> Names
forall as dim. Monoid as => TypeBase dim as -> Names
typeVars ([StructType] -> [Names]) -> [StructType] -> [Names]
forall a b. (a -> b) -> a -> b
$ Map Name StructType -> [StructType]
forall k a. Map k a -> [a]
M.elems Map Name StructType
fs
    f (a, Constraint)
_ = Names
forall a. Monoid a => a
mempty

-- | Get the type of an expression, with top level type variables
-- substituted.  Never call 'typeOf' directly (except in a few
-- carefully inspected locations)!
expType :: Exp -> TermTypeM PatternType
expType :: Exp -> TermTypeM PatternType
expType = PatternType -> TermTypeM PatternType
forall (m :: * -> *). MonadUnify m => PatternType -> m PatternType
normPatternType (PatternType -> TermTypeM PatternType)
-> (Exp -> PatternType) -> Exp -> TermTypeM PatternType
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Exp -> PatternType
typeOf

-- | Get the type of an expression, with all type variables
-- substituted.  Slower than 'expType', but sometimes necessary.
-- Never call 'typeOf' directly (except in a few carefully inspected
-- locations)!
expTypeFully :: Exp -> TermTypeM PatternType
expTypeFully :: Exp -> TermTypeM PatternType
expTypeFully = PatternType -> TermTypeM PatternType
forall a (m :: * -> *). (Substitutable a, MonadUnify m) => a -> m a
normTypeFully (PatternType -> TermTypeM PatternType)
-> (Exp -> PatternType) -> Exp -> TermTypeM PatternType
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Exp -> PatternType
typeOf

-- Wrap a function name to give it a vacuous Eq instance for SizeSource.
newtype FName = FName (Maybe (QualName VName))
  deriving (Int -> FName -> ShowS
[FName] -> ShowS
FName -> String
(Int -> FName -> ShowS)
-> (FName -> String) -> ([FName] -> ShowS) -> Show FName
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [FName] -> ShowS
$cshowList :: [FName] -> ShowS
show :: FName -> String
$cshow :: FName -> String
showsPrec :: Int -> FName -> ShowS
$cshowsPrec :: Int -> FName -> ShowS
Show)

instance Eq FName where
  FName
_ == :: FName -> FName -> Bool
== FName
_ = Bool
True

instance Ord FName where
  compare :: FName -> FName -> Ordering
compare FName
_ FName
_ = Ordering
EQ

-- | What was the source of some existential size?  This is used for
-- using the same existential variable if the same source is
-- encountered in multiple locations.
data SizeSource
  = SourceArg FName (ExpBase NoInfo VName)
  | SourceBound (ExpBase NoInfo VName)
  | SourceSlice
      (Maybe (DimDecl VName))
      (Maybe (ExpBase NoInfo VName))
      (Maybe (ExpBase NoInfo VName))
      (Maybe (ExpBase NoInfo VName))
  deriving (SizeSource -> SizeSource -> Bool
(SizeSource -> SizeSource -> Bool)
-> (SizeSource -> SizeSource -> Bool) -> Eq SizeSource
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SizeSource -> SizeSource -> Bool
$c/= :: SizeSource -> SizeSource -> Bool
== :: SizeSource -> SizeSource -> Bool
$c== :: SizeSource -> SizeSource -> Bool
Eq, Eq SizeSource
Eq SizeSource
-> (SizeSource -> SizeSource -> Ordering)
-> (SizeSource -> SizeSource -> Bool)
-> (SizeSource -> SizeSource -> Bool)
-> (SizeSource -> SizeSource -> Bool)
-> (SizeSource -> SizeSource -> Bool)
-> (SizeSource -> SizeSource -> SizeSource)
-> (SizeSource -> SizeSource -> SizeSource)
-> Ord SizeSource
SizeSource -> SizeSource -> Bool
SizeSource -> SizeSource -> Ordering
SizeSource -> SizeSource -> SizeSource
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: SizeSource -> SizeSource -> SizeSource
$cmin :: SizeSource -> SizeSource -> SizeSource
max :: SizeSource -> SizeSource -> SizeSource
$cmax :: SizeSource -> SizeSource -> SizeSource
>= :: SizeSource -> SizeSource -> Bool
$c>= :: SizeSource -> SizeSource -> Bool
> :: SizeSource -> SizeSource -> Bool
$c> :: SizeSource -> SizeSource -> Bool
<= :: SizeSource -> SizeSource -> Bool
$c<= :: SizeSource -> SizeSource -> Bool
< :: SizeSource -> SizeSource -> Bool
$c< :: SizeSource -> SizeSource -> Bool
compare :: SizeSource -> SizeSource -> Ordering
$ccompare :: SizeSource -> SizeSource -> Ordering
$cp1Ord :: Eq SizeSource
Ord, Int -> SizeSource -> ShowS
[SizeSource] -> ShowS
SizeSource -> String
(Int -> SizeSource -> ShowS)
-> (SizeSource -> String)
-> ([SizeSource] -> ShowS)
-> Show SizeSource
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SizeSource] -> ShowS
$cshowList :: [SizeSource] -> ShowS
show :: SizeSource -> String
$cshow :: SizeSource -> String
showsPrec :: Int -> SizeSource -> ShowS
$cshowsPrec :: Int -> SizeSource -> ShowS
Show)

-- | The state is a set of constraints and a counter for generating
-- type names.  This is distinct from the usual counter we use for
-- generating unique names, as these will be user-visible.
data TermTypeState = TermTypeState
  { TermTypeState -> Constraints
stateConstraints :: Constraints,
    TermTypeState -> Int
stateCounter :: !Int,
    -- | Mapping function arguments encountered to
    -- the sizes they ended up generating (when
    -- they could not be substituted directly).
    -- This happens for function arguments that are
    -- not constants or names.
    TermTypeState -> Map SizeSource VName
stateDimTable :: M.Map SizeSource VName
  }

newtype TermTypeM a
  = TermTypeM
      ( RWST
          TermEnv
          Occurences
          TermTypeState
          TypeM
          a
      )
  deriving
    ( Applicative TermTypeM
a -> TermTypeM a
Applicative TermTypeM
-> (forall a b. TermTypeM a -> (a -> TermTypeM b) -> TermTypeM b)
-> (forall a b. TermTypeM a -> TermTypeM b -> TermTypeM b)
-> (forall a. a -> TermTypeM a)
-> Monad TermTypeM
TermTypeM a -> (a -> TermTypeM b) -> TermTypeM b
TermTypeM a -> TermTypeM b -> TermTypeM b
forall a. a -> TermTypeM a
forall a b. TermTypeM a -> TermTypeM b -> TermTypeM b
forall a b. TermTypeM a -> (a -> TermTypeM b) -> TermTypeM b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: a -> TermTypeM a
$creturn :: forall a. a -> TermTypeM a
>> :: TermTypeM a -> TermTypeM b -> TermTypeM b
$c>> :: forall a b. TermTypeM a -> TermTypeM b -> TermTypeM b
>>= :: TermTypeM a -> (a -> TermTypeM b) -> TermTypeM b
$c>>= :: forall a b. TermTypeM a -> (a -> TermTypeM b) -> TermTypeM b
$cp1Monad :: Applicative TermTypeM
Monad,
      a -> TermTypeM b -> TermTypeM a
(a -> b) -> TermTypeM a -> TermTypeM b
(forall a b. (a -> b) -> TermTypeM a -> TermTypeM b)
-> (forall a b. a -> TermTypeM b -> TermTypeM a)
-> Functor TermTypeM
forall a b. a -> TermTypeM b -> TermTypeM a
forall a b. (a -> b) -> TermTypeM a -> TermTypeM b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> TermTypeM b -> TermTypeM a
$c<$ :: forall a b. a -> TermTypeM b -> TermTypeM a
fmap :: (a -> b) -> TermTypeM a -> TermTypeM b
$cfmap :: forall a b. (a -> b) -> TermTypeM a -> TermTypeM b
Functor,
      Functor TermTypeM
a -> TermTypeM a
Functor TermTypeM
-> (forall a. a -> TermTypeM a)
-> (forall a b. TermTypeM (a -> b) -> TermTypeM a -> TermTypeM b)
-> (forall a b c.
    (a -> b -> c) -> TermTypeM a -> TermTypeM b -> TermTypeM c)
-> (forall a b. TermTypeM a -> TermTypeM b -> TermTypeM b)
-> (forall a b. TermTypeM a -> TermTypeM b -> TermTypeM a)
-> Applicative TermTypeM
TermTypeM a -> TermTypeM b -> TermTypeM b
TermTypeM a -> TermTypeM b -> TermTypeM a
TermTypeM (a -> b) -> TermTypeM a -> TermTypeM b
(a -> b -> c) -> TermTypeM a -> TermTypeM b -> TermTypeM c
forall a. a -> TermTypeM a
forall a b. TermTypeM a -> TermTypeM b -> TermTypeM a
forall a b. TermTypeM a -> TermTypeM b -> TermTypeM b
forall a b. TermTypeM (a -> b) -> TermTypeM a -> TermTypeM b
forall a b c.
(a -> b -> c) -> TermTypeM a -> TermTypeM b -> TermTypeM c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: TermTypeM a -> TermTypeM b -> TermTypeM a
$c<* :: forall a b. TermTypeM a -> TermTypeM b -> TermTypeM a
*> :: TermTypeM a -> TermTypeM b -> TermTypeM b
$c*> :: forall a b. TermTypeM a -> TermTypeM b -> TermTypeM b
liftA2 :: (a -> b -> c) -> TermTypeM a -> TermTypeM b -> TermTypeM c
$cliftA2 :: forall a b c.
(a -> b -> c) -> TermTypeM a -> TermTypeM b -> TermTypeM c
<*> :: TermTypeM (a -> b) -> TermTypeM a -> TermTypeM b
$c<*> :: forall a b. TermTypeM (a -> b) -> TermTypeM a -> TermTypeM b
pure :: a -> TermTypeM a
$cpure :: forall a. a -> TermTypeM a
$cp1Applicative :: Functor TermTypeM
Applicative,
      MonadReader TermEnv,
      MonadWriter Occurences,
      MonadState TermTypeState,
      MonadError TypeError
    )

instance MonadUnify TermTypeM where
  getConstraints :: TermTypeM Constraints
getConstraints = (TermTypeState -> Constraints) -> TermTypeM Constraints
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets TermTypeState -> Constraints
stateConstraints
  putConstraints :: Constraints -> TermTypeM ()
putConstraints Constraints
x = (TermTypeState -> TermTypeState) -> TermTypeM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((TermTypeState -> TermTypeState) -> TermTypeM ())
-> (TermTypeState -> TermTypeState) -> TermTypeM ()
forall a b. (a -> b) -> a -> b
$ \TermTypeState
s -> TermTypeState
s {stateConstraints :: Constraints
stateConstraints = Constraints
x}

  newTypeVar :: SrcLoc -> String -> TermTypeM (TypeBase dim als)
newTypeVar SrcLoc
loc String
desc = do
    Int
i <- TermTypeM Int
incCounter
    VName
v <- Name -> TermTypeM VName
forall (m :: * -> *). MonadTypeChecker m => Name -> m VName
newID (Name -> TermTypeM VName) -> Name -> TermTypeM VName
forall a b. (a -> b) -> a -> b
$ String -> Int -> Name
mkTypeVarName String
desc Int
i
    VName -> Constraint -> TermTypeM ()
constrain VName
v (Constraint -> TermTypeM ()) -> Constraint -> TermTypeM ()
forall a b. (a -> b) -> a -> b
$ Liftedness -> Usage -> Constraint
NoConstraint Liftedness
Lifted (Usage -> Constraint) -> Usage -> Constraint
forall a b. (a -> b) -> a -> b
$ SrcLoc -> Usage
mkUsage' SrcLoc
loc
    TypeBase dim als -> TermTypeM (TypeBase dim als)
forall (m :: * -> *) a. Monad m => a -> m a
return (TypeBase dim als -> TermTypeM (TypeBase dim als))
-> TypeBase dim als -> TermTypeM (TypeBase dim als)
forall a b. (a -> b) -> a -> b
$ ScalarTypeBase dim als -> TypeBase dim als
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (ScalarTypeBase dim als -> TypeBase dim als)
-> ScalarTypeBase dim als -> TypeBase dim als
forall a b. (a -> b) -> a -> b
$ als
-> Uniqueness
-> TypeName
-> [TypeArg dim]
-> ScalarTypeBase dim als
forall dim as.
as
-> Uniqueness -> TypeName -> [TypeArg dim] -> ScalarTypeBase dim as
TypeVar als
forall a. Monoid a => a
mempty Uniqueness
Nonunique (VName -> TypeName
typeName VName
v) []

  curLevel :: TermTypeM Int
curLevel = (TermEnv -> Int) -> TermTypeM Int
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks TermEnv -> Int
termLevel

  newDimVar :: SrcLoc -> Rigidity -> String -> TermTypeM VName
newDimVar SrcLoc
loc Rigidity
rigidity String
name = do
    Int
i <- TermTypeM Int
incCounter
    VName
dim <- Name -> TermTypeM VName
forall (m :: * -> *). MonadTypeChecker m => Name -> m VName
newID (Name -> TermTypeM VName) -> Name -> TermTypeM VName
forall a b. (a -> b) -> a -> b
$ String -> Int -> Name
mkTypeVarName String
name Int
i
    case Rigidity
rigidity of
      Rigid RigidSource
rsrc -> VName -> Constraint -> TermTypeM ()
constrain VName
dim (Constraint -> TermTypeM ()) -> Constraint -> TermTypeM ()
forall a b. (a -> b) -> a -> b
$ SrcLoc -> RigidSource -> Constraint
UnknowableSize SrcLoc
loc RigidSource
rsrc
      Rigidity
Nonrigid -> VName -> Constraint -> TermTypeM ()
constrain VName
dim (Constraint -> TermTypeM ()) -> Constraint -> TermTypeM ()
forall a b. (a -> b) -> a -> b
$ Maybe (DimDecl VName) -> Usage -> Constraint
Size Maybe (DimDecl VName)
forall a. Maybe a
Nothing (Usage -> Constraint) -> Usage -> Constraint
forall a b. (a -> b) -> a -> b
$ SrcLoc -> Usage
mkUsage' SrcLoc
loc
    VName -> TermTypeM VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
dim

  unifyError :: loc -> Notes -> BreadCrumbs -> Doc -> TermTypeM a
unifyError loc
loc Notes
notes BreadCrumbs
bcs Doc
doc = do
    Maybe Checking
checking <- (TermEnv -> Maybe Checking) -> TermTypeM (Maybe Checking)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks TermEnv -> Maybe Checking
termChecking
    case Maybe Checking
checking of
      Just Checking
checking' ->
        TypeError -> TermTypeM a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (TypeError -> TermTypeM a) -> TypeError -> TermTypeM a
forall a b. (a -> b) -> a -> b
$
          SrcLoc -> Notes -> Doc -> TypeError
TypeError (loc -> SrcLoc
forall a. Located a => a -> SrcLoc
srclocOf loc
loc) Notes
notes (Doc -> TypeError) -> Doc -> TypeError
forall a b. (a -> b) -> a -> b
$
            Checking -> Doc
forall a. Pretty a => a -> Doc
ppr Checking
checking' Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
line Doc -> Doc -> Doc
</> Doc
doc Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> BreadCrumbs -> Doc
forall a. Pretty a => a -> Doc
ppr BreadCrumbs
bcs
      Maybe Checking
Nothing ->
        TypeError -> TermTypeM a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (TypeError -> TermTypeM a) -> TypeError -> TermTypeM a
forall a b. (a -> b) -> a -> b
$ SrcLoc -> Notes -> Doc -> TypeError
TypeError (loc -> SrcLoc
forall a. Located a => a -> SrcLoc
srclocOf loc
loc) Notes
notes (Doc -> TypeError) -> Doc -> TypeError
forall a b. (a -> b) -> a -> b
$ Doc
doc Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> BreadCrumbs -> Doc
forall a. Pretty a => a -> Doc
ppr BreadCrumbs
bcs

  matchError :: loc
-> Notes -> BreadCrumbs -> StructType -> StructType -> TermTypeM a
matchError loc
loc Notes
notes BreadCrumbs
bcs StructType
t1 StructType
t2 = do
    Maybe Checking
checking <- (TermEnv -> Maybe Checking) -> TermTypeM (Maybe Checking)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks TermEnv -> Maybe Checking
termChecking
    case Maybe Checking
checking of
      Just Checking
checking'
        | BreadCrumbs -> Bool
hasNoBreadCrumbs BreadCrumbs
bcs ->
          TypeError -> TermTypeM a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (TypeError -> TermTypeM a) -> TypeError -> TermTypeM a
forall a b. (a -> b) -> a -> b
$
            SrcLoc -> Notes -> Doc -> TypeError
TypeError (loc -> SrcLoc
forall a. Located a => a -> SrcLoc
srclocOf loc
loc) Notes
notes (Doc -> TypeError) -> Doc -> TypeError
forall a b. (a -> b) -> a -> b
$
              Checking -> Doc
forall a. Pretty a => a -> Doc
ppr Checking
checking'
        | Bool
otherwise ->
          TypeError -> TermTypeM a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (TypeError -> TermTypeM a) -> TypeError -> TermTypeM a
forall a b. (a -> b) -> a -> b
$
            SrcLoc -> Notes -> Doc -> TypeError
TypeError (loc -> SrcLoc
forall a. Located a => a -> SrcLoc
srclocOf loc
loc) Notes
notes (Doc -> TypeError) -> Doc -> TypeError
forall a b. (a -> b) -> a -> b
$
              Checking -> Doc
forall a. Pretty a => a -> Doc
ppr Checking
checking' Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
line Doc -> Doc -> Doc
</> Doc
doc Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> BreadCrumbs -> Doc
forall a. Pretty a => a -> Doc
ppr BreadCrumbs
bcs
      Maybe Checking
Nothing ->
        TypeError -> TermTypeM a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (TypeError -> TermTypeM a) -> TypeError -> TermTypeM a
forall a b. (a -> b) -> a -> b
$ SrcLoc -> Notes -> Doc -> TypeError
TypeError (loc -> SrcLoc
forall a. Located a => a -> SrcLoc
srclocOf loc
loc) Notes
notes (Doc -> TypeError) -> Doc -> TypeError
forall a b. (a -> b) -> a -> b
$ Doc
doc Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> BreadCrumbs -> Doc
forall a. Pretty a => a -> Doc
ppr BreadCrumbs
bcs
    where
      doc :: Doc
doc =
        Doc
"Types"
          Doc -> Doc -> Doc
</> Int -> Doc -> Doc
indent Int
2 (StructType -> Doc
forall a. Pretty a => a -> Doc
ppr StructType
t1)
          Doc -> Doc -> Doc
</> Doc
"and"
          Doc -> Doc -> Doc
</> Int -> Doc -> Doc
indent Int
2 (StructType -> Doc
forall a. Pretty a => a -> Doc
ppr StructType
t2)
          Doc -> Doc -> Doc
</> Doc
"do not match."

onFailure :: Checking -> TermTypeM a -> TermTypeM a
onFailure :: Checking -> TermTypeM a -> TermTypeM a
onFailure Checking
c = (TermEnv -> TermEnv) -> TermTypeM a -> TermTypeM a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((TermEnv -> TermEnv) -> TermTypeM a -> TermTypeM a)
-> (TermEnv -> TermEnv) -> TermTypeM a -> TermTypeM a
forall a b. (a -> b) -> a -> b
$ \TermEnv
env -> TermEnv
env {termChecking :: Maybe Checking
termChecking = Checking -> Maybe Checking
forall a. a -> Maybe a
Just Checking
c}

runTermTypeM :: TermTypeM a -> TypeM (a, Occurences)
runTermTypeM :: TermTypeM a -> TypeM (a, [Occurence])
runTermTypeM (TermTypeM RWST TermEnv [Occurence] TermTypeState TypeM a
m) = do
  TermScope
initial_scope <- (TermScope
initialTermScope TermScope -> TermScope -> TermScope
forall a. Semigroup a => a -> a -> a
<>) (TermScope -> TermScope) -> (Env -> TermScope) -> Env -> TermScope
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env -> TermScope
envToTermScope (Env -> TermScope) -> TypeM Env -> TypeM TermScope
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TypeM Env
askEnv
  let initial_tenv :: TermEnv
initial_tenv =
        TermEnv :: TermScope -> Maybe Checking -> Int -> TermEnv
TermEnv
          { termScope :: TermScope
termScope = TermScope
initial_scope,
            termChecking :: Maybe Checking
termChecking = Maybe Checking
forall a. Maybe a
Nothing,
            termLevel :: Int
termLevel = Int
0
          }
  RWST TermEnv [Occurence] TermTypeState TypeM a
-> TermEnv -> TermTypeState -> TypeM (a, [Occurence])
forall (m :: * -> *) r w s a.
Monad m =>
RWST r w s m a -> r -> s -> m (a, w)
evalRWST RWST TermEnv [Occurence] TermTypeState TypeM a
m TermEnv
initial_tenv (TermTypeState -> TypeM (a, [Occurence]))
-> TermTypeState -> TypeM (a, [Occurence])
forall a b. (a -> b) -> a -> b
$ Constraints -> Int -> Map SizeSource VName -> TermTypeState
TermTypeState Constraints
forall a. Monoid a => a
mempty Int
0 Map SizeSource VName
forall a. Monoid a => a
mempty

liftTypeM :: TypeM a -> TermTypeM a
liftTypeM :: TypeM a -> TermTypeM a
liftTypeM = RWST TermEnv [Occurence] TermTypeState TypeM a -> TermTypeM a
forall a.
RWST TermEnv [Occurence] TermTypeState TypeM a -> TermTypeM a
TermTypeM (RWST TermEnv [Occurence] TermTypeState TypeM a -> TermTypeM a)
-> (TypeM a -> RWST TermEnv [Occurence] TermTypeState TypeM a)
-> TypeM a
-> TermTypeM a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TypeM a -> RWST TermEnv [Occurence] TermTypeState TypeM a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift

localScope :: (TermScope -> TermScope) -> TermTypeM a -> TermTypeM a
localScope :: (TermScope -> TermScope) -> TermTypeM a -> TermTypeM a
localScope TermScope -> TermScope
f = (TermEnv -> TermEnv) -> TermTypeM a -> TermTypeM a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((TermEnv -> TermEnv) -> TermTypeM a -> TermTypeM a)
-> (TermEnv -> TermEnv) -> TermTypeM a -> TermTypeM a
forall a b. (a -> b) -> a -> b
$ \TermEnv
tenv -> TermEnv
tenv {termScope :: TermScope
termScope = TermScope -> TermScope
f (TermScope -> TermScope) -> TermScope -> TermScope
forall a b. (a -> b) -> a -> b
$ TermEnv -> TermScope
termScope TermEnv
tenv}

incCounter :: TermTypeM Int
incCounter :: TermTypeM Int
incCounter = do
  TermTypeState
s <- TermTypeM TermTypeState
forall s (m :: * -> *). MonadState s m => m s
get
  TermTypeState -> TermTypeM ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put TermTypeState
s {stateCounter :: Int
stateCounter = TermTypeState -> Int
stateCounter TermTypeState
s Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1}
  Int -> TermTypeM Int
forall (m :: * -> *) a. Monad m => a -> m a
return (Int -> TermTypeM Int) -> Int -> TermTypeM Int
forall a b. (a -> b) -> a -> b
$ TermTypeState -> Int
stateCounter TermTypeState
s

extSize :: SrcLoc -> SizeSource -> TermTypeM (DimDecl VName, Maybe VName)
extSize :: SrcLoc -> SizeSource -> TermTypeM (DimDecl VName, Maybe VName)
extSize SrcLoc
loc SizeSource
e = do
  Maybe VName
prev <- (TermTypeState -> Maybe VName) -> TermTypeM (Maybe VName)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((TermTypeState -> Maybe VName) -> TermTypeM (Maybe VName))
-> (TermTypeState -> Maybe VName) -> TermTypeM (Maybe VName)
forall a b. (a -> b) -> a -> b
$ SizeSource -> Map SizeSource VName -> Maybe VName
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup SizeSource
e (Map SizeSource VName -> Maybe VName)
-> (TermTypeState -> Map SizeSource VName)
-> TermTypeState
-> Maybe VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TermTypeState -> Map SizeSource VName
stateDimTable
  case Maybe VName
prev of
    Maybe VName
Nothing -> do
      let rsrc :: RigidSource
rsrc = case SizeSource
e of
            SourceArg (FName Maybe (QualName VName)
fname) ExpBase NoInfo VName
e' ->
              Maybe (QualName VName) -> String -> RigidSource
RigidArg Maybe (QualName VName)
fname (String -> RigidSource) -> String -> RigidSource
forall a b. (a -> b) -> a -> b
$ ExpBase NoInfo VName -> String
forall a. Pretty a => a -> String
prettyOneLine ExpBase NoInfo VName
e'
            SourceBound ExpBase NoInfo VName
e' ->
              String -> RigidSource
RigidBound (String -> RigidSource) -> String -> RigidSource
forall a b. (a -> b) -> a -> b
$ ExpBase NoInfo VName -> String
forall a. Pretty a => a -> String
prettyOneLine ExpBase NoInfo VName
e'
            SourceSlice Maybe (DimDecl VName)
d Maybe (ExpBase NoInfo VName)
i Maybe (ExpBase NoInfo VName)
j Maybe (ExpBase NoInfo VName)
s ->
              Maybe (DimDecl VName) -> String -> RigidSource
RigidSlice Maybe (DimDecl VName)
d (String -> RigidSource) -> String -> RigidSource
forall a b. (a -> b) -> a -> b
$ DimIndexBase NoInfo VName -> String
forall a. Pretty a => a -> String
prettyOneLine (DimIndexBase NoInfo VName -> String)
-> DimIndexBase NoInfo VName -> String
forall a b. (a -> b) -> a -> b
$ Maybe (ExpBase NoInfo VName)
-> Maybe (ExpBase NoInfo VName)
-> Maybe (ExpBase NoInfo VName)
-> DimIndexBase NoInfo VName
forall (f :: * -> *) vn.
Maybe (ExpBase f vn)
-> Maybe (ExpBase f vn)
-> Maybe (ExpBase f vn)
-> DimIndexBase f vn
DimSlice Maybe (ExpBase NoInfo VName)
i Maybe (ExpBase NoInfo VName)
j Maybe (ExpBase NoInfo VName)
s
      VName
d <- SrcLoc -> Rigidity -> String -> TermTypeM VName
forall (m :: * -> *).
MonadUnify m =>
SrcLoc -> Rigidity -> String -> m VName
newDimVar SrcLoc
loc (RigidSource -> Rigidity
Rigid RigidSource
rsrc) String
"argdim"
      (TermTypeState -> TermTypeState) -> TermTypeM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((TermTypeState -> TermTypeState) -> TermTypeM ())
-> (TermTypeState -> TermTypeState) -> TermTypeM ()
forall a b. (a -> b) -> a -> b
$ \TermTypeState
s -> TermTypeState
s {stateDimTable :: Map SizeSource VName
stateDimTable = SizeSource -> VName -> Map SizeSource VName -> Map SizeSource VName
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert SizeSource
e VName
d (Map SizeSource VName -> Map SizeSource VName)
-> Map SizeSource VName -> Map SizeSource VName
forall a b. (a -> b) -> a -> b
$ TermTypeState -> Map SizeSource VName
stateDimTable TermTypeState
s}
      (DimDecl VName, Maybe VName)
-> TermTypeM (DimDecl VName, Maybe VName)
forall (m :: * -> *) a. Monad m => a -> m a
return
        ( QualName VName -> DimDecl VName
forall vn. QualName vn -> DimDecl vn
NamedDim (QualName VName -> DimDecl VName)
-> QualName VName -> DimDecl VName
forall a b. (a -> b) -> a -> b
$ VName -> QualName VName
forall v. v -> QualName v
qualName VName
d,
          VName -> Maybe VName
forall a. a -> Maybe a
Just VName
d
        )
    Just VName
d ->
      (DimDecl VName, Maybe VName)
-> TermTypeM (DimDecl VName, Maybe VName)
forall (m :: * -> *) a. Monad m => a -> m a
return
        ( QualName VName -> DimDecl VName
forall vn. QualName vn -> DimDecl vn
NamedDim (QualName VName -> DimDecl VName)
-> QualName VName -> DimDecl VName
forall a b. (a -> b) -> a -> b
$ VName -> QualName VName
forall v. v -> QualName v
qualName VName
d,
          Maybe VName
forall a. Maybe a
Nothing
        )

-- Any argument sizes created with 'extSize' inside the given action
-- will be removed once the action finishes.  This is to ensure that
-- just because e.g. @n+1@ appears as a size in one branch of a
-- conditional, that doesn't mean it's also available in the other branch.
noSizeEscape :: TermTypeM a -> TermTypeM a
noSizeEscape :: TermTypeM a -> TermTypeM a
noSizeEscape TermTypeM a
m = do
  Map SizeSource VName
dimtable <- (TermTypeState -> Map SizeSource VName)
-> TermTypeM (Map SizeSource VName)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets TermTypeState -> Map SizeSource VName
stateDimTable
  a
x <- TermTypeM a
m
  (TermTypeState -> TermTypeState) -> TermTypeM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((TermTypeState -> TermTypeState) -> TermTypeM ())
-> (TermTypeState -> TermTypeState) -> TermTypeM ()
forall a b. (a -> b) -> a -> b
$ \TermTypeState
s -> TermTypeState
s {stateDimTable :: Map SizeSource VName
stateDimTable = Map SizeSource VName
dimtable}
  a -> TermTypeM a
forall (m :: * -> *) a. Monad m => a -> m a
return a
x

constrain :: VName -> Constraint -> TermTypeM ()
constrain :: VName -> Constraint -> TermTypeM ()
constrain VName
v Constraint
c = do
  Int
lvl <- TermTypeM Int
forall (m :: * -> *). MonadUnify m => m Int
curLevel
  (Constraints -> Constraints) -> TermTypeM ()
forall (m :: * -> *).
MonadUnify m =>
(Constraints -> Constraints) -> m ()
modifyConstraints ((Constraints -> Constraints) -> TermTypeM ())
-> (Constraints -> Constraints) -> TermTypeM ()
forall a b. (a -> b) -> a -> b
$ VName -> (Int, Constraint) -> Constraints -> Constraints
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v (Int
lvl, Constraint
c)

incLevel :: TermTypeM a -> TermTypeM a
incLevel :: TermTypeM a -> TermTypeM a
incLevel = (TermEnv -> TermEnv) -> TermTypeM a -> TermTypeM a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((TermEnv -> TermEnv) -> TermTypeM a -> TermTypeM a)
-> (TermEnv -> TermEnv) -> TermTypeM a -> TermTypeM a
forall a b. (a -> b) -> a -> b
$ \TermEnv
env -> TermEnv
env {termLevel :: Int
termLevel = TermEnv -> Int
termLevel TermEnv
env Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1}

initialTermScope :: TermScope
initialTermScope :: TermScope
initialTermScope =
  TermScope :: Map VName ValBinding
-> Map VName TypeBinding -> Map VName Mod -> NameMap -> TermScope
TermScope
    { scopeVtable :: Map VName ValBinding
scopeVtable = Map VName ValBinding
initialVtable,
      scopeTypeTable :: Map VName TypeBinding
scopeTypeTable = Map VName TypeBinding
forall a. Monoid a => a
mempty,
      scopeNameMap :: NameMap
scopeNameMap = NameMap
topLevelNameMap,
      scopeModTable :: Map VName Mod
scopeModTable = Map VName Mod
forall a. Monoid a => a
mempty
    }
  where
    initialVtable :: Map VName ValBinding
initialVtable = [(VName, ValBinding)] -> Map VName ValBinding
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, ValBinding)] -> Map VName ValBinding)
-> [(VName, ValBinding)] -> Map VName ValBinding
forall a b. (a -> b) -> a -> b
$ ((VName, Intrinsic) -> Maybe (VName, ValBinding))
-> [(VName, Intrinsic)] -> [(VName, ValBinding)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (VName, Intrinsic) -> Maybe (VName, ValBinding)
forall a. (a, Intrinsic) -> Maybe (a, ValBinding)
addIntrinsicF ([(VName, Intrinsic)] -> [(VName, ValBinding)])
-> [(VName, Intrinsic)] -> [(VName, ValBinding)]
forall a b. (a -> b) -> a -> b
$ Map VName Intrinsic -> [(VName, Intrinsic)]
forall k a. Map k a -> [(k, a)]
M.toList Map VName Intrinsic
intrinsics

    prim :: PrimType -> TypeBase dim as
prim = ScalarTypeBase dim as -> TypeBase dim as
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (ScalarTypeBase dim as -> TypeBase dim as)
-> (PrimType -> ScalarTypeBase dim as)
-> PrimType
-> TypeBase dim as
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> ScalarTypeBase dim as
forall dim as. PrimType -> ScalarTypeBase dim as
Prim
    arrow :: TypeBase dim as -> TypeBase dim as -> TypeBase dim as
arrow TypeBase dim as
x TypeBase dim as
y = ScalarTypeBase dim as -> TypeBase dim as
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (ScalarTypeBase dim as -> TypeBase dim as)
-> ScalarTypeBase dim as -> TypeBase dim as
forall a b. (a -> b) -> a -> b
$ as
-> PName
-> TypeBase dim as
-> TypeBase dim as
-> ScalarTypeBase dim as
forall dim as.
as
-> PName
-> TypeBase dim as
-> TypeBase dim as
-> ScalarTypeBase dim as
Arrow as
forall a. Monoid a => a
mempty PName
Unnamed TypeBase dim as
x TypeBase dim as
y

    addIntrinsicF :: (a, Intrinsic) -> Maybe (a, ValBinding)
addIntrinsicF (a
name, IntrinsicMonoFun [PrimType]
pts PrimType
t) =
      (a, ValBinding) -> Maybe (a, ValBinding)
forall a. a -> Maybe a
Just (a
name, Locality -> [TypeParam] -> PatternType -> ValBinding
BoundV Locality
Global [] (PatternType -> ValBinding) -> PatternType -> ValBinding
forall a b. (a -> b) -> a -> b
$ PatternType -> PatternType -> PatternType
forall as dim.
Monoid as =>
TypeBase dim as -> TypeBase dim as -> TypeBase dim as
arrow PatternType
forall dim as. TypeBase dim as
pts' (PatternType -> PatternType) -> PatternType -> PatternType
forall a b. (a -> b) -> a -> b
$ PrimType -> PatternType
forall dim as. PrimType -> TypeBase dim as
prim PrimType
t)
      where
        pts' :: TypeBase dim as
pts' = case [PrimType]
pts of
          [PrimType
pt] -> PrimType -> TypeBase dim as
forall dim as. PrimType -> TypeBase dim as
prim PrimType
pt
          [PrimType]
_ -> [TypeBase dim as] -> TypeBase dim as
forall dim as. [TypeBase dim as] -> TypeBase dim as
tupleRecord ([TypeBase dim as] -> TypeBase dim as)
-> [TypeBase dim as] -> TypeBase dim as
forall a b. (a -> b) -> a -> b
$ (PrimType -> TypeBase dim as) -> [PrimType] -> [TypeBase dim as]
forall a b. (a -> b) -> [a] -> [b]
map PrimType -> TypeBase dim as
forall dim as. PrimType -> TypeBase dim as
prim [PrimType]
pts
    addIntrinsicF (a
name, IntrinsicOverloadedFun [PrimType]
ts [Maybe PrimType]
pts Maybe PrimType
rts) =
      (a, ValBinding) -> Maybe (a, ValBinding)
forall a. a -> Maybe a
Just (a
name, [PrimType] -> [Maybe PrimType] -> Maybe PrimType -> ValBinding
OverloadedF [PrimType]
ts [Maybe PrimType]
pts Maybe PrimType
rts)
    addIntrinsicF (a
name, IntrinsicPolyFun [TypeParam]
tvs [StructType]
pts StructType
rt) =
      (a, ValBinding) -> Maybe (a, ValBinding)
forall a. a -> Maybe a
Just
        ( a
name,
          Locality -> [TypeParam] -> PatternType -> ValBinding
BoundV Locality
Global [TypeParam]
tvs (PatternType -> ValBinding) -> PatternType -> ValBinding
forall a b. (a -> b) -> a -> b
$
            StructType -> PatternType
forall dim as. TypeBase dim as -> TypeBase dim Aliasing
fromStruct (StructType -> PatternType) -> StructType -> PatternType
forall a b. (a -> b) -> a -> b
$ ScalarTypeBase (DimDecl VName) () -> StructType
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (ScalarTypeBase (DimDecl VName) () -> StructType)
-> ScalarTypeBase (DimDecl VName) () -> StructType
forall a b. (a -> b) -> a -> b
$ ()
-> PName
-> StructType
-> StructType
-> ScalarTypeBase (DimDecl VName) ()
forall dim as.
as
-> PName
-> TypeBase dim as
-> TypeBase dim as
-> ScalarTypeBase dim as
Arrow ()
forall a. Monoid a => a
mempty PName
Unnamed StructType
pts' StructType
rt
        )
      where
        pts' :: StructType
pts' = case [StructType]
pts of
          [StructType
pt] -> StructType
pt
          [StructType]
_ -> [StructType] -> StructType
forall dim as. [TypeBase dim as] -> TypeBase dim as
tupleRecord [StructType]
pts
    addIntrinsicF (a
name, Intrinsic
IntrinsicEquality) =
      (a, ValBinding) -> Maybe (a, ValBinding)
forall a. a -> Maybe a
Just (a
name, ValBinding
EqualityF)
    addIntrinsicF (a, Intrinsic)
_ = Maybe (a, ValBinding)
forall a. Maybe a
Nothing

instance MonadTypeChecker TermTypeM where
  warn :: loc -> String -> TermTypeM ()
warn loc
loc String
problem = TypeM () -> TermTypeM ()
forall a. TypeM a -> TermTypeM a
liftTypeM (TypeM () -> TermTypeM ()) -> TypeM () -> TermTypeM ()
forall a b. (a -> b) -> a -> b
$ loc -> String -> TypeM ()
forall (m :: * -> *) loc.
(MonadTypeChecker m, Located loc) =>
loc -> String -> m ()
warn loc
loc String
problem
  newName :: VName -> TermTypeM VName
newName = TypeM VName -> TermTypeM VName
forall a. TypeM a -> TermTypeM a
liftTypeM (TypeM VName -> TermTypeM VName)
-> (VName -> TypeM VName) -> VName -> TermTypeM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> TypeM VName
forall (m :: * -> *). MonadTypeChecker m => VName -> m VName
newName
  newID :: Name -> TermTypeM VName
newID = TypeM VName -> TermTypeM VName
forall a. TypeM a -> TermTypeM a
liftTypeM (TypeM VName -> TermTypeM VName)
-> (Name -> TypeM VName) -> Name -> TermTypeM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Name -> TypeM VName
forall (m :: * -> *). MonadTypeChecker m => Name -> m VName
newID

  checkQualName :: Namespace -> QualName Name -> SrcLoc -> TermTypeM (QualName VName)
checkQualName Namespace
space QualName Name
name SrcLoc
loc = (TermScope, QualName VName) -> QualName VName
forall a b. (a, b) -> b
snd ((TermScope, QualName VName) -> QualName VName)
-> TermTypeM (TermScope, QualName VName)
-> TermTypeM (QualName VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Namespace
-> QualName Name -> SrcLoc -> TermTypeM (TermScope, QualName VName)
checkQualNameWithEnv Namespace
space QualName Name
name SrcLoc
loc

  bindNameMap :: NameMap -> TermTypeM a -> TermTypeM a
bindNameMap NameMap
m = (TermScope -> TermScope) -> TermTypeM a -> TermTypeM a
forall a. (TermScope -> TermScope) -> TermTypeM a -> TermTypeM a
localScope ((TermScope -> TermScope) -> TermTypeM a -> TermTypeM a)
-> (TermScope -> TermScope) -> TermTypeM a -> TermTypeM a
forall a b. (a -> b) -> a -> b
$ \TermScope
scope ->
    TermScope
scope {scopeNameMap :: NameMap
scopeNameMap = NameMap
m NameMap -> NameMap -> NameMap
forall a. Semigroup a => a -> a -> a
<> TermScope -> NameMap
scopeNameMap TermScope
scope}

  bindVal :: VName -> BoundV -> TermTypeM a -> TermTypeM a
bindVal VName
v (TypeM.BoundV [TypeParam]
tps StructType
t) = (TermScope -> TermScope) -> TermTypeM a -> TermTypeM a
forall a. (TermScope -> TermScope) -> TermTypeM a -> TermTypeM a
localScope ((TermScope -> TermScope) -> TermTypeM a -> TermTypeM a)
-> (TermScope -> TermScope) -> TermTypeM a -> TermTypeM a
forall a b. (a -> b) -> a -> b
$ \TermScope
scope ->
    TermScope
scope {scopeVtable :: Map VName ValBinding
scopeVtable = VName -> ValBinding -> Map VName ValBinding -> Map VName ValBinding
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v ValBinding
vb (Map VName ValBinding -> Map VName ValBinding)
-> Map VName ValBinding -> Map VName ValBinding
forall a b. (a -> b) -> a -> b
$ TermScope -> Map VName ValBinding
scopeVtable TermScope
scope}
    where
      vb :: ValBinding
vb = Locality -> [TypeParam] -> PatternType -> ValBinding
BoundV Locality
Local [TypeParam]
tps (PatternType -> ValBinding) -> PatternType -> ValBinding
forall a b. (a -> b) -> a -> b
$ StructType -> PatternType
forall dim as. TypeBase dim as -> TypeBase dim Aliasing
fromStruct StructType
t

  lookupType :: SrcLoc
-> QualName Name
-> TermTypeM (QualName VName, [TypeParam], StructType, Liftedness)
lookupType SrcLoc
loc QualName Name
qn = do
    Env
outer_env <- TypeM Env -> TermTypeM Env
forall a. TypeM a -> TermTypeM a
liftTypeM TypeM Env
askEnv
    (TermScope
scope, qn' :: QualName VName
qn'@(QualName [VName]
qs VName
name)) <- Namespace
-> QualName Name -> SrcLoc -> TermTypeM (TermScope, QualName VName)
checkQualNameWithEnv Namespace
Type QualName Name
qn SrcLoc
loc
    case VName -> Map VName TypeBinding -> Maybe TypeBinding
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
name (Map VName TypeBinding -> Maybe TypeBinding)
-> Map VName TypeBinding -> Maybe TypeBinding
forall a b. (a -> b) -> a -> b
$ TermScope -> Map VName TypeBinding
scopeTypeTable TermScope
scope of
      Maybe TypeBinding
Nothing -> SrcLoc
-> QualName Name
-> TermTypeM (QualName VName, [TypeParam], StructType, Liftedness)
forall (m :: * -> *) a.
MonadTypeChecker m =>
SrcLoc -> QualName Name -> m a
unknownType SrcLoc
loc QualName Name
qn
      Just (TypeAbbr Liftedness
l [TypeParam]
ps StructType
def) ->
        (QualName VName, [TypeParam], StructType, Liftedness)
-> TermTypeM (QualName VName, [TypeParam], StructType, Liftedness)
forall (m :: * -> *) a. Monad m => a -> m a
return (QualName VName
qn', [TypeParam]
ps, Env -> [VName] -> [VName] -> StructType -> StructType
forall as.
Env
-> [VName]
-> [VName]
-> TypeBase (DimDecl VName) as
-> TypeBase (DimDecl VName) as
qualifyTypeVars Env
outer_env ((TypeParam -> VName) -> [TypeParam] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map TypeParam -> VName
forall vn. TypeParamBase vn -> vn
typeParamName [TypeParam]
ps) [VName]
qs StructType
def, Liftedness
l)

  lookupMod :: SrcLoc -> QualName Name -> TermTypeM (QualName VName, Mod)
lookupMod SrcLoc
loc QualName Name
qn = do
    (TermScope
scope, qn' :: QualName VName
qn'@(QualName [VName]
_ VName
name)) <- Namespace
-> QualName Name -> SrcLoc -> TermTypeM (TermScope, QualName VName)
checkQualNameWithEnv Namespace
Term QualName Name
qn SrcLoc
loc
    case VName -> Map VName Mod -> Maybe Mod
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
name (Map VName Mod -> Maybe Mod) -> Map VName Mod -> Maybe Mod
forall a b. (a -> b) -> a -> b
$ TermScope -> Map VName Mod
scopeModTable TermScope
scope of
      Maybe Mod
Nothing -> Namespace
-> QualName Name -> SrcLoc -> TermTypeM (QualName VName, Mod)
forall (m :: * -> *) a.
MonadTypeChecker m =>
Namespace -> QualName Name -> SrcLoc -> m a
unknownVariable Namespace
Term QualName Name
qn SrcLoc
loc
      Just Mod
m -> (QualName VName, Mod) -> TermTypeM (QualName VName, Mod)
forall (m :: * -> *) a. Monad m => a -> m a
return (QualName VName
qn', Mod
m)

  lookupVar :: SrcLoc -> QualName Name -> TermTypeM (QualName VName, PatternType)
lookupVar SrcLoc
loc QualName Name
qn = do
    Env
outer_env <- TypeM Env -> TermTypeM Env
forall a. TypeM a -> TermTypeM a
liftTypeM TypeM Env
askEnv
    (TermScope
scope, qn' :: QualName VName
qn'@(QualName [VName]
qs VName
name)) <- Namespace
-> QualName Name -> SrcLoc -> TermTypeM (TermScope, QualName VName)
checkQualNameWithEnv Namespace
Term QualName Name
qn SrcLoc
loc
    let usage :: Usage
usage = SrcLoc -> String -> Usage
mkUsage SrcLoc
loc (String -> Usage) -> String -> Usage
forall a b. (a -> b) -> a -> b
$ String
"use of " String -> ShowS
forall a. [a] -> [a] -> [a]
++ ShowS
quote (QualName Name -> String
forall a. Pretty a => a -> String
pretty QualName Name
qn)

    PatternType
t <- case VName -> Map VName ValBinding -> Maybe ValBinding
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
name (Map VName ValBinding -> Maybe ValBinding)
-> Map VName ValBinding -> Maybe ValBinding
forall a b. (a -> b) -> a -> b
$ TermScope -> Map VName ValBinding
scopeVtable TermScope
scope of
      Maybe ValBinding
Nothing ->
        SrcLoc -> Notes -> Doc -> TermTypeM PatternType
forall (m :: * -> *) loc a.
(MonadTypeChecker m, Located loc) =>
loc -> Notes -> Doc -> m a
typeError SrcLoc
loc Notes
forall a. Monoid a => a
mempty (Doc -> TermTypeM PatternType) -> Doc -> TermTypeM PatternType
forall a b. (a -> b) -> a -> b
$
          Doc
"Unknown variable" Doc -> Doc -> Doc
<+> Doc -> Doc
pquote (QualName Name -> Doc
forall a. Pretty a => a -> Doc
ppr QualName Name
qn) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
"."
      Just (WasConsumed SrcLoc
wloc) -> Name -> SrcLoc -> SrcLoc -> TermTypeM PatternType
forall a. Name -> SrcLoc -> SrcLoc -> TermTypeM a
useAfterConsume (VName -> Name
baseName VName
name) SrcLoc
loc SrcLoc
wloc
      Just (BoundV Locality
_ [TypeParam]
tparams PatternType
t)
        | String
"_" String -> String -> Bool
forall a. Eq a => [a] -> [a] -> Bool
`isPrefixOf` VName -> String
baseString VName
name -> SrcLoc -> QualName Name -> TermTypeM PatternType
forall (m :: * -> *) a.
MonadTypeChecker m =>
SrcLoc -> QualName Name -> m a
underscoreUse SrcLoc
loc QualName Name
qn
        | Bool
otherwise -> do
          ([VName]
tnames, PatternType
t') <- SrcLoc
-> [TypeParam] -> PatternType -> TermTypeM ([VName], PatternType)
instantiateTypeScheme SrcLoc
loc [TypeParam]
tparams PatternType
t
          PatternType -> TermTypeM PatternType
forall (m :: * -> *) a. Monad m => a -> m a
return (PatternType -> TermTypeM PatternType)
-> PatternType -> TermTypeM PatternType
forall a b. (a -> b) -> a -> b
$ Env -> [VName] -> [VName] -> PatternType -> PatternType
forall as.
Env
-> [VName]
-> [VName]
-> TypeBase (DimDecl VName) as
-> TypeBase (DimDecl VName) as
qualifyTypeVars Env
outer_env [VName]
tnames [VName]
qs PatternType
t'
      Just ValBinding
EqualityF -> do
        PatternType
argtype <- SrcLoc -> String -> TermTypeM PatternType
forall (m :: * -> *) als dim.
(MonadUnify m, Monoid als) =>
SrcLoc -> String -> m (TypeBase dim als)
newTypeVar SrcLoc
loc String
"t"
        Usage -> PatternType -> TermTypeM ()
forall (m :: * -> *) dim as.
(MonadUnify m, Pretty (ShapeDecl dim), Monoid as) =>
Usage -> TypeBase dim as -> m ()
equalityType Usage
usage PatternType
argtype
        PatternType -> TermTypeM PatternType
forall (m :: * -> *) a. Monad m => a -> m a
return (PatternType -> TermTypeM PatternType)
-> PatternType -> TermTypeM PatternType
forall a b. (a -> b) -> a -> b
$
          ScalarTypeBase (DimDecl VName) Aliasing -> PatternType
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (ScalarTypeBase (DimDecl VName) Aliasing -> PatternType)
-> ScalarTypeBase (DimDecl VName) Aliasing -> PatternType
forall a b. (a -> b) -> a -> b
$
            Aliasing
-> PName
-> PatternType
-> PatternType
-> ScalarTypeBase (DimDecl VName) Aliasing
forall dim as.
as
-> PName
-> TypeBase dim as
-> TypeBase dim as
-> ScalarTypeBase dim as
Arrow Aliasing
forall a. Monoid a => a
mempty PName
Unnamed PatternType
argtype (PatternType -> ScalarTypeBase (DimDecl VName) Aliasing)
-> PatternType -> ScalarTypeBase (DimDecl VName) Aliasing
forall a b. (a -> b) -> a -> b
$
              ScalarTypeBase (DimDecl VName) Aliasing -> PatternType
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (ScalarTypeBase (DimDecl VName) Aliasing -> PatternType)
-> ScalarTypeBase (DimDecl VName) Aliasing -> PatternType
forall a b. (a -> b) -> a -> b
$ Aliasing
-> PName
-> PatternType
-> PatternType
-> ScalarTypeBase (DimDecl VName) Aliasing
forall dim as.
as
-> PName
-> TypeBase dim as
-> TypeBase dim as
-> ScalarTypeBase dim as
Arrow Aliasing
forall a. Monoid a => a
mempty PName
Unnamed PatternType
argtype (PatternType -> ScalarTypeBase (DimDecl VName) Aliasing)
-> PatternType -> ScalarTypeBase (DimDecl VName) Aliasing
forall a b. (a -> b) -> a -> b
$ ScalarTypeBase (DimDecl VName) Aliasing -> PatternType
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (ScalarTypeBase (DimDecl VName) Aliasing -> PatternType)
-> ScalarTypeBase (DimDecl VName) Aliasing -> PatternType
forall a b. (a -> b) -> a -> b
$ PrimType -> ScalarTypeBase (DimDecl VName) Aliasing
forall dim as. PrimType -> ScalarTypeBase dim as
Prim PrimType
Bool
      Just (OverloadedF [PrimType]
ts [Maybe PrimType]
pts Maybe PrimType
rt) -> do
        StructType
argtype <- SrcLoc -> String -> TermTypeM StructType
forall (m :: * -> *) als dim.
(MonadUnify m, Monoid als) =>
SrcLoc -> String -> m (TypeBase dim als)
newTypeVar SrcLoc
loc String
"t"
        [PrimType] -> Usage -> StructType -> TermTypeM ()
forall (m :: * -> *).
MonadUnify m =>
[PrimType] -> Usage -> StructType -> m ()
mustBeOneOf [PrimType]
ts Usage
usage StructType
argtype
        let ([StructType]
pts', StructType
rt') = StructType
-> [Maybe PrimType] -> Maybe PrimType -> ([StructType], StructType)
forall dim as.
TypeBase dim as
-> [Maybe PrimType]
-> Maybe PrimType
-> ([TypeBase dim ()], TypeBase dim ())
instOverloaded StructType
argtype [Maybe PrimType]
pts Maybe PrimType
rt
            arrow :: TypeBase dim as -> TypeBase dim as -> TypeBase dim as
arrow TypeBase dim as
xt TypeBase dim as
yt = ScalarTypeBase dim as -> TypeBase dim as
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (ScalarTypeBase dim as -> TypeBase dim as)
-> ScalarTypeBase dim as -> TypeBase dim as
forall a b. (a -> b) -> a -> b
$ as
-> PName
-> TypeBase dim as
-> TypeBase dim as
-> ScalarTypeBase dim as
forall dim as.
as
-> PName
-> TypeBase dim as
-> TypeBase dim as
-> ScalarTypeBase dim as
Arrow as
forall a. Monoid a => a
mempty PName
Unnamed TypeBase dim as
xt TypeBase dim as
yt
        PatternType -> TermTypeM PatternType
forall (m :: * -> *) a. Monad m => a -> m a
return (PatternType -> TermTypeM PatternType)
-> PatternType -> TermTypeM PatternType
forall a b. (a -> b) -> a -> b
$ StructType -> PatternType
forall dim as. TypeBase dim as -> TypeBase dim Aliasing
fromStruct (StructType -> PatternType) -> StructType -> PatternType
forall a b. (a -> b) -> a -> b
$ (StructType -> StructType -> StructType)
-> StructType -> [StructType] -> StructType
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr StructType -> StructType -> StructType
forall as dim.
Monoid as =>
TypeBase dim as -> TypeBase dim as -> TypeBase dim as
arrow StructType
rt' [StructType]
pts'

    Ident -> TermTypeM ()
observe (Ident -> TermTypeM ()) -> Ident -> TermTypeM ()
forall a b. (a -> b) -> a -> b
$ VName -> Info PatternType -> SrcLoc -> Ident
forall (f :: * -> *) vn.
vn -> f PatternType -> SrcLoc -> IdentBase f vn
Ident VName
name (PatternType -> Info PatternType
forall a. a -> Info a
Info PatternType
t) SrcLoc
loc
    (QualName VName, PatternType)
-> TermTypeM (QualName VName, PatternType)
forall (m :: * -> *) a. Monad m => a -> m a
return (QualName VName
qn', PatternType
t)
    where
      instOverloaded :: TypeBase dim as
-> [Maybe PrimType]
-> Maybe PrimType
-> ([TypeBase dim ()], TypeBase dim ())
instOverloaded TypeBase dim as
argtype [Maybe PrimType]
pts Maybe PrimType
rt =
        ( (Maybe PrimType -> TypeBase dim ())
-> [Maybe PrimType] -> [TypeBase dim ()]
forall a b. (a -> b) -> [a] -> [b]
map (TypeBase dim ()
-> (PrimType -> TypeBase dim ())
-> Maybe PrimType
-> TypeBase dim ()
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (TypeBase dim as -> TypeBase dim ()
forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct TypeBase dim as
argtype) (ScalarTypeBase dim () -> TypeBase dim ()
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (ScalarTypeBase dim () -> TypeBase dim ())
-> (PrimType -> ScalarTypeBase dim ())
-> PrimType
-> TypeBase dim ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> ScalarTypeBase dim ()
forall dim as. PrimType -> ScalarTypeBase dim as
Prim)) [Maybe PrimType]
pts,
          TypeBase dim ()
-> (PrimType -> TypeBase dim ())
-> Maybe PrimType
-> TypeBase dim ()
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (TypeBase dim as -> TypeBase dim ()
forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct TypeBase dim as
argtype) (ScalarTypeBase dim () -> TypeBase dim ()
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (ScalarTypeBase dim () -> TypeBase dim ())
-> (PrimType -> ScalarTypeBase dim ())
-> PrimType
-> TypeBase dim ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> ScalarTypeBase dim ()
forall dim as. PrimType -> ScalarTypeBase dim as
Prim) Maybe PrimType
rt
        )

  checkNamedDim :: SrcLoc -> QualName Name -> TermTypeM (QualName VName)
checkNamedDim SrcLoc
loc QualName Name
v = do
    (QualName VName
v', PatternType
t) <- SrcLoc -> QualName Name -> TermTypeM (QualName VName, PatternType)
forall (m :: * -> *).
MonadTypeChecker m =>
SrcLoc -> QualName Name -> m (QualName VName, PatternType)
lookupVar SrcLoc
loc QualName Name
v
    Checking -> TermTypeM () -> TermTypeM ()
forall a. Checking -> TermTypeM a -> TermTypeM a
onFailure ([StructType] -> StructType -> Checking
CheckingRequired [ScalarTypeBase (DimDecl VName) () -> StructType
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (ScalarTypeBase (DimDecl VName) () -> StructType)
-> ScalarTypeBase (DimDecl VName) () -> StructType
forall a b. (a -> b) -> a -> b
$ PrimType -> ScalarTypeBase (DimDecl VName) ()
forall dim as. PrimType -> ScalarTypeBase dim as
Prim (PrimType -> ScalarTypeBase (DimDecl VName) ())
-> PrimType -> ScalarTypeBase (DimDecl VName) ()
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
Signed IntType
Int32] (PatternType -> StructType
forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct PatternType
t)) (TermTypeM () -> TermTypeM ()) -> TermTypeM () -> TermTypeM ()
forall a b. (a -> b) -> a -> b
$
      Usage -> StructType -> StructType -> TermTypeM ()
forall (m :: * -> *).
MonadUnify m =>
Usage -> StructType -> StructType -> m ()
unify (SrcLoc -> String -> Usage
mkUsage SrcLoc
loc String
"use as array size") (PatternType -> StructType
forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct PatternType
t) (StructType -> TermTypeM ()) -> StructType -> TermTypeM ()
forall a b. (a -> b) -> a -> b
$
        ScalarTypeBase (DimDecl VName) () -> StructType
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (ScalarTypeBase (DimDecl VName) () -> StructType)
-> ScalarTypeBase (DimDecl VName) () -> StructType
forall a b. (a -> b) -> a -> b
$ PrimType -> ScalarTypeBase (DimDecl VName) ()
forall dim as. PrimType -> ScalarTypeBase dim as
Prim (PrimType -> ScalarTypeBase (DimDecl VName) ())
-> PrimType -> ScalarTypeBase (DimDecl VName) ()
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
Signed IntType
Int32
    QualName VName -> TermTypeM (QualName VName)
forall (m :: * -> *) a. Monad m => a -> m a
return QualName VName
v'

  typeError :: loc -> Notes -> Doc -> TermTypeM a
typeError loc
loc Notes
notes Doc
s = do
    Maybe Checking
checking <- (TermEnv -> Maybe Checking) -> TermTypeM (Maybe Checking)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks TermEnv -> Maybe Checking
termChecking
    case Maybe Checking
checking of
      Just Checking
checking' ->
        TypeError -> TermTypeM a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (TypeError -> TermTypeM a) -> TypeError -> TermTypeM a
forall a b. (a -> b) -> a -> b
$ SrcLoc -> Notes -> Doc -> TypeError
TypeError (loc -> SrcLoc
forall a. Located a => a -> SrcLoc
srclocOf loc
loc) Notes
notes (Checking -> Doc
forall a. Pretty a => a -> Doc
ppr Checking
checking' Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
line Doc -> Doc -> Doc
</> Doc
s)
      Maybe Checking
Nothing ->
        TypeError -> TermTypeM a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (TypeError -> TermTypeM a) -> TypeError -> TermTypeM a
forall a b. (a -> b) -> a -> b
$ SrcLoc -> Notes -> Doc -> TypeError
TypeError (loc -> SrcLoc
forall a. Located a => a -> SrcLoc
srclocOf loc
loc) Notes
notes Doc
s

checkQualNameWithEnv :: Namespace -> QualName Name -> SrcLoc -> TermTypeM (TermScope, QualName VName)
checkQualNameWithEnv :: Namespace
-> QualName Name -> SrcLoc -> TermTypeM (TermScope, QualName VName)
checkQualNameWithEnv Namespace
space qn :: QualName Name
qn@(QualName [Name]
quals Name
name) SrcLoc
loc = do
  TermScope
scope <- (TermEnv -> TermScope) -> TermTypeM TermScope
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks TermEnv -> TermScope
termScope
  TermScope -> [Name] -> TermTypeM (TermScope, QualName VName)
descend TermScope
scope [Name]
quals
  where
    descend :: TermScope -> [Name] -> TermTypeM (TermScope, QualName VName)
descend TermScope
scope []
      | Just QualName VName
name' <- (Namespace, Name) -> NameMap -> Maybe (QualName VName)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (Namespace
space, Name
name) (NameMap -> Maybe (QualName VName))
-> NameMap -> Maybe (QualName VName)
forall a b. (a -> b) -> a -> b
$ TermScope -> NameMap
scopeNameMap TermScope
scope =
        (TermScope, QualName VName)
-> TermTypeM (TermScope, QualName VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (TermScope
scope, QualName VName
name')
      | Bool
otherwise =
        Namespace
-> QualName Name -> SrcLoc -> TermTypeM (TermScope, QualName VName)
forall (m :: * -> *) a.
MonadTypeChecker m =>
Namespace -> QualName Name -> SrcLoc -> m a
unknownVariable Namespace
space QualName Name
qn SrcLoc
loc
    descend TermScope
scope (Name
q : [Name]
qs)
      | Just (QualName [VName]
_ VName
q') <- (Namespace, Name) -> NameMap -> Maybe (QualName VName)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (Namespace
Term, Name
q) (NameMap -> Maybe (QualName VName))
-> NameMap -> Maybe (QualName VName)
forall a b. (a -> b) -> a -> b
$ TermScope -> NameMap
scopeNameMap TermScope
scope,
        Just Mod
res <- VName -> Map VName Mod -> Maybe Mod
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
q' (Map VName Mod -> Maybe Mod) -> Map VName Mod -> Maybe Mod
forall a b. (a -> b) -> a -> b
$ TermScope -> Map VName Mod
scopeModTable TermScope
scope =
        case Mod
res of
          -- Check if we are referring to the magical intrinsics
          -- module.
          Mod
_
            | VName -> Int
baseTag VName
q' Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
maxIntrinsicTag ->
              Namespace
-> QualName Name -> SrcLoc -> TermTypeM (TermScope, QualName VName)
checkIntrinsic Namespace
space QualName Name
qn SrcLoc
loc
          ModEnv Env
q_scope -> do
            (TermScope
scope', QualName [VName]
qs' VName
name') <- TermScope -> [Name] -> TermTypeM (TermScope, QualName VName)
descend (Env -> TermScope
envToTermScope Env
q_scope) [Name]
qs
            (TermScope, QualName VName)
-> TermTypeM (TermScope, QualName VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (TermScope
scope', [VName] -> VName -> QualName VName
forall vn. [vn] -> vn -> QualName vn
QualName (VName
q' VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
qs') VName
name')
          ModFun {} -> SrcLoc -> TermTypeM (TermScope, QualName VName)
forall (m :: * -> *) a. MonadTypeChecker m => SrcLoc -> m a
unappliedFunctor SrcLoc
loc
      | Bool
otherwise =
        Namespace
-> QualName Name -> SrcLoc -> TermTypeM (TermScope, QualName VName)
forall (m :: * -> *) a.
MonadTypeChecker m =>
Namespace -> QualName Name -> SrcLoc -> m a
unknownVariable Namespace
space QualName Name
qn SrcLoc
loc

checkIntrinsic :: Namespace -> QualName Name -> SrcLoc -> TermTypeM (TermScope, QualName VName)
checkIntrinsic :: Namespace
-> QualName Name -> SrcLoc -> TermTypeM (TermScope, QualName VName)
checkIntrinsic Namespace
space qn :: QualName Name
qn@(QualName [Name]
_ Name
name) SrcLoc
loc
  | Just QualName VName
v <- (Namespace, Name) -> NameMap -> Maybe (QualName VName)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (Namespace
space, Name
name) NameMap
intrinsicsNameMap = do
    ImportName
me <- TypeM ImportName -> TermTypeM ImportName
forall a. TypeM a -> TermTypeM a
liftTypeM TypeM ImportName
askImportName
    Bool -> TermTypeM () -> TermTypeM ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (String
"/prelude" String -> String -> Bool
forall a. Eq a => [a] -> [a] -> Bool
`isPrefixOf` ImportName -> String
includeToString ImportName
me) (TermTypeM () -> TermTypeM ()) -> TermTypeM () -> TermTypeM ()
forall a b. (a -> b) -> a -> b
$
      SrcLoc -> String -> TermTypeM ()
forall (m :: * -> *) loc.
(MonadTypeChecker m, Located loc) =>
loc -> String -> m ()
warn SrcLoc
loc String
"Using intrinsic functions directly can easily crash the compiler or result in wrong code generation."
    TermScope
scope <- (TermEnv -> TermScope) -> TermTypeM TermScope
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks TermEnv -> TermScope
termScope
    (TermScope, QualName VName)
-> TermTypeM (TermScope, QualName VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (TermScope
scope, QualName VName
v)
  | Bool
otherwise =
    Namespace
-> QualName Name -> SrcLoc -> TermTypeM (TermScope, QualName VName)
forall (m :: * -> *) a.
MonadTypeChecker m =>
Namespace -> QualName Name -> SrcLoc -> m a
unknownVariable Namespace
space QualName Name
qn SrcLoc
loc

-- | Wrap 'Types.checkTypeDecl' to also perform an observation of
-- every size in the type.
checkTypeDecl :: TypeDeclBase NoInfo Name -> TermTypeM (TypeDeclBase Info VName)
checkTypeDecl :: TypeDeclBase NoInfo Name -> TermTypeM (TypeDeclBase Info VName)
checkTypeDecl TypeDeclBase NoInfo Name
tdecl = do
  (TypeDeclBase Info VName
tdecl', Liftedness
_) <- TypeDeclBase NoInfo Name
-> TermTypeM (TypeDeclBase Info VName, Liftedness)
forall (m :: * -> *).
MonadTypeChecker m =>
TypeDeclBase NoInfo Name -> m (TypeDeclBase Info VName, Liftedness)
Types.checkTypeDecl TypeDeclBase NoInfo Name
tdecl
  (DimDecl VName -> TermTypeM ()) -> [DimDecl VName] -> TermTypeM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ DimDecl VName -> TermTypeM ()
observeDim ([DimDecl VName] -> TermTypeM ())
-> [DimDecl VName] -> TermTypeM ()
forall a b. (a -> b) -> a -> b
$ StructType -> [DimDecl VName]
forall as. TypeBase (DimDecl VName) as -> [DimDecl VName]
nestedDims (StructType -> [DimDecl VName]) -> StructType -> [DimDecl VName]
forall a b. (a -> b) -> a -> b
$ Info StructType -> StructType
forall a. Info a -> a
unInfo (Info StructType -> StructType) -> Info StructType -> StructType
forall a b. (a -> b) -> a -> b
$ TypeDeclBase Info VName -> Info StructType
forall (f :: * -> *) vn. TypeDeclBase f vn -> f StructType
expandedType TypeDeclBase Info VName
tdecl'
  TypeDeclBase Info VName -> TermTypeM (TypeDeclBase Info VName)
forall (m :: * -> *) a. Monad m => a -> m a
return TypeDeclBase Info VName
tdecl'
  where
    observeDim :: DimDecl VName -> TermTypeM ()
observeDim (NamedDim QualName VName
v) =
      Ident -> TermTypeM ()
observe (Ident -> TermTypeM ()) -> Ident -> TermTypeM ()
forall a b. (a -> b) -> a -> b
$ VName -> Info PatternType -> SrcLoc -> Ident
forall (f :: * -> *) vn.
vn -> f PatternType -> SrcLoc -> IdentBase f vn
Ident (QualName VName -> VName
forall vn. QualName vn -> vn
qualLeaf QualName VName
v) (PatternType -> Info PatternType
forall a. a -> Info a
Info (PatternType -> Info PatternType)
-> PatternType -> Info PatternType
forall a b. (a -> b) -> a -> b
$ ScalarTypeBase (DimDecl VName) Aliasing -> PatternType
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (ScalarTypeBase (DimDecl VName) Aliasing -> PatternType)
-> ScalarTypeBase (DimDecl VName) Aliasing -> PatternType
forall a b. (a -> b) -> a -> b
$ PrimType -> ScalarTypeBase (DimDecl VName) Aliasing
forall dim as. PrimType -> ScalarTypeBase dim as
Prim (PrimType -> ScalarTypeBase (DimDecl VName) Aliasing)
-> PrimType -> ScalarTypeBase (DimDecl VName) Aliasing
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
Signed IntType
Int32) SrcLoc
forall a. Monoid a => a
mempty
    observeDim DimDecl VName
_ = () -> TermTypeM ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

-- | Instantiate a type scheme with fresh type variables for its type
-- parameters. Returns the names of the fresh type variables, the
-- instance list, and the instantiated type.
instantiateTypeScheme ::
  SrcLoc ->
  [TypeParam] ->
  PatternType ->
  TermTypeM ([VName], PatternType)
instantiateTypeScheme :: SrcLoc
-> [TypeParam] -> PatternType -> TermTypeM ([VName], PatternType)
instantiateTypeScheme SrcLoc
loc [TypeParam]
tparams PatternType
t = do
  let tnames :: [VName]
tnames = (TypeParam -> VName) -> [TypeParam] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map TypeParam -> VName
forall vn. TypeParamBase vn -> vn
typeParamName [TypeParam]
tparams
  ([VName]
tparam_names, [Subst StructType]
tparam_substs) <- [(VName, Subst StructType)] -> ([VName], [Subst StructType])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, Subst StructType)] -> ([VName], [Subst StructType]))
-> TermTypeM [(VName, Subst StructType)]
-> TermTypeM ([VName], [Subst StructType])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (TypeParam -> TermTypeM (VName, Subst StructType))
-> [TypeParam] -> TermTypeM [(VName, Subst StructType)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SrcLoc -> TypeParam -> TermTypeM (VName, Subst StructType)
forall as dim.
Monoid as =>
SrcLoc -> TypeParam -> TermTypeM (VName, Subst (TypeBase dim as))
instantiateTypeParam SrcLoc
loc) [TypeParam]
tparams
  let substs :: Map VName (Subst StructType)
substs = [(VName, Subst StructType)] -> Map VName (Subst StructType)
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, Subst StructType)] -> Map VName (Subst StructType))
-> [(VName, Subst StructType)] -> Map VName (Subst StructType)
forall a b. (a -> b) -> a -> b
$ [VName] -> [Subst StructType] -> [(VName, Subst StructType)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
tnames [Subst StructType]
tparam_substs
      t' :: PatternType
t' = (VName -> Maybe (Subst StructType)) -> PatternType -> PatternType
forall a.
Substitutable a =>
(VName -> Maybe (Subst StructType)) -> a -> a
applySubst (VName -> Map VName (Subst StructType) -> Maybe (Subst StructType)
forall k a. Ord k => k -> Map k a -> Maybe a
`M.lookup` Map VName (Subst StructType)
substs) PatternType
t
  ([VName], PatternType) -> TermTypeM ([VName], PatternType)
forall (m :: * -> *) a. Monad m => a -> m a
return ([VName]
tparam_names, PatternType
t')

-- | Create a new type name and insert it (unconstrained) in the
-- substitution map.
instantiateTypeParam :: Monoid as => SrcLoc -> TypeParam -> TermTypeM (VName, Subst (TypeBase dim as))
instantiateTypeParam :: SrcLoc -> TypeParam -> TermTypeM (VName, Subst (TypeBase dim as))
instantiateTypeParam SrcLoc
loc TypeParam
tparam = do
  Int
i <- TermTypeM Int
incCounter
  VName
v <- Name -> TermTypeM VName
forall (m :: * -> *). MonadTypeChecker m => Name -> m VName
newID (Name -> TermTypeM VName) -> Name -> TermTypeM VName
forall a b. (a -> b) -> a -> b
$ String -> Int -> Name
mkTypeVarName ((Char -> Bool) -> ShowS
forall a. (a -> Bool) -> [a] -> [a]
takeWhile Char -> Bool
isAscii (VName -> String
baseString (TypeParam -> VName
forall vn. TypeParamBase vn -> vn
typeParamName TypeParam
tparam))) Int
i
  case TypeParam
tparam of
    TypeParamType Liftedness
x VName
_ SrcLoc
_ -> do
      VName -> Constraint -> TermTypeM ()
constrain VName
v (Constraint -> TermTypeM ()) -> Constraint -> TermTypeM ()
forall a b. (a -> b) -> a -> b
$ Liftedness -> Usage -> Constraint
NoConstraint Liftedness
x (Usage -> Constraint) -> Usage -> Constraint
forall a b. (a -> b) -> a -> b
$ SrcLoc -> Usage
mkUsage' SrcLoc
loc
      (VName, Subst (TypeBase dim as))
-> TermTypeM (VName, Subst (TypeBase dim as))
forall (m :: * -> *) a. Monad m => a -> m a
return (VName
v, TypeBase dim as -> Subst (TypeBase dim as)
forall t. t -> Subst t
Subst (TypeBase dim as -> Subst (TypeBase dim as))
-> TypeBase dim as -> Subst (TypeBase dim as)
forall a b. (a -> b) -> a -> b
$ ScalarTypeBase dim as -> TypeBase dim as
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (ScalarTypeBase dim as -> TypeBase dim as)
-> ScalarTypeBase dim as -> TypeBase dim as
forall a b. (a -> b) -> a -> b
$ as
-> Uniqueness -> TypeName -> [TypeArg dim] -> ScalarTypeBase dim as
forall dim as.
as
-> Uniqueness -> TypeName -> [TypeArg dim] -> ScalarTypeBase dim as
TypeVar as
forall a. Monoid a => a
mempty Uniqueness
Nonunique (VName -> TypeName
typeName VName
v) [])
    TypeParamDim {} -> do
      VName -> Constraint -> TermTypeM ()
constrain VName
v (Constraint -> TermTypeM ()) -> Constraint -> TermTypeM ()
forall a b. (a -> b) -> a -> b
$ Maybe (DimDecl VName) -> Usage -> Constraint
Size Maybe (DimDecl VName)
forall a. Maybe a
Nothing (Usage -> Constraint) -> Usage -> Constraint
forall a b. (a -> b) -> a -> b
$ SrcLoc -> Usage
mkUsage' SrcLoc
loc
      (VName, Subst (TypeBase dim as))
-> TermTypeM (VName, Subst (TypeBase dim as))
forall (m :: * -> *) a. Monad m => a -> m a
return (VName
v, DimDecl VName -> Subst (TypeBase dim as)
forall t. DimDecl VName -> Subst t
SizeSubst (DimDecl VName -> Subst (TypeBase dim as))
-> DimDecl VName -> Subst (TypeBase dim as)
forall a b. (a -> b) -> a -> b
$ QualName VName -> DimDecl VName
forall vn. QualName vn -> DimDecl vn
NamedDim (QualName VName -> DimDecl VName)
-> QualName VName -> DimDecl VName
forall a b. (a -> b) -> a -> b
$ VName -> QualName VName
forall v. v -> QualName v
qualName VName
v)

newArrayType :: SrcLoc -> String -> Int -> TermTypeM (StructType, StructType)
newArrayType :: SrcLoc -> String -> Int -> TermTypeM (StructType, StructType)
newArrayType SrcLoc
loc String
desc Int
r = do
  VName
v <- Name -> TermTypeM VName
forall (m :: * -> *). MonadTypeChecker m => Name -> m VName
newID (Name -> TermTypeM VName) -> Name -> TermTypeM VName
forall a b. (a -> b) -> a -> b
$ String -> Name
nameFromString String
desc
  VName -> Constraint -> TermTypeM ()
constrain VName
v (Constraint -> TermTypeM ()) -> Constraint -> TermTypeM ()
forall a b. (a -> b) -> a -> b
$ Liftedness -> Usage -> Constraint
NoConstraint Liftedness
Unlifted (Usage -> Constraint) -> Usage -> Constraint
forall a b. (a -> b) -> a -> b
$ SrcLoc -> Usage
mkUsage' SrcLoc
loc
  [VName]
dims <- Int -> TermTypeM VName -> TermTypeM [VName]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
r (TermTypeM VName -> TermTypeM [VName])
-> TermTypeM VName -> TermTypeM [VName]
forall a b. (a -> b) -> a -> b
$ SrcLoc -> Rigidity -> String -> TermTypeM VName
forall (m :: * -> *).
MonadUnify m =>
SrcLoc -> Rigidity -> String -> m VName
newDimVar SrcLoc
loc Rigidity
Nonrigid String
"dim"
  let rowt :: ScalarTypeBase dim ()
rowt = ()
-> Uniqueness -> TypeName -> [TypeArg dim] -> ScalarTypeBase dim ()
forall dim as.
as
-> Uniqueness -> TypeName -> [TypeArg dim] -> ScalarTypeBase dim as
TypeVar () Uniqueness
Nonunique (VName -> TypeName
typeName VName
v) []
  (StructType, StructType) -> TermTypeM (StructType, StructType)
forall (m :: * -> *) a. Monad m => a -> m a
return
    ( ()
-> Uniqueness
-> ScalarTypeBase (DimDecl VName) ()
-> ShapeDecl (DimDecl VName)
-> StructType
forall dim as.
as
-> Uniqueness
-> ScalarTypeBase dim ()
-> ShapeDecl dim
-> TypeBase dim as
Array () Uniqueness
Nonunique ScalarTypeBase (DimDecl VName) ()
forall dim. ScalarTypeBase dim ()
rowt ([DimDecl VName] -> ShapeDecl (DimDecl VName)
forall dim. [dim] -> ShapeDecl dim
ShapeDecl ([DimDecl VName] -> ShapeDecl (DimDecl VName))
-> [DimDecl VName] -> ShapeDecl (DimDecl VName)
forall a b. (a -> b) -> a -> b
$ (VName -> DimDecl VName) -> [VName] -> [DimDecl VName]
forall a b. (a -> b) -> [a] -> [b]
map (QualName VName -> DimDecl VName
forall vn. QualName vn -> DimDecl vn
NamedDim (QualName VName -> DimDecl VName)
-> (VName -> QualName VName) -> VName -> DimDecl VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> QualName VName
forall v. v -> QualName v
qualName) [VName]
dims),
      ScalarTypeBase (DimDecl VName) () -> StructType
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar ScalarTypeBase (DimDecl VName) ()
forall dim. ScalarTypeBase dim ()
rowt
    )

--- Errors

useAfterConsume :: Name -> SrcLoc -> SrcLoc -> TermTypeM a
useAfterConsume :: Name -> SrcLoc -> SrcLoc -> TermTypeM a
useAfterConsume Name
name SrcLoc
rloc SrcLoc
wloc =
  SrcLoc -> Notes -> Doc -> TermTypeM a
forall (m :: * -> *) loc a.
(MonadTypeChecker m, Located loc) =>
loc -> Notes -> Doc -> m a
typeError SrcLoc
rloc Notes
forall a. Monoid a => a
mempty (Doc -> TermTypeM a) -> Doc -> TermTypeM a
forall a b. (a -> b) -> a -> b
$
    Doc
"Variable" Doc -> Doc -> Doc
<+> Doc -> Doc
pquote (Name -> Doc
forall v. IsName v => v -> Doc
pprName Name
name) Doc -> Doc -> Doc
<+> Doc
"previously consumed at"
      Doc -> Doc -> Doc
<+> String -> Doc
text (SrcLoc -> SrcLoc -> String
forall a b. (Located a, Located b) => a -> b -> String
locStrRel SrcLoc
rloc SrcLoc
wloc) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
".  (Possibly through aliasing.)"

consumeAfterConsume :: Name -> SrcLoc -> SrcLoc -> TermTypeM a
consumeAfterConsume :: Name -> SrcLoc -> SrcLoc -> TermTypeM a
consumeAfterConsume Name
name SrcLoc
loc1 SrcLoc
loc2 =
  SrcLoc -> Notes -> Doc -> TermTypeM a
forall (m :: * -> *) loc a.
(MonadTypeChecker m, Located loc) =>
loc -> Notes -> Doc -> m a
typeError SrcLoc
loc2 Notes
forall a. Monoid a => a
mempty (Doc -> TermTypeM a) -> Doc -> TermTypeM a
forall a b. (a -> b) -> a -> b
$
    Doc
"Variable" Doc -> Doc -> Doc
<+> Name -> Doc
forall v. IsName v => v -> Doc
pprName Name
name Doc -> Doc -> Doc
<+> Doc
"previously consumed at"
      Doc -> Doc -> Doc
<+> String -> Doc
text (SrcLoc -> SrcLoc -> String
forall a b. (Located a, Located b) => a -> b -> String
locStrRel SrcLoc
loc2 SrcLoc
loc1) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
"."

badLetWithValue :: SrcLoc -> TermTypeM a
badLetWithValue :: SrcLoc -> TermTypeM a
badLetWithValue SrcLoc
loc =
  SrcLoc -> Notes -> Doc -> TermTypeM a
forall (m :: * -> *) loc a.
(MonadTypeChecker m, Located loc) =>
loc -> Notes -> Doc -> m a
typeError
    SrcLoc
loc
    Notes
forall a. Monoid a => a
mempty
    Doc
"New value for elements in let-with shares data with source array.  This is illegal, as it prevents in-place modification."

returnAliased :: Name -> Name -> SrcLoc -> TermTypeM ()
returnAliased :: Name -> Name -> SrcLoc -> TermTypeM ()
returnAliased Name
fname Name
name SrcLoc
loc =
  SrcLoc -> Notes -> Doc -> TermTypeM ()
forall (m :: * -> *) loc a.
(MonadTypeChecker m, Located loc) =>
loc -> Notes -> Doc -> m a
typeError SrcLoc
loc Notes
forall a. Monoid a => a
mempty (Doc -> TermTypeM ()) -> Doc -> TermTypeM ()
forall a b. (a -> b) -> a -> b
$
    Doc
"Unique return value of" Doc -> Doc -> Doc
<+> Doc -> Doc
pquote (Name -> Doc
forall v. IsName v => v -> Doc
pprName Name
fname)
      Doc -> Doc -> Doc
<+> Doc
"is aliased to"
      Doc -> Doc -> Doc
<+> Doc -> Doc
pquote (Name -> Doc
forall v. IsName v => v -> Doc
pprName Name
name) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
", which is not consumed."

uniqueReturnAliased :: Name -> SrcLoc -> TermTypeM a
uniqueReturnAliased :: Name -> SrcLoc -> TermTypeM a
uniqueReturnAliased Name
fname SrcLoc
loc =
  SrcLoc -> Notes -> Doc -> TermTypeM a
forall (m :: * -> *) loc a.
(MonadTypeChecker m, Located loc) =>
loc -> Notes -> Doc -> m a
typeError SrcLoc
loc Notes
forall a. Monoid a => a
mempty (Doc -> TermTypeM a) -> Doc -> TermTypeM a
forall a b. (a -> b) -> a -> b
$
    Doc
"A unique tuple element of return value of"
      Doc -> Doc -> Doc
<+> Doc -> Doc
pquote (Name -> Doc
forall v. IsName v => v -> Doc
pprName Name
fname)
      Doc -> Doc -> Doc
<+> Doc
"is aliased to some other tuple component."

unexpectedType :: MonadTypeChecker m => SrcLoc -> StructType -> [StructType] -> m a
unexpectedType :: SrcLoc -> StructType -> [StructType] -> m a
unexpectedType SrcLoc
loc StructType
_ [] =
  SrcLoc -> Notes -> Doc -> m a
forall (m :: * -> *) loc a.
(MonadTypeChecker m, Located loc) =>
loc -> Notes -> Doc -> m a
typeError SrcLoc
loc Notes
forall a. Monoid a => a
mempty (Doc -> m a) -> Doc -> m a
forall a b. (a -> b) -> a -> b
$
    Doc
"Type of expression at" Doc -> Doc -> Doc
<+> String -> Doc
text (SrcLoc -> String
forall a. Located a => a -> String
locStr SrcLoc
loc)
      Doc -> Doc -> Doc
<+> Doc
"cannot have any type - possibly a bug in the type checker."
unexpectedType SrcLoc
loc StructType
t [StructType]
ts =
  SrcLoc -> Notes -> Doc -> m a
forall (m :: * -> *) loc a.
(MonadTypeChecker m, Located loc) =>
loc -> Notes -> Doc -> m a
typeError SrcLoc
loc Notes
forall a. Monoid a => a
mempty (Doc -> m a) -> Doc -> m a
forall a b. (a -> b) -> a -> b
$
    Doc
"Type of expression at" Doc -> Doc -> Doc
<+> String -> Doc
text (SrcLoc -> String
forall a. Located a => a -> String
locStr SrcLoc
loc) Doc -> Doc -> Doc
<+> Doc
"must be one of"
      Doc -> Doc -> Doc
<+> [Doc] -> Doc
commasep ((StructType -> Doc) -> [StructType] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map StructType -> Doc
forall a. Pretty a => a -> Doc
ppr [StructType]
ts) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
", but is"
      Doc -> Doc -> Doc
<+> StructType -> Doc
forall a. Pretty a => a -> Doc
ppr StructType
t Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
"."

--- Basic checking

-- | Determine if the two types of identical, ignoring uniqueness.
-- Mismatched dimensions are turned into fresh rigid type variables.
-- Causes a 'TypeError' if they fail to match, and otherwise returns
-- one of them.
unifyBranchTypes :: SrcLoc -> PatternType -> PatternType -> TermTypeM (PatternType, [VName])
unifyBranchTypes :: SrcLoc
-> PatternType -> PatternType -> TermTypeM (PatternType, [VName])
unifyBranchTypes SrcLoc
loc PatternType
t1 PatternType
t2 =
  Checking
-> TermTypeM (PatternType, [VName])
-> TermTypeM (PatternType, [VName])
forall a. Checking -> TermTypeM a -> TermTypeM a
onFailure (StructType -> StructType -> Checking
CheckingBranches (PatternType -> StructType
forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct PatternType
t1) (PatternType -> StructType
forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct PatternType
t2)) (TermTypeM (PatternType, [VName])
 -> TermTypeM (PatternType, [VName]))
-> TermTypeM (PatternType, [VName])
-> TermTypeM (PatternType, [VName])
forall a b. (a -> b) -> a -> b
$
    Usage
-> PatternType -> PatternType -> TermTypeM (PatternType, [VName])
forall (m :: * -> *).
MonadUnify m =>
Usage -> PatternType -> PatternType -> m (PatternType, [VName])
unifyMostCommon (SrcLoc -> String -> Usage
mkUsage SrcLoc
loc String
"unification of branch results") PatternType
t1 PatternType
t2

unifyBranches :: SrcLoc -> Exp -> Exp -> TermTypeM (PatternType, [VName])
unifyBranches :: SrcLoc -> Exp -> Exp -> TermTypeM (PatternType, [VName])
unifyBranches SrcLoc
loc Exp
e1 Exp
e2 = do
  PatternType
e1_t <- Exp -> TermTypeM PatternType
expTypeFully Exp
e1
  PatternType
e2_t <- Exp -> TermTypeM PatternType
expTypeFully Exp
e2
  SrcLoc
-> PatternType -> PatternType -> TermTypeM (PatternType, [VName])
unifyBranchTypes SrcLoc
loc PatternType
e1_t PatternType
e2_t

--- General binding.

doNotShadow :: [String]
doNotShadow :: [String]
doNotShadow = [String
"&&", String
"||"]

data InferredType
  = NoneInferred
  | Ascribed PatternType

checkPattern' ::
  UncheckedPattern ->
  InferredType ->
  TermTypeM Pattern
checkPattern' :: UncheckedPattern -> InferredType -> TermTypeM Pattern
checkPattern' (PatternParens UncheckedPattern
p SrcLoc
loc) InferredType
t =
  Pattern -> SrcLoc -> Pattern
forall (f :: * -> *) vn.
PatternBase f vn -> SrcLoc -> PatternBase f vn
PatternParens (Pattern -> SrcLoc -> Pattern)
-> TermTypeM Pattern -> TermTypeM (SrcLoc -> Pattern)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> UncheckedPattern -> InferredType -> TermTypeM Pattern
checkPattern' UncheckedPattern
p InferredType
t TermTypeM (SrcLoc -> Pattern)
-> TermTypeM SrcLoc -> TermTypeM Pattern
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SrcLoc -> TermTypeM SrcLoc
forall (f :: * -> *) a. Applicative f => a -> f a
pure SrcLoc
loc
checkPattern' (Id Name
name NoInfo PatternType
_ SrcLoc
loc) InferredType
_
  | String
name' String -> [String] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [String]
doNotShadow =
    SrcLoc -> Notes -> Doc -> TermTypeM Pattern
forall (m :: * -> *) loc a.
(MonadTypeChecker m, Located loc) =>
loc -> Notes -> Doc -> m a
typeError SrcLoc
loc Notes
forall a. Monoid a => a
mempty (Doc -> TermTypeM Pattern) -> Doc -> TermTypeM Pattern
forall a b. (a -> b) -> a -> b
$ Doc
"The" Doc -> Doc -> Doc
<+> String -> Doc
text String
name' Doc -> Doc -> Doc
<+> Doc
"operator may not be redefined."
  where
    name' :: String
name' = Name -> String
nameToString Name
name
checkPattern' (Id Name
name NoInfo PatternType
NoInfo SrcLoc
loc) (Ascribed PatternType
t) = do
  VName
name' <- Name -> TermTypeM VName
forall (m :: * -> *). MonadTypeChecker m => Name -> m VName
newID Name
name
  Pattern -> TermTypeM Pattern
forall (m :: * -> *) a. Monad m => a -> m a
return (Pattern -> TermTypeM Pattern) -> Pattern -> TermTypeM Pattern
forall a b. (a -> b) -> a -> b
$ VName -> Info PatternType -> SrcLoc -> Pattern
forall (f :: * -> *) vn.
vn -> f PatternType -> SrcLoc -> PatternBase f vn
Id VName
name' (PatternType -> Info PatternType
forall a. a -> Info a
Info PatternType
t) SrcLoc
loc
checkPattern' (Id Name
name NoInfo PatternType
NoInfo SrcLoc
loc) InferredType
NoneInferred = do
  VName
name' <- Name -> TermTypeM VName
forall (m :: * -> *). MonadTypeChecker m => Name -> m VName
newID Name
name
  PatternType
t <- SrcLoc -> String -> TermTypeM PatternType
forall (m :: * -> *) als dim.
(MonadUnify m, Monoid als) =>
SrcLoc -> String -> m (TypeBase dim als)
newTypeVar SrcLoc
loc String
"t"
  Pattern -> TermTypeM Pattern
forall (m :: * -> *) a. Monad m => a -> m a
return (Pattern -> TermTypeM Pattern) -> Pattern -> TermTypeM Pattern
forall a b. (a -> b) -> a -> b
$ VName -> Info PatternType -> SrcLoc -> Pattern
forall (f :: * -> *) vn.
vn -> f PatternType -> SrcLoc -> PatternBase f vn
Id VName
name' (PatternType -> Info PatternType
forall a. a -> Info a
Info PatternType
t) SrcLoc
loc
checkPattern' (Wildcard NoInfo PatternType
_ SrcLoc
loc) (Ascribed PatternType
t) =
  Pattern -> TermTypeM Pattern
forall (m :: * -> *) a. Monad m => a -> m a
return (Pattern -> TermTypeM Pattern) -> Pattern -> TermTypeM Pattern
forall a b. (a -> b) -> a -> b
$ Info PatternType -> SrcLoc -> Pattern
forall (f :: * -> *) vn.
f PatternType -> SrcLoc -> PatternBase f vn
Wildcard (PatternType -> Info PatternType
forall a. a -> Info a
Info (PatternType -> Info PatternType)
-> PatternType -> Info PatternType
forall a b. (a -> b) -> a -> b
$ PatternType
t PatternType -> Uniqueness -> PatternType
forall dim as. TypeBase dim as -> Uniqueness -> TypeBase dim as
`setUniqueness` Uniqueness
Nonunique) SrcLoc
loc
checkPattern' (Wildcard NoInfo PatternType
NoInfo SrcLoc
loc) InferredType
NoneInferred = do
  PatternType
t <- SrcLoc -> String -> TermTypeM PatternType
forall (m :: * -> *) als dim.
(MonadUnify m, Monoid als) =>
SrcLoc -> String -> m (TypeBase dim als)
newTypeVar SrcLoc
loc String
"t"
  Pattern -> TermTypeM Pattern
forall (m :: * -> *) a. Monad m => a -> m a
return (Pattern -> TermTypeM Pattern) -> Pattern -> TermTypeM Pattern
forall a b. (a -> b) -> a -> b
$ Info PatternType -> SrcLoc -> Pattern
forall (f :: * -> *) vn.
f PatternType -> SrcLoc -> PatternBase f vn
Wildcard (PatternType -> Info PatternType
forall a. a -> Info a
Info PatternType
t) SrcLoc
loc
checkPattern' (TuplePattern [UncheckedPattern]
ps SrcLoc
loc) (Ascribed PatternType
t)
  | Just [PatternType]
ts <- PatternType -> Maybe [PatternType]
forall dim as. TypeBase dim as -> Maybe [TypeBase dim as]
isTupleRecord PatternType
t,
    [PatternType] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PatternType]
ts Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [UncheckedPattern] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [UncheckedPattern]
ps =
    [Pattern] -> SrcLoc -> Pattern
forall (f :: * -> *) vn.
[PatternBase f vn] -> SrcLoc -> PatternBase f vn
TuplePattern ([Pattern] -> SrcLoc -> Pattern)
-> TermTypeM [Pattern] -> TermTypeM (SrcLoc -> Pattern)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (UncheckedPattern -> InferredType -> TermTypeM Pattern)
-> [UncheckedPattern] -> [InferredType] -> TermTypeM [Pattern]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM UncheckedPattern -> InferredType -> TermTypeM Pattern
checkPattern' [UncheckedPattern]
ps ((PatternType -> InferredType) -> [PatternType] -> [InferredType]
forall a b. (a -> b) -> [a] -> [b]
map PatternType -> InferredType
Ascribed [PatternType]
ts) TermTypeM (SrcLoc -> Pattern)
-> TermTypeM SrcLoc -> TermTypeM Pattern
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SrcLoc -> TermTypeM SrcLoc
forall (f :: * -> *) a. Applicative f => a -> f a
pure SrcLoc
loc
checkPattern' p :: UncheckedPattern
p@(TuplePattern [UncheckedPattern]
ps SrcLoc
loc) (Ascribed PatternType
t) = do
  [StructType]
ps_t <- Int -> TermTypeM StructType -> TermTypeM [StructType]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM ([UncheckedPattern] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [UncheckedPattern]
ps) (SrcLoc -> String -> TermTypeM StructType
forall (m :: * -> *) als dim.
(MonadUnify m, Monoid als) =>
SrcLoc -> String -> m (TypeBase dim als)
newTypeVar SrcLoc
loc String
"t")
  Usage -> StructType -> StructType -> TermTypeM ()
forall (m :: * -> *).
MonadUnify m =>
Usage -> StructType -> StructType -> m ()
unify (SrcLoc -> String -> Usage
mkUsage SrcLoc
loc String
"matching a tuple pattern") ([StructType] -> StructType
forall dim as. [TypeBase dim as] -> TypeBase dim as
tupleRecord [StructType]
ps_t) (StructType -> TermTypeM ()) -> StructType -> TermTypeM ()
forall a b. (a -> b) -> a -> b
$ PatternType -> StructType
forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct PatternType
t
  PatternType
t' <- PatternType -> TermTypeM PatternType
forall a (m :: * -> *). (Substitutable a, MonadUnify m) => a -> m a
normTypeFully PatternType
t
  UncheckedPattern -> InferredType -> TermTypeM Pattern
checkPattern' UncheckedPattern
p (InferredType -> TermTypeM Pattern)
-> InferredType -> TermTypeM Pattern
forall a b. (a -> b) -> a -> b
$ PatternType -> InferredType
Ascribed PatternType
t'
checkPattern' (TuplePattern [UncheckedPattern]
ps SrcLoc
loc) InferredType
NoneInferred =
  [Pattern] -> SrcLoc -> Pattern
forall (f :: * -> *) vn.
[PatternBase f vn] -> SrcLoc -> PatternBase f vn
TuplePattern ([Pattern] -> SrcLoc -> Pattern)
-> TermTypeM [Pattern] -> TermTypeM (SrcLoc -> Pattern)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (UncheckedPattern -> TermTypeM Pattern)
-> [UncheckedPattern] -> TermTypeM [Pattern]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (UncheckedPattern -> InferredType -> TermTypeM Pattern
`checkPattern'` InferredType
NoneInferred) [UncheckedPattern]
ps TermTypeM (SrcLoc -> Pattern)
-> TermTypeM SrcLoc -> TermTypeM Pattern
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SrcLoc -> TermTypeM SrcLoc
forall (f :: * -> *) a. Applicative f => a -> f a
pure SrcLoc
loc
checkPattern' (RecordPattern [(Name, UncheckedPattern)]
p_fs SrcLoc
_) InferredType
_
  | Just (Name
f, UncheckedPattern
fp) <- ((Name, UncheckedPattern) -> Bool)
-> [(Name, UncheckedPattern)] -> Maybe (Name, UncheckedPattern)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((String
"_" String -> String -> Bool
forall a. Eq a => [a] -> [a] -> Bool
`isPrefixOf`) (String -> Bool)
-> ((Name, UncheckedPattern) -> String)
-> (Name, UncheckedPattern)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Name -> String
nameToString (Name -> String)
-> ((Name, UncheckedPattern) -> Name)
-> (Name, UncheckedPattern)
-> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Name, UncheckedPattern) -> Name
forall a b. (a, b) -> a
fst) [(Name, UncheckedPattern)]
p_fs =
    UncheckedPattern -> Notes -> Doc -> TermTypeM Pattern
forall (m :: * -> *) loc a.
(MonadTypeChecker m, Located loc) =>
loc -> Notes -> Doc -> m a
typeError UncheckedPattern
fp Notes
forall a. Monoid a => a
mempty (Doc -> TermTypeM Pattern) -> Doc -> TermTypeM Pattern
forall a b. (a -> b) -> a -> b
$
      Doc
"Underscore-prefixed fields are not allowed."
        Doc -> Doc -> Doc
</> Doc
"Did you mean" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
dquotes (String -> Doc
text (Int -> ShowS
forall a. Int -> [a] -> [a]
drop Int
1 (Name -> String
nameToString Name
f)) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
"=_") Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
"?"
checkPattern' (RecordPattern [(Name, UncheckedPattern)]
p_fs SrcLoc
loc) (Ascribed (Scalar (Record Map Name PatternType
t_fs)))
  | [Name] -> [Name]
forall a. Ord a => [a] -> [a]
sort (((Name, UncheckedPattern) -> Name)
-> [(Name, UncheckedPattern)] -> [Name]
forall a b. (a -> b) -> [a] -> [b]
map (Name, UncheckedPattern) -> Name
forall a b. (a, b) -> a
fst [(Name, UncheckedPattern)]
p_fs) [Name] -> [Name] -> Bool
forall a. Eq a => a -> a -> Bool
== [Name] -> [Name]
forall a. Ord a => [a] -> [a]
sort (Map Name PatternType -> [Name]
forall k a. Map k a -> [k]
M.keys Map Name PatternType
t_fs) =
    [(Name, Pattern)] -> SrcLoc -> Pattern
forall (f :: * -> *) vn.
[(Name, PatternBase f vn)] -> SrcLoc -> PatternBase f vn
RecordPattern ([(Name, Pattern)] -> SrcLoc -> Pattern)
-> (Map Name Pattern -> [(Name, Pattern)])
-> Map Name Pattern
-> SrcLoc
-> Pattern
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map Name Pattern -> [(Name, Pattern)]
forall k a. Map k a -> [(k, a)]
M.toList (Map Name Pattern -> SrcLoc -> Pattern)
-> TermTypeM (Map Name Pattern) -> TermTypeM (SrcLoc -> Pattern)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TermTypeM (Map Name Pattern)
check TermTypeM (SrcLoc -> Pattern)
-> TermTypeM SrcLoc -> TermTypeM Pattern
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SrcLoc -> TermTypeM SrcLoc
forall (f :: * -> *) a. Applicative f => a -> f a
pure SrcLoc
loc
  where
    check :: TermTypeM (Map Name Pattern)
check =
      ((UncheckedPattern, InferredType) -> TermTypeM Pattern)
-> Map Name (UncheckedPattern, InferredType)
-> TermTypeM (Map Name Pattern)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse ((UncheckedPattern -> InferredType -> TermTypeM Pattern)
-> (UncheckedPattern, InferredType) -> TermTypeM Pattern
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry UncheckedPattern -> InferredType -> TermTypeM Pattern
checkPattern') (Map Name (UncheckedPattern, InferredType)
 -> TermTypeM (Map Name Pattern))
-> Map Name (UncheckedPattern, InferredType)
-> TermTypeM (Map Name Pattern)
forall a b. (a -> b) -> a -> b
$
        (UncheckedPattern
 -> InferredType -> (UncheckedPattern, InferredType))
-> Map Name UncheckedPattern
-> Map Name InferredType
-> Map Name (UncheckedPattern, InferredType)
forall k a b c.
Ord k =>
(a -> b -> c) -> Map k a -> Map k b -> Map k c
M.intersectionWith
          (,)
          ([(Name, UncheckedPattern)] -> Map Name UncheckedPattern
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(Name, UncheckedPattern)]
p_fs)
          ((PatternType -> InferredType)
-> Map Name PatternType -> Map Name InferredType
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap PatternType -> InferredType
Ascribed Map Name PatternType
t_fs)
checkPattern' p :: UncheckedPattern
p@(RecordPattern [(Name, UncheckedPattern)]
fields SrcLoc
loc) (Ascribed PatternType
t) = do
  Map Name StructType
fields' <- (UncheckedPattern -> TermTypeM StructType)
-> Map Name UncheckedPattern -> TermTypeM (Map Name StructType)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (TermTypeM StructType -> UncheckedPattern -> TermTypeM StructType
forall a b. a -> b -> a
const (TermTypeM StructType -> UncheckedPattern -> TermTypeM StructType)
-> TermTypeM StructType -> UncheckedPattern -> TermTypeM StructType
forall a b. (a -> b) -> a -> b
$ SrcLoc -> String -> TermTypeM StructType
forall (m :: * -> *) als dim.
(MonadUnify m, Monoid als) =>
SrcLoc -> String -> m (TypeBase dim als)
newTypeVar SrcLoc
loc String
"t") (Map Name UncheckedPattern -> TermTypeM (Map Name StructType))
-> Map Name UncheckedPattern -> TermTypeM (Map Name StructType)
forall a b. (a -> b) -> a -> b
$ [(Name, UncheckedPattern)] -> Map Name UncheckedPattern
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(Name, UncheckedPattern)]
fields

  Bool -> TermTypeM () -> TermTypeM ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ([Name] -> [Name]
forall a. Ord a => [a] -> [a]
sort (Map Name StructType -> [Name]
forall k a. Map k a -> [k]
M.keys Map Name StructType
fields') [Name] -> [Name] -> Bool
forall a. Eq a => a -> a -> Bool
/= [Name] -> [Name]
forall a. Ord a => [a] -> [a]
sort (((Name, UncheckedPattern) -> Name)
-> [(Name, UncheckedPattern)] -> [Name]
forall a b. (a -> b) -> [a] -> [b]
map (Name, UncheckedPattern) -> Name
forall a b. (a, b) -> a
fst [(Name, UncheckedPattern)]
fields)) (TermTypeM () -> TermTypeM ()) -> TermTypeM () -> TermTypeM ()
forall a b. (a -> b) -> a -> b
$
    SrcLoc -> Notes -> Doc -> TermTypeM ()
forall (m :: * -> *) loc a.
(MonadTypeChecker m, Located loc) =>
loc -> Notes -> Doc -> m a
typeError SrcLoc
loc Notes
forall a. Monoid a => a
mempty (Doc -> TermTypeM ()) -> Doc -> TermTypeM ()
forall a b. (a -> b) -> a -> b
$ Doc
"Duplicate fields in record pattern" Doc -> Doc -> Doc
<+> UncheckedPattern -> Doc
forall a. Pretty a => a -> Doc
ppr UncheckedPattern
p Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
"."

  Usage -> StructType -> StructType -> TermTypeM ()
forall (m :: * -> *).
MonadUnify m =>
Usage -> StructType -> StructType -> m ()
unify (SrcLoc -> String -> Usage
mkUsage SrcLoc
loc String
"matching a record pattern") (ScalarTypeBase (DimDecl VName) () -> StructType
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (Map Name StructType -> ScalarTypeBase (DimDecl VName) ()
forall dim as. Map Name (TypeBase dim as) -> ScalarTypeBase dim as
Record Map Name StructType
fields')) (StructType -> TermTypeM ()) -> StructType -> TermTypeM ()
forall a b. (a -> b) -> a -> b
$ PatternType -> StructType
forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct PatternType
t
  PatternType
t' <- PatternType -> TermTypeM PatternType
forall a (m :: * -> *). (Substitutable a, MonadUnify m) => a -> m a
normTypeFully PatternType
t
  UncheckedPattern -> InferredType -> TermTypeM Pattern
checkPattern' UncheckedPattern
p (InferredType -> TermTypeM Pattern)
-> InferredType -> TermTypeM Pattern
forall a b. (a -> b) -> a -> b
$ PatternType -> InferredType
Ascribed PatternType
t'
checkPattern' (RecordPattern [(Name, UncheckedPattern)]
fs SrcLoc
loc) InferredType
NoneInferred =
  [(Name, Pattern)] -> SrcLoc -> Pattern
forall (f :: * -> *) vn.
[(Name, PatternBase f vn)] -> SrcLoc -> PatternBase f vn
RecordPattern ([(Name, Pattern)] -> SrcLoc -> Pattern)
-> (Map Name Pattern -> [(Name, Pattern)])
-> Map Name Pattern
-> SrcLoc
-> Pattern
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map Name Pattern -> [(Name, Pattern)]
forall k a. Map k a -> [(k, a)]
M.toList (Map Name Pattern -> SrcLoc -> Pattern)
-> TermTypeM (Map Name Pattern) -> TermTypeM (SrcLoc -> Pattern)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (UncheckedPattern -> TermTypeM Pattern)
-> Map Name UncheckedPattern -> TermTypeM (Map Name Pattern)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (UncheckedPattern -> InferredType -> TermTypeM Pattern
`checkPattern'` InferredType
NoneInferred) ([(Name, UncheckedPattern)] -> Map Name UncheckedPattern
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(Name, UncheckedPattern)]
fs) TermTypeM (SrcLoc -> Pattern)
-> TermTypeM SrcLoc -> TermTypeM Pattern
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SrcLoc -> TermTypeM SrcLoc
forall (f :: * -> *) a. Applicative f => a -> f a
pure SrcLoc
loc
checkPattern' (PatternAscription UncheckedPattern
p (TypeDecl TypeExp Name
t NoInfo StructType
NoInfo) SrcLoc
loc) InferredType
maybe_outer_t = do
  (TypeExp VName
t', StructType
st_nodims, Liftedness
_) <- TypeExp Name -> TermTypeM (TypeExp VName, StructType, Liftedness)
forall (m :: * -> *).
MonadTypeChecker m =>
TypeExp Name -> m (TypeExp VName, StructType, Liftedness)
checkTypeExp TypeExp Name
t
  (StructType
st, [VName]
_) <- SrcLoc
-> String
-> Rigidity
-> StructType
-> TermTypeM (StructType, [VName])
forall (m :: * -> *) als.
MonadUnify m =>
SrcLoc
-> String
-> Rigidity
-> TypeBase (DimDecl VName) als
-> m (TypeBase (DimDecl VName) als, [VName])
instantiateEmptyArrayDims SrcLoc
loc String
"impl" Rigidity
Nonrigid StructType
st_nodims

  let st' :: PatternType
st' = StructType -> PatternType
forall dim as. TypeBase dim as -> TypeBase dim Aliasing
fromStruct StructType
st
  case InferredType
maybe_outer_t of
    Ascribed PatternType
outer_t -> do
      Usage -> StructType -> StructType -> TermTypeM ()
forall (m :: * -> *).
MonadUnify m =>
Usage -> StructType -> StructType -> m ()
unify (SrcLoc -> String -> Usage
mkUsage SrcLoc
loc String
"explicit type ascription") (StructType -> StructType
forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct StructType
st) (PatternType -> StructType
forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct PatternType
outer_t)

      -- We also have to make sure that uniqueness matches.  This is
      -- done explicitly, because it is ignored by unification.
      PatternType
st'' <- PatternType -> TermTypeM PatternType
forall a (m :: * -> *). (Substitutable a, MonadUnify m) => a -> m a
normTypeFully PatternType
st'
      PatternType
outer_t' <- PatternType -> TermTypeM PatternType
forall a (m :: * -> *). (Substitutable a, MonadUnify m) => a -> m a
normTypeFully PatternType
outer_t
      case (Uniqueness -> Uniqueness -> Maybe Uniqueness)
-> PatternType -> PatternType -> Maybe PatternType
forall als dim.
(Monoid als, ArrayDim dim) =>
(Uniqueness -> Uniqueness -> Maybe Uniqueness)
-> TypeBase dim als -> TypeBase dim als -> Maybe (TypeBase dim als)
unifyTypesU Uniqueness -> Uniqueness -> Maybe Uniqueness
unifyUniqueness PatternType
st'' PatternType
outer_t' of
        Just PatternType
outer_t'' ->
          Pattern -> TypeDeclBase Info VName -> SrcLoc -> Pattern
forall (f :: * -> *) vn.
PatternBase f vn -> TypeDeclBase f vn -> SrcLoc -> PatternBase f vn
PatternAscription (Pattern -> TypeDeclBase Info VName -> SrcLoc -> Pattern)
-> TermTypeM Pattern
-> TermTypeM (TypeDeclBase Info VName -> SrcLoc -> Pattern)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> UncheckedPattern -> InferredType -> TermTypeM Pattern
checkPattern' UncheckedPattern
p (PatternType -> InferredType
Ascribed PatternType
outer_t'')
            TermTypeM (TypeDeclBase Info VName -> SrcLoc -> Pattern)
-> TermTypeM (TypeDeclBase Info VName)
-> TermTypeM (SrcLoc -> Pattern)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> TypeDeclBase Info VName -> TermTypeM (TypeDeclBase Info VName)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TypeExp VName -> Info StructType -> TypeDeclBase Info VName
forall (f :: * -> *) vn.
TypeExp vn -> f StructType -> TypeDeclBase f vn
TypeDecl TypeExp VName
t' (StructType -> Info StructType
forall a. a -> Info a
Info StructType
st))
            TermTypeM (SrcLoc -> Pattern)
-> TermTypeM SrcLoc -> TermTypeM Pattern
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SrcLoc -> TermTypeM SrcLoc
forall (f :: * -> *) a. Applicative f => a -> f a
pure SrcLoc
loc
        Maybe PatternType
Nothing ->
          SrcLoc -> Notes -> Doc -> TermTypeM Pattern
forall (m :: * -> *) loc a.
(MonadTypeChecker m, Located loc) =>
loc -> Notes -> Doc -> m a
typeError SrcLoc
loc Notes
forall a. Monoid a => a
mempty (Doc -> TermTypeM Pattern) -> Doc -> TermTypeM Pattern
forall a b. (a -> b) -> a -> b
$
            Doc
"Cannot match type" Doc -> Doc -> Doc
<+> Doc -> Doc
pquote (PatternType -> Doc
forall a. Pretty a => a -> Doc
ppr PatternType
outer_t') Doc -> Doc -> Doc
<+> Doc
"with expected type"
              Doc -> Doc -> Doc
<+> Doc -> Doc
pquote (PatternType -> Doc
forall a. Pretty a => a -> Doc
ppr PatternType
st'') Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
"."
    InferredType
NoneInferred ->
      Pattern -> TypeDeclBase Info VName -> SrcLoc -> Pattern
forall (f :: * -> *) vn.
PatternBase f vn -> TypeDeclBase f vn -> SrcLoc -> PatternBase f vn
PatternAscription (Pattern -> TypeDeclBase Info VName -> SrcLoc -> Pattern)
-> TermTypeM Pattern
-> TermTypeM (TypeDeclBase Info VName -> SrcLoc -> Pattern)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> UncheckedPattern -> InferredType -> TermTypeM Pattern
checkPattern' UncheckedPattern
p (PatternType -> InferredType
Ascribed PatternType
st')
        TermTypeM (TypeDeclBase Info VName -> SrcLoc -> Pattern)
-> TermTypeM (TypeDeclBase Info VName)
-> TermTypeM (SrcLoc -> Pattern)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> TypeDeclBase Info VName -> TermTypeM (TypeDeclBase Info VName)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TypeExp VName -> Info StructType -> TypeDeclBase Info VName
forall (f :: * -> *) vn.
TypeExp vn -> f StructType -> TypeDeclBase f vn
TypeDecl TypeExp VName
t' (StructType -> Info StructType
forall a. a -> Info a
Info StructType
st))
        TermTypeM (SrcLoc -> Pattern)
-> TermTypeM SrcLoc -> TermTypeM Pattern
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SrcLoc -> TermTypeM SrcLoc
forall (f :: * -> *) a. Applicative f => a -> f a
pure SrcLoc
loc
  where
    unifyUniqueness :: Uniqueness -> Uniqueness -> Maybe Uniqueness
unifyUniqueness Uniqueness
u1 Uniqueness
u2 = if Uniqueness
u2 Uniqueness -> Uniqueness -> Bool
`subuniqueOf` Uniqueness
u1 then Uniqueness -> Maybe Uniqueness
forall a. a -> Maybe a
Just Uniqueness
u1 else Maybe Uniqueness
forall a. Maybe a
Nothing
checkPattern' (PatternLit ExpBase NoInfo Name
e NoInfo PatternType
NoInfo SrcLoc
loc) (Ascribed PatternType
t) = do
  Exp
e' <- ExpBase NoInfo Name -> TermTypeM Exp
checkExp ExpBase NoInfo Name
e
  PatternType
t' <- Exp -> TermTypeM PatternType
expTypeFully Exp
e'
  Usage -> StructType -> StructType -> TermTypeM ()
forall (m :: * -> *).
MonadUnify m =>
Usage -> StructType -> StructType -> m ()
unify (SrcLoc -> String -> Usage
mkUsage SrcLoc
loc String
"matching against literal") (PatternType -> StructType
forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct PatternType
t') (PatternType -> StructType
forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct PatternType
t)
  Pattern -> TermTypeM Pattern
forall (m :: * -> *) a. Monad m => a -> m a
return (Pattern -> TermTypeM Pattern) -> Pattern -> TermTypeM Pattern
forall a b. (a -> b) -> a -> b
$ Exp -> Info PatternType -> SrcLoc -> Pattern
forall (f :: * -> *) vn.
ExpBase f vn -> f PatternType -> SrcLoc -> PatternBase f vn
PatternLit Exp
e' (PatternType -> Info PatternType
forall a. a -> Info a
Info PatternType
t') SrcLoc
loc
checkPattern' (PatternLit ExpBase NoInfo Name
e NoInfo PatternType
NoInfo SrcLoc
loc) InferredType
NoneInferred = do
  Exp
e' <- ExpBase NoInfo Name -> TermTypeM Exp
checkExp ExpBase NoInfo Name
e
  PatternType
t' <- Exp -> TermTypeM PatternType
expTypeFully Exp
e'
  Pattern -> TermTypeM Pattern
forall (m :: * -> *) a. Monad m => a -> m a
return (Pattern -> TermTypeM Pattern) -> Pattern -> TermTypeM Pattern
forall a b. (a -> b) -> a -> b
$ Exp -> Info PatternType -> SrcLoc -> Pattern
forall (f :: * -> *) vn.
ExpBase f vn -> f PatternType -> SrcLoc -> PatternBase f vn
PatternLit Exp
e' (PatternType -> Info PatternType
forall a. a -> Info a
Info PatternType
t') SrcLoc
loc
checkPattern' (PatternConstr Name
n NoInfo PatternType
NoInfo [UncheckedPattern]
ps SrcLoc
loc) (Ascribed (Scalar (Sum Map Name [PatternType]
cs)))
  | Just [PatternType]
ts <- Name -> Map Name [PatternType] -> Maybe [PatternType]
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Name
n Map Name [PatternType]
cs = do
    [Pattern]
ps' <- (UncheckedPattern -> InferredType -> TermTypeM Pattern)
-> [UncheckedPattern] -> [InferredType] -> TermTypeM [Pattern]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM UncheckedPattern -> InferredType -> TermTypeM Pattern
checkPattern' [UncheckedPattern]
ps ([InferredType] -> TermTypeM [Pattern])
-> [InferredType] -> TermTypeM [Pattern]
forall a b. (a -> b) -> a -> b
$ (PatternType -> InferredType) -> [PatternType] -> [InferredType]
forall a b. (a -> b) -> [a] -> [b]
map PatternType -> InferredType
Ascribed [PatternType]
ts
    Pattern -> TermTypeM Pattern
forall (m :: * -> *) a. Monad m => a -> m a
return (Pattern -> TermTypeM Pattern) -> Pattern -> TermTypeM Pattern
forall a b. (a -> b) -> a -> b
$ Name -> Info PatternType -> [Pattern] -> SrcLoc -> Pattern
forall (f :: * -> *) vn.
Name
-> f PatternType
-> [PatternBase f vn]
-> SrcLoc
-> PatternBase f vn
PatternConstr Name
n (PatternType -> Info PatternType
forall a. a -> Info a
Info (ScalarTypeBase (DimDecl VName) Aliasing -> PatternType
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (Map Name [PatternType] -> ScalarTypeBase (DimDecl VName) Aliasing
forall dim as. Map Name [TypeBase dim as] -> ScalarTypeBase dim as
Sum Map Name [PatternType]
cs))) [Pattern]
ps' SrcLoc
loc
checkPattern' (PatternConstr Name
n NoInfo PatternType
NoInfo [UncheckedPattern]
ps SrcLoc
loc) (Ascribed PatternType
t) = do
  StructType
t' <- SrcLoc -> String -> TermTypeM StructType
forall (m :: * -> *) als dim.
(MonadUnify m, Monoid als) =>
SrcLoc -> String -> m (TypeBase dim als)
newTypeVar SrcLoc
loc String
"t"
  [Pattern]
ps' <- (UncheckedPattern -> TermTypeM Pattern)
-> [UncheckedPattern] -> TermTypeM [Pattern]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (UncheckedPattern -> InferredType -> TermTypeM Pattern
`checkPattern'` InferredType
NoneInferred) [UncheckedPattern]
ps
  Usage -> Name -> StructType -> [StructType] -> TermTypeM ()
forall (m :: * -> *).
MonadUnify m =>
Usage -> Name -> StructType -> [StructType] -> m ()
mustHaveConstr Usage
usage Name
n StructType
t' (Pattern -> StructType
patternStructType (Pattern -> StructType) -> [Pattern] -> [StructType]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Pattern]
ps')
  Usage -> StructType -> StructType -> TermTypeM ()
forall (m :: * -> *).
MonadUnify m =>
Usage -> StructType -> StructType -> m ()
unify Usage
usage StructType
t' (PatternType -> StructType
forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct PatternType
t)
  PatternType
t'' <- PatternType -> TermTypeM PatternType
forall a (m :: * -> *). (Substitutable a, MonadUnify m) => a -> m a
normTypeFully PatternType
t
  Pattern -> TermTypeM Pattern
forall (m :: * -> *) a. Monad m => a -> m a
return (Pattern -> TermTypeM Pattern) -> Pattern -> TermTypeM Pattern
forall a b. (a -> b) -> a -> b
$ Name -> Info PatternType -> [Pattern] -> SrcLoc -> Pattern
forall (f :: * -> *) vn.
Name
-> f PatternType
-> [PatternBase f vn]
-> SrcLoc
-> PatternBase f vn
PatternConstr Name
n (PatternType -> Info PatternType
forall a. a -> Info a
Info PatternType
t'') [Pattern]
ps' SrcLoc
loc
  where
    usage :: Usage
usage = SrcLoc -> String -> Usage
mkUsage SrcLoc
loc String
"matching against constructor"
checkPattern' (PatternConstr Name
n NoInfo PatternType
NoInfo [UncheckedPattern]
ps SrcLoc
loc) InferredType
NoneInferred = do
  [Pattern]
ps' <- (UncheckedPattern -> TermTypeM Pattern)
-> [UncheckedPattern] -> TermTypeM [Pattern]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (UncheckedPattern -> InferredType -> TermTypeM Pattern
`checkPattern'` InferredType
NoneInferred) [UncheckedPattern]
ps
  StructType
t <- SrcLoc -> String -> TermTypeM StructType
forall (m :: * -> *) als dim.
(MonadUnify m, Monoid als) =>
SrcLoc -> String -> m (TypeBase dim als)
newTypeVar SrcLoc
loc String
"t"
  Usage -> Name -> StructType -> [StructType] -> TermTypeM ()
forall (m :: * -> *).
MonadUnify m =>
Usage -> Name -> StructType -> [StructType] -> m ()
mustHaveConstr Usage
usage Name
n StructType
t (Pattern -> StructType
patternStructType (Pattern -> StructType) -> [Pattern] -> [StructType]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Pattern]
ps')
  Pattern -> TermTypeM Pattern
forall (m :: * -> *) a. Monad m => a -> m a
return (Pattern -> TermTypeM Pattern) -> Pattern -> TermTypeM Pattern
forall a b. (a -> b) -> a -> b
$ Name -> Info PatternType -> [Pattern] -> SrcLoc -> Pattern
forall (f :: * -> *) vn.
Name
-> f PatternType
-> [PatternBase f vn]
-> SrcLoc
-> PatternBase f vn
PatternConstr Name
n (PatternType -> Info PatternType
forall a. a -> Info a
Info (PatternType -> Info PatternType)
-> PatternType -> Info PatternType
forall a b. (a -> b) -> a -> b
$ StructType -> PatternType
forall dim as. TypeBase dim as -> TypeBase dim Aliasing
fromStruct StructType
t) [Pattern]
ps' SrcLoc
loc
  where
    usage :: Usage
usage = SrcLoc -> String -> Usage
mkUsage SrcLoc
loc String
"matching against constructor"

patternNameMap :: Pattern -> NameMap
patternNameMap :: Pattern -> NameMap
patternNameMap = [((Namespace, Name), QualName VName)] -> NameMap
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([((Namespace, Name), QualName VName)] -> NameMap)
-> (Pattern -> [((Namespace, Name), QualName VName)])
-> Pattern
-> NameMap
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> ((Namespace, Name), QualName VName))
-> [VName] -> [((Namespace, Name), QualName VName)]
forall a b. (a -> b) -> [a] -> [b]
map VName -> ((Namespace, Name), QualName VName)
asTerm ([VName] -> [((Namespace, Name), QualName VName)])
-> (Pattern -> [VName])
-> Pattern
-> [((Namespace, Name), QualName VName)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Names -> [VName]
forall a. Set a -> [a]
S.toList (Names -> [VName]) -> (Pattern -> Names) -> Pattern -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pattern -> Names
forall (f :: * -> *) vn.
(Functor f, Ord vn) =>
PatternBase f vn -> Set vn
patternNames
  where
    asTerm :: VName -> ((Namespace, Name), QualName VName)
asTerm VName
v = ((Namespace
Term, VName -> Name
baseName VName
v), VName -> QualName VName
forall v. v -> QualName v
qualName VName
v)

checkPattern ::
  UncheckedPattern ->
  InferredType ->
  (Pattern -> TermTypeM a) ->
  TermTypeM a
checkPattern :: UncheckedPattern
-> InferredType -> (Pattern -> TermTypeM a) -> TermTypeM a
checkPattern UncheckedPattern
p InferredType
t Pattern -> TermTypeM a
m = do
  [UncheckedPattern] -> TermTypeM ()
forall (m :: * -> *).
MonadTypeChecker m =>
[UncheckedPattern] -> m ()
checkForDuplicateNames [UncheckedPattern
p]
  Pattern
p' <- Checking -> TermTypeM Pattern -> TermTypeM Pattern
forall a. Checking -> TermTypeM a -> TermTypeM a
onFailure (UncheckedPattern -> InferredType -> Checking
CheckingPattern UncheckedPattern
p InferredType
t) (TermTypeM Pattern -> TermTypeM Pattern)
-> TermTypeM Pattern -> TermTypeM Pattern
forall a b. (a -> b) -> a -> b
$ UncheckedPattern -> InferredType -> TermTypeM Pattern
checkPattern' UncheckedPattern
p InferredType
t
  NameMap -> TermTypeM a -> TermTypeM a
forall (m :: * -> *) a. MonadTypeChecker m => NameMap -> m a -> m a
bindNameMap (Pattern -> NameMap
patternNameMap Pattern
p') (TermTypeM a -> TermTypeM a) -> TermTypeM a -> TermTypeM a
forall a b. (a -> b) -> a -> b
$ Pattern -> TermTypeM a
m Pattern
p'

binding :: [Ident] -> TermTypeM a -> TermTypeM a
binding :: [Ident] -> TermTypeM a -> TermTypeM a
binding [Ident]
bnds = TermTypeM a -> TermTypeM a
forall b. TermTypeM b -> TermTypeM b
check (TermTypeM a -> TermTypeM a)
-> (TermTypeM a -> TermTypeM a) -> TermTypeM a -> TermTypeM a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TermTypeM a -> TermTypeM a
forall b. TermTypeM b -> TermTypeM b
handleVars
  where
    handleVars :: TermTypeM a -> TermTypeM a
handleVars TermTypeM a
m =
      (TermScope -> TermScope) -> TermTypeM a -> TermTypeM a
forall a. (TermScope -> TermScope) -> TermTypeM a -> TermTypeM a
localScope (TermScope -> [Ident] -> TermScope
`bindVars` [Ident]
bnds) (TermTypeM a -> TermTypeM a) -> TermTypeM a -> TermTypeM a
forall a b. (a -> b) -> a -> b
$ do
        -- Those identifiers that can potentially also be sizes are
        -- added as type constraints.  This is necessary so that we
        -- can properly detect scope violations during unification.
        -- We do this for *all* identifiers, not just those that are
        -- integers, because they may become integers later due to
        -- inference...
        [Ident] -> (Ident -> TermTypeM ()) -> TermTypeM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Ident]
bnds ((Ident -> TermTypeM ()) -> TermTypeM ())
-> (Ident -> TermTypeM ()) -> TermTypeM ()
forall a b. (a -> b) -> a -> b
$ \Ident
ident ->
          VName -> Constraint -> TermTypeM ()
constrain (Ident -> VName
forall (f :: * -> *) vn. IdentBase f vn -> vn
identName Ident
ident) (Constraint -> TermTypeM ()) -> Constraint -> TermTypeM ()
forall a b. (a -> b) -> a -> b
$ SrcLoc -> Constraint
ParamSize (SrcLoc -> Constraint) -> SrcLoc -> Constraint
forall a b. (a -> b) -> a -> b
$ Ident -> SrcLoc
forall a. Located a => a -> SrcLoc
srclocOf Ident
ident
        TermTypeM a
m

    bindVars :: TermScope -> [Ident] -> TermScope
    bindVars :: TermScope -> [Ident] -> TermScope
bindVars = (TermScope -> Ident -> TermScope)
-> TermScope -> [Ident] -> TermScope
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl TermScope -> Ident -> TermScope
bindVar

    bindVar :: TermScope -> Ident -> TermScope
    bindVar :: TermScope -> Ident -> TermScope
bindVar TermScope
scope (Ident VName
name (Info PatternType
tp) SrcLoc
_) =
      let inedges :: Names
inedges = Aliasing -> Names
boundAliases (Aliasing -> Names) -> Aliasing -> Names
forall a b. (a -> b) -> a -> b
$ PatternType -> Aliasing
forall as shape. Monoid as => TypeBase shape as -> as
aliases PatternType
tp
          update :: ValBinding -> ValBinding
update (BoundV Locality
l [TypeParam]
tparams PatternType
in_t)
            -- If 'name' is record or sum-typed, don't alias the
            -- components to 'name', because these no identity
            -- beyond their components.
            | Array {} <- PatternType
tp = Locality -> [TypeParam] -> PatternType -> ValBinding
BoundV Locality
l [TypeParam]
tparams (PatternType
in_t PatternType -> (Aliasing -> Aliasing) -> PatternType
forall dim asf ast.
TypeBase dim asf -> (asf -> ast) -> TypeBase dim ast
`addAliases` Alias -> Aliasing -> Aliasing
forall a. Ord a => a -> Set a -> Set a
S.insert (VName -> Alias
AliasBound VName
name))
            | Bool
otherwise = Locality -> [TypeParam] -> PatternType -> ValBinding
BoundV Locality
l [TypeParam]
tparams PatternType
in_t
          update ValBinding
b = ValBinding
b

          tp' :: PatternType
tp' = PatternType
tp PatternType -> (Aliasing -> Aliasing) -> PatternType
forall dim asf ast.
TypeBase dim asf -> (asf -> ast) -> TypeBase dim ast
`addAliases` Alias -> Aliasing -> Aliasing
forall a. Ord a => a -> Set a -> Set a
S.insert (VName -> Alias
AliasBound VName
name)
       in TermScope
scope
            { scopeVtable :: Map VName ValBinding
scopeVtable =
                VName -> ValBinding -> Map VName ValBinding -> Map VName ValBinding
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
name (Locality -> [TypeParam] -> PatternType -> ValBinding
BoundV Locality
Local [] PatternType
tp') (Map VName ValBinding -> Map VName ValBinding)
-> Map VName ValBinding -> Map VName ValBinding
forall a b. (a -> b) -> a -> b
$
                  (ValBinding -> ValBinding)
-> Names -> Map VName ValBinding -> Map VName ValBinding
forall (t :: * -> *) k a.
(Foldable t, Ord k) =>
(a -> a) -> t k -> Map k a -> Map k a
adjustSeveral ValBinding -> ValBinding
update Names
inedges (Map VName ValBinding -> Map VName ValBinding)
-> Map VName ValBinding -> Map VName ValBinding
forall a b. (a -> b) -> a -> b
$
                    TermScope -> Map VName ValBinding
scopeVtable TermScope
scope
            }

    adjustSeveral :: (a -> a) -> t k -> Map k a -> Map k a
adjustSeveral a -> a
f = (Map k a -> t k -> Map k a) -> t k -> Map k a -> Map k a
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((Map k a -> t k -> Map k a) -> t k -> Map k a -> Map k a)
-> (Map k a -> t k -> Map k a) -> t k -> Map k a -> Map k a
forall a b. (a -> b) -> a -> b
$ (Map k a -> k -> Map k a) -> Map k a -> t k -> Map k a
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl ((Map k a -> k -> Map k a) -> Map k a -> t k -> Map k a)
-> (Map k a -> k -> Map k a) -> Map k a -> t k -> Map k a
forall a b. (a -> b) -> a -> b
$ (k -> Map k a -> Map k a) -> Map k a -> k -> Map k a
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((k -> Map k a -> Map k a) -> Map k a -> k -> Map k a)
-> (k -> Map k a -> Map k a) -> Map k a -> k -> Map k a
forall a b. (a -> b) -> a -> b
$ (a -> a) -> k -> Map k a -> Map k a
forall k a. Ord k => (a -> a) -> k -> Map k a -> Map k a
M.adjust a -> a
f

    -- Check whether the bound variables have been used correctly
    -- within their scope.
    check :: TermTypeM b -> TermTypeM b
check TermTypeM b
m = do
      (b
a, [Occurence]
usages) <- TermTypeM b -> TermTypeM (b, [Occurence])
forall (m :: * -> *) a.
MonadWriter [Occurence] m =>
m a -> m (a, [Occurence])
collectBindingsOccurences TermTypeM b
m
      [Occurence] -> TermTypeM ()
checkOccurences [Occurence]
usages

      (Ident -> TermTypeM ()) -> [Ident] -> TermTypeM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ([Occurence] -> Ident -> TermTypeM ()
checkIfUsed [Occurence]
usages) [Ident]
bnds

      b -> TermTypeM b
forall (m :: * -> *) a. Monad m => a -> m a
return b
a

    -- Collect and remove all occurences in @bnds@.  This relies
    -- on the fact that no variables shadow any other.
    collectBindingsOccurences :: m a -> m (a, [Occurence])
collectBindingsOccurences m a
m = m ((a, [Occurence]), [Occurence] -> [Occurence])
-> m (a, [Occurence])
forall w (m :: * -> *) a. MonadWriter w m => m (a, w -> w) -> m a
pass (m ((a, [Occurence]), [Occurence] -> [Occurence])
 -> m (a, [Occurence]))
-> m ((a, [Occurence]), [Occurence] -> [Occurence])
-> m (a, [Occurence])
forall a b. (a -> b) -> a -> b
$ do
      (a
x, [Occurence]
usage) <- m a -> m (a, [Occurence])
forall w (m :: * -> *) a. MonadWriter w m => m a -> m (a, w)
listen m a
m
      let ([Occurence]
relevant, [Occurence]
rest) = [Occurence] -> ([Occurence], [Occurence])
split [Occurence]
usage
      ((a, [Occurence]), [Occurence] -> [Occurence])
-> m ((a, [Occurence]), [Occurence] -> [Occurence])
forall (m :: * -> *) a. Monad m => a -> m a
return ((a
x, [Occurence]
relevant), [Occurence] -> [Occurence] -> [Occurence]
forall a b. a -> b -> a
const [Occurence]
rest)
      where
        split :: [Occurence] -> ([Occurence], [Occurence])
split =
          [(Occurence, Occurence)] -> ([Occurence], [Occurence])
forall a b. [(a, b)] -> ([a], [b])
unzip
            ([(Occurence, Occurence)] -> ([Occurence], [Occurence]))
-> ([Occurence] -> [(Occurence, Occurence)])
-> [Occurence]
-> ([Occurence], [Occurence])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Occurence -> (Occurence, Occurence))
-> [Occurence] -> [(Occurence, Occurence)]
forall a b. (a -> b) -> [a] -> [b]
map
              ( \Occurence
occ ->
                  let (Names
obs1, Names
obs2) = Names -> (Names, Names)
divide (Names -> (Names, Names)) -> Names -> (Names, Names)
forall a b. (a -> b) -> a -> b
$ Occurence -> Names
observed Occurence
occ
                      occ_cons :: Maybe (Names, Names)
occ_cons = Names -> (Names, Names)
divide (Names -> (Names, Names)) -> Maybe Names -> Maybe (Names, Names)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Occurence -> Maybe Names
consumed Occurence
occ
                      con1 :: Maybe Names
con1 = (Names, Names) -> Names
forall a b. (a, b) -> a
fst ((Names, Names) -> Names) -> Maybe (Names, Names) -> Maybe Names
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe (Names, Names)
occ_cons
                      con2 :: Maybe Names
con2 = (Names, Names) -> Names
forall a b. (a, b) -> b
snd ((Names, Names) -> Names) -> Maybe (Names, Names) -> Maybe Names
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe (Names, Names)
occ_cons
                   in ( Occurence
occ {observed :: Names
observed = Names
obs1, consumed :: Maybe Names
consumed = Maybe Names
con1},
                        Occurence
occ {observed :: Names
observed = Names
obs2, consumed :: Maybe Names
consumed = Maybe Names
con2}
                      )
              )
        names :: Names
names = [VName] -> Names
forall a. Ord a => [a] -> Set a
S.fromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ (Ident -> VName) -> [Ident] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Ident -> VName
forall (f :: * -> *) vn. IdentBase f vn -> vn
identName [Ident]
bnds
        divide :: Names -> (Names, Names)
divide Names
s = (Names
s Names -> Names -> Names
forall a. Ord a => Set a -> Set a -> Set a
`S.intersection` Names
names, Names
s Names -> Names -> Names
forall a. Ord a => Set a -> Set a -> Set a
`S.difference` Names
names)

bindingTypes ::
  [Either (VName, TypeBinding) (VName, Constraint)] ->
  TermTypeM a ->
  TermTypeM a
bindingTypes :: [Either (VName, TypeBinding) (VName, Constraint)]
-> TermTypeM a -> TermTypeM a
bindingTypes [Either (VName, TypeBinding) (VName, Constraint)]
types TermTypeM a
m = do
  Int
lvl <- TermTypeM Int
forall (m :: * -> *). MonadUnify m => m Int
curLevel
  (Constraints -> Constraints) -> TermTypeM ()
forall (m :: * -> *).
MonadUnify m =>
(Constraints -> Constraints) -> m ()
modifyConstraints (Constraints -> Constraints -> Constraints
forall a. Semigroup a => a -> a -> a
<> (Constraint -> (Int, Constraint))
-> Map VName Constraint -> Constraints
forall a b k. (a -> b) -> Map k a -> Map k b
M.map (Int
lvl,) ([(VName, Constraint)] -> Map VName Constraint
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(VName, Constraint)]
constraints))
  (TermScope -> TermScope) -> TermTypeM a -> TermTypeM a
forall a. (TermScope -> TermScope) -> TermTypeM a -> TermTypeM a
localScope TermScope -> TermScope
extend TermTypeM a
m
  where
    ([(VName, TypeBinding)]
tbinds, [(VName, Constraint)]
constraints) = [Either (VName, TypeBinding) (VName, Constraint)]
-> ([(VName, TypeBinding)], [(VName, Constraint)])
forall a b. [Either a b] -> ([a], [b])
partitionEithers [Either (VName, TypeBinding) (VName, Constraint)]
types
    extend :: TermScope -> TermScope
extend TermScope
scope =
      TermScope
scope
        { scopeTypeTable :: Map VName TypeBinding
scopeTypeTable = [(VName, TypeBinding)] -> Map VName TypeBinding
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(VName, TypeBinding)]
tbinds Map VName TypeBinding
-> Map VName TypeBinding -> Map VName TypeBinding
forall a. Semigroup a => a -> a -> a
<> TermScope -> Map VName TypeBinding
scopeTypeTable TermScope
scope
        }

bindingTypeParams :: [TypeParam] -> TermTypeM a -> TermTypeM a
bindingTypeParams :: [TypeParam] -> TermTypeM a -> TermTypeM a
bindingTypeParams [TypeParam]
tparams =
  [Ident] -> TermTypeM a -> TermTypeM a
forall a. [Ident] -> TermTypeM a -> TermTypeM a
binding ((TypeParam -> Maybe Ident) -> [TypeParam] -> [Ident]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe TypeParam -> Maybe Ident
typeParamIdent [TypeParam]
tparams)
    (TermTypeM a -> TermTypeM a)
-> (TermTypeM a -> TermTypeM a) -> TermTypeM a -> TermTypeM a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Either (VName, TypeBinding) (VName, Constraint)]
-> TermTypeM a -> TermTypeM a
forall a.
[Either (VName, TypeBinding) (VName, Constraint)]
-> TermTypeM a -> TermTypeM a
bindingTypes ((TypeParam -> [Either (VName, TypeBinding) (VName, Constraint)])
-> [TypeParam] -> [Either (VName, TypeBinding) (VName, Constraint)]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap TypeParam -> [Either (VName, TypeBinding) (VName, Constraint)]
typeParamType [TypeParam]
tparams)
  where
    typeParamType :: TypeParam -> [Either (VName, TypeBinding) (VName, Constraint)]
typeParamType (TypeParamType Liftedness
l VName
v SrcLoc
loc) =
      [ (VName, TypeBinding)
-> Either (VName, TypeBinding) (VName, Constraint)
forall a b. a -> Either a b
Left (VName
v, Liftedness -> [TypeParam] -> StructType -> TypeBinding
TypeAbbr Liftedness
l [] (ScalarTypeBase (DimDecl VName) () -> StructType
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (()
-> Uniqueness
-> TypeName
-> [TypeArg (DimDecl VName)]
-> ScalarTypeBase (DimDecl VName) ()
forall dim as.
as
-> Uniqueness -> TypeName -> [TypeArg dim] -> ScalarTypeBase dim as
TypeVar () Uniqueness
Nonunique (VName -> TypeName
typeName VName
v) []))),
        (VName, Constraint)
-> Either (VName, TypeBinding) (VName, Constraint)
forall a b. b -> Either a b
Right (VName
v, Liftedness -> SrcLoc -> Constraint
ParamType Liftedness
l SrcLoc
loc)
      ]
    typeParamType (TypeParamDim VName
v SrcLoc
loc) =
      [(VName, Constraint)
-> Either (VName, TypeBinding) (VName, Constraint)
forall a b. b -> Either a b
Right (VName
v, SrcLoc -> Constraint
ParamSize SrcLoc
loc)]

typeParamIdent :: TypeParam -> Maybe Ident
typeParamIdent :: TypeParam -> Maybe Ident
typeParamIdent (TypeParamDim VName
v SrcLoc
loc) =
  Ident -> Maybe Ident
forall a. a -> Maybe a
Just (Ident -> Maybe Ident) -> Ident -> Maybe Ident
forall a b. (a -> b) -> a -> b
$ VName -> Info PatternType -> SrcLoc -> Ident
forall (f :: * -> *) vn.
vn -> f PatternType -> SrcLoc -> IdentBase f vn
Ident VName
v (PatternType -> Info PatternType
forall a. a -> Info a
Info (PatternType -> Info PatternType)
-> PatternType -> Info PatternType
forall a b. (a -> b) -> a -> b
$ ScalarTypeBase (DimDecl VName) Aliasing -> PatternType
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (ScalarTypeBase (DimDecl VName) Aliasing -> PatternType)
-> ScalarTypeBase (DimDecl VName) Aliasing -> PatternType
forall a b. (a -> b) -> a -> b
$ PrimType -> ScalarTypeBase (DimDecl VName) Aliasing
forall dim as. PrimType -> ScalarTypeBase dim as
Prim (PrimType -> ScalarTypeBase (DimDecl VName) Aliasing)
-> PrimType -> ScalarTypeBase (DimDecl VName) Aliasing
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
Signed IntType
Int32) SrcLoc
loc
typeParamIdent TypeParam
_ = Maybe Ident
forall a. Maybe a
Nothing

bindingIdent ::
  IdentBase NoInfo Name ->
  PatternType ->
  (Ident -> TermTypeM a) ->
  TermTypeM a
bindingIdent :: IdentBase NoInfo Name
-> PatternType -> (Ident -> TermTypeM a) -> TermTypeM a
bindingIdent (Ident Name
v NoInfo PatternType
NoInfo SrcLoc
vloc) PatternType
t Ident -> TermTypeM a
m =
  [(Namespace, Name)] -> TermTypeM a -> TermTypeM a
forall (m :: * -> *) a.
MonadTypeChecker m =>
[(Namespace, Name)] -> m a -> m a
bindSpaced [(Namespace
Term, Name
v)] (TermTypeM a -> TermTypeM a) -> TermTypeM a -> TermTypeM a
forall a b. (a -> b) -> a -> b
$ do
    VName
v' <- Namespace -> Name -> SrcLoc -> TermTypeM VName
forall (m :: * -> *).
MonadTypeChecker m =>
Namespace -> Name -> SrcLoc -> m VName
checkName Namespace
Term Name
v SrcLoc
vloc
    let ident :: Ident
ident = VName -> Info PatternType -> SrcLoc -> Ident
forall (f :: * -> *) vn.
vn -> f PatternType -> SrcLoc -> IdentBase f vn
Ident VName
v' (PatternType -> Info PatternType
forall a. a -> Info a
Info PatternType
t) SrcLoc
vloc
    [Ident] -> TermTypeM a -> TermTypeM a
forall a. [Ident] -> TermTypeM a -> TermTypeM a
binding [Ident
ident] (TermTypeM a -> TermTypeM a) -> TermTypeM a -> TermTypeM a
forall a b. (a -> b) -> a -> b
$ Ident -> TermTypeM a
m Ident
ident

bindingParams ::
  [UncheckedTypeParam] ->
  [UncheckedPattern] ->
  ([TypeParam] -> [Pattern] -> TermTypeM a) ->
  TermTypeM a
bindingParams :: [UncheckedTypeParam]
-> [UncheckedPattern]
-> ([TypeParam] -> [Pattern] -> TermTypeM a)
-> TermTypeM a
bindingParams [UncheckedTypeParam]
tps [UncheckedPattern]
orig_ps [TypeParam] -> [Pattern] -> TermTypeM a
m = do
  [UncheckedPattern] -> TermTypeM ()
forall (m :: * -> *).
MonadTypeChecker m =>
[UncheckedPattern] -> m ()
checkForDuplicateNames [UncheckedPattern]
orig_ps
  [UncheckedTypeParam] -> ([TypeParam] -> TermTypeM a) -> TermTypeM a
forall (m :: * -> *) a.
MonadTypeChecker m =>
[UncheckedTypeParam] -> ([TypeParam] -> m a) -> m a
checkTypeParams [UncheckedTypeParam]
tps (([TypeParam] -> TermTypeM a) -> TermTypeM a)
-> ([TypeParam] -> TermTypeM a) -> TermTypeM a
forall a b. (a -> b) -> a -> b
$ \[TypeParam]
tps' -> [TypeParam] -> TermTypeM a -> TermTypeM a
forall a. [TypeParam] -> TermTypeM a -> TermTypeM a
bindingTypeParams [TypeParam]
tps' (TermTypeM a -> TermTypeM a) -> TermTypeM a -> TermTypeM a
forall a b. (a -> b) -> a -> b
$ do
    let descend :: [Pattern] -> [UncheckedPattern] -> TermTypeM a
descend [Pattern]
ps' (UncheckedPattern
p : [UncheckedPattern]
ps) =
          UncheckedPattern
-> InferredType -> (Pattern -> TermTypeM a) -> TermTypeM a
forall a.
UncheckedPattern
-> InferredType -> (Pattern -> TermTypeM a) -> TermTypeM a
checkPattern UncheckedPattern
p InferredType
NoneInferred ((Pattern -> TermTypeM a) -> TermTypeM a)
-> (Pattern -> TermTypeM a) -> TermTypeM a
forall a b. (a -> b) -> a -> b
$ \Pattern
p' ->
            [Ident] -> TermTypeM a -> TermTypeM a
forall a. [Ident] -> TermTypeM a -> TermTypeM a
binding (Set Ident -> [Ident]
forall a. Set a -> [a]
S.toList (Set Ident -> [Ident]) -> Set Ident -> [Ident]
forall a b. (a -> b) -> a -> b
$ Pattern -> Set Ident
forall (f :: * -> *) vn.
(Functor f, Ord vn) =>
PatternBase f vn -> Set (IdentBase f vn)
patternIdents Pattern
p') (TermTypeM a -> TermTypeM a) -> TermTypeM a -> TermTypeM a
forall a b. (a -> b) -> a -> b
$ [Pattern] -> [UncheckedPattern] -> TermTypeM a
descend (Pattern
p' Pattern -> [Pattern] -> [Pattern]
forall a. a -> [a] -> [a]
: [Pattern]
ps') [UncheckedPattern]
ps
        descend [Pattern]
ps' [] = do
          -- Perform an observation of every type parameter.  This
          -- prevents unused-name warnings for otherwise unused
          -- dimensions.
          (Ident -> TermTypeM ()) -> [Ident] -> TermTypeM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Ident -> TermTypeM ()
observe ([Ident] -> TermTypeM ()) -> [Ident] -> TermTypeM ()
forall a b. (a -> b) -> a -> b
$ (TypeParam -> Maybe Ident) -> [TypeParam] -> [Ident]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe TypeParam -> Maybe Ident
typeParamIdent [TypeParam]
tps'
          [TypeParam] -> [Pattern] -> TermTypeM a
m [TypeParam]
tps' ([Pattern] -> TermTypeM a) -> [Pattern] -> TermTypeM a
forall a b. (a -> b) -> a -> b
$ [Pattern] -> [Pattern]
forall a. [a] -> [a]
reverse [Pattern]
ps'

    [Pattern] -> [UncheckedPattern] -> TermTypeM a
descend [] [UncheckedPattern]
orig_ps

bindingPattern ::
  PatternBase NoInfo Name ->
  InferredType ->
  (Pattern -> TermTypeM a) ->
  TermTypeM a
bindingPattern :: UncheckedPattern
-> InferredType -> (Pattern -> TermTypeM a) -> TermTypeM a
bindingPattern UncheckedPattern
p InferredType
t Pattern -> TermTypeM a
m = do
  [UncheckedPattern] -> TermTypeM ()
forall (m :: * -> *).
MonadTypeChecker m =>
[UncheckedPattern] -> m ()
checkForDuplicateNames [UncheckedPattern
p]
  UncheckedPattern
-> InferredType -> (Pattern -> TermTypeM a) -> TermTypeM a
forall a.
UncheckedPattern
-> InferredType -> (Pattern -> TermTypeM a) -> TermTypeM a
checkPattern UncheckedPattern
p InferredType
t ((Pattern -> TermTypeM a) -> TermTypeM a)
-> (Pattern -> TermTypeM a) -> TermTypeM a
forall a b. (a -> b) -> a -> b
$ \Pattern
p' -> [Ident] -> TermTypeM a -> TermTypeM a
forall a. [Ident] -> TermTypeM a -> TermTypeM a
binding (Set Ident -> [Ident]
forall a. Set a -> [a]
S.toList (Set Ident -> [Ident]) -> Set Ident -> [Ident]
forall a b. (a -> b) -> a -> b
$ Pattern -> Set Ident
forall (f :: * -> *) vn.
(Functor f, Ord vn) =>
PatternBase f vn -> Set (IdentBase f vn)
patternIdents Pattern
p') (TermTypeM a -> TermTypeM a) -> TermTypeM a -> TermTypeM a
forall a b. (a -> b) -> a -> b
$ do
    -- Perform an observation of every declared dimension.  This
    -- prevents unused-name warnings for otherwise unused dimensions.
    (Ident -> TermTypeM ()) -> [Ident] -> TermTypeM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Ident -> TermTypeM ()
observe ([Ident] -> TermTypeM ()) -> [Ident] -> TermTypeM ()
forall a b. (a -> b) -> a -> b
$ Pattern -> [Ident]
patternDims Pattern
p'

    Pattern -> TermTypeM a
m Pattern
p'

patternDims :: Pattern -> [Ident]
patternDims :: Pattern -> [Ident]
patternDims (PatternParens Pattern
p SrcLoc
_) = Pattern -> [Ident]
patternDims Pattern
p
patternDims (TuplePattern [Pattern]
pats SrcLoc
_) = (Pattern -> [Ident]) -> [Pattern] -> [Ident]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Pattern -> [Ident]
patternDims [Pattern]
pats
patternDims (PatternAscription Pattern
p (TypeDecl TypeExp VName
_ (Info StructType
t)) SrcLoc
_) =
  Pattern -> [Ident]
patternDims Pattern
p [Ident] -> [Ident] -> [Ident]
forall a. Semigroup a => a -> a -> a
<> (DimDecl VName -> Maybe Ident) -> [DimDecl VName] -> [Ident]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (SrcLoc -> DimDecl VName -> Maybe Ident
forall p vn a. p -> DimDecl vn -> Maybe a
dimIdent (Pattern -> SrcLoc
forall a. Located a => a -> SrcLoc
srclocOf Pattern
p)) (StructType -> [DimDecl VName]
forall as. TypeBase (DimDecl VName) as -> [DimDecl VName]
nestedDims StructType
t)
  where
    dimIdent :: p -> DimDecl vn -> Maybe a
dimIdent p
_ DimDecl vn
AnyDim = Maybe a
forall a. Maybe a
Nothing
    dimIdent p
_ (ConstDim Int
_) = Maybe a
forall a. Maybe a
Nothing
    dimIdent p
_ NamedDim {} = Maybe a
forall a. Maybe a
Nothing
patternDims Pattern
_ = []

sliceShape ::
  Maybe (SrcLoc, Rigidity) ->
  [DimIndex] ->
  TypeBase (DimDecl VName) as ->
  TermTypeM (TypeBase (DimDecl VName) as, [VName])
sliceShape :: Maybe (SrcLoc, Rigidity)
-> [DimIndex]
-> TypeBase (DimDecl VName) as
-> TermTypeM (TypeBase (DimDecl VName) as, [VName])
sliceShape Maybe (SrcLoc, Rigidity)
r [DimIndex]
slice t :: TypeBase (DimDecl VName) as
t@(Array as
als Uniqueness
u ScalarTypeBase (DimDecl VName) ()
et (ShapeDecl [DimDecl VName]
orig_dims)) =
  WriterT [VName] TermTypeM (TypeBase (DimDecl VName) as)
-> TermTypeM (TypeBase (DimDecl VName) as, [VName])
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT [VName] TermTypeM (TypeBase (DimDecl VName) as)
 -> TermTypeM (TypeBase (DimDecl VName) as, [VName]))
-> WriterT [VName] TermTypeM (TypeBase (DimDecl VName) as)
-> TermTypeM (TypeBase (DimDecl VName) as, [VName])
forall a b. (a -> b) -> a -> b
$ [DimDecl VName] -> TypeBase (DimDecl VName) as
setDims ([DimDecl VName] -> TypeBase (DimDecl VName) as)
-> WriterT [VName] TermTypeM [DimDecl VName]
-> WriterT [VName] TermTypeM (TypeBase (DimDecl VName) as)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [DimIndex]
-> [DimDecl VName] -> WriterT [VName] TermTypeM [DimDecl VName]
forall (t :: (* -> *) -> * -> *).
(MonadTrans t, MonadWriter [VName] (t TermTypeM)) =>
[DimIndex] -> [DimDecl VName] -> t TermTypeM [DimDecl VName]
adjustDims [DimIndex]
slice [DimDecl VName]
orig_dims
  where
    setDims :: [DimDecl VName] -> TypeBase (DimDecl VName) as
setDims [] = Int -> TypeBase (DimDecl VName) as -> TypeBase (DimDecl VName) as
forall dim as. Int -> TypeBase dim as -> TypeBase dim as
stripArray ([DimDecl VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimDecl VName]
orig_dims) TypeBase (DimDecl VName) as
t
    setDims [DimDecl VName]
dims' = as
-> Uniqueness
-> ScalarTypeBase (DimDecl VName) ()
-> ShapeDecl (DimDecl VName)
-> TypeBase (DimDecl VName) as
forall dim as.
as
-> Uniqueness
-> ScalarTypeBase dim ()
-> ShapeDecl dim
-> TypeBase dim as
Array as
als Uniqueness
u ScalarTypeBase (DimDecl VName) ()
et (ShapeDecl (DimDecl VName) -> TypeBase (DimDecl VName) as)
-> ShapeDecl (DimDecl VName) -> TypeBase (DimDecl VName) as
forall a b. (a -> b) -> a -> b
$ [DimDecl VName] -> ShapeDecl (DimDecl VName)
forall dim. [dim] -> ShapeDecl dim
ShapeDecl [DimDecl VName]
dims'

    -- If the result is supposed to be AnyDim or a nonrigid size
    -- variable, then don't bother trying to create
    -- non-existential sizes.  This is necessary to make programs
    -- type-check without too much ceremony; see
    -- e.g. tests/inplace5.fut.
    isRigid :: Rigidity -> Bool
isRigid Rigid {} = Bool
True
    isRigid Rigidity
_ = Bool
False
    refine_sizes :: Bool
refine_sizes = Bool
-> ((SrcLoc, Rigidity) -> Bool) -> Maybe (SrcLoc, Rigidity) -> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
False (Rigidity -> Bool
isRigid (Rigidity -> Bool)
-> ((SrcLoc, Rigidity) -> Rigidity) -> (SrcLoc, Rigidity) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SrcLoc, Rigidity) -> Rigidity
forall a b. (a, b) -> b
snd) Maybe (SrcLoc, Rigidity)
r

    sliceSize :: DimDecl VName
-> Maybe Exp
-> Maybe Exp
-> Maybe Exp
-> t TermTypeM (DimDecl VName)
sliceSize DimDecl VName
orig_d Maybe Exp
i Maybe Exp
j Maybe Exp
stride =
      case Maybe (SrcLoc, Rigidity)
r of
        Just (SrcLoc
loc, Rigid RigidSource
_) -> do
          (DimDecl VName
d, Maybe VName
ext) <-
            TermTypeM (DimDecl VName, Maybe VName)
-> t TermTypeM (DimDecl VName, Maybe VName)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (TermTypeM (DimDecl VName, Maybe VName)
 -> t TermTypeM (DimDecl VName, Maybe VName))
-> TermTypeM (DimDecl VName, Maybe VName)
-> t TermTypeM (DimDecl VName, Maybe VName)
forall a b. (a -> b) -> a -> b
$
              SrcLoc -> SizeSource -> TermTypeM (DimDecl VName, Maybe VName)
extSize SrcLoc
loc (SizeSource -> TermTypeM (DimDecl VName, Maybe VName))
-> SizeSource -> TermTypeM (DimDecl VName, Maybe VName)
forall a b. (a -> b) -> a -> b
$
                Maybe (DimDecl VName)
-> Maybe (ExpBase NoInfo VName)
-> Maybe (ExpBase NoInfo VName)
-> Maybe (ExpBase NoInfo VName)
-> SizeSource
SourceSlice Maybe (DimDecl VName)
orig_d' (Exp -> ExpBase NoInfo VName
bareExp (Exp -> ExpBase NoInfo VName)
-> Maybe Exp -> Maybe (ExpBase NoInfo VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe Exp
i) (Exp -> ExpBase NoInfo VName
bareExp (Exp -> ExpBase NoInfo VName)
-> Maybe Exp -> Maybe (ExpBase NoInfo VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe Exp
j) (Exp -> ExpBase NoInfo VName
bareExp (Exp -> ExpBase NoInfo VName)
-> Maybe Exp -> Maybe (ExpBase NoInfo VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe Exp
stride)
          [VName] -> t TermTypeM ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell ([VName] -> t TermTypeM ()) -> [VName] -> t TermTypeM ()
forall a b. (a -> b) -> a -> b
$ Maybe VName -> [VName]
forall a. Maybe a -> [a]
maybeToList Maybe VName
ext
          DimDecl VName -> t TermTypeM (DimDecl VName)
forall (m :: * -> *) a. Monad m => a -> m a
return DimDecl VName
d
        Just (SrcLoc
loc, Rigidity
Nonrigid) ->
          TermTypeM (DimDecl VName) -> t TermTypeM (DimDecl VName)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (TermTypeM (DimDecl VName) -> t TermTypeM (DimDecl VName))
-> TermTypeM (DimDecl VName) -> t TermTypeM (DimDecl VName)
forall a b. (a -> b) -> a -> b
$ QualName VName -> DimDecl VName
forall vn. QualName vn -> DimDecl vn
NamedDim (QualName VName -> DimDecl VName)
-> (VName -> QualName VName) -> VName -> DimDecl VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> QualName VName
forall v. v -> QualName v
qualName (VName -> DimDecl VName)
-> TermTypeM VName -> TermTypeM (DimDecl VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SrcLoc -> Rigidity -> String -> TermTypeM VName
forall (m :: * -> *).
MonadUnify m =>
SrcLoc -> Rigidity -> String -> m VName
newDimVar SrcLoc
loc Rigidity
Nonrigid String
"slice_dim"
        Maybe (SrcLoc, Rigidity)
Nothing ->
          DimDecl VName -> t TermTypeM (DimDecl VName)
forall (f :: * -> *) a. Applicative f => a -> f a
pure DimDecl VName
forall vn. DimDecl vn
AnyDim
      where
        -- The original size does not matter if the slice is fully specified.
        orig_d' :: Maybe (DimDecl VName)
orig_d'
          | Maybe Exp -> Bool
forall a. Maybe a -> Bool
isJust Maybe Exp
i, Maybe Exp -> Bool
forall a. Maybe a -> Bool
isJust Maybe Exp
j = Maybe (DimDecl VName)
forall a. Maybe a
Nothing
          | Bool
otherwise = DimDecl VName -> Maybe (DimDecl VName)
forall a. a -> Maybe a
Just DimDecl VName
orig_d

    adjustDims :: [DimIndex] -> [DimDecl VName] -> t TermTypeM [DimDecl VName]
adjustDims (DimFix {} : [DimIndex]
idxes') (DimDecl VName
_ : [DimDecl VName]
dims) =
      [DimIndex] -> [DimDecl VName] -> t TermTypeM [DimDecl VName]
adjustDims [DimIndex]
idxes' [DimDecl VName]
dims
    -- Pattern match some known slices to be non-existential.
    adjustDims (DimSlice Maybe Exp
i Maybe Exp
j Maybe Exp
stride : [DimIndex]
idxes') (DimDecl VName
_ : [DimDecl VName]
dims)
      | Bool
refine_sizes,
        Bool -> (Exp -> Bool) -> Maybe Exp -> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
True ((Maybe Int32 -> Maybe Int32 -> Bool
forall a. Eq a => a -> a -> Bool
== Int32 -> Maybe Int32
forall a. a -> Maybe a
Just Int32
0) (Maybe Int32 -> Bool) -> (Exp -> Maybe Int32) -> Exp -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Exp -> Maybe Int32
isInt32) Maybe Exp
i,
        Just DimDecl VName
j' <- Exp -> Maybe (DimDecl VName)
maybeDimFromExp (Exp -> Maybe (DimDecl VName))
-> Maybe Exp -> Maybe (DimDecl VName)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Maybe Exp
j,
        Bool -> (Exp -> Bool) -> Maybe Exp -> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
True ((Maybe Int32 -> Maybe Int32 -> Bool
forall a. Eq a => a -> a -> Bool
== Int32 -> Maybe Int32
forall a. a -> Maybe a
Just Int32
1) (Maybe Int32 -> Bool) -> (Exp -> Maybe Int32) -> Exp -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Exp -> Maybe Int32
isInt32) Maybe Exp
stride =
        (DimDecl VName
j' DimDecl VName -> [DimDecl VName] -> [DimDecl VName]
forall a. a -> [a] -> [a]
:) ([DimDecl VName] -> [DimDecl VName])
-> t TermTypeM [DimDecl VName] -> t TermTypeM [DimDecl VName]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [DimIndex] -> [DimDecl VName] -> t TermTypeM [DimDecl VName]
adjustDims [DimIndex]
idxes' [DimDecl VName]
dims
    adjustDims (DimSlice Maybe Exp
Nothing Maybe Exp
Nothing Maybe Exp
stride : [DimIndex]
idxes') (DimDecl VName
d : [DimDecl VName]
dims)
      | Bool
refine_sizes,
        Bool -> (Exp -> Bool) -> Maybe Exp -> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
True (Bool -> (Int32 -> Bool) -> Maybe Int32 -> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
False ((Int32 -> Int32 -> Bool
forall a. Eq a => a -> a -> Bool
== Int32
1) (Int32 -> Bool) -> (Int32 -> Int32) -> Int32 -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int32 -> Int32
forall a. Num a => a -> a
abs) (Maybe Int32 -> Bool) -> (Exp -> Maybe Int32) -> Exp -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Exp -> Maybe Int32
isInt32) Maybe Exp
stride =
        (DimDecl VName
d DimDecl VName -> [DimDecl VName] -> [DimDecl VName]
forall a. a -> [a] -> [a]
:) ([DimDecl VName] -> [DimDecl VName])
-> t TermTypeM [DimDecl VName] -> t TermTypeM [DimDecl VName]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [DimIndex] -> [DimDecl VName] -> t TermTypeM [DimDecl VName]
adjustDims [DimIndex]
idxes' [DimDecl VName]
dims
    adjustDims (DimSlice Maybe Exp
i Maybe Exp
j Maybe Exp
stride : [DimIndex]
idxes') (DimDecl VName
d : [DimDecl VName]
dims) =
      (:) (DimDecl VName -> [DimDecl VName] -> [DimDecl VName])
-> t TermTypeM (DimDecl VName)
-> t TermTypeM ([DimDecl VName] -> [DimDecl VName])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> DimDecl VName
-> Maybe Exp
-> Maybe Exp
-> Maybe Exp
-> t TermTypeM (DimDecl VName)
forall (t :: (* -> *) -> * -> *).
(MonadTrans t, MonadWriter [VName] (t TermTypeM)) =>
DimDecl VName
-> Maybe Exp
-> Maybe Exp
-> Maybe Exp
-> t TermTypeM (DimDecl VName)
sliceSize DimDecl VName
d Maybe Exp
i Maybe Exp
j Maybe Exp
stride t TermTypeM ([DimDecl VName] -> [DimDecl VName])
-> t TermTypeM [DimDecl VName] -> t TermTypeM [DimDecl VName]
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [DimIndex] -> [DimDecl VName] -> t TermTypeM [DimDecl VName]
adjustDims [DimIndex]
idxes' [DimDecl VName]
dims
    adjustDims [DimIndex]
_ [DimDecl VName]
dims =
      [DimDecl VName] -> t TermTypeM [DimDecl VName]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [DimDecl VName]
dims
sliceShape Maybe (SrcLoc, Rigidity)
_ [DimIndex]
_ TypeBase (DimDecl VName) as
t = (TypeBase (DimDecl VName) as, [VName])
-> TermTypeM (TypeBase (DimDecl VName) as, [VName])
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TypeBase (DimDecl VName) as
t, [])

--- Main checkers

-- | @require ts e@ causes a 'TypeError' if @expType e@ is not one of
-- the types in @ts@.  Otherwise, simply returns @e@.
require :: String -> [PrimType] -> Exp -> TermTypeM Exp
require :: String -> [PrimType] -> Exp -> TermTypeM Exp
require String
why [PrimType]
ts Exp
e = do
  [PrimType] -> Usage -> StructType -> TermTypeM ()
forall (m :: * -> *).
MonadUnify m =>
[PrimType] -> Usage -> StructType -> m ()
mustBeOneOf [PrimType]
ts (SrcLoc -> String -> Usage
mkUsage (Exp -> SrcLoc
forall a. Located a => a -> SrcLoc
srclocOf Exp
e) String
why) (StructType -> TermTypeM ())
-> (PatternType -> StructType) -> PatternType -> TermTypeM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatternType -> StructType
forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct (PatternType -> TermTypeM ())
-> TermTypeM PatternType -> TermTypeM ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Exp -> TermTypeM PatternType
expType Exp
e
  Exp -> TermTypeM Exp
forall (m :: * -> *) a. Monad m => a -> m a
return Exp
e

unifies :: String -> StructType -> Exp -> TermTypeM Exp
unifies :: String -> StructType -> Exp -> TermTypeM Exp
unifies String
why StructType
t Exp
e = do
  Usage -> StructType -> StructType -> TermTypeM ()
forall (m :: * -> *).
MonadUnify m =>
Usage -> StructType -> StructType -> m ()
unify (SrcLoc -> String -> Usage
mkUsage (Exp -> SrcLoc
forall a. Located a => a -> SrcLoc
srclocOf Exp
e) String
why) StructType
t (StructType -> TermTypeM ())
-> (PatternType -> StructType) -> PatternType -> TermTypeM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatternType -> StructType
forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct (PatternType -> TermTypeM ())
-> TermTypeM PatternType -> TermTypeM ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Exp -> TermTypeM PatternType
expType Exp
e
  Exp -> TermTypeM Exp
forall (m :: * -> *) a. Monad m => a -> m a
return Exp
e

-- The closure of a lambda or local function are those variables that
-- it references, and which local to the current top-level function.
lexicalClosure :: [Pattern] -> Occurences -> TermTypeM Aliasing
lexicalClosure :: [Pattern] -> [Occurence] -> TermTypeM Aliasing
lexicalClosure [Pattern]
params [Occurence]
closure = do
  Map VName ValBinding
vtable <- (TermEnv -> Map VName ValBinding)
-> TermTypeM (Map VName ValBinding)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((TermEnv -> Map VName ValBinding)
 -> TermTypeM (Map VName ValBinding))
-> (TermEnv -> Map VName ValBinding)
-> TermTypeM (Map VName ValBinding)
forall a b. (a -> b) -> a -> b
$ TermScope -> Map VName ValBinding
scopeVtable (TermScope -> Map VName ValBinding)
-> (TermEnv -> TermScope) -> TermEnv -> Map VName ValBinding
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TermEnv -> TermScope
termScope
  let isLocal :: VName -> Bool
isLocal VName
v = case VName
v VName -> Map VName ValBinding -> Maybe ValBinding
forall k a. Ord k => k -> Map k a -> Maybe a
`M.lookup` Map VName ValBinding
vtable of
        Just (BoundV Locality
Local [TypeParam]
_ PatternType
_) -> Bool
True
        Maybe ValBinding
_ -> Bool
False
  Aliasing -> TermTypeM Aliasing
forall (m :: * -> *) a. Monad m => a -> m a
return (Aliasing -> TermTypeM Aliasing) -> Aliasing -> TermTypeM Aliasing
forall a b. (a -> b) -> a -> b
$
    (VName -> Alias) -> Names -> Aliasing
forall b a. Ord b => (a -> b) -> Set a -> Set b
S.map VName -> Alias
AliasBound (Names -> Aliasing) -> Names -> Aliasing
forall a b. (a -> b) -> a -> b
$
      (VName -> Bool) -> Names -> Names
forall a. (a -> Bool) -> Set a -> Set a
S.filter VName -> Bool
isLocal (Names -> Names) -> Names -> Names
forall a b. (a -> b) -> a -> b
$
        [Occurence] -> Names
allOccuring [Occurence]
closure Names -> Names -> Names
forall a. Ord a => Set a -> Set a -> Set a
S.\\ [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ((Pattern -> Names) -> [Pattern] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map Pattern -> Names
forall (f :: * -> *) vn.
(Functor f, Ord vn) =>
PatternBase f vn -> Set vn
patternNames [Pattern]
params)

noAliasesIfOverloaded :: PatternType -> TermTypeM PatternType
noAliasesIfOverloaded :: PatternType -> TermTypeM PatternType
noAliasesIfOverloaded t :: PatternType
t@(Scalar (TypeVar Aliasing
_ Uniqueness
u TypeName
tn [])) = do
  Maybe Constraint
subst <- ((Int, Constraint) -> Constraint)
-> Maybe (Int, Constraint) -> Maybe Constraint
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Int, Constraint) -> Constraint
forall a b. (a, b) -> b
snd (Maybe (Int, Constraint) -> Maybe Constraint)
-> (Constraints -> Maybe (Int, Constraint))
-> Constraints
-> Maybe Constraint
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Constraints -> Maybe (Int, Constraint)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (TypeName -> VName
typeLeaf TypeName
tn) (Constraints -> Maybe Constraint)
-> TermTypeM Constraints -> TermTypeM (Maybe Constraint)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TermTypeM Constraints
forall (m :: * -> *). MonadUnify m => m Constraints
getConstraints
  case Maybe Constraint
subst of
    Just Overloaded {} -> PatternType -> TermTypeM PatternType
forall (m :: * -> *) a. Monad m => a -> m a
return (PatternType -> TermTypeM PatternType)
-> PatternType -> TermTypeM PatternType
forall a b. (a -> b) -> a -> b
$ ScalarTypeBase (DimDecl VName) Aliasing -> PatternType
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (ScalarTypeBase (DimDecl VName) Aliasing -> PatternType)
-> ScalarTypeBase (DimDecl VName) Aliasing -> PatternType
forall a b. (a -> b) -> a -> b
$ Aliasing
-> Uniqueness
-> TypeName
-> [TypeArg (DimDecl VName)]
-> ScalarTypeBase (DimDecl VName) Aliasing
forall dim as.
as
-> Uniqueness -> TypeName -> [TypeArg dim] -> ScalarTypeBase dim as
TypeVar Aliasing
forall a. Monoid a => a
mempty Uniqueness
u TypeName
tn []
    Maybe Constraint
_ -> PatternType -> TermTypeM PatternType
forall (m :: * -> *) a. Monad m => a -> m a
return PatternType
t
noAliasesIfOverloaded PatternType
t =
  PatternType -> TermTypeM PatternType
forall (m :: * -> *) a. Monad m => a -> m a
return PatternType
t

-- Check the common parts of ascription and coercion.
checkAscript ::
  SrcLoc ->
  UncheckedTypeDecl ->
  UncheckedExp ->
  (StructType -> StructType) ->
  TermTypeM (TypeDecl, Exp)
checkAscript :: SrcLoc
-> TypeDeclBase NoInfo Name
-> ExpBase NoInfo Name
-> (StructType -> StructType)
-> TermTypeM (TypeDeclBase Info VName, Exp)
checkAscript SrcLoc
loc TypeDeclBase NoInfo Name
decl ExpBase NoInfo Name
e StructType -> StructType
shapef = do
  TypeDeclBase Info VName
decl' <- TypeDeclBase NoInfo Name -> TermTypeM (TypeDeclBase Info VName)
checkTypeDecl TypeDeclBase NoInfo Name
decl
  Exp
e' <- ExpBase NoInfo Name -> TermTypeM Exp
checkExp ExpBase NoInfo Name
e
  PatternType
t <- Exp -> TermTypeM PatternType
expTypeFully Exp
e'

  (StructType
decl_t_nonrigid, [VName]
_) <-
    SrcLoc
-> String
-> Rigidity
-> StructType
-> TermTypeM (StructType, [VName])
forall (m :: * -> *) als.
MonadUnify m =>
SrcLoc
-> String
-> Rigidity
-> TypeBase (DimDecl VName) als
-> m (TypeBase (DimDecl VName) als, [VName])
instantiateEmptyArrayDims SrcLoc
loc String
"impl" Rigidity
Nonrigid (StructType -> TermTypeM (StructType, [VName]))
-> StructType -> TermTypeM (StructType, [VName])
forall a b. (a -> b) -> a -> b
$
      StructType -> StructType
shapef (StructType -> StructType) -> StructType -> StructType
forall a b. (a -> b) -> a -> b
$
        Info StructType -> StructType
forall a. Info a -> a
unInfo (Info StructType -> StructType) -> Info StructType -> StructType
forall a b. (a -> b) -> a -> b
$ TypeDeclBase Info VName -> Info StructType
forall (f :: * -> *) vn. TypeDeclBase f vn -> f StructType
expandedType TypeDeclBase Info VName
decl'

  Checking -> TermTypeM () -> TermTypeM ()
forall a. Checking -> TermTypeM a -> TermTypeM a
onFailure (StructType -> StructType -> Checking
CheckingAscription (Info StructType -> StructType
forall a. Info a -> a
unInfo (Info StructType -> StructType) -> Info StructType -> StructType
forall a b. (a -> b) -> a -> b
$ TypeDeclBase Info VName -> Info StructType
forall (f :: * -> *) vn. TypeDeclBase f vn -> f StructType
expandedType TypeDeclBase Info VName
decl') (PatternType -> StructType
forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct PatternType
t)) (TermTypeM () -> TermTypeM ()) -> TermTypeM () -> TermTypeM ()
forall a b. (a -> b) -> a -> b
$
    Usage -> StructType -> StructType -> TermTypeM ()
forall (m :: * -> *).
MonadUnify m =>
Usage -> StructType -> StructType -> m ()
unify (SrcLoc -> String -> Usage
mkUsage SrcLoc
loc String
"type ascription") StructType
decl_t_nonrigid (PatternType -> StructType
forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct PatternType
t)

  -- We also have to make sure that uniqueness matches.  This is done
  -- explicitly, because uniqueness is ignored by unification.
  PatternType
t' <- PatternType -> TermTypeM PatternType
forall a (m :: * -> *). (Substitutable a, MonadUnify m) => a -> m a
normTypeFully PatternType
t
  StructType
decl_t' <- StructType -> TermTypeM StructType
forall a (m :: * -> *). (Substitutable a, MonadUnify m) => a -> m a
normTypeFully (StructType -> TermTypeM StructType)
-> StructType -> TermTypeM StructType
forall a b. (a -> b) -> a -> b
$ Info StructType -> StructType
forall a. Info a -> a
unInfo (Info StructType -> StructType) -> Info StructType -> StructType
forall a b. (a -> b) -> a -> b
$ TypeDeclBase Info VName -> Info StructType
forall (f :: * -> *) vn. TypeDeclBase f vn -> f StructType
expandedType TypeDeclBase Info VName
decl'
  Bool -> TermTypeM () -> TermTypeM ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (PatternType
t' PatternType -> StructType -> Bool
forall dim as1 as2.
ArrayDim dim =>
TypeBase dim as1 -> TypeBase dim as2 -> Bool
`subtypeOf` StructType -> StructType
forall vn as. TypeBase (DimDecl vn) as -> TypeBase (DimDecl vn) as
anySizes StructType
decl_t') (TermTypeM () -> TermTypeM ()) -> TermTypeM () -> TermTypeM ()
forall a b. (a -> b) -> a -> b
$
    SrcLoc -> Notes -> Doc -> TermTypeM ()
forall (m :: * -> *) loc a.
(MonadTypeChecker m, Located loc) =>
loc -> Notes -> Doc -> m a
typeError SrcLoc
loc Notes
forall a. Monoid a => a
mempty (Doc -> TermTypeM ()) -> Doc -> TermTypeM ()
forall a b. (a -> b) -> a -> b
$
      Doc
"Type" Doc -> Doc -> Doc
<+> Doc -> Doc
pquote (PatternType -> Doc
forall a. Pretty a => a -> Doc
ppr PatternType
t') Doc -> Doc -> Doc
<+> Doc
"is not a subtype of"
        Doc -> Doc -> Doc
<+> Doc -> Doc
pquote (StructType -> Doc
forall a. Pretty a => a -> Doc
ppr StructType
decl_t') Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
"."

  (TypeDeclBase Info VName, Exp)
-> TermTypeM (TypeDeclBase Info VName, Exp)
forall (m :: * -> *) a. Monad m => a -> m a
return (TypeDeclBase Info VName
decl', Exp
e')

unscopeType ::
  SrcLoc ->
  M.Map VName Ident ->
  PatternType ->
  TermTypeM (PatternType, [VName])
unscopeType :: SrcLoc
-> Map VName Ident
-> PatternType
-> TermTypeM (PatternType, [VName])
unscopeType SrcLoc
tloc Map VName Ident
unscoped PatternType
t = do
  (PatternType
t', Map VName VName
m) <- StateT (Map VName VName) TermTypeM PatternType
-> Map VName VName -> TermTypeM (PatternType, Map VName VName)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT ((Names
 -> DimPos
 -> DimDecl VName
 -> StateT (Map VName VName) TermTypeM (DimDecl VName))
-> PatternType -> StateT (Map VName VName) TermTypeM PatternType
forall (f :: * -> *) fdim tdim als.
Applicative f =>
(Names -> DimPos -> fdim -> f tdim)
-> TypeBase fdim als -> f (TypeBase tdim als)
traverseDims Names
-> DimPos
-> DimDecl VName
-> StateT (Map VName VName) TermTypeM (DimDecl VName)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) p.
(MonadState (Map VName VName) (t m), MonadTrans t, MonadUnify m) =>
p -> DimPos -> DimDecl VName -> t m (DimDecl VName)
onDim PatternType
t) Map VName VName
forall a. Monoid a => a
mempty
  (PatternType, [VName]) -> TermTypeM (PatternType, [VName])
forall (m :: * -> *) a. Monad m => a -> m a
return (PatternType
t' PatternType -> (Aliasing -> Aliasing) -> PatternType
forall dim asf ast.
TypeBase dim asf -> (asf -> ast) -> TypeBase dim ast
`addAliases` (Alias -> Alias) -> Aliasing -> Aliasing
forall b a. Ord b => (a -> b) -> Set a -> Set b
S.map Alias -> Alias
unAlias, Map VName VName -> [VName]
forall k a. Map k a -> [a]
M.elems Map VName VName
m)
  where
    onDim :: p -> DimPos -> DimDecl VName -> t m (DimDecl VName)
onDim p
_ DimPos
p (NamedDim QualName VName
d)
      | Just SrcLoc
loc <- Ident -> SrcLoc
forall a. Located a => a -> SrcLoc
srclocOf (Ident -> SrcLoc) -> Maybe Ident -> Maybe SrcLoc
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> Map VName Ident -> Maybe Ident
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (QualName VName -> VName
forall vn. QualName vn -> vn
qualLeaf QualName VName
d) Map VName Ident
unscoped =
        if DimPos
p DimPos -> DimPos -> Bool
forall a. Eq a => a -> a -> Bool
== DimPos
PosImmediate Bool -> Bool -> Bool
|| DimPos
p DimPos -> DimPos -> Bool
forall a. Eq a => a -> a -> Bool
== DimPos
PosParam
          then SrcLoc -> VName -> t m (DimDecl VName)
forall (t :: (* -> *) -> * -> *) (m :: * -> *).
(MonadState (Map VName VName) (t m), MonadTrans t, MonadUnify m) =>
SrcLoc -> VName -> t m (DimDecl VName)
inst SrcLoc
loc (VName -> t m (DimDecl VName)) -> VName -> t m (DimDecl VName)
forall a b. (a -> b) -> a -> b
$ QualName VName -> VName
forall vn. QualName vn -> vn
qualLeaf QualName VName
d
          else DimDecl VName -> t m (DimDecl VName)
forall (m :: * -> *) a. Monad m => a -> m a
return DimDecl VName
forall vn. DimDecl vn
AnyDim
    onDim p
_ DimPos
_ DimDecl VName
d = DimDecl VName -> t m (DimDecl VName)
forall (m :: * -> *) a. Monad m => a -> m a
return DimDecl VName
d

    inst :: SrcLoc -> VName -> t m (DimDecl VName)
inst SrcLoc
loc VName
d = do
      Maybe VName
prev <- (Map VName VName -> Maybe VName) -> t m (Maybe VName)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((Map VName VName -> Maybe VName) -> t m (Maybe VName))
-> (Map VName VName -> Maybe VName) -> t m (Maybe VName)
forall a b. (a -> b) -> a -> b
$ VName -> Map VName VName -> Maybe VName
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
d
      case Maybe VName
prev of
        Just VName
d' -> DimDecl VName -> t m (DimDecl VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (DimDecl VName -> t m (DimDecl VName))
-> DimDecl VName -> t m (DimDecl VName)
forall a b. (a -> b) -> a -> b
$ QualName VName -> DimDecl VName
forall vn. QualName vn -> DimDecl vn
NamedDim (QualName VName -> DimDecl VName)
-> QualName VName -> DimDecl VName
forall a b. (a -> b) -> a -> b
$ VName -> QualName VName
forall v. v -> QualName v
qualName VName
d'
        Maybe VName
Nothing -> do
          VName
d' <- m VName -> t m VName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m VName -> t m VName) -> m VName -> t m VName
forall a b. (a -> b) -> a -> b
$ SrcLoc -> Rigidity -> String -> m VName
forall (m :: * -> *).
MonadUnify m =>
SrcLoc -> Rigidity -> String -> m VName
newDimVar SrcLoc
tloc (RigidSource -> Rigidity
Rigid (RigidSource -> Rigidity) -> RigidSource -> Rigidity
forall a b. (a -> b) -> a -> b
$ SrcLoc -> VName -> RigidSource
RigidOutOfScope SrcLoc
loc VName
d) String
"d"
          (Map VName VName -> Map VName VName) -> t m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((Map VName VName -> Map VName VName) -> t m ())
-> (Map VName VName -> Map VName VName) -> t m ()
forall a b. (a -> b) -> a -> b
$ VName -> VName -> Map VName VName -> Map VName VName
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
d VName
d'
          DimDecl VName -> t m (DimDecl VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (DimDecl VName -> t m (DimDecl VName))
-> DimDecl VName -> t m (DimDecl VName)
forall a b. (a -> b) -> a -> b
$ QualName VName -> DimDecl VName
forall vn. QualName vn -> DimDecl vn
NamedDim (QualName VName -> DimDecl VName)
-> QualName VName -> DimDecl VName
forall a b. (a -> b) -> a -> b
$ VName -> QualName VName
forall v. v -> QualName v
qualName VName
d'

    unAlias :: Alias -> Alias
unAlias (AliasBound VName
v) | VName
v VName -> Map VName Ident -> Bool
forall k a. Ord k => k -> Map k a -> Bool
`M.member` Map VName Ident
unscoped = VName -> Alias
AliasFree VName
v
    unAlias Alias
a = Alias
a

-- 'checkApplyExp' is like 'checkExp', but tries to find the "root
-- function", for better error messages.
checkApplyExp :: UncheckedExp -> TermTypeM (Exp, ApplyOp)
checkApplyExp :: ExpBase NoInfo Name -> TermTypeM (Exp, ApplyOp)
checkApplyExp (Apply ExpBase NoInfo Name
e1 ExpBase NoInfo Name
e2 NoInfo (Diet, Maybe VName)
_ (NoInfo PatternType, NoInfo [VName])
_ SrcLoc
loc) = do
  (Exp
e1', (Maybe (QualName VName)
fname, Int
i)) <- ExpBase NoInfo Name -> TermTypeM (Exp, ApplyOp)
checkApplyExp ExpBase NoInfo Name
e1
  Arg
arg <- ExpBase NoInfo Name -> TermTypeM Arg
checkArg ExpBase NoInfo Name
e2
  PatternType
t <- Exp -> TermTypeM PatternType
expType Exp
e1'
  (PatternType
t1, PatternType
rt, Maybe VName
argext, [VName]
exts) <- SrcLoc
-> ApplyOp
-> PatternType
-> Arg
-> TermTypeM (PatternType, PatternType, Maybe VName, [VName])
checkApply SrcLoc
loc (Maybe (QualName VName)
fname, Int
i) PatternType
t Arg
arg
  (Exp, ApplyOp) -> TermTypeM (Exp, ApplyOp)
forall (m :: * -> *) a. Monad m => a -> m a
return
    ( Exp
-> Exp
-> Info (Diet, Maybe VName)
-> (Info PatternType, Info [VName])
-> SrcLoc
-> Exp
forall (f :: * -> *) vn.
ExpBase f vn
-> ExpBase f vn
-> f (Diet, Maybe VName)
-> (f PatternType, f [VName])
-> SrcLoc
-> ExpBase f vn
Apply Exp
e1' (Arg -> Exp
argExp Arg
arg) ((Diet, Maybe VName) -> Info (Diet, Maybe VName)
forall a. a -> Info a
Info (PatternType -> Diet
forall shape as. TypeBase shape as -> Diet
diet PatternType
t1, Maybe VName
argext)) (PatternType -> Info PatternType
forall a. a -> Info a
Info PatternType
rt, [VName] -> Info [VName]
forall a. a -> Info a
Info [VName]
exts) SrcLoc
loc,
      (Maybe (QualName VName)
fname, Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
    )
checkApplyExp ExpBase NoInfo Name
e = do
  Exp
e' <- ExpBase NoInfo Name -> TermTypeM Exp
checkExp ExpBase NoInfo Name
e
  (Exp, ApplyOp) -> TermTypeM (Exp, ApplyOp)
forall (m :: * -> *) a. Monad m => a -> m a
return
    ( Exp
e',
      ( case Exp
e' of
          Var QualName VName
qn Info PatternType
_ SrcLoc
_ -> QualName VName -> Maybe (QualName VName)
forall a. a -> Maybe a
Just QualName VName
qn
          Exp
_ -> Maybe (QualName VName)
forall a. Maybe a
Nothing,
        Int
0
      )
    )

checkExp :: UncheckedExp -> TermTypeM Exp
checkExp :: ExpBase NoInfo Name -> TermTypeM Exp
checkExp (Literal PrimValue
val SrcLoc
loc) =
  Exp -> TermTypeM Exp
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> TermTypeM Exp) -> Exp -> TermTypeM Exp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SrcLoc -> Exp
forall (f :: * -> *) vn. PrimValue -> SrcLoc -> ExpBase f vn
Literal PrimValue
val SrcLoc
loc
checkExp (StringLit [Word8]
vs SrcLoc
loc) =
  Exp -> TermTypeM Exp
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> TermTypeM Exp) -> Exp -> TermTypeM Exp
forall a b. (a -> b) -> a -> b
$ [Word8] -> SrcLoc -> Exp
forall (f :: * -> *) vn. [Word8] -> SrcLoc -> ExpBase f vn
StringLit [Word8]
vs SrcLoc
loc
checkExp (IntLit Integer
val NoInfo PatternType
NoInfo SrcLoc
loc) = do
  StructType
t <- SrcLoc -> String -> TermTypeM StructType
forall (m :: * -> *) als dim.
(MonadUnify m, Monoid als) =>
SrcLoc -> String -> m (TypeBase dim als)
newTypeVar SrcLoc
loc String
"t"
  [PrimType] -> Usage -> StructType -> TermTypeM ()
forall (m :: * -> *).
MonadUnify m =>
[PrimType] -> Usage -> StructType -> m ()
mustBeOneOf [PrimType]
anyNumberType (SrcLoc -> String -> Usage
mkUsage SrcLoc
loc String
"integer literal") StructType
t
  Exp -> TermTypeM Exp
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> TermTypeM Exp) -> Exp -> TermTypeM Exp
forall a b. (a -> b) -> a -> b
$ Integer -> Info PatternType -> SrcLoc -> Exp
forall (f :: * -> *) vn.
Integer -> f PatternType -> SrcLoc -> ExpBase f vn
IntLit Integer
val (PatternType -> Info PatternType
forall a. a -> Info a
Info (PatternType -> Info PatternType)
-> PatternType -> Info PatternType
forall a b. (a -> b) -> a -> b
$ StructType -> PatternType
forall dim as. TypeBase dim as -> TypeBase dim Aliasing
fromStruct StructType
t) SrcLoc
loc
checkExp (FloatLit Double
val NoInfo PatternType
NoInfo SrcLoc
loc) = do
  StructType
t <- SrcLoc -> String -> TermTypeM StructType
forall (m :: * -> *) als dim.
(MonadUnify m, Monoid als) =>
SrcLoc -> String -> m (TypeBase dim als)
newTypeVar SrcLoc
loc String
"t"
  [PrimType] -> Usage -> StructType -> TermTypeM ()
forall (m :: * -> *).
MonadUnify m =>
[PrimType] -> Usage -> StructType -> m ()
mustBeOneOf [PrimType]
anyFloatType (SrcLoc -> String -> Usage
mkUsage SrcLoc
loc String
"float literal") StructType
t
  Exp -> TermTypeM Exp
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> TermTypeM Exp) -> Exp -> TermTypeM Exp
forall a b. (a -> b) -> a -> b
$ Double -> Info PatternType -> SrcLoc -> Exp
forall (f :: * -> *) vn.
Double -> f PatternType -> SrcLoc -> ExpBase f vn
FloatLit Double
val (PatternType -> Info PatternType
forall a. a -> Info a
Info (PatternType -> Info PatternType)
-> PatternType -> Info PatternType
forall a b. (a -> b) -> a -> b
$ StructType -> PatternType
forall dim as. TypeBase dim as -> TypeBase dim Aliasing
fromStruct StructType
t) SrcLoc
loc
checkExp (TupLit [ExpBase NoInfo Name]
es SrcLoc
loc) =
  [Exp] -> SrcLoc -> Exp
forall (f :: * -> *) vn. [ExpBase f vn] -> SrcLoc -> ExpBase f vn
TupLit ([Exp] -> SrcLoc -> Exp)
-> TermTypeM [Exp] -> TermTypeM (SrcLoc -> Exp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (ExpBase NoInfo Name -> TermTypeM Exp)
-> [ExpBase NoInfo Name] -> TermTypeM [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ExpBase NoInfo Name -> TermTypeM Exp
checkExp [ExpBase NoInfo Name]
es TermTypeM (SrcLoc -> Exp) -> TermTypeM SrcLoc -> TermTypeM Exp
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SrcLoc -> TermTypeM SrcLoc
forall (f :: * -> *) a. Applicative f => a -> f a
pure SrcLoc
loc
checkExp (RecordLit [FieldBase NoInfo Name]
fs SrcLoc
loc) = do
  [FieldBase Info VName]
fs' <- StateT (Map Name SrcLoc) TermTypeM [FieldBase Info VName]
-> Map Name SrcLoc -> TermTypeM [FieldBase Info VName]
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT ((FieldBase NoInfo Name
 -> StateT (Map Name SrcLoc) TermTypeM (FieldBase Info VName))
-> [FieldBase NoInfo Name]
-> StateT (Map Name SrcLoc) TermTypeM [FieldBase Info VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM FieldBase NoInfo Name
-> StateT (Map Name SrcLoc) TermTypeM (FieldBase Info VName)
forall (t :: (* -> *) -> * -> *).
(MonadState (Map Name SrcLoc) (t TermTypeM), MonadTrans t) =>
FieldBase NoInfo Name -> t TermTypeM (FieldBase Info VName)
checkField [FieldBase NoInfo Name]
fs) Map Name SrcLoc
forall a. Monoid a => a
mempty

  Exp -> TermTypeM Exp
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> TermTypeM Exp) -> Exp -> TermTypeM Exp
forall a b. (a -> b) -> a -> b
$ [FieldBase Info VName] -> SrcLoc -> Exp
forall (f :: * -> *) vn. [FieldBase f vn] -> SrcLoc -> ExpBase f vn
RecordLit [FieldBase Info VName]
fs' SrcLoc
loc
  where
    checkField :: FieldBase NoInfo Name -> t TermTypeM (FieldBase Info VName)
checkField (RecordFieldExplicit Name
f ExpBase NoInfo Name
e SrcLoc
rloc) = do
      Name -> SrcLoc -> t TermTypeM ()
forall a b (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadState (Map a b) (t m), Ord a, MonadTrans t,
 MonadTypeChecker m, Pretty a, Located a, Located b) =>
a -> a -> t m ()
errIfAlreadySet Name
f SrcLoc
rloc
      (Map Name SrcLoc -> Map Name SrcLoc) -> t TermTypeM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((Map Name SrcLoc -> Map Name SrcLoc) -> t TermTypeM ())
-> (Map Name SrcLoc -> Map Name SrcLoc) -> t TermTypeM ()
forall a b. (a -> b) -> a -> b
$ Name -> SrcLoc -> Map Name SrcLoc -> Map Name SrcLoc
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert Name
f SrcLoc
rloc
      Name -> Exp -> SrcLoc -> FieldBase Info VName
forall (f :: * -> *) vn.
Name -> ExpBase f vn -> SrcLoc -> FieldBase f vn
RecordFieldExplicit Name
f (Exp -> SrcLoc -> FieldBase Info VName)
-> t TermTypeM Exp -> t TermTypeM (SrcLoc -> FieldBase Info VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TermTypeM Exp -> t TermTypeM Exp
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ExpBase NoInfo Name -> TermTypeM Exp
checkExp ExpBase NoInfo Name
e) t TermTypeM (SrcLoc -> FieldBase Info VName)
-> t TermTypeM SrcLoc -> t TermTypeM (FieldBase Info VName)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SrcLoc -> t TermTypeM SrcLoc
forall (f :: * -> *) a. Applicative f => a -> f a
pure SrcLoc
rloc
    checkField (RecordFieldImplicit Name
name NoInfo PatternType
NoInfo SrcLoc
rloc) = do
      Name -> SrcLoc -> t TermTypeM ()
forall a b (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadState (Map a b) (t m), Ord a, MonadTrans t,
 MonadTypeChecker m, Pretty a, Located a, Located b) =>
a -> a -> t m ()
errIfAlreadySet Name
name SrcLoc
rloc
      (QualName [VName]
_ VName
name', PatternType
t) <- TermTypeM (QualName VName, PatternType)
-> t TermTypeM (QualName VName, PatternType)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (TermTypeM (QualName VName, PatternType)
 -> t TermTypeM (QualName VName, PatternType))
-> TermTypeM (QualName VName, PatternType)
-> t TermTypeM (QualName VName, PatternType)
forall a b. (a -> b) -> a -> b
$ SrcLoc -> QualName Name -> TermTypeM (QualName VName, PatternType)
forall (m :: * -> *).
MonadTypeChecker m =>
SrcLoc -> QualName Name -> m (QualName VName, PatternType)
lookupVar SrcLoc
rloc (QualName Name -> TermTypeM (QualName VName, PatternType))
-> QualName Name -> TermTypeM (QualName VName, PatternType)
forall a b. (a -> b) -> a -> b
$ Name -> QualName Name
forall v. v -> QualName v
qualName Name
name
      (Map Name SrcLoc -> Map Name SrcLoc) -> t TermTypeM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((Map Name SrcLoc -> Map Name SrcLoc) -> t TermTypeM ())
-> (Map Name SrcLoc -> Map Name SrcLoc) -> t TermTypeM ()
forall a b. (a -> b) -> a -> b
$ Name -> SrcLoc -> Map Name SrcLoc -> Map Name SrcLoc
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert Name
name SrcLoc
rloc
      FieldBase Info VName -> t TermTypeM (FieldBase Info VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (FieldBase Info VName -> t TermTypeM (FieldBase Info VName))
-> FieldBase Info VName -> t TermTypeM (FieldBase Info VName)
forall a b. (a -> b) -> a -> b
$ VName -> Info PatternType -> SrcLoc -> FieldBase Info VName
forall (f :: * -> *) vn.
vn -> f PatternType -> SrcLoc -> FieldBase f vn
RecordFieldImplicit VName
name' (PatternType -> Info PatternType
forall a. a -> Info a
Info PatternType
t) SrcLoc
rloc

    errIfAlreadySet :: a -> a -> t m ()
errIfAlreadySet a
f a
rloc = do
      Maybe b
maybe_sloc <- (Map a b -> Maybe b) -> t m (Maybe b)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((Map a b -> Maybe b) -> t m (Maybe b))
-> (Map a b -> Maybe b) -> t m (Maybe b)
forall a b. (a -> b) -> a -> b
$ a -> Map a b -> Maybe b
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup a
f
      case Maybe b
maybe_sloc of
        Just b
sloc ->
          m () -> t m ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> t m ()) -> m () -> t m ()
forall a b. (a -> b) -> a -> b
$
            a -> Notes -> Doc -> m ()
forall (m :: * -> *) loc a.
(MonadTypeChecker m, Located loc) =>
loc -> Notes -> Doc -> m a
typeError a
rloc Notes
forall a. Monoid a => a
mempty (Doc -> m ()) -> Doc -> m ()
forall a b. (a -> b) -> a -> b
$
              Doc
"Field" Doc -> Doc -> Doc
<+> Doc -> Doc
pquote (a -> Doc
forall a. Pretty a => a -> Doc
ppr a
f)
                Doc -> Doc -> Doc
<+> Doc
"previously defined at"
                Doc -> Doc -> Doc
<+> String -> Doc
text (a -> b -> String
forall a b. (Located a, Located b) => a -> b -> String
locStrRel a
rloc b
sloc) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
"."
        Maybe b
Nothing -> () -> t m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
checkExp (ArrayLit [ExpBase NoInfo Name]
all_es NoInfo PatternType
_ SrcLoc
loc) =
  -- Construct the result type and unify all elements with it.  We
  -- only create a type variable for empty arrays; otherwise we use
  -- the type of the first element.  This significantly cuts down on
  -- the number of type variables generated for pathologically large
  -- multidimensional array literals.
  case [ExpBase NoInfo Name]
all_es of
    [] -> do
      PatternType
et <- SrcLoc -> String -> TermTypeM PatternType
forall (m :: * -> *) als dim.
(MonadUnify m, Monoid als) =>
SrcLoc -> String -> m (TypeBase dim als)
newTypeVar SrcLoc
loc String
"t"
      PatternType
t <- SrcLoc
-> PatternType
-> ShapeDecl (DimDecl VName)
-> Uniqueness
-> TermTypeM PatternType
forall dim as.
(Pretty (ShapeDecl dim), Monoid as) =>
SrcLoc
-> TypeBase dim as
-> ShapeDecl dim
-> Uniqueness
-> TermTypeM (TypeBase dim as)
arrayOfM SrcLoc
loc PatternType
et ([DimDecl VName] -> ShapeDecl (DimDecl VName)
forall dim. [dim] -> ShapeDecl dim
ShapeDecl [Int -> DimDecl VName
forall vn. Int -> DimDecl vn
ConstDim Int
0]) Uniqueness
Unique
      Exp -> TermTypeM Exp
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> TermTypeM Exp) -> Exp -> TermTypeM Exp
forall a b. (a -> b) -> a -> b
$ [Exp] -> Info PatternType -> SrcLoc -> Exp
forall (f :: * -> *) vn.
[ExpBase f vn] -> f PatternType -> SrcLoc -> ExpBase f vn
ArrayLit [] (PatternType -> Info PatternType
forall a. a -> Info a
Info PatternType
t) SrcLoc
loc
    ExpBase NoInfo Name
e : [ExpBase NoInfo Name]
es -> do
      Exp
e' <- ExpBase NoInfo Name -> TermTypeM Exp
checkExp ExpBase NoInfo Name
e
      PatternType
et <- Exp -> TermTypeM PatternType
expType Exp
e'
      [Exp]
es' <- (ExpBase NoInfo Name -> TermTypeM Exp)
-> [ExpBase NoInfo Name] -> TermTypeM [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (String -> StructType -> Exp -> TermTypeM Exp
unifies String
"type of first array element" (PatternType -> StructType
forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct PatternType
et) (Exp -> TermTypeM Exp)
-> (ExpBase NoInfo Name -> TermTypeM Exp)
-> ExpBase NoInfo Name
-> TermTypeM Exp
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< ExpBase NoInfo Name -> TermTypeM Exp
checkExp) [ExpBase NoInfo Name]
es
      PatternType
et' <- PatternType -> TermTypeM PatternType
forall a (m :: * -> *). (Substitutable a, MonadUnify m) => a -> m a
normTypeFully PatternType
et
      PatternType
t <- SrcLoc
-> PatternType
-> ShapeDecl (DimDecl VName)
-> Uniqueness
-> TermTypeM PatternType
forall dim as.
(Pretty (ShapeDecl dim), Monoid as) =>
SrcLoc
-> TypeBase dim as
-> ShapeDecl dim
-> Uniqueness
-> TermTypeM (TypeBase dim as)
arrayOfM SrcLoc
loc PatternType
et' ([DimDecl VName] -> ShapeDecl (DimDecl VName)
forall dim. [dim] -> ShapeDecl dim
ShapeDecl [Int -> DimDecl VName
forall vn. Int -> DimDecl vn
ConstDim (Int -> DimDecl VName) -> Int -> DimDecl VName
forall a b. (a -> b) -> a -> b
$ [ExpBase NoInfo Name] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [ExpBase NoInfo Name]
all_es]) Uniqueness
Unique
      Exp -> TermTypeM Exp
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> TermTypeM Exp) -> Exp -> TermTypeM Exp
forall a b. (a -> b) -> a -> b
$ [Exp] -> Info PatternType -> SrcLoc -> Exp
forall (f :: * -> *) vn.
[ExpBase f vn] -> f PatternType -> SrcLoc -> ExpBase f vn
ArrayLit (Exp
e' Exp -> [Exp] -> [Exp]
forall a. a -> [a] -> [a]
: [Exp]
es') (PatternType -> Info PatternType
forall a. a -> Info a
Info PatternType
t) SrcLoc
loc
checkExp (Range ExpBase NoInfo Name
start Maybe (ExpBase NoInfo Name)
maybe_step Inclusiveness (ExpBase NoInfo Name)
end (NoInfo PatternType, NoInfo [VName])
_ SrcLoc
loc) = do
  Exp
start' <- String -> [PrimType] -> Exp -> TermTypeM Exp
require String
"use in range expression" [PrimType]
anySignedType (Exp -> TermTypeM Exp) -> TermTypeM Exp -> TermTypeM Exp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ExpBase NoInfo Name -> TermTypeM Exp
checkExp ExpBase NoInfo Name
start
  StructType
start_t <- PatternType -> StructType
forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct (PatternType -> StructType)
-> TermTypeM PatternType -> TermTypeM StructType
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp -> TermTypeM PatternType
expTypeFully Exp
start'
  Maybe Exp
maybe_step' <- case Maybe (ExpBase NoInfo Name)
maybe_step of
    Maybe (ExpBase NoInfo Name)
Nothing -> Maybe Exp -> TermTypeM (Maybe Exp)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe Exp
forall a. Maybe a
Nothing
    Just ExpBase NoInfo Name
step -> do
      let warning :: TermTypeM ()
warning = SrcLoc -> String -> TermTypeM ()
forall (m :: * -> *) loc.
(MonadTypeChecker m, Located loc) =>
loc -> String -> m ()
warn SrcLoc
loc String
"First and second element of range are identical, this will produce an empty array."
      case (ExpBase NoInfo Name
start, ExpBase NoInfo Name
step) of
        (Literal PrimValue
x SrcLoc
_, Literal PrimValue
y SrcLoc
_) -> Bool -> TermTypeM () -> TermTypeM ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (PrimValue
x PrimValue -> PrimValue -> Bool
forall a. Eq a => a -> a -> Bool
== PrimValue
y) TermTypeM ()
warning
        (Var QualName Name
x_name NoInfo PatternType
_ SrcLoc
_, Var QualName Name
y_name NoInfo PatternType
_ SrcLoc
_) -> Bool -> TermTypeM () -> TermTypeM ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (QualName Name
x_name QualName Name -> QualName Name -> Bool
forall a. Eq a => a -> a -> Bool
== QualName Name
y_name) TermTypeM ()
warning
        (ExpBase NoInfo Name, ExpBase NoInfo Name)
_ -> () -> TermTypeM ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
      Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> TermTypeM Exp -> TermTypeM (Maybe Exp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (String -> StructType -> Exp -> TermTypeM Exp
unifies String
"use in range expression" StructType
start_t (Exp -> TermTypeM Exp) -> TermTypeM Exp -> TermTypeM Exp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ExpBase NoInfo Name -> TermTypeM Exp
checkExp ExpBase NoInfo Name
step)

  let unifyRange :: ExpBase NoInfo Name -> TermTypeM Exp
unifyRange ExpBase NoInfo Name
e = String -> StructType -> Exp -> TermTypeM Exp
unifies String
"use in range expression" StructType
start_t (Exp -> TermTypeM Exp) -> TermTypeM Exp -> TermTypeM Exp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ExpBase NoInfo Name -> TermTypeM Exp
checkExp ExpBase NoInfo Name
e
  Inclusiveness Exp
end' <- case Inclusiveness (ExpBase NoInfo Name)
end of
    DownToExclusive ExpBase NoInfo Name
e -> Exp -> Inclusiveness Exp
forall a. a -> Inclusiveness a
DownToExclusive (Exp -> Inclusiveness Exp)
-> TermTypeM Exp -> TermTypeM (Inclusiveness Exp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ExpBase NoInfo Name -> TermTypeM Exp
unifyRange ExpBase NoInfo Name
e
    UpToExclusive ExpBase NoInfo Name
e -> Exp -> Inclusiveness Exp
forall a. a -> Inclusiveness a
UpToExclusive (Exp -> Inclusiveness Exp)
-> TermTypeM Exp -> TermTypeM (Inclusiveness Exp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ExpBase NoInfo Name -> TermTypeM Exp
unifyRange ExpBase NoInfo Name
e
    ToInclusive ExpBase NoInfo Name
e -> Exp -> Inclusiveness Exp
forall a. a -> Inclusiveness a
ToInclusive (Exp -> Inclusiveness Exp)
-> TermTypeM Exp -> TermTypeM (Inclusiveness Exp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ExpBase NoInfo Name -> TermTypeM Exp
unifyRange ExpBase NoInfo Name
e

  -- Special case some ranges to give them a known size.
  let dimFromBound :: Exp -> TermTypeM (DimDecl VName, Maybe VName)
dimFromBound = (Exp -> SizeSource)
-> Exp -> TermTypeM (DimDecl VName, Maybe VName)
dimFromExp (ExpBase NoInfo VName -> SizeSource
SourceBound (ExpBase NoInfo VName -> SizeSource)
-> (Exp -> ExpBase NoInfo VName) -> Exp -> SizeSource
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Exp -> ExpBase NoInfo VName
bareExp)
  (DimDecl VName
dim, Maybe VName
retext) <-
    case (Exp -> Maybe Int32
isInt32 Exp
start', Exp -> Maybe Int32
isInt32 (Exp -> Maybe Int32) -> Maybe Exp -> Maybe (Maybe Int32)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe Exp
maybe_step', Inclusiveness Exp
end') of
      (Just Int32
0, Just (Just Int32
1), UpToExclusive Exp
end'') ->
        Exp -> TermTypeM (DimDecl VName, Maybe VName)
dimFromBound Exp
end''
      (Just Int32
0, Maybe (Maybe Int32)
Nothing, UpToExclusive Exp
end'') ->
        Exp -> TermTypeM (DimDecl VName, Maybe VName)
dimFromBound Exp
end''
      (Just Int32
1, Just (Just Int32
2), ToInclusive Exp
end'') ->
        Exp -> TermTypeM (DimDecl VName, Maybe VName)
dimFromBound Exp
end''
      (Maybe Int32, Maybe (Maybe Int32), Inclusiveness Exp)
_ -> do
        VName
d <- SrcLoc -> Rigidity -> String -> TermTypeM VName
forall (m :: * -> *).
MonadUnify m =>
SrcLoc -> Rigidity -> String -> m VName
newDimVar SrcLoc
loc (RigidSource -> Rigidity
Rigid RigidSource
RigidRange) String
"range_dim"
        (DimDecl VName, Maybe VName)
-> TermTypeM (DimDecl VName, Maybe VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (QualName VName -> DimDecl VName
forall vn. QualName vn -> DimDecl vn
NamedDim (QualName VName -> DimDecl VName)
-> QualName VName -> DimDecl VName
forall a b. (a -> b) -> a -> b
$ VName -> QualName VName
forall v. v -> QualName v
qualName VName
d, VName -> Maybe VName
forall a. a -> Maybe a
Just VName
d)

  StructType
t <- SrcLoc
-> StructType
-> ShapeDecl (DimDecl VName)
-> Uniqueness
-> TermTypeM StructType
forall dim as.
(Pretty (ShapeDecl dim), Monoid as) =>
SrcLoc
-> TypeBase dim as
-> ShapeDecl dim
-> Uniqueness
-> TermTypeM (TypeBase dim as)
arrayOfM SrcLoc
loc StructType
start_t ([DimDecl VName] -> ShapeDecl (DimDecl VName)
forall dim. [dim] -> ShapeDecl dim
ShapeDecl [DimDecl VName
dim]) Uniqueness
Unique
  let ret :: (Info PatternType, Info [VName])
ret = (PatternType -> Info PatternType
forall a. a -> Info a
Info (StructType
t StructType -> Aliasing -> PatternType
forall dim asf ast. TypeBase dim asf -> ast -> TypeBase dim ast
`setAliases` Aliasing
forall a. Monoid a => a
mempty), [VName] -> Info [VName]
forall a. a -> Info a
Info ([VName] -> Info [VName]) -> [VName] -> Info [VName]
forall a b. (a -> b) -> a -> b
$ Maybe VName -> [VName]
forall a. Maybe a -> [a]
maybeToList Maybe VName
retext)

  Exp -> TermTypeM Exp
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> TermTypeM Exp) -> Exp -> TermTypeM Exp
forall a b. (a -> b) -> a -> b
$ Exp
-> Maybe Exp
-> Inclusiveness Exp
-> (Info PatternType, Info [VName])
-> SrcLoc
-> Exp
forall (f :: * -> *) vn.
ExpBase f vn
-> Maybe (ExpBase f vn)
-> Inclusiveness (ExpBase f vn)
-> (f PatternType, f [VName])
-> SrcLoc
-> ExpBase f vn
Range Exp
start' Maybe Exp
maybe_step' Inclusiveness Exp
end' (Info PatternType, Info [VName])
ret SrcLoc
loc
checkExp (Ascript ExpBase NoInfo Name
e TypeDeclBase NoInfo Name
decl SrcLoc
loc) = do
  (TypeDeclBase Info VName
decl', Exp
e') <- SrcLoc
-> TypeDeclBase NoInfo Name
-> ExpBase NoInfo Name
-> (StructType -> StructType)
-> TermTypeM (TypeDeclBase Info VName, Exp)
checkAscript SrcLoc
loc TypeDeclBase NoInfo Name
decl ExpBase NoInfo Name
e StructType -> StructType
forall a. a -> a
id
  Exp -> TermTypeM Exp
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> TermTypeM Exp) -> Exp -> TermTypeM Exp
forall a b. (a -> b) -> a -> b
$ Exp -> TypeDeclBase Info VName -> SrcLoc -> Exp
forall (f :: * -> *) vn.
ExpBase f vn -> TypeDeclBase f vn -> SrcLoc -> ExpBase f vn
Ascript Exp
e' TypeDeclBase Info VName
decl' SrcLoc
loc
checkExp (Coerce ExpBase NoInfo Name
e TypeDeclBase NoInfo Name
decl (NoInfo PatternType, NoInfo [VName])
_ SrcLoc
loc) = do
  -- We instantiate the declared types with all dimensions as nonrigid
  -- fresh type variables, which we then use to unify with the type of
  -- 'e'.  This lets 'e' have whatever sizes it wants, but the overall
  -- type must still match.  Eventually we will throw away those sizes
  -- (they will end up being unified with various sizes in 'e', which
  -- is fine).
  (TypeDeclBase Info VName
decl', Exp
e') <- SrcLoc
-> TypeDeclBase NoInfo Name
-> ExpBase NoInfo Name
-> (StructType -> StructType)
-> TermTypeM (TypeDeclBase Info VName, Exp)
checkAscript SrcLoc
loc TypeDeclBase NoInfo Name
decl ExpBase NoInfo Name
e StructType -> StructType
forall vn as. TypeBase (DimDecl vn) as -> TypeBase (DimDecl vn) as
anySizes

  -- Now we instantiate the declared type again, but this time we keep
  -- around the sizes as existentials.  This is the result of the
  -- ascription as a whole.  We use matchDims to obtain the aliasing
  -- of 'e'.
  (StructType
decl_t_rigid, [VName]
ext) <-
    SrcLoc
-> Maybe (QualName VName)
-> StructType
-> TermTypeM (StructType, [VName])
forall als.
SrcLoc
-> Maybe (QualName VName)
-> TypeBase (DimDecl VName) als
-> TermTypeM (TypeBase (DimDecl VName) als, [VName])
instantiateDimsInReturnType SrcLoc
loc Maybe (QualName VName)
forall a. Maybe a
Nothing (StructType -> TermTypeM (StructType, [VName]))
-> StructType -> TermTypeM (StructType, [VName])
forall a b. (a -> b) -> a -> b
$ Info StructType -> StructType
forall a. Info a -> a
unInfo (Info StructType -> StructType) -> Info StructType -> StructType
forall a b. (a -> b) -> a -> b
$ TypeDeclBase Info VName -> Info StructType
forall (f :: * -> *) vn. TypeDeclBase f vn -> f StructType
expandedType TypeDeclBase Info VName
decl'

  PatternType
t <- Exp -> TermTypeM PatternType
expTypeFully Exp
e'

  PatternType
t' <- (DimDecl VName -> DimDecl VName -> TermTypeM (DimDecl VName))
-> PatternType -> PatternType -> TermTypeM PatternType
forall as (m :: * -> *) d1 d2.
(Monoid as, Monad m) =>
(d1 -> d2 -> m d1)
-> TypeBase d1 as -> TypeBase d2 as -> m (TypeBase d1 as)
matchDims ((DimDecl VName -> TermTypeM (DimDecl VName))
-> DimDecl VName -> DimDecl VName -> TermTypeM (DimDecl VName)
forall a b. a -> b -> a
const DimDecl VName -> TermTypeM (DimDecl VName)
forall (f :: * -> *) a. Applicative f => a -> f a
pure) PatternType
t (PatternType -> TermTypeM PatternType)
-> PatternType -> TermTypeM PatternType
forall a b. (a -> b) -> a -> b
$ StructType -> PatternType
forall dim as. TypeBase dim as -> TypeBase dim Aliasing
fromStruct StructType
decl_t_rigid

  Exp -> TermTypeM Exp
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> TermTypeM Exp) -> Exp -> TermTypeM Exp
forall a b. (a -> b) -> a -> b
$ Exp
-> TypeDeclBase Info VName
-> (Info PatternType, Info [VName])
-> SrcLoc
-> Exp
forall (f :: * -> *) vn.
ExpBase f vn
-> TypeDeclBase f vn
-> (f PatternType, f [VName])
-> SrcLoc
-> ExpBase f vn
Coerce Exp
e' TypeDeclBase Info VName
decl' (PatternType -> Info PatternType
forall a. a -> Info a
Info PatternType
t', [VName] -> Info [VName]
forall a. a -> Info a
Info [VName]
ext) SrcLoc
loc
checkExp (BinOp (QualName Name
op, SrcLoc
oploc) NoInfo PatternType
NoInfo (ExpBase NoInfo Name
e1, NoInfo (StructType, Maybe VName)
_) (ExpBase NoInfo Name
e2, NoInfo (StructType, Maybe VName)
_) NoInfo PatternType
NoInfo NoInfo [VName]
NoInfo SrcLoc
loc) = do
  (QualName VName
op', PatternType
ftype) <- SrcLoc -> QualName Name -> TermTypeM (QualName VName, PatternType)
forall (m :: * -> *).
MonadTypeChecker m =>
SrcLoc -> QualName Name -> m (QualName VName, PatternType)
lookupVar SrcLoc
oploc QualName Name
op
  Arg
e1_arg <- ExpBase NoInfo Name -> TermTypeM Arg
checkArg ExpBase NoInfo Name
e1
  Arg
e2_arg <- ExpBase NoInfo Name -> TermTypeM Arg
checkArg ExpBase NoInfo Name
e2

  -- Note that the application to the first operand cannot fix any
  -- existential sizes, because it must by necessity be a function.
  (PatternType
p1_t, PatternType
rt, Maybe VName
p1_ext, [VName]
_) <- SrcLoc
-> ApplyOp
-> PatternType
-> Arg
-> TermTypeM (PatternType, PatternType, Maybe VName, [VName])
checkApply SrcLoc
loc (QualName VName -> Maybe (QualName VName)
forall a. a -> Maybe a
Just QualName VName
op', Int
0) PatternType
ftype Arg
e1_arg
  (PatternType
p2_t, PatternType
rt', Maybe VName
p2_ext, [VName]
retext) <- SrcLoc
-> ApplyOp
-> PatternType
-> Arg
-> TermTypeM (PatternType, PatternType, Maybe VName, [VName])
checkApply SrcLoc
loc (QualName VName -> Maybe (QualName VName)
forall a. a -> Maybe a
Just QualName VName
op', Int
1) PatternType
rt Arg
e2_arg

  Exp -> TermTypeM Exp
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> TermTypeM Exp) -> Exp -> TermTypeM Exp
forall a b. (a -> b) -> a -> b
$
    (QualName VName, SrcLoc)
-> Info PatternType
-> (Exp, Info (StructType, Maybe VName))
-> (Exp, Info (StructType, Maybe VName))
-> Info PatternType
-> Info [VName]
-> SrcLoc
-> Exp
forall (f :: * -> *) vn.
(QualName vn, SrcLoc)
-> f PatternType
-> (ExpBase f vn, f (StructType, Maybe VName))
-> (ExpBase f vn, f (StructType, Maybe VName))
-> f PatternType
-> f [VName]
-> SrcLoc
-> ExpBase f vn
BinOp
      (QualName VName
op', SrcLoc
oploc)
      (PatternType -> Info PatternType
forall a. a -> Info a
Info PatternType
ftype)
      (Arg -> Exp
argExp Arg
e1_arg, (StructType, Maybe VName) -> Info (StructType, Maybe VName)
forall a. a -> Info a
Info (PatternType -> StructType
forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct PatternType
p1_t, Maybe VName
p1_ext))
      (Arg -> Exp
argExp Arg
e2_arg, (StructType, Maybe VName) -> Info (StructType, Maybe VName)
forall a. a -> Info a
Info (PatternType -> StructType
forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct PatternType
p2_t, Maybe VName
p2_ext))
      (PatternType -> Info PatternType
forall a. a -> Info a
Info PatternType
rt')
      ([VName] -> Info [VName]
forall a. a -> Info a
Info [VName]
retext)
      SrcLoc
loc
checkExp (Project Name
k ExpBase NoInfo Name
e NoInfo PatternType
NoInfo SrcLoc
loc) = do
  Exp
e' <- ExpBase NoInfo Name -> TermTypeM Exp
checkExp ExpBase NoInfo Name
e
  PatternType
t <- Exp -> TermTypeM PatternType
expType Exp
e'
  PatternType
kt <- Usage -> Name -> PatternType -> TermTypeM PatternType
forall (m :: * -> *).
MonadUnify m =>
Usage -> Name -> PatternType -> m PatternType
mustHaveField (SrcLoc -> String -> Usage
mkUsage SrcLoc
loc (String -> Usage) -> String -> Usage
forall a b. (a -> b) -> a -> b
$ String
"projection of field " String -> ShowS
forall a. [a] -> [a] -> [a]
++ ShowS
quote (Name -> String
forall a. Pretty a => a -> String
pretty Name
k)) Name
k PatternType
t
  Exp -> TermTypeM Exp
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> TermTypeM Exp) -> Exp -> TermTypeM Exp
forall a b. (a -> b) -> a -> b
$ Name -> Exp -> Info PatternType -> SrcLoc -> Exp
forall (f :: * -> *) vn.
Name -> ExpBase f vn -> f PatternType -> SrcLoc -> ExpBase f vn
Project Name
k Exp
e' (PatternType -> Info PatternType
forall a. a -> Info a
Info PatternType
kt) SrcLoc
loc
checkExp (If ExpBase NoInfo Name
e1 ExpBase NoInfo Name
e2 ExpBase NoInfo Name
e3 (NoInfo PatternType, NoInfo [VName])
_ SrcLoc
loc) =
  TermTypeM Exp
-> (Exp -> [Occurence] -> TermTypeM Exp) -> TermTypeM Exp
forall a b.
TermTypeM a -> (a -> [Occurence] -> TermTypeM b) -> TermTypeM b
sequentially TermTypeM Exp
checkCond ((Exp -> [Occurence] -> TermTypeM Exp) -> TermTypeM Exp)
-> (Exp -> [Occurence] -> TermTypeM Exp) -> TermTypeM Exp
forall a b. (a -> b) -> a -> b
$ \Exp
e1' [Occurence]
_ -> do
    ((Exp
e2', Exp
e3'), [Occurence]
dflow) <- TermTypeM (Exp, Exp) -> TermTypeM ((Exp, Exp), [Occurence])
forall a. TermTypeM a -> TermTypeM (a, [Occurence])
tapOccurences (TermTypeM (Exp, Exp) -> TermTypeM ((Exp, Exp), [Occurence]))
-> TermTypeM (Exp, Exp) -> TermTypeM ((Exp, Exp), [Occurence])
forall a b. (a -> b) -> a -> b
$ ExpBase NoInfo Name -> TermTypeM Exp
checkExp ExpBase NoInfo Name
e2 TermTypeM Exp -> TermTypeM Exp -> TermTypeM (Exp, Exp)
forall a b. TermTypeM a -> TermTypeM b -> TermTypeM (a, b)
`alternative` ExpBase NoInfo Name -> TermTypeM Exp
checkExp ExpBase NoInfo Name
e3

    (PatternType
brancht, [VName]
retext) <- SrcLoc -> Exp -> Exp -> TermTypeM (PatternType, [VName])
unifyBranches SrcLoc
loc Exp
e2' Exp
e3'
    let t' :: PatternType
t' = PatternType -> (Aliasing -> Aliasing) -> PatternType
forall dim asf ast.
TypeBase dim asf -> (asf -> ast) -> TypeBase dim ast
addAliases PatternType
brancht (Aliasing -> Aliasing -> Aliasing
forall a. Ord a => Set a -> Set a -> Set a
`S.difference` (VName -> Alias) -> Names -> Aliasing
forall b a. Ord b => (a -> b) -> Set a -> Set b
S.map VName -> Alias
AliasBound ([Occurence] -> Names
allConsumed [Occurence]
dflow))

    Usage -> String -> PatternType -> TermTypeM ()
forall (m :: * -> *) dim as.
(MonadUnify m, Pretty (ShapeDecl dim), Monoid as) =>
Usage -> String -> TypeBase dim as -> m ()
zeroOrderType
      (SrcLoc -> String -> Usage
mkUsage SrcLoc
loc String
"returning value of this type from 'if' expression")
      String
"type returned from branch"
      PatternType
t'

    Exp -> TermTypeM Exp
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> TermTypeM Exp) -> Exp -> TermTypeM Exp
forall a b. (a -> b) -> a -> b
$ Exp
-> Exp -> Exp -> (Info PatternType, Info [VName]) -> SrcLoc -> Exp
forall (f :: * -> *) vn.
ExpBase f vn
-> ExpBase f vn
-> ExpBase f vn
-> (f PatternType, f [VName])
-> SrcLoc
-> ExpBase f vn
If Exp
e1' Exp
e2' Exp
e3' (PatternType -> Info PatternType
forall a. a -> Info a
Info PatternType
t', [VName] -> Info [VName]
forall a. a -> Info a
Info [VName]
retext) SrcLoc
loc
  where
    checkCond :: TermTypeM Exp
checkCond = do
      Exp
e1' <- ExpBase NoInfo Name -> TermTypeM Exp
checkExp ExpBase NoInfo Name
e1
      let bool :: TypeBase dim as
bool = ScalarTypeBase dim as -> TypeBase dim as
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (ScalarTypeBase dim as -> TypeBase dim as)
-> ScalarTypeBase dim as -> TypeBase dim as
forall a b. (a -> b) -> a -> b
$ PrimType -> ScalarTypeBase dim as
forall dim as. PrimType -> ScalarTypeBase dim as
Prim PrimType
Bool
      StructType
e1_t <- PatternType -> StructType
forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct (PatternType -> StructType)
-> TermTypeM PatternType -> TermTypeM StructType
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp -> TermTypeM PatternType
expType Exp
e1'
      Checking -> TermTypeM () -> TermTypeM ()
forall a. Checking -> TermTypeM a -> TermTypeM a
onFailure ([StructType] -> StructType -> Checking
CheckingRequired [StructType
forall dim as. TypeBase dim as
bool] StructType
e1_t) (TermTypeM () -> TermTypeM ()) -> TermTypeM () -> TermTypeM ()
forall a b. (a -> b) -> a -> b
$
        Usage -> StructType -> StructType -> TermTypeM ()
forall (m :: * -> *).
MonadUnify m =>
Usage -> StructType -> StructType -> m ()
unify (SrcLoc -> String -> Usage
mkUsage (Exp -> SrcLoc
forall a. Located a => a -> SrcLoc
srclocOf Exp
e1') String
"use as 'if' condition") StructType
forall dim as. TypeBase dim as
bool StructType
e1_t
      Exp -> TermTypeM Exp
forall (m :: * -> *) a. Monad m => a -> m a
return Exp
e1'
checkExp (Parens ExpBase NoInfo Name
e SrcLoc
loc) =
  Exp -> SrcLoc -> Exp
forall (f :: * -> *) vn. ExpBase f vn -> SrcLoc -> ExpBase f vn
Parens (Exp -> SrcLoc -> Exp)
-> TermTypeM Exp -> TermTypeM (SrcLoc -> Exp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ExpBase NoInfo Name -> TermTypeM Exp
checkExp ExpBase NoInfo Name
e TermTypeM (SrcLoc -> Exp) -> TermTypeM SrcLoc -> TermTypeM Exp
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SrcLoc -> TermTypeM SrcLoc
forall (f :: * -> *) a. Applicative f => a -> f a
pure SrcLoc
loc
checkExp (QualParens (QualName Name
modname, SrcLoc
modnameloc) ExpBase NoInfo Name
e SrcLoc
loc) = do
  (QualName VName
modname', Mod
mod) <- SrcLoc -> QualName Name -> TermTypeM (QualName VName, Mod)
forall (m :: * -> *).
MonadTypeChecker m =>
SrcLoc -> QualName Name -> m (QualName VName, Mod)
lookupMod SrcLoc
loc QualName Name
modname
  case Mod
mod of
    ModEnv Env
env -> (TermEnv -> TermEnv) -> TermTypeM Exp -> TermTypeM Exp
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local (TermEnv -> Env -> TermEnv
`withEnv` QualName VName -> Env -> Env
qualifyEnv QualName VName
modname' Env
env) (TermTypeM Exp -> TermTypeM Exp) -> TermTypeM Exp -> TermTypeM Exp
forall a b. (a -> b) -> a -> b
$ do
      Exp
e' <- ExpBase NoInfo Name -> TermTypeM Exp
checkExp ExpBase NoInfo Name
e
      Exp -> TermTypeM Exp
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> TermTypeM Exp) -> Exp -> TermTypeM Exp
forall a b. (a -> b) -> a -> b
$ (QualName VName, SrcLoc) -> Exp -> SrcLoc -> Exp
forall (f :: * -> *) vn.
(QualName vn, SrcLoc) -> ExpBase f vn -> SrcLoc -> ExpBase f vn
QualParens (QualName VName
modname', SrcLoc
modnameloc) Exp
e' SrcLoc
loc
    ModFun {} ->
      SrcLoc -> Notes -> Doc -> TermTypeM Exp
forall (m :: * -> *) loc a.
(MonadTypeChecker m, Located loc) =>
loc -> Notes -> Doc -> m a
typeError SrcLoc
loc Notes
forall a. Monoid a => a
mempty (Doc -> TermTypeM Exp) -> Doc -> TermTypeM Exp
forall a b. (a -> b) -> a -> b
$ Doc
"Module" Doc -> Doc -> Doc
<+> QualName Name -> Doc
forall a. Pretty a => a -> Doc
ppr QualName Name
modname Doc -> Doc -> Doc
<+> Doc
" is a parametric module."
  where
    qualifyEnv :: QualName VName -> Env -> Env
qualifyEnv QualName VName
modname' Env
env =
      Env
env {envNameMap :: NameMap
envNameMap = (QualName VName -> QualName VName) -> NameMap -> NameMap
forall a b k. (a -> b) -> Map k a -> Map k b
M.map (QualName VName -> QualName VName -> QualName VName
forall vn. QualName vn -> QualName vn -> QualName vn
qualify' QualName VName
modname') (NameMap -> NameMap) -> NameMap -> NameMap
forall a b. (a -> b) -> a -> b
$ Env -> NameMap
envNameMap Env
env}
    qualify' :: QualName vn -> QualName vn -> QualName vn
qualify' QualName vn
modname' (QualName [vn]
qs vn
name) =
      [vn] -> vn -> QualName vn
forall vn. [vn] -> vn -> QualName vn
QualName (QualName vn -> [vn]
forall vn. QualName vn -> [vn]
qualQuals QualName vn
modname' [vn] -> [vn] -> [vn]
forall a. [a] -> [a] -> [a]
++ [QualName vn -> vn
forall vn. QualName vn -> vn
qualLeaf QualName vn
modname'] [vn] -> [vn] -> [vn]
forall a. [a] -> [a] -> [a]
++ [vn]
qs) vn
name
checkExp (Var QualName Name
qn NoInfo PatternType
NoInfo SrcLoc
loc) = do
  -- The qualifiers of a variable is divided into two parts: first a
  -- possibly-empty sequence of module qualifiers, followed by a
  -- possible-empty sequence of record field accesses.  We use scope
  -- information to perform the split, by taking qualifiers off the
  -- end until we find a module.

  (QualName VName
qn', PatternType
t, [Name]
fields) <- [Name] -> Name -> TermTypeM (QualName VName, PatternType, [Name])
forall b (m :: * -> *).
(MonadError b m, MonadTypeChecker m) =>
[Name] -> Name -> m (QualName VName, PatternType, [Name])
findRootVar (QualName Name -> [Name]
forall vn. QualName vn -> [vn]
qualQuals QualName Name
qn) (QualName Name -> Name
forall vn. QualName vn -> vn
qualLeaf QualName Name
qn)

  (Exp -> Name -> TermTypeM Exp) -> Exp -> [Name] -> TermTypeM Exp
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM Exp -> Name -> TermTypeM Exp
checkField (QualName VName -> Info PatternType -> SrcLoc -> Exp
forall (f :: * -> *) vn.
QualName vn -> f PatternType -> SrcLoc -> ExpBase f vn
Var QualName VName
qn' (PatternType -> Info PatternType
forall a. a -> Info a
Info PatternType
t) SrcLoc
loc) [Name]
fields
  where
    findRootVar :: [Name] -> Name -> m (QualName VName, PatternType, [Name])
findRootVar [Name]
qs Name
name =
      ((QualName VName, PatternType)
-> (QualName VName, PatternType, [Name])
forall a b a. (a, b) -> (a, b, [a])
whenFound ((QualName VName, PatternType)
 -> (QualName VName, PatternType, [Name]))
-> m (QualName VName, PatternType)
-> m (QualName VName, PatternType, [Name])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SrcLoc -> QualName Name -> m (QualName VName, PatternType)
forall (m :: * -> *).
MonadTypeChecker m =>
SrcLoc -> QualName Name -> m (QualName VName, PatternType)
lookupVar SrcLoc
loc ([Name] -> Name -> QualName Name
forall vn. [vn] -> vn -> QualName vn
QualName [Name]
qs Name
name)) m (QualName VName, PatternType, [Name])
-> (b -> m (QualName VName, PatternType, [Name]))
-> m (QualName VName, PatternType, [Name])
forall e (m :: * -> *) a.
MonadError e m =>
m a -> (e -> m a) -> m a
`catchError` [Name] -> Name -> b -> m (QualName VName, PatternType, [Name])
notFound [Name]
qs Name
name

    whenFound :: (a, b) -> (a, b, [a])
whenFound (a
qn', b
t) = (a
qn', b
t, [])

    notFound :: [Name] -> Name -> b -> m (QualName VName, PatternType, [Name])
notFound [Name]
qs Name
name b
err
      | [Name] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Name]
qs = b -> m (QualName VName, PatternType, [Name])
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError b
err
      | Bool
otherwise = do
        (QualName VName
qn', PatternType
t, [Name]
fields) <-
          [Name] -> Name -> m (QualName VName, PatternType, [Name])
findRootVar ([Name] -> [Name]
forall a. [a] -> [a]
init [Name]
qs) ([Name] -> Name
forall a. [a] -> a
last [Name]
qs)
            m (QualName VName, PatternType, [Name])
-> (b -> m (QualName VName, PatternType, [Name]))
-> m (QualName VName, PatternType, [Name])
forall e (m :: * -> *) a.
MonadError e m =>
m a -> (e -> m a) -> m a
`catchError` m (QualName VName, PatternType, [Name])
-> b -> m (QualName VName, PatternType, [Name])
forall a b. a -> b -> a
const (b -> m (QualName VName, PatternType, [Name])
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError b
err)
        (QualName VName, PatternType, [Name])
-> m (QualName VName, PatternType, [Name])
forall (m :: * -> *) a. Monad m => a -> m a
return (QualName VName
qn', PatternType
t, [Name]
fields [Name] -> [Name] -> [Name]
forall a. [a] -> [a] -> [a]
++ [Name
name])

    checkField :: Exp -> Name -> TermTypeM Exp
checkField Exp
e Name
k = do
      PatternType
t <- Exp -> TermTypeM PatternType
expType Exp
e
      let usage :: Usage
usage = SrcLoc -> String -> Usage
mkUsage SrcLoc
loc (String -> Usage) -> String -> Usage
forall a b. (a -> b) -> a -> b
$ String
"projection of field " String -> ShowS
forall a. [a] -> [a] -> [a]
++ ShowS
quote (Name -> String
forall a. Pretty a => a -> String
pretty Name
k)
      PatternType
kt <- Usage -> Name -> PatternType -> TermTypeM PatternType
forall (m :: * -> *).
MonadUnify m =>
Usage -> Name -> PatternType -> m PatternType
mustHaveField Usage
usage Name
k PatternType
t
      Exp -> TermTypeM Exp
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> TermTypeM Exp) -> Exp -> TermTypeM Exp
forall a b. (a -> b) -> a -> b
$ Name -> Exp -> Info PatternType -> SrcLoc -> Exp
forall (f :: * -> *) vn.
Name -> ExpBase f vn -> f PatternType -> SrcLoc -> ExpBase f vn
Project Name
k Exp
e (PatternType -> Info PatternType
forall a. a -> Info a
Info PatternType
kt) SrcLoc
loc
checkExp (Negate ExpBase NoInfo Name
arg SrcLoc
loc) = do
  Exp
arg' <- String -> [PrimType] -> Exp -> TermTypeM Exp
require String
"numeric negation" [PrimType]
anyNumberType (Exp -> TermTypeM Exp) -> TermTypeM Exp -> TermTypeM Exp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ExpBase NoInfo Name -> TermTypeM Exp
checkExp ExpBase NoInfo Name
arg
  Exp -> TermTypeM Exp
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> TermTypeM Exp) -> Exp -> TermTypeM Exp
forall a b. (a -> b) -> a -> b
$ Exp -> SrcLoc -> Exp
forall (f :: * -> *) vn. ExpBase f vn -> SrcLoc -> ExpBase f vn
Negate Exp
arg' SrcLoc
loc
checkExp e :: ExpBase NoInfo Name
e@Apply {} = (Exp, ApplyOp) -> Exp
forall a b. (a, b) -> a
fst ((Exp, ApplyOp) -> Exp)
-> TermTypeM (Exp, ApplyOp) -> TermTypeM Exp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ExpBase NoInfo Name -> TermTypeM (Exp, ApplyOp)
checkApplyExp ExpBase NoInfo Name
e
checkExp (LetPat UncheckedPattern
pat ExpBase NoInfo Name
e ExpBase NoInfo Name
body (NoInfo PatternType, NoInfo [VName])
_ SrcLoc
loc) =
  TermTypeM Exp
-> (Exp -> [Occurence] -> TermTypeM Exp) -> TermTypeM Exp
forall a b.
TermTypeM a -> (a -> [Occurence] -> TermTypeM b) -> TermTypeM b
sequentially (ExpBase NoInfo Name -> TermTypeM Exp
checkExp ExpBase NoInfo Name
e) ((Exp -> [Occurence] -> TermTypeM Exp) -> TermTypeM Exp)
-> (Exp -> [Occurence] -> TermTypeM Exp) -> TermTypeM Exp
forall a b. (a -> b) -> a -> b
$ \Exp
e' [Occurence]
e_occs -> do
    -- Not technically an ascription, but we want the pattern to have
    -- exactly the type of 'e'.
    PatternType
t <- Exp -> TermTypeM PatternType
expType Exp
e'
    case [Occurence] -> Maybe Occurence
anyConsumption [Occurence]
e_occs of
      Just Occurence
c ->
        let msg :: String
msg = String
"type computed with consumption at " String -> ShowS
forall a. [a] -> [a] -> [a]
++ SrcLoc -> String
forall a. Located a => a -> String
locStr (Occurence -> SrcLoc
location Occurence
c)
         in Usage -> String -> PatternType -> TermTypeM ()
forall (m :: * -> *) dim as.
(MonadUnify m, Pretty (ShapeDecl dim), Monoid as) =>
Usage -> String -> TypeBase dim as -> m ()
zeroOrderType (SrcLoc -> String -> Usage
mkUsage SrcLoc
loc String
"consumption in right-hand side of 'let'-binding") String
msg PatternType
t
      Maybe Occurence
_ -> () -> TermTypeM ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

    TermTypeM Exp -> TermTypeM Exp
forall b. TermTypeM b -> TermTypeM b
incLevel (TermTypeM Exp -> TermTypeM Exp) -> TermTypeM Exp -> TermTypeM Exp
forall a b. (a -> b) -> a -> b
$
      UncheckedPattern
-> InferredType -> (Pattern -> TermTypeM Exp) -> TermTypeM Exp
forall a.
UncheckedPattern
-> InferredType -> (Pattern -> TermTypeM a) -> TermTypeM a
bindingPattern UncheckedPattern
pat (PatternType -> InferredType
Ascribed PatternType
t) ((Pattern -> TermTypeM Exp) -> TermTypeM Exp)
-> (Pattern -> TermTypeM Exp) -> TermTypeM Exp
forall a b. (a -> b) -> a -> b
$ \Pattern
pat' -> do
        Exp
body' <- ExpBase NoInfo Name -> TermTypeM Exp
checkExp ExpBase NoInfo Name
body
        (PatternType
body_t, [VName]
retext) <-
          SrcLoc
-> Map VName Ident
-> PatternType
-> TermTypeM (PatternType, [VName])
unscopeType SrcLoc
loc (Pattern -> Map VName Ident
forall (f :: * -> *).
Functor f =>
PatternBase f VName -> Map VName (IdentBase f VName)
patternMap Pattern
pat') (PatternType -> TermTypeM (PatternType, [VName]))
-> TermTypeM PatternType -> TermTypeM (PatternType, [VName])
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Exp -> TermTypeM PatternType
expTypeFully Exp
body'

        Exp -> TermTypeM Exp
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> TermTypeM Exp) -> Exp -> TermTypeM Exp
forall a b. (a -> b) -> a -> b
$ Pattern
-> Exp -> Exp -> (Info PatternType, Info [VName]) -> SrcLoc -> Exp
forall (f :: * -> *) vn.
PatternBase f vn
-> ExpBase f vn
-> ExpBase f vn
-> (f PatternType, f [VName])
-> SrcLoc
-> ExpBase f vn
LetPat Pattern
pat' Exp
e' Exp
body' (PatternType -> Info PatternType
forall a. a -> Info a
Info PatternType
body_t, [VName] -> Info [VName]
forall a. a -> Info a
Info [VName]
retext) SrcLoc
loc
checkExp (LetFun Name
name ([UncheckedTypeParam]
tparams, [UncheckedPattern]
params, Maybe (TypeExp Name)
maybe_retdecl, NoInfo StructType
NoInfo, ExpBase NoInfo Name
e) ExpBase NoInfo Name
body NoInfo PatternType
NoInfo SrcLoc
loc) =
  TermTypeM
  ([TypeParam], [Pattern], Maybe (TypeExp VName), StructType,
   [VName], Exp)
-> (([TypeParam], [Pattern], Maybe (TypeExp VName), StructType,
     [VName], Exp)
    -> [Occurence] -> TermTypeM Exp)
-> TermTypeM Exp
forall a b.
TermTypeM a -> (a -> [Occurence] -> TermTypeM b) -> TermTypeM b
sequentially ((Name, Maybe (TypeExp Name), [UncheckedTypeParam],
 [UncheckedPattern], ExpBase NoInfo Name, SrcLoc)
-> TermTypeM
     ([TypeParam], [Pattern], Maybe (TypeExp VName), StructType,
      [VName], Exp)
checkBinding (Name
name, Maybe (TypeExp Name)
maybe_retdecl, [UncheckedTypeParam]
tparams, [UncheckedPattern]
params, ExpBase NoInfo Name
e, SrcLoc
loc)) ((([TypeParam], [Pattern], Maybe (TypeExp VName), StructType,
   [VName], Exp)
  -> [Occurence] -> TermTypeM Exp)
 -> TermTypeM Exp)
-> (([TypeParam], [Pattern], Maybe (TypeExp VName), StructType,
     [VName], Exp)
    -> [Occurence] -> TermTypeM Exp)
-> TermTypeM Exp
forall a b. (a -> b) -> a -> b
$
    \([TypeParam]
tparams', [Pattern]
params', Maybe (TypeExp VName)
maybe_retdecl', StructType
rettype, [VName]
_, Exp
e') [Occurence]
closure -> do
      Aliasing
closure' <- [Pattern] -> [Occurence] -> TermTypeM Aliasing
lexicalClosure [Pattern]
params' [Occurence]
closure

      [(Namespace, Name)] -> TermTypeM Exp -> TermTypeM Exp
forall (m :: * -> *) a.
MonadTypeChecker m =>
[(Namespace, Name)] -> m a -> m a
bindSpaced [(Namespace
Term, Name
name)] (TermTypeM Exp -> TermTypeM Exp) -> TermTypeM Exp -> TermTypeM Exp
forall a b. (a -> b) -> a -> b
$ do
        VName
name' <- Namespace -> Name -> SrcLoc -> TermTypeM VName
forall (m :: * -> *).
MonadTypeChecker m =>
Namespace -> Name -> SrcLoc -> m VName
checkName Namespace
Term Name
name SrcLoc
loc

        let arrow :: (PName, TypeBase dim ()) -> TypeBase dim () -> TypeBase dim ()
arrow (PName
xp, TypeBase dim ()
xt) TypeBase dim ()
yt = ScalarTypeBase dim () -> TypeBase dim ()
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (ScalarTypeBase dim () -> TypeBase dim ())
-> ScalarTypeBase dim () -> TypeBase dim ()
forall a b. (a -> b) -> a -> b
$ ()
-> PName
-> TypeBase dim ()
-> TypeBase dim ()
-> ScalarTypeBase dim ()
forall dim as.
as
-> PName
-> TypeBase dim as
-> TypeBase dim as
-> ScalarTypeBase dim as
Arrow () PName
xp TypeBase dim ()
xt TypeBase dim ()
yt
            ftype :: StructType
ftype = (Pattern -> StructType -> StructType)
-> StructType -> [Pattern] -> StructType
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr ((PName, StructType) -> StructType -> StructType
forall dim.
(PName, TypeBase dim ()) -> TypeBase dim () -> TypeBase dim ()
arrow ((PName, StructType) -> StructType -> StructType)
-> (Pattern -> (PName, StructType))
-> Pattern
-> StructType
-> StructType
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pattern -> (PName, StructType)
patternParam) StructType
rettype [Pattern]
params'
            entry :: ValBinding
entry = Locality -> [TypeParam] -> PatternType -> ValBinding
BoundV Locality
Local [TypeParam]
tparams' (PatternType -> ValBinding) -> PatternType -> ValBinding
forall a b. (a -> b) -> a -> b
$ StructType
ftype StructType -> Aliasing -> PatternType
forall dim asf ast. TypeBase dim asf -> ast -> TypeBase dim ast
`setAliases` Aliasing
closure'
            bindF :: TermScope -> TermScope
bindF TermScope
scope =
              TermScope
scope
                { scopeVtable :: Map VName ValBinding
scopeVtable =
                    VName -> ValBinding -> Map VName ValBinding -> Map VName ValBinding
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
name' ValBinding
entry (Map VName ValBinding -> Map VName ValBinding)
-> Map VName ValBinding -> Map VName ValBinding
forall a b. (a -> b) -> a -> b
$ TermScope -> Map VName ValBinding
scopeVtable TermScope
scope,
                  scopeNameMap :: NameMap
scopeNameMap =
                    (Namespace, Name) -> QualName VName -> NameMap -> NameMap
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert (Namespace
Term, Name
name) (VName -> QualName VName
forall v. v -> QualName v
qualName VName
name') (NameMap -> NameMap) -> NameMap -> NameMap
forall a b. (a -> b) -> a -> b
$
                      TermScope -> NameMap
scopeNameMap TermScope
scope
                }
        Exp
body' <- (TermScope -> TermScope) -> TermTypeM Exp -> TermTypeM Exp
forall a. (TermScope -> TermScope) -> TermTypeM a -> TermTypeM a
localScope TermScope -> TermScope
bindF (TermTypeM Exp -> TermTypeM Exp) -> TermTypeM Exp -> TermTypeM Exp
forall a b. (a -> b) -> a -> b
$ ExpBase NoInfo Name -> TermTypeM Exp
checkExp ExpBase NoInfo Name
body

        -- We fake an ident here, but it's OK as it can't be a size
        -- anyway.
        let fake_ident :: Ident
fake_ident = VName -> Info PatternType -> SrcLoc -> Ident
forall (f :: * -> *) vn.
vn -> f PatternType -> SrcLoc -> IdentBase f vn
Ident VName
name' (PatternType -> Info PatternType
forall a. a -> Info a
Info (PatternType -> Info PatternType)
-> PatternType -> Info PatternType
forall a b. (a -> b) -> a -> b
$ StructType -> PatternType
forall dim as. TypeBase dim as -> TypeBase dim Aliasing
fromStruct StructType
ftype) SrcLoc
forall a. Monoid a => a
mempty
        (PatternType
body_t, [VName]
_) <-
          SrcLoc
-> Map VName Ident
-> PatternType
-> TermTypeM (PatternType, [VName])
unscopeType SrcLoc
loc (VName -> Ident -> Map VName Ident
forall k a. k -> a -> Map k a
M.singleton VName
name' Ident
fake_ident)
            (PatternType -> TermTypeM (PatternType, [VName]))
-> TermTypeM PatternType -> TermTypeM (PatternType, [VName])
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Exp -> TermTypeM PatternType
expTypeFully Exp
body'

        Exp -> TermTypeM Exp
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> TermTypeM Exp) -> Exp -> TermTypeM Exp
forall a b. (a -> b) -> a -> b
$
          VName
-> ([TypeParam], [Pattern], Maybe (TypeExp VName), Info StructType,
    Exp)
-> Exp
-> Info PatternType
-> SrcLoc
-> Exp
forall (f :: * -> *) vn.
vn
-> ([TypeParamBase vn], [PatternBase f vn], Maybe (TypeExp vn),
    f StructType, ExpBase f vn)
-> ExpBase f vn
-> f PatternType
-> SrcLoc
-> ExpBase f vn
LetFun
            VName
name'
            ([TypeParam]
tparams', [Pattern]
params', Maybe (TypeExp VName)
maybe_retdecl', StructType -> Info StructType
forall a. a -> Info a
Info StructType
rettype, Exp
e')
            Exp
body'
            (PatternType -> Info PatternType
forall a. a -> Info a
Info PatternType
body_t)
            SrcLoc
loc
checkExp (LetWith IdentBase NoInfo Name
dest IdentBase NoInfo Name
src [DimIndexBase NoInfo Name]
idxes ExpBase NoInfo Name
ve ExpBase NoInfo Name
body NoInfo PatternType
NoInfo SrcLoc
loc) =
  TermTypeM Ident
-> (Ident -> [Occurence] -> TermTypeM Exp) -> TermTypeM Exp
forall a b.
TermTypeM a -> (a -> [Occurence] -> TermTypeM b) -> TermTypeM b
sequentially (IdentBase NoInfo Name -> TermTypeM Ident
checkIdent IdentBase NoInfo Name
src) ((Ident -> [Occurence] -> TermTypeM Exp) -> TermTypeM Exp)
-> (Ident -> [Occurence] -> TermTypeM Exp) -> TermTypeM Exp
forall a b. (a -> b) -> a -> b
$ \Ident
src' [Occurence]
_ -> do
    (StructType
t, StructType
_) <- SrcLoc -> String -> Int -> TermTypeM (StructType, StructType)
newArrayType (IdentBase NoInfo Name -> SrcLoc
forall a. Located a => a -> SrcLoc
srclocOf IdentBase NoInfo Name
src) String
"src" (Int -> TermTypeM (StructType, StructType))
-> Int -> TermTypeM (StructType, StructType)
forall a b. (a -> b) -> a -> b
$ [DimIndexBase NoInfo Name] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimIndexBase NoInfo Name]
idxes
    Usage -> StructType -> StructType -> TermTypeM ()
forall (m :: * -> *).
MonadUnify m =>
Usage -> StructType -> StructType -> m ()
unify (SrcLoc -> String -> Usage
mkUsage SrcLoc
loc String
"type of target array") StructType
t (StructType -> TermTypeM ()) -> StructType -> TermTypeM ()
forall a b. (a -> b) -> a -> b
$ PatternType -> StructType
forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct (PatternType -> StructType) -> PatternType -> StructType
forall a b. (a -> b) -> a -> b
$ Info PatternType -> PatternType
forall a. Info a -> a
unInfo (Info PatternType -> PatternType)
-> Info PatternType -> PatternType
forall a b. (a -> b) -> a -> b
$ Ident -> Info PatternType
forall (f :: * -> *) vn. IdentBase f vn -> f PatternType
identType Ident
src'

    -- Need the fully normalised type here to get the proper aliasing information.
    PatternType
src_t <- PatternType -> TermTypeM PatternType
forall a (m :: * -> *). (Substitutable a, MonadUnify m) => a -> m a
normTypeFully (PatternType -> TermTypeM PatternType)
-> PatternType -> TermTypeM PatternType
forall a b. (a -> b) -> a -> b
$ Info PatternType -> PatternType
forall a. Info a -> a
unInfo (Info PatternType -> PatternType)
-> Info PatternType -> PatternType
forall a b. (a -> b) -> a -> b
$ Ident -> Info PatternType
forall (f :: * -> *) vn. IdentBase f vn -> f PatternType
identType Ident
src'

    [DimIndex]
idxes' <- (DimIndexBase NoInfo Name -> TermTypeM DimIndex)
-> [DimIndexBase NoInfo Name] -> TermTypeM [DimIndex]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM DimIndexBase NoInfo Name -> TermTypeM DimIndex
checkDimIndex [DimIndexBase NoInfo Name]
idxes
    (StructType
elemt, [VName]
_) <- Maybe (SrcLoc, Rigidity)
-> [DimIndex] -> StructType -> TermTypeM (StructType, [VName])
forall as.
Maybe (SrcLoc, Rigidity)
-> [DimIndex]
-> TypeBase (DimDecl VName) as
-> TermTypeM (TypeBase (DimDecl VName) as, [VName])
sliceShape ((SrcLoc, Rigidity) -> Maybe (SrcLoc, Rigidity)
forall a. a -> Maybe a
Just (SrcLoc
loc, Rigidity
Nonrigid)) [DimIndex]
idxes' (StructType -> TermTypeM (StructType, [VName]))
-> TermTypeM StructType -> TermTypeM (StructType, [VName])
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< StructType -> TermTypeM StructType
forall a (m :: * -> *). (Substitutable a, MonadUnify m) => a -> m a
normTypeFully StructType
t

    Bool -> TermTypeM () -> TermTypeM ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (PatternType -> Bool
forall shape as. TypeBase shape as -> Bool
unique PatternType
src_t) (TermTypeM () -> TermTypeM ()) -> TermTypeM () -> TermTypeM ()
forall a b. (a -> b) -> a -> b
$
      SrcLoc -> Notes -> Doc -> TermTypeM ()
forall (m :: * -> *) loc a.
(MonadTypeChecker m, Located loc) =>
loc -> Notes -> Doc -> m a
typeError SrcLoc
loc Notes
forall a. Monoid a => a
mempty (Doc -> TermTypeM ()) -> Doc -> TermTypeM ()
forall a b. (a -> b) -> a -> b
$
        Doc
"Source" Doc -> Doc -> Doc
<+> Doc -> Doc
pquote (Name -> Doc
forall v. IsName v => v -> Doc
pprName (IdentBase NoInfo Name -> Name
forall (f :: * -> *) vn. IdentBase f vn -> vn
identName IdentBase NoInfo Name
src))
          Doc -> Doc -> Doc
<+> Doc
"has type"
          Doc -> Doc -> Doc
<+> PatternType -> Doc
forall a. Pretty a => a -> Doc
ppr PatternType
src_t Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
", which is not unique."
    Map VName ValBinding
vtable <- (TermEnv -> Map VName ValBinding)
-> TermTypeM (Map VName ValBinding)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((TermEnv -> Map VName ValBinding)
 -> TermTypeM (Map VName ValBinding))
-> (TermEnv -> Map VName ValBinding)
-> TermTypeM (Map VName ValBinding)
forall a b. (a -> b) -> a -> b
$ TermScope -> Map VName ValBinding
scopeVtable (TermScope -> Map VName ValBinding)
-> (TermEnv -> TermScope) -> TermEnv -> Map VName ValBinding
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TermEnv -> TermScope
termScope
    Aliasing -> (Alias -> TermTypeM ()) -> TermTypeM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (PatternType -> Aliasing
forall as shape. Monoid as => TypeBase shape as -> as
aliases PatternType
src_t) ((Alias -> TermTypeM ()) -> TermTypeM ())
-> (Alias -> TermTypeM ()) -> TermTypeM ()
forall a b. (a -> b) -> a -> b
$ \Alias
v ->
      case Alias -> VName
aliasVar Alias
v VName -> Map VName ValBinding -> Maybe ValBinding
forall k a. Ord k => k -> Map k a -> Maybe a
`M.lookup` Map VName ValBinding
vtable of
        Just (BoundV Locality
Local [TypeParam]
_ PatternType
v_t)
          | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ PatternType -> Bool
forall shape as. TypeBase shape as -> Bool
unique PatternType
v_t ->
            SrcLoc -> Notes -> Doc -> TermTypeM ()
forall (m :: * -> *) loc a.
(MonadTypeChecker m, Located loc) =>
loc -> Notes -> Doc -> m a
typeError SrcLoc
loc Notes
forall a. Monoid a => a
mempty (Doc -> TermTypeM ()) -> Doc -> TermTypeM ()
forall a b. (a -> b) -> a -> b
$
              Doc
"Source" Doc -> Doc -> Doc
<+> Doc -> Doc
pquote (Name -> Doc
forall v. IsName v => v -> Doc
pprName (IdentBase NoInfo Name -> Name
forall (f :: * -> *) vn. IdentBase f vn -> vn
identName IdentBase NoInfo Name
src))
                Doc -> Doc -> Doc
<+> Doc
"aliases"
                Doc -> Doc -> Doc
<+> Doc -> Doc
pquote (VName -> Doc
forall v. IsName v => v -> Doc
pprName (Alias -> VName
aliasVar Alias
v)) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
", which is not consumable."
        Maybe ValBinding
_ -> () -> TermTypeM ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

    TermTypeM Exp
-> (Exp -> [Occurence] -> TermTypeM Exp) -> TermTypeM Exp
forall a b.
TermTypeM a -> (a -> [Occurence] -> TermTypeM b) -> TermTypeM b
sequentially (String -> StructType -> Exp -> TermTypeM Exp
unifies String
"type of target array" (StructType -> StructType
forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct StructType
elemt) (Exp -> TermTypeM Exp) -> TermTypeM Exp -> TermTypeM Exp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ExpBase NoInfo Name -> TermTypeM Exp
checkExp ExpBase NoInfo Name
ve) ((Exp -> [Occurence] -> TermTypeM Exp) -> TermTypeM Exp)
-> (Exp -> [Occurence] -> TermTypeM Exp) -> TermTypeM Exp
forall a b. (a -> b) -> a -> b
$ \Exp
ve' [Occurence]
_ -> do
      PatternType
ve_t <- Exp -> TermTypeM PatternType
expTypeFully Exp
ve'
      Bool -> TermTypeM () -> TermTypeM ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (VName -> Alias
AliasBound (Ident -> VName
forall (f :: * -> *) vn. IdentBase f vn -> vn
identName Ident
src') Alias -> Aliasing -> Bool
forall a. Ord a => a -> Set a -> Bool
`S.member` PatternType -> Aliasing
forall as shape. Monoid as => TypeBase shape as -> as
aliases PatternType
ve_t) (TermTypeM () -> TermTypeM ()) -> TermTypeM () -> TermTypeM ()
forall a b. (a -> b) -> a -> b
$
        SrcLoc -> TermTypeM ()
forall a. SrcLoc -> TermTypeM a
badLetWithValue SrcLoc
loc

      IdentBase NoInfo Name
-> PatternType -> (Ident -> TermTypeM Exp) -> TermTypeM Exp
forall a.
IdentBase NoInfo Name
-> PatternType -> (Ident -> TermTypeM a) -> TermTypeM a
bindingIdent IdentBase NoInfo Name
dest (PatternType
src_t PatternType -> Aliasing -> PatternType
forall dim asf ast. TypeBase dim asf -> ast -> TypeBase dim ast
`setAliases` Aliasing
forall a. Set a
S.empty) ((Ident -> TermTypeM Exp) -> TermTypeM Exp)
-> (Ident -> TermTypeM Exp) -> TermTypeM Exp
forall a b. (a -> b) -> a -> b
$ \Ident
dest' -> do
        Exp
body' <- Ident -> TermTypeM Exp -> TermTypeM Exp
forall a. Ident -> TermTypeM a -> TermTypeM a
consuming Ident
src' (TermTypeM Exp -> TermTypeM Exp) -> TermTypeM Exp -> TermTypeM Exp
forall a b. (a -> b) -> a -> b
$ ExpBase NoInfo Name -> TermTypeM Exp
checkExp ExpBase NoInfo Name
body
        (PatternType
body_t, [VName]
_) <-
          SrcLoc
-> Map VName Ident
-> PatternType
-> TermTypeM (PatternType, [VName])
unscopeType SrcLoc
loc (VName -> Ident -> Map VName Ident
forall k a. k -> a -> Map k a
M.singleton (Ident -> VName
forall (f :: * -> *) vn. IdentBase f vn -> vn
identName Ident
dest') Ident
dest')
            (PatternType -> TermTypeM (PatternType, [VName]))
-> TermTypeM PatternType -> TermTypeM (PatternType, [VName])
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Exp -> TermTypeM PatternType
expTypeFully Exp
body'
        Exp -> TermTypeM Exp
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> TermTypeM Exp) -> Exp -> TermTypeM Exp
forall a b. (a -> b) -> a -> b
$ Ident
-> Ident
-> [DimIndex]
-> Exp
-> Exp
-> Info PatternType
-> SrcLoc
-> Exp
forall (f :: * -> *) vn.
IdentBase f vn
-> IdentBase f vn
-> [DimIndexBase f vn]
-> ExpBase f vn
-> ExpBase f vn
-> f PatternType
-> SrcLoc
-> ExpBase f vn
LetWith Ident
dest' Ident
src' [DimIndex]
idxes' Exp
ve' Exp
body' (PatternType -> Info PatternType
forall a. a -> Info a
Info PatternType
body_t) SrcLoc
loc
checkExp (Update ExpBase NoInfo Name
src [DimIndexBase NoInfo Name]
idxes ExpBase NoInfo Name
ve SrcLoc
loc) = do
  (StructType
t, StructType
_) <- SrcLoc -> String -> Int -> TermTypeM (StructType, StructType)
newArrayType (ExpBase NoInfo Name -> SrcLoc
forall a. Located a => a -> SrcLoc
srclocOf ExpBase NoInfo Name
src) String
"src" (Int -> TermTypeM (StructType, StructType))
-> Int -> TermTypeM (StructType, StructType)
forall a b. (a -> b) -> a -> b
$ [DimIndexBase NoInfo Name] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimIndexBase NoInfo Name]
idxes
  [DimIndex]
idxes' <- (DimIndexBase NoInfo Name -> TermTypeM DimIndex)
-> [DimIndexBase NoInfo Name] -> TermTypeM [DimIndex]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM DimIndexBase NoInfo Name -> TermTypeM DimIndex
checkDimIndex [DimIndexBase NoInfo Name]
idxes
  (StructType
elemt, [VName]
_) <- Maybe (SrcLoc, Rigidity)
-> [DimIndex] -> StructType -> TermTypeM (StructType, [VName])
forall as.
Maybe (SrcLoc, Rigidity)
-> [DimIndex]
-> TypeBase (DimDecl VName) as
-> TermTypeM (TypeBase (DimDecl VName) as, [VName])
sliceShape ((SrcLoc, Rigidity) -> Maybe (SrcLoc, Rigidity)
forall a. a -> Maybe a
Just (SrcLoc
loc, Rigidity
Nonrigid)) [DimIndex]
idxes' (StructType -> TermTypeM (StructType, [VName]))
-> TermTypeM StructType -> TermTypeM (StructType, [VName])
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< StructType -> TermTypeM StructType
forall a (m :: * -> *). (Substitutable a, MonadUnify m) => a -> m a
normTypeFully StructType
t

  TermTypeM Exp
-> (Exp -> [Occurence] -> TermTypeM Exp) -> TermTypeM Exp
forall a b.
TermTypeM a -> (a -> [Occurence] -> TermTypeM b) -> TermTypeM b
sequentially (ExpBase NoInfo Name -> TermTypeM Exp
checkExp ExpBase NoInfo Name
ve TermTypeM Exp -> (Exp -> TermTypeM Exp) -> TermTypeM Exp
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= String -> StructType -> Exp -> TermTypeM Exp
unifies String
"type of target array" StructType
elemt) ((Exp -> [Occurence] -> TermTypeM Exp) -> TermTypeM Exp)
-> (Exp -> [Occurence] -> TermTypeM Exp) -> TermTypeM Exp
forall a b. (a -> b) -> a -> b
$ \Exp
ve' [Occurence]
_ ->
    TermTypeM Exp
-> (Exp -> [Occurence] -> TermTypeM Exp) -> TermTypeM Exp
forall a b.
TermTypeM a -> (a -> [Occurence] -> TermTypeM b) -> TermTypeM b
sequentially (ExpBase NoInfo Name -> TermTypeM Exp
checkExp ExpBase NoInfo Name
src TermTypeM Exp -> (Exp -> TermTypeM Exp) -> TermTypeM Exp
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= String -> StructType -> Exp -> TermTypeM Exp
unifies String
"type of target array" StructType
t) ((Exp -> [Occurence] -> TermTypeM Exp) -> TermTypeM Exp)
-> (Exp -> [Occurence] -> TermTypeM Exp) -> TermTypeM Exp
forall a b. (a -> b) -> a -> b
$ \Exp
src' [Occurence]
_ -> do
      PatternType
src_t <- Exp -> TermTypeM PatternType
expTypeFully Exp
src'
      Bool -> TermTypeM () -> TermTypeM ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (PatternType -> Bool
forall shape as. TypeBase shape as -> Bool
unique PatternType
src_t) (TermTypeM () -> TermTypeM ()) -> TermTypeM () -> TermTypeM ()
forall a b. (a -> b) -> a -> b
$
        SrcLoc -> Notes -> Doc -> TermTypeM ()
forall (m :: * -> *) loc a.
(MonadTypeChecker m, Located loc) =>
loc -> Notes -> Doc -> m a
typeError SrcLoc
loc Notes
forall a. Monoid a => a
mempty (Doc -> TermTypeM ()) -> Doc -> TermTypeM ()
forall a b. (a -> b) -> a -> b
$
          Doc
"Source" Doc -> Doc -> Doc
<+> Doc -> Doc
pquote (ExpBase NoInfo Name -> Doc
forall a. Pretty a => a -> Doc
ppr ExpBase NoInfo Name
src)
            Doc -> Doc -> Doc
<+> Doc
"has type"
            Doc -> Doc -> Doc
<+> PatternType -> Doc
forall a. Pretty a => a -> Doc
ppr PatternType
src_t Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
", which is not unique."

      let src_als :: Aliasing
src_als = PatternType -> Aliasing
forall as shape. Monoid as => TypeBase shape as -> as
aliases PatternType
src_t
      PatternType
ve_t <- Exp -> TermTypeM PatternType
expTypeFully Exp
ve'
      Bool -> TermTypeM () -> TermTypeM ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Aliasing -> Bool
forall a. Set a -> Bool
S.null (Aliasing -> Bool) -> Aliasing -> Bool
forall a b. (a -> b) -> a -> b
$ Aliasing
src_als Aliasing -> Aliasing -> Aliasing
forall a. Ord a => Set a -> Set a -> Set a
`S.intersection` PatternType -> Aliasing
forall as shape. Monoid as => TypeBase shape as -> as
aliases PatternType
ve_t) (TermTypeM () -> TermTypeM ()) -> TermTypeM () -> TermTypeM ()
forall a b. (a -> b) -> a -> b
$ SrcLoc -> TermTypeM ()
forall a. SrcLoc -> TermTypeM a
badLetWithValue SrcLoc
loc

      SrcLoc -> Aliasing -> TermTypeM ()
consume SrcLoc
loc Aliasing
src_als
      Exp -> TermTypeM Exp
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> TermTypeM Exp) -> Exp -> TermTypeM Exp
forall a b. (a -> b) -> a -> b
$ Exp -> [DimIndex] -> Exp -> SrcLoc -> Exp
forall (f :: * -> *) vn.
ExpBase f vn
-> [DimIndexBase f vn] -> ExpBase f vn -> SrcLoc -> ExpBase f vn
Update Exp
src' [DimIndex]
idxes' Exp
ve' SrcLoc
loc

-- Record updates are a bit hacky, because we do not have row typing
-- (yet?).  For now, we only permit record updates where we know the
-- full type up to the field we are updating.
checkExp (RecordUpdate ExpBase NoInfo Name
src [Name]
fields ExpBase NoInfo Name
ve NoInfo PatternType
NoInfo SrcLoc
loc) = do
  Exp
src' <- ExpBase NoInfo Name -> TermTypeM Exp
checkExp ExpBase NoInfo Name
src
  Exp
ve' <- ExpBase NoInfo Name -> TermTypeM Exp
checkExp ExpBase NoInfo Name
ve
  PatternType
a <- Exp -> TermTypeM PatternType
expTypeFully Exp
src'
  let usage :: Usage
usage = SrcLoc -> String -> Usage
mkUsage SrcLoc
loc String
"record update"
  PatternType
r <- (PatternType -> Name -> TermTypeM PatternType)
-> PatternType -> [Name] -> TermTypeM PatternType
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM ((Name -> PatternType -> TermTypeM PatternType)
-> PatternType -> Name -> TermTypeM PatternType
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((Name -> PatternType -> TermTypeM PatternType)
 -> PatternType -> Name -> TermTypeM PatternType)
-> (Name -> PatternType -> TermTypeM PatternType)
-> PatternType
-> Name
-> TermTypeM PatternType
forall a b. (a -> b) -> a -> b
$ Usage -> Name -> PatternType -> TermTypeM PatternType
forall (m :: * -> *).
MonadUnify m =>
Usage -> Name -> PatternType -> m PatternType
mustHaveField Usage
usage) PatternType
a [Name]
fields
  PatternType
ve_t <- Exp -> TermTypeM PatternType
expType Exp
ve'
  let r' :: StructType
r' = StructType -> StructType
forall vn as. TypeBase (DimDecl vn) as -> TypeBase (DimDecl vn) as
anySizes (StructType -> StructType) -> StructType -> StructType
forall a b. (a -> b) -> a -> b
$ PatternType -> StructType
forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct PatternType
r
      ve_t' :: StructType
ve_t' = StructType -> StructType
forall vn as. TypeBase (DimDecl vn) as -> TypeBase (DimDecl vn) as
anySizes (StructType -> StructType) -> StructType -> StructType
forall a b. (a -> b) -> a -> b
$ PatternType -> StructType
forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct PatternType
ve_t
  Checking -> TermTypeM () -> TermTypeM ()
forall a. Checking -> TermTypeM a -> TermTypeM a
onFailure ([Name] -> StructType -> StructType -> Checking
CheckingRecordUpdate [Name]
fields StructType
r' StructType
ve_t') (TermTypeM () -> TermTypeM ()) -> TermTypeM () -> TermTypeM ()
forall a b. (a -> b) -> a -> b
$
    Usage -> StructType -> StructType -> TermTypeM ()
forall (m :: * -> *).
MonadUnify m =>
Usage -> StructType -> StructType -> m ()
unify Usage
usage StructType
r' StructType
ve_t'
  Maybe PatternType
maybe_a' <- (PatternType -> PatternType)
-> [Name] -> PatternType -> Maybe PatternType
forall dim als.
(TypeBase dim als -> TypeBase dim als)
-> [Name] -> TypeBase dim als -> Maybe (TypeBase dim als)
onRecordField (PatternType -> PatternType -> PatternType
forall a b. a -> b -> a
const PatternType
ve_t) [Name]
fields (PatternType -> Maybe PatternType)
-> TermTypeM PatternType -> TermTypeM (Maybe PatternType)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp -> TermTypeM PatternType
expTypeFully Exp
src'
  case Maybe PatternType
maybe_a' of
    Just PatternType
a' -> Exp -> TermTypeM Exp
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> TermTypeM Exp) -> Exp -> TermTypeM Exp
forall a b. (a -> b) -> a -> b
$ Exp -> [Name] -> Exp -> Info PatternType -> SrcLoc -> Exp
forall (f :: * -> *) vn.
ExpBase f vn
-> [Name]
-> ExpBase f vn
-> f PatternType
-> SrcLoc
-> ExpBase f vn
RecordUpdate Exp
src' [Name]
fields Exp
ve' (PatternType -> Info PatternType
forall a. a -> Info a
Info PatternType
a') SrcLoc
loc
    Maybe PatternType
Nothing ->
      SrcLoc -> Notes -> Doc -> TermTypeM Exp
forall (m :: * -> *) loc a.
(MonadTypeChecker m, Located loc) =>
loc -> Notes -> Doc -> m a
typeError SrcLoc
loc Notes
forall a. Monoid a => a
mempty (Doc -> TermTypeM Exp) -> Doc -> TermTypeM Exp
forall a b. (a -> b) -> a -> b
$
        Doc
"Full type of"
          Doc -> Doc -> Doc
</> Int -> Doc -> Doc
indent Int
2 (ExpBase NoInfo Name -> Doc
forall a. Pretty a => a -> Doc
ppr ExpBase NoInfo Name
src)
          Doc -> Doc -> Doc
</> String -> Doc
textwrap String
" is not known at this point.  Add a size annotation to the original record to disambiguate."
checkExp (Index ExpBase NoInfo Name
e [DimIndexBase NoInfo Name]
idxes (NoInfo PatternType, NoInfo [VName])
_ SrcLoc
loc) = do
  (StructType
t, StructType
_) <- SrcLoc -> String -> Int -> TermTypeM (StructType, StructType)
newArrayType SrcLoc
loc String
"e" (Int -> TermTypeM (StructType, StructType))
-> Int -> TermTypeM (StructType, StructType)
forall a b. (a -> b) -> a -> b
$ [DimIndexBase NoInfo Name] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimIndexBase NoInfo Name]
idxes
  Exp
e' <- String -> StructType -> Exp -> TermTypeM Exp
unifies String
"being indexed at" StructType
t (Exp -> TermTypeM Exp) -> TermTypeM Exp -> TermTypeM Exp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ExpBase NoInfo Name -> TermTypeM Exp
checkExp ExpBase NoInfo Name
e
  [DimIndex]
idxes' <- (DimIndexBase NoInfo Name -> TermTypeM DimIndex)
-> [DimIndexBase NoInfo Name] -> TermTypeM [DimIndex]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM DimIndexBase NoInfo Name -> TermTypeM DimIndex
checkDimIndex [DimIndexBase NoInfo Name]
idxes
  -- XXX, the RigidSlice here will be overridden in sliceShape with a proper value.
  (PatternType
t', [VName]
retext) <-
    Maybe (SrcLoc, Rigidity)
-> [DimIndex] -> PatternType -> TermTypeM (PatternType, [VName])
forall as.
Maybe (SrcLoc, Rigidity)
-> [DimIndex]
-> TypeBase (DimDecl VName) as
-> TermTypeM (TypeBase (DimDecl VName) as, [VName])
sliceShape ((SrcLoc, Rigidity) -> Maybe (SrcLoc, Rigidity)
forall a. a -> Maybe a
Just (SrcLoc
loc, RigidSource -> Rigidity
Rigid (Maybe (DimDecl VName) -> String -> RigidSource
RigidSlice Maybe (DimDecl VName)
forall a. Maybe a
Nothing String
""))) [DimIndex]
idxes'
      (PatternType -> TermTypeM (PatternType, [VName]))
-> TermTypeM PatternType -> TermTypeM (PatternType, [VName])
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Exp -> TermTypeM PatternType
expTypeFully Exp
e'

  -- Remove aliases if the result is an overloaded type, because that
  -- will certainly not be aliased.
  PatternType
t'' <- PatternType -> TermTypeM PatternType
noAliasesIfOverloaded PatternType
t'

  Exp -> TermTypeM Exp
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> TermTypeM Exp) -> Exp -> TermTypeM Exp
forall a b. (a -> b) -> a -> b
$ Exp
-> [DimIndex] -> (Info PatternType, Info [VName]) -> SrcLoc -> Exp
forall (f :: * -> *) vn.
ExpBase f vn
-> [DimIndexBase f vn]
-> (f PatternType, f [VName])
-> SrcLoc
-> ExpBase f vn
Index Exp
e' [DimIndex]
idxes' (PatternType -> Info PatternType
forall a. a -> Info a
Info PatternType
t'', [VName] -> Info [VName]
forall a. a -> Info a
Info [VName]
retext) SrcLoc
loc
checkExp (Assert ExpBase NoInfo Name
e1 ExpBase NoInfo Name
e2 NoInfo String
NoInfo SrcLoc
loc) = do
  Exp
e1' <- String -> [PrimType] -> Exp -> TermTypeM Exp
require String
"being asserted" [PrimType
Bool] (Exp -> TermTypeM Exp) -> TermTypeM Exp -> TermTypeM Exp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ExpBase NoInfo Name -> TermTypeM Exp
checkExp ExpBase NoInfo Name
e1
  Exp
e2' <- ExpBase NoInfo Name -> TermTypeM Exp
checkExp ExpBase NoInfo Name
e2
  Exp -> TermTypeM Exp
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> TermTypeM Exp) -> Exp -> TermTypeM Exp
forall a b. (a -> b) -> a -> b
$ Exp -> Exp -> Info String -> SrcLoc -> Exp
forall (f :: * -> *) vn.
ExpBase f vn -> ExpBase f vn -> f String -> SrcLoc -> ExpBase f vn
Assert Exp
e1' Exp
e2' (String -> Info String
forall a. a -> Info a
Info (ExpBase NoInfo Name -> String
forall a. Pretty a => a -> String
pretty ExpBase NoInfo Name
e1)) SrcLoc
loc
checkExp (Lambda [UncheckedPattern]
params ExpBase NoInfo Name
body Maybe (TypeExp Name)
rettype_te NoInfo (Aliasing, StructType)
NoInfo SrcLoc
loc) =
  TermTypeM Exp -> TermTypeM Exp
forall b. TermTypeM b -> TermTypeM b
removeSeminullOccurences (TermTypeM Exp -> TermTypeM Exp) -> TermTypeM Exp -> TermTypeM Exp
forall a b. (a -> b) -> a -> b
$
    TermTypeM Exp -> TermTypeM Exp
forall b. TermTypeM b -> TermTypeM b
noUnique (TermTypeM Exp -> TermTypeM Exp) -> TermTypeM Exp -> TermTypeM Exp
forall a b. (a -> b) -> a -> b
$
      TermTypeM Exp -> TermTypeM Exp
forall b. TermTypeM b -> TermTypeM b
incLevel (TermTypeM Exp -> TermTypeM Exp) -> TermTypeM Exp -> TermTypeM Exp
forall a b. (a -> b) -> a -> b
$
        [UncheckedTypeParam]
-> [UncheckedPattern]
-> ([TypeParam] -> [Pattern] -> TermTypeM Exp)
-> TermTypeM Exp
forall a.
[UncheckedTypeParam]
-> [UncheckedPattern]
-> ([TypeParam] -> [Pattern] -> TermTypeM a)
-> TermTypeM a
bindingParams [] [UncheckedPattern]
params (([TypeParam] -> [Pattern] -> TermTypeM Exp) -> TermTypeM Exp)
-> ([TypeParam] -> [Pattern] -> TermTypeM Exp) -> TermTypeM Exp
forall a b. (a -> b) -> a -> b
$ \[TypeParam]
_ [Pattern]
params' -> do
          Maybe (TypeExp VName, StructType, Liftedness)
rettype_checked <- (TypeExp Name -> TermTypeM (TypeExp VName, StructType, Liftedness))
-> Maybe (TypeExp Name)
-> TermTypeM (Maybe (TypeExp VName, StructType, Liftedness))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse TypeExp Name -> TermTypeM (TypeExp VName, StructType, Liftedness)
forall (m :: * -> *).
MonadTypeChecker m =>
TypeExp Name -> m (TypeExp VName, StructType, Liftedness)
checkTypeExp Maybe (TypeExp Name)
rettype_te
          let declared_rettype :: Maybe StructType
declared_rettype =
                case Maybe (TypeExp VName, StructType, Liftedness)
rettype_checked of
                  Just (TypeExp VName
_, StructType
st, Liftedness
_) -> StructType -> Maybe StructType
forall a. a -> Maybe a
Just StructType
st
                  Maybe (TypeExp VName, StructType, Liftedness)
Nothing -> Maybe StructType
forall a. Maybe a
Nothing
          (Exp
body', [Occurence]
closure) <-
            TermTypeM Exp -> TermTypeM (Exp, [Occurence])
forall a. TermTypeM a -> TermTypeM (a, [Occurence])
tapOccurences (TermTypeM Exp -> TermTypeM (Exp, [Occurence]))
-> TermTypeM Exp -> TermTypeM (Exp, [Occurence])
forall a b. (a -> b) -> a -> b
$ [Pattern]
-> ExpBase NoInfo Name
-> Maybe StructType
-> SrcLoc
-> TermTypeM Exp
checkFunBody [Pattern]
params' ExpBase NoInfo Name
body Maybe StructType
declared_rettype SrcLoc
loc
          PatternType
body_t <- Exp -> TermTypeM PatternType
expTypeFully Exp
body'

          [Pattern]
params'' <- (Pattern -> TermTypeM Pattern) -> [Pattern] -> TermTypeM [Pattern]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Pattern -> TermTypeM Pattern
forall e. ASTMappable e => e -> TermTypeM e
updateTypes [Pattern]
params'

          (Maybe (TypeExp VName)
rettype', StructType
rettype_st) <-
            case Maybe (TypeExp VName, StructType, Liftedness)
rettype_checked of
              Just (TypeExp VName
te, StructType
st, Liftedness
_) ->
                (Maybe (TypeExp VName), StructType)
-> TermTypeM (Maybe (TypeExp VName), StructType)
forall (m :: * -> *) a. Monad m => a -> m a
return (TypeExp VName -> Maybe (TypeExp VName)
forall a. a -> Maybe a
Just TypeExp VName
te, StructType
st)
              Maybe (TypeExp VName, StructType, Liftedness)
Nothing -> do
                StructType
ret <-
                  [Pattern] -> StructType -> TermTypeM StructType
forall (m :: * -> *).
MonadUnify m =>
[Pattern] -> StructType -> m StructType
inferReturnSizes [Pattern]
params'' (StructType -> TermTypeM StructType)
-> StructType -> TermTypeM StructType
forall a b. (a -> b) -> a -> b
$
                    PatternType -> StructType
forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct (PatternType -> StructType) -> PatternType -> StructType
forall a b. (a -> b) -> a -> b
$
                      [Pattern] -> PatternType -> PatternType
inferReturnUniqueness [Pattern]
params'' PatternType
body_t
                (Maybe (TypeExp VName), StructType)
-> TermTypeM (Maybe (TypeExp VName), StructType)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (TypeExp VName)
forall a. Maybe a
Nothing, StructType
ret)

          [Pattern] -> PatternType -> SrcLoc -> TermTypeM ()
checkGlobalAliases [Pattern]
params' PatternType
body_t SrcLoc
loc
          Maybe Name -> [Pattern] -> TermTypeM ()
verifyFunctionParams Maybe Name
forall a. Maybe a
Nothing [Pattern]
params'

          Aliasing
closure' <- [Pattern] -> [Occurence] -> TermTypeM Aliasing
lexicalClosure [Pattern]
params'' [Occurence]
closure

          Exp -> TermTypeM Exp
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> TermTypeM Exp) -> Exp -> TermTypeM Exp
forall a b. (a -> b) -> a -> b
$ [Pattern]
-> Exp
-> Maybe (TypeExp VName)
-> Info (Aliasing, StructType)
-> SrcLoc
-> Exp
forall (f :: * -> *) vn.
[PatternBase f vn]
-> ExpBase f vn
-> Maybe (TypeExp vn)
-> f (Aliasing, StructType)
-> SrcLoc
-> ExpBase f vn
Lambda [Pattern]
params'' Exp
body' Maybe (TypeExp VName)
rettype' ((Aliasing, StructType) -> Info (Aliasing, StructType)
forall a. a -> Info a
Info (Aliasing
closure', StructType
rettype_st)) SrcLoc
loc
  where
    -- Inferring the sizes of the return type of a lambda is a lot
    -- like let-generalisation.  We wish to remove any rigid sizes
    -- that were created when checking the body, except for those that
    -- are visible in types that existed before we entered the body,
    -- are parameters, or are used in parameters.
    inferReturnSizes :: [Pattern] -> StructType -> m StructType
inferReturnSizes [Pattern]
params' StructType
ret = do
      Int
cur_lvl <- m Int
forall (m :: * -> *). MonadUnify m => m Int
curLevel
      let named :: (PName, b) -> Maybe VName
named (Named VName
x, b
_) = VName -> Maybe VName
forall a. a -> Maybe a
Just VName
x
          named (PName
Unnamed, b
_) = Maybe VName
forall a. Maybe a
Nothing
          param_names :: [VName]
param_names = (Pattern -> Maybe VName) -> [Pattern] -> [VName]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe ((PName, StructType) -> Maybe VName
forall b. (PName, b) -> Maybe VName
named ((PName, StructType) -> Maybe VName)
-> (Pattern -> (PName, StructType)) -> Pattern -> Maybe VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pattern -> (PName, StructType)
patternParam) [Pattern]
params'
          pos_sizes :: Names
pos_sizes =
            StructType -> Names
forall als. TypeBase (DimDecl VName) als -> Names
typeDimNamesPos ([StructType] -> StructType -> StructType
forall as dim.
Monoid as =>
[TypeBase dim as] -> TypeBase dim as -> TypeBase dim as
foldFunType ((Pattern -> StructType) -> [Pattern] -> [StructType]
forall a b. (a -> b) -> [a] -> [b]
map Pattern -> StructType
patternStructType [Pattern]
params') StructType
ret)
          hide :: VName -> (Int, b) -> Bool
hide VName
k (Int
lvl, b
_) =
            Int
lvl Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
cur_lvl Bool -> Bool -> Bool
&& VName
k VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [VName]
param_names Bool -> Bool -> Bool
&& VName
k VName -> Names -> Bool
forall a. Ord a => a -> Set a -> Bool
`S.notMember` Names
pos_sizes

      Names
hidden_sizes <-
        [VName] -> Names
forall a. Ord a => [a] -> Set a
S.fromList ([VName] -> Names)
-> (Constraints -> [VName]) -> Constraints -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Constraints -> [VName]
forall k a. Map k a -> [k]
M.keys (Constraints -> [VName])
-> (Constraints -> Constraints) -> Constraints -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> (Int, Constraint) -> Bool) -> Constraints -> Constraints
forall k a. (k -> a -> Bool) -> Map k a -> Map k a
M.filterWithKey VName -> (Int, Constraint) -> Bool
forall b. VName -> (Int, b) -> Bool
hide (Constraints -> Names) -> m Constraints -> m Names
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m Constraints
forall (m :: * -> *). MonadUnify m => m Constraints
getConstraints

      let onDim :: DimDecl VName -> DimDecl VName
onDim (NamedDim QualName VName
name)
            | Bool -> Bool
not (QualName VName -> VName
forall vn. QualName vn -> vn
qualLeaf QualName VName
name VName -> Names -> Bool
forall a. Ord a => a -> Set a -> Bool
`S.member` Names
hidden_sizes) = QualName VName -> DimDecl VName
forall vn. QualName vn -> DimDecl vn
NamedDim QualName VName
name
            | Bool
otherwise = DimDecl VName
forall vn. DimDecl vn
AnyDim
          onDim DimDecl VName
d = DimDecl VName
d

      StructType -> m StructType
forall (m :: * -> *) a. Monad m => a -> m a
return (StructType -> m StructType) -> StructType -> m StructType
forall a b. (a -> b) -> a -> b
$ (DimDecl VName -> DimDecl VName) -> StructType -> StructType
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first DimDecl VName -> DimDecl VName
onDim StructType
ret
checkExp (OpSection QualName Name
op NoInfo PatternType
_ SrcLoc
loc) = do
  (QualName VName
op', PatternType
ftype) <- SrcLoc -> QualName Name -> TermTypeM (QualName VName, PatternType)
forall (m :: * -> *).
MonadTypeChecker m =>
SrcLoc -> QualName Name -> m (QualName VName, PatternType)
lookupVar SrcLoc
loc QualName Name
op
  Exp -> TermTypeM Exp
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> TermTypeM Exp) -> Exp -> TermTypeM Exp
forall a b. (a -> b) -> a -> b
$ QualName VName -> Info PatternType -> SrcLoc -> Exp
forall (f :: * -> *) vn.
QualName vn -> f PatternType -> SrcLoc -> ExpBase f vn
OpSection QualName VName
op' (PatternType -> Info PatternType
forall a. a -> Info a
Info PatternType
ftype) SrcLoc
loc
checkExp (OpSectionLeft QualName Name
op NoInfo PatternType
_ ExpBase NoInfo Name
e (NoInfo (StructType, Maybe VName), NoInfo StructType)
_ (NoInfo PatternType, NoInfo [VName])
_ SrcLoc
loc) = do
  (QualName VName
op', PatternType
ftype) <- SrcLoc -> QualName Name -> TermTypeM (QualName VName, PatternType)
forall (m :: * -> *).
MonadTypeChecker m =>
SrcLoc -> QualName Name -> m (QualName VName, PatternType)
lookupVar SrcLoc
loc QualName Name
op
  Arg
e_arg <- ExpBase NoInfo Name -> TermTypeM Arg
checkArg ExpBase NoInfo Name
e
  (PatternType
t1, PatternType
rt, Maybe VName
argext, [VName]
retext) <- SrcLoc
-> ApplyOp
-> PatternType
-> Arg
-> TermTypeM (PatternType, PatternType, Maybe VName, [VName])
checkApply SrcLoc
loc (QualName VName -> Maybe (QualName VName)
forall a. a -> Maybe a
Just QualName VName
op', Int
0) PatternType
ftype Arg
e_arg
  case PatternType
rt of
    Scalar (Arrow Aliasing
_ PName
_ PatternType
t2 PatternType
rettype) ->
      Exp -> TermTypeM Exp
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> TermTypeM Exp) -> Exp -> TermTypeM Exp
forall a b. (a -> b) -> a -> b
$
        QualName VName
-> Info PatternType
-> Exp
-> (Info (StructType, Maybe VName), Info StructType)
-> (Info PatternType, Info [VName])
-> SrcLoc
-> Exp
forall (f :: * -> *) vn.
QualName vn
-> f PatternType
-> ExpBase f vn
-> (f (StructType, Maybe VName), f StructType)
-> (f PatternType, f [VName])
-> SrcLoc
-> ExpBase f vn
OpSectionLeft
          QualName VName
op'
          (PatternType -> Info PatternType
forall a. a -> Info a
Info PatternType
ftype)
          (Arg -> Exp
argExp Arg
e_arg)
          ((StructType, Maybe VName) -> Info (StructType, Maybe VName)
forall a. a -> Info a
Info (PatternType -> StructType
forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct PatternType
t1, Maybe VName
argext), StructType -> Info StructType
forall a. a -> Info a
Info (StructType -> Info StructType) -> StructType -> Info StructType
forall a b. (a -> b) -> a -> b
$ PatternType -> StructType
forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct PatternType
t2)
          (PatternType -> Info PatternType
forall a. a -> Info a
Info PatternType
rettype, [VName] -> Info [VName]
forall a. a -> Info a
Info [VName]
retext)
          SrcLoc
loc
    PatternType
_ ->
      SrcLoc -> Notes -> Doc -> TermTypeM Exp
forall (m :: * -> *) loc a.
(MonadTypeChecker m, Located loc) =>
loc -> Notes -> Doc -> m a
typeError SrcLoc
loc Notes
forall a. Monoid a => a
mempty (Doc -> TermTypeM Exp) -> Doc -> TermTypeM Exp
forall a b. (a -> b) -> a -> b
$
        Doc
"Operator section with invalid operator of type" Doc -> Doc -> Doc
<+> PatternType -> Doc
forall a. Pretty a => a -> Doc
ppr PatternType
ftype
checkExp (OpSectionRight QualName Name
op NoInfo PatternType
_ ExpBase NoInfo Name
e (NoInfo StructType, NoInfo (StructType, Maybe VName))
_ NoInfo PatternType
NoInfo SrcLoc
loc) = do
  (QualName VName
op', PatternType
ftype) <- SrcLoc -> QualName Name -> TermTypeM (QualName VName, PatternType)
forall (m :: * -> *).
MonadTypeChecker m =>
SrcLoc -> QualName Name -> m (QualName VName, PatternType)
lookupVar SrcLoc
loc QualName Name
op
  Arg
e_arg <- ExpBase NoInfo Name -> TermTypeM Arg
checkArg ExpBase NoInfo Name
e
  case PatternType
ftype of
    Scalar (Arrow Aliasing
as1 PName
m1 PatternType
t1 (Scalar (Arrow Aliasing
as2 PName
m2 PatternType
t2 PatternType
ret))) -> do
      (PatternType
t2', PatternType
ret', Maybe VName
argext, [VName]
_) <-
        SrcLoc
-> ApplyOp
-> PatternType
-> Arg
-> TermTypeM (PatternType, PatternType, Maybe VName, [VName])
checkApply
          SrcLoc
loc
          (QualName VName -> Maybe (QualName VName)
forall a. a -> Maybe a
Just QualName VName
op', Int
1)
          (ScalarTypeBase (DimDecl VName) Aliasing -> PatternType
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (ScalarTypeBase (DimDecl VName) Aliasing -> PatternType)
-> ScalarTypeBase (DimDecl VName) Aliasing -> PatternType
forall a b. (a -> b) -> a -> b
$ Aliasing
-> PName
-> PatternType
-> PatternType
-> ScalarTypeBase (DimDecl VName) Aliasing
forall dim as.
as
-> PName
-> TypeBase dim as
-> TypeBase dim as
-> ScalarTypeBase dim as
Arrow Aliasing
as2 PName
m2 PatternType
t2 (PatternType -> ScalarTypeBase (DimDecl VName) Aliasing)
-> PatternType -> ScalarTypeBase (DimDecl VName) Aliasing
forall a b. (a -> b) -> a -> b
$ ScalarTypeBase (DimDecl VName) Aliasing -> PatternType
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (ScalarTypeBase (DimDecl VName) Aliasing -> PatternType)
-> ScalarTypeBase (DimDecl VName) Aliasing -> PatternType
forall a b. (a -> b) -> a -> b
$ Aliasing
-> PName
-> PatternType
-> PatternType
-> ScalarTypeBase (DimDecl VName) Aliasing
forall dim as.
as
-> PName
-> TypeBase dim as
-> TypeBase dim as
-> ScalarTypeBase dim as
Arrow Aliasing
as1 PName
m1 PatternType
t1 PatternType
ret)
          Arg
e_arg
      Exp -> TermTypeM Exp
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> TermTypeM Exp) -> Exp -> TermTypeM Exp
forall a b. (a -> b) -> a -> b
$
        QualName VName
-> Info PatternType
-> Exp
-> (Info StructType, Info (StructType, Maybe VName))
-> Info PatternType
-> SrcLoc
-> Exp
forall (f :: * -> *) vn.
QualName vn
-> f PatternType
-> ExpBase f vn
-> (f StructType, f (StructType, Maybe VName))
-> f PatternType
-> SrcLoc
-> ExpBase f vn
OpSectionRight
          QualName VName
op'
          (PatternType -> Info PatternType
forall a. a -> Info a
Info PatternType
ftype)
          (Arg -> Exp
argExp Arg
e_arg)
          (StructType -> Info StructType
forall a. a -> Info a
Info (StructType -> Info StructType) -> StructType -> Info StructType
forall a b. (a -> b) -> a -> b
$ PatternType -> StructType
forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct PatternType
t1, (StructType, Maybe VName) -> Info (StructType, Maybe VName)
forall a. a -> Info a
Info (PatternType -> StructType
forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct PatternType
t2', Maybe VName
argext))
          (PatternType -> Info PatternType
forall a. a -> Info a
Info (PatternType -> Info PatternType)
-> PatternType -> Info PatternType
forall a b. (a -> b) -> a -> b
$ PatternType -> (Aliasing -> Aliasing) -> PatternType
forall dim asf ast.
TypeBase dim asf -> (asf -> ast) -> TypeBase dim ast
addAliases PatternType
ret (Aliasing -> Aliasing -> Aliasing
forall a. Semigroup a => a -> a -> a
<> PatternType -> Aliasing
forall as shape. Monoid as => TypeBase shape as -> as
aliases PatternType
ret'))
          SrcLoc
loc
    PatternType
_ ->
      SrcLoc -> Notes -> Doc -> TermTypeM Exp
forall (m :: * -> *) loc a.
(MonadTypeChecker m, Located loc) =>
loc -> Notes -> Doc -> m a
typeError SrcLoc
loc Notes
forall a. Monoid a => a
mempty (Doc -> TermTypeM Exp) -> Doc -> TermTypeM Exp
forall a b. (a -> b) -> a -> b
$
        Doc
"Operator section with invalid operator of type" Doc -> Doc -> Doc
<+> PatternType -> Doc
forall a. Pretty a => a -> Doc
ppr PatternType
ftype
checkExp (ProjectSection [Name]
fields NoInfo PatternType
NoInfo SrcLoc
loc) = do
  PatternType
a <- SrcLoc -> String -> TermTypeM PatternType
forall (m :: * -> *) als dim.
(MonadUnify m, Monoid als) =>
SrcLoc -> String -> m (TypeBase dim als)
newTypeVar SrcLoc
loc String
"a"
  let usage :: Usage
usage = SrcLoc -> String -> Usage
mkUsage SrcLoc
loc String
"projection at"
  PatternType
b <- (PatternType -> Name -> TermTypeM PatternType)
-> PatternType -> [Name] -> TermTypeM PatternType
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM ((Name -> PatternType -> TermTypeM PatternType)
-> PatternType -> Name -> TermTypeM PatternType
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((Name -> PatternType -> TermTypeM PatternType)
 -> PatternType -> Name -> TermTypeM PatternType)
-> (Name -> PatternType -> TermTypeM PatternType)
-> PatternType
-> Name
-> TermTypeM PatternType
forall a b. (a -> b) -> a -> b
$ Usage -> Name -> PatternType -> TermTypeM PatternType
forall (m :: * -> *).
MonadUnify m =>
Usage -> Name -> PatternType -> m PatternType
mustHaveField Usage
usage) PatternType
a [Name]
fields
  Exp -> TermTypeM Exp
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> TermTypeM Exp) -> Exp -> TermTypeM Exp
forall a b. (a -> b) -> a -> b
$ [Name] -> Info PatternType -> SrcLoc -> Exp
forall (f :: * -> *) vn.
[Name] -> f PatternType -> SrcLoc -> ExpBase f vn
ProjectSection [Name]
fields (PatternType -> Info PatternType
forall a. a -> Info a
Info (PatternType -> Info PatternType)
-> PatternType -> Info PatternType
forall a b. (a -> b) -> a -> b
$ ScalarTypeBase (DimDecl VName) Aliasing -> PatternType
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (ScalarTypeBase (DimDecl VName) Aliasing -> PatternType)
-> ScalarTypeBase (DimDecl VName) Aliasing -> PatternType
forall a b. (a -> b) -> a -> b
$ Aliasing
-> PName
-> PatternType
-> PatternType
-> ScalarTypeBase (DimDecl VName) Aliasing
forall dim as.
as
-> PName
-> TypeBase dim as
-> TypeBase dim as
-> ScalarTypeBase dim as
Arrow Aliasing
forall a. Monoid a => a
mempty PName
Unnamed PatternType
a PatternType
b) SrcLoc
loc
checkExp (IndexSection [DimIndexBase NoInfo Name]
idxes NoInfo PatternType
NoInfo SrcLoc
loc) = do
  (StructType
t, StructType
_) <- SrcLoc -> String -> Int -> TermTypeM (StructType, StructType)
newArrayType SrcLoc
loc String
"e" (Int -> TermTypeM (StructType, StructType))
-> Int -> TermTypeM (StructType, StructType)
forall a b. (a -> b) -> a -> b
$ [DimIndexBase NoInfo Name] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimIndexBase NoInfo Name]
idxes
  [DimIndex]
idxes' <- (DimIndexBase NoInfo Name -> TermTypeM DimIndex)
-> [DimIndexBase NoInfo Name] -> TermTypeM [DimIndex]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM DimIndexBase NoInfo Name -> TermTypeM DimIndex
checkDimIndex [DimIndexBase NoInfo Name]
idxes
  (StructType
t', [VName]
_) <- Maybe (SrcLoc, Rigidity)
-> [DimIndex] -> StructType -> TermTypeM (StructType, [VName])
forall as.
Maybe (SrcLoc, Rigidity)
-> [DimIndex]
-> TypeBase (DimDecl VName) as
-> TermTypeM (TypeBase (DimDecl VName) as, [VName])
sliceShape Maybe (SrcLoc, Rigidity)
forall a. Maybe a
Nothing [DimIndex]
idxes' StructType
t
  Exp -> TermTypeM Exp
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> TermTypeM Exp) -> Exp -> TermTypeM Exp
forall a b. (a -> b) -> a -> b
$ [DimIndex] -> Info PatternType -> SrcLoc -> Exp
forall (f :: * -> *) vn.
[DimIndexBase f vn] -> f PatternType -> SrcLoc -> ExpBase f vn
IndexSection [DimIndex]
idxes' (PatternType -> Info PatternType
forall a. a -> Info a
Info (PatternType -> Info PatternType)
-> PatternType -> Info PatternType
forall a b. (a -> b) -> a -> b
$ StructType -> PatternType
forall dim as. TypeBase dim as -> TypeBase dim Aliasing
fromStruct (StructType -> PatternType) -> StructType -> PatternType
forall a b. (a -> b) -> a -> b
$ ScalarTypeBase (DimDecl VName) () -> StructType
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (ScalarTypeBase (DimDecl VName) () -> StructType)
-> ScalarTypeBase (DimDecl VName) () -> StructType
forall a b. (a -> b) -> a -> b
$ ()
-> PName
-> StructType
-> StructType
-> ScalarTypeBase (DimDecl VName) ()
forall dim as.
as
-> PName
-> TypeBase dim as
-> TypeBase dim as
-> ScalarTypeBase dim as
Arrow ()
forall a. Monoid a => a
mempty PName
Unnamed StructType
t StructType
t') SrcLoc
loc
checkExp (DoLoop [VName]
_ UncheckedPattern
mergepat ExpBase NoInfo Name
mergeexp LoopFormBase NoInfo Name
form ExpBase NoInfo Name
loopbody NoInfo (PatternType, [VName])
NoInfo SrcLoc
loc) =
  TermTypeM Exp
-> (Exp -> [Occurence] -> TermTypeM Exp) -> TermTypeM Exp
forall a b.
TermTypeM a -> (a -> [Occurence] -> TermTypeM b) -> TermTypeM b
sequentially (ExpBase NoInfo Name -> TermTypeM Exp
checkExp ExpBase NoInfo Name
mergeexp) ((Exp -> [Occurence] -> TermTypeM Exp) -> TermTypeM Exp)
-> (Exp -> [Occurence] -> TermTypeM Exp) -> TermTypeM Exp
forall a b. (a -> b) -> a -> b
$ \Exp
mergeexp' [Occurence]
_ -> do
    Usage -> String -> PatternType -> TermTypeM ()
forall (m :: * -> *) dim as.
(MonadUnify m, Pretty (ShapeDecl dim), Monoid as) =>
Usage -> String -> TypeBase dim as -> m ()
zeroOrderType
      (SrcLoc -> String -> Usage
mkUsage (ExpBase NoInfo Name -> SrcLoc
forall a. Located a => a -> SrcLoc
srclocOf ExpBase NoInfo Name
mergeexp) String
"use as loop variable")
      String
"type used as loop variable"
      (PatternType -> TermTypeM ())
-> TermTypeM PatternType -> TermTypeM ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Exp -> TermTypeM PatternType
expTypeFully Exp
mergeexp'

    -- The handling of dimension sizes is a bit intricate, but very
    -- similar to checking a function, followed by checking a call to
    -- it.  The overall procedure is as follows:
    --
    -- (1) All empty dimensions in the merge pattern are instantiated
    -- with nonrigid size variables.  All explicitly specified
    -- dimensions are preserved.
    --
    -- (2) The body of the loop is type-checked.  The result type is
    -- combined with the merge pattern type to determine which sizes are
    -- variant, and these are turned into size parameters for the merge
    -- pattern.
    --
    -- (3) We now conceptually have a function parameter type and return
    -- type.  We check that it can be called with the initial merge
    -- values as argument.  The result of this is the type of the loop
    -- as a whole.
    --
    -- (There is also a convergence loop for inferring uniqueness, but
    -- that's orthogonal to the size handling.)

    (PatternType
merge_t, [VName]
new_dims) <-
      SrcLoc
-> String
-> Rigidity
-> PatternType
-> TermTypeM (PatternType, [VName])
forall (m :: * -> *) als.
MonadUnify m =>
SrcLoc
-> String
-> Rigidity
-> TypeBase (DimDecl VName) als
-> m (TypeBase (DimDecl VName) als, [VName])
instantiateEmptyArrayDims SrcLoc
loc String
"loop" Rigidity
Nonrigid
        (PatternType -> TermTypeM (PatternType, [VName]))
-> (PatternType -> PatternType)
-> PatternType
-> TermTypeM (PatternType, [VName])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatternType -> PatternType
forall vn as. TypeBase (DimDecl vn) as -> TypeBase (DimDecl vn) as
anySizes -- dim handling (1)
        (PatternType -> TermTypeM (PatternType, [VName]))
-> TermTypeM PatternType -> TermTypeM (PatternType, [VName])
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Exp -> TermTypeM PatternType
expTypeFully Exp
mergeexp'

    -- dim handling (2)
    let checkLoopReturnSize :: Pattern -> Exp -> TermTypeM ([VName], Pattern)
checkLoopReturnSize Pattern
mergepat' Exp
loopbody' = do
          PatternType
loopbody_t <- Exp -> TermTypeM PatternType
expTypeFully Exp
loopbody'
          PatternType
pat_t <- PatternType -> TermTypeM PatternType
forall a (m :: * -> *). (Substitutable a, MonadUnify m) => a -> m a
normTypeFully (PatternType -> TermTypeM PatternType)
-> PatternType -> TermTypeM PatternType
forall a b. (a -> b) -> a -> b
$ Pattern -> PatternType
patternType Pattern
mergepat'
          -- We are ignoring the dimensions here, because any mismatches
          -- should be turned into fresh size variables.
          Checking -> TermTypeM () -> TermTypeM ()
forall a. Checking -> TermTypeM a -> TermTypeM a
onFailure (StructType -> StructType -> Checking
CheckingLoopBody (PatternType -> StructType
forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct (PatternType -> PatternType
forall vn as. TypeBase (DimDecl vn) as -> TypeBase (DimDecl vn) as
anySizes PatternType
pat_t)) (PatternType -> StructType
forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct PatternType
loopbody_t)) (TermTypeM () -> TermTypeM ()) -> TermTypeM () -> TermTypeM ()
forall a b. (a -> b) -> a -> b
$
            Usage -> StructType -> StructType -> TermTypeM ()
forall (m :: * -> *).
MonadUnify m =>
Usage -> StructType -> StructType -> m ()
expect
              (SrcLoc -> String -> Usage
mkUsage (ExpBase NoInfo Name -> SrcLoc
forall a. Located a => a -> SrcLoc
srclocOf ExpBase NoInfo Name
loopbody) String
"matching loop body to loop pattern")
              (PatternType -> StructType
forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct (PatternType -> PatternType
forall vn as. TypeBase (DimDecl vn) as -> TypeBase (DimDecl vn) as
anySizes PatternType
pat_t))
              (PatternType -> StructType
forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct PatternType
loopbody_t)
          PatternType
pat_t' <- PatternType -> TermTypeM PatternType
forall a (m :: * -> *). (Substitutable a, MonadUnify m) => a -> m a
normTypeFully PatternType
pat_t
          PatternType
loopbody_t' <- PatternType -> TermTypeM PatternType
forall a (m :: * -> *). (Substitutable a, MonadUnify m) => a -> m a
normTypeFully PatternType
loopbody_t

          -- For each new_dims, figure out what they are instantiated
          -- with in the initial value.  This is used to determine
          -- whether a size is invariant because it always matches the
          -- initial instantiation of that size.
          let initSubst :: (DimDecl vn, b) -> Maybe (QualName vn, b)
initSubst (NamedDim QualName vn
v, b
d) = (QualName vn, b) -> Maybe (QualName vn, b)
forall a. a -> Maybe a
Just (QualName vn
v, b
d)
              initSubst (DimDecl vn, b)
_ = Maybe (QualName vn, b)
forall a. Maybe a
Nothing
          Map (QualName VName) (DimDecl VName)
init_substs <-
            [(QualName VName, DimDecl VName)]
-> Map (QualName VName) (DimDecl VName)
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(QualName VName, DimDecl VName)]
 -> Map (QualName VName) (DimDecl VName))
-> (PatternType -> [(QualName VName, DimDecl VName)])
-> PatternType
-> Map (QualName VName) (DimDecl VName)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((DimDecl VName, DimDecl VName)
 -> Maybe (QualName VName, DimDecl VName))
-> [(DimDecl VName, DimDecl VName)]
-> [(QualName VName, DimDecl VName)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (DimDecl VName, DimDecl VName)
-> Maybe (QualName VName, DimDecl VName)
forall vn b. (DimDecl vn, b) -> Maybe (QualName vn, b)
initSubst ([(DimDecl VName, DimDecl VName)]
 -> [(QualName VName, DimDecl VName)])
-> (PatternType -> [(DimDecl VName, DimDecl VName)])
-> PatternType
-> [(QualName VName, DimDecl VName)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PatternType, [(DimDecl VName, DimDecl VName)])
-> [(DimDecl VName, DimDecl VName)]
forall a b. (a, b) -> b
snd
              ((PatternType, [(DimDecl VName, DimDecl VName)])
 -> [(DimDecl VName, DimDecl VName)])
-> (PatternType -> (PatternType, [(DimDecl VName, DimDecl VName)]))
-> PatternType
-> [(DimDecl VName, DimDecl VName)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatternType
-> PatternType -> (PatternType, [(DimDecl VName, DimDecl VName)])
forall as.
Monoid as =>
TypeBase (DimDecl VName) as
-> TypeBase (DimDecl VName) as
-> (TypeBase (DimDecl VName) as, [(DimDecl VName, DimDecl VName)])
anyDimOnMismatch PatternType
pat_t'
              (PatternType -> Map (QualName VName) (DimDecl VName))
-> TermTypeM PatternType
-> TermTypeM (Map (QualName VName) (DimDecl VName))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp -> TermTypeM PatternType
expTypeFully Exp
mergeexp'

          -- Figure out which of the 'new_dims' dimensions are variant.
          -- This works because we know that each dimension from
          -- new_dims in the pattern is unique and distinct.
          --
          -- Our logic here is a bit reversed: the *mismatches* (from
          -- new_dims) are what we want to extract and turn into size
          -- parameters.
          let mismatchSubst :: (DimDecl VName, DimDecl VName) -> m (Maybe (VName, Subst t))
mismatchSubst (NamedDim QualName VName
v, DimDecl VName
d)
                | QualName VName -> VName
forall vn. QualName vn -> vn
qualLeaf QualName VName
v VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
new_dims =
                  case QualName VName
-> Map (QualName VName) (DimDecl VName) -> Maybe (DimDecl VName)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup QualName VName
v Map (QualName VName) (DimDecl VName)
init_substs of
                    Just DimDecl VName
d'
                      | DimDecl VName
d' DimDecl VName -> DimDecl VName -> Bool
forall a. Eq a => a -> a -> Bool
== DimDecl VName
d ->
                        Maybe (VName, Subst t) -> m (Maybe (VName, Subst t))
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (VName, Subst t) -> m (Maybe (VName, Subst t)))
-> Maybe (VName, Subst t) -> m (Maybe (VName, Subst t))
forall a b. (a -> b) -> a -> b
$ (VName, Subst t) -> Maybe (VName, Subst t)
forall a. a -> Maybe a
Just (QualName VName -> VName
forall vn. QualName vn -> vn
qualLeaf QualName VName
v, DimDecl VName -> Subst t
forall t. DimDecl VName -> Subst t
SizeSubst DimDecl VName
d)
                    Maybe (DimDecl VName)
_ -> do
                      [VName] -> m ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell [QualName VName -> VName
forall vn. QualName vn -> vn
qualLeaf QualName VName
v]
                      Maybe (VName, Subst t) -> m (Maybe (VName, Subst t))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (VName, Subst t)
forall a. Maybe a
Nothing
              mismatchSubst (DimDecl VName, DimDecl VName)
_ = Maybe (VName, Subst t) -> m (Maybe (VName, Subst t))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (VName, Subst t)
forall a. Maybe a
Nothing

              (Map VName (Subst t)
init_substs', [VName]
sparams) =
                Writer [VName] (Map VName (Subst t))
-> (Map VName (Subst t), [VName])
forall w a. Writer w a -> (a, w)
runWriter (Writer [VName] (Map VName (Subst t))
 -> (Map VName (Subst t), [VName]))
-> Writer [VName] (Map VName (Subst t))
-> (Map VName (Subst t), [VName])
forall a b. (a -> b) -> a -> b
$
                  [(VName, Subst t)] -> Map VName (Subst t)
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, Subst t)] -> Map VName (Subst t))
-> ([Maybe (VName, Subst t)] -> [(VName, Subst t)])
-> [Maybe (VName, Subst t)]
-> Map VName (Subst t)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Maybe (VName, Subst t)] -> [(VName, Subst t)]
forall a. [Maybe a] -> [a]
catMaybes
                    ([Maybe (VName, Subst t)] -> Map VName (Subst t))
-> WriterT [VName] Identity [Maybe (VName, Subst t)]
-> Writer [VName] (Map VName (Subst t))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((DimDecl VName, DimDecl VName)
 -> WriterT [VName] Identity (Maybe (VName, Subst t)))
-> [(DimDecl VName, DimDecl VName)]
-> WriterT [VName] Identity [Maybe (VName, Subst t)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM
                      (DimDecl VName, DimDecl VName)
-> WriterT [VName] Identity (Maybe (VName, Subst t))
forall (m :: * -> *) t.
MonadWriter [VName] m =>
(DimDecl VName, DimDecl VName) -> m (Maybe (VName, Subst t))
mismatchSubst
                      ((PatternType, [(DimDecl VName, DimDecl VName)])
-> [(DimDecl VName, DimDecl VName)]
forall a b. (a, b) -> b
snd ((PatternType, [(DimDecl VName, DimDecl VName)])
 -> [(DimDecl VName, DimDecl VName)])
-> (PatternType, [(DimDecl VName, DimDecl VName)])
-> [(DimDecl VName, DimDecl VName)]
forall a b. (a -> b) -> a -> b
$ PatternType
-> PatternType -> (PatternType, [(DimDecl VName, DimDecl VName)])
forall as.
Monoid as =>
TypeBase (DimDecl VName) as
-> TypeBase (DimDecl VName) as
-> (TypeBase (DimDecl VName) as, [(DimDecl VName, DimDecl VName)])
anyDimOnMismatch PatternType
pat_t' PatternType
loopbody_t')

          -- Make sure that any of new_dims that are invariant will be
          -- replaced with the invariant size in the loop body.  Failure
          -- to do this can cause type annotations to still refer to
          -- new_dims.
          let dimToInit :: (VName, Subst t) -> TermTypeM ()
dimToInit (VName
v, SizeSubst DimDecl VName
d) =
                VName -> Constraint -> TermTypeM ()
constrain VName
v (Constraint -> TermTypeM ()) -> Constraint -> TermTypeM ()
forall a b. (a -> b) -> a -> b
$ Maybe (DimDecl VName) -> Usage -> Constraint
Size (DimDecl VName -> Maybe (DimDecl VName)
forall a. a -> Maybe a
Just DimDecl VName
d) (SrcLoc -> String -> Usage
mkUsage SrcLoc
loc String
"size of loop parameter")
              dimToInit (VName, Subst t)
_ =
                () -> TermTypeM ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
          ((VName, Subst Any) -> TermTypeM ())
-> [(VName, Subst Any)] -> TermTypeM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (VName, Subst Any) -> TermTypeM ()
forall t. (VName, Subst t) -> TermTypeM ()
dimToInit ([(VName, Subst Any)] -> TermTypeM ())
-> [(VName, Subst Any)] -> TermTypeM ()
forall a b. (a -> b) -> a -> b
$ Map VName (Subst Any) -> [(VName, Subst Any)]
forall k a. Map k a -> [(k, a)]
M.toList Map VName (Subst Any)
forall t. Map VName (Subst t)
init_substs'

          Pattern
mergepat'' <- (VName -> Maybe (Subst StructType)) -> Pattern -> Pattern
forall a.
Substitutable a =>
(VName -> Maybe (Subst StructType)) -> a -> a
applySubst (VName -> Map VName (Subst StructType) -> Maybe (Subst StructType)
forall k a. Ord k => k -> Map k a -> Maybe a
`M.lookup` Map VName (Subst StructType)
forall t. Map VName (Subst t)
init_substs') (Pattern -> Pattern) -> TermTypeM Pattern -> TermTypeM Pattern
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Pattern -> TermTypeM Pattern
forall e. ASTMappable e => e -> TermTypeM e
updateTypes Pattern
mergepat'
          ([VName], Pattern) -> TermTypeM ([VName], Pattern)
forall (m :: * -> *) a. Monad m => a -> m a
return ([VName] -> [VName]
forall a. Eq a => [a] -> [a]
nub [VName]
sparams, Pattern
mergepat'')

    -- First we do a basic check of the loop body to figure out which of
    -- the merge parameters are being consumed.  For this, we first need
    -- to check the merge pattern, which requires the (initial) merge
    -- expression.
    --
    -- Play a little with occurences to ensure it does not look like
    -- none of the merge variables are being used.
    (([VName]
sparams, Pattern
mergepat', LoopFormBase Info VName
form', Exp
loopbody'), [Occurence]
bodyflow) <-
      case LoopFormBase NoInfo Name
form of
        For IdentBase NoInfo Name
i ExpBase NoInfo Name
uboundexp -> do
          Exp
uboundexp' <- String -> [PrimType] -> Exp -> TermTypeM Exp
require String
"being the bound in a 'for' loop" [PrimType]
anySignedType (Exp -> TermTypeM Exp) -> TermTypeM Exp -> TermTypeM Exp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ExpBase NoInfo Name -> TermTypeM Exp
checkExp ExpBase NoInfo Name
uboundexp
          PatternType
bound_t <- Exp -> TermTypeM PatternType
expTypeFully Exp
uboundexp'
          IdentBase NoInfo Name
-> PatternType
-> (Ident
    -> TermTypeM
         (([VName], Pattern, LoopFormBase Info VName, Exp), [Occurence]))
-> TermTypeM
     (([VName], Pattern, LoopFormBase Info VName, Exp), [Occurence])
forall a.
IdentBase NoInfo Name
-> PatternType -> (Ident -> TermTypeM a) -> TermTypeM a
bindingIdent IdentBase NoInfo Name
i PatternType
bound_t ((Ident
  -> TermTypeM
       (([VName], Pattern, LoopFormBase Info VName, Exp), [Occurence]))
 -> TermTypeM
      (([VName], Pattern, LoopFormBase Info VName, Exp), [Occurence]))
-> (Ident
    -> TermTypeM
         (([VName], Pattern, LoopFormBase Info VName, Exp), [Occurence]))
-> TermTypeM
     (([VName], Pattern, LoopFormBase Info VName, Exp), [Occurence])
forall a b. (a -> b) -> a -> b
$ \Ident
i' ->
            TermTypeM
  (([VName], Pattern, LoopFormBase Info VName, Exp), [Occurence])
-> TermTypeM
     (([VName], Pattern, LoopFormBase Info VName, Exp), [Occurence])
forall b. TermTypeM b -> TermTypeM b
noUnique (TermTypeM
   (([VName], Pattern, LoopFormBase Info VName, Exp), [Occurence])
 -> TermTypeM
      (([VName], Pattern, LoopFormBase Info VName, Exp), [Occurence]))
-> TermTypeM
     (([VName], Pattern, LoopFormBase Info VName, Exp), [Occurence])
-> TermTypeM
     (([VName], Pattern, LoopFormBase Info VName, Exp), [Occurence])
forall a b. (a -> b) -> a -> b
$
              UncheckedPattern
-> InferredType
-> (Pattern
    -> TermTypeM
         (([VName], Pattern, LoopFormBase Info VName, Exp), [Occurence]))
-> TermTypeM
     (([VName], Pattern, LoopFormBase Info VName, Exp), [Occurence])
forall a.
UncheckedPattern
-> InferredType -> (Pattern -> TermTypeM a) -> TermTypeM a
bindingPattern UncheckedPattern
mergepat (PatternType -> InferredType
Ascribed PatternType
merge_t) ((Pattern
  -> TermTypeM
       (([VName], Pattern, LoopFormBase Info VName, Exp), [Occurence]))
 -> TermTypeM
      (([VName], Pattern, LoopFormBase Info VName, Exp), [Occurence]))
-> (Pattern
    -> TermTypeM
         (([VName], Pattern, LoopFormBase Info VName, Exp), [Occurence]))
-> TermTypeM
     (([VName], Pattern, LoopFormBase Info VName, Exp), [Occurence])
forall a b. (a -> b) -> a -> b
$
                \Pattern
mergepat' -> TermTypeM
  (([VName], Pattern, LoopFormBase Info VName, Exp), [Occurence])
-> TermTypeM
     (([VName], Pattern, LoopFormBase Info VName, Exp), [Occurence])
forall b. TermTypeM b -> TermTypeM b
onlySelfAliasing (TermTypeM
   (([VName], Pattern, LoopFormBase Info VName, Exp), [Occurence])
 -> TermTypeM
      (([VName], Pattern, LoopFormBase Info VName, Exp), [Occurence]))
-> TermTypeM
     (([VName], Pattern, LoopFormBase Info VName, Exp), [Occurence])
-> TermTypeM
     (([VName], Pattern, LoopFormBase Info VName, Exp), [Occurence])
forall a b. (a -> b) -> a -> b
$
                  TermTypeM ([VName], Pattern, LoopFormBase Info VName, Exp)
-> TermTypeM
     (([VName], Pattern, LoopFormBase Info VName, Exp), [Occurence])
forall a. TermTypeM a -> TermTypeM (a, [Occurence])
tapOccurences (TermTypeM ([VName], Pattern, LoopFormBase Info VName, Exp)
 -> TermTypeM
      (([VName], Pattern, LoopFormBase Info VName, Exp), [Occurence]))
-> TermTypeM ([VName], Pattern, LoopFormBase Info VName, Exp)
-> TermTypeM
     (([VName], Pattern, LoopFormBase Info VName, Exp), [Occurence])
forall a b. (a -> b) -> a -> b
$ do
                    Exp
loopbody' <- TermTypeM Exp -> TermTypeM Exp
forall b. TermTypeM b -> TermTypeM b
noSizeEscape (TermTypeM Exp -> TermTypeM Exp) -> TermTypeM Exp -> TermTypeM Exp
forall a b. (a -> b) -> a -> b
$ ExpBase NoInfo Name -> TermTypeM Exp
checkExp ExpBase NoInfo Name
loopbody
                    ([VName]
sparams, Pattern
mergepat'') <- Pattern -> Exp -> TermTypeM ([VName], Pattern)
checkLoopReturnSize Pattern
mergepat' Exp
loopbody'
                    ([VName], Pattern, LoopFormBase Info VName, Exp)
-> TermTypeM ([VName], Pattern, LoopFormBase Info VName, Exp)
forall (m :: * -> *) a. Monad m => a -> m a
return
                      ( [VName]
sparams,
                        Pattern
mergepat'',
                        Ident -> Exp -> LoopFormBase Info VName
forall (f :: * -> *) vn.
IdentBase f vn -> ExpBase f vn -> LoopFormBase f vn
For Ident
i' Exp
uboundexp',
                        Exp
loopbody'
                      )
        ForIn UncheckedPattern
xpat ExpBase NoInfo Name
e -> do
          (StructType
arr_t, StructType
_) <- SrcLoc -> String -> Int -> TermTypeM (StructType, StructType)
newArrayType (ExpBase NoInfo Name -> SrcLoc
forall a. Located a => a -> SrcLoc
srclocOf ExpBase NoInfo Name
e) String
"e" Int
1
          Exp
e' <- String -> StructType -> Exp -> TermTypeM Exp
unifies String
"being iterated in a 'for-in' loop" StructType
arr_t (Exp -> TermTypeM Exp) -> TermTypeM Exp -> TermTypeM Exp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ExpBase NoInfo Name -> TermTypeM Exp
checkExp ExpBase NoInfo Name
e
          PatternType
t <- Exp -> TermTypeM PatternType
expTypeFully Exp
e'
          case PatternType
t of
            PatternType
_
              | Just PatternType
t' <- Int -> PatternType -> Maybe PatternType
forall dim as. Int -> TypeBase dim as -> Maybe (TypeBase dim as)
peelArray Int
1 PatternType
t ->
                UncheckedPattern
-> InferredType
-> (Pattern
    -> TermTypeM
         (([VName], Pattern, LoopFormBase Info VName, Exp), [Occurence]))
-> TermTypeM
     (([VName], Pattern, LoopFormBase Info VName, Exp), [Occurence])
forall a.
UncheckedPattern
-> InferredType -> (Pattern -> TermTypeM a) -> TermTypeM a
bindingPattern UncheckedPattern
xpat (PatternType -> InferredType
Ascribed PatternType
t') ((Pattern
  -> TermTypeM
       (([VName], Pattern, LoopFormBase Info VName, Exp), [Occurence]))
 -> TermTypeM
      (([VName], Pattern, LoopFormBase Info VName, Exp), [Occurence]))
-> (Pattern
    -> TermTypeM
         (([VName], Pattern, LoopFormBase Info VName, Exp), [Occurence]))
-> TermTypeM
     (([VName], Pattern, LoopFormBase Info VName, Exp), [Occurence])
forall a b. (a -> b) -> a -> b
$ \Pattern
xpat' ->
                  TermTypeM
  (([VName], Pattern, LoopFormBase Info VName, Exp), [Occurence])
-> TermTypeM
     (([VName], Pattern, LoopFormBase Info VName, Exp), [Occurence])
forall b. TermTypeM b -> TermTypeM b
noUnique (TermTypeM
   (([VName], Pattern, LoopFormBase Info VName, Exp), [Occurence])
 -> TermTypeM
      (([VName], Pattern, LoopFormBase Info VName, Exp), [Occurence]))
-> TermTypeM
     (([VName], Pattern, LoopFormBase Info VName, Exp), [Occurence])
-> TermTypeM
     (([VName], Pattern, LoopFormBase Info VName, Exp), [Occurence])
forall a b. (a -> b) -> a -> b
$
                    UncheckedPattern
-> InferredType
-> (Pattern
    -> TermTypeM
         (([VName], Pattern, LoopFormBase Info VName, Exp), [Occurence]))
-> TermTypeM
     (([VName], Pattern, LoopFormBase Info VName, Exp), [Occurence])
forall a.
UncheckedPattern
-> InferredType -> (Pattern -> TermTypeM a) -> TermTypeM a
bindingPattern UncheckedPattern
mergepat (PatternType -> InferredType
Ascribed PatternType
merge_t) ((Pattern
  -> TermTypeM
       (([VName], Pattern, LoopFormBase Info VName, Exp), [Occurence]))
 -> TermTypeM
      (([VName], Pattern, LoopFormBase Info VName, Exp), [Occurence]))
-> (Pattern
    -> TermTypeM
         (([VName], Pattern, LoopFormBase Info VName, Exp), [Occurence]))
-> TermTypeM
     (([VName], Pattern, LoopFormBase Info VName, Exp), [Occurence])
forall a b. (a -> b) -> a -> b
$
                      \Pattern
mergepat' -> TermTypeM
  (([VName], Pattern, LoopFormBase Info VName, Exp), [Occurence])
-> TermTypeM
     (([VName], Pattern, LoopFormBase Info VName, Exp), [Occurence])
forall b. TermTypeM b -> TermTypeM b
onlySelfAliasing (TermTypeM
   (([VName], Pattern, LoopFormBase Info VName, Exp), [Occurence])
 -> TermTypeM
      (([VName], Pattern, LoopFormBase Info VName, Exp), [Occurence]))
-> TermTypeM
     (([VName], Pattern, LoopFormBase Info VName, Exp), [Occurence])
-> TermTypeM
     (([VName], Pattern, LoopFormBase Info VName, Exp), [Occurence])
forall a b. (a -> b) -> a -> b
$
                        TermTypeM ([VName], Pattern, LoopFormBase Info VName, Exp)
-> TermTypeM
     (([VName], Pattern, LoopFormBase Info VName, Exp), [Occurence])
forall a. TermTypeM a -> TermTypeM (a, [Occurence])
tapOccurences (TermTypeM ([VName], Pattern, LoopFormBase Info VName, Exp)
 -> TermTypeM
      (([VName], Pattern, LoopFormBase Info VName, Exp), [Occurence]))
-> TermTypeM ([VName], Pattern, LoopFormBase Info VName, Exp)
-> TermTypeM
     (([VName], Pattern, LoopFormBase Info VName, Exp), [Occurence])
forall a b. (a -> b) -> a -> b
$ do
                          Exp
loopbody' <- TermTypeM Exp -> TermTypeM Exp
forall b. TermTypeM b -> TermTypeM b
noSizeEscape (TermTypeM Exp -> TermTypeM Exp) -> TermTypeM Exp -> TermTypeM Exp
forall a b. (a -> b) -> a -> b
$ ExpBase NoInfo Name -> TermTypeM Exp
checkExp ExpBase NoInfo Name
loopbody
                          ([VName]
sparams, Pattern
mergepat'') <- Pattern -> Exp -> TermTypeM ([VName], Pattern)
checkLoopReturnSize Pattern
mergepat' Exp
loopbody'
                          ([VName], Pattern, LoopFormBase Info VName, Exp)
-> TermTypeM ([VName], Pattern, LoopFormBase Info VName, Exp)
forall (m :: * -> *) a. Monad m => a -> m a
return
                            ( [VName]
sparams,
                              Pattern
mergepat'',
                              Pattern -> Exp -> LoopFormBase Info VName
forall (f :: * -> *) vn.
PatternBase f vn -> ExpBase f vn -> LoopFormBase f vn
ForIn Pattern
xpat' Exp
e',
                              Exp
loopbody'
                            )
              | Bool
otherwise ->
                SrcLoc
-> Notes
-> Doc
-> TermTypeM
     (([VName], Pattern, LoopFormBase Info VName, Exp), [Occurence])
forall (m :: * -> *) loc a.
(MonadTypeChecker m, Located loc) =>
loc -> Notes -> Doc -> m a
typeError (ExpBase NoInfo Name -> SrcLoc
forall a. Located a => a -> SrcLoc
srclocOf ExpBase NoInfo Name
e) Notes
forall a. Monoid a => a
mempty (Doc
 -> TermTypeM
      (([VName], Pattern, LoopFormBase Info VName, Exp), [Occurence]))
-> Doc
-> TermTypeM
     (([VName], Pattern, LoopFormBase Info VName, Exp), [Occurence])
forall a b. (a -> b) -> a -> b
$
                  Doc
"Iteratee of a for-in loop must be an array, but expression has type"
                    Doc -> Doc -> Doc
<+> PatternType -> Doc
forall a. Pretty a => a -> Doc
ppr PatternType
t
        While ExpBase NoInfo Name
cond ->
          TermTypeM
  (([VName], Pattern, LoopFormBase Info VName, Exp), [Occurence])
-> TermTypeM
     (([VName], Pattern, LoopFormBase Info VName, Exp), [Occurence])
forall b. TermTypeM b -> TermTypeM b
noUnique (TermTypeM
   (([VName], Pattern, LoopFormBase Info VName, Exp), [Occurence])
 -> TermTypeM
      (([VName], Pattern, LoopFormBase Info VName, Exp), [Occurence]))
-> TermTypeM
     (([VName], Pattern, LoopFormBase Info VName, Exp), [Occurence])
-> TermTypeM
     (([VName], Pattern, LoopFormBase Info VName, Exp), [Occurence])
forall a b. (a -> b) -> a -> b
$
            UncheckedPattern
-> InferredType
-> (Pattern
    -> TermTypeM
         (([VName], Pattern, LoopFormBase Info VName, Exp), [Occurence]))
-> TermTypeM
     (([VName], Pattern, LoopFormBase Info VName, Exp), [Occurence])
forall a.
UncheckedPattern
-> InferredType -> (Pattern -> TermTypeM a) -> TermTypeM a
bindingPattern UncheckedPattern
mergepat (PatternType -> InferredType
Ascribed PatternType
merge_t) ((Pattern
  -> TermTypeM
       (([VName], Pattern, LoopFormBase Info VName, Exp), [Occurence]))
 -> TermTypeM
      (([VName], Pattern, LoopFormBase Info VName, Exp), [Occurence]))
-> (Pattern
    -> TermTypeM
         (([VName], Pattern, LoopFormBase Info VName, Exp), [Occurence]))
-> TermTypeM
     (([VName], Pattern, LoopFormBase Info VName, Exp), [Occurence])
forall a b. (a -> b) -> a -> b
$ \Pattern
mergepat' ->
              TermTypeM
  (([VName], Pattern, LoopFormBase Info VName, Exp), [Occurence])
-> TermTypeM
     (([VName], Pattern, LoopFormBase Info VName, Exp), [Occurence])
forall b. TermTypeM b -> TermTypeM b
onlySelfAliasing (TermTypeM
   (([VName], Pattern, LoopFormBase Info VName, Exp), [Occurence])
 -> TermTypeM
      (([VName], Pattern, LoopFormBase Info VName, Exp), [Occurence]))
-> TermTypeM
     (([VName], Pattern, LoopFormBase Info VName, Exp), [Occurence])
-> TermTypeM
     (([VName], Pattern, LoopFormBase Info VName, Exp), [Occurence])
forall a b. (a -> b) -> a -> b
$
                TermTypeM ([VName], Pattern, LoopFormBase Info VName, Exp)
-> TermTypeM
     (([VName], Pattern, LoopFormBase Info VName, Exp), [Occurence])
forall a. TermTypeM a -> TermTypeM (a, [Occurence])
tapOccurences (TermTypeM ([VName], Pattern, LoopFormBase Info VName, Exp)
 -> TermTypeM
      (([VName], Pattern, LoopFormBase Info VName, Exp), [Occurence]))
-> TermTypeM ([VName], Pattern, LoopFormBase Info VName, Exp)
-> TermTypeM
     (([VName], Pattern, LoopFormBase Info VName, Exp), [Occurence])
forall a b. (a -> b) -> a -> b
$
                  TermTypeM Exp
-> (Exp
    -> [Occurence]
    -> TermTypeM ([VName], Pattern, LoopFormBase Info VName, Exp))
-> TermTypeM ([VName], Pattern, LoopFormBase Info VName, Exp)
forall a b.
TermTypeM a -> (a -> [Occurence] -> TermTypeM b) -> TermTypeM b
sequentially
                    ( ExpBase NoInfo Name -> TermTypeM Exp
checkExp ExpBase NoInfo Name
cond
                        TermTypeM Exp -> (Exp -> TermTypeM Exp) -> TermTypeM Exp
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= String -> StructType -> Exp -> TermTypeM Exp
unifies String
"being the condition of a 'while' loop" (ScalarTypeBase (DimDecl VName) () -> StructType
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (ScalarTypeBase (DimDecl VName) () -> StructType)
-> ScalarTypeBase (DimDecl VName) () -> StructType
forall a b. (a -> b) -> a -> b
$ PrimType -> ScalarTypeBase (DimDecl VName) ()
forall dim as. PrimType -> ScalarTypeBase dim as
Prim PrimType
Bool)
                    )
                    ((Exp
  -> [Occurence]
  -> TermTypeM ([VName], Pattern, LoopFormBase Info VName, Exp))
 -> TermTypeM ([VName], Pattern, LoopFormBase Info VName, Exp))
-> (Exp
    -> [Occurence]
    -> TermTypeM ([VName], Pattern, LoopFormBase Info VName, Exp))
-> TermTypeM ([VName], Pattern, LoopFormBase Info VName, Exp)
forall a b. (a -> b) -> a -> b
$ \Exp
cond' [Occurence]
_ -> do
                      Exp
loopbody' <- TermTypeM Exp -> TermTypeM Exp
forall b. TermTypeM b -> TermTypeM b
noSizeEscape (TermTypeM Exp -> TermTypeM Exp) -> TermTypeM Exp -> TermTypeM Exp
forall a b. (a -> b) -> a -> b
$ ExpBase NoInfo Name -> TermTypeM Exp
checkExp ExpBase NoInfo Name
loopbody
                      ([VName]
sparams, Pattern
mergepat'') <- Pattern -> Exp -> TermTypeM ([VName], Pattern)
checkLoopReturnSize Pattern
mergepat' Exp
loopbody'
                      ([VName], Pattern, LoopFormBase Info VName, Exp)
-> TermTypeM ([VName], Pattern, LoopFormBase Info VName, Exp)
forall (m :: * -> *) a. Monad m => a -> m a
return
                        ( [VName]
sparams,
                          Pattern
mergepat'',
                          Exp -> LoopFormBase Info VName
forall (f :: * -> *) vn. ExpBase f vn -> LoopFormBase f vn
While Exp
cond',
                          Exp
loopbody'
                        )

    Pattern
mergepat'' <- do
      PatternType
loopbody_t <- Exp -> TermTypeM PatternType
expTypeFully Exp
loopbody'
      Pattern -> Names -> PatternType -> Usage -> TermTypeM Pattern
forall (m :: * -> *) t.
(MonadUnify m, MonadTypeChecker m, Located t,
 MonadReader TermEnv m) =>
Pattern -> Names -> PatternType -> t -> m Pattern
convergePattern Pattern
mergepat' ([Occurence] -> Names
allConsumed [Occurence]
bodyflow) PatternType
loopbody_t (Usage -> TermTypeM Pattern) -> Usage -> TermTypeM Pattern
forall a b. (a -> b) -> a -> b
$
        SrcLoc -> String -> Usage
mkUsage (Exp -> SrcLoc
forall a. Located a => a -> SrcLoc
srclocOf Exp
loopbody') String
"being (part of) the result of the loop body"

    let consumeMerge :: PatternBase Info vn -> TypeBase dim Aliasing -> TermTypeM ()
consumeMerge (Id vn
_ (Info PatternType
pt) SrcLoc
ploc) TypeBase dim Aliasing
mt
          | PatternType -> Bool
forall shape as. TypeBase shape as -> Bool
unique PatternType
pt = SrcLoc -> Aliasing -> TermTypeM ()
consume SrcLoc
ploc (Aliasing -> TermTypeM ()) -> Aliasing -> TermTypeM ()
forall a b. (a -> b) -> a -> b
$ TypeBase dim Aliasing -> Aliasing
forall as shape. Monoid as => TypeBase shape as -> as
aliases TypeBase dim Aliasing
mt
        consumeMerge (TuplePattern [PatternBase Info vn]
pats SrcLoc
_) TypeBase dim Aliasing
t
          | Just [TypeBase dim Aliasing]
ts <- TypeBase dim Aliasing -> Maybe [TypeBase dim Aliasing]
forall dim as. TypeBase dim as -> Maybe [TypeBase dim as]
isTupleRecord TypeBase dim Aliasing
t =
            (PatternBase Info vn -> TypeBase dim Aliasing -> TermTypeM ())
-> [PatternBase Info vn] -> [TypeBase dim Aliasing] -> TermTypeM ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ PatternBase Info vn -> TypeBase dim Aliasing -> TermTypeM ()
consumeMerge [PatternBase Info vn]
pats [TypeBase dim Aliasing]
ts
        consumeMerge (PatternParens PatternBase Info vn
pat SrcLoc
_) TypeBase dim Aliasing
t =
          PatternBase Info vn -> TypeBase dim Aliasing -> TermTypeM ()
consumeMerge PatternBase Info vn
pat TypeBase dim Aliasing
t
        consumeMerge (PatternAscription PatternBase Info vn
pat TypeDeclBase Info vn
_ SrcLoc
_) TypeBase dim Aliasing
t =
          PatternBase Info vn -> TypeBase dim Aliasing -> TermTypeM ()
consumeMerge PatternBase Info vn
pat TypeBase dim Aliasing
t
        consumeMerge PatternBase Info vn
_ TypeBase dim Aliasing
_ =
          () -> TermTypeM ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    Pattern -> PatternType -> TermTypeM ()
forall vn dim.
PatternBase Info vn -> TypeBase dim Aliasing -> TermTypeM ()
consumeMerge Pattern
mergepat'' (PatternType -> TermTypeM ())
-> TermTypeM PatternType -> TermTypeM ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Exp -> TermTypeM PatternType
expTypeFully Exp
mergeexp'

    -- dim handling (3)
    let sparams_anydim :: Map VName (Subst t)
sparams_anydim = [(VName, Subst t)] -> Map VName (Subst t)
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, Subst t)] -> Map VName (Subst t))
-> [(VName, Subst t)] -> Map VName (Subst t)
forall a b. (a -> b) -> a -> b
$ [VName] -> [Subst t] -> [(VName, Subst t)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
sparams ([Subst t] -> [(VName, Subst t)])
-> [Subst t] -> [(VName, Subst t)]
forall a b. (a -> b) -> a -> b
$ Subst t -> [Subst t]
forall a. a -> [a]
repeat (Subst t -> [Subst t]) -> Subst t -> [Subst t]
forall a b. (a -> b) -> a -> b
$ DimDecl VName -> Subst t
forall t. DimDecl VName -> Subst t
SizeSubst DimDecl VName
forall vn. DimDecl vn
AnyDim
        loopt_anydims :: PatternType
loopt_anydims =
          (VName -> Maybe (Subst StructType)) -> PatternType -> PatternType
forall a.
Substitutable a =>
(VName -> Maybe (Subst StructType)) -> a -> a
applySubst (VName -> Map VName (Subst StructType) -> Maybe (Subst StructType)
forall k a. Ord k => k -> Map k a -> Maybe a
`M.lookup` Map VName (Subst StructType)
forall t. Map VName (Subst t)
sparams_anydim) (PatternType -> PatternType) -> PatternType -> PatternType
forall a b. (a -> b) -> a -> b
$
            Pattern -> PatternType
patternType Pattern
mergepat''
    (StructType
merge_t', [VName]
_) <-
      SrcLoc
-> String
-> Rigidity
-> StructType
-> TermTypeM (StructType, [VName])
forall (m :: * -> *) als.
MonadUnify m =>
SrcLoc
-> String
-> Rigidity
-> TypeBase (DimDecl VName) als
-> m (TypeBase (DimDecl VName) als, [VName])
instantiateEmptyArrayDims SrcLoc
loc String
"loopres" Rigidity
Nonrigid (StructType -> TermTypeM (StructType, [VName]))
-> StructType -> TermTypeM (StructType, [VName])
forall a b. (a -> b) -> a -> b
$ PatternType -> StructType
forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct PatternType
loopt_anydims
    StructType
mergeexp_t <- PatternType -> StructType
forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct (PatternType -> StructType)
-> TermTypeM PatternType -> TermTypeM StructType
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp -> TermTypeM PatternType
expTypeFully Exp
mergeexp'
    Checking -> TermTypeM () -> TermTypeM ()
forall a. Checking -> TermTypeM a -> TermTypeM a
onFailure (StructType -> StructType -> Checking
CheckingLoopInitial (PatternType -> StructType
forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct PatternType
loopt_anydims) StructType
mergeexp_t) (TermTypeM () -> TermTypeM ()) -> TermTypeM () -> TermTypeM ()
forall a b. (a -> b) -> a -> b
$
      Usage -> StructType -> StructType -> TermTypeM ()
forall (m :: * -> *).
MonadUnify m =>
Usage -> StructType -> StructType -> m ()
unify
        (SrcLoc -> String -> Usage
mkUsage (Exp -> SrcLoc
forall a. Located a => a -> SrcLoc
srclocOf Exp
mergeexp') String
"matching initial loop values to pattern")
        StructType
merge_t'
        StructType
mergeexp_t

    (PatternType
loopt, [VName]
retext) <- SrcLoc
-> RigidSource -> PatternType -> TermTypeM (PatternType, [VName])
forall als.
SrcLoc
-> RigidSource
-> TypeBase (DimDecl VName) als
-> TermTypeM (TypeBase (DimDecl VName) als, [VName])
instantiateDimsInType SrcLoc
loc RigidSource
RigidLoop PatternType
loopt_anydims
    -- We set all of the uniqueness to be unique.  This is intentional,
    -- and matches what happens for function calls.  Those arrays that
    -- really *cannot* be consumed will alias something unconsumable,
    -- and will be caught that way.
    let bound_here :: Names
bound_here = Pattern -> Names
forall (f :: * -> *) vn.
(Functor f, Ord vn) =>
PatternBase f vn -> Set vn
patternNames Pattern
mergepat'' Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> [VName] -> Names
forall a. Ord a => [a] -> Set a
S.fromList [VName]
sparams Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
form_bound
        form_bound :: Names
form_bound =
          case LoopFormBase Info VName
form' of
            For Ident
v Exp
_ -> VName -> Names
forall a. a -> Set a
S.singleton (VName -> Names) -> VName -> Names
forall a b. (a -> b) -> a -> b
$ Ident -> VName
forall (f :: * -> *) vn. IdentBase f vn -> vn
identName Ident
v
            ForIn Pattern
forpat Exp
_ -> Pattern -> Names
forall (f :: * -> *) vn.
(Functor f, Ord vn) =>
PatternBase f vn -> Set vn
patternNames Pattern
forpat
            While {} -> Names
forall a. Monoid a => a
mempty
        loopt' :: PatternType
loopt' =
          (Aliasing -> Aliasing) -> PatternType -> PatternType
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second (Aliasing -> Aliasing -> Aliasing
forall a. Ord a => Set a -> Set a -> Set a
`S.difference` (VName -> Alias) -> Names -> Aliasing
forall b a. Ord b => (a -> b) -> Set a -> Set b
S.map VName -> Alias
AliasBound Names
bound_here) (PatternType -> PatternType) -> PatternType -> PatternType
forall a b. (a -> b) -> a -> b
$
            PatternType
loopt PatternType -> Uniqueness -> PatternType
forall dim as. TypeBase dim as -> Uniqueness -> TypeBase dim as
`setUniqueness` Uniqueness
Unique

    -- Eliminate those new_dims that turned into sparams so it won't
    -- look like we have ambiguous sizes lying around.
    (Constraints -> Constraints) -> TermTypeM ()
forall (m :: * -> *).
MonadUnify m =>
(Constraints -> Constraints) -> m ()
modifyConstraints ((Constraints -> Constraints) -> TermTypeM ())
-> (Constraints -> Constraints) -> TermTypeM ()
forall a b. (a -> b) -> a -> b
$ (VName -> (Int, Constraint) -> Bool) -> Constraints -> Constraints
forall k a. (k -> a -> Bool) -> Map k a -> Map k a
M.filterWithKey ((VName -> (Int, Constraint) -> Bool)
 -> Constraints -> Constraints)
-> (VName -> (Int, Constraint) -> Bool)
-> Constraints
-> Constraints
forall a b. (a -> b) -> a -> b
$ \VName
k (Int, Constraint)
_ -> VName
k VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [VName]
sparams

    Exp -> TermTypeM Exp
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> TermTypeM Exp) -> Exp -> TermTypeM Exp
forall a b. (a -> b) -> a -> b
$ [VName]
-> Pattern
-> Exp
-> LoopFormBase Info VName
-> Exp
-> Info (PatternType, [VName])
-> SrcLoc
-> Exp
forall (f :: * -> *) vn.
[VName]
-> PatternBase f vn
-> ExpBase f vn
-> LoopFormBase f vn
-> ExpBase f vn
-> f (PatternType, [VName])
-> SrcLoc
-> ExpBase f vn
DoLoop [VName]
sparams Pattern
mergepat'' Exp
mergeexp' LoopFormBase Info VName
form' Exp
loopbody' ((PatternType, [VName]) -> Info (PatternType, [VName])
forall a. a -> Info a
Info (PatternType
loopt', [VName]
retext)) SrcLoc
loc
  where
    convergePattern :: Pattern -> Names -> PatternType -> t -> m Pattern
convergePattern Pattern
pat Names
body_cons PatternType
body_t t
body_loc = do
      let consumed_merge :: Names
consumed_merge = Pattern -> Names
forall (f :: * -> *) vn.
(Functor f, Ord vn) =>
PatternBase f vn -> Set vn
patternNames Pattern
pat Names -> Names -> Names
forall a. Ord a => Set a -> Set a -> Set a
`S.intersection` Names
body_cons

          uniquePat :: Pattern -> Pattern
uniquePat (Wildcard (Info PatternType
t) SrcLoc
wloc) =
            Info PatternType -> SrcLoc -> Pattern
forall (f :: * -> *) vn.
f PatternType -> SrcLoc -> PatternBase f vn
Wildcard (PatternType -> Info PatternType
forall a. a -> Info a
Info (PatternType -> Info PatternType)
-> PatternType -> Info PatternType
forall a b. (a -> b) -> a -> b
$ PatternType
t PatternType -> Uniqueness -> PatternType
forall dim as. TypeBase dim as -> Uniqueness -> TypeBase dim as
`setUniqueness` Uniqueness
Nonunique) SrcLoc
wloc
          uniquePat (PatternParens Pattern
p SrcLoc
ploc) =
            Pattern -> SrcLoc -> Pattern
forall (f :: * -> *) vn.
PatternBase f vn -> SrcLoc -> PatternBase f vn
PatternParens (Pattern -> Pattern
uniquePat Pattern
p) SrcLoc
ploc
          uniquePat (Id VName
name (Info PatternType
t) SrcLoc
iloc)
            | VName
name VName -> Names -> Bool
forall a. Ord a => a -> Set a -> Bool
`S.member` Names
consumed_merge =
              let t' :: PatternType
t' = PatternType
t PatternType -> Uniqueness -> PatternType
forall dim as. TypeBase dim as -> Uniqueness -> TypeBase dim as
`setUniqueness` Uniqueness
Unique PatternType -> Aliasing -> PatternType
forall dim asf ast. TypeBase dim asf -> ast -> TypeBase dim ast
`setAliases` Aliasing
forall a. Monoid a => a
mempty
               in VName -> Info PatternType -> SrcLoc -> Pattern
forall (f :: * -> *) vn.
vn -> f PatternType -> SrcLoc -> PatternBase f vn
Id VName
name (PatternType -> Info PatternType
forall a. a -> Info a
Info PatternType
t') SrcLoc
iloc
            | Bool
otherwise =
              let t' :: PatternType
t' = PatternType
t PatternType -> Uniqueness -> PatternType
forall dim as. TypeBase dim as -> Uniqueness -> TypeBase dim as
`setUniqueness` Uniqueness
Nonunique
               in VName -> Info PatternType -> SrcLoc -> Pattern
forall (f :: * -> *) vn.
vn -> f PatternType -> SrcLoc -> PatternBase f vn
Id VName
name (PatternType -> Info PatternType
forall a. a -> Info a
Info PatternType
t') SrcLoc
iloc
          uniquePat (TuplePattern [Pattern]
pats SrcLoc
ploc) =
            [Pattern] -> SrcLoc -> Pattern
forall (f :: * -> *) vn.
[PatternBase f vn] -> SrcLoc -> PatternBase f vn
TuplePattern ((Pattern -> Pattern) -> [Pattern] -> [Pattern]
forall a b. (a -> b) -> [a] -> [b]
map Pattern -> Pattern
uniquePat [Pattern]
pats) SrcLoc
ploc
          uniquePat (RecordPattern [(Name, Pattern)]
fs SrcLoc
ploc) =
            [(Name, Pattern)] -> SrcLoc -> Pattern
forall (f :: * -> *) vn.
[(Name, PatternBase f vn)] -> SrcLoc -> PatternBase f vn
RecordPattern (((Name, Pattern) -> (Name, Pattern))
-> [(Name, Pattern)] -> [(Name, Pattern)]
forall a b. (a -> b) -> [a] -> [b]
map ((Pattern -> Pattern) -> (Name, Pattern) -> (Name, Pattern)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Pattern -> Pattern
uniquePat) [(Name, Pattern)]
fs) SrcLoc
ploc
          uniquePat (PatternAscription Pattern
p TypeDeclBase Info VName
t SrcLoc
ploc) =
            Pattern -> TypeDeclBase Info VName -> SrcLoc -> Pattern
forall (f :: * -> *) vn.
PatternBase f vn -> TypeDeclBase f vn -> SrcLoc -> PatternBase f vn
PatternAscription Pattern
p TypeDeclBase Info VName
t SrcLoc
ploc
          uniquePat p :: Pattern
p@PatternLit {} = Pattern
p
          uniquePat (PatternConstr Name
n Info PatternType
t [Pattern]
ps SrcLoc
ploc) =
            Name -> Info PatternType -> [Pattern] -> SrcLoc -> Pattern
forall (f :: * -> *) vn.
Name
-> f PatternType
-> [PatternBase f vn]
-> SrcLoc
-> PatternBase f vn
PatternConstr Name
n Info PatternType
t ((Pattern -> Pattern) -> [Pattern] -> [Pattern]
forall a b. (a -> b) -> [a] -> [b]
map Pattern -> Pattern
uniquePat [Pattern]
ps) SrcLoc
ploc

          -- Make the pattern unique where needed.
          pat' :: Pattern
pat' = Pattern -> Pattern
uniquePat Pattern
pat

      PatternType
pat_t <- PatternType -> m PatternType
forall a (m :: * -> *). (Substitutable a, MonadUnify m) => a -> m a
normTypeFully (PatternType -> m PatternType) -> PatternType -> m PatternType
forall a b. (a -> b) -> a -> b
$ Pattern -> PatternType
patternType Pattern
pat'
      Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (PatternType -> TypeBase () ()
forall dim as. TypeBase dim as -> TypeBase () ()
toStructural PatternType
body_t TypeBase () () -> TypeBase () () -> Bool
forall dim as1 as2.
ArrayDim dim =>
TypeBase dim as1 -> TypeBase dim as2 -> Bool
`subtypeOf` PatternType -> TypeBase () ()
forall dim as. TypeBase dim as -> TypeBase () ()
toStructural PatternType
pat_t) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$
        SrcLoc -> StructType -> [StructType] -> m ()
forall (m :: * -> *) a.
MonadTypeChecker m =>
SrcLoc -> StructType -> [StructType] -> m a
unexpectedType (t -> SrcLoc
forall a. Located a => a -> SrcLoc
srclocOf t
body_loc) (PatternType -> StructType
forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct PatternType
body_t) [PatternType -> StructType
forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct PatternType
pat_t]

      -- Check that the new values of consumed merge parameters do not
      -- alias something bound outside the loop, AND that anything
      -- returned for a unique merge parameter does not alias anything
      -- else returned.  We also update the aliases for the pattern.
      Names
bound_outside <- (TermEnv -> Names) -> m Names
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((TermEnv -> Names) -> m Names) -> (TermEnv -> Names) -> m Names
forall a b. (a -> b) -> a -> b
$ [VName] -> Names
forall a. Ord a => [a] -> Set a
S.fromList ([VName] -> Names) -> (TermEnv -> [VName]) -> TermEnv -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map VName ValBinding -> [VName]
forall k a. Map k a -> [k]
M.keys (Map VName ValBinding -> [VName])
-> (TermEnv -> Map VName ValBinding) -> TermEnv -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TermScope -> Map VName ValBinding
scopeVtable (TermScope -> Map VName ValBinding)
-> (TermEnv -> TermScope) -> TermEnv -> Map VName ValBinding
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TermEnv -> TermScope
termScope
      let combAliases :: TypeBase dim ast -> TypeBase shape ast -> TypeBase dim ast
combAliases TypeBase dim ast
t1 TypeBase shape ast
t2 =
            case TypeBase dim ast
t1 of
              Scalar Record {} -> TypeBase dim ast
t1
              TypeBase dim ast
_ -> TypeBase dim ast
t1 TypeBase dim ast -> (ast -> ast) -> TypeBase dim ast
forall dim asf ast.
TypeBase dim asf -> (asf -> ast) -> TypeBase dim ast
`addAliases` (ast -> ast -> ast
forall a. Semigroup a => a -> a -> a
<> TypeBase shape ast -> ast
forall as shape. Monoid as => TypeBase shape as -> as
aliases TypeBase shape ast
t2)

          checkMergeReturn :: PatternBase Info vn
-> TypeBase dim Aliasing -> t m (PatternBase Info vn)
checkMergeReturn (Id vn
pat_v (Info PatternType
pat_v_t) SrcLoc
patloc) TypeBase dim Aliasing
t
            | PatternType -> Bool
forall shape as. TypeBase shape as -> Bool
unique PatternType
pat_v_t,
              VName
v : [VName]
_ <-
                Names -> [VName]
forall a. Set a -> [a]
S.toList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$
                  (Alias -> VName) -> Aliasing -> Names
forall b a. Ord b => (a -> b) -> Set a -> Set b
S.map Alias -> VName
aliasVar (TypeBase dim Aliasing -> Aliasing
forall as shape. Monoid as => TypeBase shape as -> as
aliases TypeBase dim Aliasing
t) Names -> Names -> Names
forall a. Ord a => Set a -> Set a -> Set a
`S.intersection` Names
bound_outside =
              m (PatternBase Info vn) -> t m (PatternBase Info vn)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m (PatternBase Info vn) -> t m (PatternBase Info vn))
-> m (PatternBase Info vn) -> t m (PatternBase Info vn)
forall a b. (a -> b) -> a -> b
$
                SrcLoc -> Notes -> Doc -> m (PatternBase Info vn)
forall (m :: * -> *) loc a.
(MonadTypeChecker m, Located loc) =>
loc -> Notes -> Doc -> m a
typeError SrcLoc
loc Notes
forall a. Monoid a => a
mempty (Doc -> m (PatternBase Info vn)) -> Doc -> m (PatternBase Info vn)
forall a b. (a -> b) -> a -> b
$
                  Doc
"Return value for loop parameter"
                    Doc -> Doc -> Doc
<+> Doc -> Doc
pquote (vn -> Doc
forall v. IsName v => v -> Doc
pprName vn
pat_v)
                    Doc -> Doc -> Doc
<+> Doc
"aliases"
                    Doc -> Doc -> Doc
<+> VName -> Doc
forall v. IsName v => v -> Doc
pprName VName
v Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
"."
            | Bool
otherwise = do
              (Aliasing
cons, Aliasing
obs) <- t m (Aliasing, Aliasing)
forall s (m :: * -> *). MonadState s m => m s
get
              Bool -> t m () -> t m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Aliasing -> Bool
forall a. Set a -> Bool
S.null (Aliasing -> Bool) -> Aliasing -> Bool
forall a b. (a -> b) -> a -> b
$ TypeBase dim Aliasing -> Aliasing
forall as shape. Monoid as => TypeBase shape as -> as
aliases TypeBase dim Aliasing
t Aliasing -> Aliasing -> Aliasing
forall a. Ord a => Set a -> Set a -> Set a
`S.intersection` Aliasing
cons) (t m () -> t m ()) -> t m () -> t m ()
forall a b. (a -> b) -> a -> b
$
                m () -> t m ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> t m ()) -> m () -> t m ()
forall a b. (a -> b) -> a -> b
$
                  SrcLoc -> Notes -> Doc -> m ()
forall (m :: * -> *) loc a.
(MonadTypeChecker m, Located loc) =>
loc -> Notes -> Doc -> m a
typeError SrcLoc
loc Notes
forall a. Monoid a => a
mempty (Doc -> m ()) -> Doc -> m ()
forall a b. (a -> b) -> a -> b
$
                    Doc
"Return value for loop parameter"
                      Doc -> Doc -> Doc
<+> Doc -> Doc
pquote (vn -> Doc
forall v. IsName v => v -> Doc
pprName vn
pat_v)
                      Doc -> Doc -> Doc
<+> Doc
"aliases other consumed loop parameter."
              Bool -> t m () -> t m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when
                ( PatternType -> Bool
forall shape as. TypeBase shape as -> Bool
unique PatternType
pat_v_t
                    Bool -> Bool -> Bool
&& Bool -> Bool
not (Aliasing -> Bool
forall a. Set a -> Bool
S.null (TypeBase dim Aliasing -> Aliasing
forall as shape. Monoid as => TypeBase shape as -> as
aliases TypeBase dim Aliasing
t Aliasing -> Aliasing -> Aliasing
forall a. Ord a => Set a -> Set a -> Set a
`S.intersection` (Aliasing
cons Aliasing -> Aliasing -> Aliasing
forall a. Semigroup a => a -> a -> a
<> Aliasing
obs)))
                )
                (t m () -> t m ()) -> t m () -> t m ()
forall a b. (a -> b) -> a -> b
$ m () -> t m ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> t m ()) -> m () -> t m ()
forall a b. (a -> b) -> a -> b
$
                  SrcLoc -> Notes -> Doc -> m ()
forall (m :: * -> *) loc a.
(MonadTypeChecker m, Located loc) =>
loc -> Notes -> Doc -> m a
typeError SrcLoc
loc Notes
forall a. Monoid a => a
mempty (Doc -> m ()) -> Doc -> m ()
forall a b. (a -> b) -> a -> b
$
                    Doc
"Return value for consuming loop parameter"
                      Doc -> Doc -> Doc
<+> Doc -> Doc
pquote (vn -> Doc
forall v. IsName v => v -> Doc
pprName vn
pat_v)
                      Doc -> Doc -> Doc
<+> Doc
"aliases previously returned value."
              if PatternType -> Bool
forall shape as. TypeBase shape as -> Bool
unique PatternType
pat_v_t
                then (Aliasing, Aliasing) -> t m ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put (Aliasing
cons Aliasing -> Aliasing -> Aliasing
forall a. Semigroup a => a -> a -> a
<> TypeBase dim Aliasing -> Aliasing
forall as shape. Monoid as => TypeBase shape as -> as
aliases TypeBase dim Aliasing
t, Aliasing
obs)
                else (Aliasing, Aliasing) -> t m ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put (Aliasing
cons, Aliasing
obs Aliasing -> Aliasing -> Aliasing
forall a. Semigroup a => a -> a -> a
<> TypeBase dim Aliasing -> Aliasing
forall as shape. Monoid as => TypeBase shape as -> as
aliases TypeBase dim Aliasing
t)

              PatternBase Info vn -> t m (PatternBase Info vn)
forall (m :: * -> *) a. Monad m => a -> m a
return (PatternBase Info vn -> t m (PatternBase Info vn))
-> PatternBase Info vn -> t m (PatternBase Info vn)
forall a b. (a -> b) -> a -> b
$ vn -> Info PatternType -> SrcLoc -> PatternBase Info vn
forall (f :: * -> *) vn.
vn -> f PatternType -> SrcLoc -> PatternBase f vn
Id vn
pat_v (PatternType -> Info PatternType
forall a. a -> Info a
Info (PatternType -> TypeBase dim Aliasing -> PatternType
forall ast dim shape.
Monoid ast =>
TypeBase dim ast -> TypeBase shape ast -> TypeBase dim ast
combAliases PatternType
pat_v_t TypeBase dim Aliasing
t)) SrcLoc
patloc
          checkMergeReturn (Wildcard (Info PatternType
pat_v_t) SrcLoc
patloc) TypeBase dim Aliasing
t =
            PatternBase Info vn -> t m (PatternBase Info vn)
forall (m :: * -> *) a. Monad m => a -> m a
return (PatternBase Info vn -> t m (PatternBase Info vn))
-> PatternBase Info vn -> t m (PatternBase Info vn)
forall a b. (a -> b) -> a -> b
$ Info PatternType -> SrcLoc -> PatternBase Info vn
forall (f :: * -> *) vn.
f PatternType -> SrcLoc -> PatternBase f vn
Wildcard (PatternType -> Info PatternType
forall a. a -> Info a
Info (PatternType -> TypeBase dim Aliasing -> PatternType
forall ast dim shape.
Monoid ast =>
TypeBase dim ast -> TypeBase shape ast -> TypeBase dim ast
combAliases PatternType
pat_v_t TypeBase dim Aliasing
t)) SrcLoc
patloc
          checkMergeReturn (PatternParens PatternBase Info vn
p SrcLoc
_) TypeBase dim Aliasing
t =
            PatternBase Info vn
-> TypeBase dim Aliasing -> t m (PatternBase Info vn)
checkMergeReturn PatternBase Info vn
p TypeBase dim Aliasing
t
          checkMergeReturn (PatternAscription PatternBase Info vn
p TypeDeclBase Info vn
_ SrcLoc
_) TypeBase dim Aliasing
t =
            PatternBase Info vn
-> TypeBase dim Aliasing -> t m (PatternBase Info vn)
checkMergeReturn PatternBase Info vn
p TypeBase dim Aliasing
t
          checkMergeReturn (RecordPattern [(Name, PatternBase Info vn)]
pfs SrcLoc
patloc) (Scalar (Record Map Name (TypeBase dim Aliasing)
tfs)) =
            [(Name, PatternBase Info vn)] -> SrcLoc -> PatternBase Info vn
forall (f :: * -> *) vn.
[(Name, PatternBase f vn)] -> SrcLoc -> PatternBase f vn
RecordPattern ([(Name, PatternBase Info vn)] -> SrcLoc -> PatternBase Info vn)
-> (Map Name (PatternBase Info vn)
    -> [(Name, PatternBase Info vn)])
-> Map Name (PatternBase Info vn)
-> SrcLoc
-> PatternBase Info vn
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map Name (PatternBase Info vn) -> [(Name, PatternBase Info vn)]
forall k a. Map k a -> [(k, a)]
M.toList (Map Name (PatternBase Info vn) -> SrcLoc -> PatternBase Info vn)
-> t m (Map Name (PatternBase Info vn))
-> t m (SrcLoc -> PatternBase Info vn)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Map Name (t m (PatternBase Info vn))
-> t m (Map Name (PatternBase Info vn))
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence Map Name (t m (PatternBase Info vn))
pfs' t m (SrcLoc -> PatternBase Info vn)
-> t m SrcLoc -> t m (PatternBase Info vn)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SrcLoc -> t m SrcLoc
forall (f :: * -> *) a. Applicative f => a -> f a
pure SrcLoc
patloc
            where
              pfs' :: Map Name (t m (PatternBase Info vn))
pfs' =
                (PatternBase Info vn
 -> TypeBase dim Aliasing -> t m (PatternBase Info vn))
-> Map Name (PatternBase Info vn)
-> Map Name (TypeBase dim Aliasing)
-> Map Name (t m (PatternBase Info vn))
forall k a b c.
Ord k =>
(a -> b -> c) -> Map k a -> Map k b -> Map k c
M.intersectionWith
                  PatternBase Info vn
-> TypeBase dim Aliasing -> t m (PatternBase Info vn)
checkMergeReturn
                  ([(Name, PatternBase Info vn)] -> Map Name (PatternBase Info vn)
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(Name, PatternBase Info vn)]
pfs)
                  Map Name (TypeBase dim Aliasing)
tfs
          checkMergeReturn (TuplePattern [PatternBase Info vn]
pats SrcLoc
patloc) TypeBase dim Aliasing
t
            | Just [TypeBase dim Aliasing]
ts <- TypeBase dim Aliasing -> Maybe [TypeBase dim Aliasing]
forall dim as. TypeBase dim as -> Maybe [TypeBase dim as]
isTupleRecord TypeBase dim Aliasing
t =
              [PatternBase Info vn] -> SrcLoc -> PatternBase Info vn
forall (f :: * -> *) vn.
[PatternBase f vn] -> SrcLoc -> PatternBase f vn
TuplePattern
                ([PatternBase Info vn] -> SrcLoc -> PatternBase Info vn)
-> t m [PatternBase Info vn] -> t m (SrcLoc -> PatternBase Info vn)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (PatternBase Info vn
 -> TypeBase dim Aliasing -> t m (PatternBase Info vn))
-> [PatternBase Info vn]
-> [TypeBase dim Aliasing]
-> t m [PatternBase Info vn]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM PatternBase Info vn
-> TypeBase dim Aliasing -> t m (PatternBase Info vn)
checkMergeReturn [PatternBase Info vn]
pats [TypeBase dim Aliasing]
ts
                t m (SrcLoc -> PatternBase Info vn)
-> t m SrcLoc -> t m (PatternBase Info vn)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SrcLoc -> t m SrcLoc
forall (f :: * -> *) a. Applicative f => a -> f a
pure SrcLoc
patloc
          checkMergeReturn PatternBase Info vn
p TypeBase dim Aliasing
_ =
            PatternBase Info vn -> t m (PatternBase Info vn)
forall (m :: * -> *) a. Monad m => a -> m a
return PatternBase Info vn
p

      (Pattern
pat'', (Aliasing
pat_cons, Aliasing
_)) <-
        StateT (Aliasing, Aliasing) m Pattern
-> (Aliasing, Aliasing) -> m (Pattern, (Aliasing, Aliasing))
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT (Pattern -> PatternType -> StateT (Aliasing, Aliasing) m Pattern
forall (t :: (* -> *) -> * -> *) (m :: * -> *) vn dim.
(MonadTrans t, MonadTypeChecker m, IsName vn,
 MonadState (Aliasing, Aliasing) (t m)) =>
PatternBase Info vn
-> TypeBase dim Aliasing -> t m (PatternBase Info vn)
checkMergeReturn Pattern
pat' PatternType
body_t) (Aliasing
forall a. Monoid a => a
mempty, Aliasing
forall a. Monoid a => a
mempty)

      let body_cons' :: Names
body_cons' = Names
body_cons Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> (Alias -> VName) -> Aliasing -> Names
forall b a. Ord b => (a -> b) -> Set a -> Set b
S.map Alias -> VName
aliasVar Aliasing
pat_cons
      if Names
body_cons' Names -> Names -> Bool
forall a. Eq a => a -> a -> Bool
== Names
body_cons Bool -> Bool -> Bool
&& Pattern -> PatternType
patternType Pattern
pat'' PatternType -> PatternType -> Bool
forall a. Eq a => a -> a -> Bool
== Pattern -> PatternType
patternType Pattern
pat
        then Pattern -> m Pattern
forall (m :: * -> *) a. Monad m => a -> m a
return Pattern
pat'
        else Pattern -> Names -> PatternType -> t -> m Pattern
convergePattern Pattern
pat'' Names
body_cons' PatternType
body_t t
body_loc
checkExp (Constr Name
name [ExpBase NoInfo Name]
es NoInfo PatternType
NoInfo SrcLoc
loc) = do
  StructType
t <- SrcLoc -> String -> TermTypeM StructType
forall (m :: * -> *) als dim.
(MonadUnify m, Monoid als) =>
SrcLoc -> String -> m (TypeBase dim als)
newTypeVar SrcLoc
loc String
"t"
  [Exp]
es' <- (ExpBase NoInfo Name -> TermTypeM Exp)
-> [ExpBase NoInfo Name] -> TermTypeM [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ExpBase NoInfo Name -> TermTypeM Exp
checkExp [ExpBase NoInfo Name]
es
  [PatternType]
ets <- (Exp -> TermTypeM PatternType) -> [Exp] -> TermTypeM [PatternType]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Exp -> TermTypeM PatternType
expTypeFully [Exp]
es'
  Usage -> Name -> StructType -> [StructType] -> TermTypeM ()
forall (m :: * -> *).
MonadUnify m =>
Usage -> Name -> StructType -> [StructType] -> m ()
mustHaveConstr (SrcLoc -> String -> Usage
mkUsage SrcLoc
loc String
"use of constructor") Name
name StructType
t (PatternType -> StructType
forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct (PatternType -> StructType) -> [PatternType] -> [StructType]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [PatternType]
ets)
  -- A sum value aliases *anything* that went into its construction.
  let als :: Aliasing
als = (PatternType -> Aliasing) -> [PatternType] -> Aliasing
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap PatternType -> Aliasing
forall as shape. Monoid as => TypeBase shape as -> as
aliases [PatternType]
ets
  Exp -> TermTypeM Exp
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> TermTypeM Exp) -> Exp -> TermTypeM Exp
forall a b. (a -> b) -> a -> b
$ Name -> [Exp] -> Info PatternType -> SrcLoc -> Exp
forall (f :: * -> *) vn.
Name -> [ExpBase f vn] -> f PatternType -> SrcLoc -> ExpBase f vn
Constr Name
name [Exp]
es' (PatternType -> Info PatternType
forall a. a -> Info a
Info (PatternType -> Info PatternType)
-> PatternType -> Info PatternType
forall a b. (a -> b) -> a -> b
$ StructType -> PatternType
forall dim as. TypeBase dim as -> TypeBase dim Aliasing
fromStruct StructType
t PatternType -> (Aliasing -> Aliasing) -> PatternType
forall dim asf ast.
TypeBase dim asf -> (asf -> ast) -> TypeBase dim ast
`addAliases` (Aliasing -> Aliasing -> Aliasing
forall a. Semigroup a => a -> a -> a
<> Aliasing
als)) SrcLoc
loc
checkExp (Match ExpBase NoInfo Name
e NonEmpty (CaseBase NoInfo Name)
cs (NoInfo PatternType, NoInfo [VName])
_ SrcLoc
loc) =
  TermTypeM Exp
-> (Exp -> [Occurence] -> TermTypeM Exp) -> TermTypeM Exp
forall a b.
TermTypeM a -> (a -> [Occurence] -> TermTypeM b) -> TermTypeM b
sequentially (ExpBase NoInfo Name -> TermTypeM Exp
checkExp ExpBase NoInfo Name
e) ((Exp -> [Occurence] -> TermTypeM Exp) -> TermTypeM Exp)
-> (Exp -> [Occurence] -> TermTypeM Exp) -> TermTypeM Exp
forall a b. (a -> b) -> a -> b
$ \Exp
e' [Occurence]
_ -> do
    PatternType
mt <- Exp -> TermTypeM PatternType
expTypeFully Exp
e'
    (NonEmpty (CaseBase Info VName)
cs', PatternType
t, [VName]
retext) <- PatternType
-> NonEmpty (CaseBase NoInfo Name)
-> TermTypeM (NonEmpty (CaseBase Info VName), PatternType, [VName])
checkCases PatternType
mt NonEmpty (CaseBase NoInfo Name)
cs
    Usage -> String -> PatternType -> TermTypeM ()
forall (m :: * -> *) dim as.
(MonadUnify m, Pretty (ShapeDecl dim), Monoid as) =>
Usage -> String -> TypeBase dim as -> m ()
zeroOrderType
      (SrcLoc -> String -> Usage
mkUsage SrcLoc
loc String
"being returned 'match'")
      String
"type returned from pattern match"
      PatternType
t
    Exp -> TermTypeM Exp
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> TermTypeM Exp) -> Exp -> TermTypeM Exp
forall a b. (a -> b) -> a -> b
$ Exp
-> NonEmpty (CaseBase Info VName)
-> (Info PatternType, Info [VName])
-> SrcLoc
-> Exp
forall (f :: * -> *) vn.
ExpBase f vn
-> NonEmpty (CaseBase f vn)
-> (f PatternType, f [VName])
-> SrcLoc
-> ExpBase f vn
Match Exp
e' NonEmpty (CaseBase Info VName)
cs' (PatternType -> Info PatternType
forall a. a -> Info a
Info PatternType
t, [VName] -> Info [VName]
forall a. a -> Info a
Info [VName]
retext) SrcLoc
loc
checkExp (Attr AttrInfo
info ExpBase NoInfo Name
e SrcLoc
loc) =
  AttrInfo -> Exp -> SrcLoc -> Exp
forall (f :: * -> *) vn.
AttrInfo -> ExpBase f vn -> SrcLoc -> ExpBase f vn
Attr AttrInfo
info (Exp -> SrcLoc -> Exp)
-> TermTypeM Exp -> TermTypeM (SrcLoc -> Exp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ExpBase NoInfo Name -> TermTypeM Exp
checkExp ExpBase NoInfo Name
e TermTypeM (SrcLoc -> Exp) -> TermTypeM SrcLoc -> TermTypeM Exp
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SrcLoc -> TermTypeM SrcLoc
forall (f :: * -> *) a. Applicative f => a -> f a
pure SrcLoc
loc

checkCases ::
  PatternType ->
  NE.NonEmpty (CaseBase NoInfo Name) ->
  TermTypeM (NE.NonEmpty (CaseBase Info VName), PatternType, [VName])
checkCases :: PatternType
-> NonEmpty (CaseBase NoInfo Name)
-> TermTypeM (NonEmpty (CaseBase Info VName), PatternType, [VName])
checkCases PatternType
mt NonEmpty (CaseBase NoInfo Name)
rest_cs =
  case NonEmpty (CaseBase NoInfo Name)
-> (CaseBase NoInfo Name, Maybe (NonEmpty (CaseBase NoInfo Name)))
forall a. NonEmpty a -> (a, Maybe (NonEmpty a))
NE.uncons NonEmpty (CaseBase NoInfo Name)
rest_cs of
    (CaseBase NoInfo Name
c, Maybe (NonEmpty (CaseBase NoInfo Name))
Nothing) -> do
      (CaseBase Info VName
c', PatternType
t, [VName]
retext) <- PatternType
-> CaseBase NoInfo Name
-> TermTypeM (CaseBase Info VName, PatternType, [VName])
checkCase PatternType
mt CaseBase NoInfo Name
c
      (NonEmpty (CaseBase Info VName), PatternType, [VName])
-> TermTypeM (NonEmpty (CaseBase Info VName), PatternType, [VName])
forall (m :: * -> *) a. Monad m => a -> m a
return (CaseBase Info VName
c' CaseBase Info VName
-> [CaseBase Info VName] -> NonEmpty (CaseBase Info VName)
forall a. a -> [a] -> NonEmpty a
NE.:| [], PatternType
t, [VName]
retext)
    (CaseBase NoInfo Name
c, Just NonEmpty (CaseBase NoInfo Name)
cs) -> do
      (((CaseBase Info VName
c', PatternType
c_t, [VName]
_), (NonEmpty (CaseBase Info VName)
cs', PatternType
cs_t, [VName]
_)), [Occurence]
dflow) <-
        TermTypeM
  ((CaseBase Info VName, PatternType, [VName]),
   (NonEmpty (CaseBase Info VName), PatternType, [VName]))
-> TermTypeM
     (((CaseBase Info VName, PatternType, [VName]),
       (NonEmpty (CaseBase Info VName), PatternType, [VName])),
      [Occurence])
forall a. TermTypeM a -> TermTypeM (a, [Occurence])
tapOccurences (TermTypeM
   ((CaseBase Info VName, PatternType, [VName]),
    (NonEmpty (CaseBase Info VName), PatternType, [VName]))
 -> TermTypeM
      (((CaseBase Info VName, PatternType, [VName]),
        (NonEmpty (CaseBase Info VName), PatternType, [VName])),
       [Occurence]))
-> TermTypeM
     ((CaseBase Info VName, PatternType, [VName]),
      (NonEmpty (CaseBase Info VName), PatternType, [VName]))
-> TermTypeM
     (((CaseBase Info VName, PatternType, [VName]),
       (NonEmpty (CaseBase Info VName), PatternType, [VName])),
      [Occurence])
forall a b. (a -> b) -> a -> b
$ PatternType
-> CaseBase NoInfo Name
-> TermTypeM (CaseBase Info VName, PatternType, [VName])
checkCase PatternType
mt CaseBase NoInfo Name
c TermTypeM (CaseBase Info VName, PatternType, [VName])
-> TermTypeM (NonEmpty (CaseBase Info VName), PatternType, [VName])
-> TermTypeM
     ((CaseBase Info VName, PatternType, [VName]),
      (NonEmpty (CaseBase Info VName), PatternType, [VName]))
forall a b. TermTypeM a -> TermTypeM b -> TermTypeM (a, b)
`alternative` PatternType
-> NonEmpty (CaseBase NoInfo Name)
-> TermTypeM (NonEmpty (CaseBase Info VName), PatternType, [VName])
checkCases PatternType
mt NonEmpty (CaseBase NoInfo Name)
cs
      (PatternType
brancht, [VName]
retext) <- SrcLoc
-> PatternType -> PatternType -> TermTypeM (PatternType, [VName])
unifyBranchTypes (CaseBase NoInfo Name -> SrcLoc
forall a. Located a => a -> SrcLoc
srclocOf CaseBase NoInfo Name
c) PatternType
c_t PatternType
cs_t
      let t :: PatternType
t =
            PatternType -> (Aliasing -> Aliasing) -> PatternType
forall dim asf ast.
TypeBase dim asf -> (asf -> ast) -> TypeBase dim ast
addAliases
              PatternType
brancht
              (Aliasing -> Aliasing -> Aliasing
forall a. Ord a => Set a -> Set a -> Set a
`S.difference` (VName -> Alias) -> Names -> Aliasing
forall b a. Ord b => (a -> b) -> Set a -> Set b
S.map VName -> Alias
AliasBound ([Occurence] -> Names
allConsumed [Occurence]
dflow))
      (NonEmpty (CaseBase Info VName), PatternType, [VName])
-> TermTypeM (NonEmpty (CaseBase Info VName), PatternType, [VName])
forall (m :: * -> *) a. Monad m => a -> m a
return (CaseBase Info VName
-> NonEmpty (CaseBase Info VName) -> NonEmpty (CaseBase Info VName)
forall a. a -> NonEmpty a -> NonEmpty a
NE.cons CaseBase Info VName
c' NonEmpty (CaseBase Info VName)
cs', PatternType
t, [VName]
retext)

checkCase ::
  PatternType ->
  CaseBase NoInfo Name ->
  TermTypeM (CaseBase Info VName, PatternType, [VName])
checkCase :: PatternType
-> CaseBase NoInfo Name
-> TermTypeM (CaseBase Info VName, PatternType, [VName])
checkCase PatternType
mt (CasePat UncheckedPattern
p ExpBase NoInfo Name
e SrcLoc
loc) =
  UncheckedPattern
-> InferredType
-> (Pattern
    -> TermTypeM (CaseBase Info VName, PatternType, [VName]))
-> TermTypeM (CaseBase Info VName, PatternType, [VName])
forall a.
UncheckedPattern
-> InferredType -> (Pattern -> TermTypeM a) -> TermTypeM a
bindingPattern UncheckedPattern
p (PatternType -> InferredType
Ascribed PatternType
mt) ((Pattern -> TermTypeM (CaseBase Info VName, PatternType, [VName]))
 -> TermTypeM (CaseBase Info VName, PatternType, [VName]))
-> (Pattern
    -> TermTypeM (CaseBase Info VName, PatternType, [VName]))
-> TermTypeM (CaseBase Info VName, PatternType, [VName])
forall a b. (a -> b) -> a -> b
$ \Pattern
p' -> do
    Exp
e' <- ExpBase NoInfo Name -> TermTypeM Exp
checkExp ExpBase NoInfo Name
e
    (PatternType
t, [VName]
retext) <- SrcLoc
-> Map VName Ident
-> PatternType
-> TermTypeM (PatternType, [VName])
unscopeType SrcLoc
loc (Pattern -> Map VName Ident
forall (f :: * -> *).
Functor f =>
PatternBase f VName -> Map VName (IdentBase f VName)
patternMap Pattern
p') (PatternType -> TermTypeM (PatternType, [VName]))
-> TermTypeM PatternType -> TermTypeM (PatternType, [VName])
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Exp -> TermTypeM PatternType
expTypeFully Exp
e'
    (CaseBase Info VName, PatternType, [VName])
-> TermTypeM (CaseBase Info VName, PatternType, [VName])
forall (m :: * -> *) a. Monad m => a -> m a
return (Pattern -> Exp -> SrcLoc -> CaseBase Info VName
forall (f :: * -> *) vn.
PatternBase f vn -> ExpBase f vn -> SrcLoc -> CaseBase f vn
CasePat Pattern
p' Exp
e' SrcLoc
loc, PatternType
t, [VName]
retext)

-- | An unmatched pattern. Used in in the generation of
-- unmatched pattern warnings by the type checker.
data Unmatched p
  = UnmatchedNum p [ExpBase Info VName]
  | UnmatchedBool p
  | UnmatchedConstr p
  | Unmatched p
  deriving (a -> Unmatched b -> Unmatched a
(a -> b) -> Unmatched a -> Unmatched b
(forall a b. (a -> b) -> Unmatched a -> Unmatched b)
-> (forall a b. a -> Unmatched b -> Unmatched a)
-> Functor Unmatched
forall a b. a -> Unmatched b -> Unmatched a
forall a b. (a -> b) -> Unmatched a -> Unmatched b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> Unmatched b -> Unmatched a
$c<$ :: forall a b. a -> Unmatched b -> Unmatched a
fmap :: (a -> b) -> Unmatched a -> Unmatched b
$cfmap :: forall a b. (a -> b) -> Unmatched a -> Unmatched b
Functor, Int -> Unmatched p -> ShowS
[Unmatched p] -> ShowS
Unmatched p -> String
(Int -> Unmatched p -> ShowS)
-> (Unmatched p -> String)
-> ([Unmatched p] -> ShowS)
-> Show (Unmatched p)
forall p. Show p => Int -> Unmatched p -> ShowS
forall p. Show p => [Unmatched p] -> ShowS
forall p. Show p => Unmatched p -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Unmatched p] -> ShowS
$cshowList :: forall p. Show p => [Unmatched p] -> ShowS
show :: Unmatched p -> String
$cshow :: forall p. Show p => Unmatched p -> String
showsPrec :: Int -> Unmatched p -> ShowS
$cshowsPrec :: forall p. Show p => Int -> Unmatched p -> ShowS
Show)

instance Pretty (Unmatched (PatternBase Info VName)) where
  ppr :: Unmatched Pattern -> Doc
ppr Unmatched Pattern
um = case Unmatched Pattern
um of
    (UnmatchedNum Pattern
p [Exp]
nums) -> Pattern -> Doc
forall v (f :: * -> *).
(Eq v, IsName v, Annot f) =>
PatternBase f v -> Doc
ppr' Pattern
p Doc -> Doc -> Doc
<+> Doc
"where p is not one of" Doc -> Doc -> Doc
<+> [Exp] -> Doc
forall a. Pretty a => a -> Doc
ppr [Exp]
nums
    (UnmatchedBool Pattern
p) -> Pattern -> Doc
forall v (f :: * -> *).
(Eq v, IsName v, Annot f) =>
PatternBase f v -> Doc
ppr' Pattern
p
    (UnmatchedConstr Pattern
p) -> Pattern -> Doc
forall v (f :: * -> *).
(Eq v, IsName v, Annot f) =>
PatternBase f v -> Doc
ppr' Pattern
p
    (Unmatched Pattern
p) -> Pattern -> Doc
forall v (f :: * -> *).
(Eq v, IsName v, Annot f) =>
PatternBase f v -> Doc
ppr' Pattern
p
    where
      ppr' :: PatternBase f v -> Doc
ppr' (PatternAscription PatternBase f v
p TypeDeclBase f v
t SrcLoc
_) = PatternBase f v -> Doc
forall a. Pretty a => a -> Doc
ppr PatternBase f v
p Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
":" Doc -> Doc -> Doc
<+> TypeDeclBase f v -> Doc
forall a. Pretty a => a -> Doc
ppr TypeDeclBase f v
t
      ppr' (PatternParens PatternBase f v
p SrcLoc
_) = Doc -> Doc
parens (Doc -> Doc) -> Doc -> Doc
forall a b. (a -> b) -> a -> b
$ PatternBase f v -> Doc
ppr' PatternBase f v
p
      ppr' (Id v
v f PatternType
_ SrcLoc
_) = v -> Doc
forall v. IsName v => v -> Doc
pprName v
v
      ppr' (TuplePattern [PatternBase f v]
pats SrcLoc
_) = Doc -> Doc
parens (Doc -> Doc) -> Doc -> Doc
forall a b. (a -> b) -> a -> b
$ [Doc] -> Doc
commasep ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ (PatternBase f v -> Doc) -> [PatternBase f v] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map PatternBase f v -> Doc
ppr' [PatternBase f v]
pats
      ppr' (RecordPattern [(Name, PatternBase f v)]
fs SrcLoc
_) = Doc -> Doc
braces (Doc -> Doc) -> Doc -> Doc
forall a b. (a -> b) -> a -> b
$ [Doc] -> Doc
commasep ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ ((Name, PatternBase f v) -> Doc)
-> [(Name, PatternBase f v)] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map (Name, PatternBase f v) -> Doc
ppField [(Name, PatternBase f v)]
fs
        where
          ppField :: (Name, PatternBase f v) -> Doc
ppField (Name
name, PatternBase f v
t) = String -> Doc
text (Name -> String
nameToString Name
name) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
equals Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> PatternBase f v -> Doc
ppr' PatternBase f v
t
      ppr' Wildcard {} = Doc
"_"
      ppr' (PatternLit ExpBase f v
e f PatternType
_ SrcLoc
_) = ExpBase f v -> Doc
forall a. Pretty a => a -> Doc
ppr ExpBase f v
e
      ppr' (PatternConstr Name
n f PatternType
_ [PatternBase f v]
ps SrcLoc
_) = Doc
"#" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Name -> Doc
forall a. Pretty a => a -> Doc
ppr Name
n Doc -> Doc -> Doc
<+> [Doc] -> Doc
sep ((PatternBase f v -> Doc) -> [PatternBase f v] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map PatternBase f v -> Doc
ppr' [PatternBase f v]
ps)

unpackPat :: Pattern -> [Maybe Pattern]
unpackPat :: Pattern -> [Maybe Pattern]
unpackPat Wildcard {} = [Maybe Pattern
forall a. Maybe a
Nothing]
unpackPat (PatternParens Pattern
p SrcLoc
_) = Pattern -> [Maybe Pattern]
unpackPat Pattern
p
unpackPat Id {} = [Maybe Pattern
forall a. Maybe a
Nothing]
unpackPat (TuplePattern [Pattern]
ps SrcLoc
_) = Pattern -> Maybe Pattern
forall a. a -> Maybe a
Just (Pattern -> Maybe Pattern) -> [Pattern] -> [Maybe Pattern]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Pattern]
ps
unpackPat (RecordPattern [(Name, Pattern)]
fs SrcLoc
_) = Pattern -> Maybe Pattern
forall a. a -> Maybe a
Just (Pattern -> Maybe Pattern)
-> ((Name, Pattern) -> Pattern) -> (Name, Pattern) -> Maybe Pattern
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Name, Pattern) -> Pattern
forall a b. (a, b) -> b
snd ((Name, Pattern) -> Maybe Pattern)
-> [(Name, Pattern)] -> [Maybe Pattern]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Map Name Pattern -> [(Name, Pattern)]
forall a. Map Name a -> [(Name, a)]
sortFields ([(Name, Pattern)] -> Map Name Pattern
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(Name, Pattern)]
fs)
unpackPat (PatternAscription Pattern
p TypeDeclBase Info VName
_ SrcLoc
_) = Pattern -> [Maybe Pattern]
unpackPat Pattern
p
unpackPat p :: Pattern
p@PatternLit {} = [Pattern -> Maybe Pattern
forall a. a -> Maybe a
Just Pattern
p]
unpackPat p :: Pattern
p@PatternConstr {} = [Pattern -> Maybe Pattern
forall a. a -> Maybe a
Just Pattern
p]

wildPattern :: Pattern -> Int -> Unmatched Pattern -> Unmatched Pattern
wildPattern :: Pattern -> Int -> Unmatched Pattern -> Unmatched Pattern
wildPattern (TuplePattern [Pattern]
ps SrcLoc
loc) Int
pos Unmatched Pattern
um = Pattern -> Pattern
forall vn. PatternBase Info vn -> PatternBase Info vn
wildTuple (Pattern -> Pattern) -> Unmatched Pattern -> Unmatched Pattern
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Unmatched Pattern
um
  where
    wildTuple :: PatternBase Info vn -> PatternBase Info vn
wildTuple PatternBase Info vn
p = [PatternBase Info vn] -> SrcLoc -> PatternBase Info vn
forall (f :: * -> *) vn.
[PatternBase f vn] -> SrcLoc -> PatternBase f vn
TuplePattern (Int -> [PatternBase Info vn] -> [PatternBase Info vn]
forall a. Int -> [a] -> [a]
take (Int
pos Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) [PatternBase Info vn]
forall vn. [PatternBase Info vn]
ps' [PatternBase Info vn]
-> [PatternBase Info vn] -> [PatternBase Info vn]
forall a. [a] -> [a] -> [a]
++ [PatternBase Info vn
p] [PatternBase Info vn]
-> [PatternBase Info vn] -> [PatternBase Info vn]
forall a. [a] -> [a] -> [a]
++ Int -> [PatternBase Info vn] -> [PatternBase Info vn]
forall a. Int -> [a] -> [a]
drop Int
pos [PatternBase Info vn]
forall vn. [PatternBase Info vn]
ps') SrcLoc
loc
    ps' :: [PatternBase Info vn]
ps' = (Pattern -> PatternBase Info vn)
-> [Pattern] -> [PatternBase Info vn]
forall a b. (a -> b) -> [a] -> [b]
map Pattern -> PatternBase Info vn
forall vn. Pattern -> PatternBase Info vn
wildOut [Pattern]
ps
    wildOut :: Pattern -> PatternBase Info vn
wildOut Pattern
p = Info PatternType -> SrcLoc -> PatternBase Info vn
forall (f :: * -> *) vn.
f PatternType -> SrcLoc -> PatternBase f vn
Wildcard (PatternType -> Info PatternType
forall a. a -> Info a
Info (Pattern -> PatternType
patternType Pattern
p)) (Pattern -> SrcLoc
forall a. Located a => a -> SrcLoc
srclocOf Pattern
p)
wildPattern (RecordPattern [(Name, Pattern)]
fs SrcLoc
loc) Int
pos Unmatched Pattern
um = Pattern -> Pattern
forall vn. PatternBase Info vn -> PatternBase Info vn
wildRecord (Pattern -> Pattern) -> Unmatched Pattern -> Unmatched Pattern
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Unmatched Pattern
um
  where
    wildRecord :: PatternBase Info vn -> PatternBase Info vn
wildRecord PatternBase Info vn
p =
      [(Name, PatternBase Info vn)] -> SrcLoc -> PatternBase Info vn
forall (f :: * -> *) vn.
[(Name, PatternBase f vn)] -> SrcLoc -> PatternBase f vn
RecordPattern (Int
-> [(Name, PatternBase Info vn)] -> [(Name, PatternBase Info vn)]
forall a. Int -> [a] -> [a]
take (Int
pos Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) [(Name, PatternBase Info vn)]
forall vn. [(Name, PatternBase Info vn)]
fs' [(Name, PatternBase Info vn)]
-> [(Name, PatternBase Info vn)] -> [(Name, PatternBase Info vn)]
forall a. [a] -> [a] -> [a]
++ [((Name, Pattern) -> Name
forall a b. (a, b) -> a
fst ([(Name, Pattern)]
fs [(Name, Pattern)] -> Int -> (Name, Pattern)
forall a. [a] -> Int -> a
!! (Int
pos Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)), PatternBase Info vn
p)] [(Name, PatternBase Info vn)]
-> [(Name, PatternBase Info vn)] -> [(Name, PatternBase Info vn)]
forall a. [a] -> [a] -> [a]
++ Int
-> [(Name, PatternBase Info vn)] -> [(Name, PatternBase Info vn)]
forall a. Int -> [a] -> [a]
drop Int
pos [(Name, PatternBase Info vn)]
forall vn. [(Name, PatternBase Info vn)]
fs') SrcLoc
loc
    fs' :: [(Name, PatternBase Info vn)]
fs' = ((Name, Pattern) -> (Name, PatternBase Info vn))
-> [(Name, Pattern)] -> [(Name, PatternBase Info vn)]
forall a b. (a -> b) -> [a] -> [b]
map (Name, Pattern) -> (Name, PatternBase Info vn)
forall a vn. (a, Pattern) -> (a, PatternBase Info vn)
wildOut [(Name, Pattern)]
fs
    wildOut :: (a, Pattern) -> (a, PatternBase Info vn)
wildOut (a
f, Pattern
p) = (a
f, Info PatternType -> SrcLoc -> PatternBase Info vn
forall (f :: * -> *) vn.
f PatternType -> SrcLoc -> PatternBase f vn
Wildcard (PatternType -> Info PatternType
forall a. a -> Info a
Info (Pattern -> PatternType
patternType Pattern
p)) (Pattern -> SrcLoc
forall a. Located a => a -> SrcLoc
srclocOf Pattern
p))
wildPattern (PatternAscription Pattern
p TypeDeclBase Info VName
_ SrcLoc
_) Int
pos Unmatched Pattern
um = Pattern -> Int -> Unmatched Pattern -> Unmatched Pattern
wildPattern Pattern
p Int
pos Unmatched Pattern
um
wildPattern (PatternParens Pattern
p SrcLoc
_) Int
pos Unmatched Pattern
um = Pattern -> Int -> Unmatched Pattern -> Unmatched Pattern
wildPattern Pattern
p Int
pos Unmatched Pattern
um
wildPattern (PatternConstr Name
n Info PatternType
t [Pattern]
ps SrcLoc
loc) Int
pos Unmatched Pattern
um = Pattern -> Pattern
forall vn. PatternBase Info vn -> PatternBase Info vn
wildConstr (Pattern -> Pattern) -> Unmatched Pattern -> Unmatched Pattern
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Unmatched Pattern
um
  where
    wildConstr :: PatternBase Info vn -> PatternBase Info vn
wildConstr PatternBase Info vn
p = Name
-> Info PatternType
-> [PatternBase Info vn]
-> SrcLoc
-> PatternBase Info vn
forall (f :: * -> *) vn.
Name
-> f PatternType
-> [PatternBase f vn]
-> SrcLoc
-> PatternBase f vn
PatternConstr Name
n Info PatternType
t (Int -> [PatternBase Info vn] -> [PatternBase Info vn]
forall a. Int -> [a] -> [a]
take (Int
pos Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) [PatternBase Info vn]
forall vn. [PatternBase Info vn]
ps' [PatternBase Info vn]
-> [PatternBase Info vn] -> [PatternBase Info vn]
forall a. [a] -> [a] -> [a]
++ [PatternBase Info vn
p] [PatternBase Info vn]
-> [PatternBase Info vn] -> [PatternBase Info vn]
forall a. [a] -> [a] -> [a]
++ Int -> [PatternBase Info vn] -> [PatternBase Info vn]
forall a. Int -> [a] -> [a]
drop Int
pos [PatternBase Info vn]
forall vn. [PatternBase Info vn]
ps') SrcLoc
loc
    ps' :: [PatternBase Info vn]
ps' = (Pattern -> PatternBase Info vn)
-> [Pattern] -> [PatternBase Info vn]
forall a b. (a -> b) -> [a] -> [b]
map Pattern -> PatternBase Info vn
forall vn. Pattern -> PatternBase Info vn
wildOut [Pattern]
ps
    wildOut :: Pattern -> PatternBase Info vn
wildOut Pattern
p = Info PatternType -> SrcLoc -> PatternBase Info vn
forall (f :: * -> *) vn.
f PatternType -> SrcLoc -> PatternBase f vn
Wildcard (PatternType -> Info PatternType
forall a. a -> Info a
Info (Pattern -> PatternType
patternType Pattern
p)) (Pattern -> SrcLoc
forall a. Located a => a -> SrcLoc
srclocOf Pattern
p)
wildPattern Pattern
_ Int
_ Unmatched Pattern
um = Unmatched Pattern
um

checkUnmatched :: Exp -> TermTypeM ()
checkUnmatched :: Exp -> TermTypeM ()
checkUnmatched Exp
e = TermTypeM Exp -> TermTypeM ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (TermTypeM Exp -> TermTypeM ()) -> TermTypeM Exp -> TermTypeM ()
forall a b. (a -> b) -> a -> b
$ Exp -> TermTypeM ()
forall (m :: * -> *). MonadTypeChecker m => Exp -> m ()
checkUnmatched' Exp
e TermTypeM () -> TermTypeM Exp -> TermTypeM Exp
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> ASTMapper TermTypeM -> Exp -> TermTypeM Exp
forall x (m :: * -> *).
(ASTMappable x, Monad m) =>
ASTMapper m -> x -> m x
astMap ASTMapper TermTypeM
tv Exp
e
  where
    checkUnmatched' :: Exp -> m ()
checkUnmatched' (Match Exp
_ NonEmpty (CaseBase Info VName)
cs (Info PatternType, Info [VName])
_ SrcLoc
loc) =
      let ps :: NonEmpty Pattern
ps = (CaseBase Info VName -> Pattern)
-> NonEmpty (CaseBase Info VName) -> NonEmpty Pattern
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\(CasePat Pattern
p Exp
_ SrcLoc
_) -> Pattern
p) NonEmpty (CaseBase Info VName)
cs
       in case (Unmatched Pattern -> Unmatched Pattern)
-> [Pattern] -> [Unmatched Pattern]
unmatched Unmatched Pattern -> Unmatched Pattern
forall a. a -> a
id ([Pattern] -> [Unmatched Pattern])
-> [Pattern] -> [Unmatched Pattern]
forall a b. (a -> b) -> a -> b
$ NonEmpty Pattern -> [Pattern]
forall a. NonEmpty a -> [a]
NE.toList NonEmpty Pattern
ps of
            [] -> () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
            [Unmatched Pattern]
ps' ->
              SrcLoc -> Notes -> Doc -> m ()
forall (m :: * -> *) loc a.
(MonadTypeChecker m, Located loc) =>
loc -> Notes -> Doc -> m a
typeError SrcLoc
loc Notes
forall a. Monoid a => a
mempty (Doc -> m ()) -> Doc -> m ()
forall a b. (a -> b) -> a -> b
$
                Doc
"Unmatched cases in match expression:"
                  Doc -> Doc -> Doc
</> Int -> Doc -> Doc
indent Int
2 ([Doc] -> Doc
stack ((Unmatched Pattern -> Doc) -> [Unmatched Pattern] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map Unmatched Pattern -> Doc
forall a. Pretty a => a -> Doc
ppr [Unmatched Pattern]
ps'))
    checkUnmatched' Exp
_ = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    tv :: ASTMapper TermTypeM
tv =
      ASTMapper :: forall (m :: * -> *).
(Exp -> m Exp)
-> (VName -> m VName)
-> (QualName VName -> m (QualName VName))
-> (StructType -> m StructType)
-> (PatternType -> m PatternType)
-> ASTMapper m
ASTMapper
        { mapOnExp :: Exp -> TermTypeM Exp
mapOnExp =
            \Exp
e' -> Exp -> TermTypeM ()
forall (m :: * -> *). MonadTypeChecker m => Exp -> m ()
checkUnmatched' Exp
e' TermTypeM () -> TermTypeM Exp -> TermTypeM Exp
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Exp -> TermTypeM Exp
forall (m :: * -> *) a. Monad m => a -> m a
return Exp
e',
          mapOnName :: VName -> TermTypeM VName
mapOnName = VName -> TermTypeM VName
forall (f :: * -> *) a. Applicative f => a -> f a
pure,
          mapOnQualName :: QualName VName -> TermTypeM (QualName VName)
mapOnQualName = QualName VName -> TermTypeM (QualName VName)
forall (f :: * -> *) a. Applicative f => a -> f a
pure,
          mapOnStructType :: StructType -> TermTypeM StructType
mapOnStructType = StructType -> TermTypeM StructType
forall (f :: * -> *) a. Applicative f => a -> f a
pure,
          mapOnPatternType :: PatternType -> TermTypeM PatternType
mapOnPatternType = PatternType -> TermTypeM PatternType
forall (f :: * -> *) a. Applicative f => a -> f a
pure
        }

-- | A data type for constructor patterns.  This is used to make the
-- code for detecting unmatched constructors cleaner, by separating
-- the constructor-pattern cases from other cases.
data ConstrPat = ConstrPat
  { ConstrPat -> Name
constrName :: Name,
    ConstrPat -> PatternType
constrType :: PatternType,
    ConstrPat -> [Pattern]
constrPayload :: [Pattern],
    ConstrPat -> SrcLoc
constrSrcLoc :: SrcLoc
  }

-- Be aware of these fishy equality instances!

instance Eq ConstrPat where
  ConstrPat Name
c1 PatternType
_ [Pattern]
_ SrcLoc
_ == :: ConstrPat -> ConstrPat -> Bool
== ConstrPat Name
c2 PatternType
_ [Pattern]
_ SrcLoc
_ = Name
c1 Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== Name
c2

instance Ord ConstrPat where
  ConstrPat Name
c1 PatternType
_ [Pattern]
_ SrcLoc
_ compare :: ConstrPat -> ConstrPat -> Ordering
`compare` ConstrPat Name
c2 PatternType
_ [Pattern]
_ SrcLoc
_ = Name
c1 Name -> Name -> Ordering
forall a. Ord a => a -> a -> Ordering
`compare` Name
c2

unmatched :: (Unmatched Pattern -> Unmatched Pattern) -> [Pattern] -> [Unmatched Pattern]
unmatched :: (Unmatched Pattern -> Unmatched Pattern)
-> [Pattern] -> [Unmatched Pattern]
unmatched Unmatched Pattern -> Unmatched Pattern
hole [Pattern]
orig_ps
  | Pattern
p : [Pattern]
_ <- [Pattern]
orig_ps,
    [(Int, [Maybe Pattern])] -> Bool
forall (t :: * -> *) a a. Foldable t => [(a, t a)] -> Bool
sameStructure [(Int, [Maybe Pattern])]
labeledCols = do
    (Int
i, [Maybe Pattern]
cols) <- [(Int, [Maybe Pattern])]
labeledCols
    let hole' :: Unmatched Pattern -> Unmatched Pattern
hole' = if Pattern -> Bool
forall (f :: * -> *) vn. PatternBase f vn -> Bool
isConstr Pattern
p then Unmatched Pattern -> Unmatched Pattern
hole else Unmatched Pattern -> Unmatched Pattern
hole (Unmatched Pattern -> Unmatched Pattern)
-> (Unmatched Pattern -> Unmatched Pattern)
-> Unmatched Pattern
-> Unmatched Pattern
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pattern -> Int -> Unmatched Pattern -> Unmatched Pattern
wildPattern Pattern
p Int
i
    case [Maybe Pattern] -> Maybe [Pattern]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence [Maybe Pattern]
cols of
      Maybe [Pattern]
Nothing -> []
      Just [Pattern]
cs
        | (Pattern -> Bool) -> [Pattern] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Pattern -> Bool
forall (f :: * -> *) vn. PatternBase f vn -> Bool
isPatternLit [Pattern]
cs -> (Unmatched Pattern -> Unmatched Pattern)
-> [Unmatched Pattern] -> [Unmatched Pattern]
forall a b. (a -> b) -> [a] -> [b]
map Unmatched Pattern -> Unmatched Pattern
hole' ([Unmatched Pattern] -> [Unmatched Pattern])
-> [Unmatched Pattern] -> [Unmatched Pattern]
forall a b. (a -> b) -> a -> b
$ [Pattern] -> [Unmatched Pattern]
localUnmatched [Pattern]
cs
        | Bool
otherwise -> (Unmatched Pattern -> Unmatched Pattern)
-> [Pattern] -> [Unmatched Pattern]
unmatched Unmatched Pattern -> Unmatched Pattern
hole' [Pattern]
cs
  | Bool
otherwise = []
  where
    labeledCols :: [(Int, [Maybe Pattern])]
labeledCols = [Int] -> [[Maybe Pattern]] -> [(Int, [Maybe Pattern])]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
1 ..] ([[Maybe Pattern]] -> [(Int, [Maybe Pattern])])
-> [[Maybe Pattern]] -> [(Int, [Maybe Pattern])]
forall a b. (a -> b) -> a -> b
$ [[Maybe Pattern]] -> [[Maybe Pattern]]
forall a. [[a]] -> [[a]]
transpose ([[Maybe Pattern]] -> [[Maybe Pattern]])
-> [[Maybe Pattern]] -> [[Maybe Pattern]]
forall a b. (a -> b) -> a -> b
$ (Pattern -> [Maybe Pattern]) -> [Pattern] -> [[Maybe Pattern]]
forall a b. (a -> b) -> [a] -> [b]
map Pattern -> [Maybe Pattern]
unpackPat [Pattern]
orig_ps

    localUnmatched :: [Pattern] -> [Unmatched Pattern]
    localUnmatched :: [Pattern] -> [Unmatched Pattern]
localUnmatched [] = []
    localUnmatched ps' :: [Pattern]
ps'@(Pattern
p' : [Pattern]
_) =
      case Pattern -> PatternType
patternType Pattern
p' of
        Scalar (Sum Map Name [PatternType]
cs'') ->
          -- We now know that we are matching a sum type, and thus
          -- that all patterns ps' are constructors (checked by
          -- 'all isPatternLit' before this function is called).
          let constrs :: [Name]
constrs = Map Name [PatternType] -> [Name]
forall k a. Map k a -> [k]
M.keys Map Name [PatternType]
cs''
              matched :: [ConstrPat]
matched = (Pattern -> Maybe ConstrPat) -> [Pattern] -> [ConstrPat]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe Pattern -> Maybe ConstrPat
constr [Pattern]
ps'
              unmatched' :: [Unmatched (PatternBase Info vn)]
unmatched' =
                (Name -> Unmatched (PatternBase Info vn))
-> [Name] -> [Unmatched (PatternBase Info vn)]
forall a b. (a -> b) -> [a] -> [b]
map (PatternBase Info vn -> Unmatched (PatternBase Info vn)
forall p. p -> Unmatched p
UnmatchedConstr (PatternBase Info vn -> Unmatched (PatternBase Info vn))
-> (Name -> PatternBase Info vn)
-> Name
-> Unmatched (PatternBase Info vn)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map Name [PatternType] -> Name -> PatternBase Info vn
forall vn. Map Name [PatternType] -> Name -> PatternBase Info vn
buildConstr Map Name [PatternType]
cs'') ([Name] -> [Unmatched (PatternBase Info vn)])
-> [Name] -> [Unmatched (PatternBase Info vn)]
forall a b. (a -> b) -> a -> b
$
                  [Name]
constrs [Name] -> [Name] -> [Name]
forall a. Eq a => [a] -> [a] -> [a]
\\ (ConstrPat -> Name) -> [ConstrPat] -> [Name]
forall a b. (a -> b) -> [a] -> [b]
map ConstrPat -> Name
constrName [ConstrPat]
matched
           in case [Unmatched (PatternBase Info Any)]
forall vn. [Unmatched (PatternBase Info vn)]
unmatched' of
                [] ->
                  let constrGroups :: [[ConstrPat]]
constrGroups = [ConstrPat] -> [[ConstrPat]]
forall a. Eq a => [a] -> [[a]]
group ([ConstrPat] -> [ConstrPat]
forall a. Ord a => [a] -> [a]
sort [ConstrPat]
matched)
                      removedConstrs :: [(Pattern, [[(Int, Pattern)]])]
removedConstrs = ([ConstrPat] -> Maybe (Pattern, [[(Int, Pattern)]]))
-> [[ConstrPat]] -> [(Pattern, [[(Int, Pattern)]])]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe [ConstrPat] -> Maybe (Pattern, [[(Int, Pattern)]])
stripConstrs [[ConstrPat]]
constrGroups
                      transposed :: [(Pattern, [[(Int, Pattern)]])]
transposed = (((Pattern, [[(Int, Pattern)]]) -> (Pattern, [[(Int, Pattern)]]))
-> [(Pattern, [[(Int, Pattern)]])]
-> [(Pattern, [[(Int, Pattern)]])]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (((Pattern, [[(Int, Pattern)]]) -> (Pattern, [[(Int, Pattern)]]))
 -> [(Pattern, [[(Int, Pattern)]])]
 -> [(Pattern, [[(Int, Pattern)]])])
-> (([[(Int, Pattern)]] -> [[(Int, Pattern)]])
    -> (Pattern, [[(Int, Pattern)]]) -> (Pattern, [[(Int, Pattern)]]))
-> ([[(Int, Pattern)]] -> [[(Int, Pattern)]])
-> [(Pattern, [[(Int, Pattern)]])]
-> [(Pattern, [[(Int, Pattern)]])]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([[(Int, Pattern)]] -> [[(Int, Pattern)]])
-> (Pattern, [[(Int, Pattern)]]) -> (Pattern, [[(Int, Pattern)]])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap) [[(Int, Pattern)]] -> [[(Int, Pattern)]]
forall a. [[a]] -> [[a]]
transpose [(Pattern, [[(Int, Pattern)]])]
removedConstrs
                      findUnmatched :: (Pattern, [[(Int, Pattern)]]) -> [Unmatched Pattern]
findUnmatched (Pattern
pc, [[(Int, Pattern)]]
trans) = do
                        [(Int, Pattern)]
col <- [[(Int, Pattern)]]
trans
                        case [(Int, Pattern)]
col of
                          [] -> []
                          ((Int
i, Pattern
_) : [(Int, Pattern)]
_) -> (Unmatched Pattern -> Unmatched Pattern)
-> [Pattern] -> [Unmatched Pattern]
unmatched (Int -> Pattern -> Unmatched Pattern -> Unmatched Pattern
wilder Int
i Pattern
pc) (((Int, Pattern) -> Pattern) -> [(Int, Pattern)] -> [Pattern]
forall a b. (a -> b) -> [a] -> [b]
map (Int, Pattern) -> Pattern
forall a b. (a, b) -> b
snd [(Int, Pattern)]
col)
                      wilder :: Int -> Pattern -> Unmatched Pattern -> Unmatched Pattern
wilder Int
i Pattern
pc Unmatched Pattern
s = (Pattern -> SrcLoc -> Pattern
forall (f :: * -> *) vn.
PatternBase f vn -> SrcLoc -> PatternBase f vn
`PatternParens` SrcLoc
forall a. Monoid a => a
mempty) (Pattern -> Pattern) -> Unmatched Pattern -> Unmatched Pattern
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Pattern -> Int -> Unmatched Pattern -> Unmatched Pattern
wildPattern Pattern
pc Int
i Unmatched Pattern
s
                   in ((Pattern, [[(Int, Pattern)]]) -> [Unmatched Pattern])
-> [(Pattern, [[(Int, Pattern)]])] -> [Unmatched Pattern]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Pattern, [[(Int, Pattern)]]) -> [Unmatched Pattern]
findUnmatched [(Pattern, [[(Int, Pattern)]])]
transposed
                [Unmatched (PatternBase Info Any)]
_ -> [Unmatched Pattern]
forall vn. [Unmatched (PatternBase Info vn)]
unmatched'
        Scalar (Prim PrimType
t) | Bool -> Bool
not ((Pattern -> Bool) -> [Pattern] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any Pattern -> Bool
forall (f :: * -> *) vn. PatternBase f vn -> Bool
idOrWild [Pattern]
ps') ->
          -- We now know that we are matching a sum type, and thus
          -- that all patterns ps' are literals (checked by 'all
          -- isPatternLit' before this function is called).
          case PrimType
t of
            PrimType
Bool ->
              let matched :: [Bool]
matched = [Bool] -> [Bool]
forall a. Eq a => [a] -> [a]
nub ([Bool] -> [Bool]) -> [Bool] -> [Bool]
forall a b. (a -> b) -> a -> b
$ (Pattern -> Maybe Bool) -> [Pattern] -> [Bool]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (Pattern -> Maybe Exp
forall (f :: * -> *) vn. PatternBase f vn -> Maybe (ExpBase f vn)
pExp (Pattern -> Maybe Exp)
-> (Exp -> Maybe Bool) -> Pattern -> Maybe Bool
forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> Exp -> Maybe Bool
forall (f :: * -> *) vn. ExpBase f vn -> Maybe Bool
bool) ([Pattern] -> [Bool]) -> [Pattern] -> [Bool]
forall a b. (a -> b) -> a -> b
$ (Pattern -> Bool) -> [Pattern] -> [Pattern]
forall a. (a -> Bool) -> [a] -> [a]
filter Pattern -> Bool
forall (f :: * -> *) vn. PatternBase f vn -> Bool
isPatternLit [Pattern]
ps'
               in (Bool -> Unmatched Pattern) -> [Bool] -> [Unmatched Pattern]
forall a b. (a -> b) -> [a] -> [b]
map (Pattern -> Unmatched Pattern
forall p. p -> Unmatched p
UnmatchedBool (Pattern -> Unmatched Pattern)
-> (Bool -> Pattern) -> Bool -> Unmatched Pattern
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TypeBase () Aliasing -> Bool -> Pattern
forall vn. TypeBase () Aliasing -> Bool -> PatternBase Info vn
buildBool (ScalarTypeBase () Aliasing -> TypeBase () Aliasing
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (PrimType -> ScalarTypeBase () Aliasing
forall dim as. PrimType -> ScalarTypeBase dim as
Prim PrimType
t))) ([Bool] -> [Unmatched Pattern]) -> [Bool] -> [Unmatched Pattern]
forall a b. (a -> b) -> a -> b
$ [Bool
True, Bool
False] [Bool] -> [Bool] -> [Bool]
forall a. Eq a => [a] -> [a] -> [a]
\\ [Bool]
matched
            PrimType
_ ->
              let matched :: [Exp]
matched = (Pattern -> Maybe Exp) -> [Pattern] -> [Exp]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe Pattern -> Maybe Exp
forall (f :: * -> *) vn. PatternBase f vn -> Maybe (ExpBase f vn)
pExp ([Pattern] -> [Exp]) -> [Pattern] -> [Exp]
forall a b. (a -> b) -> a -> b
$ (Pattern -> Bool) -> [Pattern] -> [Pattern]
forall a. (a -> Bool) -> [a] -> [a]
filter Pattern -> Bool
forall (f :: * -> *) vn. PatternBase f vn -> Bool
isPatternLit [Pattern]
ps'
               in [Pattern -> [Exp] -> Unmatched Pattern
forall p. p -> [Exp] -> Unmatched p
UnmatchedNum (Info PatternType -> String -> Pattern
forall (f :: * -> *).
f PatternType -> String -> PatternBase f VName
buildId (PatternType -> Info PatternType
forall a. a -> Info a
Info (PatternType -> Info PatternType)
-> PatternType -> Info PatternType
forall a b. (a -> b) -> a -> b
$ ScalarTypeBase (DimDecl VName) Aliasing -> PatternType
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (ScalarTypeBase (DimDecl VName) Aliasing -> PatternType)
-> ScalarTypeBase (DimDecl VName) Aliasing -> PatternType
forall a b. (a -> b) -> a -> b
$ PrimType -> ScalarTypeBase (DimDecl VName) Aliasing
forall dim as. PrimType -> ScalarTypeBase dim as
Prim PrimType
t) String
"p") [Exp]
matched]
        PatternType
_ -> []

    isConstr :: PatternBase f vn -> Bool
isConstr PatternConstr {} = Bool
True
    isConstr (PatternParens PatternBase f vn
p SrcLoc
_) = PatternBase f vn -> Bool
isConstr PatternBase f vn
p
    isConstr PatternBase f vn
_ = Bool
False

    stripConstrs :: [ConstrPat] -> Maybe (Pattern, [[(Int, Pattern)]])
    stripConstrs :: [ConstrPat] -> Maybe (Pattern, [[(Int, Pattern)]])
stripConstrs (pc :: ConstrPat
pc@ConstrPat {} : [ConstrPat]
cs') = (Pattern, [[(Int, Pattern)]])
-> Maybe (Pattern, [[(Int, Pattern)]])
forall a. a -> Maybe a
Just (ConstrPat -> Pattern
unConstr ConstrPat
pc, ConstrPat -> [(Int, Pattern)]
stripConstr ConstrPat
pc [(Int, Pattern)] -> [[(Int, Pattern)]] -> [[(Int, Pattern)]]
forall a. a -> [a] -> [a]
: (ConstrPat -> [(Int, Pattern)])
-> [ConstrPat] -> [[(Int, Pattern)]]
forall a b. (a -> b) -> [a] -> [b]
map ConstrPat -> [(Int, Pattern)]
stripConstr [ConstrPat]
cs')
    stripConstrs [] = Maybe (Pattern, [[(Int, Pattern)]])
forall a. Maybe a
Nothing

    stripConstr :: ConstrPat -> [(Int, Pattern)]
    stripConstr :: ConstrPat -> [(Int, Pattern)]
stripConstr (ConstrPat Name
_ PatternType
_ [Pattern]
ps' SrcLoc
_) = [Int] -> [Pattern] -> [(Int, Pattern)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
1 ..] [Pattern]
ps'

    sameStructure :: [(a, t a)] -> Bool
sameStructure [] = Bool
True
    sameStructure ((a, t a)
x : [(a, t a)]
xs) = (t a -> Bool) -> [t a] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (\t a
y -> t a -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length t a
y Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== t a -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length t a
x') [t a]
xs'
      where
        (t a
x' : [t a]
xs') = ((a, t a) -> t a) -> [(a, t a)] -> [t a]
forall a b. (a -> b) -> [a] -> [b]
map (a, t a) -> t a
forall a b. (a, b) -> b
snd ((a, t a)
x (a, t a) -> [(a, t a)] -> [(a, t a)]
forall a. a -> [a] -> [a]
: [(a, t a)]
xs)

    pExp :: PatternBase f vn -> Maybe (ExpBase f vn)
pExp (PatternLit ExpBase f vn
e' f PatternType
_ SrcLoc
_) = ExpBase f vn -> Maybe (ExpBase f vn)
forall a. a -> Maybe a
Just ExpBase f vn
e'
    pExp PatternBase f vn
_ = Maybe (ExpBase f vn)
forall a. Maybe a
Nothing

    constr :: Pattern -> Maybe ConstrPat
constr (PatternConstr Name
c (Info PatternType
t) [Pattern]
ps SrcLoc
loc) = ConstrPat -> Maybe ConstrPat
forall a. a -> Maybe a
Just (ConstrPat -> Maybe ConstrPat) -> ConstrPat -> Maybe ConstrPat
forall a b. (a -> b) -> a -> b
$ Name -> PatternType -> [Pattern] -> SrcLoc -> ConstrPat
ConstrPat Name
c PatternType
t [Pattern]
ps SrcLoc
loc
    constr (PatternParens Pattern
p SrcLoc
_) = Pattern -> Maybe ConstrPat
constr Pattern
p
    constr (PatternAscription Pattern
p' TypeDeclBase Info VName
_ SrcLoc
_) = Pattern -> Maybe ConstrPat
constr Pattern
p'
    constr Pattern
_ = Maybe ConstrPat
forall a. Maybe a
Nothing

    unConstr :: ConstrPat -> Pattern
unConstr ConstrPat
p =
      Name -> Info PatternType -> [Pattern] -> SrcLoc -> Pattern
forall (f :: * -> *) vn.
Name
-> f PatternType
-> [PatternBase f vn]
-> SrcLoc
-> PatternBase f vn
PatternConstr (ConstrPat -> Name
constrName ConstrPat
p) (PatternType -> Info PatternType
forall a. a -> Info a
Info (PatternType -> Info PatternType)
-> PatternType -> Info PatternType
forall a b. (a -> b) -> a -> b
$ ConstrPat -> PatternType
constrType ConstrPat
p) (ConstrPat -> [Pattern]
constrPayload ConstrPat
p) (ConstrPat -> SrcLoc
constrSrcLoc ConstrPat
p)

    isPatternLit :: PatternBase f vn -> Bool
isPatternLit PatternLit {} = Bool
True
    isPatternLit (PatternAscription PatternBase f vn
p' TypeDeclBase f vn
_ SrcLoc
_) = PatternBase f vn -> Bool
isPatternLit PatternBase f vn
p'
    isPatternLit (PatternParens PatternBase f vn
p' SrcLoc
_) = PatternBase f vn -> Bool
isPatternLit PatternBase f vn
p'
    isPatternLit PatternConstr {} = Bool
True
    isPatternLit PatternBase f vn
_ = Bool
False

    idOrWild :: PatternBase f vn -> Bool
idOrWild Id {} = Bool
True
    idOrWild Wildcard {} = Bool
True
    idOrWild (PatternAscription PatternBase f vn
p' TypeDeclBase f vn
_ SrcLoc
_) = PatternBase f vn -> Bool
idOrWild PatternBase f vn
p'
    idOrWild (PatternParens PatternBase f vn
p' SrcLoc
_) = PatternBase f vn -> Bool
idOrWild PatternBase f vn
p'
    idOrWild PatternBase f vn
_ = Bool
False

    bool :: ExpBase f vn -> Maybe Bool
bool (Literal (BoolValue Bool
b) SrcLoc
_) = Bool -> Maybe Bool
forall a. a -> Maybe a
Just Bool
b
    bool ExpBase f vn
_ = Maybe Bool
forall a. Maybe a
Nothing

    buildConstr :: Map Name [PatternType] -> Name -> PatternBase Info vn
buildConstr Map Name [PatternType]
m Name
c =
      let t :: PatternType
t = ScalarTypeBase (DimDecl VName) Aliasing -> PatternType
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (ScalarTypeBase (DimDecl VName) Aliasing -> PatternType)
-> ScalarTypeBase (DimDecl VName) Aliasing -> PatternType
forall a b. (a -> b) -> a -> b
$ Map Name [PatternType] -> ScalarTypeBase (DimDecl VName) Aliasing
forall dim as. Map Name [TypeBase dim as] -> ScalarTypeBase dim as
Sum Map Name [PatternType]
m
          cs :: [PatternType]
cs = Map Name [PatternType]
m Map Name [PatternType] -> Name -> [PatternType]
forall k a. Ord k => Map k a -> k -> a
M.! Name
c
          wildCS :: [PatternBase Info vn]
wildCS = (PatternType -> PatternBase Info vn)
-> [PatternType] -> [PatternBase Info vn]
forall a b. (a -> b) -> [a] -> [b]
map (\PatternType
ct -> Info PatternType -> SrcLoc -> PatternBase Info vn
forall (f :: * -> *) vn.
f PatternType -> SrcLoc -> PatternBase f vn
Wildcard (PatternType -> Info PatternType
forall a. a -> Info a
Info PatternType
ct) SrcLoc
forall a. Monoid a => a
mempty) [PatternType]
cs
       in if [PatternBase Info Any] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [PatternBase Info Any]
forall vn. [PatternBase Info vn]
wildCS
            then Name
-> Info PatternType
-> [PatternBase Info vn]
-> SrcLoc
-> PatternBase Info vn
forall (f :: * -> *) vn.
Name
-> f PatternType
-> [PatternBase f vn]
-> SrcLoc
-> PatternBase f vn
PatternConstr Name
c (PatternType -> Info PatternType
forall a. a -> Info a
Info PatternType
t) [] SrcLoc
forall a. Monoid a => a
mempty
            else PatternBase Info vn -> SrcLoc -> PatternBase Info vn
forall (f :: * -> *) vn.
PatternBase f vn -> SrcLoc -> PatternBase f vn
PatternParens (Name
-> Info PatternType
-> [PatternBase Info vn]
-> SrcLoc
-> PatternBase Info vn
forall (f :: * -> *) vn.
Name
-> f PatternType
-> [PatternBase f vn]
-> SrcLoc
-> PatternBase f vn
PatternConstr Name
c (PatternType -> Info PatternType
forall a. a -> Info a
Info PatternType
t) [PatternBase Info vn]
forall vn. [PatternBase Info vn]
wildCS SrcLoc
forall a. Monoid a => a
mempty) SrcLoc
forall a. Monoid a => a
mempty
    buildBool :: TypeBase () Aliasing -> Bool -> PatternBase Info vn
buildBool TypeBase () Aliasing
t Bool
b =
      ExpBase Info vn
-> Info PatternType -> SrcLoc -> PatternBase Info vn
forall (f :: * -> *) vn.
ExpBase f vn -> f PatternType -> SrcLoc -> PatternBase f vn
PatternLit (PrimValue -> SrcLoc -> ExpBase Info vn
forall (f :: * -> *) vn. PrimValue -> SrcLoc -> ExpBase f vn
Literal (Bool -> PrimValue
BoolValue Bool
b) SrcLoc
forall a. Monoid a => a
mempty) (PatternType -> Info PatternType
forall a. a -> Info a
Info (TypeBase () Aliasing -> PatternType
forall as vn. TypeBase () as -> TypeBase (DimDecl vn) as
addSizes TypeBase () Aliasing
t)) SrcLoc
forall a. Monoid a => a
mempty
    buildId :: f PatternType -> String -> PatternBase f VName
buildId f PatternType
t String
n =
      -- The VName tag here will never be used since the value
      -- exists exclusively for printing warnings.
      VName -> f PatternType -> SrcLoc -> PatternBase f VName
forall (f :: * -> *) vn.
vn -> f PatternType -> SrcLoc -> PatternBase f vn
Id (Name -> Int -> VName
VName (String -> Name
nameFromString String
n) (-Int
1)) f PatternType
t SrcLoc
forall a. Monoid a => a
mempty

checkIdent :: IdentBase NoInfo Name -> TermTypeM Ident
checkIdent :: IdentBase NoInfo Name -> TermTypeM Ident
checkIdent (Ident Name
name NoInfo PatternType
_ SrcLoc
loc) = do
  (QualName [VName]
_ VName
name', PatternType
vt) <- SrcLoc -> QualName Name -> TermTypeM (QualName VName, PatternType)
forall (m :: * -> *).
MonadTypeChecker m =>
SrcLoc -> QualName Name -> m (QualName VName, PatternType)
lookupVar SrcLoc
loc (Name -> QualName Name
forall v. v -> QualName v
qualName Name
name)
  Ident -> TermTypeM Ident
forall (m :: * -> *) a. Monad m => a -> m a
return (Ident -> TermTypeM Ident) -> Ident -> TermTypeM Ident
forall a b. (a -> b) -> a -> b
$ VName -> Info PatternType -> SrcLoc -> Ident
forall (f :: * -> *) vn.
vn -> f PatternType -> SrcLoc -> IdentBase f vn
Ident VName
name' (PatternType -> Info PatternType
forall a. a -> Info a
Info PatternType
vt) SrcLoc
loc

checkDimIndex :: DimIndexBase NoInfo Name -> TermTypeM DimIndex
checkDimIndex :: DimIndexBase NoInfo Name -> TermTypeM DimIndex
checkDimIndex (DimFix ExpBase NoInfo Name
i) =
  Exp -> DimIndex
forall (f :: * -> *) vn. ExpBase f vn -> DimIndexBase f vn
DimFix (Exp -> DimIndex) -> TermTypeM Exp -> TermTypeM DimIndex
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (String -> [PrimType] -> Exp -> TermTypeM Exp
require String
"use as index" [PrimType]
anySignedType (Exp -> TermTypeM Exp) -> TermTypeM Exp -> TermTypeM Exp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ExpBase NoInfo Name -> TermTypeM Exp
checkExp ExpBase NoInfo Name
i)
checkDimIndex (DimSlice Maybe (ExpBase NoInfo Name)
i Maybe (ExpBase NoInfo Name)
j Maybe (ExpBase NoInfo Name)
s) =
  Maybe Exp -> Maybe Exp -> Maybe Exp -> DimIndex
forall (f :: * -> *) vn.
Maybe (ExpBase f vn)
-> Maybe (ExpBase f vn)
-> Maybe (ExpBase f vn)
-> DimIndexBase f vn
DimSlice (Maybe Exp -> Maybe Exp -> Maybe Exp -> DimIndex)
-> TermTypeM (Maybe Exp)
-> TermTypeM (Maybe Exp -> Maybe Exp -> DimIndex)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe (ExpBase NoInfo Name) -> TermTypeM (Maybe Exp)
check Maybe (ExpBase NoInfo Name)
i TermTypeM (Maybe Exp -> Maybe Exp -> DimIndex)
-> TermTypeM (Maybe Exp) -> TermTypeM (Maybe Exp -> DimIndex)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Maybe (ExpBase NoInfo Name) -> TermTypeM (Maybe Exp)
check Maybe (ExpBase NoInfo Name)
j TermTypeM (Maybe Exp -> DimIndex)
-> TermTypeM (Maybe Exp) -> TermTypeM DimIndex
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Maybe (ExpBase NoInfo Name) -> TermTypeM (Maybe Exp)
check Maybe (ExpBase NoInfo Name)
s
  where
    check :: Maybe (ExpBase NoInfo Name) -> TermTypeM (Maybe Exp)
check =
      TermTypeM (Maybe Exp)
-> (ExpBase NoInfo Name -> TermTypeM (Maybe Exp))
-> Maybe (ExpBase NoInfo Name)
-> TermTypeM (Maybe Exp)
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (Maybe Exp -> TermTypeM (Maybe Exp)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe Exp
forall a. Maybe a
Nothing) ((ExpBase NoInfo Name -> TermTypeM (Maybe Exp))
 -> Maybe (ExpBase NoInfo Name) -> TermTypeM (Maybe Exp))
-> (ExpBase NoInfo Name -> TermTypeM (Maybe Exp))
-> Maybe (ExpBase NoInfo Name)
-> TermTypeM (Maybe Exp)
forall a b. (a -> b) -> a -> b
$
        (Exp -> Maybe Exp) -> TermTypeM Exp -> TermTypeM (Maybe Exp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Exp -> Maybe Exp
forall a. a -> Maybe a
Just (TermTypeM Exp -> TermTypeM (Maybe Exp))
-> (Exp -> TermTypeM Exp) -> Exp -> TermTypeM (Maybe Exp)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> StructType -> Exp -> TermTypeM Exp
unifies String
"use as index" (ScalarTypeBase (DimDecl VName) () -> StructType
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (ScalarTypeBase (DimDecl VName) () -> StructType)
-> ScalarTypeBase (DimDecl VName) () -> StructType
forall a b. (a -> b) -> a -> b
$ PrimType -> ScalarTypeBase (DimDecl VName) ()
forall dim as. PrimType -> ScalarTypeBase dim as
Prim (PrimType -> ScalarTypeBase (DimDecl VName) ())
-> PrimType -> ScalarTypeBase (DimDecl VName) ()
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
Signed IntType
Int32) (Exp -> TermTypeM (Maybe Exp))
-> (ExpBase NoInfo Name -> TermTypeM Exp)
-> ExpBase NoInfo Name
-> TermTypeM (Maybe Exp)
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< ExpBase NoInfo Name -> TermTypeM Exp
checkExp

sequentially :: TermTypeM a -> (a -> Occurences -> TermTypeM b) -> TermTypeM b
sequentially :: TermTypeM a -> (a -> [Occurence] -> TermTypeM b) -> TermTypeM b
sequentially TermTypeM a
m1 a -> [Occurence] -> TermTypeM b
m2 = do
  (a
a, [Occurence]
m1flow) <- TermTypeM a -> TermTypeM (a, [Occurence])
forall a. TermTypeM a -> TermTypeM (a, [Occurence])
collectOccurences TermTypeM a
m1
  (b
b, [Occurence]
m2flow) <- TermTypeM b -> TermTypeM (b, [Occurence])
forall a. TermTypeM a -> TermTypeM (a, [Occurence])
collectOccurences (TermTypeM b -> TermTypeM (b, [Occurence]))
-> TermTypeM b -> TermTypeM (b, [Occurence])
forall a b. (a -> b) -> a -> b
$ a -> [Occurence] -> TermTypeM b
m2 a
a [Occurence]
m1flow
  [Occurence] -> TermTypeM ()
occur ([Occurence] -> TermTypeM ()) -> [Occurence] -> TermTypeM ()
forall a b. (a -> b) -> a -> b
$ [Occurence]
m1flow [Occurence] -> [Occurence] -> [Occurence]
`seqOccurences` [Occurence]
m2flow
  b -> TermTypeM b
forall (m :: * -> *) a. Monad m => a -> m a
return b
b

type Arg = (Exp, PatternType, Occurences, SrcLoc)

argExp :: Arg -> Exp
argExp :: Arg -> Exp
argExp (Exp
e, PatternType
_, [Occurence]
_, SrcLoc
_) = Exp
e

argType :: Arg -> PatternType
argType :: Arg -> PatternType
argType (Exp
_, PatternType
t, [Occurence]
_, SrcLoc
_) = PatternType
t

checkArg :: UncheckedExp -> TermTypeM Arg
checkArg :: ExpBase NoInfo Name -> TermTypeM Arg
checkArg ExpBase NoInfo Name
arg = do
  (Exp
arg', [Occurence]
dflow) <- TermTypeM Exp -> TermTypeM (Exp, [Occurence])
forall a. TermTypeM a -> TermTypeM (a, [Occurence])
collectOccurences (TermTypeM Exp -> TermTypeM (Exp, [Occurence]))
-> TermTypeM Exp -> TermTypeM (Exp, [Occurence])
forall a b. (a -> b) -> a -> b
$ ExpBase NoInfo Name -> TermTypeM Exp
checkExp ExpBase NoInfo Name
arg
  PatternType
arg_t <- Exp -> TermTypeM PatternType
expType Exp
arg'
  Arg -> TermTypeM Arg
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp
arg', PatternType
arg_t, [Occurence]
dflow, Exp -> SrcLoc
forall a. Located a => a -> SrcLoc
srclocOf Exp
arg')

instantiateDimsInType ::
  SrcLoc ->
  RigidSource ->
  TypeBase (DimDecl VName) als ->
  TermTypeM (TypeBase (DimDecl VName) als, [VName])
instantiateDimsInType :: SrcLoc
-> RigidSource
-> TypeBase (DimDecl VName) als
-> TermTypeM (TypeBase (DimDecl VName) als, [VName])
instantiateDimsInType SrcLoc
tloc RigidSource
rsrc =
  SrcLoc
-> String
-> Rigidity
-> TypeBase (DimDecl VName) als
-> TermTypeM (TypeBase (DimDecl VName) als, [VName])
forall (m :: * -> *) als.
MonadUnify m =>
SrcLoc
-> String
-> Rigidity
-> TypeBase (DimDecl VName) als
-> m (TypeBase (DimDecl VName) als, [VName])
instantiateEmptyArrayDims SrcLoc
tloc String
"d" (Rigidity
 -> TypeBase (DimDecl VName) als
 -> TermTypeM (TypeBase (DimDecl VName) als, [VName]))
-> Rigidity
-> TypeBase (DimDecl VName) als
-> TermTypeM (TypeBase (DimDecl VName) als, [VName])
forall a b. (a -> b) -> a -> b
$ RigidSource -> Rigidity
Rigid RigidSource
rsrc

instantiateDimsInReturnType ::
  SrcLoc ->
  Maybe (QualName VName) ->
  TypeBase (DimDecl VName) als ->
  TermTypeM (TypeBase (DimDecl VName) als, [VName])
instantiateDimsInReturnType :: SrcLoc
-> Maybe (QualName VName)
-> TypeBase (DimDecl VName) als
-> TermTypeM (TypeBase (DimDecl VName) als, [VName])
instantiateDimsInReturnType SrcLoc
tloc Maybe (QualName VName)
fname =
  SrcLoc
-> String
-> Rigidity
-> TypeBase (DimDecl VName) als
-> TermTypeM (TypeBase (DimDecl VName) als, [VName])
forall (m :: * -> *) als.
MonadUnify m =>
SrcLoc
-> String
-> Rigidity
-> TypeBase (DimDecl VName) als
-> m (TypeBase (DimDecl VName) als, [VName])
instantiateEmptyArrayDims SrcLoc
tloc String
"ret" (Rigidity
 -> TypeBase (DimDecl VName) als
 -> TermTypeM (TypeBase (DimDecl VName) als, [VName]))
-> Rigidity
-> TypeBase (DimDecl VName) als
-> TermTypeM (TypeBase (DimDecl VName) als, [VName])
forall a b. (a -> b) -> a -> b
$ RigidSource -> Rigidity
Rigid (RigidSource -> Rigidity) -> RigidSource -> Rigidity
forall a b. (a -> b) -> a -> b
$ Maybe (QualName VName) -> RigidSource
RigidRet Maybe (QualName VName)
fname

-- Some information about the function/operator we are trying to
-- apply, and how many arguments it has previously accepted.  Used for
-- generating nicer type errors.
type ApplyOp = (Maybe (QualName VName), Int)

checkApply ::
  SrcLoc ->
  ApplyOp ->
  PatternType ->
  Arg ->
  TermTypeM (PatternType, PatternType, Maybe VName, [VName])
checkApply :: SrcLoc
-> ApplyOp
-> PatternType
-> Arg
-> TermTypeM (PatternType, PatternType, Maybe VName, [VName])
checkApply
  SrcLoc
loc
  (Maybe (QualName VName)
fname, Int
_)
  (Scalar (Arrow Aliasing
as PName
pname PatternType
tp1 PatternType
tp2))
  (Exp
argexp, PatternType
argtype, [Occurence]
dflow, SrcLoc
argloc) =
    Checking
-> TermTypeM (PatternType, PatternType, Maybe VName, [VName])
-> TermTypeM (PatternType, PatternType, Maybe VName, [VName])
forall a. Checking -> TermTypeM a -> TermTypeM a
onFailure (Maybe (QualName VName)
-> Exp -> StructType -> StructType -> Checking
CheckingApply Maybe (QualName VName)
fname Exp
argexp (PatternType -> StructType
forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct PatternType
tp1) (PatternType -> StructType
forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct PatternType
argtype)) (TermTypeM (PatternType, PatternType, Maybe VName, [VName])
 -> TermTypeM (PatternType, PatternType, Maybe VName, [VName]))
-> TermTypeM (PatternType, PatternType, Maybe VName, [VName])
-> TermTypeM (PatternType, PatternType, Maybe VName, [VName])
forall a b. (a -> b) -> a -> b
$ do
      Usage -> StructType -> StructType -> TermTypeM ()
forall (m :: * -> *).
MonadUnify m =>
Usage -> StructType -> StructType -> m ()
expect (SrcLoc -> String -> Usage
mkUsage SrcLoc
argloc String
"use as function argument") (PatternType -> StructType
forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct PatternType
tp1) (PatternType -> StructType
forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct PatternType
argtype)

      -- Perform substitutions of instantiated variables in the types.
      PatternType
tp1' <- PatternType -> TermTypeM PatternType
forall a (m :: * -> *). (Substitutable a, MonadUnify m) => a -> m a
normTypeFully PatternType
tp1
      (PatternType
tp2', [VName]
ext) <- SrcLoc
-> Maybe (QualName VName)
-> PatternType
-> TermTypeM (PatternType, [VName])
forall als.
SrcLoc
-> Maybe (QualName VName)
-> TypeBase (DimDecl VName) als
-> TermTypeM (TypeBase (DimDecl VName) als, [VName])
instantiateDimsInReturnType SrcLoc
loc Maybe (QualName VName)
fname (PatternType -> TermTypeM (PatternType, [VName]))
-> TermTypeM PatternType -> TermTypeM (PatternType, [VName])
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< PatternType -> TermTypeM PatternType
forall a (m :: * -> *). (Substitutable a, MonadUnify m) => a -> m a
normTypeFully PatternType
tp2
      PatternType
argtype' <- PatternType -> TermTypeM PatternType
forall a (m :: * -> *). (Substitutable a, MonadUnify m) => a -> m a
normTypeFully PatternType
argtype

      -- Check whether this would produce an impossible return type.
      let (Names
_, Names
tp2_paramdims, Names
_) = StructType -> (Names, Names, Names)
dimUses (StructType -> (Names, Names, Names))
-> StructType -> (Names, Names, Names)
forall a b. (a -> b) -> a -> b
$ PatternType -> StructType
forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct PatternType
tp2'
      case (VName -> Bool) -> [VName] -> [VName]
forall a. (a -> Bool) -> [a] -> [a]
filter (VName -> Names -> Bool
forall a. Ord a => a -> Set a -> Bool
`S.member` Names
tp2_paramdims) [VName]
ext of
        [] -> () -> TermTypeM ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
        [VName]
ext_paramdims -> do
          let onDim :: DimDecl VName -> DimDecl VName
onDim (NamedDim QualName VName
qn)
                | QualName VName -> VName
forall vn. QualName vn -> vn
qualLeaf QualName VName
qn VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
ext_paramdims = DimDecl VName
forall vn. DimDecl vn
AnyDim
              onDim DimDecl VName
d = DimDecl VName
d
          SrcLoc -> Notes -> Doc -> TermTypeM ()
forall (m :: * -> *) loc a.
(MonadTypeChecker m, Located loc) =>
loc -> Notes -> Doc -> m a
typeError SrcLoc
loc Notes
forall a. Monoid a => a
mempty (Doc -> TermTypeM ()) -> Doc -> TermTypeM ()
forall a b. (a -> b) -> a -> b
$
            Doc
"Anonymous size would appear in function parameter of return type:"
              Doc -> Doc -> Doc
</> Int -> Doc -> Doc
indent Int
2 (PatternType -> Doc
forall a. Pretty a => a -> Doc
ppr ((DimDecl VName -> DimDecl VName) -> PatternType -> PatternType
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first DimDecl VName -> DimDecl VName
onDim PatternType
tp2'))
              Doc -> Doc -> Doc
</> String -> Doc
textwrap String
"This is usually because a higher-order function is used with functional arguments that return anonymous sizes, which are then used as parameters of other function arguments."

      [Occurence] -> TermTypeM ()
occur [Aliasing -> SrcLoc -> Occurence
observation Aliasing
as SrcLoc
loc]

      [Occurence] -> TermTypeM ()
checkOccurences [Occurence]
dflow

      case [Occurence] -> Maybe Occurence
anyConsumption [Occurence]
dflow of
        Just Occurence
c ->
          let msg :: String
msg = String
"type of expression with consumption at " String -> ShowS
forall a. [a] -> [a] -> [a]
++ SrcLoc -> String
forall a. Located a => a -> String
locStr (Occurence -> SrcLoc
location Occurence
c)
           in Usage -> String -> PatternType -> TermTypeM ()
forall (m :: * -> *) dim as.
(MonadUnify m, Pretty (ShapeDecl dim), Monoid as) =>
Usage -> String -> TypeBase dim as -> m ()
zeroOrderType (SrcLoc -> String -> Usage
mkUsage SrcLoc
argloc String
"potential consumption in expression") String
msg PatternType
tp1
        Maybe Occurence
_ -> () -> TermTypeM ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

      [Occurence]
occurs <- ([Occurence]
dflow [Occurence] -> [Occurence] -> [Occurence]
`seqOccurences`) ([Occurence] -> [Occurence])
-> TermTypeM [Occurence] -> TermTypeM [Occurence]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SrcLoc -> PatternType -> Diet -> TermTypeM [Occurence]
consumeArg SrcLoc
argloc PatternType
argtype' (PatternType -> Diet
forall shape as. TypeBase shape as -> Diet
diet PatternType
tp1')

      SrcLoc -> Aliasing -> TermTypeM ()
checkIfConsumable SrcLoc
loc (Aliasing -> TermTypeM ()) -> Aliasing -> TermTypeM ()
forall a b. (a -> b) -> a -> b
$ (VName -> Alias) -> Names -> Aliasing
forall b a. Ord b => (a -> b) -> Set a -> Set b
S.map VName -> Alias
AliasBound (Names -> Aliasing) -> Names -> Aliasing
forall a b. (a -> b) -> a -> b
$ [Occurence] -> Names
allConsumed [Occurence]
occurs
      [Occurence] -> TermTypeM ()
occur [Occurence]
occurs

      (Maybe VName
argext, VName -> Maybe (Subst StructType)
parsubst) <-
        case PName
pname of
          Named VName
pname' -> do
            (DimDecl VName
d, Maybe VName
argext) <- PatternType -> Exp -> TermTypeM (DimDecl VName, Maybe VName)
forall dim as.
TypeBase dim as -> Exp -> TermTypeM (DimDecl VName, Maybe VName)
sizeSubst PatternType
tp1' Exp
argexp
            (Maybe VName, VName -> Maybe (Subst StructType))
-> TermTypeM (Maybe VName, VName -> Maybe (Subst StructType))
forall (m :: * -> *) a. Monad m => a -> m a
return
              ( Maybe VName
argext,
                (VName -> Map VName (Subst StructType) -> Maybe (Subst StructType)
forall k a. Ord k => k -> Map k a -> Maybe a
`M.lookup` VName -> Subst StructType -> Map VName (Subst StructType)
forall k a. k -> a -> Map k a
M.singleton VName
pname' (DimDecl VName -> Subst StructType
forall t. DimDecl VName -> Subst t
SizeSubst DimDecl VName
d))
              )
          PName
_ -> (Maybe VName, VName -> Maybe (Subst StructType))
-> TermTypeM (Maybe VName, VName -> Maybe (Subst StructType))
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe VName
forall a. Maybe a
Nothing, Maybe (Subst StructType) -> VName -> Maybe (Subst StructType)
forall a b. a -> b -> a
const Maybe (Subst StructType)
forall a. Maybe a
Nothing)
      let tp2'' :: PatternType
tp2'' = (VName -> Maybe (Subst StructType)) -> PatternType -> PatternType
forall a.
Substitutable a =>
(VName -> Maybe (Subst StructType)) -> a -> a
applySubst VName -> Maybe (Subst StructType)
parsubst (PatternType -> PatternType) -> PatternType -> PatternType
forall a b. (a -> b) -> a -> b
$ PatternType -> Diet -> PatternType -> PatternType
returnType PatternType
tp2' (PatternType -> Diet
forall shape as. TypeBase shape as -> Diet
diet PatternType
tp1') PatternType
argtype'

      (PatternType, PatternType, Maybe VName, [VName])
-> TermTypeM (PatternType, PatternType, Maybe VName, [VName])
forall (m :: * -> *) a. Monad m => a -> m a
return (PatternType
tp1', PatternType
tp2'', Maybe VName
argext, [VName]
ext)
    where
      sizeSubst :: TypeBase dim as -> Exp -> TermTypeM (DimDecl VName, Maybe VName)
sizeSubst (Scalar (Prim (Signed IntType
Int32))) Exp
e = Maybe (QualName VName)
-> Exp -> TermTypeM (DimDecl VName, Maybe VName)
dimFromArg Maybe (QualName VName)
fname Exp
e
      sizeSubst TypeBase dim as
_ Exp
_ = (DimDecl VName, Maybe VName)
-> TermTypeM (DimDecl VName, Maybe VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (DimDecl VName
forall vn. DimDecl vn
AnyDim, Maybe VName
forall a. Maybe a
Nothing)
checkApply SrcLoc
loc ApplyOp
fname tfun :: PatternType
tfun@(Scalar TypeVar {}) Arg
arg = do
  StructType
tv <- SrcLoc -> String -> TermTypeM StructType
forall (m :: * -> *) als dim.
(MonadUnify m, Monoid als) =>
SrcLoc -> String -> m (TypeBase dim als)
newTypeVar SrcLoc
loc String
"b"
  Usage -> StructType -> StructType -> TermTypeM ()
forall (m :: * -> *).
MonadUnify m =>
Usage -> StructType -> StructType -> m ()
unify (SrcLoc -> String -> Usage
mkUsage SrcLoc
loc String
"use as function") (PatternType -> StructType
forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct PatternType
tfun) (StructType -> TermTypeM ()) -> StructType -> TermTypeM ()
forall a b. (a -> b) -> a -> b
$
    ScalarTypeBase (DimDecl VName) () -> StructType
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (ScalarTypeBase (DimDecl VName) () -> StructType)
-> ScalarTypeBase (DimDecl VName) () -> StructType
forall a b. (a -> b) -> a -> b
$ ()
-> PName
-> StructType
-> StructType
-> ScalarTypeBase (DimDecl VName) ()
forall dim as.
as
-> PName
-> TypeBase dim as
-> TypeBase dim as
-> ScalarTypeBase dim as
Arrow ()
forall a. Monoid a => a
mempty PName
Unnamed (PatternType -> StructType
forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct (Arg -> PatternType
argType Arg
arg)) StructType
tv
  PatternType
tfun' <- PatternType -> TermTypeM PatternType
forall (m :: * -> *). MonadUnify m => PatternType -> m PatternType
normPatternType PatternType
tfun
  SrcLoc
-> ApplyOp
-> PatternType
-> Arg
-> TermTypeM (PatternType, PatternType, Maybe VName, [VName])
checkApply SrcLoc
loc ApplyOp
fname PatternType
tfun' Arg
arg
checkApply SrcLoc
loc (Maybe (QualName VName)
fname, Int
prev_applied) PatternType
ftype (Exp
argexp, PatternType
_, [Occurence]
_, SrcLoc
_) = do
  let fname' :: Doc
fname' = Doc -> (QualName VName -> Doc) -> Maybe (QualName VName) -> Doc
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Doc
"expression" (Doc -> Doc
pquote (Doc -> Doc) -> (QualName VName -> Doc) -> QualName VName -> Doc
forall b c a. (b -> c) -> (a -> b) -> a -> c
. QualName VName -> Doc
forall a. Pretty a => a -> Doc
ppr) Maybe (QualName VName)
fname

  SrcLoc
-> Notes
-> Doc
-> TermTypeM (PatternType, PatternType, Maybe VName, [VName])
forall (m :: * -> *) loc a.
(MonadTypeChecker m, Located loc) =>
loc -> Notes -> Doc -> m a
typeError SrcLoc
loc Notes
forall a. Monoid a => a
mempty (Doc -> TermTypeM (PatternType, PatternType, Maybe VName, [VName]))
-> Doc
-> TermTypeM (PatternType, PatternType, Maybe VName, [VName])
forall a b. (a -> b) -> a -> b
$
    if Int
prev_applied Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0
      then
        Doc
"Cannot apply" Doc -> Doc -> Doc
<+> Doc
fname' Doc -> Doc -> Doc
<+> Doc
"as function, as it has type:"
          Doc -> Doc -> Doc
</> Int -> Doc -> Doc
indent Int
2 (PatternType -> Doc
forall a. Pretty a => a -> Doc
ppr PatternType
ftype)
      else
        Doc
"Cannot apply" Doc -> Doc -> Doc
<+> Doc
fname' Doc -> Doc -> Doc
<+> Doc
"to argument #" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Int -> Doc
forall a. Pretty a => a -> Doc
ppr (Int
prev_applied Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
          Doc -> Doc -> Doc
<+> Doc -> Doc
pquote (String -> Doc
forall a. Pretty a => a -> Doc
shorten (String -> Doc) -> String -> Doc
forall a b. (a -> b) -> a -> b
$ Doc -> String
forall a. Pretty a => a -> String
pretty (Doc -> String) -> Doc -> String
forall a b. (a -> b) -> a -> b
$ Doc -> Doc
flatten (Doc -> Doc) -> Doc -> Doc
forall a b. (a -> b) -> a -> b
$ Exp -> Doc
forall a. Pretty a => a -> Doc
ppr Exp
argexp) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
","
          Doc -> Doc -> Doc
<+/> Doc
"as"
          Doc -> Doc -> Doc
<+> Doc
fname'
          Doc -> Doc -> Doc
<+> Doc
"only takes"
          Doc -> Doc -> Doc
<+> Int -> Doc
forall a. Pretty a => a -> Doc
ppr Int
prev_applied
          Doc -> Doc -> Doc
<+> Doc
arguments Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
"."
  where
    arguments :: Doc
arguments
      | Int
prev_applied Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1 = Doc
"argument"
      | Bool
otherwise = Doc
"arguments"

isInt32 :: Exp -> Maybe Int32
isInt32 :: Exp -> Maybe Int32
isInt32 (Literal (SignedValue (Int32Value Int32
k')) SrcLoc
_) = Int32 -> Maybe Int32
forall a. a -> Maybe a
Just (Int32 -> Maybe Int32) -> Int32 -> Maybe Int32
forall a b. (a -> b) -> a -> b
$ Int32 -> Int32
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int32
k'
isInt32 (IntLit Integer
k' Info PatternType
_ SrcLoc
_) = Int32 -> Maybe Int32
forall a. a -> Maybe a
Just (Int32 -> Maybe Int32) -> Int32 -> Maybe Int32
forall a b. (a -> b) -> a -> b
$ Integer -> Int32
forall a. Num a => Integer -> a
fromInteger Integer
k'
isInt32 (Negate Exp
x SrcLoc
_) = Int32 -> Int32
forall a. Num a => a -> a
negate (Int32 -> Int32) -> Maybe Int32 -> Maybe Int32
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp -> Maybe Int32
isInt32 Exp
x
isInt32 Exp
_ = Maybe Int32
forall a. Maybe a
Nothing

maybeDimFromExp :: Exp -> Maybe (DimDecl VName)
maybeDimFromExp :: Exp -> Maybe (DimDecl VName)
maybeDimFromExp (Var QualName VName
v Info PatternType
_ SrcLoc
_) = DimDecl VName -> Maybe (DimDecl VName)
forall a. a -> Maybe a
Just (DimDecl VName -> Maybe (DimDecl VName))
-> DimDecl VName -> Maybe (DimDecl VName)
forall a b. (a -> b) -> a -> b
$ QualName VName -> DimDecl VName
forall vn. QualName vn -> DimDecl vn
NamedDim QualName VName
v
maybeDimFromExp (Parens Exp
e SrcLoc
_) = Exp -> Maybe (DimDecl VName)
maybeDimFromExp Exp
e
maybeDimFromExp (QualParens (QualName VName, SrcLoc)
_ Exp
e SrcLoc
_) = Exp -> Maybe (DimDecl VName)
maybeDimFromExp Exp
e
maybeDimFromExp Exp
e = Int -> DimDecl VName
forall vn. Int -> DimDecl vn
ConstDim (Int -> DimDecl VName) -> (Int32 -> Int) -> Int32 -> DimDecl VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int32 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int32 -> DimDecl VName) -> Maybe Int32 -> Maybe (DimDecl VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp -> Maybe Int32
isInt32 Exp
e

dimFromExp :: (Exp -> SizeSource) -> Exp -> TermTypeM (DimDecl VName, Maybe VName)
dimFromExp :: (Exp -> SizeSource)
-> Exp -> TermTypeM (DimDecl VName, Maybe VName)
dimFromExp Exp -> SizeSource
rf (Parens Exp
e SrcLoc
_) = (Exp -> SizeSource)
-> Exp -> TermTypeM (DimDecl VName, Maybe VName)
dimFromExp Exp -> SizeSource
rf Exp
e
dimFromExp Exp -> SizeSource
rf (QualParens (QualName VName, SrcLoc)
_ Exp
e SrcLoc
_) = (Exp -> SizeSource)
-> Exp -> TermTypeM (DimDecl VName, Maybe VName)
dimFromExp Exp -> SizeSource
rf Exp
e
dimFromExp Exp -> SizeSource
rf Exp
e
  | Just DimDecl VName
d <- Exp -> Maybe (DimDecl VName)
maybeDimFromExp Exp
e =
    (DimDecl VName, Maybe VName)
-> TermTypeM (DimDecl VName, Maybe VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (DimDecl VName
d, Maybe VName
forall a. Maybe a
Nothing)
  | Bool
otherwise =
    SrcLoc -> SizeSource -> TermTypeM (DimDecl VName, Maybe VName)
extSize (Exp -> SrcLoc
forall a. Located a => a -> SrcLoc
srclocOf Exp
e) (SizeSource -> TermTypeM (DimDecl VName, Maybe VName))
-> SizeSource -> TermTypeM (DimDecl VName, Maybe VName)
forall a b. (a -> b) -> a -> b
$ Exp -> SizeSource
rf Exp
e

dimFromArg :: Maybe (QualName VName) -> Exp -> TermTypeM (DimDecl VName, Maybe VName)
dimFromArg :: Maybe (QualName VName)
-> Exp -> TermTypeM (DimDecl VName, Maybe VName)
dimFromArg Maybe (QualName VName)
fname = (Exp -> SizeSource)
-> Exp -> TermTypeM (DimDecl VName, Maybe VName)
dimFromExp ((Exp -> SizeSource)
 -> Exp -> TermTypeM (DimDecl VName, Maybe VName))
-> (Exp -> SizeSource)
-> Exp
-> TermTypeM (DimDecl VName, Maybe VName)
forall a b. (a -> b) -> a -> b
$ FName -> ExpBase NoInfo VName -> SizeSource
SourceArg (Maybe (QualName VName) -> FName
FName Maybe (QualName VName)
fname) (ExpBase NoInfo VName -> SizeSource)
-> (Exp -> ExpBase NoInfo VName) -> Exp -> SizeSource
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Exp -> ExpBase NoInfo VName
bareExp

-- | @returnType ret_type arg_diet arg_type@ gives result of applying
-- an argument the given types to a function with the given return
-- type, consuming the argument with the given diet.
returnType ::
  PatternType ->
  Diet ->
  PatternType ->
  PatternType
returnType :: PatternType -> Diet -> PatternType -> PatternType
returnType (Array Aliasing
_ Uniqueness
Unique ScalarTypeBase (DimDecl VName) ()
et ShapeDecl (DimDecl VName)
shape) Diet
_ PatternType
_ =
  Aliasing
-> Uniqueness
-> ScalarTypeBase (DimDecl VName) ()
-> ShapeDecl (DimDecl VName)
-> PatternType
forall dim as.
as
-> Uniqueness
-> ScalarTypeBase dim ()
-> ShapeDecl dim
-> TypeBase dim as
Array Aliasing
forall a. Monoid a => a
mempty Uniqueness
Unique ScalarTypeBase (DimDecl VName) ()
et ShapeDecl (DimDecl VName)
shape
returnType (Array Aliasing
als Uniqueness
Nonunique ScalarTypeBase (DimDecl VName) ()
et ShapeDecl (DimDecl VName)
shape) Diet
d PatternType
arg =
  Aliasing
-> Uniqueness
-> ScalarTypeBase (DimDecl VName) ()
-> ShapeDecl (DimDecl VName)
-> PatternType
forall dim as.
as
-> Uniqueness
-> ScalarTypeBase dim ()
-> ShapeDecl dim
-> TypeBase dim as
Array (Aliasing
als Aliasing -> Aliasing -> Aliasing
forall a. Semigroup a => a -> a -> a
<> Aliasing
arg_als) Uniqueness
Unique ScalarTypeBase (DimDecl VName) ()
et ShapeDecl (DimDecl VName)
shape -- Intentional!
  where
    arg_als :: Aliasing
arg_als = PatternType -> Aliasing
forall as shape. Monoid as => TypeBase shape as -> as
aliases (PatternType -> Aliasing) -> PatternType -> Aliasing
forall a b. (a -> b) -> a -> b
$ PatternType -> Diet -> PatternType
forall as shape.
Monoid as =>
TypeBase shape as -> Diet -> TypeBase shape as
maskAliases PatternType
arg Diet
d
returnType (Scalar (Record Map Name PatternType
fs)) Diet
d PatternType
arg =
  ScalarTypeBase (DimDecl VName) Aliasing -> PatternType
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (ScalarTypeBase (DimDecl VName) Aliasing -> PatternType)
-> ScalarTypeBase (DimDecl VName) Aliasing -> PatternType
forall a b. (a -> b) -> a -> b
$ Map Name PatternType -> ScalarTypeBase (DimDecl VName) Aliasing
forall dim as. Map Name (TypeBase dim as) -> ScalarTypeBase dim as
Record (Map Name PatternType -> ScalarTypeBase (DimDecl VName) Aliasing)
-> Map Name PatternType -> ScalarTypeBase (DimDecl VName) Aliasing
forall a b. (a -> b) -> a -> b
$ (PatternType -> PatternType)
-> Map Name PatternType -> Map Name PatternType
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\PatternType
et -> PatternType -> Diet -> PatternType -> PatternType
returnType PatternType
et Diet
d PatternType
arg) Map Name PatternType
fs
returnType (Scalar (Prim PrimType
t)) Diet
_ PatternType
_ =
  ScalarTypeBase (DimDecl VName) Aliasing -> PatternType
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (ScalarTypeBase (DimDecl VName) Aliasing -> PatternType)
-> ScalarTypeBase (DimDecl VName) Aliasing -> PatternType
forall a b. (a -> b) -> a -> b
$ PrimType -> ScalarTypeBase (DimDecl VName) Aliasing
forall dim as. PrimType -> ScalarTypeBase dim as
Prim PrimType
t
returnType (Scalar (TypeVar Aliasing
_ Uniqueness
Unique TypeName
t [TypeArg (DimDecl VName)]
targs)) Diet
_ PatternType
_ =
  ScalarTypeBase (DimDecl VName) Aliasing -> PatternType
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (ScalarTypeBase (DimDecl VName) Aliasing -> PatternType)
-> ScalarTypeBase (DimDecl VName) Aliasing -> PatternType
forall a b. (a -> b) -> a -> b
$ Aliasing
-> Uniqueness
-> TypeName
-> [TypeArg (DimDecl VName)]
-> ScalarTypeBase (DimDecl VName) Aliasing
forall dim as.
as
-> Uniqueness -> TypeName -> [TypeArg dim] -> ScalarTypeBase dim as
TypeVar Aliasing
forall a. Monoid a => a
mempty Uniqueness
Unique TypeName
t [TypeArg (DimDecl VName)]
targs
returnType (Scalar (TypeVar Aliasing
als Uniqueness
Nonunique TypeName
t [TypeArg (DimDecl VName)]
targs)) Diet
d PatternType
arg =
  ScalarTypeBase (DimDecl VName) Aliasing -> PatternType
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (ScalarTypeBase (DimDecl VName) Aliasing -> PatternType)
-> ScalarTypeBase (DimDecl VName) Aliasing -> PatternType
forall a b. (a -> b) -> a -> b
$ Aliasing
-> Uniqueness
-> TypeName
-> [TypeArg (DimDecl VName)]
-> ScalarTypeBase (DimDecl VName) Aliasing
forall dim as.
as
-> Uniqueness -> TypeName -> [TypeArg dim] -> ScalarTypeBase dim as
TypeVar (Aliasing
als Aliasing -> Aliasing -> Aliasing
forall a. Semigroup a => a -> a -> a
<> Aliasing
arg_als) Uniqueness
Unique TypeName
t [TypeArg (DimDecl VName)]
targs -- Intentional!
  where
    arg_als :: Aliasing
arg_als = PatternType -> Aliasing
forall as shape. Monoid as => TypeBase shape as -> as
aliases (PatternType -> Aliasing) -> PatternType -> Aliasing
forall a b. (a -> b) -> a -> b
$ PatternType -> Diet -> PatternType
forall as shape.
Monoid as =>
TypeBase shape as -> Diet -> TypeBase shape as
maskAliases PatternType
arg Diet
d
returnType (Scalar (Arrow Aliasing
old_als PName
v PatternType
t1 PatternType
t2)) Diet
d PatternType
arg =
  ScalarTypeBase (DimDecl VName) Aliasing -> PatternType
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (ScalarTypeBase (DimDecl VName) Aliasing -> PatternType)
-> ScalarTypeBase (DimDecl VName) Aliasing -> PatternType
forall a b. (a -> b) -> a -> b
$ Aliasing
-> PName
-> PatternType
-> PatternType
-> ScalarTypeBase (DimDecl VName) Aliasing
forall dim as.
as
-> PName
-> TypeBase dim as
-> TypeBase dim as
-> ScalarTypeBase dim as
Arrow Aliasing
als PName
v (PatternType
t1 PatternType -> Aliasing -> PatternType
forall dim asf ast. TypeBase dim asf -> ast -> TypeBase dim ast
`setAliases` Aliasing
forall a. Monoid a => a
mempty) (PatternType
t2 PatternType -> Aliasing -> PatternType
forall dim asf ast. TypeBase dim asf -> ast -> TypeBase dim ast
`setAliases` Aliasing
als)
  where
    -- Make sure to propagate the aliases of an existing closure.
    als :: Aliasing
als = Aliasing
old_als Aliasing -> Aliasing -> Aliasing
forall a. Semigroup a => a -> a -> a
<> PatternType -> Aliasing
forall as shape. Monoid as => TypeBase shape as -> as
aliases (PatternType -> Diet -> PatternType
forall as shape.
Monoid as =>
TypeBase shape as -> Diet -> TypeBase shape as
maskAliases PatternType
arg Diet
d)
returnType (Scalar (Sum Map Name [PatternType]
cs)) Diet
d PatternType
arg =
  ScalarTypeBase (DimDecl VName) Aliasing -> PatternType
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (ScalarTypeBase (DimDecl VName) Aliasing -> PatternType)
-> ScalarTypeBase (DimDecl VName) Aliasing -> PatternType
forall a b. (a -> b) -> a -> b
$ Map Name [PatternType] -> ScalarTypeBase (DimDecl VName) Aliasing
forall dim as. Map Name [TypeBase dim as] -> ScalarTypeBase dim as
Sum (Map Name [PatternType] -> ScalarTypeBase (DimDecl VName) Aliasing)
-> Map Name [PatternType]
-> ScalarTypeBase (DimDecl VName) Aliasing
forall a b. (a -> b) -> a -> b
$ (([PatternType] -> [PatternType])
-> Map Name [PatternType] -> Map Name [PatternType]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (([PatternType] -> [PatternType])
 -> Map Name [PatternType] -> Map Name [PatternType])
-> ((PatternType -> PatternType) -> [PatternType] -> [PatternType])
-> (PatternType -> PatternType)
-> Map Name [PatternType]
-> Map Name [PatternType]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PatternType -> PatternType) -> [PatternType] -> [PatternType]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap) (\PatternType
et -> PatternType -> Diet -> PatternType -> PatternType
returnType PatternType
et Diet
d PatternType
arg) Map Name [PatternType]
cs

-- | @t `maskAliases` d@ removes aliases (sets them to 'mempty') from
-- the parts of @t@ that are denoted as consumed by the 'Diet' @d@.
maskAliases ::
  Monoid as =>
  TypeBase shape as ->
  Diet ->
  TypeBase shape as
maskAliases :: TypeBase shape as -> Diet -> TypeBase shape as
maskAliases TypeBase shape as
t Diet
Consume = TypeBase shape as
t TypeBase shape as -> as -> TypeBase shape as
forall dim asf ast. TypeBase dim asf -> ast -> TypeBase dim ast
`setAliases` as
forall a. Monoid a => a
mempty
maskAliases TypeBase shape as
t Diet
Observe = TypeBase shape as
t
maskAliases (Scalar (Record Map Name (TypeBase shape as)
ets)) (RecordDiet Map Name Diet
ds) =
  ScalarTypeBase shape as -> TypeBase shape as
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (ScalarTypeBase shape as -> TypeBase shape as)
-> ScalarTypeBase shape as -> TypeBase shape as
forall a b. (a -> b) -> a -> b
$ Map Name (TypeBase shape as) -> ScalarTypeBase shape as
forall dim as. Map Name (TypeBase dim as) -> ScalarTypeBase dim as
Record (Map Name (TypeBase shape as) -> ScalarTypeBase shape as)
-> Map Name (TypeBase shape as) -> ScalarTypeBase shape as
forall a b. (a -> b) -> a -> b
$ (TypeBase shape as -> Diet -> TypeBase shape as)
-> Map Name (TypeBase shape as)
-> Map Name Diet
-> Map Name (TypeBase shape as)
forall k a b c.
Ord k =>
(a -> b -> c) -> Map k a -> Map k b -> Map k c
M.intersectionWith TypeBase shape as -> Diet -> TypeBase shape as
forall as shape.
Monoid as =>
TypeBase shape as -> Diet -> TypeBase shape as
maskAliases Map Name (TypeBase shape as)
ets Map Name Diet
ds
maskAliases TypeBase shape as
t FuncDiet {} = TypeBase shape as
t
maskAliases TypeBase shape as
_ Diet
_ = String -> TypeBase shape as
forall a. HasCallStack => String -> a
error String
"Invalid arguments passed to maskAliases."

consumeArg :: SrcLoc -> PatternType -> Diet -> TermTypeM [Occurence]
consumeArg :: SrcLoc -> PatternType -> Diet -> TermTypeM [Occurence]
consumeArg SrcLoc
loc (Scalar (Record Map Name PatternType
ets)) (RecordDiet Map Name Diet
ds) =
  [[Occurence]] -> [Occurence]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[Occurence]] -> [Occurence])
-> (Map Name [Occurence] -> [[Occurence]])
-> Map Name [Occurence]
-> [Occurence]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map Name [Occurence] -> [[Occurence]]
forall k a. Map k a -> [a]
M.elems (Map Name [Occurence] -> [Occurence])
-> TermTypeM (Map Name [Occurence]) -> TermTypeM [Occurence]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((PatternType, Diet) -> TermTypeM [Occurence])
-> Map Name (PatternType, Diet) -> TermTypeM (Map Name [Occurence])
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse ((PatternType -> Diet -> TermTypeM [Occurence])
-> (PatternType, Diet) -> TermTypeM [Occurence]
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ((PatternType -> Diet -> TermTypeM [Occurence])
 -> (PatternType, Diet) -> TermTypeM [Occurence])
-> (PatternType -> Diet -> TermTypeM [Occurence])
-> (PatternType, Diet)
-> TermTypeM [Occurence]
forall a b. (a -> b) -> a -> b
$ SrcLoc -> PatternType -> Diet -> TermTypeM [Occurence]
consumeArg SrcLoc
loc) ((PatternType -> Diet -> (PatternType, Diet))
-> Map Name PatternType
-> Map Name Diet
-> Map Name (PatternType, Diet)
forall k a b c.
Ord k =>
(a -> b -> c) -> Map k a -> Map k b -> Map k c
M.intersectionWith (,) Map Name PatternType
ets Map Name Diet
ds)
consumeArg SrcLoc
loc (Array Aliasing
_ Uniqueness
Nonunique ScalarTypeBase (DimDecl VName) ()
_ ShapeDecl (DimDecl VName)
_) Diet
Consume =
  SrcLoc -> Notes -> Doc -> TermTypeM [Occurence]
forall (m :: * -> *) loc a.
(MonadTypeChecker m, Located loc) =>
loc -> Notes -> Doc -> m a
typeError SrcLoc
loc Notes
forall a. Monoid a => a
mempty Doc
"Consuming parameter passed non-unique argument."
consumeArg SrcLoc
loc (Scalar (TypeVar Aliasing
_ Uniqueness
Nonunique TypeName
_ [TypeArg (DimDecl VName)]
_)) Diet
Consume =
  SrcLoc -> Notes -> Doc -> TermTypeM [Occurence]
forall (m :: * -> *) loc a.
(MonadTypeChecker m, Located loc) =>
loc -> Notes -> Doc -> m a
typeError SrcLoc
loc Notes
forall a. Monoid a => a
mempty Doc
"Consuming parameter passed non-unique argument."
consumeArg SrcLoc
loc (Scalar (Arrow Aliasing
_ PName
_ PatternType
t1 PatternType
_)) (FuncDiet Diet
d Diet
_)
  | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ PatternType -> Diet -> Bool
forall dim as. TypeBase dim as -> Diet -> Bool
contravariantArg PatternType
t1 Diet
d =
    SrcLoc -> Notes -> Doc -> TermTypeM [Occurence]
forall (m :: * -> *) loc a.
(MonadTypeChecker m, Located loc) =>
loc -> Notes -> Doc -> m a
typeError SrcLoc
loc Notes
forall a. Monoid a => a
mempty Doc
"Non-consuming higher-order parameter passed consuming argument."
  where
    contravariantArg :: TypeBase dim as -> Diet -> Bool
contravariantArg (Array as
_ Uniqueness
Unique ScalarTypeBase dim ()
_ ShapeDecl dim
_) Diet
Observe =
      Bool
False
    contravariantArg (Scalar (TypeVar as
_ Uniqueness
Unique TypeName
_ [TypeArg dim]
_)) Diet
Observe =
      Bool
False
    contravariantArg (Scalar (Record Map Name (TypeBase dim as)
ets)) (RecordDiet Map Name Diet
ds) =
      Map Name Bool -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and ((TypeBase dim as -> Diet -> Bool)
-> Map Name (TypeBase dim as) -> Map Name Diet -> Map Name Bool
forall k a b c.
Ord k =>
(a -> b -> c) -> Map k a -> Map k b -> Map k c
M.intersectionWith TypeBase dim as -> Diet -> Bool
contravariantArg Map Name (TypeBase dim as)
ets Map Name Diet
ds)
    contravariantArg (Scalar (Arrow as
_ PName
_ TypeBase dim as
tp TypeBase dim as
tr)) (FuncDiet Diet
dp Diet
dr) =
      TypeBase dim as -> Diet -> Bool
contravariantArg TypeBase dim as
tp Diet
dp Bool -> Bool -> Bool
&& TypeBase dim as -> Diet -> Bool
contravariantArg TypeBase dim as
tr Diet
dr
    contravariantArg TypeBase dim as
_ Diet
_ =
      Bool
True
consumeArg SrcLoc
loc (Scalar (Arrow Aliasing
_ PName
_ PatternType
_ PatternType
t2)) (FuncDiet Diet
_ Diet
pd) =
  SrcLoc -> PatternType -> Diet -> TermTypeM [Occurence]
consumeArg SrcLoc
loc PatternType
t2 Diet
pd
consumeArg SrcLoc
loc PatternType
at Diet
Consume = [Occurence] -> TermTypeM [Occurence]
forall (m :: * -> *) a. Monad m => a -> m a
return [Aliasing -> SrcLoc -> Occurence
consumption (PatternType -> Aliasing
forall as shape. Monoid as => TypeBase shape as -> as
aliases PatternType
at) SrcLoc
loc]
consumeArg SrcLoc
loc PatternType
at Diet
_ = [Occurence] -> TermTypeM [Occurence]
forall (m :: * -> *) a. Monad m => a -> m a
return [Aliasing -> SrcLoc -> Occurence
observation (PatternType -> Aliasing
forall as shape. Monoid as => TypeBase shape as -> as
aliases PatternType
at) SrcLoc
loc]

-- | Type-check a single expression in isolation.  This expression may
-- turn out to be polymorphic, in which case the list of type
-- parameters will be non-empty.
checkOneExp :: UncheckedExp -> TypeM ([TypeParam], Exp)
checkOneExp :: ExpBase NoInfo Name -> TypeM ([TypeParam], Exp)
checkOneExp ExpBase NoInfo Name
e = ((([TypeParam], Exp), [Occurence]) -> ([TypeParam], Exp))
-> TypeM (([TypeParam], Exp), [Occurence])
-> TypeM ([TypeParam], Exp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (([TypeParam], Exp), [Occurence]) -> ([TypeParam], Exp)
forall a b. (a, b) -> a
fst (TypeM (([TypeParam], Exp), [Occurence])
 -> TypeM ([TypeParam], Exp))
-> (TermTypeM ([TypeParam], Exp)
    -> TypeM (([TypeParam], Exp), [Occurence]))
-> TermTypeM ([TypeParam], Exp)
-> TypeM ([TypeParam], Exp)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TermTypeM ([TypeParam], Exp)
-> TypeM (([TypeParam], Exp), [Occurence])
forall a. TermTypeM a -> TypeM (a, [Occurence])
runTermTypeM (TermTypeM ([TypeParam], Exp) -> TypeM ([TypeParam], Exp))
-> TermTypeM ([TypeParam], Exp) -> TypeM ([TypeParam], Exp)
forall a b. (a -> b) -> a -> b
$ do
  Exp
e' <- ExpBase NoInfo Name -> TermTypeM Exp
checkExp ExpBase NoInfo Name
e
  let t :: StructType
t = PatternType -> StructType
forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct (PatternType -> StructType) -> PatternType -> StructType
forall a b. (a -> b) -> a -> b
$ Exp -> PatternType
typeOf Exp
e'
  ([TypeParam]
tparams, [Pattern]
_, StructType
_, [VName]
_) <-
    Name
-> SrcLoc
-> [TypeParam]
-> [Pattern]
-> StructType
-> TermTypeM ([TypeParam], [Pattern], StructType, [VName])
letGeneralise (String -> Name
nameFromString String
"<exp>") (ExpBase NoInfo Name -> SrcLoc
forall a. Located a => a -> SrcLoc
srclocOf ExpBase NoInfo Name
e) [] [] StructType
t
  Names -> TermTypeM ()
fixOverloadedTypes (Names -> TermTypeM ()) -> Names -> TermTypeM ()
forall a b. (a -> b) -> a -> b
$ StructType -> Names
forall as dim. Monoid as => TypeBase dim as -> Names
typeVars StructType
t
  Exp
e'' <- Exp -> TermTypeM Exp
forall e. ASTMappable e => e -> TermTypeM e
updateTypes Exp
e'
  Exp -> TermTypeM ()
checkUnmatched Exp
e''
  Exp -> TermTypeM ()
causalityCheck Exp
e''
  Exp -> TermTypeM ()
literalOverflowCheck Exp
e''
  ([TypeParam], Exp) -> TermTypeM ([TypeParam], Exp)
forall (m :: * -> *) a. Monad m => a -> m a
return ([TypeParam]
tparams, Exp
e'')

-- Verify that all sum type constructors and empty array literals have
-- a size that is known (rigid or a type parameter).  This is to
-- ensure that we can actually determine their shape at run-time.
causalityCheck :: Exp -> TermTypeM ()
causalityCheck :: Exp -> TermTypeM ()
causalityCheck Exp
binding_body = do
  Constraints
constraints <- TermTypeM Constraints
forall (m :: * -> *). MonadUnify m => m Constraints
getConstraints

  let checkCausality :: Doc
-> Names
-> TypeBase (DimDecl VName) as
-> SrcLoc
-> Maybe (t (Either TypeError) a)
checkCausality Doc
what Names
known TypeBase (DimDecl VName) as
t SrcLoc
loc
        | (VName
d, SrcLoc
dloc) : [(VName, SrcLoc)]
_ <-
            (VName -> Maybe (VName, SrcLoc)) -> [VName] -> [(VName, SrcLoc)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (Constraints -> Names -> VName -> Maybe (VName, SrcLoc)
forall a a.
Ord a =>
Map a (a, Constraint) -> Set a -> a -> Maybe (a, SrcLoc)
unknown Constraints
constraints Names
known) ([VName] -> [(VName, SrcLoc)]) -> [VName] -> [(VName, SrcLoc)]
forall a b. (a -> b) -> a -> b
$
              Names -> [VName]
forall a. Set a -> [a]
S.toList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ StructType -> Names
forall als. TypeBase (DimDecl VName) als -> Names
typeDimNames (StructType -> Names) -> StructType -> Names
forall a b. (a -> b) -> a -> b
$ TypeBase (DimDecl VName) as -> StructType
forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct TypeBase (DimDecl VName) as
t =
          t (Either TypeError) a -> Maybe (t (Either TypeError) a)
forall a. a -> Maybe a
Just (t (Either TypeError) a -> Maybe (t (Either TypeError) a))
-> t (Either TypeError) a -> Maybe (t (Either TypeError) a)
forall a b. (a -> b) -> a -> b
$ Either TypeError a -> t (Either TypeError) a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (Either TypeError a -> t (Either TypeError) a)
-> Either TypeError a -> t (Either TypeError) a
forall a b. (a -> b) -> a -> b
$ Doc
-> SrcLoc
-> VName
-> SrcLoc
-> TypeBase (DimDecl VName) as
-> Either TypeError a
forall v a b b.
(IsName v, Pretty a, Located b) =>
Doc -> SrcLoc -> v -> b -> a -> Either TypeError b
causality Doc
what SrcLoc
loc VName
d SrcLoc
dloc TypeBase (DimDecl VName) as
t
        | Bool
otherwise = Maybe (t (Either TypeError) a)
forall a. Maybe a
Nothing

      checkParamCausality :: Names -> Pattern -> Maybe (t (Either TypeError) a)
checkParamCausality Names
known Pattern
p =
        Doc
-> Names -> PatternType -> SrcLoc -> Maybe (t (Either TypeError) a)
forall (t :: (* -> *) -> * -> *) as a.
MonadTrans t =>
Doc
-> Names
-> TypeBase (DimDecl VName) as
-> SrcLoc
-> Maybe (t (Either TypeError) a)
checkCausality (Pattern -> Doc
forall a. Pretty a => a -> Doc
ppr Pattern
p) Names
known (Pattern -> PatternType
patternType Pattern
p) (Pattern -> SrcLoc
forall a. Located a => a -> SrcLoc
srclocOf Pattern
p)

      onExp ::
        S.Set VName ->
        Exp ->
        StateT (S.Set VName) (Either TypeError) Exp

      onExp :: Names -> Exp -> StateT Names (Either TypeError) Exp
onExp Names
known (Var QualName VName
v (Info PatternType
t) SrcLoc
loc)
        | Just StateT Names (Either TypeError) Exp
bad <- Doc
-> Names
-> PatternType
-> SrcLoc
-> Maybe (StateT Names (Either TypeError) Exp)
forall (t :: (* -> *) -> * -> *) as a.
MonadTrans t =>
Doc
-> Names
-> TypeBase (DimDecl VName) as
-> SrcLoc
-> Maybe (t (Either TypeError) a)
checkCausality (Doc -> Doc
pquote (QualName VName -> Doc
forall a. Pretty a => a -> Doc
ppr QualName VName
v)) Names
known PatternType
t SrcLoc
loc =
          StateT Names (Either TypeError) Exp
bad
      onExp Names
known (ArrayLit [] (Info PatternType
t) SrcLoc
loc)
        | Just StateT Names (Either TypeError) Exp
bad <- Doc
-> Names
-> PatternType
-> SrcLoc
-> Maybe (StateT Names (Either TypeError) Exp)
forall (t :: (* -> *) -> * -> *) as a.
MonadTrans t =>
Doc
-> Names
-> TypeBase (DimDecl VName) as
-> SrcLoc
-> Maybe (t (Either TypeError) a)
checkCausality Doc
"empty array" Names
known PatternType
t SrcLoc
loc =
          StateT Names (Either TypeError) Exp
bad
      onExp Names
known (Lambda [Pattern]
params Exp
_ Maybe (TypeExp VName)
_ Info (Aliasing, StructType)
_ SrcLoc
_)
        | StateT Names (Either TypeError) Exp
bad : [StateT Names (Either TypeError) Exp]
_ <- (Pattern -> Maybe (StateT Names (Either TypeError) Exp))
-> [Pattern] -> [StateT Names (Either TypeError) Exp]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (Names -> Pattern -> Maybe (StateT Names (Either TypeError) Exp)
forall (t :: (* -> *) -> * -> *) a.
MonadTrans t =>
Names -> Pattern -> Maybe (t (Either TypeError) a)
checkParamCausality Names
known) [Pattern]
params =
          StateT Names (Either TypeError) Exp
bad
      onExp Names
known e :: Exp
e@(Coerce Exp
what TypeDeclBase Info VName
_ (Info PatternType
_, Info [VName]
ext) SrcLoc
_) = do
        (Names -> Names) -> StateT Names (Either TypeError) ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ([VName] -> Names
forall a. Ord a => [a] -> Set a
S.fromList [VName]
ext Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<>)
        StateT Names (Either TypeError) Exp
-> StateT Names (Either TypeError) ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (StateT Names (Either TypeError) Exp
 -> StateT Names (Either TypeError) ())
-> StateT Names (Either TypeError) Exp
-> StateT Names (Either TypeError) ()
forall a b. (a -> b) -> a -> b
$ Names -> Exp -> StateT Names (Either TypeError) Exp
onExp Names
known Exp
what
        Exp -> StateT Names (Either TypeError) Exp
forall (m :: * -> *) a. Monad m => a -> m a
return Exp
e
      onExp Names
known e :: Exp
e@(LetPat Pattern
_ Exp
bindee_e Exp
body_e (Info PatternType
_, Info [VName]
ext) SrcLoc
_) = do
        Names
-> Exp -> Exp -> [VName] -> StateT Names (Either TypeError) ()
sequencePoint Names
known Exp
bindee_e Exp
body_e [VName]
ext
        Exp -> StateT Names (Either TypeError) Exp
forall (m :: * -> *) a. Monad m => a -> m a
return Exp
e
      onExp Names
known e :: Exp
e@(Apply Exp
f Exp
arg (Info (Diet
_, Maybe VName
p)) (Info PatternType
_, Info [VName]
ext) SrcLoc
_) = do
        Names
-> Exp -> Exp -> [VName] -> StateT Names (Either TypeError) ()
sequencePoint Names
known Exp
arg Exp
f ([VName] -> StateT Names (Either TypeError) ())
-> [VName] -> StateT Names (Either TypeError) ()
forall a b. (a -> b) -> a -> b
$ Maybe VName -> [VName]
forall a. Maybe a -> [a]
maybeToList Maybe VName
p [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
ext
        Exp -> StateT Names (Either TypeError) Exp
forall (m :: * -> *) a. Monad m => a -> m a
return Exp
e
      onExp
        Names
known
        e :: Exp
e@( BinOp
              (QualName VName
f, SrcLoc
floc)
              Info PatternType
ft
              (Exp
x, Info (StructType
_, Maybe VName
xp))
              (Exp
y, Info (StructType
_, Maybe VName
yp))
              Info PatternType
_
              (Info [VName]
ext)
              SrcLoc
_
            ) = do
          Names
args_known <-
            Either TypeError Names -> StateT Names (Either TypeError) Names
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (Either TypeError Names -> StateT Names (Either TypeError) Names)
-> Either TypeError Names -> StateT Names (Either TypeError) Names
forall a b. (a -> b) -> a -> b
$
              StateT Names (Either TypeError) ()
-> Names -> Either TypeError Names
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m s
execStateT (Names
-> Exp -> Exp -> [VName] -> StateT Names (Either TypeError) ()
sequencePoint Names
known Exp
x Exp
y ([VName] -> StateT Names (Either TypeError) ())
-> [VName] -> StateT Names (Either TypeError) ()
forall a b. (a -> b) -> a -> b
$ [Maybe VName] -> [VName]
forall a. [Maybe a] -> [a]
catMaybes [Maybe VName
xp, Maybe VName
yp]) Names
forall a. Monoid a => a
mempty
          StateT Names (Either TypeError) Exp
-> StateT Names (Either TypeError) ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (StateT Names (Either TypeError) Exp
 -> StateT Names (Either TypeError) ())
-> StateT Names (Either TypeError) Exp
-> StateT Names (Either TypeError) ()
forall a b. (a -> b) -> a -> b
$ Names -> Exp -> StateT Names (Either TypeError) Exp
onExp (Names
args_known Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
known) (QualName VName -> Info PatternType -> SrcLoc -> Exp
forall (f :: * -> *) vn.
QualName vn -> f PatternType -> SrcLoc -> ExpBase f vn
Var QualName VName
f Info PatternType
ft SrcLoc
floc)
          (Names -> Names) -> StateT Names (Either TypeError) ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((Names
args_known Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> [VName] -> Names
forall a. Ord a => [a] -> Set a
S.fromList [VName]
ext) Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<>)
          Exp -> StateT Names (Either TypeError) Exp
forall (m :: * -> *) a. Monad m => a -> m a
return Exp
e
      onExp Names
known Exp
e = do
        Names -> Exp -> StateT Names (Either TypeError) ()
forall a.
ASTMappable a =>
Names -> a -> StateT Names (Either TypeError) ()
recurse Names
known Exp
e

        case Exp
e of
          DoLoop [VName]
_ Pattern
_ Exp
_ LoopFormBase Info VName
_ Exp
_ (Info (PatternType
_, [VName]
ext)) SrcLoc
_ ->
            (Names -> Names) -> StateT Names (Either TypeError) ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> [VName] -> Names
forall a. Ord a => [a] -> Set a
S.fromList [VName]
ext)
          If Exp
_ Exp
_ Exp
_ (Info PatternType
_, Info [VName]
ext) SrcLoc
_ ->
            (Names -> Names) -> StateT Names (Either TypeError) ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> [VName] -> Names
forall a. Ord a => [a] -> Set a
S.fromList [VName]
ext)
          Index Exp
_ [DimIndex]
_ (Info PatternType
_, Info [VName]
ext) SrcLoc
_ ->
            (Names -> Names) -> StateT Names (Either TypeError) ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> [VName] -> Names
forall a. Ord a => [a] -> Set a
S.fromList [VName]
ext)
          Match Exp
_ NonEmpty (CaseBase Info VName)
_ (Info PatternType
_, Info [VName]
ext) SrcLoc
_ ->
            (Names -> Names) -> StateT Names (Either TypeError) ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> [VName] -> Names
forall a. Ord a => [a] -> Set a
S.fromList [VName]
ext)
          Range Exp
_ Maybe Exp
_ Inclusiveness Exp
_ (Info PatternType
_, Info [VName]
ext) SrcLoc
_ ->
            (Names -> Names) -> StateT Names (Either TypeError) ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> [VName] -> Names
forall a. Ord a => [a] -> Set a
S.fromList [VName]
ext)
          Exp
_ ->
            () -> StateT Names (Either TypeError) ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

        Exp -> StateT Names (Either TypeError) Exp
forall (m :: * -> *) a. Monad m => a -> m a
return Exp
e

      recurse :: Names -> a -> StateT Names (Either TypeError) ()
recurse Names
known = StateT Names (Either TypeError) a
-> StateT Names (Either TypeError) ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (StateT Names (Either TypeError) a
 -> StateT Names (Either TypeError) ())
-> (a -> StateT Names (Either TypeError) a)
-> a
-> StateT Names (Either TypeError) ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ASTMapper (StateT Names (Either TypeError))
-> a -> StateT Names (Either TypeError) a
forall x (m :: * -> *).
(ASTMappable x, Monad m) =>
ASTMapper m -> x -> m x
astMap ASTMapper (StateT Names (Either TypeError))
mapper
        where
          mapper :: ASTMapper (StateT Names (Either TypeError))
mapper = ASTMapper (StateT Names (Either TypeError))
forall (m :: * -> *). Monad m => ASTMapper m
identityMapper {mapOnExp :: Exp -> StateT Names (Either TypeError) Exp
mapOnExp = Names -> Exp -> StateT Names (Either TypeError) Exp
onExp Names
known}

      sequencePoint :: Names
-> Exp -> Exp -> [VName] -> StateT Names (Either TypeError) ()
sequencePoint Names
known Exp
x Exp
y [VName]
ext = do
        Names
new_known <- Either TypeError Names -> StateT Names (Either TypeError) Names
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (Either TypeError Names -> StateT Names (Either TypeError) Names)
-> Either TypeError Names -> StateT Names (Either TypeError) Names
forall a b. (a -> b) -> a -> b
$ StateT Names (Either TypeError) Exp
-> Names -> Either TypeError Names
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m s
execStateT (Names -> Exp -> StateT Names (Either TypeError) Exp
onExp Names
known Exp
x) Names
forall a. Monoid a => a
mempty
        StateT Names (Either TypeError) Exp
-> StateT Names (Either TypeError) ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (StateT Names (Either TypeError) Exp
 -> StateT Names (Either TypeError) ())
-> StateT Names (Either TypeError) Exp
-> StateT Names (Either TypeError) ()
forall a b. (a -> b) -> a -> b
$ Names -> Exp -> StateT Names (Either TypeError) Exp
onExp (Names
new_known Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
known) Exp
y
        (Names -> Names) -> StateT Names (Either TypeError) ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((Names
new_known Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> [VName] -> Names
forall a. Ord a => [a] -> Set a
S.fromList [VName]
ext) Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<>)

  (TypeError -> TermTypeM ())
-> (Exp -> TermTypeM ()) -> Either TypeError Exp -> TermTypeM ()
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either TypeError -> TermTypeM ()
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (TermTypeM () -> Exp -> TermTypeM ()
forall a b. a -> b -> a
const (TermTypeM () -> Exp -> TermTypeM ())
-> TermTypeM () -> Exp -> TermTypeM ()
forall a b. (a -> b) -> a -> b
$ () -> TermTypeM ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()) (Either TypeError Exp -> TermTypeM ())
-> Either TypeError Exp -> TermTypeM ()
forall a b. (a -> b) -> a -> b
$
    StateT Names (Either TypeError) Exp
-> Names -> Either TypeError Exp
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT (Names -> Exp -> StateT Names (Either TypeError) Exp
onExp Names
forall a. Monoid a => a
mempty Exp
binding_body) Names
forall a. Monoid a => a
mempty
  where
    unknown :: Map a (a, Constraint) -> Set a -> a -> Maybe (a, SrcLoc)
unknown Map a (a, Constraint)
constraints Set a
known a
v = do
      Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ a
v a -> Set a -> Bool
forall a. Ord a => a -> Set a -> Bool
`S.notMember` Set a
known
      SrcLoc
loc <- Map a (a, Constraint) -> a -> Maybe SrcLoc
forall k a. Ord k => Map k (a, Constraint) -> k -> Maybe SrcLoc
unknowable Map a (a, Constraint)
constraints a
v
      (a, SrcLoc) -> Maybe (a, SrcLoc)
forall (m :: * -> *) a. Monad m => a -> m a
return (a
v, SrcLoc
loc)

    unknowable :: Map k (a, Constraint) -> k -> Maybe SrcLoc
unknowable Map k (a, Constraint)
constraints k
v =
      case (a, Constraint) -> Constraint
forall a b. (a, b) -> b
snd ((a, Constraint) -> Constraint)
-> Maybe (a, Constraint) -> Maybe Constraint
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> k -> Map k (a, Constraint) -> Maybe (a, Constraint)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup k
v Map k (a, Constraint)
constraints of
        Just (UnknowableSize SrcLoc
loc RigidSource
_) -> SrcLoc -> Maybe SrcLoc
forall a. a -> Maybe a
Just SrcLoc
loc
        Maybe Constraint
_ -> Maybe SrcLoc
forall a. Maybe a
Nothing

    causality :: Doc -> SrcLoc -> v -> b -> a -> Either TypeError b
causality Doc
what SrcLoc
loc v
d b
dloc a
t =
      TypeError -> Either TypeError b
forall a b. a -> Either a b
Left (TypeError -> Either TypeError b)
-> TypeError -> Either TypeError b
forall a b. (a -> b) -> a -> b
$
        SrcLoc -> Notes -> Doc -> TypeError
TypeError SrcLoc
loc Notes
forall a. Monoid a => a
mempty (Doc -> TypeError) -> Doc -> TypeError
forall a b. (a -> b) -> a -> b
$
          Doc
"Causality check: size" Doc -> Doc -> Doc
<+/> Doc -> Doc
pquote (v -> Doc
forall v. IsName v => v -> Doc
pprName v
d)
            Doc -> Doc -> Doc
<+/> Doc
"needed for type of"
            Doc -> Doc -> Doc
<+> Doc
what Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
colon
            Doc -> Doc -> Doc
</> Int -> Doc -> Doc
indent Int
2 (a -> Doc
forall a. Pretty a => a -> Doc
ppr a
t)
            Doc -> Doc -> Doc
</> Doc
"But"
            Doc -> Doc -> Doc
<+> Doc -> Doc
pquote (v -> Doc
forall v. IsName v => v -> Doc
pprName v
d)
            Doc -> Doc -> Doc
<+> Doc
"is computed at"
            Doc -> Doc -> Doc
<+/> String -> Doc
text (SrcLoc -> b -> String
forall a b. (Located a, Located b) => a -> b -> String
locStrRel SrcLoc
loc b
dloc) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
"."
            Doc -> Doc -> Doc
</> Doc
""
            Doc -> Doc -> Doc
</> Doc
"Hint:"
            Doc -> Doc -> Doc
<+> Doc -> Doc
align
              ( String -> Doc
textwrap String
"Bind the expression producing" Doc -> Doc -> Doc
<+> Doc -> Doc
pquote (v -> Doc
forall v. IsName v => v -> Doc
pprName v
d)
                  Doc -> Doc -> Doc
<+> Doc
"with 'let' beforehand."
              )

-- | Traverse the expression, emitting warnings if any of the literals overflow
-- their inferred types
--
-- Note: currently unable to detect float underflow (such as 1e-400 -> 0)
literalOverflowCheck :: Exp -> TermTypeM ()
literalOverflowCheck :: Exp -> TermTypeM ()
literalOverflowCheck = TermTypeM Exp -> TermTypeM ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (TermTypeM Exp -> TermTypeM ())
-> (Exp -> TermTypeM Exp) -> Exp -> TermTypeM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Exp -> TermTypeM Exp
forall (f :: * -> *). MonadTypeChecker f => Exp -> f Exp
check
  where
    check :: Exp -> f Exp
check e :: Exp
e@(IntLit Integer
x Info PatternType
ty SrcLoc
loc) =
      Exp
e Exp -> f () -> f Exp
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ case Info PatternType
ty of
        Info (Scalar (Prim PrimType
t)) -> Bool -> Integer -> PrimType -> SrcLoc -> f ()
forall (f :: * -> *) loc a a.
(MonadTypeChecker f, Located loc, Pretty a, Pretty a) =>
Bool -> a -> a -> loc -> f ()
warnBounds (Integer -> PrimType -> Bool
forall a. Integral a => a -> PrimType -> Bool
inBoundsI Integer
x PrimType
t) Integer
x PrimType
t SrcLoc
loc
        Info PatternType
_ -> String -> f ()
forall a. HasCallStack => String -> a
error String
"Inferred type of int literal is not a number"
    check e :: Exp
e@(FloatLit Double
x Info PatternType
ty SrcLoc
loc) =
      Exp
e Exp -> f () -> f Exp
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ case Info PatternType
ty of
        Info (Scalar (Prim (FloatType FloatType
t))) -> Bool -> Double -> FloatType -> SrcLoc -> f ()
forall (f :: * -> *) loc a a.
(MonadTypeChecker f, Located loc, Pretty a, Pretty a) =>
Bool -> a -> a -> loc -> f ()
warnBounds (Double -> FloatType -> Bool
forall a. RealFloat a => a -> FloatType -> Bool
inBoundsF Double
x FloatType
t) Double
x FloatType
t SrcLoc
loc
        Info PatternType
_ -> String -> f ()
forall a. HasCallStack => String -> a
error String
"Inferred type of float literal is not a float"
    check e :: Exp
e@(Negate (IntLit Integer
x Info PatternType
ty SrcLoc
loc1) SrcLoc
loc2) =
      Exp
e Exp -> f () -> f Exp
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ case Info PatternType
ty of
        Info (Scalar (Prim PrimType
t)) -> Bool -> Integer -> PrimType -> SrcLoc -> f ()
forall (f :: * -> *) loc a a.
(MonadTypeChecker f, Located loc, Pretty a, Pretty a) =>
Bool -> a -> a -> loc -> f ()
warnBounds (Integer -> PrimType -> Bool
forall a. Integral a => a -> PrimType -> Bool
inBoundsI (- Integer
x) PrimType
t) (- Integer
x) PrimType
t (SrcLoc
loc1 SrcLoc -> SrcLoc -> SrcLoc
forall a. Semigroup a => a -> a -> a
<> SrcLoc
loc2)
        Info PatternType
_ -> String -> f ()
forall a. HasCallStack => String -> a
error String
"Inferred type of int literal is not a number"
    check Exp
e = ASTMapper f -> Exp -> f Exp
forall x (m :: * -> *).
(ASTMappable x, Monad m) =>
ASTMapper m -> x -> m x
astMap ASTMapper f
forall (m :: * -> *). Monad m => ASTMapper m
identityMapper {mapOnExp :: Exp -> f Exp
mapOnExp = Exp -> f Exp
check} Exp
e
    bitWidth :: IntType -> Int
bitWidth IntType
ty = Int
8 Int -> Int -> Int
forall a. Num a => a -> a -> a
* IntType -> Int
forall a. Num a => IntType -> a
intByteSize IntType
ty :: Int
    inBoundsI :: a -> PrimType -> Bool
inBoundsI a
x (Signed IntType
t) = a
x a -> a -> Bool
forall a. Ord a => a -> a -> Bool
>= -a
2 a -> Int -> a
forall a b. (Num a, Integral b) => a -> b -> a
^ (IntType -> Int
bitWidth IntType
t Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Bool -> Bool -> Bool
&& a
x a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
2 a -> Int -> a
forall a b. (Num a, Integral b) => a -> b -> a
^ (IntType -> Int
bitWidth IntType
t Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
    inBoundsI a
x (Unsigned IntType
t) = a
x a -> a -> Bool
forall a. Ord a => a -> a -> Bool
>= a
0 Bool -> Bool -> Bool
&& a
x a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
2 a -> Int -> a
forall a b. (Num a, Integral b) => a -> b -> a
^ IntType -> Int
bitWidth IntType
t
    inBoundsI a
x (FloatType FloatType
Float32) = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Float -> Bool
forall a. RealFloat a => a -> Bool
isInfinite (a -> Float
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
x :: Float)
    inBoundsI a
x (FloatType FloatType
Float64) = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Double -> Bool
forall a. RealFloat a => a -> Bool
isInfinite (a -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
x :: Double)
    inBoundsI a
_ PrimType
Bool = String -> Bool
forall a. HasCallStack => String -> a
error String
"Inferred type of int literal is not a number"
    inBoundsF :: a -> FloatType -> Bool
inBoundsF a
x FloatType
Float32 = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Float -> Bool
forall a. RealFloat a => a -> Bool
isInfinite (a -> Float
forall a b. (Real a, Fractional b) => a -> b
realToFrac a
x :: Float)
    inBoundsF a
x FloatType
Float64 = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ a -> Bool
forall a. RealFloat a => a -> Bool
isInfinite a
x
    warnBounds :: Bool -> a -> a -> loc -> f ()
warnBounds Bool
inBounds a
x a
ty loc
loc =
      Bool -> f () -> f ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
inBounds (f () -> f ()) -> f () -> f ()
forall a b. (a -> b) -> a -> b
$
        loc -> Notes -> Doc -> f ()
forall (m :: * -> *) loc a.
(MonadTypeChecker m, Located loc) =>
loc -> Notes -> Doc -> m a
typeError loc
loc Notes
forall a. Monoid a => a
mempty (Doc -> f ()) -> Doc -> f ()
forall a b. (a -> b) -> a -> b
$
          Doc
"Literal " Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> a -> Doc
forall a. Pretty a => a -> Doc
ppr a
x
            Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
" out of bounds for inferred type "
            Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> a -> Doc
forall a. Pretty a => a -> Doc
ppr a
ty
            Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
"."

-- | Type-check a top-level (or module-level) function definition.
-- Despite the name, this is also used for checking constant
-- definitions, by treating them as 0-ary functions.
checkFunDef ::
  ( Name,
    Maybe UncheckedTypeExp,
    [UncheckedTypeParam],
    [UncheckedPattern],
    UncheckedExp,
    SrcLoc
  ) ->
  TypeM
    ( VName,
      [TypeParam],
      [Pattern],
      Maybe (TypeExp VName),
      StructType,
      [VName],
      Exp
    )
checkFunDef :: (Name, Maybe (TypeExp Name), [UncheckedTypeParam],
 [UncheckedPattern], ExpBase NoInfo Name, SrcLoc)
-> TypeM
     (VName, [TypeParam], [Pattern], Maybe (TypeExp VName), StructType,
      [VName], Exp)
checkFunDef (Name
fname, Maybe (TypeExp Name)
maybe_retdecl, [UncheckedTypeParam]
tparams, [UncheckedPattern]
params, ExpBase NoInfo Name
body, SrcLoc
loc) =
  (((VName, [TypeParam], [Pattern], Maybe (TypeExp VName),
   StructType, [VName], Exp),
  [Occurence])
 -> (VName, [TypeParam], [Pattern], Maybe (TypeExp VName),
     StructType, [VName], Exp))
-> TypeM
     ((VName, [TypeParam], [Pattern], Maybe (TypeExp VName), StructType,
       [VName], Exp),
      [Occurence])
-> TypeM
     (VName, [TypeParam], [Pattern], Maybe (TypeExp VName), StructType,
      [VName], Exp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((VName, [TypeParam], [Pattern], Maybe (TypeExp VName), StructType,
  [VName], Exp),
 [Occurence])
-> (VName, [TypeParam], [Pattern], Maybe (TypeExp VName),
    StructType, [VName], Exp)
forall a b. (a, b) -> a
fst (TypeM
   ((VName, [TypeParam], [Pattern], Maybe (TypeExp VName), StructType,
     [VName], Exp),
    [Occurence])
 -> TypeM
      (VName, [TypeParam], [Pattern], Maybe (TypeExp VName), StructType,
       [VName], Exp))
-> TypeM
     ((VName, [TypeParam], [Pattern], Maybe (TypeExp VName), StructType,
       [VName], Exp),
      [Occurence])
-> TypeM
     (VName, [TypeParam], [Pattern], Maybe (TypeExp VName), StructType,
      [VName], Exp)
forall a b. (a -> b) -> a -> b
$
    TermTypeM
  (VName, [TypeParam], [Pattern], Maybe (TypeExp VName), StructType,
   [VName], Exp)
-> TypeM
     ((VName, [TypeParam], [Pattern], Maybe (TypeExp VName), StructType,
       [VName], Exp),
      [Occurence])
forall a. TermTypeM a -> TypeM (a, [Occurence])
runTermTypeM (TermTypeM
   (VName, [TypeParam], [Pattern], Maybe (TypeExp VName), StructType,
    [VName], Exp)
 -> TypeM
      ((VName, [TypeParam], [Pattern], Maybe (TypeExp VName), StructType,
        [VName], Exp),
       [Occurence]))
-> TermTypeM
     (VName, [TypeParam], [Pattern], Maybe (TypeExp VName), StructType,
      [VName], Exp)
-> TypeM
     ((VName, [TypeParam], [Pattern], Maybe (TypeExp VName), StructType,
       [VName], Exp),
      [Occurence])
forall a b. (a -> b) -> a -> b
$ do
      ([TypeParam]
tparams', [Pattern]
params', Maybe (TypeExp VName)
maybe_retdecl', StructType
rettype', [VName]
retext, Exp
body') <-
        (Name, Maybe (TypeExp Name), [UncheckedTypeParam],
 [UncheckedPattern], ExpBase NoInfo Name, SrcLoc)
-> TermTypeM
     ([TypeParam], [Pattern], Maybe (TypeExp VName), StructType,
      [VName], Exp)
checkBinding (Name
fname, Maybe (TypeExp Name)
maybe_retdecl, [UncheckedTypeParam]
tparams, [UncheckedPattern]
params, ExpBase NoInfo Name
body, SrcLoc
loc)

      -- Since this is a top-level function, we also resolve overloaded
      -- types, using either defaults or complaining about ambiguities.
      Names -> TermTypeM ()
fixOverloadedTypes (Names -> TermTypeM ()) -> Names -> TermTypeM ()
forall a b. (a -> b) -> a -> b
$
        StructType -> Names
forall as dim. Monoid as => TypeBase dim as -> Names
typeVars StructType
rettype' Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> (Pattern -> Names) -> [Pattern] -> Names
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (PatternType -> Names
forall as dim. Monoid as => TypeBase dim as -> Names
typeVars (PatternType -> Names)
-> (Pattern -> PatternType) -> Pattern -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pattern -> PatternType
patternType) [Pattern]
params'

      -- Then replace all inferred types in the body and parameters.
      Exp
body'' <- Exp -> TermTypeM Exp
forall e. ASTMappable e => e -> TermTypeM e
updateTypes Exp
body'
      [Pattern]
params'' <- [Pattern] -> TermTypeM [Pattern]
forall e. ASTMappable e => e -> TermTypeM e
updateTypes [Pattern]
params'
      Maybe (TypeExp VName)
maybe_retdecl'' <- (TypeExp VName -> TermTypeM (TypeExp VName))
-> Maybe (TypeExp VName) -> TermTypeM (Maybe (TypeExp VName))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse TypeExp VName -> TermTypeM (TypeExp VName)
forall e. ASTMappable e => e -> TermTypeM e
updateTypes Maybe (TypeExp VName)
maybe_retdecl'
      StructType
rettype'' <- StructType -> TermTypeM StructType
forall a (m :: * -> *). (Substitutable a, MonadUnify m) => a -> m a
normTypeFully StructType
rettype'

      -- Check if pattern matches are exhaustive and yield
      -- errors if not.
      Exp -> TermTypeM ()
checkUnmatched Exp
body''

      -- Check if the function body can actually be evaluated.
      Exp -> TermTypeM ()
causalityCheck Exp
body''

      Exp -> TermTypeM ()
literalOverflowCheck Exp
body''

      [(Namespace, Name)]
-> TermTypeM
     (VName, [TypeParam], [Pattern], Maybe (TypeExp VName), StructType,
      [VName], Exp)
-> TermTypeM
     (VName, [TypeParam], [Pattern], Maybe (TypeExp VName), StructType,
      [VName], Exp)
forall (m :: * -> *) a.
MonadTypeChecker m =>
[(Namespace, Name)] -> m a -> m a
bindSpaced [(Namespace
Term, Name
fname)] (TermTypeM
   (VName, [TypeParam], [Pattern], Maybe (TypeExp VName), StructType,
    [VName], Exp)
 -> TermTypeM
      (VName, [TypeParam], [Pattern], Maybe (TypeExp VName), StructType,
       [VName], Exp))
-> TermTypeM
     (VName, [TypeParam], [Pattern], Maybe (TypeExp VName), StructType,
      [VName], Exp)
-> TermTypeM
     (VName, [TypeParam], [Pattern], Maybe (TypeExp VName), StructType,
      [VName], Exp)
forall a b. (a -> b) -> a -> b
$ do
        VName
fname' <- Namespace -> Name -> SrcLoc -> TermTypeM VName
forall (m :: * -> *).
MonadTypeChecker m =>
Namespace -> Name -> SrcLoc -> m VName
checkName Namespace
Term Name
fname SrcLoc
loc
        Bool -> TermTypeM () -> TermTypeM ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Name -> String
nameToString Name
fname String -> [String] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [String]
doNotShadow) (TermTypeM () -> TermTypeM ()) -> TermTypeM () -> TermTypeM ()
forall a b. (a -> b) -> a -> b
$
          SrcLoc -> Notes -> Doc -> TermTypeM ()
forall (m :: * -> *) loc a.
(MonadTypeChecker m, Located loc) =>
loc -> Notes -> Doc -> m a
typeError SrcLoc
loc Notes
forall a. Monoid a => a
mempty (Doc -> TermTypeM ()) -> Doc -> TermTypeM ()
forall a b. (a -> b) -> a -> b
$
            Doc
"The" Doc -> Doc -> Doc
<+> Name -> Doc
forall v. IsName v => v -> Doc
pprName Name
fname Doc -> Doc -> Doc
<+> Doc
"operator may not be redefined."

        (VName, [TypeParam], [Pattern], Maybe (TypeExp VName), StructType,
 [VName], Exp)
-> TermTypeM
     (VName, [TypeParam], [Pattern], Maybe (TypeExp VName), StructType,
      [VName], Exp)
forall (m :: * -> *) a. Monad m => a -> m a
return (VName
fname', [TypeParam]
tparams', [Pattern]
params'', Maybe (TypeExp VName)
maybe_retdecl'', StructType
rettype'', [VName]
retext, Exp
body'')

-- | This is "fixing" as in "setting them", not "correcting them".  We
-- only make very conservative fixing.
fixOverloadedTypes :: Names -> TermTypeM ()
fixOverloadedTypes :: Names -> TermTypeM ()
fixOverloadedTypes Names
tyvars_at_toplevel =
  TermTypeM Constraints
forall (m :: * -> *). MonadUnify m => m Constraints
getConstraints TermTypeM Constraints
-> (Constraints -> TermTypeM ()) -> TermTypeM ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ((VName, Constraint) -> TermTypeM ())
-> [(VName, Constraint)] -> TermTypeM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (VName, Constraint) -> TermTypeM ()
forall (m :: * -> *).
(MonadUnify m, MonadTypeChecker m) =>
(VName, Constraint) -> m ()
fixOverloaded ([(VName, Constraint)] -> TermTypeM ())
-> (Constraints -> [(VName, Constraint)])
-> Constraints
-> TermTypeM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map VName Constraint -> [(VName, Constraint)]
forall k a. Map k a -> [(k, a)]
M.toList (Map VName Constraint -> [(VName, Constraint)])
-> (Constraints -> Map VName Constraint)
-> Constraints
-> [(VName, Constraint)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Int, Constraint) -> Constraint)
-> Constraints -> Map VName Constraint
forall a b k. (a -> b) -> Map k a -> Map k b
M.map (Int, Constraint) -> Constraint
forall a b. (a, b) -> b
snd
  where
    fixOverloaded :: (VName, Constraint) -> m ()
fixOverloaded (VName
v, Overloaded [PrimType]
ots Usage
usage)
      | IntType -> PrimType
Signed IntType
Int32 PrimType -> [PrimType] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [PrimType]
ots = do
        Usage -> StructType -> StructType -> m ()
forall (m :: * -> *).
MonadUnify m =>
Usage -> StructType -> StructType -> m ()
unify Usage
usage (ScalarTypeBase (DimDecl VName) () -> StructType
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (()
-> Uniqueness
-> TypeName
-> [TypeArg (DimDecl VName)]
-> ScalarTypeBase (DimDecl VName) ()
forall dim as.
as
-> Uniqueness -> TypeName -> [TypeArg dim] -> ScalarTypeBase dim as
TypeVar () Uniqueness
Nonunique (VName -> TypeName
typeName VName
v) [])) (StructType -> m ()) -> StructType -> m ()
forall a b. (a -> b) -> a -> b
$
          ScalarTypeBase (DimDecl VName) () -> StructType
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (ScalarTypeBase (DimDecl VName) () -> StructType)
-> ScalarTypeBase (DimDecl VName) () -> StructType
forall a b. (a -> b) -> a -> b
$ PrimType -> ScalarTypeBase (DimDecl VName) ()
forall dim as. PrimType -> ScalarTypeBase dim as
Prim (PrimType -> ScalarTypeBase (DimDecl VName) ())
-> PrimType -> ScalarTypeBase (DimDecl VName) ()
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
Signed IntType
Int32
        Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (VName
v VName -> Names -> Bool
forall a. Ord a => a -> Set a -> Bool
`S.member` Names
tyvars_at_toplevel) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$
          Usage -> String -> m ()
forall (m :: * -> *) loc.
(MonadTypeChecker m, Located loc) =>
loc -> String -> m ()
warn Usage
usage String
"Defaulting ambiguous type to i32."
      | FloatType -> PrimType
FloatType FloatType
Float64 PrimType -> [PrimType] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [PrimType]
ots = do
        Usage -> StructType -> StructType -> m ()
forall (m :: * -> *).
MonadUnify m =>
Usage -> StructType -> StructType -> m ()
unify Usage
usage (ScalarTypeBase (DimDecl VName) () -> StructType
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (()
-> Uniqueness
-> TypeName
-> [TypeArg (DimDecl VName)]
-> ScalarTypeBase (DimDecl VName) ()
forall dim as.
as
-> Uniqueness -> TypeName -> [TypeArg dim] -> ScalarTypeBase dim as
TypeVar () Uniqueness
Nonunique (VName -> TypeName
typeName VName
v) [])) (StructType -> m ()) -> StructType -> m ()
forall a b. (a -> b) -> a -> b
$
          ScalarTypeBase (DimDecl VName) () -> StructType
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (ScalarTypeBase (DimDecl VName) () -> StructType)
-> ScalarTypeBase (DimDecl VName) () -> StructType
forall a b. (a -> b) -> a -> b
$ PrimType -> ScalarTypeBase (DimDecl VName) ()
forall dim as. PrimType -> ScalarTypeBase dim as
Prim (PrimType -> ScalarTypeBase (DimDecl VName) ())
-> PrimType -> ScalarTypeBase (DimDecl VName) ()
forall a b. (a -> b) -> a -> b
$ FloatType -> PrimType
FloatType FloatType
Float64
        Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (VName
v VName -> Names -> Bool
forall a. Ord a => a -> Set a -> Bool
`S.member` Names
tyvars_at_toplevel) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$
          Usage -> String -> m ()
forall (m :: * -> *) loc.
(MonadTypeChecker m, Located loc) =>
loc -> String -> m ()
warn Usage
usage String
"Defaulting ambiguous type to f64."
      | Bool
otherwise =
        Usage -> Notes -> Doc -> m ()
forall (m :: * -> *) loc a.
(MonadTypeChecker m, Located loc) =>
loc -> Notes -> Doc -> m a
typeError Usage
usage Notes
forall a. Monoid a => a
mempty (Doc -> m ()) -> Doc -> m ()
forall a b. (a -> b) -> a -> b
$
          Doc
"Type is ambiguous (could be one of" Doc -> Doc -> Doc
<+> [Doc] -> Doc
commasep ((PrimType -> Doc) -> [PrimType] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map PrimType -> Doc
forall a. Pretty a => a -> Doc
ppr [PrimType]
ots) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
")."
            Doc -> Doc -> Doc
</> Doc
"Add a type annotation to disambiguate the type."
    fixOverloaded (VName
_, NoConstraint Liftedness
_ Usage
usage) =
      Usage -> Notes -> Doc -> m ()
forall (m :: * -> *) loc a.
(MonadTypeChecker m, Located loc) =>
loc -> Notes -> Doc -> m a
typeError Usage
usage Notes
forall a. Monoid a => a
mempty (Doc -> m ()) -> Doc -> m ()
forall a b. (a -> b) -> a -> b
$
        Doc
"Type of expression is ambiguous."
          Doc -> Doc -> Doc
</> Doc
"Add a type annotation to disambiguate the type."
    fixOverloaded (VName
_, Equality Usage
usage) =
      Usage -> Notes -> Doc -> m ()
forall (m :: * -> *) loc a.
(MonadTypeChecker m, Located loc) =>
loc -> Notes -> Doc -> m a
typeError Usage
usage Notes
forall a. Monoid a => a
mempty (Doc -> m ()) -> Doc -> m ()
forall a b. (a -> b) -> a -> b
$
        Doc
"Type is ambiguous (must be equality type)."
          Doc -> Doc -> Doc
</> Doc
"Add a type annotation to disambiguate the type."
    fixOverloaded (VName
_, HasFields Map Name StructType
fs Usage
usage) =
      Usage -> Notes -> Doc -> m ()
forall (m :: * -> *) loc a.
(MonadTypeChecker m, Located loc) =>
loc -> Notes -> Doc -> m a
typeError Usage
usage Notes
forall a. Monoid a => a
mempty (Doc -> m ()) -> Doc -> m ()
forall a b. (a -> b) -> a -> b
$
        Doc
"Type is ambiguous.  Must be record with fields:"
          Doc -> Doc -> Doc
</> Int -> Doc -> Doc
indent Int
2 ([Doc] -> Doc
stack ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ ((Name, StructType) -> Doc) -> [(Name, StructType)] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map (Name, StructType) -> Doc
forall a a. (Pretty a, Pretty a) => (a, a) -> Doc
field ([(Name, StructType)] -> [Doc]) -> [(Name, StructType)] -> [Doc]
forall a b. (a -> b) -> a -> b
$ Map Name StructType -> [(Name, StructType)]
forall k a. Map k a -> [(k, a)]
M.toList Map Name StructType
fs)
          Doc -> Doc -> Doc
</> Doc
"Add a type annotation to disambiguate the type."
      where
        field :: (a, a) -> Doc
field (a
l, a
t) = a -> Doc
forall a. Pretty a => a -> Doc
ppr a
l Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
colon Doc -> Doc -> Doc
<+> Doc -> Doc
align (a -> Doc
forall a. Pretty a => a -> Doc
ppr a
t)
    fixOverloaded (VName
_, HasConstrs Map Name [StructType]
cs Usage
usage) =
      Usage -> Notes -> Doc -> m ()
forall (m :: * -> *) loc a.
(MonadTypeChecker m, Located loc) =>
loc -> Notes -> Doc -> m a
typeError Usage
usage Notes
forall a. Monoid a => a
mempty (Doc -> m ()) -> Doc -> m ()
forall a b. (a -> b) -> a -> b
$
        Doc
"Type is ambiguous (must be a sum type with constructors:"
          Doc -> Doc -> Doc
<+> ScalarTypeBase (DimDecl VName) () -> Doc
forall a. Pretty a => a -> Doc
ppr (Map Name [StructType] -> ScalarTypeBase (DimDecl VName) ()
forall dim as. Map Name [TypeBase dim as] -> ScalarTypeBase dim as
Sum Map Name [StructType]
cs) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
")."
          Doc -> Doc -> Doc
</> Doc
"Add a type annotation to disambiguate the type."
    fixOverloaded (VName
_, Size Maybe (DimDecl VName)
Nothing Usage
usage) =
      Usage -> Notes -> Doc -> m ()
forall (m :: * -> *) loc a.
(MonadTypeChecker m, Located loc) =>
loc -> Notes -> Doc -> m a
typeError Usage
usage Notes
forall a. Monoid a => a
mempty Doc
"Size is ambiguous."
    fixOverloaded (VName, Constraint)
_ = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

hiddenParamNames :: [Pattern] -> Names
hiddenParamNames :: [Pattern] -> Names
hiddenParamNames [Pattern]
params = Names
hidden
  where
    param_all_names :: Names
param_all_names = [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ([Names] -> Names) -> [Names] -> Names
forall a b. (a -> b) -> a -> b
$ (Pattern -> Names) -> [Pattern] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map Pattern -> Names
forall (f :: * -> *) vn.
(Functor f, Ord vn) =>
PatternBase f vn -> Set vn
patternNames [Pattern]
params
    named :: (PName, b) -> Maybe VName
named (Named VName
x, b
_) = VName -> Maybe VName
forall a. a -> Maybe a
Just VName
x
    named (PName
Unnamed, b
_) = Maybe VName
forall a. Maybe a
Nothing
    param_names :: Names
param_names =
      [VName] -> Names
forall a. Ord a => [a] -> Set a
S.fromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ (Pattern -> Maybe VName) -> [Pattern] -> [VName]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe ((PName, StructType) -> Maybe VName
forall b. (PName, b) -> Maybe VName
named ((PName, StructType) -> Maybe VName)
-> (Pattern -> (PName, StructType)) -> Pattern -> Maybe VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pattern -> (PName, StructType)
patternParam) [Pattern]
params
    hidden :: Names
hidden = Names
param_all_names Names -> Names -> Names
forall a. Ord a => Set a -> Set a -> Set a
`S.difference` Names
param_names

inferredReturnType :: SrcLoc -> [Pattern] -> PatternType -> TermTypeM StructType
inferredReturnType :: SrcLoc -> [Pattern] -> PatternType -> TermTypeM StructType
inferredReturnType SrcLoc
loc [Pattern]
params PatternType
t =
  -- The inferred type may refer to names that are bound by the
  -- parameter patterns, but which will not be visible in the type.
  -- These we must turn into fresh type variables, which will be
  -- existential in the return type.
  ((PatternType, [VName]) -> StructType)
-> TermTypeM (PatternType, [VName]) -> TermTypeM StructType
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (PatternType -> StructType
forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct (PatternType -> StructType)
-> ((PatternType, [VName]) -> PatternType)
-> (PatternType, [VName])
-> StructType
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PatternType, [VName]) -> PatternType
forall a b. (a, b) -> a
fst) (TermTypeM (PatternType, [VName]) -> TermTypeM StructType)
-> TermTypeM (PatternType, [VName]) -> TermTypeM StructType
forall a b. (a -> b) -> a -> b
$
    SrcLoc
-> Map VName Ident
-> PatternType
-> TermTypeM (PatternType, [VName])
unscopeType
      SrcLoc
loc
      ((VName -> Ident -> Bool) -> Map VName Ident -> Map VName Ident
forall k a. (k -> a -> Bool) -> Map k a -> Map k a
M.filterWithKey (Bool -> Ident -> Bool
forall a b. a -> b -> a
const (Bool -> Ident -> Bool)
-> (VName -> Bool) -> VName -> Ident -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> Names -> Bool
forall a. Ord a => a -> Set a -> Bool
`S.member` Names
hidden)) (Map VName Ident -> Map VName Ident)
-> Map VName Ident -> Map VName Ident
forall a b. (a -> b) -> a -> b
$ (Pattern -> Map VName Ident) -> [Pattern] -> Map VName Ident
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Pattern -> Map VName Ident
forall (f :: * -> *).
Functor f =>
PatternBase f VName -> Map VName (IdentBase f VName)
patternMap [Pattern]
params)
      (PatternType -> TermTypeM (PatternType, [VName]))
-> PatternType -> TermTypeM (PatternType, [VName])
forall a b. (a -> b) -> a -> b
$ [Pattern] -> PatternType -> PatternType
inferReturnUniqueness [Pattern]
params PatternType
t
  where
    hidden :: Names
hidden = [Pattern] -> Names
hiddenParamNames [Pattern]
params

checkBinding ::
  ( Name,
    Maybe UncheckedTypeExp,
    [UncheckedTypeParam],
    [UncheckedPattern],
    UncheckedExp,
    SrcLoc
  ) ->
  TermTypeM
    ( [TypeParam],
      [Pattern],
      Maybe (TypeExp VName),
      StructType,
      [VName],
      Exp
    )
checkBinding :: (Name, Maybe (TypeExp Name), [UncheckedTypeParam],
 [UncheckedPattern], ExpBase NoInfo Name, SrcLoc)
-> TermTypeM
     ([TypeParam], [Pattern], Maybe (TypeExp VName), StructType,
      [VName], Exp)
checkBinding (Name
fname, Maybe (TypeExp Name)
maybe_retdecl, [UncheckedTypeParam]
tparams, [UncheckedPattern]
params, ExpBase NoInfo Name
body, SrcLoc
loc) =
  TermTypeM
  ([TypeParam], [Pattern], Maybe (TypeExp VName), StructType,
   [VName], Exp)
-> TermTypeM
     ([TypeParam], [Pattern], Maybe (TypeExp VName), StructType,
      [VName], Exp)
forall b. TermTypeM b -> TermTypeM b
noUnique (TermTypeM
   ([TypeParam], [Pattern], Maybe (TypeExp VName), StructType,
    [VName], Exp)
 -> TermTypeM
      ([TypeParam], [Pattern], Maybe (TypeExp VName), StructType,
       [VName], Exp))
-> TermTypeM
     ([TypeParam], [Pattern], Maybe (TypeExp VName), StructType,
      [VName], Exp)
-> TermTypeM
     ([TypeParam], [Pattern], Maybe (TypeExp VName), StructType,
      [VName], Exp)
forall a b. (a -> b) -> a -> b
$
    TermTypeM
  ([TypeParam], [Pattern], Maybe (TypeExp VName), StructType,
   [VName], Exp)
-> TermTypeM
     ([TypeParam], [Pattern], Maybe (TypeExp VName), StructType,
      [VName], Exp)
forall b. TermTypeM b -> TermTypeM b
incLevel (TermTypeM
   ([TypeParam], [Pattern], Maybe (TypeExp VName), StructType,
    [VName], Exp)
 -> TermTypeM
      ([TypeParam], [Pattern], Maybe (TypeExp VName), StructType,
       [VName], Exp))
-> TermTypeM
     ([TypeParam], [Pattern], Maybe (TypeExp VName), StructType,
      [VName], Exp)
-> TermTypeM
     ([TypeParam], [Pattern], Maybe (TypeExp VName), StructType,
      [VName], Exp)
forall a b. (a -> b) -> a -> b
$
      [UncheckedTypeParam]
-> [UncheckedPattern]
-> ([TypeParam]
    -> [Pattern]
    -> TermTypeM
         ([TypeParam], [Pattern], Maybe (TypeExp VName), StructType,
          [VName], Exp))
-> TermTypeM
     ([TypeParam], [Pattern], Maybe (TypeExp VName), StructType,
      [VName], Exp)
forall a.
[UncheckedTypeParam]
-> [UncheckedPattern]
-> ([TypeParam] -> [Pattern] -> TermTypeM a)
-> TermTypeM a
bindingParams [UncheckedTypeParam]
tparams [UncheckedPattern]
params (([TypeParam]
  -> [Pattern]
  -> TermTypeM
       ([TypeParam], [Pattern], Maybe (TypeExp VName), StructType,
        [VName], Exp))
 -> TermTypeM
      ([TypeParam], [Pattern], Maybe (TypeExp VName), StructType,
       [VName], Exp))
-> ([TypeParam]
    -> [Pattern]
    -> TermTypeM
         ([TypeParam], [Pattern], Maybe (TypeExp VName), StructType,
          [VName], Exp))
-> TermTypeM
     ([TypeParam], [Pattern], Maybe (TypeExp VName), StructType,
      [VName], Exp)
forall a b. (a -> b) -> a -> b
$ \[TypeParam]
tparams' [Pattern]
params' -> do
        Bool -> TermTypeM () -> TermTypeM ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ([UncheckedPattern] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [UncheckedPattern]
params Bool -> Bool -> Bool
&& (UncheckedTypeParam -> Bool) -> [UncheckedTypeParam] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any UncheckedTypeParam -> Bool
forall vn. TypeParamBase vn -> Bool
isSizeParam [UncheckedTypeParam]
tparams) (TermTypeM () -> TermTypeM ()) -> TermTypeM () -> TermTypeM ()
forall a b. (a -> b) -> a -> b
$
          SrcLoc -> Notes -> Doc -> TermTypeM ()
forall (m :: * -> *) loc a.
(MonadTypeChecker m, Located loc) =>
loc -> Notes -> Doc -> m a
typeError
            SrcLoc
loc
            Notes
forall a. Monoid a => a
mempty
            Doc
"Size parameters are only allowed on bindings that also have value parameters."

        Maybe (TypeExp VName, StructType)
maybe_retdecl' <- Maybe (TypeExp Name)
-> (TypeExp Name -> TermTypeM (TypeExp VName, StructType))
-> TermTypeM (Maybe (TypeExp VName, StructType))
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM Maybe (TypeExp Name)
maybe_retdecl ((TypeExp Name -> TermTypeM (TypeExp VName, StructType))
 -> TermTypeM (Maybe (TypeExp VName, StructType)))
-> (TypeExp Name -> TermTypeM (TypeExp VName, StructType))
-> TermTypeM (Maybe (TypeExp VName, StructType))
forall a b. (a -> b) -> a -> b
$ \TypeExp Name
retdecl -> do
          (TypeExp VName
retdecl', StructType
ret_nodims, Liftedness
_) <- TypeExp Name -> TermTypeM (TypeExp VName, StructType, Liftedness)
forall (m :: * -> *).
MonadTypeChecker m =>
TypeExp Name -> m (TypeExp VName, StructType, Liftedness)
checkTypeExp TypeExp Name
retdecl
          (StructType
ret, [VName]
_) <- SrcLoc
-> String
-> Rigidity
-> StructType
-> TermTypeM (StructType, [VName])
forall (m :: * -> *) als.
MonadUnify m =>
SrcLoc
-> String
-> Rigidity
-> TypeBase (DimDecl VName) als
-> m (TypeBase (DimDecl VName) als, [VName])
instantiateEmptyArrayDims SrcLoc
loc String
"funret" Rigidity
Nonrigid StructType
ret_nodims
          (TypeExp VName, StructType)
-> TermTypeM (TypeExp VName, StructType)
forall (m :: * -> *) a. Monad m => a -> m a
return (TypeExp VName
retdecl', StructType
ret)

        Exp
body' <-
          [Pattern]
-> ExpBase NoInfo Name
-> Maybe StructType
-> SrcLoc
-> TermTypeM Exp
checkFunBody
            [Pattern]
params'
            ExpBase NoInfo Name
body
            ((TypeExp VName, StructType) -> StructType
forall a b. (a, b) -> b
snd ((TypeExp VName, StructType) -> StructType)
-> Maybe (TypeExp VName, StructType) -> Maybe StructType
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe (TypeExp VName, StructType)
maybe_retdecl')
            (SrcLoc
-> (TypeExp Name -> SrcLoc) -> Maybe (TypeExp Name) -> SrcLoc
forall b a. b -> (a -> b) -> Maybe a -> b
maybe SrcLoc
loc TypeExp Name -> SrcLoc
forall a. Located a => a -> SrcLoc
srclocOf Maybe (TypeExp Name)
maybe_retdecl)

        [Pattern]
params'' <- (Pattern -> TermTypeM Pattern) -> [Pattern] -> TermTypeM [Pattern]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Pattern -> TermTypeM Pattern
forall e. ASTMappable e => e -> TermTypeM e
updateTypes [Pattern]
params'
        PatternType
body_t <- Exp -> TermTypeM PatternType
expTypeFully Exp
body'

        (Maybe (TypeExp VName)
maybe_retdecl'', StructType
rettype) <- case Maybe (TypeExp VName, StructType)
maybe_retdecl' of
          Just (TypeExp VName
retdecl', StructType
ret) -> do
            let rettype_structural :: TypeBase () ()
rettype_structural = StructType -> TypeBase () ()
forall dim as. TypeBase dim as -> TypeBase () ()
toStructural StructType
ret
            TypeBase () () -> [Pattern] -> PatternType -> TermTypeM ()
forall (t :: * -> *) shape as shape.
Foldable t =>
TypeBase shape as
-> t Pattern -> TypeBase shape Aliasing -> TermTypeM ()
checkReturnAlias TypeBase () ()
rettype_structural [Pattern]
params'' PatternType
body_t

            Bool -> TermTypeM () -> TermTypeM ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ([UncheckedPattern] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [UncheckedPattern]
params) (TermTypeM () -> TermTypeM ()) -> TermTypeM () -> TermTypeM ()
forall a b. (a -> b) -> a -> b
$ SrcLoc -> TypeBase () () -> TermTypeM ()
nothingMustBeUnique SrcLoc
loc TypeBase () ()
rettype_structural

            StructType
ret' <- StructType -> TermTypeM StructType
forall a (m :: * -> *). (Substitutable a, MonadUnify m) => a -> m a
normTypeFully StructType
ret

            (Maybe (TypeExp VName), StructType)
-> TermTypeM (Maybe (TypeExp VName), StructType)
forall (m :: * -> *) a. Monad m => a -> m a
return (TypeExp VName -> Maybe (TypeExp VName)
forall a. a -> Maybe a
Just TypeExp VName
retdecl', StructType
ret')
          Maybe (TypeExp VName, StructType)
Nothing
            | [UncheckedPattern] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [UncheckedPattern]
params ->
              (Maybe (TypeExp VName), StructType)
-> TermTypeM (Maybe (TypeExp VName), StructType)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (TypeExp VName)
forall a. Maybe a
Nothing, PatternType -> StructType
forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct (PatternType -> StructType) -> PatternType -> StructType
forall a b. (a -> b) -> a -> b
$ PatternType
body_t PatternType -> Uniqueness -> PatternType
forall dim as. TypeBase dim as -> Uniqueness -> TypeBase dim as
`setUniqueness` Uniqueness
Nonunique)
            | Bool
otherwise -> do
              StructType
body_t' <- SrcLoc -> [Pattern] -> PatternType -> TermTypeM StructType
inferredReturnType SrcLoc
loc [Pattern]
params'' PatternType
body_t
              (Maybe (TypeExp VName), StructType)
-> TermTypeM (Maybe (TypeExp VName), StructType)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (TypeExp VName)
forall a. Maybe a
Nothing, StructType
body_t')

        Maybe Name -> [Pattern] -> TermTypeM ()
verifyFunctionParams (Name -> Maybe Name
forall a. a -> Maybe a
Just Name
fname) [Pattern]
params''

        ([TypeParam]
tparams'', [Pattern]
params''', StructType
rettype'', [VName]
retext) <-
          Name
-> SrcLoc
-> [TypeParam]
-> [Pattern]
-> StructType
-> TermTypeM ([TypeParam], [Pattern], StructType, [VName])
letGeneralise Name
fname SrcLoc
loc [TypeParam]
tparams' [Pattern]
params'' StructType
rettype

        [Pattern] -> PatternType -> SrcLoc -> TermTypeM ()
checkGlobalAliases [Pattern]
params'' PatternType
body_t SrcLoc
loc

        ([TypeParam], [Pattern], Maybe (TypeExp VName), StructType,
 [VName], Exp)
-> TermTypeM
     ([TypeParam], [Pattern], Maybe (TypeExp VName), StructType,
      [VName], Exp)
forall (m :: * -> *) a. Monad m => a -> m a
return ([TypeParam]
tparams'', [Pattern]
params''', Maybe (TypeExp VName)
maybe_retdecl'', StructType
rettype'', [VName]
retext, Exp
body')
  where
    checkReturnAlias :: TypeBase shape as
-> t Pattern -> TypeBase shape Aliasing -> TermTypeM ()
checkReturnAlias TypeBase shape as
rettp t Pattern
params' =
      (Set (Uniqueness, VName)
 -> (Uniqueness, Names) -> TermTypeM (Set (Uniqueness, VName)))
-> Set (Uniqueness, VName) -> [(Uniqueness, Names)] -> TermTypeM ()
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m ()
foldM_ (t Pattern
-> Set (Uniqueness, VName)
-> (Uniqueness, Names)
-> TermTypeM (Set (Uniqueness, VName))
forall (t :: * -> *).
Foldable t =>
t Pattern
-> Set (Uniqueness, VName)
-> (Uniqueness, Names)
-> TermTypeM (Set (Uniqueness, VName))
checkReturnAlias' t Pattern
params') Set (Uniqueness, VName)
forall a. Set a
S.empty ([(Uniqueness, Names)] -> TermTypeM ())
-> (TypeBase shape Aliasing -> [(Uniqueness, Names)])
-> TypeBase shape Aliasing
-> TermTypeM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TypeBase shape as
-> TypeBase shape Aliasing -> [(Uniqueness, Names)]
forall shape as shape.
TypeBase shape as
-> TypeBase shape Aliasing -> [(Uniqueness, Names)]
returnAliasing TypeBase shape as
rettp
    checkReturnAlias' :: t Pattern
-> Set (Uniqueness, VName)
-> (Uniqueness, Names)
-> TermTypeM (Set (Uniqueness, VName))
checkReturnAlias' t Pattern
params' Set (Uniqueness, VName)
seen (Uniqueness
Unique, Names
names)
      | (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> Names -> Bool
forall a. Ord a => a -> Set a -> Bool
`S.member` ((Uniqueness, VName) -> VName) -> Set (Uniqueness, VName) -> Names
forall b a. Ord b => (a -> b) -> Set a -> Set b
S.map (Uniqueness, VName) -> VName
forall a b. (a, b) -> b
snd Set (Uniqueness, VName)
seen) ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
forall a. Set a -> [a]
S.toList Names
names =
        Name -> SrcLoc -> TermTypeM (Set (Uniqueness, VName))
forall a. Name -> SrcLoc -> TermTypeM a
uniqueReturnAliased Name
fname SrcLoc
loc
      | Bool
otherwise = do
        t Pattern -> Names -> TermTypeM ()
forall (t :: * -> *).
Foldable t =>
t Pattern -> Names -> TermTypeM ()
notAliasingParam t Pattern
params' Names
names
        Set (Uniqueness, VName) -> TermTypeM (Set (Uniqueness, VName))
forall (m :: * -> *) a. Monad m => a -> m a
return (Set (Uniqueness, VName) -> TermTypeM (Set (Uniqueness, VName)))
-> Set (Uniqueness, VName) -> TermTypeM (Set (Uniqueness, VName))
forall a b. (a -> b) -> a -> b
$ Set (Uniqueness, VName)
seen Set (Uniqueness, VName)
-> Set (Uniqueness, VName) -> Set (Uniqueness, VName)
forall a. Ord a => Set a -> Set a -> Set a
`S.union` Uniqueness -> Names -> Set (Uniqueness, VName)
forall t t. (Ord t, Ord t) => t -> Set t -> Set (t, t)
tag Uniqueness
Unique Names
names
    checkReturnAlias' t Pattern
_ Set (Uniqueness, VName)
seen (Uniqueness
Nonunique, Names
names)
      | ((Uniqueness, VName) -> Bool) -> [(Uniqueness, VName)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any ((Uniqueness, VName) -> Set (Uniqueness, VName) -> Bool
forall a. Ord a => a -> Set a -> Bool
`S.member` Set (Uniqueness, VName)
seen) ([(Uniqueness, VName)] -> Bool) -> [(Uniqueness, VName)] -> Bool
forall a b. (a -> b) -> a -> b
$ Set (Uniqueness, VName) -> [(Uniqueness, VName)]
forall a. Set a -> [a]
S.toList (Set (Uniqueness, VName) -> [(Uniqueness, VName)])
-> Set (Uniqueness, VName) -> [(Uniqueness, VName)]
forall a b. (a -> b) -> a -> b
$ Uniqueness -> Names -> Set (Uniqueness, VName)
forall t t. (Ord t, Ord t) => t -> Set t -> Set (t, t)
tag Uniqueness
Unique Names
names =
        Name -> SrcLoc -> TermTypeM (Set (Uniqueness, VName))
forall a. Name -> SrcLoc -> TermTypeM a
uniqueReturnAliased Name
fname SrcLoc
loc
      | Bool
otherwise = Set (Uniqueness, VName) -> TermTypeM (Set (Uniqueness, VName))
forall (m :: * -> *) a. Monad m => a -> m a
return (Set (Uniqueness, VName) -> TermTypeM (Set (Uniqueness, VName)))
-> Set (Uniqueness, VName) -> TermTypeM (Set (Uniqueness, VName))
forall a b. (a -> b) -> a -> b
$ Set (Uniqueness, VName)
seen Set (Uniqueness, VName)
-> Set (Uniqueness, VName) -> Set (Uniqueness, VName)
forall a. Ord a => Set a -> Set a -> Set a
`S.union` Uniqueness -> Names -> Set (Uniqueness, VName)
forall t t. (Ord t, Ord t) => t -> Set t -> Set (t, t)
tag Uniqueness
Nonunique Names
names

    notAliasingParam :: t Pattern -> Names -> TermTypeM ()
notAliasingParam t Pattern
params' Names
names =
      t Pattern -> (Pattern -> TermTypeM ()) -> TermTypeM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ t Pattern
params' ((Pattern -> TermTypeM ()) -> TermTypeM ())
-> (Pattern -> TermTypeM ()) -> TermTypeM ()
forall a b. (a -> b) -> a -> b
$ \Pattern
p ->
        let consumedNonunique :: Ident -> Bool
consumedNonunique Ident
p' =
              Bool -> Bool
not (PatternType -> Bool
forall shape as. TypeBase shape as -> Bool
unique (PatternType -> Bool) -> PatternType -> Bool
forall a b. (a -> b) -> a -> b
$ Info PatternType -> PatternType
forall a. Info a -> a
unInfo (Info PatternType -> PatternType)
-> Info PatternType -> PatternType
forall a b. (a -> b) -> a -> b
$ Ident -> Info PatternType
forall (f :: * -> *) vn. IdentBase f vn -> f PatternType
identType Ident
p') Bool -> Bool -> Bool
&& (Ident -> VName
forall (f :: * -> *) vn. IdentBase f vn -> vn
identName Ident
p' VName -> Names -> Bool
forall a. Ord a => a -> Set a -> Bool
`S.member` Names
names)
         in case (Ident -> Bool) -> [Ident] -> Maybe Ident
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find Ident -> Bool
consumedNonunique ([Ident] -> Maybe Ident) -> [Ident] -> Maybe Ident
forall a b. (a -> b) -> a -> b
$ Set Ident -> [Ident]
forall a. Set a -> [a]
S.toList (Set Ident -> [Ident]) -> Set Ident -> [Ident]
forall a b. (a -> b) -> a -> b
$ Pattern -> Set Ident
forall (f :: * -> *) vn.
(Functor f, Ord vn) =>
PatternBase f vn -> Set (IdentBase f vn)
patternIdents Pattern
p of
              Just Ident
p' ->
                Name -> Name -> SrcLoc -> TermTypeM ()
returnAliased Name
fname (VName -> Name
baseName (VName -> Name) -> VName -> Name
forall a b. (a -> b) -> a -> b
$ Ident -> VName
forall (f :: * -> *) vn. IdentBase f vn -> vn
identName Ident
p') SrcLoc
loc
              Maybe Ident
Nothing ->
                () -> TermTypeM ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

    tag :: t -> Set t -> Set (t, t)
tag t
u = (t -> (t, t)) -> Set t -> Set (t, t)
forall b a. Ord b => (a -> b) -> Set a -> Set b
S.map (t
u,)

    returnAliasing :: TypeBase shape as
-> TypeBase shape Aliasing -> [(Uniqueness, Names)]
returnAliasing (Scalar (Record Map Name (TypeBase shape as)
ets1)) (Scalar (Record Map Name (TypeBase shape Aliasing)
ets2)) =
      [[(Uniqueness, Names)]] -> [(Uniqueness, Names)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[(Uniqueness, Names)]] -> [(Uniqueness, Names)])
-> [[(Uniqueness, Names)]] -> [(Uniqueness, Names)]
forall a b. (a -> b) -> a -> b
$ Map Name [(Uniqueness, Names)] -> [[(Uniqueness, Names)]]
forall k a. Map k a -> [a]
M.elems (Map Name [(Uniqueness, Names)] -> [[(Uniqueness, Names)]])
-> Map Name [(Uniqueness, Names)] -> [[(Uniqueness, Names)]]
forall a b. (a -> b) -> a -> b
$ (TypeBase shape as
 -> TypeBase shape Aliasing -> [(Uniqueness, Names)])
-> Map Name (TypeBase shape as)
-> Map Name (TypeBase shape Aliasing)
-> Map Name [(Uniqueness, Names)]
forall k a b c.
Ord k =>
(a -> b -> c) -> Map k a -> Map k b -> Map k c
M.intersectionWith TypeBase shape as
-> TypeBase shape Aliasing -> [(Uniqueness, Names)]
returnAliasing Map Name (TypeBase shape as)
ets1 Map Name (TypeBase shape Aliasing)
ets2
    returnAliasing TypeBase shape as
expected TypeBase shape Aliasing
got =
      [(TypeBase shape as -> Uniqueness
forall shape as. TypeBase shape as -> Uniqueness
uniqueness TypeBase shape as
expected, (Alias -> VName) -> Aliasing -> Names
forall b a. Ord b => (a -> b) -> Set a -> Set b
S.map Alias -> VName
aliasVar (Aliasing -> Names) -> Aliasing -> Names
forall a b. (a -> b) -> a -> b
$ TypeBase shape Aliasing -> Aliasing
forall as shape. Monoid as => TypeBase shape as -> as
aliases TypeBase shape Aliasing
got)]

-- | Extract all the shape names that occur in positive position
-- (roughly, left side of an arrow) in a given type.
typeDimNamesPos :: TypeBase (DimDecl VName) als -> S.Set VName
typeDimNamesPos :: TypeBase (DimDecl VName) als -> Names
typeDimNamesPos (Scalar (Arrow als
_ PName
_ TypeBase (DimDecl VName) als
t1 TypeBase (DimDecl VName) als
t2)) = TypeBase (DimDecl VName) als -> Names
forall als. TypeBase (DimDecl VName) als -> Names
onParam TypeBase (DimDecl VName) als
t1 Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> TypeBase (DimDecl VName) als -> Names
forall als. TypeBase (DimDecl VName) als -> Names
typeDimNamesPos TypeBase (DimDecl VName) als
t2
  where
    onParam :: TypeBase (DimDecl VName) als -> S.Set VName
    onParam :: TypeBase (DimDecl VName) als -> Names
onParam (Scalar Arrow {}) = Names
forall a. Monoid a => a
mempty
    onParam (Scalar (Record Map Name (TypeBase (DimDecl VName) als)
fs)) = [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ([Names] -> Names) -> [Names] -> Names
forall a b. (a -> b) -> a -> b
$ (TypeBase (DimDecl VName) als -> Names)
-> [TypeBase (DimDecl VName) als] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map TypeBase (DimDecl VName) als -> Names
forall als. TypeBase (DimDecl VName) als -> Names
onParam ([TypeBase (DimDecl VName) als] -> [Names])
-> [TypeBase (DimDecl VName) als] -> [Names]
forall a b. (a -> b) -> a -> b
$ Map Name (TypeBase (DimDecl VName) als)
-> [TypeBase (DimDecl VName) als]
forall k a. Map k a -> [a]
M.elems Map Name (TypeBase (DimDecl VName) als)
fs
    onParam (Scalar (TypeVar als
_ Uniqueness
_ TypeName
_ [TypeArg (DimDecl VName)]
targs)) = [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ([Names] -> Names) -> [Names] -> Names
forall a b. (a -> b) -> a -> b
$ (TypeArg (DimDecl VName) -> Names)
-> [TypeArg (DimDecl VName)] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map TypeArg (DimDecl VName) -> Names
onTypeArg [TypeArg (DimDecl VName)]
targs
    onParam TypeBase (DimDecl VName) als
t = TypeBase (DimDecl VName) als -> Names
forall als. TypeBase (DimDecl VName) als -> Names
typeDimNames TypeBase (DimDecl VName) als
t
    onTypeArg :: TypeArg (DimDecl VName) -> Names
onTypeArg (TypeArgDim (NamedDim QualName VName
d) SrcLoc
_) = VName -> Names
forall a. a -> Set a
S.singleton (VName -> Names) -> VName -> Names
forall a b. (a -> b) -> a -> b
$ QualName VName -> VName
forall vn. QualName vn -> vn
qualLeaf QualName VName
d
    onTypeArg (TypeArgDim DimDecl VName
_ SrcLoc
_) = Names
forall a. Monoid a => a
mempty
    onTypeArg (TypeArgType StructType
t SrcLoc
_) = StructType -> Names
forall als. TypeBase (DimDecl VName) als -> Names
onParam StructType
t
typeDimNamesPos TypeBase (DimDecl VName) als
_ = Names
forall a. Monoid a => a
mempty

checkGlobalAliases :: [Pattern] -> PatternType -> SrcLoc -> TermTypeM ()
checkGlobalAliases :: [Pattern] -> PatternType -> SrcLoc -> TermTypeM ()
checkGlobalAliases [Pattern]
params PatternType
body_t SrcLoc
loc = do
  Map VName ValBinding
vtable <- (TermEnv -> Map VName ValBinding)
-> TermTypeM (Map VName ValBinding)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((TermEnv -> Map VName ValBinding)
 -> TermTypeM (Map VName ValBinding))
-> (TermEnv -> Map VName ValBinding)
-> TermTypeM (Map VName ValBinding)
forall a b. (a -> b) -> a -> b
$ TermScope -> Map VName ValBinding
scopeVtable (TermScope -> Map VName ValBinding)
-> (TermEnv -> TermScope) -> TermEnv -> Map VName ValBinding
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TermEnv -> TermScope
termScope
  let isLocal :: VName -> Bool
isLocal VName
v = case VName
v VName -> Map VName ValBinding -> Maybe ValBinding
forall k a. Ord k => k -> Map k a -> Maybe a
`M.lookup` Map VName ValBinding
vtable of
        Just (BoundV Locality
Local [TypeParam]
_ PatternType
_) -> Bool
True
        Maybe ValBinding
_ -> Bool
False
  let als :: [VName]
als =
        (VName -> Bool) -> [VName] -> [VName]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not (Bool -> Bool) -> (VName -> Bool) -> VName -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Bool
isLocal) ([VName] -> [VName]) -> [VName] -> [VName]
forall a b. (a -> b) -> a -> b
$
          Names -> [VName]
forall a. Set a -> [a]
S.toList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$
            PatternType -> Names
boundArrayAliases PatternType
body_t
              Names -> Names -> Names
forall a. Ord a => Set a -> Set a -> Set a
`S.difference` (Pattern -> Names) -> [Pattern] -> Names
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Pattern -> Names
forall (f :: * -> *) vn.
(Functor f, Ord vn) =>
PatternBase f vn -> Set vn
patternNames [Pattern]
params
  case [VName]
als of
    VName
v : [VName]
_
      | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [Pattern] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Pattern]
params ->
        SrcLoc -> Notes -> Doc -> TermTypeM ()
forall (m :: * -> *) loc a.
(MonadTypeChecker m, Located loc) =>
loc -> Notes -> Doc -> m a
typeError SrcLoc
loc Notes
forall a. Monoid a => a
mempty (Doc -> TermTypeM ()) -> Doc -> TermTypeM ()
forall a b. (a -> b) -> a -> b
$
          Doc
"Function result aliases the free variable "
            Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
pquote (VName -> Doc
forall v. IsName v => v -> Doc
pprName VName
v)
            Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
"."
            Doc -> Doc -> Doc
</> Doc
"Use" Doc -> Doc -> Doc
<+> Doc -> Doc
pquote Doc
"copy" Doc -> Doc -> Doc
<+> Doc
"to break the aliasing."
    [VName]
_ ->
      () -> TermTypeM ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

inferReturnUniqueness :: [Pattern] -> PatternType -> PatternType
inferReturnUniqueness :: [Pattern] -> PatternType -> PatternType
inferReturnUniqueness [Pattern]
params PatternType
t =
  let forbidden :: Names
forbidden = PatternType -> Names
aliasesMultipleTimes PatternType
t
      uniques :: Names
uniques = [Pattern] -> Names
uniqueParamNames [Pattern]
params
      delve :: PatternType -> PatternType
delve (Scalar (Record Map Name PatternType
fs)) =
        ScalarTypeBase (DimDecl VName) Aliasing -> PatternType
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (ScalarTypeBase (DimDecl VName) Aliasing -> PatternType)
-> ScalarTypeBase (DimDecl VName) Aliasing -> PatternType
forall a b. (a -> b) -> a -> b
$ Map Name PatternType -> ScalarTypeBase (DimDecl VName) Aliasing
forall dim as. Map Name (TypeBase dim as) -> ScalarTypeBase dim as
Record (Map Name PatternType -> ScalarTypeBase (DimDecl VName) Aliasing)
-> Map Name PatternType -> ScalarTypeBase (DimDecl VName) Aliasing
forall a b. (a -> b) -> a -> b
$ (PatternType -> PatternType)
-> Map Name PatternType -> Map Name PatternType
forall a b k. (a -> b) -> Map k a -> Map k b
M.map PatternType -> PatternType
delve Map Name PatternType
fs
      delve PatternType
t'
        | (VName -> Bool) -> Names -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> Names -> Bool
forall a. Ord a => a -> Set a -> Bool
`S.member` Names
uniques) (PatternType -> Names
boundArrayAliases PatternType
t'),
          Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (Alias -> Bool) -> Aliasing -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any ((VName -> Names -> Bool
forall a. Ord a => a -> Set a -> Bool
`S.member` Names
forbidden) (VName -> Bool) -> (Alias -> VName) -> Alias -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Alias -> VName
aliasVar) (PatternType -> Aliasing
forall as shape. Monoid as => TypeBase shape as -> as
aliases PatternType
t') =
          PatternType
t'
        | Bool
otherwise =
          PatternType
t' PatternType -> Uniqueness -> PatternType
forall dim as. TypeBase dim as -> Uniqueness -> TypeBase dim as
`setUniqueness` Uniqueness
Nonunique
   in PatternType -> PatternType
delve PatternType
t

-- An alias inhibits uniqueness if it is used in disjoint values.
aliasesMultipleTimes :: PatternType -> Names
aliasesMultipleTimes :: PatternType -> Names
aliasesMultipleTimes = [VName] -> Names
forall a. Ord a => [a] -> Set a
S.fromList ([VName] -> Names)
-> (PatternType -> [VName]) -> PatternType -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((VName, Int) -> VName) -> [(VName, Int)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, Int) -> VName
forall a b. (a, b) -> a
fst ([(VName, Int)] -> [VName])
-> (PatternType -> [(VName, Int)]) -> PatternType -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((VName, Int) -> Bool) -> [(VName, Int)] -> [(VName, Int)]
forall a. (a -> Bool) -> [a] -> [a]
filter ((Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
1) (Int -> Bool) -> ((VName, Int) -> Int) -> (VName, Int) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, Int) -> Int
forall a b. (a, b) -> b
snd) ([(VName, Int)] -> [(VName, Int)])
-> (PatternType -> [(VName, Int)]) -> PatternType -> [(VName, Int)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map VName Int -> [(VName, Int)]
forall k a. Map k a -> [(k, a)]
M.toList (Map VName Int -> [(VName, Int)])
-> (PatternType -> Map VName Int) -> PatternType -> [(VName, Int)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatternType -> Map VName Int
forall shape. TypeBase shape Aliasing -> Map VName Int
delve
  where
    delve :: TypeBase shape Aliasing -> Map VName Int
delve (Scalar (Record Map Name (TypeBase shape Aliasing)
fs)) =
      (Map VName Int -> Map VName Int -> Map VName Int)
-> Map VName Int -> [Map VName Int] -> Map VName Int
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' ((Int -> Int -> Int)
-> Map VName Int -> Map VName Int -> Map VName Int
forall k a. Ord k => (a -> a -> a) -> Map k a -> Map k a -> Map k a
M.unionWith Int -> Int -> Int
forall a. Num a => a -> a -> a
(+)) Map VName Int
forall a. Monoid a => a
mempty ([Map VName Int] -> Map VName Int)
-> [Map VName Int] -> Map VName Int
forall a b. (a -> b) -> a -> b
$ (TypeBase shape Aliasing -> Map VName Int)
-> [TypeBase shape Aliasing] -> [Map VName Int]
forall a b. (a -> b) -> [a] -> [b]
map TypeBase shape Aliasing -> Map VName Int
delve ([TypeBase shape Aliasing] -> [Map VName Int])
-> [TypeBase shape Aliasing] -> [Map VName Int]
forall a b. (a -> b) -> a -> b
$ Map Name (TypeBase shape Aliasing) -> [TypeBase shape Aliasing]
forall k a. Map k a -> [a]
M.elems Map Name (TypeBase shape Aliasing)
fs
    delve TypeBase shape Aliasing
t =
      [(VName, Int)] -> Map VName Int
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, Int)] -> Map VName Int)
-> [(VName, Int)] -> Map VName Int
forall a b. (a -> b) -> a -> b
$ [VName] -> [Int] -> [(VName, Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip ((Alias -> VName) -> [Alias] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Alias -> VName
aliasVar ([Alias] -> [VName]) -> [Alias] -> [VName]
forall a b. (a -> b) -> a -> b
$ Aliasing -> [Alias]
forall a. Set a -> [a]
S.toList (TypeBase shape Aliasing -> Aliasing
forall as shape. Monoid as => TypeBase shape as -> as
aliases TypeBase shape Aliasing
t)) ([Int] -> [(VName, Int)]) -> [Int] -> [(VName, Int)]
forall a b. (a -> b) -> a -> b
$ Int -> [Int]
forall a. a -> [a]
repeat (Int
1 :: Int)

uniqueParamNames :: [Pattern] -> Names
uniqueParamNames :: [Pattern] -> Names
uniqueParamNames =
  (Ident -> VName) -> Set Ident -> Names
forall b a. Ord b => (a -> b) -> Set a -> Set b
S.map Ident -> VName
forall (f :: * -> *) vn. IdentBase f vn -> vn
identName
    (Set Ident -> Names)
-> ([Pattern] -> Set Ident) -> [Pattern] -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Ident -> Bool) -> Set Ident -> Set Ident
forall a. (a -> Bool) -> Set a -> Set a
S.filter (PatternType -> Bool
forall shape as. TypeBase shape as -> Bool
unique (PatternType -> Bool) -> (Ident -> PatternType) -> Ident -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Info PatternType -> PatternType
forall a. Info a -> a
unInfo (Info PatternType -> PatternType)
-> (Ident -> Info PatternType) -> Ident -> PatternType
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ident -> Info PatternType
forall (f :: * -> *) vn. IdentBase f vn -> f PatternType
identType)
    (Set Ident -> Set Ident)
-> ([Pattern] -> Set Ident) -> [Pattern] -> Set Ident
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Pattern -> Set Ident) -> [Pattern] -> Set Ident
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Pattern -> Set Ident
forall (f :: * -> *) vn.
(Functor f, Ord vn) =>
PatternBase f vn -> Set (IdentBase f vn)
patternIdents

boundArrayAliases :: PatternType -> S.Set VName
boundArrayAliases :: PatternType -> Names
boundArrayAliases (Array Aliasing
als Uniqueness
_ ScalarTypeBase (DimDecl VName) ()
_ ShapeDecl (DimDecl VName)
_) = Aliasing -> Names
boundAliases Aliasing
als
boundArrayAliases (Scalar Prim {}) = Names
forall a. Monoid a => a
mempty
boundArrayAliases (Scalar (Record Map Name PatternType
fs)) = (PatternType -> Names) -> Map Name PatternType -> Names
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap PatternType -> Names
boundArrayAliases Map Name PatternType
fs
boundArrayAliases (Scalar (TypeVar Aliasing
als Uniqueness
_ TypeName
_ [TypeArg (DimDecl VName)]
_)) = Aliasing -> Names
boundAliases Aliasing
als
boundArrayAliases (Scalar Arrow {}) = Names
forall a. Monoid a => a
mempty
boundArrayAliases (Scalar (Sum Map Name [PatternType]
fs)) =
  [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ([Names] -> Names) -> [Names] -> Names
forall a b. (a -> b) -> a -> b
$ ([PatternType] -> [Names]) -> [[PatternType]] -> [Names]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap ((PatternType -> Names) -> [PatternType] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map PatternType -> Names
boundArrayAliases) ([[PatternType]] -> [Names]) -> [[PatternType]] -> [Names]
forall a b. (a -> b) -> a -> b
$ Map Name [PatternType] -> [[PatternType]]
forall k a. Map k a -> [a]
M.elems Map Name [PatternType]
fs

-- | The set of in-scope variables that are being aliased.
boundAliases :: Aliasing -> S.Set VName
boundAliases :: Aliasing -> Names
boundAliases = (Alias -> VName) -> Aliasing -> Names
forall b a. Ord b => (a -> b) -> Set a -> Set b
S.map Alias -> VName
aliasVar (Aliasing -> Names) -> (Aliasing -> Aliasing) -> Aliasing -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Alias -> Bool) -> Aliasing -> Aliasing
forall a. (a -> Bool) -> Set a -> Set a
S.filter Alias -> Bool
bound
  where
    bound :: Alias -> Bool
bound AliasBound {} = Bool
True
    bound AliasFree {} = Bool
False

nothingMustBeUnique :: SrcLoc -> TypeBase () () -> TermTypeM ()
nothingMustBeUnique :: SrcLoc -> TypeBase () () -> TermTypeM ()
nothingMustBeUnique SrcLoc
loc = TypeBase () () -> TermTypeM ()
forall dim as. TypeBase dim as -> TermTypeM ()
check
  where
    check :: TypeBase dim as -> TermTypeM ()
check (Array as
_ Uniqueness
Unique ScalarTypeBase dim ()
_ ShapeDecl dim
_) = TermTypeM ()
forall a. TermTypeM a
bad
    check (Scalar (TypeVar as
_ Uniqueness
Unique TypeName
_ [TypeArg dim]
_)) = TermTypeM ()
forall a. TermTypeM a
bad
    check (Scalar (Record Map Name (TypeBase dim as)
fs)) = (TypeBase dim as -> TermTypeM ())
-> Map Name (TypeBase dim as) -> TermTypeM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ TypeBase dim as -> TermTypeM ()
check Map Name (TypeBase dim as)
fs
    check (Scalar (Sum Map Name [TypeBase dim as]
fs)) = ([TypeBase dim as] -> TermTypeM ())
-> Map Name [TypeBase dim as] -> TermTypeM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ((TypeBase dim as -> TermTypeM ())
-> [TypeBase dim as] -> TermTypeM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ TypeBase dim as -> TermTypeM ()
check) Map Name [TypeBase dim as]
fs
    check TypeBase dim as
_ = () -> TermTypeM ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    bad :: TermTypeM a
bad = SrcLoc -> Notes -> Doc -> TermTypeM a
forall (m :: * -> *) loc a.
(MonadTypeChecker m, Located loc) =>
loc -> Notes -> Doc -> m a
typeError SrcLoc
loc Notes
forall a. Monoid a => a
mempty Doc
"A top-level constant cannot have a unique type."

-- | Verify certain restrictions on function parameters, and bail out
-- on dubious constructions.
--
-- These restrictions apply to all functions (anonymous or otherwise).
-- Top-level functions have further restrictions that are checked
-- during let-generalisation.
verifyFunctionParams :: Maybe Name -> [Pattern] -> TermTypeM ()
verifyFunctionParams :: Maybe Name -> [Pattern] -> TermTypeM ()
verifyFunctionParams Maybe Name
fname [Pattern]
params =
  Checking -> TermTypeM () -> TermTypeM ()
forall a. Checking -> TermTypeM a -> TermTypeM a
onFailure (Maybe Name -> Checking
CheckingParams Maybe Name
fname) (TermTypeM () -> TermTypeM ()) -> TermTypeM () -> TermTypeM ()
forall a b. (a -> b) -> a -> b
$
    Names -> [Pattern] -> TermTypeM ()
forall (m :: * -> *).
MonadTypeChecker m =>
Names -> [Pattern] -> m ()
verifyParams ((Pattern -> Names) -> [Pattern] -> Names
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Pattern -> Names
forall (f :: * -> *) vn.
(Functor f, Ord vn) =>
PatternBase f vn -> Set vn
patternNames [Pattern]
params) ([Pattern] -> TermTypeM ()) -> TermTypeM [Pattern] -> TermTypeM ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (Pattern -> TermTypeM Pattern) -> [Pattern] -> TermTypeM [Pattern]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Pattern -> TermTypeM Pattern
forall e. ASTMappable e => e -> TermTypeM e
updateTypes [Pattern]
params
  where
    verifyParams :: Names -> [Pattern] -> m ()
verifyParams Names
forbidden (Pattern
p : [Pattern]
ps)
      | VName
d : [VName]
_ <- Names -> [VName]
forall a. Set a -> [a]
S.toList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ Pattern -> Names
patternDimNames Pattern
p Names -> Names -> Names
forall a. Ord a => Set a -> Set a -> Set a
`S.intersection` Names
forbidden =
        Pattern -> Notes -> Doc -> m ()
forall (m :: * -> *) loc a.
(MonadTypeChecker m, Located loc) =>
loc -> Notes -> Doc -> m a
typeError Pattern
p Notes
forall a. Monoid a => a
mempty (Doc -> m ()) -> Doc -> m ()
forall a b. (a -> b) -> a -> b
$
          Doc
"Parameter" Doc -> Doc -> Doc
<+> Doc -> Doc
pquote (Pattern -> Doc
forall a. Pretty a => a -> Doc
ppr Pattern
p)
            Doc -> Doc -> Doc
<+/> Doc
"refers to size" Doc -> Doc -> Doc
<+> Doc -> Doc
pquote (VName -> Doc
forall v. IsName v => v -> Doc
pprName VName
d)
            Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
            Doc -> Doc -> Doc
<+/> String -> Doc
textwrap String
"which will not be accessible to the caller"
            Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
            Doc -> Doc -> Doc
<+/> String -> Doc
textwrap String
"possibly because it is nested in a tuple or record."
            Doc -> Doc -> Doc
<+/> String -> Doc
textwrap String
"Consider ascribing an explicit type that does not reference "
            Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
pquote (VName -> Doc
forall v. IsName v => v -> Doc
pprName VName
d)
            Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
"."
      | Bool
otherwise = Names -> [Pattern] -> m ()
verifyParams Names
forbidden' [Pattern]
ps
      where
        forbidden' :: Names
forbidden' =
          case Pattern -> (PName, StructType)
patternParam Pattern
p of
            (Named VName
v, StructType
_) -> Names
forbidden Names -> Names -> Names
forall a. Ord a => Set a -> Set a -> Set a
`S.difference` VName -> Names
forall a. a -> Set a
S.singleton VName
v
            (PName, StructType)
_ -> Names
forbidden
    verifyParams Names
_ [] = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

-- Returns the sizes of the immediate type produced,
-- the sizes of parameter types, and the sizes of return types.
dimUses :: StructType -> (Names, Names, Names)
dimUses :: StructType -> (Names, Names, Names)
dimUses = Writer (Names, Names, Names) (TypeBase () ())
-> (Names, Names, Names)
forall w a. Writer w a -> w
execWriter (Writer (Names, Names, Names) (TypeBase () ())
 -> (Names, Names, Names))
-> (StructType -> Writer (Names, Names, Names) (TypeBase () ()))
-> StructType
-> (Names, Names, Names)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Names
 -> DimPos
 -> DimDecl VName
 -> WriterT (Names, Names, Names) Identity ())
-> StructType -> Writer (Names, Names, Names) (TypeBase () ())
forall (f :: * -> *) fdim tdim als.
Applicative f =>
(Names -> DimPos -> fdim -> f tdim)
-> TypeBase fdim als -> f (TypeBase tdim als)
traverseDims Names
-> DimPos
-> DimDecl VName
-> WriterT (Names, Names, Names) Identity ()
forall a (m :: * -> *) p.
(Ord a, MonadWriter (Set a, Set a, Set a) m) =>
p -> DimPos -> DimDecl a -> m ()
f
  where
    f :: p -> DimPos -> DimDecl a -> m ()
f p
_ DimPos
PosImmediate (NamedDim QualName a
v) = (Set a, Set a, Set a) -> m ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell (a -> Set a
forall a. a -> Set a
S.singleton (QualName a -> a
forall vn. QualName vn -> vn
qualLeaf QualName a
v), Set a
forall a. Monoid a => a
mempty, Set a
forall a. Monoid a => a
mempty)
    f p
_ DimPos
PosParam (NamedDim QualName a
v) = (Set a, Set a, Set a) -> m ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell (Set a
forall a. Monoid a => a
mempty, a -> Set a
forall a. a -> Set a
S.singleton (QualName a -> a
forall vn. QualName vn -> vn
qualLeaf QualName a
v), Set a
forall a. Monoid a => a
mempty)
    f p
_ DimPos
PosReturn (NamedDim QualName a
v) = (Set a, Set a, Set a) -> m ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell (Set a
forall a. Monoid a => a
mempty, Set a
forall a. Monoid a => a
mempty, a -> Set a
forall a. a -> Set a
S.singleton (QualName a -> a
forall vn. QualName vn -> vn
qualLeaf QualName a
v))
    f p
_ DimPos
_ DimDecl a
_ = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

-- | Find at all type variables in the given type that are covered by
-- the constraints, and produce type parameters that close over them.
--
-- The passed-in list of type parameters is always prepended to the
-- produced list of type parameters.
closeOverTypes ::
  Name ->
  SrcLoc ->
  [TypeParam] ->
  [StructType] ->
  StructType ->
  Constraints ->
  TermTypeM ([TypeParam], StructType, [VName])
closeOverTypes :: Name
-> SrcLoc
-> [TypeParam]
-> [StructType]
-> StructType
-> Constraints
-> TermTypeM ([TypeParam], StructType, [VName])
closeOverTypes Name
defname SrcLoc
defloc [TypeParam]
tparams [StructType]
paramts StructType
ret Constraints
substs = do
  ([TypeParam]
more_tparams, [VName]
retext) <-
    [Either TypeParam VName] -> ([TypeParam], [VName])
forall a b. [Either a b] -> ([a], [b])
partitionEithers ([Either TypeParam VName] -> ([TypeParam], [VName]))
-> ([Maybe (Either TypeParam VName)] -> [Either TypeParam VName])
-> [Maybe (Either TypeParam VName)]
-> ([TypeParam], [VName])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Maybe (Either TypeParam VName)] -> [Either TypeParam VName]
forall a. [Maybe a] -> [a]
catMaybes
      ([Maybe (Either TypeParam VName)] -> ([TypeParam], [VName]))
-> TermTypeM [Maybe (Either TypeParam VName)]
-> TermTypeM ([TypeParam], [VName])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((VName, Constraint) -> TermTypeM (Maybe (Either TypeParam VName)))
-> [(VName, Constraint)]
-> TermTypeM [Maybe (Either TypeParam VName)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (VName, Constraint) -> TermTypeM (Maybe (Either TypeParam VName))
forall (m :: * -> *).
(MonadUnify m, MonadTypeChecker m) =>
(VName, Constraint) -> m (Maybe (Either TypeParam VName))
closeOver (Map VName Constraint -> [(VName, Constraint)]
forall k a. Map k a -> [(k, a)]
M.toList (Map VName Constraint -> [(VName, Constraint)])
-> Map VName Constraint -> [(VName, Constraint)]
forall a b. (a -> b) -> a -> b
$ ((Int, Constraint) -> Constraint)
-> Constraints -> Map VName Constraint
forall a b k. (a -> b) -> Map k a -> Map k b
M.map (Int, Constraint) -> Constraint
forall a b. (a, b) -> b
snd Constraints
to_close_over)
  let retToAnyDim :: VName -> Maybe (Subst t)
retToAnyDim VName
v = do
        Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ VName
v VName -> Names -> Bool
forall a. Ord a => a -> Set a -> Bool
`S.member` Names
ret_sizes
        UnknowableSize {} <- (Int, Constraint) -> Constraint
forall a b. (a, b) -> b
snd ((Int, Constraint) -> Constraint)
-> Maybe (Int, Constraint) -> Maybe Constraint
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> Constraints -> Maybe (Int, Constraint)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Constraints
substs
        Subst t -> Maybe (Subst t)
forall a. a -> Maybe a
Just (Subst t -> Maybe (Subst t)) -> Subst t -> Maybe (Subst t)
forall a b. (a -> b) -> a -> b
$ DimDecl VName -> Subst t
forall t. DimDecl VName -> Subst t
SizeSubst DimDecl VName
forall vn. DimDecl vn
AnyDim
  ([TypeParam], StructType, [VName])
-> TermTypeM ([TypeParam], StructType, [VName])
forall (m :: * -> *) a. Monad m => a -> m a
return
    ( [TypeParam]
tparams [TypeParam] -> [TypeParam] -> [TypeParam]
forall a. [a] -> [a] -> [a]
++ [TypeParam]
more_tparams,
      (VName -> Maybe (Subst StructType)) -> StructType -> StructType
forall a.
Substitutable a =>
(VName -> Maybe (Subst StructType)) -> a -> a
applySubst VName -> Maybe (Subst StructType)
forall t. VName -> Maybe (Subst t)
retToAnyDim StructType
ret,
      [VName]
retext
    )
  where
    t :: StructType
t = [StructType] -> StructType -> StructType
forall as dim.
Monoid as =>
[TypeBase dim as] -> TypeBase dim as -> TypeBase dim as
foldFunType [StructType]
paramts StructType
ret
    to_close_over :: Constraints
to_close_over = (VName -> (Int, Constraint) -> Bool) -> Constraints -> Constraints
forall k a. (k -> a -> Bool) -> Map k a -> Map k a
M.filterWithKey (\VName
k (Int, Constraint)
_ -> VName
k VName -> Names -> Bool
forall a. Ord a => a -> Set a -> Bool
`S.member` Names
visible) Constraints
substs
    visible :: Names
visible = StructType -> Names
forall as dim. Monoid as => TypeBase dim as -> Names
typeVars StructType
t Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> StructType -> Names
forall als. TypeBase (DimDecl VName) als -> Names
typeDimNames StructType
t

    (Names
produced_sizes, Names
param_sizes, Names
ret_sizes) = StructType -> (Names, Names, Names)
dimUses StructType
t

    -- Avoid duplicate type parameters.
    closeOver :: (VName, Constraint) -> m (Maybe (Either TypeParam VName))
closeOver (VName
k, Constraint
_)
      | VName
k VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` (TypeParam -> VName) -> [TypeParam] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map TypeParam -> VName
forall vn. TypeParamBase vn -> vn
typeParamName [TypeParam]
tparams =
        Maybe (Either TypeParam VName)
-> m (Maybe (Either TypeParam VName))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (Either TypeParam VName)
forall a. Maybe a
Nothing
    closeOver (VName
k, NoConstraint Liftedness
l Usage
usage) =
      Maybe (Either TypeParam VName)
-> m (Maybe (Either TypeParam VName))
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (Either TypeParam VName)
 -> m (Maybe (Either TypeParam VName)))
-> Maybe (Either TypeParam VName)
-> m (Maybe (Either TypeParam VName))
forall a b. (a -> b) -> a -> b
$ Either TypeParam VName -> Maybe (Either TypeParam VName)
forall a. a -> Maybe a
Just (Either TypeParam VName -> Maybe (Either TypeParam VName))
-> Either TypeParam VName -> Maybe (Either TypeParam VName)
forall a b. (a -> b) -> a -> b
$ TypeParam -> Either TypeParam VName
forall a b. a -> Either a b
Left (TypeParam -> Either TypeParam VName)
-> TypeParam -> Either TypeParam VName
forall a b. (a -> b) -> a -> b
$ Liftedness -> VName -> SrcLoc -> TypeParam
forall vn. Liftedness -> vn -> SrcLoc -> TypeParamBase vn
TypeParamType Liftedness
l VName
k (SrcLoc -> TypeParam) -> SrcLoc -> TypeParam
forall a b. (a -> b) -> a -> b
$ Usage -> SrcLoc
forall a. Located a => a -> SrcLoc
srclocOf Usage
usage
    closeOver (VName
k, ParamType Liftedness
l SrcLoc
loc) =
      Maybe (Either TypeParam VName)
-> m (Maybe (Either TypeParam VName))
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (Either TypeParam VName)
 -> m (Maybe (Either TypeParam VName)))
-> Maybe (Either TypeParam VName)
-> m (Maybe (Either TypeParam VName))
forall a b. (a -> b) -> a -> b
$ Either TypeParam VName -> Maybe (Either TypeParam VName)
forall a. a -> Maybe a
Just (Either TypeParam VName -> Maybe (Either TypeParam VName))
-> Either TypeParam VName -> Maybe (Either TypeParam VName)
forall a b. (a -> b) -> a -> b
$ TypeParam -> Either TypeParam VName
forall a b. a -> Either a b
Left (TypeParam -> Either TypeParam VName)
-> TypeParam -> Either TypeParam VName
forall a b. (a -> b) -> a -> b
$ Liftedness -> VName -> SrcLoc -> TypeParam
forall vn. Liftedness -> vn -> SrcLoc -> TypeParamBase vn
TypeParamType Liftedness
l VName
k SrcLoc
loc
    closeOver (VName
k, Size Maybe (DimDecl VName)
Nothing Usage
usage) =
      Maybe (Either TypeParam VName)
-> m (Maybe (Either TypeParam VName))
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (Either TypeParam VName)
 -> m (Maybe (Either TypeParam VName)))
-> Maybe (Either TypeParam VName)
-> m (Maybe (Either TypeParam VName))
forall a b. (a -> b) -> a -> b
$ Either TypeParam VName -> Maybe (Either TypeParam VName)
forall a. a -> Maybe a
Just (Either TypeParam VName -> Maybe (Either TypeParam VName))
-> Either TypeParam VName -> Maybe (Either TypeParam VName)
forall a b. (a -> b) -> a -> b
$ TypeParam -> Either TypeParam VName
forall a b. a -> Either a b
Left (TypeParam -> Either TypeParam VName)
-> TypeParam -> Either TypeParam VName
forall a b. (a -> b) -> a -> b
$ VName -> SrcLoc -> TypeParam
forall vn. vn -> SrcLoc -> TypeParamBase vn
TypeParamDim VName
k (SrcLoc -> TypeParam) -> SrcLoc -> TypeParam
forall a b. (a -> b) -> a -> b
$ Usage -> SrcLoc
forall a. Located a => a -> SrcLoc
srclocOf Usage
usage
    closeOver (VName
k, UnknowableSize SrcLoc
_ RigidSource
_)
      | VName
k VName -> Names -> Bool
forall a. Ord a => a -> Set a -> Bool
`S.member` Names
param_sizes = do
        Notes
notes <- SrcLoc -> DimDecl VName -> m Notes
forall a (m :: * -> *).
(Located a, MonadUnify m) =>
a -> DimDecl VName -> m Notes
dimNotes SrcLoc
defloc (DimDecl VName -> m Notes) -> DimDecl VName -> m Notes
forall a b. (a -> b) -> a -> b
$ QualName VName -> DimDecl VName
forall vn. QualName vn -> DimDecl vn
NamedDim (QualName VName -> DimDecl VName)
-> QualName VName -> DimDecl VName
forall a b. (a -> b) -> a -> b
$ VName -> QualName VName
forall v. v -> QualName v
qualName VName
k
        SrcLoc -> Notes -> Doc -> m (Maybe (Either TypeParam VName))
forall (m :: * -> *) loc a.
(MonadTypeChecker m, Located loc) =>
loc -> Notes -> Doc -> m a
typeError SrcLoc
defloc Notes
notes (Doc -> m (Maybe (Either TypeParam VName)))
-> Doc -> m (Maybe (Either TypeParam VName))
forall a b. (a -> b) -> a -> b
$
          Doc
"Unknowable size" Doc -> Doc -> Doc
<+> Doc -> Doc
pquote (VName -> Doc
forall v. IsName v => v -> Doc
pprName VName
k)
            Doc -> Doc -> Doc
<+> Doc
"imposes constraint on type of"
            Doc -> Doc -> Doc
<+> Doc -> Doc
pquote (Name -> Doc
forall v. IsName v => v -> Doc
pprName Name
defname)
            Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
", which is inferred as:"
            Doc -> Doc -> Doc
</> Int -> Doc -> Doc
indent Int
2 (StructType -> Doc
forall a. Pretty a => a -> Doc
ppr StructType
t)
      | VName
k VName -> Names -> Bool
forall a. Ord a => a -> Set a -> Bool
`S.member` Names
produced_sizes =
        Maybe (Either TypeParam VName)
-> m (Maybe (Either TypeParam VName))
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (Either TypeParam VName)
 -> m (Maybe (Either TypeParam VName)))
-> Maybe (Either TypeParam VName)
-> m (Maybe (Either TypeParam VName))
forall a b. (a -> b) -> a -> b
$ Either TypeParam VName -> Maybe (Either TypeParam VName)
forall a. a -> Maybe a
Just (Either TypeParam VName -> Maybe (Either TypeParam VName))
-> Either TypeParam VName -> Maybe (Either TypeParam VName)
forall a b. (a -> b) -> a -> b
$ VName -> Either TypeParam VName
forall a b. b -> Either a b
Right VName
k
    closeOver (VName
_, Constraint
_) =
      Maybe (Either TypeParam VName)
-> m (Maybe (Either TypeParam VName))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (Either TypeParam VName)
forall a. Maybe a
Nothing

letGeneralise ::
  Name ->
  SrcLoc ->
  [TypeParam] ->
  [Pattern] ->
  StructType ->
  TermTypeM ([TypeParam], [Pattern], StructType, [VName])
letGeneralise :: Name
-> SrcLoc
-> [TypeParam]
-> [Pattern]
-> StructType
-> TermTypeM ([TypeParam], [Pattern], StructType, [VName])
letGeneralise Name
defname SrcLoc
defloc [TypeParam]
tparams [Pattern]
params StructType
rettype =
  Checking
-> TermTypeM ([TypeParam], [Pattern], StructType, [VName])
-> TermTypeM ([TypeParam], [Pattern], StructType, [VName])
forall a. Checking -> TermTypeM a -> TermTypeM a
onFailure (Name -> Checking
CheckingLetGeneralise Name
defname) (TermTypeM ([TypeParam], [Pattern], StructType, [VName])
 -> TermTypeM ([TypeParam], [Pattern], StructType, [VName]))
-> TermTypeM ([TypeParam], [Pattern], StructType, [VName])
-> TermTypeM ([TypeParam], [Pattern], StructType, [VName])
forall a b. (a -> b) -> a -> b
$ do
    Constraints
now_substs <- TermTypeM Constraints
forall (m :: * -> *). MonadUnify m => m Constraints
getConstraints

    -- Candidates for let-generalisation are those type variables that
    --
    -- (1) were not known before we checked this function, and
    --
    -- (2) are not used in the (new) definition of any type variables
    -- known before we checked this function.
    --
    -- (3) are not referenced from an overloaded type (for example,
    -- are the element types of an incompletely resolved record type).
    -- This is a bit more restrictive than I'd like, and SML for
    -- example does not have this restriction.
    --
    -- Criteria (1) and (2) is implemented by looking at the binding
    -- level of the type variables.
    let keep_type_vars :: Names
keep_type_vars = Constraints -> Names
overloadedTypeVars Constraints
now_substs

    Int
cur_lvl <- TermTypeM Int
forall (m :: * -> *). MonadUnify m => m Int
curLevel
    let candidate :: VName -> (Int, b) -> Bool
candidate VName
k (Int
lvl, b
_) = (VName
k VName -> Names -> Bool
forall a. Ord a => a -> Set a -> Bool
`S.notMember` Names
keep_type_vars) Bool -> Bool -> Bool
&& Int
lvl Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
cur_lvl
        new_substs :: Constraints
new_substs = (VName -> (Int, Constraint) -> Bool) -> Constraints -> Constraints
forall k a. (k -> a -> Bool) -> Map k a -> Map k a
M.filterWithKey VName -> (Int, Constraint) -> Bool
forall b. VName -> (Int, b) -> Bool
candidate Constraints
now_substs

    ([TypeParam]
tparams', StructType
rettype', [VName]
retext) <-
      Name
-> SrcLoc
-> [TypeParam]
-> [StructType]
-> StructType
-> Constraints
-> TermTypeM ([TypeParam], StructType, [VName])
closeOverTypes
        Name
defname
        SrcLoc
defloc
        [TypeParam]
tparams
        ((Pattern -> StructType) -> [Pattern] -> [StructType]
forall a b. (a -> b) -> [a] -> [b]
map Pattern -> StructType
patternStructType [Pattern]
params)
        StructType
rettype
        Constraints
new_substs

    StructType
rettype'' <- StructType -> TermTypeM StructType
forall e. ASTMappable e => e -> TermTypeM e
updateTypes StructType
rettype'

    let used_sizes :: Names
used_sizes =
          (StructType -> Names) -> [StructType] -> Names
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap StructType -> Names
forall als. TypeBase (DimDecl VName) als -> Names
typeDimNames ([StructType] -> Names) -> [StructType] -> Names
forall a b. (a -> b) -> a -> b
$
            StructType
rettype'' StructType -> [StructType] -> [StructType]
forall a. a -> [a] -> [a]
: (Pattern -> StructType) -> [Pattern] -> [StructType]
forall a b. (a -> b) -> [a] -> [b]
map Pattern -> StructType
patternStructType [Pattern]
params
    case (TypeParam -> Bool) -> [TypeParam] -> [TypeParam]
forall a. (a -> Bool) -> [a] -> [a]
filter ((VName -> Names -> Bool
forall a. Ord a => a -> Set a -> Bool
`S.notMember` Names
used_sizes) (VName -> Bool) -> (TypeParam -> VName) -> TypeParam -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TypeParam -> VName
forall vn. TypeParamBase vn -> vn
typeParamName) ([TypeParam] -> [TypeParam]) -> [TypeParam] -> [TypeParam]
forall a b. (a -> b) -> a -> b
$
      (TypeParam -> Bool) -> [TypeParam] -> [TypeParam]
forall a. (a -> Bool) -> [a] -> [a]
filter TypeParam -> Bool
forall vn. TypeParamBase vn -> Bool
isSizeParam [TypeParam]
tparams' of
      [] -> () -> TermTypeM ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
      TypeParam
tp : [TypeParam]
_ ->
        SrcLoc -> Notes -> Doc -> TermTypeM ()
forall (m :: * -> *) loc a.
(MonadTypeChecker m, Located loc) =>
loc -> Notes -> Doc -> m a
typeError SrcLoc
defloc Notes
forall a. Monoid a => a
mempty (Doc -> TermTypeM ()) -> Doc -> TermTypeM ()
forall a b. (a -> b) -> a -> b
$
          Doc
"Size parameter" Doc -> Doc -> Doc
<+> Doc -> Doc
pquote (TypeParam -> Doc
forall a. Pretty a => a -> Doc
ppr TypeParam
tp) Doc -> Doc -> Doc
<+> Doc
"unused."

    -- We keep those type variables that were not closed over by
    -- let-generalisation.
    (Constraints -> Constraints) -> TermTypeM ()
forall (m :: * -> *).
MonadUnify m =>
(Constraints -> Constraints) -> m ()
modifyConstraints ((Constraints -> Constraints) -> TermTypeM ())
-> (Constraints -> Constraints) -> TermTypeM ()
forall a b. (a -> b) -> a -> b
$ (VName -> (Int, Constraint) -> Bool) -> Constraints -> Constraints
forall k a. (k -> a -> Bool) -> Map k a -> Map k a
M.filterWithKey ((VName -> (Int, Constraint) -> Bool)
 -> Constraints -> Constraints)
-> (VName -> (Int, Constraint) -> Bool)
-> Constraints
-> Constraints
forall a b. (a -> b) -> a -> b
$ \VName
k (Int, Constraint)
_ -> VName
k VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` (TypeParam -> VName) -> [TypeParam] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map TypeParam -> VName
forall vn. TypeParamBase vn -> vn
typeParamName [TypeParam]
tparams'

    ([TypeParam], [Pattern], StructType, [VName])
-> TermTypeM ([TypeParam], [Pattern], StructType, [VName])
forall (m :: * -> *) a. Monad m => a -> m a
return ([TypeParam]
tparams', [Pattern]
params, StructType
rettype'', [VName]
retext)

checkFunBody ::
  [Pattern] ->
  UncheckedExp ->
  Maybe StructType ->
  SrcLoc ->
  TermTypeM Exp
checkFunBody :: [Pattern]
-> ExpBase NoInfo Name
-> Maybe StructType
-> SrcLoc
-> TermTypeM Exp
checkFunBody [Pattern]
params ExpBase NoInfo Name
body Maybe StructType
maybe_rettype SrcLoc
loc = do
  Exp
body' <- TermTypeM Exp -> TermTypeM Exp
forall b. TermTypeM b -> TermTypeM b
noSizeEscape (TermTypeM Exp -> TermTypeM Exp) -> TermTypeM Exp -> TermTypeM Exp
forall a b. (a -> b) -> a -> b
$ ExpBase NoInfo Name -> TermTypeM Exp
checkExp ExpBase NoInfo Name
body

  -- Unify body return type with return annotation, if one exists.
  case Maybe StructType
maybe_rettype of
    Just StructType
rettype -> do
      (StructType
rettype_withdims, [VName]
_) <- SrcLoc
-> String
-> Rigidity
-> StructType
-> TermTypeM (StructType, [VName])
forall (m :: * -> *) als.
MonadUnify m =>
SrcLoc
-> String
-> Rigidity
-> TypeBase (DimDecl VName) als
-> m (TypeBase (DimDecl VName) als, [VName])
instantiateEmptyArrayDims SrcLoc
loc String
"impl" Rigidity
Nonrigid StructType
rettype

      PatternType
body_t <- Exp -> TermTypeM PatternType
expTypeFully Exp
body'
      -- We need to turn any sizes provided by "hidden" parameter
      -- names into existential sizes instead.
      let hidden :: Names
hidden = [Pattern] -> Names
hiddenParamNames [Pattern]
params
      (PatternType
body_t', [VName]
_) <-
        SrcLoc
-> Map VName Ident
-> PatternType
-> TermTypeM (PatternType, [VName])
unscopeType
          SrcLoc
loc
          ( (VName -> Ident -> Bool) -> Map VName Ident -> Map VName Ident
forall k a. (k -> a -> Bool) -> Map k a -> Map k a
M.filterWithKey (Bool -> Ident -> Bool
forall a b. a -> b -> a
const (Bool -> Ident -> Bool)
-> (VName -> Bool) -> VName -> Ident -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> Names -> Bool
forall a. Ord a => a -> Set a -> Bool
`S.member` Names
hidden)) (Map VName Ident -> Map VName Ident)
-> Map VName Ident -> Map VName Ident
forall a b. (a -> b) -> a -> b
$
              (Pattern -> Map VName Ident) -> [Pattern] -> Map VName Ident
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Pattern -> Map VName Ident
forall (f :: * -> *).
Functor f =>
PatternBase f VName -> Map VName (IdentBase f VName)
patternMap [Pattern]
params
          )
          PatternType
body_t

      let usage :: Usage
usage = SrcLoc -> String -> Usage
mkUsage (ExpBase NoInfo Name -> SrcLoc
forall a. Located a => a -> SrcLoc
srclocOf ExpBase NoInfo Name
body) String
"return type annotation"
      Checking -> TermTypeM () -> TermTypeM ()
forall a. Checking -> TermTypeM a -> TermTypeM a
onFailure (StructType -> StructType -> Checking
CheckingReturn StructType
rettype (PatternType -> StructType
forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct PatternType
body_t')) (TermTypeM () -> TermTypeM ()) -> TermTypeM () -> TermTypeM ()
forall a b. (a -> b) -> a -> b
$
        Usage -> StructType -> StructType -> TermTypeM ()
forall (m :: * -> *).
MonadUnify m =>
Usage -> StructType -> StructType -> m ()
expect Usage
usage StructType
rettype_withdims (StructType -> TermTypeM ()) -> StructType -> TermTypeM ()
forall a b. (a -> b) -> a -> b
$ PatternType -> StructType
forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct PatternType
body_t'

      -- We also have to make sure that uniqueness matches.  This is done
      -- explicitly, because uniqueness is ignored by unification.
      StructType
rettype' <- StructType -> TermTypeM StructType
forall a (m :: * -> *). (Substitutable a, MonadUnify m) => a -> m a
normTypeFully StructType
rettype
      StructType
body_t'' <- StructType -> TermTypeM StructType
forall a (m :: * -> *). (Substitutable a, MonadUnify m) => a -> m a
normTypeFully StructType
rettype -- Substs may have changed.
      Bool -> TermTypeM () -> TermTypeM ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (StructType
body_t'' StructType -> StructType -> Bool
forall dim as1 as2.
ArrayDim dim =>
TypeBase dim as1 -> TypeBase dim as2 -> Bool
`subtypeOf` StructType -> StructType
forall vn as. TypeBase (DimDecl vn) as -> TypeBase (DimDecl vn) as
anySizes StructType
rettype') (TermTypeM () -> TermTypeM ()) -> TermTypeM () -> TermTypeM ()
forall a b. (a -> b) -> a -> b
$
        SrcLoc -> Notes -> Doc -> TermTypeM ()
forall (m :: * -> *) loc a.
(MonadTypeChecker m, Located loc) =>
loc -> Notes -> Doc -> m a
typeError (ExpBase NoInfo Name -> SrcLoc
forall a. Located a => a -> SrcLoc
srclocOf ExpBase NoInfo Name
body) Notes
forall a. Monoid a => a
mempty (Doc -> TermTypeM ()) -> Doc -> TermTypeM ()
forall a b. (a -> b) -> a -> b
$
          Doc
"Body type" Doc -> Doc -> Doc
</> Int -> Doc -> Doc
indent Int
2 (StructType -> Doc
forall a. Pretty a => a -> Doc
ppr StructType
body_t'')
            Doc -> Doc -> Doc
</> Doc
"is not a subtype of annotated type"
            Doc -> Doc -> Doc
</> Int -> Doc -> Doc
indent Int
2 (StructType -> Doc
forall a. Pretty a => a -> Doc
ppr StructType
rettype')
    Maybe StructType
Nothing -> () -> TermTypeM ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

  Exp -> TermTypeM Exp
forall (m :: * -> *) a. Monad m => a -> m a
return Exp
body'

--- Consumption

occur :: Occurences -> TermTypeM ()
occur :: [Occurence] -> TermTypeM ()
occur = [Occurence] -> TermTypeM ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell

-- | Proclaim that we have made read-only use of the given variable.
observe :: Ident -> TermTypeM ()
observe :: Ident -> TermTypeM ()
observe (Ident VName
nm (Info PatternType
t) SrcLoc
loc) =
  let als :: Aliasing
als = VName -> Alias
AliasBound VName
nm Alias -> Aliasing -> Aliasing
forall a. Ord a => a -> Set a -> Set a
`S.insert` PatternType -> Aliasing
forall as shape. Monoid as => TypeBase shape as -> as
aliases PatternType
t
   in [Occurence] -> TermTypeM ()
occur [Aliasing -> SrcLoc -> Occurence
observation Aliasing
als SrcLoc
loc]

checkIfConsumable :: SrcLoc -> Aliasing -> TermTypeM ()
checkIfConsumable :: SrcLoc -> Aliasing -> TermTypeM ()
checkIfConsumable SrcLoc
loc Aliasing
als = do
  Map VName ValBinding
vtable <- (TermEnv -> Map VName ValBinding)
-> TermTypeM (Map VName ValBinding)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((TermEnv -> Map VName ValBinding)
 -> TermTypeM (Map VName ValBinding))
-> (TermEnv -> Map VName ValBinding)
-> TermTypeM (Map VName ValBinding)
forall a b. (a -> b) -> a -> b
$ TermScope -> Map VName ValBinding
scopeVtable (TermScope -> Map VName ValBinding)
-> (TermEnv -> TermScope) -> TermEnv -> Map VName ValBinding
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TermEnv -> TermScope
termScope
  let consumable :: VName -> Bool
consumable VName
v = case VName -> Map VName ValBinding -> Maybe ValBinding
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName ValBinding
vtable of
        Just (BoundV Locality
Local [TypeParam]
_ PatternType
t)
          | PatternType -> Int
forall dim as. TypeBase dim as -> Int
arrayRank PatternType
t Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0 -> PatternType -> Bool
forall shape as. TypeBase shape as -> Bool
unique PatternType
t
          | Scalar TypeVar {} <- PatternType
t -> PatternType -> Bool
forall shape as. TypeBase shape as -> Bool
unique PatternType
t
          | Bool
otherwise -> Bool
True
        Maybe ValBinding
_ -> Bool
False
  case (VName -> Bool) -> [VName] -> [VName]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not (Bool -> Bool) -> (VName -> Bool) -> VName -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Bool
consumable) ([VName] -> [VName]) -> [VName] -> [VName]
forall a b. (a -> b) -> a -> b
$ (Alias -> VName) -> [Alias] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Alias -> VName
aliasVar ([Alias] -> [VName]) -> [Alias] -> [VName]
forall a b. (a -> b) -> a -> b
$ Aliasing -> [Alias]
forall a. Set a -> [a]
S.toList Aliasing
als of
    VName
v : [VName]
_ ->
      SrcLoc -> Notes -> Doc -> TermTypeM ()
forall (m :: * -> *) loc a.
(MonadTypeChecker m, Located loc) =>
loc -> Notes -> Doc -> m a
typeError SrcLoc
loc Notes
forall a. Monoid a => a
mempty (Doc -> TermTypeM ()) -> Doc -> TermTypeM ()
forall a b. (a -> b) -> a -> b
$
        Doc
"Would consume variable" Doc -> Doc -> Doc
<+> Doc -> Doc
pquote (VName -> Doc
forall v. IsName v => v -> Doc
pprName VName
v)
          Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
", which is not allowed."
    [] -> () -> TermTypeM ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

-- | Proclaim that we have written to the given variable.
consume :: SrcLoc -> Aliasing -> TermTypeM ()
consume :: SrcLoc -> Aliasing -> TermTypeM ()
consume SrcLoc
loc Aliasing
als = do
  SrcLoc -> Aliasing -> TermTypeM ()
checkIfConsumable SrcLoc
loc Aliasing
als
  [Occurence] -> TermTypeM ()
occur [Aliasing -> SrcLoc -> Occurence
consumption Aliasing
als SrcLoc
loc]

-- | Proclaim that we have written to the given variable, and mark
-- accesses to it and all of its aliases as invalid inside the given
-- computation.
consuming :: Ident -> TermTypeM a -> TermTypeM a
consuming :: Ident -> TermTypeM a -> TermTypeM a
consuming (Ident VName
name (Info PatternType
t) SrcLoc
loc) TermTypeM a
m = do
  SrcLoc -> Aliasing -> TermTypeM ()
consume SrcLoc
loc (Aliasing -> TermTypeM ()) -> Aliasing -> TermTypeM ()
forall a b. (a -> b) -> a -> b
$ VName -> Alias
AliasBound VName
name Alias -> Aliasing -> Aliasing
forall a. Ord a => a -> Set a -> Set a
`S.insert` PatternType -> Aliasing
forall as shape. Monoid as => TypeBase shape as -> as
aliases PatternType
t
  (TermScope -> TermScope) -> TermTypeM a -> TermTypeM a
forall a. (TermScope -> TermScope) -> TermTypeM a -> TermTypeM a
localScope TermScope -> TermScope
consume' TermTypeM a
m
  where
    consume' :: TermScope -> TermScope
consume' TermScope
scope =
      TermScope
scope {scopeVtable :: Map VName ValBinding
scopeVtable = VName -> ValBinding -> Map VName ValBinding -> Map VName ValBinding
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
name (SrcLoc -> ValBinding
WasConsumed SrcLoc
loc) (Map VName ValBinding -> Map VName ValBinding)
-> Map VName ValBinding -> Map VName ValBinding
forall a b. (a -> b) -> a -> b
$ TermScope -> Map VName ValBinding
scopeVtable TermScope
scope}

collectOccurences :: TermTypeM a -> TermTypeM (a, Occurences)
collectOccurences :: TermTypeM a -> TermTypeM (a, [Occurence])
collectOccurences TermTypeM a
m = TermTypeM ((a, [Occurence]), [Occurence] -> [Occurence])
-> TermTypeM (a, [Occurence])
forall w (m :: * -> *) a. MonadWriter w m => m (a, w -> w) -> m a
pass (TermTypeM ((a, [Occurence]), [Occurence] -> [Occurence])
 -> TermTypeM (a, [Occurence]))
-> TermTypeM ((a, [Occurence]), [Occurence] -> [Occurence])
-> TermTypeM (a, [Occurence])
forall a b. (a -> b) -> a -> b
$ do
  (a
x, [Occurence]
dataflow) <- TermTypeM a -> TermTypeM (a, [Occurence])
forall w (m :: * -> *) a. MonadWriter w m => m a -> m (a, w)
listen TermTypeM a
m
  ((a, [Occurence]), [Occurence] -> [Occurence])
-> TermTypeM ((a, [Occurence]), [Occurence] -> [Occurence])
forall (m :: * -> *) a. Monad m => a -> m a
return ((a
x, [Occurence]
dataflow), [Occurence] -> [Occurence] -> [Occurence]
forall a b. a -> b -> a
const [Occurence]
forall a. Monoid a => a
mempty)

tapOccurences :: TermTypeM a -> TermTypeM (a, Occurences)
tapOccurences :: TermTypeM a -> TermTypeM (a, [Occurence])
tapOccurences = TermTypeM a -> TermTypeM (a, [Occurence])
forall w (m :: * -> *) a. MonadWriter w m => m a -> m (a, w)
listen

removeSeminullOccurences :: TermTypeM a -> TermTypeM a
removeSeminullOccurences :: TermTypeM a -> TermTypeM a
removeSeminullOccurences = ([Occurence] -> [Occurence]) -> TermTypeM a -> TermTypeM a
forall w (m :: * -> *) a. MonadWriter w m => (w -> w) -> m a -> m a
censor (([Occurence] -> [Occurence]) -> TermTypeM a -> TermTypeM a)
-> ([Occurence] -> [Occurence]) -> TermTypeM a -> TermTypeM a
forall a b. (a -> b) -> a -> b
$ (Occurence -> Bool) -> [Occurence] -> [Occurence]
forall a. (a -> Bool) -> [a] -> [a]
filter ((Occurence -> Bool) -> [Occurence] -> [Occurence])
-> (Occurence -> Bool) -> [Occurence] -> [Occurence]
forall a b. (a -> b) -> a -> b
$ Bool -> Bool
not (Bool -> Bool) -> (Occurence -> Bool) -> Occurence -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Occurence -> Bool
seminullOccurence

checkIfUsed :: Occurences -> Ident -> TermTypeM ()
checkIfUsed :: [Occurence] -> Ident -> TermTypeM ()
checkIfUsed [Occurence]
occs Ident
v
  | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Ident -> VName
forall (f :: * -> *) vn. IdentBase f vn -> vn
identName Ident
v VName -> Names -> Bool
forall a. Ord a => a -> Set a -> Bool
`S.member` [Occurence] -> Names
allOccuring [Occurence]
occs,
    Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ String
"_" String -> String -> Bool
forall a. Eq a => [a] -> [a] -> Bool
`isPrefixOf` VName -> String
forall v. IsName v => v -> String
prettyName (Ident -> VName
forall (f :: * -> *) vn. IdentBase f vn -> vn
identName Ident
v) =
    SrcLoc -> String -> TermTypeM ()
forall (m :: * -> *) loc.
(MonadTypeChecker m, Located loc) =>
loc -> String -> m ()
warn (Ident -> SrcLoc
forall a. Located a => a -> SrcLoc
srclocOf Ident
v) (String -> TermTypeM ()) -> String -> TermTypeM ()
forall a b. (a -> b) -> a -> b
$ String
"Unused variable " String -> ShowS
forall a. [a] -> [a] -> [a]
++ ShowS
quote (Name -> String
forall a. Pretty a => a -> String
pretty (Name -> String) -> Name -> String
forall a b. (a -> b) -> a -> b
$ VName -> Name
baseName (VName -> Name) -> VName -> Name
forall a b. (a -> b) -> a -> b
$ Ident -> VName
forall (f :: * -> *) vn. IdentBase f vn -> vn
identName Ident
v) String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"."
  | Bool
otherwise =
    () -> TermTypeM ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

alternative :: TermTypeM a -> TermTypeM b -> TermTypeM (a, b)
alternative :: TermTypeM a -> TermTypeM b -> TermTypeM (a, b)
alternative TermTypeM a
m1 TermTypeM b
m2 = TermTypeM ((a, b), [Occurence] -> [Occurence]) -> TermTypeM (a, b)
forall w (m :: * -> *) a. MonadWriter w m => m (a, w -> w) -> m a
pass (TermTypeM ((a, b), [Occurence] -> [Occurence])
 -> TermTypeM (a, b))
-> TermTypeM ((a, b), [Occurence] -> [Occurence])
-> TermTypeM (a, b)
forall a b. (a -> b) -> a -> b
$ do
  (a
x, [Occurence]
occurs1) <- TermTypeM a -> TermTypeM (a, [Occurence])
forall w (m :: * -> *) a. MonadWriter w m => m a -> m (a, w)
listen (TermTypeM a -> TermTypeM (a, [Occurence]))
-> TermTypeM a -> TermTypeM (a, [Occurence])
forall a b. (a -> b) -> a -> b
$ TermTypeM a -> TermTypeM a
forall b. TermTypeM b -> TermTypeM b
noSizeEscape TermTypeM a
m1
  (b
y, [Occurence]
occurs2) <- TermTypeM b -> TermTypeM (b, [Occurence])
forall w (m :: * -> *) a. MonadWriter w m => m a -> m (a, w)
listen (TermTypeM b -> TermTypeM (b, [Occurence]))
-> TermTypeM b -> TermTypeM (b, [Occurence])
forall a b. (a -> b) -> a -> b
$ TermTypeM b -> TermTypeM b
forall b. TermTypeM b -> TermTypeM b
noSizeEscape TermTypeM b
m2
  [Occurence] -> TermTypeM ()
checkOccurences [Occurence]
occurs1
  [Occurence] -> TermTypeM ()
checkOccurences [Occurence]
occurs2
  let usage :: [Occurence]
usage = [Occurence]
occurs1 [Occurence] -> [Occurence] -> [Occurence]
`altOccurences` [Occurence]
occurs2
  ((a, b), [Occurence] -> [Occurence])
-> TermTypeM ((a, b), [Occurence] -> [Occurence])
forall (m :: * -> *) a. Monad m => a -> m a
return ((a
x, b
y), [Occurence] -> [Occurence] -> [Occurence]
forall a b. a -> b -> a
const [Occurence]
usage)

-- | Make all bindings nonunique.
noUnique :: TermTypeM a -> TermTypeM a
noUnique :: TermTypeM a -> TermTypeM a
noUnique = (TermScope -> TermScope) -> TermTypeM a -> TermTypeM a
forall a. (TermScope -> TermScope) -> TermTypeM a -> TermTypeM a
localScope (\TermScope
scope -> TermScope
scope {scopeVtable :: Map VName ValBinding
scopeVtable = (ValBinding -> ValBinding)
-> Map VName ValBinding -> Map VName ValBinding
forall a b k. (a -> b) -> Map k a -> Map k b
M.map ValBinding -> ValBinding
set (Map VName ValBinding -> Map VName ValBinding)
-> Map VName ValBinding -> Map VName ValBinding
forall a b. (a -> b) -> a -> b
$ TermScope -> Map VName ValBinding
scopeVtable TermScope
scope})
  where
    set :: ValBinding -> ValBinding
set (BoundV Locality
l [TypeParam]
tparams PatternType
t) = Locality -> [TypeParam] -> PatternType -> ValBinding
BoundV Locality
l [TypeParam]
tparams (PatternType -> ValBinding) -> PatternType -> ValBinding
forall a b. (a -> b) -> a -> b
$ PatternType
t PatternType -> Uniqueness -> PatternType
forall dim as. TypeBase dim as -> Uniqueness -> TypeBase dim as
`setUniqueness` Uniqueness
Nonunique
    set (OverloadedF [PrimType]
ts [Maybe PrimType]
pts Maybe PrimType
rt) = [PrimType] -> [Maybe PrimType] -> Maybe PrimType -> ValBinding
OverloadedF [PrimType]
ts [Maybe PrimType]
pts Maybe PrimType
rt
    set ValBinding
EqualityF = ValBinding
EqualityF
    set (WasConsumed SrcLoc
loc) = SrcLoc -> ValBinding
WasConsumed SrcLoc
loc

onlySelfAliasing :: TermTypeM a -> TermTypeM a
onlySelfAliasing :: TermTypeM a -> TermTypeM a
onlySelfAliasing = (TermScope -> TermScope) -> TermTypeM a -> TermTypeM a
forall a. (TermScope -> TermScope) -> TermTypeM a -> TermTypeM a
localScope (\TermScope
scope -> TermScope
scope {scopeVtable :: Map VName ValBinding
scopeVtable = (VName -> ValBinding -> ValBinding)
-> Map VName ValBinding -> Map VName ValBinding
forall k a b. (k -> a -> b) -> Map k a -> Map k b
M.mapWithKey VName -> ValBinding -> ValBinding
set (Map VName ValBinding -> Map VName ValBinding)
-> Map VName ValBinding -> Map VName ValBinding
forall a b. (a -> b) -> a -> b
$ TermScope -> Map VName ValBinding
scopeVtable TermScope
scope})
  where
    set :: VName -> ValBinding -> ValBinding
set VName
k (BoundV Locality
l [TypeParam]
tparams PatternType
t) =
      Locality -> [TypeParam] -> PatternType -> ValBinding
BoundV Locality
l [TypeParam]
tparams (PatternType -> ValBinding) -> PatternType -> ValBinding
forall a b. (a -> b) -> a -> b
$
        PatternType
t PatternType -> (Aliasing -> Aliasing) -> PatternType
forall dim asf ast.
TypeBase dim asf -> (asf -> ast) -> TypeBase dim ast
`addAliases` Aliasing -> Aliasing -> Aliasing
forall a. Ord a => Set a -> Set a -> Set a
S.intersection (Alias -> Aliasing
forall a. a -> Set a
S.singleton (VName -> Alias
AliasBound VName
k))
    set VName
_ (OverloadedF [PrimType]
ts [Maybe PrimType]
pts Maybe PrimType
rt) = [PrimType] -> [Maybe PrimType] -> Maybe PrimType -> ValBinding
OverloadedF [PrimType]
ts [Maybe PrimType]
pts Maybe PrimType
rt
    set VName
_ ValBinding
EqualityF = ValBinding
EqualityF
    set VName
_ (WasConsumed SrcLoc
loc) = SrcLoc -> ValBinding
WasConsumed SrcLoc
loc

arrayOfM ::
  (Pretty (ShapeDecl dim), Monoid as) =>
  SrcLoc ->
  TypeBase dim as ->
  ShapeDecl dim ->
  Uniqueness ->
  TermTypeM (TypeBase dim as)
arrayOfM :: SrcLoc
-> TypeBase dim as
-> ShapeDecl dim
-> Uniqueness
-> TermTypeM (TypeBase dim as)
arrayOfM SrcLoc
loc TypeBase dim as
t ShapeDecl dim
shape Uniqueness
u = do
  Usage -> String -> TypeBase dim as -> TermTypeM ()
forall (m :: * -> *) dim as.
(MonadUnify m, Pretty (ShapeDecl dim), Monoid as) =>
Usage -> String -> TypeBase dim as -> m ()
zeroOrderType (SrcLoc -> String -> Usage
mkUsage SrcLoc
loc String
"use as array element") String
"type used in array" TypeBase dim as
t
  TypeBase dim as -> TermTypeM (TypeBase dim as)
forall (m :: * -> *) a. Monad m => a -> m a
return (TypeBase dim as -> TermTypeM (TypeBase dim as))
-> TypeBase dim as -> TermTypeM (TypeBase dim as)
forall a b. (a -> b) -> a -> b
$ TypeBase dim as -> ShapeDecl dim -> Uniqueness -> TypeBase dim as
forall as dim.
Monoid as =>
TypeBase dim as -> ShapeDecl dim -> Uniqueness -> TypeBase dim as
arrayOf TypeBase dim as
t ShapeDecl dim
shape Uniqueness
u

updateTypes :: ASTMappable e => e -> TermTypeM e
updateTypes :: e -> TermTypeM e
updateTypes = ASTMapper TermTypeM -> e -> TermTypeM e
forall x (m :: * -> *).
(ASTMappable x, Monad m) =>
ASTMapper m -> x -> m x
astMap ASTMapper TermTypeM
tv
  where
    tv :: ASTMapper TermTypeM
tv =
      ASTMapper :: forall (m :: * -> *).
(Exp -> m Exp)
-> (VName -> m VName)
-> (QualName VName -> m (QualName VName))
-> (StructType -> m StructType)
-> (PatternType -> m PatternType)
-> ASTMapper m
ASTMapper
        { mapOnExp :: Exp -> TermTypeM Exp
mapOnExp = ASTMapper TermTypeM -> Exp -> TermTypeM Exp
forall x (m :: * -> *).
(ASTMappable x, Monad m) =>
ASTMapper m -> x -> m x
astMap ASTMapper TermTypeM
tv,
          mapOnName :: VName -> TermTypeM VName
mapOnName = VName -> TermTypeM VName
forall (f :: * -> *) a. Applicative f => a -> f a
pure,
          mapOnQualName :: QualName VName -> TermTypeM (QualName VName)
mapOnQualName = QualName VName -> TermTypeM (QualName VName)
forall (f :: * -> *) a. Applicative f => a -> f a
pure,
          mapOnStructType :: StructType -> TermTypeM StructType
mapOnStructType = StructType -> TermTypeM StructType
forall a (m :: * -> *). (Substitutable a, MonadUnify m) => a -> m a
normTypeFully,
          mapOnPatternType :: PatternType -> TermTypeM PatternType
mapOnPatternType = PatternType -> TermTypeM PatternType
forall a (m :: * -> *). (Substitutable a, MonadUnify m) => a -> m a
normTypeFully
        }