{-# LANGUAGE CPP #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}
module Database.PostgreSQL.Pure.Internal.Connection
( connect
, disconnect
, withConnection
) where
import qualified Database.PostgreSQL.Pure.Internal.Builder as Builder
import Database.PostgreSQL.Pure.Internal.Data (Address (AddressNotResolved, AddressResolved),
AuthenticationMD5Password (AuthenticationMD5Password),
AuthenticationResponse (AuthenticationCleartextPasswordResponse, AuthenticationMD5PasswordResponse, AuthenticationOkResponse),
BackendKey, BackendKeyData (BackendKeyData),
BackendParameters, Buffer (Buffer),
Config (Config, address, database, password, receptionBufferSize, sendingBufferSize, user),
Connection (Connection, config, receptionBuffer, sendingBuffer, socket),
ParameterStatus (ParameterStatus), Pid,
ReadyForQuery (ReadyForQuery), Salt,
TransactionState (Idle))
import qualified Database.PostgreSQL.Pure.Internal.Exception as Exception
import qualified Database.PostgreSQL.Pure.Internal.Parser as Parser
import Database.PostgreSQL.Pure.Internal.SocketIO (SocketIO, buildAndSend, receive, runSocketIO, send)
import Control.Exception.Safe (assert, bracket)
import Control.Monad (void)
import Control.Monad.Reader (ask)
import qualified Data.Attoparsec.ByteString as AP
import qualified Data.ByteString as BS
import qualified Data.ByteString.Base16 as B16
import qualified Data.ByteString.Internal as BSI
import qualified Data.ByteString.UTF8 as BSU
import qualified Data.Map.Strict as Map
import qualified Network.Socket as NS
#ifdef PURE_MD5
import qualified Data.Digest.Pure.MD5 as MD5
#else
import qualified Crypto.Hash.MD5 as MD5
#endif
withConnection :: Config -> (Connection -> IO a) -> IO a
withConnection config@Config { address } f =
Exception.convert $ do
addr <-
case address of
AddressResolved a -> pure $ addrInfo a
AddressNotResolved h s -> resolve h s
bracket (open addr) NS.close $ \sock -> do
conn <- connect' sock config
f conn
connect :: Config -> IO Connection
connect config@Config { address } =
Exception.convert $ do
addr <-
case address of
AddressResolved a -> pure $ addrInfo a
AddressNotResolved h s -> resolve h s
sock <- open addr
connect' sock config
connect' :: NS.Socket -> Config -> IO Connection
connect' sock config@Config { sendingBufferSize, receptionBufferSize } = do
sBuff <- flip Buffer sendingBufferSize <$> BSI.mallocByteString sendingBufferSize
rBuff <- flip Buffer receptionBufferSize <$> BSI.mallocByteString receptionBufferSize
runSocketIO sock sBuff rBuff config initializeConnection
disconnect :: Connection -> IO ()
disconnect Connection { socket, sendingBuffer, receptionBuffer, config } =
Exception.convert $ do
runSocketIO socket sendingBuffer receptionBuffer config terminate
NS.close socket
addrInfoHints :: NS.AddrInfo
addrInfoHints =
NS.defaultHints
{ NS.addrSocketType = NS.Stream
, NS.addrProtocol = 6
, NS.addrFlags = [NS.AI_ADDRCONFIG]
}
addrInfo :: NS.SockAddr -> NS.AddrInfo
addrInfo address =
addrInfoHints
{ NS.addrAddress = address
, NS.addrFamily =
case address of
NS.SockAddrInet {} -> NS.AF_INET
NS.SockAddrInet6 {} -> NS.AF_INET6
NS.SockAddrUnix {} -> NS.AF_UNIX
#if !MIN_VERSION_network(3,0,0)
_ -> NS.AF_UNSPEC
#endif
}
resolve :: NS.HostName -> NS.ServiceName -> IO NS.AddrInfo
resolve host service = do
addrs <- NS.getAddrInfo (Just addrInfoHints) (Just host) (Just service)
case addrs of
addr:_ -> return addr
[] -> Exception.cantReachHere
open :: NS.AddrInfo -> IO NS.Socket
open addr = do
sock <- NS.socket (NS.addrFamily addr) (NS.addrSocketType addr) (NS.addrProtocol addr)
NS.connect sock $ NS.addrAddress addr
return sock
initializeConnection :: SocketIO Connection
initializeConnection = do
response <- startup
(bps, pid, bk) <- authenticate response
(sock, sBuff, rBuff, config) <- ask
pure $ Connection sock pid bk bps sBuff rBuff config
startup :: SocketIO AuthenticationResponse
startup = do
(_, _, _, Config { user, database }) <- ask
buildAndSend $ Builder.startup user database
receive Parser.authentication
authenticate :: AuthenticationResponse -> SocketIO (BackendParameters, Pid, BackendKey)
authenticate response = do
(_, _, _, Config { user, password }) <- ask
case response of
AuthenticationOkResponse -> pure ()
AuthenticationCleartextPasswordResponse -> auth $ BSU.fromString password
AuthenticationMD5PasswordResponse (AuthenticationMD5Password salt) -> auth $ hashMD5 user password salt
(bps, pid, bk) <-
receive $ do
bps <- Map.fromList . ((\(ParameterStatus k v) -> (k, v)) <$>) <$> AP.many' Parser.parameterStatus
BackendKeyData pid bk <- Parser.backendKeyData
ReadyForQuery ts <- Parser.readyForQuery
assert (ts == Idle) $ pure (bps, pid, bk)
pure (bps, pid, bk)
where
auth pw = do
buildAndSend $ Builder.password pw
void $ receive Parser.authenticationOk
terminate :: SocketIO ()
terminate = send Builder.terminate
hashMD5 :: String -> String -> Salt -> BS.ByteString
hashMD5 user password salt =
let
user' = BSU.fromString user
password' = BSU.fromString password
#ifdef PURE_MD5
hash = B16.encode . MD5.md5DigestBytes . MD5.hash'
#else
hash = B16.encode . MD5.hash
#endif
in
"md5" <> hash (hash (password' <> user') <> salt)