module Network.Wai.Handler.Warp.Run where
#if __GLASGOW_HASKELL__ < 709
import Control.Applicative ((<$>))
#endif
import Control.Arrow (first)
import Control.Concurrent (threadDelay)
import qualified Control.Concurrent as Conc (yield)
import Control.Exception as E
import Control.Monad (when, unless, void)
import Data.ByteString (ByteString)
import qualified Data.ByteString as S
import Data.Char (chr)
import Data.IP (toHostAddress, toHostAddress6)
import Data.IORef (IORef, newIORef, readIORef, writeIORef)
import Data.Streaming.Network (bindPortTCP)
import Network (sClose, Socket)
import Network.Socket (accept, withSocketsDo, SockAddr(SockAddrInet, SockAddrInet6))
import qualified Network.Socket.ByteString as Sock
import Network.Wai
import Network.Wai.HTTP2 (HTTP2Application, promoteApplication)
import Network.Wai.Handler.Warp.Buffer
import Network.Wai.Handler.Warp.Counter
import qualified Network.Wai.Handler.Warp.Date as D
import qualified Network.Wai.Handler.Warp.FdCache as F
import Network.Wai.Handler.Warp.HTTP2 (http2, isHTTP2)
import Network.Wai.Handler.Warp.Header
import Network.Wai.Handler.Warp.ReadInt
import Network.Wai.Handler.Warp.Recv
import Network.Wai.Handler.Warp.Request
import Network.Wai.Handler.Warp.Response
import Network.Wai.Handler.Warp.SendFile
import Network.Wai.Handler.Warp.Settings
import qualified Network.Wai.Handler.Warp.Timeout as T
import Network.Wai.Handler.Warp.Types
import Network.Wai.Internal (ResponseReceived (ResponseReceived))
import System.Environment (getEnvironment)
import System.IO.Error (isFullErrorType, ioeGetErrorType)
#if WINDOWS
import Network.Wai.Handler.Warp.Windows
#else
import System.Posix.IO (FdOption(CloseOnExec), setFdOption)
import Network.Socket (fdSocket)
#endif
socketConnection :: Socket -> IO Connection
socketConnection s = do
    bufferPool <- newBufferPool
    writeBuf <- allocateBuffer bufferSize
    let sendall = Sock.sendAll s
    return Connection {
        connSendMany = Sock.sendMany s
      , connSendAll = sendall
      , connSendFile = sendFile s writeBuf bufferSize sendall
      , connClose = sClose s >> freeBuffer writeBuf
      , connRecv = receive s bufferPool
      , connRecvBuf = receiveBuf s
      , connWriteBuffer = writeBuf
      , connBufferSize = bufferSize
      }
#if __GLASGOW_HASKELL__ < 702
allowInterrupt :: IO ()
allowInterrupt = unblock $ return ()
#endif
(.:) :: (c -> d) -> (a -> b -> c) -> a -> b -> d
f .: g = curry $ f . uncurry g
run :: Port -> Application -> IO ()
run port = runServe port . serveDefault
runHTTP2 :: Port -> HTTP2Application -> Application -> IO ()
runHTTP2 port = runServe port .: serveHTTP2
runServe :: Port -> ServeConnection -> IO ()
runServe p = runServeSettings defaultSettings { settingsPort = p }
runEnv :: Port -> Application -> IO ()
runEnv port = runServeEnv port . serveDefault
runHTTP2Env :: Port -> HTTP2Application -> Application -> IO ()
runHTTP2Env port = runServeEnv port .: serveHTTP2
runServeEnv :: Port -> ServeConnection -> IO ()
runServeEnv p serveConn = do
    mp <- lookup "PORT" <$> getEnvironment
    maybe (runServe p serveConn) runReadPort mp
  where
    runReadPort :: String -> IO ()
    runReadPort sp = case reads sp of
        ((p', _):_) -> runServe p' serveConn
        _ -> fail $ "Invalid value in $PORT: " ++ sp
runSettings :: Settings -> Application -> IO ()
runSettings set = runServeSettings set . serveDefault
runHTTP2Settings :: Settings -> HTTP2Application -> Application -> IO ()
runHTTP2Settings set = runServeSettings set .: serveHTTP2
runServeSettings :: Settings -> ServeConnection -> IO ()
runServeSettings set serveConn = withSocketsDo $
    bracket
        (bindPortTCP (settingsPort set) (settingsHost set))
        sClose
        (\socket -> do
            setSocketCloseOnExec socket
            runServeSettingsSocket set socket serveConn)
runSettingsSocket :: Settings -> Socket -> Application -> IO ()
runSettingsSocket set socket = runServeSettingsSocket set socket . serveDefault
runHTTP2SettingsSocket :: Settings
                       -> Socket
                       -> HTTP2Application
                       -> Application
                       -> IO ()
runHTTP2SettingsSocket set socket =
    runServeSettingsSocket set socket .: serveHTTP2
runServeSettingsSocket :: Settings -> Socket -> ServeConnection -> IO ()
runServeSettingsSocket set socket serveConn = do
    settingsInstallShutdownHandler set closeListenSocket
    runServeSettingsConnection set getConn serveConn
  where
    getConn = do
#if WINDOWS
        (s, sa) <- windowsThreadBlockHack $ accept socket
#else
        (s, sa) <- accept socket
#endif
        setSocketCloseOnExec s
        conn <- socketConnection s
        return (conn, sa)
    closeListenSocket = sClose socket
runSettingsConnection :: Settings -> IO (Connection, SockAddr) -> Application -> IO ()
runSettingsConnection set getConn =
    runServeSettingsConnection set getConn . serveDefault
runServeSettingsConnection :: Settings
                           -> IO (Connection, SockAddr)
                           -> ServeConnection
                           -> IO ()
runServeSettingsConnection set getConn serveConn =
    runServeSettingsConnectionMaker set getConnMaker serveConn
  where
    getConnMaker = do
      (conn, sa) <- getConn
      return (return conn, sa)
runSettingsConnectionMaker :: Settings -> IO (IO Connection, SockAddr) -> Application -> IO ()
runSettingsConnectionMaker set getConnMaker =
    runServeSettingsConnectionMaker set getConnMaker . serveDefault
runServeSettingsConnectionMaker :: Settings
                                -> IO (IO Connection, SockAddr)
                                -> ServeConnection
                                -> IO ()
runServeSettingsConnectionMaker x y =
    runServeSettingsConnectionMakerSecure x (toTCP <$> y)
  where
    toTCP = first ((, TCP) <$>)
runSettingsConnectionMakerSecure :: Settings -> IO (IO (Connection, Transport), SockAddr) -> Application -> IO ()
runSettingsConnectionMakerSecure set getConnMaker =
    runServeSettingsConnectionMakerSecure set getConnMaker . serveDefault
runServeSettingsConnectionMakerSecure :: Settings
                                      -> IO (IO (Connection, Transport), SockAddr)
                                      -> ServeConnection
                                      -> IO ()
runServeSettingsConnectionMakerSecure set getConnMaker serveConn = do
    settingsBeforeMainLoop set
    counter <- newCounter
    D.withDateCache $ \dc ->
        F.withFdCache fdCacheDurationInSeconds $ \fc ->
            withTimeoutManager $ \tm ->
                acceptConnection set getConnMaker serveConn dc fc tm counter
  where
    fdCacheDurationInSeconds = settingsFdCacheDuration set * 1000000
    withTimeoutManager f = case settingsManager set of
        Just tm -> f tm
        Nothing -> bracket
                   (T.initialize $ settingsTimeout set * 1000000)
                   T.stopManager
                   f
acceptConnection :: Settings
                 -> IO (IO (Connection, Transport), SockAddr)
                 -> ServeConnection
                 -> D.DateCache
                 -> Maybe F.MutableFdCache
                 -> T.Manager
                 -> Counter
                 -> IO ()
acceptConnection set getConnMaker serveConn dc fc tm counter = do
    
    
    
    void $ mask_ acceptLoop
    gracefulShutdown counter
  where
    acceptLoop = do
        
        allowInterrupt
        
        
        
        
        
        
        
        mx <- acceptNewConnection
        case mx of
            Nothing             -> return ()
            Just (mkConn, addr) -> do
                fork set mkConn addr serveConn dc fc tm counter
                acceptLoop
    acceptNewConnection = do
        ex <- try getConnMaker
        case ex of
            Right x -> return $ Just x
            Left  e  -> do
                settingsOnException set Nothing $ toException e
                if isFullErrorType (ioeGetErrorType e) then do
                    
                    
                    
                    threadDelay 1000000
                    acceptNewConnection
                  else
                    
                    return Nothing
fork :: Settings
     -> IO (Connection, Transport)
     -> SockAddr
     -> ServeConnection
     -> D.DateCache
     -> Maybe F.MutableFdCache
     -> T.Manager
     -> Counter
     -> IO ()
fork set mkConn addr serveConn dc fc tm counter = settingsFork set $ \ unmask ->
    
    
    
    
    
    
    
    
    
    
    bracket mkConn closeConn $ \(conn, transport) ->
    
    
    bracket (T.registerKillThread tm) T.cancel $ \th ->
    let ii = InternalInfo th tm fc dc
        
        
        
    in unmask .
       
       
       handle (settingsOnException set Nothing) .
       
       bracket (onOpen addr) (onClose addr) $ \goingon ->
       
       
       when goingon $ serveConn conn ii addr transport set
  where
    closeConn (conn, _transport) = connClose conn
    onOpen adr    = increase counter >> settingsOnOpen  set adr
    onClose adr _ = decrease counter >> settingsOnClose set adr
type ServeConnection = Connection
                    -> InternalInfo
                    -> SockAddr
                    -> Transport
                    -> Settings
                    -> IO ()
serveDefault :: Application -> ServeConnection
serveDefault app = serveHTTP2 (promoteApplication app) app
serveHTTP2 :: HTTP2Application -> Application -> ServeConnection
serveHTTP2 app2 app conn ii origAddr transport settings = do
    
    (h2,bs) <- if isHTTP2 transport then
                   return (True, "")
                 else do
                   bs0 <- connRecv conn
                   if S.length bs0 >= 4 && "PRI " `S.isPrefixOf` bs0 then
                       return (True, bs0)
                     else
                       return (False, bs0)
    if settingsHTTP2Enabled settings && h2 then do
        recvN <- makeReceiveN bs (connRecv conn) (connRecvBuf conn)
        
        http2 conn ii origAddr transport settings recvN app2
      else do
        istatus <- newIORef False
        src <- mkSource (wrappedRecv conn th istatus (settingsSlowlorisSize settings))
        writeIORef istatus True
        leftoverSource src bs
        addr <- getProxyProtocolAddr src
        http1 addr istatus src `E.catch` \e -> do
            sendErrorResponse addr istatus e
            throwIO (e :: SomeException)
  where
    getProxyProtocolAddr src =
        case settingsProxyProtocol settings of
            ProxyProtocolNone ->
                return origAddr
            ProxyProtocolRequired -> do
                seg <- readSource src
                parseProxyProtocolHeader src seg
            ProxyProtocolOptional -> do
                seg <- readSource src
                if S.isPrefixOf "PROXY " seg
                    then parseProxyProtocolHeader src seg
                    else do leftoverSource src seg
                            return origAddr
    parseProxyProtocolHeader src seg = do
        let (header,seg') = S.break (== 0x0d) seg 
            maybeAddr = case S.split 0x20 header of 
                ["PROXY","TCP4",clientAddr,_,clientPort,_] ->
                    case [x | (x, t) <- reads (decodeAscii clientAddr), null t] of
                        [a] -> Just (SockAddrInet (readInt clientPort)
                                                       (toHostAddress a))
                        _ -> Nothing
                ["PROXY","TCP6",clientAddr,_,clientPort,_] ->
                    case [x | (x, t) <- reads (decodeAscii clientAddr), null t] of
                        [a] -> Just (SockAddrInet6 (readInt clientPort)
                                                        0
                                                        (toHostAddress6 a)
                                                        0)
                        _ -> Nothing
                ("PROXY":"UNKNOWN":_) ->
                    Just origAddr
                _ ->
                    Nothing
        case maybeAddr of
            Nothing -> throwIO (BadProxyHeader (decodeAscii header))
            Just a -> do leftoverSource src (S.drop 2 seg') 
                         return a
    decodeAscii = map (chr . fromEnum) . S.unpack
    th = threadHandle ii
    sendErrorResponse addr istatus e = do
        status <- readIORef istatus
        when status $ void $
            sendResponse
                (settingsServerName settings)
                conn ii (dummyreq addr) defaultIndexRequestHeader (return S.empty) (errorResponse e)
    dummyreq addr = defaultRequest { remoteHost = addr }
    errorResponse e = settingsOnExceptionResponse settings e
    http1 addr istatus src = do
        (req', mremainingRef, idxhdr, nextBodyFlush) <- recvRequest settings conn ii addr src
        let req = req' { isSecure = isTransportSecure transport }
        keepAlive <- processRequest istatus src req mremainingRef idxhdr nextBodyFlush
            `E.catch` \e -> do
                
                sendErrorResponse addr istatus e
                settingsOnException settings (Just req) e
                
                return False
        when keepAlive $ http1 addr istatus src
    processRequest istatus src req mremainingRef idxhdr nextBodyFlush = do
        
        T.pause th
        
        
        
        keepAliveRef <- newIORef $ error "keepAliveRef not filled"
        _ <- app req $ \res -> do
            T.resume th
            
            
            
            writeIORef istatus False
            keepAlive <- sendResponse
                (settingsServerName settings)
                conn ii req idxhdr (readSource src) res
            writeIORef keepAliveRef keepAlive
            return ResponseReceived
        keepAlive <- readIORef keepAliveRef
        
        
        
        
        
        
        
        
        Conc.yield
        if not keepAlive then
            return False
          else
            
            
            
            case settingsMaximumBodyFlush settings of
                Nothing -> do
                    flushEntireBody nextBodyFlush
                    T.resume th
                    return True
                Just maxToRead -> do
                    let tryKeepAlive = do
                            
                            isComplete <- flushBody nextBodyFlush maxToRead
                            if isComplete then do
                                T.resume th
                                return True
                              else
                                return False
                    case mremainingRef of
                        Just ref -> do
                            remaining <- readIORef ref
                            if remaining <= maxToRead then
                                tryKeepAlive
                              else
                                return False
                        Nothing -> tryKeepAlive
flushEntireBody :: IO ByteString -> IO ()
flushEntireBody src =
    loop
  where
    loop = do
        bs <- src
        unless (S.null bs) loop
flushBody :: IO ByteString 
          -> Int 
          -> IO Bool 
flushBody src =
    loop
  where
    loop toRead = do
        bs <- src
        let toRead' = toRead  S.length bs
        case () of
            ()
                | S.null bs -> return True
                | toRead' >= 0 -> loop toRead'
                | otherwise -> return False
wrappedRecv :: Connection -> T.Handle -> IORef Bool -> Int -> IO ByteString
wrappedRecv Connection { connRecv = recv } th istatus slowlorisSize = do
    bs <- recv
    unless (S.null bs) $ do
        writeIORef istatus True
        when (S.length bs >= slowlorisSize) $ T.tickle th
    return bs
setSocketCloseOnExec :: Socket -> IO ()
#if WINDOWS
setSocketCloseOnExec _ = return ()
#else
setSocketCloseOnExec socket =
    setFdOption (fromIntegral $ fdSocket socket) CloseOnExec True
#endif
gracefulShutdown :: Counter -> IO ()
gracefulShutdown counter = waitForZero counter