{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE Trustworthy #-}
-- | Implementation of unification and other core type system building
-- blocks.
module Language.Futhark.TypeChecker.Unify
  ( Constraint(..)
  , Usage
  , mkUsage
  , mkUsage'
  , Level
  , Constraints
  , MonadUnify(..)
  , Rigidity(..)
  , RigidSource(..)
  , BreadCrumbs
  , noBreadCrumbs
  , hasNoBreadCrumbs
  , dimNotes
  , mkTypeVarName

  , zeroOrderType
  , mustHaveConstr
  , mustHaveField
  , mustBeOneOf
  , equalityType
  , normType
  , normPatternType
  , normTypeFully
  , instantiateEmptyArrayDims

  , unify
  , expect
  , unifyMostCommon
  , anyDimOnMismatch
  , doUnification
  )
where

import Control.Monad.Except
import Control.Monad.Writer hiding (Sum)
import Control.Monad.RWS.Strict hiding (Sum)
import Control.Monad.State
import Data.Bifoldable (biany)
import Data.List (intersect)
import Data.Maybe
import qualified Data.Map.Strict as M
import qualified Data.Set as S

import Language.Futhark hiding (unifyDims)
import Language.Futhark.TypeChecker.Monad hiding (BoundV)
import Language.Futhark.TypeChecker.Types
import Futhark.Util.Pretty hiding (empty)

-- | A piece of information that describes what process the type
-- checker currently performing.  This is used to give better error
-- messages for unification errors.
data BreadCrumb = MatchingTypes StructType StructType
                | MatchingFields [Name]
                | MatchingConstructor Name
                | Matching Doc

instance Pretty BreadCrumb where
  ppr :: BreadCrumb -> Doc
ppr (MatchingTypes StructType
t1 StructType
t2) =
    Doc
"When matching type" Doc -> Doc -> Doc
</> Int -> Doc -> Doc
indent Int
2 (StructType -> Doc
forall a. Pretty a => a -> Doc
ppr StructType
t1) Doc -> Doc -> Doc
</>
    Doc
"with" Doc -> Doc -> Doc
</> Int -> Doc -> Doc
indent Int
2 (StructType -> Doc
forall a. Pretty a => a -> Doc
ppr StructType
t2)
  ppr (MatchingFields [Name]
fields) =
    Doc
"When matching types of record field" Doc -> Doc -> Doc
<+>
    Doc -> Doc
pquote ([Doc] -> Doc
forall a. Monoid a => [a] -> a
mconcat ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ Doc -> [Doc] -> [Doc]
punctuate Doc
"." ([Doc] -> [Doc]) -> [Doc] -> [Doc]
forall a b. (a -> b) -> a -> b
$ (Name -> Doc) -> [Name] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map Name -> Doc
forall a. Pretty a => a -> Doc
ppr [Name]
fields) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
dot
  ppr (MatchingConstructor Name
c) =
    Doc
"When matching types of constructor" Doc -> Doc -> Doc
<+> Doc -> Doc
pquote (Name -> Doc
forall a. Pretty a => a -> Doc
ppr Name
c) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
dot
  ppr (Matching Doc
s) =
    Doc
s

-- | Unification failures can occur deep down inside complicated types
-- (consider nested records).  We leave breadcrumbs behind us so we
-- can report the path we took to find the mismatch.
newtype BreadCrumbs = BreadCrumbs [BreadCrumb]

-- | An empty path.
noBreadCrumbs :: BreadCrumbs
noBreadCrumbs :: BreadCrumbs
noBreadCrumbs = [BreadCrumb] -> BreadCrumbs
BreadCrumbs []

-- | Is the path empty?
hasNoBreadCrumbs :: BreadCrumbs -> Bool
hasNoBreadCrumbs :: BreadCrumbs -> Bool
hasNoBreadCrumbs (BreadCrumbs [BreadCrumb]
xs) = [BreadCrumb] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [BreadCrumb]
xs

-- | Drop a breadcrumb on the path behind you.
breadCrumb :: BreadCrumb -> BreadCrumbs -> BreadCrumbs
breadCrumb :: BreadCrumb -> BreadCrumbs -> BreadCrumbs
breadCrumb (MatchingFields [Name]
xs) (BreadCrumbs (MatchingFields [Name]
ys : [BreadCrumb]
bcs)) =
  [BreadCrumb] -> BreadCrumbs
BreadCrumbs ([BreadCrumb] -> BreadCrumbs) -> [BreadCrumb] -> BreadCrumbs
forall a b. (a -> b) -> a -> b
$ [Name] -> BreadCrumb
MatchingFields ([Name]
ys[Name] -> [Name] -> [Name]
forall a. [a] -> [a] -> [a]
++[Name]
xs) BreadCrumb -> [BreadCrumb] -> [BreadCrumb]
forall a. a -> [a] -> [a]
: [BreadCrumb]
bcs
breadCrumb BreadCrumb
bc (BreadCrumbs [BreadCrumb]
bcs) =
  [BreadCrumb] -> BreadCrumbs
BreadCrumbs ([BreadCrumb] -> BreadCrumbs) -> [BreadCrumb] -> BreadCrumbs
forall a b. (a -> b) -> a -> b
$ BreadCrumb
bc BreadCrumb -> [BreadCrumb] -> [BreadCrumb]
forall a. a -> [a] -> [a]
: [BreadCrumb]
bcs

instance Pretty BreadCrumbs where
  ppr :: BreadCrumbs -> Doc
ppr (BreadCrumbs []) = Doc
forall a. Monoid a => a
mempty
  ppr (BreadCrumbs [BreadCrumb]
bcs) = Doc
line Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> [Doc] -> Doc
stack ((BreadCrumb -> Doc) -> [BreadCrumb] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map BreadCrumb -> Doc
forall a. Pretty a => a -> Doc
ppr [BreadCrumb]
bcs)

-- | A usage that caused a type constraint.
data Usage = Usage (Maybe String) SrcLoc
  deriving (Int -> Usage -> ShowS
[Usage] -> ShowS
Usage -> String
(Int -> Usage -> ShowS)
-> (Usage -> String) -> ([Usage] -> ShowS) -> Show Usage
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Usage] -> ShowS
$cshowList :: [Usage] -> ShowS
show :: Usage -> String
$cshow :: Usage -> String
showsPrec :: Int -> Usage -> ShowS
$cshowsPrec :: Int -> Usage -> ShowS
Show)

-- | Construct a 'Usage' from a location and a description.
mkUsage :: SrcLoc -> String -> Usage
mkUsage :: SrcLoc -> String -> Usage
mkUsage = (String -> SrcLoc -> Usage) -> SrcLoc -> String -> Usage
forall a b c. (a -> b -> c) -> b -> a -> c
flip (Maybe String -> SrcLoc -> Usage
Usage (Maybe String -> SrcLoc -> Usage)
-> (String -> Maybe String) -> String -> SrcLoc -> Usage
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Maybe String
forall a. a -> Maybe a
Just)

-- | Construct a 'Usage' that has just a location, but no particular
-- description.
mkUsage' :: SrcLoc -> Usage
mkUsage' :: SrcLoc -> Usage
mkUsage' = Maybe String -> SrcLoc -> Usage
Usage Maybe String
forall a. Maybe a
Nothing

instance Pretty Usage where
  ppr :: Usage -> Doc
ppr (Usage Maybe String
Nothing SrcLoc
loc) = Doc
"use at " Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> String -> Doc
textwrap (SrcLoc -> String
forall a. Located a => a -> String
locStr SrcLoc
loc)
  ppr (Usage (Just String
s) SrcLoc
loc) = String -> Doc
textwrap String
s Doc -> Doc -> Doc
<+/> Doc
"at" Doc -> Doc -> Doc
<+> String -> Doc
textwrap (SrcLoc -> String
forall a. Located a => a -> String
locStr SrcLoc
loc)

instance Located Usage where
  locOf :: Usage -> Loc
locOf (Usage Maybe String
_ SrcLoc
loc) = SrcLoc -> Loc
forall a. Located a => a -> Loc
locOf SrcLoc
loc

-- | The level at which a type variable is bound.  Higher means
-- deeper.  We can only unify a type variable at level @i@ with a type
-- @t@ if all type names that occur in @t@ are at most at level @i@.
type Level = Int

-- | A constraint on a yet-ambiguous type variable.
data Constraint = NoConstraint Liftedness Usage
                | ParamType Liftedness SrcLoc
                | Constraint StructType Usage
                | Overloaded [PrimType] Usage
                | HasFields (M.Map Name StructType) Usage
                | Equality Usage
                | HasConstrs (M.Map Name [StructType]) Usage
                | ParamSize SrcLoc
                | Size (Maybe (DimDecl VName)) Usage
                  -- ^ Is not actually a type, but a term-level size,
                  -- possibly already set to something specific.
                | UnknowableSize SrcLoc RigidSource
                  -- ^ A size that does not unify with anything -
                  -- created from the result of applying a function
                  -- whose return size is existential, or otherwise
                  -- hiding a size.
                deriving Int -> Constraint -> ShowS
[Constraint] -> ShowS
Constraint -> String
(Int -> Constraint -> ShowS)
-> (Constraint -> String)
-> ([Constraint] -> ShowS)
-> Show Constraint
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Constraint] -> ShowS
$cshowList :: [Constraint] -> ShowS
show :: Constraint -> String
$cshow :: Constraint -> String
showsPrec :: Int -> Constraint -> ShowS
$cshowsPrec :: Int -> Constraint -> ShowS
Show

instance Located Constraint where
  locOf :: Constraint -> Loc
locOf (NoConstraint Liftedness
_ Usage
usage) = Usage -> Loc
forall a. Located a => a -> Loc
locOf Usage
usage
  locOf (ParamType Liftedness
_ SrcLoc
usage) = SrcLoc -> Loc
forall a. Located a => a -> Loc
locOf SrcLoc
usage
  locOf (Constraint StructType
_ Usage
usage) = Usage -> Loc
forall a. Located a => a -> Loc
locOf Usage
usage
  locOf (Overloaded [PrimType]
_ Usage
usage) = Usage -> Loc
forall a. Located a => a -> Loc
locOf Usage
usage
  locOf (HasFields Map Name StructType
_ Usage
usage) = Usage -> Loc
forall a. Located a => a -> Loc
locOf Usage
usage
  locOf (Equality Usage
usage) = Usage -> Loc
forall a. Located a => a -> Loc
locOf Usage
usage
  locOf (HasConstrs Map Name [StructType]
_ Usage
usage) = Usage -> Loc
forall a. Located a => a -> Loc
locOf Usage
usage
  locOf (ParamSize SrcLoc
loc) = SrcLoc -> Loc
forall a. Located a => a -> Loc
locOf SrcLoc
loc
  locOf (Size Maybe (DimDecl VName)
_ Usage
usage) = Usage -> Loc
forall a. Located a => a -> Loc
locOf Usage
usage
  locOf (UnknowableSize SrcLoc
loc RigidSource
_) = SrcLoc -> Loc
forall a. Located a => a -> Loc
locOf SrcLoc
loc

-- | Mapping from fresh type variables, instantiated from the type
-- schemes of polymorphic functions, to (possibly) specific types as
-- determined on application and the location of that application, or
-- a partial constraint on their type.
type Constraints = M.Map VName (Level, Constraint)

lookupSubst :: VName -> Constraints -> Maybe (Subst StructType)
lookupSubst :: VName -> Constraints -> Maybe (Subst StructType)
lookupSubst VName
v Constraints
constraints = case (Int, Constraint) -> Constraint
forall a b. (a, b) -> b
snd ((Int, Constraint) -> Constraint)
-> Maybe (Int, Constraint) -> Maybe Constraint
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> Constraints -> Maybe (Int, Constraint)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Constraints
constraints of
                              Just (Constraint StructType
t Usage
_) -> Subst StructType -> Maybe (Subst StructType)
forall a. a -> Maybe a
Just (Subst StructType -> Maybe (Subst StructType))
-> Subst StructType -> Maybe (Subst StructType)
forall a b. (a -> b) -> a -> b
$ StructType -> Subst StructType
forall t. t -> Subst t
Subst StructType
t
                              Just Overloaded{} -> Subst StructType -> Maybe (Subst StructType)
forall a. a -> Maybe a
Just Subst StructType
forall t. Subst t
PrimSubst
                              Just (Size (Just DimDecl VName
d) Usage
_) ->
                                Subst StructType -> Maybe (Subst StructType)
forall a. a -> Maybe a
Just (Subst StructType -> Maybe (Subst StructType))
-> Subst StructType -> Maybe (Subst StructType)
forall a b. (a -> b) -> a -> b
$ DimDecl VName -> Subst StructType
forall t. DimDecl VName -> Subst t
SizeSubst (DimDecl VName -> Subst StructType)
-> DimDecl VName -> Subst StructType
forall a b. (a -> b) -> a -> b
$ (VName -> Maybe (Subst StructType))
-> DimDecl VName -> DimDecl VName
forall a.
Substitutable a =>
(VName -> Maybe (Subst StructType)) -> a -> a
applySubst (VName -> Constraints -> Maybe (Subst StructType)
`lookupSubst` Constraints
constraints) DimDecl VName
d
                              Maybe Constraint
_ -> Maybe (Subst StructType)
forall a. Maybe a
Nothing

-- | The source of a rigid size.
data RigidSource
  = RigidArg (Maybe (QualName VName)) String
    -- ^ A function argument that is not a constant or variable name.
  | RigidRet (Maybe (QualName VName))
    -- ^ An existential return size.
  | RigidLoop
  | RigidSlice (Maybe (DimDecl VName)) String
    -- ^ Produced by a complicated slice expression.
  | RigidRange
    -- ^ Produced by a complicated range expression.
  | RigidBound String
    -- ^ Produced by a range expression with this bound.
  | RigidCond StructType StructType
    -- ^ Mismatch in branches.
  | RigidUnify
    -- ^ Invented during unification.
  | RigidOutOfScope SrcLoc VName
  deriving (RigidSource -> RigidSource -> Bool
(RigidSource -> RigidSource -> Bool)
-> (RigidSource -> RigidSource -> Bool) -> Eq RigidSource
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: RigidSource -> RigidSource -> Bool
$c/= :: RigidSource -> RigidSource -> Bool
== :: RigidSource -> RigidSource -> Bool
$c== :: RigidSource -> RigidSource -> Bool
Eq, Eq RigidSource
Eq RigidSource
-> (RigidSource -> RigidSource -> Ordering)
-> (RigidSource -> RigidSource -> Bool)
-> (RigidSource -> RigidSource -> Bool)
-> (RigidSource -> RigidSource -> Bool)
-> (RigidSource -> RigidSource -> Bool)
-> (RigidSource -> RigidSource -> RigidSource)
-> (RigidSource -> RigidSource -> RigidSource)
-> Ord RigidSource
RigidSource -> RigidSource -> Bool
RigidSource -> RigidSource -> Ordering
RigidSource -> RigidSource -> RigidSource
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: RigidSource -> RigidSource -> RigidSource
$cmin :: RigidSource -> RigidSource -> RigidSource
max :: RigidSource -> RigidSource -> RigidSource
$cmax :: RigidSource -> RigidSource -> RigidSource
>= :: RigidSource -> RigidSource -> Bool
$c>= :: RigidSource -> RigidSource -> Bool
> :: RigidSource -> RigidSource -> Bool
$c> :: RigidSource -> RigidSource -> Bool
<= :: RigidSource -> RigidSource -> Bool
$c<= :: RigidSource -> RigidSource -> Bool
< :: RigidSource -> RigidSource -> Bool
$c< :: RigidSource -> RigidSource -> Bool
compare :: RigidSource -> RigidSource -> Ordering
$ccompare :: RigidSource -> RigidSource -> Ordering
$cp1Ord :: Eq RigidSource
Ord, Int -> RigidSource -> ShowS
[RigidSource] -> ShowS
RigidSource -> String
(Int -> RigidSource -> ShowS)
-> (RigidSource -> String)
-> ([RigidSource] -> ShowS)
-> Show RigidSource
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [RigidSource] -> ShowS
$cshowList :: [RigidSource] -> ShowS
show :: RigidSource -> String
$cshow :: RigidSource -> String
showsPrec :: Int -> RigidSource -> ShowS
$cshowsPrec :: Int -> RigidSource -> ShowS
Show)

-- | The ridigity of a size variable.  All rigid sizes are tagged with
-- information about how they were generated.
data Rigidity = Rigid RigidSource | Nonrigid
              deriving (Rigidity -> Rigidity -> Bool
(Rigidity -> Rigidity -> Bool)
-> (Rigidity -> Rigidity -> Bool) -> Eq Rigidity
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Rigidity -> Rigidity -> Bool
$c/= :: Rigidity -> Rigidity -> Bool
== :: Rigidity -> Rigidity -> Bool
$c== :: Rigidity -> Rigidity -> Bool
Eq, Eq Rigidity
Eq Rigidity
-> (Rigidity -> Rigidity -> Ordering)
-> (Rigidity -> Rigidity -> Bool)
-> (Rigidity -> Rigidity -> Bool)
-> (Rigidity -> Rigidity -> Bool)
-> (Rigidity -> Rigidity -> Bool)
-> (Rigidity -> Rigidity -> Rigidity)
-> (Rigidity -> Rigidity -> Rigidity)
-> Ord Rigidity
Rigidity -> Rigidity -> Bool
Rigidity -> Rigidity -> Ordering
Rigidity -> Rigidity -> Rigidity
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: Rigidity -> Rigidity -> Rigidity
$cmin :: Rigidity -> Rigidity -> Rigidity
max :: Rigidity -> Rigidity -> Rigidity
$cmax :: Rigidity -> Rigidity -> Rigidity
>= :: Rigidity -> Rigidity -> Bool
$c>= :: Rigidity -> Rigidity -> Bool
> :: Rigidity -> Rigidity -> Bool
$c> :: Rigidity -> Rigidity -> Bool
<= :: Rigidity -> Rigidity -> Bool
$c<= :: Rigidity -> Rigidity -> Bool
< :: Rigidity -> Rigidity -> Bool
$c< :: Rigidity -> Rigidity -> Bool
compare :: Rigidity -> Rigidity -> Ordering
$ccompare :: Rigidity -> Rigidity -> Ordering
$cp1Ord :: Eq Rigidity
Ord, Int -> Rigidity -> ShowS
[Rigidity] -> ShowS
Rigidity -> String
(Int -> Rigidity -> ShowS)
-> (Rigidity -> String) -> ([Rigidity] -> ShowS) -> Show Rigidity
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Rigidity] -> ShowS
$cshowList :: [Rigidity] -> ShowS
show :: Rigidity -> String
$cshow :: Rigidity -> String
showsPrec :: Int -> Rigidity -> ShowS
$cshowsPrec :: Int -> Rigidity -> ShowS
Show)

prettySource :: SrcLoc -> SrcLoc -> RigidSource -> Doc

prettySource :: SrcLoc -> SrcLoc -> RigidSource -> Doc
prettySource SrcLoc
ctx SrcLoc
loc (RigidRet Maybe (QualName VName)
Nothing) =
  Doc
"is unknown size returned by function at" Doc -> Doc -> Doc
<+>
  String -> Doc
text (SrcLoc -> SrcLoc -> String
forall a b. (Located a, Located b) => a -> b -> String
locStrRel SrcLoc
ctx SrcLoc
loc) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
"."

prettySource SrcLoc
ctx SrcLoc
loc (RigidRet (Just QualName VName
fname)) =
  Doc
"is unknown size returned by" Doc -> Doc -> Doc
<+> Doc -> Doc
pquote (QualName VName -> Doc
forall a. Pretty a => a -> Doc
ppr QualName VName
fname) Doc -> Doc -> Doc
<+>
  Doc
"at" Doc -> Doc -> Doc
<+> String -> Doc
text (SrcLoc -> SrcLoc -> String
forall a b. (Located a, Located b) => a -> b -> String
locStrRel SrcLoc
ctx SrcLoc
loc) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
"."

prettySource SrcLoc
ctx SrcLoc
loc (RigidArg Maybe (QualName VName)
fname String
arg) =
  Doc
"is value of argument" Doc -> Doc -> Doc
</>
  Int -> Doc -> Doc
indent Int
2 (String -> Doc
forall a. Pretty a => a -> Doc
shorten String
arg) Doc -> Doc -> Doc
</>
  Doc
"passed to" Doc -> Doc -> Doc
<+> Doc
fname' Doc -> Doc -> Doc
<+> Doc
"at" Doc -> Doc -> Doc
<+> String -> Doc
text (SrcLoc -> SrcLoc -> String
forall a b. (Located a, Located b) => a -> b -> String
locStrRel SrcLoc
ctx SrcLoc
loc) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
"."
  where fname' :: Doc
fname' = Doc -> (QualName VName -> Doc) -> Maybe (QualName VName) -> Doc
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Doc
"function" (Doc -> Doc
pquote (Doc -> Doc) -> (QualName VName -> Doc) -> QualName VName -> Doc
forall b c a. (b -> c) -> (a -> b) -> a -> c
. QualName VName -> Doc
forall a. Pretty a => a -> Doc
ppr) Maybe (QualName VName)
fname

prettySource SrcLoc
ctx SrcLoc
loc (RigidSlice Maybe (DimDecl VName)
d String
slice) =
  Doc
"is size produced by slice" Doc -> Doc -> Doc
</>
  Int -> Doc -> Doc
indent Int
2 (String -> Doc
forall a. Pretty a => a -> Doc
shorten String
slice) Doc -> Doc -> Doc
</>
  Doc
d_desc Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
"at" Doc -> Doc -> Doc
<+> String -> Doc
text (SrcLoc -> SrcLoc -> String
forall a b. (Located a, Located b) => a -> b -> String
locStrRel SrcLoc
ctx SrcLoc
loc) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
"."
  where d_desc :: Doc
d_desc = case Maybe (DimDecl VName)
d of
                   Just DimDecl VName
d' -> Doc
"of dimension of size " Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
pquote (DimDecl VName -> Doc
forall a. Pretty a => a -> Doc
ppr DimDecl VName
d') Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
" "
                   Maybe (DimDecl VName)
Nothing -> Doc
forall a. Monoid a => a
mempty

prettySource SrcLoc
ctx SrcLoc
loc RigidSource
RigidLoop =
  Doc
"is unknown size of value returned at" Doc -> Doc -> Doc
<+> String -> Doc
text (SrcLoc -> SrcLoc -> String
forall a b. (Located a, Located b) => a -> b -> String
locStrRel SrcLoc
ctx SrcLoc
loc) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
"."

prettySource SrcLoc
ctx SrcLoc
loc RigidSource
RigidRange =
  Doc
"is unknown length of range at" Doc -> Doc -> Doc
<+> String -> Doc
text (SrcLoc -> SrcLoc -> String
forall a b. (Located a, Located b) => a -> b -> String
locStrRel SrcLoc
ctx SrcLoc
loc) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
"."

prettySource SrcLoc
ctx SrcLoc
loc (RigidBound String
bound) =
  Doc
"generated from expression" Doc -> Doc -> Doc
</>
  Int -> Doc -> Doc
indent Int
2 (String -> Doc
forall a. Pretty a => a -> Doc
shorten String
bound) Doc -> Doc -> Doc
</>
  Doc
"used in range at " Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> String -> Doc
text (SrcLoc -> SrcLoc -> String
forall a b. (Located a, Located b) => a -> b -> String
locStrRel SrcLoc
ctx SrcLoc
loc) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
"."

prettySource SrcLoc
ctx SrcLoc
loc (RigidOutOfScope SrcLoc
boundloc VName
v) =
  Doc
"is an unknown size arising from " Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
pquote (VName -> Doc
forall v. IsName v => v -> Doc
pprName VName
v) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<>
  Doc
" going out of scope at " Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> String -> Doc
text (SrcLoc -> SrcLoc -> String
forall a b. (Located a, Located b) => a -> b -> String
locStrRel SrcLoc
ctx SrcLoc
loc) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
"." Doc -> Doc -> Doc
</>
  Doc
"Originally bound at " Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> String -> Doc
text (SrcLoc -> SrcLoc -> String
forall a b. (Located a, Located b) => a -> b -> String
locStrRel SrcLoc
ctx SrcLoc
boundloc) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
"."

prettySource SrcLoc
_ SrcLoc
_ RigidSource
RigidUnify =
  Doc
"is an artificial size invented during unification of functions with anonymous sizes"

prettySource SrcLoc
ctx SrcLoc
loc (RigidCond StructType
t1 StructType
t2) =
  Doc
"is unknown due to conditional expression at " Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<>
  String -> Doc
text (SrcLoc -> SrcLoc -> String
forall a b. (Located a, Located b) => a -> b -> String
locStrRel SrcLoc
ctx SrcLoc
loc) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
"." Doc -> Doc -> Doc
</>
  Doc
"One branch returns array of type: " Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
align (StructType -> Doc
forall a. Pretty a => a -> Doc
ppr StructType
t1) Doc -> Doc -> Doc
</>
  Doc
"The other an array of type:       " Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
align (StructType -> Doc
forall a. Pretty a => a -> Doc
ppr StructType
t2)

-- | Retrieve notes describing the purpose or origin of the given
-- 'DimDecl'.  The location is used as the *current* location, for the
-- purpose of reporting relative locations.
dimNotes :: (Located a, MonadUnify m) => a -> DimDecl VName -> m Notes
dimNotes :: a -> DimDecl VName -> m Notes
dimNotes a
ctx (NamedDim QualName VName
d) = do
  Maybe (Int, Constraint)
c <- VName -> Constraints -> Maybe (Int, Constraint)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (QualName VName -> VName
forall vn. QualName vn -> vn
qualLeaf QualName VName
d) (Constraints -> Maybe (Int, Constraint))
-> m Constraints -> m (Maybe (Int, Constraint))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m Constraints
forall (m :: * -> *). MonadUnify m => m Constraints
getConstraints
  case Maybe (Int, Constraint)
c of
    Just (Int
_, UnknowableSize SrcLoc
loc RigidSource
rsrc) ->
      Notes -> m Notes
forall (m :: * -> *) a. Monad m => a -> m a
return (Notes -> m Notes) -> Notes -> m Notes
forall a b. (a -> b) -> a -> b
$ String -> Notes
forall a. Pretty a => a -> Notes
aNote (String -> Notes) -> String -> Notes
forall a b. (a -> b) -> a -> b
$ Doc -> String
forall a. Pretty a => a -> String
pretty (Doc -> String) -> Doc -> String
forall a b. (a -> b) -> a -> b
$
      Doc -> Doc
pquote (QualName VName -> Doc
forall a. Pretty a => a -> Doc
ppr QualName VName
d) Doc -> Doc -> Doc
<+> SrcLoc -> SrcLoc -> RigidSource -> Doc
prettySource (a -> SrcLoc
forall a. Located a => a -> SrcLoc
srclocOf a
ctx) SrcLoc
loc RigidSource
rsrc
    Maybe (Int, Constraint)
_ -> Notes -> m Notes
forall (m :: * -> *) a. Monad m => a -> m a
return Notes
forall a. Monoid a => a
mempty
dimNotes a
_ DimDecl VName
_ = Notes -> m Notes
forall (m :: * -> *) a. Monad m => a -> m a
return Notes
forall a. Monoid a => a
mempty

typeNotes :: (Located a, MonadUnify m) => a -> StructType -> m Notes
typeNotes :: a -> StructType -> m Notes
typeNotes a
ctx =
  ([Notes] -> Notes) -> m [Notes] -> m Notes
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [Notes] -> Notes
forall a. Monoid a => [a] -> a
mconcat (m [Notes] -> m Notes)
-> (StructType -> m [Notes]) -> StructType -> m Notes
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> m Notes) -> [VName] -> m [Notes]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (a -> DimDecl VName -> m Notes
forall a (m :: * -> *).
(Located a, MonadUnify m) =>
a -> DimDecl VName -> m Notes
dimNotes a
ctx (DimDecl VName -> m Notes)
-> (VName -> DimDecl VName) -> VName -> m Notes
forall b c a. (b -> c) -> (a -> b) -> a -> c
. QualName VName -> DimDecl VName
forall vn. QualName vn -> DimDecl vn
NamedDim (QualName VName -> DimDecl VName)
-> (VName -> QualName VName) -> VName -> DimDecl VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> QualName VName
forall v. v -> QualName v
qualName) ([VName] -> m [Notes])
-> (StructType -> [VName]) -> StructType -> m [Notes]
forall b c a. (b -> c) -> (a -> b) -> a -> c
.
  Set VName -> [VName]
forall a. Set a -> [a]
S.toList (Set VName -> [VName])
-> (StructType -> Set VName) -> StructType -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StructType -> Set VName
forall als. TypeBase (DimDecl VName) als -> Set VName
typeDimNames

-- | Monads that which to perform unification must implement this type
-- class.
class Monad m => MonadUnify m where
  getConstraints :: m Constraints
  putConstraints :: Constraints -> m ()
  modifyConstraints :: (Constraints -> Constraints) -> m ()
  modifyConstraints Constraints -> Constraints
f = do
    Constraints
x <- m Constraints
forall (m :: * -> *). MonadUnify m => m Constraints
getConstraints
    Constraints -> m ()
forall (m :: * -> *). MonadUnify m => Constraints -> m ()
putConstraints (Constraints -> m ()) -> Constraints -> m ()
forall a b. (a -> b) -> a -> b
$ Constraints -> Constraints
f Constraints
x

  newTypeVar :: Monoid als => SrcLoc -> String -> m (TypeBase dim als)
  newDimVar :: SrcLoc -> Rigidity -> String -> m VName

  curLevel :: m Level

  matchError :: Located loc => loc -> Notes -> BreadCrumbs
             -> StructType -> StructType -> m a

  unifyError :: Located loc => loc -> Notes -> BreadCrumbs
             -> Doc -> m a

-- | Replace all type variables with their substitution.
normTypeFully :: (Substitutable a, MonadUnify m) => a -> m a
normTypeFully :: a -> m a
normTypeFully a
t = do Constraints
constraints <- m Constraints
forall (m :: * -> *). MonadUnify m => m Constraints
getConstraints
                     a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (a -> m a) -> a -> m a
forall a b. (a -> b) -> a -> b
$ (VName -> Maybe (Subst StructType)) -> a -> a
forall a.
Substitutable a =>
(VName -> Maybe (Subst StructType)) -> a -> a
applySubst (VName -> Constraints -> Maybe (Subst StructType)
`lookupSubst` Constraints
constraints) a
t

-- | Replace any top-level type variable with its substitution.
normType :: MonadUnify m => StructType -> m StructType
normType :: StructType -> m StructType
normType t :: StructType
t@(Scalar (TypeVar ()
_ Uniqueness
_ (TypeName [] VName
v) [])) = do
  Constraints
constraints <- m Constraints
forall (m :: * -> *). MonadUnify m => m Constraints
getConstraints
  case (Int, Constraint) -> Constraint
forall a b. (a, b) -> b
snd ((Int, Constraint) -> Constraint)
-> Maybe (Int, Constraint) -> Maybe Constraint
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> Constraints -> Maybe (Int, Constraint)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Constraints
constraints of
    Just (Constraint StructType
t' Usage
_) -> StructType -> m StructType
forall (m :: * -> *). MonadUnify m => StructType -> m StructType
normType StructType
t'
    Maybe Constraint
_ -> StructType -> m StructType
forall (m :: * -> *) a. Monad m => a -> m a
return StructType
t
normType StructType
t = StructType -> m StructType
forall (m :: * -> *) a. Monad m => a -> m a
return StructType
t

-- | Replace any top-level type variable with its substitution.
normPatternType :: MonadUnify m => PatternType -> m PatternType
normPatternType :: PatternType -> m PatternType
normPatternType t :: PatternType
t@(Scalar (TypeVar Aliasing
als Uniqueness
u (TypeName [] VName
v) [])) = do
  Constraints
constraints <- m Constraints
forall (m :: * -> *). MonadUnify m => m Constraints
getConstraints
  case (Int, Constraint) -> Constraint
forall a b. (a, b) -> b
snd ((Int, Constraint) -> Constraint)
-> Maybe (Int, Constraint) -> Maybe Constraint
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> Constraints -> Maybe (Int, Constraint)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Constraints
constraints of
    Just (Constraint StructType
t' Usage
_) ->
      PatternType -> m PatternType
forall (m :: * -> *). MonadUnify m => PatternType -> m PatternType
normPatternType (PatternType -> m PatternType) -> PatternType -> m PatternType
forall a b. (a -> b) -> a -> b
$ StructType
t' StructType -> Uniqueness -> StructType
forall dim as. TypeBase dim as -> Uniqueness -> TypeBase dim as
`setUniqueness` Uniqueness
u StructType -> Aliasing -> PatternType
forall dim asf ast. TypeBase dim asf -> ast -> TypeBase dim ast
`setAliases` Aliasing
als
    Maybe Constraint
_ -> PatternType -> m PatternType
forall (m :: * -> *) a. Monad m => a -> m a
return PatternType
t
normPatternType PatternType
t = PatternType -> m PatternType
forall (m :: * -> *) a. Monad m => a -> m a
return PatternType
t

rigidConstraint :: Constraint -> Bool
rigidConstraint :: Constraint -> Bool
rigidConstraint ParamType{} = Bool
True
rigidConstraint ParamSize{} = Bool
True
rigidConstraint UnknowableSize{} = Bool
True
rigidConstraint Constraint
_ = Bool
False

-- | Replace 'AnyDim' dimensions that occur as 'PosImmediate' or
-- 'PosParam' with a fresh 'NamedDim'.
instantiateEmptyArrayDims :: MonadUnify m =>
                             SrcLoc -> String -> Rigidity
                          -> TypeBase (DimDecl VName) als
                          -> m (TypeBase (DimDecl VName) als, [VName])
instantiateEmptyArrayDims :: SrcLoc
-> String
-> Rigidity
-> TypeBase (DimDecl VName) als
-> m (TypeBase (DimDecl VName) als, [VName])
instantiateEmptyArrayDims SrcLoc
tloc String
desc Rigidity
r = WriterT [VName] m (TypeBase (DimDecl VName) als)
-> m (TypeBase (DimDecl VName) als, [VName])
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT [VName] m (TypeBase (DimDecl VName) als)
 -> m (TypeBase (DimDecl VName) als, [VName]))
-> (TypeBase (DimDecl VName) als
    -> WriterT [VName] m (TypeBase (DimDecl VName) als))
-> TypeBase (DimDecl VName) als
-> m (TypeBase (DimDecl VName) als, [VName])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Set VName
 -> DimPos -> DimDecl VName -> WriterT [VName] m (DimDecl VName))
-> TypeBase (DimDecl VName) als
-> WriterT [VName] m (TypeBase (DimDecl VName) als)
forall (f :: * -> *) fdim tdim als.
Applicative f =>
(Set VName -> DimPos -> fdim -> f tdim)
-> TypeBase fdim als -> f (TypeBase tdim als)
traverseDims Set VName
-> DimPos -> DimDecl VName -> WriterT [VName] m (DimDecl VName)
forall p.
p -> DimPos -> DimDecl VName -> WriterT [VName] m (DimDecl VName)
onDim
  where onDim :: p -> DimPos -> DimDecl VName -> WriterT [VName] m (DimDecl VName)
onDim p
_ DimPos
PosImmediate DimDecl VName
AnyDim = WriterT [VName] m (DimDecl VName)
inst
        onDim p
_ DimPos
PosParam DimDecl VName
AnyDim = WriterT [VName] m (DimDecl VName)
inst
        onDim p
_ DimPos
_ DimDecl VName
d = DimDecl VName -> WriterT [VName] m (DimDecl VName)
forall (m :: * -> *) a. Monad m => a -> m a
return DimDecl VName
d
        inst :: WriterT [VName] m (DimDecl VName)
inst = do
          VName
dim <- m VName -> WriterT [VName] m VName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m VName -> WriterT [VName] m VName)
-> m VName -> WriterT [VName] m VName
forall a b. (a -> b) -> a -> b
$ SrcLoc -> Rigidity -> String -> m VName
forall (m :: * -> *).
MonadUnify m =>
SrcLoc -> Rigidity -> String -> m VName
newDimVar SrcLoc
tloc Rigidity
r String
desc
          [VName] -> WriterT [VName] m ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell [VName
dim]
          DimDecl VName -> WriterT [VName] m (DimDecl VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (DimDecl VName -> WriterT [VName] m (DimDecl VName))
-> DimDecl VName -> WriterT [VName] m (DimDecl VName)
forall a b. (a -> b) -> a -> b
$ QualName VName -> DimDecl VName
forall vn. QualName vn -> DimDecl vn
NamedDim (QualName VName -> DimDecl VName)
-> QualName VName -> DimDecl VName
forall a b. (a -> b) -> a -> b
$ VName -> QualName VName
forall v. v -> QualName v
qualName VName
dim

-- | Is the given type variable the name of an abstract type or type
-- parameter, which we cannot substitute?
isRigid :: VName -> Constraints -> Bool
isRigid :: VName -> Constraints -> Bool
isRigid VName
v Constraints
constraints =
  Bool
-> ((Int, Constraint) -> Bool) -> Maybe (Int, Constraint) -> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
True (Constraint -> Bool
rigidConstraint (Constraint -> Bool)
-> ((Int, Constraint) -> Constraint) -> (Int, Constraint) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int, Constraint) -> Constraint
forall a b. (a, b) -> b
snd) (Maybe (Int, Constraint) -> Bool)
-> Maybe (Int, Constraint) -> Bool
forall a b. (a -> b) -> a -> b
$ VName -> Constraints -> Maybe (Int, Constraint)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Constraints
constraints

-- | If the given type variable is nonrigid, what is its level?
isNonRigid :: VName -> Constraints -> Maybe Level
isNonRigid :: VName -> Constraints -> Maybe Int
isNonRigid VName
v Constraints
constraints = do
  (Int
lvl, Constraint
c) <- VName -> Constraints -> Maybe (Int, Constraint)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Constraints
constraints
  Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Constraint -> Bool
rigidConstraint Constraint
c
  Int -> Maybe Int
forall (m :: * -> *) a. Monad m => a -> m a
return Int
lvl

type UnifyDims m =
  BreadCrumbs -> [VName] -> (VName -> Maybe Int) -> DimDecl VName -> DimDecl VName -> m ()

flipUnifyDims :: UnifyDims m -> UnifyDims m
flipUnifyDims :: UnifyDims m -> UnifyDims m
flipUnifyDims UnifyDims m
onDims BreadCrumbs
bcs [VName]
bound VName -> Maybe Int
nonrigid DimDecl VName
t1 DimDecl VName
t2 =
  UnifyDims m
onDims BreadCrumbs
bcs [VName]
bound VName -> Maybe Int
nonrigid DimDecl VName
t2 DimDecl VName
t1

unifyWith :: MonadUnify m =>
             UnifyDims m -> Usage -> BreadCrumbs
          -> StructType -> StructType -> m ()
unifyWith :: UnifyDims m
-> Usage -> BreadCrumbs -> StructType -> StructType -> m ()
unifyWith UnifyDims m
onDims Usage
usage = Bool -> [VName] -> BreadCrumbs -> StructType -> StructType -> m ()
subunify Bool
False [VName]
forall a. Monoid a => a
mempty
  where
    swap :: Bool -> a -> a -> (a, a)
swap Bool
True a
x a
y = (a
y, a
x)
    swap Bool
False a
x a
y = (a
x, a
y)

    subunify :: Bool -> [VName] -> BreadCrumbs -> StructType -> StructType -> m ()
subunify Bool
ord [VName]
bound BreadCrumbs
bcs StructType
t1 StructType
t2 = do
      Constraints
constraints <- m Constraints
forall (m :: * -> *). MonadUnify m => m Constraints
getConstraints

      StructType
t1' <- StructType -> m StructType
forall (m :: * -> *). MonadUnify m => StructType -> m StructType
normType StructType
t1
      StructType
t2' <- StructType -> m StructType
forall (m :: * -> *). MonadUnify m => StructType -> m StructType
normType StructType
t2

      let nonrigid :: VName -> Maybe Int
nonrigid VName
v = VName -> Constraints -> Maybe Int
isNonRigid VName
v Constraints
constraints

          failure :: m a
failure = SrcLoc -> Notes -> BreadCrumbs -> StructType -> StructType -> m a
forall (m :: * -> *) loc a.
(MonadUnify m, Located loc) =>
loc -> Notes -> BreadCrumbs -> StructType -> StructType -> m a
matchError (Usage -> SrcLoc
forall a. Located a => a -> SrcLoc
srclocOf Usage
usage) Notes
forall a. Monoid a => a
mempty BreadCrumbs
bcs StructType
t1' StructType
t2'

          -- Remove any of the intermediate dimensions we added just
          -- for unification purposes.
          unbound :: StructType -> StructType
unbound = (VName -> Maybe (Subst StructType)) -> StructType -> StructType
forall a.
Substitutable a =>
(VName -> Maybe (Subst StructType)) -> a -> a
applySubst VName -> Maybe (Subst StructType)
forall t. VName -> Maybe (Subst t)
f
            where f :: VName -> Maybe (Subst t)
f VName
d | VName
d VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
bound = Subst t -> Maybe (Subst t)
forall a. a -> Maybe a
Just (Subst t -> Maybe (Subst t)) -> Subst t -> Maybe (Subst t)
forall a b. (a -> b) -> a -> b
$ DimDecl VName -> Subst t
forall t. DimDecl VName -> Subst t
SizeSubst DimDecl VName
forall vn. DimDecl vn
AnyDim
                      | Bool
otherwise      = Maybe (Subst t)
forall a. Maybe a
Nothing

          link :: Bool -> VName -> Int -> StructType -> m ()
link Bool
ord' VName
v Int
lvl =
            UnifyDims m
-> Usage -> BreadCrumbs -> VName -> Int -> StructType -> m ()
forall (m :: * -> *).
MonadUnify m =>
UnifyDims m
-> Usage -> BreadCrumbs -> VName -> Int -> StructType -> m ()
linkVarToType UnifyDims m
linkDims Usage
usage BreadCrumbs
bcs VName
v Int
lvl (StructType -> m ())
-> (StructType -> StructType) -> StructType -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StructType -> StructType
unbound
            where -- We may have to flip the order of future calls to
                  -- onDims inside linkVarToType.
                  linkDims :: UnifyDims m
linkDims | Bool
ord' = UnifyDims m -> UnifyDims m
forall (m :: * -> *). UnifyDims m -> UnifyDims m
flipUnifyDims UnifyDims m
onDims
                           | Bool
otherwise = UnifyDims m
onDims

          unifyTypeArg :: BreadCrumbs
-> TypeArg (DimDecl VName) -> TypeArg (DimDecl VName) -> m ()
unifyTypeArg BreadCrumbs
bcs' (TypeArgDim DimDecl VName
d1 SrcLoc
_) (TypeArgDim DimDecl VName
d2 SrcLoc
_) =
            BreadCrumbs -> (DimDecl VName, DimDecl VName) -> m ()
onDims' BreadCrumbs
bcs' (Bool
-> DimDecl VName -> DimDecl VName -> (DimDecl VName, DimDecl VName)
forall a. Bool -> a -> a -> (a, a)
swap Bool
ord DimDecl VName
d1 DimDecl VName
d2)
          unifyTypeArg BreadCrumbs
bcs' (TypeArgType StructType
t SrcLoc
_) (TypeArgType StructType
arg_t SrcLoc
_) =
            Bool -> [VName] -> BreadCrumbs -> StructType -> StructType -> m ()
subunify Bool
ord [VName]
bound BreadCrumbs
bcs' StructType
t StructType
arg_t
          unifyTypeArg BreadCrumbs
bcs' TypeArg (DimDecl VName)
_ TypeArg (DimDecl VName)
_ = Usage -> Notes -> BreadCrumbs -> Doc -> m ()
forall (m :: * -> *) loc a.
(MonadUnify m, Located loc) =>
loc -> Notes -> BreadCrumbs -> Doc -> m a
unifyError Usage
usage Notes
forall a. Monoid a => a
mempty BreadCrumbs
bcs'
            Doc
"Cannot unify a type argument with a dimension argument (or vice versa)."

          onDims' :: BreadCrumbs -> (DimDecl VName, DimDecl VName) -> m ()
onDims' BreadCrumbs
bcs' (DimDecl VName
d1, DimDecl VName
d2) =
            UnifyDims m
onDims BreadCrumbs
bcs' [VName]
bound VName -> Maybe Int
nonrigid
            ((VName -> Maybe (Subst StructType))
-> DimDecl VName -> DimDecl VName
forall a.
Substitutable a =>
(VName -> Maybe (Subst StructType)) -> a -> a
applySubst (VName -> Constraints -> Maybe (Subst StructType)
`lookupSubst` Constraints
constraints) DimDecl VName
d1)
            ((VName -> Maybe (Subst StructType))
-> DimDecl VName -> DimDecl VName
forall a.
Substitutable a =>
(VName -> Maybe (Subst StructType)) -> a -> a
applySubst (VName -> Constraints -> Maybe (Subst StructType)
`lookupSubst` Constraints
constraints) DimDecl VName
d2)

      case (StructType
t1', StructType
t2') of
        (Scalar (Record Map Name StructType
fs),
         Scalar (Record Map Name StructType
arg_fs))
          | Map Name StructType -> [Name]
forall k a. Map k a -> [k]
M.keys Map Name StructType
fs [Name] -> [Name] -> Bool
forall a. Eq a => a -> a -> Bool
== Map Name StructType -> [Name]
forall k a. Map k a -> [k]
M.keys Map Name StructType
arg_fs ->
              [(Name, (StructType, StructType))]
-> ((Name, (StructType, StructType)) -> m ()) -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (Map Name (StructType, StructType)
-> [(Name, (StructType, StructType))]
forall k a. Map k a -> [(k, a)]
M.toList (Map Name (StructType, StructType)
 -> [(Name, (StructType, StructType))])
-> Map Name (StructType, StructType)
-> [(Name, (StructType, StructType))]
forall a b. (a -> b) -> a -> b
$ (StructType -> StructType -> (StructType, StructType))
-> Map Name StructType
-> Map Name StructType
-> Map Name (StructType, StructType)
forall k a b c.
Ord k =>
(a -> b -> c) -> Map k a -> Map k b -> Map k c
M.intersectionWith (,) Map Name StructType
fs Map Name StructType
arg_fs) (((Name, (StructType, StructType)) -> m ()) -> m ())
-> ((Name, (StructType, StructType)) -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \(Name
k, (StructType
k_t1, StructType
k_t2)) -> do
              let bcs' :: BreadCrumbs
bcs' = BreadCrumb -> BreadCrumbs -> BreadCrumbs
breadCrumb ([Name] -> BreadCrumb
MatchingFields [Name
k]) BreadCrumbs
bcs
              Bool -> [VName] -> BreadCrumbs -> StructType -> StructType -> m ()
subunify Bool
ord [VName]
bound BreadCrumbs
bcs' StructType
k_t1 StructType
k_t2
          | Bool
otherwise -> do
              let missing :: [Name]
missing = (Name -> Bool) -> [Name] -> [Name]
forall a. (a -> Bool) -> [a] -> [a]
filter (Name -> [Name] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` Map Name StructType -> [Name]
forall k a. Map k a -> [k]
M.keys Map Name StructType
arg_fs) (Map Name StructType -> [Name]
forall k a. Map k a -> [k]
M.keys Map Name StructType
fs) [Name] -> [Name] -> [Name]
forall a. [a] -> [a] -> [a]
++
                            (Name -> Bool) -> [Name] -> [Name]
forall a. (a -> Bool) -> [a] -> [a]
filter (Name -> [Name] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` Map Name StructType -> [Name]
forall k a. Map k a -> [k]
M.keys Map Name StructType
fs) (Map Name StructType -> [Name]
forall k a. Map k a -> [k]
M.keys Map Name StructType
arg_fs)
              Usage -> Notes -> BreadCrumbs -> Doc -> m ()
forall (m :: * -> *) loc a.
(MonadUnify m, Located loc) =>
loc -> Notes -> BreadCrumbs -> Doc -> m a
unifyError Usage
usage Notes
forall a. Monoid a => a
mempty BreadCrumbs
bcs (Doc -> m ()) -> Doc -> m ()
forall a b. (a -> b) -> a -> b
$
                Doc
"Unshared fields:" Doc -> Doc -> Doc
<+> [Doc] -> Doc
commasep ((Name -> Doc) -> [Name] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map Name -> Doc
forall a. Pretty a => a -> Doc
ppr [Name]
missing) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
"."

        (Scalar (TypeVar ()
_ Uniqueness
_ (TypeName [VName]
_ VName
tn) [TypeArg (DimDecl VName)]
targs),
         Scalar (TypeVar ()
_ Uniqueness
_ (TypeName [VName]
_ VName
arg_tn) [TypeArg (DimDecl VName)]
arg_targs))
          | VName
tn VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
arg_tn, [TypeArg (DimDecl VName)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TypeArg (DimDecl VName)]
targs Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [TypeArg (DimDecl VName)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TypeArg (DimDecl VName)]
arg_targs -> do
            let bcs' :: BreadCrumbs
bcs' = BreadCrumb -> BreadCrumbs -> BreadCrumbs
breadCrumb (Doc -> BreadCrumb
Matching Doc
"When matching type arguments.") BreadCrumbs
bcs
            (TypeArg (DimDecl VName) -> TypeArg (DimDecl VName) -> m ())
-> [TypeArg (DimDecl VName)] -> [TypeArg (DimDecl VName)] -> m ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (BreadCrumbs
-> TypeArg (DimDecl VName) -> TypeArg (DimDecl VName) -> m ()
unifyTypeArg BreadCrumbs
bcs') [TypeArg (DimDecl VName)]
targs [TypeArg (DimDecl VName)]
arg_targs

        (Scalar (TypeVar ()
_ Uniqueness
_ (TypeName [] VName
v1) []),
         Scalar (TypeVar ()
_ Uniqueness
_ (TypeName [] VName
v2) [])) ->
          case (VName -> Maybe Int
nonrigid VName
v1, VName -> Maybe Int
nonrigid VName
v2) of
            (Maybe Int
Nothing, Maybe Int
Nothing) -> m ()
forall a. m a
failure
            (Just Int
lvl1, Maybe Int
Nothing) -> Bool -> VName -> Int -> StructType -> m ()
link Bool
ord VName
v1 Int
lvl1 StructType
t2'
            (Maybe Int
Nothing, Just Int
lvl2) -> Bool -> VName -> Int -> StructType -> m ()
link (Bool -> Bool
not Bool
ord) VName
v2 Int
lvl2 StructType
t1'
            (Just Int
lvl1, Just Int
lvl2)
              | Int
lvl1 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
lvl2 -> Bool -> VName -> Int -> StructType -> m ()
link Bool
ord VName
v1 Int
lvl1 StructType
t2'
              | Bool
otherwise    -> Bool -> VName -> Int -> StructType -> m ()
link (Bool -> Bool
not Bool
ord) VName
v2 Int
lvl2 StructType
t1'

        (Scalar (TypeVar ()
_ Uniqueness
_ (TypeName [] VName
v1) []), StructType
_)
          | Just Int
lvl <- VName -> Maybe Int
nonrigid VName
v1 ->
              Bool -> VName -> Int -> StructType -> m ()
link Bool
ord VName
v1 Int
lvl StructType
t2'
        (StructType
_, Scalar (TypeVar ()
_ Uniqueness
_ (TypeName [] VName
v2) []))
          | Just Int
lvl <- VName -> Maybe Int
nonrigid VName
v2 ->
              Bool -> VName -> Int -> StructType -> m ()
link (Bool -> Bool
not Bool
ord) VName
v2 Int
lvl StructType
t1'

        (Scalar (Arrow ()
_ PName
p1 StructType
a1 StructType
b1),
         Scalar (Arrow ()
_ PName
p2 StructType
a2 StructType
b2)) -> do
          let (Rigidity
r1, Rigidity
r2) = Bool -> Rigidity -> Rigidity -> (Rigidity, Rigidity)
forall a. Bool -> a -> a -> (a, a)
swap Bool
ord (RigidSource -> Rigidity
Rigid RigidSource
RigidUnify) Rigidity
Nonrigid
          (StructType
a1', [VName]
a1_dims) <- SrcLoc
-> String -> Rigidity -> StructType -> m (StructType, [VName])
forall (m :: * -> *) als.
MonadUnify m =>
SrcLoc
-> String
-> Rigidity
-> TypeBase (DimDecl VName) als
-> m (TypeBase (DimDecl VName) als, [VName])
instantiateEmptyArrayDims (Usage -> SrcLoc
forall a. Located a => a -> SrcLoc
srclocOf Usage
usage) String
"anonymous" Rigidity
r1 StructType
a1
          (StructType
a2', [VName]
a2_dims) <- SrcLoc
-> String -> Rigidity -> StructType -> m (StructType, [VName])
forall (m :: * -> *) als.
MonadUnify m =>
SrcLoc
-> String
-> Rigidity
-> TypeBase (DimDecl VName) als
-> m (TypeBase (DimDecl VName) als, [VName])
instantiateEmptyArrayDims (Usage -> SrcLoc
forall a. Located a => a -> SrcLoc
srclocOf Usage
usage) String
"anonymous" Rigidity
r2 StructType
a2
          let bound' :: [VName]
bound' = [VName]
bound [VName] -> [VName] -> [VName]
forall a. Semigroup a => a -> a -> a
<> (PName -> Maybe VName) -> [PName] -> [VName]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe PName -> Maybe VName
pname [PName
p1, PName
p2] [VName] -> [VName] -> [VName]
forall a. Semigroup a => a -> a -> a
<> [VName]
a1_dims [VName] -> [VName] -> [VName]
forall a. Semigroup a => a -> a -> a
<> [VName]
a2_dims
          Bool -> [VName] -> BreadCrumbs -> StructType -> StructType -> m ()
subunify (Bool -> Bool
not Bool
ord) [VName]
bound
            (BreadCrumb -> BreadCrumbs -> BreadCrumbs
breadCrumb (Doc -> BreadCrumb
Matching Doc
"When matching parameter types.") BreadCrumbs
bcs)
            StructType
a1' StructType
a2'
          Bool -> [VName] -> BreadCrumbs -> StructType -> StructType -> m ()
subunify Bool
ord [VName]
bound'
            (BreadCrumb -> BreadCrumbs -> BreadCrumbs
breadCrumb (Doc -> BreadCrumb
Matching Doc
"When matching return types.") BreadCrumbs
bcs)
            StructType
b1' StructType
b2'
          where (StructType
b1', StructType
b2') =
                  -- Replace one parameter name with the other in the
                  -- return type, in case of dependent types.  I.e.,
                  -- we want type '(n: i32) -> [n]i32' to unify with
                  -- type '(x: i32) -> [x]i32'.
                  case (PName
p1, PName
p2) of
                    (Named VName
p1', Named VName
p2') ->
                      let f :: VName -> Maybe (Subst t)
f VName
v | VName
v VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
p2' = Subst t -> Maybe (Subst t)
forall a. a -> Maybe a
Just (Subst t -> Maybe (Subst t)) -> Subst t -> Maybe (Subst t)
forall a b. (a -> b) -> a -> b
$ DimDecl VName -> Subst t
forall t. DimDecl VName -> Subst t
SizeSubst (DimDecl VName -> Subst t) -> DimDecl VName -> Subst t
forall a b. (a -> b) -> a -> b
$ QualName VName -> DimDecl VName
forall vn. QualName vn -> DimDecl vn
NamedDim (QualName VName -> DimDecl VName)
-> QualName VName -> DimDecl VName
forall a b. (a -> b) -> a -> b
$ VName -> QualName VName
forall v. v -> QualName v
qualName VName
p1'
                              | Bool
otherwise = Maybe (Subst t)
forall a. Maybe a
Nothing
                      in (StructType
b1, (VName -> Maybe (Subst StructType)) -> StructType -> StructType
forall a.
Substitutable a =>
(VName -> Maybe (Subst StructType)) -> a -> a
applySubst VName -> Maybe (Subst StructType)
forall t. VName -> Maybe (Subst t)
f StructType
b2)

                    (PName
_, PName
_) ->
                      (StructType
b1, StructType
b2)

                pname :: PName -> Maybe VName
pname (Named VName
x) = VName -> Maybe VName
forall a. a -> Maybe a
Just VName
x
                pname PName
Unnamed = Maybe VName
forall a. Maybe a
Nothing

        (Array{}, Array{})
          | ShapeDecl (DimDecl VName
t1_d : [DimDecl VName]
_) <- StructType -> ShapeDecl (DimDecl VName)
forall dim as. TypeBase dim as -> ShapeDecl dim
arrayShape StructType
t1',
            ShapeDecl (DimDecl VName
t2_d : [DimDecl VName]
_) <- StructType -> ShapeDecl (DimDecl VName)
forall dim as. TypeBase dim as -> ShapeDecl dim
arrayShape StructType
t2',
            Just StructType
t1'' <- Int -> StructType -> Maybe StructType
forall dim as. Int -> TypeBase dim as -> Maybe (TypeBase dim as)
peelArray Int
1 StructType
t1',
            Just StructType
t2'' <- Int -> StructType -> Maybe StructType
forall dim as. Int -> TypeBase dim as -> Maybe (TypeBase dim as)
peelArray Int
1 StructType
t2' -> do
              BreadCrumbs -> (DimDecl VName, DimDecl VName) -> m ()
onDims' BreadCrumbs
bcs (Bool
-> DimDecl VName -> DimDecl VName -> (DimDecl VName, DimDecl VName)
forall a. Bool -> a -> a -> (a, a)
swap Bool
ord DimDecl VName
t1_d DimDecl VName
t2_d)
              Bool -> [VName] -> BreadCrumbs -> StructType -> StructType -> m ()
subunify Bool
ord [VName]
bound BreadCrumbs
bcs StructType
t1'' StructType
t2''

        (Scalar (Sum Map Name [StructType]
cs),
         Scalar (Sum Map Name [StructType]
arg_cs))
          | Map Name [StructType] -> [Name]
forall k a. Map k a -> [k]
M.keys Map Name [StructType]
cs [Name] -> [Name] -> Bool
forall a. Eq a => a -> a -> Bool
== Map Name [StructType] -> [Name]
forall k a. Map k a -> [k]
M.keys Map Name [StructType]
arg_cs ->
              UnifyDims m
-> Usage
-> BreadCrumbs
-> Map Name [StructType]
-> Map Name [StructType]
-> m ()
forall (m :: * -> *).
MonadUnify m =>
UnifyDims m
-> Usage
-> BreadCrumbs
-> Map Name [StructType]
-> Map Name [StructType]
-> m ()
unifySharedConstructors UnifyDims m
onDims Usage
usage BreadCrumbs
bcs
              ((StructType -> StructType) -> [StructType] -> [StructType]
forall a b. (a -> b) -> [a] -> [b]
map StructType -> StructType
unbound ([StructType] -> [StructType])
-> Map Name [StructType] -> Map Name [StructType]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Map Name [StructType]
cs) ((StructType -> StructType) -> [StructType] -> [StructType]
forall a b. (a -> b) -> [a] -> [b]
map StructType -> StructType
unbound ([StructType] -> [StructType])
-> Map Name [StructType] -> Map Name [StructType]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Map Name [StructType]
arg_cs)
          | Bool
otherwise -> do
              let missing :: [Name]
missing = (Name -> Bool) -> [Name] -> [Name]
forall a. (a -> Bool) -> [a] -> [a]
filter (Name -> [Name] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` Map Name [StructType] -> [Name]
forall k a. Map k a -> [k]
M.keys Map Name [StructType]
arg_cs) (Map Name [StructType] -> [Name]
forall k a. Map k a -> [k]
M.keys Map Name [StructType]
cs) [Name] -> [Name] -> [Name]
forall a. [a] -> [a] -> [a]
++
                            (Name -> Bool) -> [Name] -> [Name]
forall a. (a -> Bool) -> [a] -> [a]
filter (Name -> [Name] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` Map Name [StructType] -> [Name]
forall k a. Map k a -> [k]
M.keys Map Name [StructType]
cs) (Map Name [StructType] -> [Name]
forall k a. Map k a -> [k]
M.keys Map Name [StructType]
arg_cs)
              Usage -> Notes -> BreadCrumbs -> Doc -> m ()
forall (m :: * -> *) loc a.
(MonadUnify m, Located loc) =>
loc -> Notes -> BreadCrumbs -> Doc -> m a
unifyError Usage
usage Notes
forall a. Monoid a => a
mempty BreadCrumbs
bcs (Doc -> m ()) -> Doc -> m ()
forall a b. (a -> b) -> a -> b
$
                Doc
"Unshared constructors:" Doc -> Doc -> Doc
<+> [Doc] -> Doc
commasep ((Name -> Doc) -> [Name] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map ((Doc
"#"Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<>) (Doc -> Doc) -> (Name -> Doc) -> Name -> Doc
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Name -> Doc
forall a. Pretty a => a -> Doc
ppr) [Name]
missing) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
"."

        (StructType, StructType)
_ | StructType
t1' StructType -> StructType -> Bool
forall a. Eq a => a -> a -> Bool
== StructType
t2' -> () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
          | Bool
otherwise -> m ()
forall a. m a
failure

unifyDims :: MonadUnify m => Usage -> UnifyDims m
unifyDims :: Usage -> UnifyDims m
unifyDims Usage
_ BreadCrumbs
_ [VName]
_ VName -> Maybe Int
_ DimDecl VName
d1 DimDecl VName
d2
  | DimDecl VName
d1 DimDecl VName -> DimDecl VName -> Bool
forall a. Eq a => a -> a -> Bool
== DimDecl VName
d2 = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
unifyDims Usage
usage BreadCrumbs
bcs [VName]
_ VName -> Maybe Int
nonrigid (NamedDim (QualName [VName]
_ VName
d1)) DimDecl VName
d2
  | Just Int
lvl1 <- VName -> Maybe Int
nonrigid VName
d1 =
      Usage -> BreadCrumbs -> VName -> Int -> DimDecl VName -> m ()
forall (m :: * -> *).
MonadUnify m =>
Usage -> BreadCrumbs -> VName -> Int -> DimDecl VName -> m ()
linkVarToDim Usage
usage BreadCrumbs
bcs VName
d1 Int
lvl1 DimDecl VName
d2
unifyDims Usage
usage BreadCrumbs
bcs [VName]
_ VName -> Maybe Int
nonrigid DimDecl VName
d1 (NamedDim (QualName [VName]
_ VName
d2))
  | Just Int
lvl2 <- VName -> Maybe Int
nonrigid VName
d2 =
      Usage -> BreadCrumbs -> VName -> Int -> DimDecl VName -> m ()
forall (m :: * -> *).
MonadUnify m =>
Usage -> BreadCrumbs -> VName -> Int -> DimDecl VName -> m ()
linkVarToDim Usage
usage BreadCrumbs
bcs VName
d2 Int
lvl2 DimDecl VName
d1
unifyDims Usage
usage BreadCrumbs
bcs [VName]
_ VName -> Maybe Int
_ DimDecl VName
d1 DimDecl VName
d2 = do
  Notes
notes <- Notes -> Notes -> Notes
forall a. Semigroup a => a -> a -> a
(<>) (Notes -> Notes -> Notes) -> m Notes -> m (Notes -> Notes)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Usage -> DimDecl VName -> m Notes
forall a (m :: * -> *).
(Located a, MonadUnify m) =>
a -> DimDecl VName -> m Notes
dimNotes Usage
usage DimDecl VName
d1 m (Notes -> Notes) -> m Notes -> m Notes
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Usage -> DimDecl VName -> m Notes
forall a (m :: * -> *).
(Located a, MonadUnify m) =>
a -> DimDecl VName -> m Notes
dimNotes Usage
usage DimDecl VName
d2
  Usage -> Notes -> BreadCrumbs -> Doc -> m ()
forall (m :: * -> *) loc a.
(MonadUnify m, Located loc) =>
loc -> Notes -> BreadCrumbs -> Doc -> m a
unifyError Usage
usage Notes
notes BreadCrumbs
bcs (Doc -> m ()) -> Doc -> m ()
forall a b. (a -> b) -> a -> b
$
    Doc
"Dimensions" Doc -> Doc -> Doc
<+> Doc -> Doc
pquote (DimDecl VName -> Doc
forall a. Pretty a => a -> Doc
ppr DimDecl VName
d1) Doc -> Doc -> Doc
<+>
    Doc
"and" Doc -> Doc -> Doc
<+> Doc -> Doc
pquote (DimDecl VName -> Doc
forall a. Pretty a => a -> Doc
ppr DimDecl VName
d2) Doc -> Doc -> Doc
<+> Doc
"do not match."

-- | Unifies two types.
unify :: MonadUnify m => Usage -> StructType -> StructType -> m ()
unify :: Usage -> StructType -> StructType -> m ()
unify Usage
usage = UnifyDims m
-> Usage -> BreadCrumbs -> StructType -> StructType -> m ()
forall (m :: * -> *).
MonadUnify m =>
UnifyDims m
-> Usage -> BreadCrumbs -> StructType -> StructType -> m ()
unifyWith (Usage -> UnifyDims m
forall (m :: * -> *). MonadUnify m => Usage -> UnifyDims m
unifyDims Usage
usage) Usage
usage BreadCrumbs
noBreadCrumbs

-- | @expect super sub@ checks that @sub@ is a subtype of @super@.
expect :: MonadUnify m => Usage -> StructType -> StructType -> m ()
expect :: Usage -> StructType -> StructType -> m ()
expect Usage
usage = UnifyDims m
-> Usage -> BreadCrumbs -> StructType -> StructType -> m ()
forall (m :: * -> *).
MonadUnify m =>
UnifyDims m
-> Usage -> BreadCrumbs -> StructType -> StructType -> m ()
unifyWith UnifyDims m
forall (m :: * -> *) (t :: * -> *).
(Foldable t, MonadUnify m) =>
BreadCrumbs
-> t VName
-> (VName -> Maybe Int)
-> DimDecl VName
-> DimDecl VName
-> m ()
onDims Usage
usage BreadCrumbs
noBreadCrumbs
  where onDims :: BreadCrumbs
-> t VName
-> (VName -> Maybe Int)
-> DimDecl VName
-> DimDecl VName
-> m ()
onDims BreadCrumbs
_ t VName
_ VName -> Maybe Int
_ DimDecl VName
AnyDim DimDecl VName
_ = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
        onDims BreadCrumbs
_ t VName
_ VName -> Maybe Int
_ DimDecl VName
d1 DimDecl VName
d2
          | DimDecl VName
d1 DimDecl VName -> DimDecl VName -> Bool
forall a. Eq a => a -> a -> Bool
== DimDecl VName
d2 = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
        onDims BreadCrumbs
bcs t VName
bound VName -> Maybe Int
nonrigid (NamedDim (QualName [VName]
_ VName
d1)) DimDecl VName
d2
          | Just Int
lvl1 <- VName -> Maybe Int
nonrigid VName
d1, DimDecl VName
d2 DimDecl VName -> DimDecl VName -> Bool
forall a. Eq a => a -> a -> Bool
/= DimDecl VName
forall vn. DimDecl vn
AnyDim, Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ t VName -> DimDecl VName -> Bool
forall (t :: * -> *) a.
(Foldable t, Eq a) =>
t a -> DimDecl a -> Bool
boundParam t VName
bound DimDecl VName
d2 =
              Usage -> BreadCrumbs -> VName -> Int -> DimDecl VName -> m ()
forall (m :: * -> *).
MonadUnify m =>
Usage -> BreadCrumbs -> VName -> Int -> DimDecl VName -> m ()
linkVarToDim Usage
usage BreadCrumbs
bcs VName
d1 Int
lvl1 DimDecl VName
d2
        onDims BreadCrumbs
bcs t VName
bound VName -> Maybe Int
nonrigid DimDecl VName
d1 (NamedDim (QualName [VName]
_ VName
d2))
          | Just Int
lvl2 <- VName -> Maybe Int
nonrigid VName
d2, Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ t VName -> DimDecl VName -> Bool
forall (t :: * -> *) a.
(Foldable t, Eq a) =>
t a -> DimDecl a -> Bool
boundParam t VName
bound DimDecl VName
d1 =
              Usage -> BreadCrumbs -> VName -> Int -> DimDecl VName -> m ()
forall (m :: * -> *).
MonadUnify m =>
Usage -> BreadCrumbs -> VName -> Int -> DimDecl VName -> m ()
linkVarToDim Usage
usage BreadCrumbs
bcs VName
d2 Int
lvl2 DimDecl VName
d1
        onDims BreadCrumbs
bcs t VName
_ VName -> Maybe Int
_ DimDecl VName
d1 DimDecl VName
d2 = do
          Notes
notes <- Notes -> Notes -> Notes
forall a. Semigroup a => a -> a -> a
(<>) (Notes -> Notes -> Notes) -> m Notes -> m (Notes -> Notes)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Usage -> DimDecl VName -> m Notes
forall a (m :: * -> *).
(Located a, MonadUnify m) =>
a -> DimDecl VName -> m Notes
dimNotes Usage
usage DimDecl VName
d1 m (Notes -> Notes) -> m Notes -> m Notes
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Usage -> DimDecl VName -> m Notes
forall a (m :: * -> *).
(Located a, MonadUnify m) =>
a -> DimDecl VName -> m Notes
dimNotes Usage
usage DimDecl VName
d2
          Usage -> Notes -> BreadCrumbs -> Doc -> m ()
forall (m :: * -> *) loc a.
(MonadUnify m, Located loc) =>
loc -> Notes -> BreadCrumbs -> Doc -> m a
unifyError Usage
usage Notes
notes BreadCrumbs
bcs (Doc -> m ()) -> Doc -> m ()
forall a b. (a -> b) -> a -> b
$ Doc
"Dimensions" Doc -> Doc -> Doc
<+> Doc -> Doc
pquote (DimDecl VName -> Doc
forall a. Pretty a => a -> Doc
ppr DimDecl VName
d1) Doc -> Doc -> Doc
<+>
            Doc
"and" Doc -> Doc -> Doc
<+> Doc -> Doc
pquote (DimDecl VName -> Doc
forall a. Pretty a => a -> Doc
ppr DimDecl VName
d2) Doc -> Doc -> Doc
<+> Doc
"do not match."

        boundParam :: t a -> DimDecl a -> Bool
boundParam t a
bound (NamedDim (QualName [a]
_ a
d)) = a
d a -> t a -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` t a
bound
        boundParam t a
_ DimDecl a
_ = Bool
False

hasEmptyDims :: StructType -> Bool
hasEmptyDims :: StructType -> Bool
hasEmptyDims = (DimDecl VName -> Bool) -> (() -> Bool) -> StructType -> Bool
forall (t :: * -> * -> *) a b.
Bifoldable t =>
(a -> Bool) -> (b -> Bool) -> t a b -> Bool
biany DimDecl VName -> Bool
forall vn. DimDecl vn -> Bool
empty (Bool -> () -> Bool
forall a b. a -> b -> a
const Bool
False)
  where empty :: DimDecl vn -> Bool
empty DimDecl vn
AnyDim = Bool
True
        empty DimDecl vn
_ = Bool
False

occursCheck :: MonadUnify m =>
               Usage -> BreadCrumbs
            -> VName -> StructType -> m ()
occursCheck :: Usage -> BreadCrumbs -> VName -> StructType -> m ()
occursCheck Usage
usage BreadCrumbs
bcs VName
vn StructType
tp =
  Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (VName
vn VName -> Set VName -> Bool
forall a. Ord a => a -> Set a -> Bool
`S.member` StructType -> Set VName
forall as dim. Monoid as => TypeBase dim as -> Set VName
typeVars StructType
tp) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$
  Usage -> Notes -> BreadCrumbs -> Doc -> m ()
forall (m :: * -> *) loc a.
(MonadUnify m, Located loc) =>
loc -> Notes -> BreadCrumbs -> Doc -> m a
unifyError Usage
usage Notes
forall a. Monoid a => a
mempty BreadCrumbs
bcs (Doc -> m ()) -> Doc -> m ()
forall a b. (a -> b) -> a -> b
$ Doc
"Occurs check: cannot instantiate" Doc -> Doc -> Doc
<+>
  VName -> Doc
forall v. IsName v => v -> Doc
pprName VName
vn Doc -> Doc -> Doc
<+> Doc
"with" Doc -> Doc -> Doc
<+> StructType -> Doc
forall a. Pretty a => a -> Doc
ppr StructType
tp Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
"."

scopeCheck :: MonadUnify m =>
              Usage -> BreadCrumbs
           -> VName -> Level -> StructType -> m ()
scopeCheck :: Usage -> BreadCrumbs -> VName -> Int -> StructType -> m ()
scopeCheck Usage
usage BreadCrumbs
bcs VName
vn Int
max_lvl StructType
tp = do
  Constraints
constraints <- m Constraints
forall (m :: * -> *). MonadUnify m => m Constraints
getConstraints
  Constraints -> StructType -> m ()
forall (m :: * -> *) als.
(MonadUnify m, Monoid als) =>
Constraints -> TypeBase (DimDecl VName) als -> m ()
checkType Constraints
constraints StructType
tp
  where checkType :: Constraints -> TypeBase (DimDecl VName) als -> m ()
checkType Constraints
constraints TypeBase (DimDecl VName) als
t =
          (VName -> m ()) -> Set VName -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Constraints -> VName -> m ()
forall (m :: * -> *). MonadUnify m => Constraints -> VName -> m ()
check Constraints
constraints) (Set VName -> m ()) -> Set VName -> m ()
forall a b. (a -> b) -> a -> b
$ TypeBase (DimDecl VName) als -> Set VName
forall as dim. Monoid as => TypeBase dim as -> Set VName
typeVars TypeBase (DimDecl VName) als
t Set VName -> Set VName -> Set VName
forall a. Semigroup a => a -> a -> a
<> TypeBase (DimDecl VName) als -> Set VName
forall als. TypeBase (DimDecl VName) als -> Set VName
typeDimNames TypeBase (DimDecl VName) als
t

        check :: Constraints -> VName -> m ()
check Constraints
constraints VName
v
          | Just (Int
lvl, Constraint
c) <- VName -> Constraints -> Maybe (Int, Constraint)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Constraints
constraints,
            Int
lvl Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
max_lvl =
              if Constraint -> Bool
rigidConstraint Constraint
c
              then VName -> m ()
forall (m :: * -> *) v b. (MonadUnify m, IsName v) => v -> m b
scopeViolation VName
v
              else (Constraints -> Constraints) -> m ()
forall (m :: * -> *).
MonadUnify m =>
(Constraints -> Constraints) -> m ()
modifyConstraints ((Constraints -> Constraints) -> m ())
-> (Constraints -> Constraints) -> m ()
forall a b. (a -> b) -> a -> b
$ VName -> (Int, Constraint) -> Constraints -> Constraints
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v (Int
max_lvl, Constraint
c)

          | Bool
otherwise =
              () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

        scopeViolation :: v -> m b
scopeViolation v
v = do
          Notes
notes <- Usage -> StructType -> m Notes
forall a (m :: * -> *).
(Located a, MonadUnify m) =>
a -> StructType -> m Notes
typeNotes Usage
usage StructType
tp
          Usage -> Notes -> BreadCrumbs -> Doc -> m b
forall (m :: * -> *) loc a.
(MonadUnify m, Located loc) =>
loc -> Notes -> BreadCrumbs -> Doc -> m a
unifyError Usage
usage Notes
notes BreadCrumbs
bcs (Doc -> m b) -> Doc -> m b
forall a b. (a -> b) -> a -> b
$ Doc
"Cannot unify type" Doc -> Doc -> Doc
</>
            Int -> Doc -> Doc
indent Int
2 (StructType -> Doc
forall a. Pretty a => a -> Doc
ppr StructType
tp) Doc -> Doc -> Doc
</>
            Doc
"with" Doc -> Doc -> Doc
<+> Doc -> Doc
pquote (VName -> Doc
forall v. IsName v => v -> Doc
pprName VName
vn) Doc -> Doc -> Doc
<+> Doc
"(scope violation)." Doc -> Doc -> Doc
</>
            Doc
"This is because" Doc -> Doc -> Doc
<+> Doc -> Doc
pquote (v -> Doc
forall v. IsName v => v -> Doc
pprName v
v) Doc -> Doc -> Doc
<+>
            Doc
"is rigidly bound in a deeper scope."

linkVarToType :: MonadUnify m =>
                 UnifyDims m -> Usage -> BreadCrumbs
              -> VName -> Level -> StructType -> m ()
linkVarToType :: UnifyDims m
-> Usage -> BreadCrumbs -> VName -> Int -> StructType -> m ()
linkVarToType UnifyDims m
onDims Usage
usage BreadCrumbs
bcs VName
vn Int
lvl StructType
tp = do
  Usage -> BreadCrumbs -> VName -> StructType -> m ()
forall (m :: * -> *).
MonadUnify m =>
Usage -> BreadCrumbs -> VName -> StructType -> m ()
occursCheck Usage
usage BreadCrumbs
bcs VName
vn StructType
tp
  Usage -> BreadCrumbs -> VName -> Int -> StructType -> m ()
forall (m :: * -> *).
MonadUnify m =>
Usage -> BreadCrumbs -> VName -> Int -> StructType -> m ()
scopeCheck Usage
usage BreadCrumbs
bcs VName
vn Int
lvl StructType
tp

  Constraints
constraints <- m Constraints
forall (m :: * -> *). MonadUnify m => m Constraints
getConstraints
  let tp' :: StructType
tp' = StructType -> StructType
forall dim as. TypeBase dim as -> TypeBase dim as
removeUniqueness StructType
tp
  (Constraints -> Constraints) -> m ()
forall (m :: * -> *).
MonadUnify m =>
(Constraints -> Constraints) -> m ()
modifyConstraints ((Constraints -> Constraints) -> m ())
-> (Constraints -> Constraints) -> m ()
forall a b. (a -> b) -> a -> b
$ VName -> (Int, Constraint) -> Constraints -> Constraints
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
vn (Int
lvl, StructType -> Usage -> Constraint
Constraint StructType
tp' Usage
usage)
  case (Int, Constraint) -> Constraint
forall a b. (a, b) -> b
snd ((Int, Constraint) -> Constraint)
-> Maybe (Int, Constraint) -> Maybe Constraint
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> Constraints -> Maybe (Int, Constraint)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
vn Constraints
constraints of

    Just (NoConstraint Liftedness
Unlifted Usage
unlift_usage) -> do
      let bcs' :: BreadCrumbs
bcs' = BreadCrumb -> BreadCrumbs -> BreadCrumbs
breadCrumb
                 (Doc -> BreadCrumb
Matching (Doc -> BreadCrumb) -> Doc -> BreadCrumb
forall a b. (a -> b) -> a -> b
$ Doc
"When verifying that" Doc -> Doc -> Doc
<+> Doc -> Doc
pquote (VName -> Doc
forall v. IsName v => v -> Doc
pprName VName
vn) Doc -> Doc -> Doc
<+>
                  String -> Doc
textwrap String
"is not instantiated with a function type, due to" Doc -> Doc -> Doc
<+>
                  Usage -> Doc
forall a. Pretty a => a -> Doc
ppr Usage
unlift_usage)
                 BreadCrumbs
bcs
      Usage -> BreadCrumbs -> StructType -> m ()
forall (m :: * -> *) dim as.
(MonadUnify m, Pretty (ShapeDecl dim), Monoid as) =>
Usage -> BreadCrumbs -> TypeBase dim as -> m ()
zeroOrderTypeWith Usage
usage BreadCrumbs
bcs' StructType
tp'

      Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (StructType -> Bool
hasEmptyDims StructType
tp') (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$
        Usage -> Notes -> BreadCrumbs -> Doc -> m ()
forall (m :: * -> *) loc a.
(MonadUnify m, Located loc) =>
loc -> Notes -> BreadCrumbs -> Doc -> m a
unifyError Usage
usage Notes
forall a. Monoid a => a
mempty BreadCrumbs
bcs (Doc -> m ()) -> Doc -> m ()
forall a b. (a -> b) -> a -> b
$ Doc
"Type variable" Doc -> Doc -> Doc
<+> VName -> Doc
forall v. IsName v => v -> Doc
pprName VName
vn Doc -> Doc -> Doc
<+>
        Doc
"cannot be instantiated with type containing anonymous sizes:" Doc -> Doc -> Doc
</>
        Int -> Doc -> Doc
indent Int
2 (StructType -> Doc
forall a. Pretty a => a -> Doc
ppr StructType
tp) Doc -> Doc -> Doc
</>
        String -> Doc
textwrap String
"This is usually because the size of an array returned by a higher-order function argument cannot be determined statically.  This can also be due to the return size being a value parameter.  Add type annotation to clarify."

    Just (Equality Usage
_) ->
      Usage -> StructType -> m ()
forall (m :: * -> *) dim as.
(MonadUnify m, Pretty (ShapeDecl dim), Monoid as) =>
Usage -> TypeBase dim as -> m ()
equalityType Usage
usage StructType
tp'

    Just (Overloaded [PrimType]
ts Usage
old_usage)
      | StructType
tp StructType -> [StructType] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` (PrimType -> StructType) -> [PrimType] -> [StructType]
forall a b. (a -> b) -> [a] -> [b]
map (ScalarTypeBase (DimDecl VName) () -> StructType
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (ScalarTypeBase (DimDecl VName) () -> StructType)
-> (PrimType -> ScalarTypeBase (DimDecl VName) ())
-> PrimType
-> StructType
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> ScalarTypeBase (DimDecl VName) ()
forall dim as. PrimType -> ScalarTypeBase dim as
Prim) [PrimType]
ts ->
          case StructType
tp' of
            Scalar (TypeVar ()
_ Uniqueness
_ (TypeName [] VName
v) [])
              | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ VName -> Constraints -> Bool
isRigid VName
v Constraints
constraints ->
                  Usage -> VName -> [PrimType] -> m ()
forall (m :: * -> *).
MonadUnify m =>
Usage -> VName -> [PrimType] -> m ()
linkVarToTypes Usage
usage VName
v [PrimType]
ts
            StructType
_ ->
              Usage -> Notes -> BreadCrumbs -> Doc -> m ()
forall (m :: * -> *) loc a.
(MonadUnify m, Located loc) =>
loc -> Notes -> BreadCrumbs -> Doc -> m a
unifyError Usage
usage Notes
forall a. Monoid a => a
mempty BreadCrumbs
bcs (Doc -> m ()) -> Doc -> m ()
forall a b. (a -> b) -> a -> b
$ Doc
"Cannot instantiate" Doc -> Doc -> Doc
<+> Doc -> Doc
pquote (VName -> Doc
forall v. IsName v => v -> Doc
pprName VName
vn) Doc -> Doc -> Doc
<+>
              Doc
"with type" Doc -> Doc -> Doc
</> Int -> Doc -> Doc
indent Int
2 (StructType -> Doc
forall a. Pretty a => a -> Doc
ppr StructType
tp) Doc -> Doc -> Doc
</> Doc
"as" Doc -> Doc -> Doc
<+>
              Doc -> Doc
pquote (VName -> Doc
forall v. IsName v => v -> Doc
pprName VName
vn) Doc -> Doc -> Doc
<+> Doc
"must be one of" Doc -> Doc -> Doc
<+>
              [Doc] -> Doc
commasep ((PrimType -> Doc) -> [PrimType] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map PrimType -> Doc
forall a. Pretty a => a -> Doc
ppr [PrimType]
ts) Doc -> Doc -> Doc
<+/>
              Doc
"due to" Doc -> Doc -> Doc
<+/> Usage -> Doc
forall a. Pretty a => a -> Doc
ppr Usage
old_usage Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
"."

    Just (HasFields Map Name StructType
required_fields Usage
old_usage) ->
      case StructType
tp of
        Scalar (Record Map Name StructType
tp_fields)
          | (Name -> Bool) -> [Name] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Name -> Map Name StructType -> Bool
forall k a. Ord k => k -> Map k a -> Bool
`M.member` Map Name StructType
tp_fields) ([Name] -> Bool) -> [Name] -> Bool
forall a b. (a -> b) -> a -> b
$ Map Name StructType -> [Name]
forall k a. Map k a -> [k]
M.keys Map Name StructType
required_fields -> do
              Map Name StructType
required_fields' <- (StructType -> m StructType)
-> Map Name StructType -> m (Map Name StructType)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM StructType -> m StructType
forall a (m :: * -> *). (Substitutable a, MonadUnify m) => a -> m a
normTypeFully Map Name StructType
required_fields
              let bcs' :: BreadCrumbs
bcs' =
                    BreadCrumb -> BreadCrumbs -> BreadCrumbs
breadCrumb
                    (Doc -> BreadCrumb
Matching (Doc -> BreadCrumb) -> Doc -> BreadCrumb
forall a b. (a -> b) -> a -> b
$ VName -> Doc
forall v. IsName v => v -> Doc
pprName VName
vn Doc -> Doc -> Doc
<+>
                     Doc
"must be a record with at least the fields:" Doc -> Doc -> Doc
</>
                     Int -> Doc -> Doc
indent Int
2 (ScalarTypeBase (DimDecl VName) () -> Doc
forall a. Pretty a => a -> Doc
ppr (Map Name StructType -> ScalarTypeBase (DimDecl VName) ()
forall dim as. Map Name (TypeBase dim as) -> ScalarTypeBase dim as
Record Map Name StructType
required_fields')) Doc -> Doc -> Doc
</>
                    Doc
"due to" Doc -> Doc -> Doc
<+> Usage -> Doc
forall a. Pretty a => a -> Doc
ppr Usage
old_usage Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
".")
                    BreadCrumbs
bcs
              ((StructType, StructType) -> m ())
-> [(StructType, StructType)] -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ((StructType -> StructType -> m ())
-> (StructType, StructType) -> m ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ((StructType -> StructType -> m ())
 -> (StructType, StructType) -> m ())
-> (StructType -> StructType -> m ())
-> (StructType, StructType)
-> m ()
forall a b. (a -> b) -> a -> b
$ UnifyDims m
-> Usage -> BreadCrumbs -> StructType -> StructType -> m ()
forall (m :: * -> *).
MonadUnify m =>
UnifyDims m
-> Usage -> BreadCrumbs -> StructType -> StructType -> m ()
unifyWith UnifyDims m
onDims Usage
usage BreadCrumbs
bcs') ([(StructType, StructType)] -> m ())
-> [(StructType, StructType)] -> m ()
forall a b. (a -> b) -> a -> b
$ Map Name (StructType, StructType) -> [(StructType, StructType)]
forall k a. Map k a -> [a]
M.elems (Map Name (StructType, StructType) -> [(StructType, StructType)])
-> Map Name (StructType, StructType) -> [(StructType, StructType)]
forall a b. (a -> b) -> a -> b
$
                (StructType -> StructType -> (StructType, StructType))
-> Map Name StructType
-> Map Name StructType
-> Map Name (StructType, StructType)
forall k a b c.
Ord k =>
(a -> b -> c) -> Map k a -> Map k b -> Map k c
M.intersectionWith (,) Map Name StructType
required_fields Map Name StructType
tp_fields
        Scalar (TypeVar ()
_ Uniqueness
_ (TypeName [] VName
v) [])
          | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ VName -> Constraints -> Bool
isRigid VName
v Constraints
constraints ->
              (Constraints -> Constraints) -> m ()
forall (m :: * -> *).
MonadUnify m =>
(Constraints -> Constraints) -> m ()
modifyConstraints ((Constraints -> Constraints) -> m ())
-> (Constraints -> Constraints) -> m ()
forall a b. (a -> b) -> a -> b
$ VName -> (Int, Constraint) -> Constraints -> Constraints
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v
              (Int
lvl, Map Name StructType -> Usage -> Constraint
HasFields Map Name StructType
required_fields Usage
old_usage)
        StructType
_ ->
          Usage -> Notes -> BreadCrumbs -> Doc -> m ()
forall (m :: * -> *) loc a.
(MonadUnify m, Located loc) =>
loc -> Notes -> BreadCrumbs -> Doc -> m a
unifyError Usage
usage Notes
forall a. Monoid a => a
mempty BreadCrumbs
bcs (Doc -> m ()) -> Doc -> m ()
forall a b. (a -> b) -> a -> b
$
          Doc
"Cannot instantiate" Doc -> Doc -> Doc
<+> Doc -> Doc
pquote (VName -> Doc
forall v. IsName v => v -> Doc
pprName VName
vn) Doc -> Doc -> Doc
<+> Doc
"with type" Doc -> Doc -> Doc
</>
          Int -> Doc -> Doc
indent Int
2 (StructType -> Doc
forall a. Pretty a => a -> Doc
ppr StructType
tp) Doc -> Doc -> Doc
</>
          Doc
"as" Doc -> Doc -> Doc
<+> Doc -> Doc
pquote (VName -> Doc
forall v. IsName v => v -> Doc
pprName VName
vn) Doc -> Doc -> Doc
<+> Doc
"must be a record with fields" Doc -> Doc -> Doc
</>
          Int -> Doc -> Doc
indent Int
2 (ScalarTypeBase (DimDecl VName) () -> Doc
forall a. Pretty a => a -> Doc
ppr (Map Name StructType -> ScalarTypeBase (DimDecl VName) ()
forall dim as. Map Name (TypeBase dim as) -> ScalarTypeBase dim as
Record Map Name StructType
required_fields)) Doc -> Doc -> Doc
</>
          Doc
"due to" Doc -> Doc -> Doc
<+> Usage -> Doc
forall a. Pretty a => a -> Doc
ppr Usage
old_usage Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
"."

    Just (HasConstrs Map Name [StructType]
required_cs Usage
old_usage) ->
      case StructType
tp of
        Scalar (Sum Map Name [StructType]
ts)
          | (Name -> Bool) -> [Name] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Name -> Map Name [StructType] -> Bool
forall k a. Ord k => k -> Map k a -> Bool
`M.member` Map Name [StructType]
ts) ([Name] -> Bool) -> [Name] -> Bool
forall a b. (a -> b) -> a -> b
$ Map Name [StructType] -> [Name]
forall k a. Map k a -> [k]
M.keys Map Name [StructType]
required_cs ->
              UnifyDims m
-> Usage
-> BreadCrumbs
-> Map Name [StructType]
-> Map Name [StructType]
-> m ()
forall (m :: * -> *).
MonadUnify m =>
UnifyDims m
-> Usage
-> BreadCrumbs
-> Map Name [StructType]
-> Map Name [StructType]
-> m ()
unifySharedConstructors UnifyDims m
onDims Usage
usage BreadCrumbs
bcs Map Name [StructType]
required_cs Map Name [StructType]
ts
        Scalar (TypeVar ()
_ Uniqueness
_ (TypeName [] VName
v) [])
          | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ VName -> Constraints -> Bool
isRigid VName
v Constraints
constraints -> do
              case VName -> Constraints -> Maybe (Int, Constraint)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Constraints
constraints of
                Just (Int
_, HasConstrs Map Name [StructType]
v_cs Usage
_) ->
                  UnifyDims m
-> Usage
-> BreadCrumbs
-> Map Name [StructType]
-> Map Name [StructType]
-> m ()
forall (m :: * -> *).
MonadUnify m =>
UnifyDims m
-> Usage
-> BreadCrumbs
-> Map Name [StructType]
-> Map Name [StructType]
-> m ()
unifySharedConstructors UnifyDims m
onDims Usage
usage BreadCrumbs
bcs Map Name [StructType]
required_cs Map Name [StructType]
v_cs
                Maybe (Int, Constraint)
_ -> () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
              (Constraints -> Constraints) -> m ()
forall (m :: * -> *).
MonadUnify m =>
(Constraints -> Constraints) -> m ()
modifyConstraints ((Constraints -> Constraints) -> m ())
-> (Constraints -> Constraints) -> m ()
forall a b. (a -> b) -> a -> b
$ ((Int, Constraint) -> (Int, Constraint) -> (Int, Constraint))
-> VName -> (Int, Constraint) -> Constraints -> Constraints
forall k a. Ord k => (a -> a -> a) -> k -> a -> Map k a -> Map k a
M.insertWith (Int, Constraint) -> (Int, Constraint) -> (Int, Constraint)
forall a. (Int, Constraint) -> (a, Constraint) -> (Int, Constraint)
combineConstrs VName
v
                (Int
lvl, Map Name [StructType] -> Usage -> Constraint
HasConstrs Map Name [StructType]
required_cs Usage
old_usage)
              where combineConstrs :: (Int, Constraint) -> (a, Constraint) -> (Int, Constraint)
combineConstrs (Int
_, HasConstrs Map Name [StructType]
cs1 Usage
usage1) (a
_, HasConstrs Map Name [StructType]
cs2 Usage
_) =
                      (Int
lvl, Map Name [StructType] -> Usage -> Constraint
HasConstrs (Map Name [StructType]
-> Map Name [StructType] -> Map Name [StructType]
forall k a. Ord k => Map k a -> Map k a -> Map k a
M.union Map Name [StructType]
cs1 Map Name [StructType]
cs2) Usage
usage1)
                    combineConstrs (Int, Constraint)
hasCs (a, Constraint)
_ = (Int, Constraint)
hasCs
        StructType
_ -> m ()
forall a. m a
noSumType

    Maybe Constraint
_ -> () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

  where noSumType :: m a
noSumType = Usage -> Notes -> BreadCrumbs -> Doc -> m a
forall (m :: * -> *) loc a.
(MonadUnify m, Located loc) =>
loc -> Notes -> BreadCrumbs -> Doc -> m a
unifyError Usage
usage Notes
forall a. Monoid a => a
mempty BreadCrumbs
bcs
                    Doc
"Cannot unify a sum type with a non-sum type"

linkVarToDim :: MonadUnify m =>
                Usage -> BreadCrumbs
             -> VName -> Level -> DimDecl VName -> m ()
linkVarToDim :: Usage -> BreadCrumbs -> VName -> Int -> DimDecl VName -> m ()
linkVarToDim Usage
usage BreadCrumbs
bcs VName
vn Int
lvl DimDecl VName
dim = do
  Constraints
constraints <- m Constraints
forall (m :: * -> *). MonadUnify m => m Constraints
getConstraints

  case DimDecl VName
dim of
    NamedDim QualName VName
dim'
      | Just (Int
dim_lvl, Constraint
c) <- QualName VName -> VName
forall vn. QualName vn -> vn
qualLeaf QualName VName
dim' VName -> Constraints -> Maybe (Int, Constraint)
forall k a. Ord k => k -> Map k a -> Maybe a
`M.lookup` Constraints
constraints,
        Int
dim_lvl Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
lvl ->
          case Constraint
c of
            ParamSize{} -> do
              Notes
notes <- Usage -> DimDecl VName -> m Notes
forall a (m :: * -> *).
(Located a, MonadUnify m) =>
a -> DimDecl VName -> m Notes
dimNotes Usage
usage DimDecl VName
dim
              Usage -> Notes -> BreadCrumbs -> Doc -> m ()
forall (m :: * -> *) loc a.
(MonadUnify m, Located loc) =>
loc -> Notes -> BreadCrumbs -> Doc -> m a
unifyError Usage
usage Notes
notes BreadCrumbs
bcs (Doc -> m ()) -> Doc -> m ()
forall a b. (a -> b) -> a -> b
$
                Doc
"Cannot unify size variable" Doc -> Doc -> Doc
<+> Doc -> Doc
pquote (QualName VName -> Doc
forall a. Pretty a => a -> Doc
ppr QualName VName
dim') Doc -> Doc -> Doc
<+>
                Doc
"with" Doc -> Doc -> Doc
<+> Doc -> Doc
pquote (VName -> Doc
forall v. IsName v => v -> Doc
pprName VName
vn) Doc -> Doc -> Doc
<+> Doc
"(scope violation)." Doc -> Doc -> Doc
</>
                Doc
"This is because" Doc -> Doc -> Doc
<+> Doc -> Doc
pquote (QualName VName -> Doc
forall a. Pretty a => a -> Doc
ppr QualName VName
dim') Doc -> Doc -> Doc
<+>
                Doc
"is rigidly bound in a deeper scope."
            Constraint
_ -> (Constraints -> Constraints) -> m ()
forall (m :: * -> *).
MonadUnify m =>
(Constraints -> Constraints) -> m ()
modifyConstraints ((Constraints -> Constraints) -> m ())
-> (Constraints -> Constraints) -> m ()
forall a b. (a -> b) -> a -> b
$ VName -> (Int, Constraint) -> Constraints -> Constraints
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert (QualName VName -> VName
forall vn. QualName vn -> vn
qualLeaf QualName VName
dim') (Int
lvl, Constraint
c)
    DimDecl VName
_ -> () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

  (Constraints -> Constraints) -> m ()
forall (m :: * -> *).
MonadUnify m =>
(Constraints -> Constraints) -> m ()
modifyConstraints ((Constraints -> Constraints) -> m ())
-> (Constraints -> Constraints) -> m ()
forall a b. (a -> b) -> a -> b
$ VName -> (Int, Constraint) -> Constraints -> Constraints
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
vn (Int
lvl, Maybe (DimDecl VName) -> Usage -> Constraint
Size (DimDecl VName -> Maybe (DimDecl VName)
forall a. a -> Maybe a
Just DimDecl VName
dim) Usage
usage)

removeUniqueness :: TypeBase dim as -> TypeBase dim as
removeUniqueness :: TypeBase dim as -> TypeBase dim as
removeUniqueness (Scalar (Record Map Name (TypeBase dim as)
ets)) =
  ScalarTypeBase dim as -> TypeBase dim as
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (ScalarTypeBase dim as -> TypeBase dim as)
-> ScalarTypeBase dim as -> TypeBase dim as
forall a b. (a -> b) -> a -> b
$ Map Name (TypeBase dim as) -> ScalarTypeBase dim as
forall dim as. Map Name (TypeBase dim as) -> ScalarTypeBase dim as
Record (Map Name (TypeBase dim as) -> ScalarTypeBase dim as)
-> Map Name (TypeBase dim as) -> ScalarTypeBase dim as
forall a b. (a -> b) -> a -> b
$ (TypeBase dim as -> TypeBase dim as)
-> Map Name (TypeBase dim as) -> Map Name (TypeBase dim as)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap TypeBase dim as -> TypeBase dim as
forall dim as. TypeBase dim as -> TypeBase dim as
removeUniqueness Map Name (TypeBase dim as)
ets
removeUniqueness (Scalar (Arrow as
als PName
p TypeBase dim as
t1 TypeBase dim as
t2)) =
  ScalarTypeBase dim as -> TypeBase dim as
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (ScalarTypeBase dim as -> TypeBase dim as)
-> ScalarTypeBase dim as -> TypeBase dim as
forall a b. (a -> b) -> a -> b
$ as
-> PName
-> TypeBase dim as
-> TypeBase dim as
-> ScalarTypeBase dim as
forall dim as.
as
-> PName
-> TypeBase dim as
-> TypeBase dim as
-> ScalarTypeBase dim as
Arrow as
als PName
p (TypeBase dim as -> TypeBase dim as
forall dim as. TypeBase dim as -> TypeBase dim as
removeUniqueness TypeBase dim as
t1) (TypeBase dim as -> TypeBase dim as
forall dim as. TypeBase dim as -> TypeBase dim as
removeUniqueness TypeBase dim as
t2)
removeUniqueness (Scalar (Sum Map Name [TypeBase dim as]
cs)) =
  ScalarTypeBase dim as -> TypeBase dim as
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (ScalarTypeBase dim as -> TypeBase dim as)
-> ScalarTypeBase dim as -> TypeBase dim as
forall a b. (a -> b) -> a -> b
$ Map Name [TypeBase dim as] -> ScalarTypeBase dim as
forall dim as. Map Name [TypeBase dim as] -> ScalarTypeBase dim as
Sum (Map Name [TypeBase dim as] -> ScalarTypeBase dim as)
-> Map Name [TypeBase dim as] -> ScalarTypeBase dim as
forall a b. (a -> b) -> a -> b
$ (([TypeBase dim as] -> [TypeBase dim as])
-> Map Name [TypeBase dim as] -> Map Name [TypeBase dim as]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (([TypeBase dim as] -> [TypeBase dim as])
 -> Map Name [TypeBase dim as] -> Map Name [TypeBase dim as])
-> ((TypeBase dim as -> TypeBase dim as)
    -> [TypeBase dim as] -> [TypeBase dim as])
-> (TypeBase dim as -> TypeBase dim as)
-> Map Name [TypeBase dim as]
-> Map Name [TypeBase dim as]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (TypeBase dim as -> TypeBase dim as)
-> [TypeBase dim as] -> [TypeBase dim as]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap) TypeBase dim as -> TypeBase dim as
forall dim as. TypeBase dim as -> TypeBase dim as
removeUniqueness Map Name [TypeBase dim as]
cs
removeUniqueness TypeBase dim as
t = TypeBase dim as
t TypeBase dim as -> Uniqueness -> TypeBase dim as
forall dim as. TypeBase dim as -> Uniqueness -> TypeBase dim as
`setUniqueness` Uniqueness
Nonunique

-- | Assert that this type must be one of the given primitive types.
mustBeOneOf :: MonadUnify m => [PrimType] -> Usage -> StructType -> m ()
mustBeOneOf :: [PrimType] -> Usage -> StructType -> m ()
mustBeOneOf [PrimType
req_t] Usage
usage StructType
t = Usage -> StructType -> StructType -> m ()
forall (m :: * -> *).
MonadUnify m =>
Usage -> StructType -> StructType -> m ()
unify Usage
usage (ScalarTypeBase (DimDecl VName) () -> StructType
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (PrimType -> ScalarTypeBase (DimDecl VName) ()
forall dim as. PrimType -> ScalarTypeBase dim as
Prim PrimType
req_t)) StructType
t
mustBeOneOf [PrimType]
ts Usage
usage StructType
t = do
  StructType
t' <- StructType -> m StructType
forall (m :: * -> *). MonadUnify m => StructType -> m StructType
normType StructType
t
  Constraints
constraints <- m Constraints
forall (m :: * -> *). MonadUnify m => m Constraints
getConstraints
  let isRigid' :: VName -> Bool
isRigid' VName
v = VName -> Constraints -> Bool
isRigid VName
v Constraints
constraints

  case StructType
t' of
    Scalar (TypeVar ()
_ Uniqueness
_ (TypeName [] VName
v) [])
      | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ VName -> Bool
isRigid' VName
v -> Usage -> VName -> [PrimType] -> m ()
forall (m :: * -> *).
MonadUnify m =>
Usage -> VName -> [PrimType] -> m ()
linkVarToTypes Usage
usage VName
v [PrimType]
ts

    Scalar (Prim PrimType
pt) | PrimType
pt PrimType -> [PrimType] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [PrimType]
ts -> () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

    StructType
_ -> m ()
forall a. m a
failure

  where failure :: m a
failure = Usage -> Notes -> BreadCrumbs -> Doc -> m a
forall (m :: * -> *) loc a.
(MonadUnify m, Located loc) =>
loc -> Notes -> BreadCrumbs -> Doc -> m a
unifyError Usage
usage Notes
forall a. Monoid a => a
mempty BreadCrumbs
noBreadCrumbs (Doc -> m a) -> Doc -> m a
forall a b. (a -> b) -> a -> b
$
                  String -> Doc
text String
"Cannot unify type" Doc -> Doc -> Doc
<+> Doc -> Doc
pquote (StructType -> Doc
forall a. Pretty a => a -> Doc
ppr StructType
t) Doc -> Doc -> Doc
<+>
                  Doc
"with any of " Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> [Doc] -> Doc
commasep ((PrimType -> Doc) -> [PrimType] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map PrimType -> Doc
forall a. Pretty a => a -> Doc
ppr [PrimType]
ts) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
"."

linkVarToTypes :: MonadUnify m => Usage -> VName -> [PrimType] -> m ()
linkVarToTypes :: Usage -> VName -> [PrimType] -> m ()
linkVarToTypes Usage
usage VName
vn [PrimType]
ts = do
  Maybe (Int, Constraint)
vn_constraint <- VName -> Constraints -> Maybe (Int, Constraint)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
vn (Constraints -> Maybe (Int, Constraint))
-> m Constraints -> m (Maybe (Int, Constraint))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m Constraints
forall (m :: * -> *). MonadUnify m => m Constraints
getConstraints
  case Maybe (Int, Constraint)
vn_constraint of
    Just (Int
lvl, Overloaded [PrimType]
vn_ts Usage
vn_usage) ->
      case [PrimType]
ts [PrimType] -> [PrimType] -> [PrimType]
forall a. Eq a => [a] -> [a] -> [a]
`intersect` [PrimType]
vn_ts of
        [] -> Usage -> Notes -> BreadCrumbs -> Doc -> m ()
forall (m :: * -> *) loc a.
(MonadUnify m, Located loc) =>
loc -> Notes -> BreadCrumbs -> Doc -> m a
unifyError Usage
usage Notes
forall a. Monoid a => a
mempty BreadCrumbs
noBreadCrumbs (Doc -> m ()) -> Doc -> m ()
forall a b. (a -> b) -> a -> b
$
              Doc
"Type constrained to one of" Doc -> Doc -> Doc
<+>
              [Doc] -> Doc
commasep ((PrimType -> Doc) -> [PrimType] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map PrimType -> Doc
forall a. Pretty a => a -> Doc
ppr [PrimType]
ts) Doc -> Doc -> Doc
<+> Doc
"but also one of" Doc -> Doc -> Doc
<+>
              [Doc] -> Doc
commasep ((PrimType -> Doc) -> [PrimType] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map PrimType -> Doc
forall a. Pretty a => a -> Doc
ppr [PrimType]
vn_ts) Doc -> Doc -> Doc
<+> Doc
"due to" Doc -> Doc -> Doc
<+> Usage -> Doc
forall a. Pretty a => a -> Doc
ppr Usage
vn_usage Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
"."
        [PrimType]
ts' -> (Constraints -> Constraints) -> m ()
forall (m :: * -> *).
MonadUnify m =>
(Constraints -> Constraints) -> m ()
modifyConstraints ((Constraints -> Constraints) -> m ())
-> (Constraints -> Constraints) -> m ()
forall a b. (a -> b) -> a -> b
$ VName -> (Int, Constraint) -> Constraints -> Constraints
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
vn (Int
lvl, [PrimType] -> Usage -> Constraint
Overloaded [PrimType]
ts' Usage
usage)

    Just (Int
_, HasConstrs Map Name [StructType]
_ Usage
vn_usage) ->
      Usage -> Notes -> BreadCrumbs -> Doc -> m ()
forall (m :: * -> *) loc a.
(MonadUnify m, Located loc) =>
loc -> Notes -> BreadCrumbs -> Doc -> m a
unifyError Usage
usage Notes
forall a. Monoid a => a
mempty BreadCrumbs
noBreadCrumbs (Doc -> m ()) -> Doc -> m ()
forall a b. (a -> b) -> a -> b
$
      Doc
"Type constrained to one of" Doc -> Doc -> Doc
<+> [Doc] -> Doc
commasep ((PrimType -> Doc) -> [PrimType] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map PrimType -> Doc
forall a. Pretty a => a -> Doc
ppr [PrimType]
ts) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<>
      Doc
", but also inferred to be sum type due to" Doc -> Doc -> Doc
<+> Usage -> Doc
forall a. Pretty a => a -> Doc
ppr Usage
vn_usage Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
"."

    Just (Int
_, HasFields Map Name StructType
_ Usage
vn_usage) ->
      Usage -> Notes -> BreadCrumbs -> Doc -> m ()
forall (m :: * -> *) loc a.
(MonadUnify m, Located loc) =>
loc -> Notes -> BreadCrumbs -> Doc -> m a
unifyError Usage
usage Notes
forall a. Monoid a => a
mempty BreadCrumbs
noBreadCrumbs (Doc -> m ()) -> Doc -> m ()
forall a b. (a -> b) -> a -> b
$
      Doc
"Type constrained to one of" Doc -> Doc -> Doc
<+> [Doc] -> Doc
commasep ((PrimType -> Doc) -> [PrimType] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map PrimType -> Doc
forall a. Pretty a => a -> Doc
ppr [PrimType]
ts) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<>
      Doc
", but also inferred to be record due to" Doc -> Doc -> Doc
<+> Usage -> Doc
forall a. Pretty a => a -> Doc
ppr Usage
vn_usage Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
"."

    Just (Int
lvl, Constraint
_) -> (Constraints -> Constraints) -> m ()
forall (m :: * -> *).
MonadUnify m =>
(Constraints -> Constraints) -> m ()
modifyConstraints ((Constraints -> Constraints) -> m ())
-> (Constraints -> Constraints) -> m ()
forall a b. (a -> b) -> a -> b
$ VName -> (Int, Constraint) -> Constraints -> Constraints
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
vn (Int
lvl, [PrimType] -> Usage -> Constraint
Overloaded [PrimType]
ts Usage
usage)

    Maybe (Int, Constraint)
Nothing ->
      Usage -> Notes -> BreadCrumbs -> Doc -> m ()
forall (m :: * -> *) loc a.
(MonadUnify m, Located loc) =>
loc -> Notes -> BreadCrumbs -> Doc -> m a
unifyError Usage
usage Notes
forall a. Monoid a => a
mempty BreadCrumbs
noBreadCrumbs (Doc -> m ()) -> Doc -> m ()
forall a b. (a -> b) -> a -> b
$
      Doc
"Cannot constrain type to one of" Doc -> Doc -> Doc
<+> [Doc] -> Doc
commasep ((PrimType -> Doc) -> [PrimType] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map PrimType -> Doc
forall a. Pretty a => a -> Doc
ppr [PrimType]
ts)

-- | Assert that this type must support equality.
equalityType :: (MonadUnify m, Pretty (ShapeDecl dim), Monoid as) =>
                Usage -> TypeBase dim as -> m ()
equalityType :: Usage -> TypeBase dim as -> m ()
equalityType Usage
usage TypeBase dim as
t = do
  Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (TypeBase dim as -> Bool
forall dim as. TypeBase dim as -> Bool
orderZero TypeBase dim as
t) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$
    Usage -> Notes -> BreadCrumbs -> Doc -> m ()
forall (m :: * -> *) loc a.
(MonadUnify m, Located loc) =>
loc -> Notes -> BreadCrumbs -> Doc -> m a
unifyError Usage
usage Notes
forall a. Monoid a => a
mempty BreadCrumbs
noBreadCrumbs (Doc -> m ()) -> Doc -> m ()
forall a b. (a -> b) -> a -> b
$
    Doc
"Type " Doc -> Doc -> Doc
<+> Doc -> Doc
pquote (TypeBase dim as -> Doc
forall a. Pretty a => a -> Doc
ppr TypeBase dim as
t) Doc -> Doc -> Doc
<+> Doc
"does not support equality (is higher-order)."
  (VName -> m ()) -> Set VName -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ VName -> m ()
forall (m :: * -> *). MonadUnify m => VName -> m ()
mustBeEquality (Set VName -> m ()) -> Set VName -> m ()
forall a b. (a -> b) -> a -> b
$ TypeBase dim as -> Set VName
forall as dim. Monoid as => TypeBase dim as -> Set VName
typeVars TypeBase dim as
t
  where mustBeEquality :: VName -> m ()
mustBeEquality VName
vn = do
          Constraints
constraints <- m Constraints
forall (m :: * -> *). MonadUnify m => m Constraints
getConstraints
          case VName -> Constraints -> Maybe (Int, Constraint)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
vn Constraints
constraints of
            Just (Int
_, Constraint (Scalar (TypeVar ()
_ Uniqueness
_ (TypeName [] VName
vn') [])) Usage
_) ->
              VName -> m ()
mustBeEquality VName
vn'
            Just (Int
_, Constraint StructType
vn_t Usage
cusage)
              | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ StructType -> Bool
forall dim as. TypeBase dim as -> Bool
orderZero StructType
vn_t ->
                  Usage -> Notes -> BreadCrumbs -> Doc -> m ()
forall (m :: * -> *) loc a.
(MonadUnify m, Located loc) =>
loc -> Notes -> BreadCrumbs -> Doc -> m a
unifyError Usage
usage Notes
forall a. Monoid a => a
mempty BreadCrumbs
noBreadCrumbs (Doc -> m ()) -> Doc -> m ()
forall a b. (a -> b) -> a -> b
$
                  Doc
"Type" Doc -> Doc -> Doc
<+> Doc -> Doc
pquote (TypeBase dim as -> Doc
forall a. Pretty a => a -> Doc
ppr TypeBase dim as
t) Doc -> Doc -> Doc
<+> Doc
"does not support equality." Doc -> Doc -> Doc
</>
                  Doc
"Constrained to be higher-order due to" Doc -> Doc -> Doc
<+> Usage -> Doc
forall a. Pretty a => a -> Doc
ppr Usage
cusage Doc -> Doc -> Doc
<+> Doc
"."
              | Bool
otherwise -> () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
            Just (Int
lvl, NoConstraint Liftedness
_ Usage
_) ->
              (Constraints -> Constraints) -> m ()
forall (m :: * -> *).
MonadUnify m =>
(Constraints -> Constraints) -> m ()
modifyConstraints ((Constraints -> Constraints) -> m ())
-> (Constraints -> Constraints) -> m ()
forall a b. (a -> b) -> a -> b
$ VName -> (Int, Constraint) -> Constraints -> Constraints
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
vn (Int
lvl, Usage -> Constraint
Equality Usage
usage)
            Just (Int
_, Overloaded [PrimType]
_ Usage
_) ->
              () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return () -- All primtypes support equality.
            Just (Int
_, Equality{}) ->
              () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
            Just (Int
_, HasConstrs Map Name [StructType]
cs Usage
_) ->
              (StructType -> m ()) -> [StructType] -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Usage -> StructType -> m ()
forall (m :: * -> *) dim as.
(MonadUnify m, Pretty (ShapeDecl dim), Monoid as) =>
Usage -> TypeBase dim as -> m ()
equalityType Usage
usage) ([StructType] -> m ()) -> [StructType] -> m ()
forall a b. (a -> b) -> a -> b
$ [[StructType]] -> [StructType]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[StructType]] -> [StructType]) -> [[StructType]] -> [StructType]
forall a b. (a -> b) -> a -> b
$ Map Name [StructType] -> [[StructType]]
forall k a. Map k a -> [a]
M.elems Map Name [StructType]
cs
            Maybe (Int, Constraint)
_ ->
              Usage -> Notes -> BreadCrumbs -> Doc -> m ()
forall (m :: * -> *) loc a.
(MonadUnify m, Located loc) =>
loc -> Notes -> BreadCrumbs -> Doc -> m a
unifyError Usage
usage Notes
forall a. Monoid a => a
mempty BreadCrumbs
noBreadCrumbs (Doc -> m ()) -> Doc -> m ()
forall a b. (a -> b) -> a -> b
$
              Doc
"Type" Doc -> Doc -> Doc
<+> VName -> Doc
forall v. IsName v => v -> Doc
pprName VName
vn Doc -> Doc -> Doc
<+> Doc
"does not support equality."

zeroOrderTypeWith :: (MonadUnify m, Pretty (ShapeDecl dim), Monoid as) =>
                     Usage -> BreadCrumbs -> TypeBase dim as -> m ()
zeroOrderTypeWith :: Usage -> BreadCrumbs -> TypeBase dim as -> m ()
zeroOrderTypeWith Usage
usage BreadCrumbs
bcs TypeBase dim as
t = do
  Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (TypeBase dim as -> Bool
forall dim as. TypeBase dim as -> Bool
orderZero TypeBase dim as
t) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$
    Usage -> Notes -> BreadCrumbs -> Doc -> m ()
forall (m :: * -> *) loc a.
(MonadUnify m, Located loc) =>
loc -> Notes -> BreadCrumbs -> Doc -> m a
unifyError Usage
usage Notes
forall a. Monoid a => a
mempty BreadCrumbs
bcs (Doc -> m ()) -> Doc -> m ()
forall a b. (a -> b) -> a -> b
$
    Doc
"Type" Doc -> Doc -> Doc
</> Int -> Doc -> Doc
indent Int
2 (TypeBase dim as -> Doc
forall a. Pretty a => a -> Doc
ppr TypeBase dim as
t) Doc -> Doc -> Doc
</> Doc
"found to be functional."
  (VName -> m ()) -> [VName] -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ VName -> m ()
forall (m :: * -> *). MonadUnify m => VName -> m ()
mustBeZeroOrder ([VName] -> m ())
-> (TypeBase dim as -> [VName]) -> TypeBase dim as -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Set VName -> [VName]
forall a. Set a -> [a]
S.toList (Set VName -> [VName])
-> (TypeBase dim as -> Set VName) -> TypeBase dim as -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TypeBase dim as -> Set VName
forall as dim. Monoid as => TypeBase dim as -> Set VName
typeVars (TypeBase dim as -> m ()) -> TypeBase dim as -> m ()
forall a b. (a -> b) -> a -> b
$ TypeBase dim as
t
  where mustBeZeroOrder :: VName -> m ()
mustBeZeroOrder VName
vn = do
          Constraints
constraints <- m Constraints
forall (m :: * -> *). MonadUnify m => m Constraints
getConstraints
          case VName -> Constraints -> Maybe (Int, Constraint)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
vn Constraints
constraints of
            Just (Int
lvl, NoConstraint Liftedness
_ Usage
_) ->
              (Constraints -> Constraints) -> m ()
forall (m :: * -> *).
MonadUnify m =>
(Constraints -> Constraints) -> m ()
modifyConstraints ((Constraints -> Constraints) -> m ())
-> (Constraints -> Constraints) -> m ()
forall a b. (a -> b) -> a -> b
$ VName -> (Int, Constraint) -> Constraints -> Constraints
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
vn (Int
lvl, Liftedness -> Usage -> Constraint
NoConstraint Liftedness
Unlifted Usage
usage)
            Just (Int
_, ParamType Liftedness
Lifted SrcLoc
ploc) ->
              Usage -> Notes -> BreadCrumbs -> Doc -> m ()
forall (m :: * -> *) loc a.
(MonadUnify m, Located loc) =>
loc -> Notes -> BreadCrumbs -> Doc -> m a
unifyError Usage
usage Notes
forall a. Monoid a => a
mempty BreadCrumbs
bcs (Doc -> m ()) -> Doc -> m ()
forall a b. (a -> b) -> a -> b
$ Doc
"Type parameter" Doc -> Doc -> Doc
<+>
              Doc -> Doc
pquote (VName -> Doc
forall v. IsName v => v -> Doc
pprName VName
vn) Doc -> Doc -> Doc
<+> Doc
"at" Doc -> Doc -> Doc
<+>
              String -> Doc
text (SrcLoc -> String
forall a. Located a => a -> String
locStr SrcLoc
ploc) Doc -> Doc -> Doc
<+> Doc
"may be a function."
            Maybe (Int, Constraint)
_ -> () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

-- | Assert that this type must be zero-order.
zeroOrderType :: (MonadUnify m, Pretty (ShapeDecl dim), Monoid as) =>
                 Usage -> String -> TypeBase dim as -> m ()
zeroOrderType :: Usage -> String -> TypeBase dim as -> m ()
zeroOrderType Usage
usage String
desc =
  Usage -> BreadCrumbs -> TypeBase dim as -> m ()
forall (m :: * -> *) dim as.
(MonadUnify m, Pretty (ShapeDecl dim), Monoid as) =>
Usage -> BreadCrumbs -> TypeBase dim as -> m ()
zeroOrderTypeWith Usage
usage (BreadCrumbs -> TypeBase dim as -> m ())
-> BreadCrumbs -> TypeBase dim as -> m ()
forall a b. (a -> b) -> a -> b
$ BreadCrumb -> BreadCrumbs -> BreadCrumbs
breadCrumb BreadCrumb
bc BreadCrumbs
noBreadCrumbs
  where bc :: BreadCrumb
bc = Doc -> BreadCrumb
Matching (Doc -> BreadCrumb) -> Doc -> BreadCrumb
forall a b. (a -> b) -> a -> b
$ Doc
"When checking" Doc -> Doc -> Doc
<+> String -> Doc
textwrap String
desc

unifySharedConstructors :: MonadUnify m =>
                           UnifyDims m -> Usage -> BreadCrumbs
                        -> M.Map Name [StructType]
                        -> M.Map Name [StructType]
                        -> m ()
unifySharedConstructors :: UnifyDims m
-> Usage
-> BreadCrumbs
-> Map Name [StructType]
-> Map Name [StructType]
-> m ()
unifySharedConstructors UnifyDims m
onDims Usage
usage BreadCrumbs
bcs Map Name [StructType]
cs1 Map Name [StructType]
cs2 =
  [(Name, ([StructType], [StructType]))]
-> ((Name, ([StructType], [StructType])) -> m ()) -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (Map Name ([StructType], [StructType])
-> [(Name, ([StructType], [StructType]))]
forall k a. Map k a -> [(k, a)]
M.toList (Map Name ([StructType], [StructType])
 -> [(Name, ([StructType], [StructType]))])
-> Map Name ([StructType], [StructType])
-> [(Name, ([StructType], [StructType]))]
forall a b. (a -> b) -> a -> b
$ ([StructType] -> [StructType] -> ([StructType], [StructType]))
-> Map Name [StructType]
-> Map Name [StructType]
-> Map Name ([StructType], [StructType])
forall k a b c.
Ord k =>
(a -> b -> c) -> Map k a -> Map k b -> Map k c
M.intersectionWith (,) Map Name [StructType]
cs1 Map Name [StructType]
cs2) (((Name, ([StructType], [StructType])) -> m ()) -> m ())
-> ((Name, ([StructType], [StructType])) -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \(Name
c, ([StructType]
f1, [StructType]
f2)) ->
  Name -> [StructType] -> [StructType] -> m ()
unifyConstructor Name
c [StructType]
f1 [StructType]
f2
  where unifyConstructor :: Name -> [StructType] -> [StructType] -> m ()
unifyConstructor Name
c [StructType]
f1 [StructType]
f2
          | [StructType] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [StructType]
f1 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [StructType] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [StructType]
f2 = do
              let bcs' :: BreadCrumbs
bcs' = BreadCrumb -> BreadCrumbs -> BreadCrumbs
breadCrumb (Name -> BreadCrumb
MatchingConstructor Name
c) BreadCrumbs
bcs
              (StructType -> StructType -> m ())
-> [StructType] -> [StructType] -> m ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (UnifyDims m
-> Usage -> BreadCrumbs -> StructType -> StructType -> m ()
forall (m :: * -> *).
MonadUnify m =>
UnifyDims m
-> Usage -> BreadCrumbs -> StructType -> StructType -> m ()
unifyWith UnifyDims m
onDims Usage
usage BreadCrumbs
bcs') [StructType]
f1 [StructType]
f2
          | Bool
otherwise =
              Usage -> Notes -> BreadCrumbs -> Doc -> m ()
forall (m :: * -> *) loc a.
(MonadUnify m, Located loc) =>
loc -> Notes -> BreadCrumbs -> Doc -> m a
unifyError Usage
usage Notes
forall a. Monoid a => a
mempty BreadCrumbs
bcs (Doc -> m ()) -> Doc -> m ()
forall a b. (a -> b) -> a -> b
$
              Doc
"Cannot unify constructor" Doc -> Doc -> Doc
<+> Doc -> Doc
pquote (Name -> Doc
forall v. IsName v => v -> Doc
pprName Name
c) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
"."

-- | In @mustHaveConstr usage c t fs@, the type @t@ must have a
-- constructor named @c@ that takes arguments of types @ts@.
mustHaveConstr :: MonadUnify m =>
                  Usage -> Name -> StructType -> [StructType] -> m ()
mustHaveConstr :: Usage -> Name -> StructType -> [StructType] -> m ()
mustHaveConstr Usage
usage Name
c StructType
t [StructType]
fs = do
  Constraints
constraints <- m Constraints
forall (m :: * -> *). MonadUnify m => m Constraints
getConstraints
  case StructType
t of
    Scalar (TypeVar ()
_ Uniqueness
_ (TypeName [VName]
_ VName
tn) [])
      | Just (Int
lvl, NoConstraint{}) <- VName -> Constraints -> Maybe (Int, Constraint)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
tn Constraints
constraints -> do
          (StructType -> m ()) -> [StructType] -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Usage -> BreadCrumbs -> VName -> Int -> StructType -> m ()
forall (m :: * -> *).
MonadUnify m =>
Usage -> BreadCrumbs -> VName -> Int -> StructType -> m ()
scopeCheck Usage
usage BreadCrumbs
noBreadCrumbs VName
tn Int
lvl) [StructType]
fs
          (Constraints -> Constraints) -> m ()
forall (m :: * -> *).
MonadUnify m =>
(Constraints -> Constraints) -> m ()
modifyConstraints ((Constraints -> Constraints) -> m ())
-> (Constraints -> Constraints) -> m ()
forall a b. (a -> b) -> a -> b
$ VName -> (Int, Constraint) -> Constraints -> Constraints
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
tn (Int
lvl, Map Name [StructType] -> Usage -> Constraint
HasConstrs (Name -> [StructType] -> Map Name [StructType]
forall k a. k -> a -> Map k a
M.singleton Name
c [StructType]
fs) Usage
usage)
      | Just (Int
lvl, HasConstrs Map Name [StructType]
cs Usage
_) <- VName -> Constraints -> Maybe (Int, Constraint)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
tn Constraints
constraints ->
        case Name -> Map Name [StructType] -> Maybe [StructType]
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Name
c Map Name [StructType]
cs of
          Maybe [StructType]
Nothing  -> (Constraints -> Constraints) -> m ()
forall (m :: * -> *).
MonadUnify m =>
(Constraints -> Constraints) -> m ()
modifyConstraints ((Constraints -> Constraints) -> m ())
-> (Constraints -> Constraints) -> m ()
forall a b. (a -> b) -> a -> b
$ VName -> (Int, Constraint) -> Constraints -> Constraints
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
tn (Int
lvl, Map Name [StructType] -> Usage -> Constraint
HasConstrs (Name
-> [StructType] -> Map Name [StructType] -> Map Name [StructType]
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert Name
c [StructType]
fs Map Name [StructType]
cs) Usage
usage)
          Just [StructType]
fs'
            | [StructType] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [StructType]
fs Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [StructType] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [StructType]
fs' -> (StructType -> StructType -> m ())
-> [StructType] -> [StructType] -> m ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (Usage -> StructType -> StructType -> m ()
forall (m :: * -> *).
MonadUnify m =>
Usage -> StructType -> StructType -> m ()
unify Usage
usage) [StructType]
fs [StructType]
fs'
            | Bool
otherwise ->
                Usage -> Notes -> BreadCrumbs -> Doc -> m ()
forall (m :: * -> *) loc a.
(MonadUnify m, Located loc) =>
loc -> Notes -> BreadCrumbs -> Doc -> m a
unifyError Usage
usage Notes
forall a. Monoid a => a
mempty BreadCrumbs
noBreadCrumbs (Doc -> m ()) -> Doc -> m ()
forall a b. (a -> b) -> a -> b
$
                Doc
"Different arity for constructor" Doc -> Doc -> Doc
<+> Doc -> Doc
pquote (Name -> Doc
forall a. Pretty a => a -> Doc
ppr Name
c) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
"."

    Scalar (Sum Map Name [StructType]
cs) ->
      case Name -> Map Name [StructType] -> Maybe [StructType]
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Name
c Map Name [StructType]
cs of
        Maybe [StructType]
Nothing ->
          Usage -> Notes -> BreadCrumbs -> Doc -> m ()
forall (m :: * -> *) loc a.
(MonadUnify m, Located loc) =>
loc -> Notes -> BreadCrumbs -> Doc -> m a
unifyError Usage
usage Notes
forall a. Monoid a => a
mempty BreadCrumbs
noBreadCrumbs (Doc -> m ()) -> Doc -> m ()
forall a b. (a -> b) -> a -> b
$
          Doc
"Constuctor" Doc -> Doc -> Doc
<+> Doc -> Doc
pquote (Name -> Doc
forall a. Pretty a => a -> Doc
ppr Name
c) Doc -> Doc -> Doc
<+> Doc
"not present in type."
        Just [StructType]
fs'
            | [StructType] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [StructType]
fs Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [StructType] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [StructType]
fs' -> (StructType -> StructType -> m ())
-> [StructType] -> [StructType] -> m ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (Usage -> StructType -> StructType -> m ()
forall (m :: * -> *).
MonadUnify m =>
Usage -> StructType -> StructType -> m ()
unify Usage
usage) [StructType]
fs [StructType]
fs'
            | Bool
otherwise ->
                Usage -> Notes -> BreadCrumbs -> Doc -> m ()
forall (m :: * -> *) loc a.
(MonadUnify m, Located loc) =>
loc -> Notes -> BreadCrumbs -> Doc -> m a
unifyError Usage
usage Notes
forall a. Monoid a => a
mempty BreadCrumbs
noBreadCrumbs (Doc -> m ()) -> Doc -> m ()
forall a b. (a -> b) -> a -> b
$
                Doc
"Different arity for constructor" Doc -> Doc -> Doc
<+> Doc -> Doc
pquote (Name -> Doc
forall a. Pretty a => a -> Doc
ppr Name
c) Doc -> Doc -> Doc
<+> Doc
"."

    StructType
_ -> do Usage -> StructType -> StructType -> m ()
forall (m :: * -> *).
MonadUnify m =>
Usage -> StructType -> StructType -> m ()
unify Usage
usage StructType
t (StructType -> m ()) -> StructType -> m ()
forall a b. (a -> b) -> a -> b
$ ScalarTypeBase (DimDecl VName) () -> StructType
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (ScalarTypeBase (DimDecl VName) () -> StructType)
-> ScalarTypeBase (DimDecl VName) () -> StructType
forall a b. (a -> b) -> a -> b
$ Map Name [StructType] -> ScalarTypeBase (DimDecl VName) ()
forall dim as. Map Name [TypeBase dim as] -> ScalarTypeBase dim as
Sum (Map Name [StructType] -> ScalarTypeBase (DimDecl VName) ())
-> Map Name [StructType] -> ScalarTypeBase (DimDecl VName) ()
forall a b. (a -> b) -> a -> b
$ Name -> [StructType] -> Map Name [StructType]
forall k a. k -> a -> Map k a
M.singleton Name
c [StructType]
fs
            () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

mustHaveFieldWith :: MonadUnify m =>
                     UnifyDims m -> Usage -> BreadCrumbs
                  -> Name -> PatternType -> m PatternType
mustHaveFieldWith :: UnifyDims m
-> Usage -> BreadCrumbs -> Name -> PatternType -> m PatternType
mustHaveFieldWith UnifyDims m
onDims Usage
usage BreadCrumbs
bcs Name
l PatternType
t = do
  Constraints
constraints <- m Constraints
forall (m :: * -> *). MonadUnify m => m Constraints
getConstraints
  PatternType
l_type <- SrcLoc -> String -> m PatternType
forall (m :: * -> *) als dim.
(MonadUnify m, Monoid als) =>
SrcLoc -> String -> m (TypeBase dim als)
newTypeVar (Usage -> SrcLoc
forall a. Located a => a -> SrcLoc
srclocOf Usage
usage) String
"t"
  let l_type' :: StructType
l_type' = PatternType -> StructType
forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct PatternType
l_type
  case PatternType
t of
    Scalar (TypeVar Aliasing
_ Uniqueness
_ (TypeName [VName]
_ VName
tn) [])
      | Just (Int
lvl, NoConstraint{}) <- VName -> Constraints -> Maybe (Int, Constraint)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
tn Constraints
constraints -> do
          Usage -> BreadCrumbs -> VName -> Int -> StructType -> m ()
forall (m :: * -> *).
MonadUnify m =>
Usage -> BreadCrumbs -> VName -> Int -> StructType -> m ()
scopeCheck Usage
usage BreadCrumbs
bcs VName
tn Int
lvl StructType
l_type'
          (Constraints -> Constraints) -> m ()
forall (m :: * -> *).
MonadUnify m =>
(Constraints -> Constraints) -> m ()
modifyConstraints ((Constraints -> Constraints) -> m ())
-> (Constraints -> Constraints) -> m ()
forall a b. (a -> b) -> a -> b
$ VName -> (Int, Constraint) -> Constraints -> Constraints
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
tn (Int
lvl, Map Name StructType -> Usage -> Constraint
HasFields (Name -> StructType -> Map Name StructType
forall k a. k -> a -> Map k a
M.singleton Name
l StructType
l_type') Usage
usage)
          PatternType -> m PatternType
forall (m :: * -> *) a. Monad m => a -> m a
return PatternType
l_type
      | Just (Int
lvl, HasFields Map Name StructType
fields Usage
_) <- VName -> Constraints -> Maybe (Int, Constraint)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
tn Constraints
constraints -> do
          case Name -> Map Name StructType -> Maybe StructType
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Name
l Map Name StructType
fields of
            Just StructType
t' -> UnifyDims m
-> Usage -> BreadCrumbs -> StructType -> StructType -> m ()
forall (m :: * -> *).
MonadUnify m =>
UnifyDims m
-> Usage -> BreadCrumbs -> StructType -> StructType -> m ()
unifyWith UnifyDims m
onDims Usage
usage BreadCrumbs
bcs StructType
l_type' StructType
t'
            Maybe StructType
Nothing -> (Constraints -> Constraints) -> m ()
forall (m :: * -> *).
MonadUnify m =>
(Constraints -> Constraints) -> m ()
modifyConstraints ((Constraints -> Constraints) -> m ())
-> (Constraints -> Constraints) -> m ()
forall a b. (a -> b) -> a -> b
$ VName -> (Int, Constraint) -> Constraints -> Constraints
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
tn
                       (Int
lvl, Map Name StructType -> Usage -> Constraint
HasFields (Name -> StructType -> Map Name StructType -> Map Name StructType
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert Name
l StructType
l_type' Map Name StructType
fields) Usage
usage)
          PatternType -> m PatternType
forall (m :: * -> *) a. Monad m => a -> m a
return PatternType
l_type
    Scalar (Record Map Name PatternType
fields)
      | Just PatternType
t' <- Name -> Map Name PatternType -> Maybe PatternType
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Name
l Map Name PatternType
fields -> do
          Usage -> StructType -> StructType -> m ()
forall (m :: * -> *).
MonadUnify m =>
Usage -> StructType -> StructType -> m ()
unify Usage
usage StructType
l_type' (StructType -> m ()) -> StructType -> m ()
forall a b. (a -> b) -> a -> b
$ PatternType -> StructType
forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct PatternType
t'
          PatternType -> m PatternType
forall (m :: * -> *) a. Monad m => a -> m a
return PatternType
t'
      | Bool
otherwise ->
          Usage -> Notes -> BreadCrumbs -> Doc -> m PatternType
forall (m :: * -> *) loc a.
(MonadUnify m, Located loc) =>
loc -> Notes -> BreadCrumbs -> Doc -> m a
unifyError Usage
usage Notes
forall a. Monoid a => a
mempty BreadCrumbs
bcs (Doc -> m PatternType) -> Doc -> m PatternType
forall a b. (a -> b) -> a -> b
$
            Doc
"Attempt to access field" Doc -> Doc -> Doc
<+> Doc -> Doc
pquote (Name -> Doc
forall a. Pretty a => a -> Doc
ppr Name
l) Doc -> Doc -> Doc
<+> Doc
" of value of type" Doc -> Doc -> Doc
<+>
            TypeBase () () -> Doc
forall a. Pretty a => a -> Doc
ppr (PatternType -> TypeBase () ()
forall dim as. TypeBase dim as -> TypeBase () ()
toStructural PatternType
t) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
"."
    PatternType
_ -> do Usage -> StructType -> StructType -> m ()
forall (m :: * -> *).
MonadUnify m =>
Usage -> StructType -> StructType -> m ()
unify Usage
usage (PatternType -> StructType
forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct PatternType
t) (StructType -> m ()) -> StructType -> m ()
forall a b. (a -> b) -> a -> b
$ ScalarTypeBase (DimDecl VName) () -> StructType
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (ScalarTypeBase (DimDecl VName) () -> StructType)
-> ScalarTypeBase (DimDecl VName) () -> StructType
forall a b. (a -> b) -> a -> b
$ Map Name StructType -> ScalarTypeBase (DimDecl VName) ()
forall dim as. Map Name (TypeBase dim as) -> ScalarTypeBase dim as
Record (Map Name StructType -> ScalarTypeBase (DimDecl VName) ())
-> Map Name StructType -> ScalarTypeBase (DimDecl VName) ()
forall a b. (a -> b) -> a -> b
$ Name -> StructType -> Map Name StructType
forall k a. k -> a -> Map k a
M.singleton Name
l StructType
l_type'
            PatternType -> m PatternType
forall (m :: * -> *) a. Monad m => a -> m a
return PatternType
l_type

-- | Assert that some type must have a field with this name and type.
mustHaveField :: MonadUnify m =>
                 Usage -> Name -> PatternType -> m PatternType
mustHaveField :: Usage -> Name -> PatternType -> m PatternType
mustHaveField Usage
usage = UnifyDims m
-> Usage -> BreadCrumbs -> Name -> PatternType -> m PatternType
forall (m :: * -> *).
MonadUnify m =>
UnifyDims m
-> Usage -> BreadCrumbs -> Name -> PatternType -> m PatternType
mustHaveFieldWith (Usage -> UnifyDims m
forall (m :: * -> *). MonadUnify m => Usage -> UnifyDims m
unifyDims Usage
usage) Usage
usage BreadCrumbs
noBreadCrumbs

-- | Replace dimension mismatches with AnyDim.
anyDimOnMismatch :: Monoid as =>
                    TypeBase (DimDecl VName) as -> TypeBase (DimDecl VName) as
                 -> (TypeBase (DimDecl VName) as, [(DimDecl VName, DimDecl VName)])
anyDimOnMismatch :: TypeBase (DimDecl VName) as
-> TypeBase (DimDecl VName) as
-> (TypeBase (DimDecl VName) as, [(DimDecl VName, DimDecl VName)])
anyDimOnMismatch TypeBase (DimDecl VName) as
t1 TypeBase (DimDecl VName) as
t2 = Writer
  [(DimDecl VName, DimDecl VName)] (TypeBase (DimDecl VName) as)
-> (TypeBase (DimDecl VName) as, [(DimDecl VName, DimDecl VName)])
forall w a. Writer w a -> (a, w)
runWriter (Writer
   [(DimDecl VName, DimDecl VName)] (TypeBase (DimDecl VName) as)
 -> (TypeBase (DimDecl VName) as, [(DimDecl VName, DimDecl VName)]))
-> Writer
     [(DimDecl VName, DimDecl VName)] (TypeBase (DimDecl VName) as)
-> (TypeBase (DimDecl VName) as, [(DimDecl VName, DimDecl VName)])
forall a b. (a -> b) -> a -> b
$ (DimDecl VName
 -> DimDecl VName
 -> WriterT
      [(DimDecl VName, DimDecl VName)] Identity (DimDecl VName))
-> TypeBase (DimDecl VName) as
-> TypeBase (DimDecl VName) as
-> Writer
     [(DimDecl VName, DimDecl VName)] (TypeBase (DimDecl VName) as)
forall as (m :: * -> *) d1 d2.
(Monoid as, Monad m) =>
(d1 -> d2 -> m d1)
-> TypeBase d1 as -> TypeBase d2 as -> m (TypeBase d1 as)
matchDims DimDecl VName
-> DimDecl VName
-> WriterT
     [(DimDecl VName, DimDecl VName)] Identity (DimDecl VName)
forall vn (m :: * -> *).
(Eq (DimDecl vn), MonadWriter [(DimDecl vn, DimDecl vn)] m) =>
DimDecl vn -> DimDecl vn -> m (DimDecl vn)
onDims TypeBase (DimDecl VName) as
t1 TypeBase (DimDecl VName) as
t2
  where onDims :: DimDecl vn -> DimDecl vn -> m (DimDecl vn)
onDims DimDecl vn
d1 DimDecl vn
d2
          | DimDecl vn
d1 DimDecl vn -> DimDecl vn -> Bool
forall a. Eq a => a -> a -> Bool
== DimDecl vn
d2 = DimDecl vn -> m (DimDecl vn)
forall (m :: * -> *) a. Monad m => a -> m a
return DimDecl vn
d1
          | Bool
otherwise = do [(DimDecl vn, DimDecl vn)] -> m ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell [(DimDecl vn
d1, DimDecl vn
d2)]
                           DimDecl vn -> m (DimDecl vn)
forall (m :: * -> *) a. Monad m => a -> m a
return DimDecl vn
forall vn. DimDecl vn
AnyDim

newDimOnMismatch :: (Monoid as, MonadUnify m) =>
                    SrcLoc -> TypeBase (DimDecl VName) as -> TypeBase (DimDecl VName) as
                 -> m (TypeBase (DimDecl VName) as, [VName])
newDimOnMismatch :: SrcLoc
-> TypeBase (DimDecl VName) as
-> TypeBase (DimDecl VName) as
-> m (TypeBase (DimDecl VName) as, [VName])
newDimOnMismatch SrcLoc
loc TypeBase (DimDecl VName) as
t1 TypeBase (DimDecl VName) as
t2 = do
  (TypeBase (DimDecl VName) as
t, Map (DimDecl VName, DimDecl VName) VName
seen) <- StateT
  (Map (DimDecl VName, DimDecl VName) VName)
  m
  (TypeBase (DimDecl VName) as)
-> Map (DimDecl VName, DimDecl VName) VName
-> m (TypeBase (DimDecl VName) as,
      Map (DimDecl VName, DimDecl VName) VName)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT ((DimDecl VName
 -> DimDecl VName
 -> StateT
      (Map (DimDecl VName, DimDecl VName) VName) m (DimDecl VName))
-> TypeBase (DimDecl VName) as
-> TypeBase (DimDecl VName) as
-> StateT
     (Map (DimDecl VName, DimDecl VName) VName)
     m
     (TypeBase (DimDecl VName) as)
forall as (m :: * -> *) d1 d2.
(Monoid as, Monad m) =>
(d1 -> d2 -> m d1)
-> TypeBase d1 as -> TypeBase d2 as -> m (TypeBase d1 as)
matchDims DimDecl VName
-> DimDecl VName
-> StateT
     (Map (DimDecl VName, DimDecl VName) VName) m (DimDecl VName)
forall (t :: (* -> *) -> * -> *) (m :: * -> *).
(MonadState (Map (DimDecl VName, DimDecl VName) VName) (t m),
 MonadTrans t, MonadUnify m) =>
DimDecl VName -> DimDecl VName -> t m (DimDecl VName)
onDims TypeBase (DimDecl VName) as
t1 TypeBase (DimDecl VName) as
t2) Map (DimDecl VName, DimDecl VName) VName
forall a. Monoid a => a
mempty
  (TypeBase (DimDecl VName) as, [VName])
-> m (TypeBase (DimDecl VName) as, [VName])
forall (m :: * -> *) a. Monad m => a -> m a
return (TypeBase (DimDecl VName) as
t, Map (DimDecl VName, DimDecl VName) VName -> [VName]
forall k a. Map k a -> [a]
M.elems Map (DimDecl VName, DimDecl VName) VName
seen)
  where r :: Rigidity
r = RigidSource -> Rigidity
Rigid (RigidSource -> Rigidity) -> RigidSource -> Rigidity
forall a b. (a -> b) -> a -> b
$ StructType -> StructType -> RigidSource
RigidCond (TypeBase (DimDecl VName) as -> StructType
forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct TypeBase (DimDecl VName) as
t1) (TypeBase (DimDecl VName) as -> StructType
forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct TypeBase (DimDecl VName) as
t2)
        onDims :: DimDecl VName -> DimDecl VName -> t m (DimDecl VName)
onDims DimDecl VName
d1 DimDecl VName
d2
          | DimDecl VName
d1 DimDecl VName -> DimDecl VName -> Bool
forall a. Eq a => a -> a -> Bool
== DimDecl VName
d2 = DimDecl VName -> t m (DimDecl VName)
forall (m :: * -> *) a. Monad m => a -> m a
return DimDecl VName
d1
          | Bool
otherwise = do
              -- Remember mismatches we have seen before and reuse the
              -- same new size.
              Maybe VName
maybe_d <- (Map (DimDecl VName, DimDecl VName) VName -> Maybe VName)
-> t m (Maybe VName)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((Map (DimDecl VName, DimDecl VName) VName -> Maybe VName)
 -> t m (Maybe VName))
-> (Map (DimDecl VName, DimDecl VName) VName -> Maybe VName)
-> t m (Maybe VName)
forall a b. (a -> b) -> a -> b
$ (DimDecl VName, DimDecl VName)
-> Map (DimDecl VName, DimDecl VName) VName -> Maybe VName
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (DimDecl VName
d1, DimDecl VName
d2)
              case Maybe VName
maybe_d of
                Just VName
d -> DimDecl VName -> t m (DimDecl VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (DimDecl VName -> t m (DimDecl VName))
-> DimDecl VName -> t m (DimDecl VName)
forall a b. (a -> b) -> a -> b
$ QualName VName -> DimDecl VName
forall vn. QualName vn -> DimDecl vn
NamedDim (QualName VName -> DimDecl VName)
-> QualName VName -> DimDecl VName
forall a b. (a -> b) -> a -> b
$ VName -> QualName VName
forall v. v -> QualName v
qualName VName
d
                Maybe VName
Nothing -> do
                  VName
d <- m VName -> t m VName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m VName -> t m VName) -> m VName -> t m VName
forall a b. (a -> b) -> a -> b
$ SrcLoc -> Rigidity -> String -> m VName
forall (m :: * -> *).
MonadUnify m =>
SrcLoc -> Rigidity -> String -> m VName
newDimVar SrcLoc
loc Rigidity
r String
"differ"
                  (Map (DimDecl VName, DimDecl VName) VName
 -> Map (DimDecl VName, DimDecl VName) VName)
-> t m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((Map (DimDecl VName, DimDecl VName) VName
  -> Map (DimDecl VName, DimDecl VName) VName)
 -> t m ())
-> (Map (DimDecl VName, DimDecl VName) VName
    -> Map (DimDecl VName, DimDecl VName) VName)
-> t m ()
forall a b. (a -> b) -> a -> b
$ (DimDecl VName, DimDecl VName)
-> VName
-> Map (DimDecl VName, DimDecl VName) VName
-> Map (DimDecl VName, DimDecl VName) VName
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert (DimDecl VName
d1, DimDecl VName
d2) VName
d
                  DimDecl VName -> t m (DimDecl VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (DimDecl VName -> t m (DimDecl VName))
-> DimDecl VName -> t m (DimDecl VName)
forall a b. (a -> b) -> a -> b
$ QualName VName -> DimDecl VName
forall vn. QualName vn -> DimDecl vn
NamedDim (QualName VName -> DimDecl VName)
-> QualName VName -> DimDecl VName
forall a b. (a -> b) -> a -> b
$ VName -> QualName VName
forall v. v -> QualName v
qualName VName
d

-- | Like unification, but creates new size variables where mismatches
-- occur.  Returns the new dimensions thus created.
unifyMostCommon :: MonadUnify m =>
                   Usage -> PatternType -> PatternType -> m (PatternType, [VName])
unifyMostCommon :: Usage -> PatternType -> PatternType -> m (PatternType, [VName])
unifyMostCommon Usage
usage PatternType
t1 PatternType
t2 = do
  -- We are ignoring the dimensions here, because any mismatches
  -- should be turned into fresh size variables.
  let allOK :: p -> p -> p -> p -> p -> m ()
allOK p
_ p
_ p
_ p
_ p
_ = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
  UnifyDims m
-> Usage -> BreadCrumbs -> StructType -> StructType -> m ()
forall (m :: * -> *).
MonadUnify m =>
UnifyDims m
-> Usage -> BreadCrumbs -> StructType -> StructType -> m ()
unifyWith UnifyDims m
forall (m :: * -> *) p p p p p.
Monad m =>
p -> p -> p -> p -> p -> m ()
allOK Usage
usage BreadCrumbs
noBreadCrumbs (PatternType -> StructType
forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct PatternType
t1) (PatternType -> StructType
forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct PatternType
t2)
  PatternType
t1' <- PatternType -> m PatternType
forall a (m :: * -> *). (Substitutable a, MonadUnify m) => a -> m a
normTypeFully PatternType
t1
  PatternType
t2' <- PatternType -> m PatternType
forall a (m :: * -> *). (Substitutable a, MonadUnify m) => a -> m a
normTypeFully PatternType
t2
  SrcLoc -> PatternType -> PatternType -> m (PatternType, [VName])
forall as (m :: * -> *).
(Monoid as, MonadUnify m) =>
SrcLoc
-> TypeBase (DimDecl VName) as
-> TypeBase (DimDecl VName) as
-> m (TypeBase (DimDecl VName) as, [VName])
newDimOnMismatch (Usage -> SrcLoc
forall a. Located a => a -> SrcLoc
srclocOf Usage
usage) PatternType
t1' PatternType
t2'

-- Simple MonadUnify implementation.

type UnifyMState = (Constraints, Int)

newtype UnifyM a = UnifyM (StateT UnifyMState (Except TypeError) a)
  deriving (Applicative UnifyM
a -> UnifyM a
Applicative UnifyM
-> (forall a b. UnifyM a -> (a -> UnifyM b) -> UnifyM b)
-> (forall a b. UnifyM a -> UnifyM b -> UnifyM b)
-> (forall a. a -> UnifyM a)
-> Monad UnifyM
UnifyM a -> (a -> UnifyM b) -> UnifyM b
UnifyM a -> UnifyM b -> UnifyM b
forall a. a -> UnifyM a
forall a b. UnifyM a -> UnifyM b -> UnifyM b
forall a b. UnifyM a -> (a -> UnifyM b) -> UnifyM b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: a -> UnifyM a
$creturn :: forall a. a -> UnifyM a
>> :: UnifyM a -> UnifyM b -> UnifyM b
$c>> :: forall a b. UnifyM a -> UnifyM b -> UnifyM b
>>= :: UnifyM a -> (a -> UnifyM b) -> UnifyM b
$c>>= :: forall a b. UnifyM a -> (a -> UnifyM b) -> UnifyM b
$cp1Monad :: Applicative UnifyM
Monad, a -> UnifyM b -> UnifyM a
(a -> b) -> UnifyM a -> UnifyM b
(forall a b. (a -> b) -> UnifyM a -> UnifyM b)
-> (forall a b. a -> UnifyM b -> UnifyM a) -> Functor UnifyM
forall a b. a -> UnifyM b -> UnifyM a
forall a b. (a -> b) -> UnifyM a -> UnifyM b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> UnifyM b -> UnifyM a
$c<$ :: forall a b. a -> UnifyM b -> UnifyM a
fmap :: (a -> b) -> UnifyM a -> UnifyM b
$cfmap :: forall a b. (a -> b) -> UnifyM a -> UnifyM b
Functor, Functor UnifyM
a -> UnifyM a
Functor UnifyM
-> (forall a. a -> UnifyM a)
-> (forall a b. UnifyM (a -> b) -> UnifyM a -> UnifyM b)
-> (forall a b c.
    (a -> b -> c) -> UnifyM a -> UnifyM b -> UnifyM c)
-> (forall a b. UnifyM a -> UnifyM b -> UnifyM b)
-> (forall a b. UnifyM a -> UnifyM b -> UnifyM a)
-> Applicative UnifyM
UnifyM a -> UnifyM b -> UnifyM b
UnifyM a -> UnifyM b -> UnifyM a
UnifyM (a -> b) -> UnifyM a -> UnifyM b
(a -> b -> c) -> UnifyM a -> UnifyM b -> UnifyM c
forall a. a -> UnifyM a
forall a b. UnifyM a -> UnifyM b -> UnifyM a
forall a b. UnifyM a -> UnifyM b -> UnifyM b
forall a b. UnifyM (a -> b) -> UnifyM a -> UnifyM b
forall a b c. (a -> b -> c) -> UnifyM a -> UnifyM b -> UnifyM c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: UnifyM a -> UnifyM b -> UnifyM a
$c<* :: forall a b. UnifyM a -> UnifyM b -> UnifyM a
*> :: UnifyM a -> UnifyM b -> UnifyM b
$c*> :: forall a b. UnifyM a -> UnifyM b -> UnifyM b
liftA2 :: (a -> b -> c) -> UnifyM a -> UnifyM b -> UnifyM c
$cliftA2 :: forall a b c. (a -> b -> c) -> UnifyM a -> UnifyM b -> UnifyM c
<*> :: UnifyM (a -> b) -> UnifyM a -> UnifyM b
$c<*> :: forall a b. UnifyM (a -> b) -> UnifyM a -> UnifyM b
pure :: a -> UnifyM a
$cpure :: forall a. a -> UnifyM a
$cp1Applicative :: Functor UnifyM
Applicative,
            MonadState UnifyMState,
            MonadError TypeError)

newVar :: String -> UnifyM VName
newVar :: String -> UnifyM VName
newVar String
name = do
  (Constraints
x, Int
i) <- UnifyM UnifyMState
forall s (m :: * -> *). MonadState s m => m s
get
  UnifyMState -> UnifyM ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put (Constraints
x, Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)
  VName -> UnifyM VName
forall (m :: * -> *) a. Monad m => a -> m a
return (VName -> UnifyM VName) -> VName -> UnifyM VName
forall a b. (a -> b) -> a -> b
$ Name -> Int -> VName
VName (String -> Int -> Name
mkTypeVarName String
name Int
i) Int
i

instance MonadUnify UnifyM where
  getConstraints :: UnifyM Constraints
getConstraints = (UnifyMState -> Constraints) -> UnifyM Constraints
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets UnifyMState -> Constraints
forall a b. (a, b) -> a
fst
  putConstraints :: Constraints -> UnifyM ()
putConstraints Constraints
x = (UnifyMState -> UnifyMState) -> UnifyM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((UnifyMState -> UnifyMState) -> UnifyM ())
-> (UnifyMState -> UnifyMState) -> UnifyM ()
forall a b. (a -> b) -> a -> b
$ \(Constraints
_, Int
i) -> (Constraints
x, Int
i)

  newTypeVar :: SrcLoc -> String -> UnifyM (TypeBase dim als)
newTypeVar SrcLoc
loc String
name = do
    VName
v <- String -> UnifyM VName
newVar String
name
    (Constraints -> Constraints) -> UnifyM ()
forall (m :: * -> *).
MonadUnify m =>
(Constraints -> Constraints) -> m ()
modifyConstraints ((Constraints -> Constraints) -> UnifyM ())
-> (Constraints -> Constraints) -> UnifyM ()
forall a b. (a -> b) -> a -> b
$ VName -> (Int, Constraint) -> Constraints -> Constraints
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v (Int
0, Liftedness -> Usage -> Constraint
NoConstraint Liftedness
Lifted (Usage -> Constraint) -> Usage -> Constraint
forall a b. (a -> b) -> a -> b
$ Maybe String -> SrcLoc -> Usage
Usage Maybe String
forall a. Maybe a
Nothing SrcLoc
loc)
    TypeBase dim als -> UnifyM (TypeBase dim als)
forall (m :: * -> *) a. Monad m => a -> m a
return (TypeBase dim als -> UnifyM (TypeBase dim als))
-> TypeBase dim als -> UnifyM (TypeBase dim als)
forall a b. (a -> b) -> a -> b
$ ScalarTypeBase dim als -> TypeBase dim als
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (ScalarTypeBase dim als -> TypeBase dim als)
-> ScalarTypeBase dim als -> TypeBase dim als
forall a b. (a -> b) -> a -> b
$ als
-> Uniqueness
-> TypeName
-> [TypeArg dim]
-> ScalarTypeBase dim als
forall dim as.
as
-> Uniqueness -> TypeName -> [TypeArg dim] -> ScalarTypeBase dim as
TypeVar als
forall a. Monoid a => a
mempty Uniqueness
Nonunique (VName -> TypeName
typeName VName
v) []

  newDimVar :: SrcLoc -> Rigidity -> String -> UnifyM VName
newDimVar SrcLoc
loc Rigidity
rigidity String
name = do
    VName
dim <- String -> UnifyM VName
newVar String
name
    case Rigidity
rigidity of
      Rigid RigidSource
src -> (Constraints -> Constraints) -> UnifyM ()
forall (m :: * -> *).
MonadUnify m =>
(Constraints -> Constraints) -> m ()
modifyConstraints ((Constraints -> Constraints) -> UnifyM ())
-> (Constraints -> Constraints) -> UnifyM ()
forall a b. (a -> b) -> a -> b
$ VName -> (Int, Constraint) -> Constraints -> Constraints
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
dim (Int
0, SrcLoc -> RigidSource -> Constraint
UnknowableSize SrcLoc
loc RigidSource
src)
      Rigidity
Nonrigid -> (Constraints -> Constraints) -> UnifyM ()
forall (m :: * -> *).
MonadUnify m =>
(Constraints -> Constraints) -> m ()
modifyConstraints ((Constraints -> Constraints) -> UnifyM ())
-> (Constraints -> Constraints) -> UnifyM ()
forall a b. (a -> b) -> a -> b
$ VName -> (Int, Constraint) -> Constraints -> Constraints
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
dim (Int
0, Maybe (DimDecl VName) -> Usage -> Constraint
Size Maybe (DimDecl VName)
forall a. Maybe a
Nothing (Usage -> Constraint) -> Usage -> Constraint
forall a b. (a -> b) -> a -> b
$ Maybe String -> SrcLoc -> Usage
Usage Maybe String
forall a. Maybe a
Nothing SrcLoc
loc)
    VName -> UnifyM VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
dim

  curLevel :: UnifyM Int
curLevel = Int -> UnifyM Int
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
0

  unifyError :: loc -> Notes -> BreadCrumbs -> Doc -> UnifyM a
unifyError loc
loc Notes
notes BreadCrumbs
bcs Doc
doc =
    TypeError -> UnifyM a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (TypeError -> UnifyM a) -> TypeError -> UnifyM a
forall a b. (a -> b) -> a -> b
$ SrcLoc -> Notes -> Doc -> TypeError
TypeError (loc -> SrcLoc
forall a. Located a => a -> SrcLoc
srclocOf loc
loc) Notes
notes (Doc -> TypeError) -> Doc -> TypeError
forall a b. (a -> b) -> a -> b
$ Doc
doc Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> BreadCrumbs -> Doc
forall a. Pretty a => a -> Doc
ppr BreadCrumbs
bcs

  matchError :: loc -> Notes -> BreadCrumbs -> StructType -> StructType -> UnifyM a
matchError loc
loc Notes
notes BreadCrumbs
bcs StructType
t1 StructType
t2 =
    TypeError -> UnifyM a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (TypeError -> UnifyM a) -> TypeError -> UnifyM a
forall a b. (a -> b) -> a -> b
$ SrcLoc -> Notes -> Doc -> TypeError
TypeError (loc -> SrcLoc
forall a. Located a => a -> SrcLoc
srclocOf loc
loc) Notes
notes (Doc -> TypeError) -> Doc -> TypeError
forall a b. (a -> b) -> a -> b
$ Doc
doc Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> BreadCrumbs -> Doc
forall a. Pretty a => a -> Doc
ppr BreadCrumbs
bcs
    where doc :: Doc
doc = Doc
"Types" Doc -> Doc -> Doc
</>
                Int -> Doc -> Doc
indent Int
2 (StructType -> Doc
forall a. Pretty a => a -> Doc
ppr StructType
t1) Doc -> Doc -> Doc
</>
                Doc
"and" Doc -> Doc -> Doc
</>
                Int -> Doc -> Doc
indent Int
2 (StructType -> Doc
forall a. Pretty a => a -> Doc
ppr StructType
t2) Doc -> Doc -> Doc
</>
                Doc
"do not match."

-- | Construct a the name of a new type variable given a base
-- description and a tag number (note that this is distinct from
-- actually constructing a VName; the tag here is intended for human
-- consumption but the machine does not care).
mkTypeVarName :: String -> Int -> Name
mkTypeVarName :: String -> Int -> Name
mkTypeVarName String
desc Int
i =
  String -> Name
nameFromString (String -> Name) -> String -> Name
forall a b. (a -> b) -> a -> b
$ String
desc String -> ShowS
forall a. [a] -> [a] -> [a]
++ (Char -> Maybe Char) -> ShowS
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe Char -> Maybe Char
subscript (Int -> String
forall a. Show a => a -> String
show Int
i)
  where subscript :: Char -> Maybe Char
subscript = (Char -> [(Char, Char)] -> Maybe Char)
-> [(Char, Char)] -> Char -> Maybe Char
forall a b c. (a -> b -> c) -> b -> a -> c
flip Char -> [(Char, Char)] -> Maybe Char
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup ([(Char, Char)] -> Char -> Maybe Char)
-> [(Char, Char)] -> Char -> Maybe Char
forall a b. (a -> b) -> a -> b
$ String -> String -> [(Char, Char)]
forall a b. [a] -> [b] -> [(a, b)]
zip String
"0123456789" String
"₀₁₂₃₄₅₆₇₈₉"

runUnifyM :: [TypeParam] -> UnifyM a -> Either TypeError a
runUnifyM :: [TypeParam] -> UnifyM a -> Either TypeError a
runUnifyM [TypeParam]
tparams (UnifyM StateT UnifyMState (Except TypeError) a
m) = Except TypeError a -> Either TypeError a
forall e a. Except e a -> Either e a
runExcept (Except TypeError a -> Either TypeError a)
-> Except TypeError a -> Either TypeError a
forall a b. (a -> b) -> a -> b
$ StateT UnifyMState (Except TypeError) a
-> UnifyMState -> Except TypeError a
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT StateT UnifyMState (Except TypeError) a
m (Constraints
constraints, Int
0)
  where constraints :: Constraints
constraints = [(VName, (Int, Constraint))] -> Constraints
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, (Int, Constraint))] -> Constraints)
-> [(VName, (Int, Constraint))] -> Constraints
forall a b. (a -> b) -> a -> b
$ (TypeParam -> (VName, (Int, Constraint)))
-> [TypeParam] -> [(VName, (Int, Constraint))]
forall a b. (a -> b) -> [a] -> [b]
map TypeParam -> (VName, (Int, Constraint))
forall a a. Num a => TypeParamBase a -> (a, (a, Constraint))
f [TypeParam]
tparams
        f :: TypeParamBase a -> (a, (a, Constraint))
f (TypeParamDim a
p SrcLoc
loc) = (a
p, (a
0, Maybe (DimDecl VName) -> Usage -> Constraint
Size Maybe (DimDecl VName)
forall a. Maybe a
Nothing (Usage -> Constraint) -> Usage -> Constraint
forall a b. (a -> b) -> a -> b
$ Maybe String -> SrcLoc -> Usage
Usage Maybe String
forall a. Maybe a
Nothing SrcLoc
loc))
        f (TypeParamType Liftedness
l a
p SrcLoc
loc) = (a
p, (a
0, Liftedness -> Usage -> Constraint
NoConstraint Liftedness
l (Usage -> Constraint) -> Usage -> Constraint
forall a b. (a -> b) -> a -> b
$ Maybe String -> SrcLoc -> Usage
Usage Maybe String
forall a. Maybe a
Nothing SrcLoc
loc))

-- | Perform a unification of two types outside a monadic context.
-- The type parameters are allowed to be instantiated; all other types
-- are considered rigid.
doUnification :: SrcLoc -> [TypeParam]
              -> StructType -> StructType
              -> Either TypeError StructType
doUnification :: SrcLoc
-> [TypeParam]
-> StructType
-> StructType
-> Either TypeError StructType
doUnification SrcLoc
loc [TypeParam]
tparams StructType
t1 StructType
t2 = [TypeParam] -> UnifyM StructType -> Either TypeError StructType
forall a. [TypeParam] -> UnifyM a -> Either TypeError a
runUnifyM [TypeParam]
tparams (UnifyM StructType -> Either TypeError StructType)
-> UnifyM StructType -> Either TypeError StructType
forall a b. (a -> b) -> a -> b
$ do
  let rsrc :: RigidSource
rsrc = RigidSource
RigidUnify
  (StructType
t1', [VName]
_) <- SrcLoc
-> String -> Rigidity -> StructType -> UnifyM (StructType, [VName])
forall (m :: * -> *) als.
MonadUnify m =>
SrcLoc
-> String
-> Rigidity
-> TypeBase (DimDecl VName) als
-> m (TypeBase (DimDecl VName) als, [VName])
instantiateEmptyArrayDims SrcLoc
loc String
"n" (RigidSource -> Rigidity
Rigid RigidSource
rsrc) StructType
t1
  (StructType
t2', [VName]
_) <- SrcLoc
-> String -> Rigidity -> StructType -> UnifyM (StructType, [VName])
forall (m :: * -> *) als.
MonadUnify m =>
SrcLoc
-> String
-> Rigidity
-> TypeBase (DimDecl VName) als
-> m (TypeBase (DimDecl VName) als, [VName])
instantiateEmptyArrayDims SrcLoc
loc String
"m" (RigidSource -> Rigidity
Rigid RigidSource
rsrc) StructType
t2
  Usage -> StructType -> StructType -> UnifyM ()
forall (m :: * -> *).
MonadUnify m =>
Usage -> StructType -> StructType -> m ()
expect (Maybe String -> SrcLoc -> Usage
Usage Maybe String
forall a. Maybe a
Nothing SrcLoc
loc) StructType
t1' StructType
t2'
  StructType -> UnifyM StructType
forall a (m :: * -> *). (Substitutable a, MonadUnify m) => a -> m a
normTypeFully StructType
t2