module PostgREST.App (
postgrest
) where
import Control.Applicative
import Data.Aeson (toJSON, eitherDecode)
import qualified Data.ByteString.Char8 as BS
import Data.Maybe
import Data.IORef (IORef, readIORef)
import Data.Text (intercalate)
import qualified Hasql.Pool as P
import qualified Hasql.Transaction as HT
import qualified Hasql.Transaction.Sessions as HT
import Network.HTTP.Types.Header
import Network.HTTP.Types.Status
import Network.HTTP.Types.URI (renderSimpleQuery)
import Network.Wai
import Network.Wai.Middleware.RequestLogger (logStdout)
import qualified Data.Vector as V
import qualified Hasql.Transaction as H
import qualified Data.HashMap.Strict as M
import PostgREST.ApiRequest ( ApiRequest(..), ContentType(..)
, Action(..), Target(..)
, PreferRepresentation (..)
, mutuallyAgreeable
, userApiRequest
)
import PostgREST.Auth (jwtClaims, containsRole, parseJWK)
import PostgREST.Config (AppConfig (..))
import PostgREST.DbStructure
import PostgREST.DbRequestBuilder( readRequest
, mutateRequest
, readRpcRequest
, fieldNames
)
import PostgREST.Error ( simpleError, pgError
, apiRequestError
, singularityError, binaryFieldError
, connectionLostError, gucHeadersError
)
import PostgREST.RangeQuery (allRange, rangeOffset)
import PostgREST.Middleware
import PostgREST.QueryBuilder ( callProc
, requestToQuery
, requestToCountQuery
, createReadStatement
, createWriteStatement
, ResultsWithCount
)
import PostgREST.Types
import PostgREST.OpenAPI
import Data.Function (id)
import Protolude hiding (intercalate, Proxy)
import Safe (headMay)
postgrest :: AppConfig -> IORef (Maybe DbStructure) -> P.Pool -> IO () -> Application
postgrest conf refDbStructure pool worker =
let middle = (if configQuiet conf then id else logStdout) . defaultMiddle
jwtSecret = parseJWK <$> configJwtSecret conf in
middle $ \ req respond -> do
body <- strictRequestBody req
maybeDbStructure <- readIORef refDbStructure
case maybeDbStructure of
Nothing -> respond connectionLostError
Just dbStructure -> do
response <- case userApiRequest (configSchema conf) req body of
Left err -> return $ apiRequestError err
Right apiRequest -> do
eClaims <- jwtClaims jwtSecret (configJwtAudience conf) (toS $ iJWT apiRequest)
let authed = containsRole eClaims
handleReq = runWithClaims conf eClaims (app dbStructure conf) apiRequest
txMode = transactionMode dbStructure
(iTarget apiRequest) (iAction apiRequest)
response <- P.use pool $ HT.transaction HT.ReadCommitted txMode handleReq
return $ either (pgError authed) identity response
when (responseStatus response == status503) worker
respond response
transactionMode :: DbStructure -> Target -> Action -> H.Mode
transactionMode structure target action =
case action of
ActionRead -> HT.Read
ActionInfo -> HT.Read
ActionInspect -> HT.Read
ActionInvoke{isReadOnly=False} ->
let proc =
case target of
(TargetProc qi) -> M.lookup (qiName qi) $
dbProcs structure
_ -> Nothing
v = fromMaybe Volatile $ pdVolatility <$> proc in
if v == Stable || v == Immutable
then HT.Read
else HT.Write
ActionInvoke{isReadOnly=True} -> HT.Read
_ -> HT.Write
app :: DbStructure -> AppConfig -> ApiRequest -> H.Transaction Response
app dbStructure conf apiRequest =
case responseContentTypeOrError (iAccepts apiRequest) (iAction apiRequest) of
Left errorResponse -> return errorResponse
Right contentType ->
case (iAction apiRequest, iTarget apiRequest, iPayload apiRequest) of
(ActionRead, TargetIdent qi, Nothing) ->
let partsField = (,) <$> readSqlParts
<*> (binaryField contentType =<< fldNames) in
case partsField of
Left errorResponse -> return errorResponse
Right ((q, cq), bField) -> do
let stm = createReadStatement q cq (contentType == CTSingularJSON) shouldCount
(contentType == CTTextCSV) bField
row <- H.query () stm
let (tableTotal, queryTotal, _ , body) = row
(status, contentRange) = rangeHeader queryTotal tableTotal
canonical = iCanonicalQS apiRequest
return $
if contentType == CTSingularJSON && queryTotal /= 1
then singularityError (toInteger queryTotal)
else responseLBS status
[toHeader contentType, contentRange,
("Content-Location",
"/" <> toS (qiName qi) <>
if BS.null canonical then "" else "?" <> toS canonical
)
] (toS body)
(ActionCreate, TargetIdent (QualifiedIdentifier _ table), Just payload@(PayloadJSON rows)) ->
case mutateSqlParts of
Left errorResponse -> return errorResponse
Right (sq, mq) -> do
let isSingle = (==1) $ V.length rows
if contentType == CTSingularJSON
&& not isSingle
&& iPreferRepresentation apiRequest == Full
then return $ singularityError (toInteger $ V.length rows)
else do
let pKeys = map pkName $ filter (filterPk schema table) allPrKeys
stm = createWriteStatement sq mq
(contentType == CTSingularJSON) isSingle
(contentType == CTTextCSV) (iPreferRepresentation apiRequest)
pKeys
row <- H.query payload stm
let (_, _, fs, body) = extractQueryResult row
headers = catMaybes [
if null fs
then Nothing
else Just (hLocation, "/" <> toS table <> renderLocationFields fs)
, if iPreferRepresentation apiRequest == Full
then Just $ toHeader contentType
else Nothing
, Just . contentRangeH 1 0 $
toInteger <$> if shouldCount then Just (V.length rows) else Nothing
]
return . responseLBS status201 headers $
if iPreferRepresentation apiRequest == Full
then toS body else ""
(ActionUpdate, TargetIdent _, Just payload@(PayloadJSON rows)) ->
case (mutateSqlParts, null <$> rows V.!? 0, iPreferRepresentation apiRequest == Full) of
(Left errorResponse, _, _) -> return errorResponse
(_, Just True, True) -> return $ responseLBS status200 [contentRangeH 1 0 Nothing] "[]"
(_, Just True, False) -> return $ responseLBS status204 [contentRangeH 1 0 Nothing] ""
(Right (sq, mq), _, _) -> do
let stm = createWriteStatement sq mq
(contentType == CTSingularJSON) False (contentType == CTTextCSV)
(iPreferRepresentation apiRequest) []
row <- H.query payload stm
let (_, queryTotal, _, body) = extractQueryResult row
if contentType == CTSingularJSON
&& queryTotal /= 1
&& iPreferRepresentation apiRequest == Full
then do
HT.condemn
return $ singularityError (toInteger queryTotal)
else do
let r = contentRangeH 0 (toInteger $ queryTotal1)
(toInteger <$> if shouldCount then Just queryTotal else Nothing)
s = if iPreferRepresentation apiRequest == Full
then status200
else status204
return $ if iPreferRepresentation apiRequest == Full
then responseLBS s [toHeader contentType, r] (toS body)
else responseLBS s [r] ""
(ActionDelete, TargetIdent _, Nothing) ->
case mutateSqlParts of
Left errorResponse -> return errorResponse
Right (sq, mq) -> do
let emptyPayload = PayloadJSON V.empty
stm = createWriteStatement sq mq
(contentType == CTSingularJSON) False
(contentType == CTTextCSV)
(iPreferRepresentation apiRequest) []
row <- H.query emptyPayload stm
let (_, queryTotal, _, body) = extractQueryResult row
r = contentRangeH 1 0 $
toInteger <$> if shouldCount then Just queryTotal else Nothing
if contentType == CTSingularJSON
&& queryTotal /= 1
&& iPreferRepresentation apiRequest == Full
then do
HT.condemn
return $ singularityError (toInteger queryTotal)
else
return $ if iPreferRepresentation apiRequest == Full
then responseLBS status200 [toHeader contentType, r] (toS body)
else responseLBS status204 [r] ""
(ActionInfo, TargetIdent (QualifiedIdentifier tSchema tTable), Nothing) ->
let mTable = find (\t -> tableName t == tTable && tableSchema t == tSchema) (dbTables dbStructure) in
case mTable of
Nothing -> return notFound
Just table ->
let acceptH = (hAllow, if tableInsertable table then "GET,POST,PATCH,DELETE" else "GET") in
return $ responseLBS status200 [allOrigins, acceptH] ""
(ActionInvoke _isReadOnly, TargetProc qi, payload) ->
let proc = M.lookup (qiName qi) allProcs
returnsScalar = case proc of
Just ProcDescription{pdReturnType = (Single (Scalar _))} -> True
_ -> False
rpcBinaryField = if returnsScalar
then Right Nothing
else binaryField contentType =<< fldNames
parts = (,,) <$> readSqlParts <*> rpcBinaryField <*> rpcQParams in
case parts of
Left errorResponse -> return errorResponse
Right ((q, cq), bField, params) -> do
let prms = case payload of
Just (PayloadJSON pld) -> V.head pld
Nothing -> M.fromList $ second toJSON <$> params
singular = contentType == CTSingularJSON
paramsAsSingleObject = iPreferSingleObjectParameter apiRequest
row <- H.query () $
callProc qi prms returnsScalar q cq shouldCount
singular paramsAsSingleObject
(contentType == CTTextCSV)
(contentType == CTOctetStream) _isReadOnly bField
(pgVersion dbStructure)
let (tableTotal, queryTotal, body, jsonHeaders) =
fromMaybe (Just 0, 0, "[]", "[]") row
(status, contentRange) = rangeHeader queryTotal tableTotal
decodedHeaders = first toS $ eitherDecode $ toS jsonHeaders :: Either Text [GucHeader]
case decodedHeaders of
Left _ -> return gucHeadersError
Right hs ->
if singular && queryTotal /= 1
then do
HT.condemn
return $ singularityError (toInteger queryTotal)
else return $ responseLBS status ([toHeader contentType, contentRange] ++ toHeaders hs) (toS body)
(ActionInspect, TargetRoot, Nothing) -> do
let host = configHost conf
port = toInteger $ configPort conf
proxy = pickProxy $ toS <$> configProxyUri conf
uri Nothing = ("http", host, port, "/")
uri (Just Proxy { proxyScheme = s, proxyHost = h, proxyPort = p, proxyPath = b }) = (s, h, p, b)
uri' = uri proxy
encodeApi ti sd procs = encodeOpenAPI (M.elems procs) (toTableInfo ti) uri' sd (dbPrimaryKeys dbStructure)
body <- encodeApi <$> H.query schema accessibleTables <*> H.query schema schemaDescription <*> H.query schema accessibleProcs
return $ responseLBS status200 [toHeader CTOpenAPI] $ toS body
_ -> return notFound
where
toTableInfo :: [Table] -> [(Table, [Column], [Text])]
toTableInfo = map (\t ->
let tSchema = tableSchema t
tTable = tableName t
cols = filter (filterCol tSchema tTable) $ dbColumns dbStructure
pkeys = map pkName $ filter (filterPk tSchema tTable) allPrKeys
in (t, cols, pkeys))
notFound = responseLBS status404 [] ""
filterPk sc table pk = sc == (tableSchema . pkTable) pk && table == (tableName . pkTable) pk
filterCol :: Schema -> TableName -> Column -> Bool
filterCol sc tb Column{colTable=Table{tableSchema=s, tableName=t}} = s==sc && t==tb
allPrKeys = dbPrimaryKeys dbStructure
allProcs = dbProcs dbStructure
allOrigins = ("Access-Control-Allow-Origin", "*") :: Header
shouldCount = iPreferCount apiRequest
schema = toS $ configSchema conf
topLevelRange = fromMaybe allRange $ M.lookup "limit" $ iRange apiRequest
rangeHeader queryTotal tableTotal =
let lower = rangeOffset topLevelRange
upper = lower + toInteger queryTotal 1
contentRange = contentRangeH lower upper (toInteger <$> tableTotal)
status = rangeStatus lower upper (toInteger <$> tableTotal)
in (status, contentRange)
readReq = readRequest (configMaxRows conf) (dbRelations dbStructure) allProcs apiRequest
fldNames = fieldNames <$> readReq
readDbRequest = DbRead <$> readReq
mutateDbRequest = DbMutate <$> (mutateRequest apiRequest =<< fldNames)
rpcQParams = readRpcRequest apiRequest
selectQuery = requestToQuery schema False <$> readDbRequest
mutateQuery = requestToQuery schema False <$> mutateDbRequest
countQuery = requestToCountQuery schema <$> readDbRequest
readSqlParts = (,) <$> selectQuery <*> countQuery
mutateSqlParts = (,) <$> selectQuery <*> mutateQuery
responseContentTypeOrError :: [ContentType] -> Action -> Either Response ContentType
responseContentTypeOrError accepts action = serves contentTypesForRequest accepts
where
contentTypesForRequest =
case action of
ActionRead -> [CTApplicationJSON, CTSingularJSON, CTTextCSV, CTOctetStream]
ActionCreate -> [CTApplicationJSON, CTSingularJSON, CTTextCSV]
ActionUpdate -> [CTApplicationJSON, CTSingularJSON, CTTextCSV]
ActionDelete -> [CTApplicationJSON, CTSingularJSON, CTTextCSV]
ActionInvoke _ -> [CTApplicationJSON, CTSingularJSON, CTTextCSV, CTOctetStream]
ActionInspect -> [CTOpenAPI, CTApplicationJSON]
ActionInfo -> [CTTextCSV]
serves sProduces cAccepts =
case mutuallyAgreeable sProduces cAccepts of
Nothing -> do
let failed = intercalate ", " $ map (toS . toMime) cAccepts
Left $ simpleError status415 [] $
"None of these Content-Types are available: " <> failed
Just ct -> Right ct
binaryField :: ContentType -> [FieldName] -> Either Response (Maybe FieldName)
binaryField CTOctetStream fldNames =
if length fldNames == 1 && fieldName /= Just "*"
then Right fieldName
else Left binaryFieldError
where
fieldName = headMay fldNames
binaryField _ _ = Right Nothing
splitKeyValue :: BS.ByteString -> (BS.ByteString, BS.ByteString)
splitKeyValue kv = (k, BS.tail v)
where (k, v) = BS.break (== '=') kv
renderLocationFields :: [BS.ByteString] -> BS.ByteString
renderLocationFields fields =
renderSimpleQuery True $ map splitKeyValue fields
rangeStatus :: Integer -> Integer -> Maybe Integer -> Status
rangeStatus _ _ Nothing = status200
rangeStatus lower upper (Just total)
| lower > total = status416
| (1 + upper lower) < total = status206
| otherwise = status200
contentRangeH :: Integer -> Integer -> Maybe Integer -> Header
contentRangeH lower upper total =
("Content-Range", headerValue)
where
headerValue = rangeString <> "/" <> totalString
rangeString
| totalNotZero && fromInRange = show lower <> "-" <> show upper
| otherwise = "*"
totalString = fromMaybe "*" (show <$> total)
totalNotZero = fromMaybe True ((/=) 0 <$> total)
fromInRange = lower <= upper
extractQueryResult :: Maybe ResultsWithCount -> ResultsWithCount
extractQueryResult = fromMaybe (Nothing, 0, [], "")