{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
module Database.Redis.ConnectionContext (
    ConnectionContext(..)
  , ConnectTimeout(..)
  , ConnectionLostException(..)
  , PortID(..)
  , connect
  , disconnect
  , send
  , recv
  , errConnClosed
  , enableTLS
  , flush
  , ioErrorToConnLost
) where

import           Control.Concurrent (threadDelay)
import           Control.Concurrent.Async (race)
import Control.Monad(when)
import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as LB
import qualified Data.IORef as IOR
import Control.Concurrent.MVar(newMVar, readMVar, swapMVar)
import Control.Exception(bracketOnError, Exception, throwIO, try)
import           Data.Typeable
import Data.Functor(void)
import qualified Network.Socket as NS
import qualified Network.TLS as TLS
import System.IO(Handle, hSetBinaryMode, hClose, IOMode(..), hFlush, hIsOpen)
import System.IO.Error(catchIOError)

data ConnectionContext = NormalHandle Handle | TLSContext TLS.Context

instance Show ConnectionContext where
    show :: ConnectionContext -> String
show (NormalHandle Handle
_) = String
"NormalHandle"
    show (TLSContext Context
_) = String
"TLSContext"

data Connection = Connection
    { Connection -> ConnectionContext
ctx :: ConnectionContext
    , Connection -> IORef (Maybe ByteString)
lastRecvRef :: IOR.IORef (Maybe B.ByteString) }

instance Show Connection where
    show :: Connection -> String
show Connection{IORef (Maybe ByteString)
ConnectionContext
lastRecvRef :: IORef (Maybe ByteString)
ctx :: ConnectionContext
lastRecvRef :: Connection -> IORef (Maybe ByteString)
ctx :: Connection -> ConnectionContext
..} = String
"Connection{ ctx = " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show ConnectionContext
ctx forall a. [a] -> [a] -> [a]
++ String
", lastRecvRef = IORef}"

data ConnectPhase
  = PhaseUnknown
  | PhaseResolve
  | PhaseOpenSocket
  deriving (Int -> ConnectPhase -> ShowS
[ConnectPhase] -> ShowS
ConnectPhase -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ConnectPhase] -> ShowS
$cshowList :: [ConnectPhase] -> ShowS
show :: ConnectPhase -> String
$cshow :: ConnectPhase -> String
showsPrec :: Int -> ConnectPhase -> ShowS
$cshowsPrec :: Int -> ConnectPhase -> ShowS
Show)

newtype ConnectTimeout = ConnectTimeout ConnectPhase
  deriving (Int -> ConnectTimeout -> ShowS
[ConnectTimeout] -> ShowS
ConnectTimeout -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ConnectTimeout] -> ShowS
$cshowList :: [ConnectTimeout] -> ShowS
show :: ConnectTimeout -> String
$cshow :: ConnectTimeout -> String
showsPrec :: Int -> ConnectTimeout -> ShowS
$cshowsPrec :: Int -> ConnectTimeout -> ShowS
Show, Typeable)

instance Exception ConnectTimeout

data ConnectionLostException = ConnectionLost deriving Int -> ConnectionLostException -> ShowS
[ConnectionLostException] -> ShowS
ConnectionLostException -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ConnectionLostException] -> ShowS
$cshowList :: [ConnectionLostException] -> ShowS
show :: ConnectionLostException -> String
$cshow :: ConnectionLostException -> String
showsPrec :: Int -> ConnectionLostException -> ShowS
$cshowsPrec :: Int -> ConnectionLostException -> ShowS
Show
instance Exception ConnectionLostException

data PortID = PortNumber NS.PortNumber
            | UnixSocket String
            deriving (PortID -> PortID -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: PortID -> PortID -> Bool
$c/= :: PortID -> PortID -> Bool
== :: PortID -> PortID -> Bool
$c== :: PortID -> PortID -> Bool
Eq, Int -> PortID -> ShowS
[PortID] -> ShowS
PortID -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [PortID] -> ShowS
$cshowList :: [PortID] -> ShowS
show :: PortID -> String
$cshow :: PortID -> String
showsPrec :: Int -> PortID -> ShowS
$cshowsPrec :: Int -> PortID -> ShowS
Show)

connect :: NS.HostName -> PortID -> Maybe Int -> IO ConnectionContext
connect :: String -> PortID -> Maybe Int -> IO ConnectionContext
connect String
hostName PortID
portId Maybe Int
timeoutOpt =
  forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracketOnError IO Handle
hConnect Handle -> IO ()
hClose forall a b. (a -> b) -> a -> b
$ \Handle
h -> do
    Handle -> Bool -> IO ()
hSetBinaryMode Handle
h Bool
True
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ Handle -> ConnectionContext
NormalHandle Handle
h
  where
        hConnect :: IO Handle
hConnect = do
          MVar ConnectPhase
phaseMVar <- forall a. a -> IO (MVar a)
newMVar ConnectPhase
PhaseUnknown
          let doConnect :: IO Handle
doConnect = MVar ConnectPhase -> IO Handle
hConnect' MVar ConnectPhase
phaseMVar
          case Maybe Int
timeoutOpt of
            Maybe Int
Nothing -> IO Handle
doConnect
            Just Int
micros -> do
              Either Handle ()
result <- forall a b. IO a -> IO b -> IO (Either a b)
race IO Handle
doConnect (Int -> IO ()
threadDelay Int
micros)
              case Either Handle ()
result of
                Left Handle
h -> forall (m :: * -> *) a. Monad m => a -> m a
return Handle
h
                Right () -> do
                  ConnectPhase
phase <- forall a. MVar a -> IO a
readMVar MVar ConnectPhase
phaseMVar
                  forall a. ConnectPhase -> IO a
errConnectTimeout ConnectPhase
phase
        hConnect' :: MVar ConnectPhase -> IO Handle
hConnect' MVar ConnectPhase
mvar = forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracketOnError IO Socket
createSock Socket -> IO ()
NS.close forall a b. (a -> b) -> a -> b
$ \Socket
sock -> do
          Socket -> SocketOption -> Int -> IO ()
NS.setSocketOption Socket
sock SocketOption
NS.KeepAlive Int
1
          forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ forall a. MVar a -> a -> IO a
swapMVar MVar ConnectPhase
mvar ConnectPhase
PhaseResolve
          forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ forall a. MVar a -> a -> IO a
swapMVar MVar ConnectPhase
mvar ConnectPhase
PhaseOpenSocket
          Socket -> IOMode -> IO Handle
NS.socketToHandle Socket
sock IOMode
ReadWriteMode
          where
            createSock :: IO Socket
createSock = case PortID
portId of
              PortNumber PortNumber
portNumber -> do
                [AddrInfo]
addrInfo <- String -> PortNumber -> IO [AddrInfo]
getHostAddrInfo String
hostName PortNumber
portNumber
                [AddrInfo] -> IO Socket
connectSocket [AddrInfo]
addrInfo
              UnixSocket String
addr -> forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracketOnError
                (Family -> SocketType -> ProtocolNumber -> IO Socket
NS.socket Family
NS.AF_UNIX SocketType
NS.Stream ProtocolNumber
NS.defaultProtocol)
                Socket -> IO ()
NS.close
                (\Socket
sock -> Socket -> SockAddr -> IO ()
NS.connect Socket
sock (String -> SockAddr
NS.SockAddrUnix String
addr) forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall (m :: * -> *) a. Monad m => a -> m a
return Socket
sock)

getHostAddrInfo :: NS.HostName -> NS.PortNumber -> IO [NS.AddrInfo]
getHostAddrInfo :: String -> PortNumber -> IO [AddrInfo]
getHostAddrInfo String
hostname PortNumber
port =
  Maybe AddrInfo -> Maybe String -> Maybe String -> IO [AddrInfo]
NS.getAddrInfo (forall a. a -> Maybe a
Just AddrInfo
hints) (forall a. a -> Maybe a
Just String
hostname) (forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall a. Show a => a -> String
show PortNumber
port)
  where
    hints :: AddrInfo
hints = AddrInfo
NS.defaultHints
      { addrSocketType :: SocketType
NS.addrSocketType = SocketType
NS.Stream }

errConnectTimeout :: ConnectPhase -> IO a
errConnectTimeout :: forall a. ConnectPhase -> IO a
errConnectTimeout ConnectPhase
phase = forall e a. Exception e => e -> IO a
throwIO forall a b. (a -> b) -> a -> b
$ ConnectPhase -> ConnectTimeout
ConnectTimeout ConnectPhase
phase

connectSocket :: [NS.AddrInfo] -> IO NS.Socket
connectSocket :: [AddrInfo] -> IO Socket
connectSocket [] = forall a. HasCallStack => String -> a
error String
"connectSocket: unexpected empty list"
connectSocket (AddrInfo
addr:[AddrInfo]
rest) = IO (Either IOError Socket)
tryConnect forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
  Right Socket
sock -> forall (m :: * -> *) a. Monad m => a -> m a
return Socket
sock
  Left IOError
err   -> if forall (t :: * -> *) a. Foldable t => t a -> Bool
null [AddrInfo]
rest
                then forall e a. Exception e => e -> IO a
throwIO IOError
err
                else [AddrInfo] -> IO Socket
connectSocket [AddrInfo]
rest
  where
    tryConnect :: IO (Either IOError NS.Socket)
    tryConnect :: IO (Either IOError Socket)
tryConnect = forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracketOnError IO Socket
createSock Socket -> IO ()
NS.close forall a b. (a -> b) -> a -> b
$ \Socket
sock ->
      forall e a. Exception e => IO a -> IO (Either e a)
try (Socket -> SockAddr -> IO ()
NS.connect Socket
sock forall a b. (a -> b) -> a -> b
$ AddrInfo -> SockAddr
NS.addrAddress AddrInfo
addr) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
      Right () -> forall (m :: * -> *) a. Monad m => a -> m a
return (forall a b. b -> Either a b
Right Socket
sock)
      Left IOError
err -> Socket -> IO ()
NS.close Socket
sock forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall (m :: * -> *) a. Monad m => a -> m a
return (forall a b. a -> Either a b
Left IOError
err)
      where
        createSock :: IO Socket
createSock = Family -> SocketType -> ProtocolNumber -> IO Socket
NS.socket (AddrInfo -> Family
NS.addrFamily AddrInfo
addr)
                               (AddrInfo -> SocketType
NS.addrSocketType AddrInfo
addr)
                               (AddrInfo -> ProtocolNumber
NS.addrProtocol AddrInfo
addr)

send :: ConnectionContext -> B.ByteString -> IO ()
send :: ConnectionContext -> ByteString -> IO ()
send (NormalHandle Handle
h) ByteString
requestData =
      forall a. IO a -> IO a
ioErrorToConnLost (Handle -> ByteString -> IO ()
B.hPut Handle
h ByteString
requestData)
send (TLSContext Context
ctx) ByteString
requestData =
        forall a. IO a -> IO a
ioErrorToConnLost (forall (m :: * -> *). MonadIO m => Context -> ByteString -> m ()
TLS.sendData Context
ctx (ByteString -> ByteString
LB.fromStrict ByteString
requestData))

recv :: ConnectionContext -> IO B.ByteString
recv :: ConnectionContext -> IO ByteString
recv (NormalHandle Handle
h) = forall a. IO a -> IO a
ioErrorToConnLost forall a b. (a -> b) -> a -> b
$ Handle -> Int -> IO ByteString
B.hGetSome Handle
h Int
4096
recv (TLSContext Context
ctx) = forall (m :: * -> *). MonadIO m => Context -> m ByteString
TLS.recvData Context
ctx


ioErrorToConnLost :: IO a -> IO a
ioErrorToConnLost :: forall a. IO a -> IO a
ioErrorToConnLost IO a
a = IO a
a forall a. IO a -> (IOError -> IO a) -> IO a
`catchIOError` forall a b. a -> b -> a
const forall a. IO a
errConnClosed

errConnClosed :: IO a
errConnClosed :: forall a. IO a
errConnClosed = forall e a. Exception e => e -> IO a
throwIO ConnectionLostException
ConnectionLost


enableTLS :: TLS.ClientParams -> ConnectionContext -> IO ConnectionContext
enableTLS :: ClientParams -> ConnectionContext -> IO ConnectionContext
enableTLS ClientParams
tlsParams (NormalHandle Handle
h) = do
  Context
ctx <- forall (m :: * -> *) backend params.
(MonadIO m, HasBackend backend, TLSParams params) =>
backend -> params -> m Context
TLS.contextNew Handle
h ClientParams
tlsParams
  forall (m :: * -> *). MonadIO m => Context -> m ()
TLS.handshake Context
ctx
  forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ Context -> ConnectionContext
TLSContext Context
ctx
enableTLS ClientParams
_ c :: ConnectionContext
c@(TLSContext Context
_) = forall (m :: * -> *) a. Monad m => a -> m a
return ConnectionContext
c

disconnect :: ConnectionContext -> IO ()
disconnect :: ConnectionContext -> IO ()
disconnect (NormalHandle Handle
h) = do
  Bool
open <- Handle -> IO Bool
hIsOpen Handle
h
  forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
open forall a b. (a -> b) -> a -> b
$ Handle -> IO ()
hClose Handle
h
disconnect (TLSContext Context
ctx) = do
  forall (m :: * -> *). MonadIO m => Context -> m ()
TLS.bye Context
ctx
  Context -> IO ()
TLS.contextClose Context
ctx

flush :: ConnectionContext -> IO ()
flush :: ConnectionContext -> IO ()
flush (NormalHandle Handle
h) = Handle -> IO ()
hFlush Handle
h
flush (TLSContext Context
c) = Context -> IO ()
TLS.contextFlush Context
c