{-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE CPP #-} -- {- module Test.AgataTH ( agatath , derive, deriveall , DerivOption(..), (<++>) , echoAgata , module Test.Agata , module Test.QuickCheck ) where -- }- -- module Test.AgataTH where import Language.Haskell.TH.Syntax hiding (lift) import qualified Language.Haskell.TH.Syntax as TH (lift) import Language.Haskell.TH import Control.Monad import Test.Agata import Test.QuickCheck(Arbitrary(..)) import Data.List(nub, union) import Data.Maybe(fromMaybe) import qualified Data.Map as Map import qualified Data.Set as Set import Control.Monad.State.Lazy --------------------------------------------------------------------- -- Some day this file might be tidied up into a presentable state... -- data Derivation = Derivation { derivNames :: [Name] , derivOptions :: Set.Set DerivOption } data DerivOption = Inline Name | NoArbitrary deriving (Show,Eq,Ord) deriveall :: [Name] -> Derivation deriveall ns = Derivation ns Set.empty derive :: Name -> Derivation derive n = deriveall [n] (<++>) :: Derivation -> DerivOption -> Derivation (<++>) d o = d{derivOptions = o `Set.insert` derivOptions d} echoAgata s n = agatath (derive n) >>= (\r -> return [FunD (mkName s) [Clause [] (NormalB $ LitE $ StringL r) []]]) . dump agatath :: Derivation -> Q [Dec] agatath der@(Derivation ts ss) = fmap concat $ mapM deriveAgata ts where isSet o = o `Set.member` ss deriveAgata n = do i@(TyConI d) <- reify n nns <- replicateM (length $ dParams d) (newName "a") nns1 <- replicateM (length $ dParams d) (newName "b") -- >>= mapM unVarBndr let vs = map VarT nns expanded <- fmap reTuple $ expand n nns1 m@[InstanceD [] (AppT (ConT cBuildable_) _) [ValD (VarP improve_) _ _,ValD (VarP build_) _ _,ValD (VarP dimension_) (NormalB (SigE (AppE rerelate_ _) (AppT tDimension_ _))) []]] <- [d| instance Buildable T1 where improve = undefined build = undefined dimension = retag dimension :: Dimension T1 |] impbody <- mapM impClause (dConsts d) buildbody <- fmap NormalB $ bldClauses (dConsts d) -- mapM (bldClause t) (dConsts d) >>= return . NormalB . ListE allTypesT_t <- fmap (nub . concat) $ mapM (recs n . cFields) (dConsts d) let isRecursive = Mut `elem` allTypesT_t dimplus = InfixE (Just $ VarE dimension_) (VarE $ mkName "+") (Just (LitE (IntegerL 1))) dimtyp = ForallT (map varBndr nns1) [] $ AppT (AppT ArrowT (AppT tDimension_ expanded)) (AppT tDimension_ (getType n nns1)) dimbody = NormalB $ AppE (SigE rerelate_ dimtyp) (if isRecursive then dimplus else VarE dimension_) let preqs = allInClass cBuildable_ vs arb <- arbInstance preqs vs return $ [ InstanceD preqs (AppT (ConT cBuildable_) (rt vs n)) [FunD improve_ impbody , ValD (VarP build_) buildbody [] , ValD (VarP dimension_) dimbody [] ]] ++ if isSet NoArbitrary then [] else [arb] where rt :: [Type] -> Name -> Type rt [] n = ConT n rt (v:vs) n = AppT (rt vs n) v genPE n = do ids <- replicateM n (newName "x") return (map varP ids, map varE ids) bldClauses [c] = bldClause c bldClauses (c:cs) = [| $(bldClause c) ++ $(bldClauses cs) |] bldClause :: Con -> Q Exp bldClause c | isSet $ Inline $ cName c = [| inline $(conE $ cName c) |] | otherwise = do let ts = cFields c name = cName c f [] = [| id |] f (Auto:vars) = [| autorec .> ($(f vars)) |] f (Mut:vars) = [| automutrec .> ($(f vars)) |] [| $(conE name) $> $(recs n ts >>= f) |] impClause c = do let fields = cFields c let name = cName c let idExp = cId c (pats,vars) <- genPE (length fields) let f [] = [| return . id |] f (v:vars) = [| rb $v *> $(f vars) |] clause [conP name pats] -- (A x1 x2) (normalB [| rebuild $(idExp) $(f vars) |]) [] -- "A "++show x1++" "++show x2 arbInstance preqs vs = do m@[InstanceD [] (AppT cArbitrary_ _) body_] <- [d| instance Arbitrary T1 where arbitrary = agata |] return $ InstanceD preqs (AppT cArbitrary_ (rt vs n)) body_ data Recu = Mut | Auto deriving (Eq,Show) recs :: Name -> [Type] -> Q [Recu] recs n [] = return [] recs n (t:ts) = do ats <- allTypesT t rest <- recs n ts return $ (if n `Set.member` ats then Mut else Auto) : rest allTypesT :: Type -> Q (Set.Set Name) allTypesT t = getCollected (xf t) where f n1 = do i <- lift $ reify n1 mapM_ xf (iTypes i) xf :: Type -> Collecting Name () xf t = case t of ConT n2 -> collectIf n2 (f n2) AppT t1 t2 -> xf t1 >> xf t2 VarT n -> return () TupleT x -> return () ArrowT -> return () ListT -> return () contains :: Type -> Name -> Q Bool contains t n = fmap (Set.member n) $ allTypesT t flat :: Type -> (Type,[Type]) flat = flat' where flat' (AppT t1 t2) = case flat' t1 of (t,ts) -> (t,ts++[t2]) flat' x = (x,[]) getType :: Name -> [Name] -> Type getType n [] = ConT n getType n (n1:ns) = AppT (getType n ns) (VarT n1) expand :: Name -> [Name] -> Q Type expand n0 ns = fmap simplify $ applic [] (getType n0 ns) where applic :: [(Type,[Type])] -> Type -> Q Type applic nts t0 = do b <- t0 `contains` n0 if not b then return t0 else case flat t0 of (TupleT _,ts) -> fmap toTuple $ mapM (applic nts) ts (ConT n, ts) -> if (ConT n,ts) `elem` nts then return (ConT n0) else do let rec = applic $ (ConT n,ts) : nts i <- reify n let fs = toTuple $ nub $ iTypes i rec $ subst (zip (iParams i) ts) fs where subst nmap t1 = case t1 of AppT t2 t3 -> AppT (subst nmap t2) (subst nmap t3) VarT n1 -> fromMaybe t1 $ lookup n1 nmap _ -> t1 simplify :: Type -> Type simplify = toTuple . filter filt . nub . toList filt t = case t of ConT n -> n0/=n AppT t1 t2 -> filt t1 && filt t2 _ -> True toList :: Type -> [Type] toList t = toList' $ flat t where toList' :: (Type,[Type]) -> [Type] toList' (TupleT _,ts) = concatMap toList ts toList' _ = [t] toTuple :: [Type] -> Type toTuple [t] = t toTuple ts = toTuple' ts where toTuple' [] = TupleT (length ts) toTuple' (t:ts') = AppT (toTuple' ts') t reTuple :: Type -> Type reTuple = reTuple' . toList where reTuple' [] = TupleT 0 reTuple' [t] = t reTuple' (t:ts) = AppT (AppT (TupleT 2) t) $ reTuple' ts iName :: Info -> Name iName i = case i of TyConI d -> dName d iTypes :: Info -> [Type] iTypes i = case i of TyConI d -> dTypes d PrimTyConI n _ _ -> [ConT n] _ -> error (show i) iParams :: Info -> [Name] iParams i = case i of TyConI d -> dParams d dName d = case d of DataD _ n _ _ _ -> n dTypes d = case d of DataD _ _ _ cs _ -> concatMap cFields cs NewtypeD _ _ _ c _ -> cFields c TySynD _ _ t -> [t] dParams :: Dec -> [Name] dParams d = case d of DataD _ _ ns _ _ -> map unVarBndr ns NewtypeD _ _ ns _ _ -> map unVarBndr ns dConsts :: Dec -> [Con] dConsts d = case d of DataD _ _ _ cs _ -> cs NewtypeD _ _ _ c _ -> [c] cName :: Con -> Name cName c = case c of NormalC n sts -> n RecC n _ -> n InfixC _ n _ -> n ForallC _ _ c1 -> cName c1 cId = conE . cName cFields :: Con -> [Type] cFields c = case c of NormalC n sts -> map snd sts InfixC st n st' -> [snd st,snd st'] data T1 = T1 dump :: Ppr a => a -> String dump = show . ppr type Collecting b a = StateT (Set.Set b) Q a collected :: (Ord b) => b -> Collecting b Bool collected = gets . Set.member collect :: (Ord b) => b -> Collecting b () collect b = modify (Set.insert b) getCollected :: Collecting b a -> Q (Set.Set b) getCollected = flip execStateT Set.empty collectIf :: Ord b => b -> Collecting b () -> Collecting b () collectIf b x = do collected_b <- collected b unless collected_b $ collect b >> x -- TH 2.4 compatability -- #if __GLASGOW_HASKELL__ >= 611 #if MIN_VERSION_template_haskell(2,4,0) unVarBndr :: TyVarBndr -> Name unVarBndr (PlainTV n) = n unVarBndr (KindedTV n _) = n varBndr :: Name -> TyVarBndr varBndr n = (PlainTV n) allInClass :: Name -> [Type] -> [Pred] allInClass n vs = map (ClassP n) (map (:[]) vs) #else unVarBndr = id varBndr = id allInClass n vs = map (AppT (ConT n)) vs #endif -- DEBUG topApp :: Name -> Q [Dec] topApp n = do i@(TyConI (DataD _ _ ns _ _)) <- reify n nns1 <- replicateM (length ns) (newName "b") expand n nns1 >>= error . dump testDimVal :: Name -> Q [Dec] testDimVal n = return []