{-# LANGUAGE TypeOperators #-} {-# LANGUAGE TemplateHaskell #-} module Data.StackPrism.TH (deriveStackPrisms, deriveStackPrismsWith, deriveStackPrismsFor) where import Data.StackPrism import Language.Haskell.TH import Control.Applicative import Control.Monad -- | Derive stack prisms for a given datatype. -- -- For example: -- -- > deriveStackPrisms ''Maybe -- -- will create -- -- > _Just :: StackPrism (a :- t) (Maybe a :- t) -- > _Nothing :: StackPrism t (Nothing :- t) -- -- together with their implementations. deriveStackPrisms :: Name -> Q [Dec] deriveStackPrisms = deriveStackPrismsWith ('_':) -- | Derive stack prisms given a function that derives variable names from constructor names. deriveStackPrismsWith :: (String -> String) -> Name -> Q [Dec] deriveStackPrismsWith nameFun = deriveStackPrismsWith' (const nameFun) -- | Derive stack prisms given a list of variable names, one for each constructor. deriveStackPrismsFor :: [String] -> Name -> Q [Dec] deriveStackPrismsFor names = deriveStackPrismsWith' (\i _ -> names !! i) deriveStackPrismsWith' :: (Int -> String -> String) -> Name -> Q [Dec] deriveStackPrismsWith' nameFun name = do info <- reify name routers <- case info of TyConI (DataD _ _ tyArgs cons _) -> mapM (deriveStackPrism name tyArgs (length cons /= 1)) cons TyConI (NewtypeD _ _ tyArgs con _) -> (:[]) <$> deriveStackPrism name tyArgs False con _ -> fail $ show name ++ " is not a datatype." return $ concat [ [ SigD nm typeF , ValD (VarP nm) (NormalB router) [] ] | (i, (conNm, typeF, router)) <- zip [0..] routers , let nm = mkName (nameFun i (nameBase conNm)) ] deriveStackPrism :: Name -> [TyVarBndr] -> Bool -> Con -> Q (Name, Type, Exp) deriveStackPrism resNm tyArgs matchWildcard con = case con of NormalC name tys -> go name (map snd tys) RecC name tys -> go name (map (\(_,_,ty) -> ty) tys) InfixC (_, tyl) name (_, tyr) -> go name [tyl, tyr] _ -> fail $ "Unsupported constructor " ++ show (conName con) where go name tys = do stackPrismE <- [| stackPrism |] stackPrismCon <- deriveConstructor name tys stackPrismDes <- deriveDestructor matchWildcard name tys tNm <- newName "t" let t = VarT tNm let fromType = foldr (-:) t tys let toType = foldl (\t' (PlainTV ty) -> AppT t' (VarT ty)) (ConT resNm) tyArgs -: t return $ ( name , ForallT (PlainTV tNm:tyArgs) [] $ ConT (mkName "StackPrism") `AppT` fromType `AppT` toType , stackPrismE `AppE` stackPrismCon `AppE` stackPrismDes ) (-:) :: Type -> Type -> Type l -: r = ConT (mkName ":-") `AppT` l `AppT` r deriveConstructor :: Name -> [Type] -> Q Exp deriveConstructor name tys = do -- Introduce some names t <- newName "t" fieldNames <- replicateM (length tys) (newName "a") let cons = mkName ":-" let pat = foldr (\f fs -> UInfixP (VarP f) cons fs) (VarP t) fieldNames let applyCon = foldl (\f x -> f `AppE` VarE x) (ConE name) fieldNames let body = UInfixE applyCon (ConE cons) (VarE t) return $ LamE [pat] body deriveDestructor :: Bool -> Name -> [Type] -> Q Exp deriveDestructor matchWildcard name tys = do -- Introduce some names r <- newName "r" fieldNames <- replicateM (length tys) (newName "a") -- Figure out the names of some constructors ConE just <- [| Just |] ConE cons <- [| (:-) |] nothing <- [| Nothing |] let conPat = ConP name (map VarP fieldNames) let okBody = ConE just `AppE` foldr (\h t -> UInfixE (VarE h) (ConE cons) t) (VarE r) fieldNames let okCase = Match (UInfixP conPat cons (VarP r)) (NormalB okBody) [] let failCase = Match WildP (NormalB nothing) [] let allCases = if matchWildcard then [okCase, failCase] else [okCase] return $ LamCaseE allCases -- Retrieve the name of a constructor. conName :: Con -> Name conName con = case con of NormalC name _ -> name RecC name _ -> name InfixC _ name _ -> name ForallC _ _ con' -> conName con'