{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards  #-}
{-# LANGUAGE RankNTypes  #-}

module Network.HTTP2.Client.RawConnection (
      RawHttp2Connection (..)
    , newRawHttp2Connection
    ) where

import           Control.Monad (forever, when)
import           Control.Concurrent.Async (Async, async, cancel, pollSTM)
import           Control.Concurrent.STM (STM, atomically, retry, throwSTM)
import           Control.Concurrent.STM.TVar (TVar, modifyTVar', newTVarIO, readTVar, writeTVar)
import           Data.ByteString (ByteString)
import qualified Data.ByteString as ByteString
import           Data.ByteString.Lazy (fromChunks)
import           Data.Monoid ((<>))
import qualified Network.HTTP2 as HTTP2
import           Network.Socket hiding (recv)
import           Network.Socket.ByteString
import qualified Network.TLS as TLS

-- TODO: catch connection errrors
data RawHttp2Connection = RawHttp2Connection {
    _sendRaw :: [ByteString] -> IO ()
  -- ^ Function to send raw data to the server.
  , _nextRaw :: Int -> IO ByteString
  -- ^ Function to block reading a datachunk of a given size from the server.
  , _close   :: IO ()
  }

-- | Initiates a RawHttp2Connection with a server.
--
-- The current code does not handle closing the connexion, yikes.
newRawHttp2Connection :: HostName
                      -- ^ Server's hostname.
                      -> PortNumber
                      -- ^ Server's port to connect to.
                      -> Maybe TLS.ClientParams
                      -- ^ TLS parameters. The 'TLS.onSuggestALPN' hook is
                      -- overwritten to always return ["h2", "h2-17"].
                      -> IO RawHttp2Connection
newRawHttp2Connection host port mparams = do
    -- Connects to TCP.
    let hints = defaultHints { addrFlags = [AI_NUMERICSERV], addrSocketType = Stream }
    addr:_ <- getAddrInfo (Just hints) (Just host) (Just $ show port)
    skt <- socket (addrFamily addr) (addrSocketType addr) (addrProtocol addr)
    connect skt (addrAddress addr)

    -- Prepare structure with abstract API.
    conn <- maybe (plainTextRaw skt) (tlsRaw skt) mparams

    -- Initializes the HTTP2 stream.
    _sendRaw conn [HTTP2.connectionPreface]

    return conn

plainTextRaw :: Socket -> IO RawHttp2Connection
plainTextRaw skt = do
    (b,putRaw) <- startWriteWorker (sendMany skt)
    (a,getRaw) <- startReadWorker (recv skt)
    let doClose = cancel a >> cancel b >> close skt
    return $ RawHttp2Connection (atomically . putRaw) (atomically . getRaw) doClose

tlsRaw :: Socket -> TLS.ClientParams -> IO RawHttp2Connection
tlsRaw skt params = do
    -- Connects to SSL
    tlsContext <- TLS.contextNew skt (modifyParams params)
    TLS.handshake tlsContext

    (b,putRaw) <- startWriteWorker (TLS.sendData tlsContext . fromChunks)
    (a,getRaw) <- startReadWorker (const $ TLS.recvData tlsContext)
    let doClose       = cancel a >> cancel b >> TLS.bye tlsContext >> TLS.contextClose tlsContext

    return $ RawHttp2Connection (atomically . putRaw) (atomically . getRaw) doClose
  where
    modifyParams prms = prms {
        TLS.clientHooks = (TLS.clientHooks prms) {
            TLS.onSuggestALPN = return $ Just [ "h2", "h2-17" ]
          }
      }

startWriteWorker
  :: ([ByteString] -> IO ())
  -> IO (Async (), [ByteString] -> STM ())
startWriteWorker sendChunks = do
    outQ <- newTVarIO []
    let putRaw chunks = modifyTVar' outQ (\xs -> xs ++ chunks)
    b <- async $ writeWorkerLoop outQ sendChunks
    return (b, putRaw)

writeWorkerLoop :: TVar [ByteString] -> ([ByteString] -> IO ()) -> IO ()
writeWorkerLoop outQ sendChunks = forever $ do
    xs <- atomically $ do
        chunks <- readTVar outQ
        when (null chunks) retry
        writeTVar outQ []
        return chunks
    sendChunks xs

startReadWorker
  :: (Int -> IO ByteString)
  -> IO (Async (), (Int -> STM ByteString))
startReadWorker get = do
    buf <- newTVarIO ""
    a <- async $ readWorkerLoop buf get
    return $ (a, getRawWorker a buf)

readWorkerLoop :: TVar ByteString -> (Int -> IO ByteString) -> IO ()
readWorkerLoop buf next = forever $ do
    dat <- next 4096
    atomically $ modifyTVar' buf (\bs -> (bs <> dat))

getRawWorker :: Async () -> TVar ByteString -> Int -> STM ByteString
getRawWorker a buf amount = do
    -- Verifies if the STM is alive, if dead, we re-throw the original
    -- exception.
    asyncStatus <- pollSTM a
    case asyncStatus of
        (Just (Left e)) -> throwSTM e
        _               -> return ()
    -- Read data consume, if there's enough, retry otherwise.
    dat <- readTVar buf
    if amount > ByteString.length dat
    then retry
    else do
        writeTVar buf (ByteString.drop amount dat)
        return $ ByteString.take amount dat