{-# LANGUAGE FlexibleContexts, MultiParamTypeClasses, FlexibleInstances, GeneralizedNewtypeDeriving, TemplateHaskell #-} module Language.Haskell.TH.Specialize ( -- ** Main Interface expand_and_specialize, expand_and_specialize', -- *** Helper Types ConstructorName (..), TypeName (..), -- *** Renamer Interface DecRenamer, ConstrRenamer, -- *** Stock Renamers mk_new_dec_name, id_constr_renamer, -- ** Utils ... Pretend these are not here sub_dec_and_rename, create_dec_from_type, find_con, get_con_vars, get_ty_vars, concat_type_names, rename_dec, run_state', Result ) where import Language.Haskell.TH import Language.Haskell.TH.Universe import Language.Haskell.TH.TypeSub import Control.Monad.Error import Control.Monad.State import Data.Generics.Uniplate.Data import Data.List import Control.Applicative import Data.List.Utils import Control.Applicative import Control.Newtype import Control.Newtype.TH import Language.Haskell.TH.ExpandSyns import Data.Composition import Data.List -- | Expand all the type syn's and create specialize types for any polymorphic types. -- All of the new specialized declarations are returned, along with the original dec -- with subbed types and a new name. -- The first Name is the name of the Dec to create specialize instances for. -- The second Name, is the new name for the Dec. -- use mk_new_dec_name for the Dec renaming and id_constr_renamer for the constructor renaming. expand_and_specialize :: Name -> Name -> Q [Dec] expand_and_specialize = expand_and_specialize' mk_new_dec_name id_constr_renamer -- | Expand all the type syn's and create specialize types for any polymorphic types. -- All of the new specialized declarations are returned, along with the original dec -- with subbed types and a new name. -- The first Name is the name of the Dec to create specialize instances for. -- The second Name, is the new name for the Dec. -- The DecRenamer and ConstrRenamer are used to rename Dec's and Con's respectively. expand_and_specialize' :: DecRenamer -> ConstrRenamer -> Name -> Name -> Q [Dec] expand_and_specialize' dr cr name new_name = do universe <- map snd <$> get_universe name decs <- expand_type_syn_decs universe (new_dec, new_decs) <- run_state' (create_decs_from_name dr cr (TypeName name)) decs let result = case new_dec of Right x -> (from_right $ set_dec_name (pack new_name) x):new_decs Left _ -> new_decs return $ (nub result) \\ decs expand_type_syn_decs :: [Dec] -> Q [Dec] expand_type_syn_decs decs = mapM expand_type_syn_dec decs expand_type_syn_dec :: Dec -> Q Dec expand_type_syn_dec dec = (from_right . (flip set_cons) dec) <$> (mapM expand_type_syn_con $ get_cons dec) expand_type_syn_con :: Con -> Q Con expand_type_syn_con con = set_con_types' con <$> (mapM expand_type_syn_type $ get_con_types con) expand_type_syn_type :: Type -> Q Type expand_type_syn_type = expandSyns find_dec :: [Dec] -> TypeName -> Result Dec find_dec decs name = maybe_to_either "could not find dec" $ find (is_dec_name name) decs is_dec_name name dec = result where dec_name_result = get_dec_name dec result = case dec_name_result of Right x -> x == name Left _ -> False find_dec_from_constr :: (Monad m, Functor m, MonadError String m) => [Dec] -> ConstructorName -> m Dec find_dec_from_constr decs name = throw_maybe "could not find dec" $ find (has_constr name) decs has_constr :: ConstructorName -> Dec -> Bool has_constr name dec = result where found = find (\x -> name == get_con_name x) $ get_cons dec result = case found of Just _ -> True Nothing -> False create_decs_from_name :: (Functor m, Monad m, MonadState [Dec] m, MonadError String m) => DecRenamer -> ConstrRenamer -> TypeName -> m Dec create_decs_from_name dr cr name = do decs <- get dec <- throw_either $ find_dec decs name create_decs_from_dec dr cr dec create_decs_from_dec :: (Functor m, Monad m, MonadState [Dec] m, MonadError String m) => DecRenamer -> ConstrRenamer -> Dec -> m Dec create_decs_from_dec dr cr dec = (from_right . (flip set_cons) dec) <$> (mapM (create_decs_from_con dr cr) $ get_cons dec) create_decs_from_con :: (Functor m, Monad m, MonadState [Dec] m, MonadError String m) => DecRenamer -> ConstrRenamer -> Con -> m Con create_decs_from_con dr cr con = set_con_types' con <$> (mapM (create_dec_from_type dr cr) $ get_con_types con) create_dec_from_type :: (Functor m, Monad m, MonadState [Dec] m, MonadError String m) => DecRenamer -> ConstrRenamer -> Type -> m Type create_dec_from_type dr cr typ@(AppT _ _) = do let x:args = collect_type_args typ create_dec_from_type' dr cr x =<< (mapM (create_dec_from_type dr cr) args) create_dec_from_type dr cr typ = return typ type DecRenamer = ([Type] -> TypeName -> Result TypeName) type ConstrRenamer = ([Type] -> Con -> Con) has_dec :: (Monad m, MonadState [Dec] m) => TypeName -> m Bool has_dec name = gets (any (is_dec_name name)) add_dec dec = modify (dec:) -- | Default Con renamer id_constr_renamer :: [Type] -> Con -> Con id_constr_renamer x y = y newtype ConstructorName = ConstructorName { runConstructorName :: Name } deriving(Show, Eq) newtype TypeName = TypeName { runTypeName :: Name } deriving(Show, Eq) create_dec_from_type' :: (Monad m, MonadState [Dec] m, Functor m, MonadError String m) => DecRenamer -> ConstrRenamer -> Type -> [Type] -> m Type create_dec_from_type' dr cr (ConT name) args = do decs <- get dec <- throw_either $ find_dec decs $ TypeName name dec_name <- throw_either $ get_dec_name dec new_dec_name <- throw_either $ dr args dec_name has_dec' <- has_dec new_dec_name when (not has_dec') $ do new_dec <- sub_dec_and_rename cr dec args add_dec new_dec return $ ConT $ runTypeName new_dec_name create_dec_from_type' dr cr ListT args = create_dec_from_type' dr cr (ConT $ mkName "GHC.Types.[]") args create_dec_from_type' dr cr (TupleT count) args = create_dec_from_type' dr cr (ConT $ mkName ("GHC.Types.(" ++ (concat $ take count (cycle [","])) ++ ")")) args create_dec_from_type' dr cr t@(AppT _ _) args = do typ <- (create_dec_from_type dr cr t) create_dec_from_type' dr cr typ args create_dec_from_type' dr cr (SigT t _) args = do typ <- (create_dec_from_type dr cr t) create_dec_from_type' dr cr typ args create_dec_from_type' dr cr (ForallT _ _ t) args = do typ <- (create_dec_from_type dr cr t) create_dec_from_type' dr cr typ args create_dec_from_type' dr cr t args = --just return what was passed in if we can't do anything return $ foldl' AppT t args rename_cons :: (Con -> Con) -> Dec -> Result Dec rename_cons cr dec = result where new_cons = map cr $ get_cons dec result = set_cons new_cons dec set_cons :: [Con] -> Dec -> Result Dec set_cons (con:[]) (NewtypeD x y z _ w) = Right $ NewtypeD x y z con w set_cons cons (NewtypeD _ _ _ _ _) = Left $ show cons ++ " is not a appropiate arg for setting the NewtypeD's constructor arg" set_cons cons (DataD x y z _ w) = Right $ DataD x y z cons w set_cons cons (DataInstD x y z _ w) = Right $ DataInstD x y z cons w set_cons (con:[]) (NewtypeInstD x y z _ w) = Right $ NewtypeInstD x y z con w set_cons cons (NewtypeInstD x y z _ w) = Left $ show cons ++ " is not a appropiate arg for setting the NewtypeInstD's constructor arg" set_cons _ x = Left $ "Can't set the constructors for " ++ show x get_ty_vars :: Dec -> [TyVarBndr] get_ty_vars (NewtypeD _ _ ty_vars _ _) = ty_vars get_ty_vars (DataD _ _ ty_vars _ _) = ty_vars get_ty_vars (TySynD _ ty_vars _) = ty_vars get_ty_vars (ClassD _ _ ty_vars _ _) = ty_vars get_ty_vars (FamilyD _ _ ty_vars _ ) = ty_vars get_ty_vars _ = [] ty_var_name :: TyVarBndr -> Name ty_var_name (KindedTV name _ ) = name ty_var_name (PlainTV name) = name sub_dec_by_con :: (Monad m, MonadError String m) => ConstrRenamer -> Dec -> [Type] -> m Dec sub_dec_by_con cr dec args = do --get the names of the ty vars let tv_vars = get_ty_vars dec --subsistute the types into the dec throw_either $ rename_cons (cr args) $ foldl' sub_type_dec' dec $ zip args $ map (VarT . ty_var_name) tv_vars sub_dec_and_rename :: (Monad m, Functor m, MonadError String m) => ConstrRenamer -> Dec -> [Type] -> m Dec sub_dec_and_rename cr dec types = rename_dec types =<< sub_dec_by_con cr dec types concat_type_names :: [Type] -> String concat_type_names types = concat $ intersperse "_" $ map (replace " " "_" . show) types -- | Default Dec renamer mk_new_dec_name :: [Type] -> TypeName -> Result TypeName mk_new_dec_name types dec_name = do let suffix = concat_type_names types let dec_name_string = show $ unpack dec_name let name_string = if isSuffixOf "[]" dec_name_string then suffix ++ "_List" else dec_name_string ++ "_" ++ suffix return $ pack $ mkName $ name_string rename_dec :: (Monad m, MonadError String m) => [Type] -> Dec -> m Dec rename_dec types dec = do new_name <- (throw_either . mk_new_dec_name types) =<< (throw_either $ get_dec_name dec) set_dec_name new_name dec sub_type_dec' dec (new, old) = sub_type_dec new old dec find_con :: (Monad m, MonadError String m) => ConstructorName -> Dec -> m Con find_con name dec = throw_maybe err_msg $ find (\x -> name == get_con_name x) $ get_cons dec where err_msg = "constructor " ++ show name ++ "not found" throw_maybe :: (Monad m, MonadError String m) => String -> Maybe a -> m a throw_maybe _ (Just x) = return x throw_maybe err Nothing = throwError err maybe_to_either :: String -> Maybe a -> Result a maybe_to_either msg (Just a) = Right a maybe_to_either msg Nothing = Left msg throw_either :: (Monad m, MonadError String m) => Result a -> m a throw_either (Right x) = return x throw_either (Left x) = throwError x is_vart (VarT _) = True is_vart _ = False get_con_name :: Con -> ConstructorName get_con_name (NormalC n _) = pack n get_con_name (RecC n _) = pack n get_con_name (InfixC _ n _) = pack n get_con_name (ForallC _ _ con) = get_con_name con get_con_vars :: Con -> [Type] get_con_vars = filter is_vart . concatMap universe . get_con_types collect_type_args :: Type -> [Type] collect_type_args (AppT x y) = x:(collect_type_args y) collect_type_args x = [x] run_state' x xs = runStateT (runErrorT (runErrorStateT x)) xs type ErrorStateType m e s a = ErrorT e (StateT s m) a newtype ErrorStateT e s m a = ErrorStateT { runErrorStateT :: ErrorStateType m e s a } deriving (Monad, MonadState s, MonadError e, Functor, MonadPlus) instance MonadTrans (ErrorStateT String [Dec]) where lift = ErrorStateT . lift . lift collect_constr :: [Dec] -> [(TypeName, [Con])] collect_constr decs = right_only $ map get_cons_pair decs is_right :: Either a b -> Bool is_right (Right _) = True is_right (Left _) = False from_right :: Either a b -> b from_right (Right x) = x from_right (Left _) = error "from_right" right_only :: [Either a b] -> [b] right_only = map from_right . filter is_right --get_cons_pair :: (Monad m, MonadError String m) => Dec -> m (Name, [Con]) get_cons_pair :: Dec -> Result (TypeName, [Con]) get_cons_pair dec = do name <- get_dec_name dec let cons = get_cons dec return (name, cons) --get_dec_name :: (Monad m, MonadError String m) => Dec -> m Name get_dec_name :: Dec -> Result TypeName get_dec_name (FunD name _) = Right $ pack name get_dec_name (ValD _ _ _) = Left "ValD does not have a name" get_dec_name (DataD _ name _ _ _) = Right $ pack name get_dec_name (NewtypeD _ name _ _ _) = Right $ pack name get_dec_name (TySynD name _ _) = Right $ pack name get_dec_name (ClassD _ name _ _ _) = Right $ pack name get_dec_name (InstanceD _ _ _) = Left "InstanceD does not have a name" get_dec_name (SigD name _) = Right $ pack name get_dec_name (ForeignD _) = Left "ForeignD does not have a name" get_dec_name (PragmaD _ ) = Left "PragmaD does not have a name" get_dec_name (FamilyD _ name _ _) = Right $ pack name get_dec_name (DataInstD _ name _ _ _ ) = Right $ pack name set_dec_name :: (Monad m, MonadError String m) => TypeName -> Dec -> m Dec set_dec_name name (FunD _ x) = return $ FunD (unpack name) x set_dec_name name (ValD _ _ _) = throwError "ValD does not have a name" set_dec_name name (DataD x _ y z w) = return $ DataD x (unpack name) y z w set_dec_name name (NewtypeD x _ y z w) = return $ NewtypeD x (unpack name) y z w set_dec_name name (TySynD _ x y) = return $ TySynD (unpack name) x y set_dec_name name (ClassD x _ y z w) = return $ ClassD x (unpack name) y z w set_dec_name name (InstanceD _ _ _) = throwError "InstanceD does not have a name" set_dec_name name (SigD _ x) = return $ SigD (unpack name) x set_dec_name name (ForeignD _) = throwError "ForeignD does not have a name" set_dec_name name (PragmaD _ ) = throwError "PragmaD does not have a name" set_dec_name name (FamilyD x _ y z) = return $ FamilyD x (unpack name) y z set_dec_name name (DataInstD x _ y z w ) = return $ DataInstD x (unpack name) y z w set_con_types' :: Con -> [Type] -> Con set_con_types' (NormalC n st) types = NormalC n $ zipWith (\(x, _) t -> (x, t)) st types set_con_types' (RecC n st) types = RecC n $ zipWith (\(x, y, _) t -> (x,y,t)) st types set_con_types' (InfixC (x, _) n (y, _)) [a, b] = InfixC (x, a) n (y, b) set_con_types' (ForallC x y con) types = ForallC x y $ set_con_types' con types $(mkNewTypes [''ConstructorName, ''TypeName])