{-# LANGUAGE FlexibleInstances #-} -- | -- -- Functions for generic traversals across Futhark syntax trees. The -- motivation for this module came from dissatisfaction with rewriting -- the same trivial tree recursions for every module. A possible -- alternative would be to use normal \"Scrap your -- boilerplate\"-techniques, but these are rejected for two reasons: -- -- * They are too slow. -- -- * More importantly, they do not tell you whether you have missed -- some cases. -- -- Instead, this module defines various traversals of the Futhark syntax -- tree. The implementation is rather tedious, but the interface is -- easy to use. -- -- A traversal of the Futhark syntax tree is expressed as a tuple of -- functions expressing the operations to be performed on the various -- types of nodes. module Language.Futhark.Traversals ( ASTMapper(..) , ASTMappable(..) ) where import qualified Data.Set as S import Language.Futhark.Syntax -- | Express a monad mapping operation on a syntax node. Each element -- of this structure expresses the operation to be performed on a -- given child. data ASTMapper m = ASTMapper { mapOnExp :: ExpBase Info VName -> m (ExpBase Info VName) , mapOnName :: VName -> m VName , mapOnQualName :: QualName VName -> m (QualName VName) , mapOnType :: TypeBase () () -> m (TypeBase () ()) , mapOnCompType :: CompType -> m CompType , mapOnStructType :: StructType -> m StructType , mapOnPatternType :: PatternType -> m PatternType } class ASTMappable x where -- | Map a monadic action across the immediate children of an -- object. Importantly, the 'astMap' action is not invoked for -- the object itself, and the mapping does not descend recursively -- into subexpressions. The mapping is done left-to-right. astMap :: Monad m => ASTMapper m -> x -> m x instance ASTMappable (ExpBase Info VName) where astMap tv (Var name t loc) = Var <$> mapOnQualName tv name <*> traverse (mapOnPatternType tv) t <*> pure loc astMap _ (Literal val loc) = pure $ Literal val loc astMap tv (IntLit val t loc) = IntLit val <$> traverse (mapOnType tv) t <*> pure loc astMap tv (FloatLit val t loc) = FloatLit val <$> traverse (mapOnType tv) t <*> pure loc astMap tv (Parens e loc) = Parens <$> mapOnExp tv e <*> pure loc astMap tv (QualParens name e loc) = QualParens <$> mapOnQualName tv name <*> mapOnExp tv e <*> pure loc astMap tv (TupLit els loc) = TupLit <$> mapM (mapOnExp tv) els <*> pure loc astMap tv (RecordLit fields loc) = RecordLit <$> astMap tv fields <*> pure loc astMap tv (ArrayLit els t loc) = ArrayLit <$> mapM (mapOnExp tv) els <*> traverse (mapOnCompType tv) t <*> pure loc astMap tv (Range start next end t loc) = Range <$> mapOnExp tv start <*> traverse (mapOnExp tv) next <*> traverse (mapOnExp tv) end <*> traverse (mapOnCompType tv) t <*> pure loc astMap tv (Ascript e tdecl loc) = Ascript <$> mapOnExp tv e <*> astMap tv tdecl <*> pure loc astMap tv (BinOp fname t (x,xt) (y,yt) (Info rt) loc) = BinOp <$> mapOnQualName tv fname <*> traverse (mapOnPatternType tv) t <*> ((,) <$> mapOnExp tv x <*> traverse (mapOnStructType tv) xt) <*> ((,) <$> mapOnExp tv y <*> traverse (mapOnStructType tv) yt) <*> (Info <$> mapOnPatternType tv rt) <*> pure loc astMap tv (Negate x loc) = Negate <$> mapOnExp tv x <*> pure loc astMap tv (If c texp fexp t loc) = If <$> mapOnExp tv c <*> mapOnExp tv texp <*> mapOnExp tv fexp <*> traverse (mapOnCompType tv) t <*> pure loc astMap tv (Apply f arg d (Info t) loc) = Apply <$> mapOnExp tv f <*> mapOnExp tv arg <*> pure d <*> (Info <$> mapOnPatternType tv t) <*> pure loc astMap tv (LetPat tparams pat e body loc) = LetPat <$> mapM (astMap tv) tparams <*> astMap tv pat <*> mapOnExp tv e <*> mapOnExp tv body <*> pure loc astMap tv (LetFun name (fparams, params, ret, t, e) body loc) = LetFun <$> mapOnName tv name <*> ((,,,,) <$> mapM (astMap tv) fparams <*> mapM (astMap tv) params <*> traverse (astMap tv) ret <*> traverse (mapOnStructType tv) t <*> mapOnExp tv e) <*> mapOnExp tv body <*> pure loc astMap tv (LetWith dest src idxexps vexp body loc) = pure LetWith <*> astMap tv dest <*> astMap tv src <*> mapM (astMap tv) idxexps <*> mapOnExp tv vexp <*> mapOnExp tv body <*> pure loc astMap tv (Update src slice v loc) = Update <$> mapOnExp tv src <*> mapM (astMap tv) slice <*> mapOnExp tv v <*> pure loc astMap tv (RecordUpdate src fs v (Info t) loc) = RecordUpdate <$> mapOnExp tv src <*> pure fs <*> mapOnExp tv v <*> (Info <$> mapOnPatternType tv t) <*> pure loc astMap tv (Project field e t loc) = Project field <$> mapOnExp tv e <*> traverse (mapOnCompType tv) t <*> pure loc astMap tv (Index arr idxexps t loc) = pure Index <*> astMap tv arr <*> mapM (astMap tv) idxexps <*> traverse (mapOnCompType tv) t <*> pure loc astMap tv (Map fun e t loc) = Map <$> mapOnExp tv fun <*> mapOnExp tv e <*> traverse (mapOnCompType tv) t <*> pure loc astMap tv (Reduce comm fun startexp arrexp loc) = Reduce comm <$> mapOnExp tv fun <*> mapOnExp tv startexp <*> mapOnExp tv arrexp <*> pure loc astMap tv (GenReduce hist op ne bfun img loc) = GenReduce <$> mapOnExp tv hist <*> mapOnExp tv op <*> mapOnExp tv ne <*> mapOnExp tv bfun <*> mapOnExp tv img <*> pure loc astMap tv (Unsafe e loc) = Unsafe <$> mapOnExp tv e <*> pure loc astMap tv (Assert e1 e2 desc loc) = Assert <$> mapOnExp tv e1 <*> mapOnExp tv e2 <*> pure desc <*> pure loc astMap tv (Scan fun startexp arrexp loc) = pure Scan <*> mapOnExp tv fun <*> mapOnExp tv startexp <*> mapOnExp tv arrexp <*> pure loc astMap tv (Filter fun arrexp loc) = pure Filter <*> mapOnExp tv fun <*> mapOnExp tv arrexp <*> pure loc astMap tv (Partition k fun arrexp loc) = Partition k <$> mapOnExp tv fun <*> mapOnExp tv arrexp <*> pure loc astMap tv (Stream form fun arr loc) = pure Stream <*> mapOnStreamForm form <*> mapOnExp tv fun <*> mapOnExp tv arr <*> pure loc where mapOnStreamForm (MapLike o) = pure $ MapLike o mapOnStreamForm (RedLike o comm lam) = RedLike o comm <$> mapOnExp tv lam astMap tv (Lambda tparams params body ret t loc) = Lambda <$> mapM (astMap tv) tparams <*> mapM (astMap tv) params <*> astMap tv body <*> traverse (astMap tv) ret <*> traverse (traverse $ mapOnStructType tv) t <*> pure loc astMap tv (OpSection name t loc) = OpSection <$> mapOnQualName tv name <*> traverse (mapOnPatternType tv) t <*> pure loc astMap tv (OpSectionLeft name t arg (t1a, t1b) t2 loc) = OpSectionLeft <$> mapOnQualName tv name <*> traverse (mapOnPatternType tv) t <*> mapOnExp tv arg <*> ((,) <$> traverse (mapOnStructType tv) t1a <*> traverse (mapOnStructType tv) t1b) <*> traverse (mapOnPatternType tv) t2 <*> pure loc astMap tv (OpSectionRight name t arg (t1a, t1b) t2 loc) = OpSectionRight <$> mapOnQualName tv name <*> traverse (mapOnPatternType tv) t <*> mapOnExp tv arg <*> ((,) <$> traverse (mapOnStructType tv) t1a <*> traverse (mapOnStructType tv) t1b) <*> traverse (mapOnPatternType tv) t2 <*> pure loc astMap tv (ProjectSection fields t loc) = ProjectSection fields <$> traverse (mapOnPatternType tv) t <*> pure loc astMap tv (IndexSection idxs t loc) = IndexSection <$> mapM (astMap tv) idxs <*> traverse (mapOnPatternType tv) t <*> pure loc astMap tv (DoLoop tparams mergepat mergeexp form loopbody loc) = DoLoop <$> mapM (astMap tv) tparams <*> astMap tv mergepat <*> mapOnExp tv mergeexp <*> astMap tv form <*> mapOnExp tv loopbody <*> pure loc astMap tv (VConstr0 name t loc) = VConstr0 name <$> traverse (mapOnCompType tv) t <*> pure loc astMap tv (Match e cases t loc) = Match <$> mapOnExp tv e <*> astMap tv cases <*> traverse (mapOnCompType tv) t <*> pure loc instance ASTMappable (LoopFormBase Info VName) where astMap tv (For i bound) = For <$> astMap tv i <*> astMap tv bound astMap tv (ForIn pat e) = ForIn <$> astMap tv pat <*> astMap tv e astMap tv (While e) = While <$> astMap tv e instance ASTMappable (TypeExp VName) where astMap tv (TEVar qn loc) = TEVar <$> mapOnQualName tv qn <*> pure loc astMap tv (TETuple ts loc) = TETuple <$> traverse (astMap tv) ts <*> pure loc astMap tv (TERecord ts loc) = TERecord <$> traverse (traverse $ astMap tv) ts <*> pure loc astMap tv (TEArray te dim loc) = TEArray <$> astMap tv te <*> astMap tv dim <*> pure loc astMap tv (TEUnique t loc) = TEUnique <$> astMap tv t <*> pure loc astMap tv (TEApply t1 t2 loc) = TEApply <$> astMap tv t1 <*> astMap tv t2 <*> pure loc astMap tv (TEArrow v t1 t2 loc) = TEArrow v <$> astMap tv t1 <*> astMap tv t2 <*> pure loc astMap _ te@TEEnum{} = pure te instance ASTMappable (TypeArgExp VName) where astMap tv (TypeArgExpDim dim loc) = TypeArgExpDim <$> astMap tv dim <*> pure loc astMap tv (TypeArgExpType te) = TypeArgExpType <$> astMap tv te instance ASTMappable (DimDecl VName) where astMap tv (NamedDim vn) = NamedDim <$> mapOnQualName tv vn astMap _ (ConstDim k) = pure $ ConstDim k astMap _ AnyDim = pure AnyDim instance ASTMappable (TypeParamBase VName) where astMap = traverse . mapOnName instance ASTMappable (DimIndexBase Info VName) where astMap tv (DimFix j) = DimFix <$> astMap tv j astMap tv (DimSlice i j stride) = DimSlice <$> maybe (return Nothing) (fmap Just . astMap tv) i <*> maybe (return Nothing) (fmap Just . astMap tv) j <*> maybe (return Nothing) (fmap Just . astMap tv) stride instance ASTMappable Alias where astMap tv (AliasBound v) = AliasBound <$> mapOnName tv v astMap tv (AliasFree v) = AliasFree <$> mapOnName tv v instance ASTMappable Aliasing where astMap tv = fmap S.fromList . traverse (astMap tv) . S.toList type TypeTraverser f t dim1 als1 dim2 als2 = (TypeName -> f TypeName) -> (dim1 -> f dim2) -> (als1 -> f als2) -> t dim1 als1 -> f (t dim2 als2) traverseType :: Applicative f => TypeTraverser f TypeBase dim1 als1 dims als2 traverseType _ _ _ (Prim t) = pure $ Prim t traverseType f g h (Array als u et shape) = Array <$> h als <*> pure u <*> traverseArrayElemType f g et <*> traverse g shape traverseType f g h (Record fs) = Record <$> traverse (traverseType f g h) fs traverseType f g h (TypeVar als u t args) = TypeVar <$> h als <*> pure u <*> f t <*> traverse (traverseTypeArg f g) args traverseType f g h (Arrow als v t1 t2) = Arrow <$> h als <*> pure v <*> traverseType f g h t1 <*> traverseType f g h t2 traverseType _ _ _ (Enum cs) = pure $ Enum cs traverseArrayElemType :: Applicative f => (TypeName -> f TypeName) -> (dim1 -> f dim2) -> ArrayElemTypeBase dim1 -> f (ArrayElemTypeBase dim2) traverseArrayElemType _ _ (ArrayPrimElem t) = pure $ ArrayPrimElem t traverseArrayElemType f g (ArrayPolyElem t args) = ArrayPolyElem <$> f t <*> traverse (traverseTypeArg f g) args traverseArrayElemType f g (ArrayRecordElem fs) = ArrayRecordElem <$> traverse (traverseRecordArrayElemType f g) fs traverseArrayElemType _ _ (ArrayEnumElem cs) = pure $ ArrayEnumElem cs traverseRecordArrayElemType :: Applicative f => (TypeName -> f TypeName) -> (dim1 -> f dim2) -> RecordArrayElemTypeBase dim1 -> f (RecordArrayElemTypeBase dim2) traverseRecordArrayElemType f g (RecordArrayElem et) = RecordArrayElem <$> traverseArrayElemType f g et traverseRecordArrayElemType f g (RecordArrayArrayElem et shape) = RecordArrayArrayElem <$> traverseArrayElemType f g et <*> traverse g shape traverseTypeArg :: Applicative f => (TypeName -> f TypeName) -> (dim1 -> f dim2) -> TypeArg dim1 -> f (TypeArg dim2) traverseTypeArg _ g (TypeArgDim d loc) = TypeArgDim <$> g d <*> pure loc traverseTypeArg f g (TypeArgType t loc) = TypeArgType <$> traverseType f g pure t <*> pure loc instance ASTMappable (TypeBase () ()) where astMap tv = traverseType f pure pure where f = fmap typeNameFromQualName . mapOnQualName tv . qualNameFromTypeName instance ASTMappable CompType where astMap tv = traverseType f pure (astMap tv) where f = fmap typeNameFromQualName . mapOnQualName tv . qualNameFromTypeName instance ASTMappable StructType where astMap tv = traverseType f (astMap tv) pure where f = fmap typeNameFromQualName . mapOnQualName tv . qualNameFromTypeName instance ASTMappable PatternType where astMap tv = traverseType f (astMap tv) (astMap tv) where f = fmap typeNameFromQualName . mapOnQualName tv . qualNameFromTypeName instance ASTMappable (TypeDeclBase Info VName) where astMap tv (TypeDecl dt (Info et)) = TypeDecl <$> astMap tv dt <*> (Info <$> mapOnStructType tv et) instance ASTMappable (IdentBase Info VName) where astMap tv (Ident name (Info t) loc) = Ident <$> mapOnName tv name <*> (Info <$> mapOnCompType tv t) <*> pure loc instance ASTMappable (PatternBase Info VName) where astMap tv (Id name (Info t) loc) = Id <$> mapOnName tv name <*> (Info <$> mapOnPatternType tv t) <*> pure loc astMap tv (TuplePattern pats loc) = TuplePattern <$> mapM (astMap tv) pats <*> pure loc astMap tv (RecordPattern fields loc) = RecordPattern <$> mapM (traverse $ astMap tv) fields <*> pure loc astMap tv (PatternParens pat loc) = PatternParens <$> astMap tv pat <*> pure loc astMap tv (PatternAscription pat t loc) = PatternAscription <$> astMap tv pat <*> astMap tv t <*> pure loc astMap tv (Wildcard (Info t) loc) = Wildcard <$> (Info <$> mapOnPatternType tv t) <*> pure loc astMap tv (PatternLit e (Info t) loc) = PatternLit <$> astMap tv e <*> (Info <$> mapOnPatternType tv t) <*> pure loc instance ASTMappable (FieldBase Info VName) where astMap tv (RecordFieldExplicit name e loc) = RecordFieldExplicit name <$> mapOnExp tv e <*> pure loc astMap tv (RecordFieldImplicit name t loc) = RecordFieldImplicit <$> mapOnName tv name <*> traverse (mapOnCompType tv) t <*> pure loc instance ASTMappable (CaseBase Info VName) where astMap tv (CasePat pat e loc) = CasePat <$> astMap tv pat <*> astMap tv e <*> pure loc instance ASTMappable a => ASTMappable (Info a) where astMap tv = traverse $ astMap tv instance ASTMappable a => ASTMappable [a] where astMap tv = traverse $ astMap tv instance (ASTMappable a, ASTMappable b) => ASTMappable (a,b) where astMap tv (x,y) = (,) <$> astMap tv x <*> astMap tv y instance (ASTMappable a, ASTMappable b, ASTMappable c) => ASTMappable (a,b,c) where astMap tv (x,y,z) = (,,) <$> astMap tv x <*> astMap tv y <*> astMap tv z