{-# OPTIONS_GHC -fno-warn-orphans #-}

{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE CPP #-}

module Database.Beam.Postgres.Connection
  ( PgRowReadError(..), PgError(..)
  , Pg(..), PgF(..)

  , runBeamPostgres, runBeamPostgresDebug

  , pgRenderSyntax, runPgRowReader, getFields

  , withPgDebug

  , postgresUriSyntax ) where

import           Control.Exception (Exception, throwIO)
import           Control.Monad.Free.Church
import           Control.Monad.IO.Class

import           Database.Beam hiding (runDelete, runUpdate, runInsert, insert)
import           Database.Beam.Schema.Tables
import           Database.Beam.Backend.SQL
import           Database.Beam.Backend.SQL.BeamExtensions
import           Database.Beam.Backend.URI
import           Database.Beam.Query.Types (QGenExpr(..))

import           Database.Beam.Postgres.Syntax
import           Database.Beam.Postgres.Full
import           Database.Beam.Postgres.Types

import qualified Database.PostgreSQL.LibPQ as Pg hiding
  (Connection, escapeStringConn, escapeIdentifier, escapeByteaConn, exec)
import qualified Database.PostgreSQL.Simple as Pg
import qualified Database.PostgreSQL.Simple.FromField as Pg
import qualified Database.PostgreSQL.Simple.Internal as Pg
  ( Field(..), RowParser(..)
  , escapeStringConn, escapeIdentifier, escapeByteaConn
  , exec, throwResultError )
import qualified Database.PostgreSQL.Simple.Internal as PgI
import qualified Database.PostgreSQL.Simple.Ok as Pg
import qualified Database.PostgreSQL.Simple.Types as Pg (Null(..), Query(..))

import           Control.Monad.Reader
import           Control.Monad.State

import           Data.ByteString (ByteString)
import           Data.ByteString.Builder (toLazyByteString, byteString)
import qualified Data.ByteString.Lazy as BL
import           Data.Proxy
import           Data.String
import qualified Data.Text as T
import           Data.Text.Encoding (decodeUtf8)
#if !MIN_VERSION_base(4, 11, 0)
import           Data.Semigroup
#endif

import           Foreign.C.Types

import           Network.URI (uriToString)

-- | Errors that may arise while using the 'Pg' monad.
data PgError
  = PgRowParseError PgRowReadError
  | PgInternalError String
  deriving Show
instance Exception PgError

data PgStream a = PgStreamDone     (Either PgError a)
                | PgStreamContinue (Maybe PgI.Row -> IO (PgStream a))

-- | 'BeamURIOpeners' for the standard @postgresql:@ URI scheme. See the
-- postgres documentation for more details on the formatting. See documentation
-- for 'BeamURIOpeners' for more information on how to use this with beam
postgresUriSyntax :: c PgCommandSyntax Postgres Pg.Connection Pg
                  -> BeamURIOpeners c
postgresUriSyntax =
    mkUriOpener "postgresql:"
        (\uri -> do
            let pgConnStr = fromString (uriToString id uri "")
            hdl <- Pg.connectPostgreSQL pgConnStr
            pure (hdl, Pg.close hdl))

-- * Syntax rendering

pgRenderSyntax ::
  Pg.Connection -> PgSyntax -> IO ByteString
pgRenderSyntax conn (PgSyntax mkQuery) =
  renderBuilder <$> runF mkQuery finish step mempty
  where
    renderBuilder = BL.toStrict . toLazyByteString

    step (EmitBuilder b next) a = next (a <> b)
    step (EmitByteString b next) a = next (a <> byteString b)
    step (EscapeString b next) a = do
      res <- wrapError "EscapeString" (Pg.escapeStringConn conn b)
      next (a <> byteString res)
    step (EscapeBytea b next) a = do
      res <- wrapError "EscapeBytea" (Pg.escapeByteaConn conn b)
      next (a <> byteString res)
    step (EscapeIdentifier b next) a = do
      res <- wrapError "EscapeIdentifier" (Pg.escapeIdentifier conn b)
      next (a <> byteString res)

    finish _ = pure

    wrapError step' go = do
      res <- go
      case res of
        Right res' -> pure res'
        Left res' -> fail (step' <> ": " <> show res')

-- * Run row readers

-- | An error that may occur while parsing a row
data PgRowReadError
    = PgRowReadNoMoreColumns !CInt !CInt
      -- ^ We attempted to read more columns than postgres returned. First
      -- argument is the zero-based index of the column we attempted to read,
      -- and the second is the total number of columns
    | PgRowCouldNotParseField !CInt
      -- ^ There was an error while parsing the field. The first argument gives
      -- the zero-based index of the column that could not have been
      -- parsed. This is usually caused by your Haskell schema type being
      -- incompatible with the one in the database.
    deriving Show

instance Exception PgRowReadError

getFields :: Pg.Result -> IO [Pg.Field]
getFields res = do
  Pg.Col colCount <- Pg.nfields res

  let getField col =
        Pg.Field res (Pg.Col col) <$> Pg.ftype res (Pg.Col col)

  mapM getField [0..colCount - 1]

runPgRowReader ::
  Pg.Connection -> Pg.Row -> Pg.Result -> [Pg.Field] -> FromBackendRowM Postgres a -> IO (Either PgRowReadError a)
runPgRowReader conn rowIdx res fields readRow =
  Pg.nfields res >>= \(Pg.Col colCount) ->
  runF readRow finish step 0 colCount fields
  where
    step (ParseOneField _) curCol colCount [] = pure (Left (PgRowReadNoMoreColumns curCol colCount))
    step (ParseOneField _) curCol colCount _
      | curCol >= colCount = pure (Left (PgRowReadNoMoreColumns curCol colCount))
    step (ParseOneField next) curCol colCount remainingFields =
      let next' Nothing _ _ _ = pure (Left (PgRowCouldNotParseField curCol))
          next' (Just {}) _ _ [] = fail "Internal error"
          next' (Just x) curCol' colCount' (_:remainingFields') = next x (curCol' + 1) colCount' remainingFields'
      in step (PeekField next') curCol colCount remainingFields

    step (PeekField next) curCol colCount [] = next Nothing curCol colCount []
    step (PeekField next) curCol colCount remainingFields
      | curCol >= colCount = next Nothing curCol colCount remainingFields
    step (PeekField next) curCol colCount remainingFields@(field:_) =
      do fieldValue <- Pg.getvalue res rowIdx (Pg.Col curCol)
         res' <- Pg.runConversion (Pg.fromField field fieldValue) conn
         case res' of
           Pg.Errors {} -> next Nothing curCol colCount remainingFields
           Pg.Ok x -> next (Just x) curCol colCount remainingFields

    step (CheckNextNNull n next) curCol colCount remainingFields =
      doCheckNextN (fromIntegral n) (curCol :: CInt) (colCount :: CInt) remainingFields >>= \yes ->
      next yes (curCol + if yes then fromIntegral n else 0) colCount (if yes then drop (fromIntegral n) remainingFields else remainingFields)

    doCheckNextN 0 _ _ _ = pure False
    doCheckNextN n curCol colCount remainingFields
      | curCol + n > colCount = pure False
      | otherwise =
        let fieldsInQuestion = zip [curCol..] (take (fromIntegral n) remainingFields)
        in readAndCheck fieldsInQuestion

    readAndCheck [] = pure True
    readAndCheck ((i, field):xs) =
      do fieldValue <- Pg.getvalue res rowIdx (Pg.Col i)
         res' <- Pg.runConversion (Pg.fromField field fieldValue) conn
         case res' of
           Pg.Errors _ -> pure False
           Pg.Ok Pg.Null -> readAndCheck xs

    finish x _ _ _ = pure (Right x)

withPgDebug :: (String -> IO ()) -> Pg.Connection -> Pg a -> IO (Either PgError a)
withPgDebug dbg conn (Pg action) =
  let finish x = pure (Right x)
      step (PgLiftIO io next) = io >>= next
      step (PgLiftWithHandle withConn next) = withConn conn >>= next
      step (PgFetchNext next) = next Nothing
      step (PgRunReturning (PgCommandSyntax PgCommandTypeQuery syntax)
                           (mkProcess :: Pg (Maybe x) -> Pg a')
                           next) =
        do query <- pgRenderSyntax conn syntax
           let Pg process = mkProcess (Pg (liftF (PgFetchNext id)))
           dbg (T.unpack (decodeUtf8 query))
           action' <- runF process finishProcess stepProcess Nothing
           case action' of
             PgStreamDone (Right x) -> Pg.execute_ conn (Pg.Query query) >> next x
             PgStreamDone (Left err) -> pure (Left err)
             PgStreamContinue nextStream ->
               let finishUp (PgStreamDone (Right x)) = next x
                   finishUp (PgStreamDone (Left err)) = pure (Left err)
                   finishUp (PgStreamContinue next') = next' Nothing >>= finishUp

                   columnCount = fromIntegral $ valuesNeeded (Proxy @Postgres) (Proxy @x)
               in Pg.foldWith_ (Pg.RP (put columnCount >> ask)) conn (Pg.Query query) (PgStreamContinue nextStream) runConsumer >>= finishUp
      step (PgRunReturning (PgCommandSyntax PgCommandTypeDataUpdateReturning syntax) mkProcess next) =
        do query <- pgRenderSyntax conn syntax
           dbg (T.unpack (decodeUtf8 query))

           res <- Pg.exec conn query
           sts <- Pg.resultStatus res
           case sts of
             Pg.TuplesOk -> do
               let Pg process = mkProcess (Pg (liftF (PgFetchNext id)))
               runF process (\x _ -> Pg.unsafeFreeResult res >> next x) (stepReturningList res) 0
             _ -> Pg.throwResultError "No tuples returned to Postgres update/insert returning"
                                      res sts
      step (PgRunReturning (PgCommandSyntax _ syntax) mkProcess next) =
        do query <- pgRenderSyntax conn syntax
           dbg (T.unpack (decodeUtf8 query))
           _ <- Pg.execute_ conn (Pg.Query query)

           let Pg process = mkProcess (Pg (liftF (PgFetchNext id)))
           runF process next stepReturningNone

      stepReturningNone :: forall a. PgF (IO (Either PgError a)) -> IO (Either PgError a)
      stepReturningNone (PgLiftIO action' next) = action' >>= next
      stepReturningNone (PgLiftWithHandle withConn next) = withConn conn >>= next
      stepReturningNone (PgFetchNext next) = next Nothing
      stepReturningNone (PgRunReturning _ _ _) = pure (Left (PgInternalError "Nested queries not allowed"))

      stepReturningList :: forall a. Pg.Result -> PgF (CInt -> IO (Either PgError a)) -> CInt -> IO (Either PgError a)
      stepReturningList _   (PgLiftIO action' next) rowIdx = action' >>= \x -> next x rowIdx
      stepReturningList res (PgFetchNext next) rowIdx =
        do fields <- getFields res
           Pg.Row rowCount <- Pg.ntuples res
           if rowIdx >= rowCount
             then next Nothing rowIdx
             else runPgRowReader conn (Pg.Row rowIdx) res fields fromBackendRow >>= \case
                    Left err -> pure (Left (PgRowParseError err))
                    Right r -> next (Just r) (rowIdx + 1)
      stepReturningList _   (PgRunReturning _ _ _) _ = pure (Left (PgInternalError "Nested queries not allowed"))
      stepReturningList _   (PgLiftWithHandle {}) _ = pure (Left (PgInternalError "Nested queries not allowed"))

      finishProcess :: forall a. a -> Maybe PgI.Row -> IO (PgStream a)
      finishProcess x _ = pure (PgStreamDone (Right x))

      stepProcess :: forall a. PgF (Maybe PgI.Row -> IO (PgStream a)) -> Maybe PgI.Row -> IO (PgStream a)
      stepProcess (PgLiftIO action' next) row = action' >>= flip next row
      stepProcess (PgFetchNext next) Nothing =
        pure . PgStreamContinue $ \res ->
        case res of
          Nothing -> next Nothing Nothing
          Just (PgI.Row rowIdx res') ->
            getFields res' >>= \fields ->
            runPgRowReader conn rowIdx res' fields fromBackendRow >>= \case
              Left err -> pure (PgStreamDone (Left (PgRowParseError err)))
              Right r -> next r Nothing
      stepProcess (PgFetchNext next) (Just (PgI.Row rowIdx res)) =
        getFields res >>= \fields ->
        runPgRowReader conn rowIdx res fields fromBackendRow >>= \case
          Left err -> pure (PgStreamDone (Left (PgRowParseError err)))
          Right r -> pure (PgStreamContinue (next (Just r)))
      stepProcess (PgRunReturning _ _ _) _ = pure (PgStreamDone (Left (PgInternalError "Nested queries not allowed")))
      stepProcess (PgLiftWithHandle _ _) _ = pure (PgStreamDone (Left (PgInternalError "Nested queries not allowed")))

      runConsumer :: forall a. PgStream a -> PgI.Row -> IO (PgStream a)
      runConsumer s@(PgStreamDone {}) _ = pure s
      runConsumer (PgStreamContinue next) row = next (Just row)
  in runF action finish step

-- * Beam Monad class

data PgF next where
    PgLiftIO :: IO a -> (a -> next) -> PgF next
    PgRunReturning ::
        FromBackendRow Postgres x =>
        PgCommandSyntax -> (Pg (Maybe x) -> Pg a) -> (a -> next) -> PgF next
    PgFetchNext ::
        FromBackendRow Postgres x =>
        (Maybe x -> next) -> PgF next
    PgLiftWithHandle :: (Pg.Connection -> IO a) -> (a -> next) -> PgF next
deriving instance Functor PgF

-- | 'MonadBeam' in which we can run Postgres commands. See the documentation
-- for 'MonadBeam' on examples of how to use.
--
-- @beam-postgres@ also provides functions that let you run queries without
-- 'MonadBeam'. These functions may be more efficient and offer a conduit
-- API. See "Database.Beam.Postgres.Conduit" for more information.
newtype Pg a = Pg { runPg :: F PgF a }
    deriving (Monad, Applicative, Functor, MonadFree PgF)

instance MonadIO Pg where
    liftIO x = liftF (PgLiftIO x id)

runBeamPostgresDebug :: (String -> IO ()) -> Pg.Connection -> Pg a -> IO a
runBeamPostgresDebug dbg conn action =
    withPgDebug dbg conn action >>= either throwIO pure

runBeamPostgres :: Pg.Connection -> Pg a -> IO a
runBeamPostgres = runBeamPostgresDebug (\_ -> pure ())

instance MonadBeam PgCommandSyntax Postgres Pg.Connection Pg where
    withDatabase = runBeamPostgres
    withDatabaseDebug = runBeamPostgresDebug

    runReturningMany cmd consume =
        liftF (PgRunReturning cmd consume id)

instance MonadBeamInsertReturning PgCommandSyntax Postgres Pg.Connection Pg where
    runInsertReturningList tbl values = do
        let insertReturningCmd' =
                insertReturning tbl values onConflictDefault
                                (Just (changeBeamRep (\(Columnar' (QExpr s) :: Columnar' (QExpr PgExpressionSyntax PostgresInaccessible) ty) ->
                                                              Columnar' (QExpr s) :: Columnar' (QExpr PgExpressionSyntax ()) ty)))

        -- Make savepoint
        case insertReturningCmd' of
          PgInsertReturningEmpty ->
            pure []
          PgInsertReturning insertReturningCmd ->
            runReturningList (PgCommandSyntax PgCommandTypeDataUpdateReturning insertReturningCmd)

instance MonadBeamUpdateReturning PgCommandSyntax Postgres Pg.Connection Pg where
    runUpdateReturningList tbl mkAssignments mkWhere = do
        let updateReturningCmd' =
                updateReturning tbl mkAssignments mkWhere
                                (changeBeamRep (\(Columnar' (QExpr s) :: Columnar' (QExpr PgExpressionSyntax PostgresInaccessible) ty) ->
                                                        Columnar' (QExpr s) :: Columnar' (QExpr PgExpressionSyntax ()) ty))

        case updateReturningCmd' of
          PgUpdateReturningEmpty ->
            pure []
          PgUpdateReturning updateReturningCmd ->
            runReturningList (PgCommandSyntax PgCommandTypeDataUpdateReturning updateReturningCmd)

instance MonadBeamDeleteReturning PgCommandSyntax Postgres Pg.Connection Pg where
    runDeleteReturningList tbl mkWhere = do
        let PgDeleteReturning deleteReturningCmd =
                deleteReturning tbl mkWhere
                                (changeBeamRep (\(Columnar' (QExpr s) :: Columnar' (QExpr PgExpressionSyntax PostgresInaccessible) ty) ->
                                                        Columnar' (QExpr s) :: Columnar' (QExpr PgExpressionSyntax ()) ty))

        runReturningList (PgCommandSyntax PgCommandTypeDataUpdateReturning deleteReturningCmd)