{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards  #-}
{-# LANGUAGE RankNTypes  #-}
module Network.HTTP2.Client.RawConnection (
      RawHttp2Connection (..)
    , newRawHttp2Connection
    ) where
import           Control.Monad (forever, when)
import           Control.Concurrent.Async (Async, async, cancel, pollSTM)
import           Control.Concurrent.STM (STM, atomically, retry, throwSTM)
import           Control.Concurrent.STM.TVar (TVar, modifyTVar', newTVarIO, readTVar, writeTVar)
import           Data.ByteString (ByteString)
import qualified Data.ByteString as ByteString
import           Data.ByteString.Lazy (fromChunks)
import           Data.Monoid ((<>))
import qualified Network.HTTP2 as HTTP2
import           Network.Socket hiding (recv)
import           Network.Socket.ByteString
import qualified Network.TLS as TLS
data RawHttp2Connection = RawHttp2Connection {
    _sendRaw :: [ByteString] -> IO ()
  
  , _nextRaw :: Int -> IO ByteString
  
  , _close   :: IO ()
  }
newRawHttp2Connection :: HostName
                      
                      -> PortNumber
                      
                      -> Maybe TLS.ClientParams
                      
                      
                      -> IO RawHttp2Connection
newRawHttp2Connection host port mparams = do
    
    let hints = defaultHints { addrFlags = [AI_NUMERICSERV], addrSocketType = Stream }
    addr:_ <- getAddrInfo (Just hints) (Just host) (Just $ show port)
    skt <- socket (addrFamily addr) (addrSocketType addr) (addrProtocol addr)
    connect skt (addrAddress addr)
    
    conn <- maybe (plainTextRaw skt) (tlsRaw skt) mparams
    
    _sendRaw conn [HTTP2.connectionPreface]
    return conn
plainTextRaw :: Socket -> IO RawHttp2Connection
plainTextRaw skt = do
    (b,putRaw) <- startWriteWorker (sendMany skt)
    (a,getRaw) <- startReadWorker (recv skt)
    let doClose = cancel a >> cancel b >> close skt
    return $ RawHttp2Connection (atomically . putRaw) (atomically . getRaw) doClose
tlsRaw :: Socket -> TLS.ClientParams -> IO RawHttp2Connection
tlsRaw skt params = do
    
    tlsContext <- TLS.contextNew skt (modifyParams params)
    TLS.handshake tlsContext
    (b,putRaw) <- startWriteWorker (TLS.sendData tlsContext . fromChunks)
    (a,getRaw) <- startReadWorker (const $ TLS.recvData tlsContext)
    let doClose       = cancel a >> cancel b >> TLS.bye tlsContext >> TLS.contextClose tlsContext
    return $ RawHttp2Connection (atomically . putRaw) (atomically . getRaw) doClose
  where
    modifyParams prms = prms {
        TLS.clientHooks = (TLS.clientHooks prms) {
            TLS.onSuggestALPN = return $ Just [ "h2", "h2-17" ]
          }
      }
startWriteWorker
  :: ([ByteString] -> IO ())
  -> IO (Async (), [ByteString] -> STM ())
startWriteWorker sendChunks = do
    outQ <- newTVarIO []
    let putRaw chunks = modifyTVar' outQ (\xs -> xs ++ chunks)
    b <- async $ writeWorkerLoop outQ sendChunks
    return (b, putRaw)
writeWorkerLoop :: TVar [ByteString] -> ([ByteString] -> IO ()) -> IO ()
writeWorkerLoop outQ sendChunks = forever $ do
    xs <- atomically $ do
        chunks <- readTVar outQ
        when (null chunks) retry
        writeTVar outQ []
        return chunks
    sendChunks xs
startReadWorker
  :: (Int -> IO ByteString)
  -> IO (Async (), (Int -> STM ByteString))
startReadWorker get = do
    buf <- newTVarIO ""
    a <- async $ readWorkerLoop buf get
    return $ (a, getRawWorker a buf)
readWorkerLoop :: TVar ByteString -> (Int -> IO ByteString) -> IO ()
readWorkerLoop buf next = forever $ do
    dat <- next 4096
    atomically $ modifyTVar' buf (\bs -> (bs <> dat))
getRawWorker :: Async () -> TVar ByteString -> Int -> STM ByteString
getRawWorker a buf amount = do
    
    
    asyncStatus <- pollSTM a
    case asyncStatus of
        (Just (Left e)) -> throwSTM e
        _               -> return ()
    
    dat <- readTVar buf
    if amount > ByteString.length dat
    then retry
    else do
        writeTVar buf (ByteString.drop amount dat)
        return $ ByteString.take amount dat