{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TupleSections #-} --module PostgREST.App where module PostgREST.App ( postgrest ) where import Control.Applicative import Control.Arrow ((***)) import Control.Monad (join) import Data.Bifunctor (first) import Data.List (find, sortBy, delete) import Data.Maybe (isJust, fromMaybe, fromJust, mapMaybe) import Data.Ord (comparing) import Data.Ranged.Ranges (emptyRange) import Data.String.Conversions (cs) import Data.Text (Text, replace, strip) import Data.Tree import qualified Hasql.Pool as P import qualified Hasql.Transaction as HT import Text.Parsec.Error import Text.ParserCombinators.Parsec (parse) import Network.HTTP.Base (urlEncodeVars) import Network.HTTP.Types.Header import Network.HTTP.Types.Status import Network.HTTP.Types.URI (parseSimpleQuery) import Network.Wai import Network.Wai.Middleware.RequestLogger (logStdout) import Data.Aeson import Data.Aeson.Types (emptyArray) import Data.Monoid import Data.Time.Clock.POSIX (getPOSIXTime) import qualified Data.Vector as V import qualified Hasql.Transaction as H import PostgREST.ApiRequest (ApiRequest(..), ContentType(..) , Action(..), Target(..) , PreferRepresentation (..) , userApiRequest) import PostgREST.Auth (tokenJWT) import PostgREST.Config (AppConfig (..)) import PostgREST.DbStructure import PostgREST.Error (errResponse, pgErrResponse) import PostgREST.Parsers import PostgREST.RangeQuery import PostgREST.Middleware import PostgREST.QueryBuilder ( callProc , addJoinConditions , sourceCTEName , requestToQuery , requestToCountQuery , addRelations , createReadStatement , createWriteStatement , ResultsWithCount ) import PostgREST.Types import Prelude postgrest :: AppConfig -> DbStructure -> P.Pool -> Application postgrest conf dbStructure pool = let middle = (if configQuiet conf then id else logStdout) . defaultMiddle in middle $ \ req respond -> do time <- getPOSIXTime body <- strictRequestBody req let handleReq = runWithClaims conf time (app dbStructure conf body) req resp <- either pgErrResponse id <$> P.use pool (HT.run handleReq HT.ReadCommitted HT.Write) respond resp app :: DbStructure -> AppConfig -> RequestBody -> Request -> H.Transaction Response app dbStructure conf reqBody req = let -- TODO: blow up for Left values (there is a middleware that checks the headers) contentType = either (const ApplicationJSON) id (iAccepts apiRequest) contentTypeH = (hContentType, cs $ show contentType) in case (iAction apiRequest, iTarget apiRequest, iPayload apiRequest) of (ActionRead, TargetIdent qi, Nothing) -> case readSqlParts of Left e -> return $ responseLBS status400 [jsonH] $ cs e Right (q, cq) -> do let singular = iPreferSingular apiRequest stm = createReadStatement q cq range singular shouldCount (contentType == TextCSV) respondToRange $ do row <- H.query () stm let (tableTotal, queryTotal, _ , body) = row if singular then return $ if queryTotal <= 0 then responseLBS status404 [] "" else responseLBS status200 [contentTypeH] (cs body) else do let (status, contentRange) = rangeHeader queryTotal tableTotal canonical = urlEncodeVars -- should this be moved to the dbStructure (location)? . sortBy (comparing fst) . map (join (***) cs) . parseSimpleQuery $ rawQueryString req return $ responseLBS status [contentTypeH, contentRange, ("Content-Location", "/" <> cs (qiName qi) <> if Prelude.null canonical then "" else "?" <> cs canonical ) ] (cs body) (ActionCreate, TargetIdent qi@(QualifiedIdentifier _ table), Just payload@(PayloadJSON uniform@(UniformObjects rows))) -> case mutateSqlParts of Left e -> return $ responseLBS status400 [jsonH] $ cs e Right (sq,mq) -> do let isSingle = (==1) $ V.length rows let pKeys = map pkName $ filter (filterPk schema table) allPrKeys -- would it be ok to move primary key detection in the query itself? let stm = createWriteStatement qi sq mq isSingle (iPreferRepresentation apiRequest) pKeys (contentType == TextCSV) payload row <- H.query uniform stm let (_, _, location, body) = extractQueryResult row return $ responseLBS status201 [ contentTypeH, (hLocation, "/" <> cs table <> "?" <> cs location) ] $ if iPreferRepresentation apiRequest == Full then cs body else "" (ActionUpdate, TargetIdent qi, Just payload@(PayloadJSON uniform)) -> case mutateSqlParts of Left e -> return $ responseLBS status400 [jsonH] $ cs e Right (sq,mq) -> do let stm = createWriteStatement qi sq mq False (iPreferRepresentation apiRequest) [] (contentType == TextCSV) payload row <- H.query uniform stm let (_, queryTotal, _, body) = extractQueryResult row r = contentRangeH 0 (toInteger $ queryTotal-1) (toInteger <$> Just queryTotal) s = case () of _ | queryTotal == 0 -> status404 | iPreferRepresentation apiRequest == Full -> status200 | otherwise -> status204 return $ responseLBS s [contentTypeH, r] $ if iPreferRepresentation apiRequest == Full then cs body else "" (ActionDelete, TargetIdent qi, Nothing) -> case mutateSqlParts of Left e -> return $ responseLBS status400 [jsonH] $ cs e Right (sq,mq) -> do let emptyUniform = UniformObjects V.empty let fakeload = PayloadJSON emptyUniform let stm = createWriteStatement qi sq mq False (iPreferRepresentation apiRequest) [] (contentType == TextCSV) fakeload row <- H.query emptyUniform stm let (_, queryTotal, _, _) = extractQueryResult row return $ if queryTotal == 0 then notFound else responseLBS status204 [("Content-Range", "*/"<> cs (show queryTotal))] "" (ActionInfo, TargetIdent (QualifiedIdentifier tSchema tTable), Nothing) -> if isJust $ find (\t -> tableName t == tTable && tableSchema t == tSchema) (dbTables dbStructure) then let cols = filter (filterCol tSchema tTable) $ dbColumns dbStructure pkeys = map pkName $ filter (filterPk tSchema tTable) allPrKeys body = encode (TableOptions cols pkeys) filterCol :: Schema -> TableName -> Column -> Bool filterCol sc tb Column{colTable=Table{tableSchema=s, tableName=t}} = s==sc && t==tb filterCol _ _ _ = False in return $ responseLBS status200 [jsonH, allOrigins] $ cs body else return notFound (ActionInvoke, TargetProc qi, Just (PayloadJSON (UniformObjects payload))) -> do exists <- H.query qi doesProcExist if exists then do let p = V.head payload jwtSecret = configJwtSecret conf respondToRange $ do row <- H.query () (callProc qi p range shouldCount) returnJWT <- H.query qi doesProcReturnJWT let (tableTotal, queryTotal, body) = fromMaybe (Just 0, 0, emptyArray) row (status, contentRange) = rangeHeader queryTotal tableTotal in return $ responseLBS status [jsonH, contentRange] (if returnJWT then "{\"token\":\"" <> cs (tokenJWT jwtSecret body) <> "\"}" else cs $ encode body) else return notFound (ActionRead, TargetRoot, Nothing) -> do body <- encode <$> H.query schema accessibleTables return $ responseLBS status200 [jsonH] $ cs body (ActionInappropriate, _, _) -> return $ responseLBS status405 [] "" (_, _, Just (PayloadParseError e)) -> return $ responseLBS status400 [jsonH] $ cs (formatGeneralError "Cannot parse request payload" (cs e)) (_, TargetUnknown _, _) -> return notFound (_, _, _) -> return notFound where notFound = responseLBS status404 [] "" filterPk sc table pk = sc == (tableSchema . pkTable) pk && table == (tableName . pkTable) pk allPrKeys = dbPrimaryKeys dbStructure allOrigins = ("Access-Control-Allow-Origin", "*") :: Header schema = cs $ configSchema conf apiRequest = userApiRequest schema req reqBody shouldCount = iPreferCount apiRequest range = restrictRange (configMaxRows conf) $ iRange apiRequest readDbRequest = DbRead <$> buildReadRequest (dbRelations dbStructure) apiRequest mutateDbRequest = DbMutate <$> buildMutateRequest apiRequest selectQuery = requestToQuery schema <$> readDbRequest countQuery = requestToCountQuery schema <$> readDbRequest mutateQuery = requestToQuery schema <$> mutateDbRequest readSqlParts = (,) <$> selectQuery <*> countQuery mutateSqlParts = (,) <$> selectQuery <*> mutateQuery respondToRange response = if range == emptyRange then return $ errResponse status416 "HTTP Range error" else response rangeHeader queryTotal tableTotal = let frm = rangeOffset range to = frm + toInteger queryTotal - 1 contentRange = contentRangeH frm to (toInteger <$> tableTotal) status = rangeStatus frm to (toInteger <$> tableTotal) in (status, contentRange) rangeStatus :: Integer -> Integer -> Maybe Integer -> Status rangeStatus _ _ Nothing = status200 rangeStatus frm to (Just total) | frm > total = status416 | (1 + to - frm) < total = status206 | otherwise = status200 contentRangeH :: Integer -> Integer -> Maybe Integer -> Header contentRangeH frm to total = ("Content-Range", cs headerValue) where headerValue = rangeString <> "/" <> totalString rangeString | totalNotZero && fromInRange = show frm <> "-" <> cs (show to) | otherwise = "*" totalString = fromMaybe "*" (show <$> total) totalNotZero = fromMaybe True ((/=) 0 <$> total) fromInRange = frm <= to jsonH :: Header jsonH = (hContentType, "application/json; charset=utf-8") formatRelationError :: Text -> Text formatRelationError = formatGeneralError "could not find foreign keys between these entities" formatParserError :: ParseError -> Text formatParserError e = formatGeneralError message details where message = cs $ show (errorPos e) details = strip $ replace "\n" " " $ cs $ showErrorMessages "or" "unknown parse error" "expecting" "unexpected" "end of input" (errorMessages e) formatGeneralError :: Text -> Text -> Text formatGeneralError message details = cs $ encode $ object [ "message" .= message, "details" .= details] augumentRequestWithJoin :: Schema -> [Relation] -> ReadRequest -> Either Text ReadRequest augumentRequestWithJoin schema allRels request = (first formatRelationError . addRelations schema allRels Nothing) request >>= addJoinConditions schema buildReadRequest :: [Relation] -> ApiRequest -> Either Text ReadRequest buildReadRequest allRels apiRequest = augumentRequestWithJoin schema rels =<< first formatParserError (foldr addFilter <$> (addOrder <$> readRequest <*> ord) <*> flts) where selStr = iSelect apiRequest orderS = iOrder apiRequest action = iAction apiRequest target = iTarget apiRequest (schema, rootTableName) = fromJust $ -- Make it safe case target of (TargetIdent (QualifiedIdentifier s t) ) -> Just (s, t) _ -> Nothing rootName = if action == ActionRead then rootTableName else sourceCTEName filters = if action == ActionRead then iFilters apiRequest else filter (( '.' `elem` ) . fst) $ iFilters apiRequest -- there can be no filters on the root table whre we are doing insert/update rels = case action of ActionCreate -> fakeSourceRelations ++ allRels ActionUpdate -> fakeSourceRelations ++ allRels _ -> allRels where fakeSourceRelations = mapMaybe (toSourceRelation rootTableName) allRels -- see comment in toSourceRelation readRequest = parse (pRequestSelect rootName) ("failed to parse select parameter <<"++selStr++">>") selStr addOrder (Node (q,i) f) o = Node (q{order=o}, i) f flts = mapM pRequestFilter filters ord = traverse (parse pOrder ("failed to parse order parameter <<"++fromMaybe "" orderS++">>")) orderS buildMutateRequest :: ApiRequest -> Either Text MutateRequest buildMutateRequest apiRequest = mutateApiRequest where action = iAction apiRequest target = iTarget apiRequest payload = fromJust $ iPayload apiRequest rootTableName = -- TODO: Make it safe case target of (TargetIdent (QualifiedIdentifier _ t) ) -> t _ -> undefined mutateApiRequest = case action of ActionCreate -> Insert rootTableName <$> pure payload ActionUpdate -> Update rootTableName <$> pure payload <*> cond ActionDelete -> Delete rootTableName <$> cond _ -> Left "Unsupported HTTP verb" mutateFilters = filter (not . ( '.' `elem` ) . fst) $ iFilters apiRequest -- update/delete filters can be only on the root table cond = first formatParserError $ map snd <$> mapM pRequestFilter mutateFilters addFilter :: (Path, Filter) -> ReadRequest -> ReadRequest addFilter ([], flt) (Node (q@Select {flt_=flts}, i) forest) = Node (q {flt_=flt:flts}, i) forest addFilter (path, flt) (Node rn forest) = case targetNode of Nothing -> Node rn forest -- the filter is silenty dropped in the Request does not contain the required path Just tn -> Node rn (addFilter (remainingPath, flt) tn:restForest) where targetNodeName:remainingPath = path (targetNode,restForest) = splitForest targetNodeName forest splitForest name forst = case maybeNode of Nothing -> (Nothing,forest) Just node -> (Just node, delete node forest) where maybeNode = find ((name==).fst.snd.rootLabel) forst -- in a relation where one of the tables mathces "TableName" -- replace the name to that table with pg_source -- this "fake" relations is needed so that in a mutate query -- we can look a the "returning *" part which is wrapped with a "with" -- as just another table that has relations with other tables toSourceRelation :: TableName -> Relation -> Maybe Relation toSourceRelation mt r@(Relation t _ ft _ _ rt _ _) | mt == tableName t = Just $ r {relTable=t {tableName=sourceCTEName}} | mt == tableName ft = Just $ r {relFTable=t {tableName=sourceCTEName}} | Just mt == (tableName <$> rt) = Just $ r {relLTable=(\tbl -> tbl {tableName=sourceCTEName}) <$> rt} | otherwise = Nothing data TableOptions = TableOptions { tblOptcolumns :: [Column] , tblOptpkey :: [Text] } instance ToJSON TableOptions where toJSON t = object [ "columns" .= tblOptcolumns t , "pkey" .= tblOptpkey t ] extractQueryResult :: Maybe ResultsWithCount -> ResultsWithCount extractQueryResult = fromMaybe (Nothing, 0, "", "")