module Network.Wai.Handler.Warp
(
run
, runEx
, serveConnections
, runSettings
, Settings
, defaultSettings
, settingsPort
, settingsOnException
, settingsTimeout
, Port
, InvalidRequest (..)
, sendResponse
, parseRequest
#if TEST
, takeHeaders
#endif
) where
import Prelude hiding (catch, lines)
import Network.Wai
import qualified System.IO
import Data.ByteString (ByteString)
import qualified Data.ByteString as S
import qualified Data.ByteString.Unsafe as SU
import qualified Data.ByteString.Char8 as B
import qualified Data.ByteString.Lazy as L
import Network (listenOn, sClose, PortID(PortNumber), Socket)
import Network.Socket
( accept, SockAddr
)
import qualified Network.Socket.ByteString as Sock
import Control.Exception
( bracket, finally, Exception, SomeException, catch
, fromException
)
import Control.Concurrent (forkIO, threadWaitWrite)
import Data.Maybe (fromMaybe)
import Data.Typeable (Typeable)
import Data.Enumerator (($$), (>>==))
import qualified Data.Enumerator as E
import qualified Data.Enumerator.List as EL
import qualified Data.Enumerator.Binary as EB
import Blaze.ByteString.Builder.Enumerator (builderToByteString)
import Blaze.ByteString.Builder.HTTP
(chunkedTransferEncoding, chunkedTransferTerminator)
import Blaze.ByteString.Builder
(copyByteString, Builder, toLazyByteString, toByteStringIO)
import Blaze.ByteString.Builder.Char8 (fromChar, fromShow)
import Data.Monoid (mappend, mconcat)
import Network.Socket.SendFile (sendFileIterWith, Iter (..))
import Control.Monad.IO.Class (liftIO)
import qualified Timeout as T
import Data.Word (Word8)
import Data.List (foldl')
import Control.Monad (forever)
#if WINDOWS
import Control.Concurrent (threadDelay)
import qualified Control.Concurrent.MVar as MV
import Network.Socket (withSocketsDo)
#endif
run :: Port -> Application -> IO ()
run = runEx (const $ return ())
runEx :: (SomeException -> IO ()) -> Port -> Application -> IO ()
runEx onE port = runSettings Settings
{ settingsPort = port
, settingsOnException = onE
, settingsTimeout = 30
}
runSettings :: Settings -> Application -> IO ()
#if WINDOWS
runSettings set app = withSocketsDo $ do
var <- MV.newMVar Nothing
let clean = MV.modifyMVar_ var $ \s -> maybe (return ()) sClose s >> return Nothing
_ <- forkIO $ bracket
(listenOn $ PortNumber $ fromIntegral $ settingsPort set)
(const clean)
(\s -> do
MV.modifyMVar_ var (\_ -> return $ Just s)
serveConnections' set app s)
forever (threadDelay maxBound) `finally` clean
#else
runSettings set =
bracket
(listenOn $ PortNumber $ fromIntegral $ settingsPort set)
sClose .
serveConnections' set
#endif
type Port = Int
serveConnections :: (SomeException -> IO ())
-> Port -> Application -> Socket -> IO ()
serveConnections onE port = serveConnections' defaultSettings
{ settingsOnException = onE
, settingsPort = port
}
serveConnections' :: Settings
-> Application -> Socket -> IO ()
serveConnections' set app socket = do
let onE = settingsOnException set
port = settingsPort set
tm <- T.initialize $ settingsTimeout set * 1000000
forever $ do
(conn, sa) <- accept socket
_ <- forkIO $ do
th <- T.registerKillThread tm
serveConnection th onE port app conn sa
T.cancel th
return ()
serveConnection :: T.Handle
-> (SomeException -> IO ())
-> Port -> Application -> Socket -> SockAddr -> IO ()
serveConnection th onException port app conn remoteHost' = do
catch
(finally
(E.run_ $ fromClient $$ serveConnection')
(sClose conn))
onException
where
fromClient = enumSocket th bytesPerRead conn
serveConnection' = do
(enumeratee, env) <- parseRequest port remoteHost'
liftIO $ T.pause th
res <- E.joinI $ enumeratee $$ app env
liftIO $ T.resume th
keepAlive <- liftIO $ sendResponse th env (httpVersion env) conn res
if keepAlive then serveConnection' else return ()
parseRequest :: Port -> SockAddr -> E.Iteratee S.ByteString IO (E.Enumeratee ByteString ByteString IO a, Request)
parseRequest port remoteHost' = do
headers' <- takeHeaders
parseRequest' port headers' remoteHost'
bytesPerRead, maxTotalHeaderLength :: Int
bytesPerRead = 4096
maxTotalHeaderLength = 50 * 1024
sendFileCount :: Integer
sendFileCount = 65536
data InvalidRequest =
NotEnoughLines [String]
| BadFirstLine String
| NonHttp
| TooManyHeaders
| IncompleteHeaders
| OverLargeHeader
| SocketTimeout
deriving (Show, Typeable, Eq)
instance Exception InvalidRequest
parseRequest' :: Port
-> [ByteString]
-> SockAddr
-> E.Iteratee S.ByteString IO (E.Enumeratee S.ByteString S.ByteString IO a, Request)
parseRequest' _ [] _ = E.throwError $ NotEnoughLines []
parseRequest' port (firstLine:otherLines) remoteHost' = do
(method, rpath', gets, httpversion) <- parseFirst firstLine
let rpath =
if S.null rpath'
then "/"
else rpath'
let heads = map parseHeaderNoAttr otherLines
let host = fromMaybe "" $ lookup "host" heads
let len =
case lookup "content-length" heads of
Nothing -> 0
Just bs ->
case reads $ B.unpack bs of
(x, _):_ -> x
[] -> 0
let serverName' = takeUntil 58 host
return (EB.isolate len, Request
{ requestMethod = method
, httpVersion = httpversion
, pathInfo = rpath
, queryString = gets
, serverName = serverName'
, serverPort = port
, requestHeaders = heads
, isSecure = False
, errorHandler = System.IO.hPutStr System.IO.stderr
, remoteHost = remoteHost'
})
takeUntil :: Word8 -> ByteString -> ByteString
takeUntil c bs =
case S.elemIndex c bs of
Just !idx -> SU.unsafeTake idx bs
Nothing -> bs
parseFirst :: ByteString
-> E.Iteratee S.ByteString IO (ByteString, ByteString, ByteString, HttpVersion)
parseFirst s = do
let pieces = S.split 32 s
(method, query, http') <-
case pieces of
[x, y, z] -> return (x, y, z)
_ -> E.throwError $ BadFirstLine $ B.unpack s
let (hfirst, hsecond) = B.splitAt 5 http'
if (hfirst == "HTTP/")
then
let (rpath, qstring) = B.break (== '?') query
in return (method, rpath, qstring, hsecond)
else E.throwError NonHttp
httpBuilder, spaceBuilder, newlineBuilder, transferEncodingBuilder
, colonSpaceBuilder :: Builder
httpBuilder = copyByteString "HTTP/"
spaceBuilder = fromChar ' '
newlineBuilder = copyByteString "\r\n"
transferEncodingBuilder = copyByteString "Transfer-Encoding: chunked\r\n\r\n"
colonSpaceBuilder = copyByteString ": "
headers :: HttpVersion -> Status -> ResponseHeaders -> Bool -> Builder
headers !httpversion !status !responseHeaders !isChunked' =
let !start = httpBuilder
`mappend` copyByteString httpversion
`mappend` spaceBuilder
`mappend` fromShow (statusCode status)
`mappend` spaceBuilder
`mappend` copyByteString (statusMessage status)
`mappend` newlineBuilder
!start' = foldl' responseHeaderToBuilder start responseHeaders
!end = if isChunked'
then transferEncodingBuilder
else newlineBuilder
in start' `mappend` end
responseHeaderToBuilder :: Builder -> (CIByteString, ByteString) -> Builder
responseHeaderToBuilder b (x, y) = b
`mappend` (copyByteString $ ciOriginal x)
`mappend` colonSpaceBuilder
`mappend` copyByteString y
`mappend` newlineBuilder
isChunked :: HttpVersion -> Bool
isChunked = (==) http11
hasBody :: Status -> Request -> Bool
hasBody s req = s /= (Status 204 "") && requestMethod req /= "HEAD"
sendResponse :: T.Handle
-> Request -> HttpVersion -> Socket -> Response -> IO Bool
sendResponse th req hv socket (ResponseFile s hs fp) = do
Sock.sendMany socket $ L.toChunks $ toLazyByteString $ headers hv s hs False
if hasBody s req
then do
sendFileIterWith tickler socket fp sendFileCount
return $ lookup "content-length" hs /= Nothing
else return True
where
tickler iter = do
r <- iter
case r of
Done _ -> return ()
Sent _ cont -> T.tickle th >> tickler cont
WouldBlock _ fd cont -> do
threadWaitWrite fd
tickler cont
sendResponse th req hv socket (ResponseBuilder s hs b)
| hasBody s req = do
toByteStringIO (\bs -> do
Sock.sendAll socket bs
T.tickle th) b'
return isKeepAlive
| otherwise = do
Sock.sendMany socket
$ L.toChunks
$ toLazyByteString
$ headers hv s hs False
T.tickle th
return True
where
headers' = headers hv s hs isChunked'
b' =
if isChunked'
then headers'
`mappend` chunkedTransferEncoding b
`mappend` chunkedTransferTerminator
else headers hv s hs False `mappend` b
hasLength = lookup "content-length" hs /= Nothing
isChunked' = isChunked hv && not hasLength
isKeepAlive = isChunked' || hasLength
sendResponse th req hv socket (ResponseEnumerator res) =
res go
where
go s hs
| not (hasBody s req) = do
liftIO $ Sock.sendMany socket
$ L.toChunks $ toLazyByteString
$ headers hv s hs False
return True
go s hs =
chunk'
$ E.enumList 1 [headers hv s hs isChunked']
$$ E.joinI $ builderToByteString
$$ (iterSocket th socket >> return isKeepAlive)
where
hasLength = lookup "content-length" hs /= Nothing
isChunked' = isChunked hv && not hasLength
isKeepAlive = isChunked' || hasLength
chunk' i =
if isChunked'
then E.joinI $ chunk $$ i
else i
chunk :: E.Enumeratee Builder Builder IO Bool
chunk = E.checkDone $ E.continue . step
step k E.EOF = k (E.Chunks [chunkedTransferTerminator]) >>== return
step k (E.Chunks []) = E.continue $ step k
step k (E.Chunks [x]) = k (E.Chunks [chunkedTransferEncoding x]) >>== chunk
step k (E.Chunks xs) = k (E.Chunks [chunkedTransferEncoding $ mconcat xs]) >>== chunk
parseHeaderNoAttr :: ByteString -> (CIByteString, ByteString)
parseHeaderNoAttr s =
let (k, rest) = S.breakByte 58 s
restLen = S.length rest
rest' = if restLen > 1 && SU.unsafeTake 2 rest == ": "
then SU.unsafeDrop 2 rest
else rest
in (mkCIByteString k, rest')
enumSocket :: T.Handle -> Int -> Socket -> E.Enumerator ByteString IO a
enumSocket th len socket =
inner
where
inner (E.Continue k) = do
bs <- liftIO $ Sock.recv socket len
liftIO $ T.tickle th
if S.null bs
then E.throwError SocketTimeout
else go k bs
inner step = E.returnI step
go k bs
| S.length bs == 0 = E.continue k
| otherwise = k (E.Chunks [bs]) >>== enumSocket th len socket
iterSocket :: T.Handle
-> Socket
-> E.Iteratee B.ByteString IO ()
iterSocket th sock =
E.continue step
where
step E.EOF = liftIO (T.resume th) >> E.yield () E.EOF
step (E.Chunks []) = E.continue step
step (E.Chunks xs) = do
liftIO $ T.resume th
liftIO $ Sock.sendMany sock xs
liftIO $ T.pause th
E.continue step
data Settings = Settings
{ settingsPort :: Int
, settingsOnException :: SomeException -> IO ()
, settingsTimeout :: Int
}
defaultSettings :: Settings
defaultSettings = Settings
{ settingsPort = 3000
, settingsOnException = \e ->
case fromException e of
Just x -> go x
Nothing -> print e
, settingsTimeout = 30
}
where
go :: InvalidRequest -> IO ()
go _ = return ()
takeHeaders :: E.Iteratee ByteString IO [ByteString]
takeHeaders = do
!x <- forceHead
takeHeaders' 0 id id x
takeHeaders' :: Int
-> ([ByteString] -> [ByteString])
-> ([ByteString] -> [ByteString])
-> ByteString
-> E.Iteratee S.ByteString IO [ByteString]
takeHeaders' !len _ _ _ | len > maxTotalHeaderLength = E.throwError OverLargeHeader
takeHeaders' !len !lines !prepend !bs = do
let !bsLen = S.length bs
!mnl = S.elemIndex 10 bs
case mnl of
!Nothing -> do
let !len' = len + bsLen
!more <- forceHead
takeHeaders' len' lines (prepend . (:) bs) more
Just !nl -> do
let !end = nl
!start = nl + 1
!line =
if end > 0
then S.concat $! prepend [SU.unsafeTake (checkCR bs end) bs]
else S.concat $! prepend []
if S.null line
then do
let !lines' = lines []
if start < bsLen
then do
let !rest = SU.unsafeDrop start bs
E.yield lines' $! E.Chunks [rest]
else return lines'
else do
let !len' = len + start
!lines' = lines . (:) line
!more <-
if start < bsLen
then return $! SU.unsafeDrop start bs
else forceHead
takeHeaders' len' lines' id more
forceHead :: E.Iteratee ByteString IO ByteString
forceHead = do
!mx <- EL.head
case mx of
!Nothing -> E.throwError IncompleteHeaders
Just !x -> return x
checkCR :: ByteString -> Int -> Int
checkCR bs pos =
let !p = pos 1
in if '\r' == B.index bs p
then p
else pos