module Database.PostgreSQL.Base
(begin
,rollback
,commit
,query
,exec
,escapeBS
,connect
,defaultConnectInfo
,close
,withDB
,withTransaction
,newPool
,pconnect
,withPoolConnection)
where
import Database.PostgreSQL.Base.Types
import Control.Concurrent
import Control.Monad
import Control.Monad.CatchIO (MonadCatchIO)
import qualified Control.Monad.CatchIO as E
import Control.Monad.Fix
import Control.Monad.State (MonadState,execStateT,modify)
import Control.Monad.Trans
import Data.Binary
import Data.Binary.Get
import Data.Binary.Put
import Data.ByteString (ByteString)
import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as L
import qualified Data.ByteString.Lazy.UTF8 as L (toString,fromString)
import Data.ByteString.UTF8 (toString,fromString)
import Data.Int
import Data.List
import Data.IntMap.Strict (IntMap)
import qualified Data.IntMap.Strict as IntMap
import Data.Maybe
import Data.Monoid
import Network
import Prelude
import System.IO hiding (hPutStr)
defaultConnectInfo :: ConnectInfo
defaultConnectInfo = ConnectInfo {
connectHost = "127.0.0.1"
, connectPort = 5432
, connectUser = "postgres"
, connectPassword = ""
, connectDatabase = ""
}
newPool :: MonadIO m
=> ConnectInfo
-> m Pool
newPool info = liftIO $ do
var <- newMVar $ PoolState {
poolConnections = []
, poolConnectInfo = info
}
return $ Pool var
pconnect :: MonadIO m => Pool -> m Connection
pconnect (Pool var) = liftIO $ do
modifyMVar var $ \state@PoolState{..} -> do
case poolConnections of
[] -> do conn <- connect poolConnectInfo
return (state,conn)
(conn:conns) -> return (state { poolConnections = conns },conn)
restore :: MonadIO m => Pool -> Connection -> m ()
restore (Pool var) conn = liftIO $ do
handle <- readMVar $ connectionHandle conn
modifyMVar_ var $ \state -> do
case handle of
Nothing -> return state
Just h -> do
eof <- hIsOpen h
if eof
then return state { poolConnections = conn : poolConnections state }
else return state
withPoolConnection
:: (MonadCatchIO m,MonadIO m)
=> Pool
-> (Connection -> m a)
-> m ()
withPoolConnection pool m = do
_ <- E.bracket (pconnect pool) (restore pool) m
return ()
connect :: MonadIO m => ConnectInfo -> m Connection
connect connectInfo@ConnectInfo{..} = liftIO $ withSocketsDo $ do
var <- newEmptyMVar
h <- connectTo connectHost (PortNumber $ fromIntegral connectPort)
hSetBuffering h NoBuffering
putMVar var $ Just h
types <- newMVar IntMap.empty
let conn = Connection var types
authenticate conn connectInfo
return conn
withDB :: (MonadCatchIO m,MonadIO m) => ConnectInfo -> (Connection -> m a) -> m a
withDB connectInfo m = E.bracket (liftIO $ connect connectInfo) (liftIO . close) m
withTransaction :: (MonadCatchIO m,MonadIO m) => Connection -> m a -> m a
withTransaction conn act = do
begin conn
r <- act `E.onException` rollback conn
commit conn
return r
rollback :: (MonadCatchIO m,MonadIO m) => Connection -> m ()
rollback conn = do
_ <- exec conn (fromString ("ABORT;" :: String))
return ()
commit :: (MonadCatchIO m,MonadIO m) => Connection -> m ()
commit conn = do
_ <- exec conn (fromString ("COMMIT;" :: String))
return ()
begin :: (MonadCatchIO m,MonadIO m) => Connection -> m ()
begin conn = do
_ <- exec conn (fromString ("BEGIN;" :: String))
return ()
close :: MonadIO m => Connection
-> m ()
close Connection{connectionHandle} = liftIO$ do
modifyMVar_ connectionHandle $ \h -> do
case h of
Just h -> hClose h
Nothing -> return ()
return Nothing
query :: (MonadCatchIO m)
=> Connection
-> ByteString
-> m ([Field],[[Maybe ByteString]])
query conn sql = do
result <- execQuery conn sql
case result of
(_,Just ok) -> return ok
_ -> return ([],[])
execQuery :: (MonadCatchIO m)
=> Connection
-> ByteString
-> m (Integer,Maybe ([Field],[[Maybe ByteString]]))
execQuery conn sql = liftIO $ do
withConnection conn $ \h -> do
types <- readMVar $ connectionObjects conn
Result{..} <- sendQuery types h sql
case resultType of
ErrorResponse -> E.throw $ QueryError (fmap L.toString resultError)
EmptyQueryResponse -> E.throw QueryEmpty
_ ->
let tagCount = fromMaybe 0 resultTagRows
in case resultDesc of
Just fields -> return $ (tagCount,Just (fields,resultRows))
Nothing -> return $ (tagCount,Nothing)
exec :: (MonadCatchIO m)
=> Connection
-> ByteString
-> m Integer
exec conn sql = do
result <- execQuery conn sql
case result of
(ok,_) -> return ok
protocolVersion :: Int32
protocolVersion = 196608
escape :: String -> String
escape ('\\':cs) = '\\' : '\\' : escape cs
escape ('\'':cs) = '\'' : '\'' : escape cs
escape (c:cs) = c : escape cs
escape [] = []
escapeBS :: ByteString -> ByteString
escapeBS = fromString . escape . toString
authenticate :: Connection -> ConnectInfo -> IO ()
authenticate conn@Connection{..} connectInfo = do
withConnection conn $ \h -> do
sendStartUp h connectInfo
getConnectInfoResponse h connectInfo
objects <- objectIds h
modifyMVar_ connectionObjects (\_ -> return objects)
sendStartUp :: Handle -> ConnectInfo -> IO ()
sendStartUp h ConnectInfo{..} = do
sendBlock h Nothing $ do
int32 protocolVersion
string (fromString "user") ; string (fromString connectUser)
string (fromString "database") ; string (fromString connectDatabase)
zero
getConnectInfoResponse :: Handle -> ConnectInfo -> IO ()
getConnectInfoResponse h conninfo = do
(typ,block) <- getMessage h
case typ of
AuthenticationOk
| param == 0 -> waitForReady h
| param == 3 -> sendPassClearText h conninfo
| otherwise -> E.throw $ UnsupportedAuthenticationMethod param (show block)
where param = decode block :: Int32
_salt = flip runGet block $ do
_ <- getInt32
getWord8
els -> E.throw $ AuthenticationFailed (show (els,block))
sendPassClearText :: Handle -> ConnectInfo -> IO ()
sendPassClearText h conninfo@ConnectInfo{..} = do
sendMessage h PasswordMessage $
string (fromString connectPassword)
getConnectInfoResponse h conninfo
objectIds :: Handle -> IO (IntMap Type)
objectIds h = do
Result{..} <- sendQuery IntMap.empty h q
case resultType of
ErrorResponse -> E.throw $ InitializationError "Couldn't get types."
_ -> return $ IntMap.fromList $ catMaybes $ flip map resultRows $ \row ->
case map toString $ catMaybes row of
[typeName,readMay -> Just objId] -> do
typ <- typeFromName typeName
return (fromIntegral objId,typ)
_ -> Nothing
where q = fromString ("SELECT typname, oid FROM pg_type" :: String)
sendQuery :: IntMap Type -> Handle -> ByteString -> IO Result
sendQuery types h sql = do
sendMessage h Query $ string sql
listener $ \continue -> do
(typ,block) <- liftIO $ getMessage h
let setStatus = modify $ \r -> r { resultType = typ }
case typ of
ReadyForQuery ->
modify $ \r -> r { resultRows = reverse (resultRows r) }
listenPassively -> do
case listenPassively of
EmptyQueryResponse -> setStatus
CommandComplete -> do setStatus
setCommandTag block
ErrorResponse -> do
modify $ \r -> r { resultError = Just block }
setStatus
RowDescription -> getRowDesc types block
DataRow -> getDataRow block
_ -> return ()
continue
where emptyResponse = Result [] Nothing Nothing [] UnknownMessageType Nothing
listener m = execStateT (fix m) emptyResponse
setCommandTag :: MonadState Result m => L.ByteString -> m ()
setCommandTag block = do
modify $ \r -> r { resultTagRows = rows }
where rows =
case tag block of
["INSERT",_oid,readMay -> Just rows] -> return rows
[cmd,readMay -> Just rows] | cmd `elem` cmds -> return rows
_ -> Nothing
tag = words . concat . map toString . L.toChunks . runGet getString
cmds = ["DELETE","UPDATE","SELECT","MOVE","FETCH"]
getRowDesc :: MonadState Result m => IntMap Type -> L.ByteString -> m ()
getRowDesc types block =
modify $ \r -> r {
resultDesc = Just (parseFields types (runGet parseMsg block))
}
where parseMsg = do
fieldCount :: Int16 <- getInt16
forM [1..fieldCount] $ \_ -> do
name <- getString
objid <- getInt32
colid <- getInt16
dtype <- getInt32
size <- getInt16
modifier <- getInt32
code <- getInt16
return (name,objid,colid,dtype,size,modifier,code)
parseFields :: IntMap Type
-> [(L.ByteString,Int32,Int16,Int32,Int16,Int32,Int16)]
-> [Field]
parseFields types = mapMaybe parse where
parse (_fieldName
,_
,_
,parseType types -> typ
,_
,_
,parseFormatCode -> formatCode)
= Just $ Field {
fieldType = typ
, fieldFormatCode = formatCode
}
parseType :: IntMap Type -> Int32 -> Type
parseType types objId =
case IntMap.lookup (fromIntegral objId) types of
Just typ -> typ
_ -> error $ "parseType: Unknown type given by object-id: " ++ show objId
typeFromName :: String -> Maybe Type
typeFromName = flip lookup fieldTypes
fieldTypes :: [(String, Type)]
fieldTypes =
[("bool",Boolean)
,("int2",Short)
,("integer",Long)
,("int",Long)
,("int4",Long)
,("int8",LongLong)
,("timestamptz",TimestampWithZone)
,("varchar",CharVarying)
,("text",Text)]
parseFormatCode :: Int16 -> FormatCode
parseFormatCode 1 = BinaryCode
parseFormatCode _ = TextCode
getDataRow :: MonadState Result m => L.ByteString -> m ()
getDataRow block =
modify $ \r -> r { resultRows = runGet parseMsg block : resultRows r }
where parseMsg = do
values :: Int16 <- getInt16
forM [1..values] $ \_ -> do
size <- getInt32
if size == 1
then return Nothing
else do v <- getByteString (fromIntegral size)
return (Just v)
typeFromChar :: Char -> Maybe MessageType
typeFromChar c = lookup c types
charFromType :: MessageType -> Maybe Char
charFromType typ = fmap fst $ find ((==typ).snd) types
types :: [(Char, MessageType)]
types = [('C',CommandComplete)
,('T',RowDescription)
,('D',DataRow)
,('I',EmptyQueryResponse)
,('E',ErrorResponse)
,('Z',ReadyForQuery)
,('N',NoticeResponse)
,('R',AuthenticationOk)
,('Q',Query)
,('p',PasswordMessage)]
waitForReady :: Handle -> IO ()
waitForReady h = loop where
loop = do
(typ,block) <- getMessage h
case typ of
ErrorResponse -> E.throw $ GeneralError $ show block
ReadyForQuery | decode block == 'I' -> return ()
_ -> loop
withConnection :: Connection -> (Handle -> IO a) -> IO a
withConnection Connection{..} m = do
withMVar connectionHandle $ \h -> do
case h of
Just h -> m h
Nothing -> E.throw ConnectionLost
sendMessage :: Handle -> MessageType -> Put -> IO ()
sendMessage h typ output =
case charFromType typ of
Just char -> sendBlock h (Just char) output
Nothing -> error $ "sendMessage: Bad message type " ++ show typ
sendBlock :: Handle -> Maybe Char -> Put -> IO ()
sendBlock h typ output = do
L.hPutStr h bytes
where bytes = start `mappend` out
start = runPut $ do
maybe (return ()) (put . toByte) typ
int32 $ fromIntegral int32Size +
fromIntegral (L.length out)
out = runPut output
toByte c = fromIntegral (fromEnum c) :: Word8
getMessage :: Handle -> IO (MessageType,L.ByteString)
getMessage h = do
messageType <- L.hGet h 1
blockLength <- L.hGet h int32Size
let typ = decode messageType
rest = fromIntegral (decode blockLength :: Int32) int32Size
block <- L.hGet h rest
return (maybe UnknownMessageType id $ typeFromChar typ,block)
string :: B.ByteString -> Put
string s = do putByteString s; zero
int32 :: Int32 -> Put
int32 = put
zero :: Put
zero = put (0 :: Word8)
int32Size :: Int
int32Size = 4
getInt16 :: Get Int16
getInt16 = get
getInt32 :: Get Int32
getInt32 = get
getString :: Get L.ByteString
getString = getLazyByteStringNul
readMay :: Read a => String -> Maybe a
readMay x = case reads x of
[(v,"")] -> return v
_ -> Nothing