{-# LANGUAGE FlexibleContexts , LambdaCase , RankNTypes , ScopedTypeVariables #-} -- | Helper functions for using the AST. module Language.Haskell.Tools.AST.Helpers where import SrcLoc import qualified Name as GHC import Control.Reference hiding (element) import Control.Monad import Data.List import Data.Maybe import Data.Function hiding ((&)) import Data.Generics.Uniplate.Operations import Language.Haskell.Tools.AST.Ann import Language.Haskell.Tools.AST.Modules import Language.Haskell.Tools.AST.Decls import Language.Haskell.Tools.AST.Binds import Language.Haskell.Tools.AST.Types import Language.Haskell.Tools.AST.Base import Language.Haskell.Tools.AST.References import Debug.Trace ordByOccurrence :: SimpleName a -> SimpleName a -> Ordering ordByOccurrence = compare `on` nameElements -- | The occurrence of the name. nameString :: SimpleName a -> String nameString = intercalate "." . nameElements -- | The qualifiers and the unqualified name nameElements :: SimpleName a -> [String] nameElements n = (n ^? qualifiers&annList&element&simpleNameStr) ++ [n ^. unqualifiedName&element&simpleNameStr] -- | The qualifier of the name nameQualifier :: SimpleName a -> [String] nameQualifier n = n ^? qualifiers&annList&element&simpleNameStr -- | Does the import declaration import only the explicitly listed elements? importIsExact :: ImportDecl a -> Bool importIsExact = isJust . (^? importSpec&annJust&element&importSpecList) -- | Does the import declaration has a 'hiding' clause? importIsHiding :: ImportDecl a -> Bool importIsHiding = isJust . (^? importSpec&annJust&element&importSpecHiding) -- | All elements that are explicitly listed to be imported in the import declaration importExacts :: Simple Traversal (ImportDecl a) (IESpec a) importExacts = importSpec&annJust&element&importSpecList&annList&element -- | All elements that are hidden in an import importHidings :: Simple Traversal (ImportDecl a) (IESpec a) importHidings = importSpec&annJust&element&importSpecList&annList&element -- | Possible qualifiers to use imported definitions importQualifiers :: ImportDecl a -> [[String]] importQualifiers imp = (if isAnnNothing (imp ^. importQualified) then [[]] else []) ++ maybe [] (\n -> [nameElements n]) (imp ^? importAs&annJust&element&importRename&element) bindingSemantics :: Simple Traversal (Ann ValueBind (NodeInfo (SemanticInfo n) s)) (SemanticInfo n) bindingSemantics = element&(valBindPat&element&patternName&element&simpleName &+& funBindMatches&annList&element&matchLhs&element &(matchLhsName&element&simpleName &+& matchLhsOperator&element&operatorName)) &semantics bindingName :: Simple Traversal (Ann ValueBind (NodeInfo (SemanticInfo n) s)) n bindingName = bindingSemantics&nameInfo declHeadNames :: Simple Traversal (Ann DeclHead a) (Ann SimpleName a) declHeadNames = element & (dhName&element&simpleName &+& dhBody&declHeadNames &+& dhAppFun&declHeadNames &+& dhOperator&element&operatorName) typeParams :: Simple Traversal (Ann Type a) (Ann Type a) typeParams = fromTraversal typeParamsTrav where typeParamsTrav f (Ann a (TyFun p r)) = Ann a <$> (TyFun <$> f p <*> typeParamsTrav f r) typeParamsTrav f (Ann a (TyForall vs t)) = Ann a <$> (TyForall vs <$> typeParamsTrav f t) typeParamsTrav f (Ann a (TyCtx ctx t)) = Ann a <$> (TyCtx ctx <$> typeParamsTrav f t) typeParamsTrav f (Ann a (TyParen t)) = Ann a <$> (TyParen <$> typeParamsTrav f t) typeParamsTrav f t = f t -- | Access the semantic information of an AST node. semantics :: Simple Lens (Ann a (NodeInfo sema src)) sema semantics = annotation&semanticInfo dhNames :: Simple Traversal (Ann DeclHead (NodeInfo (SemanticInfo n) src)) n dhNames = declHeadNames & semantics & nameInfo -- | A type class for transformations that work on both top-level and local definitions class BindingElem d where sigBind :: Simple Partial (d a) (Ann TypeSignature a) valBind :: Simple Partial (d a) (Ann ValueBind a) createTypeSig :: Ann TypeSignature a -> d a createBinding :: Ann ValueBind a -> d a isTypeSig :: d a -> Bool isBinding :: d a -> Bool instance BindingElem Decl where sigBind = declTypeSig valBind = declValBind createTypeSig = TypeSigDecl createBinding = ValueBinding isTypeSig (TypeSigDecl _) = True isTypeSig _ = False isBinding (ValueBinding _) = True isBinding _ = False instance BindingElem LocalBind where sigBind = localSig valBind = localVal createTypeSig = LocalSignature createBinding = LocalValBind isTypeSig (LocalSignature _) = True isTypeSig _ = False isBinding (LocalValBind _) = True isBinding _ = False bindName :: BindingElem d => Simple Traversal (d (NodeInfo (SemanticInfo n) src)) n bindName = valBind&bindingName &+& sigBind&element&tsName&annList&element&simpleName&semantics&nameInfo valBindsInList :: BindingElem d => Simple Traversal (AnnList d a) (Ann ValueBind a) valBindsInList = annList & element & valBind getValBindInList :: (BindingElem d, HasRange a) => RealSrcSpan -> AnnList d a -> Maybe (Ann ValueBind a) getValBindInList sp ls = case ls ^? valBindsInList & filtered (isInside sp) of [] -> Nothing [n] -> Just n _ -> error "getValBindInList: Multiple nodes" nodesContaining :: forall node inner a . (Biplate (node a) (inner a), HasAnnot node, HasAnnot inner, HasRange a) => RealSrcSpan -> Simple Traversal (node a) (inner a) nodesContaining rng = biplateRef & filtered (isInside rng) isInside :: (HasAnnot node, HasRange a) => RealSrcSpan -> node a -> Bool isInside rng nd = case getRange (getAnnot nd) of RealSrcSpan sp -> sp `containsSpan` rng _ -> False nodesWithRange :: forall node inner a . (Biplate (node a) (inner a), HasAnnot node, HasAnnot inner, HasRange a) => RealSrcSpan -> Simple Traversal (node a) (inner a) nodesWithRange rng = biplateRef & filtered (hasRange rng) hasRange :: (HasAnnot node, HasRange a) => RealSrcSpan -> node a -> Bool hasRange rng node = case getRange (getAnnot node) of RealSrcSpan sp -> sp == rng _ -> False getNodeContaining :: (Biplate (node a) (Ann inner a), HasAnnot node, HasRange a) => RealSrcSpan -> node a -> Maybe (Ann inner a) getNodeContaining sp node = case node ^? nodesContaining sp of [] -> Nothing results -> Just $ minimumBy (compareRangeLength `on` (getRange . (^. annotation))) results compareRangeLength :: SrcSpan -> SrcSpan -> Ordering compareRangeLength (RealSrcSpan sp1) (RealSrcSpan sp2) = (lineDiff sp1 `compare` lineDiff sp2) `mappend` (colDiff sp1 `compare` colDiff sp2) where lineDiff sp = srcLocLine (realSrcSpanStart sp) - srcLocLine (realSrcSpanEnd sp) colDiff sp = srcLocCol (realSrcSpanStart sp) - srcLocCol (realSrcSpanEnd sp) getNode :: (Biplate (node a) (inner a), HasAnnot node, HasAnnot inner, HasRange a) => RealSrcSpan -> node a -> inner a getNode sp node = case node ^? nodesWithRange sp of [] -> error "getNode: The node cannot be found" [n] -> n _ -> error "getNode: Multiple nodes"