{-# LANGUAGE GeneralizedNewtypeDeriving #-} -- | Defunctionalization of typed, monomorphic Futhark programs without modules. module Futhark.Internalise.Defunctionalise ( transformProg ) where import qualified Control.Arrow as Arrow import Control.Monad.RWS hiding (Sum) import Data.Bifunctor import Data.Foldable import Data.List import qualified Data.List.NonEmpty as NE import Data.Loc import qualified Data.Map.Strict as M import qualified Data.Set as S import qualified Data.Sequence as Seq import Futhark.MonadFreshNames import Language.Futhark import Futhark.Representation.AST.Pretty () -- | An expression or an extended 'Lambda' (with size parameters, -- which AST lambdas do not support). data ExtExp = ExtLambda [TypeParam] [Pattern] Exp (Aliasing, StructType) SrcLoc | ExtExp Exp deriving (Show) -- | A static value stores additional information about the result of -- defunctionalization of an expression, aside from the residual expression. data StaticVal = Dynamic PatternType | LambdaSV [VName] Pattern StructType ExtExp Env -- ^ The 'VName's are shape parameters that are bound -- by the 'Pattern'. | RecordSV [(Name, StaticVal)] | SumSV Name [StaticVal] [(Name, [PatternType])] -- ^ The constructor that is actually present, plus -- the others that are not. | DynamicFun (Exp, StaticVal) StaticVal | IntrinsicSV deriving (Show) -- | Environment mapping variable names to their associated static value. type Env = M.Map VName StaticVal localEnv :: Env -> DefM a -> DefM a localEnv env = local $ Arrow.second (env<>) -- Even when using a "new" environment (for evaluating closures) we -- still ram the global environment of DynamicFuns in there. localNewEnv :: Env -> DefM a -> DefM a localNewEnv env = local $ \(globals, old_env) -> (globals, M.filterWithKey (\k _ -> k `S.member` globals) old_env <> env) extendEnv :: VName -> StaticVal -> DefM a -> DefM a extendEnv vn sv = localEnv (M.singleton vn sv) askEnv :: DefM Env askEnv = asks snd isGlobal :: VName -> DefM a -> DefM a isGlobal v = local $ Arrow.first (S.insert v) -- | Returns the defunctionalization environment restricted -- to the given set of variable names and types. restrictEnvTo :: NameSet -> DefM Env restrictEnvTo (NameSet m) = restrict <$> ask where restrict (globals, env) = M.mapMaybeWithKey keep env where keep k sv = do guard $ not $ k `S.member` globals u <- M.lookup k m Just $ restrict' u sv restrict' Nonunique (Dynamic t) = Dynamic $ t `setUniqueness` Nonunique restrict' _ (Dynamic t) = Dynamic t restrict' u (LambdaSV dims pat t e env) = LambdaSV dims pat t e $ M.map (restrict' u) env restrict' u (RecordSV fields) = RecordSV $ map (fmap $ restrict' u) fields restrict' u (SumSV c svs fields) = SumSV c (map (restrict' u) svs) fields restrict' u (DynamicFun (e, sv1) sv2) = DynamicFun (e, restrict' u sv1) $ restrict' u sv2 restrict' _ IntrinsicSV = IntrinsicSV -- | Defunctionalization monad. The Reader environment tracks both -- the current Env as well as the set of globally defined dynamic -- functions. This is used to avoid unnecessarily large closure -- environments. newtype DefM a = DefM (RWS (S.Set VName, Env) (Seq.Seq ValBind) VNameSource a) deriving (Functor, Applicative, Monad, MonadReader (S.Set VName, Env), MonadWriter (Seq.Seq ValBind), MonadFreshNames) -- | Run a computation in the defunctionalization monad. Returns the result of -- the computation, a new name source, and a list of lifted function declations. runDefM :: VNameSource -> DefM a -> (a, VNameSource, Seq.Seq ValBind) runDefM src (DefM m) = runRWS m mempty src collectFuns :: DefM a -> DefM (a, Seq.Seq ValBind) collectFuns m = pass $ do (x, decs) <- listen m return ((x, decs), const mempty) -- | Looks up the associated static value for a given name in the environment. lookupVar :: SrcLoc -> VName -> DefM StaticVal lookupVar loc x = do env <- askEnv case M.lookup x env of Just sv -> return sv Nothing -- If the variable is unknown, it may refer to the 'intrinsics' -- module, which we will have to treat specially. | baseTag x <= maxIntrinsicTag -> return IntrinsicSV | otherwise -> error $ "Variable " ++ pretty x ++ " at " ++ locStr loc ++ " is out of scope." defuncFun :: [TypeParam] -> [Pattern] -> Exp -> (Aliasing, StructType) -> SrcLoc -> DefM (Exp, StaticVal) defuncFun tparams pats e0 (closure, ret) loc = do when (any isTypeParam tparams) $ error $ "Received a lambda with type parameters at " ++ locStr loc ++ ", but the defunctionalizer expects a monomorphic input program." -- Extract the first parameter of the lambda and "push" the -- remaining ones (if there are any) into the body of the lambda. let (dims, pat, ret', e0') = case pats of [] -> error "Received a lambda with no parameters." [pat'] -> (map typeParamName tparams, pat', ret, ExtExp e0) (pat' : pats') -> -- Split shape parameters into those that are determined by -- the first pattern, and those that are determined by later -- patterns. let bound_by_pat = (`S.member` patternDimNames pat') . typeParamName (pat_dims, rest_dims) = partition bound_by_pat tparams in (map typeParamName pat_dims, pat', foldFunType (map (toStruct . patternType) pats') ret, ExtLambda rest_dims pats' e0 (closure, ret) loc) -- Construct a record literal that closes over the environment of -- the lambda. Closed-over 'DynamicFun's are converted to their -- closure representation. env <- restrictEnvTo $ freeVars (Lambda pats e0 Nothing (Info (closure, ret)) loc) `without` mconcat (map (oneName . typeParamName) tparams) let (fields, env') = unzip $ map closureFromDynamicFun $ M.toList env return (RecordLit fields loc, LambdaSV dims pat ret' e0' $ M.fromList env') where closureFromDynamicFun (vn, DynamicFun (clsr_env, sv) _) = let name = nameFromString $ pretty vn in (RecordFieldExplicit name clsr_env noLoc, (vn, sv)) closureFromDynamicFun (vn, sv) = let name = nameFromString $ pretty vn tp' = typeFromSV sv in (RecordFieldExplicit name (Var (qualName vn) (Info tp') noLoc) noLoc, (vn, sv)) -- | Defunctionalization of an expression. Returns the residual expression and -- the associated static value in the defunctionalization monad. defuncExp :: Exp -> DefM (Exp, StaticVal) defuncExp e@Literal{} = return (e, Dynamic $ typeOf e) defuncExp e@IntLit{} = return (e, Dynamic $ typeOf e) defuncExp e@FloatLit{} = return (e, Dynamic $ typeOf e) defuncExp (Parens e loc) = do (e', sv) <- defuncExp e return (Parens e' loc, sv) defuncExp (QualParens qn e loc) = do (e', sv) <- defuncExp e return (QualParens qn e' loc, sv) defuncExp (TupLit es loc) = do (es', svs) <- unzip <$> mapM defuncExp es return (TupLit es' loc, RecordSV $ zip fields svs) where fields = map (nameFromString . show) [(1 :: Int) ..] defuncExp (RecordLit fs loc) = do (fs', names_svs) <- unzip <$> mapM defuncField fs return (RecordLit fs' loc, RecordSV names_svs) where defuncField (RecordFieldExplicit vn e loc') = do (e', sv) <- defuncExp e return (RecordFieldExplicit vn e' loc', (vn, sv)) defuncField (RecordFieldImplicit vn _ loc') = do sv <- lookupVar loc' vn case sv of -- If the implicit field refers to a dynamic function, we -- convert it to an explicit field with a record closing over -- the environment and bind the corresponding static value. DynamicFun (e, sv') _ -> let vn' = baseName vn in return (RecordFieldExplicit vn' e loc', (vn', sv')) -- The field may refer to a functional expression, so we get the -- type from the static value and not the one from the AST. _ -> let tp = Info $ typeFromSV sv in return (RecordFieldImplicit vn tp loc', (baseName vn, sv)) defuncExp (ArrayLit es t@(Info t') loc) = do es' <- mapM defuncExp' es return (ArrayLit es' t loc, Dynamic t') defuncExp (Range e1 me incl t@(Info t') loc) = do e1' <- defuncExp' e1 me' <- mapM defuncExp' me incl' <- mapM defuncExp' incl return (Range e1' me' incl' t loc, Dynamic t') defuncExp e@(Var qn _ loc) = do sv <- lookupVar loc (qualLeaf qn) case sv of -- If the variable refers to a dynamic function, we return its closure -- representation (i.e., a record expression capturing the free variables -- and a 'LambdaSV' static value) instead of the variable itself. DynamicFun closure _ -> return closure -- Intrinsic functions used as variables are eta-expanded, so we -- can get rid of them. IntrinsicSV -> do (pats, body, tp) <- etaExpand e defuncExp $ Lambda pats body Nothing (Info (mempty, tp)) noLoc _ -> let tp = typeFromSV sv in return (Var qn (Info tp) loc, sv) defuncExp (Ascript e0 tydecl t loc) | orderZero (typeOf e0) = do (e0', sv) <- defuncExp e0 return (Ascript e0' tydecl t loc, sv) | otherwise = defuncExp e0 defuncExp (LetPat pat e1 e2 _ loc) = do (e1', sv1) <- defuncExp e1 let env = matchPatternSV pat sv1 pat' = updatePattern pat sv1 (e2', sv2) <- localEnv env $ defuncExp e2 return (LetPat pat' e1' e2' (Info $ typeOf e2') loc, sv2) -- Local functions are handled by rewriting them to lambdas, so that -- the same machinery can be re-used. defuncExp (LetFun vn (dims, pats, _, Info ret, e1) e2 loc) = do (e1', sv1) <- defuncFun dims pats e1 (mempty, ret) loc (e2', sv2) <- localEnv (M.singleton vn sv1) $ defuncExp e2 return (LetPat (Id vn (Info (typeOf e1')) loc) e1' e2' (Info $ typeOf e2') loc, sv2) defuncExp (If e1 e2 e3 tp loc) = do (e1', _ ) <- defuncExp e1 (e2', sv) <- defuncExp e2 (e3', _ ) <- defuncExp e3 return (If e1' e2' e3' tp loc, sv) defuncExp e@(Apply f@(Var f' _ _) arg d t loc) | baseTag (qualLeaf f') <= maxIntrinsicTag, TupLit es tuploc <- arg = do -- defuncSoacExp also works fine for non-SOACs. es' <- mapM defuncSoacExp es return (Apply f (TupLit es' tuploc) d t loc, Dynamic $ typeOf e) defuncExp e@Apply{} = defuncApply 0 e defuncExp (Negate e0 loc) = do (e0', sv) <- defuncExp e0 return (Negate e0' loc, sv) defuncExp (Lambda pats e0 _ (Info (closure, ret)) loc) = defuncFun [] pats e0 (closure, ret) loc -- Operator sections are expected to be converted to lambda-expressions -- by the monomorphizer, so they should no longer occur at this point. defuncExp OpSection{} = error "defuncExp: unexpected operator section." defuncExp OpSectionLeft{} = error "defuncExp: unexpected operator section." defuncExp OpSectionRight{} = error "defuncExp: unexpected operator section." defuncExp ProjectSection{} = error "defuncExp: unexpected projection section." defuncExp IndexSection{} = error "defuncExp: unexpected projection section." defuncExp (DoLoop pat e1 form e3 loc) = do (e1', sv1) <- defuncExp e1 let env1 = matchPatternSV pat sv1 (form', env2) <- case form of For v e2 -> do e2' <- defuncExp' e2 return (For v e2', envFromIdent v) ForIn pat2 e2 -> do e2' <- defuncExp' e2 return (ForIn pat2 e2', envFromPattern pat2) While e2 -> do e2' <- localEnv env1 $ defuncExp' e2 return (While e2', mempty) (e3', sv) <- localEnv (env1 <> env2) $ defuncExp e3 return (DoLoop pat e1' form' e3' loc, sv) where envFromIdent (Ident vn (Info tp) _) = M.singleton vn $ Dynamic tp -- We handle BinOps by turning them into ordinary function applications. defuncExp (BinOp (qn, qnloc) (Info t) (e1, Info pt1) (e2, Info pt2) (Info ret) loc) = defuncExp $ Apply (Apply (Var qn (Info t) qnloc) e1 (Info (diet pt1)) (Info (Scalar $ Arrow mempty Unnamed (fromStruct pt2) ret)) loc) e2 (Info (diet pt2)) (Info ret) loc defuncExp (Project vn e0 tp@(Info tp') loc) = do (e0', sv0) <- defuncExp e0 case sv0 of RecordSV svs -> case lookup vn svs of Just sv -> return (Project vn e0' (Info $ typeFromSV sv) loc, sv) Nothing -> error "Invalid record projection." Dynamic _ -> return (Project vn e0' tp loc, Dynamic tp') _ -> error $ "Projection of an expression with static value " ++ show sv0 defuncExp (LetWith id1 id2 idxs e1 body t loc) = do e1' <- defuncExp' e1 sv1 <- lookupVar (identSrcLoc id2) $ identName id2 idxs' <- mapM defuncDimIndex idxs (body', sv) <- extendEnv (identName id1) sv1 $ defuncExp body return (LetWith id1 id2 idxs' e1' body' t loc, sv) defuncExp expr@(Index e0 idxs info loc) = do e0' <- defuncExp' e0 idxs' <- mapM defuncDimIndex idxs return (Index e0' idxs' info loc, Dynamic $ typeOf expr) defuncExp (Update e1 idxs e2 loc) = do (e1', sv) <- defuncExp e1 idxs' <- mapM defuncDimIndex idxs e2' <- defuncExp' e2 return (Update e1' idxs' e2' loc, sv) -- Note that we might change the type of the record field here. This -- is not permitted in the type checker due to problems with type -- inference, but it actually works fine. defuncExp (RecordUpdate e1 fs e2 _ loc) = do (e1', sv1) <- defuncExp e1 (e2', sv2) <- defuncExp e2 let sv = staticField sv1 sv2 fs return (RecordUpdate e1' fs e2' (Info $ typeFromSV sv1) loc, sv) where staticField (RecordSV svs) sv2 (f:fs') = case lookup f svs of Just sv -> RecordSV $ (f, staticField sv sv2 fs') : filter ((/=f) . fst) svs Nothing -> error "Invalid record projection." staticField (Dynamic t@(Scalar Record{})) sv2 fs'@(_:_) = staticField (svFromType t) sv2 fs' staticField _ sv2 _ = sv2 defuncExp (Unsafe e1 loc) = do (e1', sv) <- defuncExp e1 return (Unsafe e1' loc, sv) defuncExp (Assert e1 e2 desc loc) = do (e1', _) <- defuncExp e1 (e2', sv) <- defuncExp e2 return (Assert e1' e2' desc loc, sv) defuncExp (Constr name es (Info (Scalar (Sum all_fs))) loc) = do (es', svs) <- unzip <$> mapM defuncExp es let sv = SumSV name svs $ M.toList $ name `M.delete` M.map (map defuncType) all_fs return (Constr name es' (Info (typeFromSV sv)) loc, sv) where defuncType :: Monoid als => TypeBase (DimDecl VName) als -> TypeBase (DimDecl VName) als defuncType (Array as u t shape) = Array as u (defuncScalar t) shape defuncType (Scalar t) = Scalar $ defuncScalar t defuncScalar :: Monoid als => ScalarTypeBase (DimDecl VName) als -> ScalarTypeBase (DimDecl VName) als defuncScalar (Record fs) = Record $ M.map defuncType fs defuncScalar Arrow{} = Record mempty defuncScalar (Sum fs) = Sum $ M.map (map defuncType) fs defuncScalar (Prim t) = Prim t defuncScalar (TypeVar as u tn targs) = TypeVar as u tn targs defuncExp (Constr name _ (Info t) loc) = error $ "Constructor " ++ pretty name ++ " given type " ++ pretty t ++ " at " ++ locStr loc defuncExp (Match e cs t loc) = do (e', sv) <- defuncExp e csPairs <- mapM (defuncCase sv) cs let cs' = fmap fst csPairs sv' = snd $ NE.head csPairs return (Match e' cs' t loc, sv') -- | Same as 'defuncExp', except it ignores the static value. defuncExp' :: Exp -> DefM Exp defuncExp' = fmap fst . defuncExp defuncExtExp :: ExtExp -> DefM (Exp, StaticVal) defuncExtExp (ExtExp e) = defuncExp e defuncExtExp (ExtLambda tparams pats e0 (closure, ret) loc) = defuncFun tparams pats e0 (closure, ret) loc defuncCase :: StaticVal -> Case -> DefM (Case, StaticVal) defuncCase sv (CasePat p e loc) = do let p' = updatePattern p sv env = matchPatternSV p sv (e', sv') <- localEnv env $ defuncExp e return (CasePat p' e' loc, sv') -- | Defunctionalize the function argument to a SOAC by eta-expanding if -- necessary and then defunctionalizing the body of the introduced lambda. defuncSoacExp :: Exp -> DefM Exp defuncSoacExp e@OpSection{} = return e defuncSoacExp e@OpSectionLeft{} = return e defuncSoacExp e@OpSectionRight{} = return e defuncSoacExp e@ProjectSection{} = return e defuncSoacExp (Parens e loc) = Parens <$> defuncSoacExp e <*> pure loc defuncSoacExp (Lambda params e0 decl tp loc) = do let env = foldMap envFromPattern params e0' <- localEnv env $ defuncSoacExp e0 return $ Lambda params e0' decl tp loc defuncSoacExp e | Scalar Arrow{} <- typeOf e = do (pats, body, tp) <- etaExpand e let env = foldMap envFromPattern pats body' <- localEnv env $ defuncExp' body return $ Lambda pats body' Nothing (Info (mempty, tp)) noLoc | otherwise = defuncExp' e etaExpand :: Exp -> DefM ([Pattern], Exp, StructType) etaExpand e = do let (ps, ret) = getType $ typeOf e (pats, vars) <- fmap unzip . forM ps $ \t -> do x <- newNameFromString "x" return (Id x (Info t) noLoc, Var (qualName x) (Info t) noLoc) let e' = foldl' (\e1 (e2, t2, argtypes) -> Apply e1 e2 (Info $ diet t2) (Info (foldFunType argtypes ret)) noLoc) e $ zip3 vars ps (drop 1 $ tails ps) return (pats, e', toStruct ret) where getType (Scalar (Arrow _ _ t1 t2)) = let (ps, r) = getType t2 in (t1 : ps, r) getType t = ([], t) -- | Defunctionalize an indexing of a single array dimension. defuncDimIndex :: DimIndexBase Info VName -> DefM (DimIndexBase Info VName) defuncDimIndex (DimFix e1) = DimFix . fst <$> defuncExp e1 defuncDimIndex (DimSlice me1 me2 me3) = DimSlice <$> defunc' me1 <*> defunc' me2 <*> defunc' me3 where defunc' = mapM defuncExp' -- | Defunctionalize a let-bound function, while preserving parameters -- that have order 0 types (i.e., non-functional). defuncLet :: [TypeParam] -> [Pattern] -> Exp -> Info StructType -> DefM ([Pattern], Exp, StaticVal) defuncLet dims ps@(pat:pats) body (Info rettype) | patternOrderZero pat = do let env = envFromPattern pat bound_by_pat = (`S.member` patternDimNames pat) . typeParamName (_pat_dims, rest_dims) = partition bound_by_pat dims (pats', body', sv) <- localEnv env $ defuncLet rest_dims pats body (Info rettype) closure <- defuncFun dims ps body (mempty, rettype) noLoc return (pat : pats', body', DynamicFun closure sv) | otherwise = do (e, sv) <- defuncFun dims ps body (mempty, rettype) noLoc return ([], e, sv) defuncLet _ [] body (Info rettype) = do (body', sv) <- defuncExp body return ([], body', imposeType sv rettype ) where imposeType Dynamic{} t = Dynamic $ fromStruct t imposeType (RecordSV fs1) (Scalar (Record fs2)) = RecordSV $ M.toList $ M.intersectionWith imposeType (M.fromList fs1) fs2 imposeType sv _ = sv -- | Defunctionalize an application expression at a given depth of application. -- Calls to dynamic (first-order) functions are preserved at much as possible, -- but a new lifted function is created if a dynamic function is only partially -- applied. defuncApply :: Int -> Exp -> DefM (Exp, StaticVal) defuncApply depth e@(Apply e1 e2 d t@(Info ret) loc) = do let (argtypes, _) = unfoldFunType ret (e1', sv1) <- defuncApply (depth+1) e1 (e2', sv2) <- defuncExp e2 let e' = Apply e1' e2' d t loc case sv1 of LambdaSV dims pat e0_t e0 closure_env -> do let env' = matchPatternSV pat sv2 env_dim = envFromDimNames dims (e0', sv) <- localNewEnv (env' <> closure_env <> env_dim) $ defuncExtExp e0 let closure_pat = buildEnvPattern closure_env pat' = updatePattern pat sv2 -- Lift lambda to top-level function definition. We put in -- a lot of effort to try to infer the uniqueness attributes -- of the lifted function, but this is ultimately all a sham -- and a hack. There is some piece we're missing. let params = [closure_pat, pat'] params_for_rettype = params ++ svParams sv1 ++ svParams sv2 svParams (LambdaSV _ sv_pat _ _ _) = [sv_pat] svParams _ = [] rettype = buildRetType closure_env params_for_rettype e0_t $ anyDimShapeAnnotations $ typeOf e0' -- Embed some information about the original function -- into the name of the lifted function, to make the -- result slightly more human-readable. liftedName i (Var f _ _) = "lifted_" ++ show i ++ "_" ++ baseString (qualLeaf f) liftedName i (Apply f _ _ _ _) = liftedName (i+1) f liftedName _ _ = "lifted" fname <- newNameFromString $ liftedName (0::Int) e1 liftValDec fname rettype dims params e0' let t1 = toStruct $ typeOf e1' t2 = toStruct $ typeOf e2' fname' = qualName fname return (Parens (Apply (Apply (Var fname' (Info (Scalar $ Arrow mempty Unnamed (fromStruct t1) $ Scalar $ Arrow mempty Unnamed (fromStruct t2) rettype)) loc) e1' (Info Observe) (Info $ Scalar $ Arrow mempty Unnamed (fromStruct t2) rettype) loc) e2' d (Info rettype) loc) noLoc, sv) -- If e1 is a dynamic function, we just leave the application in place, -- but we update the types since it may be partially applied or return -- a higher-order term. DynamicFun _ sv -> let (argtypes', rettype) = dynamicFunType sv argtypes apply_e = Apply e1' e2' d (Info $ foldFunType argtypes' rettype `setAliases` aliases ret) loc in return (apply_e, sv) -- Propagate the 'IntrinsicsSV' until we reach the outermost application, -- where we construct a dynamic static value with the appropriate type. IntrinsicSV | depth == 0 -> return (e', Dynamic $ typeOf e) | otherwise -> return (e', IntrinsicSV) _ -> error $ "Application of an expression that is neither a static lambda " ++ "nor a dynamic function, but has static value: " ++ show sv1 defuncApply depth e@(Var qn (Info t) loc) = do let (argtypes, _) = unfoldFunType t sv <- lookupVar loc (qualLeaf qn) case sv of DynamicFun _ _ | fullyApplied sv depth -> -- We still need to update the types in case the dynamic -- function returns a higher-order term. let (argtypes', rettype) = dynamicFunType sv argtypes in return (Var qn (Info (foldFunType argtypes' rettype)) loc, sv) | otherwise -> do fname <- newName $ qualLeaf qn let (dims, pats, e0, sv') = liftDynFun sv depth (argtypes', rettype) = dynamicFunType sv' argtypes liftValDec fname (fromStruct rettype) dims pats e0 return (Var (qualName fname) (Info (foldFunType argtypes' $ fromStruct rettype)) loc, sv') IntrinsicSV -> return (e, IntrinsicSV) _ -> return (Var qn (Info (typeFromSV sv)) loc, sv) defuncApply depth (Parens e _) = defuncApply depth e defuncApply _ expr = defuncExp expr -- | Check if a 'StaticVal' and a given application depth corresponds -- to a fully applied dynamic function. fullyApplied :: StaticVal -> Int -> Bool fullyApplied (DynamicFun _ sv) depth | depth == 0 = False | depth > 0 = fullyApplied sv (depth-1) fullyApplied _ _ = True -- | Converts a dynamic function 'StaticVal' into a list of -- dimensions, a list of parameters, a function body, and the -- appropriate static value for applying the function at the given -- depth of partial application. liftDynFun :: StaticVal -> Int -> ([VName], [Pattern], Exp, StaticVal) liftDynFun (DynamicFun (e, sv) _) 0 = ([], [], e, sv) liftDynFun (DynamicFun clsr@(_, LambdaSV dims pat _ _ _) sv) d | d > 0 = let (dims', pats, e', sv') = liftDynFun sv (d-1) in (dims ++ dims', pat : pats, e', DynamicFun clsr sv') liftDynFun sv _ = error $ "Tried to lift a StaticVal " ++ show sv ++ ", but expected a dynamic function." -- | Converts a pattern to an environment that binds the individual names of the -- pattern to their corresponding types wrapped in a 'Dynamic' static value. envFromPattern :: Pattern -> Env envFromPattern pat = case pat of TuplePattern ps _ -> foldMap envFromPattern ps RecordPattern fs _ -> foldMap (envFromPattern . snd) fs PatternParens p _ -> envFromPattern p Id vn (Info t) _ -> M.singleton vn $ Dynamic t Wildcard _ _ -> mempty PatternAscription p _ _ -> envFromPattern p PatternLit{} -> mempty PatternConstr _ _ ps _ -> foldMap envFromPattern ps -- | Create an environment that binds the shape parameters. envFromShapeParams :: [TypeParamBase VName] -> Env envFromShapeParams = envFromDimNames . map dim where dim (TypeParamDim vn _) = vn dim tparam = error $ "The defunctionalizer expects a monomorphic input program,\n" ++ "but it received a type parameter " ++ pretty tparam ++ " at " ++ locStr (srclocOf tparam) ++ "." envFromDimNames :: [VName] -> Env envFromDimNames = M.fromList . flip zip (repeat $ Dynamic $ Scalar $ Prim $ Signed Int32) -- | Create a new top-level value declaration with the given function name, -- return type, list of parameters, and body expression. liftValDec :: VName -> PatternType -> [VName] -> [Pattern] -> Exp -> DefM () liftValDec fname rettype dims pats body = tell $ Seq.singleton dec where dims' = map (flip TypeParamDim noLoc) dims rettype_st = anyDimShapeAnnotations $ toStruct rettype dec = ValBind { valBindEntryPoint = Nothing , valBindName = fname , valBindRetDecl = Nothing , valBindRetType = Info rettype_st , valBindTypeParams = dims' , valBindParams = pats , valBindBody = body , valBindDoc = Nothing , valBindLocation = noLoc } -- | Given a closure environment, construct a record pattern that -- binds the closed over variables. buildEnvPattern :: Env -> Pattern buildEnvPattern env = RecordPattern (map buildField $ M.toList env) noLoc where buildField (vn, sv) = (nameFromString (pretty vn), Id vn (Info $ anyDimShapeAnnotations $ typeFromSV sv) noLoc) -- | Given a closure environment pattern and the type of a term, -- construct the type of that term, where uniqueness is set to -- `Nonunique` for those arrays that are bound in the environment or -- pattern (except if they are unique there). This ensures that a -- lifted function can create unique arrays as long as they do not -- alias any of its parameters. XXX: it is not clear that this is a -- sufficient property, unfortunately. buildRetType :: Env -> [Pattern] -> StructType -> PatternType -> PatternType buildRetType env pats = comb where bound = foldMap oneName (M.keys env) <> foldMap patternVars pats boundAsUnique v = maybe False (unique . unInfo . identType) $ find ((==v) . identName) $ S.toList $ foldMap patternIdents pats problematic v = (v `member` bound) && not (boundAsUnique v) comb (Scalar (Record fs_annot)) (Scalar (Record fs_got)) = Scalar $ Record $ M.intersectionWith comb fs_annot fs_got comb (Scalar Arrow{}) t = descend t comb got et = descend $ fromStruct got `setUniqueness` uniqueness et `setAliases` aliases et descend t@Array{} | any (problematic . aliasVar) (aliases t) = t `setUniqueness` Nonunique descend (Scalar (Record t)) = Scalar $ Record $ fmap descend t descend t = t -- | Compute the corresponding type for a given static value. typeFromSV :: StaticVal -> PatternType typeFromSV (Dynamic tp) = anyDimShapeAnnotations tp typeFromSV (LambdaSV _ _ _ _ env) = typeFromEnv env typeFromSV (RecordSV ls) = Scalar $ Record $ M.fromList $ map (fmap typeFromSV) ls typeFromSV (DynamicFun (_, sv) _) = typeFromSV sv typeFromSV (SumSV name svs fields) = Scalar $ Sum $ M.insert name (map typeFromSV svs) $ M.fromList fields typeFromSV IntrinsicSV = error $ "Tried to get the type from the " ++ "static value of an intrinsic." typeFromEnv :: Env -> PatternType typeFromEnv = Scalar . Record . M.fromList . map (bimap (nameFromString . pretty) typeFromSV) . M.toList -- | Construct the type for a fully-applied dynamic function from its -- static value and the original types of its arguments. dynamicFunType :: StaticVal -> [PatternType] -> ([PatternType], PatternType) dynamicFunType (DynamicFun _ sv) (p:ps) = let (ps', ret) = dynamicFunType sv ps in (p : ps', ret) dynamicFunType sv _ = ([], typeFromSV sv) -- | Match a pattern with its static value. Returns an environment with -- the identifier components of the pattern mapped to the corresponding -- subcomponents of the static value. matchPatternSV :: PatternBase Info VName -> StaticVal -> Env matchPatternSV (TuplePattern ps _) (RecordSV ls) = mconcat $ zipWith (\p (_, sv) -> matchPatternSV p sv) ps ls matchPatternSV (RecordPattern ps _) (RecordSV ls) | ps' <- sortOn fst ps, ls' <- sortOn fst ls, map fst ps' == map fst ls' = mconcat $ zipWith (\(_, p) (_, sv) -> matchPatternSV p sv) ps' ls' matchPatternSV (PatternParens pat _) sv = matchPatternSV pat sv matchPatternSV (Id vn (Info t) _) sv = -- When matching a pattern with a zero-order STaticVal, the type of -- the pattern wins out. This is important when matching a -- nonunique pattern with a unique value. if orderZeroSV sv then M.singleton vn $ Dynamic t else M.singleton vn sv matchPatternSV (Wildcard _ _) _ = mempty matchPatternSV (PatternAscription pat _ _) sv = matchPatternSV pat sv matchPatternSV PatternLit{} _ = mempty matchPatternSV (PatternConstr c1 _ ps _) (SumSV c2 ls fs) | c1 == c2 = mconcat $ zipWith matchPatternSV ps ls | Just ts <- lookup c1 fs = mconcat $ zipWith matchPatternSV ps $ map svFromType ts | otherwise = error $ "matchPatternSV: missing constructor in type: " ++ pretty c1 matchPatternSV (PatternConstr c1 _ ps _) (Dynamic (Scalar (Sum fs))) | Just ts <- M.lookup c1 fs = mconcat $ zipWith matchPatternSV ps $ map svFromType ts | otherwise = error $ "matchPatternSV: missing constructor in type: " ++ pretty c1 matchPatternSV pat (Dynamic t) = matchPatternSV pat $ svFromType t matchPatternSV pat sv = error $ "Tried to match pattern " ++ pretty pat ++ " with static value " ++ show sv ++ "." orderZeroSV :: StaticVal -> Bool orderZeroSV Dynamic{} = True orderZeroSV (RecordSV fields) = all (orderZeroSV . snd) fields orderZeroSV _ = False -- | Given a pattern and the static value for the defunctionalized argument, -- update the pattern to reflect the changes in the types. updatePattern :: Pattern -> StaticVal -> Pattern updatePattern (TuplePattern ps loc) (RecordSV svs) = TuplePattern (zipWith updatePattern ps $ map snd svs) loc updatePattern (RecordPattern ps loc) (RecordSV svs) | ps' <- sortOn fst ps, svs' <- sortOn fst svs = RecordPattern (zipWith (\(n, p) (_, sv) -> (n, updatePattern p sv)) ps' svs') loc updatePattern (PatternParens pat loc) sv = PatternParens (updatePattern pat sv) loc updatePattern pat@(Id vn (Info tp) loc) sv | orderZero tp = pat | otherwise = Id vn (Info $ typeFromSV sv `setUniqueness` Nonunique) loc updatePattern pat@(Wildcard (Info tp) loc) sv | orderZero tp = pat | otherwise = Wildcard (Info $ typeFromSV sv) loc updatePattern (PatternAscription pat tydecl loc) sv | orderZero . unInfo $ expandedType tydecl = PatternAscription (updatePattern pat sv) tydecl loc | otherwise = updatePattern pat sv updatePattern p@PatternLit{} _ = p updatePattern pat@(PatternConstr c1 (Info t) ps loc) sv@(SumSV _ svs _) | orderZero t = pat | otherwise = PatternConstr c1 (Info t') ps' loc where t' = typeFromSV sv `setUniqueness` Nonunique ps' = zipWith updatePattern ps svs updatePattern (PatternConstr c1 _ ps loc) (Dynamic t) = PatternConstr c1 (Info t) ps loc updatePattern pat (Dynamic t) = updatePattern pat (svFromType t) updatePattern pat sv = error $ "Tried to update pattern " ++ pretty pat ++ "to reflect the static value " ++ show sv -- | Convert a record (or tuple) type to a record static value. This is used for -- "unwrapping" tuples and records that are nested in 'Dynamic' static values. svFromType :: PatternType -> StaticVal svFromType (Scalar (Record fs)) = RecordSV . M.toList $ M.map svFromType fs svFromType t = Dynamic t -- A set of names where we also track uniqueness. newtype NameSet = NameSet (M.Map VName Uniqueness) instance Semigroup NameSet where NameSet x <> NameSet y = NameSet $ M.unionWith max x y instance Monoid NameSet where mempty = NameSet mempty without :: NameSet -> NameSet -> NameSet without (NameSet x) (NameSet y) = NameSet $ x `M.difference` y member :: VName -> NameSet -> Bool member v (NameSet m) = v `M.member` m ident :: Ident -> NameSet ident v = NameSet $ M.singleton (identName v) (uniqueness $ unInfo $ identType v) oneName :: VName -> NameSet oneName v = NameSet $ M.singleton v Nonunique names :: S.Set VName -> NameSet names = foldMap oneName -- | Compute the set of free variables of an expression. freeVars :: Exp -> NameSet freeVars expr = case expr of Literal{} -> mempty IntLit{} -> mempty FloatLit{} -> mempty Parens e _ -> freeVars e QualParens _ e _ -> freeVars e TupLit es _ -> foldMap freeVars es RecordLit fs _ -> foldMap freeVarsField fs where freeVarsField (RecordFieldExplicit _ e _) = freeVars e freeVarsField (RecordFieldImplicit vn t _) = ident $ Ident vn t noLoc ArrayLit es _ _ -> foldMap freeVars es Range e me incl _ _ -> freeVars e <> foldMap freeVars me <> foldMap freeVars incl Var qn (Info t) _ -> NameSet $ M.singleton (qualLeaf qn) $ uniqueness t Ascript e t _ _ -> freeVars e <> names (typeDimNames $ unInfo $ expandedType t) LetPat pat e1 e2 _ _ -> freeVars e1 <> ((names (patternDimNames pat) <> freeVars e2) `without` patternVars pat) LetFun vn (_, pats, _, _, e1) e2 _ -> ((freeVars e1 <> names (foldMap patternDimNames pats)) `without` foldMap patternVars pats) <> (freeVars e2 `without` oneName vn) If e1 e2 e3 _ _ -> freeVars e1 <> freeVars e2 <> freeVars e3 Apply e1 e2 _ _ _ -> freeVars e1 <> freeVars e2 Negate e _ -> freeVars e Lambda pats e0 _ _ _ -> (names (foldMap patternDimNames pats) <> freeVars e0) `without` foldMap patternVars pats OpSection{} -> mempty OpSectionLeft _ _ e _ _ _ -> freeVars e OpSectionRight _ _ e _ _ _ -> freeVars e ProjectSection{} -> mempty IndexSection idxs _ _ -> foldMap freeDimIndex idxs DoLoop pat e1 form e3 _ -> let (e2fv, e2ident) = formVars form in freeVars e1 <> e2fv <> (freeVars e3 `without` (patternVars pat <> e2ident)) where formVars (For v e2) = (freeVars e2, ident v) formVars (ForIn p e2) = (freeVars e2, patternVars p) formVars (While e2) = (freeVars e2, mempty) BinOp (qn, _) _ (e1, _) (e2, _) _ _ -> oneName (qualLeaf qn) <> freeVars e1 <> freeVars e2 Project _ e _ _ -> freeVars e LetWith id1 id2 idxs e1 e2 _ _ -> ident id2 <> foldMap freeDimIndex idxs <> freeVars e1 <> (freeVars e2 `without` ident id1) Index e idxs _ _ -> freeVars e <> foldMap freeDimIndex idxs Update e1 idxs e2 _ -> freeVars e1 <> foldMap freeDimIndex idxs <> freeVars e2 RecordUpdate e1 _ e2 _ _ -> freeVars e1 <> freeVars e2 Unsafe e _ -> freeVars e Assert e1 e2 _ _ -> freeVars e1 <> freeVars e2 Constr _ es _ _ -> foldMap freeVars es Match e cs _ _ -> freeVars e <> foldMap caseFV cs where caseFV (CasePat p eCase _) = (names (patternDimNames p) <> freeVars eCase) `without` patternVars p freeDimIndex :: DimIndexBase Info VName -> NameSet freeDimIndex (DimFix e) = freeVars e freeDimIndex (DimSlice me1 me2 me3) = foldMap (foldMap freeVars) [me1, me2, me3] -- | Extract all the variable names bound in a pattern. patternVars :: Pattern -> NameSet patternVars = mconcat . map ident . S.toList . patternIdents -- | Defunctionalize a top-level value binding. Returns the -- transformed result as well as an environment that binds the name of -- the value binding to the static value of the transformed body. The -- boolean is true if the function is a 'DynamicFun'. defuncValBind :: ValBind -> DefM (ValBind, Env, Bool) -- Eta-expand entry points with a functional return type. defuncValBind (ValBind entry@Just{} name _ (Info rettype) tparams params body _ loc) | (rettype_ps, rettype') <- unfoldFunType rettype, not $ null rettype_ps = do (body_pats, body', _) <- etaExpand body -- FIXME: we should also handle non-constant size annotations -- here. defuncValBind $ ValBind entry name Nothing (Info $ onlyConstantDims rettype') tparams (params <> body_pats) body' Nothing loc where onlyConstantDims = first onDim onDim (ConstDim x) = ConstDim x onDim _ = AnyDim defuncValBind valbind@(ValBind _ name retdecl rettype tparams params body _ _) = do let env = envFromShapeParams tparams (params', body', sv) <- localEnv env $ defuncLet tparams params body rettype -- Remove any shape parameters that no longer occur in the value parameters. let dim_names = foldMap patternDimNames params' tparams' = filter ((`S.member` dim_names) . typeParamName) tparams let rettype' = anyDimShapeAnnotations $ toStruct $ typeOf body' return ( valbind { valBindRetDecl = retdecl , valBindRetType = Info $ combineTypeShapes (unInfo rettype) rettype' , valBindTypeParams = tparams' , valBindParams = params' , valBindBody = body' } , M.singleton name sv , case sv of DynamicFun{} -> True _ -> False) -- | Defunctionalize a list of top-level declarations. defuncVals :: [ValBind] -> DefM (Seq.Seq ValBind) defuncVals [] = return mempty defuncVals (valbind : ds) = do ((valbind', env, dyn), defs) <- collectFuns $ defuncValBind valbind ds' <- localEnv env $ if dyn then isGlobal (valBindName valbind') $ defuncVals ds else defuncVals ds return $ defs <> Seq.singleton valbind' <> ds' -- | Transform a list of top-level value bindings. May produce new -- lifted function definitions, which are placed in front of the -- resulting list of declarations. transformProg :: MonadFreshNames m => [ValBind] -> m [ValBind] transformProg decs = modifyNameSource $ \namesrc -> let (decs', namesrc', liftedDecs) = runDefM namesrc $ defuncVals decs in (toList $ liftedDecs <> decs', namesrc')