{-# LANGUAGE TypeSynonymInstances, FlexibleInstances, MultiWayIf #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}

module PostgREST.PgQuery where

import PostgREST.RangeQuery

import qualified Hasql as H
import qualified Hasql.Postgres as P
import qualified Hasql.Backend as B

import qualified Data.Text as T
import qualified Data.HashMap.Strict as H
import Text.Regex.TDFA ( (=~) )
import qualified Network.HTTP.Types.URI as Net
import qualified Data.ByteString.Char8 as BS
import Data.Monoid
import Data.Vector (empty)
import Data.Maybe (fromMaybe, mapMaybe)
import Data.Functor
import Control.Monad (join)
import Data.String.Conversions (cs)
import qualified Data.Aeson as JSON
import qualified Data.List as L
import qualified Data.Vector as V
import Data.Scientific (isInteger, formatScientific, FPFormat(..))

import Prelude

type PStmt = H.Stmt P.Postgres
instance Monoid PStmt where
  mappend (B.Stmt query params prep) (B.Stmt query' params' prep') =
    B.Stmt (query <> query') (params <> params') (prep && prep')
  mempty = B.Stmt "" empty True
type StatementT = PStmt -> PStmt

data QualifiedIdentifier = QualifiedIdentifier {
  qiSchema :: T.Text
, qiName   :: T.Text
} deriving (Show)

data OrderTerm = OrderTerm {
  otTerm :: T.Text
, otDirection :: BS.ByteString
, otNullOrder :: Maybe BS.ByteString
}

limitT :: Maybe NonnegRange -> StatementT
limitT r q =
  q <> B.Stmt (" LIMIT " <> limit <> " OFFSET " <> offset <> " ") empty True
  where
    limit  = maybe "ALL" (cs . show) $ join $ rangeLimit <$> r
    offset = cs . show $ fromMaybe 0 $ rangeOffset <$> r

whereT :: QualifiedIdentifier -> Net.Query -> StatementT
whereT table params q =
 if L.null cols
   then q
   else q <> B.Stmt " where " empty True <> conjunction
 where
   cols = [ col | col <- params, fst col `notElem` ["order"] ]
   wherePredTable = wherePred table
   conjunction = mconcat $ L.intersperse andq (map wherePredTable cols)

withT :: PStmt -> T.Text -> StatementT
withT (B.Stmt eq ep epre) v (B.Stmt wq wp wpre) =
  B.Stmt ("WITH " <> v <> " AS (" <> eq <> ") " <> wq <> " from " <> v)
    (ep <> wp)
    (epre && wpre)

orderT :: [OrderTerm] -> StatementT
orderT ts q =
  if L.null ts
    then q
    else q <> B.Stmt " order by " empty True <> clause
  where
   clause = mconcat $ L.intersperse commaq (map queryTerm ts)
   queryTerm :: OrderTerm -> PStmt
   queryTerm t = B.Stmt
                   (" " <> cs (pgFmtIdent $ otTerm t) <> " "
                        <> cs (otDirection t)         <> " "
                        <> maybe "" cs (otNullOrder t) <> " ")
                   empty True

parentheticT :: StatementT
parentheticT s =
  s { B.stmtTemplate = " (" <> B.stmtTemplate s <> ") " }

iffNotT :: PStmt -> StatementT
iffNotT (B.Stmt aq ap apre) (B.Stmt bq bp bpre) =
  B.Stmt
    ("WITH aaa AS (" <> aq <> " returning *) " <>
      bq <> " WHERE NOT EXISTS (SELECT * FROM aaa)")
    (ap <> bp)
    (apre && bpre)

countT :: StatementT
countT s =
  s { B.stmtTemplate = "WITH qqq AS (" <> B.stmtTemplate s <> ") SELECT pg_catalog.count(1) FROM qqq" }

countRows :: QualifiedIdentifier  -> PStmt
countRows t = B.Stmt ("select pg_catalog.count(1) from " <> fromQi t) empty True

asCsvWithCount :: QualifiedIdentifier -> StatementT
asCsvWithCount table = withCount . asCsv table

asCsv :: QualifiedIdentifier -> StatementT
asCsv table s = s { B.stmtTemplate =
    "(select string_agg(quote_ident(column_name::text), ',') from "
    <> "(select column_name from information_schema.columns where quote_ident(table_schema) || '.' || table_name = '"
    <> fromQi table <> "' order by ordinal_position) h) || '\r' || "
    <> "coalesce(string_agg(substring(t::text, 2, length(t::text) - 2), '\r'), '') from ("
    <> B.stmtTemplate s <> ") t" }

asJsonWithCount :: StatementT
asJsonWithCount = withCount . asJson

asJson :: StatementT
asJson s = s { B.stmtTemplate =
    "array_to_json(array_agg(row_to_json(t)))::character varying from ("
    <> B.stmtTemplate s <> ") t" }

withCount :: StatementT
withCount s = s { B.stmtTemplate = "pg_catalog.count(t), " <> B.stmtTemplate s }

asJsonRow :: StatementT
asJsonRow s = s { B.stmtTemplate = "row_to_json(t) from (" <> B.stmtTemplate s <> ") t" }

selectStar :: QualifiedIdentifier -> PStmt
selectStar t = B.Stmt ("select * from " <> fromQi t) empty True

returningStarT :: StatementT
returningStarT s = s { B.stmtTemplate = B.stmtTemplate s <> " RETURNING *" }

deleteFrom :: QualifiedIdentifier -> PStmt
deleteFrom t = B.Stmt ("delete from " <> fromQi t) empty True

insertInto :: QualifiedIdentifier
              -> V.Vector T.Text
              -> V.Vector (V.Vector JSON.Value)
              -> PStmt
insertInto t cols vals
  | V.null cols = B.Stmt ("insert into " <> fromQi t <> " default values returning *") empty True
  | otherwise   = B.Stmt
    ("insert into " <> fromQi t <> " (" <>
      T.intercalate ", " (V.toList $ V.map pgFmtIdent cols) <>
      ") values "
      <> T.intercalate ", "
        (V.toList $ V.map (\v -> "("
            <> T.intercalate ", " (V.toList $ V.map insertableValue v)
            <> ")"
          ) vals
        )
      <> " returning row_to_json(" <> fromQi t <> ".*)")
    empty True

insertSelect :: QualifiedIdentifier -> [T.Text] -> [JSON.Value] -> PStmt
insertSelect t [] _ = B.Stmt
  ("insert into " <> fromQi t <> " default values returning *") empty True
insertSelect t cols vals = B.Stmt
  ("insert into " <> fromQi t <> " ("
    <> T.intercalate ", " (map pgFmtIdent cols)
    <> ") select "
    <> T.intercalate ", " (map insertableValue vals))
  empty True

update :: QualifiedIdentifier -> [T.Text] -> [JSON.Value] -> PStmt
update t cols vals = B.Stmt
  ("update " <> fromQi t <> " set ("
    <> T.intercalate ", " (map pgFmtIdent cols)
    <> ") = ("
    <> T.intercalate ", " (map insertableValue vals)
    <> ")")
  empty True

callProc :: QualifiedIdentifier -> JSON.Object -> PStmt
callProc qi params = do
  let args = T.intercalate "," $ map assignment (H.toList params)
  B.Stmt ("select * from " <> fromQi qi <> "(" <> args <> ")") empty True
  where
    assignment (n,v) = pgFmtIdent n <> ":=" <> insertableValue v

wherePred :: QualifiedIdentifier -> Net.QueryItem -> PStmt
wherePred table (col, predicate) =
  B.Stmt (notOp <> " " <> pgFmtJsonbPath table (cs col) <> " " <> op <> " " <>
      if opCode `elem` ["is","isnot"] then whiteList value
                                 else cs sqlValue)
      empty True

  where
    headPredicate:rest = T.split (=='.') $ cs $ fromMaybe "." predicate
    hasNot caseTrue caseFalse =  if headPredicate == "not" then caseTrue else caseFalse
    opCode = hasNot (head rest) headPredicate
    notOp  = hasNot headPredicate ""
    value  = hasNot (T.intercalate "." $ tail rest) (T.intercalate "." rest)
    whiteList val = fromMaybe (cs (pgFmtLit val) <> "::unknown ")
                              (L.find ((==) . T.toLower $ val)
                                      ["null","true","false"])
    star c = if c == '*' then '%' else c
    unknownLiteral = (<> "::unknown ") . pgFmtLit

    sqlValue = case opCode of
            "like" -> unknownLiteral $ T.map star value
            "ilike" -> unknownLiteral $ T.map star value
            "in" -> "(" <> T.intercalate ", " (map unknownLiteral $ T.split (==',') value) <> ") "
            "notin" -> "(" <> T.intercalate ", " (map unknownLiteral $ T.split (==',') value) <> ") "
            "@@" -> "to_tsquery(" <> unknownLiteral value <> ") "
            _    -> unknownLiteral value

    op = case opCode of
         "eq"  -> "="
         "gt"  -> ">"
         "lt"  -> "<"
         "gte" -> ">="
         "lte" -> "<="
         "neq" -> "<>"
         "like"-> "like"
         "ilike"-> "ilike"
         "in"  -> "in"
         "notin" -> "not in"
         "is"    -> "is"
         "isnot" -> "is not"
         "@@" -> "@@"
         _     -> "="

orderParse :: Net.Query -> [OrderTerm]
orderParse q =
  mapMaybe orderParseTerm . T.split (==',') $ cs order
  where
    order = fromMaybe "" $ join (lookup "order" q)

orderParseTerm :: T.Text -> Maybe OrderTerm
orderParseTerm s =
  case T.split (=='.') s of
       (c:d:nls) ->
         if d `elem` ["asc", "desc"]
            then Just $ OrderTerm c
              ( if d == "asc" then "asc" else "desc" )
              ( case nls of
                  [n] -> if | n == "nullsfirst" -> Just "nulls first"
                            | n == "nullslast"  -> Just "nulls last"
                            | otherwise -> Nothing
                  _   -> Nothing
              )
            else Nothing
       _ -> Nothing

commaq :: PStmt
commaq  = B.Stmt ", " empty True

andq :: PStmt
andq = B.Stmt " and " empty True

data JsonbPath =
    ColIdentifier T.Text
  | KeyIdentifier T.Text
  | SingleArrow JsonbPath JsonbPath
  | DoubleArrow JsonbPath JsonbPath
  deriving (Show)

parseJsonbPath :: T.Text -> Maybe JsonbPath
parseJsonbPath p =
  case T.splitOn "->>" p of
    [a,b] ->
      let i:is = T.splitOn "->" a in
      Just $ DoubleArrow
        (foldl SingleArrow (ColIdentifier i) (map KeyIdentifier is))
        (KeyIdentifier b)
    _ -> Nothing

pgFmtJsonbPath :: QualifiedIdentifier -> T.Text -> T.Text
pgFmtJsonbPath table p =
  pgFmtJsonbPath' $ fromMaybe (ColIdentifier p) (parseJsonbPath p)
  where
    pgFmtJsonbPath' (ColIdentifier i) = fromQi table <> "." <> pgFmtIdent i
    pgFmtJsonbPath' (KeyIdentifier i) = pgFmtLit i
    pgFmtJsonbPath' (SingleArrow a b) =
      pgFmtJsonbPath' a <> "->" <> pgFmtJsonbPath' b
    pgFmtJsonbPath' (DoubleArrow a b) =
      pgFmtJsonbPath' a <> "->>" <> pgFmtJsonbPath' b

pgFmtIdent :: T.Text -> T.Text
pgFmtIdent x =
  let escaped = T.replace "\"" "\"\"" (trimNullChars $ cs x) in
  if (cs escaped :: BS.ByteString) =~ danger
    then "\"" <> escaped <> "\""
    else escaped

  where danger = "^$|^[^a-z_]|[^a-z_0-9]" :: BS.ByteString

pgFmtLit :: T.Text -> T.Text
pgFmtLit x =
  let trimmed = trimNullChars x
      escaped = "'" <> T.replace "'" "''" trimmed <> "'"
      slashed = T.replace "\\" "\\\\" escaped in
  if T.isInfixOf "\\\\" escaped
    then "E" <> slashed
    else slashed

trimNullChars :: T.Text -> T.Text
trimNullChars = T.takeWhile (/= '\x0')

fromQi :: QualifiedIdentifier -> T.Text
fromQi t = pgFmtIdent (qiSchema t) <> "." <> pgFmtIdent (qiName t)

unquoted :: JSON.Value -> T.Text
unquoted (JSON.String t) = t
unquoted (JSON.Number n) =
  cs $ formatScientific Fixed (if isInteger n then Just 0 else Nothing) n
unquoted (JSON.Bool b) = cs . show $ b
unquoted v = cs $ JSON.encode v

insertableText :: T.Text -> T.Text
insertableText = (<> "::unknown") . pgFmtLit

insertableValue :: JSON.Value -> T.Text
insertableValue JSON.Null = "null"
insertableValue v = insertableText $ unquoted v

paramFilter :: JSON.Value -> T.Text
paramFilter JSON.Null = "is.null"
paramFilter v = "eq." <> unquoted v