{-# LANGUAGE BangPatterns      #-}
{-# LANGUAGE DeriveAnyClass    #-}
{-# LANGUAGE NamedFieldPuns    #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards   #-}
{-# LANGUAGE StrictData        #-}
{-# LANGUAGE TypeApplications  #-}
module Database.Redis.Sentinel
  (
    
    SentinelConnectInfo(..)
  , SentinelConnection
  , connect
    
  , runRedis
  , RedisSentinelException(..)
    
  , module Database.Redis
  ) where
import           Control.Concurrent
import           Control.Exception     (Exception, IOException, evaluate, throwIO)
import           Control.Monad
import           Control.Monad.Catch   (Handler (..), MonadCatch, catches, throwM)
import           Control.Monad.Except
import           Data.ByteString       (ByteString)
import qualified Data.ByteString       as BS
import qualified Data.ByteString.Char8 as BS8
import           Data.Foldable         (toList)
import           Data.List             (delete)
import           Data.List.NonEmpty    (NonEmpty (..))
import           Data.Typeable         (Typeable)
import           Data.Unique
import           Network.Socket        (HostName)
import           Database.Redis hiding (Connection, connect, runRedis)
import qualified Database.Redis as Redis
runRedis :: SentinelConnection
         -> Redis (Either Reply a)
         -> IO (Either Reply a)
runRedis (SentinelConnection connMVar) action = do
  (baseConn, preToken) <- modifyMVar connMVar $ \oldConnection@SentinelConnection'
          { rcCheckFailover
          , rcToken = oldToken
          , rcSentinelConnectInfo = oldConnectInfo
          , rcMasterConnectInfo = oldMasterConnectInfo
          , rcBaseConnection = oldBaseConnection } ->
      if rcCheckFailover
        then do
          (newConnectInfo, newMasterConnectInfo) <- updateMaster oldConnectInfo
          newToken <- newUnique
          (connInfo, conn) <-
            if sameHost newMasterConnectInfo oldMasterConnectInfo
              then return (oldMasterConnectInfo, oldBaseConnection)
              else do
                newConn <- Redis.connect newMasterConnectInfo
                return (newMasterConnectInfo, newConn)
          return
            ( SentinelConnection'
              { rcCheckFailover = False
              , rcToken = newToken
              , rcSentinelConnectInfo = newConnectInfo
              , rcMasterConnectInfo = connInfo
              , rcBaseConnection = conn
              }
            , (conn, newToken)
            )
        else return (oldConnection, (oldBaseConnection, oldToken))
  
  reply <- (Redis.runRedis baseConn action >>= evaluate)
    `catchRedisRethrow` (\_ -> setCheckSentinel preToken)
  case reply of
    Left (Error e) | "READONLY " `BS.isPrefixOf` e ->
        
        setCheckSentinel preToken
    _ -> return ()
  return reply
  where
    sameHost :: Redis.ConnectInfo -> Redis.ConnectInfo -> Bool
    sameHost l r = connectHost l == connectHost r && connectPort l == connectPort r
    setCheckSentinel preToken = modifyMVar_ connMVar $ \conn@SentinelConnection'{rcToken} ->
      if preToken == rcToken
        then do
          newToken <- newUnique
          return (conn{rcToken = newToken, rcCheckFailover = True})
        else return conn
connect :: SentinelConnectInfo -> IO SentinelConnection
connect origConnectInfo = do
  (connectInfo, masterConnectInfo) <- updateMaster origConnectInfo
  conn <- Redis.connect masterConnectInfo
  token <- newUnique
  SentinelConnection <$> newMVar SentinelConnection'
    { rcCheckFailover = False
    , rcToken = token
    , rcSentinelConnectInfo = connectInfo
    , rcMasterConnectInfo = masterConnectInfo
    , rcBaseConnection = conn
    }
updateMaster :: SentinelConnectInfo
             -> IO (SentinelConnectInfo, Redis.ConnectInfo)
updateMaster sci@SentinelConnectInfo{..} = do
    
    
    resultEither <- runExceptT $ forM_ connectSentinels $ \(host, port) -> do
      trySentinel host port `catchRedis` (\_ -> return ())
    case resultEither of
        Left (conn, sentinelPair) -> return
          ( sci
            { connectSentinels = sentinelPair :| delete sentinelPair (toList connectSentinels)
            }
          , conn
          )
        Right () -> throwIO $ NoSentinels connectSentinels
  where
    trySentinel :: HostName -> PortID -> ExceptT (Redis.ConnectInfo, (HostName, PortID)) IO ()
    trySentinel sentinelHost sentinelPort = do
      
      !replyE <- liftIO $ do
        !sentinelConn <- Redis.connect $ Redis.defaultConnectInfo
            { connectHost = sentinelHost
            , connectPort = sentinelPort
            , connectMaxConnections = 1
            }
        Redis.runRedis sentinelConn $ sendRequest
          ["SENTINEL", "get-master-addr-by-name", connectMasterName]
      case replyE of
        Right [host, port] ->
          throwError
            ( connectBaseInfo
              { connectHost = BS8.unpack host
              , connectPort =
                  maybe
                    (PortNumber 26379)
                    (PortNumber . fromIntegral . fst)
                    $ BS8.readInt port
              }
            , (sentinelHost, sentinelPort)
            )
        _ -> return ()
catchRedisRethrow :: MonadCatch m => m a -> (String -> m ()) -> m a
catchRedisRethrow action handler =
  action `catches`
    [ Handler $ \ex -> handler (show @IOException ex) >> throwM ex
    , Handler $ \ex -> handler (show @ConnectionLostException ex) >> throwM ex
    ]
catchRedis :: MonadCatch m => m a -> (String -> m a) -> m a
catchRedis action handler =
  action `catches`
    [ Handler $ \ex -> handler (show @IOException ex)
    , Handler $ \ex -> handler (show @ConnectionLostException ex)
    ]
newtype SentinelConnection = SentinelConnection (MVar SentinelConnection')
data SentinelConnection'
  = SentinelConnection'
      { rcCheckFailover       :: Bool
      , rcToken               :: Unique
      , rcSentinelConnectInfo :: SentinelConnectInfo
      , rcMasterConnectInfo   :: Redis.ConnectInfo
      , rcBaseConnection      :: Redis.Connection
      }
data SentinelConnectInfo
  = SentinelConnectInfo
      { connectSentinels  :: NonEmpty (HostName, PortID)
        
      , connectMasterName :: ByteString
        
      , connectBaseInfo   :: Redis.ConnectInfo
        
        
      }
  deriving (Show)
data RedisSentinelException
  = NoSentinels (NonEmpty (HostName, PortID))
    
  deriving (Show, Typeable, Exception)