{-# LANGUAGE CPP #-} {-# LANGUAGE DeriveDataTypeable #-} {-# LANGUAGE ForeignFunctionInterface #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE PackageImports #-} {-# LANGUAGE Rank2Types #-} {-# LANGUAGE ScopedTypeVariables #-} module Snap.Internal.Http.Server.SimpleBackend ( simpleEventLoop ) where ------------------------------------------------------------------------------ import Control.Monad.Trans import Control.Concurrent hiding (yield) import Control.Concurrent.Extended (forkOnLabeledWithUnmaskBs) import Control.Exception import Control.Monad import Data.ByteString (ByteString) import qualified Data.ByteString as S import qualified Data.ByteString.Char8 as SC import Data.ByteString.Internal (c2w) import Foreign hiding (new) import Foreign.C #if MIN_VERSION_base(4,4,0) import GHC.Conc (forkOn, labelThread) #else import GHC.Conc (forkOnIO, labelThread) #endif import Network.Socket #if !MIN_VERSION_base(4,6,0) import Prelude hiding (catch) #endif ------------------------------------------------------------------------------ import Snap.Internal.Debug import Snap.Internal.Http.Server.Address import Snap.Internal.Http.Server.Backend import Snap.Internal.Http.Server.Date import qualified Snap.Internal.Http.Server.ListenHelpers as Listen import Snap.Internal.Http.Server.TimeoutManager (TimeoutManager) import qualified Snap.Internal.Http.Server.TimeoutManager as TM import Snap.Iteratee hiding (map) #if defined(HAS_SENDFILE) import System.Posix.IO import System.Posix.Types (Fd (..)) import qualified System.SendFile as SF #endif ------------------------------------------------------------------------------ #if !MIN_VERSION_base(4,4,0) forkOn :: Int -> IO () -> IO ThreadId forkOn = forkOnIO #endif ------------------------------------------------------------------------------ -- | For each cpu, we store: -- * A list of accept threads, one per port. -- * A TimeoutManager -- * An mvar to signal when the timeout thread is shutdown data EventLoopCpu = EventLoopCpu { _boundCpu :: Int , _acceptThreads :: [ThreadId] , _timeoutManager :: TimeoutManager , _exitMVar :: !(MVar ()) } ------------------------------------------------------------------------------ simpleEventLoop :: EventLoop simpleEventLoop defaultTimeout sockets cap elog initial handler = do loops <- Prelude.mapM (newLoop defaultTimeout sockets handler elog) [0..(cap-1)] initial debug "simpleEventLoop: waiting for mvars" --wait for all threads to exit Prelude.mapM_ (takeMVar . _exitMVar) loops `finally` do debug "simpleEventLoop: killing all threads" _ <- mapM_ stopLoop loops mapM_ Listen.closeSocket sockets ------------------------------------------------------------------------------ newLoop :: Int -> [ListenSocket] -> SessionHandler -> (S.ByteString -> IO ()) -> Int -> IO EventLoopCpu newLoop defaultTimeout sockets handler elog cpu = do tmgr <- TM.initialize defaultTimeout getCurrentDateTime exit <- newEmptyMVar accThreads <- forM sockets $ \p -> do let label = S.concat [ "snap-server: ", SC.pack (show p) , " on capability: ", SC.pack (show cpu) ] forkOnLabeledWithUnmaskBs label cpu $ \unmask -> acceptThread defaultTimeout handler tmgr elog cpu p unmask `finally` (tryPutMVar exit () >> return ()) return $! EventLoopCpu cpu accThreads tmgr exit ------------------------------------------------------------------------------ stopLoop :: EventLoopCpu -> IO () stopLoop loop = mask_ $ do TM.stop $ _timeoutManager loop Prelude.mapM_ killThread $ _acceptThreads loop ------------------------------------------------------------------------------ acceptThread :: Int -> SessionHandler -> TimeoutManager -> (S.ByteString -> IO ()) -> Int -> ListenSocket -> (forall a. IO a -> IO a) -> IO () acceptThread defaultTimeout handler tmgr elog cpu sock unmask = loop where loop = do unmask (forever acceptAndFork) `catches` acceptHandler loop acceptAndFork = do debug $ "acceptThread: calling accept() on socket " ++ show sock (s,addr) <- accept $ Listen.listenSocket sock setSocketOption s NoDelay 1 debug $ "acceptThread: accepted connection from remote: " ++ show addr let label = S.concat [ "snap-server: connection from: " , SC.pack (show addr) , " on socket: " , SC.pack (show (fdSocket s)) , "\0" ] _ <- forkOnLabeledWithUnmaskBs label cpu $ \unmask' -> unmask' (runSession defaultTimeout handler tmgr sock s addr) `catches` cleanup return () acceptHandler = [ Handler $ \(e :: AsyncException) -> throwIO e , Handler $ \(e :: SomeException) -> do elog $ S.concat [ "SimpleBackend.acceptThread: accept threw: " , S.pack . map c2w $ show e ] -- we're out of file descriptors, and it isn't likely to get -- better immediately; sleep for 10ms to avoid spamming the error -- log. threadDelay $ 10000 ] cleanup = [ Handler $ \(e :: AsyncException) -> case e of ThreadKilled -> return () UserInterrupt -> return () _ -> throwIO e -- This ensures all other asynchronous exceptions -- (StackOverflow and HeapOverflow) are logged to -- stderr by forkIO. , Handler $ \(e :: SomeException) -> elog $ S.concat [ "SimpleBackend.acceptThread: " , S.pack . map c2w $ show e] ] ------------------------------------------------------------------------------ runSession :: Int -> SessionHandler -> TimeoutManager -> ListenSocket -> Socket -> SockAddr -> IO () runSession defaultTimeout handler tmgr lsock sock addr = do let fd = fdSocket sock curId <- myThreadId debug $ "Backend.withConnection: running session: " ++ show addr (rport,rhost) <- getAddress addr (lport,lhost) <- getSocketName sock >>= getAddress let sinfo = SessionInfo lhost lport rhost rport $ Listen.isSecure lsock timeoutHandle <- TM.register (killThread curId) tmgr let modifyTimeout = TM.modify timeoutHandle let tickleTimeout = modifyTimeout . max bracket (Listen.createSession lsock 8192 fd (threadWaitRead $ fromIntegral fd)) (\session -> mask_ $ do debug "thread killed, closing socket" -- cancel thread timeout TM.cancel timeoutHandle eatException $ Listen.endSession lsock session eatException $ shutdown sock ShutdownBoth eatException $ sClose sock ) (\s -> let writeEnd = writeOut lsock s sock (tickleTimeout defaultTimeout) in handler sinfo (enumerate lsock s sock) writeEnd (sendFile lsock (tickleTimeout defaultTimeout) fd writeEnd) modifyTimeout ) ------------------------------------------------------------------------------ eatException :: IO a -> IO () eatException act = (act >> return ()) `catch` \(_::SomeException) -> return () ------------------------------------------------------------------------------ sendFile :: ListenSocket -> IO () -> CInt -> Iteratee ByteString IO () -> FilePath -> Int64 -> Int64 -> IO () #if defined(HAS_SENDFILE) sendFile lsock tickle sock writeEnd fp start sz = case lsock of ListenHttp _ -> bracket (openFd fp ReadOnly Nothing defaultFileFlags) (closeFd) (go start sz) _ -> do step <- runIteratee writeEnd run_ $ enumFilePartial fp (start,start+sz) step where go off bytes fd | bytes == 0 = return () | otherwise = do sent <- SF.sendFile (threadWaitWrite $ fromIntegral sock) sfd fd off bytes if sent < bytes then tickle >> go (off+sent) (bytes-sent) fd else return () sfd = Fd sock #else sendFile _ _ _ writeEnd fp start sz = do -- no need to count bytes step <- runIteratee writeEnd run_ $ enumFilePartial fp (start,start+sz) step return () #endif ------------------------------------------------------------------------------ enumerate :: (MonadIO m) => ListenSocket -> NetworkSession -> Socket -> Enumerator ByteString m a enumerate port session sock = loop where dbg s = debug $ "SimpleBackend.enumerate(" ++ show (_socket session) ++ "): " ++ s loop (Continue k) = do dbg "reading from socket" s <- liftIO $ timeoutRecv case s of Nothing -> do dbg "got EOF from socket" sendOne k "" Just s' -> do dbg $ "got " ++ Prelude.show (S.length s') ++ " bytes from read end" sendOne k s' loop x = returnI x sendOne k s | S.null s = do dbg "sending EOF to continuation" enumEOF $ Continue k | otherwise = do dbg $ "sending " ++ show s ++ " to continuation" step <- lift $ runIteratee $ k $ Chunks [s] case step of (Yield x st) -> do dbg $ "got yield, remainder is " ++ show st yield x st r@(Continue _) -> do dbg $ "got continue" loop r (Error e) -> throwError e fd = fdSocket sock #ifdef PORTABLE timeoutRecv = Listen.recv port sock (threadWaitRead $ fromIntegral fd) session #else timeoutRecv = Listen.recv port (threadWaitRead $ fromIntegral fd) session #endif ------------------------------------------------------------------------------ writeOut :: (MonadIO m) => ListenSocket -> NetworkSession -> Socket -> (IO ()) -> Iteratee ByteString m () writeOut port session sock tickle = loop where dbg s = debug $ "SimpleBackend.writeOut(" ++ show (_socket session) ++ "): " ++ s loop = continue k k EOF = yield () EOF k (Chunks xs) = do let s = S.concat xs let n = S.length s dbg $ "got chunk with " ++ show n ++ " bytes" ee <- liftIO $ try $ timeoutSend s case ee of (Left (e::SomeException)) -> do dbg $ "timeoutSend got error " ++ show e throwError e (Right _) -> do let last10 = S.drop (n-10) s dbg $ "wrote " ++ show n ++ " bytes, last 10=" ++ show last10 loop fd = fdSocket sock #ifdef PORTABLE timeoutSend = Listen.send port sock tickle (threadWaitWrite $ fromIntegral fd) session #else timeoutSend = Listen.send port tickle (threadWaitWrite $ fromIntegral fd) session #endif