module Network.HTTP.ReverseProxy
(
ProxyDest (..)
, rawProxyTo
, waiProxyTo
, defaultOnExc
, waiProxyToSettings
, WaiProxySettings
, def
, wpsOnExc
, wpsTimeout
, waiToRaw
) where
import ClassyPrelude
import Data.Conduit
import qualified Network.Wai as WAI
import qualified Network.HTTP.Conduit as HC
import Control.Exception.Lifted (try, finally)
import Blaze.ByteString.Builder (fromByteString)
import Data.Word8 (isSpace, _colon, toLower, _cr)
import qualified Data.ByteString.Char8 as S8
import qualified Network.HTTP.Types as HT
import qualified Data.CaseInsensitive as CI
import qualified Data.Text.Encoding as TE
import qualified Data.Text.Lazy.Encoding as TLE
import qualified Data.Conduit.Network as DCN
import Control.Concurrent.MVar.Lifted (newEmptyMVar, putMVar, takeMVar)
import Control.Concurrent.Lifted (fork, killThread)
import Control.Monad.Trans.Control (MonadBaseControl)
import Network.Wai.Handler.Warp
import Data.Conduit.Binary (sourceFileRange)
import qualified Data.IORef as I
import Network.Socket (PortNumber (PortNum), SockAddr (SockAddrInet))
import Data.Default (Default (def))
import Data.Version (showVersion)
import qualified Paths_http_reverse_proxy
data ProxyDest = ProxyDest
{ pdHost :: !ByteString
, pdPort :: !Int
}
rawProxyTo :: (MonadBaseControl IO m, MonadIO m)
=> (HT.RequestHeaders -> m (Either (DCN.Application m) ProxyDest))
-> DCN.Application m
rawProxyTo getDest appdata = do
(rsrc, headers) <- fromClient $$+ getHeaders
edest <- getDest headers
case edest of
Left app -> do
(fromClient', _) <- unwrapResumable rsrc
app appdata { DCN.appSource = fromClient' }
Right (ProxyDest host port) -> DCN.runTCPClient (DCN.clientSettings port host) (withServer rsrc)
where
fromClient = DCN.appSource appdata
toClient = DCN.appSink appdata
withServer rsrc appdataServer = do
x <- newEmptyMVar
tid1 <- fork $ (rsrc $$+- toServer) `finally` putMVar x True
tid2 <- fork $ (fromServer $$ toClient) `finally` putMVar x False
y <- takeMVar x
killThread $ if y then tid2 else tid1
where
fromServer = DCN.appSource appdataServer
toServer = DCN.appSink appdataServer
defaultOnExc :: SomeException -> WAI.Application
defaultOnExc exc _ = return $ WAI.responseLBS
HT.status502
[("content-type", "text/plain")]
("Error connecting to gateway:\n\n" ++ TLE.encodeUtf8 (show exc))
waiProxyTo :: (WAI.Request -> ResourceT IO (Either WAI.Response ProxyDest))
-> (SomeException -> WAI.Application)
-> HC.Manager
-> WAI.Application
waiProxyTo getDest onError = waiProxyToSettings getDest def { wpsOnExc = onError }
data WaiProxySettings = WaiProxySettings
{ wpsOnExc :: SomeException -> WAI.Application
, wpsTimeout :: Maybe Int
}
instance Default WaiProxySettings where
def = WaiProxySettings
{ wpsOnExc = defaultOnExc
, wpsTimeout = Nothing
}
waiProxyToSettings getDest wps manager req = do
edest <- getDest req
case edest of
Left response -> return response
Right (ProxyDest host port) -> do
let req' = HC.def
{ HC.method = WAI.requestMethod req
, HC.host = host
, HC.port = port
, HC.path = WAI.rawPathInfo req
, HC.queryString = WAI.rawQueryString req
, HC.requestHeaders = filter (\(key, _) -> not $ key `member` strippedHeaders) $ WAI.requestHeaders req
, HC.requestBody = body
, HC.redirectCount = 0
#if MIN_VERSION_http_conduit(1, 9, 0)
, HC.checkStatus = \_ _ _ -> Nothing
#else
, HC.checkStatus = \_ _ -> Nothing
#endif
, HC.responseTimeout = wpsTimeout wps
}
bodySrc = mapOutput fromByteString $ WAI.requestBody req
bodyChunked = HC.RequestBodySourceChunked bodySrc
#if MIN_VERSION_wai(1, 4, 0)
body =
case WAI.requestBodyLength req of
WAI.KnownLength i -> HC.RequestBodySource
(fromIntegral i)
bodySrc
WAI.ChunkedBody -> bodyChunked
#else
body = bodyChunked
#endif
ex <- try $ HC.http req' manager
case ex of
Left e -> wpsOnExc wps e req
Right res -> do
(src, _) <- unwrapResumable $ HC.responseBody res
return $ WAI.ResponseSource
(HC.responseStatus res)
(filter (\(key, _) -> not $ key `member` strippedHeaders) $ HC.responseHeaders res) $ do
yield Flush
src =$= awaitForever (\bs -> yield (Chunk $ fromByteString bs) >> yield Flush)
where
strippedHeaders = asSet $ fromList ["content-length", "transfer-encoding", "accept-encoding", "content-encoding"]
asSet :: Set a -> Set a
asSet = id
getHeaders :: Monad m => Sink ByteString m HT.RequestHeaders
getHeaders =
toHeaders <$> go id
where
go front =
await >>= maybe close push
where
close = leftover bs >> return bs
where
bs = front empty
push bs'
| "\r\n\r\n" `S8.isInfixOf` bs
|| "\n\n" `S8.isInfixOf` bs
|| length bs > 4096 = leftover bs >> return bs
| otherwise = go $ append bs
where
bs = front bs'
toHeaders = map toHeader . takeWhile (not . null) . drop 1 . S8.lines
toHeader bs =
(CI.mk key, val)
where
(key, bs') = break (== _colon) bs
val = takeWhile (/= _cr) $ dropWhile isSpace $ drop 1 bs'
waiToRaw :: WAI.Application -> DCN.Application IO
waiToRaw app appdata0 =
loop $ transPipe lift fromClient0
where
fromClient0 = DCN.appSource appdata0
toClient = DCN.appSink appdata0
loop fromClient = do
mfromClient <- runResourceT $ do
ex <- try $ parseRequest conn 0 dummyAddr fromClient
case ex of
Left (_ :: SomeException) -> return Nothing
Right (req, fromClient') -> do
res <- app req
keepAlive <- sendResponse
#if MIN_VERSION_warp(1, 3, 8)
defaultSettings
{ settingsServerName = S8.pack $ concat
[ "Warp/"
, warpVersion
, " + http-reverse-proxy/"
, showVersion Paths_http_reverse_proxy.version
]
}
#endif
dummyCleaner req conn res
(fromClient'', _) <- liftIO fromClient' >>= unwrapResumable
return $ if keepAlive then Just fromClient'' else Nothing
maybe (return ()) loop mfromClient
dummyAddr = SockAddrInet (PortNum 0) 0
conn = Connection
{ connSendMany = \bss -> mapM_ yield bss $$ toClient
, connSendAll = \bs -> yield bs $$ toClient
, connSendFile = \fp offset len _th headers _cleaner ->
let src1 = mapM_ yield headers
src2 = sourceFileRange fp (Just offset) (Just len)
in runResourceT
$ (src1 >> src2)
$$ transPipe lift toClient
, connClose = return ()
, connRecv = error "connRecv should not be used"
}