module Database.Groundhog.Generic.Sql
( renderCond
, defaultShowPrim
, renderArith
, renderOrders
, renderUpdates
, renderFields
, renderChain
, intercalateS
, RenderS(..)
, StringLike(..)
, fromString
, (<>)
, parens
) where
import Database.Groundhog.Core
import Database.Groundhog.Instances ()
import Data.List (foldl')
import Data.Maybe (mapMaybe)
import Data.Monoid
import Data.String
class (Monoid a, IsString a) => StringLike a where
fromChar :: Char -> a
data RenderS s = RenderS {
getQuery :: s
, getValues :: [PersistValue] -> [PersistValue]
}
instance Monoid s => Monoid (RenderS s) where
mempty = RenderS mempty id
(RenderS f1 g1) `mappend` (RenderS f2 g2) = RenderS (f1 `mappend` f2) (g1 . g2)
instance StringLike String where
fromChar c = [c]
parens :: StringLike s => Int -> Int -> RenderS s -> RenderS s
parens p1 p2 expr = if p1 < p2 then char '(' <> expr <> char ')' else expr
#if !MIN_VERSION_base(4, 5, 0)
(<>) :: Monoid m => m -> m -> m
(<>) = mappend
#endif
string :: StringLike s => String -> RenderS s
string s = RenderS (fromString s) id
char :: StringLike s => Char -> RenderS s
char c = RenderS (fromChar c) id
renderArith :: (PersistEntity v, Constructor c, StringLike s, DbDescriptor db) => Proxy db -> (s -> s) -> Arith v c a -> RenderS s
renderArith proxy escape arith = go arith 0 where
go (Plus a b) p = parens 6 p $ go a 6 <> char '+' <> go b 6
go (Minus a b) p = parens 6 p $ go a 6 <> char '-' <> go b 6
go (Mult a b) p = parens 7 p $ go a 7 <> char '*' <> go b 7
go (Abs a) p = parens 9 p $ string "ABS(" <> go a 0 <> char ')'
go (ArithField f) _ = RenderS (head $ renderField escape f []) id
go (Lit a) _ = RenderS (fromChar '?') (toPurePersistValues proxy a)
renderCond :: forall v c s db . (PersistEntity v, Constructor c, StringLike s, DbDescriptor db)
=> Proxy db
-> (s -> s)
-> (s -> s -> s)
-> (s -> s -> s)
-> Cond v c -> Maybe (RenderS s)
renderCond proxy esc rendEq rendNotEq (cond :: Cond v c) = go cond 0 where
go (And a b) p = perhaps 3 p " AND " a b
go (Or a b) p = perhaps 2 p " OR " a b
go (Not a) p = fmap (\a' -> parens 1 p $ string "NOT " <> a') $ go a 1
go (Compare op f1 f2) p = case op of
Eq -> renderComp 3 p " AND " rendEq f1 f2
Ne -> renderComp 2 p " OR " rendNotEq f1 f2
Gt -> renderComp 2 p " OR " (\a b -> a <> fromChar '>' <> b) f1 f2
Lt -> renderComp 2 p " OR " (\a b -> a <> fromChar '<' <> b) f1 f2
Ge -> renderComp 2 p " OR " (\a b -> a <> ">=" <> b) f1 f2
Le -> renderComp 2 p " OR " (\a b -> a <> "<=" <> b) f1 f2
renderComp :: Int -> Int -> s -> (s -> s -> s) -> Expr v c a -> Expr v c b -> Maybe (RenderS s)
renderComp p pOuter logicOp op expr1 expr2 = (case expr1 of
ExprField field -> (case expr2 of
ExprPure a -> guard (map (\f -> f `op` fromChar '?') fs) (toPurePersistValues proxy a)
ExprField a -> guard (zipWith op fs $ renderField esc a []) id
ExprArith a -> case fs of
[f] -> let RenderS q v = renderArith proxy esc a in Just $ RenderS (f `op` q) v
_ -> error $ "renderComp: expected one column field, found " ++ show (length fs)) where
fs = renderField esc field []
ExprPure pure -> (case expr2 of
ExprPure a -> guard (replicate (length fs) $ fromChar '?' `op` fromChar '?') (interleave fs $ toPurePersistValues proxy a [])
ExprField a -> guard (map (\f -> fromChar '?' `op` f) $ renderField esc a []) (toPurePersistValues proxy pure)
ExprArith a -> case fs of
[_] -> let RenderS q v = renderArith proxy esc a in Just $ RenderS (fromChar '?' `op` q) (toPurePersistValues proxy pure . v)
_ -> error $ "renderComp: expected one column field, found " ++ show (length fs)) where
fs = toPurePersistValues proxy pure []
ExprArith arith -> (case expr2 of
ExprPure a -> Just $ RenderS (q `op` fromChar '?') (v . toPurePersistValues proxy a)
ExprField a -> Just $ RenderS (q `op` head (renderField esc a [])) v
ExprArith a -> let RenderS q2 v2 = renderArith proxy esc a in Just $ RenderS (q `op` q2) (v . v2)) where
RenderS q v = renderArith proxy esc arith
) where
guard :: [s] -> ([PersistValue] -> [PersistValue]) -> Maybe (RenderS s)
guard clauses values = case clauses of
[] -> Nothing
[clause] -> Just $ RenderS clause values
clauses' -> Just $ parens p pOuter $ RenderS (intercalateS logicOp clauses') values
interleave [] [] acc = acc
interleave (x:xs) (y:ys) acc = x:y:interleave xs ys acc
interleave _ _ _ = error "renderComp: pure values lists must have the same size"
perhaps :: Int -> Int -> s -> Cond v c -> Cond v c -> Maybe (RenderS s)
perhaps p pOuter op a b = result where
(priority, result) = case (go a priority, go b priority) of
(Just a', Just b') -> (p, Just $ parens p pOuter $ a' <> RenderS op id <> b')
(Just a', Nothing) -> (pOuter, Just a')
(Nothing, Just b') -> (pOuter, Just b')
(Nothing, Nothing) -> (pOuter, Nothing)
renderField :: (PersistEntity v, Constructor c, FieldLike f (RestrictionHolder v c) a', StringLike s) => (s -> s) -> f -> [s] -> [s]
renderField esc field acc = renderChain esc (fieldChain field) acc
renderChain :: StringLike s => (s -> s) -> FieldChain -> [s] -> [s]
renderChain esc (f, prefix) acc = (case prefix of
((name, EmbeddedDef False _):fs) -> flattenP esc (goP (fromString name) fs) f acc
_ -> flatten esc f acc) where
goP p ((name, EmbeddedDef False _):fs) = goP (fromString name <> fromChar delim <> p) fs
goP p _ = p
defaultShowPrim :: PersistValue -> String
defaultShowPrim (PersistString x) = "'" ++ x ++ "'"
defaultShowPrim (PersistByteString x) = "'" ++ show x ++ "'"
defaultShowPrim (PersistInt64 x) = show x
defaultShowPrim (PersistDouble x) = show x
defaultShowPrim (PersistBool x) = if x then "1" else "0"
defaultShowPrim (PersistDay x) = show x
defaultShowPrim (PersistTimeOfDay x) = show x
defaultShowPrim (PersistUTCTime x) = show x
defaultShowPrim (PersistZonedTime x) = show x
defaultShowPrim (PersistNull) = "NULL"
renderOrders :: forall v c s . (PersistEntity v, Constructor c, StringLike s) => (s -> s) -> [Order v c] -> s
renderOrders _ [] = mempty
renderOrders esc xs = if null orders then mempty else " ORDER BY " <> commasJoin orders where
orders = foldr go [] xs
go (Asc a) acc = renderField esc a acc
go (Desc a) acc = renderField (\f -> esc f <> " DESC") a acc
renderFields :: StringLike s => (s -> s) -> [(String, DbType)] -> s
renderFields esc = commasJoin . foldr (flatten esc) []
flatten :: StringLike s => (s -> s) -> (String, DbType) -> ([s] -> [s])
flatten esc (fname, typ) acc = go typ where
go typ' = case typ' of
DbMaybe t -> go t
DbEmbedded emb -> handleEmb emb
DbEntity (Just (emb, _)) _ -> handleEmb emb
_ -> esc fullName : acc
fullName = fromString fname
handleEmb (EmbeddedDef False ts) = foldr (flattenP esc fullName) acc ts
handleEmb (EmbeddedDef True ts) = foldr (flatten esc) acc ts
flattenP :: StringLike s => (s -> s) -> s -> (String, DbType) -> ([s] -> [s])
flattenP esc prefix (fname, typ) acc = go typ where
go typ' = case typ' of
DbMaybe t -> go t
DbEmbedded emb -> handleEmb emb
DbEntity (Just (emb, _)) _ -> handleEmb emb
_ -> esc fullName : acc
fullName = prefix <> fromChar delim <> fromString fname
handleEmb (EmbeddedDef False ts) = foldr (flattenP esc fullName) acc ts
handleEmb (EmbeddedDef True ts) = foldr (flatten esc) acc ts
commasJoin :: StringLike s => [s] -> s
commasJoin = intercalateS (fromChar ',')
intercalateS :: StringLike s => s -> [s] -> s
intercalateS _ [] = mempty
intercalateS a (x:xs) = x <> go xs where
go [] = mempty
go (f:fs) = a <> f <> go fs
commasJoinRenders :: StringLike s => [RenderS s] -> Maybe (RenderS s)
commasJoinRenders [] = Nothing
commasJoinRenders (x:xs) = Just $ foldl' f x xs where
f (RenderS str1 vals1) (RenderS str2 vals2) = RenderS (str1 <> comma <> str2) (vals1 <> vals2)
comma = fromChar ','
renderUpdates :: (PersistEntity v, Constructor c, StringLike s, DbDescriptor db) => Proxy db -> (s -> s) -> [Update v c] -> Maybe (RenderS s)
renderUpdates p esc = commasJoinRenders . mapMaybe go where
go (Update field expr) = (case expr of
ExprPure a -> guard $ RenderS (commasJoin $ map (\f -> f <> "=?") fs) (toPurePersistValues p a)
ExprField a -> guard $ RenderS (commasJoin $ zipWith (\f1 f2 -> f1 <> fromChar '=' <> f2) fs $ renderField esc a []) id
ExprArith a -> case fs of
[f] -> Just $ RenderS (f <> fromChar '=') id <> renderArith p esc a
_ -> error $ "renderUpdates: expected one column field, found " ++ show (length fs)) where
guard a = if null fs then Nothing else Just a
fs = renderField esc field []