{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
module Database.Redis.ConnectionContext (
    ConnectionContext(..)
  , ConnectTimeout(..)
  , ConnectionLostException(..)
  , PortID(..)
  , connect
  , disconnect
  , send
  , recv
  , errConnClosed
  , enableTLS
  , flush
  , ioErrorToConnLost
) where

import           Control.Concurrent (threadDelay)
import           Control.Concurrent.Async (race)
import Control.Monad(when)
import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as LB
import qualified Data.IORef as IOR
import Control.Concurrent.MVar(newMVar, readMVar, swapMVar)
import Control.Exception(bracketOnError, Exception, throwIO, try)
import           Data.Typeable
import Data.Functor(void)
import qualified Network.Socket as NS
import qualified Network.TLS as TLS
import System.IO(Handle, hSetBinaryMode, hClose, IOMode(..), hFlush, hIsOpen)
import System.IO.Error(catchIOError)

data ConnectionContext = NormalHandle Handle | TLSContext TLS.Context

instance Show ConnectionContext where
    show :: ConnectionContext -> String
show (NormalHandle Handle
_) = String
"NormalHandle"
    show (TLSContext Context
_) = String
"TLSContext"

data Connection = Connection
    { Connection -> ConnectionContext
ctx :: ConnectionContext
    , Connection -> IORef (Maybe ByteString)
lastRecvRef :: IOR.IORef (Maybe B.ByteString) }

instance Show Connection where
    show :: Connection -> String
show Connection{IORef (Maybe ByteString)
ConnectionContext
lastRecvRef :: IORef (Maybe ByteString)
ctx :: ConnectionContext
lastRecvRef :: Connection -> IORef (Maybe ByteString)
ctx :: Connection -> ConnectionContext
..} = String
"Connection{ ctx = " String -> ShowS
forall a. [a] -> [a] -> [a]
++ ConnectionContext -> String
forall a. Show a => a -> String
show ConnectionContext
ctx String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
", lastRecvRef = IORef}"

data ConnectPhase
  = PhaseUnknown
  | PhaseResolve
  | PhaseOpenSocket
  deriving (Int -> ConnectPhase -> ShowS
[ConnectPhase] -> ShowS
ConnectPhase -> String
(Int -> ConnectPhase -> ShowS)
-> (ConnectPhase -> String)
-> ([ConnectPhase] -> ShowS)
-> Show ConnectPhase
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ConnectPhase] -> ShowS
$cshowList :: [ConnectPhase] -> ShowS
show :: ConnectPhase -> String
$cshow :: ConnectPhase -> String
showsPrec :: Int -> ConnectPhase -> ShowS
$cshowsPrec :: Int -> ConnectPhase -> ShowS
Show)

newtype ConnectTimeout = ConnectTimeout ConnectPhase
  deriving (Int -> ConnectTimeout -> ShowS
[ConnectTimeout] -> ShowS
ConnectTimeout -> String
(Int -> ConnectTimeout -> ShowS)
-> (ConnectTimeout -> String)
-> ([ConnectTimeout] -> ShowS)
-> Show ConnectTimeout
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ConnectTimeout] -> ShowS
$cshowList :: [ConnectTimeout] -> ShowS
show :: ConnectTimeout -> String
$cshow :: ConnectTimeout -> String
showsPrec :: Int -> ConnectTimeout -> ShowS
$cshowsPrec :: Int -> ConnectTimeout -> ShowS
Show, Typeable)

instance Exception ConnectTimeout

data ConnectionLostException = ConnectionLost deriving Int -> ConnectionLostException -> ShowS
[ConnectionLostException] -> ShowS
ConnectionLostException -> String
(Int -> ConnectionLostException -> ShowS)
-> (ConnectionLostException -> String)
-> ([ConnectionLostException] -> ShowS)
-> Show ConnectionLostException
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ConnectionLostException] -> ShowS
$cshowList :: [ConnectionLostException] -> ShowS
show :: ConnectionLostException -> String
$cshow :: ConnectionLostException -> String
showsPrec :: Int -> ConnectionLostException -> ShowS
$cshowsPrec :: Int -> ConnectionLostException -> ShowS
Show
instance Exception ConnectionLostException

data PortID = PortNumber NS.PortNumber
            | UnixSocket String
            deriving (PortID -> PortID -> Bool
(PortID -> PortID -> Bool)
-> (PortID -> PortID -> Bool) -> Eq PortID
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: PortID -> PortID -> Bool
$c/= :: PortID -> PortID -> Bool
== :: PortID -> PortID -> Bool
$c== :: PortID -> PortID -> Bool
Eq, Int -> PortID -> ShowS
[PortID] -> ShowS
PortID -> String
(Int -> PortID -> ShowS)
-> (PortID -> String) -> ([PortID] -> ShowS) -> Show PortID
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [PortID] -> ShowS
$cshowList :: [PortID] -> ShowS
show :: PortID -> String
$cshow :: PortID -> String
showsPrec :: Int -> PortID -> ShowS
$cshowsPrec :: Int -> PortID -> ShowS
Show)

connect :: NS.HostName -> PortID -> Maybe Int -> IO ConnectionContext
connect :: String -> PortID -> Maybe Int -> IO ConnectionContext
connect String
hostName PortID
portId Maybe Int
timeoutOpt =
  IO Handle
-> (Handle -> IO ())
-> (Handle -> IO ConnectionContext)
-> IO ConnectionContext
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracketOnError IO Handle
hConnect Handle -> IO ()
hClose ((Handle -> IO ConnectionContext) -> IO ConnectionContext)
-> (Handle -> IO ConnectionContext) -> IO ConnectionContext
forall a b. (a -> b) -> a -> b
$ \Handle
h -> do
    Handle -> Bool -> IO ()
hSetBinaryMode Handle
h Bool
True
    ConnectionContext -> IO ConnectionContext
forall (m :: * -> *) a. Monad m => a -> m a
return (ConnectionContext -> IO ConnectionContext)
-> ConnectionContext -> IO ConnectionContext
forall a b. (a -> b) -> a -> b
$ Handle -> ConnectionContext
NormalHandle Handle
h
  where
        hConnect :: IO Handle
hConnect = do
          MVar ConnectPhase
phaseMVar <- ConnectPhase -> IO (MVar ConnectPhase)
forall a. a -> IO (MVar a)
newMVar ConnectPhase
PhaseUnknown
          let doConnect :: IO Handle
doConnect = MVar ConnectPhase -> IO Handle
hConnect' MVar ConnectPhase
phaseMVar
          case Maybe Int
timeoutOpt of
            Maybe Int
Nothing -> IO Handle
doConnect
            Just Int
micros -> do
              Either Handle ()
result <- IO Handle -> IO () -> IO (Either Handle ())
forall a b. IO a -> IO b -> IO (Either a b)
race IO Handle
doConnect (Int -> IO ()
threadDelay Int
micros)
              case Either Handle ()
result of
                Left Handle
h -> Handle -> IO Handle
forall (m :: * -> *) a. Monad m => a -> m a
return Handle
h
                Right () -> do
                  ConnectPhase
phase <- MVar ConnectPhase -> IO ConnectPhase
forall a. MVar a -> IO a
readMVar MVar ConnectPhase
phaseMVar
                  ConnectPhase -> IO Handle
forall a. ConnectPhase -> IO a
errConnectTimeout ConnectPhase
phase
        hConnect' :: MVar ConnectPhase -> IO Handle
hConnect' MVar ConnectPhase
mvar = IO Socket
-> (Socket -> IO ()) -> (Socket -> IO Handle) -> IO Handle
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracketOnError IO Socket
createSock Socket -> IO ()
NS.close ((Socket -> IO Handle) -> IO Handle)
-> (Socket -> IO Handle) -> IO Handle
forall a b. (a -> b) -> a -> b
$ \Socket
sock -> do
          Socket -> SocketOption -> Int -> IO ()
NS.setSocketOption Socket
sock SocketOption
NS.KeepAlive Int
1
          IO ConnectPhase -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO ConnectPhase -> IO ()) -> IO ConnectPhase -> IO ()
forall a b. (a -> b) -> a -> b
$ MVar ConnectPhase -> ConnectPhase -> IO ConnectPhase
forall a. MVar a -> a -> IO a
swapMVar MVar ConnectPhase
mvar ConnectPhase
PhaseResolve
          IO ConnectPhase -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO ConnectPhase -> IO ()) -> IO ConnectPhase -> IO ()
forall a b. (a -> b) -> a -> b
$ MVar ConnectPhase -> ConnectPhase -> IO ConnectPhase
forall a. MVar a -> a -> IO a
swapMVar MVar ConnectPhase
mvar ConnectPhase
PhaseOpenSocket
          Socket -> IOMode -> IO Handle
NS.socketToHandle Socket
sock IOMode
ReadWriteMode
          where
            createSock :: IO Socket
createSock = case PortID
portId of
              PortNumber PortNumber
portNumber -> do
                [AddrInfo]
addrInfo <- String -> PortNumber -> IO [AddrInfo]
getHostAddrInfo String
hostName PortNumber
portNumber
                [AddrInfo] -> IO Socket
connectSocket [AddrInfo]
addrInfo
              UnixSocket String
addr -> IO Socket
-> (Socket -> IO ()) -> (Socket -> IO Socket) -> IO Socket
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracketOnError
                (Family -> SocketType -> ProtocolNumber -> IO Socket
NS.socket Family
NS.AF_UNIX SocketType
NS.Stream ProtocolNumber
NS.defaultProtocol)
                Socket -> IO ()
NS.close
                (\Socket
sock -> Socket -> SockAddr -> IO ()
NS.connect Socket
sock (String -> SockAddr
NS.SockAddrUnix String
addr) IO () -> IO Socket -> IO Socket
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Socket -> IO Socket
forall (m :: * -> *) a. Monad m => a -> m a
return Socket
sock)

getHostAddrInfo :: NS.HostName -> NS.PortNumber -> IO [NS.AddrInfo]
getHostAddrInfo :: String -> PortNumber -> IO [AddrInfo]
getHostAddrInfo String
hostname PortNumber
port =
  Maybe AddrInfo -> Maybe String -> Maybe String -> IO [AddrInfo]
NS.getAddrInfo (AddrInfo -> Maybe AddrInfo
forall a. a -> Maybe a
Just AddrInfo
hints) (String -> Maybe String
forall a. a -> Maybe a
Just String
hostname) (String -> Maybe String
forall a. a -> Maybe a
Just (String -> Maybe String) -> String -> Maybe String
forall a b. (a -> b) -> a -> b
$ PortNumber -> String
forall a. Show a => a -> String
show PortNumber
port)
  where
    hints :: AddrInfo
hints = AddrInfo
NS.defaultHints
      { addrSocketType :: SocketType
NS.addrSocketType = SocketType
NS.Stream }

errConnectTimeout :: ConnectPhase -> IO a
errConnectTimeout :: ConnectPhase -> IO a
errConnectTimeout ConnectPhase
phase = ConnectTimeout -> IO a
forall e a. Exception e => e -> IO a
throwIO (ConnectTimeout -> IO a) -> ConnectTimeout -> IO a
forall a b. (a -> b) -> a -> b
$ ConnectPhase -> ConnectTimeout
ConnectTimeout ConnectPhase
phase

connectSocket :: [NS.AddrInfo] -> IO NS.Socket
connectSocket :: [AddrInfo] -> IO Socket
connectSocket [] = String -> IO Socket
forall a. HasCallStack => String -> a
error String
"connectSocket: unexpected empty list"
connectSocket (AddrInfo
addr:[AddrInfo]
rest) = IO (Either IOError Socket)
tryConnect IO (Either IOError Socket)
-> (Either IOError Socket -> IO Socket) -> IO Socket
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
  Right Socket
sock -> Socket -> IO Socket
forall (m :: * -> *) a. Monad m => a -> m a
return Socket
sock
  Left IOError
err   -> if [AddrInfo] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [AddrInfo]
rest
                then IOError -> IO Socket
forall e a. Exception e => e -> IO a
throwIO IOError
err
                else [AddrInfo] -> IO Socket
connectSocket [AddrInfo]
rest
  where
    tryConnect :: IO (Either IOError NS.Socket)
    tryConnect :: IO (Either IOError Socket)
tryConnect = IO Socket
-> (Socket -> IO ())
-> (Socket -> IO (Either IOError Socket))
-> IO (Either IOError Socket)
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracketOnError IO Socket
createSock Socket -> IO ()
NS.close ((Socket -> IO (Either IOError Socket))
 -> IO (Either IOError Socket))
-> (Socket -> IO (Either IOError Socket))
-> IO (Either IOError Socket)
forall a b. (a -> b) -> a -> b
$ \Socket
sock ->
      IO () -> IO (Either IOError ())
forall e a. Exception e => IO a -> IO (Either e a)
try (Socket -> SockAddr -> IO ()
NS.connect Socket
sock (SockAddr -> IO ()) -> SockAddr -> IO ()
forall a b. (a -> b) -> a -> b
$ AddrInfo -> SockAddr
NS.addrAddress AddrInfo
addr) IO (Either IOError ())
-> (Either IOError () -> IO (Either IOError Socket))
-> IO (Either IOError Socket)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
      Right () -> Either IOError Socket -> IO (Either IOError Socket)
forall (m :: * -> *) a. Monad m => a -> m a
return (Socket -> Either IOError Socket
forall a b. b -> Either a b
Right Socket
sock)
      Left IOError
err -> Socket -> IO ()
NS.close Socket
sock IO () -> IO (Either IOError Socket) -> IO (Either IOError Socket)
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Either IOError Socket -> IO (Either IOError Socket)
forall (m :: * -> *) a. Monad m => a -> m a
return (IOError -> Either IOError Socket
forall a b. a -> Either a b
Left IOError
err)
      where
        createSock :: IO Socket
createSock = Family -> SocketType -> ProtocolNumber -> IO Socket
NS.socket (AddrInfo -> Family
NS.addrFamily AddrInfo
addr)
                               (AddrInfo -> SocketType
NS.addrSocketType AddrInfo
addr)
                               (AddrInfo -> ProtocolNumber
NS.addrProtocol AddrInfo
addr)

send :: ConnectionContext -> B.ByteString -> IO ()
send :: ConnectionContext -> ByteString -> IO ()
send (NormalHandle Handle
h) ByteString
requestData =
      IO () -> IO ()
forall a. IO a -> IO a
ioErrorToConnLost (Handle -> ByteString -> IO ()
B.hPut Handle
h ByteString
requestData)
send (TLSContext Context
ctx) ByteString
requestData =
        IO () -> IO ()
forall a. IO a -> IO a
ioErrorToConnLost (Context -> ByteString -> IO ()
forall (m :: * -> *). MonadIO m => Context -> ByteString -> m ()
TLS.sendData Context
ctx (ByteString -> ByteString
LB.fromStrict ByteString
requestData))

recv :: ConnectionContext -> IO B.ByteString
recv :: ConnectionContext -> IO ByteString
recv (NormalHandle Handle
h) = IO ByteString -> IO ByteString
forall a. IO a -> IO a
ioErrorToConnLost (IO ByteString -> IO ByteString) -> IO ByteString -> IO ByteString
forall a b. (a -> b) -> a -> b
$ Handle -> Int -> IO ByteString
B.hGetSome Handle
h Int
4096
recv (TLSContext Context
ctx) = Context -> IO ByteString
forall (m :: * -> *). MonadIO m => Context -> m ByteString
TLS.recvData Context
ctx


ioErrorToConnLost :: IO a -> IO a
ioErrorToConnLost :: IO a -> IO a
ioErrorToConnLost IO a
a = IO a
a IO a -> (IOError -> IO a) -> IO a
forall a. IO a -> (IOError -> IO a) -> IO a
`catchIOError` IO a -> IOError -> IO a
forall a b. a -> b -> a
const IO a
forall a. IO a
errConnClosed

errConnClosed :: IO a
errConnClosed :: IO a
errConnClosed = ConnectionLostException -> IO a
forall e a. Exception e => e -> IO a
throwIO ConnectionLostException
ConnectionLost


enableTLS :: TLS.ClientParams -> ConnectionContext -> IO ConnectionContext
enableTLS :: ClientParams -> ConnectionContext -> IO ConnectionContext
enableTLS ClientParams
tlsParams (NormalHandle Handle
h) = do
  Context
ctx <- Handle -> ClientParams -> IO Context
forall (m :: * -> *) backend params.
(MonadIO m, HasBackend backend, TLSParams params) =>
backend -> params -> m Context
TLS.contextNew Handle
h ClientParams
tlsParams
  Context -> IO ()
forall (m :: * -> *). MonadIO m => Context -> m ()
TLS.handshake Context
ctx
  ConnectionContext -> IO ConnectionContext
forall (m :: * -> *) a. Monad m => a -> m a
return (ConnectionContext -> IO ConnectionContext)
-> ConnectionContext -> IO ConnectionContext
forall a b. (a -> b) -> a -> b
$ Context -> ConnectionContext
TLSContext Context
ctx
enableTLS ClientParams
_ c :: ConnectionContext
c@(TLSContext Context
_) = ConnectionContext -> IO ConnectionContext
forall (m :: * -> *) a. Monad m => a -> m a
return ConnectionContext
c

disconnect :: ConnectionContext -> IO ()
disconnect :: ConnectionContext -> IO ()
disconnect (NormalHandle Handle
h) = do
  Bool
open <- Handle -> IO Bool
hIsOpen Handle
h
  Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
open (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Handle -> IO ()
hClose Handle
h
disconnect (TLSContext Context
ctx) = do
  Context -> IO ()
forall (m :: * -> *). MonadIO m => Context -> m ()
TLS.bye Context
ctx
  Context -> IO ()
TLS.contextClose Context
ctx

flush :: ConnectionContext -> IO ()
flush :: ConnectionContext -> IO ()
flush (NormalHandle Handle
h) = Handle -> IO ()
hFlush Handle
h
flush (TLSContext Context
c) = Context -> IO ()
TLS.contextFlush Context
c