module Network.Wai.Handler.Warp
(
run
, runSettings
, runSettingsSocket
, Settings
, defaultSettings
, settingsPort
, settingsHost
, settingsOnException
, settingsTimeout
, Port
, InvalidRequest (..)
#if TEST
, takeHeaders
#endif
) where
import Prelude hiding (catch, lines)
import Network.Wai
import Data.ByteString (ByteString)
import qualified Data.ByteString as S
import qualified Data.ByteString.Unsafe as SU
import qualified Data.ByteString.Char8 as B
import qualified Data.ByteString.Lazy as L
import Network (sClose, Socket)
import Network.Socket
( accept, SockAddr (SockAddrInet), Family (AF_INET)
, SocketType (Stream), listen, bindSocket, setSocketOption
, SocketOption (ReuseAddr), iNADDR_ANY, inet_addr, SockAddr (SockAddrInet)
)
import qualified Network.Socket
import qualified Network.Socket.ByteString as Sock
import Control.Exception
( bracket, finally, Exception, SomeException, catch
, fromException
)
import Control.Concurrent (forkIO, threadWaitWrite)
import qualified Data.Char as C
import Data.Maybe (fromMaybe)
import Data.Typeable (Typeable)
import Data.Enumerator (($$), (>>==))
import qualified Data.Enumerator as E
import qualified Data.Enumerator.List as EL
import qualified Data.Enumerator.Binary as EB
import Blaze.ByteString.Builder.Enumerator (builderToByteString)
import Blaze.ByteString.Builder.HTTP
(chunkedTransferEncoding, chunkedTransferTerminator)
import Blaze.ByteString.Builder
(copyByteString, Builder, toLazyByteString, toByteStringIO)
import Blaze.ByteString.Builder.Char8 (fromChar, fromShow)
import Data.Monoid (mappend, mconcat)
import Network.Socket.SendFile (sendFileIterWith,sendFileIterWith', Iter (..))
import Control.Monad.IO.Class (liftIO)
import qualified Timeout as T
import Data.Word (Word8)
import Data.List (foldl')
import Control.Monad (forever)
import qualified Network.HTTP.Types as H
import qualified Data.CaseInsensitive as CI
import System.IO (hPutStrLn, stderr)
#if WINDOWS
import Control.Concurrent (threadDelay)
import qualified Control.Concurrent.MVar as MV
import Network.Socket (withSocketsDo)
#endif
bindPort :: Int -> String -> IO Socket
bindPort p s = do
sock <- Network.Socket.socket AF_INET Stream 0
h <- if s == "*" then return iNADDR_ANY else inet_addr s
let addr = SockAddrInet (fromIntegral p) h
setSocketOption sock ReuseAddr 1
bindSocket sock addr
listen sock 150
return sock
run :: Port -> Application -> IO ()
run p = runSettings defaultSettings { settingsPort = p }
runSettings :: Settings -> Application -> IO ()
#if WINDOWS
runSettings set app = withSocketsDo $ do
var <- MV.newMVar Nothing
let clean = MV.modifyMVar_ var $ \s -> maybe (return ()) sClose s >> return Nothing
_ <- forkIO $ bracket
(bindPort (settingsPort set) (settingsHost set))
(const clean)
(\s -> do
MV.modifyMVar_ var (\_ -> return $ Just s)
runSettingsSocket set s app)
forever (threadDelay maxBound) `finally` clean
#else
runSettings set =
bracket
(bindPort (settingsPort set) (settingsHost set))
sClose .
(flip (runSettingsSocket set))
#endif
type Port = Int
runSettingsSocket :: Settings -> Socket -> Application -> IO ()
runSettingsSocket set socket app = do
let onE = settingsOnException set
port = settingsPort set
tm <- T.initialize $ settingsTimeout set * 1000000
forever $ do
(conn, sa) <- accept socket
_ <- forkIO $ do
th <- T.registerKillThread tm
serveConnection th onE port app conn sa
T.cancel th
return ()
serveConnection :: T.Handle
-> (SomeException -> IO ())
-> Port -> Application -> Socket -> SockAddr -> IO ()
serveConnection th onException port app conn remoteHost' = do
catch
(finally
(E.run_ $ fromClient $$ serveConnection')
(sClose conn))
onException
where
fromClient = enumSocket th bytesPerRead conn
serveConnection' = do
(len, env) <- parseRequest port remoteHost'
liftIO $ T.pause th
res <- E.joinI $ EB.isolate len $$ app env
liftIO $ T.resume th
keepAlive <- liftIO $ sendResponse th env (httpVersion env) conn res
if keepAlive then serveConnection' else return ()
parseRequest :: Port -> SockAddr -> E.Iteratee S.ByteString IO (Integer, Request)
parseRequest port remoteHost' = do
headers' <- takeHeaders
parseRequest' port headers' remoteHost'
bytesPerRead, maxTotalHeaderLength :: Int
bytesPerRead = 4096
maxTotalHeaderLength = 50 * 1024
sendFileCount :: Integer
sendFileCount = 65536
data InvalidRequest =
NotEnoughLines [String]
| BadFirstLine String
| NonHttp
| IncompleteHeaders
| OverLargeHeader
deriving (Show, Typeable, Eq)
instance Exception InvalidRequest
parseRequest' :: Port
-> [ByteString]
-> SockAddr
-> E.Iteratee S.ByteString IO (Integer, Request)
parseRequest' _ [] _ = E.throwError $ NotEnoughLines []
parseRequest' port (firstLine:otherLines) remoteHost' = do
(method, rpath', gets, httpversion) <- parseFirst firstLine
let (host',rpath) =
if S.null rpath'
then ("","/")
else if "http://" `S.isPrefixOf` rpath'
then S.breakByte 47 $ S.drop 7 rpath'
else ("", rpath')
let heads = map parseHeaderNoAttr otherLines
let host = fromMaybe host' $ lookup "host" heads
let len =
case lookup "content-length" heads of
Nothing -> 0
Just bs -> fromIntegral $ B.foldl' (\i c -> i * 10 + C.digitToInt c) 0 $ B.takeWhile C.isDigit bs
let serverName' = takeUntil 58 host
return (len, Request
{ requestMethod = method
, httpVersion = httpversion
, pathInfo = H.decodePathSegments rpath
, rawPathInfo = rpath
, rawQueryString = gets
, queryString = H.parseQuery gets
, serverName = serverName'
, serverPort = port
, requestHeaders = heads
, isSecure = False
, remoteHost = remoteHost'
})
takeUntil :: Word8 -> ByteString -> ByteString
takeUntil c bs =
case S.elemIndex c bs of
Just !idx -> SU.unsafeTake idx bs
Nothing -> bs
parseFirst :: ByteString
-> E.Iteratee S.ByteString IO (ByteString, ByteString, ByteString, H.HttpVersion)
parseFirst s =
case S.split 32 s of
[method, query, http'] -> do
let (hfirst, hsecond) = B.splitAt 5 http'
if hfirst == "HTTP/"
then let (rpath, qstring) = S.breakByte 63 query
hv =
case hsecond of
"1.1" -> H.http11
_ -> H.http10
in return (method, rpath, qstring, hv)
else E.throwError NonHttp
_ -> E.throwError $ BadFirstLine $ B.unpack s
httpBuilder, spaceBuilder, newlineBuilder, transferEncodingBuilder
, colonSpaceBuilder :: Builder
httpBuilder = copyByteString "HTTP/"
spaceBuilder = fromChar ' '
newlineBuilder = copyByteString "\r\n"
transferEncodingBuilder = copyByteString "Transfer-Encoding: chunked\r\n\r\n"
colonSpaceBuilder = copyByteString ": "
headers :: H.HttpVersion -> H.Status -> H.ResponseHeaders -> Bool -> Builder
headers !httpversion !status !responseHeaders !isChunked' =
let !start = httpBuilder
`mappend` (copyByteString $
case httpversion of
H.HttpVersion 1 1 -> "1.1"
_ -> "1.0")
`mappend` spaceBuilder
`mappend` fromShow (H.statusCode status)
`mappend` spaceBuilder
`mappend` copyByteString (H.statusMessage status)
`mappend` newlineBuilder
!start' = foldl' responseHeaderToBuilder start responseHeaders
!end = if isChunked'
then transferEncodingBuilder
else newlineBuilder
in start' `mappend` end
responseHeaderToBuilder :: Builder -> H.Header -> Builder
responseHeaderToBuilder b (x, y) = b
`mappend` copyByteString (CI.original x)
`mappend` colonSpaceBuilder
`mappend` copyByteString y
`mappend` newlineBuilder
isChunked :: H.HttpVersion -> Bool
isChunked = (==) H.http11
hasBody :: H.Status -> Request -> Bool
hasBody s req = s /= (H.Status 204 "") && requestMethod req /= "HEAD"
sendResponse :: T.Handle
-> Request -> H.HttpVersion -> Socket -> Response -> IO Bool
sendResponse th req hv socket (ResponseFile s hs fp mpart) = do
Sock.sendMany socket $ L.toChunks $ toLazyByteString $ headers hv s hs False
if hasBody s req
then do
case mpart of
Nothing -> sendFileIterWith tickler socket fp sendFileCount
Just part ->
sendFileIterWith' tickler socket fp sendFileCount
(filePartOffset part)
(filePartByteCount part)
return $ lookup "content-length" hs /= Nothing
else return True
where
tickler iter = do
r <- iter
case r of
Done _ -> return ()
Sent _ cont -> T.tickle th >> tickler cont
WouldBlock _ fd cont -> do
threadWaitWrite fd
tickler cont
sendResponse th req hv socket (ResponseBuilder s hs b)
| hasBody s req = do
toByteStringIO (\bs -> do
Sock.sendAll socket bs
T.tickle th) b'
return isKeepAlive
| otherwise = do
Sock.sendMany socket
$ L.toChunks
$ toLazyByteString
$ headers hv s hs False
T.tickle th
return True
where
headers' = headers hv s hs isChunked'
b' = if isChunked'
then headers'
`mappend` chunkedTransferEncoding b
`mappend` chunkedTransferTerminator
else headers hv s hs False `mappend` b
hasLength = lookup "content-length" hs /= Nothing
isChunked' = isChunked hv && not hasLength
isKeepAlive = isChunked' || hasLength
sendResponse th req hv socket (ResponseEnumerator res) =
res go
where
go s hs | not (hasBody s req) = do
liftIO $ Sock.sendMany socket
$ L.toChunks $ toLazyByteString
$ headers hv s hs False
return True
go s hs = chunk'
$ E.enumList 1 [headers hv s hs isChunked']
$$ E.joinI $ builderToByteString
$$ (iterSocket th socket >> return isKeepAlive)
where
hasLength = lookup "content-length" hs /= Nothing
isChunked' = isChunked hv && not hasLength
isKeepAlive = isChunked' || hasLength
chunk' i = if isChunked'
then E.joinI $ chunk $$ i
else i
chunk :: E.Enumeratee Builder Builder IO Bool
chunk = E.checkDone $ E.continue . step
step k E.EOF = k (E.Chunks [chunkedTransferTerminator]) >>== return
step k (E.Chunks []) = E.continue $ step k
step k (E.Chunks [x]) = k (E.Chunks [chunkedTransferEncoding x]) >>== chunk
step k (E.Chunks xs) = k (E.Chunks [chunkedTransferEncoding $ mconcat xs]) >>== chunk
parseHeaderNoAttr :: ByteString -> H.Header
parseHeaderNoAttr s =
let (k, rest) = S.breakByte 58 s
restLen = S.length rest
rest' = if restLen > 1 && SU.unsafeTake 2 rest == ": "
then SU.unsafeDrop 2 rest
else rest
in (CI.mk k, rest')
enumSocket :: T.Handle -> Int -> Socket -> E.Enumerator ByteString IO a
enumSocket th len socket =
inner
where
inner (E.Continue k) = do
bs <- liftIO $ Sock.recv socket len
liftIO $ T.tickle th
if S.null bs
then E.continue k
else k (E.Chunks [bs]) >>== inner
inner step = E.returnI step
iterSocket :: T.Handle
-> Socket
-> E.Iteratee B.ByteString IO ()
iterSocket th sock =
E.continue step
where
step E.EOF = liftIO (T.resume th) >> E.yield () E.EOF
step (E.Chunks []) = E.continue step
step (E.Chunks xs) = do
liftIO $ T.resume th
liftIO $ Sock.sendMany sock xs
liftIO $ T.pause th
E.continue step
data Settings = Settings
{ settingsPort :: Int
, settingsHost :: String
, settingsOnException :: SomeException -> IO ()
, settingsTimeout :: Int
}
defaultSettings :: Settings
defaultSettings = Settings
{ settingsPort = 3000
, settingsHost = "*"
, settingsOnException = \e ->
case fromException e of
Just x -> go x
Nothing -> hPutStrLn stderr $ show e
, settingsTimeout = 30
}
where
go :: InvalidRequest -> IO ()
go _ = return ()
takeHeaders :: E.Iteratee ByteString IO [ByteString]
takeHeaders = do
!x <- forceHead
takeHeaders' 0 id id x
takeHeaders' :: Int
-> ([ByteString] -> [ByteString])
-> ([ByteString] -> [ByteString])
-> ByteString
-> E.Iteratee S.ByteString IO [ByteString]
takeHeaders' !len _ _ _ | len > maxTotalHeaderLength = E.throwError OverLargeHeader
takeHeaders' !len !lines !prepend !bs = do
let !bsLen = S.length bs
!mnl = S.elemIndex 10 bs
case mnl of
!Nothing -> do
let !len' = len + bsLen
!more <- forceHead
takeHeaders' len' lines (prepend . (:) bs) more
Just !nl -> do
let !end = nl
!start = nl + 1
!line =
if end > 0
then S.concat $! prepend [SU.unsafeTake (checkCR bs end) bs]
else S.concat $! prepend []
if S.null line
then do
let !lines' = lines []
if start < bsLen
then do
let !rest = SU.unsafeDrop start bs
E.yield lines' $! E.Chunks [rest]
else return lines'
else do
let !len' = len + start
!lines' = lines . (:) line
!more <-
if start < bsLen
then return $! SU.unsafeDrop start bs
else forceHead
takeHeaders' len' lines' id more
forceHead :: E.Iteratee ByteString IO ByteString
forceHead = do
!mx <- EL.head
case mx of
!Nothing -> E.throwError IncompleteHeaders
Just !x -> return x
checkCR :: ByteString -> Int -> Int
checkCR bs pos =
let !p = pos 1
in if '\r' == B.index bs p
then p
else pos