module Network.Wai.Handler.SimpleServer
( run
) where
import Network.Wai
import Network.Wai.Handler.Helper
import qualified System.IO
import Data.ByteString (ByteString)
import qualified Data.ByteString.Char8 as B
import Network
( listenOn, accept, sClose, PortID(PortNumber), Socket
, withSocketsDo)
import Control.Exception (bracket, finally, Exception, throwIO)
import System.IO (Handle, hClose)
import Control.Concurrent (forkIO)
import Control.Monad (unless)
import Data.Maybe (isJust, fromJust, fromMaybe)
import Data.Typeable (Typeable)
import Network.Socket.SendFile
import Control.Arrow (first)
run :: Port -> Application -> IO ()
run port = withSocketsDo .
bracket
(listenOn $ PortNumber $ fromIntegral port)
sClose .
serveConnections port
type Port = Int
serveConnections :: Port -> Application -> Socket -> IO ()
serveConnections port app socket = do
(conn, remoteHost', _) <- accept socket
_ <- forkIO $ serveConnection port app conn remoteHost'
serveConnections port app socket
serveConnection :: Port -> Application -> Handle -> String -> IO ()
serveConnection port app conn remoteHost' =
finally
serveConnection'
(hClose conn)
where
serveConnection' = do
env <- hParseRequest port conn remoteHost'
res <- app env
sendResponse (httpVersion env) conn res
hParseRequest :: Port -> Handle -> String -> IO Request
hParseRequest port conn remoteHost' = do
headers' <- takeUntilBlank conn id
parseRequest port headers' conn remoteHost'
takeUntilBlank :: Handle
-> ([ByteString] -> [ByteString])
-> IO [ByteString]
takeUntilBlank h front = do
l <- stripCR `fmap` B.hGetLine h
if B.null l
then return $ front []
else takeUntilBlank h $ front . (:) l
stripCR :: ByteString -> ByteString
stripCR bs
| B.null bs = bs
| B.last bs == '\r' = B.init bs
| otherwise = bs
data InvalidRequest =
NotEnoughLines [String]
| HostNotIncluded
| BadFirstLine String
| NonHttp
deriving (Show, Typeable)
instance Exception InvalidRequest
parseRequest :: Port
-> [ByteString]
-> Handle
-> String
-> IO Request
parseRequest port lines' handle remoteHost' = do
case lines' of
(_:_:_) -> return ()
_ -> throwIO $ NotEnoughLines $ map B.unpack lines'
(method', rpath', gets, httpversion) <- parseFirst $ head lines'
let method = methodFromBS method'
let rpath = '/' : case B.unpack rpath' of
('/':x) -> x
_ -> B.unpack rpath'
let heads = map (first requestHeaderFromBS . parseHeaderNoAttr)
$ tail lines'
let host' = lookup Host heads
unless (isJust host') $ throwIO HostNotIncluded
let host = fromJust host'
let len = fromMaybe 0 $ do
bs <- lookup ReqContentLength heads
let str = B.unpack bs
case reads str of
(x, _):_ -> Just x
_ -> Nothing
let (serverName', _) = B.break (== ':') host
return $ Request
{ requestMethod = method
, httpVersion = httpversion
, pathInfo = B.pack rpath
, queryString = gets
, serverName = serverName'
, serverPort = port
, requestHeaders = heads
, urlScheme = HTTP
, requestBody = requestBodyHandle handle len
, errorHandler = System.IO.hPutStr System.IO.stderr
, remoteHost = B.pack remoteHost'
}
parseFirst :: ByteString
-> IO (ByteString, ByteString, ByteString, HttpVersion)
parseFirst s = do
let pieces = B.words s
(method, query, http') <-
case pieces of
[x, y, z] -> return (x, y, z)
_ -> throwIO $ BadFirstLine $ B.unpack s
let (hfirst, hsecond) = B.splitAt 5 http'
unless (hfirst == B.pack "HTTP/") $ throwIO NonHttp
let (rpath, qstring) = B.break (== '?') query
return (method, rpath, qstring, httpVersionFromBS hsecond)
sendResponse :: HttpVersion -> Handle -> Response -> IO ()
sendResponse httpversion h res = do
B.hPut h $ B.pack "HTTP/"
B.hPut h $ httpVersionToBS httpversion
B.hPut h $ B.pack " "
B.hPut h $ B.pack $ show $ statusCode $ status res
B.hPut h $ B.pack " "
B.hPut h $ statusMessage $ status res
B.hPut h $ B.pack "\r\n"
mapM_ putHeader $ responseHeaders res
B.hPut h $ B.pack "\r\n"
case responseBody res of
Left fp -> unsafeSendFile h fp
Right (Enumerator enum) -> enum myPut h >> return ()
where
myPut _ bs = do
B.hPut h bs
return (Right h)
putHeader (x, y) = do
B.hPut h $ responseHeaderToBS x
B.hPut h $ B.pack ": "
B.hPut h y
B.hPut h $ B.pack "\r\n"
parseHeaderNoAttr :: ByteString -> (ByteString, ByteString)
parseHeaderNoAttr s =
let (k, rest) = B.span (/= ':') s
rest' = if not (B.null rest) &&
B.head rest == ':' &&
not (B.null $ B.tail rest) &&
B.head (B.tail rest) == ' '
then B.drop 2 rest
else rest
in (k, rest')