{-# LANGUAGE RecordWildCards, DeriveDataTypeable, OverloadedStrings #-}
module Database.Redis.ProtocolPipelining (
Connection,
connect, enableTLS, beginReceiving, disconnect, request, send, recv, flush,
ConnectionLostException(..),
HostName, PortID(..)
) where
import Prelude
import Control.Concurrent (threadDelay)
import Control.Concurrent.Async (race)
import Control.Concurrent.MVar
import Control.Exception
import Control.Monad
import qualified Scanner
import qualified Data.ByteString as S
import qualified Data.ByteString.Lazy as L
import Data.IORef
import Data.Typeable
import Network
import qualified Network.BSD as BSD
import qualified Network.Socket as NS
import qualified Network.TLS as TLS
import System.IO
import System.IO.Error
import System.IO.Unsafe
import Database.Redis.Protocol
data ConnectionContext = NormalHandle Handle | TLSContext TLS.Context
data Connection = Conn
{ connCtx :: ConnectionContext
, connReplies :: IORef [Reply]
, connPending :: IORef [Reply]
, connPendingCnt :: IORef Int
}
data ConnectionLostException = ConnectionLost
deriving (Show, Typeable)
instance Exception ConnectionLostException
data ConnectPhase
= PhaseUnknown
| PhaseResolve
| PhaseOpenSocket
deriving (Show)
data ConnectTimeout = ConnectTimeout ConnectPhase
deriving (Show, Typeable)
instance Exception ConnectTimeout
connect :: HostName -> PortID -> Maybe Int -> IO Connection
connect hostName portID timeoutOpt =
bracketOnError hConnect hClose $ \h -> do
hSetBinaryMode h True
connReplies <- newIORef []
connPending <- newIORef []
connPendingCnt <- newIORef 0
let connCtx = NormalHandle h
return Conn{..}
where
hConnect = do
phaseMVar <- newMVar PhaseUnknown
let doConnect = hConnect' portID phaseMVar
case timeoutOpt of
Nothing -> doConnect
Just micros -> do
result <- race doConnect (threadDelay micros)
case result of
Left h -> return h
Right () -> do
phase <- readMVar phaseMVar
errConnectTimeout phase
hConnect' (PortNumber port) mvar =
bracketOnError mkSocket NS.close $ \sock -> do
NS.setSocketOption sock NS.KeepAlive 1
void $ swapMVar mvar PhaseResolve
host <- BSD.getHostByName hostName
void $ swapMVar mvar PhaseOpenSocket
NS.connect sock $ NS.SockAddrInet port (BSD.hostAddress host)
NS.socketToHandle sock ReadWriteMode
hConnect' _ _ = connectTo hostName portID
mkSocket = NS.socket NS.AF_INET NS.Stream 0
enableTLS :: TLS.ClientParams -> Connection -> IO Connection
enableTLS tlsParams conn@Conn{..} = do
case connCtx of
NormalHandle h -> do
ctx <- TLS.contextNew h tlsParams
TLS.handshake ctx
return $ conn { connCtx = TLSContext ctx }
TLSContext _ -> return conn
beginReceiving :: Connection -> IO ()
beginReceiving conn = do
rs <- connGetReplies conn
writeIORef (connReplies conn) rs
writeIORef (connPending conn) rs
disconnect :: Connection -> IO ()
disconnect Conn{..} = do
case connCtx of
NormalHandle h -> do
open <- hIsOpen h
when open $ hClose h
TLSContext ctx -> do
TLS.bye ctx
TLS.contextClose ctx
send :: Connection -> S.ByteString -> IO ()
send Conn{..} s = do
case connCtx of
NormalHandle h ->
ioErrorToConnLost $ S.hPut h s
TLSContext ctx ->
ioErrorToConnLost $ TLS.sendData ctx (L.fromStrict s)
n <- atomicModifyIORef' connPendingCnt $ \n -> let n' = n+1 in (n', n')
when (n >= 1000) $ do
r:_ <- readIORef connPending
r `seq` return ()
recv :: Connection -> IO Reply
recv Conn{..} = do
(r:rs) <- readIORef connReplies
writeIORef connReplies rs
return r
flush :: Connection -> IO ()
flush Conn{..} =
case connCtx of
NormalHandle h -> hFlush h
TLSContext ctx -> TLS.contextFlush ctx
request :: Connection -> S.ByteString -> IO Reply
request conn req = send conn req >> recv conn
connGetReplies :: Connection -> IO [Reply]
connGetReplies conn@Conn{..} = go S.empty (SingleLine "previous of first")
where
go rest previous = do
~(r, rest') <- unsafeInterleaveIO $ do
previous `seq` return ()
scanResult <- Scanner.scanWith readMore reply rest
case scanResult of
Scanner.Fail{} -> errConnClosed
Scanner.More{} -> error "Hedis: parseWith returned Partial"
Scanner.Done rest' r -> do
atomicModifyIORef' connPending $ \(_:rs) -> (rs, ())
atomicModifyIORef' connPendingCnt $ \n -> (max 0 (n-1), ())
return (r, rest')
rs <- unsafeInterleaveIO (go rest' r)
return (r:rs)
readMore = ioErrorToConnLost $ do
flush conn
case connCtx of
NormalHandle h -> S.hGetSome h 4096
TLSContext ctx -> TLS.recvData ctx
ioErrorToConnLost :: IO a -> IO a
ioErrorToConnLost a = a `catchIOError` const errConnClosed
errConnClosed :: IO a
errConnClosed = throwIO ConnectionLost
errConnectTimeout :: ConnectPhase -> IO a
errConnectTimeout phase = throwIO $ ConnectTimeout phase