module PostgREST.App (
postgrest
) where
import Control.Applicative
import Data.Bifunctor (first)
import qualified Data.ByteString.Char8 as BS
import Data.IORef (IORef, readIORef)
import Data.List (find, delete)
import Data.Maybe (fromMaybe, fromJust, mapMaybe)
import Data.Ranged.Ranges (emptyRange)
import Data.String.Conversions (cs)
import Data.Text (Text, replace, strip)
import Data.Tree
import qualified Hasql.Pool as P
import qualified Hasql.Transaction as HT
import Text.Parsec.Error
import Text.ParserCombinators.Parsec (parse)
import Network.HTTP.Types.Header
import Network.HTTP.Types.Status
import Network.HTTP.Types.URI (renderSimpleQuery)
import Network.Wai
import Network.Wai.Middleware.RequestLogger (logStdout)
import Data.Aeson
import Data.Aeson.Types (emptyArray)
import Data.Monoid
import Data.Time.Clock.POSIX (getPOSIXTime)
import qualified Data.Vector as V
import qualified Hasql.Transaction as H
import qualified Data.HashMap.Strict as M
import PostgREST.ApiRequest (ApiRequest(..), ContentType(..)
, Action(..), Target(..)
, PreferRepresentation (..)
, userApiRequest)
import PostgREST.Auth (tokenJWT, jwtClaims, containsRole)
import PostgREST.Config (AppConfig (..))
import PostgREST.DbStructure
import PostgREST.Error (errResponse, pgErrResponse)
import PostgREST.Parsers
import PostgREST.RangeQuery (NonnegRange, allRange, rangeOffset, restrictRange)
import PostgREST.Middleware
import PostgREST.QueryBuilder ( callProc
, addJoinConditions
, sourceCTEName
, requestToQuery
, requestToCountQuery
, addRelations
, createReadStatement
, createWriteStatement
, ResultsWithCount
)
import PostgREST.Types
import Prelude
postgrest :: AppConfig -> IORef DbStructure -> P.Pool -> Application
postgrest conf refDbStructure pool =
let middle = (if configQuiet conf then id else logStdout) . defaultMiddle in
middle $ \ req respond -> do
time <- getPOSIXTime
body <- strictRequestBody req
dbStructure <- readIORef refDbStructure
let schema = cs $ configSchema conf
apiRequest = userApiRequest schema req body
eClaims = jwtClaims (configJwtSecret conf) (iJWT apiRequest) time
authed = containsRole eClaims
handleReq = runWithClaims conf eClaims (app dbStructure conf) apiRequest
txMode = transactionMode $ iAction apiRequest
resp <- either (pgErrResponse authed) id <$> P.use pool
(HT.run handleReq HT.ReadCommitted txMode)
respond resp
transactionMode :: Action -> H.Mode
transactionMode ActionRead = HT.Read
transactionMode ActionInfo = HT.Read
transactionMode _ = HT.Write
app :: DbStructure -> AppConfig -> ApiRequest -> H.Transaction Response
app dbStructure conf apiRequest =
let
contentType = either (const ApplicationJSON) id (iAccepts apiRequest)
contentTypeH = (hContentType, cs $ show contentType) in
case (iAction apiRequest, iTarget apiRequest, iPayload apiRequest) of
(ActionRead, TargetIdent qi, Nothing) ->
case readSqlParts of
Left e -> return $ responseLBS status400 [jsonH] $ cs e
Right (q, cq) -> do
let singular = iPreferSingular apiRequest
stm = createReadStatement q cq singular
shouldCount (contentType == TextCSV)
respondToRange $ do
row <- H.query () stm
let (tableTotal, queryTotal, _ , body) = row
if singular
then return $ if queryTotal <= 0
then responseLBS status404 [] ""
else responseLBS status200 [contentTypeH] (cs body)
else do
let (status, contentRange) = rangeHeader queryTotal tableTotal
canonical = iCanonicalQS apiRequest
return $ responseLBS status
[contentTypeH, contentRange,
("Content-Location",
"/" <> cs (qiName qi) <>
if Prelude.null canonical then "" else "?" <> cs canonical
)
] (cs body)
(ActionCreate, TargetIdent qi@(QualifiedIdentifier _ table),
Just payload@(PayloadJSON uniform@(UniformObjects rows))) ->
case mutateSqlParts of
Left e -> return $ responseLBS status400 [jsonH] $ cs e
Right (sq,mq) -> do
let isSingle = (==1) $ V.length rows
let pKeys = map pkName $ filter (filterPk schema table) allPrKeys
let stm = createWriteStatement qi sq mq isSingle (iPreferRepresentation apiRequest) pKeys (contentType == TextCSV) payload
row <- H.query uniform stm
let (_, _, fs, body) = extractQueryResult row
header =
if null fs then []
else [(hLocation, "/" <> cs table <> renderLocationFields fs)]
return $ if iPreferRepresentation apiRequest == Full
then responseLBS status201 (contentTypeH : header) (cs body)
else responseLBS status201 header ""
(ActionUpdate, TargetIdent qi, Just payload@(PayloadJSON uniform)) ->
case mutateSqlParts of
Left e -> return $ responseLBS status400 [jsonH] $ cs e
Right (sq,mq) -> do
let stm = createWriteStatement qi sq mq False (iPreferRepresentation apiRequest) [] (contentType == TextCSV) payload
row <- H.query uniform stm
let (_, queryTotal, _, body) = extractQueryResult row
r = contentRangeH 0 (toInteger $ queryTotal1) (toInteger <$> Just queryTotal)
s = case () of _ | queryTotal == 0 -> status404
| iPreferRepresentation apiRequest == Full -> status200
| otherwise -> status204
return $ if iPreferRepresentation apiRequest == Full
then responseLBS s [contentTypeH, r] (cs body)
else responseLBS s [r] ""
(ActionDelete, TargetIdent qi, Nothing) ->
case mutateSqlParts of
Left e -> return $ responseLBS status400 [jsonH] $ cs e
Right (sq,mq) -> do
let emptyUniform = UniformObjects V.empty
fakeload = PayloadJSON emptyUniform
stm = createWriteStatement qi sq mq False (iPreferRepresentation apiRequest) [] (contentType == TextCSV) fakeload
row <- H.query emptyUniform stm
let (_, queryTotal, _, body) = extractQueryResult row
r = contentRangeH 1 0 (toInteger <$> Just queryTotal)
return $ if queryTotal == 0
then notFound
else if iPreferRepresentation apiRequest == Full
then responseLBS status200 [contentTypeH, r] (cs body)
else responseLBS status204 [r] ""
(ActionInfo, TargetIdent (QualifiedIdentifier tSchema tTable), Nothing) ->
let mTable = find (\t -> tableName t == tTable && tableSchema t == tSchema) (dbTables dbStructure) in
case mTable of
Nothing -> return notFound
Just table ->
let cols = filter (filterCol tSchema tTable) $ dbColumns dbStructure
pkeys = map pkName $ filter (filterPk tSchema tTable) allPrKeys
body = encode (TableOptions cols pkeys)
filterCol :: Schema -> TableName -> Column -> Bool
filterCol sc tb Column{colTable=Table{tableSchema=s, tableName=t}} = s==sc && t==tb
filterCol _ _ _ = False
acceptH = (hAllow, if tableInsertable table then "GET,POST,PATCH,DELETE" else "GET") in
return $ responseLBS status200 [jsonH, allOrigins, acceptH] $ cs body
(ActionInvoke, TargetProc qi,
Just (PayloadJSON (UniformObjects payload))) -> do
exists <- H.query qi doesProcExist
if exists
then do
let p = V.head payload
jwtSecret = configJwtSecret conf
respondToRange $ do
row <- H.query () (callProc qi p topLevelRange shouldCount)
returnJWT <- H.query qi doesProcReturnJWT
let (tableTotal, queryTotal, body) = fromMaybe (Just 0, 0, emptyArray) row
(status, contentRange) = rangeHeader queryTotal tableTotal
in
return $ responseLBS status [jsonH, contentRange]
(if returnJWT
then "{\"token\":\"" <> cs (tokenJWT jwtSecret body) <> "\"}"
else cs $ encode body)
else return notFound
(ActionRead, TargetRoot, Nothing) -> do
body <- encode <$> H.query schema accessibleTables
return $ responseLBS status200 [jsonH] $ cs body
(ActionInappropriate, _, _) -> return $ responseLBS status405 [] ""
(_, _, Just (PayloadParseError e)) ->
return $ responseLBS status400 [jsonH] $
cs (formatGeneralError "Cannot parse request payload" (cs e))
(_, TargetUnknown _, _) -> return notFound
(_, _, _) -> return notFound
where
notFound = responseLBS status404 [] ""
filterPk sc table pk = sc == (tableSchema . pkTable) pk && table == (tableName . pkTable) pk
allPrKeys = dbPrimaryKeys dbStructure
allOrigins = ("Access-Control-Allow-Origin", "*") :: Header
schema = cs $ configSchema conf
shouldCount = iPreferCount apiRequest
topLevelRange = fromMaybe allRange $ M.lookup "limit" $ iRange apiRequest
readDbRequest = DbRead <$> buildReadRequest (configMaxRows conf) (dbRelations dbStructure) apiRequest
mutateDbRequest = DbMutate <$> buildMutateRequest apiRequest
selectQuery = requestToQuery schema <$> readDbRequest
countQuery = requestToCountQuery schema <$> readDbRequest
mutateQuery = requestToQuery schema <$> mutateDbRequest
readSqlParts = (,) <$> selectQuery <*> countQuery
mutateSqlParts = (,) <$> selectQuery <*> mutateQuery
respondToRange response = if topLevelRange == emptyRange
then return $ errResponse status416 "HTTP Range error"
else response
rangeHeader queryTotal tableTotal = let frm = rangeOffset topLevelRange
to = frm + toInteger queryTotal 1
contentRange = contentRangeH frm to (toInteger <$> tableTotal)
status = rangeStatus frm to (toInteger <$> tableTotal)
in (status, contentRange)
splitKeyValue :: BS.ByteString -> (BS.ByteString, BS.ByteString)
splitKeyValue kv = (k, BS.tail v)
where (k, v) = BS.break (== '=') kv
renderLocationFields :: [BS.ByteString] -> BS.ByteString
renderLocationFields fields =
renderSimpleQuery True $ map splitKeyValue fields
rangeStatus :: Integer -> Integer -> Maybe Integer -> Status
rangeStatus _ _ Nothing = status200
rangeStatus frm to (Just total)
| frm > total = status416
| (1 + to frm) < total = status206
| otherwise = status200
contentRangeH :: Integer -> Integer -> Maybe Integer -> Header
contentRangeH frm to total =
("Content-Range", cs headerValue)
where
headerValue = rangeString <> "/" <> totalString
rangeString
| totalNotZero && fromInRange = show frm <> "-" <> cs (show to)
| otherwise = "*"
totalString = fromMaybe "*" (show <$> total)
totalNotZero = fromMaybe True ((/=) 0 <$> total)
fromInRange = frm <= to
jsonH :: Header
jsonH = (hContentType, "application/json; charset=utf-8")
formatRelationError :: Text -> Text
formatRelationError = formatGeneralError
"could not find foreign keys between these entities"
formatParserError :: ParseError -> Text
formatParserError e = formatGeneralError message details
where
message = cs $ show (errorPos e)
details = strip $ replace "\n" " " $ cs
$ showErrorMessages "or" "unknown parse error" "expecting" "unexpected" "end of input" (errorMessages e)
formatGeneralError :: Text -> Text -> Text
formatGeneralError message details = cs $ encode $ object [
"message" .= message,
"details" .= details]
augumentRequestWithJoin :: Schema -> [Relation] -> ReadRequest -> Either Text ReadRequest
augumentRequestWithJoin schema allRels request =
(first formatRelationError . addRelations schema allRels Nothing) request
>>= addJoinConditions schema
addFiltersOrdersRanges :: ApiRequest -> Either ParseError (ReadRequest -> ReadRequest)
addFiltersOrdersRanges apiRequest = foldr1 (liftA2 (.)) [
flip (foldr addFilter) <$> filters,
flip (foldr addOrder) <$> orders,
flip (foldr addRange) <$> ranges
]
where
filters :: Either ParseError [(Path, Filter)]
filters = mapM pRequestFilter flts
where
action = iAction apiRequest
flts = if action == ActionRead
then iFilters apiRequest
else filter (( '.' `elem` ) . fst) $ iFilters apiRequest
orders :: Either ParseError [(Path, [OrderTerm])]
orders = mapM pRequestOrder $ iOrder apiRequest
ranges :: Either ParseError [(Path, NonnegRange)]
ranges = mapM pRequestRange $ M.toList $ iRange apiRequest
treeRestrictRange :: Maybe Integer -> ReadRequest -> Either Text ReadRequest
treeRestrictRange maxRows_ request = pure $ nodeRestrictRange maxRows_ `fmap` request
where
nodeRestrictRange :: Maybe Integer -> ReadNode -> ReadNode
nodeRestrictRange m (q@Select {range_=r}, i) = (q{range_=restrictRange m r }, i)
buildReadRequest :: Maybe Integer -> [Relation] -> ApiRequest -> Either Text ReadRequest
buildReadRequest maxRows allRels apiRequest =
treeRestrictRange maxRows =<<
augumentRequestWithJoin schema relations =<<
first formatParserError readRequest
where
(schema, rootTableName) = fromJust $
let target = iTarget apiRequest in
case target of
(TargetIdent (QualifiedIdentifier s t) ) -> Just (s, t)
_ -> Nothing
action :: Action
action = iAction apiRequest
readRequest :: Either ParseError ReadRequest
readRequest = addFiltersOrdersRanges apiRequest <*>
parse (pRequestSelect rootName) ("failed to parse select parameter <<"++selStr++">>") selStr
where
selStr = iSelect apiRequest
rootName = if action == ActionRead
then rootTableName
else sourceCTEName
relations :: [Relation]
relations = case action of
ActionCreate -> fakeSourceRelations ++ allRels
ActionUpdate -> fakeSourceRelations ++ allRels
ActionDelete -> fakeSourceRelations ++ allRels
_ -> allRels
where fakeSourceRelations = mapMaybe (toSourceRelation rootTableName) allRels
buildMutateRequest :: ApiRequest -> Either Text MutateRequest
buildMutateRequest apiRequest = case action of
ActionCreate -> Insert rootTableName <$> pure payload
ActionUpdate -> Update rootTableName <$> pure payload <*> filters
ActionDelete -> Delete rootTableName <$> filters
_ -> Left "Unsupported HTTP verb"
where
action = iAction apiRequest
payload = fromJust $ iPayload apiRequest
rootTableName =
let target = iTarget apiRequest in
case target of
(TargetIdent (QualifiedIdentifier _ t) ) -> t
_ -> undefined
filters = first formatParserError $ map snd <$> mapM pRequestFilter mutateFilters
where mutateFilters = filter (not . ( '.' `elem` ) . fst) $ iFilters apiRequest
addFilterToNode :: Filter -> ReadRequest -> ReadRequest
addFilterToNode flt (Node (q@Select {flt_=flts}, i) f) = Node (q {flt_=flt:flts}, i) f
addFilter :: (Path, Filter) -> ReadRequest -> ReadRequest
addFilter = addProperty addFilterToNode
addOrderToNode :: [OrderTerm] -> ReadRequest -> ReadRequest
addOrderToNode o (Node (q,i) f) = Node (q{order=Just o}, i) f
addOrder :: (Path, [OrderTerm]) -> ReadRequest -> ReadRequest
addOrder = addProperty addOrderToNode
addRangeToNode :: NonnegRange -> ReadRequest -> ReadRequest
addRangeToNode r (Node (q,i) f) = Node (q{range_=r}, i) f
addRange :: (Path, NonnegRange) -> ReadRequest -> ReadRequest
addRange = addProperty addRangeToNode
addProperty :: (a -> ReadRequest -> ReadRequest) -> (Path, a) -> ReadRequest -> ReadRequest
addProperty f ([], a) n = f a n
addProperty f (path, a) (Node rn forest) =
case targetNode of
Nothing -> Node rn forest
Just tn -> Node rn (addProperty f (remainingPath, a) tn:restForest)
where
targetNodeName:remainingPath = path
(targetNode,restForest) = splitForest targetNodeName forest
splitForest :: NodeName -> Forest ReadNode -> (Maybe ReadRequest, Forest ReadNode)
splitForest name forst =
case maybeNode of
Nothing -> (Nothing,forest)
Just node -> (Just node, delete node forest)
where
maybeNode :: Maybe ReadRequest
maybeNode = find fnd forst
where
fnd :: ReadRequest -> Bool
fnd (Node (_,(n,_,_)) _) = n == name
toSourceRelation :: TableName -> Relation -> Maybe Relation
toSourceRelation mt r@(Relation t _ ft _ _ rt _ _)
| mt == tableName t = Just $ r {relTable=t {tableName=sourceCTEName}}
| mt == tableName ft = Just $ r {relFTable=t {tableName=sourceCTEName}}
| Just mt == (tableName <$> rt) = Just $ r {relLTable=(\tbl -> tbl {tableName=sourceCTEName}) <$> rt}
| otherwise = Nothing
data TableOptions = TableOptions {
tblOptcolumns :: [Column]
, tblOptpkey :: [Text]
}
instance ToJSON TableOptions where
toJSON t = object [
"columns" .= tblOptcolumns t
, "pkey" .= tblOptpkey t ]
extractQueryResult :: Maybe ResultsWithCount -> ResultsWithCount
extractQueryResult = fromMaybe (Nothing, 0, [], "")