{-# LANGUAGE GeneralizedNewtypeDeriving, FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances, DeriveFunctor #-}
{-# Language TupleSections #-}
module Language.Futhark.TypeChecker.Terms
( checkOneExp
, checkFunDef
)
where
import Control.Monad.Except
import Control.Monad.State
import Control.Monad.RWS hiding (Sum)
import qualified Control.Monad.Fail as Fail
import Data.Char (isAlpha)
import Data.List
import qualified Data.List.NonEmpty as NE
import Data.Loc
import Data.Maybe
import qualified Data.Map.Strict as M
import qualified Data.Set as S
import Prelude hiding (mod)
import Language.Futhark
import Language.Futhark.Semantic (includeToString)
import Language.Futhark.Traversals
import Language.Futhark.TypeChecker.Monad hiding (BoundV, checkQualNameWithEnv)
import Language.Futhark.TypeChecker.Types hiding (checkTypeDecl)
import Language.Futhark.TypeChecker.Unify hiding (Usage)
import qualified Language.Futhark.TypeChecker.Types as Types
import qualified Language.Futhark.TypeChecker.Monad as TypeM
import Futhark.Util.Pretty hiding (space, bool, group)
data Usage = Consumed SrcLoc
| Observed SrcLoc
deriving (Eq, Ord, Show)
type Names = S.Set VName
data Occurence = Occurence { observed :: Names
, consumed :: Maybe Names
, location :: SrcLoc
}
deriving (Eq, Show)
instance Located Occurence where
locOf = locOf . location
observation :: Aliasing -> SrcLoc -> Occurence
observation = flip Occurence Nothing . S.map aliasVar
consumption :: Aliasing -> SrcLoc -> Occurence
consumption = Occurence S.empty . Just . S.map aliasVar
nullOccurence :: Occurence -> Bool
nullOccurence occ = S.null (observed occ) && isNothing (consumed occ)
seminullOccurence :: Occurence -> Bool
seminullOccurence occ = S.null (observed occ) && maybe True S.null (consumed occ)
type Occurences = [Occurence]
type UsageMap = M.Map VName [Usage]
usageMap :: Occurences -> UsageMap
usageMap = foldl comb M.empty
where comb m (Occurence obs cons loc) =
let m' = S.foldl' (ins $ Observed loc) m obs
in S.foldl' (ins $ Consumed loc) m' $ fromMaybe mempty cons
ins v m k = M.insertWith (++) k [v] m
combineOccurences :: MonadTypeChecker m => VName -> Usage -> Usage -> m Usage
combineOccurences _ (Observed loc) (Observed _) = return $ Observed loc
combineOccurences name (Consumed wloc) (Observed rloc) =
useAfterConsume (baseName name) rloc wloc
combineOccurences name (Observed rloc) (Consumed wloc) =
useAfterConsume (baseName name) rloc wloc
combineOccurences name (Consumed loc1) (Consumed loc2) =
consumeAfterConsume (baseName name) (max loc1 loc2) (min loc1 loc2)
checkOccurences :: MonadTypeChecker m => Occurences -> m ()
checkOccurences = void . M.traverseWithKey comb . usageMap
where comb _ [] = return ()
comb name (u:us) = foldM_ (combineOccurences name) u us
allObserved :: Occurences -> Names
allObserved = S.unions . map observed
allConsumed :: Occurences -> Names
allConsumed = S.unions . map (fromMaybe mempty . consumed)
allOccuring :: Occurences -> Names
allOccuring occs = allConsumed occs <> allObserved occs
anyConsumption :: Occurences -> Maybe Occurence
anyConsumption = find (isJust . consumed)
seqOccurences :: Occurences -> Occurences -> Occurences
seqOccurences occurs1 occurs2 =
filter (not . nullOccurence) $ map filt occurs1 ++ occurs2
where filt occ =
occ { observed = observed occ `S.difference` postcons }
postcons = allConsumed occurs2
altOccurences :: Occurences -> Occurences -> Occurences
altOccurences occurs1 occurs2 =
filter (not . nullOccurence) $ map filt1 occurs1 ++ map filt2 occurs2
where filt1 occ =
occ { consumed = S.difference <$> consumed occ <*> pure cons2
, observed = observed occ `S.difference` cons2 }
filt2 occ =
occ { consumed = consumed occ
, observed = observed occ `S.difference` cons1 }
cons1 = allConsumed occurs1
cons2 = allConsumed occurs2
data Locality = Local | Global
deriving (Show)
data ValBinding = BoundV Locality [TypeParam] PatternType
| OverloadedF [PrimType] [Maybe PrimType] (Maybe PrimType)
| EqualityF
| OpaqueF
| WasConsumed SrcLoc
deriving (Show)
data TermScope = TermScope { scopeVtable :: M.Map VName ValBinding
, scopeTypeTable :: M.Map VName TypeBinding
, scopeNameMap :: NameMap
, scopeBreadCrumbs :: [BreadCrumb]
} deriving (Show)
instance Semigroup TermScope where
TermScope vt1 tt1 nt1 bc1 <> TermScope vt2 tt2 nt2 bc2 =
TermScope (vt2 `M.union` vt1) (tt2 `M.union` tt1) (nt2 `M.union` nt1) (bc1 <> bc2)
instance Monoid TermScope where
mempty = TermScope mempty mempty mempty mempty
envToTermScope :: Env -> TermScope
envToTermScope env = TermScope vtable (envTypeTable env) (envNameMap env) mempty
where vtable = M.mapWithKey valBinding $ envVtable env
valBinding k (TypeM.BoundV tps v) =
BoundV Global tps $ v `setAliases`
(if arrayRank v > 0 then S.singleton (AliasBound k) else mempty)
constraintTypeVars :: Constraints -> Names
constraintTypeVars = mconcat . map f . M.elems
where f (Constraint t _) = typeVars t
f _ = mempty
overloadedTypeVars :: Constraints -> Names
overloadedTypeVars = mconcat . map f . M.elems
where f (HasFields fs _) = mconcat $ map typeVars $ M.elems fs
f _ = mempty
expType :: Exp -> TermTypeM PatternType
expType = normaliseType . typeOf
type TermTypeState = (Constraints, Int)
newtype TermTypeM a = TermTypeM (RWST
TermScope
Occurences
TermTypeState
TypeM
a)
deriving (Monad, Functor, Applicative,
MonadReader TermScope,
MonadWriter Occurences,
MonadState TermTypeState,
MonadError TypeError)
instance Fail.MonadFail TermTypeM where
fail = typeError (noLoc :: SrcLoc) . ("unknown failure (likely a bug): "++)
instance MonadUnify TermTypeM where
getConstraints = gets fst
putConstraints x = modify $ \s -> (x, snd s)
newTypeVar loc desc = do
i <- incCounter
v <- newID $ mkTypeVarName desc i
modifyConstraints $ M.insert v $ NoConstraint Nothing $ mkUsage' loc
return $ Scalar $ TypeVar mempty Nonunique (typeName v) []
instance MonadBreadCrumbs TermTypeM where
breadCrumb bc = local $ \env ->
env { scopeBreadCrumbs = bc : scopeBreadCrumbs env }
getBreadCrumbs = asks scopeBreadCrumbs
runTermTypeM :: TermTypeM a -> TypeM (a, Occurences)
runTermTypeM (TermTypeM m) = do
initial_scope <- (initialTermScope <>) <$> (envToTermScope <$> askEnv)
evalRWST m initial_scope (mempty, 0)
liftTypeM :: TypeM a -> TermTypeM a
liftTypeM = TermTypeM . lift
incCounter :: TermTypeM Int
incCounter = do (x, i) <- get
put (x, i+1)
return i
initialTermScope :: TermScope
initialTermScope = TermScope initialVtable mempty topLevelNameMap mempty
where initialVtable = M.fromList $ mapMaybe addIntrinsicF $ M.toList intrinsics
funF ts t = foldr (arrow . Scalar . Prim) (Scalar $ Prim t) ts
arrow x y = Scalar $ Arrow mempty Unnamed x y
addIntrinsicF (name, IntrinsicMonoFun ts t) =
Just (name, BoundV Global [] $ funF ts t)
addIntrinsicF (name, IntrinsicOverloadedFun ts pts rts) =
Just (name, OverloadedF ts pts rts)
addIntrinsicF (name, IntrinsicPolyFun tvs pts rt) =
Just (name, BoundV Global tvs $
fromStruct $ vacuousShapeAnnotations $
Scalar $ Arrow mempty Unnamed pts' rt)
where pts' = case pts of [pt] -> pt
_ -> tupleRecord pts
addIntrinsicF (name, IntrinsicEquality) =
Just (name, EqualityF)
addIntrinsicF (name, IntrinsicOpaque) =
Just (name, OpaqueF)
addIntrinsicF _ = Nothing
instance MonadTypeChecker TermTypeM where
warn loc problem = liftTypeM $ warn loc problem
newName = liftTypeM . newName
newID = liftTypeM . newID
checkQualName space name loc = snd <$> checkQualNameWithEnv space name loc
bindNameMap m = local $ \scope ->
scope { scopeNameMap = m <> scopeNameMap scope }
localEnv env (TermTypeM m) = do
cur_state <- get
cur_scope <- ask
let cur_scope' =
cur_scope { scopeNameMap = scopeNameMap cur_scope `M.difference` envNameMap env }
(x,new_state,occs) <- liftTypeM $ localTmpEnv env $
runRWST m cur_scope' cur_state
tell occs
put new_state
return x
lookupType loc qn = do
outer_env <- liftTypeM askRootEnv
(scope, qn'@(QualName qs name)) <- checkQualNameWithEnv Type qn loc
case M.lookup name $ scopeTypeTable scope of
Nothing -> undefinedType loc qn
Just (TypeAbbr l ps def) ->
return (qn', ps, qualifyTypeVars outer_env (map typeParamName ps) qs def, l)
lookupMod loc name = liftTypeM $ TypeM.lookupMod loc name
lookupMTy loc name = liftTypeM $ TypeM.lookupMTy loc name
lookupImport loc name = liftTypeM $ TypeM.lookupImport loc name
lookupVar loc qn = do
outer_env <- liftTypeM askRootEnv
(scope, qn'@(QualName qs name)) <- checkQualNameWithEnv Term qn loc
let usage = mkUsage loc $ "use of " ++ quote (pretty qn)
t <- case M.lookup name $ scopeVtable scope of
Nothing -> throwError $ TypeError loc $
"Unknown variable " ++ quote (pretty qn) ++ "."
Just (WasConsumed wloc) -> useAfterConsume (baseName name) loc wloc
Just (BoundV _ tparams t)
| "_" `isPrefixOf` baseString name -> underscoreUse loc qn
| otherwise -> do
(tnames, t') <- instantiateTypeScheme loc tparams t
let qual = qualifyTypeVars outer_env tnames qs
qual . anyDimShapeAnnotations <$> normaliseType t'
Just OpaqueF -> do
argtype <- newTypeVar loc "t"
return $ Scalar $ Arrow mempty Unnamed argtype argtype
Just EqualityF -> do
argtype <- newTypeVar loc "t"
equalityType usage argtype
return $
Scalar $ Arrow mempty Unnamed argtype $
Scalar $ Arrow mempty Unnamed argtype $ Scalar $ Prim Bool
Just (OverloadedF ts pts rt) -> do
argtype <- newTypeVar loc "t"
mustBeOneOf ts usage argtype
let (pts', rt') = instOverloaded argtype pts rt
arrow xt yt = Scalar $ Arrow mempty Unnamed xt yt
return $ fromStruct $ vacuousShapeAnnotations $ foldr arrow rt' pts'
observe $ Ident name (Info t) loc
return (qn', t)
where instOverloaded argtype pts rt =
(map (maybe (toStruct argtype) (Scalar . Prim)) pts,
maybe (toStruct argtype) (Scalar . Prim) rt)
checkNamedDim loc v = do
(v', t) <- lookupVar loc v
unify (mkUsage loc "use as array size") (toStructural t) $
Scalar $ Prim $ Signed Int32
return v'
checkQualNameWithEnv :: Namespace -> QualName Name -> SrcLoc -> TermTypeM (TermScope, QualName VName)
checkQualNameWithEnv space qn@(QualName [q] _) loc
| nameToString q == "intrinsics" = do
(_, QualName _ q') <- liftTypeM $ TypeM.checkQualNameWithEnv Term (qualName q) loc
if baseTag q' <= maxIntrinsicTag
then checkIntrinsic space qn loc
else checkReallyQualName space qn loc
checkQualNameWithEnv space qn@(QualName quals name) loc = do
scope <- ask
case quals of
[] | Just name' <- M.lookup (space, name) $ scopeNameMap scope ->
return (scope, name')
_ -> checkReallyQualName space qn loc
checkIntrinsic :: Namespace -> QualName Name -> SrcLoc -> TermTypeM (TermScope, QualName VName)
checkIntrinsic space qn@(QualName _ name) loc
| Just v <- M.lookup (space, name) intrinsicsNameMap = do
me <- liftTypeM askImportName
unless ("/futlib" `isPrefixOf` includeToString me) $
warn loc "Using intrinsic functions directly can easily crash the compiler or result in wrong code generation."
scope <- ask
return (scope, v)
| otherwise =
unknownVariableError space qn loc
checkReallyQualName :: Namespace -> QualName Name -> SrcLoc -> TermTypeM (TermScope, QualName VName)
checkReallyQualName space qn loc = do
(env, name') <- liftTypeM $ TypeM.checkQualNameWithEnv space qn loc
return (envToTermScope env, name')
checkTypeDecl :: TypeDeclBase NoInfo Name -> TermTypeM (TypeDeclBase Info VName)
checkTypeDecl tdecl = do
(tdecl', _) <- Types.checkTypeDecl [] tdecl
mapM_ observeDim $ nestedDims $ unInfo $ expandedType tdecl'
return tdecl'
where observeDim (NamedDim v) =
observe $ Ident (qualLeaf v) (Info $ Scalar $ Prim $ Signed Int32) noLoc
observeDim _ = return ()
instantiateTypeScheme :: SrcLoc -> [TypeParam] -> PatternType
-> TermTypeM ([VName], PatternType)
instantiateTypeScheme loc tparams t = do
let tparams' = filter isTypeParam tparams
tnames = map typeParamName tparams'
(fresh_tnames, substs) <- unzip <$> mapM (instantiateTypeParam loc) tparams'
let substs' = M.fromList $ zip tnames substs
t' = substTypesAny (`M.lookup` substs') t
return (fresh_tnames, t')
instantiateTypeParam :: Monoid as => SrcLoc -> TypeParam -> TermTypeM (VName, Subst (TypeBase dim as))
instantiateTypeParam loc tparam = do
i <- incCounter
v <- newID $ mkTypeVarName (takeWhile isAlpha (baseString (typeParamName tparam))) i
modifyConstraints $ M.insert v $ NoConstraint (Just l) $ mkUsage' loc
return (v, Subst $ Scalar $ TypeVar mempty Nonunique (typeName v) [])
where l = case tparam of TypeParamType x _ _ -> x
_ -> Lifted
newArrayType :: SrcLoc -> String -> Int -> TermTypeM (TypeBase () (), TypeBase () ())
newArrayType loc desc r = do
v <- newID $ nameFromString desc
modifyConstraints $ M.insert v $ NoConstraint Nothing $ mkUsage' loc
let rowt = TypeVar () Nonunique (typeName v) []
return (Array () Nonunique rowt (ShapeDecl $ replicate r ()),
Scalar rowt)
useAfterConsume :: MonadTypeChecker m => Name -> SrcLoc -> SrcLoc -> m a
useAfterConsume name rloc wloc =
throwError $ TypeError rloc $
"Variable " ++ quote (pretty name) ++ " previously consumed at " ++
locStr wloc ++ ". (Possibly through aliasing)"
consumeAfterConsume :: MonadTypeChecker m => Name -> SrcLoc -> SrcLoc -> m a
consumeAfterConsume name loc1 loc2 =
throwError $ TypeError loc2 $
"Variable " ++ pretty name ++ " previously consumed at " ++ locStr loc1 ++ "."
badLetWithValue :: MonadTypeChecker m => SrcLoc -> m a
badLetWithValue loc =
throwError $ TypeError loc
"New value for elements in let-with shares data with source array. This is illegal, as it prevents in-place modification."
returnAliased :: MonadTypeChecker m => Name -> Name -> SrcLoc -> m ()
returnAliased fname name loc =
throwError $ TypeError loc $
"Unique return value of function " ++ quote (pretty fname) ++
" is aliased to " ++ quote (pretty name) ++ ", which is not consumed."
uniqueReturnAliased :: MonadTypeChecker m => Name -> SrcLoc -> m a
uniqueReturnAliased fname loc =
throwError $ TypeError loc $
"A unique tuple element of return value of `" ++
quote (pretty fname) ++ "` is aliased to some other tuple component."
unifyExpTypes :: Exp -> Exp -> TermTypeM PatternType
unifyExpTypes e1 e2 = do
e1_t <- expType e1
e2_t <- expType e2
unify (mkUsage (srclocOf e2) "requiring equality of types") (toStructural e1_t) (toStructural e2_t)
return $ unifyTypeAliases e1_t e2_t
unifyTypeAliases :: PatternType -> PatternType -> PatternType
unifyTypeAliases t1 t2 =
case (t1, t2) of
(Array als1 u1 et1 shape1, Array als2 u2 _ _) ->
Array (als1<>als2) (min u1 u2) et1 shape1
(Scalar (Record f1), Scalar (Record f2)) ->
Scalar $ Record $ M.intersectionWith unifyTypeAliases f1 f2
(Scalar (TypeVar als1 u v targs1), Scalar (TypeVar als2 _ _ targs2)) ->
Scalar $ TypeVar (als1 <> als2) u v $ zipWith unifyTypeArg targs1 targs2
_ -> t1
where unifyTypeArg (TypeArgType t1' loc) (TypeArgType _ _) =
TypeArgType t1' loc
unifyTypeArg a _ = a
data InferredType = NoneInferred
| Ascribed PatternType
checkPattern' :: UncheckedPattern -> InferredType
-> TermTypeM Pattern
checkPattern' (PatternParens p loc) t =
PatternParens <$> checkPattern' p t <*> pure loc
checkPattern' (Id name NoInfo loc) (Ascribed t) = do
name' <- checkName Term name loc
return $ Id name' (Info t) loc
checkPattern' (Id name NoInfo loc) NoneInferred = do
name' <- checkName Term name loc
t <- newTypeVar loc "t"
return $ Id name' (Info t) loc
checkPattern' (Wildcard _ loc) (Ascribed t) =
return $ Wildcard (Info $ t `setUniqueness` Nonunique) loc
checkPattern' (Wildcard NoInfo loc) NoneInferred = do
t <- newTypeVar loc "t"
return $ Wildcard (Info t) loc
checkPattern' (TuplePattern ps loc) (Ascribed t)
| Just ts <- isTupleRecord t, length ts == length ps =
TuplePattern <$> zipWithM checkPattern' ps (map Ascribed ts) <*> pure loc
checkPattern' p@(TuplePattern ps loc) (Ascribed t) = do
ps_t <- replicateM (length ps) (newTypeVar loc "t")
unify (mkUsage loc "matching a tuple pattern") (tupleRecord ps_t) $ toStructural t
t' <- normaliseType t
checkPattern' p $ Ascribed t'
checkPattern' (TuplePattern ps loc) NoneInferred =
TuplePattern <$> mapM (`checkPattern'` NoneInferred) ps <*> pure loc
checkPattern' (RecordPattern p_fs _) _
| Just (f, fp) <- find (("_" `isPrefixOf`) . nameToString . fst) p_fs =
typeError fp $ unlines [ "Underscore-prefixed fields are not allowed."
, "Did you mean " ++
quote (drop 1 (nameToString f) ++ "=_") ++ "?"]
checkPattern' (RecordPattern p_fs loc) (Ascribed (Scalar (Record t_fs)))
| sort (map fst p_fs) == sort (M.keys t_fs) =
RecordPattern . M.toList <$> check <*> pure loc
where check = traverse (uncurry checkPattern') $ M.intersectionWith (,)
(M.fromList p_fs) (fmap Ascribed t_fs)
checkPattern' p@(RecordPattern fields loc) (Ascribed t) = do
fields' <- traverse (const $ newTypeVar loc "t") $ M.fromList fields
when (sort (M.keys fields') /= sort (map fst fields)) $
typeError loc $ "Duplicate fields in record pattern " ++ pretty p
unify (mkUsage loc "matching a record pattern") (Scalar (Record fields')) $ toStructural t
t' <- normaliseType t
checkPattern' p $ Ascribed t'
checkPattern' (RecordPattern fs loc) NoneInferred =
RecordPattern . M.toList <$> traverse (`checkPattern'` NoneInferred) (M.fromList fs) <*> pure loc
checkPattern' (PatternAscription p (TypeDecl t NoInfo) loc) maybe_outer_t = do
(t', st, _) <- checkTypeExp t
let st' = fromStruct st
case maybe_outer_t of
Ascribed outer_t -> do
unify (mkUsage loc "explicit type ascription") (toStructural st) (toStructural outer_t)
st'' <- normaliseType st'
outer_t' <- normaliseType outer_t
case unifyTypesU unifyUniqueness st' outer_t' of
Just outer_t'' ->
PatternAscription <$> checkPattern' p (Ascribed outer_t'') <*>
pure (TypeDecl t' (Info st)) <*> pure loc
Nothing ->
typeError loc $ "Cannot match type " ++ quote (pretty outer_t') ++ " with expected type " ++
quote (pretty st'') ++ "."
NoneInferred ->
PatternAscription <$> checkPattern' p (Ascribed st') <*>
pure (TypeDecl t' (Info st)) <*> pure loc
where unifyUniqueness u1 u2 = if u2 `subuniqueOf` u1 then Just u1 else Nothing
checkPattern' (PatternLit e NoInfo loc) (Ascribed t) = do
e' <- checkExp e
t' <- expType e'
unify (mkUsage loc "matching against literal") (toStructural t') (toStructural t)
return $ PatternLit e' (Info t') loc
checkPattern' (PatternLit e NoInfo loc) NoneInferred = do
e' <- checkExp e
t' <- expType e'
return $ PatternLit e' (Info t') loc
checkPattern' (PatternConstr n NoInfo ps loc) (Ascribed (Scalar (Sum cs)))
| Just ts <- M.lookup n cs = do
ps' <- zipWithM checkPattern' ps $ map Ascribed ts
return $ PatternConstr n (Info (Scalar (Sum cs))) ps' loc
checkPattern' (PatternConstr n NoInfo ps loc) (Ascribed t) = do
t' <- newTypeVar loc "t"
ps' <- mapM (`checkPattern'` NoneInferred) ps
mustHaveConstr usage n t' (toStructural . patternType <$> ps')
unify usage t' (toStructural t)
t'' <- normaliseType t
return $ PatternConstr n (Info t'') ps' loc
where usage = mkUsage loc "matching against constructor"
checkPattern' (PatternConstr n NoInfo ps loc) NoneInferred = do
ps' <- mapM (`checkPattern'` NoneInferred) ps
t <- newTypeVar loc "t"
mustHaveConstr usage n t (toStructural . patternType <$> ps')
return $ PatternConstr n (Info t) ps' loc
where usage = mkUsage loc "matching against constructor"
bindPatternNames :: PatternBase NoInfo Name -> TermTypeM a -> TermTypeM a
bindPatternNames = bindSpaced . map asTerm . S.toList . patternIdents
where asTerm v = (Term, identName v)
checkPattern :: UncheckedPattern -> InferredType -> (Pattern -> TermTypeM a)
-> TermTypeM a
checkPattern p t m = do
checkForDuplicateNames [p]
bindPatternNames p $
m =<< checkPattern' p t
binding :: [Ident] -> TermTypeM a -> TermTypeM a
binding bnds = check . local (`bindVars` bnds)
where bindVars :: TermScope -> [Ident] -> TermScope
bindVars = foldl bindVar
bindVar :: TermScope -> Ident -> TermScope
bindVar scope (Ident name (Info tp) _) =
let inedges = boundAliases $ aliases tp
update (BoundV l tparams in_t)
| Array{} <- tp = BoundV l tparams (in_t `addAliases` S.insert (AliasBound name))
| otherwise = BoundV l tparams in_t
update b = b
tp' = tp `addAliases` S.insert (AliasBound name)
in scope { scopeVtable = M.insert name (BoundV Local [] tp') $
adjustSeveral update inedges $
scopeVtable scope
}
adjustSeveral f = flip $ foldl $ flip $ M.adjust f
check m = do
(a, usages) <- collectBindingsOccurences m
checkOccurences usages
mapM_ (checkIfUsed usages) bnds
return a
collectBindingsOccurences m = pass $ do
(x, usage) <- listen m
let (relevant, rest) = split usage
return ((x, relevant), const rest)
where split = unzip .
map (\occ ->
let (obs1, obs2) = divide $ observed occ
occ_cons = divide <$> consumed occ
con1 = fst <$> occ_cons
con2 = snd <$> occ_cons
in (occ { observed = obs1, consumed = con1 },
occ { observed = obs2, consumed = con2 }))
names = S.fromList $ map identName bnds
divide s = (s `S.intersection` names, s `S.difference` names)
bindingTypes :: [(VName, (TypeBinding, Constraint))] -> TermTypeM a -> TermTypeM a
bindingTypes types m = do
modifyConstraints (<>M.map snd (M.fromList types))
local extend m
where extend scope = scope {
scopeTypeTable = M.map fst (M.fromList types) <> scopeTypeTable scope
}
bindingTypeParams :: [TypeParam] -> TermTypeM a -> TermTypeM a
bindingTypeParams tparams = binding (mapMaybe typeParamIdent tparams) .
bindingTypes (mapMaybe typeParamType tparams)
where typeParamType (TypeParamType l v loc) =
Just (v, (TypeAbbr l [] (Scalar (TypeVar () Nonunique (typeName v) [])),
ParamType l loc))
typeParamType TypeParamDim{} =
Nothing
typeParamIdent :: TypeParam -> Maybe Ident
typeParamIdent (TypeParamDim v loc) =
Just $ Ident v (Info $ Scalar $ Prim $ Signed Int32) loc
typeParamIdent _ = Nothing
bindingIdent :: IdentBase NoInfo Name -> PatternType -> (Ident -> TermTypeM a)
-> TermTypeM a
bindingIdent (Ident v NoInfo vloc) t m =
bindSpaced [(Term, v)] $ do
v' <- checkName Term v vloc
let ident = Ident v' (Info t) vloc
binding [ident] $ m ident
bindingPatternGroup :: [UncheckedTypeParam]
-> [UncheckedPattern]
-> ([TypeParam] -> [Pattern] -> TermTypeM a) -> TermTypeM a
bindingPatternGroup tps orig_ps m = do
checkForDuplicateNames orig_ps
checkTypeParams tps $ \tps' -> bindingTypeParams tps' $ do
let descend ps' (p:ps) =
checkPattern p NoneInferred $ \p' ->
binding (S.toList $ patternIdents p') $ descend (p':ps') ps
descend ps' [] = do
mapM_ observe $ mapMaybe typeParamIdent tps'
let ps'' = reverse ps'
checkShapeParamUses patternUses tps' ps''
m tps' ps''
descend [] orig_ps
bindingPattern :: PatternBase NoInfo Name -> InferredType
-> (Pattern -> TermTypeM a) -> TermTypeM a
bindingPattern p t m = do
checkForDuplicateNames [p]
checkPattern p t $ \p' -> binding (S.toList $ patternIdents p') $ do
mapM_ observe $ patternDims p'
m p'
patternUses :: Pattern -> ([VName], [VName])
patternUses Id{} = mempty
patternUses Wildcard{} = mempty
patternUses PatternLit{} = mempty
patternUses (PatternParens p _) = patternUses p
patternUses (TuplePattern ps _) = foldMap patternUses ps
patternUses (RecordPattern fs _) = foldMap (patternUses . snd) fs
patternUses (PatternAscription p (TypeDecl declte _) _) =
patternUses p <> typeExpUses declte
patternUses (PatternConstr _ _ ps _) = foldMap patternUses ps
patternDims :: Pattern -> [Ident]
patternDims (PatternParens p _) = patternDims p
patternDims (TuplePattern pats _) = concatMap patternDims pats
patternDims (PatternAscription p (TypeDecl _ (Info t)) _) =
patternDims p <> mapMaybe (dimIdent (srclocOf p)) (nestedDims t)
where dimIdent _ AnyDim = Nothing
dimIdent _ (ConstDim _) = Nothing
dimIdent _ NamedDim{} = Nothing
patternDims _ = []
require :: String -> [PrimType] -> Exp -> TermTypeM Exp
require why ts e = do mustBeOneOf ts (mkUsage (srclocOf e) why) . toStructural =<< expType e
return e
unifies :: String -> TypeBase () () -> Exp -> TermTypeM Exp
unifies why t e = do
unify (mkUsage (srclocOf e) why) t =<< toStructural <$> expType e
return e
lexicalClosure :: [Pattern] -> Occurences -> TermTypeM Aliasing
lexicalClosure params closure = do
vtable <- asks scopeVtable
let isLocal v = case v `M.lookup` vtable of
Just (BoundV Local _ _) -> True
_ -> False
return $ S.map AliasBound $ S.filter isLocal $
allOccuring closure S.\\
S.map identName (mconcat (map patternIdents params))
checkExp :: UncheckedExp -> TermTypeM Exp
checkExp (Literal val loc) =
return $ Literal val loc
checkExp (IntLit val NoInfo loc) = do
t <- newTypeVar loc "t"
mustBeOneOf anyNumberType (mkUsage loc "integer literal") t
return $ IntLit val (Info $ vacuousShapeAnnotations $ fromStruct t) loc
checkExp (FloatLit val NoInfo loc) = do
t <- newTypeVar loc "t"
mustBeOneOf anyFloatType (mkUsage loc "float literal") t
return $ FloatLit val (Info $ vacuousShapeAnnotations $ fromStruct t) loc
checkExp (TupLit es loc) =
TupLit <$> mapM checkExp es <*> pure loc
checkExp (RecordLit fs loc) = do
fs' <- evalStateT (mapM checkField fs) mempty
return $ RecordLit fs' loc
where checkField (RecordFieldExplicit f e rloc) = do
errIfAlreadySet f rloc
modify $ M.insert f rloc
RecordFieldExplicit f <$> lift (checkExp e) <*> pure rloc
checkField (RecordFieldImplicit name NoInfo rloc) = do
errIfAlreadySet name rloc
(QualName _ name', t) <- lift $ lookupVar rloc $ qualName name
modify $ M.insert name rloc
return $ RecordFieldImplicit name' (Info t) rloc
errIfAlreadySet f rloc = do
maybe_sloc <- gets $ M.lookup f
case maybe_sloc of
Just sloc ->
lift $ typeError rloc $ "Field '" ++ pretty f ++
" previously defined at " ++ locStr sloc ++ "."
Nothing -> return ()
checkExp (ArrayLit all_es _ loc) =
case all_es of
[] -> do et <- newTypeVar loc "t"
t <- arrayOfM loc et (ShapeDecl [AnyDim]) Unique
return $ ArrayLit [] (Info t) loc
e:es -> do
e' <- checkExp e
et <- expType e'
es' <- mapM (unifies "type of first array element" (toStructural et) <=< checkExp) es
et' <- normaliseType et
t <- arrayOfM loc et' (ShapeDecl [AnyDim]) Unique
return $ ArrayLit (e':es') (Info t) loc
checkExp (Range start maybe_step end NoInfo loc) = do
start' <- require "use in range expression" anyIntType =<< checkExp start
start_t <- toStructural <$> expType start'
maybe_step' <- case maybe_step of
Nothing -> return Nothing
Just step -> do
let warning = warn loc "First and second element of range are identical, this will produce an empty array."
case (start, step) of
(Literal x _, Literal y _) -> when (x == y) warning
(Var x_name _ _, Var y_name _ _) -> when (x_name == y_name) warning
_ -> return ()
Just <$> (unifies "use in range expression" start_t =<< checkExp step)
end' <- case end of
DownToExclusive e -> DownToExclusive <$>
(unifies "use in range expression" start_t =<< checkExp e)
UpToExclusive e -> UpToExclusive <$>
(unifies "use in range expression" start_t =<< checkExp e)
ToInclusive e -> ToInclusive <$>
(unifies "use in range expression" start_t =<< checkExp e)
t <- arrayOfM loc start_t (rank 1) Unique
return $ Range start' maybe_step' end'
(Info (vacuousShapeAnnotations t `setAliases` mempty)) loc
checkExp (Ascript e decl NoInfo loc) = do
decl' <- checkTypeDecl decl
e' <- checkExp e
t <- expType e'
let decl_t = unInfo $ expandedType decl'
unify (mkUsage loc "explicit type ascription") (toStructural decl_t) (toStructural t)
t' <- normaliseType t
decl_t' <- normaliseType decl_t
unless (t' `subtypeOf` anyDimShapeAnnotations decl_t') $
typeError loc $ "Type " ++ quote (pretty t') ++ " is not a subtype of " ++
quote (pretty decl_t') ++ "."
return $ Ascript e' decl' (Info (combineTypeShapes t $ fromStruct decl_t)) loc
checkExp (BinOp op NoInfo (e1,_) (e2,_) NoInfo loc) = do
(op', ftype) <- lookupVar loc op
(e1', e1_arg) <- checkArg e1
(e2', e2_arg) <- checkArg e2
(p1_t, rt) <- checkApply loc ftype e1_arg
(p2_t, rt') <- checkApply loc rt e2_arg
return $ BinOp op' (Info ftype)
(e1', Info $ toStruct p1_t) (e2', Info $ toStruct p2_t)
(Info rt') loc
checkExp (Project k e NoInfo loc) = do
e' <- checkExp e
t <- expType e'
kt <- mustHaveField (mkUsage loc $ "projection of field " ++ quote (pretty k)) k t
return $ Project k e' (Info kt) loc
checkExp (If e1 e2 e3 _ loc) =
sequentially checkCond $ \e1' _ -> do
((e2', e3'), dflow) <- tapOccurences $ checkExp e2 `alternative` checkExp e3
brancht <- unifyExpTypes e2' e3'
let t' = addAliases brancht (`S.difference` S.map AliasBound (allConsumed dflow))
zeroOrderType (mkUsage loc "returning value of this type from 'if' expression") "returned from branch" t'
return $ If e1' e2' e3' (Info t') loc
where checkCond = do
e1' <- checkExp e1
unify (mkUsage (srclocOf e1') "use as 'if' condition") (Scalar $ Prim Bool) . toStructural =<< expType e1'
return e1'
checkExp (Parens e loc) =
Parens <$> checkExp e <*> pure loc
checkExp (QualParens modname e loc) = do
(modname',mod) <- lookupMod loc modname
case mod of
ModEnv env -> localEnv (qualifyEnv modname' env) $ do
e' <- checkExp e
return $ QualParens modname' e' loc
ModFun{} ->
typeError loc $ "Module " ++ pretty modname ++ " is a parametric module."
where qualifyEnv modname' env =
env { envNameMap = M.map (qualify' modname') $ envNameMap env }
qualify' modname' (QualName qs name) =
QualName (qualQuals modname' ++ [qualLeaf modname'] ++ qs) name
checkExp (Var qn NoInfo loc) = do
(qn', t, fields) <- findRootVar (qualQuals qn) (qualLeaf qn)
foldM checkField (Var qn' (Info t) loc) fields
where findRootVar qs name =
(whenFound <$> lookupVar loc (QualName qs name)) `catchError` notFound qs name
whenFound (qn', t) = (qn', t, [])
notFound qs name err
| null qs = throwError err
| otherwise = do
(qn', t, fields) <- findRootVar (init qs) (last qs) `catchError`
const (throwError err)
return (qn', t, fields++[name])
checkField e k = do
t <- expType e
let usage = mkUsage loc $ "projection of field " ++ quote (pretty k)
kt <- mustHaveField usage k t
return $ Project k e (Info kt) loc
checkExp (Negate arg loc) = do
arg' <- require "numeric negation" anyNumberType =<< checkExp arg
return $ Negate arg' loc
checkExp (Apply e1 e2 NoInfo NoInfo loc) = do
e1' <- checkExp e1
(e2', arg) <- checkArg e2
t <- expType e1'
(t1, rt) <- checkApply loc t arg
return $ Apply e1' e2' (Info $ diet t1) (Info rt) loc
checkExp (LetPat pat e body NoInfo loc) =
sequentially (checkExp e) $ \e' e_occs -> do
t <- expType e'
case anyConsumption e_occs of
Just c ->
let msg = "of value computed with consumption at " ++ locStr (location c)
in zeroOrderType (mkUsage loc "consumption in right-hand side of 'let'-binding") msg t
_ -> return ()
bindingPattern pat (Ascribed $ anyDimShapeAnnotations t) $ \pat' -> do
body' <- checkExp body
body_t <- unscopeType (S.map identName $ patternIdents pat') <$> expType body'
return $ LetPat pat' e' body' (Info body_t) loc
checkExp (LetFun name (tparams, params, maybe_retdecl, NoInfo, e) body loc) =
sequentially (checkFunDef' (name, maybe_retdecl, tparams, params, e, loc)) $
\(name', tparams', params', maybe_retdecl', rettype, e') closure -> do
closure' <- lexicalClosure params' closure
let arrow (xp, xt) yt = Scalar $ Arrow () xp xt yt
ftype = foldr (arrow . patternParam) rettype params'
entry = BoundV Local tparams' $ ftype `setAliases` closure'
bindF scope = scope { scopeVtable = M.insert name' entry $ scopeVtable scope
, scopeNameMap = M.insert (Term, name) (qualName name') $
scopeNameMap scope }
body' <- local bindF $ checkExp body
return $ LetFun name' (tparams', params', maybe_retdecl', Info rettype, e') body' loc
checkExp (LetWith dest src idxes ve body NoInfo loc) = do
(t, _) <- newArrayType (srclocOf src) "src" $ length idxes
let elemt = stripArray (length $ filter isFix idxes) t
sequentially (checkIdent src) $ \src' _ -> do
let src'' = Var (qualName $ identName src') (identType src') (srclocOf src)
void $ unifies "type of target array" t src''
unless (unique $ unInfo $ identType src') $
typeError loc $ "Source " ++ quote (pretty (identName src)) ++
" has type " ++ pretty (unInfo $ identType src') ++ ", which is not unique."
vtable <- asks scopeVtable
forM_ (aliases $ unInfo $ identType src') $ \v ->
case aliasVar v `M.lookup` vtable of
Just (BoundV Local _ v_t)
| not $ unique v_t ->
typeError loc $ "Source " ++ quote (pretty (identName src)) ++
" aliases " ++ quote (prettyName (aliasVar v)) ++ ", which is not consumable."
_ -> return ()
idxes' <- mapM checkDimIndex idxes
sequentially (unifies "type of target array" elemt =<< checkExp ve) $ \ve' _ -> do
ve_t <- expType ve'
when (AliasBound (identName src') `S.member` aliases ve_t) $
badLetWithValue loc
bindingIdent dest (unInfo (identType src') `setAliases` S.empty) $ \dest' -> do
body' <- consuming src' $ checkExp body
body_t <- unscopeType (S.singleton $ identName dest') <$> expType body'
return $ LetWith dest' src' idxes' ve' body' (Info body_t) loc
where isFix DimFix{} = True
isFix _ = False
checkExp (Update src idxes ve loc) = do
(t, _) <- newArrayType (srclocOf src) "src" $ length idxes
let elemt = stripArray (length $ filter isFix idxes) t
sequentially (checkExp ve >>= unifies "type of target array" elemt) $ \ve' _ ->
sequentially (checkExp src >>= unifies "type of target array" t) $ \src' _ -> do
idxes' <- mapM checkDimIndex idxes
src_t <- expType src'
unless (unique src_t) $
typeError loc $ "Source " ++ quote (pretty src) ++
" has type " ++ pretty src_t ++ ", which is not unique"
let src_als = aliases src_t
ve_t <- expType ve'
unless (S.null $ src_als `S.intersection` aliases ve_t) $ badLetWithValue loc
consume loc src_als
return $ Update src' idxes' ve' loc
where isFix DimFix{} = True
isFix _ = False
checkExp (RecordUpdate src fields ve NoInfo loc) = do
src' <- checkExp src
ve' <- checkExp ve
a <- expType src'
let usage = mkUsage loc "record update"
r <- foldM (flip $ mustHaveField usage) a fields
unify usage (toStructural r) . toStructural =<< expType ve'
return $ RecordUpdate src' fields ve' (Info a) loc
checkExp (Index e idxes NoInfo loc) = do
(t, _) <- newArrayType (srclocOf e) "e" $ length idxes
e' <- unifies "being indexed at" t =<< checkExp e
idxes' <- mapM checkDimIndex idxes
t' <- anyDimShapeAnnotations .
stripArray (length $ filter isFix idxes) <$>
normaliseType (typeOf e')
return $ Index e' idxes' (Info t') loc
where isFix DimFix{} = True
isFix _ = False
checkExp (Unsafe e loc) =
Unsafe <$> checkExp e <*> pure loc
checkExp (Assert e1 e2 NoInfo loc) = do
e1' <- require "being asserted" [Bool] =<< checkExp e1
e2' <- checkExp e2
return $ Assert e1' e2' (Info (pretty e1)) loc
checkExp (Lambda params body rettype_te NoInfo loc) =
removeSeminullOccurences $
bindingPatternGroup [] params $ \_ params' -> do
rettype_checked <- traverse checkTypeExp rettype_te
let declared_rettype =
case rettype_checked of Just (_, st, _) -> Just st
Nothing -> Nothing
(body', closure) <-
tapOccurences $ noUnique $ checkFunBody body declared_rettype loc
body_t <- expType body'
let (rettype', rettype_st) =
case rettype_checked of
Just (te, st, _) -> (Just te, st)
Nothing -> (Nothing, inferReturnUniqueness params' body_t)
checkGlobalAliases params' body_t loc
closure' <- lexicalClosure params' closure
return $ Lambda params' body' rettype' (Info (closure', rettype_st)) loc
checkExp (OpSection op _ loc) = do
(op', ftype) <- lookupVar loc op
return $ OpSection op' (Info ftype) loc
checkExp (OpSectionLeft op _ e _ _ loc) = do
(op', ftype) <- lookupVar loc op
(e', e_arg) <- checkArg e
(t1, rt) <- checkApply loc ftype e_arg
case rt of
Scalar (Arrow _ _ t2 rettype) ->
return $ OpSectionLeft op' (Info ftype) e'
(Info $ toStruct t1, Info $ toStruct t2) (Info rettype) loc
_ -> typeError loc $
"Operator section with invalid operator of type " ++ pretty ftype
checkExp (OpSectionRight op _ e _ _ loc) = do
(op', ftype) <- lookupVar loc op
(e', e_arg) <- checkArg e
case ftype of
Scalar (Arrow as1 m1 t1 (Scalar (Arrow as2 m2 t2 ret))) -> do
(t2', Scalar (Arrow _ _ t1' rettype)) <-
checkApply loc (Scalar $ Arrow as2 m2 t2 $ Scalar $ Arrow as1 m1 t1 ret) e_arg
return $ OpSectionRight op' (Info ftype) e'
(Info $ toStruct t1', Info $ toStruct t2') (Info rettype) loc
_ -> typeError loc $
"Operator section with invalid operator of type " ++ pretty ftype
checkExp (ProjectSection fields NoInfo loc) = do
a <- newTypeVar loc "a"
let usage = mkUsage loc "projection at"
b <- foldM (flip $ mustHaveField usage) a fields
return $ ProjectSection fields (Info $ Scalar $ Arrow mempty Unnamed a b) loc
checkExp (IndexSection idxes NoInfo loc) = do
(t, _) <- newArrayType loc "e" (length idxes)
idxes' <- mapM checkDimIndex idxes
let t' = stripArray (length $ filter isFix idxes) t
return $ IndexSection idxes' (Info $ vacuousShapeAnnotations $ fromStruct $
Scalar $ Arrow mempty Unnamed t t') loc
where isFix DimFix{} = True
isFix _ = False
checkExp (DoLoop mergepat mergeexp form loopbody loc) =
sequentially (checkExp mergeexp) $ \mergeexp' _ -> do
zeroOrderType (mkUsage (srclocOf mergeexp) "use as loop variable")
"used as loop variable" (typeOf mergeexp')
merge_t <- do
merge_t <- expType mergeexp'
return $ Ascribed $ anyDimShapeAnnotations $ merge_t `setAliases` mempty
((mergepat', form', loopbody'), bodyflow) <-
case form of
For i uboundexp -> do
uboundexp' <- require "being the bound in a 'for' loop" anySignedType =<< checkExp uboundexp
bound_t <- expType uboundexp'
bindingIdent i bound_t $ \i' ->
noUnique $ bindingPattern mergepat merge_t $
\mergepat' -> onlySelfAliasing $ tapOccurences $ do
loopbody' <- checkExp loopbody
return (mergepat',
For i' uboundexp',
loopbody')
ForIn xpat e -> do
(arr_t, _) <- newArrayType (srclocOf e) "e" 1
e' <- unifies "being iterated in a 'for-in' loop" arr_t =<< checkExp e
t <- expType e'
case t of
_ | Just t' <- peelArray 1 t ->
bindingPattern xpat (Ascribed t') $ \xpat' ->
noUnique $ bindingPattern mergepat merge_t $
\mergepat' -> onlySelfAliasing $ tapOccurences $ do
loopbody' <- checkExp loopbody
return (mergepat',
ForIn xpat' e',
loopbody')
| otherwise ->
typeError (srclocOf e) $
"Iteratee of a for-in loop must be an array, but expression has type " ++ pretty t
While cond ->
noUnique $ bindingPattern mergepat merge_t $ \mergepat' ->
onlySelfAliasing $ tapOccurences $
sequentially (checkExp cond >>=
unifies "being the condition of a 'while' loop" (Scalar $ Prim Bool)) $ \cond' _ -> do
loopbody' <- checkExp loopbody
return (mergepat',
While cond',
loopbody')
mergepat'' <- do
loop_t <- expType loopbody'
convergePattern mergepat' (allConsumed bodyflow) loop_t $
mkUsage (srclocOf loopbody') "being (part of) the result of the loop body"
let consumeMerge (Id _ (Info pt) ploc) mt
| unique pt = consume ploc $ aliases mt
consumeMerge (TuplePattern pats _) t | Just ts <- isTupleRecord t =
zipWithM_ consumeMerge pats ts
consumeMerge (PatternParens pat _) t =
consumeMerge pat t
consumeMerge (PatternAscription pat _ _) t =
consumeMerge pat t
consumeMerge _ _ =
return ()
consumeMerge mergepat'' =<< expType mergeexp'
return $ DoLoop mergepat'' mergeexp' form' loopbody' loc
where
convergePattern pat body_cons body_t body_loc = do
let consumed_merge = S.map identName (patternIdents pat) `S.intersection`
body_cons
uniquePat (Wildcard (Info t) wloc) =
Wildcard (Info $ t `setUniqueness` Nonunique) wloc
uniquePat (PatternParens p ploc) =
PatternParens (uniquePat p) ploc
uniquePat (Id name (Info t) iloc)
| name `S.member` consumed_merge =
let t' = t `setUniqueness` Unique `setAliases` mempty
in Id name (Info t') iloc
| otherwise =
let t' = case t of Scalar Record{} -> t
_ -> t `setUniqueness` Nonunique
in Id name (Info t') iloc
uniquePat (TuplePattern pats ploc) =
TuplePattern (map uniquePat pats) ploc
uniquePat (RecordPattern fs ploc) =
RecordPattern (map (fmap uniquePat) fs) ploc
uniquePat (PatternAscription p t ploc) =
PatternAscription p t ploc
uniquePat p@PatternLit{} = p
uniquePat (PatternConstr n t ps ploc) =
PatternConstr n t (map uniquePat ps) ploc
pat' = uniquePat pat
unify body_loc (toStructural body_t) $ toStructural $ patternType pat'
body_t' <- normaliseType body_t
pat_t <- normaliseType $ patternType pat'
unless (body_t' `subtypeOf` pat_t) $
unexpectedType (srclocOf body_loc)
(toStructural body_t')
[toStructural pat_t]
bound_outside <- asks $ S.fromList . M.keys . scopeVtable
let checkMergeReturn (Id pat_v (Info pat_v_t) _) t
| unique pat_v_t,
v:_ <- S.toList $ S.map aliasVar (aliases t) `S.intersection` bound_outside =
lift $ typeError loc $ "Loop return value corresponding to merge parameter " ++
quote (prettyName pat_v) ++ " aliases " ++ prettyName v ++ "."
| otherwise = do
(cons,obs) <- get
unless (S.null $ aliases t `S.intersection` cons) $
lift $ typeError loc $ "Loop return value for merge parameter " ++
quote (prettyName pat_v) ++ " aliases other consumed merge parameter."
when (unique pat_v_t &&
not (S.null (aliases t `S.intersection` (cons<>obs)))) $
lift $ typeError loc $ "Loop return value for consuming merge parameter " ++
quote (prettyName pat_v) ++ " aliases previously returned value."
if unique pat_v_t
then put (cons<>aliases t, obs)
else put (cons, obs<>aliases t)
checkMergeReturn (PatternParens p _) t =
checkMergeReturn p t
checkMergeReturn (PatternAscription p _ _) t =
checkMergeReturn p t
checkMergeReturn (RecordPattern pfs _) (Scalar (Record tfs)) =
sequence_ $ M.elems $ M.intersectionWith checkMergeReturn (M.fromList pfs) tfs
checkMergeReturn (TuplePattern pats _) t | Just ts <- isTupleRecord t =
zipWithM_ checkMergeReturn pats ts
checkMergeReturn _ _ =
return ()
(pat_cons, _) <- execStateT (checkMergeReturn pat' body_t') (mempty, mempty)
let body_cons' = body_cons <> S.map aliasVar pat_cons
if body_cons' == body_cons && patternType pat' == patternType pat
then return pat'
else convergePattern pat' body_cons' body_t' body_loc
checkExp (Constr name es NoInfo loc) = do
t <- newTypeVar loc "t"
es' <- mapM checkExp es
ets <- mapM expType es'
mustHaveConstr (mkUsage loc "use of constructor") name t (toStructural <$> ets)
let als = mconcat (map aliases ets)
return $ Constr name es' (Info $ t `addAliases` (<>als)) loc
checkExp (Match e cs NoInfo loc) =
sequentially (checkExp e) $ \e' _ -> do
mt <- expType e'
(cs', t) <- checkCases mt cs
zeroOrderType (mkUsage loc "being returned 'match'") "returned from pattern match" t
return $ Match e' cs' (Info t) loc
checkCases :: PatternType
-> NE.NonEmpty (CaseBase NoInfo Name)
-> TermTypeM (NE.NonEmpty (CaseBase Info VName), PatternType)
checkCases mt rest_cs =
case NE.uncons rest_cs of
(c, Nothing) -> do
(c', t) <- checkCase mt c
return (c' NE.:| [], t)
(c, Just cs) -> do
(((c', c_t), (cs', cs_t)), dflow) <-
tapOccurences $ checkCase mt c `alternative` checkCases mt cs
unify (mkUsage (srclocOf c) "pattern match") (toStructural c_t) (toStructural cs_t)
let t = unifyTypeAliases c_t cs_t `addAliases`
(`S.difference` S.map AliasBound (allConsumed dflow))
return (NE.cons c' cs', t)
checkCase :: PatternType -> CaseBase NoInfo Name
-> TermTypeM (CaseBase Info VName, PatternType)
checkCase mt (CasePat p caseExp loc) =
bindingPattern p (Ascribed mt) $ \p' -> do
caseExp' <- checkExp caseExp
caseType <- expType caseExp'
return (CasePat p' caseExp' loc, caseType)
data Unmatched p = UnmatchedNum p [ExpBase Info VName]
| UnmatchedBool p
| UnmatchedConstr p
| Unmatched p
deriving (Functor, Show)
instance Pretty (Unmatched (PatternBase Info VName)) where
ppr um = case um of
(UnmatchedNum p nums) -> ppr' p <+> text "where p is not one of" <+> ppr nums
(UnmatchedBool p) -> ppr' p
(UnmatchedConstr p) -> ppr' p
(Unmatched p) -> ppr' p
where
ppr' (PatternAscription p t _) = ppr p <> text ":" <+> ppr t
ppr' (PatternParens p _) = parens $ ppr' p
ppr' (Id v _ _) = pprName v
ppr' (TuplePattern pats _) = parens $ commasep $ map ppr' pats
ppr' (RecordPattern fs _) = braces $ commasep $ map ppField fs
where ppField (name, t) = text (nameToString name) <> equals <> ppr' t
ppr' Wildcard{} = text "_"
ppr' (PatternLit e _ _) = ppr e
ppr' (PatternConstr n _ ps _) = text "#" <> ppr n <+> sep (map ppr' ps)
unpackPat :: Pattern -> [Maybe Pattern]
unpackPat Wildcard{} = [Nothing]
unpackPat (PatternParens p _) = unpackPat p
unpackPat Id{} = [Nothing]
unpackPat (TuplePattern ps _) = Just <$> ps
unpackPat (RecordPattern fs _) = Just . snd <$> sortFields (M.fromList fs)
unpackPat (PatternAscription p _ _) = unpackPat p
unpackPat p@PatternLit{} = [Just p]
unpackPat p@PatternConstr{} = [Just p]
wildPattern :: Pattern -> Int -> Unmatched Pattern -> Unmatched Pattern
wildPattern (TuplePattern ps loc) pos um = wildTuple <$> um
where wildTuple p = TuplePattern (take (pos - 1) ps' ++ [p] ++ drop pos ps') loc
ps' = map wildOut ps
wildOut p = Wildcard (Info (patternType p)) (srclocOf p)
wildPattern (RecordPattern fs loc) pos um = wildRecord <$> um
where wildRecord p =
RecordPattern (take (pos - 1) fs' ++ [(fst (fs!!(pos - 1)), p)] ++ drop pos fs') loc
fs' = map wildOut fs
wildOut (f,p) = (f, Wildcard (Info (patternType p)) (srclocOf p))
wildPattern (PatternAscription p _ _) pos um = wildPattern p pos um
wildPattern (PatternParens p _) pos um = wildPattern p pos um
wildPattern (PatternConstr n t ps loc) pos um = wildConstr <$> um
where wildConstr p = PatternConstr n t (take (pos - 1) ps' ++ [p] ++ drop pos ps') loc
ps' = map wildOut ps
wildOut p = Wildcard (Info (patternType p)) (srclocOf p)
wildPattern _ _ um = um
checkUnmatched :: (MonadBreadCrumbs m, MonadTypeChecker m) => Exp -> m ()
checkUnmatched e = void $ checkUnmatched' e >> astMap tv e
where checkUnmatched' (Match _ cs _ loc) =
let ps = fmap (\(CasePat p _ _) -> p) cs
in case unmatched id $ NE.toList ps of
[] -> return ()
ps' -> typeError loc $ "Unmatched cases in match expression: \n"
++ unlines (map ((" " ++) . pretty) ps')
checkUnmatched' _ = return ()
tv = ASTMapper { mapOnExp =
\e' -> checkUnmatched' e' >> return e'
, mapOnName = pure
, mapOnQualName = pure
, mapOnStructType = pure
, mapOnPatternType = pure
}
data ConstrPat = ConstrPat { constrName :: Name
, constrType :: PatternType
, constrPayload :: [Pattern]
, constrSrcLoc :: SrcLoc
}
instance Eq ConstrPat where
ConstrPat c1 _ _ _ == ConstrPat c2 _ _ _ = c1 == c2
instance Ord ConstrPat where
ConstrPat c1 _ _ _ `compare` ConstrPat c2 _ _ _ = c1 `compare` c2
unmatched :: (Unmatched Pattern -> Unmatched Pattern) -> [Pattern] -> [Unmatched Pattern]
unmatched hole orig_ps
| p:_ <- orig_ps,
sameStructure labeledCols = do
(i, cols) <- labeledCols
let hole' = if isConstr p then hole else hole . wildPattern p i
case sequence cols of
Nothing -> []
Just cs
| all isPatternLit cs -> map hole' $ localUnmatched cs
| otherwise -> unmatched hole' cs
| otherwise = []
where labeledCols = zip [1..] $ transpose $ map unpackPat orig_ps
localUnmatched :: [Pattern] -> [Unmatched Pattern]
localUnmatched [] = []
localUnmatched ps'@(p':_) =
case patternType p' of
Scalar (Sum cs'') ->
let constrs = M.keys cs''
matched = mapMaybe constr ps'
unmatched' = map (UnmatchedConstr . buildConstr cs'') $
constrs \\ map constrName matched
in case unmatched' of
[] ->
let constrGroups = group (sort matched)
removedConstrs = mapMaybe stripConstrs constrGroups
transposed = (fmap . fmap) transpose removedConstrs
findUnmatched (pc, trans) = do
col <- trans
case col of
[] -> []
((i, _):_) -> unmatched (wilder i pc) (map snd col)
wilder i pc s = (`PatternParens` noLoc) <$> wildPattern pc i s
in concatMap findUnmatched transposed
_ -> unmatched'
Scalar (Prim t) | not (any idOrWild ps') ->
case t of
Bool ->
let matched = nub $ mapMaybe (pExp >=> bool) $ filter isPatternLit ps'
in map (UnmatchedBool . buildBool (Scalar (Prim t))) $ [True, False] \\ matched
_ ->
let matched = mapMaybe pExp $ filter isPatternLit ps'
in [UnmatchedNum (buildId (Info $ Scalar $ Prim t) "p") matched]
_ -> []
isConstr PatternConstr{} = True
isConstr (PatternParens p _) = isConstr p
isConstr _ = False
stripConstrs :: [ConstrPat] -> Maybe (Pattern, [[(Int, Pattern)]])
stripConstrs (pc@ConstrPat{} : cs') = Just (unConstr pc, stripConstr pc : map stripConstr cs')
stripConstrs [] = Nothing
stripConstr :: ConstrPat -> [(Int, Pattern)]
stripConstr (ConstrPat _ _ ps' _) = zip [1..] ps'
sameStructure [] = True
sameStructure (x:xs) = all (\y -> length y == length x' ) xs'
where (x':xs') = map snd (x:xs)
pExp (PatternLit e' _ _) = Just e'
pExp _ = Nothing
constr (PatternConstr c (Info t) ps loc) = Just $ ConstrPat c t ps loc
constr (PatternParens p _) = constr p
constr (PatternAscription p' _ _) = constr p'
constr _ = Nothing
unConstr p =
PatternConstr (constrName p) (Info $ constrType p) (constrPayload p) (constrSrcLoc p)
isPatternLit PatternLit{} = True
isPatternLit (PatternAscription p' _ _) = isPatternLit p'
isPatternLit (PatternParens p' _) = isPatternLit p'
isPatternLit PatternConstr{} = True
isPatternLit _ = False
idOrWild Id{} = True
idOrWild Wildcard{} = True
idOrWild (PatternAscription p' _ _) = idOrWild p'
idOrWild (PatternParens p' _) = idOrWild p'
idOrWild _ = False
bool (Literal (BoolValue b) _ ) = Just b
bool _ = Nothing
buildConstr m c =
let t = Scalar $ Sum m
cs = m M.! c
wildCS = map (\ct -> Wildcard (Info ct) noLoc) cs
in if null wildCS
then PatternConstr c (Info t) [] noLoc
else PatternParens (PatternConstr c (Info t) wildCS noLoc) noLoc
buildBool t b =
PatternLit (Literal (BoolValue b) noLoc) (Info (vacuousShapeAnnotations t)) noLoc
buildId t n =
Id (VName (nameFromString n) (-1)) t noLoc
checkIdent :: IdentBase NoInfo Name -> TermTypeM Ident
checkIdent (Ident name _ loc) = do
(QualName _ name', vt) <- lookupVar loc (qualName name)
return $ Ident name' (Info vt) loc
checkDimIndex :: DimIndexBase NoInfo Name -> TermTypeM DimIndex
checkDimIndex (DimFix i) =
DimFix <$> (unifies "use as index" (Scalar $ Prim $ Signed Int32) =<< checkExp i)
checkDimIndex (DimSlice i j s) =
DimSlice <$> check i <*> check j <*> check s
where check = maybe (return Nothing) $
fmap Just . unifies "use as index" (Scalar $ Prim $ Signed Int32) <=< checkExp
sequentially :: TermTypeM a -> (a -> Occurences -> TermTypeM b) -> TermTypeM b
sequentially m1 m2 = do
(a, m1flow) <- collectOccurences m1
(b, m2flow) <- collectOccurences $ m2 a m1flow
occur $ m1flow `seqOccurences` m2flow
return b
type Arg = (PatternType, Occurences, SrcLoc)
argType :: Arg -> PatternType
argType (t, _, _) = t
checkArg :: UncheckedExp -> TermTypeM (Exp, Arg)
checkArg arg = do
(arg', dflow) <- collectOccurences $ checkExp arg
arg_t <- expType arg'
return (arg', (arg_t, dflow, srclocOf arg'))
checkApply :: SrcLoc -> PatternType -> Arg
-> TermTypeM (PatternType, PatternType)
checkApply loc (Scalar (Arrow as _ tp1 tp2)) (argtype, dflow, argloc) = do
unify (mkUsage argloc "use as function argument") (toStructural tp1) (toStructural argtype)
tp1' <- normaliseType tp1
tp2' <- normaliseType tp2
argtype' <- normaliseType argtype
occur [observation as loc]
checkOccurences dflow
occurs <- consumeArg argloc argtype' (diet tp1')
case anyConsumption dflow of
Just c ->
let msg = "of value computed with consumption at " ++ locStr (location c)
in zeroOrderType (mkUsage argloc "potential consumption in expression") msg tp1
_ -> return ()
occur $ dflow `seqOccurences` occurs
let tp2'' = anyDimShapeAnnotations $ returnType tp2' (diet tp1') argtype'
return (tp1', tp2'')
checkApply loc tfun@(Scalar TypeVar{}) arg = do
tv <- newTypeVar loc "b"
unify (mkUsage loc "use as function") (toStructural tfun) $
Scalar $ Arrow mempty Unnamed (toStructural (argType arg)) tv
constraints <- getConstraints
checkApply loc (applySubst (`lookupSubst` constraints) tfun) arg
checkApply loc ftype arg =
typeError loc $
"Attempt to apply an expression of type " ++ pretty ftype ++
" to an argument of type " ++ pretty (argType arg) ++ "."
returnType :: PatternType
-> Diet
-> PatternType
-> PatternType
returnType (Array _ Unique et shape) _ _ =
Array mempty Unique et shape
returnType (Array als Nonunique et shape) d arg =
Array (als<>arg_als) Unique et shape
where arg_als = aliases $ maskAliases arg d
returnType (Scalar (Record fs)) d arg =
Scalar $ Record $ fmap (\et -> returnType et d arg) fs
returnType (Scalar (Prim t)) _ _ =
Scalar $ Prim t
returnType (Scalar (TypeVar _ Unique t targs)) _ _ =
Scalar $ TypeVar mempty Unique t targs
returnType (Scalar (TypeVar als Nonunique t targs)) d arg =
Scalar $ TypeVar (als<>arg_als) Unique t targs
where arg_als = aliases $ maskAliases arg d
returnType (Scalar (Arrow _ v t1 t2)) d arg =
Scalar $ Arrow als v (t1 `setAliases` mempty) (t2 `setAliases` als)
where als = aliases $ maskAliases arg d
returnType (Scalar (Sum cs)) d arg =
Scalar $ Sum $ (fmap . fmap) (\et -> returnType et d arg) cs
maskAliases :: Monoid as =>
TypeBase shape as
-> Diet
-> TypeBase shape as
maskAliases t Consume = t `setAliases` mempty
maskAliases t Observe = t
maskAliases (Scalar (Record ets)) (RecordDiet ds) =
Scalar $ Record $ M.intersectionWith maskAliases ets ds
maskAliases t FuncDiet{} = t
maskAliases _ _ = error "Invalid arguments passed to maskAliases."
consumeArg :: SrcLoc -> PatternType -> Diet -> TermTypeM [Occurence]
consumeArg loc (Scalar (Record ets)) (RecordDiet ds) =
concat . M.elems <$> traverse (uncurry $ consumeArg loc) (M.intersectionWith (,) ets ds)
consumeArg loc (Array _ Nonunique _ _) Consume =
typeError loc "Consuming parameter passed non-unique argument."
consumeArg loc (Scalar (Arrow _ _ t1 _)) (FuncDiet d _)
| not $ contravariantArg t1 d =
typeError loc "Non-consuming higher-order parameter passed consuming argument."
where contravariantArg (Array _ Unique _ _) Observe =
False
contravariantArg (Scalar (TypeVar _ Unique _ _)) Observe =
False
contravariantArg (Scalar (Record ets)) (RecordDiet ds) =
and (M.intersectionWith contravariantArg ets ds)
contravariantArg (Scalar (Arrow _ _ tp tr)) (FuncDiet dp dr) =
contravariantArg tp dp && contravariantArg tr dr
contravariantArg _ _ =
True
consumeArg loc (Scalar (Arrow _ _ _ t2)) (FuncDiet _ pd) =
consumeArg loc t2 pd
consumeArg loc at Consume = return [consumption (aliases at) loc]
consumeArg loc at _ = return [observation (aliases at) loc]
checkOneExp :: UncheckedExp -> TypeM ([TypeParam], Exp)
checkOneExp e = fmap fst . runTermTypeM $ do
e' <- checkExp e
let t = toStruct $ typeOf e'
tparams <- letGeneralise [] t mempty
fixOverloadedTypes
e'' <- updateExpTypes e'
return (tparams, e'')
checkFunDef :: (Name, Maybe UncheckedTypeExp,
[UncheckedTypeParam], [UncheckedPattern],
UncheckedExp, SrcLoc)
-> TypeM (VName, [TypeParam], [Pattern], Maybe (TypeExp VName), StructType, Exp)
checkFunDef f = fmap fst $ runTermTypeM $ do
(fname, tparams, params, maybe_retdecl, rettype, body) <- checkFunDef' f
fixOverloadedTypes
body' <- updateExpTypes body
params' <- updateExpTypes params
maybe_retdecl' <- traverse updateExpTypes maybe_retdecl
rettype' <- normaliseType rettype
checkUnmatched body'
return (fname, tparams, params', maybe_retdecl', rettype', body')
fixOverloadedTypes :: TermTypeM ()
fixOverloadedTypes = getConstraints >>= mapM_ fixOverloaded . M.toList
where fixOverloaded (v, Overloaded ots usage)
| Signed Int32 `elem` ots = do
unify usage (Scalar (TypeVar () Nonunique (typeName v) [])) $
Scalar $ Prim $ Signed Int32
warn usage "Defaulting ambiguous type to `i32`."
| FloatType Float64 `elem` ots = do
unify usage (Scalar (TypeVar () Nonunique (typeName v) [])) $
Scalar $ Prim $ FloatType Float64
warn usage "Defaulting ambiguous type to `f64`."
| otherwise =
typeError usage $
unlines ["Type is ambiguous (could be one of " ++ intercalate ", " (map pretty ots) ++ ").",
"Add a type annotation to disambiguate the type."]
fixOverloaded (_, NoConstraint _ usage) =
typeError usage $ unlines ["Type of expression is ambiguous.",
"Add a type annotation to disambiguate the type."]
fixOverloaded (_, Equality usage) =
typeError usage $ unlines ["Type is ambiguous (must be equality type).",
"Add a type annotation to disambiguate the type."]
fixOverloaded (_, HasFields fs usage) =
typeError usage $ unlines ["Type is ambiguous. Must be record with fields:",
unlines $ map field $ M.toList fs,
"Add a type annotation to disambiguate the type."]
where field (l, t) = pretty $ indent 2 $ ppr l <> colon <+> align (ppr t)
fixOverloaded (_, HasConstrs cs usage) =
typeError usage $ unlines [ "Type is ambiguous (must be a sum type with constructors: " ++
pretty (Sum cs) ++ ")."
, "Add a type annotation to disambiguate the type."]
fixOverloaded _ = return ()
checkFunDef' :: (Name, Maybe UncheckedTypeExp,
[UncheckedTypeParam], [UncheckedPattern],
UncheckedExp, SrcLoc)
-> TermTypeM (VName, [TypeParam], [Pattern], Maybe (TypeExp VName), StructType, Exp)
checkFunDef' (fname, maybe_retdecl, tparams, params, body, loc) = noUnique $ do
when (nameToString fname == "&&") $
typeError loc "The && operator may not be redefined."
when (nameToString fname == "||") $
typeError loc "The || operator may not be redefined."
then_substs <- getConstraints
bindingPatternGroup tparams params $ \tparams' params' -> do
maybe_retdecl' <- traverse checkTypeExp maybe_retdecl
body' <- checkFunBody body ((\(_,t,_)->t) <$> maybe_retdecl') (maybe loc srclocOf maybe_retdecl)
params'' <- updateExpTypes params'
body_t <- expType body'
(maybe_retdecl'', rettype) <- case maybe_retdecl' of
Just (retdecl', retdecl_type, _) -> do
let rettype_structural = toStructural retdecl_type
checkReturnAlias rettype_structural params'' body_t
when (null params) $ nothingMustBeUnique loc rettype_structural
warnOnDubiousShapeAnnotations loc params'' retdecl_type
return (Just retdecl', retdecl_type)
Nothing
| null params ->
return (Nothing, toStruct $ body_t `setUniqueness` Nonunique)
| otherwise ->
return (Nothing, inferReturnUniqueness params'' body_t)
let fun_t = foldFunType (map patternStructType params'') rettype
tparams'' <- letGeneralise tparams' fun_t then_substs
bindSpaced [(Term, fname)] $ do
fname' <- checkName Term fname loc
checkGlobalAliases params'' body_t loc
return (fname', tparams'', params'', maybe_retdecl'', rettype, body')
where
checkReturnAlias rettp params' =
foldM_ (checkReturnAlias' params') S.empty . returnAliasing rettp
checkReturnAlias' params' seen (Unique, names)
| any (`S.member` S.map snd seen) $ S.toList names =
uniqueReturnAliased fname loc
| otherwise = do
notAliasingParam params' names
return $ seen `S.union` tag Unique names
checkReturnAlias' _ seen (Nonunique, names)
| any (`S.member` seen) $ S.toList $ tag Unique names =
uniqueReturnAliased fname loc
| otherwise = return $ seen `S.union` tag Nonunique names
notAliasingParam params' names =
forM_ params' $ \p ->
let consumedNonunique p' =
not (unique $ unInfo $ identType p') && (identName p' `S.member` names)
in case find consumedNonunique $ S.toList $ patternIdents p of
Just p' ->
returnAliased fname (baseName $ identName p') loc
Nothing ->
return ()
tag u = S.map (u,)
returnAliasing (Scalar (Record ets1)) (Scalar (Record ets2)) =
concat $ M.elems $ M.intersectionWith returnAliasing ets1 ets2
returnAliasing expected got =
[(uniqueness expected, S.map aliasVar $ aliases got)]
warnOnDubiousShapeAnnotations :: SrcLoc -> [Pattern] -> StructType -> TermTypeM ()
warnOnDubiousShapeAnnotations loc params rettype =
onDubiousNames $ S.filter patternNameButNotParamName $
mconcat $ map typeDimNames $
rettype : map patternStructType params
where named (Named v) = Just v
named Unnamed = Nothing
param_names = S.fromList $ mapMaybe (named . fst . patternParam) params
all_pattern_names = S.map identName $ mconcat $ map patternIdents params
patternNameButNotParamName v = v `S.member` all_pattern_names && not (v `S.member` param_names)
onDubiousNames dubious
| S.null dubious = return ()
| otherwise = warn loc $ unlines
[ "Size annotations in parameter and/or return type refers to the following names,"
, "which will not be visible to the caller, because they are nested in tuples or records:"
, " " ++ intercalate ", " (map (quote . prettyName) $ S.toList dubious)
, "To eliminate this warning, make these names parameters on their own."]
checkGlobalAliases :: [Pattern] -> PatternType -> SrcLoc -> TermTypeM ()
checkGlobalAliases params body_t loc = do
vtable <- asks scopeVtable
let isLocal v = case v `M.lookup` vtable of
Just (BoundV Local _ _) -> True
_ -> False
let als = filter (not . isLocal) $ S.toList $
boundArrayAliases body_t `S.difference`
S.map identName (mconcat (map patternIdents params))
case als of
v:_ | not $ null params ->
typeError loc $
unlines [ "Function result aliases the free variable " <>
quote (prettyName v) <> "."
, "Use " ++ quote "copy" ++ " to break the aliasing."]
_ ->
return ()
inferReturnUniqueness :: [Pattern] -> PatternType -> StructType
inferReturnUniqueness params t =
let forbidden = aliasesMultipleTimes t
uniques = uniqueParamNames params
delve (Scalar (Record fs)) =
Scalar $ Record $ M.map delve fs
delve t'
| all (`S.member` uniques) (boundArrayAliases t'),
not $ any ((`S.member` forbidden) . aliasVar) (aliases t') =
toStruct t'
| otherwise =
toStruct $ t' `setUniqueness` Nonunique
in delve t
aliasesMultipleTimes :: PatternType -> Names
aliasesMultipleTimes = S.fromList . map fst . filter ((>1) . snd) . M.toList . delve
where delve (Scalar (Record fs)) =
foldl' (M.unionWith (+)) mempty $ map delve $ M.elems fs
delve t =
M.fromList $ zip (map aliasVar $ S.toList (aliases t)) $ repeat (1::Int)
uniqueParamNames :: [Pattern] -> Names
uniqueParamNames =
S.fromList . map identName
. filter (unique . unInfo . identType)
. S.toList . mconcat . map patternIdents
boundArrayAliases :: PatternType -> S.Set VName
boundArrayAliases (Array als _ _ _) = boundAliases als
boundArrayAliases (Scalar Prim{}) = mempty
boundArrayAliases (Scalar (Record fs)) = foldMap boundArrayAliases fs
boundArrayAliases (Scalar (TypeVar als _ _ _)) = boundAliases als
boundArrayAliases (Scalar Arrow{}) = mempty
boundArrayAliases (Scalar (Sum fs)) =
mconcat $ concatMap (map boundArrayAliases) $ M.elems fs
boundAliases :: Aliasing -> S.Set VName
boundAliases = S.map aliasVar . S.filter bound
where bound AliasBound{} = True
bound AliasFree{} = False
nothingMustBeUnique :: SrcLoc -> TypeBase () () -> TermTypeM ()
nothingMustBeUnique loc = check
where check (Array _ Unique _ _) = bad
check (Scalar (TypeVar _ Unique _ _)) = bad
check (Scalar (Record fs)) = mapM_ check fs
check _ = return ()
bad = typeError loc "A top-level constant cannot have a unique type."
letGeneralise :: [TypeParam]
-> StructType
-> Constraints
-> TermTypeM [TypeParam]
letGeneralise tparams t then_substs = do
now_substs <- getConstraints
let then_type_variables = S.fromList $ M.keys then_substs
then_type_constraints = constraintTypeVars $
M.filterWithKey (\k _ -> k `S.member` then_type_variables) now_substs
keep_type_variables = then_type_variables <>
then_type_constraints <>
overloadedTypeVars now_substs
let new_substs = M.filterWithKey (\k _ -> not (k `S.member` keep_type_variables)) now_substs
tparams' <- closeOverTypes new_substs tparams t
modifyConstraints $ M.filterWithKey $ \k _ -> k `notElem` map typeParamName tparams'
return tparams'
checkFunBody :: ExpBase NoInfo Name
-> Maybe StructType
-> SrcLoc
-> TermTypeM Exp
checkFunBody body maybe_rettype _loc = do
body' <- checkExp body
case maybe_rettype of
Just rettype -> do
let rettype_structural = toStructural rettype
void $ unifies "return type annotation" rettype_structural body'
rettype' <- normaliseType rettype
body_t <- expType body'
unless (body_t `subtypeOf` anyDimShapeAnnotations rettype') $
typeError (srclocOf body) $ "Body type " ++ quote (pretty body_t) ++
" is not a subtype of annotated type " ++
quote (pretty rettype') ++ "."
Nothing -> return ()
return body'
closeOverTypes :: Constraints -> [TypeParam] -> StructType -> TermTypeM [TypeParam]
closeOverTypes substs tparams t =
fmap ((tparams++) . catMaybes) $ mapM closeOver $ M.toList to_close_over
where to_close_over = M.filterWithKey (\k _ -> k `S.member` visible) substs
visible = typeVars t
closeOver (k, _)
| k `elem` map typeParamName tparams =
return Nothing
closeOver (k, NoConstraint (Just Unlifted) usage) =
return $ Just $ TypeParamType Unlifted k $ srclocOf usage
closeOver (k, NoConstraint _ usage) =
return $ Just $ TypeParamType Lifted k $ srclocOf usage
closeOver (k, ParamType l loc) =
return $ Just $ TypeParamType l k loc
closeOver (_, _) =
return Nothing
occur :: Occurences -> TermTypeM ()
occur = tell
observe :: Ident -> TermTypeM ()
observe (Ident nm (Info t) loc) =
let als = AliasBound nm `S.insert` aliases t
in occur [observation als loc]
consume :: SrcLoc -> Aliasing -> TermTypeM ()
consume loc als = do
vtable <- asks scopeVtable
let consumable v = case M.lookup v vtable of
Just (BoundV Local _ t)
| arrayRank t > 0 -> unique t
| otherwise -> True
_ -> False
case filter (not . consumable) $ map aliasVar $ S.toList als of
v:_ -> typeError loc $ "Attempt to consume variable " ++ quote (prettyName v)
++ ", which is not allowed."
[] -> occur [consumption als loc]
consuming :: Ident -> TermTypeM a -> TermTypeM a
consuming (Ident name (Info t) loc) m = do
consume loc $ AliasBound name `S.insert` aliases t
local consume' m
where consume' scope =
scope { scopeVtable = M.insert name (WasConsumed loc) $ scopeVtable scope }
collectOccurences :: TermTypeM a -> TermTypeM (a, Occurences)
collectOccurences m = pass $ do
(x, dataflow) <- listen m
return ((x, dataflow), const mempty)
tapOccurences :: TermTypeM a -> TermTypeM (a, Occurences)
tapOccurences = listen
removeSeminullOccurences :: TermTypeM a -> TermTypeM a
removeSeminullOccurences = censor $ filter $ not . seminullOccurence
checkIfUsed :: Occurences -> Ident -> TermTypeM ()
checkIfUsed occs v
| not $ identName v `S.member` allOccuring occs,
not $ "_" `isPrefixOf` prettyName (identName v) =
warn (srclocOf v) $ "Unused variable " ++ quote (pretty $ baseName $ identName v) ++ "."
| otherwise =
return ()
alternative :: TermTypeM a -> TermTypeM b -> TermTypeM (a,b)
alternative m1 m2 = pass $ do
(x, occurs1) <- listen m1
(y, occurs2) <- listen m2
checkOccurences occurs1
checkOccurences occurs2
let usage = occurs1 `altOccurences` occurs2
return ((x, y), const usage)
noUnique :: TermTypeM a -> TermTypeM a
noUnique = local (\scope -> scope { scopeVtable = M.map set $ scopeVtable scope})
where set (BoundV l tparams t) = BoundV l tparams $ t `setUniqueness` Nonunique
set (OverloadedF ts pts rt) = OverloadedF ts pts rt
set EqualityF = EqualityF
set OpaqueF = OpaqueF
set (WasConsumed loc) = WasConsumed loc
onlySelfAliasing :: TermTypeM a -> TermTypeM a
onlySelfAliasing = local (\scope -> scope { scopeVtable = M.mapWithKey set $ scopeVtable scope})
where set k (BoundV l tparams t) = BoundV l tparams $
t `addAliases` S.intersection (S.singleton (AliasBound k))
set _ (OverloadedF ts pts rt) = OverloadedF ts pts rt
set _ EqualityF = EqualityF
set _ OpaqueF = OpaqueF
set _ (WasConsumed loc) = WasConsumed loc
arrayOfM :: (Pretty (ShapeDecl dim), Monoid as) =>
SrcLoc
-> TypeBase dim as -> ShapeDecl dim -> Uniqueness
-> TermTypeM (TypeBase dim as)
arrayOfM loc t shape u = do
zeroOrderType (mkUsage loc "use as array element") "used in array" t
return $ arrayOf t shape u
updateExpTypes :: ASTMappable e => e -> TermTypeM e
updateExpTypes e = do
constraints <- getConstraints
let look = (`lookupSubst` constraints)
tv = ASTMapper { mapOnExp = astMap tv
, mapOnName = pure
, mapOnQualName = pure
, mapOnStructType = pure . applySubst look
, mapOnPatternType = pure . applySubst look
}
astMap tv e