{-# LANGUAGE TupleSections #-} module Language.Futhark.TypeChecker.Modules ( matchMTys , newNamesForMTy , refineEnv , applyFunctor ) where import Control.Monad.Except import Control.Monad.Writer hiding (Sum) import Data.List import Data.Loc import Data.Maybe import Data.Either import Data.Ord import qualified Data.Map.Strict as M import qualified Data.Set as S import Prelude hiding (abs, mod) import Language.Futhark import Language.Futhark.Semantic import Language.Futhark.TypeChecker.Monad import Language.Futhark.TypeChecker.Unify (doUnification) import Language.Futhark.TypeChecker.Types import Futhark.Util.Pretty (Pretty) substituteTypesInMod :: TypeSubs -> Mod -> Mod substituteTypesInMod substs (ModEnv e) = ModEnv $ substituteTypesInEnv substs e substituteTypesInMod substs (ModFun (FunSig abs mod mty)) = ModFun $ FunSig abs (substituteTypesInMod substs mod) (substituteTypesInMTy substs mty) substituteTypesInMTy :: TypeSubs -> MTy -> MTy substituteTypesInMTy substs (MTy abs mod) = MTy abs $ substituteTypesInMod substs mod substituteTypesInEnv :: TypeSubs -> Env -> Env substituteTypesInEnv substs env = env { envVtable = M.map (substituteTypesInBoundV substs) $ envVtable env , envTypeTable = M.mapWithKey subT $ envTypeTable env , envModTable = M.map (substituteTypesInMod substs) $ envModTable env } where subT name _ | Just (TypeSub (TypeAbbr l ps t)) <- M.lookup name substs = TypeAbbr l ps t subT _ (TypeAbbr l ps t) = TypeAbbr l ps $ substituteTypes substs t substituteTypesInBoundV :: TypeSubs -> BoundV -> BoundV substituteTypesInBoundV substs (BoundV tps t) = BoundV tps (substituteTypes substs t) allNamesInMTy :: MTy -> S.Set VName allNamesInMTy (MTy abs mod) = S.fromList (map qualLeaf $ M.keys abs) <> allNamesInMod mod allNamesInMod :: Mod -> S.Set VName allNamesInMod (ModEnv env) = allNamesInEnv env allNamesInMod ModFun{} = mempty -- All names defined anywhere in the env. allNamesInEnv :: Env -> S.Set VName allNamesInEnv (Env vtable ttable stable modtable _names) = S.fromList (M.keys vtable ++ M.keys ttable ++ M.keys stable ++ M.keys modtable) <> mconcat (map allNamesInMTy (M.elems stable) ++ map allNamesInMod (M.elems modtable) ++ map allNamesInType (M.elems ttable)) where allNamesInType (TypeAbbr _ ps _) = S.fromList $ map typeParamName ps newNamesForMTy :: MTy -> TypeM (MTy, M.Map VName VName) newNamesForMTy orig_mty = do -- Create unique renames for the module type. pairs <- forM (S.toList $ allNamesInMTy orig_mty) $ \v -> do v' <- newName v return (v, v') let substs = M.fromList pairs rev_substs = M.fromList $ map (uncurry $ flip (,)) pairs return (substituteInMTy substs orig_mty, rev_substs) where substituteInMTy :: M.Map VName VName -> MTy -> MTy substituteInMTy substs (MTy mty_abs mty_mod) = MTy (M.mapKeys (fmap substitute) mty_abs) (substituteInMod mty_mod) where substituteInEnv (Env vtable ttable _stable modtable names) = let vtable' = substituteInMap substituteInBinding vtable ttable' = substituteInMap substituteInTypeBinding ttable mtable' = substituteInMap substituteInMod modtable in Env { envVtable = vtable' , envTypeTable = ttable' , envSigTable = mempty , envModTable = mtable' , envNameMap = M.map (fmap substitute) names } substitute v = fromMaybe v $ M.lookup v substs substituteInMap f m = let (ks, vs) = unzip $ M.toList m in M.fromList $ zip (map (\k -> fromMaybe k $ M.lookup k substs) ks) (map f vs) substituteInBinding (BoundV ps t) = BoundV (map substituteInTypeParam ps) (substituteInType t) substituteInMod (ModEnv env) = ModEnv $ substituteInEnv env substituteInMod (ModFun funsig) = ModFun $ substituteInFunSig funsig substituteInFunSig (FunSig abs mod mty) = FunSig (M.mapKeys (fmap substitute) abs) (substituteInMod mod) (substituteInMTy substs mty) substituteInTypeBinding (TypeAbbr l ps t) = TypeAbbr l (map substituteInTypeParam ps) $ substituteInType t substituteInTypeParam (TypeParamDim p loc) = TypeParamDim (substitute p) loc substituteInTypeParam (TypeParamType l p loc) = TypeParamType l (substitute p) loc substituteInType :: StructType -> StructType substituteInType (Scalar (TypeVar () u (TypeName qs v) targs)) = Scalar $ TypeVar () u (TypeName (map substitute qs) $ substitute v) $ map substituteInTypeArg targs substituteInType (Scalar (Prim t)) = Scalar $ Prim t substituteInType (Scalar (Record ts)) = Scalar $ Record $ fmap substituteInType ts substituteInType (Scalar (Sum ts)) = Scalar $ Sum $ (fmap . fmap) substituteInType ts substituteInType (Array () u t shape) = arrayOf (substituteInType $ Scalar t) (substituteInShape shape) u substituteInType (Scalar (Arrow als v t1 t2)) = Scalar $ Arrow als v (substituteInType t1) (substituteInType t2) substituteInShape (ShapeDecl ds) = ShapeDecl $ map substituteInDim ds substituteInDim (NamedDim (QualName qs v)) = NamedDim $ QualName (map substitute qs) $ substitute v substituteInDim d = d substituteInTypeArg (TypeArgDim (NamedDim (QualName qs v)) loc) = TypeArgDim (NamedDim $ QualName (map substitute qs) $ substitute v) loc substituteInTypeArg (TypeArgDim (ConstDim x) loc) = TypeArgDim (ConstDim x) loc substituteInTypeArg (TypeArgDim AnyDim loc) = TypeArgDim AnyDim loc substituteInTypeArg (TypeArgType t loc) = TypeArgType (substituteInType t) loc mtyTypeAbbrs :: MTy -> M.Map VName TypeBinding mtyTypeAbbrs (MTy _ mod) = modTypeAbbrs mod modTypeAbbrs :: Mod -> M.Map VName TypeBinding modTypeAbbrs (ModEnv env) = envTypeAbbrs env modTypeAbbrs (ModFun (FunSig _ mod mty)) = modTypeAbbrs mod <> mtyTypeAbbrs mty envTypeAbbrs :: Env -> M.Map VName TypeBinding envTypeAbbrs env = envTypeTable env <> (mconcat . map modTypeAbbrs . M.elems . envModTable) env -- | Refine the given type name in the given env. refineEnv :: SrcLoc -> TySet -> Env -> QualName Name -> [TypeParam] -> StructType -> TypeM (QualName VName, TySet, Env) refineEnv loc tset env tname ps t | Just (tname', TypeAbbr l cur_ps (Scalar (TypeVar () _ (TypeName qs v) _))) <- findTypeDef tname (ModEnv env), QualName (qualQuals tname') v `M.member` tset = if paramsMatch cur_ps ps then return (tname', QualName qs v `M.delete` tset, substituteTypesInEnv (M.fromList [(qualLeaf tname', TypeSub $ TypeAbbr l cur_ps t), (v, TypeSub $ TypeAbbr l ps t)]) env) else throwError $ TypeError loc $ "Cannot refine a type having " <> tpMsg ps <> " with a type having " <> tpMsg cur_ps <> "." | otherwise = throwError $ TypeError loc $ pretty tname ++ " is not an abstract type in the module type." where tpMsg [] = "no type parameters" tpMsg xs = "type parameters " <> unwords (map pretty xs) paramsMatch :: [TypeParam] -> [TypeParam] -> Bool paramsMatch ps1 ps2 = length ps1 == length ps2 && all match (zip ps1 ps2) where match (TypeParamType l1 _ _, TypeParamType l2 _ _) = l1 <= l2 match (TypeParamDim _ _, TypeParamDim _ _) = True match _ = False findBinding :: (Env -> M.Map VName v) -> Namespace -> Name -> Env -> Maybe (VName, v) findBinding table namespace name the_env = do QualName _ name' <- M.lookup (namespace, name) $ envNameMap the_env (name',) <$> M.lookup name' (table the_env) findTypeDef :: QualName Name -> Mod -> Maybe (QualName VName, TypeBinding) findTypeDef _ ModFun{} = Nothing findTypeDef (QualName [] name) (ModEnv the_env) = do (name', tb) <- findBinding envTypeTable Type name the_env return (qualName name', tb) findTypeDef (QualName (q:qs) name) (ModEnv the_env) = do (q', q_mod) <- findBinding envModTable Term q the_env (QualName qs' name', tb) <- findTypeDef (QualName qs name) q_mod return (QualName (q':qs') name', tb) resolveAbsTypes :: TySet -> Mod -> TySet -> SrcLoc -> Either TypeError (M.Map VName (QualName VName, TypeBinding)) resolveAbsTypes mod_abs mod sig_abs loc = do let abs_mapping = M.fromList $ zip (map (fmap baseName . fst) $ M.toList mod_abs) (M.toList mod_abs) fmap M.fromList $ forM (M.toList sig_abs) $ \(name, name_l) -> case findTypeDef (fmap baseName name) mod of Just (name', TypeAbbr mod_l ps t) | Unlifted <- name_l, not (orderZero t) || mod_l == Lifted -> mismatchedLiftedness (map qualLeaf $ M.keys mod_abs) name (ps, t) | Just (abs_name, _) <- M.lookup (fmap baseName name) abs_mapping -> return (qualLeaf name, (abs_name, TypeAbbr name_l ps t)) | otherwise -> return (qualLeaf name, (name', TypeAbbr name_l ps t)) _ -> missingType loc $ fmap baseName name where mismatchedLiftedness abs name mod_t = Left $ TypeError loc $ unlines ["Module defines", indent $ ppTypeAbbr abs name mod_t, "but module type requires this type to be non-functional."] missingType :: Pretty a => SrcLoc -> a -> Either TypeError b missingType loc name = Left $ TypeError loc $ "Module does not define a type named " ++ pretty name ++ "." missingVal :: Pretty a => SrcLoc -> a -> Either TypeError b missingVal loc name = Left $ TypeError loc $ "Module does not define a value named " ++ pretty name ++ "." missingMod :: Pretty a => SrcLoc -> a -> Either TypeError b missingMod loc name = Left $ TypeError loc $ "Module does not define a module named " ++ pretty name ++ "." mismatchedType :: Pretty a => SrcLoc -> [VName] -> a -> ([TypeParam], StructType) -> ([TypeParam], StructType) -> Either TypeError b mismatchedType loc abs name spec_t env_t = Left $ TypeError loc $ unlines ["Module defines", indent $ ppTypeAbbr abs name env_t, "but module type requires", indent $ ppTypeAbbr abs name spec_t] indent :: String -> String indent = intercalate "\n" . map (" "++) . lines ppTypeAbbr :: Pretty a => [VName] -> a -> ([TypeParam], StructType) -> String ppTypeAbbr abs name (ps, t) = "type " ++ unwords (pretty name : map pretty ps) ++ t' where t' = case t of Scalar (TypeVar () _ tn args) | typeLeaf tn `elem` abs, map typeParamToArg ps == args -> "" _ -> " = " ++ pretty t -- Return new renamed/abstracted env, as well as a mapping from -- names in the signature to names in the new env. This is used for -- functor application. The first env is the module env, and the -- second the env it must match. matchMTys :: MTy -> MTy -> SrcLoc -> Either TypeError (M.Map VName VName) matchMTys = matchMTys' mempty where matchMTys' :: TypeSubs -> MTy -> MTy -> SrcLoc -> Either TypeError (M.Map VName VName) matchMTys' _ (MTy _ ModFun{}) (MTy _ ModEnv{}) loc = Left $ TypeError loc "Cannot match parametric module with non-paramatric module type." matchMTys' _ (MTy _ ModEnv{}) (MTy _ ModFun{}) loc = Left $ TypeError loc "Cannot match non-parametric module with paramatric module type." matchMTys' old_abs_subst_to_type (MTy mod_abs mod) (MTy sig_abs sig) loc = do -- Check that abstract types in 'sig' have an implementation in -- 'mod'. This also gives us a substitution that we use to check -- the types of values. abs_substs <- resolveAbsTypes mod_abs mod sig_abs loc let abs_subst_to_type = old_abs_subst_to_type <> M.map (TypeSub . snd) abs_substs abs_name_substs = M.map (qualLeaf . fst) abs_substs substs <- matchMods abs_subst_to_type mod sig loc return (substs <> abs_name_substs) matchMods :: TypeSubs -> Mod -> Mod -> SrcLoc -> Either TypeError (M.Map VName VName) matchMods _ ModEnv{} ModFun{} loc = Left $ TypeError loc "Cannot match non-parametric module with paramatric module type." matchMods _ ModFun{} ModEnv{} loc = Left $ TypeError loc "Cannot match parametric module with non-paramatric module type." matchMods abs_subst_to_type (ModEnv mod) (ModEnv sig) loc = matchEnvs abs_subst_to_type mod sig loc matchMods old_abs_subst_to_type (ModFun (FunSig mod_abs mod_pmod mod_mod)) (ModFun (FunSig sig_abs sig_pmod sig_mod)) loc = do abs_substs <- resolveAbsTypes mod_abs mod_pmod sig_abs loc let abs_subst_to_type = old_abs_subst_to_type <> M.map (TypeSub . snd) abs_substs abs_name_substs = M.map (qualLeaf . fst) abs_substs pmod_substs <- matchMods abs_subst_to_type mod_pmod sig_pmod loc mod_substs <- matchMTys' abs_subst_to_type mod_mod sig_mod loc return (pmod_substs <> mod_substs <> abs_name_substs) matchEnvs :: TypeSubs -> Env -> Env -> SrcLoc -> Either TypeError (M.Map VName VName) matchEnvs abs_subst_to_type env sig loc = do -- XXX: we only want to create substitutions for visible names. -- This must be wrong in some cases. Probably we need to -- rethink how we do shadowing for module types. let visible = S.fromList $ map qualLeaf $ M.elems $ envNameMap sig isVisible name = name `S.member` visible -- Check that all values are defined correctly, substituting the -- abstract types first. val_substs <- fmap M.fromList $ forM (M.toList $ envVtable sig) $ \(name, spec_bv) -> do let spec_bv' = substituteTypesInBoundV abs_subst_to_type spec_bv case findBinding envVtable Term (baseName name) env of Just (name', bv) -> matchVal loc name spec_bv' name' bv _ -> missingVal loc (baseName name) -- Check that all type abbreviations are correctly defined. abbr_name_substs <- fmap M.fromList $ forM (filter (isVisible . fst) $ M.toList $ envTypeTable sig) $ \(name, TypeAbbr _ spec_ps spec_t) -> case findBinding envTypeTable Type (baseName name) env of Just (name', TypeAbbr _ ps t) -> matchTypeAbbr loc abs_subst_to_type val_substs name spec_ps spec_t name' ps t Nothing -> missingType loc $ baseName name -- Check for correct modules. mod_substs <- fmap M.unions $ forM (M.toList $ envModTable sig) $ \(name, modspec) -> case findBinding envModTable Term (baseName name) env of Just (name', mod) -> M.insert name name' <$> matchMods abs_subst_to_type mod modspec loc Nothing -> missingMod loc $ baseName name return $ val_substs <> mod_substs <> abbr_name_substs matchTypeAbbr :: SrcLoc -> TypeSubs -> M.Map VName VName -> VName -> [TypeParam] -> StructType -> VName -> [TypeParam] -> StructType -> Either TypeError (VName, VName) matchTypeAbbr loc abs_subst_to_type val_substs spec_name spec_ps spec_t name ps t = do -- We have to create substitutions for the type parameters, too. unless (length spec_ps == length ps) nomatch param_substs <- mconcat <$> zipWithM matchTypeParam spec_ps ps let val_substs' = M.map (DimSub . NamedDim . qualName) val_substs spec_t' = substituteTypes (val_substs'<>param_substs<>abs_subst_to_type) spec_t if spec_t' == t then return (spec_name, name) else nomatch where nomatch = mismatchedType loc (M.keys abs_subst_to_type) (baseName spec_name) (spec_ps, spec_t) (ps, t) matchTypeParam (TypeParamDim x _) (TypeParamDim y _) = pure $ M.singleton x $ DimSub $ NamedDim $ qualName y matchTypeParam (TypeParamType Unlifted x _) (TypeParamType Unlifted y _) = pure $ M.singleton x $ TypeSub $ TypeAbbr Unlifted [] $ Scalar $ TypeVar () Nonunique (typeName y) [] matchTypeParam (TypeParamType _ x _) (TypeParamType Lifted y _) = pure $ M.singleton x $ TypeSub $ TypeAbbr Lifted [] $ Scalar $ TypeVar () Nonunique (typeName y) [] matchTypeParam _ _ = nomatch matchVal :: SrcLoc -> VName -> BoundV -> VName -> BoundV -> Either TypeError (VName, VName) matchVal loc spec_name spec_t name t | matchFunBinding loc spec_t t = return (spec_name, name) matchVal loc spec_name spec_v _ v = Left $ TypeError loc $ unlines $ ["Module type specifies"] ++ map (" "++) (lines $ ppValBind spec_name spec_v) ++ ["but module provides"] ++ map (" "++) (lines $ppValBind spec_name v) matchFunBinding :: SrcLoc -> BoundV -> BoundV -> Bool matchFunBinding loc (BoundV _ orig_spec_t) (BoundV tps orig_t) = -- Would be nice if we could propagate the actual error here. case doUnification loc tps (toStructural orig_spec_t) (toStructural orig_t) of Left _ -> False Right t -> t `subtypeOf` toStructural orig_spec_t ppValBind v (BoundV tps t) = unwords $ ["val", prettyName v] ++ map pretty tps ++ [":", pretty t] applyFunctor :: SrcLoc -> FunSig -> MTy -> TypeM (MTy, M.Map VName VName, M.Map VName VName) applyFunctor applyloc (FunSig p_abs p_mod body_mty) a_mty = do p_subst <- badOnLeft $ matchMTys a_mty (MTy p_abs p_mod) applyloc -- Apply type abbreviations from a_mty to body_mty. let a_abbrs = mtyTypeAbbrs a_mty isSub v = case M.lookup v a_abbrs of Just abbr -> Just $ TypeSub abbr _ -> Just $ DimSub $ NamedDim $ qualName v type_subst = M.mapMaybe isSub p_subst body_mty' = substituteTypesInMTy type_subst body_mty (body_mty'', body_subst) <- newNamesForMTy body_mty' return (body_mty'', p_subst, body_subst)