{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE OverloadedStrings #-} -- | Provides instances to be able to use combinators from -- "Servant.API.NamedArgs" with "Servant.Client", returning functions -- using named parameters from "Named" module Servant.Server.NamedArgs where import Named ((:!), (:?), arg, argF, Name(..), (!), NamedF(..)) import Named.Internal (pattern Arg) import Data.Functor.Identity (Identity) import Servant.API ( (:>), SBoolI, FromHttpApiData, toQueryParam, toUrlPiece , parseQueryParam, parseHeader, If, SBool(..), sbool) import Servant.API.Modifiers (FoldRequired, FoldLenient) import Servant.API.NamedArgs ( foldRequiredNamedArgument, NamedCapture', NamedFlag , NamedParam, NamedParams, RequiredNamedArgument , NamedCaptureAll, RequestNamedArgument, NamedHeader' , unfoldRequestNamedArgument, NamedBody') import Servant.API.ContentTypes (AllCTUnrender(..)) import Data.Either (partitionEithers) import Data.Maybe (mapMaybe, fromMaybe) import Servant.Server (HasServer(..), errBody, err400, err415) import Servant.Server.Internal ( passToServer, addParameterCheck, withRequest, delayedFailFatal , Router'(..), addCapture, delayedFail, DelayedIO, addHeaderCheck , addBodyCheck) import Web.HttpApiData (parseUrlPieceMaybe, parseUrlPieces) import Data.String.Conversions (cs) import Network.HTTP.Types (parseQueryText, hContentType) import Network.Wai (rawQueryString, Request, requestHeaders, lazyRequestBody) import Data.Text (Text) import qualified Data.Text as T import Data.String (IsString(..)) import Control.Monad (join) import GHC.TypeLits (KnownSymbol, symbolVal) import Data.Proxy (Proxy(..)) import Control.Monad.IO.Class (liftIO) import qualified Data.ByteString.Lazy as BL -- | 'NamedFlag's are converted to required named arguments instance (KnownSymbol name, HasServer api context) => HasServer (NamedFlag name :> api) context where type ServerT (NamedFlag name :> api) m = (name :! Bool) -> ServerT api m hoistServerWithContext _ pc nt s = hoistServerWithContext (Proxy @api) pc nt . s route Proxy context subserver = route (Proxy @api) context (passToServer subserver (Arg . param)) where queryText = parseQueryText . rawQueryString param r = case lookup paramName (queryText r) of Just Nothing -> True -- param is there, with no value Just (Just v) -> examine v -- param with a value Nothing -> False -- param not in the query string paramName = cs $ symbolVal (Proxy @name) examine v | v == "true" || v == "1" || v == "" = True | otherwise = False -- | 'NamedCapture''s are converted to required named arguments instance (KnownSymbol name, FromHttpApiData a, HasServer api context) => HasServer (NamedCapture' mods name a :> api) context where type ServerT (NamedCapture' mods name a :> api) m = (name :! a) -> ServerT api m hoistServerWithContext _ pc nt s = hoistServerWithContext (Proxy @api) pc nt . s route Proxy context d = CaptureRouter $ route (Proxy @api) context (addCapture d $ \ txt -> case parseUrlPieceMaybe txt of Nothing -> delayedFail err400 Just v -> pure $ Arg v ) -- | 'NamedCaptureAll's are converted to required named arguments, taking a list instance (KnownSymbol name, FromHttpApiData a, HasServer api context) => HasServer (NamedCaptureAll name a :> api) context where type ServerT (NamedCaptureAll name a :> api) m = (name :! [a]) -> ServerT api m hoistServerWithContext _ pc nt s = hoistServerWithContext (Proxy @api) pc nt . s route Proxy context d = CaptureAllRouter $ route (Proxy @api) context (addCapture d $ \ txts -> case parseUrlPieces txts of Left _ -> delayedFail err400 Right v -> pure $ Arg v ) -- | 'NamedParams's are converted to required named arguments, taking a list instance (KnownSymbol name, FromHttpApiData a, HasServer api context) => HasServer (NamedParams name a :> api) context where type ServerT (NamedParams name a :> api) m = (name :! [a]) -> ServerT api m hoistServerWithContext _ pc nt s = hoistServerWithContext (Proxy @api) pc nt . s route Proxy context subserver = route (Proxy @api) context $ subserver `addParameterCheck` withRequest paramsCheck where paramname = cs $ symbolVal (Proxy @name) paramsCheck req = case partitionEithers $ fmap parseQueryParam params of ([], parsed) -> pure $ Arg parsed (errs, _) -> delayedFailFatal err400 { errBody = cs $ "Error parsing query parameter(s) " <> paramname <> " failed: " <> T.intercalate ", " errs } where params :: [Text] params = mapMaybe snd . filter (looksLikeParam . fst) . parseQueryText . rawQueryString $ req looksLikeParam name = name == paramname || name == (paramname <> "[]") -- | 'NamedHeader''s are converted to required or optional named arguments -- depending on the 'Servant.API.Modifiers.Required' and -- 'Servant.API.Modifiers.Optional' modifiers, of type a or 'Either' 'Text' -- a depending on the 'Servant.API.Modifiers.Strict' and -- 'Servant.API.Modifiers.Lenient' modifiers instance ( KnownSymbol name , FromHttpApiData a , HasServer api context , SBoolI (FoldRequired mods) , SBoolI (FoldLenient mods) ) => HasServer (NamedHeader' mods name a :> api) context where type ServerT (NamedHeader' mods name a :> api) m = RequestNamedArgument mods name a -> ServerT api m hoistServerWithContext _ pc nt s = hoistServerWithContext (Proxy @api) pc nt . s route Proxy context subserver = route (Proxy @api) context $ subserver `addHeaderCheck` withRequest headerCheck where headerName :: IsString n => n headerName = fromString $ symbolVal (Proxy @name) headerCheck :: Request -> DelayedIO (RequestNamedArgument mods name a) headerCheck req = unfoldRequestNamedArgument @mods @name errReq errSt mev where mev :: Maybe (Either T.Text a) mev = fmap parseHeader $ lookup headerName (requestHeaders req) errReq = delayedFailFatal err400 { errBody = "Header " <> headerName <> " is required" } errSt e = delayedFailFatal err400 { errBody = cs $ "Error parsing header " <> headerName <> " failed: " <> e } -- | 'NamedParam's are converted to required or optional named arguments -- depending on the 'Servant.API.Modifiers.Required' and -- 'Servant.API.Modifiers.Optional' modifiers, of type a or 'Either' 'Text' -- a depending on the 'Servant.API.Modifiers.Strict' and -- 'Servant.API.Modifiers.Lenient' modifiers instance ( KnownSymbol name , FromHttpApiData a , HasServer api context , SBoolI (FoldRequired mods) , SBoolI (FoldLenient mods) ) => HasServer (NamedParam mods name a :> api) context where type ServerT (NamedParam mods name a :> api) m = RequestNamedArgument mods name a -> ServerT api m hoistServerWithContext _ pc nt s = hoistServerWithContext (Proxy @api) pc nt . s route Proxy context subserver = route (Proxy @api) context delayed where queryText = parseQueryText . rawQueryString paramName = cs $ symbolVal (Proxy @name) errReq = delayedFailFatal err400 { errBody = cs $ "Query parameter " <> paramName <> " is required" } errSt e = delayedFailFatal err400 { errBody = cs $ "Error parsing query parameter " <> paramName <> " failed: " <> e } mev :: Request -> Maybe (Either T.Text a) mev req = fmap parseQueryParam $ join $ lookup paramName $ queryText req parseParam :: Request -> DelayedIO (RequestNamedArgument mods name a) parseParam req = unfoldRequestNamedArgument @mods @name errReq errSt (mev req) delayed = addParameterCheck subserver . withRequest $ \req -> parseParam req instance ( KnownSymbol name, AllCTUnrender list a, HasServer api context , SBoolI (FoldLenient mods) ) => HasServer (NamedBody' mods name list a :> api) context where type ServerT (NamedBody' mods name list a :> api) m = (name :! (If (FoldLenient mods) (Either String a) a)) -> ServerT api m hoistServerWithContext _ pc nt s = hoistServerWithContext (Proxy @api) pc nt . s route _ context subserver = route (Proxy @api) context $ addBodyCheck subserver ctCheck bodyCheck where ctCheck = withRequest $ \req -> do let contentTypeH = fromMaybe "application/octet-stream" $ lookup hContentType $ requestHeaders req case canHandleCTypeH (Proxy @list) (cs contentTypeH) :: Maybe (BL.ByteString -> Either String a) of Nothing -> delayedFail err415 Just f -> pure f bodyCheck f = withRequest $ \req -> do mrqbody <- f <$> liftIO (lazyRequestBody req) case sbool :: SBool (FoldLenient mods) of STrue -> pure . Arg $ mrqbody SFalse -> case mrqbody of Left e -> delayedFailFatal err400 { errBody = cs e } Right v -> pure . Arg $ v