{-# LANGUAGE RecordWildCards #-} {-# LANGUAGE DeriveDataTypeable #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedStrings #-} module Database.Redis.ConnectionContext ( ConnectionContext(..) , ConnectTimeout(..) , ConnectionLostException(..) , PortID(..) , connect , disconnect , send , recv , errConnClosed , enableTLS , flush , ioErrorToConnLost ) where import Control.Concurrent (threadDelay) import Control.Concurrent.Async (race) import Control.Monad(when) import qualified Data.ByteString as B import qualified Data.ByteString.Lazy as LB import qualified Data.IORef as IOR import Control.Concurrent.MVar(newMVar, readMVar, swapMVar) import Control.Exception(bracketOnError, Exception, throwIO, try) import Data.Typeable import Data.Functor(void) import qualified Network.Socket as NS import qualified Network.TLS as TLS import System.IO(Handle, hSetBinaryMode, hClose, IOMode(..), hFlush, hIsOpen) import System.IO.Error(catchIOError) data ConnectionContext = NormalHandle Handle | TLSContext TLS.Context instance Show ConnectionContext where show :: ConnectionContext -> String show (NormalHandle Handle _) = String "NormalHandle" show (TLSContext Context _) = String "TLSContext" data Connection = Connection { Connection -> ConnectionContext ctx :: ConnectionContext , Connection -> IORef (Maybe ByteString) lastRecvRef :: IOR.IORef (Maybe B.ByteString) } instance Show Connection where show :: Connection -> String show Connection{IORef (Maybe ByteString) ConnectionContext lastRecvRef :: IORef (Maybe ByteString) ctx :: ConnectionContext lastRecvRef :: Connection -> IORef (Maybe ByteString) ctx :: Connection -> ConnectionContext ..} = String "Connection{ ctx = " forall a. [a] -> [a] -> [a] ++ forall a. Show a => a -> String show ConnectionContext ctx forall a. [a] -> [a] -> [a] ++ String ", lastRecvRef = IORef}" data ConnectPhase = PhaseUnknown | PhaseResolve | PhaseOpenSocket deriving (Int -> ConnectPhase -> ShowS [ConnectPhase] -> ShowS ConnectPhase -> String forall a. (Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a showList :: [ConnectPhase] -> ShowS $cshowList :: [ConnectPhase] -> ShowS show :: ConnectPhase -> String $cshow :: ConnectPhase -> String showsPrec :: Int -> ConnectPhase -> ShowS $cshowsPrec :: Int -> ConnectPhase -> ShowS Show) newtype ConnectTimeout = ConnectTimeout ConnectPhase deriving (Int -> ConnectTimeout -> ShowS [ConnectTimeout] -> ShowS ConnectTimeout -> String forall a. (Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a showList :: [ConnectTimeout] -> ShowS $cshowList :: [ConnectTimeout] -> ShowS show :: ConnectTimeout -> String $cshow :: ConnectTimeout -> String showsPrec :: Int -> ConnectTimeout -> ShowS $cshowsPrec :: Int -> ConnectTimeout -> ShowS Show, Typeable) instance Exception ConnectTimeout data ConnectionLostException = ConnectionLost deriving Int -> ConnectionLostException -> ShowS [ConnectionLostException] -> ShowS ConnectionLostException -> String forall a. (Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a showList :: [ConnectionLostException] -> ShowS $cshowList :: [ConnectionLostException] -> ShowS show :: ConnectionLostException -> String $cshow :: ConnectionLostException -> String showsPrec :: Int -> ConnectionLostException -> ShowS $cshowsPrec :: Int -> ConnectionLostException -> ShowS Show instance Exception ConnectionLostException data PortID = PortNumber NS.PortNumber | UnixSocket String deriving (PortID -> PortID -> Bool forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a /= :: PortID -> PortID -> Bool $c/= :: PortID -> PortID -> Bool == :: PortID -> PortID -> Bool $c== :: PortID -> PortID -> Bool Eq, Int -> PortID -> ShowS [PortID] -> ShowS PortID -> String forall a. (Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a showList :: [PortID] -> ShowS $cshowList :: [PortID] -> ShowS show :: PortID -> String $cshow :: PortID -> String showsPrec :: Int -> PortID -> ShowS $cshowsPrec :: Int -> PortID -> ShowS Show) connect :: NS.HostName -> PortID -> Maybe Int -> IO ConnectionContext connect :: String -> PortID -> Maybe Int -> IO ConnectionContext connect String hostName PortID portId Maybe Int timeoutOpt = forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c bracketOnError IO Handle hConnect Handle -> IO () hClose forall a b. (a -> b) -> a -> b $ \Handle h -> do Handle -> Bool -> IO () hSetBinaryMode Handle h Bool True forall (m :: * -> *) a. Monad m => a -> m a return forall a b. (a -> b) -> a -> b $ Handle -> ConnectionContext NormalHandle Handle h where hConnect :: IO Handle hConnect = do MVar ConnectPhase phaseMVar <- forall a. a -> IO (MVar a) newMVar ConnectPhase PhaseUnknown let doConnect :: IO Handle doConnect = MVar ConnectPhase -> IO Handle hConnect' MVar ConnectPhase phaseMVar case Maybe Int timeoutOpt of Maybe Int Nothing -> IO Handle doConnect Just Int micros -> do Either Handle () result <- forall a b. IO a -> IO b -> IO (Either a b) race IO Handle doConnect (Int -> IO () threadDelay Int micros) case Either Handle () result of Left Handle h -> forall (m :: * -> *) a. Monad m => a -> m a return Handle h Right () -> do ConnectPhase phase <- forall a. MVar a -> IO a readMVar MVar ConnectPhase phaseMVar forall a. ConnectPhase -> IO a errConnectTimeout ConnectPhase phase hConnect' :: MVar ConnectPhase -> IO Handle hConnect' MVar ConnectPhase mvar = forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c bracketOnError IO Socket createSock Socket -> IO () NS.close forall a b. (a -> b) -> a -> b $ \Socket sock -> do Socket -> SocketOption -> Int -> IO () NS.setSocketOption Socket sock SocketOption NS.KeepAlive Int 1 forall (f :: * -> *) a. Functor f => f a -> f () void forall a b. (a -> b) -> a -> b $ forall a. MVar a -> a -> IO a swapMVar MVar ConnectPhase mvar ConnectPhase PhaseResolve forall (f :: * -> *) a. Functor f => f a -> f () void forall a b. (a -> b) -> a -> b $ forall a. MVar a -> a -> IO a swapMVar MVar ConnectPhase mvar ConnectPhase PhaseOpenSocket Socket -> IOMode -> IO Handle NS.socketToHandle Socket sock IOMode ReadWriteMode where createSock :: IO Socket createSock = case PortID portId of PortNumber PortNumber portNumber -> do [AddrInfo] addrInfo <- String -> PortNumber -> IO [AddrInfo] getHostAddrInfo String hostName PortNumber portNumber [AddrInfo] -> IO Socket connectSocket [AddrInfo] addrInfo UnixSocket String addr -> forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c bracketOnError (Family -> SocketType -> ProtocolNumber -> IO Socket NS.socket Family NS.AF_UNIX SocketType NS.Stream ProtocolNumber NS.defaultProtocol) Socket -> IO () NS.close (\Socket sock -> Socket -> SockAddr -> IO () NS.connect Socket sock (String -> SockAddr NS.SockAddrUnix String addr) forall (m :: * -> *) a b. Monad m => m a -> m b -> m b >> forall (m :: * -> *) a. Monad m => a -> m a return Socket sock) getHostAddrInfo :: NS.HostName -> NS.PortNumber -> IO [NS.AddrInfo] getHostAddrInfo :: String -> PortNumber -> IO [AddrInfo] getHostAddrInfo String hostname PortNumber port = Maybe AddrInfo -> Maybe String -> Maybe String -> IO [AddrInfo] NS.getAddrInfo (forall a. a -> Maybe a Just AddrInfo hints) (forall a. a -> Maybe a Just String hostname) (forall a. a -> Maybe a Just forall a b. (a -> b) -> a -> b $ forall a. Show a => a -> String show PortNumber port) where hints :: AddrInfo hints = AddrInfo NS.defaultHints { addrSocketType :: SocketType NS.addrSocketType = SocketType NS.Stream } errConnectTimeout :: ConnectPhase -> IO a errConnectTimeout :: forall a. ConnectPhase -> IO a errConnectTimeout ConnectPhase phase = forall e a. Exception e => e -> IO a throwIO forall a b. (a -> b) -> a -> b $ ConnectPhase -> ConnectTimeout ConnectTimeout ConnectPhase phase connectSocket :: [NS.AddrInfo] -> IO NS.Socket connectSocket :: [AddrInfo] -> IO Socket connectSocket [] = forall a. HasCallStack => String -> a error String "connectSocket: unexpected empty list" connectSocket (AddrInfo addr:[AddrInfo] rest) = IO (Either IOError Socket) tryConnect forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b >>= \case Right Socket sock -> forall (m :: * -> *) a. Monad m => a -> m a return Socket sock Left IOError err -> if forall (t :: * -> *) a. Foldable t => t a -> Bool null [AddrInfo] rest then forall e a. Exception e => e -> IO a throwIO IOError err else [AddrInfo] -> IO Socket connectSocket [AddrInfo] rest where tryConnect :: IO (Either IOError NS.Socket) tryConnect :: IO (Either IOError Socket) tryConnect = forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c bracketOnError IO Socket createSock Socket -> IO () NS.close forall a b. (a -> b) -> a -> b $ \Socket sock -> forall e a. Exception e => IO a -> IO (Either e a) try (Socket -> SockAddr -> IO () NS.connect Socket sock forall a b. (a -> b) -> a -> b $ AddrInfo -> SockAddr NS.addrAddress AddrInfo addr) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b >>= \case Right () -> forall (m :: * -> *) a. Monad m => a -> m a return (forall a b. b -> Either a b Right Socket sock) Left IOError err -> Socket -> IO () NS.close Socket sock forall (m :: * -> *) a b. Monad m => m a -> m b -> m b >> forall (m :: * -> *) a. Monad m => a -> m a return (forall a b. a -> Either a b Left IOError err) where createSock :: IO Socket createSock = Family -> SocketType -> ProtocolNumber -> IO Socket NS.socket (AddrInfo -> Family NS.addrFamily AddrInfo addr) (AddrInfo -> SocketType NS.addrSocketType AddrInfo addr) (AddrInfo -> ProtocolNumber NS.addrProtocol AddrInfo addr) send :: ConnectionContext -> B.ByteString -> IO () send :: ConnectionContext -> ByteString -> IO () send (NormalHandle Handle h) ByteString requestData = forall a. IO a -> IO a ioErrorToConnLost (Handle -> ByteString -> IO () B.hPut Handle h ByteString requestData) send (TLSContext Context ctx) ByteString requestData = forall a. IO a -> IO a ioErrorToConnLost (forall (m :: * -> *). MonadIO m => Context -> ByteString -> m () TLS.sendData Context ctx (ByteString -> ByteString LB.fromStrict ByteString requestData)) recv :: ConnectionContext -> IO B.ByteString recv :: ConnectionContext -> IO ByteString recv (NormalHandle Handle h) = forall a. IO a -> IO a ioErrorToConnLost forall a b. (a -> b) -> a -> b $ Handle -> Int -> IO ByteString B.hGetSome Handle h Int 4096 recv (TLSContext Context ctx) = forall (m :: * -> *). MonadIO m => Context -> m ByteString TLS.recvData Context ctx ioErrorToConnLost :: IO a -> IO a ioErrorToConnLost :: forall a. IO a -> IO a ioErrorToConnLost IO a a = IO a a forall a. IO a -> (IOError -> IO a) -> IO a `catchIOError` forall a b. a -> b -> a const forall a. IO a errConnClosed errConnClosed :: IO a errConnClosed :: forall a. IO a errConnClosed = forall e a. Exception e => e -> IO a throwIO ConnectionLostException ConnectionLost enableTLS :: TLS.ClientParams -> ConnectionContext -> IO ConnectionContext enableTLS :: ClientParams -> ConnectionContext -> IO ConnectionContext enableTLS ClientParams tlsParams (NormalHandle Handle h) = do Context ctx <- forall (m :: * -> *) backend params. (MonadIO m, HasBackend backend, TLSParams params) => backend -> params -> m Context TLS.contextNew Handle h ClientParams tlsParams forall (m :: * -> *). MonadIO m => Context -> m () TLS.handshake Context ctx forall (m :: * -> *) a. Monad m => a -> m a return forall a b. (a -> b) -> a -> b $ Context -> ConnectionContext TLSContext Context ctx enableTLS ClientParams _ c :: ConnectionContext c@(TLSContext Context _) = forall (m :: * -> *) a. Monad m => a -> m a return ConnectionContext c disconnect :: ConnectionContext -> IO () disconnect :: ConnectionContext -> IO () disconnect (NormalHandle Handle h) = do Bool open <- Handle -> IO Bool hIsOpen Handle h forall (f :: * -> *). Applicative f => Bool -> f () -> f () when Bool open forall a b. (a -> b) -> a -> b $ Handle -> IO () hClose Handle h disconnect (TLSContext Context ctx) = do forall (m :: * -> *). MonadIO m => Context -> m () TLS.bye Context ctx Context -> IO () TLS.contextClose Context ctx flush :: ConnectionContext -> IO () flush :: ConnectionContext -> IO () flush (NormalHandle Handle h) = Handle -> IO () hFlush Handle h flush (TLSContext Context c) = Context -> IO () TLS.contextFlush Context c