module Language.Haskell.TH.SCCs
  (binding_group, binding_groups,
   scc, sccs,
   Dependencies(..), type_dependencies, td_recur, td_descend,
   Named(..),
   printQ
  ) where

import Language.Haskell.TH.Syntax

import qualified Data.Set as Set; import Data.Set (Set)
import qualified Data.Map as Map
import qualified Data.Traversable as Traversable
import Control.Monad (liftM, liftM2, (<=<))

import Data.Graph (stronglyConnComp, SCC(..))


-- | Helpful for debugging generated code
printQ :: Show a => Maybe String -> Q a -> Q [Dec]
printQ s m = do
 x <- m
 runIO (maybe (return ()) putStr s >> print x) >> return []


-- | Computes the SCC that includes the declaration of the given name; @Left@
-- is a singly acyclic declaration, @Right@ is a mutually recursive group
-- (possibly of size one: singly recursion).
scc :: Name -> Q (Either Name (Set Name))
scc n = (head . filter (either (==n) (Set.member n))) `liftM` sccs [n]


-- | Computes all SCCs for the given names (including those it dominates)
sccs :: [Name] -> Q [Either Name (Set Name)]
sccs ns = do
  let withK f k = (,) k `liftM` f k
      chaotic f = loop <=< analyze where
        analyze = Traversable.mapM f . Map.fromList .
                  map (\ x -> (x, x)) . Set.toList
        loop m | Set.null fringe = return m
               | otherwise = Map.union m `liftM` analyze fringe >>= loop
          where fringe =
                  Set.unions (Map.elems m) `Set.difference` Map.keysSet m
  names <- chaotic (fmap type_dependencies . reify) (Set.fromList ns)
  let listify (AcyclicSCC v) = Left v
      listify (CyclicSCC vs) = Right (Set.fromList vs)
  return (map listify (stronglyConnComp [(n, n, Set.toList deps) |
                                         (n, deps) <- Map.assocs names]))

-- | Wrapper for 'scc' that forgets the distinction between a single acyclic
-- SCC and a singly recursive SCC
binding_group :: Name -> Q (Set Name)
binding_group = liftM binding_group' . scc

-- | Wrapper for 'sccs' that forgets the distinction between a single acyclic
-- SCC and a singly recursive SCC
binding_groups :: [Name] -> Q [Set Name]
binding_groups ns = (filter relevant . map binding_group') `liftM` sccs ns
  where relevant bg = not (Set.null (Set.intersection (Set.fromList ns) bg))

binding_group' = either Set.singleton id

-- | This is semantically murky: it's just the name of anything that
-- \"naturally\" /defines/ a name; error if it doesn't.
class Named t where name_of :: t -> Name

instance Named Info where
  name_of i = case i of
    ClassI d _ -> name_of d
    ClassOpI n _ _ _ -> n
    TyConI d -> name_of d
    PrimTyConI n _ _ -> n
    DataConI n _ _ _ -> n
    VarI n _ _ _ -> n
    TyVarI n _ -> n

instance Named Dec where
  name_of d = case d of
    FunD n _ -> n
    ValD p _ _ -> name_of p
    DataD _ n _ _ _ -> n
    NewtypeD _ n _ _ _ -> n
    TySynD n _ _ -> n
    ClassD _ n _ _ _ -> n
    FamilyD _ n _ _ -> n
    o -> error $ show o ++ " is not a named declaration."

instance Named Con where
  name_of c = case c of
    NormalC n _ -> n
    RecC n _ -> n
    InfixC _ n _ -> n
    ForallC _ _ c -> name_of c

instance Named Pat where
  name_of p = case p of
    VarP n -> n
    AsP n _ -> n
    SigP p _ -> name_of p
    o -> error $ "The pattern `" ++ show o ++ "' does not define exactly one name."



-- | Calculate the type declarations upon which this construct syntactically
-- depends. The first argument tracks the bindings traversed; use 'td_descend'
-- to maintain it.
class Dependencies t where
  type_dependencies' :: [Name] -> t -> Set Name

type_dependencies :: Dependencies t => t -> Set Name
type_dependencies = type_dependencies' []


-- | Just a bit shorter than 'type_dependencies''
td_recur :: Dependencies t => [Name] -> t -> Set Name
td_recur ns = type_dependencies' ns

-- | Shorter than 'type_dependencies'' and also adds the name of the seconda
-- argument to the tracked bindings
td_descend :: (Named a, Dependencies t) => [Name] -> a -> t -> Set Name
td_descend ns x = type_dependencies' (ns ++ [name_of x])

instance Dependencies Info where
  type_dependencies' ns i = case i of
    TyConI d -> td_descend ns i d
    PrimTyConI n _ _ -> Set.empty
    _ -> error $ "This version of th-sccs only calculates mutually " ++
                 "recursive groups for types; " ++ show (name_of i) ++
                 " is not a type."

instance Dependencies Dec where
  type_dependencies' ns d = case d of
    DataD _ _ _ cons _ -> Set.unions (map w cons)
    NewtypeD _ _ _ c _ -> w c
    TySynD _ _ ty -> w ty
    FamilyD {} ->
      error $ "This version of th-sccs cannot calculate mutually recursive " ++
              "groups for types involving type families; " ++
              show (last ns) ++ " uses " ++ show (name_of d) ++ "."
    o -> error $ "Unexpected declaration: " ++ show o ++ "."
    where w x = td_descend ns d x

instance Dependencies Con where
  type_dependencies' ns c = case c of
    NormalC _ sts -> w $ map snd sts
    RecC _ vsts -> w $ map (\ (n, _, t) -> RecordField n t) vsts
    InfixC stL _ stR -> w $ map snd [stL, stR]
    ForallC _ _ c -> type_dependencies' ns c
    where w xs = Set.unions (map (td_descend ns c) xs)

data RecordField = RecordField Name Type
instance Named RecordField where name_of (RecordField n _) = n
instance Dependencies RecordField where
  type_dependencies' ns rf@(RecordField _ ty) = td_descend ns rf ty

instance Dependencies Type where
  type_dependencies' ns t = case t of
    ForallT _ _ t -> w t
    ConT n -> Set.singleton n
    AppT tfn targ -> Set.union (w tfn) (w targ)
    SigT t _ -> w t
    _ -> Set.empty
    where w x = td_recur ns x