module Data.CQRS.EventStore.Backend.PostgreSQLUtils
( run
, sourceQuery
, withTransaction
, SqlValue(..)
) where
import Control.Monad (forM)
import Control.Monad.IO.Class (liftIO)
import Control.Exception (catch, onException, SomeException)
import qualified Data.ByteString.Char8 as B8
import Data.ByteString (ByteString)
import qualified Data.ByteString as B
import Data.ByteString.Lex.Integral (readDecimal)
import Data.ByteString.Lex.Double (readDouble)
import Data.Conduit (ResourceT, Source, ($$), bracketP, yield)
import qualified Data.Conduit as C
import qualified Data.Conduit.List as CL
import Data.Int (Int16, Int32, Int64)
import Data.Text (Text)
import Data.Text.Encoding (decodeUtf8', encodeUtf8)
import Data.Time (Day)
import Data.Time.Format (parseTime)
import Database.PostgreSQL.LibPQ (Connection, Oid(..), Format(..), ExecStatus(..), Column(..), Row(..))
import qualified Database.PostgreSQL.LibPQ as P
import System.Locale (defaultTimeLocale)
data SqlValue = SqlByteArray (Maybe ByteString)
| SqlBlankPaddedString (Maybe ByteString)
| SqlBool (Maybe Bool)
| SqlInt16 (Maybe Int16)
| SqlInt32 (Maybe Int32)
| SqlInt64 (Maybe Int64)
| SqlFloating (Maybe Double)
| SqlVarChar (Maybe Text)
| SqlText (Maybe Text)
| SqlDate (Maybe Day)
| Unmatched (Oid, Maybe ByteString)
deriving (Eq, Show)
withTransaction :: Connection -> IO a -> IO a
withTransaction connection action = do
begin
onException runAction tryRollback
where
runAction = do
r <- action
commit
return r
tryRollback =
catch rollback (\(_::SomeException) -> return ())
begin = run connection "BEGIN TRANSACTION;" []
commit = run connection "COMMIT TRANSACTION;" []
rollback = run connection "ROLLBACK TRANSACTION;" []
readBoolean :: ByteString -> Maybe Bool
readBoolean "t" = Just True
readBoolean "f" = Just False
readBoolean _ = Nothing
fromSqlValue :: Connection -> SqlValue -> IO (Maybe (Oid, ByteString, Format))
fromSqlValue connection (SqlByteArray a) = do
case a of
Nothing -> return Nothing
Just a' -> do
x <- P.escapeByteaConn connection a'
case x of
Nothing -> error "Conversion failed"
Just x' -> return $ Just (Oid 17, x', Text)
fromSqlValue _ (SqlBool (Just True)) = return $ Just (Oid 16, "t", Text)
fromSqlValue _ (SqlBool (Just False)) = return $ Just (Oid 16, "f", Text)
fromSqlValue _ (SqlBool Nothing) = return Nothing
fromSqlValue _ (SqlInt32 Nothing) = return Nothing
fromSqlValue _ (SqlInt32 (Just i)) = return $ Just (Oid 23, B8.pack (show i), Text)
fromSqlValue _ (SqlVarChar Nothing) = return Nothing
fromSqlValue _ (SqlVarChar (Just t)) = return $ Just (Oid 1043, encodeUtf8 t, Binary)
fromSqlValue _ (SqlText Nothing) = return Nothing
fromSqlValue _ (SqlText (Just t)) = return $ Just (Oid 25, encodeUtf8 t, Text)
fromSqlValue _ (SqlDate Nothing) = return Nothing
fromSqlValue _ (SqlDate (Just d)) = return $ Just (Oid 1082, B8.pack (show d), Text)
fromSqlValue _ _ = error "Parameter conversion failed"
toSqlValue :: (Oid, Maybe ByteString) -> IO SqlValue
toSqlValue (oid, mvalue) =
case oid of
Oid 17 -> c P.unescapeBytea SqlByteArray
Oid 16 -> c (return . readBoolean) SqlBool
Oid 20 -> c (return . fmap fst . readDecimal) SqlInt64
Oid 21 -> c (return . fmap fst . readDecimal) SqlInt16
Oid 23 -> c (return . fmap fst . readDecimal) SqlInt32
Oid 25 -> c (return . either (const Nothing) Just . decodeUtf8') SqlText
Oid 700 -> c (return . fmap fst . readDouble) SqlFloating
Oid 701 -> c (return . fmap fst . readDouble) SqlFloating
Oid 1042 -> c (return . Just) SqlBlankPaddedString
Oid 1043 -> c (return . either (const Nothing) Just . decodeUtf8') SqlVarChar
Oid 1082 -> c (return . parseTime defaultTimeLocale "%F" . B8.unpack) SqlDate
_ -> return $ Unmatched (oid,mvalue)
where
c :: Monad m => (ByteString -> m (Maybe a)) -> (Maybe a -> SqlValue) -> m SqlValue
c convert construct =
case mvalue of
Nothing -> return $ construct Nothing
Just value -> do
mvalue' <- convert value
case mvalue' of
Nothing -> fail "Conversion failed"
Just _ -> return $ construct mvalue'
run :: Connection -> ByteString -> [SqlValue] -> IO ()
run connection sql parameters = do
C.runResourceT (sourceQuery connection sql parameters $$ CL.sinkNull)
sourceQuery :: Connection -> ByteString -> [SqlValue] -> Source (ResourceT IO) [SqlValue]
sourceQuery connection sql parameters =
bracketP open done go
where
done r = P.unsafeFreeResult r
open = do
parameters' <- forM parameters $ fromSqlValue connection
mr <- P.execParams connection sql parameters' Text
case mr of
Nothing -> error "No result"
Just r -> do
status <- P.resultStatus r
case status of
CommandOk -> return r
TuplesOk -> return r
_ -> do
statusMessage <- P.resStatus status
errorMessage <- fmap (maybe "(Unknown error)" id) $ P.resultErrorMessage r
error $ B8.unpack $ B.concat [statusMessage, ": ", errorMessage]
go r = do
Col nFields <- liftIO $ P.nfields r
Row nRows <- liftIO $ P.ntuples r
let rows = map P.toRow [0..nRows1]
let columns = map P.toColumn [0..nFields1]
loop columns rows
where
loop _ [] = return ()
loop columns (row:rows) = do
(forM columns $ (liftIO . getSqlVal r row)) >>= yield
loop columns rows
getSqlVal r row c = do
mval <- P.getvalue' r row c
typ <- P.ftype r c
toSqlValue (typ,mval)