module Language.Haskell.TH.SCCs
  (binding_group, binding_groups,
   scc, sccs,
   Dependencies(..), type_dependencies,
   printQ
  ) where

import Language.Haskell.TH.Syntax

import qualified Data.Set as Set; import Data.Set (Set)
import qualified Data.Map as Map
import qualified Data.Traversable as Traversable
import Control.Monad (liftM, liftM2, (<=<))

import Data.Graph (stronglyConnComp, SCC(..))



printQ s m = do
 x <- m
 runIO (maybe (return ()) putStr s >> print x) >> return []



binding_group :: Name -> Q (Set Name)
binding_group = liftM binding_group' . scc

binding_groups :: [Name] -> Q [Set Name]
binding_groups ns = (filter relevant . map binding_group') `liftM` sccs ns
  where relevant bg = not (Set.null (Set.intersection (Set.fromList ns) bg))

binding_group' = either Set.singleton id

scc :: Name -> Q (Either Name (Set Name))
scc n = (head . filter (either (==n) (Set.member n))) `liftM` sccs [n]
  

sccs :: [Name] -> Q [Either Name (Set Name)]
sccs ns = do
  let withK f k = (,) k `liftM` f k
      chaotic f = loop <=< analyze where
        analyze = Traversable.mapM f . Map.fromList .
                  map (\ x -> (x, x)) . Set.toList
        loop m | Set.null fringe = return m
               | otherwise = Map.union m `liftM` analyze fringe >>= loop
          where fringe =
                  Set.unions (Map.elems m) `Set.difference` Map.keysSet m
  names <- chaotic (type_dependencies <=< reify) (Set.fromList ns)
  let listify (AcyclicSCC v) = Left v
      listify (CyclicSCC vs) = Right (Set.fromList vs)
  return (map listify (stronglyConnComp [(n, n, Set.toList deps) |
                                         (n, deps) <- Map.assocs names]))



class Named t where name_of :: t -> Name

instance Named Info where
  name_of i = case i of
    ClassI d _ -> name_of d
    ClassOpI n _ _ _ -> n
    TyConI d -> name_of d
    PrimTyConI n _ _ -> n
    DataConI n _ _ _ -> n
    VarI n _ _ _ -> n
    TyVarI n _ -> n

instance Named Dec where
  name_of d = case d of
    FunD n _ -> n
    ValD p _ _ -> name_of p
    DataD _ n _ _ _ -> n
    NewtypeD _ n _ _ _ -> n
    TySynD n _ _ -> n
    ClassD _ n _ _ _ -> n
    FamilyD _ n _ _ -> n
    o -> error $ show o ++ " is not a named declaration."

instance Named Con where
  name_of c = case c of
    NormalC n _ -> n
    RecC n _ -> n
    InfixC _ n _ -> n
    ForallC _ _ c -> name_of c

instance Named Pat where
  name_of p = case p of
    VarP n -> n
    AsP n _ -> n
    SigP p _ -> name_of p
    o -> error $ "The pattern `" ++ show o ++ "' does not define exactly one name."



-- | Calculate the type declaration upon which this syntactic construct
-- syntactically dependends.
class Dependencies t where
  type_dependencies' :: [Name] -> t -> Q (Set Name)

type_dependencies :: Dependencies t => t -> Q (Set Name)
type_dependencies = type_dependencies' []



recur ns = type_dependencies' ns
recur' ns x = type_dependencies' (ns ++ [name_of x])

instance Dependencies Info where
  type_dependencies' ns i = case i of
    TyConI d -> recur' ns i d
    PrimTyConI n _ _ -> return Set.empty
    _ -> error $ "This version of th-sccs only calculates mutually " ++
                 "recursive groups for types; " ++ show (name_of i) ++
                 " is not a type."

instance Dependencies Dec where
  type_dependencies' ns d = case d of
    DataD _ _ _ cons _ -> Set.unions `liftM` mapM w cons
    NewtypeD _ _ _ c _ -> w c
    TySynD _ _ ty -> w ty
    FamilyD {} ->
      error $ "This version of th-sccs cannot calculate mutually recursive " ++
              "groups for types involving type families; " ++
              show (last ns) ++ " uses " ++ show (name_of d) ++ "."
    o -> error $ "Unexpected declaration: " ++ show o ++ "."
    where w x = recur' ns d x

instance Dependencies Con where
  type_dependencies' ns c = case c of
    NormalC _ sts -> w $ map snd sts
    RecC _ vsts -> w $ map (\ (n, _, t) -> RecordField n t) vsts
    InfixC stL _ stR -> w $ map snd [stL, stR]
    ForallC _ _ c -> type_dependencies' ns c
    where w xs = Set.unions `liftM` mapM (recur' ns c) xs

data RecordField = RecordField Name Type
instance Named RecordField where name_of (RecordField n _) = n
instance Dependencies RecordField where
  type_dependencies' ns rf@(RecordField _ ty) = recur' ns rf ty

instance Dependencies Type where
  type_dependencies' ns t = case t of
    ForallT _ _ t -> w t
    ConT n -> return (Set.singleton n)
    AppT tfn targ -> liftM2 Set.union (w tfn) (w targ)
    SigT t _ -> w t
    _ -> return Set.empty
    where w x = recur ns x