{-# LANGUAGE DeriveDataTypeable #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE RecordWildCards #-} -- |A module for automatic, optimal protocol pipelining. -- -- Protocol pipelining is a technique in which multiple requests are written -- out to a single socket without waiting for the corresponding responses. -- The pipelining of requests results in a dramatic improvement in protocol -- performance. -- -- [Optimal Pipelining] uses the least number of network packets possible -- -- [Automatic Pipelining] means that requests are implicitly pipelined as much -- as possible, i.e. as long as a request's response is not used before any -- subsequent requests. -- module Database.Redis.ProtocolPipelining ( Connection, connect, enableTLS, beginReceiving, disconnect, request, send, recv, flush, ConnectionLostException(..), PortID(..) ) where import Prelude import Control.Concurrent (threadDelay) import Control.Concurrent.Async (race) import Control.Concurrent.MVar import Control.Exception import Control.Monad import qualified Scanner import qualified Data.ByteString as S import qualified Data.ByteString.Lazy as L import Data.IORef import Data.Typeable import qualified Network.Socket as NS import qualified Network.TLS as TLS import System.IO import System.IO.Error import System.IO.Unsafe import Database.Redis.Protocol data PortID = PortNumber NS.PortNumber | UnixSocket String deriving Show data ConnectionContext = NormalHandle Handle | TLSContext TLS.Context data Connection = Conn { connCtx :: ConnectionContext -- ^ Connection socket-handle. , connReplies :: IORef [Reply] -- ^ Reply thunks for unsent requests. , connPending :: IORef [Reply] -- ^ Reply thunks for requests "in the pipeline". Refers to the same list as -- 'connReplies', but can have an offset. , connPendingCnt :: IORef Int -- ^ Number of pending replies and thus the difference length between -- 'connReplies' and 'connPending'. -- length connPending - pendingCount = length connReplies } data ConnectionLostException = ConnectionLost deriving (Show, Typeable) instance Exception ConnectionLostException data ConnectPhase = PhaseUnknown | PhaseResolve | PhaseOpenSocket deriving (Show) data ConnectTimeout = ConnectTimeout ConnectPhase deriving (Show, Typeable) instance Exception ConnectTimeout getHostAddrInfo :: NS.HostName -> NS.PortNumber -> IO [NS.AddrInfo] getHostAddrInfo hostname port = do NS.getAddrInfo (Just hints) (Just hostname) (Just $ show port) where hints = NS.defaultHints { NS.addrSocketType = NS.Stream } connectSocket :: [NS.AddrInfo] -> IO NS.Socket connectSocket [] = error "connectSocket: unexpected empty list" connectSocket (addr:rest) = tryConnect >>= \case Right sock -> return sock Left err -> if null rest then throwIO err else connectSocket rest where tryConnect :: IO (Either IOError NS.Socket) tryConnect = bracketOnError createSock NS.close $ \sock -> do try (NS.connect sock $ NS.addrAddress addr) >>= \case Right () -> return (Right sock) Left err -> return (Left err) where createSock = NS.socket (NS.addrFamily addr) (NS.addrSocketType addr) (NS.addrProtocol addr) connect :: NS.HostName -> PortID -> Maybe Int -> IO Connection connect hostName portId timeoutOpt = bracketOnError hConnect hClose $ \h -> do hSetBinaryMode h True connReplies <- newIORef [] connPending <- newIORef [] connPendingCnt <- newIORef 0 let connCtx = NormalHandle h return Conn{..} where hConnect = do phaseMVar <- newMVar PhaseUnknown let doConnect = hConnect' phaseMVar case timeoutOpt of Nothing -> doConnect Just micros -> do result <- race doConnect (threadDelay micros) case result of Left h -> return h Right () -> do phase <- readMVar phaseMVar errConnectTimeout phase hConnect' mvar = bracketOnError createSock NS.close $ \sock -> do NS.setSocketOption sock NS.KeepAlive 1 void $ swapMVar mvar PhaseResolve void $ swapMVar mvar PhaseOpenSocket NS.socketToHandle sock ReadWriteMode where createSock = case portId of PortNumber portNumber -> do addrInfo <- getHostAddrInfo hostName portNumber connectSocket addrInfo UnixSocket addr -> bracketOnError (NS.socket NS.AF_UNIX NS.Stream NS.defaultProtocol) NS.close (\sock -> NS.connect sock (NS.SockAddrUnix addr) >> return sock) enableTLS :: TLS.ClientParams -> Connection -> IO Connection enableTLS tlsParams conn@Conn{..} = do case connCtx of NormalHandle h -> do ctx <- TLS.contextNew h tlsParams TLS.handshake ctx return $ conn { connCtx = TLSContext ctx } TLSContext _ -> return conn beginReceiving :: Connection -> IO () beginReceiving conn = do rs <- connGetReplies conn writeIORef (connReplies conn) rs writeIORef (connPending conn) rs disconnect :: Connection -> IO () disconnect Conn{..} = do case connCtx of NormalHandle h -> do open <- hIsOpen h when open $ hClose h TLSContext ctx -> do TLS.bye ctx TLS.contextClose ctx -- |Write the request to the socket output buffer, without actually sending. -- The 'Handle' is 'hFlush'ed when reading replies from the 'connCtx'. send :: Connection -> S.ByteString -> IO () send Conn{..} s = do case connCtx of NormalHandle h -> ioErrorToConnLost $ S.hPut h s TLSContext ctx -> ioErrorToConnLost $ TLS.sendData ctx (L.fromStrict s) -- Signal that we expect one more reply from Redis. n <- atomicModifyIORef' connPendingCnt $ \n -> let n' = n+1 in (n', n') -- Limit the "pipeline length". This is necessary in long pipelines, to avoid -- thunk build-up, and thus space-leaks. -- TODO find smallest max pending with good-enough performance. when (n >= 1000) $ do -- Force oldest pending reply. r:_ <- readIORef connPending r `seq` return () -- |Take a reply-thunk from the list of future replies. recv :: Connection -> IO Reply recv Conn{..} = do (r:rs) <- readIORef connReplies writeIORef connReplies rs return r -- | Flush the socket. Normally, the socket is flushed in 'recv' (actually 'conGetReplies'), but -- for the multithreaded pub/sub code, the sending thread needs to explicitly flush the subscription -- change requests. flush :: Connection -> IO () flush Conn{..} = case connCtx of NormalHandle h -> hFlush h TLSContext ctx -> TLS.contextFlush ctx -- |Send a request and receive the corresponding reply request :: Connection -> S.ByteString -> IO Reply request conn req = send conn req >> recv conn -- |A list of all future 'Reply's of the 'Connection'. -- -- The spine of the list can be evaluated without forcing the replies. -- -- Evaluating/forcing a 'Reply' from the list will 'unsafeInterleaveIO' the -- reading and parsing from the 'connCtx'. To ensure correct ordering, each -- Reply first evaluates (and thus reads from the network) the previous one. -- -- 'unsafeInterleaveIO' only evaluates it's result once, making this function -- thread-safe. 'Handle' as implemented by GHC is also threadsafe, it is safe -- to call 'hFlush' here. The list constructor '(:)' must be called from -- /within/ unsafeInterleaveIO, to keep the replies in correct order. connGetReplies :: Connection -> IO [Reply] connGetReplies conn@Conn{..} = go S.empty (SingleLine "previous of first") where go rest previous = do -- lazy pattern match to actually delay the receiving ~(r, rest') <- unsafeInterleaveIO $ do -- Force previous reply for correct order. previous `seq` return () scanResult <- Scanner.scanWith readMore reply rest case scanResult of Scanner.Fail{} -> errConnClosed Scanner.More{} -> error "Hedis: parseWith returned Partial" Scanner.Done rest' r -> do -- r is the same as 'head' of 'connPending'. Since we just -- received r, we remove it from the pending list. atomicModifyIORef' connPending $ \(_:rs) -> (rs, ()) -- We now expect one less reply from Redis. We don't count to -- negative, which would otherwise occur during pubsub. atomicModifyIORef' connPendingCnt $ \n -> (max 0 (n-1), ()) return (r, rest') rs <- unsafeInterleaveIO (go rest' r) return (r:rs) readMore = ioErrorToConnLost $ do flush conn case connCtx of NormalHandle h -> S.hGetSome h 4096 TLSContext ctx -> TLS.recvData ctx ioErrorToConnLost :: IO a -> IO a ioErrorToConnLost a = a `catchIOError` const errConnClosed errConnClosed :: IO a errConnClosed = throwIO ConnectionLost errConnectTimeout :: ConnectPhase -> IO a errConnectTimeout phase = throwIO $ ConnectTimeout phase