{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE TupleSections #-} {-# OPTIONS_GHC -fno-warn-orphans #-} {-| Module : PostgREST.QueryBuilder Description : PostgREST SQL generating functions. This module provides functions to consume data types that represent database objects (e.g. Relation, Schema, SqlQuery) and produces SQL Statements. Any function that outputs a SQL fragment should be in this module. -} module PostgREST.QueryBuilder ( addRelations , addJoinConditions , asJson , callProc , createReadStatement , createWriteStatement , operators , pgFmtIdent , pgFmtLit , requestToQuery , requestToCountQuery , sourceCTEName , unquoted ) where import qualified Hasql as H import qualified Hasql.Backend as B import qualified Hasql.Postgres as P import qualified Data.Aeson as JSON import PostgREST.RangeQuery (NonnegRange, rangeLimit, rangeOffset) import Control.Error (note, fromMaybe, mapMaybe) import qualified Data.HashMap.Strict as HM import Data.List (find, (\\)) import Data.Monoid ((<>)) import Data.Text (Text, intercalate, unwords, replace, isInfixOf, toLower, split) import qualified Data.Text as T (map, takeWhile) import Data.String.Conversions (cs) import Control.Applicative (empty, (<|>)) import Control.Monad (join) import Data.Tree (Tree(..)) import qualified Data.Vector as V import PostgREST.Types import qualified Data.Map as M import Text.Regex.TDFA ((=~)) import qualified Data.ByteString.Char8 as BS import Data.Scientific ( FPFormat (..) , formatScientific , isInteger ) import Prelude hiding (unwords) import PostgREST.ApiRequest (PreferRepresentation (..)) type PStmt = H.Stmt P.Postgres instance Monoid PStmt where mappend (B.Stmt query params prep) (B.Stmt query' params' prep') = B.Stmt (query <> query') (params <> params') (prep && prep') mempty = B.Stmt "" empty True type StatementT = PStmt -> PStmt createReadStatement :: SqlQuery -> SqlQuery -> NonnegRange -> Bool -> Bool -> Bool -> B.Stmt P.Postgres createReadStatement selectQuery countQuery range isSingle countTotal asCsv = B.Stmt ( "WITH " <> sourceCTEName <> " AS (" <> selectQuery <> ") " <> "SELECT " <> intercalate ", " [ countResultF <> " AS total_result_set", "pg_catalog.count(t) AS page_total", "null AS header", bodyF <> " AS body" ] <> " FROM ( SELECT * FROM " <> sourceCTEName <> " " <> limitF range <> ") t" ) V.empty True where countResultF = if countTotal then "("<>countQuery<>")" else "null" bodyF | asCsv = asCsvF | isSingle = asJsonSingleF | otherwise = asJsonF createWriteStatement :: QualifiedIdentifier -> SqlQuery -> SqlQuery -> Bool -> PreferRepresentation -> [Text] -> Bool -> Payload -> B.Stmt P.Postgres createWriteStatement _ _ _ _ _ _ _ (PayloadParseError _) = undefined createWriteStatement _ _ mutateQuery _ None _ _ (PayloadJSON (UniformObjects rows)) = B.Stmt ( "WITH " <> sourceCTEName <> " AS (" <> mutateQuery <> ") " <> "SELECT null, 0, null, null" ) (V.singleton . B.encodeValue . JSON.Array . V.map JSON.Object $ rows) True createWriteStatement qi _ mutateQuery isSingle HeadersOnly pKeys _ (PayloadJSON (UniformObjects rows)) = B.Stmt ( "WITH " <> sourceCTEName <> " AS (" <> mutateQuery <> " RETURNING " <> fromQi qi <> ".*" <> ") " <> "SELECT " <> intercalate ", " [ "null AS total_result_set", "pg_catalog.count(t) AS page_total", if isSingle then locationF pKeys else "null", "null" ] <> " FROM (SELECT 1 FROM " <> sourceCTEName <> ") t" ) (V.singleton . B.encodeValue . JSON.Array . V.map JSON.Object $ rows) True createWriteStatement qi selectQuery mutateQuery isSingle Full pKeys asCsv (PayloadJSON (UniformObjects rows)) = B.Stmt ( "WITH " <> sourceCTEName <> " AS (" <> mutateQuery <> " RETURNING " <> fromQi qi <> ".*" <> ") " <> "SELECT " <> intercalate ", " [ "null AS total_result_set", -- when updateing it does not make sense "pg_catalog.count(t) AS page_total", if isSingle then locationF pKeys else "null" <> " AS header", bodyF <> " AS body" ] <> " FROM ( "<>selectQuery<>") t" ) (V.singleton . B.encodeValue . JSON.Array . V.map JSON.Object $ rows) True where bodyF | asCsv = asCsvF | isSingle = asJsonSingleF | otherwise = asJsonF addRelations :: Schema -> [Relation] -> Maybe ReadRequest -> ReadRequest -> Either Text ReadRequest addRelations schema allRelations parentNode node@(Node readNode@(query, (name, _)) forest) = case parentNode of (Just (Node (Select{from=[parentTable]}, (_, _)) _)) -> Node <$> (addRel readNode <$> rel) <*> updatedForest where rel = note ("no relation between " <> parentTable <> " and " <> name) $ findRelationByTable schema name parentTable <|> findRelationByTable schema parentTable name <|> findRelationByColumn schema parentTable name addRel :: (ReadQuery, (NodeName, Maybe Relation)) -> Relation -> (ReadQuery, (NodeName, Maybe Relation)) addRel (q, (n, _)) r = (q {from=fromRelation}, (n, Just r)) where fromRelation = map (\t -> if t == n then tableName (relTable r) else t) (from q) _ -> Node (query, (name, Nothing)) <$> updatedForest where updatedForest = mapM (addRelations schema allRelations (Just node)) forest -- Searches through all the relations and returns a match given the parameter conditions. -- Will only find a relation where both schemas are in the PostgREST schema. -- `findRelationByColumn` also does a ducktype check to see if the column name has any variation of `id` or `fk`. If so then the relation is returned as a match. findRelationByTable s t1 t2 = find (\r -> s == tableSchema (relTable r) && s == tableSchema (relFTable r) && t1 == tableName (relTable r) && t2 == tableName (relFTable r)) allRelations findRelationByColumn s t c = find (\r -> s == tableSchema (relTable r) && s == tableSchema (relFTable r) && t == tableName (relFTable r) && length (relFColumns r) == 1 && c `colMatches` (colName . head . relFColumns) r) allRelations where n `colMatches` rc = (cs ("^" <> rc <> "_?(?:|[iI][dD]|[fF][kK])$") :: BS.ByteString) =~ (cs n :: BS.ByteString) addJoinConditions :: Schema -> ReadRequest -> Either Text ReadRequest addJoinConditions schema (Node (query, (n, r)) forest) = case r of Nothing -> Node (updatedQuery, (n,r)) <$> updatedForest -- this is the root node Just rel@(Relation{relType=Child}) -> Node (addCond updatedQuery (getJoinConditions rel),(n,r)) <$> updatedForest Just (Relation{relType=Parent}) -> Node (updatedQuery, (n,r)) <$> updatedForest Just rel@(Relation{relType=Many, relLTable=(Just linkTable)}) -> Node (qq, (n, r)) <$> updatedForest where q = addCond updatedQuery (getJoinConditions rel) qq = q{from=tableName linkTable : from q} _ -> Left "unknown relation" where -- add parentTable and parentJoinConditions to the query updatedQuery = foldr (flip addCond) query parentJoinConditions where parentJoinConditions = map (getJoinConditions . snd) parents parents = mapMaybe (getParents . rootLabel) forest getParents (_, (tbl, Just rel@(Relation{relType=Parent}))) = Just (tbl, rel) getParents _ = Nothing updatedForest = mapM (addJoinConditions schema) forest addCond q con = q{flt_=con ++ flt_ q} asJson :: StatementT asJson s = s { B.stmtTemplate = "array_to_json(coalesce(array_agg(row_to_json(t)), '{}'))::character varying from (" <> B.stmtTemplate s <> ") t" } callProc :: QualifiedIdentifier -> JSON.Object -> PStmt callProc qi params = do let args = intercalate "," $ map assignment (HM.toList params) B.Stmt ("select * from " <> fromQi qi <> "(" <> args <> ")") empty True where assignment (n,v) = pgFmtIdent n <> ":=" <> insertableValue v operators :: [(Text, SqlFragment)] operators = [ ("eq", "="), ("gte", ">="), -- has to be before gt (parsers) ("gt", ">"), ("lte", "<="), -- has to be before lt (parsers) ("lt", "<"), ("neq", "<>"), ("like", "like"), ("ilike", "ilike"), ("in", "in"), ("notin", "not in"), ("isnot", "is not"), -- has to be before is (parsers) ("is", "is"), ("@@", "@@"), ("@>", "@>"), ("<@", "<@") ] pgFmtIdent :: SqlFragment -> SqlFragment pgFmtIdent x = "\"" <> replace "\"" "\"\"" (trimNullChars $ cs x) <> "\"" pgFmtLit :: SqlFragment -> SqlFragment pgFmtLit x = let trimmed = trimNullChars x escaped = "'" <> replace "'" "''" trimmed <> "'" slashed = replace "\\" "\\\\" escaped in if "\\\\" `isInfixOf` escaped then "E" <> slashed else slashed requestToCountQuery :: Schema -> DbRequest -> SqlQuery requestToCountQuery _ (DbMutate _) = undefined requestToCountQuery schema (DbRead (Node (Select _ _ conditions _, (mainTbl, _)) _)) = unwords [ "SELECT pg_catalog.count(1)", "FROM ", fromQi $ QualifiedIdentifier schema mainTbl, ("WHERE " <> intercalate " AND " ( map (pgFmtCondition (QualifiedIdentifier schema mainTbl)) localConditions )) `emptyOnNull` localConditions ] where fn (Filter{value=VText _}) = True fn (Filter{value=VForeignKey _ _}) = False localConditions = filter fn conditions requestToQuery :: Schema -> DbRequest -> SqlQuery requestToQuery _ (DbMutate (Insert _ (PayloadParseError _))) = undefined requestToQuery _ (DbMutate (Update _ (PayloadParseError _) _)) = undefined requestToQuery schema (DbRead (Node (Select colSelects tbls conditions ord, (nodeName, maybeRelation)) forest)) = query where -- TODO! the folloing helper functions are just to remove the "schema" part when the table is "source" which is the name -- of our WITH query part mainTbl = fromMaybe nodeName (tableName . relTable <$> maybeRelation) tblSchema tbl = if tbl == sourceCTEName then "" else schema qi = QualifiedIdentifier (tblSchema mainTbl) mainTbl toQi t = QualifiedIdentifier (tblSchema t) t query = unwords [ "SELECT ", intercalate ", " (map (pgFmtSelectItem qi) colSelects ++ selects), "FROM ", intercalate ", " (map (fromQi . toQi) tbls), unwords (map joinStr joins), ("WHERE " <> intercalate " AND " ( map (pgFmtCondition qi ) localConditions )) `emptyOnNull` localConditions, orderF (fromMaybe [] ord) ] orderF ts = if null ts then "" else "ORDER BY " <> clause where clause = intercalate "," (map queryTerm ts) queryTerm :: OrderTerm -> Text queryTerm t = " " <> cs (pgFmtColumn qi $ otTerm t) <> " " <> (cs.show) (otDirection t) <> " " <> maybe "" (cs.show) (otNullOrder t) <> " " (joins, selects) = foldr getQueryParts ([],[]) forest parentTables = map snd joins parentConditions = join $ map (( `filter` conditions ) . filterParentConditions) parentTables localConditions = conditions \\ parentConditions joinStr :: (SqlFragment, TableName) -> SqlFragment joinStr (sql, t) = "LEFT OUTER JOIN " <> sql <> " ON " <> intercalate " AND " ( map (pgFmtCondition qi ) joinConditions ) where joinConditions = filter (filterParentConditions t) conditions filterParentConditions parentTable (Filter _ _ (VForeignKey (QualifiedIdentifier "" t) _)) = parentTable == t filterParentConditions _ _ = False getQueryParts :: Tree ReadNode -> ([(SqlFragment, TableName)], [SqlFragment]) -> ([(SqlFragment,TableName)], [SqlFragment]) getQueryParts (Node n@(_, (name, Just (Relation {relType=Child,relTable=Table{tableName=table}}))) forst) (j,s) = (j,sel:s) where sel = "COALESCE((" <> "SELECT array_to_json(array_agg(row_to_json("<>table<>"))) " <> "FROM (" <> subquery <> ") " <> table <> "), '[]') AS " <> pgFmtIdent name where subquery = requestToQuery schema (DbRead (Node n forst)) getQueryParts (Node n@(_, (name, Just (Relation {relType=Parent,relTable=Table{tableName=table}}))) forst) (j,s) = (joi:j,sel:s) where sel = "row_to_json(" <> table <> ".*) AS "<>pgFmtIdent name --TODO must be singular joi = ("( " <> subquery <> " ) AS " <> table, table) where subquery = requestToQuery schema (DbRead (Node n forst)) getQueryParts (Node n@(_, (name, Just (Relation {relType=Many,relTable=Table{tableName=table}}))) forst) (j,s) = (j,sel:s) where sel = "COALESCE ((" <> "SELECT array_to_json(array_agg(row_to_json("<>table<>"))) " <> "FROM (" <> subquery <> ") " <> table <> "), '[]') AS " <> pgFmtIdent name where subquery = requestToQuery schema (DbRead (Node n forst)) --the following is just to remove the warning --getQueryParts is not total but requestToQuery is called only after addJoinConditions which ensures the only --posible relations are Child Parent Many getQueryParts (Node (_,(_,Nothing)) _) _ = undefined requestToQuery schema (DbMutate (Insert mainTbl (PayloadJSON (UniformObjects rows)))) = let qi = QualifiedIdentifier schema mainTbl cols = map pgFmtIdent $ fromMaybe [] (HM.keys <$> (rows V.!? 0)) colsString = intercalate ", " cols in unwords [ "INSERT INTO ", fromQi qi, " (" <> colsString <> ")" <> " SELECT " <> colsString <> " FROM json_populate_recordset(null::" , fromQi qi, ", ?)" ] requestToQuery schema (DbMutate (Update mainTbl (PayloadJSON (UniformObjects rows)) conditions)) = case rows V.!? 0 of Just obj -> let assignments = map (\(k,v) -> pgFmtIdent k <> "=" <> insertableValue v) $ HM.toList obj in unwords [ "UPDATE ", fromQi qi, " SET " <> intercalate "," assignments <> " ", ("WHERE " <> intercalate " AND " ( map (pgFmtCondition qi ) conditions )) `emptyOnNull` conditions ] Nothing -> undefined where qi = QualifiedIdentifier schema mainTbl requestToQuery schema (DbMutate (Delete mainTbl conditions)) = query where qi = QualifiedIdentifier schema mainTbl query = unwords [ "DELETE FROM ", fromQi qi, ("WHERE " <> intercalate " AND " ( map (pgFmtCondition qi ) conditions )) `emptyOnNull` conditions ] sourceCTEName :: SqlFragment sourceCTEName = "pg_source" unquoted :: JSON.Value -> Text unquoted (JSON.String t) = t unquoted (JSON.Number n) = cs $ formatScientific Fixed (if isInteger n then Just 0 else Nothing) n unquoted (JSON.Bool b) = cs . show $ b unquoted v = cs $ JSON.encode v -- private functions asCsvF :: SqlFragment asCsvF = asCsvHeaderF <> " || '\n' || " <> asCsvBodyF where asCsvHeaderF = "(SELECT string_agg(a.k, ',')" <> " FROM (" <> " SELECT json_object_keys(r)::TEXT as k" <> " FROM ( " <> " SELECT row_to_json(hh) as r from " <> sourceCTEName <> " as hh limit 1" <> " ) s" <> " ) a" <> ")" asCsvBodyF = "coalesce(string_agg(substring(t::text, 2, length(t::text) - 2), '\n'), '')" asJsonF :: SqlFragment asJsonF = "array_to_json(array_agg(row_to_json(t)))::character varying" asJsonSingleF :: SqlFragment --TODO! unsafe when the query actually returns multiple rows, used only on inserting and returning single element asJsonSingleF = "string_agg(row_to_json(t)::text, ',')::character varying " locationF :: [Text] -> SqlFragment locationF pKeys = "(" <> " WITH s AS (SELECT row_to_json(ss) as r from " <> sourceCTEName <> " as ss limit 1)" <> " SELECT string_agg(json_data.key || '=' || coalesce( 'eq.' || json_data.value, 'is.null'), '&')" <> " FROM s, json_each_text(s.r) AS json_data" <> ( if null pKeys then "" else " WHERE json_data.key IN ('" <> intercalate "','" pKeys <> "')" ) <> ")" limitF :: NonnegRange -> SqlFragment limitF r = "LIMIT " <> limit <> " OFFSET " <> offset where limit = maybe "ALL" (cs . show) $ rangeLimit r offset = cs . show $ rangeOffset r fromQi :: QualifiedIdentifier -> SqlFragment fromQi t = (if s == "" then "" else pgFmtIdent s <> ".") <> pgFmtIdent n where n = qiName t s = qiSchema t getJoinConditions :: Relation -> [Filter] getJoinConditions (Relation t cols ft fcs typ lt lc1 lc2) = case typ of Child -> zipWith (toFilter tN ftN) cols fcs Parent -> zipWith (toFilter tN ftN) cols fcs Many -> zipWith (toFilter tN ltN) cols (fromMaybe [] lc1) ++ zipWith (toFilter ftN ltN) fcs (fromMaybe [] lc2) where s = if typ == Parent then "" else tableSchema t tN = tableName t ftN = tableName ft ltN = fromMaybe "" (tableName <$> lt) toFilter :: Text -> Text -> Column -> Column -> Filter toFilter tb ftb c fc = Filter (colName c, Nothing) "=" (VForeignKey (QualifiedIdentifier s tb) (ForeignKey fc{colTable=(colTable fc){tableName=ftb}})) emptyOnNull :: Text -> [a] -> Text emptyOnNull val x = if null x then "" else val insertableValue :: JSON.Value -> SqlFragment insertableValue JSON.Null = "null" insertableValue v = (<> "::unknown") . pgFmtLit $ unquoted v whiteList :: Text -> SqlFragment whiteList val = fromMaybe (cs (pgFmtLit val) <> "::unknown ") (find ((==) . toLower $ val) ["null","true","false"]) pgFmtColumn :: QualifiedIdentifier -> Text -> SqlFragment pgFmtColumn table "*" = fromQi table <> ".*" pgFmtColumn table c = fromQi table <> "." <> pgFmtIdent c pgFmtField :: QualifiedIdentifier -> Field -> SqlFragment pgFmtField table (c, jp) = pgFmtColumn table c <> pgFmtJsonPath jp pgFmtSelectItem :: QualifiedIdentifier -> SelectItem -> SqlFragment pgFmtSelectItem table (f@(_, jp), Nothing) = pgFmtField table f <> pgFmtAsJsonPath jp pgFmtSelectItem table (f@(_, jp), Just cast ) = "CAST (" <> pgFmtField table f <> " AS " <> cast <> " )" <> pgFmtAsJsonPath jp pgFmtCondition :: QualifiedIdentifier -> Filter -> SqlFragment pgFmtCondition table (Filter (col,jp) ops val) = notOp <> " " <> sqlCol <> " " <> pgFmtOperator opCode <> " " <> if opCode `elem` ["is","isnot"] then whiteList (getInner val) else sqlValue where headPredicate:rest = split (=='.') ops hasNot caseTrue caseFalse = if headPredicate == "not" then caseTrue else caseFalse opCode = hasNot (head rest) headPredicate notOp = hasNot headPredicate "" sqlCol = case val of VText _ -> pgFmtColumn table col <> pgFmtJsonPath jp VForeignKey qi _ -> pgFmtColumn qi col sqlValue = valToStr val getInner v = case v of VText s -> s _ -> "" valToStr v = case v of VText s -> pgFmtValue opCode s VForeignKey (QualifiedIdentifier s _) (ForeignKey Column{colTable=Table{tableName=ft}, colName=fc}) -> pgFmtColumn qi fc where qi = QualifiedIdentifier (if ft == sourceCTEName then "" else s) ft _ -> "" pgFmtValue :: Text -> Text -> SqlFragment pgFmtValue opCode val = case opCode of "like" -> unknownLiteral $ T.map star val "ilike" -> unknownLiteral $ T.map star val "in" -> "(" <> intercalate ", " (map unknownLiteral $ split (==',') val) <> ") " "notin" -> "(" <> intercalate ", " (map unknownLiteral $ split (==',') val) <> ") " "@@" -> "to_tsquery(" <> unknownLiteral val <> ") " _ -> unknownLiteral val where star c = if c == '*' then '%' else c unknownLiteral = (<> "::unknown ") . pgFmtLit pgFmtOperator :: Text -> SqlFragment pgFmtOperator opCode = fromMaybe "=" $ M.lookup opCode operatorsMap where operatorsMap = M.fromList operators pgFmtJsonPath :: Maybe JsonPath -> SqlFragment pgFmtJsonPath (Just [x]) = "->>" <> pgFmtLit x pgFmtJsonPath (Just (x:xs)) = "->" <> pgFmtLit x <> pgFmtJsonPath ( Just xs ) pgFmtJsonPath _ = "" pgFmtAsJsonPath :: Maybe JsonPath -> SqlFragment pgFmtAsJsonPath Nothing = "" pgFmtAsJsonPath (Just xx) = " AS " <> last xx trimNullChars :: Text -> Text trimNullChars = T.takeWhile (/= '\x0')