module Network.HTTP.Proxy
( runProxy
, runProxySettings
, Settings (..)
, defaultSettings
)
where
import Prelude hiding (catch, lines)
import Data.ByteString (ByteString)
import qualified Data.ByteString as S
import qualified Data.ByteString.Unsafe as SU
import qualified Data.ByteString.Char8 as B
import qualified Data.ByteString.Lazy as L
import Network ( PortID(..) )
import Network.Socket
( accept, Family (..)
, SocketType (Stream), listen, bindSocket, setSocketOption, maxListenQueue
, SockAddr, SocketOption (ReuseAddr)
, AddrInfo(..), AddrInfoFlag(..), defaultHints, getAddrInfo
, Socket, sClose, shutdown, ShutdownCmd(..), HostName, ServiceName, socket, connect
)
import Network.BSD ( getProtocolNumber )
import Network.Wai
import qualified Network.Socket
import qualified Network.Socket.ByteString as Sock
import Control.Applicative
import Control.Exception
( bracket, finally, Exception, SomeException, catch
, fromException, AsyncException (ThreadKilled)
, bracketOnError, IOException, throw
)
import Control.Concurrent (forkIO, ThreadId, killThread)
import Data.Maybe (fromMaybe, isNothing)
import qualified Network.HTTP.Enumerator as HE
import Data.Typeable (Typeable)
import Data.Enumerator (($$), (=$), (>>==))
import qualified Data.Enumerator as E
import qualified Data.Enumerator.List as EL
import qualified Data.Enumerator.Binary as EB
import Blaze.ByteString.Builder
(copyByteString, Builder, toByteString, fromByteString)
import Blaze.ByteString.Builder.Char8 (fromChar, fromShow)
import Data.Monoid (mappend, mconcat)
import Control.Monad.IO.Class (liftIO)
import qualified Network.HTTP.Proxy.Timeout as T
import Data.List (delete, foldl')
import Control.Monad (forever, when)
import qualified Network.HTTP.Types as H
import qualified Data.CaseInsensitive as CI
import System.IO (hPutStrLn, stderr)
#if WINDOWS
import Control.Concurrent (threadDelay)
import qualified Control.Concurrent.MVar as MV
import Network.Socket (withSocketsDo)
#endif
bindPort :: Int
-> String
-> IO Socket
bindPort p s = do
let hints = defaultHints { addrFlags = [AI_PASSIVE
, AI_NUMERICSERV
, AI_NUMERICHOST]
, addrSocketType = Stream }
host = if s == "*" then Nothing else Just s
port = Just . show $ p
addrs <- getAddrInfo (Just hints) host port
let addrs' = filter (\x -> addrFamily x == AF_INET6) addrs
addr = if null addrs' then head addrs else head addrs'
bracketOnError
(Network.Socket.socket (addrFamily addr) (addrSocketType addr) (addrProtocol addr))
sCloseX
(\sock -> do
setSocketOption sock ReuseAddr 1
bindSocket sock (addrAddress addr)
listen sock maxListenQueue
return sock
)
runProxy :: Port -> IO ()
runProxy p = runProxySettings defaultSettings { proxyPort = p }
runProxySettings :: Settings -> IO ()
#if WINDOWS
runProxySettings set = withSocketsDo $ do
var <- MV.newMVar Nothing
let clean = MV.modifyMVar_ var $ \s -> maybe (return ()) sCloseX s >> return Nothing
_ <- forkIO $ bracket
(bindPort (proxyPort set) (proxyHost set))
(const clean)
(\s -> do
MV.modifyMVar_ var (\_ -> return $ Just s)
runSettingsSocket set s)
forever (threadDelay maxBound) `finally` clean
#else
runProxySettings set =
bracket
(bindPort (proxyPort set) (proxyHost set))
sCloseX
(runSettingsSocket set)
#endif
type Port = Int
runSettingsSocket :: Settings -> Socket -> IO ()
runSettingsSocket set sock = do
let onE = proxyOnException set
port = proxyPort set
tm <- T.initialize $ proxyTimeout set * 1000000
mgr <- HE.newManager
forever $ do
(conn, sa) <- accept sock
_ <- forkIO $ do
th <- T.registerKillThread tm
serveConnection th tm onE port conn sa mgr
T.cancel th
return ()
verboseSockets :: Bool
verboseSockets = False
sCloseX :: Socket -> IO ()
sCloseX s = do
when verboseSockets $ putStrLn ("close " ++ show s)
sClose s
shutdownX :: Socket -> ShutdownCmd -> IO ()
shutdownX s ShutdownReceive = do
when verboseSockets $ putStrLn ("shutdown " ++ show s++" ShutdownReceive")
shutdown s ShutdownReceive
shutdownX s ShutdownSend = do
when verboseSockets $ putStrLn ("shutdown " ++ show s++" ShutdownSend")
shutdown s ShutdownSend
shutdownX s ShutdownBoth = do
when verboseSockets $ putStrLn ("shutdown " ++ show s++" ShutdownBoth")
shutdown s ShutdownBoth
mkHeaders :: Monad m
=> H.HttpVersion
-> H.Status
-> H.ResponseHeaders
-> E.Enumerator ByteString m b
mkHeaders ver s hrs =
E.enumList 1 [toByteString $ headers ver s hrs False]
serveConnection :: T.Handle
-> T.Manager
-> (SomeException -> IO ())
-> Port -> Socket -> SockAddr
-> HE.Manager
-> IO ()
serveConnection th tm onException port conn remoteHost' mgr = do
mExtraSocket <- E.run_ (fromClient $$ serveConnection')
`finally` sCloseX conn
case mExtraSocket of
Just (tid, s) -> do
killThread tid
sCloseX s
Nothing -> return ()
`catch` onException
where
fromClient = enumSocket th bytesPerRead conn
serveConnection' :: E.Iteratee ByteString IO (Maybe (ThreadId, Socket))
serveConnection' = do
req <- parseRequest port remoteHost'
case req of
_ | requestMethod req `elem` [ "GET", "POST" ] ->
proxyPlain req
_ | requestMethod req == "CONNECT" ->
case B.split ':' (rawPathInfo req) of
[h, p] -> proxyConnect th tm onException conn h (read $ B.unpack p) req
_ -> failRequest th conn req "Bad request" ("Bad request '" `mappend` rawPathInfo req `mappend` "'.")
_ | otherwise ->
failRequest th conn req "Unknown request" ("Unknown request '" `mappend` rawPathInfo req `mappend` "'.")
proxyPlain :: Request -> E.Iteratee ByteString IO (Maybe (ThreadId, Socket))
proxyPlain req = do
let urlStr = "http://" `mappend` serverName req
`mappend` rawPathInfo req
`mappend` H.renderQuery True (queryString req)
close =
let hasClose hdrs = (== "close") . CI.mk <$> lookup "connection" hdrs
mClose = hasClose (requestHeaders req)
defaultClose = httpVersion req == H.HttpVersion 1 0
in fromMaybe defaultClose mClose
outHdrs = [(n,v) | (n,v) <- requestHeaders req, n `notElem` [
"Host", "Content-Length"
]]
liftIO $ putStrLn $ B.unpack (requestMethod req) ++ " " ++ B.unpack urlStr
let contentLength = if requestMethod req == "GET"
then 0
else read . B.unpack . fromMaybe "0" . lookup "content-length" . requestHeaders $ req
postBody <- mconcat . L.toChunks <$> EB.take contentLength
let enumPostBody = E.enumList 1 [fromByteString postBody]
url <-
(\url -> url { HE.method = requestMethod req,
HE.requestHeaders = outHdrs,
HE.rawBody = True,
HE.requestBody = HE.RequestBodyEnum (fromIntegral contentLength) enumPostBody })
<$> liftIO (HE.parseUrl (B.unpack urlStr))
close' <- liftIO $ E.run_ $ HE.http url (handleHttpReply close) mgr
if close'
then return Nothing
else serveConnection'
where
handleHttpReply close status hdrs = do
let remoteClose = isNothing ("content-length" `lookup` hdrs)
close' = close || remoteClose
hdrs' = [(n, v) | (n, v) <- hdrs, n `notElem`
["connection", "proxy-connection"]
]
++ [("Connection", if close' then "Close" else "Keep-Alive")]
mkHeaders (httpVersion req) status hdrs' $$ iterSocket th conn close
return remoteClose
failRequest :: T.Handle -> Socket -> Request -> ByteString -> ByteString -> E.Iteratee ByteString IO (Maybe (ThreadId, Socket))
failRequest th conn req headerMsg bodyMsg = do
EB.isolate 0 =$
E.enumList 1 [bodyMsg] $$
mkHeaders (httpVersion req) status [("Content-Length", B.pack . show . B.length $ bodyMsg)] $$
iterSocket th conn True
return Nothing
where
status = H.status500 { H.statusMessage = headerMsg }
proxyConnect :: T.Handle -> T.Manager -> (SomeException -> IO ()) -> Socket -> ByteString -> Int -> Request -> E.Iteratee ByteString IO (Maybe (ThreadId, Socket))
proxyConnect th tm onException conn host prt req = do
liftIO $ putStrLn $ B.unpack (requestMethod req) ++ " " ++ B.unpack host ++ ":" ++ show prt
mHandles <- liftIO $ do
s <- connectTo (B.unpack host) (PortNumber . fromIntegral $ prt)
let eh = enumSocket th 65536 s
ih = iterSocket th s True
return $ Right (s, eh, ih)
`catch` \(exc :: IOException) ->
return $ Left $ "Unable to connect: " `mappend` B.pack (show exc)
case mHandles of
Right (s, eh, ih) -> do
tid <- liftIO $ forkIO $ do
wrTh <- T.registerKillThread tm
E.run_ (eh $$ mkHeaders (httpVersion req) H.statusOK [] $$ iterSocket wrTh conn True)
`catch` onException
T.cancel wrTh
ih
return (Just (tid, s))
Left errorMsg ->
failRequest th conn req errorMsg ("PROXY FAILURE\r\n" `mappend` errorMsg)
connectTo :: HostName
-> PortID
-> IO Socket
connectTo hostname (Service serv) = connect' hostname serv
connectTo hostname (PortNumber port) = connect' hostname (show port)
connectTo _ (UnixSocket _) = error "Cannot connect to a UnixSocket"
connect' :: HostName -> ServiceName -> IO Socket
connect' host serv = do
proto <- getProtocolNumber "tcp"
let hints = defaultHints { addrFlags = [AI_ADDRCONFIG]
, addrProtocol = proto
, addrSocketType = Stream }
addrs <- getAddrInfo (Just hints) (Just host) (Just serv)
firstSuccessful $ map tryToConnect addrs
where
tryToConnect addr =
bracketOnError
(socket (addrFamily addr) (addrSocketType addr) (addrProtocol addr))
sClose
(\sock -> do
connect sock (addrAddress addr)
return sock
)
firstSuccessful :: [IO a] -> IO a
firstSuccessful [] = error "firstSuccessful: empty list"
firstSuccessful (p:ps) = catch p $ \e ->
case ps of
[] -> throw (e :: IOException)
_ -> firstSuccessful ps
parseRequest :: Port -> SockAddr -> E.Iteratee S.ByteString IO Request
parseRequest port remoteHost' = do
headers' <- takeHeaders
parseRequest' port headers' remoteHost'
bytesPerRead, maxTotalHeaderLength :: Int
bytesPerRead = 4096
maxTotalHeaderLength = 50 * 1024
data InvalidRequest =
NotEnoughLines [String]
| BadFirstLine String
| NonHttp
| IncompleteHeaders
| ConnectionClosedByPeer
| OverLargeHeader
deriving (Show, Typeable, Eq)
instance Exception InvalidRequest
parseRequest' :: Port
-> [ByteString]
-> SockAddr
-> E.Iteratee S.ByteString IO Request
parseRequest' _ [] _ = E.throwError $ NotEnoughLines []
parseRequest' port (firstLine:otherLines) remoteHost' = do
(method, rpath', gets, httpversion) <- parseFirst firstLine
let (host',rpath) =
if S.null rpath'
then ("","/")
else if "http://" `S.isPrefixOf` rpath'
then S.breakByte 47 $ S.drop 7 rpath'
else ("", rpath')
let heads = map parseHeaderNoAttr otherLines
let host = case (host', lookup "host" heads) of
("", Just h) -> h
(h, _) -> h
return Request
{ requestMethod = method
, httpVersion = httpversion
, pathInfo = H.decodePathSegments rpath
, rawPathInfo = rpath
, rawQueryString = gets
, queryString = H.parseQuery gets
, serverName = host
, serverPort = port
, requestHeaders = heads
, isSecure = False
, remoteHost = remoteHost'
}
parseFirst :: ByteString
-> E.Iteratee S.ByteString IO (ByteString, ByteString, ByteString, H.HttpVersion)
parseFirst s =
case S.split 32 s of
[method, query, http'] -> do
let (hfirst, hsecond) = B.splitAt 5 http'
if hfirst == "HTTP/"
then let (rpath, qstring) = S.breakByte 63 query
hv =
case hsecond of
"1.1" -> H.http11
_ -> H.http10
in return (method, rpath, qstring, hv)
else E.throwError NonHttp
_ -> E.throwError $ BadFirstLine $ B.unpack s
httpBuilder, spaceBuilder, newlineBuilder, transferEncodingBuilder
, colonSpaceBuilder :: Builder
httpBuilder = copyByteString "HTTP/"
spaceBuilder = fromChar ' '
newlineBuilder = copyByteString "\r\n"
transferEncodingBuilder = copyByteString "Transfer-Encoding: chunked\r\n\r\n"
colonSpaceBuilder = copyByteString ": "
headers :: H.HttpVersion -> H.Status -> H.ResponseHeaders -> Bool -> Builder
headers !httpversion !status !responseHeaders !isChunked' =
let !start = httpBuilder
`mappend` copyByteString
(case httpversion of
H.HttpVersion 1 1 -> "1.1"
_ -> "1.0")
`mappend` spaceBuilder
`mappend` fromShow (H.statusCode status)
`mappend` spaceBuilder
`mappend` copyByteString (H.statusMessage status)
`mappend` newlineBuilder
!start' = foldl' responseHeaderToBuilder start (serverHeader responseHeaders)
!end = if isChunked'
then transferEncodingBuilder
else newlineBuilder
in start' `mappend` end
responseHeaderToBuilder :: Builder -> H.Header -> Builder
responseHeaderToBuilder b (x, y) = b
`mappend` copyByteString (CI.original x)
`mappend` colonSpaceBuilder
`mappend` copyByteString y
`mappend` newlineBuilder
parseHeaderNoAttr :: ByteString -> H.Header
parseHeaderNoAttr s =
let (k, rest) = S.breakByte 58 s
restLen = S.length rest
rest' = if restLen > 1 && SU.unsafeTake 2 rest == ": "
then SU.unsafeDrop 2 rest
else rest
in (CI.mk k, rest')
enumSocket :: T.Handle -> Int -> Socket -> E.Enumerator ByteString IO a
enumSocket th len sock =
inner
where
inner (E.Continue k) = do
bs <- liftIO $ Sock.recv sock len
liftIO $ T.tickle th
if S.null bs
then do
liftIO $
shutdownX sock ShutdownReceive
`catch` \(exc :: IOException) ->
putStrLn $ "couldn't shutdown read side of " ++ show sock++": " ++ show exc
E.continue k
else k (E.Chunks [bs]) >>== inner
inner step = E.returnI step
iterSocket :: T.Handle
-> Socket
-> Bool
-> E.Iteratee B.ByteString IO ()
iterSocket th sock toClose =
E.continue step
where
step E.EOF = do
liftIO $ T.resume th
when toClose $
liftIO $
shutdownX sock ShutdownSend
`catch` \(exc :: IOException) ->
putStrLn $ "couldn't shutdown send side of " ++ show sock ++ ": " ++ show exc
E.yield () E.EOF
step (E.Chunks []) = E.continue step
step (E.Chunks xs) = do
liftIO $ T.resume th
liftIO $ Sock.sendMany sock xs
liftIO $ T.pause th
E.continue step
data Settings = Settings
{ proxyPort :: Int
, proxyHost :: String
, proxyOnException :: SomeException -> IO ()
, proxyTimeout :: Int
}
defaultSettings :: Settings
defaultSettings = Settings
{ proxyPort = 3100
, proxyHost = "*"
, proxyOnException = \e ->
case fromException e of
Just x -> go x
Nothing ->
when (go' $ fromException e)
$ hPutStrLn stderr $ show e
, proxyTimeout = 30
}
where
go :: InvalidRequest -> IO ()
go _ = return ()
go' (Just ThreadKilled) = False
go' _ = True
takeHeaders :: E.Iteratee ByteString IO [ByteString]
takeHeaders = do
!x <- forceHead ConnectionClosedByPeer
takeHeaders' 0 id id x
takeHeaders' :: Int
-> ([ByteString] -> [ByteString])
-> ([ByteString] -> [ByteString])
-> ByteString
-> E.Iteratee S.ByteString IO [ByteString]
takeHeaders' !len _ _ _ | len > maxTotalHeaderLength = E.throwError OverLargeHeader
takeHeaders' !len !lines !prepend !bs = do
let !bsLen = S.length bs
!mnl = S.elemIndex 10 bs
case mnl of
!Nothing -> do
let !len' = len + bsLen
!more <- forceHead IncompleteHeaders
takeHeaders' len' lines (prepend . (:) bs) more
Just !nl -> do
let !end = nl
!start = nl + 1
!line =
if end > 0
then S.concat $! prepend [SU.unsafeTake (checkCR bs end) bs]
else S.concat $! prepend []
if S.null line
then do
let !lines' = lines []
if start < bsLen
then do
let !rest = SU.unsafeDrop start bs
E.yield lines' $! E.Chunks [rest]
else return lines'
else do
let !len' = len + start
!lines' = lines . (:) line
!more <-
if start < bsLen
then return $! SU.unsafeDrop start bs
else forceHead IncompleteHeaders
takeHeaders' len' lines' id more
forceHead :: InvalidRequest -> E.Iteratee ByteString IO ByteString
forceHead err = do
!mx <- EL.head
case mx of
!Nothing -> E.throwError err
Just !x -> return x
checkCR :: ByteString -> Int -> Int
checkCR bs pos =
let !p = pos 1
in if '\r' == B.index bs p
then p
else pos
serverHeader :: H.RequestHeaders -> H.RequestHeaders
serverHeader hdrs = case lookup key hdrs of
Nothing -> server : hdrs
Just svr -> servers svr : delete (key,svr) hdrs
where
key = "Via"
ver = B.pack "Proxy/0.0"
server = (key, ver)
servers svr = (key, S.concat [svr, " ", ver])