{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE ViewPatterns #-}

module Database.Persist.Postgresql.Streaming.Internal
  ( rawSelectStream
  ) where

import           Control.Exception (throwIO)
import           Control.Monad.IO.Class (MonadIO(liftIO))
import           Control.Monad.Logger (LoggingT(..), logDebugNS)
import           Control.Monad.Reader.Class (MonadReader(ask))
import           Control.Monad.Trans.Class (lift)
import           Control.Monad.Trans.Reader (ReaderT(..))
import           Control.Monad.Trans.Resource (MonadResource, release)
import           Data.Acquire (Acquire, allocateAcquire, mkAcquire)
import           Data.Conduit (ConduitT, (.|))
import qualified Data.Conduit.Combinators as CC (mapM, yieldMany)
import qualified Data.Text as T (Text, append, pack)
import qualified Data.Text.Encoding as T (encodeUtf8)
import           Database.Persist.Sql.Types.Internal (SqlBackend(..))
import           Database.Persist.Postgresql
import           Database.Persist.Postgresql.Internal
import qualified Database.PostgreSQL.Simple as PG
import qualified Database.PostgreSQL.Simple.Cursor as PGC
import qualified Database.PostgreSQL.Simple.Types as PG

-- | Run a @Text@ query, with interpolated @PersistValue@s, against a PostgreSQL
-- backend using cursors, parsing the results with a custom parser function and
-- streaming them back.
--
-- If the parser function returns @Left@ for any row, a 'PersistException' will
-- be thrown.
rawSelectStream
  :: MonadResource m
  => ([PersistValue] -> Either T.Text result)
  -> T.Text
  -> [PersistValue]
  -> ConduitT () result (ReaderT (RawPostgresql SqlBackend) m) ()
rawSelectStream :: ([PersistValue] -> Either Text result)
-> Text
-> [PersistValue]
-> ConduitT () result (ReaderT (RawPostgresql SqlBackend) m) ()
rawSelectStream [PersistValue] -> Either Text result
parseRes Text
query [PersistValue]
vals = do
  Acquire
  (ConduitT () result (ReaderT (RawPostgresql SqlBackend) m) ())
srcRes <- ReaderT
  (RawPostgresql SqlBackend)
  m
  (Acquire
     (ConduitT () result (ReaderT (RawPostgresql SqlBackend) m) ()))
-> ConduitT
     ()
     result
     (ReaderT (RawPostgresql SqlBackend) m)
     (Acquire
        (ConduitT () result (ReaderT (RawPostgresql SqlBackend) m) ()))
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ReaderT
   (RawPostgresql SqlBackend)
   m
   (Acquire
      (ConduitT () result (ReaderT (RawPostgresql SqlBackend) m) ()))
 -> ConduitT
      ()
      result
      (ReaderT (RawPostgresql SqlBackend) m)
      (Acquire
         (ConduitT () result (ReaderT (RawPostgresql SqlBackend) m) ())))
-> ReaderT
     (RawPostgresql SqlBackend)
     m
     (Acquire
        (ConduitT () result (ReaderT (RawPostgresql SqlBackend) m) ()))
-> ConduitT
     ()
     result
     (ReaderT (RawPostgresql SqlBackend) m)
     (Acquire
        (ConduitT () result (ReaderT (RawPostgresql SqlBackend) m) ()))
forall a b. (a -> b) -> a -> b
$ ReaderT
  (RawPostgresql SqlBackend)
  IO
  (Acquire
     (ConduitT () result (ReaderT (RawPostgresql SqlBackend) m) ()))
-> ReaderT
     (RawPostgresql SqlBackend)
     m
     (Acquire
        (ConduitT () result (ReaderT (RawPostgresql SqlBackend) m) ()))
forall (m :: * -> *) backend b.
(MonadIO m, MonadReader backend m) =>
ReaderT backend IO b -> m b
liftPersist (ReaderT
   (RawPostgresql SqlBackend)
   IO
   (Acquire
      (ConduitT () result (ReaderT (RawPostgresql SqlBackend) m) ()))
 -> ReaderT
      (RawPostgresql SqlBackend)
      m
      (Acquire
         (ConduitT () result (ReaderT (RawPostgresql SqlBackend) m) ())))
-> ReaderT
     (RawPostgresql SqlBackend)
     IO
     (Acquire
        (ConduitT () result (ReaderT (RawPostgresql SqlBackend) m) ()))
-> ReaderT
     (RawPostgresql SqlBackend)
     m
     (Acquire
        (ConduitT () result (ReaderT (RawPostgresql SqlBackend) m) ()))
forall a b. (a -> b) -> a -> b
$ do
    Acquire
  (ConduitT
     () [PersistValue] (ReaderT (RawPostgresql SqlBackend) m) ())
srcRes <- Text
-> [PersistValue]
-> ReaderT
     (RawPostgresql SqlBackend)
     IO
     (Acquire
        (ConduitT
           () [PersistValue] (ReaderT (RawPostgresql SqlBackend) m) ()))
forall (m1 :: * -> *) (m2 :: * -> *) backend.
(MonadIO m1, MonadIO m2, BackendCompatible SqlBackend backend) =>
Text
-> [PersistValue]
-> ReaderT
     (RawPostgresql backend)
     m1
     (Acquire (ConduitT () [PersistValue] m2 ()))
rawQueryResFromCursor Text
query [PersistValue]
vals
    Acquire
  (ConduitT () result (ReaderT (RawPostgresql SqlBackend) m) ())
-> ReaderT
     (RawPostgresql SqlBackend)
     IO
     (Acquire
        (ConduitT () result (ReaderT (RawPostgresql SqlBackend) m) ()))
forall (m :: * -> *) a. Monad m => a -> m a
return (Acquire
   (ConduitT () result (ReaderT (RawPostgresql SqlBackend) m) ())
 -> ReaderT
      (RawPostgresql SqlBackend)
      IO
      (Acquire
         (ConduitT () result (ReaderT (RawPostgresql SqlBackend) m) ())))
-> Acquire
     (ConduitT () result (ReaderT (RawPostgresql SqlBackend) m) ())
-> ReaderT
     (RawPostgresql SqlBackend)
     IO
     (Acquire
        (ConduitT () result (ReaderT (RawPostgresql SqlBackend) m) ()))
forall a b. (a -> b) -> a -> b
$ (ConduitT
   () [PersistValue] (ReaderT (RawPostgresql SqlBackend) m) ()
 -> ConduitT () result (ReaderT (RawPostgresql SqlBackend) m) ())
-> Acquire
     (ConduitT
        () [PersistValue] (ReaderT (RawPostgresql SqlBackend) m) ())
-> Acquire
     (ConduitT () result (ReaderT (RawPostgresql SqlBackend) m) ())
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (ConduitT
  () [PersistValue] (ReaderT (RawPostgresql SqlBackend) m) ()
-> ConduitM
     [PersistValue] result (ReaderT (RawPostgresql SqlBackend) m) ()
-> ConduitT () result (ReaderT (RawPostgresql SqlBackend) m) ()
forall (m :: * -> *) a b c r.
Monad m =>
ConduitM a b m () -> ConduitM b c m r -> ConduitM a c m r
.| ([PersistValue] -> ReaderT (RawPostgresql SqlBackend) m result)
-> ConduitM
     [PersistValue] result (ReaderT (RawPostgresql SqlBackend) m) ()
forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> ConduitT a b m ()
CC.mapM [PersistValue] -> ReaderT (RawPostgresql SqlBackend) m result
parse) Acquire
  (ConduitT
     () [PersistValue] (ReaderT (RawPostgresql SqlBackend) m) ())
srcRes
  (ReleaseKey
releaseKey, ConduitT () result (ReaderT (RawPostgresql SqlBackend) m) ()
src) <- Acquire
  (ConduitT () result (ReaderT (RawPostgresql SqlBackend) m) ())
-> ConduitT
     ()
     result
     (ReaderT (RawPostgresql SqlBackend) m)
     (ReleaseKey,
      ConduitT () result (ReaderT (RawPostgresql SqlBackend) m) ())
forall (m :: * -> *) a.
MonadResource m =>
Acquire a -> m (ReleaseKey, a)
allocateAcquire Acquire
  (ConduitT () result (ReaderT (RawPostgresql SqlBackend) m) ())
srcRes
  ConduitT () result (ReaderT (RawPostgresql SqlBackend) m) ()
src
  ReleaseKey
-> ConduitT () result (ReaderT (RawPostgresql SqlBackend) m) ()
forall (m :: * -> *). MonadIO m => ReleaseKey -> m ()
release ReleaseKey
releaseKey
 where
  parse :: [PersistValue] -> ReaderT (RawPostgresql SqlBackend) m result
parse [PersistValue]
resVals =
    case [PersistValue] -> Either Text result
parseRes [PersistValue]
resVals of
      Left Text
s ->
        IO result -> ReaderT (RawPostgresql SqlBackend) m result
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO result -> ReaderT (RawPostgresql SqlBackend) m result)
-> IO result -> ReaderT (RawPostgresql SqlBackend) m result
forall a b. (a -> b) -> a -> b
$ PersistException -> IO result
forall e a. Exception e => e -> IO a
throwIO (PersistException -> IO result) -> PersistException -> IO result
forall a b. (a -> b) -> a -> b
$
          Text -> PersistException
PersistMarshalError (Text
"rawSelectStream: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
s Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
", vals: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> String -> Text
T.pack ([PersistValue] -> String
forall a. Show a => a -> String
show [PersistValue]
vals ))
      Right result
row ->
        result -> ReaderT (RawPostgresql SqlBackend) m result
forall (m :: * -> *) a. Monad m => a -> m a
return result
row

rawQueryResFromCursor
  :: (MonadIO m1, MonadIO m2, BackendCompatible SqlBackend backend)
  => T.Text
  -> [PersistValue]
  -> ReaderT (RawPostgresql backend) m1 (Acquire (ConduitT () [PersistValue] m2 ()))
rawQueryResFromCursor :: Text
-> [PersistValue]
-> ReaderT
     (RawPostgresql backend)
     m1
     (Acquire (ConduitT () [PersistValue] m2 ()))
rawQueryResFromCursor Text
sql [PersistValue]
vals = do
  RawPostgresql backend
conn' Connection
pgConn <- ReaderT (RawPostgresql backend) m1 (RawPostgresql backend)
forall r (m :: * -> *). MonadReader r m => m r
ask
  let conn :: SqlBackend
conn = backend -> SqlBackend
forall sup sub. BackendCompatible sup sub => sub -> sup
projectBackend backend
conn'
  LoggingT (ReaderT (RawPostgresql backend) m1) ()
-> (Loc -> Text -> LogLevel -> LogStr -> IO ())
-> ReaderT (RawPostgresql backend) m1 ()
forall (m :: * -> *) a.
LoggingT m a -> (Loc -> Text -> LogLevel -> LogStr -> IO ()) -> m a
runLoggingT
    (Text -> Text -> LoggingT (ReaderT (RawPostgresql backend) m1) ()
forall (m :: * -> *). MonadLogger m => Text -> Text -> m ()
logDebugNS Text
"SQL" (Text -> LoggingT (ReaderT (RawPostgresql backend) m1) ())
-> Text -> LoggingT (ReaderT (RawPostgresql backend) m1) ()
forall a b. (a -> b) -> a -> b
$ Text -> Text -> Text
T.append Text
sql (Text -> Text) -> Text -> Text
forall a b. (a -> b) -> a -> b
$ String -> Text
T.pack (String -> Text) -> String -> Text
forall a b. (a -> b) -> a -> b
$ String
"; " String -> String -> String
forall a. [a] -> [a] -> [a]
++ [PersistValue] -> String
forall a. Show a => a -> String
show [PersistValue]
vals)
    (SqlBackend -> Loc -> Text -> LogLevel -> LogStr -> IO ()
connLogFunc SqlBackend
conn)
  Acquire (ConduitT () [PersistValue] m2 ())
-> ReaderT
     (RawPostgresql backend)
     m1
     (Acquire (ConduitT () [PersistValue] m2 ()))
forall (m :: * -> *) a. Monad m => a -> m a
return (Acquire (ConduitT () [PersistValue] m2 ())
 -> ReaderT
      (RawPostgresql backend)
      m1
      (Acquire (ConduitT () [PersistValue] m2 ())))
-> Acquire (ConduitT () [PersistValue] m2 ())
-> ReaderT
     (RawPostgresql backend)
     m1
     (Acquire (ConduitT () [PersistValue] m2 ()))
forall a b. (a -> b) -> a -> b
$ Connection
-> Query
-> [PersistValue]
-> Acquire (ConduitT () [PersistValue] m2 ())
forall (m :: * -> *).
MonadIO m =>
Connection
-> Query
-> [PersistValue]
-> Acquire (ConduitT () [PersistValue] m ())
withCursorStmt Connection
pgConn (ByteString -> Query
PG.Query (ByteString -> Query) -> ByteString -> Query
forall a b. (a -> b) -> a -> b
$ Text -> ByteString
T.encodeUtf8 Text
sql) [PersistValue]
vals

withCursorStmt
  :: MonadIO m
  => PG.Connection
  -> PG.Query
  -> [PersistValue]
  -> Acquire (ConduitT () [PersistValue] m ())
withCursorStmt :: Connection
-> Query
-> [PersistValue]
-> Acquire (ConduitT () [PersistValue] m ())
withCursorStmt Connection
conn Query
query [PersistValue]
vals =
  Cursor -> ConduitT () [PersistValue] m ()
forall (m :: * -> *) i.
MonadIO m =>
Cursor -> ConduitT i [PersistValue] m ()
foldWithCursor (Cursor -> ConduitT () [PersistValue] m ())
-> Acquire Cursor -> Acquire (ConduitT () [PersistValue] m ())
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
`fmap` IO Cursor -> (Cursor -> IO ()) -> Acquire Cursor
forall a. IO a -> (a -> IO ()) -> Acquire a
mkAcquire IO Cursor
openC Cursor -> IO ()
closeC
 where
  openC :: IO Cursor
openC = do
    ByteString
rawquery <- IO ByteString -> IO ByteString
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO ByteString -> IO ByteString) -> IO ByteString -> IO ByteString
forall a b. (a -> b) -> a -> b
$ Connection -> Query -> [P] -> IO ByteString
forall q. ToRow q => Connection -> Query -> q -> IO ByteString
PG.formatQuery Connection
conn Query
query ((PersistValue -> P) -> [PersistValue] -> [P]
forall a b. (a -> b) -> [a] -> [b]
map PersistValue -> P
P [PersistValue]
vals)
    Connection -> Query -> IO Cursor
PGC.declareCursor Connection
conn (ByteString -> Query
PG.Query ByteString
rawquery)
  closeC :: Cursor -> IO ()
closeC = Cursor -> IO ()
PGC.closeCursor
  foldWithCursor :: Cursor -> ConduitT i [PersistValue] m ()
foldWithCursor Cursor
cursor = ConduitT i [PersistValue] m ()
go
   where
    go :: ConduitT i [PersistValue] m ()
go = do
      -- 256 is the default chunk size used for fetching
      Either [[PersistValue]] [[PersistValue]]
rows <- IO (Either [[PersistValue]] [[PersistValue]])
-> ConduitT
     i [PersistValue] m (Either [[PersistValue]] [[PersistValue]])
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Either [[PersistValue]] [[PersistValue]])
 -> ConduitT
      i [PersistValue] m (Either [[PersistValue]] [[PersistValue]]))
-> IO (Either [[PersistValue]] [[PersistValue]])
-> ConduitT
     i [PersistValue] m (Either [[PersistValue]] [[PersistValue]])
forall a b. (a -> b) -> a -> b
$ Cursor
-> Int
-> ([[PersistValue]] -> [P] -> IO [[PersistValue]])
-> [[PersistValue]]
-> IO (Either [[PersistValue]] [[PersistValue]])
forall r a.
FromRow r =>
Cursor -> Int -> (a -> r -> IO a) -> a -> IO (Either a a)
PGC.foldForward Cursor
cursor Int
256 [[PersistValue]] -> [P] -> IO [[PersistValue]]
forall (f :: * -> *).
Applicative f =>
[[PersistValue]] -> [P] -> f [[PersistValue]]
processRow []
      case Either [[PersistValue]] [[PersistValue]]
rows of
        Left [[PersistValue]]
final -> [[PersistValue]] -> ConduitT i (Element [[PersistValue]]) m ()
forall (m :: * -> *) mono i.
(Monad m, MonoFoldable mono) =>
mono -> ConduitT i (Element mono) m ()
CC.yieldMany [[PersistValue]]
final
        Right [[PersistValue]]
nonfinal -> [[PersistValue]] -> ConduitT i (Element [[PersistValue]]) m ()
forall (m :: * -> *) mono i.
(Monad m, MonoFoldable mono) =>
mono -> ConduitT i (Element mono) m ()
CC.yieldMany [[PersistValue]]
nonfinal ConduitT i [PersistValue] m ()
-> ConduitT i [PersistValue] m () -> ConduitT i [PersistValue] m ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> ConduitT i [PersistValue] m ()
go
  processRow :: [[PersistValue]] -> [P] -> f [[PersistValue]]
processRow [[PersistValue]]
s [P]
row = [[PersistValue]] -> f [[PersistValue]]
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([[PersistValue]] -> f [[PersistValue]])
-> [[PersistValue]] -> f [[PersistValue]]
forall a b. (a -> b) -> a -> b
$ [[PersistValue]]
s [[PersistValue]] -> [[PersistValue]] -> [[PersistValue]]
forall a. Semigroup a => a -> a -> a
<> [(P -> PersistValue) -> [P] -> [PersistValue]
forall a b. (a -> b) -> [a] -> [b]
map P -> PersistValue
unP [P]
row]