-- | This monomorphization module converts a well-typed, polymorphic, -- module-free Futhark program into an equivalent monomorphic program. -- -- This pass also does a few other simplifications to make the job of -- subsequent passes easier. Specifically, it does the following: -- -- * Turn operator sections into explicit lambdas. -- -- * Converts identifiers of record type into record patterns (and -- similarly for tuples). -- -- * Converts applications of intrinsic SOACs into SOAC AST nodes -- (Map, Reduce, etc). -- -- * Elide functions that are not reachable from an entry point (this -- is a side effect of the monomorphisation algorithm, which uses -- the entry points as roots). -- -- * Turns implicit record fields into explicit record fields. -- -- Note that these changes are unfortunately not visible in the AST -- representation. {-# LANGUAGE GeneralizedNewtypeDeriving #-} module Futhark.Internalise.Monomorphise ( transformProg , transformDecs , runMonoM ) where import Control.Monad.RWS import Control.Monad.State import Data.Loc import qualified Data.Map.Strict as M import qualified Data.Sequence as Seq import Data.Foldable import Futhark.MonadFreshNames import Language.Futhark import Language.Futhark.Traversals import Language.Futhark.TypeChecker.Monad (TypeBinding(..)) import Language.Futhark.TypeChecker.Types -- | The monomorphization monad reads 'PolyBinding's and writes 'ValBinding's. -- The 'TypeParam's in a 'ValBinding' can only be shape parameters. -- -- Each 'Polybinding' is also connected with the 'RecordReplacements' -- that were active when the binding was defined. This is used only -- in local functions. data PolyBinding = PolyBinding RecordReplacements (VName, [TypeParam], [Pattern], Maybe (TypeExp VName), StructType, Exp, SrcLoc) -- | Mapping from record names to the variable names that contain the -- fields. This is used because the monomorphiser also expands all -- record patterns. type RecordReplacements = M.Map VName RecordReplacement type RecordReplacement = M.Map Name (VName, PatternType) -- | Monomorphization environment mapping names of polymorphic functions to a -- representation of their corresponding function bindings. data Env = Env { envPolyBindings :: M.Map VName PolyBinding , envTypeBindings :: M.Map VName TypeBinding , envRecordReplacements :: RecordReplacements } instance Semigroup Env where Env tb1 pb1 rr1 <> Env tb2 pb2 rr2 = Env (tb1 <> tb2) (pb1 <> pb2) (rr1 <> rr2) instance Monoid Env where mempty = Env mempty mempty mempty localEnv :: Env -> MonoM a -> MonoM a localEnv env = local (env <>) extendEnv :: VName -> PolyBinding -> MonoM a -> MonoM a extendEnv vn binding = localEnv mempty { envPolyBindings = M.singleton vn binding } withRecordReplacements :: RecordReplacements -> MonoM a -> MonoM a withRecordReplacements rr = localEnv mempty { envRecordReplacements = rr } replaceRecordReplacements :: RecordReplacements -> MonoM a -> MonoM a replaceRecordReplacements rr = local $ \env -> env { envRecordReplacements = rr } -- | The monomorphization monad. newtype MonoM a = MonoM (RWST Env (Seq.Seq (VName, ValBind)) VNameSource (State Lifts) a) deriving (Functor, Applicative, Monad, MonadReader Env, MonadWriter (Seq.Seq (VName, ValBind)), MonadFreshNames) runMonoM :: VNameSource -> MonoM a -> ((a, Seq.Seq (VName, ValBind)), VNameSource) runMonoM src (MonoM m) = ((a, defs), src') where (a, src', defs) = evalState (runRWST m mempty src) mempty lookupFun :: VName -> MonoM (Maybe PolyBinding) lookupFun vn = do env <- asks envPolyBindings case M.lookup vn env of Just valbind -> return $ Just valbind Nothing -> return Nothing lookupRecordReplacement :: VName -> MonoM (Maybe RecordReplacement) lookupRecordReplacement v = asks $ M.lookup v . envRecordReplacements -- | Mapping from function name and instance list to a new function name in case -- the function has already been instantiated with those concrete types. type Lifts = [((VName, TypeBase () ()), VName)] getLifts :: MonoM Lifts getLifts = MonoM $ lift get modifyLifts :: (Lifts -> Lifts) -> MonoM () modifyLifts = MonoM . lift . modify addLifted :: VName -> TypeBase () () -> VName -> MonoM () addLifted fname il lifted_fname = modifyLifts (((fname, il), lifted_fname) :) lookupLifted :: VName -> TypeBase () () -> MonoM (Maybe VName) lookupLifted fname t = lookup (fname, t) <$> getLifts transformFName :: VName -> TypeBase () () -> MonoM VName transformFName fname t | baseTag fname <= maxIntrinsicTag = return fname | otherwise = do maybe_fname <- lookupLifted fname t maybe_funbind <- lookupFun fname case (maybe_fname, maybe_funbind) of -- The function has already been monomorphized. (Just fname', _) -> return fname' -- An intrinsic function. (Nothing, Nothing) -> return fname -- A polymorphic function. (Nothing, Just funbind) -> do (fname', funbind') <- monomorphizeBinding funbind t tell $ Seq.singleton (fname, funbind') addLifted fname t fname' return fname' -- | Monomorphization of expressions. transformExp :: Exp -> MonoM Exp transformExp e@Literal{} = return e transformExp e@IntLit{} = return e transformExp e@FloatLit{} = return e transformExp (Parens e loc) = Parens <$> transformExp e <*> pure loc transformExp (QualParens qn e loc) = QualParens qn <$> transformExp e <*> pure loc transformExp (TupLit es loc) = TupLit <$> mapM transformExp es <*> pure loc transformExp (RecordLit fs loc) = RecordLit <$> mapM transformField fs <*> pure loc where transformField (RecordFieldExplicit name e loc') = RecordFieldExplicit name <$> transformExp e <*> pure loc' transformField (RecordFieldImplicit v t _) = transformField $ RecordFieldExplicit (baseName v) (Var (qualName v) (vacuousShapeAnnotations <$> t) loc) loc transformExp (ArrayLit es tp loc) = ArrayLit <$> mapM transformExp es <*> pure tp <*> pure loc transformExp (Range e1 me incl tp loc) = do e1' <- transformExp e1 me' <- mapM transformExp me incl' <- mapM transformExp incl return $ Range e1' me' incl' tp loc transformExp (Var (QualName qs fname) (Info t) loc) = do maybe_fs <- lookupRecordReplacement fname case maybe_fs of Just fs -> do let toField (f, (f_v, f_t)) = let f_v' = Var (qualName f_v) (Info $ vacuousShapeAnnotations f_t) loc in RecordFieldExplicit f f_v' loc return $ RecordLit (map toField $ M.toList fs) loc Nothing -> do fname' <- transformFName fname (toStructural t) return $ Var (QualName qs fname') (Info t) loc transformExp (Ascript e tp loc) = Ascript <$> transformExp e <*> pure tp <*> pure loc transformExp (LetPat tparams pat e1 e2 loc) = do (pat', rr) <- expandRecordPattern pat LetPat tparams pat' <$> transformExp e1 <*> withRecordReplacements rr (transformExp e2) <*> pure loc transformExp (LetFun fname (tparams, params, retdecl, Info ret, body) e loc) | any isTypeParam tparams = do -- Retrieve the lifted monomorphic function bindings that are produced, -- filter those that are monomorphic versions of the current let-bound -- function and insert them at this point, and propagate the rest. rr <- asks envRecordReplacements let funbind = PolyBinding rr (fname, tparams, params, retdecl, ret, body, loc) pass $ do (e', bs) <- listen $ extendEnv fname funbind $ transformExp e let (bs_local, bs_prop) = Seq.partition ((== fname) . fst) bs return (unfoldLetFuns (map snd $ toList bs_local) e', const bs_prop) | otherwise = transformExp $ LetPat [] (Id fname (Info ft) loc) lam e loc where lam = Lambda tparams params body Nothing (Info (mempty, ret)) loc ft = foldFunType (map (vacuousShapeAnnotations . patternType) params) $ fromStruct ret transformExp (If e1 e2 e3 tp loc) = do e1' <- transformExp e1 e2' <- transformExp e2 e3' <- transformExp e3 return $ If e1' e2' e3' tp loc transformExp (Apply e1 e2 d tp loc) = -- We handle on an ad-hoc basis certain polymorphic higher-order -- intrinsics here. They can only be used in very particular ways, -- or the compiler will fail. In practice they will only be used -- once, in the basis library, to define normal functions. case (e1, e2) of (Var v _ _, TupLit [op, ne, arr] _) | intrinsic "reduce" v -> transformExp $ Reduce Noncommutative op ne arr loc | intrinsic "reduce_comm" v -> transformExp $ Reduce Commutative op ne arr loc | intrinsic "scan" v -> transformExp $ Scan op ne arr loc (Var v _ _, TupLit [f, arr] _) | intrinsic "map" v -> transformExp $ Map f arr (removeShapeAnnotations <$> tp) loc | intrinsic "filter" v -> transformExp $ Filter f arr loc (Var v _ _, TupLit [k, f, arr] _) | intrinsic "partition" v, Just k' <- isInt32 k -> transformExp $ Partition (fromIntegral k') f arr loc (Var v _ _, TupLit [op, f, arr] _) | intrinsic "stream_red" v -> transformExp $ Stream (RedLike InOrder Noncommutative op) f arr loc | intrinsic "stream_red_per" v -> transformExp $ Stream (RedLike Disorder Commutative op) f arr loc (Var v _ _, TupLit [f, arr] _) | intrinsic "stream_map" v -> transformExp $ Stream (MapLike InOrder) f arr loc | intrinsic "stream_map_per" v -> transformExp $ Stream (MapLike Disorder) f arr loc (Var v _ _, TupLit [dest, op, ne, buckets, img] _) | intrinsic "gen_reduce" v -> transformExp $ GenReduce dest op ne buckets img loc _ -> do e1' <- transformExp e1 e2' <- transformExp e2 return $ Apply e1' e2' d tp loc where intrinsic s (QualName _ v) = baseTag v <= maxIntrinsicTag && baseName v == nameFromString s isInt32 (Literal (SignedValue (Int32Value k)) _) = Just k isInt32 (IntLit k (Info (Prim (Signed Int32))) _) = Just $ fromInteger k isInt32 _ = Nothing transformExp (Negate e loc) = Negate <$> transformExp e <*> pure loc transformExp (Lambda tparams params e0 decl tp loc) = do e0' <- transformExp e0 return $ Lambda tparams params e0' decl tp loc transformExp (OpSection qn t loc) = transformExp $ Var qn t loc transformExp (OpSectionLeft (QualName qs fname) (Info t) e (Info xtype, Info ytype) (Info rettype) loc) = do fname' <- transformFName fname (toStructural t) e' <- transformExp e desugarBinOpSection (QualName qs fname') (Just e') Nothing t xtype ytype rettype loc transformExp (OpSectionRight (QualName qs fname) (Info t) e (Info xtype, Info ytype) (Info rettype) loc) = do fname' <- transformFName fname (toStructural t) e' <- transformExp e desugarBinOpSection (QualName qs fname') Nothing (Just e') t xtype ytype rettype loc transformExp (ProjectSection fields (Info t) loc) = desugarProjectSection fields t loc transformExp (IndexSection idxs (Info t) loc) = desugarIndexSection idxs t loc transformExp (DoLoop tparams pat e1 form e3 loc) = do e1' <- transformExp e1 form' <- case form of For ident e2 -> For ident <$> transformExp e2 ForIn pat2 e2 -> ForIn pat2 <$> transformExp e2 While e2 -> While <$> transformExp e2 e3' <- transformExp e3 return $ DoLoop tparams pat e1' form' e3' loc transformExp (BinOp (QualName qs fname) (Info t) (e1, d1) (e2, d2) tp loc) = do fname' <- transformFName fname (toStructural t) e1' <- transformExp e1 e2' <- transformExp e2 return $ BinOp (QualName qs fname') (Info t) (e1', d1) (e2', d2) tp loc transformExp (Project n e tp loc) = do maybe_fs <- case e of Var qn _ _ -> lookupRecordReplacement (qualLeaf qn) _ -> return Nothing case maybe_fs of Just m | Just (v, _) <- M.lookup n m -> return $ Var (qualName v) (vacuousShapeAnnotations <$> tp) loc _ -> do e' <- transformExp e return $ Project n e' tp loc transformExp (LetWith id1 id2 idxs e1 body loc) = do idxs' <- mapM transformDimIndex idxs e1' <- transformExp e1 body' <- transformExp body return $ LetWith id1 id2 idxs' e1' body' loc transformExp (Index e0 idxs info loc) = Index <$> transformExp e0 <*> mapM transformDimIndex idxs <*> pure info <*> pure loc transformExp (Update e1 idxs e2 loc) = Update <$> transformExp e1 <*> mapM transformDimIndex idxs <*> transformExp e2 <*> pure loc transformExp (RecordUpdate e1 fs e2 t loc) = RecordUpdate <$> transformExp e1 <*> pure fs <*> transformExp e2 <*> pure t <*> pure loc transformExp (Map e1 es t loc) = Map <$> transformExp e1 <*> transformExp es <*> pure t <*> pure loc transformExp (Reduce comm e1 e2 e3 loc) = Reduce comm <$> transformExp e1 <*> transformExp e2 <*> transformExp e3 <*> pure loc transformExp (Scan e1 e2 e3 loc) = Scan <$> transformExp e1 <*> transformExp e2 <*> transformExp e3 <*> pure loc transformExp (Filter e1 e2 loc) = Filter <$> transformExp e1 <*> transformExp e2 <*> pure loc transformExp (Partition k f e0 loc) = Partition k <$> transformExp f <*> transformExp e0 <*> pure loc transformExp (Stream form e1 e2 loc) = do form' <- case form of MapLike _ -> return form RedLike so comm e -> RedLike so comm <$> transformExp e Stream form' <$> transformExp e1 <*> transformExp e2 <*> pure loc transformExp (GenReduce e1 e2 e3 e4 e5 loc) = GenReduce <$> transformExp e1 -- hist <*> transformExp e2 -- operator <*> transformExp e3 -- neutral element <*> transformExp e4 -- buckets <*> transformExp e5 -- input image <*> pure loc transformExp (Zip i e1 es t loc) = do e1' <- transformExp e1 es' <- mapM transformExp es return $ Zip i e1' es' t loc transformExp (Unzip e0 tps loc) = Unzip <$> transformExp e0 <*> pure tps <*> pure loc transformExp (Unsafe e1 loc) = Unsafe <$> transformExp e1 <*> pure loc transformExp (Assert e1 e2 desc loc) = Assert <$> transformExp e1 <*> transformExp e2 <*> pure desc <*> pure loc transformExp e@VConstr0{} = return e transformExp (Match e cs t loc) = Match <$> transformExp e <*> mapM transformCase cs <*> pure t <*> pure loc transformCase :: Case -> MonoM Case transformCase (CasePat p e loc) = do (p', rr) <- expandRecordPattern p CasePat <$> pure p' <*> withRecordReplacements rr (transformExp e) <*> pure loc transformDimIndex :: DimIndexBase Info VName -> MonoM (DimIndexBase Info VName) transformDimIndex (DimFix e) = DimFix <$> transformExp e transformDimIndex (DimSlice me1 me2 me3) = DimSlice <$> trans me1 <*> trans me2 <*> trans me3 where trans = mapM transformExp -- | Transform an operator section into a lambda. desugarBinOpSection :: QualName VName -> Maybe Exp -> Maybe Exp -> PatternType -> StructType -> StructType -> PatternType -> SrcLoc -> MonoM Exp desugarBinOpSection qn e_left e_right t xtype ytype rettype loc = do (e1, p1) <- makeVarParam e_left $ fromStruct xtype (e2, p2) <- makeVarParam e_right $ fromStruct ytype let body = BinOp qn (Info t) (e1, Info xtype) (e2, Info ytype) (Info rettype) loc rettype' = vacuousShapeAnnotations $ toStruct rettype return $ Lambda [] (p1 ++ p2) body Nothing (Info (mempty, rettype')) loc where makeVarParam (Just e) _ = return (e, []) makeVarParam Nothing argtype = do x <- newNameFromString "x" return (Var (qualName x) (Info argtype) noLoc, [Id x (Info $ fromStruct argtype) noLoc]) desugarProjectSection :: [Name] -> PatternType -> SrcLoc -> MonoM Exp desugarProjectSection fields (Arrow _ _ t1 t2) loc = do p <- newVName "project_p" let body = foldl project (Var (qualName p) (Info t1) noLoc) fields return $ Lambda [] [Id p (Info t1) noLoc] body Nothing (Info (mempty, toStruct t2)) loc where project e field = case typeOf e of Record fs | Just t <- M.lookup field fs -> Project field e (Info t) noLoc t -> error $ "desugarOpSection: type " ++ pretty t ++ " does not have field " ++ pretty field desugarProjectSection _ t _ = error $ "desugarOpSection: not a function type: " ++ pretty t desugarIndexSection :: [DimIndex] -> PatternType -> SrcLoc -> MonoM Exp desugarIndexSection idxs (Arrow _ _ t1 t2) loc = do p <- newVName "index_i" let body = Index (Var (qualName p) (Info t1) loc) idxs (Info t2') loc return $ Lambda [] [Id p (Info t1) noLoc] body Nothing (Info (mempty, toStruct t2)) loc where t2' = removeShapeAnnotations t2 desugarIndexSection _ t _ = error $ "desugarIndexSection: not a function type: " ++ pretty t noticeDims :: TypeBase (DimDecl VName) as -> MonoM () noticeDims = mapM_ notice . nestedDims where notice (NamedDim v) = void $ transformFName (qualLeaf v) $ Prim $ Signed Int32 notice _ = return () -- | Convert a collection of 'ValBind's to a nested sequence of let-bound, -- monomorphic functions with the given expression at the bottom. unfoldLetFuns :: [ValBind] -> Exp -> Exp unfoldLetFuns [] e = e unfoldLetFuns (ValBind _ fname _ rettype dim_params params body _ loc : rest) e = LetFun fname (dim_params, params, Nothing, rettype, body) e' loc where e' = unfoldLetFuns rest e expandRecordPattern :: Pattern -> MonoM (Pattern, RecordReplacements) expandRecordPattern (Id v (Info (Record fs)) loc) = do let fs' = M.toList fs (fs_ks, fs_ts) <- fmap unzip $ forM fs' $ \(f, ft) -> (,) <$> newVName (nameToString f) <*> pure ft return (RecordPattern (zip (map fst fs') (zipWith3 Id fs_ks (map Info fs_ts) $ repeat loc)) loc, M.singleton v $ M.fromList $ zip (map fst fs') $ zip fs_ks fs_ts) expandRecordPattern (Id v t loc) = return (Id v t loc, mempty) expandRecordPattern (TuplePattern pats loc) = do (pats', rrs) <- unzip <$> mapM expandRecordPattern pats return (TuplePattern pats' loc, mconcat rrs) expandRecordPattern (RecordPattern fields loc) = do let (field_names, field_pats) = unzip fields (field_pats', rrs) <- unzip <$> mapM expandRecordPattern field_pats return (RecordPattern (zip field_names field_pats') loc, mconcat rrs) expandRecordPattern (PatternParens pat loc) = do (pat', rr) <- expandRecordPattern pat return (PatternParens pat' loc, rr) expandRecordPattern (Wildcard t loc) = return (Wildcard t loc, mempty) expandRecordPattern (PatternAscription pat td loc) = do (pat', rr) <- expandRecordPattern pat return (PatternAscription pat' td loc, rr) expandRecordPattern (PatternLit e t loc) = return (PatternLit e t loc, mempty) -- | Monomorphize a polymorphic function at the types given in the instance -- list. Monomorphizes the body of the function as well. Returns the fresh name -- of the generated monomorphic function and its 'ValBind' representation. monomorphizeBinding :: PolyBinding -> TypeBase () () -> MonoM (VName, ValBind) monomorphizeBinding (PolyBinding rr (name, tparams, params, retdecl, rettype, body, loc)) t = replaceRecordReplacements rr $ do t' <- removeTypeVariablesInType t let bind_t = foldFunType (map (toStructural . patternType) params) $ toStructural rettype substs = M.map Subst $ typeSubsts bind_t t' rettype' = applySubst (`M.lookup` substs) rettype params' = map (substPattern $ applySubst (`M.lookup` substs)) params (params'', rrs) <- unzip <$> mapM expandRecordPattern params' mapM_ noticeDims $ rettype : map patternStructType params'' body' <- updateExpTypes (`M.lookup` substs) body body'' <- withRecordReplacements (mconcat rrs) $ transformExp body' name' <- if null tparams then return name else newName name return (name', toValBinding name' params'' rettype' body'') where shape_params = filter (not . isTypeParam) tparams updateExpTypes substs = astMap $ mapper substs mapper substs = ASTMapper { mapOnExp = astMap $ mapper substs , mapOnName = pure , mapOnQualName = pure , mapOnType = pure . applySubst substs , mapOnCompType = pure . applySubst substs , mapOnStructType = pure . applySubst substs , mapOnPatternType = pure . applySubst substs } toValBinding name' params'' rettype' body'' = ValBind { valBindEntryPoint = False , valBindName = name' , valBindRetDecl = retdecl , valBindRetType = Info rettype' , valBindTypeParams = shape_params , valBindParams = params'' , valBindBody = body'' , valBindDoc = Nothing , valBindLocation = loc } typeSubsts :: TypeBase () () -> TypeBase () () -> M.Map VName (TypeBase () ()) typeSubsts (Record fields1) (Record fields2) = mconcat $ zipWith typeSubsts (map snd $ sortFields fields1) (map snd $ sortFields fields2) typeSubsts (TypeVar _ _ v _) t = M.singleton (typeLeaf v) t typeSubsts Prim{} Prim{} = mempty typeSubsts (Arrow _ _ t1a t1b) (Arrow _ _ t2a t2b) = typeSubsts t1a t2a <> typeSubsts t1b t2b typeSubsts t1@Array{} t2@Array{} | Just t1' <- peelArray (arrayRank t1) t1, Just t2' <- peelArray (arrayRank t1) t2 = typeSubsts t1' t2' typeSubsts Enum{} Enum{} = mempty typeSubsts t1 t2 = error $ unlines ["typeSubsts: mismatched types:", pretty t1, pretty t2] -- | Perform a given substitution on the types in a pattern. substPattern :: (PatternType -> PatternType) -> Pattern -> Pattern substPattern f pat = case pat of TuplePattern pats loc -> TuplePattern (map (substPattern f) pats) loc RecordPattern fs loc -> RecordPattern (map substField fs) loc where substField (n, p) = (n, substPattern f p) PatternParens p loc -> PatternParens (substPattern f p) loc Id vn (Info tp) loc -> Id vn (Info $ f tp) loc Wildcard (Info tp) loc -> Wildcard (Info $ f tp) loc PatternAscription p td loc -> PatternAscription (substPattern f p) td loc PatternLit e (Info tp) loc -> PatternLit e (Info $ f tp) loc toPolyBinding :: ValBind -> PolyBinding toPolyBinding (ValBind _ name retdecl (Info rettype) tparams params body _ loc) = PolyBinding mempty (name, tparams, params, retdecl, rettype, body, loc) -- | Remove all type variables and type abbreviations from a value binding. removeTypeVariables :: ValBind -> MonoM ValBind removeTypeVariables valbind@(ValBind _ _ _ (Info rettype) _ pats body _ _) = do subs <- asks $ M.map TypeSub . envTypeBindings let substPatternType = fromStruct . substituteTypes subs . toStruct mapper = ASTMapper { mapOnExp = astMap mapper , mapOnName = pure , mapOnQualName = pure , mapOnType = pure . removeShapeAnnotations . substituteTypes subs . vacuousShapeAnnotations , mapOnCompType = pure . fromStruct . removeShapeAnnotations . substituteTypes subs . vacuousShapeAnnotations . toStruct , mapOnStructType = pure . substituteTypes subs , mapOnPatternType = pure . substPatternType } body' <- astMap mapper body return valbind { valBindRetType = Info $ substituteTypes subs rettype , valBindParams = map (substPattern substPatternType) pats , valBindBody = body' } removeTypeVariablesInType :: TypeBase dim () -> MonoM (TypeBase () ()) removeTypeVariablesInType t = do subs <- asks $ M.map TypeSub . envTypeBindings return $ removeShapeAnnotations $ substituteTypes subs $ vacuousShapeAnnotations t transformValBind :: ValBind -> MonoM Env transformValBind valbind = do valbind' <- toPolyBinding <$> removeTypeVariables valbind when (valBindEntryPoint valbind) $ do t <- removeTypeVariablesInType $ removeShapeAnnotations $ foldFunType (map patternStructType (valBindParams valbind)) $ unInfo $ valBindRetType valbind (name, valbind'') <- monomorphizeBinding valbind' t tell $ Seq.singleton (name, valbind'' { valBindEntryPoint = True}) addLifted (valBindName valbind) t name return mempty { envPolyBindings = M.singleton (valBindName valbind) valbind' } transformTypeBind :: TypeBind -> MonoM Env transformTypeBind (TypeBind name tparams tydecl _ _) = do subs <- asks $ M.map TypeSub . envTypeBindings noticeDims $ unInfo $ expandedType tydecl let tp = substituteTypes subs . unInfo $ expandedType tydecl tbinding = TypeAbbr Lifted tparams tp -- The Lifted is arbitrary. return mempty { envTypeBindings = M.singleton name tbinding } -- | Monomorphize a list of top-level declarations. A module-free input program -- is expected, so only value declarations and type declaration are accepted. transformDecs :: [Dec] -> MonoM () transformDecs [] = return () transformDecs (ValDec valbind : ds) = do env <- transformValBind valbind localEnv env $ transformDecs ds transformDecs (TypeDec typebind : ds) = do env <- transformTypeBind typebind localEnv env $ transformDecs ds transformDecs (dec : _) = error $ "The monomorphization module expects a module-free " ++ "input program, but received: " ++ pretty dec transformProg :: MonadFreshNames m => [Dec] -> m [ValBind] transformProg decs = fmap (toList . fmap snd . snd) $ modifyNameSource $ \namesrc -> runMonoM namesrc $ transformDecs decs