{-| Implementation of a PostgreSQL-based event store. -} module Data.CQRS.EventStore.Backend.PostgreSQLUtils ( run , sourceQuery , withTransaction , SqlValue(..) ) where import Control.Monad (forM) import Control.Monad.IO.Class (liftIO) import Control.Exception (catch, onException, SomeException) import qualified Data.ByteString.Char8 as B8 import Data.ByteString (ByteString) import qualified Data.ByteString as B import Data.ByteString.Lex.Integral (readDecimal) import Data.ByteString.Lex.Double (readDouble) import Data.Conduit (ResourceT, Source, ($$), bracketP, yield) import qualified Data.Conduit as C import qualified Data.Conduit.List as CL import Data.Int (Int16, Int32, Int64) import Data.Text (Text) import Data.Text.Encoding (decodeUtf8', encodeUtf8) import Data.Time (Day) import Data.Time.Format (parseTime) import Database.PostgreSQL.LibPQ (Connection, Oid(..), Format(..), ExecStatus(..), Column(..), Row(..)) import qualified Database.PostgreSQL.LibPQ as P import System.Locale (defaultTimeLocale) -- | Known field types. data SqlValue = SqlByteArray (Maybe ByteString) | SqlBlankPaddedString (Maybe ByteString) | SqlBool (Maybe Bool) | SqlInt16 (Maybe Int16) | SqlInt32 (Maybe Int32) | SqlInt64 (Maybe Int64) | SqlFloating (Maybe Double) | SqlVarChar (Maybe Text) | SqlText (Maybe Text) | SqlDate (Maybe Day) | Unmatched (Oid, Maybe ByteString) deriving (Eq, Show) -- | Execute an IO action with an active transaction. withTransaction :: Connection -> IO a -> IO a withTransaction connection action = do begin onException runAction tryRollback where runAction = do r <- action commit return r tryRollback = -- Try rollback while discarding exception; we want to preserve -- original exception. catch rollback (\(_::SomeException) -> return ()) begin = run connection "BEGIN TRANSACTION;" [] commit = run connection "COMMIT TRANSACTION;" [] rollback = run connection "ROLLBACK TRANSACTION;" [] -- | Read a boolean. readBoolean :: ByteString -> Maybe Bool readBoolean "t" = Just True readBoolean "f" = Just False readBoolean _ = Nothing -- | Map an SqlValue to a parameter. fromSqlValue :: Connection -> SqlValue -> IO (Maybe (Oid, ByteString, Format)) fromSqlValue connection (SqlByteArray a) = do case a of Nothing -> return Nothing Just a' -> do x <- P.escapeByteaConn connection a' case x of Nothing -> error "Conversion failed" Just x' -> return $ Just (Oid 17, x', Text) fromSqlValue _ (SqlBool (Just True)) = return $ Just (Oid 16, "t", Text) fromSqlValue _ (SqlBool (Just False)) = return $ Just (Oid 16, "f", Text) fromSqlValue _ (SqlBool Nothing) = return Nothing fromSqlValue _ (SqlInt32 Nothing) = return Nothing fromSqlValue _ (SqlInt32 (Just i)) = return $ Just (Oid 23, B8.pack (show i), Text) fromSqlValue _ (SqlVarChar Nothing) = return Nothing fromSqlValue _ (SqlVarChar (Just t)) = return $ Just (Oid 1043, encodeUtf8 t, Binary) fromSqlValue _ (SqlText Nothing) = return Nothing fromSqlValue _ (SqlText (Just t)) = return $ Just (Oid 25, encodeUtf8 t, Text) fromSqlValue _ (SqlDate Nothing) = return Nothing fromSqlValue _ (SqlDate (Just d)) = return $ Just (Oid 1082, B8.pack (show d), Text) fromSqlValue _ _ = error "Parameter conversion failed" -- | Map field to an SqlValue. toSqlValue :: (Oid, Maybe ByteString) -> IO SqlValue toSqlValue (oid, mvalue) = case oid of Oid 17 -> c P.unescapeBytea SqlByteArray Oid 16 -> c (return . readBoolean) SqlBool Oid 20 -> c (return . fmap fst . readDecimal) SqlInt64 Oid 21 -> c (return . fmap fst . readDecimal) SqlInt16 Oid 23 -> c (return . fmap fst . readDecimal) SqlInt32 Oid 25 -> c (return . either (const Nothing) Just . decodeUtf8') SqlText Oid 700 -> c (return . fmap fst . readDouble) SqlFloating Oid 701 -> c (return . fmap fst . readDouble) SqlFloating Oid 1042 -> c (return . Just) SqlBlankPaddedString Oid 1043 -> c (return . either (const Nothing) Just . decodeUtf8') SqlVarChar Oid 1082 -> c (return . parseTime defaultTimeLocale "%F" . B8.unpack) SqlDate _ -> return $ Unmatched (oid,mvalue) where c :: Monad m => (ByteString -> m (Maybe a)) -> (Maybe a -> SqlValue) -> m SqlValue c convert construct = case mvalue of Nothing -> return $ construct Nothing Just value -> do mvalue' <- convert value case mvalue' of Nothing -> fail "Conversion failed" Just _ -> return $ construct mvalue' -- | Execute a query with no result. run :: Connection -> ByteString -> [SqlValue] -> IO () run connection sql parameters = do C.runResourceT (sourceQuery connection sql parameters $$ CL.sinkNull) -- | Source for traversing all the results of a PostgreSQL query. sourceQuery :: Connection -> ByteString -> [SqlValue] -> Source (ResourceT IO) [SqlValue] sourceQuery connection sql parameters = bracketP open done go where done r = P.unsafeFreeResult r open = do parameters' <- forM parameters $ fromSqlValue connection mr <- P.execParams connection sql parameters' Text case mr of Nothing -> error "No result" Just r -> do status <- P.resultStatus r case status of CommandOk -> return r TuplesOk -> return r _ -> do statusMessage <- P.resStatus status errorMessage <- fmap (maybe "(Unknown error)" id) $ P.resultErrorMessage r error $ B8.unpack $ B.concat [statusMessage, ": ", errorMessage] go r = do Col nFields <- liftIO $ P.nfields r Row nRows <- liftIO $ P.ntuples r let rows = map P.toRow [0..nRows-1] let columns = map P.toColumn [0..nFields-1] loop columns rows where loop _ [] = return () loop columns (row:rows) = do (forM columns $ (liftIO . getSqlVal r row)) >>= yield loop columns rows getSqlVal r row c = do mval <- P.getvalue' r row c typ <- P.ftype r c toSqlValue (typ,mval)