{-# OPTIONS_GHC -Wall -fno-warn-name-shadowing #-} {-# LANGUAGE RecordWildCards, OverloadedStrings, ScopedTypeVariables #-} {-# LANGUAGE FlexibleContexts, ViewPatterns, NamedFieldPuns, TupleSections #-} -- | A front-end implementation for the PostgreSQL database protocol -- version 3.0 (implemented in PostgreSQL 7.4 and later). module Database.PostgreSQL.Base (begin ,rollback ,commit ,query ,exec ,escapeBS ,connect ,defaultConnectInfo ,close ,withDB ,withTransaction ,newPool ,pconnect ,withPoolConnection) where import Database.PostgreSQL.Base.Types import Control.Concurrent import Control.Monad import Control.Monad.CatchIO (MonadCatchIO) import qualified Control.Monad.CatchIO as E import Control.Monad.Fix import Control.Monad.State (MonadState,execStateT,modify) import Control.Monad.Trans import Data.Binary import Data.Binary.Get import Data.Binary.Put import Data.ByteString (ByteString) import qualified Data.ByteString as B import qualified Data.ByteString.Lazy as L import qualified Data.ByteString.Lazy.UTF8 as L (toString,fromString) import Data.ByteString.UTF8 (toString,fromString) import Data.Int import Data.List import Data.IntMap.Strict (IntMap) import qualified Data.IntMap.Strict as IntMap import Data.Maybe import Data.Monoid import Network import Prelude import System.IO hiding (hPutStr) -------------------------------------------------------------------------------- -- Exported values -- | Default information for setting up a connection. -- -- Defaults are as follows: -- -- * Server on @localhost@ -- -- * User @postgres@ -- -- * No password -- -- * Database @test@ -- -- * Character set @utf8@ -- -- Use as in the following example: -- -- > connect defaultConnectInfo { connectHost = "db.example.com" } defaultConnectInfo :: ConnectInfo defaultConnectInfo = ConnectInfo { connectHost = "127.0.0.1" , connectPort = 5432 , connectUser = "postgres" , connectPassword = "" , connectDatabase = "" } -- | Create a new connection pool. newPool :: MonadIO m => ConnectInfo -- ^ Connect info. -> m Pool newPool info = liftIO $ do var <- newMVar $ PoolState { poolConnections = [] , poolConnectInfo = info } return $ Pool var -- | Connect using the connection pool. pconnect :: MonadIO m => Pool -> m Connection pconnect (Pool var) = liftIO $ do modifyMVar var $ \state@PoolState{..} -> do case poolConnections of [] -> do conn <- connect poolConnectInfo return (state,conn) (conn:conns) -> return (state { poolConnections = conns },conn) -- | Restore a connection to the pool. restore :: MonadIO m => Pool -> Connection -> m () restore (Pool var) conn = liftIO $ do handle <- readMVar $ connectionHandle conn modifyMVar_ var $ \state -> do case handle of Nothing -> return state Just h -> do eof <- hIsOpen h if eof then return state { poolConnections = conn : poolConnections state } else return state -- | Use the connection pool. withPoolConnection :: (MonadCatchIO m,MonadIO m) => Pool -- ^ The connection pool. -> (Connection -> m a) -- ^ Use the connection. -> m () withPoolConnection pool m = do _ <- E.bracket (pconnect pool) (restore pool) m return () -- | Connect with the given username to the given database. Will throw -- an exception if it cannot connect. connect :: MonadIO m => ConnectInfo -> m Connection -- ^ The datase connection. connect connectInfo@ConnectInfo{..} = liftIO $ withSocketsDo $ do var <- newEmptyMVar h <- connectTo connectHost (PortNumber $ fromIntegral connectPort) hSetBuffering h NoBuffering putMVar var $ Just h types <- newMVar IntMap.empty let conn = Connection var types authenticate conn connectInfo return conn -- | Run a an action with a connection and close the connection -- afterwards (protects against exceptions). withDB :: (MonadCatchIO m,MonadIO m) => ConnectInfo -> (Connection -> m a) -> m a withDB connectInfo m = E.bracket (liftIO $ connect connectInfo) (liftIO . close) m -- | With a transaction, do some action (protects against exceptions). withTransaction :: (MonadCatchIO m,MonadIO m) => Connection -> m a -> m a withTransaction conn act = do begin conn r <- act `E.onException` rollback conn commit conn return r -- | Rollback a transaction. rollback :: (MonadCatchIO m,MonadIO m) => Connection -> m () rollback conn = do _ <- exec conn (fromString ("ABORT;" :: String)) return () -- | Commit a transaction. commit :: (MonadCatchIO m,MonadIO m) => Connection -> m () commit conn = do _ <- exec conn (fromString ("COMMIT;" :: String)) return () -- | Begin a transaction. begin :: (MonadCatchIO m,MonadIO m) => Connection -> m () begin conn = do _ <- exec conn (fromString ("BEGIN;" :: String)) return () -- | Close a connection. Can safely be called any number of times. close :: MonadIO m => Connection -- ^ The connection. -> m () close Connection{connectionHandle} = liftIO$ do modifyMVar_ connectionHandle $ \h -> do case h of Just h -> hClose h Nothing -> return () return Nothing -- | Run a simple query on a connection. query :: (MonadCatchIO m) => Connection -- ^ The connection. -> ByteString -- ^ The query. -> m ([Field],[[Maybe ByteString]]) query conn sql = do result <- execQuery conn sql case result of (_,Just ok) -> return ok _ -> return ([],[]) -- | Run a simple query on a connection. execQuery :: (MonadCatchIO m) => Connection -- ^ The connection. -> ByteString -- ^ The query. -> m (Integer,Maybe ([Field],[[Maybe ByteString]])) execQuery conn sql = liftIO $ do withConnection conn $ \h -> do types <- readMVar $ connectionObjects conn Result{..} <- sendQuery types h sql case resultType of ErrorResponse -> E.throw $ QueryError (fmap L.toString resultError) EmptyQueryResponse -> E.throw QueryEmpty _ -> let tagCount = fromMaybe 0 resultTagRows in case resultDesc of Just fields -> return $ (tagCount,Just (fields,resultRows)) Nothing -> return $ (tagCount,Nothing) -- | Exec a command. exec :: (MonadCatchIO m) => Connection -> ByteString -> m Integer exec conn sql = do result <- execQuery conn sql case result of (ok,_) -> return ok -- | PostgreSQL protocol version supported by this library. protocolVersion :: Int32 protocolVersion = 196608 -- | Escape a string for PostgreSQL. escape :: String -> String escape ('\\':cs) = '\\' : '\\' : escape cs escape ('\'':cs) = '\'' : '\'' : escape cs escape (c:cs) = c : escape cs escape [] = [] -- | Escape a string for PostgreSQL. escapeBS :: ByteString -> ByteString escapeBS = fromString . escape . toString -------------------------------------------------------------------------------- -- Authentication -- | Run the connectInfoentication procedure. authenticate :: Connection -> ConnectInfo -> IO () authenticate conn@Connection{..} connectInfo = do withConnection conn $ \h -> do sendStartUp h connectInfo getConnectInfoResponse h connectInfo objects <- objectIds h modifyMVar_ connectionObjects (\_ -> return objects) -- | Send the start-up message. sendStartUp :: Handle -> ConnectInfo -> IO () sendStartUp h ConnectInfo{..} = do sendBlock h Nothing $ do int32 protocolVersion string (fromString "user") ; string (fromString connectUser) string (fromString "database") ; string (fromString connectDatabase) zero -- | Wait for and process the connectInfoentication response from the server. getConnectInfoResponse :: Handle -> ConnectInfo -> IO () getConnectInfoResponse h conninfo = do (typ,block) <- getMessage h -- TODO: Handle connectInfo failure. Handle information messages that are -- sent, maybe store in the connection value for later -- inspection. case typ of AuthenticationOk | param == 0 -> waitForReady h | param == 3 -> sendPassClearText h conninfo -- | param == 5 -> sendPassMd5 h conninfo salt | otherwise -> E.throw $ UnsupportedAuthenticationMethod param (show block) where param = decode block :: Int32 _salt = flip runGet block $ do _ <- getInt32 getWord8 els -> E.throw $ AuthenticationFailed (show (els,block)) -- | Send the pass as clear text and wait for connect response. sendPassClearText :: Handle -> ConnectInfo -> IO () sendPassClearText h conninfo@ConnectInfo{..} = do sendMessage h PasswordMessage $ string (fromString connectPassword) getConnectInfoResponse h conninfo -- -- | Send the pass as salted MD5 and wait for connect response. -- sendPassMd5 :: Handle -> ConnectInfo -> Word8 -> IO () -- sendPassMd5 h conninfo@ConnectInfo{..} salt = do -- -- TODO: Need md5 library with salt support. -- sendMessage h PasswordMessage $ -- string (fromString connectPassword) -- getConnectInfoResponse h conninfo -------------------------------------------------------------------------------- -- Initialization objectIds :: Handle -> IO (IntMap Type) objectIds h = do Result{..} <- sendQuery IntMap.empty h q case resultType of ErrorResponse -> E.throw $ InitializationError "Couldn't get types." _ -> return $ IntMap.fromList $ catMaybes $ flip map resultRows $ \row -> case map toString $ catMaybes row of [typeName,readMay -> Just objId] -> do typ <- typeFromName typeName return (fromIntegral objId,typ) _ -> Nothing where q = fromString ("SELECT typname, oid FROM pg_type" :: String) -------------------------------------------------------------------------------- -- Queries and commands -- | Send a simple query. sendQuery :: IntMap Type -> Handle -> ByteString -> IO Result sendQuery types h sql = do sendMessage h Query $ string sql listener $ \continue -> do (typ,block) <- liftIO $ getMessage h let setStatus = modify $ \r -> r { resultType = typ } case typ of ReadyForQuery -> modify $ \r -> r { resultRows = reverse (resultRows r) } listenPassively -> do case listenPassively of EmptyQueryResponse -> setStatus CommandComplete -> do setStatus setCommandTag block ErrorResponse -> do modify $ \r -> r { resultError = Just block } setStatus RowDescription -> getRowDesc types block DataRow -> getDataRow block _ -> return () continue where emptyResponse = Result [] Nothing Nothing [] UnknownMessageType Nothing listener m = execStateT (fix m) emptyResponse -- | CommandComplete returns a ‘tag’ which indicates how many rows were -- affected, or returned, as a result of the command. -- See http://developer.postgresql.org/pgdocs/postgres/protocol-message-formats.html setCommandTag :: MonadState Result m => L.ByteString -> m () setCommandTag block = do modify $ \r -> r { resultTagRows = rows } where rows = case tag block of ["INSERT",_oid,readMay -> Just rows] -> return rows [cmd,readMay -> Just rows] | cmd `elem` cmds -> return rows _ -> Nothing tag = words . concat . map toString . L.toChunks . runGet getString cmds = ["DELETE","UPDATE","SELECT","MOVE","FETCH"] -- | Update the row description of the result. getRowDesc :: MonadState Result m => IntMap Type -> L.ByteString -> m () getRowDesc types block = modify $ \r -> r { resultDesc = Just (parseFields types (runGet parseMsg block)) } where parseMsg = do fieldCount :: Int16 <- getInt16 forM [1..fieldCount] $ \_ -> do name <- getString objid <- getInt32 colid <- getInt16 dtype <- getInt32 size <- getInt16 modifier <- getInt32 code <- getInt16 return (name,objid,colid,dtype,size,modifier,code) -- | Parse a row description. -- -- Parts of the row description are: -- -- String: The field name. -- -- Int32: If the field can be identified as a column of a specific -- table, the object ID of the table; otherwise zero. -- -- Int16: If the field can be identified as a column of a specific -- table, the attribute number of the column; otherwise zero. ---- -- Int32: The object ID of the field's data type. ---- -- Int16: The data type size (see pg_type.typlen). Note that negative -- values denote variable-width types. ---- -- Int32: The type modifier (see pg_attribute.atttypmod). The meaning -- of the modifier is type-specific. -- -- Int16: The format code being used for the field. Currently will be -- zero (text) or one (binary). In a RowDescription returned from the -- statement variant of Describe, the format code is not yet known and -- will always be zero. -- parseFields :: IntMap Type -> [(L.ByteString,Int32,Int16,Int32,Int16,Int32,Int16)] -> [Field] parseFields types = mapMaybe parse where parse (_fieldName ,_ -- parseObjId -> _objectId ,_ -- parseAttrId -> _attrId ,parseType types -> typ ,_ -- parseSize -> _typeSize ,_ -- parseModifier typ -> _typeModifier ,parseFormatCode -> formatCode) = Just $ Field { fieldType = typ , fieldFormatCode = formatCode } -- These aren't used (yet). -- -- | Parse an object ID. 0 means no object. -- parseObjId :: Int32 -> Maybe ObjectId -- parseObjId 0 = Nothing -- parseObjId n = Just (ObjectId n) -- -- | Parse an attribute ID. 0 means no object. -- parseAttrId :: Int16 -> Maybe ObjectId -- parseAttrId 0 = Nothing -- parseAttrId n = Just (ObjectId $ fromIntegral n) -- | Parse a number into a type. parseType :: IntMap Type -> Int32 -> Type parseType types objId = case IntMap.lookup (fromIntegral objId) types of Just typ -> typ _ -> error $ "parseType: Unknown type given by object-id: " ++ show objId typeFromName :: String -> Maybe Type typeFromName = flip lookup fieldTypes fieldTypes :: [(String, Type)] fieldTypes = [("bool",Boolean) ,("int2",Short) ,("integer",Long) ,("int",Long) ,("int4",Long) ,("int8",LongLong) ,("timestamptz",TimestampWithZone) ,("varchar",CharVarying) ,("text",Text)] -- This isn't used yet. -- | Parse a type's size. -- parseSize :: Int16 -> Size -- parseSize (-1) = Varying -- parseSize n = Size n -- This isn't used yet. -- -- | Parse a type-specific modifier. -- parseModifier :: Type -> Int32 -> Maybe Modifier -- parseModifier _typ _modifier = Nothing -- | Parse a format code (text or binary). parseFormatCode :: Int16 -> FormatCode parseFormatCode 1 = BinaryCode parseFormatCode _ = TextCode -- | Add a data row to the response. getDataRow :: MonadState Result m => L.ByteString -> m () getDataRow block = modify $ \r -> r { resultRows = runGet parseMsg block : resultRows r } where parseMsg = do values :: Int16 <- getInt16 forM [1..values] $ \_ -> do size <- getInt32 if size == -1 then return Nothing else do v <- getByteString (fromIntegral size) return (Just v) -- TODO: -- getNotice :: MonadState Result m => L.ByteString -> m () -- getNotice block = -- return () -- modify $ \r -> r { responseNotices = runGet parseMsg block : responseNotices r } -- where parseMsg = return "" typeFromChar :: Char -> Maybe MessageType typeFromChar c = lookup c types charFromType :: MessageType -> Maybe Char charFromType typ = fmap fst $ find ((==typ).snd) types types :: [(Char, MessageType)] types = [('C',CommandComplete) ,('T',RowDescription) ,('D',DataRow) ,('I',EmptyQueryResponse) ,('E',ErrorResponse) ,('Z',ReadyForQuery) ,('N',NoticeResponse) ,('R',AuthenticationOk) ,('Q',Query) ,('p',PasswordMessage)] -- | Blocks until receives ReadyForQuery. waitForReady :: Handle -> IO () waitForReady h = loop where loop = do (typ,block) <- getMessage h case typ of ErrorResponse -> E.throw $ GeneralError $ show block ReadyForQuery | decode block == 'I' -> return () _ -> loop -------------------------------------------------------------------------------- -- Connections -- | Atomically perform an action with the database handle, if there is one. withConnection :: Connection -> (Handle -> IO a) -> IO a withConnection Connection{..} m = do withMVar connectionHandle $ \h -> do case h of Just h -> m h -- TODO: Use extensible exceptions. Nothing -> E.throw ConnectionLost -- | Send a block of bytes on a handle, prepending the message type -- and complete length. sendMessage :: Handle -> MessageType -> Put -> IO () sendMessage h typ output = case charFromType typ of Just char -> sendBlock h (Just char) output Nothing -> error $ "sendMessage: Bad message type " ++ show typ -- | Send a block of bytes on a handle, prepending the complete length. sendBlock :: Handle -> Maybe Char -> Put -> IO () sendBlock h typ output = do L.hPutStr h bytes where bytes = start `mappend` out start = runPut $ do maybe (return ()) (put . toByte) typ int32 $ fromIntegral int32Size + fromIntegral (L.length out) out = runPut output toByte c = fromIntegral (fromEnum c) :: Word8 -- | Get a message (block) from the stream. getMessage :: Handle -> IO (MessageType,L.ByteString) getMessage h = do messageType <- L.hGet h 1 blockLength <- L.hGet h int32Size let typ = decode messageType rest = fromIntegral (decode blockLength :: Int32) - int32Size block <- L.hGet h rest return (maybe UnknownMessageType id $ typeFromChar typ,block) -------------------------------------------------------------------------------- -- Binary input/output -- | Put a Haskell string, encoding it to UTF-8, and null-terminating it. string :: B.ByteString -> Put string s = do putByteString s; zero -- | Put a Haskell 32-bit integer. int32 :: Int32 -> Put int32 = put -- | Put zero-byte terminator. zero :: Put zero = put (0 :: Word8) -- | To avoid magic numbers, size of a 32-bit integer in bytes. int32Size :: Int int32Size = 4 getInt16 :: Get Int16 getInt16 = get getInt32 :: Get Int32 getInt32 = get getString :: Get L.ByteString getString = getLazyByteStringNul readMay :: Read a => String -> Maybe a readMay x = case reads x of [(v,"")] -> return v _ -> Nothing