{-# LANGUAGE DeriveDataTypeable  #-}
{-# LANGUAGE GADTs               #-}
{-# LANGUAGE RecordWildCards     #-}
{-# LANGUAGE ScopedTypeVariables #-}
--------------------------------------------------------------------------------
-- |
-- Module : Database.EventStore.Internal.Connection
-- Copyright : (C) 2015 Yorick Laupa
-- License : (see the file LICENSE)
--
-- Maintainer : Yorick Laupa <yo.eight@gmail.com>
-- Stability : provisional
-- Portability : non-portable
--
--------------------------------------------------------------------------------
module Database.EventStore.Internal.Connection
    ( InternalConnection
    , ConnectionException(..)
    , connUUID
    , connClose
    , connSend
    , connRecv
    , connIsClosed
    , newConnection
    ) where

--------------------------------------------------------------------------------
import Text.Printf

--------------------------------------------------------------------------------
import ClassyPrelude
import Data.Serialize
import Data.UUID
import Data.UUID.V4
import Network.Connection

--------------------------------------------------------------------------------
import Database.EventStore.Internal.Command
import Database.EventStore.Internal.Discovery
import Database.EventStore.Internal.Types
import Database.EventStore.Logging

--------------------------------------------------------------------------------
-- | Type of connection issue that can arise during the communication with the
--   server.
data ConnectionException
    = MaxAttemptConnectionReached
      -- ^ The max reconnection attempt threshold has been reached.
    | ClosedConnection
      -- ^ Use of a close 'Connection'.
    | WrongPackageFraming
      -- ^ TCP package sent by the server had a wrong framing.
    | PackageParsingError String
      -- ^ Server sent a malformed TCP package.
    deriving (Show, Typeable)

--------------------------------------------------------------------------------
instance Exception ConnectionException

--------------------------------------------------------------------------------
data In a where
    Id    :: In UUID
    Close :: In ()
    Send  :: Package -> In ()
    Recv  :: In Package

--------------------------------------------------------------------------------
-- | Represents connection logic action to carry out.
data Status a where
    Noop :: Status ()
    WithConnection :: UUID -> Connection -> In a -> Status a
    CreateConnection :: In a -> Status a
    Errored :: ConnectionException -> Status a

--------------------------------------------------------------------------------
-- | Internal representation of a connection with the server.
data InternalConnection =
    InternalConnection
    { _var    :: TMVar State
    , _last   :: IORef (Maybe EndPoint)
    , _disc   :: Discovery
    , _setts  :: Settings
    , _ctx    :: ConnectionContext
    }

--------------------------------------------------------------------------------
data State
    = Offline
    | Online !UUID !Connection
    | Closed

--------------------------------------------------------------------------------
-- | Creates a new 'InternalConnection'.
newConnection :: Settings -> Discovery -> IO InternalConnection
newConnection setts disc = do
    ctx <- initConnectionContext
    var <- newTMVarIO Offline
    ref <- newIORef Nothing
    return $ InternalConnection var ref disc setts ctx

--------------------------------------------------------------------------------
-- | Gets current 'InternalConnection' 'UUID'.
connUUID :: InternalConnection -> IO UUID
connUUID conn = execute conn Id

--------------------------------------------------------------------------------
-- | Closes the 'InternalConnection'. It will not retry to reconnect after that
--   call. it means a new 'InternalConnection' has to be created.
--   'ClosedConnection' exception will be raised if the same
--   'InternalConnection' object is used after a 'connClose' call.
connClose :: InternalConnection -> IO ()
connClose conn = execute conn Close

--------------------------------------------------------------------------------
-- | Sends 'Package' to the server.
connSend :: InternalConnection -> Package -> IO ()
connSend conn pkg = execute conn (Send pkg)

--------------------------------------------------------------------------------
-- | Asks the requested amount of bytes from the 'handle'.
connRecv :: InternalConnection -> IO Package
connRecv conn = execute conn Recv

--------------------------------------------------------------------------------
-- | Returns True if the connection is in closed state.
connIsClosed :: InternalConnection -> STM Bool
connIsClosed InternalConnection{..} = do
    r <- readTMVar _var
    case r of
        Closed -> return True
        _      -> return False

--------------------------------------------------------------------------------
-- Connection Logic
--------------------------------------------------------------------------------
onlineLogic :: forall a. TMVar State
            -> UUID
            -> Connection
            -> In a
            -> STM (Status a)
onlineLogic var uuid conn input =
    let status = WithConnection uuid conn input
        state =
            case input of
                Close -> Closed
                _ -> Online uuid conn in
    status <$ putTMVar var state

--------------------------------------------------------------------------------
offlineLogic :: forall a. TMVar State -> In a -> STM (Status a)
offlineLogic var Close = Noop <$ putTMVar var Closed
offlineLogic _ other = return $ CreateConnection other

--------------------------------------------------------------------------------
closedLogic :: forall a. TMVar State -> In a -> STM (Status a)
closedLogic var input = do
    putTMVar var Closed
    case input of
        Close -> return Noop
        _ -> return $ Errored ClosedConnection

--------------------------------------------------------------------------------
connectionLogic :: forall a. TMVar State -> In a -> STM (Status a)
connectionLogic var input = do
    state <- takeTMVar var
    case state of
        Online uuid conn -> onlineLogic var uuid conn input
        Offline -> offlineLogic var input
        Closed -> closedLogic var input

--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
handleInput :: forall a. UUID -> Connection -> In a -> IO a
handleInput _ conn (Send pkg) = send conn pkg
handleInput _ conn Recv = recv conn
handleInput uuid _ Id = return uuid
handleInput _ conn Close = liftIO $ connectionClose conn

--------------------------------------------------------------------------------
-- | Main connection logic. It will automatically reconnect to the server when
--   a exception occured while the 'Handle' is accessed.
execute :: forall a. InternalConnection -> In a -> IO a
execute iconn input = do
    res <- atomically $ connectionLogic (_var iconn) input

    case res of
        Noop -> return ()
        Errored e -> throwIO e
        WithConnection uuid conn op -> handleInput uuid conn op
        CreateConnection op -> do
            (uuid, conn) <- openConnection iconn
            atomically $ putTMVar (_var iconn) (Online uuid conn)
            handleInput uuid conn op

--------------------------------------------------------------------------------
reachedMaxAttempt :: Retry -> Int -> Bool
reachedMaxAttempt KeepRetrying _ = False
reachedMaxAttempt (AtMost n) cur = n <= cur

--------------------------------------------------------------------------------
openConnection :: InternalConnection -> IO (UUID, Connection)
openConnection InternalConnection{..} = attempt 1
  where
    delay = s_reconnect_delay_secs _setts * secs

    handleFailure trialCount = do
        threadDelay delay
        when (reachedMaxAttempt (s_retry _setts) trialCount) $ do
            atomically $ putTMVar _var Closed
            throwIO MaxAttemptConnectionReached
        attempt (trialCount + 1)

    attempt trialCount = do
        _settingsLog _setts (Info $ Connecting trialCount)
        old <- readIORef _last
        ept_opt <- runDiscovery _disc old
        case ept_opt of
            Nothing -> handleFailure trialCount
            Just ept -> do
                let host = endPointIp ept
                    port = endPointPort ept
                res <- tryAny $ connect _setts _ctx host port
                case res of
                    Left _ -> handleFailure trialCount
                    Right st -> st <$ writeIORef _last (Just ept)

--------------------------------------------------------------------------------
secs :: Int
secs = 1000000

--------------------------------------------------------------------------------
connect :: Settings
        -> ConnectionContext
        -> String
        -> Int
        -> IO (UUID, Connection)
connect sett ctx host port = do
    let params = ConnectionParams host (fromIntegral port) (s_ssl sett) Nothing
    conn <- connectTo ctx params
    uuid <- nextRandom
    _settingsLog sett (Info $ Connected uuid)
    return  (uuid, conn)

--------------------------------------------------------------------------------
-- Binary operations
--------------------------------------------------------------------------------
recv :: Connection -> IO Package
recv con = do
    header_bs <- connectionGetExact con 4
    case runGet getLengthPrefix header_bs of
        Left _              -> throwIO WrongPackageFraming
        Right length_prefix -> do
            bs <- connectionGetExact con length_prefix
            case runGet getPackage bs of
                Left e    -> throwIO $ PackageParsingError e
                Right pkg -> return pkg

--------------------------------------------------------------------------------
send :: Connection -> Package -> IO ()
send  con pkg = connectionPut con bs
  where
    bs = runPut $ putPackage pkg

--------------------------------------------------------------------------------
-- Serialization
--------------------------------------------------------------------------------
-- | Serializes a 'Package' into raw bytes.
putPackage :: Package -> Put
putPackage pkg = do
    putWord32le length_prefix
    putWord8 (cmdWord8 $ packageCmd pkg)
    putWord8 flag_word8
    putLazyByteString corr_bytes
    for_ cred_m $ \(Credentials login passw) -> do
        putWord8 $ fromIntegral $ olength login
        putByteString login
        putWord8 $ fromIntegral $ olength passw
        putByteString passw
    putByteString pack_data
  where
    pack_data     = packageData pkg
    cred_len      = maybe 0 credSize cred_m
    length_prefix = fromIntegral (olength pack_data + mandatorySize + cred_len)
    cred_m        = packageCred pkg
    flag_word8    = maybe 0x00 (const 0x01) cred_m
    corr_bytes    = toByteString $ packageCorrelation pkg

--------------------------------------------------------------------------------
credSize :: Credentials -> Int
credSize (Credentials login passw) = olength login + olength passw + 2

--------------------------------------------------------------------------------
-- | The minimun size a 'Package' should have. It's basically a command byte,
--   correlation bytes ('UUID') and a 'Flag' byte.
mandatorySize :: Int
mandatorySize = 18

--------------------------------------------------------------------------------
-- Parsing
--------------------------------------------------------------------------------
getLengthPrefix :: Get Int
getLengthPrefix = fmap fromIntegral getWord32le

--------------------------------------------------------------------------------
getPackage :: Get Package
getPackage = do
    cmd  <- getWord8
    flg  <- getFlag
    col  <- getUUID
    cred <- getCredentials flg
    rest <- remaining
    dta  <- getBytes rest

    let pkg = Package
              { packageCmd         = Command cmd
              , packageCorrelation = col
              , packageData        = dta
              , packageCred        = cred
              }

    return pkg

--------------------------------------------------------------------------------
getFlag :: Get Flag
getFlag = do
    wd <- getWord8
    case wd of
        0x00 -> return None
        0x01 -> return Authenticated
        _    -> fail $ printf "TCP: Unhandled flag value 0x%x" wd

--------------------------------------------------------------------------------
getCredEntryLength :: Get Int
getCredEntryLength = fmap fromIntegral getWord8

--------------------------------------------------------------------------------
getCredentials :: Flag -> Get (Maybe Credentials)
getCredentials None = return Nothing
getCredentials _ = do
    loginLen <- getCredEntryLength
    login    <- getBytes loginLen
    passwLen <- getCredEntryLength
    passw    <- getBytes passwLen
    return $ Just $ credentials login passw

--------------------------------------------------------------------------------
getUUID :: Get UUID
getUUID = do
    bs <- getLazyByteString 16
    case fromByteString bs of
        Just uuid -> return uuid
        _         -> fail "TCP: Wrong UUID format"