{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE CPP #-}
module Network.HTTP.ReverseProxy
(
ProxyDest (..)
, rawProxyTo
, rawTcpProxyTo
, waiProxyTo
, defaultOnExc
, waiProxyToSettings
, WaiProxyResponse (..)
, WaiProxySettings
, defaultWaiProxySettings
, wpsOnExc
, wpsTimeout
, wpsSetIpHeader
, wpsProcessBody
, wpsUpgradeToRaw
, wpsGetDest
, SetIpHeader (..)
, LocalWaiProxySettings
, defaultLocalWaiProxySettings
, setLpsTimeBound
) where
import Blaze.ByteString.Builder (Builder, fromByteString,
toLazyByteString)
import Control.Applicative ((<$>), (<|>))
import Control.Monad (unless)
import Data.ByteString (ByteString)
import qualified Data.ByteString as S
import qualified Data.ByteString.Char8 as S8
import qualified Data.ByteString.Lazy as L
import qualified Data.CaseInsensitive as CI
import Data.Conduit
import qualified Data.Conduit.List as CL
import qualified Data.Conduit.Network as DCN
import Data.Functor.Identity (Identity (..))
import Data.IORef
import Data.Maybe (fromMaybe, listToMaybe)
import Data.Monoid (mappend, mconcat, (<>))
import Data.Set (Set)
import qualified Data.Set as Set
import Data.Streaming.Network (AppData, readLens)
import qualified Data.Text.Lazy as TL
import qualified Data.Text.Lazy.Encoding as TLE
import qualified Data.Text as T
import qualified Data.Text.Encoding as TE
import Data.Word8 (isSpace, _colon, _cr)
import GHC.Generics (Generic)
import Network.HTTP.Client (BodyReader, brRead)
import qualified Network.HTTP.Client as HC
import qualified Network.HTTP.Types as HT
import qualified Network.Wai as WAI
import Network.Wai.Logger (showSockAddr)
import UnliftIO (MonadIO, liftIO, MonadUnliftIO, timeout, SomeException, try, bracket, concurrently_)
data ProxyDest = ProxyDest
{ pdHost :: !ByteString
, pdPort :: !Int
} deriving (Read, Show, Eq, Ord, Generic)
rawProxyTo :: MonadUnliftIO m
=> (HT.RequestHeaders -> m (Either (DCN.AppData -> m ()) ProxyDest))
-> AppData -> m ()
rawProxyTo getDest appdata = do
(rsrc, headers) <- liftIO $ fromClient $$+ getHeaders
edest <- getDest headers
case edest of
Left app -> do
irsrc <- liftIO $ newIORef rsrc
let readData = do
rsrc1 <- readIORef irsrc
(rsrc2, mbs) <- rsrc1 $$++ await
writeIORef irsrc rsrc2
return $ fromMaybe "" mbs
app $ runIdentity (readLens (const (Identity readData)) appdata)
Right (ProxyDest host port) -> liftIO $ DCN.runTCPClient (DCN.clientSettings port host) (withServer rsrc)
where
fromClient = DCN.appSource appdata
toClient = DCN.appSink appdata
withServer rsrc appdataServer = concurrently_
(rsrc $$+- toServer)
(runConduit $ fromServer .| toClient)
where
fromServer = DCN.appSource appdataServer
toServer = DCN.appSink appdataServer
rawTcpProxyTo :: MonadIO m
=> ProxyDest
-> AppData
-> m ()
rawTcpProxyTo (ProxyDest host port) appdata = liftIO $
DCN.runTCPClient (DCN.clientSettings port host) withServer
where
withServer appdataServer = concurrently_
(runConduit $ DCN.appSource appdata .| DCN.appSink appdataServer)
(runConduit $ DCN.appSource appdataServer .| DCN.appSink appdata )
defaultOnExc :: SomeException -> WAI.Application
defaultOnExc exc _ sendResponse = sendResponse $ WAI.responseLBS
HT.status502
[("content-type", "text/plain")]
("Error connecting to gateway:\n\n" <> TLE.encodeUtf8 (TL.pack $ show exc))
data WaiProxyResponse = WPRResponse WAI.Response
| WPRProxyDest ProxyDest
| WPRProxyDestSecure ProxyDest
| WPRModifiedRequest WAI.Request ProxyDest
| WPRModifiedRequestSecure WAI.Request ProxyDest
| WPRApplication WAI.Application
waiProxyTo :: (WAI.Request -> IO WaiProxyResponse)
-> (SomeException -> WAI.Application)
-> HC.Manager
-> WAI.Application
waiProxyTo getDest onError = waiProxyToSettings getDest defaultWaiProxySettings { wpsOnExc = onError }
data LocalWaiProxySettings = LocalWaiProxySettings
{ lpsTimeBound :: Maybe Int
}
defaultLocalWaiProxySettings :: LocalWaiProxySettings
defaultLocalWaiProxySettings = LocalWaiProxySettings Nothing
setLpsTimeBound :: Maybe Int -> LocalWaiProxySettings -> LocalWaiProxySettings
setLpsTimeBound x s = s { lpsTimeBound = x }
data WaiProxySettings = WaiProxySettings
{ wpsOnExc :: SomeException -> WAI.Application
, wpsTimeout :: Maybe Int
, wpsSetIpHeader :: SetIpHeader
, wpsProcessBody :: WAI.Request -> HC.Response () -> Maybe (ConduitT ByteString (Flush Builder) IO ())
, wpsUpgradeToRaw :: WAI.Request -> Bool
, wpsGetDest :: Maybe (WAI.Request -> IO (LocalWaiProxySettings, WaiProxyResponse))
}
data SetIpHeader = SIHNone
| SIHFromSocket
| SIHFromHeader
defaultWaiProxySettings :: WaiProxySettings
defaultWaiProxySettings = WaiProxySettings
{ wpsOnExc = defaultOnExc
, wpsTimeout = Nothing
, wpsSetIpHeader = SIHFromSocket
, wpsProcessBody = \_ _ -> Nothing
, wpsUpgradeToRaw = \req ->
(CI.mk <$> lookup "upgrade" (WAI.requestHeaders req)) == Just "websocket"
, wpsGetDest = Nothing
}
renderHeaders :: WAI.Request -> HT.RequestHeaders -> Builder
renderHeaders req headers
= fromByteString (WAI.requestMethod req)
<> fromByteString " "
<> fromByteString (WAI.rawPathInfo req)
<> fromByteString (WAI.rawQueryString req)
<> (if WAI.httpVersion req == HT.http11
then fromByteString " HTTP/1.1"
else fromByteString " HTTP/1.0")
<> mconcat (map goHeader headers)
<> fromByteString "\r\n\r\n"
where
goHeader (x, y)
= fromByteString "\r\n"
<> fromByteString (CI.original x)
<> fromByteString ": "
<> fromByteString y
tryWebSockets :: WaiProxySettings -> ByteString -> Int -> WAI.Request -> (WAI.Response -> IO b) -> IO b -> IO b
tryWebSockets wps host port req sendResponse fallback
| wpsUpgradeToRaw wps req =
sendResponse $ flip WAI.responseRaw backup $ \fromClientBody toClient ->
DCN.runTCPClient settings $ \server ->
let toServer = DCN.appSink server
fromServer = DCN.appSource server
fromClient = do
mapM_ yield $ L.toChunks $ toLazyByteString headers
let loop = do
bs <- liftIO fromClientBody
unless (S.null bs) $ do
yield bs
loop
loop
toClient' = awaitForever $ liftIO . toClient
headers = renderHeaders req $ fixReqHeaders wps req
in concurrently_
(runConduit $ fromClient .| toServer)
(runConduit $ fromServer .| toClient')
| otherwise = fallback
where
backup = WAI.responseLBS HT.status500 [("Content-Type", "text/plain")]
"http-reverse-proxy detected WebSockets request, but server does not support responseRaw"
settings = DCN.clientSettings port host
strippedHeaders :: Set HT.HeaderName
strippedHeaders = Set.fromList
["content-length", "transfer-encoding", "accept-encoding", "content-encoding"]
fixReqHeaders :: WaiProxySettings -> WAI.Request -> HT.RequestHeaders
fixReqHeaders wps req =
addXRealIP $ filter (\(key, value) -> not $ key `Set.member` strippedHeaders
|| (key == "connection" && value == "close"))
$ WAI.requestHeaders req
where
fromSocket = (("X-Real-IP", S8.pack $ showSockAddr $ WAI.remoteHost req):)
fromForwardedFor = do
h <- lookup "x-forwarded-for" (WAI.requestHeaders req)
listToMaybe $ map (TE.encodeUtf8 . T.strip) $ T.splitOn "," $ TE.decodeUtf8 h
addXRealIP =
case wpsSetIpHeader wps of
SIHFromSocket -> fromSocket
SIHFromHeader ->
case lookup "x-real-ip" (WAI.requestHeaders req) <|> fromForwardedFor of
Nothing -> fromSocket
Just ip -> (("X-Real-IP", ip):)
SIHNone -> id
waiProxyToSettings :: (WAI.Request -> IO WaiProxyResponse)
-> WaiProxySettings
-> HC.Manager
-> WAI.Application
waiProxyToSettings getDest wps' manager req0 sendResponse = do
let wps = wps'{wpsGetDest = wpsGetDest wps' <|> Just (fmap (LocalWaiProxySettings $ wpsTimeout wps',) . getDest)}
(lps, edest') <- fromMaybe
(const $ return (defaultLocalWaiProxySettings, WPRResponse $ WAI.responseLBS HT.status500 [] "proxy not setup"))
(wpsGetDest wps)
req0
let edest =
case edest' of
WPRResponse res -> Left $ \_req -> ($ res)
WPRProxyDest pd -> Right (pd, req0, False)
WPRProxyDestSecure pd -> Right (pd, req0, True)
WPRModifiedRequest req pd -> Right (pd, req, False)
WPRModifiedRequestSecure req pd -> Right (pd, req, True)
WPRApplication app -> Left app
timeBound us f =
timeout us f >>= \case
Just res -> return res
Nothing -> sendResponse $ WAI.responseLBS HT.status500 [] "timeBound"
case edest of
Left app -> maybe id timeBound (lpsTimeBound lps) $ app req0 sendResponse
Right (ProxyDest host port, req, secure) -> tryWebSockets wps host port req sendResponse $ do
let req' =
#if MIN_VERSION_http_client(0, 5, 0)
HC.defaultRequest
{ HC.checkResponse = \_ _ -> return ()
, HC.responseTimeout = maybe HC.responseTimeoutNone HC.responseTimeoutMicro $ lpsTimeBound lps
#else
def
{ HC.checkStatus = \_ _ _ -> Nothing
, HC.responseTimeout = lpsTimeBound lps
#endif
, HC.method = WAI.requestMethod req
, HC.secure = secure
, HC.host = host
, HC.port = port
, HC.path = WAI.rawPathInfo req
, HC.queryString = WAI.rawQueryString req
, HC.requestHeaders = fixReqHeaders wps req
, HC.requestBody = body
, HC.redirectCount = 0
}
body =
case WAI.requestBodyLength req of
WAI.KnownLength i -> HC.RequestBodyStream
(fromIntegral i)
($ WAI.requestBody req)
WAI.ChunkedBody -> HC.RequestBodyStreamChunked ($ WAI.requestBody req)
bracket
(try $ HC.responseOpen req' manager)
(either (const $ return ()) HC.responseClose)
$ \case
Left e -> wpsOnExc wps e req sendResponse
Right res -> do
let conduit = fromMaybe
(awaitForever (\bs -> yield (Chunk $ fromByteString bs) >> yield Flush))
(wpsProcessBody wps req $ const () <$> res)
src = bodyReaderSource $ HC.responseBody res
sendResponse $ WAI.responseStream
(HC.responseStatus res)
(filter (\(key, _) -> not $ key `Set.member` strippedHeaders) $ HC.responseHeaders res)
(\sendChunk flush -> runConduit $ src .| conduit .| CL.mapM_ (\mb ->
case mb of
Flush -> flush
Chunk b -> sendChunk b))
getHeaders :: Monad m => ConduitT ByteString o m HT.RequestHeaders
getHeaders =
toHeaders <$> go id
where
go front =
await >>= maybe close push
where
close = leftover bs >> return bs
where
bs = front S8.empty
push bs'
| "\r\n\r\n" `S8.isInfixOf` bs
|| "\n\n" `S8.isInfixOf` bs
|| S8.length bs > 4096 = leftover bs >> return bs
| otherwise = go $ mappend bs
where
bs = front bs'
toHeaders = map toHeader . takeWhile (not . S8.null) . drop 1 . S8.lines
toHeader bs =
(CI.mk key, val)
where
(key, bs') = S.break (== _colon) bs
val = S.takeWhile (/= _cr) $ S.dropWhile isSpace $ S.drop 1 bs'
bodyReaderSource :: MonadIO m => BodyReader -> ConduitT i ByteString m ()
bodyReaderSource br =
loop
where
loop = do
bs <- liftIO $ brRead br
unless (S.null bs) $ do
yield bs
loop