{-# LANGUAGE RecordWildCards #-}

{-|
Module      : Database.Memcache.Server
Description : Server Handling
Copyright   : (c) David Terei, 2016
License     : BSD
Maintainer  : code@davidterei.com
Stability   : stable
Portability : GHC

Handles the connections between a Memcached client and a single server.

Memcached expected errors (part of protocol) are returned in the Response,
unexpected errors (e.g., network failure) are thrown as exceptions. While
the Server datatype supports a `failed` and `failedAt` flag for managing
retries, it's up to consumers to use this.
-}
module Database.Memcache.Server (
      -- * Server
        Server(sid, failed), newServer, sendRecv, withSocket, close
    ) where

import Database.Memcache.SASL
import Database.Memcache.Socket

import Control.Exception
import Data.Hashable
import Data.IORef
import Data.Pool
import Data.Time.Clock (NominalDiffTime)
import Data.Time.Clock.POSIX (POSIXTime)

import Network.Socket (getAddrInfo, HostName, ServiceName)
import qualified Network.Socket as S

-- Connection pool constants.
-- TODO: make configurable
sSTRIPES, sCONNECTIONS :: Int
sKEEPALIVE :: NominalDiffTime
sSTRIPES :: Int
sSTRIPES     = Int
1
sCONNECTIONS :: Int
sCONNECTIONS = Int
1
sKEEPALIVE :: NominalDiffTime
sKEEPALIVE = NominalDiffTime
300

-- | Memcached server connection.
data Server = Server {
        -- | ID of server for consistent hashing.
        Server -> Int
sid      :: {-# UNPACK #-} !Int,
        -- | Connection pool to server.
        Server -> Pool Socket
pool     :: Pool Socket,
        -- | Hostname of server.
        Server -> HostName
addr     :: !HostName,
        -- | Port number of server.
        Server -> HostName
port     :: !ServiceName,
        -- | Credentials for server.
        Server -> Authentication
auth     :: !Authentication,
        -- | When did the server fail? 0 if it is alive.
        Server -> IORef NominalDiffTime
failed   :: IORef POSIXTime

        -- TODO: 
        -- weight   :: Double
        -- tansport :: Transport (UDP vs. TCP)
        -- poolLim  :: Int (pooled connection limit)
        -- cnxnBuf   :: IORef ByteString
    }

instance Show Server where
  show :: Server -> HostName
show Server{Int
HostName
IORef NominalDiffTime
Pool Socket
Authentication
failed :: IORef NominalDiffTime
auth :: Authentication
port :: HostName
addr :: HostName
pool :: Pool Socket
sid :: Int
auth :: Server -> Authentication
port :: Server -> HostName
addr :: Server -> HostName
pool :: Server -> Pool Socket
failed :: Server -> IORef NominalDiffTime
sid :: Server -> Int
..} =
    HostName
"Server [" HostName -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> HostName
forall a. Show a => a -> HostName
show Int
sid HostName -> ShowS
forall a. [a] -> [a] -> [a]
++ HostName
"] " HostName -> ShowS
forall a. [a] -> [a] -> [a]
++ HostName
addr HostName -> ShowS
forall a. [a] -> [a] -> [a]
++ HostName
":" HostName -> ShowS
forall a. [a] -> [a] -> [a]
++ ShowS
forall a. Show a => a -> HostName
show HostName
port

instance Eq Server where
    == :: Server -> Server -> Bool
(==) Server
x Server
y = Server -> Int
sid Server
x Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Server -> Int
sid Server
y

instance Ord Server where
    compare :: Server -> Server -> Ordering
compare Server
x Server
y = Int -> Int -> Ordering
forall a. Ord a => a -> a -> Ordering
compare (Server -> Int
sid Server
x) (Server -> Int
sid Server
y)

-- | Create a new Memcached server connection.
newServer :: HostName -> ServiceName -> Authentication -> IO Server
newServer :: HostName -> HostName -> Authentication -> IO Server
newServer HostName
host HostName
port Authentication
auth = do
    IORef NominalDiffTime
fat <- NominalDiffTime -> IO (IORef NominalDiffTime)
forall a. a -> IO (IORef a)
newIORef NominalDiffTime
0
    Pool Socket
pSock <- IO Socket
-> (Socket -> IO ())
-> Int
-> NominalDiffTime
-> Int
-> IO (Pool Socket)
forall a.
IO a
-> (a -> IO ()) -> Int -> NominalDiffTime -> Int -> IO (Pool a)
createPool IO Socket
connectSocket Socket -> IO ()
releaseSocket
                Int
sSTRIPES NominalDiffTime
sKEEPALIVE Int
sCONNECTIONS
    Server -> IO Server
forall (m :: * -> *) a. Monad m => a -> m a
return Server :: Int
-> Pool Socket
-> HostName
-> HostName
-> Authentication
-> IORef NominalDiffTime
-> Server
Server
        { sid :: Int
sid      = Int
serverHash
        , pool :: Pool Socket
pool     = Pool Socket
pSock
        , addr :: HostName
addr     = HostName
host
        , port :: HostName
port     = HostName
port
        , auth :: Authentication
auth     = Authentication
auth
        , failed :: IORef NominalDiffTime
failed   = IORef NominalDiffTime
fat
        }
  where
    serverHash :: Int
serverHash = (HostName, HostName) -> Int
forall a. Hashable a => a -> Int
hash (HostName
host, HostName
port)

    connectSocket :: IO Socket
connectSocket = do
        let hints :: AddrInfo
hints = AddrInfo
S.defaultHints {
          addrSocketType :: SocketType
S.addrSocketType = SocketType
S.Stream
        }
        AddrInfo
addr:[AddrInfo]
_ <- Maybe AddrInfo -> Maybe HostName -> Maybe HostName -> IO [AddrInfo]
getAddrInfo (AddrInfo -> Maybe AddrInfo
forall a. a -> Maybe a
Just AddrInfo
hints) (HostName -> Maybe HostName
forall a. a -> Maybe a
Just HostName
host) (HostName -> Maybe HostName
forall a. a -> Maybe a
Just HostName
port)
        IO Socket
-> (Socket -> IO ()) -> (Socket -> IO Socket) -> IO Socket
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracketOnError
            (Family -> SocketType -> ProtocolNumber -> IO Socket
S.socket (AddrInfo -> Family
S.addrFamily AddrInfo
addr) (AddrInfo -> SocketType
S.addrSocketType AddrInfo
addr) (AddrInfo -> ProtocolNumber
S.addrProtocol AddrInfo
addr))
            Socket -> IO ()
releaseSocket
            (\Socket
s -> do
                Socket -> SockAddr -> IO ()
S.connect Socket
s (SockAddr -> IO ()) -> SockAddr -> IO ()
forall a b. (a -> b) -> a -> b
$ AddrInfo -> SockAddr
S.addrAddress AddrInfo
addr
                Socket -> SocketOption -> Int -> IO ()
S.setSocketOption Socket
s SocketOption
S.KeepAlive Int
1
                Socket -> SocketOption -> Int -> IO ()
S.setSocketOption Socket
s SocketOption
S.NoDelay Int
1
                Socket -> Authentication -> IO ()
authenticate Socket
s Authentication
auth
                Socket -> IO Socket
forall (m :: * -> *) a. Monad m => a -> m a
return Socket
s
            )

    releaseSocket :: Socket -> IO ()
releaseSocket = Socket -> IO ()
S.close

-- | Send and receive a single request/response pair to the Memcached server.
sendRecv :: Server -> Request -> IO Response
{-# INLINE sendRecv #-}
sendRecv :: Server -> Request -> IO Response
sendRecv Server
svr Request
msg = Server -> (Socket -> IO Response) -> IO Response
forall a. Server -> (Socket -> IO a) -> IO a
withSocket Server
svr ((Socket -> IO Response) -> IO Response)
-> (Socket -> IO Response) -> IO Response
forall a b. (a -> b) -> a -> b
$ \Socket
s -> do
    Socket -> Request -> IO ()
send Socket
s Request
msg
    Socket -> IO Response
recv Socket
s

-- | Run a function with access to an server socket for using 'send' and
-- 'recv'.
withSocket :: Server -> (Socket -> IO a) -> IO a
{-# INLINE withSocket #-}
withSocket :: Server -> (Socket -> IO a) -> IO a
withSocket Server
svr = Pool Socket -> (Socket -> IO a) -> IO a
forall (m :: * -> *) a b.
MonadBaseControl IO m =>
Pool a -> (a -> m b) -> m b
withResource (Pool Socket -> (Socket -> IO a) -> IO a)
-> Pool Socket -> (Socket -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ Server -> Pool Socket
pool Server
svr

-- | Close the server connection. If you perform another operation after this,
-- the connection will be re-established.
close :: Server -> IO ()
{-# INLINE close #-}
close :: Server -> IO ()
close Server
srv = Pool Socket -> IO ()
forall a. Pool a -> IO ()
destroyAllResources (Pool Socket -> IO ()) -> Pool Socket -> IO ()
forall a b. (a -> b) -> a -> b
$ Server -> Pool Socket
pool Server
srv