{-# LANGUAGE GADTs                 #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings     #-}
{-# LANGUAGE ScopedTypeVariables   #-}
{-# LANGUAGE StandaloneDeriving    #-}
{-# LANGUAGE TypeFamilies          #-}

module Servant.Haxl.Client.Internal
  ( ServantResponse(..)
  , ServantRequest(..)
  , initServantClientState
  ) where

import           Control.Concurrent.Async
import           Control.Concurrent.QSem
import           Control.Exception
import           Control.Monad
import           Control.Monad.Catch                (MonadThrow)
import           Control.Monad.IO.Class
import           Control.Monad.Trans.Either
import           Data.ByteString.Lazy               hiding (elem, filter, map,
                                                     null, pack)
import           Data.Hashable
import           Data.String
import           Data.String.Conversions
import           Data.Text.Encoding
import           Haxl.Core                          hiding (Request, catch)
import           Network.HTTP.Client
import           Network.HTTP.Client.TLS
import           Network.HTTP.Media
import           Network.HTTP.Types
import qualified Network.HTTP.Types.Header          as HTTP
import           Network.URI
import           Servant.Haxl.Client.BaseUrl
import           Servant.Haxl.Client.Internal.Error
import           Servant.Haxl.Client.Types

import qualified Network.HTTP.Client                as Client

data ServantResponse = ServantResponse (Response ByteString) deriving Show

reqToRequest :: MonadThrow m => Req -> BaseUrl -> m Request
reqToRequest req (BaseUrl reqScheme reqHost reqPort) =
    (setheaders . setAccept . setrqb . setQS) <$> parseUrl url

  where url = show $ nullURI { uriScheme = case reqScheme of
                                  Http  -> "http:"
                                  Https -> "https:"
                             , uriAuthority = Just
                                 URIAuth { uriUserInfo = ""
                                         , uriRegName = reqHost
                                         , uriPort = ":" ++ show reqPort
                                         }
                             , uriPath = reqPath req
                             }

        setrqb r = case reqBody req of
                     Nothing -> r
                     Just (b,t) -> r { requestBody = RequestBodyLBS b
                                     , requestHeaders = requestHeaders r
                                                     ++ [(hContentType, cs . show $ t)] }
        setQS = setQueryString $ queryTextToQuery (qs req)
        setheaders r = r { requestHeaders = requestHeaders r
                                         <> fmap toProperHeader (headers req) }
        setAccept r = r { requestHeaders = filter ((/= "Accept") . fst) (requestHeaders r)
                                        <> [("Accept", renderHeader $ reqAccept req)
                                              | not . null . reqAccept $ req] }
        toProperHeader (name, val) =
          (fromString name, encodeUtf8 val)


performRequest_ :: Manager -> Method -> Req -> WantedStatusCodes -> BaseUrl
               -> EitherT ServantError IO ( Int, ByteString, MediaType
                                          , [HTTP.Header], ServantResponse)
performRequest_ manager reqMethod req wantedStatus reqHost = do
  partialRequest <- liftIO $ reqToRequest req reqHost

  let request = partialRequest { Client.method = reqMethod
                               , checkStatus = \ _status _headers _cookies -> Nothing
                               }

  eResponse <- liftIO $ catchHttpException $ Client.httpLbs request manager
  case eResponse of
    Left err ->
      left $ ConnectionError $ ServantConnectionError err

    Right response -> do
      let status = Client.responseStatus response
          body = Client.responseBody response
          hrds = Client.responseHeaders response
          status_code = statusCode status
      ct <- case lookup "Content-Type" $ Client.responseHeaders response of
                 Nothing -> pure $ "application"//"octet-stream"
                 Just t -> case parseAccept t of
                   Nothing -> left $ InvalidContentTypeHeader (cs t) body
                   Just t' -> pure t'
      unless (wantedStatus `wants` status_code) $
        left $ FailureResponse status ct body
      return (status_code, body, ct, hrds, ServantResponse response)
      where
        wants AllCodes _ = True
        wants (SelectCodes codes) status_code = status_code `elem` codes

catchHttpException :: IO a -> IO (Either HttpException a)
catchHttpException action =
  catch (Right <$> action) (pure . Left)

data ServantRequest a where
  ServantRequest :: Method -> Req -> WantedStatusCodes -> BaseUrl ->
    ServantRequest (Int, ByteString, MediaType, [HTTP.Header], ServantResponse)

deriving instance Show (ServantRequest a)
deriving instance Eq (ServantRequest a)

instance Show1 ServantRequest where
  show1 = show

instance Hashable (ServantRequest a) where
  hashWithSalt s (ServantRequest m r w h) = hashWithSalt s (m, r, w, h)

instance StateKey ServantRequest where
  data State ServantRequest = ServantRequestState Int Manager

instance DataSourceName ServantRequest where
  dataSourceName _ = "ServantRequest"

instance DataSource () ServantRequest where
  fetch (ServantRequestState numThreads manager) _ () requests = AsyncFetch $ \inner -> do
    sem <- newQSem numThreads
    asyncs <- mapM (handler sem) requests
    inner
    mapM_ wait asyncs
    where
      handler :: QSem -> BlockedFetch ServantRequest -> IO (Async ())
      handler sem (BlockedFetch (ServantRequest met req wantedStatus reqHost) rvar) =
        async $ bracket_ (waitQSem sem) (signalQSem sem) $ do
          e <- runEitherT $ performRequest_ manager met req wantedStatus reqHost
          case e of
            Left err -> putFailure rvar err
            Right a -> putSuccess rvar a
          return ()

initServantClientState :: Int -> IO (State ServantRequest)
initServantClientState numThreads =
  ServantRequestState numThreads <$> newManager tlsManagerSettings