{-# LANGUAGE GeneralizedNewtypeDeriving, NoMonomorphismRestriction,
    FlexibleInstances, FlexibleContexts #-}
module Language.Haskell.TH.Universe (
    get_universe, 
    sub_universe,
    -- ** Utils ... Not sure how to hide these.
    get_type_names,
    filter_dups',
    collect_new_dec_names,
    collect_dec_type_names,
    eval_state,
    Universe) where
import Language.Haskell.TH
import Language.Haskell.TH.Syntax hiding (lift)
import Control.Monad.State
import Control.Monad.Error
import Control.Monad.Trans
import Data.Generics.Uniplate.Data
import Data.List
import Data.Tuple.Select
import Control.Applicative
import Data.Composition
import Control.Monad
import Control.Monad.Reader
import Control.Monad.Identity

type Universe = [(Name, Dec)]

type ErrorStateType m e s a = ErrorT e (StateT s (ReaderT s m)) a

newtype ErrorStateT e s m a = ErrorStateT { runErrorStateT :: ErrorStateType m e s a }
    deriving (Monad, MonadState s, MonadError e, Functor, MonadPlus, MonadReader s)
    
instance MonadTrans (ErrorStateT String Universe) where
    lift = ErrorStateT . lift . lift . lift

type UniverseState a = ErrorStateT String Universe Q a
 
type Result a = Either String a


-- | Collect all the ancestor Dec's for whatever is passed in by name. 
--   For instance if we have
-- 
-- > data Otherthing = Otherthing Float
--  
-- > data Thing = Thing OtherThing Int
-- 
--   then 
-- 
-- > get_universe ''Thing
--
-- would return the Dec's for Thing, OtherThing, Int and Float
get_universe :: Name -> Q (Universe)
get_universe = exec_state . create_universe_name 

-- | Find the type in the passed in universe of Decs ([Dec]), and all of the ancestors in the universe.  
sub_universe :: [Dec] -> Name -> Universe
sub_universe decs name = nub $ snd $ runIdentity $ 
    runReaderT (runStateT (runErrorT (runErrorStateT (sub_universe' name))) []) $ 
        make_dec_dict decs 

make_dec_dict :: [Dec] -> [(Name, Dec)]
make_dec_dict decs = go decs [] where
    go [] output = output
    go (x:xs) output = case (get_dec_name x) of
                         Right n -> (n, x):(go xs output)
                         _ -> go xs output

get_all_decs = do
    decs <- get
    other_decs <- ask
    return $ decs ++ other_decs

sub_universe' :: (Monad m, Functor m,  MonadPlus m,
    MonadState Universe m, 
    MonadError String   m,
    MonadReader Universe m) => Name -> m ()
sub_universe' name = do
    decs <- map snd <$> get_all_decs
    let found = find (is_dec_name name) decs   
    case found of 
                Just x -> create_universe_dec' sub_universe' x
                Nothing -> return ()

is_dec_name name dec = result where
    dec_name_result = get_dec_name dec
    result = case dec_name_result of 
                Right x -> x == name
                Left _  -> False

create_universe_name :: Name -> UniverseState ()
create_universe_name name = do
    reify_result <- lift $ reify name 
    case reify_result of
                         ClassI dec _            -> create_universe_dec dec
                         ClassOpI _ _ dec_name _ -> create_universe_name dec_name
                         TyConI dec              -> create_universe_dec dec
                         (PrimTyConI _ _ _)      -> return ()
                         DataConI _ _ dec_name _ -> create_universe_name dec_name
                         VarI _ _ m_dec _        -> maybe (return ()) create_universe_dec m_dec
                         TyVarI _ _              -> error "Don't know what a TyVarI is"
        
-- | Collect all the ancestor Dec's for the given Dec
create_universe_dec :: Dec -> UniverseState ()
create_universe_dec dec = create_universe_dec' create_universe_name dec

create_universe_dec' f dec = mapM_ f =<< (collect_new_dec_names dec)

collect_new_dec_names :: (Monad m, MonadPlus m,
    MonadState Universe m, 
    MonadError String   m) => Dec -> m [Name]
collect_new_dec_names dec = do
    dec_name <- throw_either $ get_dec_name dec 
    modify ((dec_name, dec):)
    
    if (show dec_name == "GHC.Types.[]")
        then return []
        else do
                let type_names = collect_dec_type_names dec
                --look them up if haven't looked them up before
                filter_dups type_names
    

collect_dec_type_names :: Dec -> [Name]
collect_dec_type_names (TySynD _ _ typ) = get_type_names typ
collect_dec_type_names x = nub . concatMap get_type_names . concatMap get_con_types $ 
                                get_cons x

filter_dups' :: Eq a => [a] -> [(a, b)] -> [a]
filter_dups' names uni = names \\ (fst $ unzip uni)
    
filter_dups :: (Eq a, MonadState [(a, b)] m) =>[a] -> m [a]
filter_dups names = gets (filter_dups' names) 

throw_either :: (Monad m, MonadError String m) => Result a -> m a
throw_either (Right x) = return x
throw_either (Left x)  = throwError x
    
--------------------------------------------------------------------------------
--utility functions without a home
    
get_cons :: Dec -> [Con]
get_cons (NewtypeD _ _ _ con _)      = [con]
get_cons (DataD _ _ _ cons _)        = cons
get_cons (DataInstD _ _ _ cons _)    = cons
get_cons (NewtypeInstD _ _ _ con _)  = [con]
get_cons _                           = []

get_con_types :: Con -> [Type]
get_con_types (NormalC _ st)    = map snd st
get_con_types (RecC _ st)       = map sel3 st
get_con_types (InfixC x _ y)    = map snd [x, y]
get_con_types (ForallC _ _ con) = get_con_types con

get_type_names :: Type -> [Name]
get_type_names typ = go (universe typ) [] where
    go []     output = output
    go (x:xs) output = case type_to_name x of
                        Just y -> y:(go xs output)
                        _ -> go xs output
    
type_to_name :: Type -> Maybe Name
type_to_name (ConT n) = Just n
type_to_name ListT  = Just $ mkName "[]"
type_to_name (TupleT count)= Just $ mkName ("(" ++ 
    (concat $ take count (cycle [","])) ++ ")") 
type_to_name _ = Nothing

from_right :: Either String b -> b
from_right (Right x) = x
from_right (Left msg) = error msg

get_type_name' :: Type -> Name
get_type_name' = from_right . get_type_name

get_type_name :: Type -> Result Name
get_type_name (ForallT _ _ typ) = get_type_name typ
get_type_name (VarT n)          = Right n
get_type_name (ConT n)          = Right n
get_type_name x                 = Left ("No name for " ++ show x)

get_constr_types :: Type -> [Type]
get_constr_types = filter is_cont . universe

is_cont :: Type -> Bool
is_cont (ConT _) = True
is_cont _        = False

get_dec_name :: Dec -> Result Name
get_dec_name (FunD name _)            = return name
get_dec_name (ValD _ _ _)             = Left "InstanceD does not have a name"
get_dec_name (DataD _ name _ _ _)     = return name
get_dec_name (NewtypeD _ name _ _ _)  = return name
get_dec_name (TySynD name _ _)        = return name
get_dec_name (ClassD _ name _ _ _ )   = return name
get_dec_name (InstanceD _ _ _)        = Left "InstanceD does not have a name"
get_dec_name (SigD name _)            = return name
get_dec_name (ForeignD _)             = Left "ForeignD does not have a name"
get_dec_name (PragmaD  _)             = Left "PragmaD does not have a name"
get_dec_name (FamilyD _ name _ _)     = return name
get_dec_name (DataInstD _ name _ _ _) = return name

exec_state :: (Monad m, Functor m) => ErrorStateT e [a1] m a -> m ([a1])
exec_state x = snd <$> runReaderT (runStateT (runErrorT (runErrorStateT x)) []) []

eval_state :: (Monad m, Functor m) => ErrorStateT e [a1] m a -> m (Either e a)
eval_state x = fst <$> runReaderT (runStateT (runErrorT (runErrorStateT x)) []) []