{-# LANGUAGE TypeFamilies          #-}
{-# LANGUAGE TypeOperators         #-}
{-# LANGUAGE TypeApplications      #-}
{-# LANGUAGE FlexibleInstances     #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables   #-}
{-# LANGUAGE FlexibleContexts      #-}
{-# LANGUAGE UndecidableInstances  #-}
{-# LANGUAGE PatternSynonyms       #-}
{-# LANGUAGE OverloadedStrings     #-}

-- | Provides instances to be able to use combinators from
-- "Servant.API.NamedArgs" with "Servant.Client", returning functions
-- using named parameters from "Named"
module Servant.Server.NamedArgs where

import Named ((:!), (:?), arg, argF, Name(..), (!), NamedF(..))
import Named.Internal (pattern Arg)
import Data.Functor.Identity (Identity)
import Servant.API ( (:>), SBoolI, FromHttpApiData, toQueryParam, toUrlPiece
                   , parseQueryParam, parseHeader )
import Servant.API.Modifiers (FoldRequired, FoldLenient)
import Servant.API.NamedArgs ( foldRequiredNamedArgument, NamedCapture', NamedFlag
                             , NamedParam, NamedParams, RequiredNamedArgument
                             , NamedCaptureAll, RequestNamedArgument, NamedHeader'
                             , unfoldRequestNamedArgument)
import Data.Either (partitionEithers)
import Data.Maybe (mapMaybe)
import Servant.Server (HasServer(..), errBody, err400)
import Servant.Server.Internal ( passToServer, addParameterCheck, withRequest, delayedFailFatal
                               , Router'(..), addCapture, delayedFail, DelayedIO, addHeaderCheck)
import Web.HttpApiData (parseUrlPieceMaybe, parseUrlPieces)
import Data.String.Conversions (cs)
import Network.HTTP.Types (parseQueryText)
import Network.Wai (rawQueryString, Request, requestHeaders)
import Data.Text (Text)
import qualified Data.Text as T
import Data.String (IsString(..))
import Control.Monad (join)
import GHC.TypeLits (KnownSymbol, symbolVal)
import Data.Proxy (Proxy(..))

-- | 'NamedFlag's are converted to required named arguments
instance (KnownSymbol name, HasServer api context)
    => HasServer (NamedFlag name :> api) context where

  type ServerT (NamedFlag name :> api) m
      = (name :! Bool) -> ServerT api m

  hoistServerWithContext _ pc nt s = hoistServerWithContext (Proxy @api) pc nt . s

  route Proxy context subserver =
      route (Proxy @api) context (passToServer subserver (Arg . param))
    where
      queryText = parseQueryText . rawQueryString
      param r = case lookup paramName (queryText r) of
          Just Nothing  -> True  -- param is there, with no value
          Just (Just v) -> examine v -- param with a value
          Nothing       -> False -- param not in the query string
      paramName = cs $ symbolVal (Proxy @name)
      examine v
        | v == "true" || v == "1" || v == "" = True
        | otherwise = False

-- | 'NamedCapture''s are converted to required named arguments
instance (KnownSymbol name, FromHttpApiData a, HasServer api context)
      => HasServer (NamedCapture' mods name a :> api) context where

  type ServerT (NamedCapture' mods name a :> api) m =
     (name :! a) -> ServerT api m

  hoistServerWithContext _ pc nt s = hoistServerWithContext (Proxy @api) pc nt . s

  route Proxy context d =
    CaptureRouter $
        route (Proxy @api)
              context
              (addCapture d $ \ txt -> case parseUrlPieceMaybe txt of
                 Nothing -> delayedFail err400
                 Just v  -> pure $ Arg v
              )

-- | 'NamedCaptureAll's are converted to required named arguments, taking a list
instance (KnownSymbol name, FromHttpApiData a, HasServer api context)
      => HasServer (NamedCaptureAll name a :> api) context where

  type ServerT (NamedCaptureAll name a :> api) m =
    (name :! [a]) -> ServerT api m

  hoistServerWithContext _ pc nt s = hoistServerWithContext (Proxy @api) pc nt . s

  route Proxy context d =
      CaptureAllRouter $
          route (Proxy @api)
                context
                (addCapture d $ \ txts -> case parseUrlPieces txts of
                   Left _  -> delayedFail err400
                   Right v -> pure $ Arg v
                )

-- | 'NamedParams's are converted to required named arguments, taking a list
instance (KnownSymbol name, FromHttpApiData a, HasServer api context)
    => HasServer (NamedParams name a :> api) context where

  type ServerT (NamedParams name a :> api) m =
    (name :! [a]) -> ServerT api m

  hoistServerWithContext _ pc nt s = hoistServerWithContext (Proxy @api) pc nt . s

  route Proxy context subserver = route (Proxy @api) context $
      subserver `addParameterCheck` withRequest paramsCheck
    where
      paramname = cs $ symbolVal (Proxy @name)
      paramsCheck req =
          case partitionEithers $ fmap parseQueryParam params of
              ([], parsed) -> pure $ Arg parsed
              (errs, _)    -> delayedFailFatal err400
                  { errBody = cs $ "Error parsing query parameter(s) "
                                   <> paramname <> " failed: "
                                   <> T.intercalate ", " errs
                  }
        where
          params :: [Text]
          params = mapMaybe snd
                 . filter (looksLikeParam . fst)
                 . parseQueryText
                 . rawQueryString
                 $ req

          looksLikeParam name = name == paramname || name == (paramname <> "[]")

-- | 'NamedHeader''s are converted to required or optional named arguments
-- depending on the 'Servant.API.Modifiers.Required' and
-- 'Servant.API.Modifiers.Optional' modifiers, of type a or 'Either' 'Text'
-- a depending on the 'Servant.API.Modifiers.Strict' and
-- 'Servant.API.Modifiers.Lenient' modifiers
instance ( KnownSymbol name
         , FromHttpApiData a
         , HasServer api context
         , SBoolI (FoldRequired mods)
         , SBoolI (FoldLenient mods)
         ) => HasServer (NamedHeader' mods name a :> api) context where

  type ServerT (NamedHeader' mods name a :> api) m =
    RequestNamedArgument mods name a -> ServerT api m

  hoistServerWithContext _ pc nt s = hoistServerWithContext (Proxy @api) pc nt . s

  route Proxy context subserver = route (Proxy @api) context $
      subserver `addHeaderCheck` withRequest headerCheck
    where
      headerName :: IsString n => n
      headerName = fromString $ symbolVal (Proxy @name)

      headerCheck :: Request -> DelayedIO (RequestNamedArgument mods name a)
      headerCheck req =
          unfoldRequestNamedArgument @mods @name errReq errSt mev
        where
          mev :: Maybe (Either T.Text a)
          mev = fmap parseHeader $ lookup headerName (requestHeaders req)

          errReq = delayedFailFatal err400
            { errBody = "Header " <> headerName <> " is required"
            }

          errSt e = delayedFailFatal err400
              { errBody = cs $ "Error parsing header "
                               <> headerName
                               <> " failed: " <> e
              }

-- | 'NamedParam's are converted to required or optional named arguments
-- depending on the 'Servant.API.Modifiers.Required' and
-- 'Servant.API.Modifiers.Optional' modifiers, of type a or 'Either' 'Text'
-- a depending on the 'Servant.API.Modifiers.Strict' and
-- 'Servant.API.Modifiers.Lenient' modifiers
instance ( KnownSymbol name
         , FromHttpApiData a
         , HasServer api context
         , SBoolI (FoldRequired mods)
         , SBoolI (FoldLenient mods)
         ) => HasServer (NamedParam mods name a :> api) context where

  type ServerT (NamedParam mods name a :> api) m =
    RequestNamedArgument mods name a -> ServerT api m

  hoistServerWithContext _ pc nt s = hoistServerWithContext (Proxy @api) pc nt . s

  route Proxy context subserver = route (Proxy @api) context delayed
    where
      queryText = parseQueryText . rawQueryString
      paramName = cs $ symbolVal (Proxy @name)

      errReq = delayedFailFatal err400
        { errBody = cs $ "Query parameter " <> paramName <> " is required"
        }
      errSt e = delayedFailFatal err400
        { errBody = cs $ "Error parsing query parameter "
                           <> paramName
                           <> " failed: " <> e
        }

      mev :: Request -> Maybe (Either T.Text a)
      mev req = fmap parseQueryParam $ join $ lookup paramName $ queryText req

      parseParam :: Request -> DelayedIO (RequestNamedArgument mods name a)
      parseParam req
          = unfoldRequestNamedArgument @mods @name errReq errSt (mev req)

      delayed = addParameterCheck subserver . withRequest $ \req -> parseParam req