{-# LANGUAGE CPP               #-}
{-# LANGUAGE NamedFieldPuns    #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PackageImports    #-} -- for doctest

module Database.PostgreSQL.Pure.Internal.Connection
  ( connect
  , disconnect
  , withConnection
  ) where

import qualified Database.PostgreSQL.Pure.Internal.Builder   as Builder
import           Database.PostgreSQL.Pure.Internal.Data      (Address (AddressNotResolved, AddressResolved),
                                                              AuthenticationMD5Password (AuthenticationMD5Password),
                                                              AuthenticationResponse (AuthenticationCleartextPasswordResponse, AuthenticationMD5PasswordResponse, AuthenticationOkResponse),
                                                              BackendKey, BackendKeyData (BackendKeyData),
                                                              BackendParameters, Buffer (Buffer),
                                                              Config (Config, address, database, password, receptionBufferSize, sendingBufferSize, user),
                                                              Connection (Connection, config, receptionBuffer, sendingBuffer, socket),
                                                              ParameterStatus (ParameterStatus), Pid,
                                                              ReadyForQuery (ReadyForQuery), Salt,
                                                              TransactionState (Idle))
import qualified Database.PostgreSQL.Pure.Internal.Exception as Exception
import qualified Database.PostgreSQL.Pure.Internal.Parser    as Parser
import           Database.PostgreSQL.Pure.Internal.SocketIO  (SocketIO, buildAndSend, receive, runSocketIO, send)

import           Control.Exception.Safe                      (assert, bracket)
import           Control.Monad                               (void)
import           Control.Monad.Reader                        (ask)
import qualified Data.Attoparsec.ByteString                  as AP
import qualified Data.ByteString                             as BS
import qualified Data.ByteString.Base16                      as B16
import qualified Data.ByteString.Internal                    as BSI
import qualified Data.ByteString.UTF8                        as BSU
import qualified Data.Map.Strict                             as Map
import qualified Network.Socket                              as NS

#ifdef PURE_MD5
import qualified Data.Digest.Pure.MD5                        as MD5
#else
import qualified "cryptohash-md5" Crypto.Hash.MD5            as MD5
#endif

-- | Bracket function for a connection.
withConnection :: Config -> (Connection -> IO a) -> IO a
withConnection :: Config -> (Connection -> IO a) -> IO a
withConnection config :: Config
config@Config { Address
address :: Address
$sel:address:Config :: Config -> Address
address } Connection -> IO a
f =
  IO a -> IO a
forall a. IO a -> IO a
Exception.convert (IO a -> IO a) -> IO a -> IO a
forall a b. (a -> b) -> a -> b
$ do
    AddrInfo
addr <-
      case Address
address of
        AddressResolved SockAddr
a      -> AddrInfo -> IO AddrInfo
forall (f :: * -> *) a. Applicative f => a -> f a
pure (AddrInfo -> IO AddrInfo) -> AddrInfo -> IO AddrInfo
forall a b. (a -> b) -> a -> b
$ SockAddr -> AddrInfo
addrInfo SockAddr
a
        AddressNotResolved HostName
h HostName
s -> HostName -> HostName -> IO AddrInfo
resolve HostName
h HostName
s
    IO Socket -> (Socket -> IO ()) -> (Socket -> IO a) -> IO a
forall (m :: * -> *) a b c.
MonadMask m =>
m a -> (a -> m b) -> (a -> m c) -> m c
bracket (AddrInfo -> IO Socket
open AddrInfo
addr) Socket -> IO ()
NS.close ((Socket -> IO a) -> IO a) -> (Socket -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ \Socket
sock -> do
      Connection
conn <- Socket -> Config -> IO Connection
connect' Socket
sock Config
config
      Connection -> IO a
f Connection
conn

-- | To connect to the server.
connect :: Config -> IO Connection
connect :: Config -> IO Connection
connect config :: Config
config@Config { Address
address :: Address
$sel:address:Config :: Config -> Address
address } =
  IO Connection -> IO Connection
forall a. IO a -> IO a
Exception.convert (IO Connection -> IO Connection) -> IO Connection -> IO Connection
forall a b. (a -> b) -> a -> b
$ do
    AddrInfo
addr <-
      case Address
address of
        AddressResolved SockAddr
a      -> AddrInfo -> IO AddrInfo
forall (f :: * -> *) a. Applicative f => a -> f a
pure (AddrInfo -> IO AddrInfo) -> AddrInfo -> IO AddrInfo
forall a b. (a -> b) -> a -> b
$ SockAddr -> AddrInfo
addrInfo SockAddr
a
        AddressNotResolved HostName
h HostName
s -> HostName -> HostName -> IO AddrInfo
resolve HostName
h HostName
s
    Socket
sock <- AddrInfo -> IO Socket
open AddrInfo
addr
    Socket -> Config -> IO Connection
connect' Socket
sock Config
config

connect' :: NS.Socket -> Config -> IO Connection
connect' :: Socket -> Config -> IO Connection
connect' Socket
sock config :: Config
config@Config { Int
sendingBufferSize :: Int
$sel:sendingBufferSize:Config :: Config -> Int
sendingBufferSize, Int
receptionBufferSize :: Int
$sel:receptionBufferSize:Config :: Config -> Int
receptionBufferSize } = do
    Buffer
sBuff <- (ForeignPtr Word8 -> Int -> Buffer)
-> Int -> ForeignPtr Word8 -> Buffer
forall a b c. (a -> b -> c) -> b -> a -> c
flip ForeignPtr Word8 -> Int -> Buffer
Buffer Int
sendingBufferSize (ForeignPtr Word8 -> Buffer) -> IO (ForeignPtr Word8) -> IO Buffer
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> IO (ForeignPtr Word8)
forall a. Int -> IO (ForeignPtr a)
BSI.mallocByteString Int
sendingBufferSize
    Buffer
rBuff <- (ForeignPtr Word8 -> Int -> Buffer)
-> Int -> ForeignPtr Word8 -> Buffer
forall a b c. (a -> b -> c) -> b -> a -> c
flip ForeignPtr Word8 -> Int -> Buffer
Buffer Int
receptionBufferSize (ForeignPtr Word8 -> Buffer) -> IO (ForeignPtr Word8) -> IO Buffer
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> IO (ForeignPtr Word8)
forall a. Int -> IO (ForeignPtr a)
BSI.mallocByteString Int
receptionBufferSize
    Socket
-> Buffer
-> Buffer
-> Config
-> SocketIO Connection
-> IO Connection
forall a.
Socket -> Buffer -> Buffer -> Config -> SocketIO a -> IO a
runSocketIO Socket
sock Buffer
sBuff Buffer
rBuff Config
config SocketIO Connection
initializeConnection

-- | To disconnect to the server.
disconnect :: Connection -> IO ()
disconnect :: Connection -> IO ()
disconnect Connection { Socket
socket :: Socket
$sel:socket:Connection :: Connection -> Socket
socket, Buffer
sendingBuffer :: Buffer
$sel:sendingBuffer:Connection :: Connection -> Buffer
sendingBuffer, Buffer
receptionBuffer :: Buffer
$sel:receptionBuffer:Connection :: Connection -> Buffer
receptionBuffer, Config
config :: Config
$sel:config:Connection :: Connection -> Config
config } =
  IO () -> IO ()
forall a. IO a -> IO a
Exception.convert (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
    Socket -> Buffer -> Buffer -> Config -> SocketIO () -> IO ()
forall a.
Socket -> Buffer -> Buffer -> Config -> SocketIO a -> IO a
runSocketIO Socket
socket Buffer
sendingBuffer Buffer
receptionBuffer Config
config SocketIO ()
terminate
    Socket -> IO ()
NS.close Socket
socket

addrInfoHints :: NS.AddrInfo
addrInfoHints :: AddrInfo
addrInfoHints =
  AddrInfo
NS.defaultHints
    { addrSocketType :: SocketType
NS.addrSocketType = SocketType
NS.Stream
    , addrProtocol :: ProtocolNumber
NS.addrProtocol = ProtocolNumber
6 -- TCP
    , addrFlags :: [AddrInfoFlag]
NS.addrFlags = [AddrInfoFlag
NS.AI_ADDRCONFIG]
    }

addrInfo :: NS.SockAddr -> NS.AddrInfo
addrInfo :: SockAddr -> AddrInfo
addrInfo SockAddr
address =
  AddrInfo
addrInfoHints
    { addrAddress :: SockAddr
NS.addrAddress = SockAddr
address
    , addrFamily :: Family
NS.addrFamily =
        case SockAddr
address of
          NS.SockAddrInet {}  -> Family
NS.AF_INET
          NS.SockAddrInet6 {} -> Family
NS.AF_INET6
          NS.SockAddrUnix {}  -> Family
NS.AF_UNIX
#if !MIN_VERSION_network(3,0,0)
          _                   -> NS.AF_UNSPEC
#endif
    }

resolve :: NS.HostName -> NS.ServiceName -> IO NS.AddrInfo
resolve :: HostName -> HostName -> IO AddrInfo
resolve HostName
host HostName
service = do
  [AddrInfo]
addrs <- Maybe AddrInfo -> Maybe HostName -> Maybe HostName -> IO [AddrInfo]
NS.getAddrInfo (AddrInfo -> Maybe AddrInfo
forall a. a -> Maybe a
Just AddrInfo
addrInfoHints) (HostName -> Maybe HostName
forall a. a -> Maybe a
Just HostName
host) (HostName -> Maybe HostName
forall a. a -> Maybe a
Just HostName
service)
  case [AddrInfo]
addrs of
    AddrInfo
addr:[AddrInfo]
_ -> AddrInfo -> IO AddrInfo
forall (m :: * -> *) a. Monad m => a -> m a
return AddrInfo
addr
    []     -> IO AddrInfo
forall a. HasCallStack => a
Exception.cantReachHere

open :: NS.AddrInfo -> IO NS.Socket
open :: AddrInfo -> IO Socket
open AddrInfo
addr = do
  Socket
sock <- Family -> SocketType -> ProtocolNumber -> IO Socket
NS.socket (AddrInfo -> Family
NS.addrFamily AddrInfo
addr) (AddrInfo -> SocketType
NS.addrSocketType AddrInfo
addr) (AddrInfo -> ProtocolNumber
NS.addrProtocol AddrInfo
addr)
  Socket -> SockAddr -> IO ()
NS.connect Socket
sock (SockAddr -> IO ()) -> SockAddr -> IO ()
forall a b. (a -> b) -> a -> b
$ AddrInfo -> SockAddr
NS.addrAddress AddrInfo
addr
  Socket -> IO Socket
forall (m :: * -> *) a. Monad m => a -> m a
return Socket
sock

initializeConnection :: SocketIO Connection
initializeConnection :: SocketIO Connection
initializeConnection = do
  AuthenticationResponse
response <- SocketIO AuthenticationResponse
startup
  (BackendParameters
bps, Pid
pid, Pid
bk) <- AuthenticationResponse -> SocketIO (BackendParameters, Pid, Pid)
authenticate AuthenticationResponse
response
  (Socket
sock, Buffer
sBuff, Buffer
rBuff, Config
config) <- StateT
  Carry
  (ReaderT (Socket, Buffer, Buffer, Config) IO)
  (Socket, Buffer, Buffer, Config)
forall r (m :: * -> *). MonadReader r m => m r
ask
  Connection -> SocketIO Connection
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Connection -> SocketIO Connection)
-> Connection -> SocketIO Connection
forall a b. (a -> b) -> a -> b
$ Socket
-> Pid
-> Pid
-> BackendParameters
-> Buffer
-> Buffer
-> Config
-> Connection
Connection Socket
sock Pid
pid Pid
bk BackendParameters
bps Buffer
sBuff Buffer
rBuff Config
config

startup :: SocketIO AuthenticationResponse
startup :: SocketIO AuthenticationResponse
startup = do
  (Socket
_, Buffer
_, Buffer
_, Config { HostName
user :: HostName
$sel:user:Config :: Config -> HostName
user, HostName
database :: HostName
$sel:database:Config :: Config -> HostName
database }) <- StateT
  Carry
  (ReaderT (Socket, Buffer, Buffer, Config) IO)
  (Socket, Buffer, Buffer, Config)
forall r (m :: * -> *). MonadReader r m => m r
ask
  Builder -> SocketIO ()
buildAndSend (Builder -> SocketIO ()) -> Builder -> SocketIO ()
forall a b. (a -> b) -> a -> b
$ HostName -> HostName -> Builder
Builder.startup HostName
user HostName
database
  Parser AuthenticationResponse -> SocketIO AuthenticationResponse
forall response. Parser response -> SocketIO response
receive Parser AuthenticationResponse
Parser.authentication

authenticate :: AuthenticationResponse -> SocketIO (BackendParameters, Pid, BackendKey)
authenticate :: AuthenticationResponse -> SocketIO (BackendParameters, Pid, Pid)
authenticate AuthenticationResponse
response = do
  (Socket
_, Buffer
_, Buffer
_, Config { HostName
user :: HostName
$sel:user:Config :: Config -> HostName
user, HostName
password :: HostName
$sel:password:Config :: Config -> HostName
password }) <- StateT
  Carry
  (ReaderT (Socket, Buffer, Buffer, Config) IO)
  (Socket, Buffer, Buffer, Config)
forall r (m :: * -> *). MonadReader r m => m r
ask
  case AuthenticationResponse
response of
    AuthenticationResponse
AuthenticationOkResponse                                           -> () -> SocketIO ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    AuthenticationResponse
AuthenticationCleartextPasswordResponse                            -> Carry -> SocketIO ()
auth (Carry -> SocketIO ()) -> Carry -> SocketIO ()
forall a b. (a -> b) -> a -> b
$ HostName -> Carry
BSU.fromString HostName
password
    AuthenticationMD5PasswordResponse (AuthenticationMD5Password Carry
salt) -> Carry -> SocketIO ()
auth (Carry -> SocketIO ()) -> Carry -> SocketIO ()
forall a b. (a -> b) -> a -> b
$ HostName -> HostName -> Carry -> Carry
hashMD5 HostName
user HostName
password Carry
salt
  (BackendParameters
bps, Pid
pid, Pid
bk) <-
    Parser (BackendParameters, Pid, Pid)
-> SocketIO (BackendParameters, Pid, Pid)
forall response. Parser response -> SocketIO response
receive (Parser (BackendParameters, Pid, Pid)
 -> SocketIO (BackendParameters, Pid, Pid))
-> Parser (BackendParameters, Pid, Pid)
-> SocketIO (BackendParameters, Pid, Pid)
forall a b. (a -> b) -> a -> b
$ do
      BackendParameters
bps <- [(ShortByteString, ShortByteString)] -> BackendParameters
forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList ([(ShortByteString, ShortByteString)] -> BackendParameters)
-> ([ParameterStatus] -> [(ShortByteString, ShortByteString)])
-> [ParameterStatus]
-> BackendParameters
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((\(ParameterStatus ShortByteString
k ShortByteString
v) -> (ShortByteString
k, ShortByteString
v)) (ParameterStatus -> (ShortByteString, ShortByteString))
-> [ParameterStatus] -> [(ShortByteString, ShortByteString)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>) ([ParameterStatus] -> BackendParameters)
-> Parser Carry [ParameterStatus] -> Parser Carry BackendParameters
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Parser Carry ParameterStatus -> Parser Carry [ParameterStatus]
forall (m :: * -> *) a. MonadPlus m => m a -> m [a]
AP.many' Parser Carry ParameterStatus
Parser.parameterStatus
      BackendKeyData Pid
pid Pid
bk <- Parser BackendKeyData
Parser.backendKeyData
      ReadyForQuery TransactionState
ts <- Parser ReadyForQuery
Parser.readyForQuery
      Bool
-> Parser (BackendParameters, Pid, Pid)
-> Parser (BackendParameters, Pid, Pid)
forall a. HasCallStack => Bool -> a -> a
assert (TransactionState
ts TransactionState -> TransactionState -> Bool
forall a. Eq a => a -> a -> Bool
== TransactionState
Idle) (Parser (BackendParameters, Pid, Pid)
 -> Parser (BackendParameters, Pid, Pid))
-> Parser (BackendParameters, Pid, Pid)
-> Parser (BackendParameters, Pid, Pid)
forall a b. (a -> b) -> a -> b
$ (BackendParameters, Pid, Pid)
-> Parser (BackendParameters, Pid, Pid)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (BackendParameters
bps, Pid
pid, Pid
bk)
  (BackendParameters, Pid, Pid)
-> SocketIO (BackendParameters, Pid, Pid)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (BackendParameters
bps, Pid
pid, Pid
bk)
  where
    auth :: Carry -> SocketIO ()
auth Carry
pw = do
      Builder -> SocketIO ()
buildAndSend (Builder -> SocketIO ()) -> Builder -> SocketIO ()
forall a b. (a -> b) -> a -> b
$ Carry -> Builder
Builder.password Carry
pw
      SocketIO () -> SocketIO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (SocketIO () -> SocketIO ()) -> SocketIO () -> SocketIO ()
forall a b. (a -> b) -> a -> b
$ Parser () -> SocketIO ()
forall response. Parser response -> SocketIO response
receive Parser ()
Parser.authenticationOk

terminate :: SocketIO ()
terminate :: SocketIO ()
terminate = Carry -> SocketIO ()
send Carry
Builder.terminate

hashMD5 :: String -> String -> Salt -> BS.ByteString
hashMD5 :: HostName -> HostName -> Carry -> Carry
hashMD5 HostName
user HostName
password Carry
salt =
  let
    user' :: Carry
user' = HostName -> Carry
BSU.fromString HostName
user
    password' :: Carry
password' = HostName -> Carry
BSU.fromString HostName
password
#ifdef PURE_MD5
    hash = B16.encode . MD5.md5DigestBytes . MD5.hash'
#else
    hash :: Carry -> Carry
hash = Carry -> Carry
B16.encode (Carry -> Carry) -> (Carry -> Carry) -> Carry -> Carry
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Carry -> Carry
MD5.hash
#endif
  in
    Carry
"md5" Carry -> Carry -> Carry
forall a. Semigroup a => a -> a -> a
<> Carry -> Carry
hash (Carry -> Carry
hash (Carry
password' Carry -> Carry -> Carry
forall a. Semigroup a => a -> a -> a
<> Carry
user') Carry -> Carry -> Carry
forall a. Semigroup a => a -> a -> a
<> Carry
salt)