-- | Subsitute one type for another in Template Haskell Dec's, Con's and Type's.
-- | Warning! 
-- | There are a few known issues. The types in cxt's are not subsistuted.
-- | Additionally, the Dec's type variables are regenerated after subistution 
-- | and all explicitly kinded type variables are converted to implicitly kinded type
-- | variables.
module Language.Haskell.TH.TypeSub (
 sub_types_dec, 
 sub_type_dec, 
 sub_type_con, 
 sub_type,
 has_type,
 get_cons,
 get_con_types,
 update_ty_vars,
 Result) where
import Language.Haskell.TH
import Control.Arrow
import Data.List
import Data.Generics.Uniplate.Data
import Control.Applicative
import Control.Monad
import Data.Tuple.Select

-- | A result for partial functions
type Result a = Either String a

-- | Create a new data declaration where the type variables have been subsituted with the 
-- | supplied types. Returns an error if the more types the types are provided.
sub_types_dec :: [Type] -> Dec -> Result Dec
sub_types_dec types dec = do
    let sub_type_dec' dec' (n, t) = sub_type_dec t (VarT n) dec'
    names <- mapM (get_ty_var_name dec) [0..length types - 1] 
    return $ foldl' sub_type_dec' dec $ zip names types 
    
-- | Substitute the a new type for an existing type in all the constructors in a Dec.
-- | If the type to replace is missing, the function does nothing. 
sub_type_dec :: Type -> Type -> Dec -> Dec 
sub_type_dec new_type old_type dec = update_ty_vars $ 
    modify_cons dec (map (sub_type_con new_type old_type)) 
    
-- | Substitute the new type for the old type in the constructor       
sub_type_con :: Type -> Type -> Con -> Con
sub_type_con new_type old_type con = modify_types con (map (sub_type new_type old_type))

-- | Substitute the new type for the old type in the type  
-- transform is from Uniplate     
sub_type :: Type -> Type -> Type -> Type
sub_type new_type old_type input = transform sub_type_type' input where
     sub_type_type' t | t == old_type = new_type
     sub_type_type' x | otherwise = x

------------------------------------------------------------------------------------------------
--Various helper functions that will get moved somewhere else eventually

ty_var_name :: TyVarBndr -> Name
ty_var_name (KindedTV name _ ) = name
ty_var_name (PlainTV name) = name

modify_cons :: Dec -> ([Con] -> [Con]) -> Dec
modify_cons (NewtypeD x y z con w) f     = NewtypeD x y z (head $ f [con]) w
modify_cons (DataD x y z cons w)   f     = DataD x y z (f cons) w
modify_cons (DataInstD x y z cons w) f   = DataInstD x y z (f cons) w
modify_cons (NewtypeInstD x y z con w) f = NewtypeInstD x y z (head $ f [con]) w
modify_cons x _                          = x

get_cons :: Dec -> [Con]
get_cons (NewtypeD _ _ _ con _)      = [con]
get_cons (DataD _ _ _ cons _)        = cons
get_cons (DataInstD _ _ _ cons _)    = cons
get_cons (NewtypeInstD _ _ _ con _)  = [con]
get_cons _                           = []
 
get_ty_vars :: Dec -> [TyVarBndr]
get_ty_vars (NewtypeD _ _ ty_vars _ _) = ty_vars
get_ty_vars (DataD    _ _ ty_vars _ _) = ty_vars
get_ty_vars (TySynD _ ty_vars _)       = ty_vars
get_ty_vars (ClassD _ _ ty_vars _ _)   = ty_vars
get_ty_vars (FamilyD _ _ ty_vars _ )   = ty_vars
get_ty_vars _                          = []

set_ty_vars :: Dec -> [TyVarBndr] -> Dec
set_ty_vars (NewtypeD x y _ z w)  ty_vars = NewtypeD x y ty_vars z w
set_ty_vars (DataD    x y _ z w)  ty_vars = DataD x y ty_vars z w
set_ty_vars (TySynD x _ y)        ty_vars = TySynD x ty_vars y
set_ty_vars (ClassD x y _ z w)    ty_vars = ClassD x y ty_vars z w
set_ty_vars (FamilyD x y _ z )    ty_vars = FamilyD x y ty_vars z 
set_ty_vars x _                           = x


get_type_name :: Type -> Result Name
get_type_name (ForallT _ _ typ) = get_type_name typ
get_type_name (VarT n)          = Right n
get_type_name (ConT n)          = Right n
get_type_name x                 = Left ("No name for " ++ show x)

from_right :: Result a -> a
from_right (Right x) = x
from_right (Left x)  = error $ x ++ " is not Right!"

get_ty_var_name :: Dec -> Int -> Result Name
get_ty_var_name dec i = ty_var_name <$> get_value (get_ty_vars dec) i

get_value :: [a] -> Int -> Result a
get_value xs i | i < length xs = Right $ xs !! i
get_value _ i | otherwise     = Left $ show i ++ " Index out of bounds"

collect_vars :: Type -> [Type]
collect_vars typ = [VarT n | VarT n <- universe typ]

make_ty_vars :: Type -> [TyVarBndr]
make_ty_vars = map (PlainTV . from_right . get_type_name) . nub . collect_vars

update_ty_vars :: Dec -> Dec 
update_ty_vars (TySynD n _ t) = (TySynD n (make_ty_vars t) t)
update_ty_vars dec = set_ty_vars dec $ nub $ concatMap make_ty_vars $ 
    concatMap (get_con_types) $ get_cons dec

    
third :: (c -> d) -> (a, b, c) -> (a, b, d)
third f (x, y, z) = (x, y, f z)
    
get_con_types :: Con -> [Type]
get_con_types (NormalC _ st) = map snd st
get_con_types (RecC _ st) = map sel3 st
get_con_types (InfixC x _ y) = map snd [x, y]
get_con_types (ForallC _ _ con) = get_con_types con

modify_types :: Con -> ([Type] -> [Type]) -> Con
modify_types (NormalC n strict_types)  f = NormalC n $ uncurry zip $ (second f $ unzip strict_types)
modify_types (RecC n var_strict_types) f = RecC    n $ (\(x, y, z) -> zip3 x y z) (third f $ unzip3 var_strict_types )
modify_types (InfixC x n y) f = result where
    [x', y'] = uncurry zip $ second f $ unzip [x, y]
    result = InfixC x' n y'
modify_types (ForallC t context con) f = ForallC t context $ modify_types con f
       
has_var :: Name -> Type -> Bool
has_var name typ = any (name==)  [ n | VarT n <- universe typ]

has_type :: Type -> Type -> Bool
has_type typ_to_find typ = any (typ_to_find==) $ universe typ