{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
module PostgREST.DbRequestBuilder (
readRequest
, mutateRequest
, returningCols
) where
import qualified Data.HashMap.Strict as M
import qualified Data.Set as S
import Control.Arrow ((***))
import Data.Either.Combinators (mapLeft)
import Data.Foldable (foldr1)
import Data.List (delete)
import Data.Text (isInfixOf)
import Control.Applicative
import Data.Tree
import Network.Wai
import PostgREST.ApiRequest (Action (..), ApiRequest (..))
import PostgREST.Error (ApiRequestError (..), errorResponseFor)
import PostgREST.Parsers
import PostgREST.RangeQuery (NonnegRange, allRange, restrictRange)
import PostgREST.Types
import Protolude hiding (from)
readRequest :: Schema -> TableName -> Maybe Integer -> [Relation] -> ApiRequest -> Either Response ReadRequest
readRequest schema rootTableName maxRows allRels apiRequest =
mapLeft errorResponseFor $
treeRestrictRange maxRows =<<
augmentRequestWithJoin schema rootRels =<<
addFiltersOrdersRanges apiRequest <*>
(initReadRequest rootName <$> pRequestSelect sel)
where
sel = fromMaybe "*" $ iSelect apiRequest
(rootName, rootRels) = rootWithRels schema rootTableName allRels (iAction apiRequest)
rootWithRels :: Schema -> TableName -> [Relation] -> Action -> (QualifiedIdentifier, [Relation])
rootWithRels schema rootTableName allRels action = case action of
ActionRead _ -> (QualifiedIdentifier schema rootTableName, allRels)
_ -> (QualifiedIdentifier mempty sourceCTEName, mapMaybe toSourceRel allRels ++ allRels)
where
toSourceRel :: Relation -> Maybe Relation
toSourceRel r@Relation{relTable=t}
| rootTableName == tableName t = Just $ r {relTable=t {tableName=sourceCTEName}}
| otherwise = Nothing
initReadRequest :: QualifiedIdentifier -> [Tree SelectItem] -> ReadRequest
initReadRequest rootQi =
foldr (treeEntry rootDepth) initial
where
rootDepth = 0
rootSchema = qiSchema rootQi
rootName = qiName rootQi
initial = Node (Select [] rootQi Nothing [] [] [] [] allRange, (rootName, Nothing, Nothing, Nothing, rootDepth)) []
treeEntry :: Depth -> Tree SelectItem -> ReadRequest -> ReadRequest
treeEntry depth (Node fld@((fn, _),_,alias, embedHint) fldForest) (Node (q, i) rForest) =
let nxtDepth = succ depth in
case fldForest of
[] -> Node (q {select=fld:select q}, i) rForest
_ -> Node (q, i) $
foldr (treeEntry nxtDepth)
(Node (Select [] (QualifiedIdentifier rootSchema fn) Nothing [] [] [] [] allRange,
(fn, Nothing, alias, embedHint, nxtDepth)) [])
fldForest:rForest
treeRestrictRange :: Maybe Integer -> ReadRequest -> Either ApiRequestError ReadRequest
treeRestrictRange maxRows request = pure $ nodeRestrictRange maxRows <$> request
where
nodeRestrictRange :: Maybe Integer -> ReadNode -> ReadNode
nodeRestrictRange m (q@Select {range_=r}, i) = (q{range_=restrictRange m r }, i)
augmentRequestWithJoin :: Schema -> [Relation] -> ReadRequest -> Either ApiRequestError ReadRequest
augmentRequestWithJoin schema allRels request =
addRels schema allRels Nothing request
>>= addJoinConditions Nothing
addRels :: Schema -> [Relation] -> Maybe ReadRequest -> ReadRequest -> Either ApiRequestError ReadRequest
addRels schema allRels parentNode (Node (query@Select{from=tbl}, (nodeName, _, alias, hint, depth)) forest) =
case parentNode of
Just (Node (Select{from=parentNodeQi}, _) _) ->
let newFrom r = if qiName tbl == nodeName then tableQi (relFTable r) else tbl
newReadNode = (\r -> (query{from=newFrom r}, (nodeName, Just r, alias, Nothing, depth))) <$> rel
rel = findRel schema allRels (qiName parentNodeQi) nodeName hint
in
Node <$> newReadNode <*> (updateForest . hush $ Node <$> newReadNode <*> pure forest)
_ ->
let rn = (query, (nodeName, Nothing, alias, Nothing, depth)) in
Node rn <$> updateForest (Just $ Node rn forest)
where
updateForest :: Maybe ReadRequest -> Either ApiRequestError [ReadRequest]
updateForest rq = mapM (addRels schema allRels rq) forest
findRel :: Schema -> [Relation] -> NodeName -> NodeName -> Maybe EmbedHint -> Either ApiRequestError Relation
findRel schema allRels origin target hint =
case rel of
[] -> Left $ NoRelBetween origin target
[r] -> Right r
rs ->
let [rel0, rel1] = take 2 rs in
if length rs == 2 && relConstraint rel0 == relConstraint rel1 && relTable rel0 == relTable rel1 && relFTable rel0 == relFTable rel1
then note (NoRelBetween origin target) (find (\r -> relType r == O2M) rs)
else Left $ AmbiguousRelBetween origin target rs
where
matchFKSingleCol hint_ cols = length cols == 1 && hint_ == (colName <$> head cols)
rel = filter (
\Relation{relTable, relColumns, relConstraint, relFTable, relFColumns, relType, relJunction} ->
schema == tableSchema relTable && schema == tableSchema relFTable &&
(
origin == tableName relTable &&
target == tableName relFTable ||
(
origin == tableName relTable &&
Just target == relConstraint
) ||
(
origin == tableName relTable &&
matchFKSingleCol (Just target) relColumns
)
) && (
isNothing hint ||
hint == relConstraint ||
matchFKSingleCol hint relColumns ||
matchFKSingleCol hint relFColumns ||
(
relType == M2M &&
hint == (tableName . junTable <$> relJunction)
)
)
) allRels
addJoinConditions :: Maybe Alias -> ReadRequest -> Either ApiRequestError ReadRequest
addJoinConditions previousAlias (Node node@(query@Select{from=tbl}, nodeProps@(_, rel, _, _, depth)) forest) =
case rel of
Just r@Relation{relType=O2M} -> Node (augmentQuery r, nodeProps) <$> updatedForest
Just r@Relation{relType=M2O} -> Node (augmentQuery r, nodeProps) <$> updatedForest
Just r@Relation{relType=M2M, relJunction=junction} ->
case junction of
Just Junction{junTable} ->
let rq = augmentQuery r in
Node (rq{implicitJoins=tableQi junTable:implicitJoins rq}, nodeProps) <$> updatedForest
Nothing ->
Left UnknownRelation
Nothing -> Node node <$> updatedForest
where
newAlias = case isSelfReference <$> rel of
Just True
| depth /= 0 -> Just (qiName tbl <> "_" <> show depth)
| otherwise -> Nothing
_ -> Nothing
augmentQuery r =
foldr
(\jc rq@Select{joinConditions=jcs} -> rq{joinConditions=jc:jcs})
query{fromAlias=newAlias}
(getJoinConditions previousAlias newAlias r)
updatedForest = mapM (addJoinConditions newAlias) forest
getJoinConditions :: Maybe Alias -> Maybe Alias -> Relation -> [JoinCondition]
getJoinConditions previousAlias newAlias (Relation Table{tableSchema=tSchema, tableName=tN} cols _ Table{tableName=ftN} fCols typ jun) =
case typ of
O2M ->
zipWith (toJoinCondition tN ftN) cols fCols
M2O ->
zipWith (toJoinCondition tN ftN) cols fCols
M2M -> case jun of
Just (Junction jt _ jc1 _ jc2) ->
let jtn = tableName jt in
zipWith (toJoinCondition tN jtn) cols jc1 ++ zipWith (toJoinCondition ftN jtn) fCols jc2
Nothing -> []
where
toJoinCondition :: Text -> Text -> Column -> Column -> JoinCondition
toJoinCondition tb ftb c fc =
let qi1 = removeSourceCTESchema tSchema tb
qi2 = removeSourceCTESchema tSchema ftb in
JoinCondition (maybe qi1 (QualifiedIdentifier mempty) previousAlias, colName c)
(maybe qi2 (QualifiedIdentifier mempty) newAlias, colName fc)
removeSourceCTESchema :: Schema -> TableName -> QualifiedIdentifier
removeSourceCTESchema schema tbl = QualifiedIdentifier (if tbl == sourceCTEName then mempty else schema) tbl
addFiltersOrdersRanges :: ApiRequest -> Either ApiRequestError (ReadRequest -> ReadRequest)
addFiltersOrdersRanges apiRequest = foldr1 (liftA2 (.)) [
flip (foldr addFilter) <$> filters,
flip (foldr addOrder) <$> orders,
flip (foldr addRange) <$> ranges,
flip (foldr addLogicTree) <$> logicForest
]
where
filters :: Either ApiRequestError [(EmbedPath, Filter)]
filters = mapM pRequestFilter flts
logicForest :: Either ApiRequestError [(EmbedPath, LogicTree)]
logicForest = mapM pRequestLogicTree logFrst
action = iAction apiRequest
(flts, logFrst) =
case action of
ActionInvoke _ -> (iFilters apiRequest, iLogic apiRequest)
ActionRead _ -> (iFilters apiRequest, iLogic apiRequest)
_ -> join (***) (filter (( "." `isInfixOf` ) . fst)) (iFilters apiRequest, iLogic apiRequest)
orders :: Either ApiRequestError [(EmbedPath, [OrderTerm])]
orders = mapM pRequestOrder $ iOrder apiRequest
ranges :: Either ApiRequestError [(EmbedPath, NonnegRange)]
ranges = mapM pRequestRange $ M.toList $ iRange apiRequest
addFilterToNode :: Filter -> ReadRequest -> ReadRequest
addFilterToNode flt (Node (q@Select {where_=lf}, i) f) = Node (q{where_=addFilterToLogicForest flt lf}::ReadQuery, i) f
addFilter :: (EmbedPath, Filter) -> ReadRequest -> ReadRequest
addFilter = addProperty addFilterToNode
addOrderToNode :: [OrderTerm] -> ReadRequest -> ReadRequest
addOrderToNode o (Node (q,i) f) = Node (q{order=o}, i) f
addOrder :: (EmbedPath, [OrderTerm]) -> ReadRequest -> ReadRequest
addOrder = addProperty addOrderToNode
addRangeToNode :: NonnegRange -> ReadRequest -> ReadRequest
addRangeToNode r (Node (q,i) f) = Node (q{range_=r}, i) f
addRange :: (EmbedPath, NonnegRange) -> ReadRequest -> ReadRequest
addRange = addProperty addRangeToNode
addLogicTreeToNode :: LogicTree -> ReadRequest -> ReadRequest
addLogicTreeToNode t (Node (q@Select{where_=lf},i) f) = Node (q{where_=t:lf}::ReadQuery, i) f
addLogicTree :: (EmbedPath, LogicTree) -> ReadRequest -> ReadRequest
addLogicTree = addProperty addLogicTreeToNode
addProperty :: (a -> ReadRequest -> ReadRequest) -> (EmbedPath, a) -> ReadRequest -> ReadRequest
addProperty f ([], a) rr = f a rr
addProperty f (targetNodeName:remainingPath, a) (Node rn forest) =
case pathNode of
Nothing -> Node rn forest
Just tn -> Node rn (addProperty f (remainingPath, a) tn:delete tn forest)
where
pathNode = find (\(Node (_,(nodeName,_,alias,_,_)) _) -> nodeName == targetNodeName || alias == Just targetNodeName) forest
mutateRequest :: Schema -> TableName -> ApiRequest -> S.Set FieldName -> [FieldName] -> ReadRequest -> Either Response MutateRequest
mutateRequest schema tName apiRequest cols pkCols readReq = mapLeft errorResponseFor $
case action of
ActionCreate -> do
confCols <- case iOnConflict apiRequest of
Nothing -> pure pkCols
Just param -> pRequestOnConflict param
pure $ Insert qi cols ((,) <$> iPreferResolution apiRequest <*> Just confCols) [] returnings
ActionUpdate -> Update qi cols <$> combinedLogic <*> pure returnings
ActionSingleUpsert ->
(\flts ->
if null (iLogic apiRequest) &&
S.fromList (fst <$> iFilters apiRequest) == S.fromList pkCols &&
not (null (S.fromList pkCols)) &&
all (\case
Filter _ (OpExpr False (Op "eq" _)) -> True
_ -> False) flts
then Insert qi cols (Just (MergeDuplicates, pkCols)) <$> combinedLogic <*> pure returnings
else
Left InvalidFilters) =<< filters
ActionDelete -> Delete qi <$> combinedLogic <*> pure returnings
_ -> Left UnsupportedVerb
where
qi = QualifiedIdentifier schema tName
action = iAction apiRequest
returnings =
if iPreferRepresentation apiRequest == None
then []
else returningCols readReq
filters = map snd <$> mapM pRequestFilter mutateFilters
logic = map snd <$> mapM pRequestLogicTree logicFilters
combinedLogic = foldr addFilterToLogicForest <$> logic <*> filters
(mutateFilters, logicFilters) = join (***) onlyRoot (iFilters apiRequest, iLogic apiRequest)
onlyRoot = filter (not . ( "." `isInfixOf` ) . fst)
returningCols :: ReadRequest -> [FieldName]
returningCols rr@(Node _ forest) = returnings
where
fldNames = fstFieldNames rr
fkCols = concat $ mapMaybe (\case
Node (_, (_, Just Relation{relColumns=cols, relType=relTyp}, _, _, _)) _ -> case relTyp of
O2M -> Just cols
M2O -> Just cols
M2M -> Just cols
_ -> Nothing
) forest
returnings = S.toList . S.fromList $ fldNames ++ (colName <$> fkCols)
addFilterToLogicForest :: Filter -> [LogicTree] -> [LogicTree]
addFilterToLogicForest flt lf = Stmnt flt : lf