{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE ForeignFunctionInterface #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE ViewPatterns #-}
module Database.PostgreSQL.Typed.Protocol (
PGDatabase(..)
, defaultPGDatabase
, PGConnection
, PGError(..)
#ifdef VERSION_tls
, PGTlsMode(..)
, PGTlsValidateMode (..)
#endif
, pgErrorCode
, pgConnectionDatabase
, pgTypeEnv
, pgConnect
, pgDisconnect
, pgReconnect
, pgDescribe
, pgSimpleQuery
, pgSimpleQueries_
, pgPreparedQuery
, pgPreparedLazyQuery
, pgCloseStatement
, pgBegin
, pgCommit
, pgRollback
, pgCommitAll
, pgRollbackAll
, pgTransaction
, pgDisconnectOnce
, pgRun
, PGPreparedStatement
, pgPrepare
, pgClose
, PGColDescription(..)
, PGRowDescription
, pgBind
, pgFetch
, PGNotification(..)
, pgGetNotification
, pgGetNotifications
#ifdef VERSION_tls
, pgTlsValidate
#endif
, pgSupportsTls
) where
#if !MIN_VERSION_base(4,8,0)
import Control.Applicative ((<$>), (<$))
#endif
import Control.Arrow ((&&&), first, second)
import Control.Exception (Exception, onException, finally, throwIO)
#ifdef VERSION_tls
import Control.Exception (catch)
#endif
import Control.Monad (void, liftM2, replicateM, when, unless)
#ifdef VERSION_cryptonite
import qualified Crypto.Hash as Hash
import qualified Data.ByteArray.Encoding as BA
#endif
import qualified Data.Binary.Get as G
import qualified Data.ByteString as BS
import qualified Data.ByteString.Builder as B
import qualified Data.ByteString.Char8 as BSC
import Data.ByteString.Internal (w2c, createAndTrim)
import qualified Data.ByteString.Lazy as BSL
import qualified Data.ByteString.Lazy.Char8 as BSLC
import Data.ByteString.Lazy.Internal (smallChunkSize)
#ifdef VERSION_tls
import Data.Default (def)
#endif
import qualified Data.Foldable as Fold
import Data.IORef (IORef, newIORef, writeIORef, readIORef, atomicModifyIORef, atomicModifyIORef', modifyIORef')
import Data.Int (Int32, Int16)
import qualified Data.Map.Lazy as Map
import Data.Maybe (fromMaybe)
import Data.Monoid ((<>))
#if !MIN_VERSION_base(4,8,0)
import Data.Monoid (mempty)
#endif
import Data.Time.Clock (getCurrentTime)
import Data.Tuple (swap)
import Data.Typeable (Typeable)
#if !MIN_VERSION_base(4,8,0)
import Data.Word (Word)
#endif
import Data.Word (Word32, Word8)
#ifdef VERSION_tls
import Data.X509 (SignedCertificate, HashALG(HashSHA256))
import Data.X509.Memory (readSignedObjectFromMemory)
import Data.X509.CertificateStore (makeCertificateStore)
import qualified Data.X509.Validation
#endif
#ifndef mingw32_HOST_OS
import Foreign.C.Error (eWOULDBLOCK, getErrno, errnoToIOError)
import Foreign.C.Types (CChar(..), CInt(..), CSize(..))
import Foreign.Ptr (Ptr, castPtr)
import GHC.IO.Exception (IOErrorType(InvalidArgument))
#endif
import qualified Network.Socket as Net
import qualified Network.Socket.ByteString as NetBS
import qualified Network.Socket.ByteString.Lazy as NetBSL
#ifdef VERSION_tls
import qualified Network.TLS as TLS
import qualified Network.TLS.Extra.Cipher as TLS
#endif
import System.IO (stderr, hPutStrLn)
import System.IO.Error (IOError, mkIOError, eofErrorType, ioError, ioeSetErrorString)
import System.IO.Unsafe (unsafeInterleaveIO)
import Text.Read (readMaybe)
import Text.Show.Functions ()
import Database.PostgreSQL.Typed.Types
import Database.PostgreSQL.Typed.Dynamic
data PGState
= StateUnsync
| StatePending
| StateIdle
| StateTransaction
| StateTransactionFailed
| StateClosed
deriving (Show, Eq)
#ifdef VERSION_tls
data PGTlsValidateMode
= TlsValidateFull
| TlsValidateCA
deriving (Show, Eq)
data PGTlsMode
= TlsDisabled
| TlsNoValidate
| TlsValidate PGTlsValidateMode SignedCertificate
deriving (Eq, Show)
pgTlsValidate :: PGTlsValidateMode -> BSC.ByteString -> Either String PGTlsMode
pgTlsValidate mode certPem =
case readSignedObjectFromMemory certPem of
[] -> Left "Could not parse any certificate in PEM"
(x:_) -> Right (TlsValidate mode x)
pgSupportsTls :: PGConnection -> Bool
pgSupportsTls PGConnection{connHandle=PGTlsContext _} = True
pgSupportsTls _ = False
#else
pgSupportsTls :: PGConnection -> Bool
pgSupportsTls _ = False
#endif
data PGDatabase = PGDatabase
{ pgDBAddr :: Either (Net.HostName, Net.ServiceName) Net.SockAddr
, pgDBName :: BS.ByteString
, pgDBUser, pgDBPass :: BS.ByteString
, pgDBParams :: [(BS.ByteString, BS.ByteString)]
, pgDBDebug :: Bool
, pgDBLogMessage :: MessageFields -> IO ()
#ifdef VERSION_tls
, pgDBTLS :: PGTlsMode
#endif
} deriving (Show)
instance Eq PGDatabase where
#ifdef VERSION_tls
PGDatabase a1 n1 u1 p1 l1 _ _ s1 == PGDatabase a2 n2 u2 p2 l2 _ _ s2 =
a1 == a2 && n1 == n2 && u1 == u2 && p1 == p2 && l1 == l2 && s1 == s2
#else
PGDatabase a1 n1 u1 p1 l1 _ _ == PGDatabase a2 n2 u2 p2 l2 _ _ =
a1 == a2 && n1 == n2 && u1 == u2 && p1 == p2 && l1 == l2
#endif
newtype PGPreparedStatement = PGPreparedStatement Integer
deriving (Eq, Show)
preparedStatementName :: PGPreparedStatement -> BS.ByteString
preparedStatementName (PGPreparedStatement n) = BSC.pack $ show n
data PGHandle
= PGSocket Net.Socket
#ifdef VERSION_tls
| PGTlsContext TLS.Context
#endif
pgPutBuilder :: PGHandle -> B.Builder -> IO ()
pgPutBuilder (PGSocket s) b = NetBSL.sendAll s (B.toLazyByteString b)
#ifdef VERSION_tls
pgPutBuilder (PGTlsContext c) b = TLS.sendData c (B.toLazyByteString b)
#endif
pgPut:: PGHandle -> BS.ByteString -> IO ()
pgPut (PGSocket s) bs = NetBS.sendAll s bs
#ifdef VERSION_tls
pgPut (PGTlsContext c) bs = TLS.sendData c (BSL.fromChunks [bs])
#endif
pgGetSome :: PGHandle -> Int -> IO BSC.ByteString
pgGetSome (PGSocket s) count = NetBS.recv s count
#ifdef VERSION_tls
pgGetSome (PGTlsContext c) _ = TLS.recvData c
#endif
pgCloseHandle :: PGHandle -> IO ()
pgCloseHandle (PGSocket s) = Net.close s
#ifdef VERSION_tls
pgCloseHandle (PGTlsContext c) = do
TLS.bye c `catch` \(_ :: IOError) -> pure ()
TLS.contextClose c
#endif
pgFlush :: PGConnection -> IO ()
pgFlush PGConnection{connHandle=PGSocket _} = pure ()
#ifdef VERSION_tls
pgFlush PGConnection{connHandle=PGTlsContext c} = TLS.contextFlush c
#endif
data PGConnection = PGConnection
{ connHandle :: PGHandle
, connDatabase :: !PGDatabase
, connPid :: !Word32
, connKey :: !Word32
, connTypeEnv :: PGTypeEnv
, connParameters :: IORef (Map.Map BS.ByteString BS.ByteString)
, connPreparedStatementCount :: IORef Integer
, connPreparedStatementMap :: IORef (Map.Map (BS.ByteString, [OID]) PGPreparedStatement)
, connState :: IORef PGState
, connInput :: IORef (G.Decoder PGBackendMessage)
, connTransaction :: IORef Word
, connNotifications :: IORef (Queue PGNotification)
}
data PGColDescription = PGColDescription
{ pgColName :: BS.ByteString
, pgColTable :: !OID
, pgColNumber :: !Int16
, pgColType :: !OID
, pgColSize :: !Int16
, pgColModifier :: !Int32
, pgColBinary :: !Bool
} deriving (Show)
type PGRowDescription = [PGColDescription]
type MessageFields = Map.Map Char BS.ByteString
data PGNotification = PGNotification
{ pgNotificationPid :: !Word32
, pgNotificationChannel :: !BS.ByteString
, pgNotificationPayload :: BSL.ByteString
} deriving (Show)
data Queue a = Queue [a] [a]
emptyQueue :: Queue a
emptyQueue = Queue [] []
enQueue :: a -> Queue a -> Queue a
enQueue a (Queue e d) = Queue (a:e) d
deQueue :: Queue a -> (Queue a, Maybe a)
deQueue (Queue e (x:d)) = (Queue e d, Just x)
deQueue (Queue (reverse -> x:d) []) = (Queue [] d, Just x)
deQueue q = (q, Nothing)
data PGFrontendMessage
= StartupMessage [(BS.ByteString, BS.ByteString)]
| CancelRequest !Word32 !Word32
| Bind { portalName :: BS.ByteString, statementName :: BS.ByteString, bindParameters :: PGValues, binaryColumns :: [Bool] }
| CloseStatement { statementName :: BS.ByteString }
| ClosePortal { portalName :: BS.ByteString }
| DescribeStatement { statementName :: BS.ByteString }
| DescribePortal { portalName :: BS.ByteString }
| Execute { portalName :: BS.ByteString, executeRows :: !Word32 }
| Flush
| Parse { statementName :: BS.ByteString, queryString :: BSL.ByteString, parseTypes :: [OID] }
| PasswordMessage BS.ByteString
| SimpleQuery { queryString :: BSL.ByteString }
| Sync
| Terminate
deriving (Show)
data PGBackendMessage
= AuthenticationOk
| AuthenticationCleartextPassword
| AuthenticationMD5Password BS.ByteString
| BackendKeyData Word32 Word32
| BindComplete
| CloseComplete
| CommandComplete BS.ByteString
| DataRow PGValues
| EmptyQueryResponse
| ErrorResponse { messageFields :: MessageFields }
| NoData
| NoticeResponse { messageFields :: MessageFields }
| NotificationResponse PGNotification
| ParameterDescription [OID]
| ParameterStatus BS.ByteString BS.ByteString
| ParseComplete
| PortalSuspended
| ReadyForQuery PGState
| RowDescription PGRowDescription
deriving (Show)
newtype PGError = PGError { pgErrorFields :: MessageFields }
deriving (Typeable)
instance Show PGError where
show (PGError m) = displayMessage m
instance Exception PGError
displayMessage :: MessageFields -> String
displayMessage m = "PG" ++ f 'S' ++ (if null fC then ": " else " [" ++ fC ++ "]: ") ++ f 'M' ++ (if null fD then fD else '\n' : fD)
where
fC = f 'C'
fD = f 'D'
f c = BSC.unpack $ Map.findWithDefault BS.empty c m
makeMessage :: BS.ByteString -> BS.ByteString -> MessageFields
makeMessage m d = Map.fromAscList [('D', d), ('M', m)]
pgErrorCode :: PGError -> BS.ByteString
pgErrorCode (PGError e) = Map.findWithDefault BS.empty 'C' e
defaultLogMessage :: MessageFields -> IO ()
defaultLogMessage = hPutStrLn stderr . displayMessage
defaultPGDatabase :: PGDatabase
defaultPGDatabase = PGDatabase
{ pgDBAddr = Right $ Net.SockAddrInet 5432 (Net.tupleToHostAddress (127,0,0,1))
, pgDBName = "postgres"
, pgDBUser = "postgres"
, pgDBPass = BS.empty
, pgDBParams = []
, pgDBDebug = False
, pgDBLogMessage = defaultLogMessage
#ifdef VERSION_tls
, pgDBTLS = TlsDisabled
#endif
}
connDebugMsg :: PGConnection -> String -> IO ()
connDebugMsg c msg = when (pgDBDebug $ connDatabase c) $ do
t <- getCurrentTime
hPutStrLn stderr $ show t ++ msg
connLogMessage :: PGConnection -> MessageFields -> IO ()
connLogMessage = pgDBLogMessage . connDatabase
-- |The database information for this connection.
pgConnectionDatabase :: PGConnection -> PGDatabase
pgConnectionDatabase = connDatabase
-- |The type environment for this connection.
pgTypeEnv :: PGConnection -> PGTypeEnv
pgTypeEnv = connTypeEnv
#ifdef VERSION_cryptonite
md5 :: BS.ByteString -> BS.ByteString
md5 = BA.convertToBase BA.Base16 . (Hash.hash :: BS.ByteString -> Hash.Digest Hash.MD5)
#endif
nul :: B.Builder
nul = B.word8 0
byteStringNul :: BS.ByteString -> B.Builder
byteStringNul s = B.byteString s <> nul
lazyByteStringNul :: BSL.ByteString -> B.Builder
lazyByteStringNul s = B.lazyByteString s <> nul
messageBody :: PGFrontendMessage -> (Maybe Char, B.Builder)
messageBody (StartupMessage kv) = (Nothing, B.word32BE 0x30000
<> Fold.foldMap (\(k, v) -> byteStringNul k <> byteStringNul v) kv <> nul)
messageBody (CancelRequest pid key) = (Nothing, B.word32BE 80877102
<> B.word32BE pid <> B.word32BE key)
messageBody Bind{ portalName = d, statementName = n, bindParameters = p, binaryColumns = bc } = (Just 'B',
byteStringNul d
<> byteStringNul n
<> (if any fmt p
then B.word16BE (fromIntegral $ length p) <> Fold.foldMap (B.word16BE . fromIntegral . fromEnum . fmt) p
else B.word16BE 0)
<> B.word16BE (fromIntegral $ length p) <> Fold.foldMap val p
<> (if or bc
then B.word16BE (fromIntegral $ length bc) <> Fold.foldMap (B.word16BE . fromIntegral . fromEnum) bc
else B.word16BE 0))
where
fmt (PGBinaryValue _) = True
fmt _ = False
val PGNullValue = B.int32BE (-1)
val (PGTextValue v) = B.word32BE (fromIntegral $ BS.length v) <> B.byteString v
val (PGBinaryValue v) = B.word32BE (fromIntegral $ BS.length v) <> B.byteString v
messageBody CloseStatement{ statementName = n } = (Just 'C',
B.char7 'S' <> byteStringNul n)
messageBody ClosePortal{ portalName = n } = (Just 'C',
B.char7 'P' <> byteStringNul n)
messageBody DescribeStatement{ statementName = n } = (Just 'D',
B.char7 'S' <> byteStringNul n)
messageBody DescribePortal{ portalName = n } = (Just 'D',
B.char7 'P' <> byteStringNul n)
messageBody Execute{ portalName = n, executeRows = r } = (Just 'E',
byteStringNul n <> B.word32BE r)
messageBody Flush = (Just 'H', mempty)
messageBody Parse{ statementName = n, queryString = s, parseTypes = t } = (Just 'P',
byteStringNul n <> lazyByteStringNul s
<> B.word16BE (fromIntegral $ length t) <> Fold.foldMap B.word32BE t)
messageBody (PasswordMessage s) = (Just 'p',
B.byteString s <> nul)
messageBody SimpleQuery{ queryString = s } = (Just 'Q',
lazyByteStringNul s)
messageBody Sync = (Just 'S', mempty)
messageBody Terminate = (Just 'X', mempty)
pgSend :: PGConnection -> PGFrontendMessage -> IO ()
pgSend c@PGConnection{ connHandle = h, connState = sr } msg = do
modifyIORef' sr $ state msg
connDebugMsg c $ "> " ++ show msg
pgPutBuilder h $ Fold.foldMap B.char7 t <> B.word32BE (fromIntegral $ 4 + BS.length b)
pgPut h b
where
(t, b) = second (BSL.toStrict . B.toLazyByteString) $ messageBody msg
state _ StateClosed = StateClosed
state Sync _ = StatePending
state SimpleQuery{} _ = StatePending
state Terminate _ = StateClosed
state _ _ = StateUnsync
getByteStringNul :: G.Get BS.ByteString
getByteStringNul = fmap BSL.toStrict G.getLazyByteStringNul
getMessageFields :: G.Get MessageFields
getMessageFields = g . w2c =<< G.getWord8 where
g '\0' = return Map.empty
g f = liftM2 (Map.insert f) getByteStringNul getMessageFields
getMessageBody :: Char -> G.Get PGBackendMessage
getMessageBody 'R' = auth =<< G.getWord32be where
auth 0 = return AuthenticationOk
auth 3 = return AuthenticationCleartextPassword
auth 5 = AuthenticationMD5Password <$> G.getByteString 4
auth op = fail $ "pgGetMessage: unsupported authentication type: " ++ show op
getMessageBody 't' = do
numParams <- G.getWord16be
ParameterDescription <$> replicateM (fromIntegral numParams) G.getWord32be
getMessageBody 'T' = do
numFields <- G.getWord16be
RowDescription <$> replicateM (fromIntegral numFields) getField where
getField = do
name <- getByteStringNul
oid <- G.getWord32be
col <- G.getWord16be
typ' <- G.getWord32be
siz <- G.getWord16be
tmod <- G.getWord32be
fmt <- G.getWord16be
return $ PGColDescription
{ pgColName = name
, pgColTable = oid
, pgColNumber = fromIntegral col
, pgColType = typ'
, pgColSize = fromIntegral siz
, pgColModifier = fromIntegral tmod
, pgColBinary = toEnum (fromIntegral fmt)
}
getMessageBody 'Z' = ReadyForQuery <$> (rs . w2c =<< G.getWord8) where
rs 'I' = return StateIdle
rs 'T' = return StateTransaction
rs 'E' = return StateTransactionFailed
rs s = fail $ "pgGetMessage: unknown ready state: " ++ show s
getMessageBody '1' = return ParseComplete
getMessageBody '2' = return BindComplete
getMessageBody '3' = return CloseComplete
getMessageBody 'C' = CommandComplete <$> getByteStringNul
getMessageBody 'S' = liftM2 ParameterStatus getByteStringNul getByteStringNul
getMessageBody 'D' = do
numFields <- G.getWord16be
DataRow <$> replicateM (fromIntegral numFields) (getField =<< G.getWord32be) where
getField 0xFFFFFFFF = return PGNullValue
getField len = PGTextValue <$> G.getByteString (fromIntegral len)
getMessageBody 'K' = liftM2 BackendKeyData G.getWord32be G.getWord32be
getMessageBody 'E' = ErrorResponse <$> getMessageFields
getMessageBody 'I' = return EmptyQueryResponse
getMessageBody 'n' = return NoData
getMessageBody 's' = return PortalSuspended
getMessageBody 'N' = NoticeResponse <$> getMessageFields
getMessageBody 'A' = NotificationResponse <$> do
PGNotification
<$> G.getWord32be
<*> getByteStringNul
<*> G.getLazyByteStringNul
getMessageBody t = fail $ "pgGetMessage: unknown message type: " ++ show t
getMessage :: G.Decoder PGBackendMessage
getMessage = G.runGetIncremental $ do
typ <- G.getWord8
len <- G.getWord32be
G.isolate (fromIntegral len - 4) $ getMessageBody (w2c typ)
class Show m => RecvMsg m where
recvMsgData :: PGConnection -> IO (Either m BS.ByteString)
recvMsgData c = do
r <- pgGetSome (connHandle c) smallChunkSize
if BS.null r
then do
writeIORef (connState c) StateClosed
pgCloseHandle (connHandle c)
ioError $ mkIOError eofErrorType "PGConnection" Nothing Nothing
else
return (Right r)
recvMsgSync :: Maybe m
recvMsgSync = Nothing
recvMsgNotif :: PGConnection -> PGNotification -> IO (Maybe m)
recvMsgNotif c n = Nothing <$
modifyIORef' (connNotifications c) (enQueue n)
recvMsgErr :: PGConnection -> MessageFields -> IO (Maybe m)
recvMsgErr c m = Nothing <$
connLogMessage c m
recvMsg :: PGConnection -> PGBackendMessage -> IO (Maybe m)
recvMsg c m = Nothing <$
connLogMessage c (makeMessage (BSC.pack $ "Unexpected server message: " ++ show m) "Each statement should only contain a single query")
data RecvNonBlock = RecvNonBlock deriving (Show)
instance RecvMsg RecvNonBlock where
#ifndef mingw32_HOST_OS
recvMsgData PGConnection{connHandle=PGSocket s} = do
r <- recvNonBlock s smallChunkSize
if BS.null r
then return (Left RecvNonBlock)
else return (Right r)
#else
recvMsgData PGConnection{connHandle=PGSocket _} =
throwIO (userError "Non-blocking recvMsgData is not supported on mingw32 ATM")
#endif
#ifdef VERSION_tls
recvMsgData PGConnection{connHandle=PGTlsContext _} =
throwIO (userError "Non-blocking recvMsgData is not supported on TLS connections")
#endif
data RecvSync = RecvSync deriving (Show)
instance RecvMsg RecvSync where
recvMsgSync = Just RecvSync
instance RecvMsg PGNotification where
recvMsgNotif _ = return . Just
instance RecvMsg PGBackendMessage where
recvMsgErr _ = throwIO . PGError
recvMsg _ = return . Just
instance RecvMsg (Either PGBackendMessage RecvSync) where
recvMsgSync = Just $ Right RecvSync
recvMsgErr _ = throwIO . PGError
recvMsg _ = return . Just . Left
pgRecv :: RecvMsg m => PGConnection -> IO m
pgRecv c@PGConnection{ connInput = dr, connState = sr } =
rcv =<< readIORef dr where
next = writeIORef dr
new = G.pushChunk getMessage
rcv (G.Done b _ m) = do
connDebugMsg c $ "< " ++ show m
got (new b) m
rcv (G.Fail _ _ r) = next (new BS.empty) >> fail r
rcv d@(G.Partial r) = recvMsgData c `onException` next d >>=
either (<$ next d) (rcv . r . Just)
msg (ParameterStatus k v) = Nothing <$
modifyIORef' (connParameters c) (Map.insert k v)
msg (NoticeResponse m) = Nothing <$
connLogMessage c m
msg (ErrorResponse m) =
recvMsgErr c m
msg m@(ReadyForQuery s) = do
s' <- atomicModifyIORef' sr (s, )
if s' == StatePending
then return recvMsgSync
else recvMsg c m
msg (NotificationResponse n) =
recvMsgNotif c n
msg m@AuthenticationOk = do
writeIORef sr StatePending
recvMsg c m
msg m = recvMsg c m
got d m = msg m `onException` next d >>=
maybe (rcv d) (<$ next d)
pgConnect :: PGDatabase -> IO PGConnection
pgConnect db = do
param <- newIORef Map.empty
state <- newIORef StateUnsync
prepc <- newIORef 0
prepm <- newIORef Map.empty
input <- newIORef getMessage
tr <- newIORef 0
notif <- newIORef emptyQueue
addr <- either
(\(h,p) -> head <$> Net.getAddrInfo (Just defai) (Just h) (Just p))
(\a -> return defai{ Net.addrAddress = a, Net.addrFamily = case a of
Net.SockAddrInet{} -> Net.AF_INET
Net.SockAddrInet6{} -> Net.AF_INET6
Net.SockAddrUnix{} -> Net.AF_UNIX
_ -> Net.AF_UNSPEC })
$ pgDBAddr db
sock <- Net.socket (Net.addrFamily addr) (Net.addrSocketType addr) (Net.addrProtocol addr)
unless (Net.addrFamily addr == Net.AF_UNIX) $ Net.setSocketOption sock Net.NoDelay 1
Net.connect sock $ Net.addrAddress addr
pgHandle <- mkPGHandle db sock
let c = PGConnection
{ connHandle = pgHandle
, connDatabase = db
, connPid = 0
, connKey = 0
, connParameters = param
, connPreparedStatementCount = prepc
, connPreparedStatementMap = prepm
, connState = state
, connTypeEnv = unknownPGTypeEnv
, connInput = input
, connTransaction = tr
, connNotifications = notif
}
pgSend c $ StartupMessage $
[ ("user", pgDBUser db)
, ("database", pgDBName db)
, ("client_encoding", "UTF8")
, ("standard_conforming_strings", "on")
, ("bytea_output", "hex")
, ("DateStyle", "ISO, YMD")
, ("IntervalStyle", "iso_8601")
] ++ pgDBParams db
pgFlush c
conn c
where
defai = Net.defaultHints{ Net.addrSocketType = Net.Stream }
conn c = pgRecv c >>= msg c
msg c (Right RecvSync) = do
cp <- readIORef (connParameters c)
return c
{ connTypeEnv = PGTypeEnv
{ pgIntegerDatetimes = fmap ("on" ==) $ Map.lookup "integer_datetimes" cp
, pgServerVersion = Map.lookup "server_version" cp
}
}
msg c (Left (BackendKeyData p k)) = conn c{ connPid = p, connKey = k }
msg c (Left AuthenticationOk) = conn c
msg c (Left AuthenticationCleartextPassword) = do
pgSend c $ PasswordMessage $ pgDBPass db
pgFlush c
conn c
#ifdef VERSION_cryptonite
msg c (Left (AuthenticationMD5Password salt)) = do
pgSend c $ PasswordMessage $ "md5" `BS.append` md5 (md5 (pgDBPass db <> pgDBUser db) `BS.append` salt)
pgFlush c
conn c
#endif
msg _ (Left m) = fail $ "pgConnect: unexpected response: " ++ show m
mkPGHandle :: PGDatabase -> Net.Socket -> IO PGHandle
#ifdef VERSION_tls
mkPGHandle db sock =
case pgDBTLS db of
TlsDisabled -> pure (PGSocket sock)
TlsNoValidate -> mkTlsContext
TlsValidate _ _ -> mkTlsContext
where
mkTlsContext = do
NetBSL.sendAll sock sslRequest
resp <- NetBS.recv sock 1
case resp of
"S" -> do
ctx <- TLS.contextNew sock params
void $ TLS.handshake ctx
pure $ PGTlsContext ctx
"N" -> throwIO (userError "Server does not support TLS")
_ -> throwIO (userError "Unexpected response from server when issuing SSLRequest")
params = (TLS.defaultParamsClient tlsHost tlsPort)
{ TLS.clientSupported =
def { TLS.supportedCiphers = TLS.ciphersuite_strong }
, TLS.clientShared = clientShared
, TLS.clientHooks = clientHooks
}
tlsHost = case pgDBAddr db of
Left (h,_) -> h
Right (Net.SockAddrUnix s) -> s
Right _ -> "some-socket"
tlsPort = case pgDBAddr db of
Left (_,p) -> BSC.pack p
Right _ -> "socket"
clientShared =
case pgDBTLS db of
TlsDisabled -> def { TLS.sharedValidationCache = noValidate }
TlsNoValidate -> def { TLS.sharedValidationCache = noValidate }
TlsValidate _ sc -> def { TLS.sharedCAStore = makeCertificateStore [sc] }
clientHooks =
case pgDBTLS db of
TlsValidate TlsValidateCA _ -> def { TLS.onServerCertificate = validateNoCheckFQHN }
_ -> def
validateNoCheckFQHN = Data.X509.Validation.validate HashSHA256 def (def { TLS.checkFQHN = False })
noValidate = TLS.ValidationCache
(\_ _ _ -> return TLS.ValidationCachePass)
(\_ _ _ -> return ())
sslRequest = B.toLazyByteString (B.word32BE 8 <> B.word32BE 80877103)
#else
mkPGHandle _ sock = pure (PGSocket sock)
#endif
pgDisconnect :: PGConnection
-> IO ()
pgDisconnect c@PGConnection{ connHandle = h } =
pgSend c Terminate `finally` pgCloseHandle h
pgDisconnectOnce :: PGConnection
-> IO ()
pgDisconnectOnce c@PGConnection{ connState = cs } = do
s <- readIORef cs
unless (s == StateClosed) $
pgDisconnect c
pgReconnect :: PGConnection -> PGDatabase -> IO PGConnection
pgReconnect c@PGConnection{ connDatabase = cd, connState = cs } d = do
s <- readIORef cs
if cd == d && s /= StateClosed
then return c{ connDatabase = d }
else do
pgDisconnectOnce c
pgConnect d
pgSync :: PGConnection -> IO ()
pgSync c@PGConnection{ connState = sr } = do
s <- readIORef sr
case s of
StateClosed -> fail "pgSync: operation on closed connection"
StatePending -> wait
StateUnsync -> do
pgSend c Sync
pgFlush c
wait
_ -> return ()
where
wait = do
RecvSync <- pgRecv c
return ()
rowDescription :: PGBackendMessage -> PGRowDescription
rowDescription (RowDescription d) = d
rowDescription NoData = []
rowDescription m = error $ "describe: unexpected response: " ++ show m
pgDescribe :: PGConnection -> BSL.ByteString
-> [OID]
-> Bool
-> IO ([OID], [(BS.ByteString, OID, Bool)])
pgDescribe h sql types nulls = do
pgSync h
pgSend h Parse{ queryString = sql, statementName = BS.empty, parseTypes = types }
pgSend h DescribeStatement{ statementName = BS.empty }
pgSend h Sync
pgFlush h
ParseComplete <- pgRecv h
ParameterDescription ps <- pgRecv h
(,) ps <$> (mapM desc . rowDescription =<< pgRecv h)
where
desc (PGColDescription{ pgColName = name, pgColTable = tab, pgColNumber = col, pgColType = typ }) = do
n <- nullable tab col
return (name, typ, n)
nullable oid col
| nulls && oid /= 0 = do
(_, r) <- pgPreparedQuery h "SELECT attnotnull FROM pg_catalog.pg_attribute WHERE attrelid = $1 AND attnum = $2" [26, 21] [pgEncodeRep (oid :: OID), pgEncodeRep (col :: Int16)] []
case r of
[[s]] -> return $ not $ pgDecodeRep s
[] -> return True
_ -> fail $ "Failed to determine nullability of column #" ++ show col
| otherwise = return True
rowsAffected :: (Integral i, Read i) => BS.ByteString -> i
rowsAffected = ra . BSC.words where
ra [] = -1
ra l = fromMaybe (-1) $ readMaybe $ BSC.unpack $ last l
fixBinary :: [Bool] -> PGValues -> PGValues
fixBinary (False:b) (PGBinaryValue x:r) = PGTextValue x : fixBinary b r
fixBinary (True :b) (PGTextValue x:r) = PGBinaryValue x : fixBinary b r
fixBinary (_:b) (x:r) = x : fixBinary b r
fixBinary _ l = l
pgSimpleQuery :: PGConnection -> BSL.ByteString
-> IO (Int, [PGValues])
pgSimpleQuery h sql = do
pgSync h
pgSend h $ SimpleQuery sql
pgFlush h
go start where
go = (pgRecv h >>=)
start (RowDescription rd) = go $ row (map pgColBinary rd) id
start (CommandComplete c) = got c []
start EmptyQueryResponse = return (0, [])
start m = fail $ "pgSimpleQuery: unexpected response: " ++ show m
row bc r (DataRow fs) = go $ row bc (r . (fixBinary bc fs :))
row _ r (CommandComplete c) = got c (r [])
row _ _ m = fail $ "pgSimpleQuery: unexpected row: " ++ show m
got c r = return (rowsAffected c, r)
pgSimpleQueries_ :: PGConnection -> BSL.ByteString
-> IO ()
pgSimpleQueries_ h sql = do
pgSync h
pgSend h $ SimpleQuery sql
pgFlush h
go where
go = pgRecv h >>= res
res (Left (RowDescription _)) = go
res (Left (CommandComplete _)) = go
res (Left EmptyQueryResponse) = go
res (Left (DataRow _)) = go
res (Right RecvSync) = return ()
res m = fail $ "pgSimpleQueries_: unexpected response: " ++ show m
pgPreparedBind :: PGConnection -> BS.ByteString -> [OID] -> PGValues -> [Bool] -> IO (IO ())
pgPreparedBind c sql types bind bc = do
pgSync c
m <- readIORef (connPreparedStatementMap c)
(p, n) <- maybe
(atomicModifyIORef' (connPreparedStatementCount c) (succ &&& (,) False . PGPreparedStatement))
(return . (,) True) $ Map.lookup key m
unless p $
pgSend c Parse{ queryString = BSL.fromStrict sql, statementName = preparedStatementName n, parseTypes = types }
pgSend c Bind{ portalName = BS.empty, statementName = preparedStatementName n, bindParameters = bind, binaryColumns = bc }
let
go = pgRecv c >>= start
start ParseComplete = do
modifyIORef' (connPreparedStatementMap c) $
Map.insert key n
go
start BindComplete = return ()
start r = fail $ "pgPrepared: unexpected response: " ++ show r
return go
where key = (sql, types)
pgPreparedQuery :: PGConnection -> BS.ByteString
-> [OID]
-> PGValues
-> [Bool]
-> IO (Int, [PGValues])
pgPreparedQuery c sql types bind bc = do
start <- pgPreparedBind c sql types bind bc
pgSend c Execute{ portalName = BS.empty, executeRows = 0 }
pgSend c Sync
pgFlush c
start
go id
where
go r = pgRecv c >>= row r
row r (DataRow fs) = go (r . (fixBinary bc fs :))
row r (CommandComplete d) = return (rowsAffected d, r [])
row r EmptyQueryResponse = return (0, r [])
row _ m = fail $ "pgPreparedQuery: unexpected row: " ++ show m
pgPreparedLazyQuery :: PGConnection -> BS.ByteString -> [OID] -> PGValues -> [Bool] -> Word32
-> IO [PGValues]
pgPreparedLazyQuery c sql types bind bc count = do
start <- pgPreparedBind c sql types bind bc
unsafeInterleaveIO $ do
execute
start
go id
where
execute = do
pgSend c Execute{ portalName = BS.empty, executeRows = count }
pgSend c Flush
pgFlush c
go r = pgRecv c >>= row r
row r (DataRow fs) = go (r . (fixBinary bc fs :))
row r PortalSuspended = r <$> unsafeInterleaveIO (execute >> go id)
row r (CommandComplete _) = return (r [])
row r EmptyQueryResponse = return (r [])
row _ m = fail $ "pgPreparedLazyQuery: unexpected row: " ++ show m
pgCloseStatement :: PGConnection -> BS.ByteString -> [OID] -> IO ()
pgCloseStatement c sql types = do
mn <- atomicModifyIORef (connPreparedStatementMap c) $
swap . Map.updateLookupWithKey (\_ _ -> Nothing) (sql, types)
Fold.mapM_ (pgClose c) mn
pgBegin :: PGConnection -> IO ()
pgBegin c@PGConnection{ connTransaction = tr } = do
t <- atomicModifyIORef' tr (succ &&& id)
void $ pgSimpleQuery c $ BSLC.pack $ if t == 0 then "BEGIN" else "SAVEPOINT pgt" ++ show t
predTransaction :: Word -> (Word, Word)
predTransaction 0 = (0, error "pgTransaction: no transactions")
predTransaction x = (x', x') where x' = pred x
pgRollback :: PGConnection -> IO ()
pgRollback c@PGConnection{ connTransaction = tr } = do
t <- atomicModifyIORef' tr predTransaction
void $ pgSimpleQuery c $ BSLC.pack $ if t == 0 then "ROLLBACK" else "ROLLBACK TO SAVEPOINT pgt" ++ show t
pgCommit :: PGConnection -> IO ()
pgCommit c@PGConnection{ connTransaction = tr } = do
t <- atomicModifyIORef' tr predTransaction
void $ pgSimpleQuery c $ BSLC.pack $ if t == 0 then "COMMIT" else "RELEASE SAVEPOINT pgt" ++ show t
pgRollbackAll :: PGConnection -> IO ()
pgRollbackAll c@PGConnection{ connTransaction = tr } = do
writeIORef tr 0
void $ pgSimpleQuery c $ BSLC.pack "ROLLBACK"
pgCommitAll :: PGConnection -> IO ()
pgCommitAll c@PGConnection{ connTransaction = tr } = do
writeIORef tr 0
void $ pgSimpleQuery c $ BSLC.pack "COMMIT"
pgTransaction :: PGConnection -> IO a -> IO a
pgTransaction c f = do
pgBegin c
onException (do
r <- f
pgCommit c
return r)
(pgRollback c)
pgRun :: PGConnection -> BSL.ByteString -> [OID] -> PGValues -> IO (Maybe Integer)
pgRun c sql types bind = do
pgSync c
pgSend c Parse{ queryString = sql, statementName = BS.empty, parseTypes = types }
pgSend c Bind{ portalName = BS.empty, statementName = BS.empty, bindParameters = bind, binaryColumns = [] }
pgSend c Execute{ portalName = BS.empty, executeRows = 1 }
pgSend c Sync
pgFlush c
go where
go = pgRecv c >>= res
res ParseComplete = go
res BindComplete = go
res (DataRow _) = go
res PortalSuspended = return Nothing
res (CommandComplete d) = return (Just $ rowsAffected d)
res EmptyQueryResponse = return (Just 0)
res m = fail $ "pgRun: unexpected response: " ++ show m
pgPrepare :: PGConnection -> BSL.ByteString -> [OID] -> IO PGPreparedStatement
pgPrepare c sql types = do
n <- atomicModifyIORef' (connPreparedStatementCount c) (succ &&& PGPreparedStatement)
pgSync c
pgSend c Parse{ queryString = sql, statementName = preparedStatementName n, parseTypes = types }
pgSend c Sync
pgFlush c
ParseComplete <- pgRecv c
return n
pgClose :: PGConnection -> PGPreparedStatement -> IO ()
pgClose c n = do
pgSync c
pgSend c ClosePortal{ portalName = preparedStatementName n }
pgSend c CloseStatement{ statementName = preparedStatementName n }
pgSend c Sync
pgFlush c
CloseComplete <- pgRecv c
CloseComplete <- pgRecv c
return ()
pgBind :: PGConnection -> PGPreparedStatement -> PGValues -> IO PGRowDescription
pgBind c n bind = do
pgSync c
pgSend c ClosePortal{ portalName = sn }
pgSend c Bind{ portalName = sn, statementName = sn, bindParameters = bind, binaryColumns = [] }
pgSend c DescribePortal{ portalName = sn }
pgSend c Sync
pgFlush c
CloseComplete <- pgRecv c
BindComplete <- pgRecv c
rowDescription <$> pgRecv c
where sn = preparedStatementName n
pgFetch :: PGConnection -> PGPreparedStatement -> Word32
-> IO ([PGValues], Maybe Integer)
pgFetch c n count = do
pgSync c
pgSend c Execute{ portalName = preparedStatementName n, executeRows = count }
pgSend c Sync
pgFlush c
go where
go = pgRecv c >>= res
res (DataRow v) = first (v :) <$> go
res PortalSuspended = return ([], Nothing)
res (CommandComplete d) = do
pgSync c
pgSend c ClosePortal{ portalName = preparedStatementName n }
pgSend c Sync
pgFlush c
CloseComplete <- pgRecv c
return ([], Just $ rowsAffected d)
res EmptyQueryResponse = return ([], Just 0)
res m = fail $ "pgFetch: unexpected response: " ++ show m
pgGetNotification :: PGConnection -> IO PGNotification
pgGetNotification c =
maybe (pgRecv c) return
=<< atomicModifyIORef' (connNotifications c) deQueue
pgGetNotifications :: PGConnection -> IO [PGNotification]
pgGetNotifications c = do
RecvNonBlock <- pgRecv c
queueToList <$> atomicModifyIORef' (connNotifications c) (emptyQueue, )
where
queueToList :: Queue a -> [a]
queueToList (Queue e d) = d ++ reverse e
#ifndef mingw32_HOST_OS
recvNonBlock
:: Net.Socket
-> Int
-> IO BS.ByteString
recvNonBlock s nbytes
| nbytes < 0 = ioError (mkInvalidRecvArgError "Database.PostgreSQL.Typed.Protocol.recvNonBlock")
| otherwise = createAndTrim nbytes $ \ptr -> recvBufNonBlock s ptr nbytes
recvBufNonBlock :: Net.Socket -> Ptr Word8 -> Int -> IO Int
recvBufNonBlock s ptr nbytes
| nbytes <= 0 = ioError (mkInvalidRecvArgError "Database.PostgreSQL.Typed.recvBufNonBlock")
| otherwise = do
len <-
#if MIN_VERSION_network(3,1,0)
Net.withFdSocket s $ \fd ->
#elif MIN_VERSION_network(3,0,0)
Net.fdSocket s >>= \fd ->
#else
let fd = Net.fdSocket s in
#endif
c_recv fd (castPtr ptr) (fromIntegral nbytes) 0
if len == -1
then do
errno <- getErrno
if errno == eWOULDBLOCK
then return 0
else throwIO (errnoToIOError "recvBufNonBlock" errno Nothing (Just "Database.PostgreSQL.Typed"))
else
return $ fromIntegral len
mkInvalidRecvArgError :: String -> IOError
mkInvalidRecvArgError loc = ioeSetErrorString (mkIOError
InvalidArgument
loc Nothing Nothing) "non-positive length"
foreign import ccall unsafe "recv"
c_recv :: CInt -> Ptr CChar -> CSize -> CInt -> IO CInt
#endif