{-# LANGUAGE TemplateHaskell, MultiParamTypeClasses, QuasiQuotes, GADTs, FlexibleContexts, FlexibleInstances, NoMonomorphismRestriction, GeneralizedNewtypeDeriving #-} module Data.Generic.Diff.TH ( -- ** Main Interface make_family_gadt, -- ** Utils collect_type_args) where import Language.Haskell.TH import Control.Applicative import Data.List import Control.Monad import Data.Generic.Diff ((:=:) (..), Nil (..), Cons (..)) -- import Language.Haskell.TH.ExpandSyns import Language.Haskell.TH.TypeSub import Language.Haskell.TH.Universe import Language.Haskell.TH.Specialize (expand_and_specialize, expand_and_specialize_syns) import Data.Tuple.Select import Control.Monad.Error import Control.Monad.State import Control.Monad.Reader collect_type_args :: Type -> [Type] collect_type_args (AppT x y) = x:(collect_type_args y) collect_type_args x = [x] run_state x xs = runReaderT (runErrorT (runErrorStateT x)) xs type ERType m e r a = ErrorT e (ReaderT r m) a newtype ERT e r m a = ERT { runErrorStateT :: ERType m e r a } deriving (Monad, MonadError e, Functor, MonadPlus, MonadReader r) instance MonadTrans (ERT String [Dec]) where lift = ERT . lift . lift type DecState = ERT String [Dec] Q -------------------------------------------------------- mk_specialized_universe :: Name -> Q [Dec] mk_specialized_universe name = do first_universe <- (map snd . filter ((/=name) . fst)) <$> get_universe name specialized_universe <- expand_and_specialize name name let combined_universe = map snd $ sub_universe (first_universe ++ specialized_universe) name return combined_universe -- | Pass in the name of to generate the GDiff GADT and instances. make_family_gadt :: Name -> Q [Dec] make_family_gadt name = do result <- run_state (make_family_gadt' name) =<< mk_specialized_universe name case result of Right x -> return x Left x -> do error x return [] is_primitive name = any ((nameBase name) ==) ["Int", "Char", "String", "Float", "Double", "Int8", "Int16", "Int32", "Int64", "Word", "Word8", "Word16", "Word32", "Word64", "Addr"] convert_to_gadt_constrs (name, cons) = map (mk_gadt_con name) cons mk_gadt_con :: Name -> Con -> Con mk_gadt_con name (NormalC c_name stys) = mk_gadt_con' name c_name $ map snd stys mk_gadt_con name (RecC c_name vtys) = mk_gadt_con' name c_name $ map sel3 vtys mk_gadt_con name (InfixC (_, x) c_name (_, y)) = mk_gadt_con' name c_name [x, y] mk_gadt_con' name c_name tys | is_primitive name = ForallC [] [EqualP (VarT $ mkName "a") $ ConT name, EqualP (VarT $ mkName "b") $ ConT $ mkName "Nil" ] (NormalC (con_name_to_display (nameBase name) name) [(NotStrict, ConT $ name)]) mk_gadt_con' name c_name tys | otherwise = ForallC [] [EqualP (VarT $ mkName "a") $ ConT name, EqualP (VarT $ mkName "b") $ foldr AppT (ConT $ mkName "Nil") (add_cons tys) ] (NormalC (con_name_to_display (nameBase c_name) name) []) con_name_to_display c_name name | c_name == "[]" = mkName $ "Nil" ++ show name ++ "'" con_name_to_display c_name name | c_name == ":" = mkName $ show name ++ "__Cons'" con_name_to_display c_name _ | otherwise = mkName $ c_name ++ "'" display_to_con_name display_name name | display_name == "Nil" ++ show name = mkName $ "[]" display_to_con_name display_name name | display_name == (show name ++ "__Cons") = mkName $ ":" display_to_con_name display_name _ | otherwise = mkName display_name add_cons tys = map (AppT (ConT $ mkName "Cons")) tys mk_gadt :: Name -> [Con] -> DecState Dec mk_gadt gadt_name gadt_constrs = return $ DataD [] gadt_name [PlainTV $ mkName "a", PlainTV $ mkName "b"] gadt_constrs [] is_dec_name name x = result where dec_name = get_dec_name x result = case dec_name of Right y -> name == y Left _ -> False make_family_gadt' :: Name -> DecState [Dec] make_family_gadt' name = do first_universe <- (map snd . filter ((/=name) . fst)) <$> (lift $ get_universe name) specialized_universe <- (filter (not . is_dec_name name)) <$> (lift $ expand_and_specialize_syns name name) all_constructor <- get_cons_from_name let gadt_constrs = nub $ concatMap convert_to_gadt_constrs all_constructor gadt_name = mkName $ (nameBase name) ++ "Fam" data_dec <- mk_gadt gadt_name gadt_constrs family_instance <- mk_family_instance data_dec type_instances <- mk_type_instances gadt_name $ nub all_constructor return $ data_dec:family_instance:(specialized_universe ++ type_instances) is_right (Right _) = True is_right _ = False get_dec_name :: Dec -> Result Name get_dec_name (FunD name _) = Right $ name get_dec_name (ValD _ _ _) = Left "ValD does not have a name" get_dec_name (DataD _ name _ _ _) = Right $ name get_dec_name (NewtypeD _ name _ _ _) = Right $ name get_dec_name (TySynD name _ _) = Right $ name get_dec_name (ClassD _ name _ _ _) = Right $ name get_dec_name (InstanceD _ _ _) = Left "InstanceD does not have a name" get_dec_name (SigD name _) = Right $ name get_dec_name (ForeignD _) = Left "ForeignD does not have a name" get_dec_name (PragmaD _ ) = Left "PragmaD does not have a name" get_dec_name (FamilyD _ name _ _) = Right $ name get_dec_name (DataInstD _ name _ _ _ ) = Right $ name get_name_and_cons :: [Dec] -> [(Name, [Con])] get_name_and_cons decs = go decs [] where go [] output = output go (x:xs) output = result where name_result = get_dec_name x cons = get_cons x result = case name_result of Right x -> (x, cons):(go xs output) _ -> go xs output get_cons_from_name :: DecState [(Name, [Con])] get_cons_from_name = do decs <- ask return $ get_name_and_cons decs find_dec :: [Dec] -> Name -> Result Dec find_dec decs name = maybe_to_either ("could not find dec " ++ show name ++ " in " ++ show decs) $ find (is_dec_name name) decs maybe_to_either :: String -> Maybe a -> Result a maybe_to_either msg (Just a) = Right a maybe_to_either msg Nothing = Left msg throw_either :: (Monad m, MonadError String m) => Result a -> m a throw_either (Right x) = return x throw_either (Left x) = throwError x -------------------------------------------------------------------------------------------------- mk_family_instance :: Dec -> DecState Dec mk_family_instance (DataD _ name _ cons _) = do dec_eqs <- mapM mk_deceq cons fields <- mapM mk_field cons --todo add the default cases applys <- mapM mk_apply cons strings <- lift $ mapM mk_string cons let dec_fun = FunD (mkName "decEq") (dec_eqs ++ [default_dec]) fields_fun = FunD (mkName "fields") (fields ++ [default_field]) apply_fun = FunD (mkName "apply") applys string_fun = FunD (mkName "string") strings dec = InstanceD [] (AppT (ConT $ mkName "Family") (ConT name)) [dec_fun, fields_fun, apply_fun, string_fun] return dec default_dec = Clause [WildP, WildP] (NormalB $ ConE $ mkName "Nothing") [] mk_deceq :: Con -> DecState Clause mk_deceq (ForallC _ _ (NormalC name [])) = lift $ clause [conP name [], conP name []] (normalB [| Just (Refl, Refl) |]) [] mk_deceq (ForallC _ _ (NormalC name typs)) = lift $ clause [conP name [varP $ mkName "x"], conP name [varP $ mkName "y"]] (normalB [| if $(varE $ mkName "x") == $(varE $ mkName "y") then Just (Refl, Refl) else Nothing |]) [] default_field = Clause [WildP, WildP] (NormalB $ ConE $ mkName "Nothing") [] default_apply = Clause [WildP, WildP] (NormalB $ (AppE (VarE $ mkName "error") (LitE $ StringL "apply failed"))) [] mk_field :: Con -> DecState Clause mk_field (ForallC _ (x:y:[]) (NormalC c_name [])) = do pat <- name_to_con_pat (mkName $ reverse $ tail $ reverse (show c_name)) $ get_pred_name x let vars = collect_vars pat lift $ clause [conP c_name [], return pat] (normalB [| Just $(ccons_vars (reverse vars) [| CNil |]) |]) [] mk_field (ForallC _ _ (NormalC name typs)) = lift $ clause [conP name [wildP], wildP] (normalB [| Just CNil |]) [] get_pred_name (EqualP _ (ConT n)) = n name_to_con_pat :: Name -> Name -> DecState Pat name_to_con_pat con_name n = do decs <- ask dec <- throw_either $ find_dec decs n data_dec_to_con_p con_name dec get_con_name :: Con -> Name get_con_name (NormalC n _) = n get_con_name (RecC n _) = n get_con_name (InfixC _ n _) = n get_con_name (ForallC _ _ con) = get_con_name con is_nil_list con_name name typ | "Nil" `isPrefixOf` show con_name = True is_nil_list con_name name typ | otherwise = False reduce_type = undefined reduce_type' = undefined head_either [] = Left "Called head on empty list" head_either (x:_) = Right x data_dec_to_con_p :: Name -> Dec -> DecState Pat data_dec_to_con_p display_name x = do name <- throw_either $ get_dec_name x let con_name = display_to_con_name (nameBase display_name) name let cons = get_cons x let con_result = head_either $ filter ((show con_name ==) . nameBase . get_con_name) cons con <- case con_result of Left _ -> throwError $ "data_dec_to_con_p failed with con_name = " ++ show con_name ++ " and x " ++ show x Right x -> return x return $ con_to_con_p con rec_to_normal (RecC name vts) = NormalC name $ map (\(_, x, y) -> (x, y)) vts con_to_con_p :: Con -> Pat con_to_con_p x = result where name = get_con_name x typs = get_con_types x result = ConP name $ concat $ zipWith type_to_pat (map (:[]) ['a'..]) typs type_to_pat var_preface (AppT ListT x) = [InfixP (VarP $ mkName (var_preface)) (mkName ":") (VarP $ mkName (var_preface ++ "s"))] type_to_pat var_preface (VarT x) = [VarP $ mkName var_preface] type_to_pat var_preface (AppT (ConT x) y) = [ConP x $ type_to_pat (var_preface ++ "_a") y] type_to_pat var_preface (ConT x) = [(VarP $ mkName var_preface)] collect_vars (ConP _ pvars) = concatMap collect_vars pvars collect_vars (InfixP x _ y) = concatMap collect_vars [x, y] collect_vars (VarP x) = [x] collect_vars (ListP []) = [] ccons_vars [] e = e ccons_vars (x:xs) e = ccons_vars xs [| CCons $(varE x) $(e) |] mk_apply :: Con -> DecState Clause mk_apply (ForallC _ (x:y:[]) (NormalC name [])) = do ex <- name_to_con_exp (mkName $ reverse $ tail $ reverse (show name)) $ get_pred_name x let vars = collect_vars_exp ex lift $ clause [conP name [], ccons_vars_pat (reverse vars) $ conP (mkName "CNil") []] (normalB (return ex)) [] mk_apply (ForallC _ _ (NormalC name typs)) = lift $ clause [conP name [varP $ mkName "x"], conP (mkName "CNil") []] (normalB $ varE $ mkName "x") [] name_to_con_exp :: Name -> Name -> DecState Exp name_to_con_exp con_name n = do decs <- ask dec <- throw_either $ find_dec decs n data_dec_to_con_exp con_name dec data_dec_to_con_exp :: Name -> Dec -> DecState Exp data_dec_to_con_exp display_name dec = do name <- throw_either $ get_dec_name dec let con_name = display_to_con_name (nameBase display_name) name cons = get_cons dec con_result = head_either $ filter (((nameBase con_name)==) . nameBase . get_con_name) cons con <- case con_result of Left _ -> throwError $ "data_dec_to_con_p failed with con_name = " ++ show display_name ++ " and x " ++ show dec Right x -> return x return $ con_to_con_e con collect_vars_exp (AppE x y) = collect_vars_exp x ++ collect_vars_exp y collect_vars_exp (VarE x) = [x] collect_vars_exp (ListE xs) = concatMap collect_vars_exp xs collect_vars_exp (ConE _) = [] collect_vars_exp (InfixE (Just (VarE x)) _ (Just (VarE y))) = [x, y] ccons_vars_pat [] p = p ccons_vars_pat (x:xs) p = ccons_vars_pat xs (conP (mkName "CCons") [varP x, p]) con_to_con_e :: Con -> Exp con_to_con_e con = result where name = get_con_name con typs = get_con_types con result = foldl' AppE (ConE name) $ zipWith type_to_exp (map (:[]) ['a'..]) typs type_to_exp var_preface (AppT ListT x) = InfixE (Just $ VarE $ mkName (var_preface)) (ConE $ mkName ":") (Just $ VarE $ mkName (var_preface ++ "s")) type_to_exp var_preface (VarT x) = VarE $ mkName var_preface type_to_exp var_preface (AppT x y) = AppE (type_to_exp (var_preface ++ "_a") x) (type_to_exp (var_preface ++ "_b") y) type_to_exp var_preface (ConT x) = VarE $ mkName var_preface mk_string :: Con -> Q Clause mk_string (ForallC _ _ (NormalC name [])) = clause [conP name []] (normalB $ stringE $ reverse $ tail $ reverse $ show $ name) [] mk_string (ForallC _ _ (NormalC name typs)) = clause [conP name [varP $ mkName "x"]] (normalB $ appE (varE $ mkName "show") $ varE $ mkName "x") [] --------------------------------------------------------------------------------------------- mk_type_instances gadt_name cons = mapM (mk_type_instance gadt_name) cons mk_type_instance gadt_name (name, cons) = do let cons_exps = map (make_const name) cons return $ InstanceD [] (AppT (AppT (ConT $ mkName "Type") (ConT gadt_name)) (ConT name)) [ FunD (mkName "constructors") [Clause [] (NormalB $ ListE cons_exps) []]] make_const name con | is_primitive name = AppE (ConE $ mkName "Abstr") (ConE $ con_name_to_display (nameBase $ name) name) make_const name con | otherwise = AppE (ConE $ mkName "Concr") (ConE $ con_name_to_display (nameBase $ get_con_name con) name)