{-# LANGUAGE CPP #-}
module Data.Comp.Derive.Utils where
import Control.Monad
import Language.Haskell.TH
import Language.Haskell.TH.Syntax
import Language.Haskell.TH.ExpandSyns
#if __GLASGOW_HASKELL__ < 706
reportError :: String -> Q ()
reportError = report True
#endif
#if __GLASGOW_HASKELL__ < 800
data DataInfo = DataInfo Cxt Name [TyVarBndr] [Con] [Name]
#else
#if __GLASGOW_HASKELL__ < 802
data DataInfo = DataInfo Cxt Name [TyVarBndr] [Con] Cxt
#else
data DataInfo = DataInfo Cxt Name [TyVarBndr] [Con] [DerivClause]
#endif
#endif
abstractNewtypeQ :: Q Info -> Q (Maybe DataInfo)
abstractNewtypeQ = liftM abstractNewtype
abstractNewtype :: Info -> Maybe DataInfo
#if __GLASGOW_HASKELL__ < 800
abstractNewtype (TyConI (NewtypeD cxt name args constr derive))
= Just (DataInfo cxt name args [constr] derive)
abstractNewtype (TyConI (DataD cxt name args constrs derive))
= Just (DataInfo cxt name args constrs derive)
#else
abstractNewtype (TyConI (NewtypeD cxt name args _ constr derive))
= Just (DataInfo cxt name args [constr] derive)
abstractNewtype (TyConI (DataD cxt name args _ constrs derive))
= Just (DataInfo cxt name args constrs derive)
#endif
abstractNewtype _ = Nothing
normalCon :: Con -> (Name,[StrictType], Maybe Type)
normalCon (NormalC constr args) = (constr, args, Nothing)
normalCon (RecC constr args) = (constr, map (\(_,s,t) -> (s,t)) args, Nothing)
normalCon (InfixC a constr b) = (constr, [a,b], Nothing)
normalCon (ForallC _ _ constr) = normalCon constr
#if __GLASGOW_HASKELL__ >= 800
normalCon (GadtC (constr:_) args typ) = (constr,args,Just typ)
#endif
normalCon _ = error "missing case for 'normalCon'"
normalCon' :: Con -> (Name,[Type], Maybe Type)
normalCon' con = (n, map snd ts, t)
where (n, ts, t) = normalCon con
normalConExp :: Con -> Q (Name,[Type], Maybe Type)
normalConExp c = do
let (n,ts,t) = normalCon' c
return (n, ts,t)
normalConStrExp :: Con -> Q (Name,[StrictType], Maybe Type)
normalConStrExp c = do
let (n,ts,t) = normalCon c
ts' <- mapM (\ (st,ty) -> do ty' <- expandSyns ty; return (st,ty')) ts
return (n, ts',t)
getBinaryFArg :: Type -> Maybe Type -> Type
getBinaryFArg _ (Just (AppT (AppT _ t) _)) = t
getBinaryFArg def _ = def
getUnaryFArg :: Type -> Maybe Type -> Type
getUnaryFArg _ (Just (AppT _ t)) = t
getUnaryFArg def _ = def
abstractConType :: Con -> (Name,Int)
abstractConType (NormalC constr args) = (constr, length args)
abstractConType (RecC constr args) = (constr, length args)
abstractConType (InfixC _ constr _) = (constr, 2)
abstractConType (ForallC _ _ constr) = abstractConType constr
#if __GLASGOW_HASKELL__ >= 800
abstractConType (GadtC (constr:_) args _typ) = (constr,length args)
#endif
abstractConType _ = error "missing case for 'abstractConType'"
tyVarBndrName (PlainTV n) = n
tyVarBndrName (KindedTV n _) = n
containsType :: Type -> Type -> Bool
containsType s t
| s == t = True
| otherwise = case s of
ForallT _ _ s' -> containsType s' t
AppT s1 s2 -> containsType s1 t || containsType s2 t
SigT s' _ -> containsType s' t
_ -> False
containsType' :: Type -> Type -> [Int]
containsType' = run 0
where run n s t
| s == t = [n]
| otherwise = case s of
ForallT _ _ s' -> run n s' t
AppT s1 s2 -> run n s1 t ++ run (n+1) s2 t
SigT s' _ -> run n s' t
_ -> []
newNames :: Int -> String -> Q [Name]
newNames n name = replicateM n (newName name)
tupleTypes n m = map tupleTypeName [n..m]
derive :: [Name -> Q [Dec]] -> [Name] -> Q [Dec]
derive ders names = liftM concat $ sequence [der name | der <- ders, name <- names]
#if __GLASGOW_HASKELL__ < 710
mkClassP :: Name -> [Type] -> Pred
mkClassP = ClassP
#else
mkClassP :: Name -> [Type] -> Type
mkClassP name = foldl AppT (ConT name)
#endif
#if __GLASGOW_HASKELL__ < 710
isEqualP :: Pred -> Maybe (Type, Type)
isEqualP (EqualP x y) = Just (x, y)
isEqualP _ = Nothing
#else
isEqualP :: Type -> Maybe (Type, Type)
isEqualP (AppT (AppT EqualityT x) y) = Just (x, y)
isEqualP _ = Nothing
#endif
mkInstanceD :: Cxt -> Type -> [Dec] -> Dec
#if __GLASGOW_HASKELL__ < 800
mkInstanceD cxt ty decs = InstanceD cxt ty decs
#else
mkInstanceD cxt ty decs = InstanceD Nothing cxt ty decs
#endif
liftSumGen :: Name -> Name -> Name -> Q [Dec]
liftSumGen caseName sumName fname = do
ClassI (ClassD _ name targs_ _ decs) _ <- reify fname
let targs = map tyVarBndrName targs_
splitM <- findSig targs decs
case splitM of
Nothing -> do reportError $ "Class " ++ show name ++ " cannot be lifted to sums!"
return []
Just (ts1_, ts2_) -> do
let f = VarT $ mkName "f"
let g = VarT $ mkName "g"
let ts1 = map VarT ts1_
let ts2 = map VarT ts2_
let cxt = [mkClassP name (ts1 ++ f : ts2),
mkClassP name (ts1 ++ g : ts2)]
let tp = ((ConT sumName `AppT` f) `AppT` g)
let complType = foldl AppT (foldl AppT (ConT name) ts1 `AppT` tp) ts2
decs' <- sequence $ concatMap decl decs
return [mkInstanceD cxt complType decs']
where decl :: Dec -> [DecQ]
decl (SigD f _) = [funD f [clause f]]
decl _ = []
clause :: Name -> ClauseQ
clause f = do x <- newName "x"
let b = NormalB (VarE caseName `AppE` VarE f `AppE` VarE f `AppE` VarE x)
return $ Clause [VarP x] b []
findSig :: [Name] -> [Dec] -> Q (Maybe ([Name],[Name]))
findSig targs decs = case map run decs of
[] -> return Nothing
mx:_ -> do x <- mx
case x of
Nothing -> return Nothing
Just n -> return $ splitNames n targs
where run :: Dec -> Q (Maybe Name)
run (SigD _ ty) = do
ty' <- expandSyns ty
return $ getSig False ty'
run _ = return Nothing
getSig t (ForallT _ _ ty) = getSig t ty
getSig False (AppT (AppT ArrowT ty) _) = getSig True ty
getSig True (AppT ty _) = getSig True ty
getSig True (VarT n) = Just n
getSig _ _ = Nothing
splitNames y (x:xs)
| y == x = Just ([],xs)
| otherwise = do (xs1,xs2) <- splitNames y xs
return (x:xs1,xs2)
splitNames _ [] = Nothing