{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE OverloadedStrings #-} -- | Provides instances to be able to use combinators from -- "Servant.API.NamedArgs" with "Servant.Client", returning functions -- using named parameters from "Named" module Servant.Server.NamedArgs where import Named ((:!), (:?), arg, argF, Name(..), (!), NamedF(..)) import Named.Internal (pattern Arg) import Data.Functor.Identity (Identity) import Servant.API ( (:>), SBoolI, FromHttpApiData, toQueryParam, toUrlPiece , parseQueryParam, parseHeader ) import Servant.API.Modifiers (FoldRequired, FoldLenient) import Servant.API.NamedArgs ( foldRequiredNamedArgument, NamedCapture', NamedFlag , NamedParam, NamedParams, RequiredNamedArgument , NamedCaptureAll, RequestNamedArgument, NamedHeader' , unfoldRequestNamedArgument) import Data.Either (partitionEithers) import Data.Maybe (mapMaybe) import Servant.Server (HasServer(..), errBody, err400) import Servant.Server.Internal ( passToServer, addParameterCheck, withRequest, delayedFailFatal , Router'(..), addCapture, delayedFail, DelayedIO, addHeaderCheck) import Web.HttpApiData (parseUrlPieceMaybe, parseUrlPieces) import Data.String.Conversions (cs) import Network.HTTP.Types (parseQueryText) import Network.Wai (rawQueryString, Request, requestHeaders) import Data.Text (Text) import qualified Data.Text as T import Data.String (IsString(..)) import Control.Monad (join) import GHC.TypeLits (KnownSymbol, symbolVal) import Data.Proxy (Proxy(..)) -- | 'NamedFlag's are converted to required named arguments instance (KnownSymbol name, HasServer api context) => HasServer (NamedFlag name :> api) context where type ServerT (NamedFlag name :> api) m = (name :! Bool) -> ServerT api m hoistServerWithContext _ pc nt s = hoistServerWithContext (Proxy @api) pc nt . s route Proxy context subserver = route (Proxy @api) context (passToServer subserver (Arg . param)) where queryText = parseQueryText . rawQueryString param r = case lookup paramName (queryText r) of Just Nothing -> True -- param is there, with no value Just (Just v) -> examine v -- param with a value Nothing -> False -- param not in the query string paramName = cs $ symbolVal (Proxy @name) examine v | v == "true" || v == "1" || v == "" = True | otherwise = False -- | 'NamedCapture''s are converted to required named arguments instance (KnownSymbol name, FromHttpApiData a, HasServer api context) => HasServer (NamedCapture' mods name a :> api) context where type ServerT (NamedCapture' mods name a :> api) m = (name :! a) -> ServerT api m hoistServerWithContext _ pc nt s = hoistServerWithContext (Proxy @api) pc nt . s route Proxy context d = CaptureRouter $ route (Proxy @api) context (addCapture d $ \ txt -> case parseUrlPieceMaybe txt of Nothing -> delayedFail err400 Just v -> pure $ Arg v ) -- | 'NamedCaptureAll's are converted to required named arguments, taking a list instance (KnownSymbol name, FromHttpApiData a, HasServer api context) => HasServer (NamedCaptureAll name a :> api) context where type ServerT (NamedCaptureAll name a :> api) m = (name :! [a]) -> ServerT api m hoistServerWithContext _ pc nt s = hoistServerWithContext (Proxy @api) pc nt . s route Proxy context d = CaptureAllRouter $ route (Proxy @api) context (addCapture d $ \ txts -> case parseUrlPieces txts of Left _ -> delayedFail err400 Right v -> pure $ Arg v ) -- | 'NamedParams's are converted to required named arguments, taking a list instance (KnownSymbol name, FromHttpApiData a, HasServer api context) => HasServer (NamedParams name a :> api) context where type ServerT (NamedParams name a :> api) m = (name :! [a]) -> ServerT api m hoistServerWithContext _ pc nt s = hoistServerWithContext (Proxy @api) pc nt . s route Proxy context subserver = route (Proxy @api) context $ subserver `addParameterCheck` withRequest paramsCheck where paramname = cs $ symbolVal (Proxy @name) paramsCheck req = case partitionEithers $ fmap parseQueryParam params of ([], parsed) -> pure $ Arg parsed (errs, _) -> delayedFailFatal err400 { errBody = cs $ "Error parsing query parameter(s) " <> paramname <> " failed: " <> T.intercalate ", " errs } where params :: [Text] params = mapMaybe snd . filter (looksLikeParam . fst) . parseQueryText . rawQueryString $ req looksLikeParam name = name == paramname || name == (paramname <> "[]") -- | 'NamedHeader''s are converted to required or optional named arguments -- depending on the 'Servant.API.Modifiers.Required' and -- 'Servant.API.Modifiers.Optional' modifiers, of type a or 'Either' 'Text' -- a depending on the 'Servant.API.Modifiers.Strict' and -- 'Servant.API.Modifiers.Lenient' modifiers instance ( KnownSymbol name , FromHttpApiData a , HasServer api context , SBoolI (FoldRequired mods) , SBoolI (FoldLenient mods) ) => HasServer (NamedHeader' mods name a :> api) context where type ServerT (NamedHeader' mods name a :> api) m = RequestNamedArgument mods name a -> ServerT api m hoistServerWithContext _ pc nt s = hoistServerWithContext (Proxy @api) pc nt . s route Proxy context subserver = route (Proxy @api) context $ subserver `addHeaderCheck` withRequest headerCheck where headerName :: IsString n => n headerName = fromString $ symbolVal (Proxy @name) headerCheck :: Request -> DelayedIO (RequestNamedArgument mods name a) headerCheck req = unfoldRequestNamedArgument @mods @name errReq errSt mev where mev :: Maybe (Either T.Text a) mev = fmap parseHeader $ lookup headerName (requestHeaders req) errReq = delayedFailFatal err400 { errBody = "Header " <> headerName <> " is required" } errSt e = delayedFailFatal err400 { errBody = cs $ "Error parsing header " <> headerName <> " failed: " <> e } -- | 'NamedParam's are converted to required or optional named arguments -- depending on the 'Servant.API.Modifiers.Required' and -- 'Servant.API.Modifiers.Optional' modifiers, of type a or 'Either' 'Text' -- a depending on the 'Servant.API.Modifiers.Strict' and -- 'Servant.API.Modifiers.Lenient' modifiers instance ( KnownSymbol name , FromHttpApiData a , HasServer api context , SBoolI (FoldRequired mods) , SBoolI (FoldLenient mods) ) => HasServer (NamedParam mods name a :> api) context where type ServerT (NamedParam mods name a :> api) m = RequestNamedArgument mods name a -> ServerT api m hoistServerWithContext _ pc nt s = hoistServerWithContext (Proxy @api) pc nt . s route Proxy context subserver = route (Proxy @api) context delayed where queryText = parseQueryText . rawQueryString paramName = cs $ symbolVal (Proxy @name) errReq = delayedFailFatal err400 { errBody = cs $ "Query parameter " <> paramName <> " is required" } errSt e = delayedFailFatal err400 { errBody = cs $ "Error parsing query parameter " <> paramName <> " failed: " <> e } mev :: Request -> Maybe (Either T.Text a) mev req = fmap parseQueryParam $ join $ lookup paramName $ queryText req parseParam :: Request -> DelayedIO (RequestNamedArgument mods name a) parseParam req = unfoldRequestNamedArgument @mods @name errReq errSt (mev req) delayed = addParameterCheck subserver . withRequest $ \req -> parseParam req