module Database.TemplatePG.Protocol ( PGException(..)
, pgConnect
, pgDisconnect
, describeStatement
, executeSimpleQuery
, executeSimpleStatement
) where
import Database.TemplatePG.Types
import Control.Exception
import Control.Monad (liftM, replicateM)
import Data.Binary
import qualified Data.Binary.Builder as B
import qualified Data.Binary.Get as G
import qualified Data.Binary.Put as P
import Data.ByteString.Internal (c2w, w2c)
import Data.ByteString.Lazy as L hiding (take, repeat, map, any, zipWith)
import Data.ByteString.Lazy.UTF8 hiding (length, decode, take)
import Data.Monoid
import Data.Typeable
import Network
import System.Environment
import System.IO hiding (putStr, putStrLn)
import System.IO.Error (isDoesNotExistError)
import Prelude hiding (putStr, putStrLn)
data PGMessage = Authentication
| BackendKeyData
| CommandComplete
| DataRow [Maybe ByteString]
| Describe String
| EmptyQueryResponse
| ErrorResponse String String String
| Execute
| Flush
| NoData
| NoticeResponse
| ParameterDescription [PGType]
| ParameterStatus
| Parse String String
| ParseComplete
| ReadyForQuery
| RowDescription [(String, PGType, Integer, Int)]
| SimpleQuery String
| UnknownMessage
data PGException = PGException String String
deriving (Show, Typeable)
instance Exception PGException
protocolVersion :: Word32
protocolVersion = 0x30000
debug :: IO (Bool)
debug = catchJust (\e -> if isDoesNotExistError e
then Just ()
else Nothing)
(getEnv "TPG_DEBUG" >> return True) (\ _ -> return False)
pgConnect :: HostName
-> PortID
-> String
-> String
-> String
-> IO Handle
pgConnect host port db user _ = do
h <- connectTo host port
hPut h $ B.toLazyByteString $ pgMessage handshake
hFlush h
_ <- pgWaitFor h [pgMessageID ReadyForQuery]
return h
where handshake = mconcat
[ B.putWord32be protocolVersion
, pgString "user", pgString user
, pgString "database", pgString db
, B.singleton 0 ]
pgMessage :: B.Builder -> B.Builder
pgMessage msg = B.append len msg
where len = B.putWord32be $ fromIntegral $ (L.length $ B.toLazyByteString msg) + 4
pgDisconnect :: Handle
-> IO ()
pgDisconnect = hClose
pgString :: String -> B.Builder
pgString = B.fromLazyByteString . flip snoc 0 . fromString
pgMessageID :: PGMessage -> Word8
pgMessageID m = c2w $ case m of
Authentication -> 'R'
BackendKeyData -> 'K'
CommandComplete -> 'C'
(DataRow _) -> 'D'
(Describe _) -> 'D'
EmptyQueryResponse -> 'I'
(ErrorResponse _ _ _) -> 'E'
Execute -> 'E'
Flush -> 'H'
NoData -> 'n'
NoticeResponse -> 'N'
(ParameterDescription _) -> 't'
ParameterStatus -> 'S'
(Parse _ _) -> 'P'
ParseComplete -> '1'
ReadyForQuery -> 'Z'
(RowDescription _) -> 'T'
(SimpleQuery _) -> 'Q'
UnknownMessage -> error "Unknown message type"
instance Binary PGMessage where
put m = do
let body = B.toLazyByteString $ putMessageBody m
P.putWord8 $ pgMessageID m
P.putWord32be $ fromIntegral $ (L.length body) + 4
P.putLazyByteString body
get = do
(typ, len) <- getMessageHeader
body <- G.getLazyByteString ((fromIntegral len) 4)
return $ G.runGet (getMessageBody typ) body
putMessageBody :: PGMessage -> B.Builder
putMessageBody (Describe n) = mconcat [B.singleton $ c2w 'S', pgString n]
putMessageBody Execute = mconcat [pgString "", B.putWord32be 0]
putMessageBody Flush = B.empty
putMessageBody (Parse s n) = mconcat [pgString n, pgString s, B.putWord16be 0]
putMessageBody (SimpleQuery s) = pgString s
putMessageBody _ = undefined
getMessageHeader :: Get (Word8, Int)
getMessageHeader = do
typ <- G.getWord8
len <- G.getWord32be
return (typ, fromIntegral len)
getMessageBody :: Word8
-> Get PGMessage
getMessageBody typ =
case w2c typ of
'R' -> do return Authentication
't' -> do numParams <- fromIntegral `liftM` G.getWord16be
ps <- replicateM numParams readParam
return $ ParameterDescription ps
where readParam = do typ' <- fromIntegral `liftM` G.getWord32be
return $ pgTypeFromOID typ'
'T' -> do numFields <- fromIntegral `liftM` G.getWord16be
ds <- replicateM numFields readField
return $ RowDescription ds
where readField = do name <- toString `liftM` G.getLazyByteStringNul
oid <- fromIntegral `liftM` G.getWord32be
col <- fromIntegral `liftM` G.getWord16be
typ' <- fromIntegral `liftM` G.getWord32be
_ <- G.getWord16be
_ <- G.getWord32be
_ <- G.getWord16be
return (name, pgTypeFromOID typ', oid, col)
'Z' -> G.getWord8 >> return ReadyForQuery
'1' -> return ParseComplete
'C' -> return CommandComplete
'S' -> return ParameterStatus
'D' -> do numFields <- fromIntegral `liftM` G.getWord16be
ds <- replicateM numFields readField
return $ DataRow ds
where readField = do len <- fromIntegral `liftM` G.getWord32be
s <- case len of
0xFFFFFFFF -> return Nothing
_ -> Just `liftM` G.getLazyByteString len
return s
'K' -> return BackendKeyData
'E' -> do fs <- readFields
case (lookup (c2w 'S') fs,
lookup (c2w 'C') fs,
lookup (c2w 'M') fs) of
(Just s, Just c, Just m) -> return $ ErrorResponse s c m
_ -> error "Unreadable error response"
where readFields :: Get [(Word8, String)]
readFields = do f <- G.getWord8
case f of
0 -> return []
_ -> do s <- G.getLazyByteStringNul
f' <- readFields
return ((f,toString s):f')
'I' -> return EmptyQueryResponse
'n' -> return NoData
'N' -> return NoticeResponse
_ -> return UnknownMessage
pgSend :: Handle -> PGMessage -> IO ()
pgSend h msg = do
d <- debug
if d then putStrLn (encode msg) else return ()
hPut h (encode msg) >> hFlush h
pgReceive :: Handle -> IO PGMessage
pgReceive h = do
d <- debug
(typ, len) <- G.runGet getMessageHeader `liftM` hGet h 5
body <- hGet h (len 4)
if d
then do putStr (P.runPut (do P.putWord8 typ
P.putWord32be (fromIntegral len)))
putStrLn body
hFlush stdout
else return ()
let msg = decode $ cons typ (append (B.toLazyByteString $ B.putWord32be $ fromIntegral len) body)
case msg of
(ErrorResponse _ c m) -> throwIO (PGException c m)
_ -> return msg
pgWaitFor :: Handle
-> [Word8]
-> IO PGMessage
pgWaitFor h ids = do
response <- pgReceive h
if any (pgMessageID response ==) ids
then return response
else pgWaitFor h ids
describeStatement :: Handle
-> String
-> IO ([PGType], [(String, PGType, Bool)])
describeStatement h sql = do
pgSend h $ Parse sql ""
pgSend h $ Describe ""
pgSend h $ Flush
_ <- pgWaitFor h [pgMessageID ParseComplete]
(ParameterDescription ps) <- pgReceive h
m <- pgWaitFor h $ map c2w ['n', 'T']
case m of
NoData -> return (ps, [])
(RowDescription r) -> do
r' <- zipWith (\ (name, typ, _, _) n -> (name, typ, n)) r `liftM` mapM nullable r
return (ps, r')
_ -> error ""
where
nullable (_, _, oid, col) =
if oid == 0
then return True
else do r <- executeSimpleQuery ("SELECT attnotnull FROM pg_attribute WHERE attrelid = " ++ show oid ++ " AND attnum = " ++ show col) h
case r of
[[Just s]] -> return $ case toString s of
"t" -> False
"f" -> True
_ -> error "Unexpected result from PostgreSQL"
_ -> error $ "Can't determine nullability of column #" ++ show col
executeSimpleQuery :: String
-> Handle
-> IO ([[Maybe ByteString]])
executeSimpleQuery sql h = do
pgSend h $ SimpleQuery sql
m <- pgWaitFor h $ map c2w ['C', 'I', 'T']
case m of
EmptyQueryResponse -> return [[]]
(RowDescription _) -> readDataRows
_ -> error "executeSimpleQuery: Unexpected Message"
where readDataRows = do
m <- pgWaitFor h $ map c2w ['C', 'D']
case m of
CommandComplete -> return []
(DataRow fs) -> do rs <- readDataRows
return (fs:rs)
_ -> error ""
executeSimpleStatement :: String
-> Handle
-> IO ()
executeSimpleStatement sql h = do
pgSend h $ SimpleQuery sql
m <- pgWaitFor h $ map c2w ['C', 'I']
case m of
CommandComplete -> return ()
EmptyQueryResponse -> return ()
_ -> error "executeSimpleStatement: Unexpected Message"