{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
module Database.Redis.ProtocolPipelining (
  Connection,
  connect, enableTLS, beginReceiving, disconnect, request, send, recv, flush,
  ConnectionLostException(..),
  PortID(..)
) where
import           Prelude
import           Control.Concurrent (threadDelay)
import           Control.Concurrent.Async (race)
import           Control.Concurrent.MVar
import           Control.Exception
import           Control.Monad
import qualified Scanner
import qualified Data.ByteString as S
import qualified Data.ByteString.Lazy as L
import           Data.IORef
import           Data.Typeable
import qualified Network.Socket as NS
import qualified Network.TLS as TLS
import           System.IO
import           System.IO.Error
import           System.IO.Unsafe
import           Database.Redis.Protocol
data PortID = PortNumber NS.PortNumber
            | UnixSocket String
            deriving Show
data ConnectionContext = NormalHandle Handle | TLSContext TLS.Context
data Connection = Conn
  { connCtx        :: ConnectionContext 
  , connReplies    :: IORef [Reply] 
  , connPending    :: IORef [Reply]
    
    
  , connPendingCnt :: IORef Int
    
    
    
  }
data ConnectionLostException = ConnectionLost
  deriving (Show, Typeable)
instance Exception ConnectionLostException
data ConnectPhase
  = PhaseUnknown
  | PhaseResolve
  | PhaseOpenSocket
  deriving (Show)
data ConnectTimeout = ConnectTimeout ConnectPhase
  deriving (Show, Typeable)
instance Exception ConnectTimeout
getHostAddrInfo :: NS.HostName -> NS.PortNumber -> IO [NS.AddrInfo]
getHostAddrInfo hostname port = do
  NS.getAddrInfo (Just hints) (Just hostname) (Just $ show port)
  where
    hints = NS.defaultHints
      { NS.addrSocketType = NS.Stream }
connectSocket :: [NS.AddrInfo] -> IO NS.Socket
connectSocket [] = error "connectSocket: unexpected empty list"
connectSocket (addr:rest) = tryConnect >>= \case
  Right sock -> return sock
  Left err   -> if null rest
                then throwIO err
                else connectSocket rest
  where
    tryConnect :: IO (Either IOError NS.Socket)
    tryConnect = bracketOnError createSock NS.close $ \sock -> do
      try (NS.connect sock $ NS.addrAddress addr) >>= \case
        Right () -> return (Right sock)
        Left err -> return (Left err)
      where
        createSock = NS.socket (NS.addrFamily addr)
                               (NS.addrSocketType addr)
                               (NS.addrProtocol addr)
connect :: NS.HostName -> PortID -> Maybe Int -> IO Connection
connect hostName portId timeoutOpt =
  bracketOnError hConnect hClose $ \h -> do
    hSetBinaryMode h True
    connReplies <- newIORef []
    connPending <- newIORef []
    connPendingCnt <- newIORef 0
    let connCtx = NormalHandle h
    return Conn{..}
  where
        hConnect = do
          phaseMVar <- newMVar PhaseUnknown
          let doConnect = hConnect' phaseMVar
          case timeoutOpt of
            Nothing -> doConnect
            Just micros -> do
              result <- race doConnect (threadDelay micros)
              case result of
                Left h -> return h
                Right () -> do
                  phase <- readMVar phaseMVar
                  errConnectTimeout phase
        hConnect' mvar = bracketOnError createSock NS.close $ \sock -> do
          NS.setSocketOption sock NS.KeepAlive 1
          void $ swapMVar mvar PhaseResolve
          void $ swapMVar mvar PhaseOpenSocket
          NS.socketToHandle sock ReadWriteMode
          where
            createSock = case portId of
              PortNumber portNumber -> do
                addrInfo <- getHostAddrInfo hostName portNumber
                connectSocket addrInfo
              UnixSocket addr -> bracketOnError
                (NS.socket NS.AF_UNIX NS.Stream NS.defaultProtocol)
                NS.close
                (\sock -> NS.connect sock (NS.SockAddrUnix addr) >> return sock)
enableTLS :: TLS.ClientParams -> Connection -> IO Connection
enableTLS tlsParams conn@Conn{..} = do
  case connCtx of
    NormalHandle h -> do
      ctx <- TLS.contextNew h tlsParams
      TLS.handshake ctx
      return $ conn { connCtx = TLSContext ctx }
    TLSContext _ -> return conn
beginReceiving :: Connection -> IO ()
beginReceiving conn = do
  rs <- connGetReplies conn
  writeIORef (connReplies conn) rs
  writeIORef (connPending conn) rs
disconnect :: Connection -> IO ()
disconnect Conn{..} = do
  case connCtx of
    NormalHandle h -> do
      open <- hIsOpen h
      when open $ hClose h
    TLSContext ctx -> do
      TLS.bye ctx
      TLS.contextClose ctx
send :: Connection -> S.ByteString -> IO ()
send Conn{..} s = do
  case connCtx of
    NormalHandle h ->
      ioErrorToConnLost $ S.hPut h s
    TLSContext ctx ->
      ioErrorToConnLost $ TLS.sendData ctx (L.fromStrict s)
  
  n <- atomicModifyIORef' connPendingCnt $ \n -> let n' = n+1 in (n', n')
  
  
  
  when (n >= 1000) $ do
    
    r:_ <- readIORef connPending
    r `seq` return ()
recv :: Connection -> IO Reply
recv Conn{..} = do
  (r:rs) <- readIORef connReplies
  writeIORef connReplies rs
  return r
flush :: Connection -> IO ()
flush Conn{..} =
  case connCtx of
    NormalHandle h -> hFlush h
    TLSContext ctx -> TLS.contextFlush ctx
request :: Connection -> S.ByteString -> IO Reply
request conn req = send conn req >> recv conn
connGetReplies :: Connection -> IO [Reply]
connGetReplies conn@Conn{..} = go S.empty (SingleLine "previous of first")
  where
    go rest previous = do
      
      ~(r, rest') <- unsafeInterleaveIO $ do
        
        previous `seq` return ()
        scanResult <- Scanner.scanWith readMore reply rest
        case scanResult of
          Scanner.Fail{}       -> errConnClosed
          Scanner.More{}    -> error "Hedis: parseWith returned Partial"
          Scanner.Done rest' r -> do
            
            
            atomicModifyIORef' connPending $ \(_:rs) -> (rs, ())
            
            
            atomicModifyIORef' connPendingCnt $ \n -> (max 0 (n-1), ())
            return (r, rest')
      rs <- unsafeInterleaveIO (go rest' r)
      return (r:rs)
    readMore = ioErrorToConnLost $ do
      flush conn
      case connCtx of
        NormalHandle h -> S.hGetSome h 4096
        TLSContext ctx -> TLS.recvData ctx
ioErrorToConnLost :: IO a -> IO a
ioErrorToConnLost a = a `catchIOError` const errConnClosed
errConnClosed :: IO a
errConnClosed = throwIO ConnectionLost
errConnectTimeout :: ConnectPhase -> IO a
errConnectTimeout phase = throwIO $ ConnectTimeout phase