{-# LANGUAGE BangPatterns #-} {-# LANGUAGE CPP #-} {-# LANGUAGE DeriveDataTypeable #-} {-# LANGUAGE ForeignFunctionInterface #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE PackageImports #-} {-# LANGUAGE Rank2Types #-} {-# LANGUAGE ScopedTypeVariables #-} module Snap.Internal.Http.Server.SimpleBackend ( simpleEventLoop ) where ------------------------------------------------------------------------------ import Control.Monad.Trans import Control.Concurrent hiding (yield) import Control.Exception import Control.Monad import Data.ByteString (ByteString) import qualified Data.ByteString as S import Data.ByteString.Internal (c2w) import Data.Maybe import Foreign hiding (new) import Foreign.C import GHC.Conc (labelThread, forkOnIO) import Network.Socket import Prelude hiding (catch) ------------------------------------------------------------------------------ import Snap.Internal.Debug import Snap.Internal.Http.Server.Date import qualified Snap.Internal.Http.Server.TimeoutManager as TM import Snap.Internal.Http.Server.TimeoutManager (TimeoutManager) import Snap.Internal.Http.Server.Backend import Snap.Internal.Http.Server.Address import qualified Snap.Internal.Http.Server.ListenHelpers as Listen import Snap.Iteratee hiding (map) #if defined(HAS_SENDFILE) import qualified System.SendFile as SF import System.Posix.IO import System.Posix.Types (Fd(..)) #endif ------------------------------------------------------------------------------ -- | For each cpu, we store: -- * A list of accept threads, one per port. -- * A TimeoutManager -- * An mvar to signal when the timeout thread is shutdown data EventLoopCpu = EventLoopCpu { _boundCpu :: Int , _acceptThreads :: [ThreadId] , _timeoutManager :: TimeoutManager , _exitMVar :: !(MVar ()) } ------------------------------------------------------------------------------ simpleEventLoop :: EventLoop simpleEventLoop defaultTimeout sockets cap elog handler = do loops <- Prelude.mapM (newLoop defaultTimeout sockets handler elog) [0..(cap-1)] debug "simpleEventLoop: waiting for mvars" --wait for all threads to exit Prelude.mapM_ (takeMVar . _exitMVar) loops `finally` do debug "simpleEventLoop: killing all threads" _ <- mapM_ stopLoop loops mapM_ Listen.closeSocket sockets ------------------------------------------------------------------------------ newLoop :: Int -> [ListenSocket] -> SessionHandler -> (S.ByteString -> IO ()) -> Int -> IO EventLoopCpu newLoop defaultTimeout sockets handler elog cpu = do tmgr <- TM.initialize defaultTimeout getCurrentDateTime exit <- newEmptyMVar accThreads <- forM sockets $ \p -> forkOnIO cpu $ acceptThread defaultTimeout handler tmgr elog cpu p exit return $ EventLoopCpu cpu accThreads tmgr exit ------------------------------------------------------------------------------ stopLoop :: EventLoopCpu -> IO () stopLoop loop = block $ do TM.stop $ _timeoutManager loop Prelude.mapM_ killThread $ _acceptThreads loop ------------------------------------------------------------------------------ acceptThread :: Int -> SessionHandler -> TimeoutManager -> (S.ByteString -> IO ()) -> Int -> ListenSocket -> MVar () -> IO () acceptThread defaultTimeout handler tmgr elog cpu sock exitMVar = loop `finally` (tryPutMVar exitMVar () >> return ()) where loop = do debug $ "acceptThread: calling accept() on socket " ++ show sock (s,addr) <- accept $ Listen.listenSocket sock debug $ "acceptThread: accepted connection from remote: " ++ show addr _ <- forkOnIO cpu (go s addr `catches` cleanup) loop go = runSession defaultTimeout handler tmgr sock cleanup = [ Handler $ \(_ :: AsyncException) -> return () , Handler $ \(e :: SomeException) -> elog $ S.concat [ "SimpleBackend.acceptThread: " , S.pack . map c2w $ show e] ] ------------------------------------------------------------------------------ runSession :: Int -> SessionHandler -> TimeoutManager -> ListenSocket -> Socket -> SockAddr -> IO () runSession defaultTimeout handler tmgr lsock sock addr = do let fd = fdSocket sock curId <- myThreadId debug $ "Backend.withConnection: running session: " ++ show addr labelThread curId $ "connHndl " ++ show fd (rport,rhost) <- getAddress addr (lport,lhost) <- getSocketName sock >>= getAddress let sinfo = SessionInfo lhost lport rhost rport $ Listen.isSecure lsock timeoutHandle <- TM.register (killThread curId) tmgr let setTimeout = TM.set timeoutHandle let tickleTimeout = TM.tickle timeoutHandle bracket (Listen.createSession lsock 8192 fd (threadWaitRead $ fromIntegral fd)) (\session -> block $ do debug "thread killed, closing socket" -- cancel thread timeout TM.cancel timeoutHandle eatException $ Listen.endSession lsock session eatException $ shutdown sock ShutdownBoth eatException $ sClose sock ) (\s -> let writeEnd = writeOut lsock s sock (tickleTimeout defaultTimeout) in handler sinfo (enumerate lsock s sock) writeEnd (sendFile lsock (tickleTimeout defaultTimeout) fd writeEnd) setTimeout ) ------------------------------------------------------------------------------ eatException :: IO a -> IO () eatException act = (act >> return ()) `catch` \(_::SomeException) -> return () ------------------------------------------------------------------------------ sendFile :: ListenSocket -> IO () -> CInt -> Iteratee ByteString IO () -> FilePath -> Int64 -> Int64 -> IO () #if defined(HAS_SENDFILE) sendFile lsock tickle sock writeEnd fp start sz = case lsock of ListenHttp _ -> bracket (openFd fp ReadOnly Nothing defaultFileFlags) (closeFd) (go start sz) _ -> do step <- runIteratee writeEnd run_ $ enumFilePartial fp (start,start+sz) step where go off bytes fd | bytes == 0 = return () | otherwise = do sent <- SF.sendFile (threadWaitWrite $ fromIntegral sock) sfd fd off bytes if sent < bytes then tickle >> go (off+sent) (bytes-sent) fd else return () sfd = Fd sock #else sendFile _ _ _ writeEnd fp start sz = do -- no need to count bytes step <- runIteratee writeEnd run_ $ enumFilePartial fp (start,start+sz) step return () #endif ------------------------------------------------------------------------------ enumerate :: (MonadIO m) => ListenSocket -> NetworkSession -> Socket -> Enumerator ByteString m a enumerate port session sock = loop where dbg s = debug $ "SimpleBackend.enumerate(" ++ show (_socket session) ++ "): " ++ s loop (Continue k) = do dbg "reading from socket" s <- liftIO $ timeoutRecv case s of Nothing -> do dbg "got EOF from socket" sendOne k "" Just s' -> do dbg $ "got " ++ Prelude.show (S.length s') ++ " bytes from read end" sendOne k s' loop x = returnI x sendOne k s | S.null s = do dbg "sending EOF to continuation" enumEOF $ Continue k | otherwise = do dbg $ "sending " ++ show s ++ " to continuation" step <- lift $ runIteratee $ k $ Chunks [s] case step of (Yield x st) -> do dbg $ "got yield, remainder is " ++ show st yield x st r@(Continue _) -> do dbg $ "got continue" loop r (Error e) -> throwError e fd = fdSocket sock #ifdef PORTABLE timeoutRecv = Listen.recv port sock (threadWaitRead $ fromIntegral fd) session #else timeoutRecv = Listen.recv port (threadWaitRead $ fromIntegral fd) session #endif ------------------------------------------------------------------------------ writeOut :: (MonadIO m) => ListenSocket -> NetworkSession -> Socket -> (IO ()) -> Iteratee ByteString m () writeOut port session sock tickle = loop where dbg s = debug $ "SimpleBackend.writeOut(" ++ show (_socket session) ++ "): " ++ s loop = continue k k EOF = yield () EOF k (Chunks xs) = do let s = S.concat xs let n = S.length s dbg $ "got chunk with " ++ show n ++ " bytes" ee <- liftIO $ try $ timeoutSend s case ee of (Left (e::SomeException)) -> do dbg $ "timeoutSend got error " ++ show e throwError e (Right _) -> do let last10 = S.drop (n-10) s dbg $ "wrote " ++ show n ++ " bytes, last 10=" ++ show last10 loop fd = fdSocket sock #ifdef PORTABLE timeoutSend = Listen.send port sock tickle (threadWaitWrite $ fromIntegral fd) session #else timeoutSend = Listen.send port tickle (threadWaitWrite $ fromIntegral fd) session #endif