{-# LANGUAGE RankNTypes #-}
{-# OPTIONS -Wall #-}
module Language.Haskell.HBB.Internal.AST (
     transformRenamedAST
    ,queryRenamedAST
    ,transformTypecheckedAST) where

import Data.Generics (GenericM,GenericQ,gmapM,Data,extQ,gmapQ)
import NameSet (NameSet)
import GHC

-- | This function is similar to 'everythingBut' from the SYB schemes but uses a
-- neutral element and automatically omits elements of the renamed AST that
-- mustn't be evaluated.
queryRenamedAST :: r -> (r -> r -> r) -> GenericQ r -> GenericQ r
queryRenamedAST n m f x = if isPostTcType x
                          then n
                          else foldl m (f x) (gmapQ (queryRenamedAST n m f) x)
    where isPostTcType :: GenericQ Bool
          isPostTcType = (const False) `extQ` 
                         ((const True) :: PostTcType -> Bool)

-- | Uses SYB to create generic transformation that omits elements of type
-- 'PostTcType'. This is needed because after the renamer 'PostTcTypes' must not
-- be evaluated.
transformRenamedAST :: (Monad m, Data a) => GenericM m -> a -> m a
transformRenamedAST = everywhereButM isPostTcType
    where
          isPostTcType :: GenericQ Bool
          isPostTcType = extQ (const False) ((const True) :: PostTcType -> Bool)

-- | Uses SYB to create generic transformation that omits elements of type
-- 'NameSet'. This is needed because after the renamer 'NameSet' (the element
-- 'bind_fvs' in function bindings) must not be evaluated.
transformTypecheckedAST :: (Monad m, Data a) => GenericM m -> a -> m a
transformTypecheckedAST = everywhereButM isNameSet
    where
          isNameSet :: GenericQ Bool
          isNameSet = extQ (const False) ((const True) :: NameSet -> Bool)

-- | The default traversal functions of the SYB package lack a generic monadic
-- transformation where certain cases can be excluded. We want this case because
-- we want to use the Writer Monad and certain cases that should not be
-- evaluated must be omitted.
everywhereButM :: Monad m => GenericQ Bool -> GenericM m -> GenericM m
everywhereButM shouldExclude f x | (not . shouldExclude) x = do x' <- f x; gmapM (everywhereButM shouldExclude f) x'
everywhereButM _ _ x                                       = return (id x)