{-# LANGUAGE CPP, DeriveDataTypeable #-}

------------------------------------------------------------------------------
-- |
-- Module:      Database.PostgreSQL.Simple.Copy
-- Copyright:   (c) 2013 Leon P Smith
-- License:     BSD3
-- Maintainer:  Leon P Smith <leon@melding-monads.com>
-- Stability:   experimental
--
-- mid-level support for COPY IN and COPY OUT.   See
-- <https://www.postgresql.org/docs/9.5/static/sql-copy.html> for
-- more information.
--
-- To use this binding,  first call 'copy' with an appropriate
-- query as documented in the link above.  Then, in the case of a
-- @COPY TO STDOUT@ query,  call 'getCopyData' repeatedly until it
-- returns 'CopyOutDone'.   In the case of a @COPY FROM STDIN@
-- query,  call 'putCopyData' repeatedly and then finish by calling
-- either 'putCopyEnd' to proceed or 'putCopyError' to abort.
--
-- You cannot issue another query on the same connection while a copy
-- is ongoing; this will result in an exception.   It is harmless to
-- concurrently call @getNotification@ on a connection while it is in
-- a @CopyIn@ or @CopyOut@ state,  however be aware that current versions
-- of the PostgreSQL backend will not deliver notifications to a client
-- while a transaction is ongoing.
--
------------------------------------------------------------------------------

module Database.PostgreSQL.Simple.Copy
    ( copy
    , copy_
    , CopyOutResult(..)
    , foldCopyData
    , getCopyData
    , putCopyData
    , putCopyEnd
    , putCopyError
    ) where

import           Control.Applicative
import           Control.Concurrent
import           Control.Exception  ( throwIO )
import qualified Data.Attoparsec.ByteString.Char8 as P
import           Data.Typeable(Typeable)
import           Data.Int(Int64)
import qualified Data.ByteString.Char8 as B
import qualified Database.PostgreSQL.LibPQ as PQ
import           Database.PostgreSQL.Simple
import           Database.PostgreSQL.Simple.Types
import           Database.PostgreSQL.Simple.Internal hiding (result, row)


-- | Issue a @COPY FROM STDIN@ or @COPY TO STDOUT@ query.   In the former
--   case, the connection's state will change to @CopyIn@;  in the latter,
--   @CopyOut@.  The connection must be in the ready state in order
--   to call this function.  Performs parameter subsitution.

copy :: ( ToRow params ) => Connection -> Query -> params -> IO ()
copy :: Connection -> Query -> params -> IO ()
copy Connection
conn Query
template params
qs = do
    ByteString
q <- Connection -> Query -> params -> IO ByteString
forall q. ToRow q => Connection -> Query -> q -> IO ByteString
formatQuery Connection
conn Query
template params
qs
    ByteString -> Connection -> Query -> ByteString -> IO ()
doCopy ByteString
"Database.PostgreSQL.Simple.Copy.copy" Connection
conn Query
template ByteString
q


-- | Issue a @COPY FROM STDIN@ or @COPY TO STDOUT@ query.   In the former
--   case, the connection's state will change to @CopyIn@;  in the latter,
--   @CopyOut@.  The connection must be in the ready state in order
--   to call this function.  Does not perform parameter subsitution.

copy_ :: Connection -> Query -> IO ()
copy_ :: Connection -> Query -> IO ()
copy_ Connection
conn (Query ByteString
q) = do
    ByteString -> Connection -> Query -> ByteString -> IO ()
doCopy ByteString
"Database.PostgreSQL.Simple.Copy.copy_" Connection
conn (ByteString -> Query
Query ByteString
q) ByteString
q

doCopy :: B.ByteString -> Connection -> Query -> B.ByteString -> IO ()
doCopy :: ByteString -> Connection -> Query -> ByteString -> IO ()
doCopy ByteString
funcName Connection
conn Query
template ByteString
q = do
    Result
result <- Connection -> ByteString -> IO Result
exec Connection
conn ByteString
q
    ExecStatus
status <- Result -> IO ExecStatus
PQ.resultStatus Result
result
    let errMsg :: [Char] -> IO a
errMsg [Char]
msg = QueryError -> IO a
forall e a. Exception e => e -> IO a
throwIO (QueryError -> IO a) -> QueryError -> IO a
forall a b. (a -> b) -> a -> b
$ [Char] -> Query -> QueryError
QueryError
                  (ByteString -> [Char]
B.unpack ByteString
funcName [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
" " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
msg)
                  Query
template
    let err :: IO a
err = [Char] -> IO a
forall a. [Char] -> IO a
errMsg ([Char] -> IO a) -> [Char] -> IO a
forall a b. (a -> b) -> a -> b
$ ExecStatus -> [Char]
forall a. Show a => a -> [Char]
show ExecStatus
status
    case ExecStatus
status of
      ExecStatus
PQ.EmptyQuery    -> IO ()
forall a. IO a
err
      ExecStatus
PQ.CommandOk     -> IO ()
forall a. IO a
err
      ExecStatus
PQ.TuplesOk      -> IO ()
forall a. IO a
err
      ExecStatus
PQ.CopyOut       -> () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
      ExecStatus
PQ.CopyIn        -> () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
#if MIN_VERSION_postgresql_libpq(0,9,3)
      ExecStatus
PQ.CopyBoth      -> [Char] -> IO ()
forall a. [Char] -> IO a
errMsg [Char]
"COPY BOTH is not supported"
#endif
#if MIN_VERSION_postgresql_libpq(0,9,2)
      ExecStatus
PQ.SingleTuple   -> [Char] -> IO ()
forall a. [Char] -> IO a
errMsg [Char]
"single-row mode is not supported"
#endif
      ExecStatus
PQ.BadResponse   -> ByteString -> Result -> ExecStatus -> IO ()
forall a. ByteString -> Result -> ExecStatus -> IO a
throwResultError ByteString
funcName Result
result ExecStatus
status
      ExecStatus
PQ.NonfatalError -> ByteString -> Result -> ExecStatus -> IO ()
forall a. ByteString -> Result -> ExecStatus -> IO a
throwResultError ByteString
funcName Result
result ExecStatus
status
      ExecStatus
PQ.FatalError    -> ByteString -> Result -> ExecStatus -> IO ()
forall a. ByteString -> Result -> ExecStatus -> IO a
throwResultError ByteString
funcName Result
result ExecStatus
status

data CopyOutResult
   = CopyOutRow  !B.ByteString         -- ^ Data representing either exactly
                                       --   one row of the result,  or header
                                       --   or footer data depending on format.
   | CopyOutDone {-# UNPACK #-} !Int64 -- ^ No more rows, and a count of the
                                       --   number of rows returned.
     deriving (CopyOutResult -> CopyOutResult -> Bool
(CopyOutResult -> CopyOutResult -> Bool)
-> (CopyOutResult -> CopyOutResult -> Bool) -> Eq CopyOutResult
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: CopyOutResult -> CopyOutResult -> Bool
$c/= :: CopyOutResult -> CopyOutResult -> Bool
== :: CopyOutResult -> CopyOutResult -> Bool
$c== :: CopyOutResult -> CopyOutResult -> Bool
Eq, Typeable, Int -> CopyOutResult -> [Char] -> [Char]
[CopyOutResult] -> [Char] -> [Char]
CopyOutResult -> [Char]
(Int -> CopyOutResult -> [Char] -> [Char])
-> (CopyOutResult -> [Char])
-> ([CopyOutResult] -> [Char] -> [Char])
-> Show CopyOutResult
forall a.
(Int -> a -> [Char] -> [Char])
-> (a -> [Char]) -> ([a] -> [Char] -> [Char]) -> Show a
showList :: [CopyOutResult] -> [Char] -> [Char]
$cshowList :: [CopyOutResult] -> [Char] -> [Char]
show :: CopyOutResult -> [Char]
$cshow :: CopyOutResult -> [Char]
showsPrec :: Int -> CopyOutResult -> [Char] -> [Char]
$cshowsPrec :: Int -> CopyOutResult -> [Char] -> [Char]
Show)


-- | Fold over @COPY TO STDOUT@ query passing each copied row to an accumulator
--   and calling a post-process at the end. A connection must be in the
--   @CopyOut@ state in order to call this function.
--
--   __Example__
--
--   > (acc, count) <- foldCopyData conn
--   >     (\acc row -> return (row:acc))
--   >     (\acc count -> return (acc, count))
--   >     []

foldCopyData
  :: Connection                   -- ^ Database connection
  -> (a -> B.ByteString -> IO a)  -- ^ Accumulate one row of the result
  -> (a -> Int64 -> IO b)         -- ^ Post-process accumulator with a count of rows
  -> a                            -- ^ Initial accumulator
  -> IO b                         -- ^ Result
foldCopyData :: Connection
-> (a -> ByteString -> IO a) -> (a -> Int64 -> IO b) -> a -> IO b
foldCopyData Connection
conn a -> ByteString -> IO a
f a -> Int64 -> IO b
g !a
acc = do
    CopyOutResult
result <- Connection -> IO CopyOutResult
getCopyData Connection
conn
    case CopyOutResult
result of
        CopyOutRow ByteString
row    -> a -> ByteString -> IO a
f a
acc ByteString
row IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Connection
-> (a -> ByteString -> IO a) -> (a -> Int64 -> IO b) -> a -> IO b
forall a b.
Connection
-> (a -> ByteString -> IO a) -> (a -> Int64 -> IO b) -> a -> IO b
foldCopyData Connection
conn a -> ByteString -> IO a
f a -> Int64 -> IO b
g
        CopyOutDone Int64
count -> a -> Int64 -> IO b
g a
acc Int64
count


-- | Retrieve some data from a @COPY TO STDOUT@ query.   A connection
--   must be in the @CopyOut@ state in order to call this function.  If this
--   returns a 'CopyOutRow', the connection remains in the @CopyOut@ state,
--   if it returns 'CopyOutDone', then the connection has reverted to the
--   ready state.

getCopyData :: Connection -> IO CopyOutResult
getCopyData :: Connection -> IO CopyOutResult
getCopyData Connection
conn = Connection -> (Connection -> IO CopyOutResult) -> IO CopyOutResult
forall a. Connection -> (Connection -> IO a) -> IO a
withConnection Connection
conn Connection -> IO CopyOutResult
loop
  where
    funcName :: ByteString
funcName = ByteString
"Database.PostgreSQL.Simple.Copy.getCopyData"
    loop :: Connection -> IO CopyOutResult
loop Connection
pqconn = do
#if defined(mingw32_HOST_OS)
      row <- PQ.getCopyData pqconn False
#else
      CopyOutResult
row <- Connection -> Bool -> IO CopyOutResult
PQ.getCopyData Connection
pqconn Bool
True
#endif
      case CopyOutResult
row of
        PQ.CopyOutRow ByteString
rowdata -> CopyOutResult -> IO CopyOutResult
forall (m :: * -> *) a. Monad m => a -> m a
return (CopyOutResult -> IO CopyOutResult)
-> CopyOutResult -> IO CopyOutResult
forall a b. (a -> b) -> a -> b
$! ByteString -> CopyOutResult
CopyOutRow ByteString
rowdata
        CopyOutResult
PQ.CopyOutDone -> Int64 -> CopyOutResult
CopyOutDone (Int64 -> CopyOutResult) -> IO Int64 -> IO CopyOutResult
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ByteString -> Connection -> IO Int64
getCopyCommandTag ByteString
funcName Connection
pqconn
#if defined(mingw32_HOST_OS)
        PQ.CopyOutWouldBlock -> do
            fail (B.unpack funcName ++ ": the impossible happened")
#else
        CopyOutResult
PQ.CopyOutWouldBlock -> do
            Maybe Fd
mfd <- Connection -> IO (Maybe Fd)
PQ.socket Connection
pqconn
            case Maybe Fd
mfd of
              Maybe Fd
Nothing -> IOError -> IO CopyOutResult
forall e a. Exception e => e -> IO a
throwIO (ByteString -> IOError
fdError ByteString
funcName)
              Just Fd
fd -> do
                  Fd -> IO ()
threadWaitRead Fd
fd
                  Bool
_ <- Connection -> IO Bool
PQ.consumeInput Connection
pqconn
                  Connection -> IO CopyOutResult
loop Connection
pqconn
#endif
        CopyOutResult
PQ.CopyOutError -> do
            Maybe ByteString
mmsg <- Connection -> IO (Maybe ByteString)
PQ.errorMessage Connection
pqconn
            SqlError -> IO CopyOutResult
forall e a. Exception e => e -> IO a
throwIO SqlError :: ByteString
-> ExecStatus -> ByteString -> ByteString -> ByteString -> SqlError
SqlError {
                          sqlState :: ByteString
sqlState       = ByteString
"",
                          sqlExecStatus :: ExecStatus
sqlExecStatus  = ExecStatus
FatalError,
                          sqlErrorMsg :: ByteString
sqlErrorMsg    = ByteString
-> (ByteString -> ByteString) -> Maybe ByteString -> ByteString
forall b a. b -> (a -> b) -> Maybe a -> b
maybe ByteString
"" ByteString -> ByteString
forall a. a -> a
id Maybe ByteString
mmsg,
                          sqlErrorDetail :: ByteString
sqlErrorDetail = ByteString
"",
                          sqlErrorHint :: ByteString
sqlErrorHint   = ByteString
funcName
                        }


-- | Feed some data to a @COPY FROM STDIN@ query.  Note that
--   the data does not need to represent a single row,  or even an
--   integral number of rows.  The net result of
--   @putCopyData conn a >> putCopyData conn b@
--   is the same as @putCopyData conn c@ whenever @c == BS.append a b@.
--
--   A connection must be in the @CopyIn@ state in order to call this
--   function,  otherwise a 'SqlError' exception will result.  The
--   connection remains in the @CopyIn@ state after this function
--   is called.

putCopyData :: Connection -> B.ByteString -> IO ()
putCopyData :: Connection -> ByteString -> IO ()
putCopyData Connection
conn ByteString
dat = Connection -> (Connection -> IO ()) -> IO ()
forall a. Connection -> (Connection -> IO a) -> IO a
withConnection Connection
conn ((Connection -> IO ()) -> IO ()) -> (Connection -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Connection
pqconn -> do
    ByteString
-> (Connection -> IO CopyInResult) -> Connection -> IO ()
doCopyIn ByteString
funcName (\Connection
c -> Connection -> ByteString -> IO CopyInResult
PQ.putCopyData Connection
c ByteString
dat) Connection
pqconn
  where
    funcName :: ByteString
funcName = ByteString
"Database.PostgreSQL.Simple.Copy.putCopyData"


-- | Completes a @COPY FROM STDIN@ query.  Returns the number of rows
--   processed.
--
--   A connection must be in the @CopyIn@ state in order to call this
--   function,  otherwise a 'SqlError' exception will result.  The
--   connection's state changes back to ready after this function
--   is called.

putCopyEnd :: Connection -> IO Int64
putCopyEnd :: Connection -> IO Int64
putCopyEnd Connection
conn = Connection -> (Connection -> IO Int64) -> IO Int64
forall a. Connection -> (Connection -> IO a) -> IO a
withConnection Connection
conn ((Connection -> IO Int64) -> IO Int64)
-> (Connection -> IO Int64) -> IO Int64
forall a b. (a -> b) -> a -> b
$ \Connection
pqconn -> do
    ByteString
-> (Connection -> IO CopyInResult) -> Connection -> IO ()
doCopyIn ByteString
funcName (\Connection
c -> Connection -> Maybe ByteString -> IO CopyInResult
PQ.putCopyEnd Connection
c Maybe ByteString
forall a. Maybe a
Nothing) Connection
pqconn
    ByteString -> Connection -> IO Int64
getCopyCommandTag ByteString
funcName Connection
pqconn
  where
    funcName :: ByteString
funcName = ByteString
"Database.PostgreSQL.Simple.Copy.putCopyEnd"


-- | Aborts a @COPY FROM STDIN@ query.  The string parameter is simply
--   an arbitrary error message that may show up in the PostgreSQL
--   server's log.
--
--   A connection must be in the @CopyIn@ state in order to call this
--   function,  otherwise a 'SqlError' exception will result.  The
--   connection's state changes back to ready after this function
--   is called.

putCopyError :: Connection -> B.ByteString -> IO ()
putCopyError :: Connection -> ByteString -> IO ()
putCopyError Connection
conn ByteString
err = Connection -> (Connection -> IO ()) -> IO ()
forall a. Connection -> (Connection -> IO a) -> IO a
withConnection Connection
conn ((Connection -> IO ()) -> IO ()) -> (Connection -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Connection
pqconn -> do
    ByteString
-> (Connection -> IO CopyInResult) -> Connection -> IO ()
doCopyIn ByteString
funcName (\Connection
c -> Connection -> Maybe ByteString -> IO CopyInResult
PQ.putCopyEnd Connection
c (ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just ByteString
err)) Connection
pqconn
    Connection -> IO ()
consumeResults Connection
pqconn
  where
    funcName :: ByteString
funcName = ByteString
"Database.PostgreSQL.Simple.Copy.putCopyError"


doCopyIn :: B.ByteString -> (PQ.Connection -> IO PQ.CopyInResult)
         -> PQ.Connection -> IO ()
doCopyIn :: ByteString
-> (Connection -> IO CopyInResult) -> Connection -> IO ()
doCopyIn ByteString
funcName Connection -> IO CopyInResult
action = Connection -> IO ()
loop
  where
    loop :: Connection -> IO ()
loop Connection
pqconn = do
      CopyInResult
stat <- Connection -> IO CopyInResult
action Connection
pqconn
      case CopyInResult
stat of
        CopyInResult
PQ.CopyInOk    -> () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
        CopyInResult
PQ.CopyInError -> do
            Maybe ByteString
mmsg <- Connection -> IO (Maybe ByteString)
PQ.errorMessage Connection
pqconn
            SqlError -> IO ()
forall e a. Exception e => e -> IO a
throwIO SqlError :: ByteString
-> ExecStatus -> ByteString -> ByteString -> ByteString -> SqlError
SqlError {
                      sqlState :: ByteString
sqlState = ByteString
"",
                      sqlExecStatus :: ExecStatus
sqlExecStatus  = ExecStatus
FatalError,
                      sqlErrorMsg :: ByteString
sqlErrorMsg    = ByteString
-> (ByteString -> ByteString) -> Maybe ByteString -> ByteString
forall b a. b -> (a -> b) -> Maybe a -> b
maybe ByteString
"" ByteString -> ByteString
forall a. a -> a
id Maybe ByteString
mmsg,
                      sqlErrorDetail :: ByteString
sqlErrorDetail = ByteString
"",
                      sqlErrorHint :: ByteString
sqlErrorHint   = ByteString
funcName
                    }
        CopyInResult
PQ.CopyInWouldBlock -> do
            Maybe Fd
mfd <- Connection -> IO (Maybe Fd)
PQ.socket Connection
pqconn
            case Maybe Fd
mfd of
              Maybe Fd
Nothing -> IOError -> IO ()
forall e a. Exception e => e -> IO a
throwIO (ByteString -> IOError
fdError ByteString
funcName)
              Just Fd
fd -> do
                  Fd -> IO ()
threadWaitWrite Fd
fd
                  Connection -> IO ()
loop Connection
pqconn
{-# INLINE doCopyIn #-}

getCopyCommandTag :: B.ByteString -> PQ.Connection -> IO Int64
getCopyCommandTag :: ByteString -> Connection -> IO Int64
getCopyCommandTag ByteString
funcName Connection
pqconn = do
    Result
result  <- IO Result -> (Result -> IO Result) -> Maybe Result -> IO Result
forall b a. b -> (a -> b) -> Maybe a -> b
maybe ([Char] -> IO Result
forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
errCmdStatus) Result -> IO Result
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Result -> IO Result) -> IO (Maybe Result) -> IO Result
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Connection -> IO (Maybe Result)
PQ.getResult Connection
pqconn
    ByteString
cmdStat <- IO ByteString
-> (ByteString -> IO ByteString)
-> Maybe ByteString
-> IO ByteString
forall b a. b -> (a -> b) -> Maybe a -> b
maybe ([Char] -> IO ByteString
forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
errCmdStatus) ByteString -> IO ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe ByteString -> IO ByteString)
-> IO (Maybe ByteString) -> IO ByteString
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Result -> IO (Maybe ByteString)
PQ.cmdStatus Result
result
    Connection -> IO ()
consumeResults Connection
pqconn
    let rowCount :: Parser ByteString Int64
rowCount =   ByteString -> Parser ByteString
P.string ByteString
"COPY " Parser ByteString
-> Parser ByteString Int64 -> Parser ByteString Int64
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> (Parser ByteString Int64
forall a. Integral a => Parser a
P.decimal Parser ByteString Int64
-> Parser ByteString () -> Parser ByteString Int64
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* Parser ByteString ()
forall t. Chunk t => Parser t ()
P.endOfInput)
    case Parser ByteString Int64 -> ByteString -> Either [Char] Int64
forall a. Parser a -> ByteString -> Either [Char] a
P.parseOnly Parser ByteString Int64
rowCount ByteString
cmdStat of
      Left  [Char]
_ -> do Maybe ByteString
mmsg <- Connection -> IO (Maybe ByteString)
PQ.errorMessage Connection
pqconn
                    [Char] -> IO Int64
forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail ([Char] -> IO Int64) -> [Char] -> IO Int64
forall a b. (a -> b) -> a -> b
$ [Char]
errCmdStatusFmt
                        [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char] -> (ByteString -> [Char]) -> Maybe ByteString -> [Char]
forall b a. b -> (a -> b) -> Maybe a -> b
maybe [Char]
"" (\ByteString
msg -> [Char]
"\nConnection error: "[Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ByteString -> [Char]
B.unpack ByteString
msg) Maybe ByteString
mmsg
      Right Int64
n -> Int64 -> IO Int64
forall (m :: * -> *) a. Monad m => a -> m a
return (Int64 -> IO Int64) -> Int64 -> IO Int64
forall a b. (a -> b) -> a -> b
$! Int64
n
  where
    errCmdStatus :: [Char]
errCmdStatus    = ByteString -> [Char]
B.unpack ByteString
funcName [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
": failed to fetch command status"
    errCmdStatusFmt :: [Char]
errCmdStatusFmt = ByteString -> [Char]
B.unpack ByteString
funcName [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
": failed to parse command status"


consumeResults :: PQ.Connection -> IO ()
consumeResults :: Connection -> IO ()
consumeResults Connection
pqconn = do
    Maybe Result
mres <- Connection -> IO (Maybe Result)
PQ.getResult Connection
pqconn
    case Maybe Result
mres of
      Maybe Result
Nothing -> () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
      Just Result
_  -> Connection -> IO ()
consumeResults Connection
pqconn