{-# LANGUAGE FlexibleContexts    #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
--module PostgREST.App where
module PostgREST.App (
  postgrest
) where

import           Control.Applicative
import           Data.Bifunctor            (first)
import qualified Data.ByteString.Char8   as BS
import           Data.IORef                (IORef, readIORef)
import           Data.List                 (find, delete)
import           Data.Maybe                (fromMaybe, fromJust, mapMaybe)
import           Data.Ranged.Ranges        (emptyRange)
import           Data.String.Conversions   (cs)
import           Data.Text                 (Text, replace, strip)
import           Data.Tree

import qualified Hasql.Pool                as P
import qualified Hasql.Transaction         as HT

import           Text.Parsec.Error
import           Text.ParserCombinators.Parsec (parse)

import           Network.HTTP.Types.Header
import           Network.HTTP.Types.Status
import           Network.HTTP.Types.URI (renderSimpleQuery)
import           Network.Wai
import           Network.Wai.Middleware.RequestLogger (logStdout)

import           Data.Aeson
import           Data.Aeson.Types (emptyArray)
import           Data.Monoid
import           Data.Time.Clock.POSIX     (getPOSIXTime)
import qualified Data.Vector               as V
import qualified Hasql.Transaction         as H

import qualified Data.HashMap.Strict       as M

import           PostgREST.ApiRequest   (ApiRequest(..), ContentType(..)
                                            , Action(..), Target(..)
                                            , PreferRepresentation (..)
                                            , userApiRequest)
import           PostgREST.Auth            (tokenJWT, jwtClaims, containsRole)
import           PostgREST.Config          (AppConfig (..))
import           PostgREST.DbStructure
import           PostgREST.Error           (errResponse, pgErrResponse)
import           PostgREST.Parsers
import           PostgREST.RangeQuery      (NonnegRange, allRange, rangeOffset, restrictRange)
import           PostgREST.Middleware
import           PostgREST.QueryBuilder ( callProc
                                        , addJoinConditions
                                        , sourceCTEName
                                        , requestToQuery
                                        , requestToCountQuery
                                        , addRelations
                                        , createReadStatement
                                        , createWriteStatement
                                        , ResultsWithCount
                                        )
import           PostgREST.Types

import           Prelude


postgrest :: AppConfig -> IORef DbStructure -> P.Pool -> Application
postgrest conf refDbStructure pool =
  let middle = (if configQuiet conf then id else logStdout) . defaultMiddle in

  middle $ \ req respond -> do
    time <- getPOSIXTime
    body <- strictRequestBody req
    dbStructure <- readIORef refDbStructure

    let schema = cs $ configSchema conf
        apiRequest = userApiRequest schema req body
        eClaims = jwtClaims (configJwtSecret conf) (iJWT apiRequest) time
        authed = containsRole eClaims
        handleReq = runWithClaims conf eClaims (app dbStructure conf) apiRequest
        txMode = transactionMode $ iAction apiRequest

    resp <- either (pgErrResponse authed) id <$> P.use pool
      (HT.run handleReq HT.ReadCommitted txMode)
    respond resp

transactionMode :: Action -> H.Mode
transactionMode ActionRead = HT.Read
transactionMode ActionInfo = HT.Read
transactionMode _ = HT.Write

app :: DbStructure -> AppConfig -> ApiRequest -> H.Transaction Response
app dbStructure conf apiRequest =
  let
      -- TODO: blow up for Left values (there is a middleware that checks the headers)
      contentType = either (const ApplicationJSON) id (iAccepts apiRequest)
      contentTypeH = (hContentType, cs $ show contentType) in

  case (iAction apiRequest, iTarget apiRequest, iPayload apiRequest) of

    (ActionRead, TargetIdent qi, Nothing) ->
      case readSqlParts of
        Left e -> return $ responseLBS status400 [jsonH] $ cs e
        Right (q, cq) -> do
          let singular = iPreferSingular apiRequest
              stm = createReadStatement q cq singular
                    shouldCount (contentType == TextCSV)
          respondToRange $ do
            row <- H.query () stm
            let (tableTotal, queryTotal, _ , body) = row
            if singular
            then return $ if queryTotal <= 0
              then responseLBS status404 [] ""
              else responseLBS status200 [contentTypeH] (cs body)
            else do
              let (status, contentRange) = rangeHeader queryTotal tableTotal
                  canonical = iCanonicalQS apiRequest
              return $ responseLBS status
                [contentTypeH, contentRange,
                  ("Content-Location",
                    "/" <> cs (qiName qi) <>
                      if Prelude.null canonical then "" else "?" <> cs canonical
                  )
                ] (cs body)

    (ActionCreate, TargetIdent qi@(QualifiedIdentifier _ table),
     Just payload@(PayloadJSON uniform@(UniformObjects rows))) ->
      case mutateSqlParts of
        Left e -> return $ responseLBS status400 [jsonH] $ cs e
        Right (sq,mq) -> do
          let isSingle = (==1) $ V.length rows
          let pKeys = map pkName $ filter (filterPk schema table) allPrKeys -- would it be ok to move primary key detection in the query itself?
          let stm = createWriteStatement qi sq mq isSingle (iPreferRepresentation apiRequest) pKeys (contentType == TextCSV) payload
          row <- H.query uniform stm
          let (_, _, fs, body) = extractQueryResult row
              header =
                if null fs then []
                else [(hLocation, "/" <> cs table <> renderLocationFields fs)]

          return $ if iPreferRepresentation apiRequest == Full
            then responseLBS status201 (contentTypeH : header) (cs body)
            else responseLBS status201 header ""

    (ActionUpdate, TargetIdent qi, Just payload@(PayloadJSON uniform)) ->
      case mutateSqlParts of
        Left e -> return $ responseLBS status400 [jsonH] $ cs e
        Right (sq,mq) -> do
          let stm = createWriteStatement qi sq mq False (iPreferRepresentation apiRequest) [] (contentType == TextCSV) payload
          row <- H.query uniform stm
          let (_, queryTotal, _, body) = extractQueryResult row
              r = contentRangeH 0 (toInteger $ queryTotal-1) (toInteger <$> Just queryTotal)
              s = case () of _ | queryTotal == 0 -> status404
                               | iPreferRepresentation apiRequest == Full -> status200
                               | otherwise -> status204
          return $ if iPreferRepresentation apiRequest == Full
            then responseLBS s [contentTypeH, r] (cs body)
            else responseLBS s [r] ""

    (ActionDelete, TargetIdent qi, Nothing) ->
      case mutateSqlParts of
        Left e -> return $ responseLBS status400 [jsonH] $ cs e
        Right (sq,mq) -> do
          let emptyUniform = UniformObjects V.empty
              fakeload = PayloadJSON emptyUniform
              stm = createWriteStatement qi sq mq False (iPreferRepresentation apiRequest) [] (contentType == TextCSV) fakeload
          row <- H.query emptyUniform stm
          let (_, queryTotal, _, body) = extractQueryResult row
              r = contentRangeH 1 0 (toInteger <$> Just queryTotal)
          return $ if queryTotal == 0
            then notFound
            else if iPreferRepresentation apiRequest == Full
              then responseLBS status200 [contentTypeH, r] (cs body)
              else responseLBS status204 [r] ""

    (ActionInfo, TargetIdent (QualifiedIdentifier tSchema tTable), Nothing) ->
      let mTable = find (\t -> tableName t == tTable && tableSchema t == tSchema) (dbTables dbStructure) in
      case mTable of
        Nothing -> return notFound
        Just table ->
          let cols = filter (filterCol tSchema tTable) $ dbColumns dbStructure
              pkeys = map pkName $ filter (filterPk tSchema tTable) allPrKeys
              body = encode (TableOptions cols pkeys)
              filterCol :: Schema -> TableName -> Column -> Bool
              filterCol sc tb Column{colTable=Table{tableSchema=s, tableName=t}} = s==sc && t==tb
              filterCol _ _ _ =  False
              acceptH = (hAllow, if tableInsertable table then "GET,POST,PATCH,DELETE" else "GET") in
          return $ responseLBS status200 [jsonH, allOrigins, acceptH] $ cs body

    (ActionInvoke, TargetProc qi,
     Just (PayloadJSON (UniformObjects payload))) -> do
      exists <- H.query qi doesProcExist
      if exists
        then do
          let p = V.head payload
              jwtSecret = configJwtSecret conf
          respondToRange $ do
            row <- H.query () (callProc qi p topLevelRange shouldCount)
            returnJWT <- H.query qi doesProcReturnJWT
            let (tableTotal, queryTotal, body) = fromMaybe (Just 0, 0, emptyArray) row
                (status, contentRange) = rangeHeader queryTotal tableTotal
              in
              return $ responseLBS status [jsonH, contentRange]
                      (if returnJWT
                      then "{\"token\":\"" <> cs (tokenJWT jwtSecret body) <> "\"}"
                      else cs $ encode body)
        else return notFound

    (ActionRead, TargetRoot, Nothing) -> do
      body <- encode <$> H.query schema accessibleTables
      return $ responseLBS status200 [jsonH] $ cs body

    (ActionInappropriate, _, _) -> return $ responseLBS status405 [] ""

    (_, _, Just (PayloadParseError e)) ->
      return $ responseLBS status400 [jsonH] $
        cs (formatGeneralError "Cannot parse request payload" (cs e))

    (_, TargetUnknown _, _) -> return notFound

    (_, _, _) -> return notFound

 where
  notFound = responseLBS status404 [] ""
  filterPk sc table pk = sc == (tableSchema . pkTable) pk && table == (tableName . pkTable) pk
  allPrKeys = dbPrimaryKeys dbStructure
  allOrigins = ("Access-Control-Allow-Origin", "*") :: Header
  schema = cs $ configSchema conf
  shouldCount = iPreferCount apiRequest
  topLevelRange = fromMaybe allRange $ M.lookup "limit" $ iRange apiRequest
  readDbRequest = DbRead <$> buildReadRequest (configMaxRows conf) (dbRelations dbStructure) apiRequest
  mutateDbRequest = DbMutate <$> buildMutateRequest apiRequest
  selectQuery = requestToQuery schema <$> readDbRequest
  countQuery = requestToCountQuery schema <$> readDbRequest
  mutateQuery = requestToQuery schema <$> mutateDbRequest
  readSqlParts = (,) <$> selectQuery <*> countQuery
  mutateSqlParts = (,) <$> selectQuery <*> mutateQuery
  respondToRange response = if topLevelRange == emptyRange
                            then return $ errResponse status416 "HTTP Range error"
                            else response
  rangeHeader queryTotal tableTotal = let frm = rangeOffset topLevelRange
                                          to = frm + toInteger queryTotal - 1
                                          contentRange = contentRangeH frm to (toInteger <$> tableTotal)
                                          status = rangeStatus frm to (toInteger <$> tableTotal)
                                      in (status, contentRange)

splitKeyValue :: BS.ByteString -> (BS.ByteString, BS.ByteString)
splitKeyValue kv = (k, BS.tail v)
  where (k, v) = BS.break (== '=') kv

renderLocationFields :: [BS.ByteString] -> BS.ByteString
renderLocationFields fields =
  renderSimpleQuery True $ map splitKeyValue fields

rangeStatus :: Integer -> Integer -> Maybe Integer -> Status
rangeStatus _ _ Nothing = status200
rangeStatus frm to (Just total)
  | frm > total            = status416
  | (1 + to - frm) < total = status206
  | otherwise               = status200

contentRangeH :: Integer -> Integer -> Maybe Integer -> Header
contentRangeH frm to total =
    ("Content-Range", cs headerValue)
    where
      headerValue   = rangeString <> "/" <> totalString
      rangeString
        | totalNotZero && fromInRange = show frm <> "-" <> cs (show to)
        | otherwise = "*"
      totalString   = fromMaybe "*" (show <$> total)
      totalNotZero  = fromMaybe True ((/=) 0 <$> total)
      fromInRange   = frm <= to

jsonH :: Header
jsonH = (hContentType, "application/json; charset=utf-8")

formatRelationError :: Text -> Text
formatRelationError = formatGeneralError
  "could not find foreign keys between these entities"

formatParserError :: ParseError -> Text
formatParserError e = formatGeneralError message details
  where
     message = cs $ show (errorPos e)
     details = strip $ replace "\n" " " $ cs
       $ showErrorMessages "or" "unknown parse error" "expecting" "unexpected" "end of input" (errorMessages e)

formatGeneralError :: Text -> Text -> Text
formatGeneralError message details = cs $ encode $ object [
  "message" .= message,
  "details" .= details]

augumentRequestWithJoin :: Schema ->  [Relation] ->  ReadRequest -> Either Text ReadRequest
augumentRequestWithJoin schema allRels request =
  (first formatRelationError . addRelations schema allRels Nothing) request
  >>= addJoinConditions schema

addFiltersOrdersRanges :: ApiRequest -> Either ParseError (ReadRequest -> ReadRequest)
addFiltersOrdersRanges apiRequest = foldr1 (liftA2 (.)) [
    flip (foldr addFilter) <$> filters,
    flip (foldr addOrder) <$> orders,
    flip (foldr addRange) <$> ranges
  ]
  {-
  The esence of what is going on above is that we are composing tree functions
  of type (ReadRequest->ReadRequest) that are in (Either ParseError a) context
  -}
  where
    filters :: Either ParseError [(Path, Filter)]
    filters = mapM pRequestFilter flts
      where
        action = iAction apiRequest
        flts = if action == ActionRead
          then iFilters apiRequest
          else filter (( '.' `elem` ) . fst) $ iFilters apiRequest -- there can be no filters on the root table whre we are doing insert/update
    orders :: Either ParseError [(Path, [OrderTerm])]
    orders = mapM pRequestOrder $ iOrder apiRequest
    ranges :: Either ParseError [(Path, NonnegRange)]
    ranges = mapM pRequestRange $ M.toList $ iRange apiRequest

treeRestrictRange :: Maybe Integer -> ReadRequest -> Either Text ReadRequest
treeRestrictRange maxRows_ request = pure $ nodeRestrictRange maxRows_ `fmap` request
  where
    nodeRestrictRange :: Maybe Integer -> ReadNode -> ReadNode
    nodeRestrictRange m (q@Select {range_=r}, i) = (q{range_=restrictRange m r }, i)

buildReadRequest :: Maybe Integer -> [Relation] -> ApiRequest -> Either Text ReadRequest
buildReadRequest maxRows allRels apiRequest  =
  treeRestrictRange maxRows =<<
  augumentRequestWithJoin schema relations =<<
  first formatParserError readRequest
  where
    (schema, rootTableName) = fromJust $ -- Make it safe
      let target = iTarget apiRequest in
      case target of
        (TargetIdent (QualifiedIdentifier s t) ) -> Just (s, t)
        _ -> Nothing

    action :: Action
    action = iAction apiRequest

    readRequest :: Either ParseError ReadRequest
    readRequest = addFiltersOrdersRanges apiRequest <*>
      parse (pRequestSelect rootName) ("failed to parse select parameter <<"++selStr++">>") selStr
      where
        selStr = iSelect apiRequest
        rootName = if action == ActionRead
          then rootTableName
          else sourceCTEName

    relations :: [Relation]
    relations = case action of
      ActionCreate -> fakeSourceRelations ++ allRels
      ActionUpdate -> fakeSourceRelations ++ allRels
      ActionDelete -> fakeSourceRelations ++ allRels
      _       -> allRels
      where fakeSourceRelations = mapMaybe (toSourceRelation rootTableName) allRels -- see comment in toSourceRelation

buildMutateRequest :: ApiRequest -> Either Text MutateRequest
buildMutateRequest apiRequest = case action of
  ActionCreate -> Insert rootTableName <$> pure payload
  ActionUpdate -> Update rootTableName <$> pure payload <*> filters
  ActionDelete -> Delete rootTableName <$> filters
  _        -> Left "Unsupported HTTP verb"
  where
    action = iAction apiRequest
    payload = fromJust $ iPayload apiRequest
    rootTableName = -- TODO: Make it safe
      let target = iTarget apiRequest in
      case target of
        (TargetIdent (QualifiedIdentifier _ t) ) -> t
        _ -> undefined
    filters = first formatParserError $ map snd <$> mapM pRequestFilter mutateFilters
      where mutateFilters = filter (not . ( '.' `elem` ) . fst) $ iFilters apiRequest -- update/delete filters can be only on the root table

addFilterToNode :: Filter -> ReadRequest -> ReadRequest
addFilterToNode flt (Node (q@Select {flt_=flts}, i) f) = Node (q {flt_=flt:flts}, i) f

addFilter :: (Path, Filter) -> ReadRequest -> ReadRequest
addFilter = addProperty addFilterToNode

addOrderToNode :: [OrderTerm] -> ReadRequest -> ReadRequest
addOrderToNode o (Node (q,i) f) = Node (q{order=Just o}, i) f

addOrder :: (Path, [OrderTerm]) -> ReadRequest -> ReadRequest
addOrder = addProperty addOrderToNode

addRangeToNode :: NonnegRange -> ReadRequest -> ReadRequest
addRangeToNode r (Node (q,i) f) = Node (q{range_=r}, i) f

addRange :: (Path, NonnegRange) -> ReadRequest -> ReadRequest
addRange = addProperty addRangeToNode

addProperty :: (a -> ReadRequest -> ReadRequest) -> (Path, a) -> ReadRequest -> ReadRequest
addProperty f ([], a) n = f a n
addProperty f (path, a) (Node rn forest) =
  case targetNode of
    Nothing -> Node rn forest -- the property is silenty dropped in the Request does not contain the required path
    Just tn -> Node rn (addProperty f (remainingPath, a) tn:restForest)
  where
    targetNodeName:remainingPath = path
    (targetNode,restForest) = splitForest targetNodeName forest
    splitForest :: NodeName -> Forest ReadNode -> (Maybe ReadRequest, Forest ReadNode)
    splitForest name forst =
      case maybeNode of
        Nothing -> (Nothing,forest)
        Just node -> (Just node, delete node forest)
      where
        maybeNode :: Maybe ReadRequest
        maybeNode = find fnd forst
          where
            fnd :: ReadRequest -> Bool
            fnd (Node (_,(n,_,_)) _) = n == name

-- in a relation where one of the tables mathces "TableName"
-- replace the name to that table with pg_source
-- this "fake" relations is needed so that in a mutate query
-- we can look a the "returning *" part which is wrapped with a "with"
-- as just another table that has relations with other tables
toSourceRelation :: TableName -> Relation -> Maybe Relation
toSourceRelation mt r@(Relation t _ ft _ _ rt _ _)
  | mt == tableName t = Just $ r {relTable=t {tableName=sourceCTEName}}
  | mt == tableName ft = Just $ r {relFTable=t {tableName=sourceCTEName}}
  | Just mt == (tableName <$> rt) = Just $ r {relLTable=(\tbl -> tbl {tableName=sourceCTEName}) <$> rt}
  | otherwise = Nothing

data TableOptions = TableOptions {
  tblOptcolumns :: [Column]
, tblOptpkey    :: [Text]
}

instance ToJSON TableOptions where
  toJSON t = object [
      "columns" .= tblOptcolumns t
    , "pkey"   .= tblOptpkey t ]


extractQueryResult :: Maybe ResultsWithCount -> ResultsWithCount
extractQueryResult = fromMaybe (Nothing, 0, [], "")