{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}

-- |A module for automatic, optimal protocol pipelining.
--
--  Protocol pipelining is a technique in which multiple requests are written
--  out to a single socket without waiting for the corresponding responses.
--  The pipelining of requests results in a dramatic improvement in protocol
--  performance.
--
--  [Optimal Pipelining] uses the least number of network packets possible
--
--  [Automatic Pipelining] means that requests are implicitly pipelined as much
--      as possible, i.e. as long as a request's response is not used before any
--      subsequent requests.
--
module Database.Redis.ProtocolPipelining (
  Connection,
  connect, enableTLS, beginReceiving, disconnect, request, send, recv, flush,
  ConnectionLostException(..),
  ConnectTimeout(..),
  PortID(..)
) where

import           Prelude
import           Control.Concurrent (threadDelay)
import           Control.Concurrent.Async (race)
import           Control.Concurrent.MVar
import           Control.Exception
import           Control.Monad
import qualified Scanner
import qualified Data.ByteString as S
import qualified Data.ByteString.Lazy as L
import           Data.IORef
import           Data.Typeable
import qualified Network.Socket as NS
import qualified Network.TLS as TLS
import           System.IO
import           System.IO.Error
import           System.IO.Unsafe

import           Database.Redis.Protocol

data PortID = PortNumber NS.PortNumber
            | UnixSocket String
            deriving (PortID -> PortID -> Bool
(PortID -> PortID -> Bool)
-> (PortID -> PortID -> Bool) -> Eq PortID
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: PortID -> PortID -> Bool
$c/= :: PortID -> PortID -> Bool
== :: PortID -> PortID -> Bool
$c== :: PortID -> PortID -> Bool
Eq, Int -> PortID -> ShowS
[PortID] -> ShowS
PortID -> String
(Int -> PortID -> ShowS)
-> (PortID -> String) -> ([PortID] -> ShowS) -> Show PortID
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [PortID] -> ShowS
$cshowList :: [PortID] -> ShowS
show :: PortID -> String
$cshow :: PortID -> String
showsPrec :: Int -> PortID -> ShowS
$cshowsPrec :: Int -> PortID -> ShowS
Show)

data ConnectionContext = NormalHandle Handle | TLSContext TLS.Context

data Connection = Conn
  { Connection -> ConnectionContext
connCtx        :: ConnectionContext -- ^ Connection socket-handle.
  , Connection -> IORef [Reply]
connReplies    :: IORef [Reply] -- ^ Reply thunks for unsent requests.
  , Connection -> IORef [Reply]
connPending    :: IORef [Reply]
    -- ^ Reply thunks for requests "in the pipeline". Refers to the same list as
    --   'connReplies', but can have an offset.
  , Connection -> IORef Int
connPendingCnt :: IORef Int
    -- ^ Number of pending replies and thus the difference length between
    --   'connReplies' and 'connPending'.
    --   length connPending  - pendingCount = length connReplies
  }

data ConnectionLostException = ConnectionLost
  deriving (Int -> ConnectionLostException -> ShowS
[ConnectionLostException] -> ShowS
ConnectionLostException -> String
(Int -> ConnectionLostException -> ShowS)
-> (ConnectionLostException -> String)
-> ([ConnectionLostException] -> ShowS)
-> Show ConnectionLostException
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ConnectionLostException] -> ShowS
$cshowList :: [ConnectionLostException] -> ShowS
show :: ConnectionLostException -> String
$cshow :: ConnectionLostException -> String
showsPrec :: Int -> ConnectionLostException -> ShowS
$cshowsPrec :: Int -> ConnectionLostException -> ShowS
Show, Typeable)

instance Exception ConnectionLostException

data ConnectPhase
  = PhaseUnknown
  | PhaseResolve
  | PhaseOpenSocket
  deriving (Int -> ConnectPhase -> ShowS
[ConnectPhase] -> ShowS
ConnectPhase -> String
(Int -> ConnectPhase -> ShowS)
-> (ConnectPhase -> String)
-> ([ConnectPhase] -> ShowS)
-> Show ConnectPhase
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ConnectPhase] -> ShowS
$cshowList :: [ConnectPhase] -> ShowS
show :: ConnectPhase -> String
$cshow :: ConnectPhase -> String
showsPrec :: Int -> ConnectPhase -> ShowS
$cshowsPrec :: Int -> ConnectPhase -> ShowS
Show)

data ConnectTimeout = ConnectTimeout ConnectPhase
  deriving (Int -> ConnectTimeout -> ShowS
[ConnectTimeout] -> ShowS
ConnectTimeout -> String
(Int -> ConnectTimeout -> ShowS)
-> (ConnectTimeout -> String)
-> ([ConnectTimeout] -> ShowS)
-> Show ConnectTimeout
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ConnectTimeout] -> ShowS
$cshowList :: [ConnectTimeout] -> ShowS
show :: ConnectTimeout -> String
$cshow :: ConnectTimeout -> String
showsPrec :: Int -> ConnectTimeout -> ShowS
$cshowsPrec :: Int -> ConnectTimeout -> ShowS
Show, Typeable)

instance Exception ConnectTimeout

getHostAddrInfo :: NS.HostName -> NS.PortNumber -> IO [NS.AddrInfo]
getHostAddrInfo :: String -> PortNumber -> IO [AddrInfo]
getHostAddrInfo String
hostname PortNumber
port = do
  Maybe AddrInfo -> Maybe String -> Maybe String -> IO [AddrInfo]
NS.getAddrInfo (AddrInfo -> Maybe AddrInfo
forall a. a -> Maybe a
Just AddrInfo
hints) (String -> Maybe String
forall a. a -> Maybe a
Just String
hostname) (String -> Maybe String
forall a. a -> Maybe a
Just (String -> Maybe String) -> String -> Maybe String
forall a b. (a -> b) -> a -> b
$ PortNumber -> String
forall a. Show a => a -> String
show PortNumber
port)
  where
    hints :: AddrInfo
hints = AddrInfo
NS.defaultHints
      { addrSocketType :: SocketType
NS.addrSocketType = SocketType
NS.Stream }

connectSocket :: [NS.AddrInfo] -> IO NS.Socket
connectSocket :: [AddrInfo] -> IO Socket
connectSocket [] = String -> IO Socket
forall a. HasCallStack => String -> a
error String
"connectSocket: unexpected empty list"
connectSocket (AddrInfo
addr:[AddrInfo]
rest) = IO (Either IOError Socket)
tryConnect IO (Either IOError Socket)
-> (Either IOError Socket -> IO Socket) -> IO Socket
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
  Right Socket
sock -> Socket -> IO Socket
forall (m :: * -> *) a. Monad m => a -> m a
return Socket
sock
  Left IOError
err   -> if [AddrInfo] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [AddrInfo]
rest
                then IOError -> IO Socket
forall e a. Exception e => e -> IO a
throwIO IOError
err
                else [AddrInfo] -> IO Socket
connectSocket [AddrInfo]
rest
  where
    tryConnect :: IO (Either IOError NS.Socket)
    tryConnect :: IO (Either IOError Socket)
tryConnect = IO Socket
-> (Socket -> IO ())
-> (Socket -> IO (Either IOError Socket))
-> IO (Either IOError Socket)
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracketOnError IO Socket
createSock Socket -> IO ()
NS.close ((Socket -> IO (Either IOError Socket))
 -> IO (Either IOError Socket))
-> (Socket -> IO (Either IOError Socket))
-> IO (Either IOError Socket)
forall a b. (a -> b) -> a -> b
$ \Socket
sock -> do
      IO () -> IO (Either IOError ())
forall e a. Exception e => IO a -> IO (Either e a)
try (Socket -> SockAddr -> IO ()
NS.connect Socket
sock (SockAddr -> IO ()) -> SockAddr -> IO ()
forall a b. (a -> b) -> a -> b
$ AddrInfo -> SockAddr
NS.addrAddress AddrInfo
addr) IO (Either IOError ())
-> (Either IOError () -> IO (Either IOError Socket))
-> IO (Either IOError Socket)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
        Right () -> Either IOError Socket -> IO (Either IOError Socket)
forall (m :: * -> *) a. Monad m => a -> m a
return (Socket -> Either IOError Socket
forall a b. b -> Either a b
Right Socket
sock)
        Left IOError
err -> Socket -> IO ()
NS.close Socket
sock IO () -> IO (Either IOError Socket) -> IO (Either IOError Socket)
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Either IOError Socket -> IO (Either IOError Socket)
forall (m :: * -> *) a. Monad m => a -> m a
return (IOError -> Either IOError Socket
forall a b. a -> Either a b
Left IOError
err)
      where
        createSock :: IO Socket
createSock = Family -> SocketType -> ProtocolNumber -> IO Socket
NS.socket (AddrInfo -> Family
NS.addrFamily AddrInfo
addr)
                               (AddrInfo -> SocketType
NS.addrSocketType AddrInfo
addr)
                               (AddrInfo -> ProtocolNumber
NS.addrProtocol AddrInfo
addr)

connect :: NS.HostName -> PortID -> Maybe Int -> IO Connection
connect :: String -> PortID -> Maybe Int -> IO Connection
connect String
hostName PortID
portId Maybe Int
timeoutOpt =
  IO Handle
-> (Handle -> IO ()) -> (Handle -> IO Connection) -> IO Connection
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracketOnError IO Handle
hConnect Handle -> IO ()
hClose ((Handle -> IO Connection) -> IO Connection)
-> (Handle -> IO Connection) -> IO Connection
forall a b. (a -> b) -> a -> b
$ \Handle
h -> do
    Handle -> Bool -> IO ()
hSetBinaryMode Handle
h Bool
True
    IORef [Reply]
connReplies <- [Reply] -> IO (IORef [Reply])
forall a. a -> IO (IORef a)
newIORef []
    IORef [Reply]
connPending <- [Reply] -> IO (IORef [Reply])
forall a. a -> IO (IORef a)
newIORef []
    IORef Int
connPendingCnt <- Int -> IO (IORef Int)
forall a. a -> IO (IORef a)
newIORef Int
0
    let connCtx :: ConnectionContext
connCtx = Handle -> ConnectionContext
NormalHandle Handle
h
    Connection -> IO Connection
forall (m :: * -> *) a. Monad m => a -> m a
return Conn :: ConnectionContext
-> IORef [Reply] -> IORef [Reply] -> IORef Int -> Connection
Conn{IORef Int
IORef [Reply]
ConnectionContext
connCtx :: ConnectionContext
connPendingCnt :: IORef Int
connPending :: IORef [Reply]
connReplies :: IORef [Reply]
connPendingCnt :: IORef Int
connPending :: IORef [Reply]
connReplies :: IORef [Reply]
connCtx :: ConnectionContext
..}
  where
        hConnect :: IO Handle
hConnect = do
          MVar ConnectPhase
phaseMVar <- ConnectPhase -> IO (MVar ConnectPhase)
forall a. a -> IO (MVar a)
newMVar ConnectPhase
PhaseUnknown
          let doConnect :: IO Handle
doConnect = MVar ConnectPhase -> IO Handle
hConnect' MVar ConnectPhase
phaseMVar
          case Maybe Int
timeoutOpt of
            Maybe Int
Nothing -> IO Handle
doConnect
            Just Int
micros -> do
              Either Handle ()
result <- IO Handle -> IO () -> IO (Either Handle ())
forall a b. IO a -> IO b -> IO (Either a b)
race IO Handle
doConnect (Int -> IO ()
threadDelay Int
micros)
              case Either Handle ()
result of
                Left Handle
h -> Handle -> IO Handle
forall (m :: * -> *) a. Monad m => a -> m a
return Handle
h
                Right () -> do
                  ConnectPhase
phase <- MVar ConnectPhase -> IO ConnectPhase
forall a. MVar a -> IO a
readMVar MVar ConnectPhase
phaseMVar
                  ConnectPhase -> IO Handle
forall a. ConnectPhase -> IO a
errConnectTimeout ConnectPhase
phase
        hConnect' :: MVar ConnectPhase -> IO Handle
hConnect' MVar ConnectPhase
mvar = IO Socket
-> (Socket -> IO ()) -> (Socket -> IO Handle) -> IO Handle
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracketOnError IO Socket
createSock Socket -> IO ()
NS.close ((Socket -> IO Handle) -> IO Handle)
-> (Socket -> IO Handle) -> IO Handle
forall a b. (a -> b) -> a -> b
$ \Socket
sock -> do
          Socket -> SocketOption -> Int -> IO ()
NS.setSocketOption Socket
sock SocketOption
NS.KeepAlive Int
1
          IO ConnectPhase -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO ConnectPhase -> IO ()) -> IO ConnectPhase -> IO ()
forall a b. (a -> b) -> a -> b
$ MVar ConnectPhase -> ConnectPhase -> IO ConnectPhase
forall a. MVar a -> a -> IO a
swapMVar MVar ConnectPhase
mvar ConnectPhase
PhaseResolve
          IO ConnectPhase -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO ConnectPhase -> IO ()) -> IO ConnectPhase -> IO ()
forall a b. (a -> b) -> a -> b
$ MVar ConnectPhase -> ConnectPhase -> IO ConnectPhase
forall a. MVar a -> a -> IO a
swapMVar MVar ConnectPhase
mvar ConnectPhase
PhaseOpenSocket
          Socket -> IOMode -> IO Handle
NS.socketToHandle Socket
sock IOMode
ReadWriteMode
          where
            createSock :: IO Socket
createSock = case PortID
portId of
              PortNumber PortNumber
portNumber -> do
                [AddrInfo]
addrInfo <- String -> PortNumber -> IO [AddrInfo]
getHostAddrInfo String
hostName PortNumber
portNumber
                [AddrInfo] -> IO Socket
connectSocket [AddrInfo]
addrInfo
              UnixSocket String
addr -> IO Socket
-> (Socket -> IO ()) -> (Socket -> IO Socket) -> IO Socket
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracketOnError
                (Family -> SocketType -> ProtocolNumber -> IO Socket
NS.socket Family
NS.AF_UNIX SocketType
NS.Stream ProtocolNumber
NS.defaultProtocol)
                Socket -> IO ()
NS.close
                (\Socket
sock -> Socket -> SockAddr -> IO ()
NS.connect Socket
sock (String -> SockAddr
NS.SockAddrUnix String
addr) IO () -> IO Socket -> IO Socket
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Socket -> IO Socket
forall (m :: * -> *) a. Monad m => a -> m a
return Socket
sock)

enableTLS :: TLS.ClientParams -> Connection -> IO Connection
enableTLS :: ClientParams -> Connection -> IO Connection
enableTLS ClientParams
tlsParams conn :: Connection
conn@Conn{IORef Int
IORef [Reply]
ConnectionContext
connPendingCnt :: IORef Int
connPending :: IORef [Reply]
connReplies :: IORef [Reply]
connCtx :: ConnectionContext
connPendingCnt :: Connection -> IORef Int
connPending :: Connection -> IORef [Reply]
connReplies :: Connection -> IORef [Reply]
connCtx :: Connection -> ConnectionContext
..} = do
  case ConnectionContext
connCtx of
    NormalHandle Handle
h -> do
      Context
ctx <- Handle -> ClientParams -> IO Context
forall (m :: * -> *) backend params.
(MonadIO m, HasBackend backend, TLSParams params) =>
backend -> params -> m Context
TLS.contextNew Handle
h ClientParams
tlsParams
      Context -> IO ()
forall (m :: * -> *). MonadIO m => Context -> m ()
TLS.handshake Context
ctx
      Connection -> IO Connection
forall (m :: * -> *) a. Monad m => a -> m a
return (Connection -> IO Connection) -> Connection -> IO Connection
forall a b. (a -> b) -> a -> b
$ Connection
conn { connCtx :: ConnectionContext
connCtx = Context -> ConnectionContext
TLSContext Context
ctx }
    TLSContext Context
_ -> Connection -> IO Connection
forall (m :: * -> *) a. Monad m => a -> m a
return Connection
conn

beginReceiving :: Connection -> IO ()
beginReceiving :: Connection -> IO ()
beginReceiving Connection
conn = do
  [Reply]
rs <- Connection -> IO [Reply]
connGetReplies Connection
conn
  IORef [Reply] -> [Reply] -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef (Connection -> IORef [Reply]
connReplies Connection
conn) [Reply]
rs
  IORef [Reply] -> [Reply] -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef (Connection -> IORef [Reply]
connPending Connection
conn) [Reply]
rs

disconnect :: Connection -> IO ()
disconnect :: Connection -> IO ()
disconnect Conn{IORef Int
IORef [Reply]
ConnectionContext
connPendingCnt :: IORef Int
connPending :: IORef [Reply]
connReplies :: IORef [Reply]
connCtx :: ConnectionContext
connPendingCnt :: Connection -> IORef Int
connPending :: Connection -> IORef [Reply]
connReplies :: Connection -> IORef [Reply]
connCtx :: Connection -> ConnectionContext
..} = do
  case ConnectionContext
connCtx of
    NormalHandle Handle
h -> do
      Bool
open <- Handle -> IO Bool
hIsOpen Handle
h
      Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
open (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Handle -> IO ()
hClose Handle
h
    TLSContext Context
ctx -> do
      Context -> IO ()
forall (m :: * -> *). MonadIO m => Context -> m ()
TLS.bye Context
ctx
      Context -> IO ()
TLS.contextClose Context
ctx

-- |Write the request to the socket output buffer, without actually sending.
--  The 'Handle' is 'hFlush'ed when reading replies from the 'connCtx'.
send :: Connection -> S.ByteString -> IO ()
send :: Connection -> ByteString -> IO ()
send Conn{IORef Int
IORef [Reply]
ConnectionContext
connPendingCnt :: IORef Int
connPending :: IORef [Reply]
connReplies :: IORef [Reply]
connCtx :: ConnectionContext
connPendingCnt :: Connection -> IORef Int
connPending :: Connection -> IORef [Reply]
connReplies :: Connection -> IORef [Reply]
connCtx :: Connection -> ConnectionContext
..} ByteString
s = do
  case ConnectionContext
connCtx of
    NormalHandle Handle
h ->
      IO () -> IO ()
forall a. IO a -> IO a
ioErrorToConnLost (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Handle -> ByteString -> IO ()
S.hPut Handle
h ByteString
s

    TLSContext Context
ctx ->
      IO () -> IO ()
forall a. IO a -> IO a
ioErrorToConnLost (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Context -> ByteString -> IO ()
forall (m :: * -> *). MonadIO m => Context -> ByteString -> m ()
TLS.sendData Context
ctx (ByteString -> ByteString
L.fromStrict ByteString
s)

  -- Signal that we expect one more reply from Redis.
  Int
n <- IORef Int -> (Int -> (Int, Int)) -> IO Int
forall a b. IORef a -> (a -> (a, b)) -> IO b
atomicModifyIORef' IORef Int
connPendingCnt ((Int -> (Int, Int)) -> IO Int) -> (Int -> (Int, Int)) -> IO Int
forall a b. (a -> b) -> a -> b
$ \Int
n -> let n' :: Int
n' = Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1 in (Int
n', Int
n')
  -- Limit the "pipeline length". This is necessary in long pipelines, to avoid
  -- thunk build-up, and thus space-leaks.
  -- TODO find smallest max pending with good-enough performance.
  Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
1000) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
    -- Force oldest pending reply.
    Reply
r:[Reply]
_ <- IORef [Reply] -> IO [Reply]
forall a. IORef a -> IO a
readIORef IORef [Reply]
connPending
    Reply
r Reply -> IO () -> IO ()
`seq` () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

-- |Take a reply-thunk from the list of future replies.
recv :: Connection -> IO Reply
recv :: Connection -> IO Reply
recv Conn{IORef Int
IORef [Reply]
ConnectionContext
connPendingCnt :: IORef Int
connPending :: IORef [Reply]
connReplies :: IORef [Reply]
connCtx :: ConnectionContext
connPendingCnt :: Connection -> IORef Int
connPending :: Connection -> IORef [Reply]
connReplies :: Connection -> IORef [Reply]
connCtx :: Connection -> ConnectionContext
..} = do
  (Reply
r:[Reply]
rs) <- IORef [Reply] -> IO [Reply]
forall a. IORef a -> IO a
readIORef IORef [Reply]
connReplies
  IORef [Reply] -> [Reply] -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef [Reply]
connReplies [Reply]
rs
  Reply -> IO Reply
forall (m :: * -> *) a. Monad m => a -> m a
return Reply
r

-- | Flush the socket.  Normally, the socket is flushed in 'recv' (actually 'conGetReplies'), but
-- for the multithreaded pub/sub code, the sending thread needs to explicitly flush the subscription
-- change requests.
flush :: Connection -> IO ()
flush :: Connection -> IO ()
flush Conn{IORef Int
IORef [Reply]
ConnectionContext
connPendingCnt :: IORef Int
connPending :: IORef [Reply]
connReplies :: IORef [Reply]
connCtx :: ConnectionContext
connPendingCnt :: Connection -> IORef Int
connPending :: Connection -> IORef [Reply]
connReplies :: Connection -> IORef [Reply]
connCtx :: Connection -> ConnectionContext
..} =
  case ConnectionContext
connCtx of
    NormalHandle Handle
h -> Handle -> IO ()
hFlush Handle
h
    TLSContext Context
ctx -> Context -> IO ()
TLS.contextFlush Context
ctx

-- |Send a request and receive the corresponding reply
request :: Connection -> S.ByteString -> IO Reply
request :: Connection -> ByteString -> IO Reply
request Connection
conn ByteString
req = Connection -> ByteString -> IO ()
send Connection
conn ByteString
req IO () -> IO Reply -> IO Reply
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Connection -> IO Reply
recv Connection
conn

-- |A list of all future 'Reply's of the 'Connection'.
--
--  The spine of the list can be evaluated without forcing the replies.
--
--  Evaluating/forcing a 'Reply' from the list will 'unsafeInterleaveIO' the
--  reading and parsing from the 'connCtx'. To ensure correct ordering, each
--  Reply first evaluates (and thus reads from the network) the previous one.
--
--  'unsafeInterleaveIO' only evaluates it's result once, making this function
--  thread-safe. 'Handle' as implemented by GHC is also threadsafe, it is safe
--  to call 'hFlush' here. The list constructor '(:)' must be called from
--  /within/ unsafeInterleaveIO, to keep the replies in correct order.
connGetReplies :: Connection -> IO [Reply]
connGetReplies :: Connection -> IO [Reply]
connGetReplies conn :: Connection
conn@Conn{IORef Int
IORef [Reply]
ConnectionContext
connPendingCnt :: IORef Int
connPending :: IORef [Reply]
connReplies :: IORef [Reply]
connCtx :: ConnectionContext
connPendingCnt :: Connection -> IORef Int
connPending :: Connection -> IORef [Reply]
connReplies :: Connection -> IORef [Reply]
connCtx :: Connection -> ConnectionContext
..} = ByteString -> Reply -> IO [Reply]
go ByteString
S.empty (ByteString -> Reply
SingleLine ByteString
"previous of first")
  where
    go :: ByteString -> Reply -> IO [Reply]
go ByteString
rest Reply
previous = do
      -- lazy pattern match to actually delay the receiving
      ~(Reply
r, ByteString
rest') <- IO (Reply, ByteString) -> IO (Reply, ByteString)
forall a. IO a -> IO a
unsafeInterleaveIO (IO (Reply, ByteString) -> IO (Reply, ByteString))
-> IO (Reply, ByteString) -> IO (Reply, ByteString)
forall a b. (a -> b) -> a -> b
$ do
        -- Force previous reply for correct order.
        Reply
previous Reply -> IO () -> IO ()
`seq` () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
        Result Reply
scanResult <- IO ByteString -> Scanner Reply -> ByteString -> IO (Result Reply)
forall (m :: * -> *) a.
Monad m =>
m ByteString -> Scanner a -> ByteString -> m (Result a)
Scanner.scanWith IO ByteString
readMore Scanner Reply
reply ByteString
rest
        case Result Reply
scanResult of
          Scanner.Fail{}       -> IO (Reply, ByteString)
forall a. IO a
errConnClosed
          Scanner.More{}    -> String -> IO (Reply, ByteString)
forall a. HasCallStack => String -> a
error String
"Hedis: parseWith returned Partial"
          Scanner.Done ByteString
rest' Reply
r -> do
            -- r is the same as 'head' of 'connPending'. Since we just
            -- received r, we remove it from the pending list.
            IORef [Reply] -> ([Reply] -> ([Reply], ())) -> IO ()
forall a b. IORef a -> (a -> (a, b)) -> IO b
atomicModifyIORef' IORef [Reply]
connPending (([Reply] -> ([Reply], ())) -> IO ())
-> ([Reply] -> ([Reply], ())) -> IO ()
forall a b. (a -> b) -> a -> b
$ \(Reply
_:[Reply]
rs) -> ([Reply]
rs, ())
            -- We now expect one less reply from Redis. We don't count to
            -- negative, which would otherwise occur during pubsub.
            IORef Int -> (Int -> (Int, ())) -> IO ()
forall a b. IORef a -> (a -> (a, b)) -> IO b
atomicModifyIORef' IORef Int
connPendingCnt ((Int -> (Int, ())) -> IO ()) -> (Int -> (Int, ())) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Int
n -> (Int -> Int -> Int
forall a. Ord a => a -> a -> a
max Int
0 (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1), ())
            (Reply, ByteString) -> IO (Reply, ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return (Reply
r, ByteString
rest')
      [Reply]
rs <- IO [Reply] -> IO [Reply]
forall a. IO a -> IO a
unsafeInterleaveIO (ByteString -> Reply -> IO [Reply]
go ByteString
rest' Reply
r)
      [Reply] -> IO [Reply]
forall (m :: * -> *) a. Monad m => a -> m a
return (Reply
rReply -> [Reply] -> [Reply]
forall a. a -> [a] -> [a]
:[Reply]
rs)

    readMore :: IO ByteString
readMore = IO ByteString -> IO ByteString
forall a. IO a -> IO a
ioErrorToConnLost (IO ByteString -> IO ByteString) -> IO ByteString -> IO ByteString
forall a b. (a -> b) -> a -> b
$ do
      Connection -> IO ()
flush Connection
conn
      case ConnectionContext
connCtx of
        NormalHandle Handle
h -> Handle -> Int -> IO ByteString
S.hGetSome Handle
h Int
4096
        TLSContext Context
ctx -> Context -> IO ByteString
forall (m :: * -> *). MonadIO m => Context -> m ByteString
TLS.recvData Context
ctx

ioErrorToConnLost :: IO a -> IO a
ioErrorToConnLost :: IO a -> IO a
ioErrorToConnLost IO a
a = IO a
a IO a -> (IOError -> IO a) -> IO a
forall a. IO a -> (IOError -> IO a) -> IO a
`catchIOError` IO a -> IOError -> IO a
forall a b. a -> b -> a
const IO a
forall a. IO a
errConnClosed

errConnClosed :: IO a
errConnClosed :: IO a
errConnClosed = ConnectionLostException -> IO a
forall e a. Exception e => e -> IO a
throwIO ConnectionLostException
ConnectionLost

errConnectTimeout :: ConnectPhase -> IO a
errConnectTimeout :: ConnectPhase -> IO a
errConnectTimeout ConnectPhase
phase = ConnectTimeout -> IO a
forall e a. Exception e => e -> IO a
throwIO (ConnectTimeout -> IO a) -> ConnectTimeout -> IO a
forall a b. (a -> b) -> a -> b
$ ConnectPhase -> ConnectTimeout
ConnectTimeout ConnectPhase
phase