module Language.Haskell.TH.SCCs
(binding_group, binding_groups,
scc, sccs,
Dependencies(..), type_dependencies,
printQ
) where
import Language.Haskell.TH.Syntax
import qualified Data.Set as Set; import Data.Set (Set)
import qualified Data.Map as Map
import qualified Data.Traversable as Traversable
import Control.Monad (liftM, liftM2, (<=<))
import Data.Graph (stronglyConnComp, SCC(..))
printQ s m = do
x <- m
runIO (maybe (return ()) putStr s >> print x) >> return []
binding_group :: Name -> Q (Set Name)
binding_group = liftM binding_group' . scc
binding_groups :: [Name] -> Q [Set Name]
binding_groups ns = (filter relevant . map binding_group') `liftM` sccs ns
where relevant bg = not (Set.null (Set.intersection (Set.fromList ns) bg))
binding_group' = either Set.singleton id
scc :: Name -> Q (Either Name (Set Name))
scc n = (head . filter (either (==n) (Set.member n))) `liftM` sccs [n]
sccs :: [Name] -> Q [Either Name (Set Name)]
sccs ns = do
let withK f k = (,) k `liftM` f k
chaotic f = loop <=< analyze where
analyze = Traversable.mapM f . Map.fromList .
map (\ x -> (x, x)) . Set.toList
loop m | Set.null fringe = return m
| otherwise = Map.union m `liftM` analyze fringe >>= loop
where fringe =
Set.unions (Map.elems m) `Set.difference` Map.keysSet m
names <- chaotic (type_dependencies <=< reify) (Set.fromList ns)
let listify (AcyclicSCC v) = Left v
listify (CyclicSCC vs) = Right (Set.fromList vs)
return (map listify (stronglyConnComp [(n, n, Set.toList deps) |
(n, deps) <- Map.assocs names]))
class Named t where name_of :: t -> Name
instance Named Info where
name_of i = case i of
ClassI d _ -> name_of d
ClassOpI n _ _ _ -> n
TyConI d -> name_of d
PrimTyConI n _ _ -> n
DataConI n _ _ _ -> n
VarI n _ _ _ -> n
TyVarI n _ -> n
instance Named Dec where
name_of d = case d of
FunD n _ -> n
ValD p _ _ -> name_of p
DataD _ n _ _ _ -> n
NewtypeD _ n _ _ _ -> n
TySynD n _ _ -> n
ClassD _ n _ _ _ -> n
FamilyD _ n _ _ -> n
o -> error $ show o ++ " is not a named declaration."
instance Named Con where
name_of c = case c of
NormalC n _ -> n
RecC n _ -> n
InfixC _ n _ -> n
ForallC _ _ c -> name_of c
instance Named Pat where
name_of p = case p of
VarP n -> n
AsP n _ -> n
SigP p _ -> name_of p
o -> error $ "The pattern `" ++ show o ++ "' does not define exactly one name."
class Dependencies t where
type_dependencies' :: [Name] -> t -> Q (Set Name)
type_dependencies :: Dependencies t => t -> Q (Set Name)
type_dependencies = type_dependencies' []
recur ns = type_dependencies' ns
recur' ns x = type_dependencies' (ns ++ [name_of x])
instance Dependencies Info where
type_dependencies' ns i = case i of
TyConI d -> recur' ns i d
PrimTyConI n _ _ -> return Set.empty
_ -> error $ "This version of th-sccs only calculates mutually " ++
"recursive groups for types; " ++ show (name_of i) ++
" is not a type."
instance Dependencies Dec where
type_dependencies' ns d = case d of
DataD _ _ _ cons _ -> Set.unions `liftM` mapM w cons
NewtypeD _ _ _ c _ -> w c
TySynD _ _ ty -> w ty
FamilyD {} ->
error $ "This version of th-sccs cannot calculate mutually recursive " ++
"groups for types involving type families; " ++
show (last ns) ++ " uses " ++ show (name_of d) ++ "."
o -> error $ "Unexpected declaration: " ++ show o ++ "."
where w x = recur' ns d x
instance Dependencies Con where
type_dependencies' ns c = case c of
NormalC _ sts -> w $ map snd sts
RecC _ vsts -> w $ map (\ (n, _, t) -> RecordField n t) vsts
InfixC stL _ stR -> w $ map snd [stL, stR]
ForallC _ _ c -> type_dependencies' ns c
where w xs = Set.unions `liftM` mapM (recur' ns c) xs
data RecordField = RecordField Name Type
instance Named RecordField where name_of (RecordField n _) = n
instance Dependencies RecordField where
type_dependencies' ns rf@(RecordField _ ty) = recur' ns rf ty
instance Dependencies Type where
type_dependencies' ns t = case t of
ForallT _ _ t -> w t
ConT n -> return (Set.singleton n)
AppT tfn targ -> liftM2 Set.union (w tfn) (w targ)
SigT t _ -> w t
_ -> return Set.empty
where w x = recur ns x