module Network.HTTP.ReverseProxy
(
ProxyDest (..)
, rawProxyTo
, waiProxyTo
, defaultOnExc
, waiProxyToSettings
, WaiProxyResponse (..)
, WaiProxySettings
, def
, wpsOnExc
, wpsTimeout
, wpsSetIpHeader
, wpsProcessBody
, wpsUpgradeToRaw
, SetIpHeader (..)
) where
import BasicPrelude
import Debug.Trace
import Data.Conduit
import Data.Default.Class (def)
import qualified Network.Wai as WAI
import qualified Network.HTTP.Client as HC
import Network.HTTP.Client (BodyReader, brRead)
import Control.Exception (bracketOnError)
import Blaze.ByteString.Builder (fromByteString)
import Data.Word8 (isSpace, _colon, _cr)
import qualified Data.ByteString as S
import qualified Data.ByteString.Char8 as S8
import qualified Network.HTTP.Types as HT
import qualified Data.CaseInsensitive as CI
import qualified Data.Text.Lazy.Encoding as TLE
import qualified Data.Text.Lazy as TL
import qualified Data.Conduit.Network as DCN
import Control.Concurrent.MVar.Lifted (newEmptyMVar, putMVar, takeMVar)
import Control.Concurrent.Lifted (fork, killThread)
import Data.Default.Class (Default (..))
import Network.Wai.Logger (showSockAddr)
import qualified Data.Set as Set
import Data.IORef
#if MIN_VERSION_wai(2, 1, 0)
import qualified Data.ByteString.Lazy as L
import Control.Concurrent.Async (concurrently)
import Blaze.ByteString.Builder (Builder, toLazyByteString)
#else
import Blaze.ByteString.Builder (Builder)
#endif
data ProxyDest = ProxyDest
{ pdHost :: !ByteString
, pdPort :: !Int
}
rawProxyTo :: (MonadBaseControl IO m, MonadIO m)
=> (HT.RequestHeaders -> m (Either (DCN.Application m) ProxyDest))
-> DCN.Application m
rawProxyTo getDest appdata = do
(rsrc, headers) <- fromClient $$+ getHeaders
edest <- getDest headers
case edest of
Left app -> do
(fromClient', _) <- unwrapResumable rsrc
app appdata { DCN.appSource = fromClient' }
Right (ProxyDest host port) -> DCN.runTCPClient (DCN.clientSettings port host) (withServer rsrc)
where
fromClient = DCN.appSource appdata
toClient = DCN.appSink appdata
withServer rsrc appdataServer = do
x <- newEmptyMVar
tid1 <- fork $ (rsrc $$+- toServer) `finally` putMVar x True
tid2 <- fork $ (fromServer $$ toClient) `finally` putMVar x False
y <- takeMVar x
killThread $ if y then tid2 else tid1
where
fromServer = DCN.appSource appdataServer
toServer = DCN.appSink appdataServer
defaultOnExc :: SomeException -> WAI.Application
defaultOnExc exc _ = return $ WAI.responseLBS
HT.status502
[("content-type", "text/plain")]
("Error connecting to gateway:\n\n" ++ TLE.encodeUtf8 (TL.fromStrict $ show exc))
data WaiProxyResponse = WPRResponse WAI.Response
| WPRProxyDest ProxyDest
| WPRModifiedRequest WAI.Request ProxyDest
waiProxyTo :: (WAI.Request -> IO WaiProxyResponse)
-> (SomeException -> WAI.Application)
-> HC.Manager
-> WAI.Application
waiProxyTo getDest onError = waiProxyToSettings getDest def { wpsOnExc = onError }
data WaiProxySettings = WaiProxySettings
{ wpsOnExc :: SomeException -> WAI.Application
, wpsTimeout :: Maybe Int
, wpsSetIpHeader :: SetIpHeader
, wpsProcessBody :: HC.Response () -> Maybe (Conduit ByteString IO (Flush Builder))
, wpsUpgradeToRaw :: WAI.Request -> Bool
}
data SetIpHeader = SIHNone
| SIHFromSocket
| SIHFromHeader
instance Default WaiProxySettings where
def = WaiProxySettings
{ wpsOnExc = defaultOnExc
, wpsTimeout = Nothing
, wpsSetIpHeader = SIHFromSocket
, wpsProcessBody = const Nothing
, wpsUpgradeToRaw = \req ->
traceShow ((CI.mk <$> lookup "upgrade" (WAI.requestHeaders req)) == Just "websocket") $
(CI.mk <$> lookup "upgrade" (WAI.requestHeaders req)) == Just "websocket"
}
tryWebSockets :: WaiProxySettings -> ByteString -> Int -> WAI.Request -> IO WAI.Response -> IO WAI.Response
#if MIN_VERSION_wai(2, 1, 0)
tryWebSockets wps host port req fallback
| wpsUpgradeToRaw wps req = do
putStrLn "here0"
return $ flip WAI.responseRaw backup $ \fromClientBody toClient ->
DCN.runTCPClient settings $ \server ->
let toServer = DCN.appSink server
fromServer = DCN.appSource server
fromClient = do
mapM_ yield $ L.toChunks $ toLazyByteString headers
fromClientBody
headers = renderHeaders req $ fixReqHeaders wps req
in void $ concurrently
(do
putStrLn "here1"
fromClient $$ toServer)
(do
putStrLn "here2"
fromServer $$ toClient)
| otherwise = fallback
where
backup = WAI.responseLBS HT.status500 [("Content-Type", "text/plain")]
"http-reverse-proxy detected WebSockets request, but server does not support responseRaw"
settings = DCN.clientSettings port host
renderHeaders :: WAI.Request -> HT.RequestHeaders -> Builder
renderHeaders req headers
= fromByteString (WAI.requestMethod req)
<> fromByteString " "
<> fromByteString (WAI.rawPathInfo req)
<> fromByteString (WAI.rawQueryString req)
<> (if WAI.httpVersion req == HT.http11
then fromByteString " HTTP/1.1"
else fromByteString " HTTP/1.0")
<> mconcat (map goHeader headers)
<> fromByteString "\r\n\r\n"
where
goHeader (x, y)
= fromByteString "\r\n"
<> fromByteString (CI.original x)
<> fromByteString ": "
<> fromByteString y
#else
tryWebSockets _ _ _ _ = id
#endif
strippedHeaders :: Set HT.HeaderName
strippedHeaders = Set.fromList
["content-length", "transfer-encoding", "accept-encoding", "content-encoding"]
fixReqHeaders :: WaiProxySettings -> WAI.Request -> HT.RequestHeaders
fixReqHeaders wps req =
addXRealIP $ filter (\(key, _) -> not $ key `Set.member` strippedHeaders)
$ WAI.requestHeaders req
where
addXRealIP =
case wpsSetIpHeader wps of
SIHFromSocket -> (("X-Real-IP", S8.pack $ showSockAddr $ WAI.remoteHost req):)
SIHFromHeader ->
case lookup "x-real-ip" (WAI.requestHeaders req) <|> lookup "X-Forwarded-For" (WAI.requestHeaders req) of
Nothing -> id
Just ip -> (("X-Real-IP", ip):)
SIHNone -> id
waiProxyToSettings :: (WAI.Request -> IO WaiProxyResponse)
-> WaiProxySettings
-> HC.Manager
-> WAI.Application
waiProxyToSettings getDest wps manager req0 = do
edest' <- getDest req0
let edest =
case edest' of
WPRResponse res -> Left res
WPRProxyDest pd -> Right (pd, req0)
WPRModifiedRequest req pd -> Right (pd, req)
case edest of
Left response -> return response
Right (ProxyDest host port, req) -> tryWebSockets wps host port req $ do
let req' = def
{ HC.method = WAI.requestMethod req
, HC.host = host
, HC.port = port
, HC.path = WAI.rawPathInfo req
, HC.queryString = WAI.rawQueryString req
, HC.requestHeaders = fixReqHeaders wps req
, HC.requestBody = body
, HC.redirectCount = 0
, HC.checkStatus = \_ _ _ -> Nothing
, HC.responseTimeout = wpsTimeout wps
}
bodyChunked = requestBodySourceChunked $ WAI.requestBody req
body =
case WAI.requestBodyLength req of
WAI.KnownLength i -> requestBodySource
(fromIntegral i)
(WAI.requestBody req)
WAI.ChunkedBody -> bodyChunked
bracketOnError
(try $ HC.responseOpen req' manager)
(either (const $ return ()) HC.responseClose)
$ \ex -> do
case ex of
Left e -> wpsOnExc wps e req
Right res -> do
let conduit =
case wpsProcessBody wps $ fmap (const ()) res of
Nothing -> awaitForever (\bs -> yield (Chunk $ fromByteString bs) >> yield Flush)
Just conduit' -> conduit'
WAI.responseSourceBracket
(return ())
(\() -> HC.responseClose res)
$ \() -> do
let src = bodyReaderSource $ HC.responseBody res
return
( HC.responseStatus res
, filter (\(key, _) -> not $ key `Set.member` strippedHeaders) $ HC.responseHeaders res
, src $= conduit
)
getHeaders :: Monad m => Sink ByteString m HT.RequestHeaders
getHeaders =
toHeaders <$> go id
where
go front =
await >>= maybe close push
where
close = leftover bs >> return bs
where
bs = front empty
push bs'
| "\r\n\r\n" `S8.isInfixOf` bs
|| "\n\n" `S8.isInfixOf` bs
|| S8.length bs > 4096 = leftover bs >> return bs
| otherwise = go $ mappend bs
where
bs = front bs'
toHeaders = map toHeader . takeWhile (not . S8.null) . drop 1 . S8.lines
toHeader bs =
(CI.mk key, val)
where
(key, bs') = S.break (== _colon) bs
val = S.takeWhile (/= _cr) $ S.dropWhile isSpace $ S.drop 1 bs'
requestBodySource :: Int64 -> Source IO ByteString -> HC.RequestBody
requestBodySource size = HC.RequestBodyStream size . srcToPopper
requestBodySourceChunked :: Source IO ByteString -> HC.RequestBody
requestBodySourceChunked = HC.RequestBodyStreamChunked . srcToPopper
srcToPopper :: Source IO ByteString -> HC.GivesPopper ()
srcToPopper src f = do
(rsrc0, ()) <- src $$+ return ()
irsrc <- newIORef rsrc0
let popper :: IO ByteString
popper = do
rsrc <- readIORef irsrc
(rsrc', mres) <- rsrc $$++ await
writeIORef irsrc rsrc'
case mres of
Nothing -> return S.empty
Just bs
| S.null bs -> popper
| otherwise -> return bs
f popper
bodyReaderSource :: MonadIO m => BodyReader -> Source m ByteString
bodyReaderSource br =
loop
where
loop = do
bs <- liftIO $ brRead br
unless (S.null bs) $ do
yield bs
loop