{-# LANGUAGE BangPatterns, DeriveGeneric, OverloadedStrings #-}
module Data.CQRS.PostgreSQL.Internal.Utils
       ( SqlValue(..)
       , QueryError(..)
       , badQueryResultMsg
       , execSql
       , execSql'
       , ioQuery
       , ioQuery'
       , isDuplicateKey
       , runQuery
       , withTransaction
       ) where

import           Control.DeepSeq (NFData(..), ($!!))
import           Control.Exception (Exception, throw)
import           Control.Exception.Enclosed (catchAny)
import           Control.Monad (forM)
import           Control.Monad.IO.Class (liftIO)
import           Control.Exception (SomeException, bracket)
import qualified Data.ByteString.Char8 as B8
import           Data.ByteString (ByteString)
import qualified Data.ByteString as B
import           Data.ByteString.Lex.Integral (readDecimal)
import           Data.ByteString.Lex.Double (readDouble)
import           Data.Int (Int16, Int32, Int64)
import           Data.Pool (Pool, withResource)
import           Data.Text (Text)
import           Data.Text.Encoding (decodeUtf8', encodeUtf8)
import           Data.Typeable (Typeable)
import           Data.UUID.Types (UUID)
import qualified Data.UUID.Types as U
import           Database.PostgreSQL.LibPQ (Connection, Oid(..), Format(..), ExecStatus(..), Column(..), Row(..), FieldCode(..))
import qualified Database.PostgreSQL.LibPQ as P
import           GHC.Generics (Generic)
import           System.IO.Streams (InputStream)
import qualified System.IO.Streams.Combinators as SC
import qualified System.IO.Streams.List as SL

-- | Known field types.
data SqlValue = SqlByteArray (Maybe ByteString)
              | SqlBlankPaddedString (Maybe ByteString)
              | SqlBool (Maybe Bool)
              | SqlInt16 (Maybe Int16)
              | SqlInt32 (Maybe Int32)
              | SqlInt64 (Maybe Int64)
              | SqlFloating (Maybe Double)
              | SqlVarChar (Maybe Text)
              | SqlText (Maybe Text)
              | SqlUUID (Maybe UUID)
              | Unmatched (Oid, Maybe ByteString)
              deriving (Eq, Show)

-- | Is the given query exception a duplicate key exception?
isDuplicateKey :: QueryError -> Maybe ()
isDuplicateKey qe | qeSqlState qe == Just "23505" = Just ()
                  | otherwise                     = Nothing

-- | Execute an IO action with an active transaction.
withTransaction :: Connection -> IO a -> IO a
withTransaction connection action = do
  begin
  catchAny runAction tryRollback
  where
    runAction = do
      r <- action
      commit
      return r

    tryRollback :: SomeException -> IO a
    tryRollback e =
      -- Try explicit rollback; we want to preserve original exception.
      catchAny (rollback >> throw e) $ \_ ->
          -- Rethrow original exception; resource pool will make sure the database
          -- connection is properly destroyed (rather than being returned to the
          -- pool).
          throw e

    begin = execSql connection "START TRANSACTION;" [ ]
    commit = execSql connection "COMMIT TRANSACTION;" [ ]
    rollback = execSql connection "ROLLBACK TRANSACTION;" [ ]

-- | Read a boolean.
readBoolean :: ByteString -> Maybe Bool
readBoolean "t" = Just True
readBoolean "f" = Just False
readBoolean _ = Nothing

-- | Map an SqlValue to a parameter.
fromSqlValue :: Connection -> SqlValue -> IO (Maybe (Oid, ByteString, Format))
fromSqlValue connection (SqlByteArray a) = do
  case a of
    Nothing -> return Nothing
    Just a' -> do
      x <- P.escapeByteaConn connection a'
      case x of
        Nothing -> error "Conversion failed"
        Just x' -> return $ Just (Oid 17, x', Text)
fromSqlValue _ (SqlBool (Just True)) = return $ Just (Oid 16, "t", Text)
fromSqlValue _ (SqlBool (Just False)) = return $ Just (Oid 16, "f", Text)
fromSqlValue _ (SqlBool Nothing) = return Nothing
fromSqlValue _ (SqlInt32 Nothing) = return Nothing
fromSqlValue _ (SqlInt32 (Just i)) = return $ Just (Oid 23, B8.pack (show i), Text)
fromSqlValue _ (SqlInt64 Nothing) = return Nothing
fromSqlValue _ (SqlInt64 (Just i)) = return $ Just (Oid 20, B8.pack (show i), Text)
fromSqlValue _ (SqlVarChar Nothing) = return Nothing
fromSqlValue _ (SqlVarChar (Just t)) = return $ Just (Oid 1043, encodeUtf8 t, Binary)
fromSqlValue _ (SqlText Nothing) = return Nothing
fromSqlValue _ (SqlText (Just t)) = return $ Just (Oid 25, encodeUtf8 t, Text)
fromSqlValue _ (SqlUUID Nothing) = return Nothing
fromSqlValue _ (SqlUUID (Just u)) = return $ Just (Oid 2950, U.toASCIIBytes u, Text)
fromSqlValue _ _ = error "fromSqlValue: Parameter conversion failed"

-- | Map field to an SqlValue.
toSqlValue :: (Oid, Maybe ByteString) -> IO SqlValue
toSqlValue (oid, mvalue) =
  case oid of
    Oid 17 -> c P.unescapeBytea SqlByteArray
    Oid 16 -> c (return . readBoolean) SqlBool
    Oid 20 -> c (return . fmap fst . readDecimal) SqlInt64
    Oid 21 -> c (return . fmap fst . readDecimal) SqlInt16
    Oid 23 -> c (return . fmap fst . readDecimal) SqlInt32
    Oid 25 -> c (return . either (const Nothing) Just . decodeUtf8') SqlText
    Oid 700 -> c (return . fmap fst . readDouble) SqlFloating
    Oid 701 -> c (return . fmap fst . readDouble) SqlFloating
    Oid 1042 -> c (return . Just) SqlBlankPaddedString
    Oid 1043 -> c (return . either (const Nothing) Just . decodeUtf8') SqlVarChar
    Oid 2950 -> c (return . U.fromASCIIBytes) SqlUUID

    _ -> return $ Unmatched (oid,mvalue)
  where
    c :: Monad m => (ByteString -> m (Maybe a)) -> (Maybe a -> SqlValue) -> m SqlValue
    c convert construct =
      case mvalue of
        Nothing -> return $ construct Nothing
        Just value -> do
          mvalue' <- convert value
          case mvalue' of
            Nothing -> error "toSqlValue: Conversion failed"
            Just _  -> return $ construct mvalue'

-- | Execute a query with no result.
execSql :: Connection -> ByteString -> [SqlValue] -> IO ()
execSql connection sql parameters =
  ioQuery connection sql parameters (\_ -> return ())

-- | Execute a query an return the number of updated rows (if available).
execSql' :: Connection -> ByteString -> [SqlValue] -> IO (Maybe Int)
execSql' connection sql parameters =
  ioQuery' connection sql parameters (\n _ -> return n)

-- | Error happened during query.
data QueryError = QueryError
    { qeSqlState :: Maybe ByteString
    , qeStatusMessage :: ByteString
    , qeErrorMessage :: Maybe ByteString
    } deriving (Show, Typeable, Generic)

instance Exception QueryError

instance NFData QueryError

-- | Run a query and fold over the results. The action receives an
-- 'InputStream' over all the rows in the result.
ioQuery :: Connection -> ByteString -> [SqlValue] -> (InputStream [SqlValue] -> IO a) -> IO a
ioQuery connection sql parameters f =
    ioQuery' connection sql parameters $ \_ is -> f is

-- | Run a query and fold over the results. The action receives the number of rows affected
-- and an 'InputStream' over all the rows in the result.
ioQuery' :: Connection -> ByteString -> [SqlValue] -> (Maybe Int -> InputStream [SqlValue] -> IO a) -> IO a
ioQuery' connection sql parameters f =
  bracket open done $ \r -> do
    -- Check the status
    status <- P.resultStatus r
    case status of
      CommandOk -> do
        n <- affectedRows r
        go r >>= f n
      TuplesOk  -> do
        n <- affectedRows r
        go r >>= f n
      _ -> do
        -- Extract error information. We need to be careful to
        -- COPY the values here since freeing the result will
        -- cause the "original" values to become garbage.
        sqlState <- fmap (fmap B.copy) $ P.resultErrorField r DiagSqlstate
        statusMessage <- fmap (fmap B.copy) P.resStatus status
        errorMessage <- fmap (fmap B.copy) $ P.resultErrorMessage r
        throw $!! QueryError
                  { qeSqlState = sqlState
                  , qeStatusMessage = statusMessage
                  , qeErrorMessage = errorMessage
                  }

  where
    affectedRows r = do
      !cmdTuples <- P.cmdTuples r
      case cmdTuples of
        Nothing -> return $ Nothing
        Just !x -> return $! fmap fst $! readDecimal $! B.copy x

    done r = P.unsafeFreeResult r

    open = do
      parameters' <- forM parameters $ fromSqlValue connection
      mr <- P.execParams connection sql parameters' Text
      case mr of
        Nothing -> error "No result set; something is very wrong"
        Just r -> return r

    go :: P.Result -> IO (InputStream [SqlValue])
    go r = do
      Col nFields <- liftIO $ P.nfields r
      Row nRows <- liftIO $ P.ntuples r
      let columns = map P.toColumn [0..nFields-1]
      let loop []         = return Nothing
          loop (row:rows) = do
            columnValues <- forM columns $ getSqlVal r row
            return $ Just (columnValues, rows)
      SC.unfoldM loop $ map P.toRow [0..nRows-1]

    getSqlVal r row c = do
      mval <- P.getvalue' r row c
      typ <- P.ftype r c
      toSqlValue (typ,mval)

-- Run a query and result a list of the rows in the result.
runQuery :: Pool Connection -> ByteString -> [SqlValue] -> IO [[SqlValue]]
runQuery connectionPool sql parameters =
  withResource connectionPool $ \c -> do
    ioQuery c sql parameters (\inputStream -> SL.toList inputStream)

-- | Format a message indicating a bad query result due to the "shape".
badQueryResultMsg :: [String] -> [SqlValue] -> String
badQueryResultMsg params columns = concat ["Invalid query result shape. Params: ", show params, ". Result columns: ", show columns]