--------------------------------------------------------------------------------
-- Rate Limiting Middleware for Servant                                       --
--------------------------------------------------------------------------------
-- This source code is licensed under the MIT license found in the LICENSE    --
-- file in the root directory of this source tree.                            --
--------------------------------------------------------------------------------

{-# OPTIONS_GHC -Wno-orphans #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE UndecidableInstances #-}

module Servant.RateLimit.Server where

--------------------------------------------------------------------------------

import Control.Monad
import Control.Monad.IO.Class

import Network.Wai.RateLimit.Backend
import Network.Wai.RateLimit.Strategy

import Servant
import Servant.RateLimit.Types
import Servant.Server.Internal.Delayed
import Servant.Server.Internal.DelayedIO

--------------------------------------------------------------------------------

instance
    ( HasServer api ctx
    , HasContextEntry ctx (Backend key)
    , HasRateLimitStrategy strategy
    , HasRateLimitPolicy policy
    , key ~ RateLimitPolicyKey policy
    ) => HasServer (RateLimit strategy policy :> api) ctx
    where

    type ServerT (RateLimit strategy policy :> api) m = ServerT api m

    hoistServerWithContext :: Proxy (RateLimit strategy policy :> api)
-> Proxy ctx
-> (forall x. m x -> n x)
-> ServerT (RateLimit strategy policy :> api) m
-> ServerT (RateLimit strategy policy :> api) n
hoistServerWithContext Proxy (RateLimit strategy policy :> api)
_ Proxy ctx
pc forall x. m x -> n x
nt ServerT (RateLimit strategy policy :> api) m
s =
        Proxy api
-> Proxy ctx
-> (forall x. m x -> n x)
-> ServerT api m
-> ServerT api n
forall k (api :: k) (context :: [*]) (m :: * -> *) (n :: * -> *).
HasServer api context =>
Proxy api
-> Proxy context
-> (forall x. m x -> n x)
-> ServerT api m
-> ServerT api n
hoistServerWithContext (Proxy api
forall k (t :: k). Proxy t
Proxy :: Proxy api) Proxy ctx
pc forall x. m x -> n x
nt ServerT api m
ServerT (RateLimit strategy policy :> api) m
s

    route :: Proxy (RateLimit strategy policy :> api)
-> Context ctx
-> Delayed env (Server (RateLimit strategy policy :> api))
-> Router env
route Proxy (RateLimit strategy policy :> api)
_ Context ctx
context Delayed env (Server (RateLimit strategy policy :> api))
subserver = do
        -- retrieve the backend from the Servant context
        let backend :: Backend key
backend = Context ctx -> Backend key
forall (context :: [*]) val.
HasContextEntry context val =>
Context context -> val
getContextEntry Context ctx
context

        -- retrieve the rate-limiting policy used to identify clients
        let policy :: Request -> IO (RateLimitPolicyKey policy)
policy = HasRateLimitPolicy policy =>
Request -> IO (RateLimitPolicyKey policy)
forall policy.
HasRateLimitPolicy policy =>
Request -> IO (RateLimitPolicyKey policy)
policyGetIdentifier @policy

        -- retrieve the rate-limiting strategy used to limit access
        let strategy :: Strategy
strategy = Backend key -> (Request -> IO key) -> Strategy
forall strategy key.
HasRateLimitStrategy strategy =>
Backend key -> (Request -> IO key) -> Strategy
strategyValue @strategy @key Backend key
backend Request -> IO key
Request -> IO (RateLimitPolicyKey policy)
policy

        let rateCheck :: DelayedIO ()
rateCheck = (Request -> DelayedIO ()) -> DelayedIO ()
forall a. (Request -> DelayedIO a) -> DelayedIO a
withRequest ((Request -> DelayedIO ()) -> DelayedIO ())
-> (Request -> DelayedIO ()) -> DelayedIO ()
forall a b. (a -> b) -> a -> b
$ \Request
req -> do
                -- apply the rate-limiting strategy to the request
                Bool
allowRequest <- IO Bool -> DelayedIO Bool
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Bool -> DelayedIO Bool) -> IO Bool -> DelayedIO Bool
forall a b. (a -> b) -> a -> b
$ Strategy -> Request -> IO Bool
strategyOnRequest Strategy
strategy Request
req

                -- fail if the rate limit has been exceeded
                Bool -> DelayedIO () -> DelayedIO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
allowRequest (DelayedIO () -> DelayedIO ()) -> DelayedIO () -> DelayedIO ()
forall a b. (a -> b) -> a -> b
$ ServerError -> DelayedIO ()
forall a. ServerError -> DelayedIO a
delayedFailFatal (ServerError -> DelayedIO ()) -> ServerError -> DelayedIO ()
forall a b. (a -> b) -> a -> b
$ ServerError :: Int -> String -> ByteString -> [Header] -> ServerError
ServerError{
                    errHTTPCode :: Int
errHTTPCode = Int
429,
                    errReasonPhrase :: String
errReasonPhrase = String
"Rate limit exceeded",
                    errBody :: ByteString
errBody = ByteString
"",
                    errHeaders :: [Header]
errHeaders = []
                }

        -- add the check for whether the rate limit has been exceeded to the
        -- server and return it
        Proxy api -> Context ctx -> Delayed env (Server api) -> Router env
forall k (api :: k) (context :: [*]) env.
HasServer api context =>
Proxy api
-> Context context -> Delayed env (Server api) -> Router env
route (Proxy api
forall k (t :: k). Proxy t
Proxy :: Proxy api) Context ctx
context (Delayed env (Server api) -> Router env)
-> Delayed env (Server api) -> Router env
forall a b. (a -> b) -> a -> b
$
            Delayed env (Server api)
Delayed env (Server (RateLimit strategy policy :> api))
subserver Delayed env (Server api)
-> DelayedIO () -> Delayed env (Server api)
forall env a. Delayed env a -> DelayedIO () -> Delayed env a
`addAcceptCheck` DelayedIO ()
rateCheck

--------------------------------------------------------------------------------