module Database.Groundhog.TH
( deriveEntity
, setDbEntityName
, setConstructor
, setPhantomName
, setDbConstrName
, setConstraints
, setField
, setDbFieldName
, setExprFieldName
) where
import Database.Groundhog.Core(PersistEntity(..), PersistField(..), NeverNull, Primitive(toPrim), PersistBackend(..), DbType(DbEntity), Constraint, Constructor(..), namedType, EntityDef(..), ConstructorDef(..), PersistValue(..))
import Language.Haskell.TH
import Language.Haskell.TH.Syntax(StrictType, VarStrictType)
import Control.Monad(liftM, forM, forM_)
import Control.Monad.Trans.State(State, runState, modify)
import Data.Char(toUpper, toLower, isSpace)
import Data.List(nub, (\\))
data THEntityDef = THEntityDef {
dataName :: Name
, dbEntityName :: String
, thTypeParams :: [TyVarBndr]
, thConstructors :: [THConstructorDef]
} deriving Show
data THConstructorDef = THConstructorDef {
thConstrName :: Name
, thPhantomConstrName :: String
, dbConstrName :: String
, thConstrParams :: [FieldDef]
, thConstrConstrs :: [Constraint]
} deriving Show
data FieldDef = FieldDef {
fieldName :: String
, dbFieldName :: String
, exprName :: String
, fieldType :: Type
} deriving Show
setDbEntityName :: String -> State THEntityDef ()
setDbEntityName name = modify $ \d -> d {dbEntityName = name}
setConstructor :: Name -> State THConstructorDef () -> State THEntityDef ()
setConstructor name f = modify $ \d ->
d {thConstructors = replaceOne thConstrName name f $ thConstructors d}
setPhantomName :: String -> State THConstructorDef ()
setPhantomName name = modify $ \c -> c {thPhantomConstrName = name}
setDbConstrName :: String -> State THConstructorDef ()
setDbConstrName name = modify $ \c -> c {dbConstrName = name}
setConstraints :: [Constraint] -> State THConstructorDef ()
setConstraints cs = modify $ \c -> c {thConstrConstrs = cs}
setField :: String -> State FieldDef () -> State THConstructorDef ()
setField name f = modify $ \c ->
c {thConstrParams = replaceOne fieldName name f $ thConstrParams c}
setDbFieldName :: String -> State FieldDef ()
setDbFieldName name = modify $ \f -> f {dbFieldName = name}
setExprFieldName :: String -> State FieldDef ()
setExprFieldName name = modify $ \f -> f {exprName = name}
replaceOne :: (Eq b, Show b) => (a -> b) -> b -> State a () -> [a] -> [a]
replaceOne p a f xs = case length (filter ((==a).p) xs) of
1 -> map (\x -> if p x == a then runModify f x else x) xs
0 -> error $ "Element with name " ++ show a ++ " not found"
_ -> error $ "Found more than one element with name " ++ show a
runModify :: State a () -> a -> a
runModify m a = snd $ runState m a
deriveEntity :: Name -> Maybe (State THEntityDef ()) -> Q [Dec]
deriveEntity name f = do
info <- reify name
let f' = maybe id runModify f
case info of
TyConI x -> do
case x of
def@(DataD _ _ _ _ _) -> mkDecs $ either error id $ validate $ f' $ mkTHEntityDef def
NewtypeD _ _ _ _ _ -> error "Newtypes are not supported"
_ -> error $ "Unknown declaration type"
_ -> error "Only datatypes can be processed"
validate :: THEntityDef -> Either String THEntityDef
validate def = do
let notUniqueBy f xs = let xs' = map f xs in nub $ xs' \\ nub xs'
let assertUnique f xs what = case notUniqueBy f xs of
[] -> return ()
ys -> fail $ "All " ++ what ++ " must be unique: " ++ show ys
let isSpaceFree = not . any isSpace
let assertSpaceFree s what = if isSpaceFree s then return () else fail $ "Spaces in " ++ what ++ " are not allowed: " ++ show s
let constrs = thConstructors def
assertUnique thPhantomConstrName constrs "constructor phantom name"
assertUnique dbConstrName constrs "constructor db name"
forM_ constrs $ \cdef -> do
let fields = thConstrParams cdef
assertSpaceFree (thPhantomConstrName cdef) "constructor phantom name"
assertUnique exprName fields "expr field name in a constructor"
assertUnique dbFieldName fields "db field name in a constructor"
forM_ fields $ \fdef -> assertSpaceFree (exprName fdef) "field expr name"
return def
mkTHEntityDef :: Dec -> THEntityDef
mkTHEntityDef (DataD _ dname typeVars cons _) =
THEntityDef dname (nameBase dname) typeVars constrs where
constrs = map mkConstr cons
mkConstr (NormalC name params) = mkConstr' name $ zipWith (\p i -> mkField (firstLetter toLower (nameBase name) ++ show i) p) params [0::Int ..]
mkConstr (RecC name params) = mkConstr' name (map mkVarField params)
mkConstr (InfixC _ _ _) = error "Types with infix constructors are not supported"
mkConstr (ForallC _ _ _) = error "Types with existential quantification are not supported"
mkConstr' name params = THConstructorDef name (nameBase name ++ "Constructor") (nameBase name) params []
mkField :: String -> StrictType -> FieldDef
mkField name (_, t) = mkField' name t
mkVarField :: VarStrictType -> FieldDef
mkVarField (name, _, t) = mkField' (nameBase name) t
mkField' name t = FieldDef name name (firstLetter toUpper name ++ "Field") t
mkTHEntityDef _ = error "Only datatypes can be processed"
firstLetter :: (Char -> Char) -> String -> String
firstLetter f s = f (head s):tail s
mkDecs :: THEntityDef -> Q [Dec]
mkDecs def = do
decs <- fmap concat $ sequence
[ mkPhantomConstructors def
, mkPhantomConstructorInstances def
, mkPersistFieldInstance def
, mkPersistEntityInstance def
]
return decs
mkPhantomConstructors :: THEntityDef -> Q [Dec]
mkPhantomConstructors def = mapM f $ thConstructors def where
f c = dataD (cxt []) (mkName $ thPhantomConstrName c) [] [] []
mkPhantomConstructorInstances :: THEntityDef -> Q [Dec]
mkPhantomConstructorInstances def = sequence $ zipWith f [0..] $ thConstructors def where
f :: Int -> THConstructorDef -> Q Dec
f cNum c = instanceD (cxt []) (appT (conT ''Constructor) (conT $ mkName $ thPhantomConstrName c)) [phantomConstrName', phantomConstrNum'] where
phantomConstrName' = funD 'phantomConstrName [clause [wildP] (normalB $ stringE $ dbConstrName c) []]
phantomConstrNum' = funD 'phantomConstrNum [clause [wildP] (normalB $ [|cNum |]) []]
mkPersistEntityInstance :: THEntityDef -> Q [Dec]
mkPersistEntityInstance def = do
let entity = foldl AppT (ConT (dataName def)) $ map getType $ thTypeParams def
fields' <- do
cParam <- newName "c"
fParam <- newName "f"
let mkField name field = ForallC [] ([EqualP (VarT cParam) (ConT name), EqualP (VarT fParam) (fieldType field)]) $ NormalC (mkName $ exprName field) []
let f cdef = map (mkField $ mkName $ thPhantomConstrName cdef) $ thConstrParams cdef
let constrs = concatMap f $ thConstructors def
return $ DataInstD [] ''Fields [entity, VarT cParam, VarT fParam] constrs []
entityDef' <- do
v <- newName "v"
let mkLambda t = [|undefined :: $(forallT (thTypeParams def) (cxt []) [t| $(return entity) -> $(return t) |]) |]
let typeParams' = listE $ map (\t -> [| namedType ($(mkLambda $ getType t) $(varE v)) |]) $ thTypeParams def
let mkField c fNum f = do
a <- newName "a"
let fname = dbFieldName f
let pats = replicate fNum wildP ++ [varP a] ++ replicate (length (thConstrParams c) fNum 1) wildP
let nvar = case hasFreeVars (fieldType f) of
True -> appE (lamE [conP (thConstrName c) pats] (varE a)) (varE v)
False -> [| undefined :: $(return $ fieldType f) |]
[| (fname, namedType $nvar) |]
let constrs = listE $ zipWith (\cNum c@(THConstructorDef _ _ name params conss) -> [| ConstructorDef cNum name $(listE $ zipWith (mkField c) [0..] params) conss |]) [0..] $ thConstructors def
let body = normalB [| EntityDef $(stringE $ dbEntityName def) $typeParams' $constrs |]
funD 'entityDef $ [ clause [varP v] body [] ]
toPersistValues' <- liftM (FunD 'toPersistValues) $ forM (zip [0..] $ thConstructors def) $ \(cNum, c) -> do
names <- mapM (const $ newName "f") $ thConstrParams c
let pat = conP (thConstrName c) (map varP names)
let body = normalB $ [| sequence $(listE $ map (appE (varE 'toPersistValue)) ([|cNum::Int|]:map varE names) ) |]
clause [pat] body []
fromPersistValues' <- do
clauses <- forM (zip [0..] (thConstructors def)) $ \(cNum, c) -> do
names <- mapM (const $ newName "x") $ thConstrParams c
names' <- mapM (const $ newName "x'") $ thConstrParams c
let pat = conP '(:) [conP 'PersistInt64 [litP $ integerL cNum], listP $ map varP names]
let result = noBindS (appE (varE 'return) ( foldl (\a -> appE a . varE) (conE (thConstrName c)) names'))
let getField name name' f = bindS (varP name') [| fromPersistValue $(varE name) |]
let body = normalB $ doE $ (zipWith3 getField names names' (thConstrParams c)) ++ [result]
clause [pat] body []
unexpected <- newName "xs" >>= \xs -> clause [varP xs] (normalB [| fail $ "Invalid values: " ++ show $(varE xs) |]) []
return $ FunD 'fromPersistValues $ clauses ++ [unexpected]
getConstraints' <- let
hasConstraints = not . null . thConstrConstrs
clauses = zipWith mkClause [0::Int ..] (thConstructors def)
mkClause cNum cdef | not (hasConstraints cdef) = clause [conP (thConstrName cdef) pats] (normalB [| (cNum, []) |]) [] where
pats = map (const wildP) $ thConstrParams cdef
mkClause cNum cdef = clause [conP (thConstrName cdef) (map varP names)] body [] where
getFieldName n = case filter ((==n).dbFieldName) $ thConstrParams cdef of
[f] -> varE $ mkName $ fieldName f
[] -> error $ "Database field name " ++ show n ++ " declared in constraint not found"
_ -> error $ "It can never happen. Found several fields with one database name " ++ show n
body = normalB $ [| (cNum, $(listE $ map (\(cname, fnames) -> [|(cname, $(listE $ map (\fname -> [| (fname, toPrim $(getFieldName fname)) |] ) fnames )) |] ) $ thConstrConstrs cdef)) |]
names = map (mkName . fieldName) $ thConstrParams cdef
in funD 'getConstraints clauses
showField' <- do
let fields = concatMap thConstrParams $ thConstructors def
funD 'showField $ map (\f -> clause [conP (mkName $ exprName f) []] (normalB $ stringE $ dbFieldName f)[] ) fields
eqField' <- let
fieldNames = thConstructors def >>= thConstrParams >>= return.mkName.exprName
clauses = map (\n -> clause [conP n [], conP n []] (normalB [| True |]) []) fieldNames
in funD 'eqField $ if length clauses > 1
then clauses ++ [clause [wildP, wildP] (normalB [| False |]) []]
else clauses
let context = paramsContext def
let decs = [fields', entityDef', toPersistValues', fromPersistValues', getConstraints', showField', eqField']
return $ [InstanceD context (AppT (ConT ''PersistEntity) entity) decs]
mkPersistFieldInstance :: THEntityDef -> Q [Dec]
mkPersistFieldInstance def = do
let types = map getType $ thTypeParams def
let entity = foldl AppT (ConT (dataName def)) types
persistName' <- do
v <- newName "v"
let mkLambda t = [|undefined :: $(forallT (thTypeParams def) (cxt []) [t| $(return entity) -> $(return t) |]) |]
let paramNames = foldr1 (\p xs -> [| $p ++ "$" ++ $xs |] ) $ map (\t -> [| persistName ($(mkLambda t) $(varE v)) |]) types
let namesList = case null types of
True -> [| $(stringE $ dbEntityName def) |]
False -> [| $(stringE $ dbEntityName def) ++ "$" ++ $(paramNames) |]
let body = normalB $ namesList
funD 'persistName $ [ clause [varP v] body [] ]
toPersistValue' <- do
let body = normalB [| liftM (either toPrim toPrim) . insertBy |]
funD 'toPersistValue $ [ clause [] body [] ]
fromPersistValue' <- do
x <- newName "x"
let body = normalB [| fromPersistValue $(varE x) >>= get >>= maybe (fail $ "No data with id " ++ show $(varE x)) return |]
funD 'fromPersistValue $ [ clause [varP x] body [] ]
dbType' <- funD 'dbType $ [ clause [] (normalB [| DbEntity . entityDef |]) [] ]
let context = paramsContext def
let decs = [persistName', toPersistValue', fromPersistValue', dbType']
return $ [InstanceD context (AppT (ConT ''PersistField) entity) decs]
paramsContext :: THEntityDef -> Cxt
paramsContext def = classPred ''PersistField params ++ classPred ''NeverNull maybys where
classPred clazz = map (\t -> ClassP clazz [t])
params = map getType $ thTypeParams def
maybys = nub $ thConstructors def >>= thConstrParams >>= insideMaybe . fieldType
getType :: TyVarBndr -> Type
getType (PlainTV name) = VarT name
getType (KindedTV name _) = VarT name
foldType :: (Type -> a) -> (a -> a -> a) -> Type -> a
foldType f (<>) = go where
go (ForallT _ _ _) = error "forall'ed fields are not allowed"
go z@(AppT a b) = f z <> go a <> go b
go z@(SigT t _) = f z <> go t
go z = f z
hasFreeVars :: Type -> Bool
hasFreeVars = foldType f (||) where
f (VarT _) = True
f _ = False
insideMaybe :: Type -> [Type]
insideMaybe = foldType f (++) where
f (AppT (ConT c) t@(VarT _)) | c == ''Maybe = [t]
f _ = []