module Data.Comp.Trans.Collect ( collectTypes ) where import Control.Monad ( liftM, liftM2 ) import Data.Foldable ( fold ) import Data.Monoid ( Monoid(..) ) import Data.Set as Set ( Set, singleton, union, difference, toList, member, empty ) import Language.Haskell.TH.Syntax import Language.Haskell.TH.ExpandSyns ( expandSyns ) import Data.Comp.Trans.Names ( standardNameSet ) -- | Finds all type names transitively referred to by a given type, -- removing standard types collectTypes :: Name -> Q [Name] collectTypes n = do names <- fixpoint collectTypes' n return $ toList $ difference names standardNameSet -- | -- Finds the fixpoint of a monotone monadic function using chaotic iteration fixpoint :: (Ord a, Monad m) => (a -> m (Set a)) -> a -> m (Set a) fixpoint f x = run $ singleton x where run s = do s' <- liftM fold $ mapSetM f s if s' == s then return s' else run s' -- | mapM for Data.Set mapSetM :: (Monad m, Ord b) => (a -> m b) -> Set a -> m (Set b) mapSetM f x = liftM (mconcat . map singleton) $ mapM f (toList x) collectTypes' :: Name -> Q (Set Name) collectTypes' n | member n standardNameSet = return empty collectTypes' n = do inf <- reify n let cons = case inf of TyConI (DataD _ _ _ cns _) -> cns TyConI (NewtypeD _ _ _ con _) -> [con] _ -> [] childNames <- liftM concat $ mapM extractNames cons return $ (singleton n) `union` (mconcat $ map singleton childNames) class ExtractNames a where extractNames :: a -> Q [Name] instance ExtractNames Con where extractNames (NormalC _ xs) = liftM concat $ mapM extractNames xs extractNames (RecC _ xs) = liftM concat $ mapM extractNames xs extractNames (InfixC a _ b) = liftM2 (++) (extractNames a) (extractNames b) extractNames (ForallC _ _ x) = extractNames x instance ExtractNames StrictType where extractNames (_, t) = extractNames t instance ExtractNames VarStrictType where extractNames (_, _, t) = extractNames t instance ExtractNames Type where extractNames tSyn = do t <- expandSyns tSyn case t of AppT a b -> liftM2 (++) (extractNames a) (extractNames b) ConT n -> return [n] _ -> return []