{-# LANGUAGE TypeSynonymInstances, FlexibleInstances, MultiWayIf #-} {-# OPTIONS_GHC -fno-warn-orphans #-} module PostgREST.PgQuery where import PostgREST.RangeQuery import qualified Hasql as H import qualified Hasql.Postgres as P import qualified Hasql.Backend as B import qualified Data.Text as T import qualified Data.HashMap.Strict as H import Text.Regex.TDFA ( (=~) ) import qualified Network.HTTP.Types.URI as Net import qualified Data.ByteString.Char8 as BS import Data.Monoid import Data.Vector (empty) import Data.Maybe (fromMaybe, mapMaybe) import Data.Functor import Control.Monad (join) import Data.String.Conversions (cs) import qualified Data.Aeson as JSON import qualified Data.List as L import qualified Data.Vector as V import Data.Scientific (isInteger, formatScientific, FPFormat(..)) import Prelude type PStmt = H.Stmt P.Postgres instance Monoid PStmt where mappend (B.Stmt query params prep) (B.Stmt query' params' prep') = B.Stmt (query <> query') (params <> params') (prep && prep') mempty = B.Stmt "" empty True type StatementT = PStmt -> PStmt data QualifiedIdentifier = QualifiedIdentifier { qiSchema :: T.Text , qiName :: T.Text } deriving (Show) data OrderTerm = OrderTerm { otTerm :: T.Text , otDirection :: BS.ByteString , otNullOrder :: Maybe BS.ByteString } limitT :: Maybe NonnegRange -> StatementT limitT r q = q <> B.Stmt (" LIMIT " <> limit <> " OFFSET " <> offset <> " ") empty True where limit = maybe "ALL" (cs . show) $ join $ rangeLimit <$> r offset = cs . show $ fromMaybe 0 $ rangeOffset <$> r whereT :: QualifiedIdentifier -> Net.Query -> StatementT whereT table params q = if L.null cols then q else q <> B.Stmt " where " empty True <> conjunction where cols = [ col | col <- params, fst col `notElem` ["order"] ] wherePredTable = wherePred table conjunction = mconcat $ L.intersperse andq (map wherePredTable cols) withT :: PStmt -> T.Text -> StatementT withT (B.Stmt eq ep epre) v (B.Stmt wq wp wpre) = B.Stmt ("WITH " <> v <> " AS (" <> eq <> ") " <> wq <> " from " <> v) (ep <> wp) (epre && wpre) orderT :: [OrderTerm] -> StatementT orderT ts q = if L.null ts then q else q <> B.Stmt " order by " empty True <> clause where clause = mconcat $ L.intersperse commaq (map queryTerm ts) queryTerm :: OrderTerm -> PStmt queryTerm t = B.Stmt (" " <> cs (pgFmtIdent $ otTerm t) <> " " <> cs (otDirection t) <> " " <> maybe "" cs (otNullOrder t) <> " ") empty True parentheticT :: StatementT parentheticT s = s { B.stmtTemplate = " (" <> B.stmtTemplate s <> ") " } iffNotT :: PStmt -> StatementT iffNotT (B.Stmt aq ap apre) (B.Stmt bq bp bpre) = B.Stmt ("WITH aaa AS (" <> aq <> " returning *) " <> bq <> " WHERE NOT EXISTS (SELECT * FROM aaa)") (ap <> bp) (apre && bpre) countT :: StatementT countT s = s { B.stmtTemplate = "WITH qqq AS (" <> B.stmtTemplate s <> ") SELECT pg_catalog.count(1) FROM qqq" } countRows :: QualifiedIdentifier -> PStmt countRows t = B.Stmt ("select pg_catalog.count(1) from " <> fromQi t) empty True asCsvWithCount :: QualifiedIdentifier -> StatementT asCsvWithCount table = withCount . asCsv table asCsv :: QualifiedIdentifier -> StatementT asCsv table s = s { B.stmtTemplate = "(select string_agg(quote_ident(column_name::text), ',') from " <> "(select column_name from information_schema.columns where quote_ident(table_schema) || '.' || table_name = '" <> fromQi table <> "' order by ordinal_position) h) || '\r' || " <> "coalesce(string_agg(substring(t::text, 2, length(t::text) - 2), '\r'), '') from (" <> B.stmtTemplate s <> ") t" } asJsonWithCount :: StatementT asJsonWithCount = withCount . asJson asJson :: StatementT asJson s = s { B.stmtTemplate = "array_to_json(array_agg(row_to_json(t)))::character varying from (" <> B.stmtTemplate s <> ") t" } withCount :: StatementT withCount s = s { B.stmtTemplate = "pg_catalog.count(t), " <> B.stmtTemplate s } asJsonRow :: StatementT asJsonRow s = s { B.stmtTemplate = "row_to_json(t) from (" <> B.stmtTemplate s <> ") t" } selectStar :: QualifiedIdentifier -> PStmt selectStar t = B.Stmt ("select * from " <> fromQi t) empty True returningStarT :: StatementT returningStarT s = s { B.stmtTemplate = B.stmtTemplate s <> " RETURNING *" } deleteFrom :: QualifiedIdentifier -> PStmt deleteFrom t = B.Stmt ("delete from " <> fromQi t) empty True insertInto :: QualifiedIdentifier -> V.Vector T.Text -> V.Vector (V.Vector JSON.Value) -> PStmt insertInto t cols vals | V.null cols = B.Stmt ("insert into " <> fromQi t <> " default values returning *") empty True | otherwise = B.Stmt ("insert into " <> fromQi t <> " (" <> T.intercalate ", " (V.toList $ V.map pgFmtIdent cols) <> ") values " <> T.intercalate ", " (V.toList $ V.map (\v -> "(" <> T.intercalate ", " (V.toList $ V.map insertableValue v) <> ")" ) vals ) <> " returning row_to_json(" <> fromQi t <> ".*)") empty True insertSelect :: QualifiedIdentifier -> [T.Text] -> [JSON.Value] -> PStmt insertSelect t [] _ = B.Stmt ("insert into " <> fromQi t <> " default values returning *") empty True insertSelect t cols vals = B.Stmt ("insert into " <> fromQi t <> " (" <> T.intercalate ", " (map pgFmtIdent cols) <> ") select " <> T.intercalate ", " (map insertableValue vals)) empty True update :: QualifiedIdentifier -> [T.Text] -> [JSON.Value] -> PStmt update t cols vals = B.Stmt ("update " <> fromQi t <> " set (" <> T.intercalate ", " (map pgFmtIdent cols) <> ") = (" <> T.intercalate ", " (map insertableValue vals) <> ")") empty True callProc :: QualifiedIdentifier -> JSON.Object -> PStmt callProc qi params = do let args = T.intercalate "," $ map assignment (H.toList params) B.Stmt ("select * from " <> fromQi qi <> "(" <> args <> ")") empty True where assignment (n,v) = pgFmtIdent n <> ":=" <> insertableValue v wherePred :: QualifiedIdentifier -> Net.QueryItem -> PStmt wherePred table (col, predicate) = B.Stmt (notOp <> " " <> pgFmtJsonbPath table (cs col) <> " " <> op <> " " <> if opCode `elem` ["is","isnot"] then whiteList value else cs sqlValue) empty True where headPredicate:rest = T.split (=='.') $ cs $ fromMaybe "." predicate hasNot caseTrue caseFalse = if headPredicate == "not" then caseTrue else caseFalse opCode = hasNot (head rest) headPredicate notOp = hasNot headPredicate "" value = hasNot (T.intercalate "." $ tail rest) (T.intercalate "." rest) whiteList val = fromMaybe (cs (pgFmtLit val) <> "::unknown ") (L.find ((==) . T.toLower $ val) ["null","true","false"]) star c = if c == '*' then '%' else c unknownLiteral = (<> "::unknown ") . pgFmtLit sqlValue = case opCode of "like" -> unknownLiteral $ T.map star value "ilike" -> unknownLiteral $ T.map star value "in" -> "(" <> T.intercalate ", " (map unknownLiteral $ T.split (==',') value) <> ") " "notin" -> "(" <> T.intercalate ", " (map unknownLiteral $ T.split (==',') value) <> ") " "@@" -> "to_tsquery(" <> unknownLiteral value <> ") " _ -> unknownLiteral value op = case opCode of "eq" -> "=" "gt" -> ">" "lt" -> "<" "gte" -> ">=" "lte" -> "<=" "neq" -> "<>" "like"-> "like" "ilike"-> "ilike" "in" -> "in" "notin" -> "not in" "is" -> "is" "isnot" -> "is not" "@@" -> "@@" _ -> "=" orderParse :: Net.Query -> [OrderTerm] orderParse q = mapMaybe orderParseTerm . T.split (==',') $ cs order where order = fromMaybe "" $ join (lookup "order" q) orderParseTerm :: T.Text -> Maybe OrderTerm orderParseTerm s = case T.split (=='.') s of (c:d:nls) -> if d `elem` ["asc", "desc"] then Just $ OrderTerm c ( if d == "asc" then "asc" else "desc" ) ( case nls of [n] -> if | n == "nullsfirst" -> Just "nulls first" | n == "nullslast" -> Just "nulls last" | otherwise -> Nothing _ -> Nothing ) else Nothing _ -> Nothing commaq :: PStmt commaq = B.Stmt ", " empty True andq :: PStmt andq = B.Stmt " and " empty True data JsonbPath = ColIdentifier T.Text | KeyIdentifier T.Text | SingleArrow JsonbPath JsonbPath | DoubleArrow JsonbPath JsonbPath deriving (Show) parseJsonbPath :: T.Text -> Maybe JsonbPath parseJsonbPath p = case T.splitOn "->>" p of [a,b] -> let i:is = T.splitOn "->" a in Just $ DoubleArrow (foldl SingleArrow (ColIdentifier i) (map KeyIdentifier is)) (KeyIdentifier b) _ -> Nothing pgFmtJsonbPath :: QualifiedIdentifier -> T.Text -> T.Text pgFmtJsonbPath table p = pgFmtJsonbPath' $ fromMaybe (ColIdentifier p) (parseJsonbPath p) where pgFmtJsonbPath' (ColIdentifier i) = fromQi table <> "." <> pgFmtIdent i pgFmtJsonbPath' (KeyIdentifier i) = pgFmtLit i pgFmtJsonbPath' (SingleArrow a b) = pgFmtJsonbPath' a <> "->" <> pgFmtJsonbPath' b pgFmtJsonbPath' (DoubleArrow a b) = pgFmtJsonbPath' a <> "->>" <> pgFmtJsonbPath' b pgFmtIdent :: T.Text -> T.Text pgFmtIdent x = let escaped = T.replace "\"" "\"\"" (trimNullChars $ cs x) in if (cs escaped :: BS.ByteString) =~ danger then "\"" <> escaped <> "\"" else escaped where danger = "^$|^[^a-z_]|[^a-z_0-9]" :: BS.ByteString pgFmtLit :: T.Text -> T.Text pgFmtLit x = let trimmed = trimNullChars x escaped = "'" <> T.replace "'" "''" trimmed <> "'" slashed = T.replace "\\" "\\\\" escaped in if T.isInfixOf "\\\\" escaped then "E" <> slashed else slashed trimNullChars :: T.Text -> T.Text trimNullChars = T.takeWhile (/= '\x0') fromQi :: QualifiedIdentifier -> T.Text fromQi t = pgFmtIdent (qiSchema t) <> "." <> pgFmtIdent (qiName t) unquoted :: JSON.Value -> T.Text unquoted (JSON.String t) = t unquoted (JSON.Number n) = cs $ formatScientific Fixed (if isInteger n then Just 0 else Nothing) n unquoted (JSON.Bool b) = cs . show $ b unquoted v = cs $ JSON.encode v insertableText :: T.Text -> T.Text insertableText = (<> "::unknown") . pgFmtLit insertableValue :: JSON.Value -> T.Text insertableValue JSON.Null = "null" insertableValue v = insertableText $ unquoted v paramFilter :: JSON.Value -> T.Text paramFilter JSON.Null = "is.null" paramFilter v = "eq." <> unquoted v