{-# LANGUAGE BangPatterns #-} {-# LANGUAGE CPP #-} {-# LANGUAGE DeriveDataTypeable #-} {-# LANGUAGE ForeignFunctionInterface #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE PackageImports #-} module Snap.Internal.Http.Server.SimpleBackend ( Backend , BackendTerminatedException(..) , Connection , TimeoutException(..) , name , debug , bindIt , new , stop , withConnection , sendFile , tickleTimeout , getReadEnd , getWriteEnd , getRemoteAddr , getRemotePort , getLocalAddr , getLocalPort ) where ------------------------------------------------------------------------------ import "monads-fd" Control.Monad.Trans import Control.Concurrent import Control.Exception import Control.Monad import Data.ByteString (ByteString) import Data.ByteString.Internal (c2w, w2c) import qualified Data.ByteString as B import Data.DList (DList) import qualified Data.DList as D import Data.IORef import Data.Iteratee.WrappedByteString import Data.List (foldl') import qualified Data.PSQueue as PSQ import Data.PSQueue (PSQ) import Data.Typeable import Foreign hiding (new) import Foreign.C.Types (CTime) import GHC.Conc (labelThread, forkOnIO) import Network.Socket import qualified Network.Socket.ByteString as SB import Prelude hiding (catch) ------------------------------------------------------------------------------ import Snap.Internal.Debug import Snap.Internal.Http.Server.Date import Snap.Iteratee hiding (foldl') #if defined(HAS_SENDFILE) import qualified System.SendFile as SF import System.Posix.IO import System.Posix.Types (Fd(..)) #endif data BackendTerminatedException = BackendTerminatedException deriving (Typeable) instance Show BackendTerminatedException where show (BackendTerminatedException) = "Backend terminated" instance Exception BackendTerminatedException type TimeoutTable = PSQ ThreadId CTime type QueueElem = Maybe (Socket,SockAddr) data Backend = Backend { _acceptSocket :: !Socket , _acceptThread :: !ThreadId , _timeoutEdits :: !(IORef (DList (TimeoutTable -> TimeoutTable))) , _timeoutThread :: !(MVar ThreadId) , _connectionQueue :: !(Chan QueueElem) } data Connection = Connection { _backend :: Backend , _socket :: Socket , _remoteAddr :: ByteString , _remotePort :: Int , _localAddr :: ByteString , _localPort :: Int , _connTid :: MVar ThreadId } {-# INLINE name #-} name :: ByteString name = "simple" sendFile :: Connection -> FilePath -> Int64 -> Int64 -> IO () #if defined(HAS_SENDFILE) sendFile c fp start sz = do bracket (openFd fp ReadOnly Nothing defaultFileFlags) (closeFd) (go start sz) where go off bytes fd | bytes == 0 = return () | otherwise = do sent <- SF.sendFile sfd fd off bytes if sent < bytes then tickleTimeout c >> go (off+sent) (bytes-sent) fd else return () sfd = Fd . fdSocket $ _socket c #else sendFile c fp start sz = do -- no need to count bytes enumFilePartial fp (start,start+sz) (getWriteEnd c) >>= run return () #endif bindIt :: ByteString -- ^ bind address, or \"*\" for all -> Int -- ^ port to bind to -> IO Socket bindIt bindAddress bindPort = do sock <- socket AF_INET Stream 0 addr <- getHostAddr bindPort bindAddress setSocketOption sock ReuseAddr 1 bindSocket sock addr listen sock 150 return sock acceptThread :: Socket -> Chan QueueElem -> IO () acceptThread sock connq = loop `finally` cleanup where loop = do debug $ "acceptThread: calling accept()" s@(_,addr) <- accept sock debug $ "acceptThread: accepted connection from remote: " ++ show addr debug $ "acceptThread: queueing" writeChan connq $ Just s loop cleanup = block $ do debug $ "acceptThread: cleanup, closing socket and notifying " ++ "chan listeners" sClose sock replicateM 10 $ writeChan connq Nothing new :: Socket -- ^ value you got from bindIt -> Int -> IO Backend new sock cpu = do debug $ "Backend.new: listening" ed <- newIORef D.empty t <- newEmptyMVar connq <- newChan accThread <- forkOnIO cpu $ acceptThread sock connq let b = Backend sock accThread ed t connq tid <- forkIO $ timeoutThread b putMVar t tid return b timeoutThread :: Backend -> IO () timeoutThread backend = do tref <- newIORef $ PSQ.empty let loop = do killTooOld tref threadDelay (5000000) loop loop `catch` (\(_::SomeException) -> killAll tref) where applyEdits table = do edits <- atomicModifyIORef tedits $ \t -> (D.empty, D.toList t) return $ foldl' (flip ($)) table edits killTooOld tref = do !table <- readIORef tref -- atomic swap edit list now <- getCurrentDateTime table' <- applyEdits table !t' <- killOlderThan now table' writeIORef tref t' -- timeout = 30 seconds tIMEOUT = 30 killAll !tref = do debug "Backend.timeoutThread: shutdown, killing all connections" !table <- readIORef tref !table' <- applyEdits table go table' where go !t = maybe (return ()) (\m -> (killThread $ PSQ.key m) >> (go $ PSQ.deleteMin t)) (PSQ.findMin t) killOlderThan now !table = do debug "Backend.timeoutThread: killing old connections" let mmin = PSQ.findMin table maybe (return table) (\m -> do debug $ "Backend.timeoutThread: minimum value " ++ show (PSQ.prio m) ++ ", cutoff=" ++ show (now - tIMEOUT) if now - PSQ.prio m >= tIMEOUT then do killThread $ PSQ.key m killOlderThan now $ PSQ.deleteMin table else return table) mmin tedits = _timeoutEdits backend stop :: Backend -> IO () stop backend = do debug $ "Backend.stop: killing accept thread" killThread acthr debug $ "Backend.stop: killing timeout thread" -- kill timeout thread; timeout thread handler will stop all of the running -- connection threads readMVar tthr >>= killThread debug $ "Backend.stop: exiting.." where acthr = _acceptThread backend tthr = _timeoutThread backend data AddressNotSupportedException = AddressNotSupportedException String deriving (Typeable) instance Show AddressNotSupportedException where show (AddressNotSupportedException x) = "Address not supported: " ++ x instance Exception AddressNotSupportedException withConnection :: Backend -> Int -> (Connection -> IO ()) -> IO () withConnection backend cpu proc = do debug $ "Backend.withConnection: reading from chan" qelem <- readChan $ _connectionQueue backend when (qelem == Nothing) $ do debug $ "Backend.withConnection: channel terminated, throwing " ++ "BackendTerminatedException" throwIO BackendTerminatedException let (Just (sock,addr)) = qelem let fd = fdSocket sock debug $ "Backend.withConnection: dequeued connection from remote: " ++ show addr (port,host) <- case addr of SockAddrInet p h -> do h' <- inet_ntoa h return (fromIntegral p, B.pack $ map c2w h') x -> throwIO $ AddressNotSupportedException $ show x laddr <- getSocketName sock (lport,lhost) <- case laddr of SockAddrInet p h -> do h' <- inet_ntoa h return (fromIntegral p, B.pack $ map c2w h') x -> throwIO $ AddressNotSupportedException $ show x tmvar <- newEmptyMVar let c = Connection backend sock host port lhost lport tmvar tid <- forkOnIO cpu $ do labelMe $ "connHndl " ++ show fd bracket (return c) (\_ -> block $ do debug "thread killed, closing socket" thr <- readMVar tmvar -- remove thread from timeout table atomicModifyIORef (_timeoutEdits backend) $ \es -> (D.snoc es (PSQ.delete thr), ()) eatException $ shutdown sock ShutdownBoth eatException $ sClose sock ) proc putMVar tmvar tid tickleTimeout c return () labelMe :: String -> IO () labelMe s = do tid <- myThreadId labelThread tid s eatException :: IO a -> IO () eatException act = (act >> return ()) `catch` \(_::SomeException) -> return () getReadEnd :: Connection -> Enumerator IO a getReadEnd = enumerate getWriteEnd :: Connection -> Iteratee IO () getWriteEnd = writeOut getRemoteAddr :: Connection -> ByteString getRemoteAddr = _remoteAddr getRemotePort :: Connection -> Int getRemotePort = _remotePort getLocalAddr :: Connection -> ByteString getLocalAddr = _localAddr getLocalPort :: Connection -> Int getLocalPort = _localPort ------------------------------------------------------------------------------ getHostAddr :: Int -> ByteString -> IO SockAddr getHostAddr p s = do h <- if s == "*" then return iNADDR_ANY else inet_addr (map w2c . B.unpack $ s) return $ SockAddrInet (fromIntegral p) h data TimeoutException = TimeoutException deriving (Typeable) instance Show TimeoutException where show TimeoutException = "timeout" instance Exception TimeoutException tickleTimeout :: Connection -> IO () tickleTimeout conn = do debug "Backend.tickleTimeout" now <- getCurrentDateTime tid <- readMVar $ _connTid conn atomicModifyIORef tedits $ \es -> (D.snoc es (PSQ.insert tid now), ()) where tedits = _timeoutEdits $ _backend conn _cancelTimeout :: Connection -> IO () _cancelTimeout conn = do debug "Backend.cancelTimeout" tid <- readMVar $ _connTid conn atomicModifyIORef tedits $ \es -> (D.snoc es (PSQ.delete tid), ()) where tedits = _timeoutEdits $ _backend conn timeoutRecv :: Connection -> Int -> IO ByteString timeoutRecv conn n = do let sock = _socket conn SB.recv sock n timeoutSend :: Connection -> ByteString -> IO () timeoutSend conn s = do let len = B.length s debug $ "Backend.timeoutSend: entered w/ " ++ show len ++ " bytes" let sock = _socket conn SB.sendAll sock s debug $ "Backend.timeoutSend: sent all" tickleTimeout conn bLOCKSIZE :: Int bLOCKSIZE = 8192 enumerate :: (MonadIO m) => Connection -> Enumerator m a enumerate = loop where loop conn f = do debug $ "Backend.enumerate: reading from socket" s <- liftIO $ timeoutRecv conn bLOCKSIZE debug $ "Backend.enumerate: got " ++ Prelude.show (B.length s) ++ " bytes from read end" sendOne conn f s sendOne conn f s = do v <- runIter f (if B.null s then EOF Nothing else Chunk $ WrapBS s) case v of r@(Done _ _) -> return $ liftI r (Cont k Nothing) -> loop conn k (Cont _ (Just e)) -> return $ throwErr e writeOut :: (MonadIO m) => Connection -> Iteratee m () writeOut conn = IterateeG out where out c@(EOF _) = return $ Done () c out (Chunk s) = do let x = unWrap s ee <- liftIO $ ((try $ timeoutSend conn x) :: IO (Either SomeException ())) case ee of (Left e) -> return $ Done () (EOF $ Just $ Err $ show e) (Right _) -> return $ Cont (writeOut conn) Nothing