module Network.Wai.Handler.Warp.Request (
recvRequest
, headerLines
) where
import Control.Applicative
import qualified Control.Concurrent as Conc (yield)
import Control.Exception.Lifted (throwIO)
import Control.Monad.IO.Class (liftIO)
import Data.Array ((!))
import Data.ByteString (ByteString)
import qualified Data.ByteString as S
import qualified Data.ByteString.Unsafe as SU
import qualified Data.CaseInsensitive as CI
import Data.Conduit
import qualified Data.IORef as I
import Data.Monoid (mempty)
import qualified Network.HTTP.Types as H
import Network.Socket (SockAddr)
import Network.Wai
import Network.Wai.Handler.Warp.Conduit
import Network.Wai.Handler.Warp.Header
import Network.Wai.Handler.Warp.ReadInt
import Network.Wai.Handler.Warp.RequestHeader
import Network.Wai.Handler.Warp.Settings (Settings, settingsNoParsePath)
import qualified Network.Wai.Handler.Warp.Timeout as Timeout
import Network.Wai.Handler.Warp.Types
import Network.Wai.Internal
import Prelude hiding (lines)
maxTotalHeaderLength :: Int
maxTotalHeaderLength = 50 * 1024
recvRequest :: Settings
-> Connection
-> InternalInfo
-> SockAddr
-> Source IO ByteString
-> IO (Request
,IndexedHeader
,IO (ResumableSource IO ByteString)
,Maybe ByteString)
recvRequest settings conn ii addr src0 = do
(src, (leftover', hdrlines)) <- src0 $$+ headerLines
(method, unparsedPath, path, query, httpversion, hdr) <- parseHeaderLines hdrlines
let idxhdr = indexRequestHeader hdr
expect = idxhdr ! idxExpect
cl = idxhdr ! idxContentLength
te = idxhdr ! idxTransferEncoding
liftIO $ handleExpect conn httpversion expect
(rbody, bodyLength, getSource) <- bodyAndSource src cl te
let req = Request {
requestMethod = method
, httpVersion = httpversion
, pathInfo = H.decodePathSegments path
, rawPathInfo = if settingsNoParsePath settings then unparsedPath else path
, rawQueryString = query
, queryString = H.parseQuery query
, requestHeaders = hdr
, isSecure = False
, remoteHost = addr
, requestBody = timeoutBody th rbody
, vault = mempty
, requestBodyLength = bodyLength
, requestHeaderHost = idxhdr ! idxHost
, requestHeaderRange = idxhdr ! idxRange
}
return (req, idxhdr, getSource, leftover')
where
th = threadHandle ii
headerLines :: Sink ByteString IO (Maybe ByteString, [ByteString])
headerLines =
await >>= maybe (throwIO (NotEnoughLines [])) (push (THStatus 0 id id))
handleExpect :: Connection
-> H.HttpVersion
-> Maybe HeaderValue
-> IO ()
handleExpect conn ver (Just "100-continue") = do
connSendAll conn continue
Conc.yield
where
continue
| ver == H.http11 = "HTTP/1.1 100 Continue\r\n\r\n"
| otherwise = "HTTP/1.0 100 Continue\r\n\r\n"
handleExpect _ _ _ = return ()
bodyAndSource :: ResumableSource IO ByteString
-> Maybe HeaderValue
-> Maybe HeaderValue
-> IO (Source IO ByteString
,RequestBodyLength
,IO (ResumableSource IO ByteString))
bodyAndSource src cl te
| chunked = do
ref <- I.newIORef (src, NeedLen)
return (chunkedSource ref, ChunkedBody, fst <$> I.readIORef ref)
| otherwise = do
ibs <- IsolatedBSSource <$> I.newIORef (len, src)
return (ibsIsolate ibs, bodyLen, ibsDone ibs)
where
len = toLength cl
bodyLen = KnownLength $ fromIntegral len
chunked = isChunked te
toLength :: Maybe HeaderValue -> Int
toLength Nothing = 0
toLength (Just bs) = readInt bs
isChunked :: Maybe HeaderValue -> Bool
isChunked (Just bs) = CI.foldCase bs == "chunked"
isChunked _ = False
timeoutBody :: Timeout.Handle -> Source IO ByteString -> Source IO ByteString
timeoutBody timeoutHandle rbody = do
liftIO $ Timeout.resume timeoutHandle
addCleanup (const $ liftIO $ Timeout.pause timeoutHandle) rbody
type BSEndo = ByteString -> ByteString
type BSEndoList = [ByteString] -> [ByteString]
data THStatus = THStatus
!Int
BSEndoList
BSEndo
close :: Sink ByteString IO a
close = throwIO IncompleteHeaders
push :: THStatus -> ByteString -> Sink ByteString IO (Maybe ByteString, [ByteString])
push (THStatus len lines prepend) bs'
| len > maxTotalHeaderLength = throwIO OverLargeHeader
| otherwise = push' mnl
where
bs = prepend bs'
bsLen = S.length bs
mnl = do
nl <- S.elemIndex 10 bs
if bsLen > nl + 1 then
let c = S.index bs (nl + 1)
b = case nl of
0 -> True
1 -> S.index bs 0 == 13
_ -> False
in Just (nl, not b && (c == 32 || c == 9))
else
Just (nl, False)
push' Nothing = await >>= maybe close (push status)
where
len' = len + bsLen
prepend' = S.append bs
status = THStatus len' lines prepend'
push' (Just (end, True)) = push status rest
where
rest = S.drop (end + 1) bs
prepend' = S.append (SU.unsafeTake (checkCR bs end) bs)
len' = len + end
status = THStatus len' lines prepend'
push' (Just (end, False))
| S.null line = let lines' = lines []
rest = if start < bsLen then
Just (SU.unsafeDrop start bs)
else
Nothing
in maybe (return ()) leftover rest >> return (rest, lines')
| otherwise = let len' = len + start
lines' = lines . (line:)
status = THStatus len' lines' id
in if start < bsLen then
let bs'' = SU.unsafeDrop start bs
in push status bs''
else
await >>= maybe close (push status)
where
start = end + 1
line = SU.unsafeTake (checkCR bs end) bs
checkCR :: ByteString -> Int -> Int
checkCR bs pos = if 13 == S.index bs p then p else pos
where
!p = pos 1