{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE NamedFieldPuns #-}
module PostgREST.App (
postgrest
) where
import Control.Applicative
import Data.Aeson as JSON
import qualified Data.ByteString.Char8 as BS
import Data.Maybe
import Data.IORef (IORef, readIORef)
import Data.Text (intercalate)
import Data.Time.Clock (UTCTime)
import qualified Data.Set as S
import qualified Hasql.Pool as P
import qualified Hasql.Transaction as HT
import qualified Hasql.Transaction.Sessions as HT
import Network.HTTP.Types.Header
import Network.HTTP.Types.Status
import Network.HTTP.Types.URI (renderSimpleQuery)
import Network.Wai
import Network.Wai.Middleware.RequestLogger (logStdout)
import qualified Hasql.Transaction as H
import qualified Data.HashMap.Strict as M
import PostgREST.ApiRequest ( ApiRequest(..), ContentType(..)
, Action(..), Target(..)
, PreferRepresentation (..)
, mutuallyAgreeable
, userApiRequest
)
import PostgREST.Auth (jwtClaims, containsRole, parseSecret)
import PostgREST.Config (AppConfig (..))
import PostgREST.DbStructure
import PostgREST.DbRequestBuilder( readRequest
, mutateRequest
, fieldNames
)
import PostgREST.Error ( simpleError, pgError
, apiRequestError
, singularityError, binaryFieldError
, connectionLostError, gucHeadersError
)
import PostgREST.RangeQuery (allRange, rangeOffset)
import PostgREST.Middleware
import PostgREST.QueryBuilder ( callProc
, requestToQuery
, requestToCountQuery
, createReadStatement
, createWriteStatement
, ResultsWithCount
)
import PostgREST.Types
import PostgREST.OpenAPI
import Data.Function (id)
import Protolude hiding (intercalate, Proxy)
postgrest :: AppConfig -> IORef (Maybe DbStructure) -> P.Pool -> IO UTCTime -> IO () -> Application
postgrest conf refDbStructure pool getTime worker =
let middle = (if configQuiet conf then id else logStdout) . defaultMiddle
jwtSecret = parseSecret <$> configJwtSecret conf in
middle $ \ req respond -> do
time <- getTime
body <- strictRequestBody req
maybeDbStructure <- readIORef refDbStructure
case maybeDbStructure of
Nothing -> respond connectionLostError
Just dbStructure -> do
response <- case userApiRequest (configSchema conf) req body of
Left err -> return $ apiRequestError err
Right apiRequest -> do
eClaims <- jwtClaims jwtSecret (configJwtAudience conf) (toS $ iJWT apiRequest) time (rightToMaybe $ configRoleClaimKey conf)
let authed = containsRole eClaims
proc = case (iTarget apiRequest, iPayload apiRequest, iPreferSingleObjectParameter apiRequest) of
(TargetProc qi, Just PayloadJSON{pjKeys}, s) -> findProc qi pjKeys s $ dbProcs dbStructure
_ -> Nothing
handleReq = runWithClaims conf eClaims (app dbStructure proc conf) apiRequest
txMode = transactionMode proc (iAction apiRequest)
response <- P.use pool $ HT.transaction HT.ReadCommitted txMode handleReq
return $ either (pgError authed) identity response
when (responseStatus response == status503) worker
respond response
findProc :: QualifiedIdentifier -> S.Set Text -> Bool -> M.HashMap Text [ProcDescription] -> Maybe ProcDescription
findProc qi payloadKeys paramsAsSingleObject allProcs =
let procs = M.lookup (qiName qi) allProcs in
join $ (case length <$> procs of
Just 1 -> headMay
_ -> find (\x ->
if paramsAsSingleObject
then length (pdArgs x) == 1
else payloadKeys `S.isSubsetOf` S.fromList (pgaName <$> pdArgs x))
) <$> procs
transactionMode :: Maybe ProcDescription -> Action -> HT.Mode
transactionMode proc action =
case action of
ActionRead -> HT.Read
ActionInfo -> HT.Read
ActionInspect -> HT.Read
ActionInvoke{isReadOnly=False} ->
let v = maybe Volatile pdVolatility proc in
if v == Stable || v == Immutable
then HT.Read
else HT.Write
ActionInvoke{isReadOnly=True} -> HT.Read
_ -> HT.Write
app :: DbStructure -> Maybe ProcDescription -> AppConfig -> ApiRequest -> H.Transaction Response
app dbStructure proc conf apiRequest =
case responseContentTypeOrError (iAccepts apiRequest) (iAction apiRequest) of
Left errorResponse -> return errorResponse
Right contentType ->
case (iAction apiRequest, iTarget apiRequest, iPayload apiRequest) of
(ActionRead, TargetIdent qi, Nothing) ->
let partsField = (,) <$> readSqlParts
<*> (binaryField contentType =<< fldNames) in
case partsField of
Left errorResponse -> return errorResponse
Right ((q, cq), bField) -> do
let stm = createReadStatement q cq (contentType == CTSingularJSON) shouldCount
(contentType == CTTextCSV) bField
row <- H.statement () stm
let (tableTotal, queryTotal, _ , body) = row
(status, contentRange) = rangeHeader queryTotal tableTotal
canonical = iCanonicalQS apiRequest
return $
if contentType == CTSingularJSON && queryTotal /= 1
then singularityError (toInteger queryTotal)
else responseLBS status
[toHeader contentType, contentRange,
("Content-Location",
"/" <> toS (qiName qi) <>
if BS.null canonical then "" else "?" <> toS canonical
)
] (toS body)
(ActionCreate, TargetIdent (QualifiedIdentifier tSchema tName), Just PayloadJSON{pjRaw, pjType}) ->
case mutateSqlParts tSchema tName of
Left errorResponse -> return errorResponse
Right (sq, mq) -> do
let (isSingle, nRows) = case pjType of
PJArray len -> (len == 1, len)
PJObject -> (True, 1)
if contentType == CTSingularJSON
&& not isSingle
&& iPreferRepresentation apiRequest == Full
then return $ singularityError (toInteger nRows)
else do
let pkCols = tablePKCols dbStructure tSchema tName
stm = createWriteStatement sq mq
(contentType == CTSingularJSON) isSingle
(contentType == CTTextCSV) (iPreferRepresentation apiRequest) pkCols
row <- H.statement (toS pjRaw) stm
let (_, _, fs, body) = extractQueryResult row
headers = catMaybes [
if null fs
then Nothing
else Just (hLocation, "/" <> toS tName <> renderLocationFields fs)
, if iPreferRepresentation apiRequest == Full
then Just $ toHeader contentType
else Nothing
, Just . contentRangeH 1 0 $
toInteger <$> if shouldCount then Just nRows else Nothing
, if null pkCols
then Nothing
else (\x -> ("Preference-Applied", show x)) <$> iPreferResolution apiRequest
]
return . responseLBS status201 headers $
if iPreferRepresentation apiRequest == Full
then toS body else ""
(ActionUpdate, TargetIdent (QualifiedIdentifier tSchema tName), Just p@PayloadJSON{pjRaw}) ->
case (mutateSqlParts tSchema tName, pjIsEmpty p, iPreferRepresentation apiRequest == Full) of
(Left errorResponse, _, _) -> return errorResponse
(_, True, True) -> return $ responseLBS status200 [contentRangeH 1 0 Nothing] "[]"
(_, True, False) -> return $ responseLBS status204 [contentRangeH 1 0 Nothing] ""
(Right (sq, mq), _, _) -> do
let stm = createWriteStatement sq mq
(contentType == CTSingularJSON) False (contentType == CTTextCSV)
(iPreferRepresentation apiRequest) []
row <- H.statement (toS pjRaw) stm
let (_, queryTotal, _, body) = extractQueryResult row
if contentType == CTSingularJSON
&& queryTotal /= 1
&& iPreferRepresentation apiRequest == Full
then do
HT.condemn
return $ singularityError (toInteger queryTotal)
else do
let r = contentRangeH 0 (toInteger $ queryTotal-1)
(toInteger <$> if shouldCount then Just queryTotal else Nothing)
s = if iPreferRepresentation apiRequest == Full
then status200
else status204
return $ if iPreferRepresentation apiRequest == Full
then responseLBS s [toHeader contentType, r] (toS body)
else responseLBS s [r] ""
(ActionSingleUpsert, TargetIdent (QualifiedIdentifier tSchema tName), Just PayloadJSON{pjRaw, pjType, pjKeys}) ->
case mutateSqlParts tSchema tName of
Left errorResponse -> return errorResponse
Right (sq, mq) -> do
let isSingle = case pjType of
PJArray len -> len == 1
PJObject -> True
colNames = colName <$> tableCols dbStructure tSchema tName
if topLevelRange /= allRange
then return $ simpleError status400 [] "Range header and limit/offset querystring parameters are not allowed for PUT"
else if not isSingle
then return $ simpleError status400 [] "PUT payload must contain a single row"
else if S.fromList colNames /= pjKeys
then return $ simpleError status400 [] "You must specify all columns in the payload when using PUT"
else do
row <- H.statement (toS pjRaw) $
createWriteStatement sq mq (contentType == CTSingularJSON) False
(contentType == CTTextCSV) (iPreferRepresentation apiRequest) []
let (_, queryTotal, _, body) = extractQueryResult row
if queryTotal /= 1
then do
HT.condemn
return $ simpleError status400 [] "Payload values do not match URL in primary key column(s)"
else
return $ if iPreferRepresentation apiRequest == Full
then responseLBS status200 [toHeader contentType] (toS body)
else responseLBS status204 [] ""
(ActionDelete, TargetIdent (QualifiedIdentifier tSchema tName), Nothing) ->
case mutateSqlParts tSchema tName of
Left errorResponse -> return errorResponse
Right (sq, mq) -> do
let stm = createWriteStatement sq mq
(contentType == CTSingularJSON) False
(contentType == CTTextCSV)
(iPreferRepresentation apiRequest) []
row <- H.statement mempty stm
let (_, queryTotal, _, body) = extractQueryResult row
r = contentRangeH 1 0 $
toInteger <$> if shouldCount then Just queryTotal else Nothing
if contentType == CTSingularJSON
&& queryTotal /= 1
&& iPreferRepresentation apiRequest == Full
then do
HT.condemn
return $ singularityError (toInteger queryTotal)
else
return $ if iPreferRepresentation apiRequest == Full
then responseLBS status200 [toHeader contentType, r] (toS body)
else responseLBS status204 [r] ""
(ActionInfo, TargetIdent (QualifiedIdentifier tSchema tTable), Nothing) ->
let mTable = find (\t -> tableName t == tTable && tableSchema t == tSchema) (dbTables dbStructure) in
case mTable of
Nothing -> return notFound
Just table ->
let acceptH = (hAllow, if tableInsertable table then "GET,POST,PATCH,DELETE" else "GET") in
return $ responseLBS status200 [allOrigins, acceptH] ""
(ActionInvoke _, TargetProc qi, Just PayloadJSON{pjRaw, pjType, pjKeys}) ->
let returnsScalar = case proc of
Just ProcDescription{pdReturnType = (Single (Scalar _))} -> True
_ -> False
rpcBinaryField = if returnsScalar
then Right Nothing
else binaryField contentType =<< fldNames
parts = (,) <$> readSqlParts <*> rpcBinaryField in
case parts of
Left errorResponse -> return errorResponse
Right ((q, cq), bField) -> do
let isObject = case pjType of
PJObject -> True
PJArray _ -> False
singular = contentType == CTSingularJSON
specifiedPgArgs = filter ((`S.member` pjKeys) . pgaName) $ maybe [] pdArgs proc
row <- H.statement (toS pjRaw) $
callProc qi specifiedPgArgs returnsScalar q cq shouldCount
singular (iPreferSingleObjectParameter apiRequest)
(contentType == CTTextCSV)
(contentType == CTOctetStream) bField isObject
(pgVersion dbStructure)
let (tableTotal, queryTotal, body, jsonHeaders) =
fromMaybe (Just 0, 0, "[]", "[]") row
(status, contentRange) = rangeHeader queryTotal tableTotal
decodedHeaders = first toS $ JSON.eitherDecode $ toS jsonHeaders :: Either Text [GucHeader]
case decodedHeaders of
Left _ -> return gucHeadersError
Right hs ->
if singular && queryTotal /= 1
then do
HT.condemn
return $ singularityError (toInteger queryTotal)
else return $ responseLBS status ([toHeader contentType, contentRange] ++ toHeaders hs) (toS body)
(ActionInspect, TargetRoot, Nothing) -> do
let host = configHost conf
port = toInteger $ configPort conf
proxy = pickProxy $ toS <$> configProxyUri conf
uri Nothing = ("http", host, port, "/")
uri (Just Proxy { proxyScheme = s, proxyHost = h, proxyPort = p, proxyPath = b }) = (s, h, p, b)
uri' = uri proxy
toTableInfo :: [Table] -> [(Table, [Column], [Text])]
toTableInfo = map (\t -> let (s, tn) = (tableSchema t, tableName t) in (t, tableCols dbStructure s tn, tablePKCols dbStructure s tn))
encodeApi ti sd procs = encodeOpenAPI (concat $ M.elems procs) (toTableInfo ti) uri' sd $ dbPrimaryKeys dbStructure
body <- encodeApi <$> H.statement schema accessibleTables <*> H.statement schema schemaDescription <*> H.statement schema accessibleProcs
return $ responseLBS status200 [toHeader CTOpenAPI] $ toS body
_ -> return notFound
where
notFound = responseLBS status404 [] ""
allOrigins = ("Access-Control-Allow-Origin", "*") :: Header
shouldCount = iPreferCount apiRequest
schema = toS $ configSchema conf
topLevelRange = fromMaybe allRange $ M.lookup "limit" $ iRange apiRequest
rangeHeader queryTotal tableTotal =
let lower = rangeOffset topLevelRange
upper = lower + toInteger queryTotal - 1
contentRange = contentRangeH lower upper (toInteger <$> tableTotal)
status = rangeStatus lower upper (toInteger <$> tableTotal)
in (status, contentRange)
readReq = readRequest (configMaxRows conf) (dbRelations dbStructure) proc apiRequest
fldNames = fieldNames <$> readReq
readDbRequest = DbRead <$> readReq
selectQuery = requestToQuery schema False <$> readDbRequest
countQuery = requestToCountQuery schema <$> readDbRequest
readSqlParts = (,) <$> selectQuery <*> countQuery
mutateSqlParts s t =
(,) <$> selectQuery
<*> (requestToQuery schema False . DbMutate <$> (mutateRequest apiRequest t (tablePKCols dbStructure s t) =<< fldNames))
responseContentTypeOrError :: [ContentType] -> Action -> Either Response ContentType
responseContentTypeOrError accepts action = serves contentTypesForRequest accepts
where
contentTypesForRequest =
case action of
ActionRead -> [CTApplicationJSON, CTSingularJSON, CTTextCSV, CTOctetStream]
ActionCreate -> [CTApplicationJSON, CTSingularJSON, CTTextCSV]
ActionUpdate -> [CTApplicationJSON, CTSingularJSON, CTTextCSV]
ActionDelete -> [CTApplicationJSON, CTSingularJSON, CTTextCSV]
ActionInvoke _ -> [CTApplicationJSON, CTSingularJSON, CTTextCSV, CTOctetStream]
ActionInspect -> [CTOpenAPI, CTApplicationJSON]
ActionInfo -> [CTTextCSV]
ActionSingleUpsert -> [CTApplicationJSON, CTSingularJSON, CTTextCSV]
serves sProduces cAccepts =
case mutuallyAgreeable sProduces cAccepts of
Nothing -> do
let failed = intercalate ", " $ map (toS . toMime) cAccepts
Left $ simpleError status415 [] $
"None of these Content-Types are available: " <> failed
Just ct -> Right ct
binaryField :: ContentType -> [FieldName] -> Either Response (Maybe FieldName)
binaryField CTOctetStream fldNames =
if length fldNames == 1 && fieldName /= Just "*"
then Right fieldName
else Left binaryFieldError
where
fieldName = headMay fldNames
binaryField _ _ = Right Nothing
splitKeyValue :: BS.ByteString -> (BS.ByteString, BS.ByteString)
splitKeyValue kv = (k, BS.tail v)
where (k, v) = BS.break (== '=') kv
renderLocationFields :: [BS.ByteString] -> BS.ByteString
renderLocationFields fields =
renderSimpleQuery True $ map splitKeyValue fields
rangeStatus :: Integer -> Integer -> Maybe Integer -> Status
rangeStatus _ _ Nothing = status200
rangeStatus lower upper (Just total)
| lower > total = status416
| (1 + upper - lower) < total = status206
| otherwise = status200
contentRangeH :: Integer -> Integer -> Maybe Integer -> Header
contentRangeH lower upper total =
("Content-Range", headerValue)
where
headerValue = rangeString <> "/" <> totalString
rangeString
| totalNotZero && fromInRange = show lower <> "-" <> show upper
| otherwise = "*"
totalString = maybe "*" show total
totalNotZero = maybe True (0 /=) total
fromInRange = lower <= upper
extractQueryResult :: Maybe ResultsWithCount -> ResultsWithCount
extractQueryResult = fromMaybe (Nothing, 0, [], "")