{-# LANGUAGE OverloadedStrings, RecordWildCards, GADTs, CPP #-} -- | PostgreSQL backend for Selda. module Database.Selda.PostgreSQL ( PGConnectInfo (..) , withPostgreSQL, on, auth , pgBackend , pgConnString ) where import Data.Monoid import qualified Data.Text as T import Data.Text.Encoding import Database.Selda.Backend import Control.Monad.Catch #ifndef __HASTE__ import Database.PostgreSQL.LibPQ hiding (user, pass, db, host) import qualified Data.ByteString.Char8 as BS #endif -- | PostgreSQL connection information. data PGConnectInfo = PGConnectInfo { -- | Host to connect to. pgHost :: T.Text -- | Port to connect to. , pgPort :: Int -- | Name of database to use. , pgDatabase :: T.Text -- | Username for authentication, if necessary. , pgUsername :: Maybe T.Text -- | Password for authentication, if necessary. , pgPassword :: Maybe T.Text } -- | Connect to the given database on the given host, on the default PostgreSQL -- port (5432): -- -- > withPostgreSQL ("my_db" `on` "example.com") $ do -- > ... on :: T.Text -> T.Text -> PGConnectInfo on db host = PGConnectInfo { pgHost = host , pgPort = 5432 , pgDatabase = db , pgUsername = Nothing , pgPassword = Nothing } infixl 7 `on` -- | Add the given username and password to the given connection information: -- -- > withPostgreSQL ("my_db" `on` "example.com" `auth` ("user", "pass")) $ do -- > ... -- -- For more precise control over the connection options, you should modify -- the 'PGConnectInfo' directly. auth :: PGConnectInfo -> (T.Text, T.Text) -> PGConnectInfo auth ci (user, pass) = ci { pgUsername = Just user , pgPassword = Just pass } infixl 4 `auth` -- | Perform the given computation over a PostgreSQL database. -- The database connection is guaranteed to be closed when the computation -- terminates. withPostgreSQL :: (MonadIO m, MonadThrow m, MonadMask m) => PGConnectInfo -> SeldaT m a -> m a #ifdef __HASTE__ withPostgreSQL _ _ = return $ error "withPostgreSQL called in JS context" #else withPostgreSQL ci m = do conn <- liftIO $ connectdb connstr st <- liftIO $ status conn case st of ConnectionOk -> do let backend = pgBackend (decodeUtf8 connstr) conn liftIO $ runStmt backend "SET client_min_messages TO WARNING;" [] runSeldaT m backend `finally` liftIO (finish conn) nope -> do connFailed nope where connstr = pgConnString ci connFailed f = throwM $ DbError $ unwords [ "unable to connect to postgres server: " ++ show f ] -- | Create a `SeldaBackend` for PostgreSQL `Connection` pgBackend :: T.Text -- ^ Unique database identifier. Preferably the -- connection string used to open the connection. -> Connection -- ^ PostgreSQL connection object. -> SeldaBackend pgBackend ident c = SeldaBackend { runStmt = \q ps -> right <$> pgQueryRunner c False q ps , runStmtWithPK = \q ps -> left <$> pgQueryRunner c True q ps , customColType = pgColType , defaultKeyword = "DEFAULT" , dbIdentifier = ident } where left (Left x) = x left _ = error "impossible" right (Right x) = x right _ = error "impossible" -- | Convert `PGConnectInfo` into `ByteString` pgConnString :: PGConnectInfo -> BS.ByteString pgConnString PGConnectInfo{..} = mconcat [ "host=", encodeUtf8 pgHost, " " , "port=", BS.pack (show pgPort), " " , "dbname=", encodeUtf8 pgDatabase, " " , case pgUsername of Just user -> "user=" <> encodeUtf8 user <> " " _ -> "" , case pgPassword of Just pass -> "password=" <> encodeUtf8 pass <> " " _ -> "" , "connect_timeout=10", " " , "client_encoding=UTF8" ] pgQueryRunner :: Connection -> Bool -> T.Text -> [Param] -> IO (Either Int (Int, [[SqlValue]])) pgQueryRunner c return_lastid q ps = do mres <- execParams c (encodeUtf8 q') [fromSqlValue p | Param p <- ps] Text case mres of Just res -> do st <- resultStatus res case st of BadResponse -> throwM $ SqlError "bad response" FatalError -> throwM $ SqlError errmsg NonfatalError -> throwM $ SqlError errmsg _ | return_lastid -> Left <$> getLastId res | otherwise -> Right <$> getRows res Nothing -> throwM $ DbError "unable to submit query to server" where errmsg = "error executing query `" ++ T.unpack q' ++ "'" q' | return_lastid = q <> " RETURNING LASTVAL();" | otherwise = q getLastId res = (read . BS.unpack . maybe "" id) <$> getvalue res 0 0 getRows res = do rows <- ntuples res cols <- nfields res types <- mapM (ftype res) [0..cols-1] affected <- cmdTuples res result <- mapM (getRow res types cols) [0..rows-1] pure $ case affected of Just "" -> (0, result) Just s -> (read $ BS.unpack s, result) _ -> (0, result) getRow res types cols row = do sequence $ zipWith (getCol res row) [0..cols-1] types getCol res row col t = do mval <- getvalue res row col case mval of Just val -> pure $ toSqlValue t val _ -> pure SqlNull -- | Convert the given postgres return value and type to an @SqlValue@. -- TODO: use binary format instead of text. toSqlValue :: Oid -> BS.ByteString -> SqlValue toSqlValue t val | t == boolType = SqlBool $ readBool val | t == intType = SqlInt $ read (BS.unpack val) | t == doubleType = SqlFloat $ read (BS.unpack val) | t `elem` textish = SqlString (decodeUtf8 val) | otherwise = error $ "BUG: result with unknown type oid: " ++ show t where textish = [textType, timestampType, timeType, dateType] readBool "f" = False readBool "0" = False readBool "0.0" = False readBool "F" = False readBool _ = True -- | Convert a parameter into an postgres parameter triple. fromSqlValue :: Lit a -> Maybe (Oid, BS.ByteString, Format) fromSqlValue (LitB b) = Just (boolType, if b then "true" else "false", Text) fromSqlValue (LitI n) = Just (intType, BS.pack $ show n, Text) fromSqlValue (LitD f) = Just (doubleType, BS.pack $ show f, Text) fromSqlValue (LitS s) = Just (textType, encodeUtf8 s, Text) fromSqlValue (LitTS s) = Just (timestampType, encodeUtf8 s, Text) fromSqlValue (LitTime s) = Just (timeType, encodeUtf8 s, Text) fromSqlValue (LitDate s) = Just (dateType, encodeUtf8 s, Text) fromSqlValue (LitNull) = Nothing fromSqlValue (LitJust x) = fromSqlValue x -- | Custom column types for postgres: auto-incrementing primary keys need to -- be @BIGSERIAL@, and ints need to be @INT8@. pgColType :: T.Text -> [ColAttr] -> Maybe T.Text pgColType "INTEGER" attrs | AutoIncrement `elem` attrs = Just "BIGSERIAL PRIMARY KEY NOT NULL" | otherwise = Just $ T.unwords ("INT8" : map compileColAttr attrs) pgColType "DOUBLE" attrs = Just $ T.unwords ("FLOAT8" : map compileColAttr attrs) pgColType "DATETIME" attrs = Just $ T.unwords ("TIMESTAMP" : map compileColAttr attrs) pgColType _ _ = Nothing -- | OIDs for all types used by Selda. boolType, intType, textType, doubleType, dateType, timeType, timestampType :: Oid boolType = Oid 16 intType = Oid 20 textType = Oid 25 doubleType = Oid 701 dateType = Oid 1082 timeType = Oid 1083 timestampType = Oid 1114 #endif