{-# LANGUAGE OverloadedStrings #-}
{-|
Copyright   : (c) 2018-2021 Tim Emiola
SPDX-License-Identifier: BSD3
Maintainer  : Tim Emiola <tim.emiola@gmail.com>

Provides a [WAI](https://hackage.haskell.com/packages/wai) middleware that
delegates handling of requests.

Provides 3 combinators that create middleware along with supporting data types.

* 'delegateTo': delegates handling of requests matching a predicate to a
  delegate Application

* 'delegateToProxy': delegates handling of requests matching a predicate to
  different host

* 'simpleProxy': is a simple reverse proxy, based on proxyApp of http-proxy by
  Erik de Castro Lopo/Michael Snoyman

-}

module Network.Wai.Middleware.Delegate
  ( -- * Middleware
    delegateTo
  , delegateToProxy
  , simpleProxy

    -- * Configuration
  , ProxySettings(..)

    -- * Aliases
  , RequestPredicate
  )

where

import           Control.Exception           (SomeException, handle,
                                              toException)
import           Control.Monad.IO.Class      (MonadIO, liftIO)
import qualified Data.ByteString             as BS
import qualified Data.ByteString.Char8       as C8
import qualified Data.ByteString.Lazy.Char8  as LC8
import           Data.Monoid                 ((<>))
import           Data.String                 (IsString)

import           Blaze.ByteString.Builder    (fromByteString)
import           Control.Concurrent.Async    (race_)
import           Data.CaseInsensitive        (mk)
import           Data.Conduit                (ConduitT, Flush (..), Void,
                                              mapOutput, runConduit, yield,
                                              (.|))
import           Data.Conduit.Network        (appSink, appSource)
import           Data.Default                (Default (..))
import           Data.Streaming.Network      (ClientSettings, clientSettingsTCP,
                                              runTCPClient)
import           Network.HTTP.Client         (Manager, Request (..),
                                              Response (..), parseRequest,
                                              withResponse)
import           Network.HTTP.Client.Conduit (bodyReaderSource)
import           Network.HTTP.Conduit        (requestBodySourceChunkedIO,
                                              requestBodySourceIO)
import           Network.HTTP.Types          (hContentType,
                                              internalServerError500, status304,
                                              status500)
import           Network.HTTP.Types.Header   (hHost)
import qualified Network.Wai                 as Wai
import           Network.Wai.Conduit         (responseRawSource, responseSource,
                                              sourceRequestBody)

-- | Type alias for a function that determines if a request should be handled by
-- a delegate.
type RequestPredicate = Wai.Request -> Bool

-- | Create a middleware that handles all requests matching a predicate by
-- delegating to an alternate Application.
delegateTo :: Wai.Application -> RequestPredicate -> Wai.Middleware
delegateTo :: Application -> RequestPredicate -> Middleware
delegateTo Application
alt RequestPredicate
f Application
actual Request
req
  | RequestPredicate
f Request
req = Application
alt Request
req
  | Bool
otherwise = Application
actual Request
req

-- | Creates a middleware that handles all requests matching a predicate by
-- proxing them to a host specified by ProxySettings.
delegateToProxy :: ProxySettings -> Manager -> RequestPredicate -> Wai.Middleware
delegateToProxy :: ProxySettings -> Manager -> RequestPredicate -> Middleware
delegateToProxy ProxySettings
settings Manager
mgr = Application -> RequestPredicate -> Middleware
delegateTo (ProxySettings -> Manager -> Application
simpleProxy ProxySettings
settings Manager
mgr)

-- | Settings that configure the proxy endpoint.
data ProxySettings =
  ProxySettings
  { -- | What to do with exceptions thrown by either the application or server.
    ProxySettings -> SomeException -> Response
proxyOnException   :: SomeException -> Wai.Response
    -- | Timeout value in seconds. Default value: 30
  , ProxySettings -> Int
proxyTimeout       :: Int
    -- | The host being proxied
  , ProxySettings -> ByteString
proxyHost          :: BS.ByteString
    -- | The number of redirects to follow. 0 means none, which is the default.
  , ProxySettings -> Int
proxyRedirectCount :: Int
  }

instance Default ProxySettings where
  -- | The default settings for the Proxy server. See the individual settings for
  -- the default value.
  def :: ProxySettings
def = ProxySettings :: (SomeException -> Response)
-> Int -> ByteString -> Int -> ProxySettings
ProxySettings
    { -- defaults to returning internal server error showing the error in the body
      proxyOnException :: SomeException -> Response
proxyOnException = SomeException -> Response
onException
      -- default to 15 seconds
    , proxyTimeout :: Int
proxyTimeout = Int
15
    , proxyHost :: ByteString
proxyHost = ByteString
"localhost"
    , proxyRedirectCount :: Int
proxyRedirectCount = Int
0
    }
    where
      onException :: SomeException -> Wai.Response
      onException :: SomeException -> Response
onException SomeException
e =
        Status -> ResponseHeaders -> ByteString -> Response
Wai.responseLBS Status
internalServerError500
        [ (HeaderName
hContentType, ByteString
"text/plain; charset=utf-8") ] (ByteString -> Response) -> ByteString -> Response
forall a b. (a -> b) -> a -> b
$
        [ByteString] -> ByteString
LC8.fromChunks [String -> ByteString
C8.pack (String -> ByteString) -> String -> ByteString
forall a b. (a -> b) -> a -> b
$ SomeException -> String
forall a. Show a => a -> String
show SomeException
e]

-- | A Wai Application that acts as a http/https proxy.
simpleProxy
  :: ProxySettings
  -> Manager
  -> Wai.Application
simpleProxy :: ProxySettings -> Manager -> Application
simpleProxy ProxySettings
settings Manager
manager Request
req Response -> IO ResponseReceived
respond
    -- we may connect requests to secure sites, when we do, we will not have
    -- seen their URI properly
    | Request -> ByteString
Wai.requestMethod Request
req ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
"CONNECT" = do
        Response -> IO ResponseReceived
respond (Response -> IO ResponseReceived)
-> Response -> IO ResponseReceived
forall a b. (a -> b) -> a -> b
$ (Source IO ByteString -> Sink ByteString IO () -> IO ())
-> Response -> Response
forall (m :: * -> *) (n :: * -> *).
(MonadIO m, MonadIO n) =>
(Source m ByteString -> Sink ByteString n () -> IO ())
-> Response -> Response
responseRawSource (Request -> Source IO ByteString -> Sink ByteString IO () -> IO ()
handleConnect Request
req)
                    (Status -> ResponseHeaders -> ByteString -> Response
Wai.responseLBS Status
status500 [(HeaderName
"Content-Type", ByteString
"text/plain")] ByteString
"method CONNECT is not supported")
    | Bool
otherwise = do
        let scheme :: String
scheme
              | RequestPredicate
Wai.isSecure Request
req = String
"https"
              | Bool
otherwise = String
"http"
            rawUrl :: ByteString
rawUrl = Request -> ByteString
Wai.rawPathInfo Request
req ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> Request -> ByteString
Wai.rawQueryString Request
req
            effectiveUrl :: String
effectiveUrl = String
scheme String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"://" String -> String -> String
forall a. [a] -> [a] -> [a]
++ (ByteString -> String
C8.unpack (ByteString -> String) -> ByteString -> String
forall a b. (a -> b) -> a -> b
$ ProxySettings -> ByteString
proxyHost ProxySettings
settings) String -> String -> String
forall a. [a] -> [a] -> [a]
++ ByteString -> String
C8.unpack (ByteString
rawUrl)
            newHost :: ByteString
newHost = ProxySettings -> ByteString
proxyHost ProxySettings
settings
            addHostHeader :: ResponseHeaders -> ResponseHeaders
addHostHeader = (:) (HeaderName
hHost, ByteString
newHost)

        Request
proxyReq' <- String -> IO Request
forall (m :: * -> *). MonadThrow m => String -> m Request
parseRequest String
effectiveUrl
        let onException :: SomeException -> Wai.Response
            onException :: SomeException -> Response
onException = ProxySettings -> SomeException -> Response
proxyOnException ProxySettings
settings (SomeException -> Response)
-> (SomeException -> SomeException) -> SomeException -> Response
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SomeException -> SomeException
forall e. Exception e => e -> SomeException
toException

            proxyReq :: Request
proxyReq = Request
proxyReq'
              { method :: ByteString
method = Request -> ByteString
Wai.requestMethod Request
req
              , requestHeaders :: ResponseHeaders
requestHeaders = ResponseHeaders -> ResponseHeaders
addHostHeader (ResponseHeaders -> ResponseHeaders)
-> ResponseHeaders -> ResponseHeaders
forall a b. (a -> b) -> a -> b
$ ((HeaderName, ByteString) -> Bool)
-> ResponseHeaders -> ResponseHeaders
forall a. (a -> Bool) -> [a] -> [a]
filter (HeaderName, ByteString) -> Bool
forall a b. (Eq a, IsString a) => (a, b) -> Bool
dropUpstreamHeaders (ResponseHeaders -> ResponseHeaders)
-> ResponseHeaders -> ResponseHeaders
forall a b. (a -> b) -> a -> b
$ Request -> ResponseHeaders
Wai.requestHeaders Request
req
                -- always pass redirects back to the client.
              , redirectCount :: Int
redirectCount = ProxySettings -> Int
proxyRedirectCount ProxySettings
settings
              , requestBody :: RequestBody
requestBody =
                  case Request -> RequestBodyLength
Wai.requestBodyLength Request
req of
                    RequestBodyLength
Wai.ChunkedBody ->
                      Source IO ByteString -> RequestBody
requestBodySourceChunkedIO (Request -> Source IO ByteString
forall (m :: * -> *). MonadIO m => Request -> Source m ByteString
sourceRequestBody Request
req)
                    Wai.KnownLength Word64
l ->
                      Int64 -> Source IO ByteString -> RequestBody
requestBodySourceIO (Word64 -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word64
l) (Request -> Source IO ByteString
forall (m :: * -> *). MonadIO m => Request -> Source m ByteString
sourceRequestBody Request
req)
              -- don't modify the response to ensure consistency with the response headers
              , decompress :: ByteString -> Bool
decompress = Bool -> ByteString -> Bool
forall a b. a -> b -> a
const Bool
False
              , host :: ByteString
host = ByteString
newHost
              }

            respondUpstream :: IO ResponseReceived
respondUpstream = Request
-> Manager
-> (Response BodyReader -> IO ResponseReceived)
-> IO ResponseReceived
forall a.
Request -> Manager -> (Response BodyReader -> IO a) -> IO a
withResponse Request
proxyReq Manager
manager ((Response BodyReader -> IO ResponseReceived)
 -> IO ResponseReceived)
-> (Response BodyReader -> IO ResponseReceived)
-> IO ResponseReceived
forall a b. (a -> b) -> a -> b
$ \Response BodyReader
res -> do
              let body :: ConduitT i (Flush Builder) IO ()
body = (ByteString -> Flush Builder)
-> ConduitT i ByteString IO () -> ConduitT i (Flush Builder) IO ()
forall (m :: * -> *) o1 o2 i r.
Monad m =>
(o1 -> o2) -> ConduitT i o1 m r -> ConduitT i o2 m r
mapOutput (Builder -> Flush Builder
forall a. a -> Flush a
Chunk (Builder -> Flush Builder)
-> (ByteString -> Builder) -> ByteString -> Flush Builder
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Builder
fromByteString) (ConduitT i ByteString IO () -> ConduitT i (Flush Builder) IO ())
-> (BodyReader -> ConduitT i ByteString IO ())
-> BodyReader
-> ConduitT i (Flush Builder) IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BodyReader -> ConduitT i ByteString IO ()
forall (m :: * -> *) i.
MonadIO m =>
BodyReader -> ConduitM i ByteString m ()
bodyReaderSource (BodyReader -> ConduitT i (Flush Builder) IO ())
-> BodyReader -> ConduitT i (Flush Builder) IO ()
forall a b. (a -> b) -> a -> b
$ Response BodyReader -> BodyReader
forall body. Response body -> body
responseBody Response BodyReader
res
                  headers :: ResponseHeaders
headers = (ByteString -> HeaderName
forall s. FoldCase s => s -> CI s
mk ByteString
"X-Via-Proxy", ByteString
"yes") (HeaderName, ByteString) -> ResponseHeaders -> ResponseHeaders
forall a. a -> [a] -> [a]
: (Response BodyReader -> ResponseHeaders
forall body. Response body -> ResponseHeaders
responseHeaders Response BodyReader
res)
              Response -> IO ResponseReceived
respond (Response -> IO ResponseReceived)
-> Response -> IO ResponseReceived
forall a b. (a -> b) -> a -> b
$ Status -> ResponseHeaders -> Source IO (Flush Builder) -> Response
responseSource (Response BodyReader -> Status
forall body. Response body -> Status
responseStatus Response BodyReader
res) ResponseHeaders
headers Source IO (Flush Builder)
forall i. ConduitT i (Flush Builder) IO ()
body

        (SomeException -> IO ResponseReceived)
-> IO ResponseReceived -> IO ResponseReceived
forall e a. Exception e => (e -> IO a) -> IO a -> IO a
handle (Response -> IO ResponseReceived
respond (Response -> IO ResponseReceived)
-> (SomeException -> Response)
-> SomeException
-> IO ResponseReceived
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SomeException -> Response
onException) IO ResponseReceived
respondUpstream

handleConnect
  :: Wai.Request
  -> ConduitT () C8.ByteString IO ()
  -> ConduitT C8.ByteString Void IO ()
  -> IO ()
handleConnect :: Request -> Source IO ByteString -> Sink ByteString IO () -> IO ()
handleConnect Request
req Source IO ByteString
fromClient Sink ByteString IO ()
toClient =
  ClientSettings -> (AppData -> IO ()) -> IO ()
forall a. ClientSettings -> (AppData -> IO a) -> IO a
runTCPClient (Request -> ClientSettings
toClientSettings Request
req) ((AppData -> IO ()) -> IO ()) -> (AppData -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \AppData
ad -> do
  ConduitT () Void IO () -> IO ()
forall (m :: * -> *) r. Monad m => ConduitT () Void m r -> m r
runConduit (ConduitT () Void IO () -> IO ())
-> ConduitT () Void IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ ByteString -> Source IO ByteString
forall (m :: * -> *) o i. Monad m => o -> ConduitT i o m ()
yield ByteString
"HTTP/1.1 200 OK\r\n\r\n" Source IO ByteString
-> Sink ByteString IO () -> ConduitT () Void IO ()
forall (m :: * -> *) a b c r.
Monad m =>
ConduitM a b m () -> ConduitM b c m r -> ConduitM a c m r
.| Sink ByteString IO ()
toClient
  IO () -> IO () -> IO ()
forall a b. IO a -> IO b -> IO ()
race_
    (ConduitT () Void IO () -> IO ()
forall (m :: * -> *) r. Monad m => ConduitT () Void m r -> m r
runConduit (ConduitT () Void IO () -> IO ())
-> ConduitT () Void IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Source IO ByteString
fromClient Source IO ByteString
-> Sink ByteString IO () -> ConduitT () Void IO ()
forall (m :: * -> *) a b c r.
Monad m =>
ConduitM a b m () -> ConduitM b c m r -> ConduitM a c m r
.| AppData -> Sink ByteString IO ()
forall ad (m :: * -> *) o.
(HasReadWrite ad, MonadIO m) =>
ad -> ConduitT ByteString o m ()
appSink AppData
ad)
    (ConduitT () Void IO () -> IO ()
forall (m :: * -> *) r. Monad m => ConduitT () Void m r -> m r
runConduit (ConduitT () Void IO () -> IO ())
-> ConduitT () Void IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ AppData -> Source IO ByteString
forall ad (m :: * -> *) i.
(HasReadWrite ad, MonadIO m) =>
ad -> ConduitT i ByteString m ()
appSource AppData
ad Source IO ByteString
-> Sink ByteString IO () -> ConduitT () Void IO ()
forall (m :: * -> *) a b c r.
Monad m =>
ConduitM a b m () -> ConduitM b c m r -> ConduitM a c m r
.| Sink ByteString IO ()
toClient)

defaultClientPort :: Wai.Request -> Int
defaultClientPort :: Request -> Int
defaultClientPort Request
req
  | RequestPredicate
Wai.isSecure Request
req = Int
443
  | Bool
otherwise = Int
90

toClientSettings :: Wai.Request -> ClientSettings
toClientSettings :: Request -> ClientSettings
toClientSettings Request
req =
  case (Char -> Bool) -> ByteString -> (ByteString, ByteString)
C8.break (Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
== Char
':') (ByteString -> (ByteString, ByteString))
-> ByteString -> (ByteString, ByteString)
forall a b. (a -> b) -> a -> b
$ Request -> ByteString
Wai.rawPathInfo Request
req of
    (ByteString
host, ByteString
"") -> Int -> ByteString -> ClientSettings
clientSettingsTCP (Request -> Int
defaultClientPort Request
req) ByteString
host
    (ByteString
host, ByteString
port') -> case ByteString -> Maybe (Int, ByteString)
C8.readInt (ByteString -> Maybe (Int, ByteString))
-> ByteString -> Maybe (Int, ByteString)
forall a b. (a -> b) -> a -> b
$ Int -> ByteString -> ByteString
C8.drop Int
1 ByteString
port' of
      Just (Int
port, ByteString
_) -> Int -> ByteString -> ClientSettings
clientSettingsTCP Int
port ByteString
host
      Maybe (Int, ByteString)
Nothing        -> Int -> ByteString -> ClientSettings
clientSettingsTCP (Request -> Int
defaultClientPort Request
req) ByteString
host

dropUpstreamHeaders :: (Eq a, IsString a) => (a, b) -> Bool
dropUpstreamHeaders :: (a, b) -> Bool
dropUpstreamHeaders (a
k, b
_) = a
k a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem`
  [ a
"content-encoding"
  , a
"content-length"
  , a
"host"
  ]