module Language.Haskell.TH.TypeGraph.Free
( freeTypeVars
) where
import Control.Lens hiding (Strict, cons)
import Control.Monad.State (MonadState, execStateT)
import Data.Set as Set (Set, delete, difference, empty, fromList, insert, member)
import Language.Haskell.TH
import Language.Haskell.TH.Desugar ()
import Language.Haskell.TH.Syntax (Quasi(qReify))
import Language.Haskell.TH.TypeGraph.Prelude (pprint1)
data St
= St { _result :: Set Name
, _visited :: Set Name
} deriving Show
st0 :: St
st0 = St {_result = empty, _visited = empty}
$(makeLenses ''St)
freeTypeVars :: (FreeTypeVars t, Quasi m) => t -> m (Set Name)
freeTypeVars x = view result <$> execStateT (ftv x) st0
class FreeTypeVars t where
ftv :: (Quasi m, MonadState St m) => t -> m ()
instance FreeTypeVars a => FreeTypeVars [a] where
ftv ts = mapM_ ftv ts
instance FreeTypeVars Type where
ftv (ForallT tvbs cx ty) = do
ftv ty
mapM_ go_pred cx
result %= (`Set.difference` (Set.fromList (map tvbName tvbs)))
where
#if __GLASGOW_HASKELL__ >= 709
go_pred typ =
ftv typ
#else
go_pred (ClassP _ tys) = ftv tys
go_pred (EqualP t1 t2) = do
ftv t1
ftv t2
#endif
ftv (SigT ty _) = ftv ty
ftv (VarT n) = result %= Set.insert n
ftv (AppT t1 t2) = go_app [t2] t1
ftv typ@(ConT _) = go_app [] typ
ftv _ = return ()
go_app :: (Quasi m, MonadState St m) => [Type] -> Type -> m ()
go_app params (AppT t1 t2) = go_app (t2 : params) t1
go_app params (ConT n) = do
stk <- use visited
case Set.member n stk of
True -> return ()
False -> do
visited %= Set.insert n
qReify n >>= go_info (reverse params)
go_app params typ = mapM_ ftv (typ : params)
go_info :: (Quasi m, MonadState St m) => [Type] -> Info -> m ()
go_info params (TyConI dec) = go_dec params ( dec)
go_info params (FamilyI dec _insts) = go_dec params dec
go_info _params (PrimTyConI _name _arity _unlifed) = return ()
go_info _params info = error $ "go_info - unexpected: " ++ pprint1 info
go_dec :: (Quasi m, MonadState St m) => [Type] -> Dec -> m ()
#if MIN_VERSION_template_haskell(2,11,0)
go_dec params (NewtypeD cx tname tvs m con supers) = go_dec params (DataD cx tname tvs m [con] supers)
go_dec params (DataD _ tname tvs _ _ _) | length params > length tvs = error $ "Too many arguments to " ++ show tname
go_dec params (DataD _cx tname tvs _ cons _supers) = do
ftv cons
go_params tname tvs params
#else
go_dec params (NewtypeD cx tname tvs con supers) = go_dec params (DataD cx tname tvs [con] supers)
go_dec params (DataD _ tname tvs _ _) | length params > length tvs = error $ "Too many arguments to " ++ show tname
go_dec params (DataD _cx tname tvs cons _supers) = do
ftv cons
go_params tname tvs params
#endif
go_dec params (TySynD tname tvs typ) = do
ftv typ
go_params tname tvs params
#if MIN_VERSION_template_haskell(2,11,0)
go_dec params (DataFamilyD tname tvs _mkind) = go_params tname tvs params
#else
go_dec params (FamilyD _flavour tname tvs _mkind) = go_params tname tvs params
#endif
go_dec params dec = error $ "go_dec - unexpected: " ++ pprint1 dec ++ ", params=" ++ show params
go_params :: (Quasi m, MonadState St m) => Name -> [TyVarBndr] -> [Type] -> m ()
go_params tname tvs params | length params > length tvs = error $ "Too many arguments to " ++ show tname
go_params _ tvs params = mapM_ (uncurry go_param) (zip tvs (map Just params ++ repeat Nothing))
go_param :: (Quasi m, MonadState St m) => TyVarBndr -> Maybe Type -> m ()
go_param tvb (Just param) = do
ftv param
result %= Set.delete (tvbName tvb)
go_param tvb Nothing = do
result %= Set.insert (tvbName tvb)
instance FreeTypeVars Con where
ftv (NormalC _name sts) = ftv sts
ftv (RecC _name vsts) = ftv vsts
ftv (InfixC st1 _ st2) = ftv [st1, st2]
ftv (ForallC _tvbs _cx con) = ftv con
instance FreeTypeVars (Strict, Type) where
ftv (_, typ) = ftv typ
instance FreeTypeVars (Name, Strict, Type) where
ftv (_, _, typ) = ftv typ
tvbName :: TyVarBndr -> Name
tvbName (PlainTV n) = n
tvbName (KindedTV n _) = n