{-| Implementation of a PostgreSQL-based event store. -}
module Data.CQRS.EventStore.Backend.PostgreSQLUtils
       ( run
       , sourceQuery
       , withTransaction
       , SqlValue(..)
       ) where

import           Control.Monad (forM)
import           Control.Monad.IO.Class (liftIO)
import           Control.Exception (catch, onException, SomeException)
import qualified Data.ByteString.Char8 as B8
import           Data.ByteString (ByteString)
import qualified Data.ByteString as B
import           Data.ByteString.Lex.Integral (readDecimal)
import           Data.ByteString.Lex.Double (readDouble)
import           Data.Conduit (ResourceT, Source, ($$), bracketP, yield)
import qualified Data.Conduit as C
import qualified Data.Conduit.List as CL
import           Data.Int (Int16, Int32, Int64)
import           Data.Text (Text)
import           Data.Text.Encoding (decodeUtf8', encodeUtf8)
import           Data.Time (Day)
import           Data.Time.Format (parseTime)
import           Database.PostgreSQL.LibPQ (Connection, Oid(..), Format(..), ExecStatus(..), Column(..), Row(..))
import qualified Database.PostgreSQL.LibPQ as P
import           System.Locale (defaultTimeLocale)

-- | Known field types.
data SqlValue = SqlByteArray (Maybe ByteString)
              | SqlBlankPaddedString (Maybe ByteString)
              | SqlBool (Maybe Bool)
              | SqlInt16 (Maybe Int16)
              | SqlInt32 (Maybe Int32)
              | SqlInt64 (Maybe Int64)
              | SqlFloating (Maybe Double)
              | SqlVarChar (Maybe Text)
              | SqlText (Maybe Text)
              | SqlDate (Maybe Day)
              | Unmatched (Oid, Maybe ByteString)
              deriving (Eq, Show)

-- | Execute an IO action with an active transaction.
withTransaction :: Connection -> IO a -> IO a
withTransaction connection action = do
  begin
  onException runAction tryRollback
  where
    runAction = do
      r <- action
      commit
      return r

    tryRollback =
      -- Try rollback while discarding exception; we want to preserve
      -- original exception.
      catch rollback (\(_::SomeException) -> return ())

    begin = run connection "BEGIN TRANSACTION;" []
    commit = run connection "COMMIT TRANSACTION;" []
    rollback = run connection "ROLLBACK TRANSACTION;" []

-- | Read a boolean.
readBoolean :: ByteString -> Maybe Bool
readBoolean "t" = Just True
readBoolean "f" = Just False
readBoolean _ = Nothing

-- | Map an SqlValue to a parameter.
fromSqlValue :: Connection -> SqlValue -> IO (Maybe (Oid, ByteString, Format))
fromSqlValue connection (SqlByteArray a) = do
  case a of
    Nothing -> return Nothing
    Just a' -> do
      x <- P.escapeByteaConn connection a'
      case x of
        Nothing -> error "Conversion failed"
        Just x' -> return $ Just (Oid 17, x', Text)
fromSqlValue _ (SqlBool (Just True)) = return $ Just (Oid 16, "t", Text)
fromSqlValue _ (SqlBool (Just False)) = return $ Just (Oid 16, "f", Text)
fromSqlValue _ (SqlBool Nothing) = return Nothing
fromSqlValue _ (SqlInt32 Nothing) = return Nothing
fromSqlValue _ (SqlInt32 (Just i)) = return $ Just (Oid 23, B8.pack (show i), Text)
fromSqlValue _ (SqlVarChar Nothing) = return Nothing
fromSqlValue _ (SqlVarChar (Just t)) = return $ Just (Oid 1043, encodeUtf8 t, Binary)
fromSqlValue _ (SqlText Nothing) = return Nothing
fromSqlValue _ (SqlText (Just t)) = return $ Just (Oid 25, encodeUtf8 t, Text)
fromSqlValue _ (SqlDate Nothing) = return Nothing
fromSqlValue _ (SqlDate (Just d)) = return $ Just (Oid 1082, B8.pack (show d), Text)
fromSqlValue _ _ = error "Parameter conversion failed"

-- | Map field to an SqlValue.
toSqlValue :: (Oid, Maybe ByteString) -> IO SqlValue
toSqlValue (oid, mvalue) =
  case oid of
    Oid 17 -> c P.unescapeBytea SqlByteArray
    Oid 16 -> c (return . readBoolean) SqlBool
    Oid 20 -> c (return . fmap fst . readDecimal) SqlInt64
    Oid 21 -> c (return . fmap fst . readDecimal) SqlInt16
    Oid 23 -> c (return . fmap fst . readDecimal) SqlInt32
    Oid 25 -> c (return . either (const Nothing) Just . decodeUtf8') SqlText
    Oid 700 -> c (return . fmap fst . readDouble) SqlFloating
    Oid 701 -> c (return . fmap fst . readDouble) SqlFloating
    Oid 1042 -> c (return . Just) SqlBlankPaddedString
    Oid 1043 -> c (return . either (const Nothing) Just . decodeUtf8') SqlVarChar
    Oid 1082 -> c (return . parseTime defaultTimeLocale "%F" . B8.unpack) SqlDate

    _ -> return $ Unmatched (oid,mvalue)
  where
    c :: Monad m => (ByteString -> m (Maybe a)) -> (Maybe a -> SqlValue) -> m SqlValue
    c convert construct =
      case mvalue of
        Nothing -> return $ construct Nothing
        Just value -> do
          mvalue' <- convert value
          case mvalue' of
            Nothing -> fail "Conversion failed"
            Just _  -> return $ construct mvalue'

-- | Execute a query with no result.
run :: Connection -> ByteString -> [SqlValue] -> IO ()
run connection sql parameters = do
  C.runResourceT (sourceQuery connection sql parameters $$ CL.sinkNull)

-- | Source for traversing all the results of a PostgreSQL query.
sourceQuery :: Connection -> ByteString -> [SqlValue] -> Source (ResourceT IO) [SqlValue]
sourceQuery connection sql parameters =
  bracketP open done go
  where
    done r = P.unsafeFreeResult r
    open = do
      parameters' <- forM parameters $ fromSqlValue connection
      mr <- P.execParams connection sql parameters' Text
      case mr of
        Nothing -> error "No result"
        Just r -> do
          status <- P.resultStatus r
          case status of
            CommandOk -> return r
            TuplesOk -> return r
            _ -> do
              statusMessage <- P.resStatus status
              errorMessage <- fmap (maybe "(Unknown error)" id) $ P.resultErrorMessage r
              error $ B8.unpack $ B.concat [statusMessage, ": ", errorMessage]

    go r = do
      Col nFields <- liftIO $ P.nfields r
      Row nRows <- liftIO $ P.ntuples r
      let rows = map P.toRow [0..nRows-1]
      let columns = map P.toColumn [0..nFields-1]
      loop columns rows
      where
        loop _       []         = return ()
        loop columns (row:rows) = do
             (forM columns $ (liftIO . getSqlVal r row)) >>= yield
             loop columns rows

    getSqlVal r row c = do
      mval <- P.getvalue' r row c
      typ <- P.ftype r c
      toSqlValue (typ,mval)