{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedLabels #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
module Database.PostgreSQL.Pure.Internal.Query
(
parse
, Bind (..)
, Execute (..)
, flush
, sync
, Message (..)
, Close (..)
, begin
, commit
, rollback
) where
import qualified Database.PostgreSQL.Pure.Internal.Builder as Builder
import Database.PostgreSQL.Pure.Internal.Data (BackendParameters,
BindParameterFormatCodes (BindParameterFormatCodesAll),
BindResultFormatCodes (BindResultFormatCodesEach),
CloseProcedure (CloseProcedure),
ColumnInfo (ColumnInfo, formatCode),
CommandComplete (CommandComplete),
Connection (Connection, config, receptionBuffer, sendingBuffer, socket),
DataRow (DataRow), ErrorFields,
ExecuteResult (ExecuteComplete, ExecuteEmptyQuery, ExecuteSuspended),
Executed (Executed),
ExecutedProcedure (ExecutedProcedure),
FormatCode (BinaryFormat), FromRecord, MessageResult,
Notice (Notice), Oid,
ParameterDescription (ParameterDescription),
Portal (Portal), PortalName,
PortalProcedure (PortalProcedure),
PreparedStatement (PreparedStatement),
PreparedStatementName,
PreparedStatementProcedure (PreparedStatementProcedure),
Query, ReadyForQuery (ReadyForQuery),
RowDescription (RowDescription), StringDecoder,
StringEncoder, ToRecord (toRecord), TransactionState,
TypeLength (FixedLength))
import qualified Database.PostgreSQL.Pure.Internal.Data as Data
import qualified Database.PostgreSQL.Pure.Internal.Exception as Exception
import qualified Database.PostgreSQL.Pure.Internal.Parser as Parser
import Database.PostgreSQL.Pure.Internal.SocketIO (buildAndSend, receive, runSocketIO, send)
import Control.Applicative ((<|>))
import Control.Exception.Safe (throw, try)
import Control.Monad (void, when)
import Control.Monad.State.Strict (put)
import qualified Data.Attoparsec.ByteString as AP
import qualified Data.Attoparsec.Combinator as AP
import qualified Data.ByteString.Builder as BSB
import qualified Data.ByteString.Char8 as BSC
import Data.Functor (($>))
import Data.List (genericLength)
import GHC.Records (HasField (getField))
#if !MIN_VERSION_base(4,13,0)
import Control.Monad.Fail (MonadFail)
#endif
parse
:: PreparedStatementName
-> Query
-> Either (Word, Word) ([Oid], [Oid])
-> PreparedStatementProcedure
parse name query (Left (parameterLength, resultLength)) = parse' name query parameterLength resultLength Nothing
parse name query (Right oids@(parameterOids, resultOids)) = parse' name query (genericLength parameterOids) (genericLength resultOids) (Just oids)
parse' :: PreparedStatementName -> Query -> Word -> Word -> Maybe ([Oid], [Oid]) -> PreparedStatementProcedure
parse' name query parameterLength resultLength oids =
let
inaneColumnInfo oid = ColumnInfo "" 0 0 oid (FixedLength 0) 0 BinaryFormat
parameterOids = fst <$> oids
builder =
case oids of
Just (parameterOids, _) -> Builder.parse name query parameterOids
_ -> Builder.parse name query [] <> Builder.describePreparedStatement name
parser = do
Parser.parseComplete
(parameterOids, resultInfos) <-
case oids of
Just (parameterOids, resultOids) -> pure (parameterOids, inaneColumnInfo <$> resultOids)
_ -> do
ParameterDescription parameterOids <- Parser.parameterDescription
resultInfos <-
AP.choice
[ do
RowDescription infos <- Parser.rowDescription
pure infos
, Parser.noData $> []
]
pure (parameterOids, resultInfos)
pure $ PreparedStatement name parameterOids resultInfos
in PreparedStatementProcedure name parameterLength resultLength parameterOids builder parser
class Bind ps where
bind
:: (ToRecord param, MonadFail m)
=> PortalName
-> FormatCode
-> FormatCode
-> BackendParameters
-> StringEncoder
-> param
-> ps
-> m PortalProcedure
instance Bind PreparedStatement where
bind name parameterFormat resultFormat backendParams encode parameters ps@(PreparedStatement psName psParameterOids psResultInfos) = do
record <- toRecord backendParams encode (Just psParameterOids) (replicate (length psParameterOids) parameterFormat) parameters
let
builder = Builder.bind name psName (BindParameterFormatCodesAll parameterFormat) record (BindResultFormatCodesEach $ replicate (length psResultInfos) resultFormat)
parser = do
Parser.bindComplete
pure (ps, Portal name ((\i -> i { formatCode = resultFormat }) <$> psResultInfos) ps)
pure $ PortalProcedure name resultFormat builder parser
instance Bind PreparedStatementProcedure where
bind name parameterFormat resultFormat backendParams encode parameters (PreparedStatementProcedure psName psParameterLength psResultLength psParameterOids psBuilder psParser) = do
record <- toRecord backendParams encode psParameterOids (replicate (fromIntegral psParameterLength) parameterFormat) parameters
let
builder =
psBuilder
<> Builder.bind name psName (BindParameterFormatCodesAll parameterFormat) record (BindResultFormatCodesEach $ replicate (fromIntegral psResultLength) resultFormat)
parser = do
ps@PreparedStatement { resultInfos } <- psParser
Parser.bindComplete
pure (ps, Portal name ((\i -> i { formatCode = resultFormat }) <$> resultInfos) ps)
pure $ PortalProcedure name resultFormat builder parser
class Execute p where
execute
:: FromRecord result
=> Word
-> StringDecoder
-> p
-> ExecutedProcedure result
instance Execute Portal where
execute rowLimit decode p@(Portal pName pInfos ps@PreparedStatement {}) =
let
builder = Builder.execute pName $ fromIntegral rowLimit
parser = executeParser ps p pInfos decode
in ExecutedProcedure builder parser
instance Execute PortalProcedure where
execute rowLimit decode (PortalProcedure pName pFormat pBuilder pParser) =
let
builder = pBuilder <> Builder.execute pName (fromIntegral rowLimit)
parser = do
(ps@(PreparedStatement _ _ psInfos), p) <- pParser
executeParser ps p ((\i -> i { formatCode = pFormat }) <$> psInfos) decode
in ExecutedProcedure builder parser
executeParser :: forall r. FromRecord r => PreparedStatement -> Portal -> [ColumnInfo] -> StringDecoder -> AP.Parser (PreparedStatement, Portal, Executed r, Maybe ErrorFields)
executeParser ps p infos decode = do
records <- ((\(DataRow d) -> d) <$>) <$> AP.many' (Parser.dataRow decode infos)
when (null records) $ do
r <- AP.option False $ AP.lookAhead Parser.dataRowRaw >> pure True
when r $ do
void (Parser.dataRow decode infos :: AP.Parser (DataRow r))
fail "can't reach here"
err <- AP.option Nothing $ (\(Notice err) -> Just err) <$> Parser.notice
result <-
((\(CommandComplete tag) -> ExecuteComplete tag) <$> Parser.commandComplete)
<|> (Parser.emptyQuery >> pure ExecuteEmptyQuery)
<|> (Parser.portalSuspended >> pure ExecuteSuspended)
pure (ps, p, Executed result records p, err)
class Close p where
close :: p -> CloseProcedure
instance Close PreparedStatement where
close p = CloseProcedure (Builder.closePreparedStatement $ getField @"name" p) Parser.closeComplete
instance Close Portal where
close p = CloseProcedure (Builder.closePortal $ getField @"name" p) Parser.closeComplete
class Message m where
builder :: m -> BSB.Builder
default builder :: HasField "builder" m BSB.Builder => m -> BSB.Builder
builder = getField @"builder"
parser :: m -> AP.Parser (MessageResult m)
default parser :: HasField "parser" m (AP.Parser (MessageResult m)) => m -> AP.Parser (MessageResult m)
parser = getField @"parser"
instance Message PreparedStatementProcedure
instance Message PortalProcedure
instance Message (ExecutedProcedure r)
instance Message CloseProcedure
instance Message () where
builder _ = mempty
parser _ = pure ()
type instance MessageResult () = ()
instance (Message m0, Message m1) => Message (m0, m1) where
builder (m0, m1) = builder m0 <> builder m1
parser (m0, m1) = (,) <$> parser m0 <*> parser m1
type instance MessageResult (m0, m1) = (MessageResult m0, MessageResult m1)
instance (Message m0, Message m1, Message m2) => Message (m0, m1, m2) where
builder (m0, m1, m2) = builder m0 <> builder m1 <> builder m2
parser (m0, m1, m2) = (,,) <$> parser m0 <*> parser m1 <*> parser m2
type instance MessageResult (m0, m1, m2) = (MessageResult m0, MessageResult m1, MessageResult m2)
instance (Message m0, Message m1, Message m2, Message m3) => Message (m0, m1, m2, m3) where
builder (m0, m1, m2, m3) = builder m0 <> builder m1 <> builder m2 <> builder m3
parser (m0, m1, m2, m3) = (,,,) <$> parser m0 <*> parser m1 <*> parser m2 <*> parser m3
type instance MessageResult (m0, m1, m2, m3) = (MessageResult m0, MessageResult m1, MessageResult m2, MessageResult m3)
instance Message m => Message [m] where
builder = mconcat . (builder <$>)
parser = sequence . (parser <$>)
type instance MessageResult [m] = [MessageResult m]
flush :: Message m => Connection -> m -> IO (MessageResult m)
flush Connection { socket, sendingBuffer, receptionBuffer, config } m =
Exception.convert $
runSocketIO socket sendingBuffer receptionBuffer config $ do
r <- try $ do
buildAndSend $ builder m <> BSB.byteString Builder.flush
receive $ parser m
case r of
Right r -> pure r
Left (Exception.InternalErrorResponse fields _ _) -> do
ReadyForQuery ts <- do
put mempty
send Builder.sync
receive Parser.readyForQuery
throw $ Exception.InternalErrorResponse fields (Just ts) mempty
Left e -> throw e
sync :: Message m => Connection -> m -> IO (MessageResult m, TransactionState)
sync Connection { socket, sendingBuffer, receptionBuffer, config } m =
Exception.convert $
runSocketIO socket sendingBuffer receptionBuffer config $ do
r <-
try $ do
buildAndSend $ builder m <> BSB.byteString Builder.sync
(r, ReadyForQuery ts) <- receive $ (,) <$> parser m <*> Parser.readyForQuery
pure (r, ts)
case r of
Right r -> pure r
Left (Exception.InternalErrorResponse fields _ rest) -> do
put rest
ReadyForQuery ts <- receive Parser.readyForQuery
throw $ Exception.InternalErrorResponse fields (Just ts) mempty
Left e -> throw e
begin :: ExecutedProcedure ()
begin = transact "BEGIN"
commit :: ExecutedProcedure ()
commit = transact "COMMIT"
rollback :: ExecutedProcedure ()
rollback = transact "ROLLBACK"
transact :: Query -> ExecutedProcedure ()
transact q =
let
psProc = parse "" q (Right ([], []))
in
case bind "" BinaryFormat BinaryFormat mempty (pure . BSC.pack) () psProc of
Right pProc -> execute 1 (pure . BSC.unpack) pProc
Left err -> error err