module Language.Haskell.TH.TypeSub (
sub_types_dec,
sub_type_dec,
sub_type_con,
sub_type,
has_type,
get_cons,
get_con_types,
update_ty_vars,
Result) where
import Language.Haskell.TH
import Control.Arrow
import Data.List
import Data.Generics.Uniplate.Data
import Control.Applicative
import Control.Monad
import Data.Tuple.Select
type Result a = Either String a
sub_types_dec :: [Type] -> Dec -> Result Dec
sub_types_dec types dec = do
let sub_type_dec' dec' (n, t) = sub_type_dec t (VarT n) dec'
names <- mapM (get_ty_var_name dec) [0..length types 1]
return $ foldl' sub_type_dec' dec $ zip names types
sub_type_dec :: Type -> Type -> Dec -> Dec
sub_type_dec new_type old_type dec = update_ty_vars $
modify_cons dec (map (sub_type_con new_type old_type))
sub_type_con :: Type -> Type -> Con -> Con
sub_type_con new_type old_type con = modify_types con (map (sub_type new_type old_type))
sub_type :: Type -> Type -> Type -> Type
sub_type new_type old_type input = transform sub_type_type' input where
sub_type_type' t | t == old_type = new_type
sub_type_type' x | otherwise = x
ty_var_name :: TyVarBndr -> Name
ty_var_name (KindedTV name _ ) = name
ty_var_name (PlainTV name) = name
modify_cons :: Dec -> ([Con] -> [Con]) -> Dec
modify_cons (NewtypeD x y z con w) f = NewtypeD x y z (head $ f [con]) w
modify_cons (DataD x y z cons w) f = DataD x y z (f cons) w
modify_cons (DataInstD x y z cons w) f = DataInstD x y z (f cons) w
modify_cons (NewtypeInstD x y z con w) f = NewtypeInstD x y z (head $ f [con]) w
modify_cons x _ = x
get_cons :: Dec -> [Con]
get_cons (NewtypeD _ _ _ con _) = [con]
get_cons (DataD _ _ _ cons _) = cons
get_cons (DataInstD _ _ _ cons _) = cons
get_cons (NewtypeInstD _ _ _ con _) = [con]
get_cons _ = []
get_ty_vars :: Dec -> [TyVarBndr]
get_ty_vars (NewtypeD _ _ ty_vars _ _) = ty_vars
get_ty_vars (DataD _ _ ty_vars _ _) = ty_vars
get_ty_vars (TySynD _ ty_vars _) = ty_vars
get_ty_vars (ClassD _ _ ty_vars _ _) = ty_vars
get_ty_vars (FamilyD _ _ ty_vars _ ) = ty_vars
get_ty_vars _ = []
set_ty_vars :: Dec -> [TyVarBndr] -> Dec
set_ty_vars (NewtypeD x y _ z w) ty_vars = NewtypeD x y ty_vars z w
set_ty_vars (DataD x y _ z w) ty_vars = DataD x y ty_vars z w
set_ty_vars (TySynD x _ y) ty_vars = TySynD x ty_vars y
set_ty_vars (ClassD x y _ z w) ty_vars = ClassD x y ty_vars z w
set_ty_vars (FamilyD x y _ z ) ty_vars = FamilyD x y ty_vars z
set_ty_vars x _ = x
get_type_name :: Type -> Result Name
get_type_name (ForallT _ _ typ) = get_type_name typ
get_type_name (VarT n) = Right n
get_type_name (ConT n) = Right n
get_type_name x = Left ("No name for " ++ show x)
from_right :: Result a -> a
from_right (Right x) = x
from_right (Left x) = error $ x ++ " is not Right!"
get_ty_var_name :: Dec -> Int -> Result Name
get_ty_var_name dec i = ty_var_name <$> get_value (get_ty_vars dec) i
get_value :: [a] -> Int -> Result a
get_value xs i | i < length xs = Right $ xs !! i
get_value _ i | otherwise = Left $ show i ++ " Index out of bounds"
collect_vars :: Type -> [Type]
collect_vars typ = [VarT n | VarT n <- universe typ]
make_ty_vars :: Type -> [TyVarBndr]
make_ty_vars = map (PlainTV . from_right . get_type_name) . nub . collect_vars
update_ty_vars :: Dec -> Dec
update_ty_vars (TySynD n _ t) = (TySynD n (make_ty_vars t) t)
update_ty_vars dec = set_ty_vars dec $ nub $ concatMap make_ty_vars $
concatMap (get_con_types) $ get_cons dec
third :: (c -> d) -> (a, b, c) -> (a, b, d)
third f (x, y, z) = (x, y, f z)
get_con_types :: Con -> [Type]
get_con_types (NormalC _ st) = map snd st
get_con_types (RecC _ st) = map sel3 st
get_con_types (InfixC x _ y) = map snd [x, y]
get_con_types (ForallC _ _ con) = get_con_types con
modify_types :: Con -> ([Type] -> [Type]) -> Con
modify_types (NormalC n strict_types) f = NormalC n $ uncurry zip $ (second f $ unzip strict_types)
modify_types (RecC n var_strict_types) f = RecC n $ (\(x, y, z) -> zip3 x y z) (third f $ unzip3 var_strict_types )
modify_types (InfixC x n y) f = result where
[x', y'] = uncurry zip $ second f $ unzip [x, y]
result = InfixC x' n y'
modify_types (ForallC t context con) f = ForallC t context $ modify_types con f
has_var :: Name -> Type -> Bool
has_var name typ = any (name==) [ n | VarT n <- universe typ]
has_type :: Type -> Type -> Bool
has_type typ_to_find typ = any (typ_to_find==) $ universe typ