module Network.Wai.Handler.Warp
(
run
, runSettings
, runSettingsSocket
, Settings
, defaultSettings
, settingsPort
, settingsHost
, settingsOnException
, settingsTimeout
, settingsIntercept
, settingsManager
, Connection (..)
, runSettingsConnection
, Port
, InvalidRequest (..)
, Manager
, withManager
, parseRequest
, sendResponse
, registerKillThread
, bindPort
, pause
, resume
, T.cancel
, T.register
, T.initialize
#if TEST
, takeHeaders
, readInt
#endif
) where
import Prelude hiding (catch, lines)
import Network.Wai
import Data.ByteString (ByteString)
import qualified Data.ByteString as S
import qualified Data.ByteString.Unsafe as SU
import qualified Data.ByteString.Char8 as B
import qualified Data.ByteString.Lazy as L
import Network (sClose, Socket)
import Network.Socket
( accept, Family (..)
, SocketType (Stream), listen, bindSocket, setSocketOption, maxListenQueue
, SockAddr, SocketOption (ReuseAddr)
, AddrInfo(..), AddrInfoFlag(..), defaultHints, getAddrInfo
)
import qualified Network.Socket
import qualified Network.Socket.ByteString as Sock
import Control.Exception
( bracket, finally, Exception, SomeException, catch
, fromException, AsyncException (ThreadKilled)
, bracketOnError
)
import Control.Concurrent (forkIO)
import qualified Data.Char as C
import Data.Maybe (fromMaybe)
import Data.Typeable (Typeable)
import Control.Monad.Trans.Resource (ResourceT, runResourceT)
import qualified Data.Conduit as C
import qualified Data.Conduit.Binary as CB
import qualified Data.Conduit.List as CL
import Data.Conduit.Blaze (builderToByteString)
import Control.Exception.Lifted (throwIO)
import Blaze.ByteString.Builder.HTTP
(chunkedTransferEncoding, chunkedTransferTerminator)
import Blaze.ByteString.Builder
(copyByteString, Builder, toLazyByteString, toByteStringIO)
import Blaze.ByteString.Builder.Char8 (fromChar, fromShow)
import Data.Monoid (mappend, mempty)
import Network.Sendfile
import qualified System.PosixCompat.Files as P
import Control.Monad.IO.Class (liftIO)
import qualified Timeout as T
import Timeout (Manager, registerKillThread, pause, resume)
import Data.Word (Word8)
import Data.List (foldl')
import Control.Monad (forever, when)
import qualified Network.HTTP.Types as H
import qualified Data.CaseInsensitive as CI
import System.IO (hPutStrLn, stderr)
#if WINDOWS
import Control.Concurrent (threadDelay)
import qualified Control.Concurrent.MVar as MV
import Network.Socket (withSocketsDo)
#endif
import Data.Version (showVersion)
import qualified Paths_warp
warpVersion :: String
warpVersion = showVersion Paths_warp.version
data Connection = Connection
{ connSendMany :: [B.ByteString] -> IO ()
, connSendAll :: B.ByteString -> IO ()
, connSendFile :: FilePath -> Integer -> Integer -> IO () -> IO ()
, connClose :: IO ()
, connRecv :: IO B.ByteString
}
socketConnection :: Socket -> Connection
socketConnection s = Connection
{ connSendMany = Sock.sendMany s
, connSendAll = Sock.sendAll s
, connSendFile = \fp off len act -> sendfile s fp (PartOfFile off len) act
, connClose = sClose s
, connRecv = Sock.recv s bytesPerRead
}
bindPort :: Int -> String -> IO Socket
bindPort p s = do
let hints = defaultHints { addrFlags = [AI_PASSIVE
, AI_NUMERICSERV
, AI_NUMERICHOST]
, addrSocketType = Stream }
host = if s == "*" then Nothing else Just s
port = Just . show $ p
addrs <- getAddrInfo (Just hints) host port
let addrs' = filter (\x -> addrFamily x == AF_INET6) addrs
addr = if null addrs' then head addrs else head addrs'
bracketOnError
(Network.Socket.socket (addrFamily addr) (addrSocketType addr) (addrProtocol addr))
(sClose)
(\sock -> do
setSocketOption sock ReuseAddr 1
bindSocket sock (addrAddress addr)
listen sock maxListenQueue
return sock
)
run :: Port -> Application -> IO ()
run p = runSettings defaultSettings { settingsPort = p }
runSettings :: Settings -> Application -> IO ()
#if WINDOWS
runSettings set app = withSocketsDo $ do
var <- MV.newMVar Nothing
let clean = MV.modifyMVar_ var $ \s -> maybe (return ()) sClose s >> return Nothing
_ <- forkIO $ bracket
(bindPort (settingsPort set) (settingsHost set))
(const clean)
(\s -> do
MV.modifyMVar_ var (\_ -> return $ Just s)
runSettingsSocket set s app)
forever (threadDelay maxBound) `finally` clean
#else
runSettings set =
bracket
(bindPort (settingsPort set) (settingsHost set))
sClose .
(flip (runSettingsSocket set))
#endif
type Port = Int
runSettingsSocket :: Settings -> Socket -> Application -> IO ()
runSettingsSocket set socket app = do
runSettingsConnection set getter app
where
getter = do
(conn, sa) <- accept socket
return (socketConnection conn, sa)
runSettingsConnection :: Settings -> IO (Connection, SockAddr) -> Application -> IO ()
runSettingsConnection set getConn app = do
let onE = settingsOnException set
port = settingsPort set
tm <- maybe (T.initialize $ settingsTimeout set * 1000000) return
$ settingsManager set
forever $ do
(conn, addr) <- getConn
_ <- forkIO $ do
th <- T.registerKillThread tm
serveConnection set th onE port app conn addr
T.cancel th
return ()
serveConnection :: Settings
-> T.Handle
-> (SomeException -> IO ())
-> Port -> Application -> Connection -> SockAddr -> IO ()
serveConnection settings th onException port app conn remoteHost' = do
catch
(finally
(runResourceT serveConnection')
(connClose conn))
onException
where
serveConnection' :: ResourceT IO ()
serveConnection' = do
fromClient <- C.bufferSource $ C.Source $ return $ connSource conn th
serveConnection'' fromClient
serveConnection'' fromClient = do
env <- parseRequest port remoteHost' fromClient
case settingsIntercept settings env of
Nothing -> do
liftIO $ T.pause th
res <- app env
requestBody env C.$$ CL.sinkNull
liftIO $ T.resume th
keepAlive <- sendResponse th env conn res
if keepAlive then serveConnection'' fromClient else return ()
Just intercept -> do
liftIO $ T.pause th
intercept fromClient conn
parseRequest :: Port -> SockAddr
-> C.BufferedSource IO S.ByteString
-> ResourceT IO Request
parseRequest port remoteHost' src = do
headers' <- src C.$$ takeHeaders
parseRequest' port headers' remoteHost' src
bytesPerRead, maxTotalHeaderLength :: Int
bytesPerRead = 4096
maxTotalHeaderLength = 50 * 1024
data InvalidRequest =
NotEnoughLines [String]
| BadFirstLine String
| NonHttp
| IncompleteHeaders
| OverLargeHeader
deriving (Show, Typeable, Eq)
instance Exception InvalidRequest
parseRequest' :: Port
-> [ByteString]
-> SockAddr
-> C.BufferedSource IO S.ByteString
-> ResourceT IO Request
parseRequest' _ [] _ _ = throwIO $ NotEnoughLines []
parseRequest' port (firstLine:otherLines) remoteHost' src = do
(method, rpath', gets, httpversion) <- parseFirst firstLine
let (host',rpath) =
if S.null rpath'
then ("","/")
else if "http://" `S.isPrefixOf` rpath'
then S.breakByte 47 $ S.drop 7 rpath'
else ("", rpath')
let heads = map parseHeaderNoAttr otherLines
let host = fromMaybe host' $ lookup "host" heads
let len =
case lookup "content-length" heads of
Nothing -> 0
Just bs -> fromIntegral $ B.foldl' (\i c -> i * 10 + C.digitToInt c) 0 $ B.takeWhile C.isDigit bs
let serverName' = takeUntil 58 host
rbody <- C.prepareSource $
if len == 0
then mempty
else src C.$= CB.isolate len
return Request
{ requestMethod = method
, httpVersion = httpversion
, pathInfo = H.decodePathSegments rpath
, rawPathInfo = rpath
, rawQueryString = gets
, queryString = H.parseQuery gets
, serverName = serverName'
, serverPort = port
, requestHeaders = heads
, isSecure = False
, remoteHost = remoteHost'
, requestBody = C.Source $ return rbody
, vault = mempty
}
takeUntil :: Word8 -> ByteString -> ByteString
takeUntil c bs =
case S.elemIndex c bs of
Just !idx -> SU.unsafeTake idx bs
Nothing -> bs
parseFirst :: ByteString
-> ResourceT IO (ByteString, ByteString, ByteString, H.HttpVersion)
parseFirst s =
case S.split 32 s of
[method, query, http'] -> do
let (hfirst, hsecond) = B.splitAt 5 http'
if hfirst == "HTTP/"
then let (rpath, qstring) = S.breakByte 63 query
hv =
case hsecond of
"1.1" -> H.http11
_ -> H.http10
in return (method, rpath, qstring, hv)
else throwIO NonHttp
_ -> throwIO $ BadFirstLine $ B.unpack s
httpBuilder, spaceBuilder, newlineBuilder, transferEncodingBuilder
, colonSpaceBuilder :: Builder
httpBuilder = copyByteString "HTTP/"
spaceBuilder = fromChar ' '
newlineBuilder = copyByteString "\r\n"
transferEncodingBuilder = copyByteString "Transfer-Encoding: chunked\r\n\r\n"
colonSpaceBuilder = copyByteString ": "
headers :: H.HttpVersion -> H.Status -> H.ResponseHeaders -> Bool -> Builder
headers !httpversion !status !responseHeaders !isChunked' =
let !start = httpBuilder
`mappend` (copyByteString $
case httpversion of
H.HttpVersion 1 1 -> "1.1"
_ -> "1.0")
`mappend` spaceBuilder
`mappend` fromShow (H.statusCode status)
`mappend` spaceBuilder
`mappend` copyByteString (H.statusMessage status)
`mappend` newlineBuilder
!start' = foldl' responseHeaderToBuilder start (serverHeader responseHeaders)
!end = if isChunked'
then transferEncodingBuilder
else newlineBuilder
in start' `mappend` end
responseHeaderToBuilder :: Builder -> H.Header -> Builder
responseHeaderToBuilder b (x, y) = b
`mappend` copyByteString (CI.original x)
`mappend` colonSpaceBuilder
`mappend` copyByteString y
`mappend` newlineBuilder
checkPersist :: Request -> Bool
checkPersist req
| ver == H.http11 = checkPersist11 conn
| otherwise = checkPersist10 conn
where
ver = httpVersion req
conn = lookup "connection" $ requestHeaders req
checkPersist11 (Just x)
| CI.foldCase x == "close" = False
checkPersist11 _ = True
checkPersist10 (Just x)
| CI.foldCase x == "keep-alive" = True
checkPersist10 _ = False
isChunked :: H.HttpVersion -> Bool
isChunked = (==) H.http11
hasBody :: H.Status -> Request -> Bool
hasBody s req = s /= (H.Status 204 "") && s /= H.status304 &&
H.statusCode s >= 200 && requestMethod req /= "HEAD"
sendResponse :: T.Handle
-> Request -> Connection -> Response -> ResourceT IO Bool
sendResponse th req conn r = sendResponse' r
where
version = httpVersion req
isPersist = checkPersist req
isChunked' = isChunked version
needsChunked hs = isChunked' && not (hasLength hs)
isKeepAlive hs = isPersist && (isChunked' || hasLength hs)
hasLength hs = lookup "content-length" hs /= Nothing
sendResponse' :: Response -> ResourceT IO Bool
sendResponse' (ResponseFile s hs fp mpart) = liftIO $ do
(lengthyHeaders, cl) <-
case (readInt `fmap` lookup "content-length" hs, mpart) of
(Just cl, _) -> return (hs, cl)
(Nothing, Nothing) -> do
cl <- P.fileSize `fmap` P.getFileStatus fp
return $ addClToHeaders cl
(Nothing, Just part) -> do
let cl = filePartByteCount part
return $ addClToHeaders cl
connSendMany conn $ L.toChunks $ toLazyByteString $
headers version s lengthyHeaders False
T.tickle th
if not (hasBody s req) then return isPersist else do
case mpart of
Nothing -> connSendFile conn fp 0 cl (T.tickle th)
Just part -> connSendFile conn fp (filePartOffset part) (filePartByteCount part) (T.tickle th)
T.tickle th
return isPersist
where
addClToHeaders cl = (("Content-Length", B.pack $ show cl):hs, fromIntegral cl)
sendResponse' (ResponseBuilder s hs b)
| hasBody s req = liftIO $ do
toByteStringIO (\bs -> do
connSendAll conn bs
T.tickle th) body
return (isKeepAlive hs)
| otherwise = liftIO $ do
connSendMany conn
$ L.toChunks
$ toLazyByteString
$ headers' False
T.tickle th
return isPersist
where
headers' = headers version s hs
needsChunked' = needsChunked hs
body = if needsChunked'
then (headers' needsChunked')
`mappend` chunkedTransferEncoding b
`mappend` chunkedTransferTerminator
else (headers' False) `mappend` b
sendResponse' (ResponseSource s hs body) =
response
where
headers' = headers version s hs
response
| not (hasBody s req) = liftIO $ do
connSendMany conn
$ L.toChunks $ toLazyByteString
$ headers' False
T.tickle th
return (checkPersist req)
| otherwise = do
let src =
CL.sourceList [headers' needsChunked'] `mappend`
(if needsChunked' then body C.$= chunk else body)
src C.$$ builderToByteString C.=$ connSink conn th
return $ isKeepAlive hs
needsChunked' = needsChunked hs
chunk :: C.Conduit Builder IO Builder
chunk = C.Conduit $ return $ C.PreparedConduit
{ C.conduitPush = push
, C.conduitClose = close
}
push x = return $ C.Producing [chunkedTransferEncoding x]
close = return [chunkedTransferTerminator]
parseHeaderNoAttr :: ByteString -> H.Header
parseHeaderNoAttr s =
let (k, rest) = S.breakByte 58 s
restLen = S.length rest
rest' = if restLen > 1 && SU.unsafeTake 2 rest == ": "
then SU.unsafeDrop 2 rest
else rest
in (CI.mk k, rest')
connSource :: Connection -> T.Handle -> C.PreparedSource IO ByteString
connSource Connection { connRecv = recv } th = C.PreparedSource
{ C.sourcePull = do
bs <- liftIO recv
if S.null bs
then return C.Closed
else do
when (S.length bs >= 2048) $ liftIO $ T.tickle th
return (C.Open bs)
, C.sourceClose = return ()
}
connSink :: Connection -> T.Handle -> C.Sink B.ByteString IO ()
connSink Connection { connSendAll = send } th = C.Sink $ return $ C.SinkData
{ C.sinkPush = push
, C.sinkClose = close
}
where
close = liftIO (T.resume th)
push x = do
liftIO $ T.resume th
liftIO $ send x
liftIO $ T.pause th
return $ C.Processing
data Settings = Settings
{ settingsPort :: Int
, settingsHost :: String
, settingsOnException :: SomeException -> IO ()
, settingsTimeout :: Int
, settingsIntercept :: Request -> Maybe (C.BufferedSource IO S.ByteString -> Connection -> ResourceT IO ())
, settingsManager :: Maybe Manager
}
defaultSettings :: Settings
defaultSettings = Settings
{ settingsPort = 3000
, settingsHost = "*"
, settingsOnException = \e ->
case fromException e of
Just x -> go x
Nothing ->
if go' $ fromException e
then hPutStrLn stderr $ show e
else return ()
, settingsTimeout = 30
, settingsIntercept = const Nothing
, settingsManager = Nothing
}
where
go :: InvalidRequest -> IO ()
go _ = return ()
go' (Just ThreadKilled) = False
go' _ = True
type BSEndo = ByteString -> ByteString
type BSEndoList = [ByteString] -> [ByteString]
data THStatus = THStatus
!Int
BSEndoList
BSEndo
takeHeaders :: C.Sink ByteString IO [ByteString]
takeHeaders =
C.sinkState (THStatus 0 id id) takeHeadersPush close
where
close _ = throwIO IncompleteHeaders
takeHeadersPush :: THStatus
-> ByteString
-> ResourceT IO (THStatus, C.SinkResult ByteString [ByteString])
takeHeadersPush (THStatus len _ _ ) _
| len > maxTotalHeaderLength = throwIO OverLargeHeader
takeHeadersPush (THStatus len lines prepend) bs =
case mnl of
Nothing -> do
let len' = len + bsLen
return (THStatus len' lines (prepend . S.append bs), C.Processing)
Just nl -> do
let end = nl
start = nl + 1
line = prepend (if end > 0
then SU.unsafeTake (checkCR bs end) bs
else S.empty)
if S.null line
then do
let lines' = lines []
if start < bsLen
then do
let rest = SU.unsafeDrop start bs
return (undefined, C.Done (Just rest) lines')
else return (undefined, C.Done Nothing lines')
else do
let len' = len + start
lines' = lines . (:) line
if start < bsLen
then do
let more = SU.unsafeDrop start bs
takeHeadersPush (THStatus len' lines' id) more
else return (THStatus len' lines' id, C.Processing)
where
bsLen = S.length bs
mnl = S.elemIndex 10 bs
checkCR :: ByteString -> Int -> Int
checkCR bs pos =
let !p = pos 1
in if '\r' == B.index bs p
then p
else pos
readInt :: S.ByteString -> Integer
readInt = S.foldl' (\x w -> x * 10 + fromIntegral w 48) 0
withManager :: Int
-> (Manager -> IO a)
-> IO a
withManager timeout f = do
man <- T.initialize timeout
f man
serverHeader :: H.RequestHeaders -> H.RequestHeaders
serverHeader hdrs = case lookup key hdrs of
Nothing -> server : hdrs
Just _ -> hdrs
where
key = "Server"
ver = B.pack $ "Warp/" ++ warpVersion
server = (key, ver)