{-# LANGUAGE LambdaCase, OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE TypeFamilies #-}
module Database.Franz.Network
  ( defaultPort
  , Connection
  , withConnection
  , connect
  , disconnect
  , Query(..)
  , ItemRef(..)
  , RequestType(..)
  , defQuery
  , Response
  , awaitResponse
  , SomeIndexMap
  , Contents
  , fetch
  , fetchSimple
  , atomicallyWithin
  , FranzException(..)) where

import Control.Arrow ((&&&))
import Control.Concurrent
import Control.Concurrent.STM
import Control.Concurrent.STM.Delay (newDelay, waitDelay)
import Control.Exception
import Control.Monad
import qualified Data.ByteString.Char8 as B
import Data.ConcurrentResourceMap
import qualified Data.HashMap.Strict as HM
import Data.IORef
import Data.IORef.Unboxed
import Data.Int (Int64)
import qualified Data.IntMap.Strict as IM
import Data.Serialize hiding (getInt64le)
import qualified Data.Vector as V
import qualified Data.Vector.Generic.Mutable as VGM
import Database.Franz.Internal
import Database.Franz.Protocol
import qualified Network.Socket as S
import qualified Network.Socket.ByteString as SB

-- The protocol
--
-- Client                     Server
---  | ---- Archive prefix ---> |  Mounts P if possible
---  | <--- apiVersion -------- |
---  | ---- RawRequest i p ---> |
---  | ---- RawRequest j q ---> |
---  | ---- RawRequest k r ---> |
---  | <--- ResponseInstant i - |
---  | <--- result for p -----  |
---  | <--- ResponseWait j ---- |
---  | <--- ResponseWait k ---- |
---  | <--- ResponseDelayed j - |
---  | <--- result for q -----  |
--   | ----  RawClean i ---->   |
--   | ----  RawClean j ---->   |
--   | ----  RawClean k ---->   |

newtype ConnStateMap v = ConnStateMap (IM.IntMap v)

instance ResourceMap ConnStateMap where
  type Key ConnStateMap = Int
  empty = ConnStateMap IM.empty
  delete k (ConnStateMap m) = ConnStateMap (IM.delete k m)
  insert k v (ConnStateMap m) = ConnStateMap (IM.insert k v m)
  lookup k (ConnStateMap m) = IM.lookup k m

data Connection = Connection
  { connSocket :: MVar S.Socket
  , connReqId :: !Counter
  , connStates :: !(ConcurrentResourceMap
                     ConnStateMap
                     (TVar (ResponseStatus Contents)))
  , connThread :: !ThreadId
  }

data ResponseStatus a = WaitingInstant
    | WaitingDelayed
    | Errored !FranzException
    | Available !a
    -- | The user cancelled the request.
    | RequestFinished
    deriving (Show, Functor)

withConnection :: String -> S.PortNumber -> B.ByteString -> (Connection -> IO r) -> IO r
withConnection host port dir = bracket (connect host port dir) disconnect

connect :: String -> S.PortNumber -> B.ByteString -> IO Connection
connect host port dir = do
  let hints = S.defaultHints { S.addrFlags = [S.AI_NUMERICSERV], S.addrSocketType = S.Stream }
  addr:_ <- S.getAddrInfo (Just hints) (Just host) (Just $ show port)
  sock <- S.socket (S.addrFamily addr) S.Stream (S.addrProtocol addr)
  S.setSocketOption sock S.NoDelay 1
  S.connect sock $ S.addrAddress addr
  SB.sendAll sock $ encode dir
  readyMsg <- SB.recv sock 4096
  unless (readyMsg == apiVersion) $ case decode readyMsg of
    Right (ResponseError _ e) -> throwIO e
    e -> throwIO $ ClientError $ "Database.Franz.Network.connect: Unexpected response: " ++ show e

  connSocket <- newMVar sock
  connReqId <- newCounter 0
  connStates <- newResourceMap
  buf <- newIORef B.empty
  let -- Get a reference to shared state for the request if it exists.
      withRequest i f = withInitialisedResource connStates i (\_ -> pure ()) $ \case
        Nothing ->
          -- If it throws an exception on no value, great, it will
          -- float out here. If it returns a value, it'll just be
          -- ignore as we can't do anything with it anyway.
          void $ atomically (f Nothing)
        Just reqVar -> atomically $ readTVar reqVar >>= \case
          -- If request is finished, do nothing to the content.
          RequestFinished -> void $ f Nothing
          s -> f (Just s) >>= mapM_ (writeTVar reqVar)

      runGetThrow :: Get a -> IO a
      runGetThrow g = runGetRecv buf sock g
        >>= either (throwIO . ClientError) pure

  connThread <- flip forkFinally (either throwIO pure) $ forever $ runGetThrow get >>= \case
    Response i -> do
      resp <- runGetThrow getResponse
      withRequest i . traverse $ \case
        WaitingInstant -> pure (Available resp)
        WaitingDelayed -> pure (Available resp)
        e -> throwSTM $ ClientError $ "Unexpected state on ResponseInstant " ++ show i ++ ": " ++ show e
    ResponseWait i -> withRequest i . traverse $ \case
      WaitingInstant -> pure WaitingDelayed
      e -> throwSTM $ ClientError $ "Unexpected state on ResponseWait " ++ show i ++ ": " ++ show e
    ResponseError i e -> withRequest i $ \case
      Nothing -> throwSTM e
      Just{} -> pure $ Just (Errored e)
  return Connection{..}

disconnect :: Connection -> IO ()
disconnect Connection{..} = do
  killThread connThread
  withMVar connSocket S.close

defQuery :: B.ByteString -> Query
defQuery name = Query
  { reqStream = name
  , reqFrom = BySeqNum 0
  , reqTo = BySeqNum 0
  , reqType = AllItems
  }

type SomeIndexMap = HM.HashMap IndexName Int64

-- | (seqno, indices, payloads)
type Contents = V.Vector (Int, SomeIndexMap, B.ByteString)

-- | When it is 'Right', it might block until the content arrives.
type Response = Either Contents (STM Contents)

awaitResponse :: STM (Either a (STM a)) -> STM a
awaitResponse = (>>=either pure id)

getResponse :: Get Contents
getResponse = do
    PayloadHeader s0 s1 p0 names <- get
    let df = s1 - s0
    if df <= 0
        then pure mempty
        else do
            ixs <- V.replicateM df $ (,) <$> getInt64le <*> traverse (const getInt64le) names
            payload <- getByteString $ fst (V.unsafeLast ixs) - p0
            pure $ V.create $ do
                vres <- VGM.unsafeNew df
                let go i ofs0
                        | i >= df = pure ()
                        | otherwise = do
                              let (ofs1, indices) = V.unsafeIndex ixs i
                                  !m = HM.fromList $ zip names indices
                                  !bs = B.take (ofs1 - ofs0) $ B.drop (ofs0 - p0) payload
                                  !num = s0 + i + 1
                              VGM.unsafeWrite vres i (num, m, bs)
                              go (i + 1) ofs1
                go 0 p0
                return vres

-- | Fetch requested data from the server.
--
-- Termination of 'fetch' continuation cancels the request, allowing
-- flexible control of its lifetime.
fetch
  :: Connection
  -> Query
  -> (STM Response -> IO r)
  -- ^ Wait for the response in a blocking manner. You should only run
  -- the continuation inside a 'fetch' block: leaking the STM action
  -- and running it outside will result in a 'ClientError' exception.
  -> IO r
fetch Connection{..} req cont = do
  reqId <- atomicAddCounter connReqId 1

  let -- When we exit the scope of the request, ensure that we cancel any
      -- outstanding request and set the appropriate state, lest the user
      -- leaks the resource and tries to re-run the provided action.
      cleanupRequest reqVar = do
        let inFlight WaitingInstant = True
            inFlight WaitingDelayed = True
            inFlight _ = False
        -- Check set the internal state to RequestFinished while
        -- noting if there's possibly a request still in flight.
        requestInFlight <- atomically $
          stateTVar reqVar $ inFlight &&& const RequestFinished
        when requestInFlight $ withMVar connSocket $ \sock ->
          SB.sendAll sock $ encode $ RawClean reqId

  -- We use a shared resource map here to ensure that we only hold
  -- onto the share connection state TVar for the duration of making a
  -- fetch request. If anything goes wrong in the middle, we're
  -- certain it'll get removed.
  withSharedResource connStates reqId
    (newTVarIO WaitingInstant)
    cleanupRequest $ \reqVar -> do
    -- Send the user request.
    withMVar connSocket $ \sock -> SB.sendAll sock $ encode
      $ RawRequest reqId req

    let
      requestFinished = ClientError "request already finished"

      getDelayed = readTVar reqVar >>= \case
        RequestFinished -> throwSTM requestFinished
        WaitingDelayed -> retry
        Available xs -> return xs
        Errored e -> throwSTM e
        WaitingInstant -> throwSTM $ ClientError $
          "fetch/WaitingDelayed: unexpected state WaitingInstant"

    -- Run the user's continuation. 'withSharedResource' takes care of
    -- any clean-up necessary.
    cont $ readTVar reqVar >>= \case
      RequestFinished -> throwSTM requestFinished
      Errored e -> throwSTM e
      WaitingInstant -> retry -- wait for an instant response
      Available xs -> pure $ Left xs
      WaitingDelayed -> pure $ Right getDelayed

-- | Send a single query and wait for the result. If it timeouts, it returns an empty list.
fetchSimple :: Connection
  -> Int -- ^ timeout in microseconds
  -> Query
  -> IO Contents
fetchSimple conn timeout req = fetch conn req (fmap (maybe mempty id) . atomicallyWithin timeout . awaitResponse)

atomicallyWithin :: Int -- ^ timeout in microseconds
  -> STM a
  -> IO (Maybe a)
atomicallyWithin timeout m = do
  d <- newDelay timeout
  atomically $ fmap Just m `orElse` (Nothing <$ waitDelay d)