{-# LANGUAGE CPP, BangPatterns, DoAndIfThenElse, RecordWildCards #-} {-# LANGUAGE DeriveDataTypeable, GeneralizedNewtypeDeriving #-} ------------------------------------------------------------------------------ -- | -- Module: Database.PostgreSQL.Simple.Internal -- Copyright: (c) 2011-2012 Leon P Smith -- License: BSD3 -- Maintainer: Leon P Smith -- Stability: experimental -- -- Internal bits. This interface is less stable and can change at any time. -- In particular this means that while the rest of the postgresql-simple -- package endeavors to follow the package versioning policy, this module -- does not. Also, at the moment there are things in here that aren't -- particularly internal and are exported elsewhere; these will eventually -- disappear from this module. -- ------------------------------------------------------------------------------ module Database.PostgreSQL.Simple.Internal where import Control.Applicative import Control.Exception import Control.Concurrent.MVar import Control.Monad(MonadPlus(..)) import Data.ByteString(ByteString) import qualified Data.ByteString as B import qualified Data.ByteString.Char8 as B8 import Data.ByteString.Builder ( Builder, byteString ) import Data.Char (ord) import Data.Int (Int64) import qualified Data.IntMap as IntMap import Data.IORef import Data.Maybe(fromMaybe) import Data.Monoid import Data.String import qualified Data.Text as T import qualified Data.Text.Encoding as TE import Data.Typeable import Data.Word import Database.PostgreSQL.LibPQ(Oid(..)) import qualified Database.PostgreSQL.LibPQ as PQ import Database.PostgreSQL.LibPQ(ExecStatus(..)) import Database.PostgreSQL.Simple.Compat ( toByteString ) import Database.PostgreSQL.Simple.Ok import Database.PostgreSQL.Simple.ToField (Action(..), inQuotes) import Database.PostgreSQL.Simple.Types (Query(..)) import Database.PostgreSQL.Simple.TypeInfo.Types(TypeInfo) import Control.Monad.Trans.State.Strict import Control.Monad.Trans.Reader import Control.Monad.Trans.Class import GHC.IO.Exception #if !defined(mingw32_HOST_OS) import Control.Concurrent(threadWaitRead, threadWaitWrite) #endif -- | A Field represents metadata about a particular field -- -- You don't particularly want to retain these structures for a long -- period of time, as they will retain the entire query result, not -- just the field metadata data Field = Field { result :: !PQ.Result , column :: {-# UNPACK #-} !PQ.Column , typeOid :: {-# UNPACK #-} !PQ.Oid -- ^ This returns the type oid associated with the column. Analogous -- to libpq's @PQftype@. } type TypeInfoCache = IntMap.IntMap TypeInfo data Connection = Connection { connectionHandle :: {-# UNPACK #-} !(MVar PQ.Connection) , connectionObjects :: {-# UNPACK #-} !(MVar TypeInfoCache) , connectionTempNameCounter :: {-# UNPACK #-} !(IORef Int64) } deriving (Typeable) instance Eq Connection where x == y = connectionHandle x == connectionHandle y data SqlError = SqlError { sqlState :: ByteString , sqlExecStatus :: ExecStatus , sqlErrorMsg :: ByteString , sqlErrorDetail :: ByteString , sqlErrorHint :: ByteString } deriving (Show, Typeable) fatalError :: ByteString -> SqlError fatalError msg = SqlError "" FatalError msg "" "" instance Exception SqlError -- | Exception thrown if 'query' is used to perform an @INSERT@-like -- operation, or 'execute' is used to perform a @SELECT@-like operation. data QueryError = QueryError { qeMessage :: String , qeQuery :: Query } deriving (Eq, Show, Typeable) instance Exception QueryError -- | Exception thrown if a 'Query' could not be formatted correctly. -- This may occur if the number of \'@?@\' characters in the query -- string does not match the number of parameters provided. data FormatError = FormatError { fmtMessage :: String , fmtQuery :: Query , fmtParams :: [ByteString] } deriving (Eq, Show, Typeable) instance Exception FormatError data ConnectInfo = ConnectInfo { connectHost :: String , connectPort :: Word16 , connectUser :: String , connectPassword :: String , connectDatabase :: String } deriving (Eq,Read,Show,Typeable) -- | Default information for setting up a connection. -- -- Defaults are as follows: -- -- * Server on @localhost@ -- -- * Port on @5432@ -- -- * User @postgres@ -- -- * No password -- -- * Database @postgres@ -- -- Use as in the following example: -- -- > connect defaultConnectInfo { connectHost = "db.example.com" } defaultConnectInfo :: ConnectInfo defaultConnectInfo = ConnectInfo { connectHost = "127.0.0.1" , connectPort = 5432 , connectUser = "postgres" , connectPassword = "" , connectDatabase = "" } -- | Connect with the given username to the given database. Will throw -- an exception if it cannot connect. connect :: ConnectInfo -> IO Connection connect = connectPostgreSQL . postgreSQLConnectionString -- | Attempt to make a connection based on a libpq connection string. -- See -- for more information. Here is an example with some -- of the most commonly used parameters: -- -- > host='db.somedomain.com' port=5432 ... -- -- This attempts to connect to @db.somedomain.com:5432@. Omitting the port -- will normally default to 5432. -- -- On systems that provide unix domain sockets, omitting the host parameter -- will cause libpq to attempt to connect via unix domain sockets. -- The default filesystem path to the socket is constructed from the -- port number and the @DEFAULT_PGSOCKET_DIR@ constant defined in the -- @pg_config_manual.h@ header file. Connecting via unix sockets tends -- to use the @peer@ authentication method, which is very secure and -- does not require a password. -- -- On Windows and other systems without unix domain sockets, omitting -- the host will default to @localhost@. -- -- > ... dbname='postgres' user='postgres' password='secret \' \\ pw' -- -- This attempts to connect to a database named @postgres@ with -- user @postgres@ and password @secret \' \\ pw@. Backslash -- characters will have to be double-quoted in literal Haskell strings, -- of course. Omitting @dbname@ and @user@ will both default to the -- system username that the client process is running as. -- -- Omitting @password@ will default to an appropriate password found -- in the @pgpass@ file, or no password at all if a matching line is -- not found. See -- for -- more information regarding this file. -- -- As all parameters are optional and the defaults are sensible, the -- empty connection string can be useful for development and -- exploratory use, assuming your system is set up appropriately. -- -- On Unix, such a setup would typically consist of a local -- postgresql server listening on port 5432, as well as a system user, -- database user, and database sharing a common name, with permissions -- granted to the user on the database. -- -- On Windows, in addition you will either need @pg_hba.conf@ -- to specify the use of the @trust@ authentication method for -- the connection, which may not be appropriate for multiuser -- or production machines, or you will need to use a @pgpass@ file -- with the @password@ or @md5@ authentication methods. -- -- See -- for more information regarding the authentication process. -- -- SSL/TLS will typically "just work" if your postgresql server supports or -- requires it. However, note that libpq is trivially vulnerable to a MITM -- attack without setting additional SSL parameters in the connection string. -- In particular, @sslmode@ needs to be set to @require@, @verify-ca@, or -- @verify-full@ in order to perform certificate validation. When @sslmode@ -- is @require@, then you will also need to specify a @sslrootcert@ file, -- otherwise no validation of the server's identity will be performed. -- Client authentication via certificates is also possible via the -- @sslcert@ and @sslkey@ parameters. connectPostgreSQL :: ByteString -> IO Connection connectPostgreSQL connstr = do conn <- connectdb connstr stat <- PQ.status conn case stat of PQ.ConnectionOk -> do connectionHandle <- newMVar conn connectionObjects <- newMVar (IntMap.empty) connectionTempNameCounter <- newIORef 0 let wconn = Connection{..} version <- PQ.serverVersion conn let settings | version < 80200 = "SET datestyle TO ISO;SET client_encoding TO UTF8" | otherwise = "SET datestyle TO ISO;SET client_encoding TO UTF8;SET standard_conforming_strings TO on" _ <- execute_ wconn settings return wconn _ -> do msg <- maybe "connectPostgreSQL error" id <$> PQ.errorMessage conn throwIO $ fatalError msg connectdb :: ByteString -> IO PQ.Connection #if defined(mingw32_HOST_OS) connectdb = PQ.connectdb #else connectdb conninfo = do conn <- PQ.connectStart conninfo loop conn where funcName = "Database.PostgreSQL.Simple.connectPostgreSQL" loop conn = do status <- PQ.connectPoll conn case status of PQ.PollingFailed -> throwLibPQError conn "connection failed" PQ.PollingReading -> do mfd <- PQ.socket conn case mfd of Nothing -> throwIO $! fdError funcName Just fd -> do threadWaitRead fd loop conn PQ.PollingWriting -> do mfd <- PQ.socket conn case mfd of Nothing -> throwIO $! fdError funcName Just fd -> do threadWaitWrite fd loop conn PQ.PollingOk -> return conn #endif -- | Turns a 'ConnectInfo' data structure into a libpq connection string. postgreSQLConnectionString :: ConnectInfo -> ByteString postgreSQLConnectionString connectInfo = fromString connstr where connstr = str "host=" connectHost $ num "port=" connectPort $ str "user=" connectUser $ str "password=" connectPassword $ str "dbname=" connectDatabase $ [] str name field | null value = id | otherwise = showString name . quote value . space where value = field connectInfo num name field | value <= 0 = id | otherwise = showString name . shows value . space where value = field connectInfo quote str rest = '\'' : foldr delta ('\'' : rest) str where delta c cs = case c of '\\' -> '\\' : '\\' : cs '\'' -> '\\' : '\'' : cs _ -> c : cs space [] = [] space xs = ' ':xs oid2int :: Oid -> Int oid2int (Oid x) = fromIntegral x {-# INLINE oid2int #-} exec :: Connection -> ByteString -> IO PQ.Result #if defined(mingw32_HOST_OS) exec conn sql = withConnection conn $ \h -> do mres <- PQ.exec h sql case mres of Nothing -> throwLibPQError h "PQexec returned no results" Just res -> return res #else exec conn sql = withConnection conn $ \h -> do success <- PQ.sendQuery h sql if success then awaitResult h Nothing else throwLibPQError h "PQsendQuery failed" where awaitResult h mres = do mfd <- PQ.socket h case mfd of Nothing -> throwIO $! fdError "Database.PostgreSQL.Simple.Internal.exec" Just fd -> do threadWaitRead fd _ <- PQ.consumeInput h -- FIXME? getResult h mres getResult h mres = do isBusy <- PQ.isBusy h if isBusy then awaitResult h mres else do mres' <- PQ.getResult h case mres' of Nothing -> case mres of Nothing -> throwLibPQError h "PQgetResult returned no results" Just res -> return res Just res -> do status <- PQ.resultStatus res case status of PQ.EmptyQuery -> getResult h mres' PQ.CommandOk -> getResult h mres' PQ.TuplesOk -> getResult h mres' PQ.CopyOut -> return res PQ.CopyIn -> return res PQ.BadResponse -> getResult h mres' PQ.NonfatalError -> getResult h mres' PQ.FatalError -> getResult h mres' #endif -- | A version of 'execute' that does not perform query substitution. execute_ :: Connection -> Query -> IO Int64 execute_ conn q@(Query stmt) = do result <- exec conn stmt finishExecute conn q result finishExecute :: Connection -> Query -> PQ.Result -> IO Int64 finishExecute _conn q result = do status <- PQ.resultStatus result case status of PQ.EmptyQuery -> throwIO $ QueryError "execute: Empty query" q PQ.CommandOk -> do ncols <- PQ.nfields result if ncols /= 0 then throwIO $ QueryError ("execute resulted in " ++ show ncols ++ "-column result") q else do nstr <- PQ.cmdTuples result return $ case nstr of Nothing -> 0 -- is this appropriate? Just str -> toInteger str PQ.TuplesOk -> do ncols <- PQ.nfields result throwIO $ QueryError ("execute resulted in " ++ show ncols ++ "-column result") q PQ.CopyOut -> throwIO $ QueryError "execute: COPY TO is not supported" q PQ.CopyIn -> throwIO $ QueryError "execute: COPY FROM is not supported" q PQ.BadResponse -> throwResultError "execute" result status PQ.NonfatalError -> throwResultError "execute" result status PQ.FatalError -> throwResultError "execute" result status where toInteger str = B8.foldl' delta 0 str where delta acc c = if '0' <= c && c <= '9' then 10 * acc + fromIntegral (ord c - ord '0') else error ("finishExecute: not an int: " ++ B8.unpack str) throwResultError :: ByteString -> PQ.Result -> PQ.ExecStatus -> IO a throwResultError _ result status = do errormsg <- fromMaybe "" <$> PQ.resultErrorField result PQ.DiagMessagePrimary detail <- fromMaybe "" <$> PQ.resultErrorField result PQ.DiagMessageDetail hint <- fromMaybe "" <$> PQ.resultErrorField result PQ.DiagMessageHint state <- maybe "" id <$> PQ.resultErrorField result PQ.DiagSqlstate throwIO $ SqlError { sqlState = state , sqlExecStatus = status , sqlErrorMsg = errormsg , sqlErrorDetail = detail , sqlErrorHint = hint } disconnectedError :: SqlError disconnectedError = fatalError "connection disconnected" -- | Atomically perform an action with the database handle, if there is one. withConnection :: Connection -> (PQ.Connection -> IO a) -> IO a withConnection Connection{..} m = do withMVar connectionHandle $ \conn -> do if PQ.isNullConnection conn then throwIO disconnectedError else m conn close :: Connection -> IO () close Connection{..} = mask $ \restore -> (do conn <- takeMVar connectionHandle restore (PQ.finish conn) `finally` do putMVar connectionHandle =<< PQ.newNullConnection ) newNullConnection :: IO Connection newNullConnection = do connectionHandle <- newMVar =<< PQ.newNullConnection connectionObjects <- newMVar IntMap.empty connectionTempNameCounter <- newIORef 0 return Connection{..} data Row = Row { row :: {-# UNPACK #-} !PQ.Row , rowresult :: !PQ.Result } newtype RowParser a = RP { unRP :: ReaderT Row (StateT PQ.Column Conversion) a } deriving ( Functor, Applicative, Alternative, Monad ) liftRowParser :: IO a -> RowParser a liftRowParser = RP . lift . lift . liftConversion newtype Conversion a = Conversion { runConversion :: Connection -> IO (Ok a) } liftConversion :: IO a -> Conversion a liftConversion m = Conversion (\_ -> Ok <$> m) instance Functor Conversion where fmap f m = Conversion $ \conn -> (fmap . fmap) f (runConversion m conn) instance Applicative Conversion where pure a = Conversion $ \_conn -> pure (pure a) mf <*> ma = Conversion $ \conn -> do okf <- runConversion mf conn case okf of Ok f -> (fmap . fmap) f (runConversion ma conn) Errors errs -> return (Errors errs) instance Alternative Conversion where empty = Conversion $ \_conn -> pure empty ma <|> mb = Conversion $ \conn -> do oka <- runConversion ma conn case oka of Ok _ -> return oka Errors _ -> (oka <|>) <$> runConversion mb conn instance Monad Conversion where return a = Conversion $ \_conn -> return (return a) m >>= f = Conversion $ \conn -> do oka <- runConversion m conn case oka of Ok a -> runConversion (f a) conn Errors err -> return (Errors err) instance MonadPlus Conversion where mzero = empty mplus = (<|>) conversionMap :: (Ok a -> Ok b) -> Conversion a -> Conversion b conversionMap f m = Conversion $ \conn -> f <$> runConversion m conn conversionError :: Exception err => err -> Conversion a conversionError err = Conversion $ \_ -> return (Errors [SomeException err]) newTempName :: Connection -> IO Query newTempName Connection{..} = do !n <- atomicModifyIORef connectionTempNameCounter (\n -> let !n' = n+1 in (n', n')) return $! Query $ B8.pack $ "temp" ++ show n -- FIXME? What error should getNotification and getCopyData throw? fdError :: ByteString -> IOError fdError funcName = IOError { ioe_handle = Nothing, ioe_type = ResourceVanished, ioe_location = B8.unpack funcName, ioe_description = "failed to fetch file descriptor", ioe_errno = Nothing, ioe_filename = Nothing } libPQError :: ByteString -> IOError libPQError desc = IOError { ioe_handle = Nothing, ioe_type = OtherError, ioe_location = "libpq", ioe_description = B8.unpack desc, ioe_errno = Nothing, ioe_filename = Nothing } throwLibPQError :: PQ.Connection -> ByteString -> IO a throwLibPQError conn default_desc = do msg <- maybe default_desc id <$> PQ.errorMessage conn throwIO $! libPQError msg fmtError :: String -> Query -> [Action] -> a fmtError msg q xs = throw FormatError { fmtMessage = msg , fmtQuery = q , fmtParams = map twiddle xs } where twiddle (Plain b) = toByteString b twiddle (Escape s) = s twiddle (EscapeByteA s) = s twiddle (EscapeIdentifier s) = s twiddle (Many ys) = B.concat (map twiddle ys) fmtErrorBs :: Query -> [Action] -> ByteString -> a fmtErrorBs q xs msg = fmtError (T.unpack $ TE.decodeUtf8 msg) q xs -- | Quote bytestring or throw 'FormatError' quote :: Query -> [Action] -> Either ByteString ByteString -> Builder quote q xs = either (fmtErrorBs q xs) (inQuotes . byteString) buildAction :: Connection -- ^ Connection for string escaping -> Query -- ^ Query for message error -> [Action] -- ^ List of parameters for message error -> Action -- ^ Action to build -> IO Builder buildAction _ _ _ (Plain b) = pure b buildAction conn q xs (Escape s) = quote q xs <$> escapeStringConn conn s buildAction conn q xs (EscapeByteA s) = quote q xs <$> escapeByteaConn conn s buildAction conn q xs (EscapeIdentifier s) = either (fmtErrorBs q xs) byteString <$> escapeIdentifier conn s buildAction conn q xs (Many ys) = mconcat <$> mapM (buildAction conn q xs) ys checkError :: PQ.Connection -> Maybe a -> IO (Either ByteString a) checkError _ (Just x) = return $ Right x checkError c Nothing = Left . maybe "" id <$> PQ.errorMessage c escapeWrap :: (PQ.Connection -> ByteString -> IO (Maybe ByteString)) -> Connection -> ByteString -> IO (Either ByteString ByteString) escapeWrap f conn s = withConnection conn $ \c -> f c s >>= checkError c escapeStringConn :: Connection -> ByteString -> IO (Either ByteString ByteString) escapeStringConn = escapeWrap PQ.escapeStringConn escapeIdentifier :: Connection -> ByteString -> IO (Either ByteString ByteString) escapeIdentifier = escapeWrap PQ.escapeIdentifier escapeByteaConn :: Connection -> ByteString -> IO (Either ByteString ByteString) escapeByteaConn = escapeWrap PQ.escapeByteaConn