module Test.AgataTH (
agatath
, derive, deriveall
, DerivOption(..), (<++>)
, echoAgata
, module Test.Agata
, module Test.QuickCheck
) where
import Language.Haskell.TH.Syntax hiding (lift)
import qualified Language.Haskell.TH.Syntax as TH (lift)
import Language.Haskell.TH
import Control.Monad
import Test.Agata
import Test.QuickCheck(Arbitrary(..))
import Data.List(nub, union)
import Data.Maybe(fromMaybe)
import qualified Data.Map as Map
import qualified Data.Set as Set
import Control.Monad.State.Lazy
data Derivation = Derivation {
derivNames :: [Name]
, derivOptions :: Set.Set DerivOption
}
data DerivOption =
Inline Name
| NoArbitrary
deriving (Show,Eq,Ord)
deriveall :: [Name] -> Derivation
deriveall ns = Derivation ns Set.empty
derive :: Name -> Derivation
derive n = deriveall [n]
(<++>) :: Derivation -> DerivOption -> Derivation
(<++>) d o = d{derivOptions = o `Set.insert` derivOptions d}
echoAgata s n = agatath (derive n) >>= (\r -> return [FunD (mkName s) [Clause [] (NormalB $ LitE $ StringL r) []]]) . dump
agatath :: Derivation -> Q [Dec]
agatath der@(Derivation ts ss) = fmap concat $ mapM deriveAgata ts where
isSet o = o `Set.member` ss
deriveAgata n = do
i@(TyConI d) <- reify n
nns <- replicateM (length $ dParams d) (newName "a")
nns1 <- replicateM (length $ dParams d) (newName "b")
let vs = map VarT nns
expanded <- fmap reTuple $ expand n nns1
m@[InstanceD [] (AppT (ConT cBuildable_) _) [ValD (VarP improve_) _ _,ValD (VarP build_) _ _,ValD (VarP dimension_) (NormalB (SigE (AppE rerelate_ _) (AppT tDimension_ _))) []]] <-
[d| instance Buildable T1 where
improve = undefined
build = undefined
dimension = retag dimension :: Dimension T1
|]
impbody <- mapM impClause (dConsts d)
buildbody <- fmap NormalB $ bldClauses (dConsts d)
allTypesT_t <- fmap (nub . concat) $ mapM (recs n . cFields) (dConsts d)
let
isRecursive = Mut `elem` allTypesT_t
dimplus = InfixE (Just $ VarE dimension_) (VarE $ mkName "+") (Just (LitE (IntegerL 1)))
dimtyp = ForallT (map varBndr nns1) [] $ AppT (AppT ArrowT (AppT tDimension_ expanded)) (AppT tDimension_ (getType n nns1))
dimbody = NormalB $ AppE (SigE rerelate_ dimtyp) (if isRecursive then dimplus else VarE dimension_)
let preqs = allInClass cBuildable_ vs
arb <- arbInstance preqs vs
return $ [
InstanceD preqs (AppT (ConT cBuildable_) (rt vs n))
[FunD improve_ impbody
, ValD (VarP build_) buildbody []
, ValD (VarP dimension_) dimbody []
]] ++ if isSet NoArbitrary then [] else [arb]
where
rt :: [Type] -> Name -> Type
rt [] n = ConT n
rt (v:vs) n = AppT (rt vs n) v
genPE n = do
ids <- replicateM n (newName "x")
return (map varP ids, map varE ids)
bldClauses [c] = bldClause c
bldClauses (c:cs) = [| $(bldClause c) ++ $(bldClauses cs) |]
bldClause :: Con -> Q Exp
bldClause c
| isSet $ Inline $ cName c =
[| inline $(conE $ cName c) |]
| otherwise = do
let ts = cFields c
name = cName c
f [] = [| id |]
f (Auto:vars) = [| autorec .> ($(f vars)) |]
f (Mut:vars) = [| automutrec .> ($(f vars)) |]
[| $(conE name) $> $(recs n ts >>= f) |]
impClause c = do
let fields = cFields c
let name = cName c
let idExp = cId c
(pats,vars) <- genPE (length fields)
let f [] = [| return . id |]
f (v:vars) = [| rb $v *> $(f vars) |]
clause [conP name pats]
(normalB [| rebuild $(idExp) $(f vars) |]) []
arbInstance preqs vs = do
m@[InstanceD [] (AppT cArbitrary_ _) body_] <-
[d| instance Arbitrary T1 where
arbitrary = agata
|]
return $ InstanceD preqs (AppT cArbitrary_ (rt vs n)) body_
data Recu = Mut | Auto deriving (Eq,Show)
recs :: Name -> [Type] -> Q [Recu]
recs n [] = return []
recs n (t:ts) = do
ats <- allTypesT t
rest <- recs n ts
return $ (if n `Set.member` ats then Mut else Auto) : rest
allTypesT :: Type -> Q (Set.Set Name)
allTypesT t = getCollected (xf t) where
f n1 = do
i <- lift $ reify n1
mapM_ xf (iTypes i)
xf :: Type -> Collecting Name ()
xf t = case t of
ConT n2 -> collectIf n2 (f n2)
AppT t1 t2 -> xf t1 >> xf t2
VarT n -> return ()
TupleT x -> return ()
ArrowT -> return ()
ListT -> return ()
contains :: Type -> Name -> Q Bool
contains t n = fmap (Set.member n) $ allTypesT t
flat :: Type -> (Type,[Type])
flat = flat' where
flat' (AppT t1 t2) = case flat' t1 of
(t,ts) -> (t,ts++[t2])
flat' x = (x,[])
getType :: Name -> [Name] -> Type
getType n [] = ConT n
getType n (n1:ns) = AppT (getType n ns) (VarT n1)
expand :: Name -> [Name] -> Q Type
expand n0 ns = fmap simplify $ applic [] (getType n0 ns) where
applic :: [(Type,[Type])] -> Type -> Q Type
applic nts t0 = do
b <- t0 `contains` n0
if not b then return t0 else case flat t0 of
(TupleT _,ts) -> fmap toTuple $ mapM (applic nts) ts
(ConT n, ts) ->
if (ConT n,ts) `elem` nts then return (ConT n0) else do
let rec = applic $ (ConT n,ts) : nts
i <- reify n
let fs = toTuple $ nub $ iTypes i
rec $ subst (zip (iParams i) ts) fs
where
subst nmap t1 = case t1 of
AppT t2 t3 -> AppT (subst nmap t2) (subst nmap t3)
VarT n1 -> fromMaybe t1 $ lookup n1 nmap
_ -> t1
simplify :: Type -> Type
simplify = toTuple . filter filt . nub . toList
filt t = case t of
ConT n -> n0/=n
AppT t1 t2 -> filt t1 && filt t2
_ -> True
toList :: Type -> [Type]
toList t = toList' $ flat t where
toList' :: (Type,[Type]) -> [Type]
toList' (TupleT _,ts) = concatMap toList ts
toList' _ = [t]
toTuple :: [Type] -> Type
toTuple [t] = t
toTuple ts = toTuple' ts where
toTuple' [] = TupleT (length ts)
toTuple' (t:ts') = AppT (toTuple' ts') t
reTuple :: Type -> Type
reTuple = reTuple' . toList where
reTuple' [] = TupleT 0
reTuple' [t] = t
reTuple' (t:ts) = AppT (AppT (TupleT 2) t) $ reTuple' ts
iName :: Info -> Name
iName i = case i of
TyConI d -> dName d
iTypes :: Info -> [Type]
iTypes i = case i of
TyConI d -> dTypes d
PrimTyConI n _ _ -> [ConT n]
_ -> error (show i)
iParams :: Info -> [Name]
iParams i = case i of
TyConI d -> dParams d
dName d = case d of
DataD _ n _ _ _ -> n
dTypes d = case d of
DataD _ _ _ cs _ -> concatMap cFields cs
NewtypeD _ _ _ c _ -> cFields c
TySynD _ _ t -> [t]
dParams :: Dec -> [Name]
dParams d = case d of
DataD _ _ ns _ _ -> map unVarBndr ns
NewtypeD _ _ ns _ _ -> map unVarBndr ns
dConsts :: Dec -> [Con]
dConsts d = case d of
DataD _ _ _ cs _ -> cs
NewtypeD _ _ _ c _ -> [c]
cName :: Con -> Name
cName c = case c of
NormalC n sts -> n
RecC n _ -> n
InfixC _ n _ -> n
ForallC _ _ c1 -> cName c1
cId = conE . cName
cFields :: Con -> [Type]
cFields c = case c of
NormalC n sts -> map snd sts
InfixC st n st' -> [snd st,snd st']
data T1 = T1
dump :: Ppr a => a -> String
dump = show . ppr
type Collecting b a = StateT (Set.Set b) Q a
collected :: (Ord b) => b -> Collecting b Bool
collected = gets . Set.member
collect :: (Ord b) => b -> Collecting b ()
collect b = modify (Set.insert b)
getCollected :: Collecting b a -> Q (Set.Set b)
getCollected = flip execStateT Set.empty
collectIf :: Ord b => b -> Collecting b () -> Collecting b ()
collectIf b x = do
collected_b <- collected b
unless collected_b $ collect b >> x
#if MIN_VERSION_template_haskell(2,4,0)
unVarBndr :: TyVarBndr -> Name
unVarBndr (PlainTV n) = n
unVarBndr (KindedTV n _) = n
varBndr :: Name -> TyVarBndr
varBndr n = (PlainTV n)
allInClass :: Name -> [Type] -> [Pred]
allInClass n vs = map (ClassP n) (map (:[]) vs)
#else
unVarBndr = id
varBndr = id
allInClass n vs = map (AppT (ConT n)) vs
#endif
topApp :: Name -> Q [Dec]
topApp n = do
i@(TyConI (DataD _ _ ns _ _)) <- reify n
nns1 <- replicateM (length ns) (newName "b")
expand n nns1 >>= error . dump
testDimVal :: Name -> Q [Dec]
testDimVal n = return []