{-# LANGUAGE CPP #-}
module Utils
  ( nodeHasAnnotation
  , getNodeInfo
  , foldNodeChildren
  ) where

#if __GLASGOW_HASKELL__ >= 900
import qualified Data.Map as M
#endif
import qualified Data.Set as S
import           Data.String

import           GHC.Api

#if __GLASGOW_HASKELL__ >= 900
mergeNodeInfo :: NodeInfo a -> NodeInfo a -> NodeInfo a
mergeNodeInfo (NodeInfo as ai ad) (NodeInfo bs bi bd) =
  NodeInfo (as <> bs) (ai <> bi) (M.unionWith (<>) ad bd)
#endif

-- | Extract node info for an AST. GHC 9 includes generated things that need to
-- be removed.
getNodeInfo :: HieAST a -> NodeInfo a
#if __GLASGOW_HASKELL__ >= 900
getNodeInfo = M.foldl' mergeNodeInfo emptyNodeInfo
            . M.delete GeneratedInfo -- removed ghc generated nodes
            . getSourcedNodeInfo . sourcedNodeInfo
#else
getNodeInfo :: HieAST a -> NodeInfo a
getNodeInfo = HieAST a -> NodeInfo a
forall a. HieAST a -> NodeInfo a
nodeInfo
#endif

nodeHasAnnotation :: String -> String -> HieAST a -> Bool
nodeHasAnnotation :: String -> String -> HieAST a -> Bool
nodeHasAnnotation String
constructor String
ty =
    (FastString, FastString) -> Set (FastString, FastString) -> Bool
forall a. Ord a => a -> Set a -> Bool
S.member (String -> FastString
forall a. IsString a => String -> a
fromString String
constructor, String -> FastString
forall a. IsString a => String -> a
fromString String
ty)
  (Set (FastString, FastString) -> Bool)
-> (HieAST a -> Set (FastString, FastString)) -> HieAST a -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. NodeInfo a -> Set (FastString, FastString)
forall a. NodeInfo a -> Set (FastString, FastString)
nodeAnnotations
  (NodeInfo a -> Set (FastString, FastString))
-> (HieAST a -> NodeInfo a)
-> HieAST a
-> Set (FastString, FastString)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HieAST a -> NodeInfo a
forall a. HieAST a -> NodeInfo a
getNodeInfo

foldNodeChildren :: Monoid m => (HieAST a -> m) -> HieAST a -> m
foldNodeChildren :: (HieAST a -> m) -> HieAST a -> m
foldNodeChildren HieAST a -> m
f = (HieAST a -> m) -> [HieAST a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap HieAST a -> m
f ([HieAST a] -> m) -> (HieAST a -> [HieAST a]) -> HieAST a -> m
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HieAST a -> [HieAST a]
forall a. HieAST a -> [HieAST a]
nodeChildren