module Snap.Internal.Http.Server.SimpleBackend
( simpleEventLoop
) where
import Control.Monad.Trans
import Control.Concurrent hiding (yield)
import Control.Exception
import Control.Monad
import Data.ByteString (ByteString)
import qualified Data.ByteString as S
import Data.ByteString.Internal (c2w)
import Data.Maybe
import Data.Typeable
import Data.Word
import Foreign hiding (new)
import Foreign.C
import GHC.Conc (labelThread, forkOnIO)
import Network.Socket
import Prelude hiding (catch)
import Data.Concurrent.HashMap (hashString)
import Snap.Internal.Debug
import Snap.Internal.Http.Server.Date
import qualified Snap.Internal.Http.Server.TimeoutTable as TT
import Snap.Internal.Http.Server.TimeoutTable (TimeoutTable)
import Snap.Internal.Http.Server.Backend
import qualified Snap.Internal.Http.Server.ListenHelpers as Listen
import Snap.Iteratee hiding (map)
#if defined(HAS_SENDFILE)
import qualified System.SendFile as SF
import System.Posix.IO
import System.Posix.Types (Fd(..))
#endif
data EventLoopCpu = EventLoopCpu
{ _boundCpu :: Int
, _acceptThreads :: [ThreadId]
, _timeoutTable :: TimeoutTable
, _timeoutThread :: ThreadId
, _exitMVar :: !(MVar ())
}
simpleEventLoop :: EventLoop
simpleEventLoop defaultTimeout sockets cap elog handler = do
loops <- Prelude.mapM (newLoop defaultTimeout sockets handler elog)
[0..(cap1)]
debug "simpleEventLoop: waiting for mvars"
Prelude.mapM_ (takeMVar . _exitMVar) loops `finally` do
debug "simpleEventLoop: killing all threads"
_ <- mapM_ stopLoop loops
mapM_ Listen.closeSocket sockets
newLoop :: Int
-> [ListenSocket]
-> SessionHandler
-> (S.ByteString -> IO ())
-> Int
-> IO EventLoopCpu
newLoop defaultTimeout sockets handler elog cpu = do
tt <- TT.new
exit <- newEmptyMVar
accThreads <- forM sockets $ \p -> forkOnIO cpu $
acceptThread defaultTimeout handler tt elog cpu p
tid <- forkOnIO cpu $ timeoutThread tt exit
return $ EventLoopCpu cpu accThreads tt tid exit
stopLoop :: EventLoopCpu -> IO ()
stopLoop loop = block $ do
Prelude.mapM_ killThread $ _acceptThreads loop
killThread $ _timeoutThread loop
acceptThread :: Int
-> SessionHandler
-> TimeoutTable
-> (S.ByteString -> IO ())
-> Int
-> ListenSocket
-> IO ()
acceptThread defaultTimeout handler tt elog cpu sock = loop
where
loop = do
debug $ "acceptThread: calling accept() on socket " ++ show sock
(s,addr) <- accept $ Listen.listenSocket sock
debug $ "acceptThread: accepted connection from remote: " ++ show addr
_ <- forkOnIO cpu (go s addr `catches` cleanup)
loop
go = runSession defaultTimeout handler tt sock
cleanup =
[
Handler $ \(_ :: AsyncException) -> return ()
, Handler $ \(e :: SomeException) -> elog
$ S.concat [ "SimpleBackend.acceptThread: "
, S.pack . map c2w $ show e]
]
timeoutThread :: TimeoutTable -> MVar () -> IO ()
timeoutThread table exitMVar = do
go `catch` (\(_::SomeException) -> killAll)
putMVar exitMVar ()
where
go = do
debug "timeoutThread: waiting for activity on thread table"
TT.waitForActivity table
debug "timeoutThread: woke up, killing old connections"
killTooOld
go
killTooOld = do
now <- getCurrentDateTime
TT.killOlderThan now table
killAll = do
debug "Backend.timeoutThread: shutdown, killing all connections"
TT.killAll table
data AddressNotSupportedException = AddressNotSupportedException String
deriving (Typeable)
instance Show AddressNotSupportedException where
show (AddressNotSupportedException x) = "Address not supported: " ++ x
instance Exception AddressNotSupportedException
runSession :: Int
-> SessionHandler
-> TimeoutTable
-> ListenSocket
-> Socket
-> SockAddr -> IO ()
runSession defaultTimeout handler tt lsock sock addr = do
let fd = fdSocket sock
curId <- myThreadId
debug $ "Backend.withConnection: running session: " ++ show addr
labelThread curId $ "connHndl " ++ show fd
(rport,rhost) <-
case addr of
SockAddrInet p h -> do
h' <- inet_ntoa h
return (fromIntegral p, S.pack $ map c2w h')
x -> throwIO $ AddressNotSupportedException $ show x
laddr <- getSocketName sock
(lport,lhost) <-
case laddr of
SockAddrInet p h -> do
h' <- inet_ntoa h
return (fromIntegral p, S.pack $ map c2w h')
x -> throwIO $ AddressNotSupportedException $ show x
let sinfo = SessionInfo lhost lport rhost rport $ Listen.isSecure lsock
let curHash = hashString $ show curId
let timeout = tickleTimeout tt curId curHash
timeout defaultTimeout
bracket (Listen.createSession lsock 8192 fd
(threadWaitRead $ fromIntegral fd))
(\session -> block $ do
debug "thread killed, closing socket"
TT.delete curHash curId tt
eatException $ Listen.endSession lsock session
eatException $ shutdown sock ShutdownBoth
eatException $ sClose sock
)
(\s -> let writeEnd = writeOut lsock s sock
(timeout defaultTimeout)
in handler sinfo
(enumerate lsock s sock)
writeEnd
(sendFile lsock (timeout defaultTimeout) fd
writeEnd)
timeout
)
eatException :: IO a -> IO ()
eatException act = (act >> return ()) `catch` \(_::SomeException) -> return ()
sendFile :: ListenSocket
-> IO ()
-> CInt
-> Iteratee ByteString IO ()
-> FilePath
-> Int64
-> Int64
-> IO ()
#if defined(HAS_SENDFILE)
sendFile lsock tickle sock writeEnd fp start sz =
case lsock of
ListenHttp _ -> bracket (openFd fp ReadOnly Nothing defaultFileFlags)
(closeFd)
(go start sz)
_ -> do
step <- runIteratee writeEnd
run_ $ enumFilePartial fp (start,start+sz) step
where
go off bytes fd
| bytes == 0 = return ()
| otherwise = do
sent <- SF.sendFile sfd fd off bytes
if sent < bytes
then tickle >> go (off+sent) (bytessent) fd
else return ()
sfd = Fd sock
#else
sendFile _ _ _ writeEnd fp start sz = do
step <- runIteratee writeEnd
run_ $ enumFilePartial fp (start,start+sz) step
return ()
#endif
tickleTimeout :: TimeoutTable -> ThreadId -> Word -> Int -> IO ()
tickleTimeout table tid thash tm = do
debug "Backend.tickleTimeout"
now <- getCurrentDateTime
TT.insert thash tid (now + toEnum tm) table
enumerate :: (MonadIO m)
=> ListenSocket
-> NetworkSession
-> Socket
-> Enumerator ByteString m a
enumerate port session sock = loop
where
dbg s = debug $ "SimpleBackend.enumerate(" ++ show (_socket session)
++ "): " ++ s
loop (Continue k) = do
dbg "reading from socket"
s <- liftIO $ timeoutRecv
case s of
Nothing -> do
dbg "got EOF from socket"
sendOne k ""
Just s' -> do
dbg $ "got " ++ Prelude.show (S.length s')
++ " bytes from read end"
sendOne k s'
loop x = returnI x
sendOne k s | S.null s = do
dbg "sending EOF to continuation"
enumEOF $ Continue k
| otherwise = do
dbg $ "sending " ++ show s ++ " to continuation"
step <- lift $ runIteratee $ k $ Chunks [s]
case step of
(Yield x st) -> do
dbg $ "got yield, remainder is " ++ show st
yield x st
r@(Continue _) -> do
dbg $ "got continue"
loop r
(Error e) -> throwError e
fd = fdSocket sock
#ifdef PORTABLE
timeoutRecv = Listen.recv port sock (threadWaitRead $
fromIntegral fd) session
#else
timeoutRecv = Listen.recv port (threadWaitRead $
fromIntegral fd) session
#endif
writeOut :: (MonadIO m)
=> ListenSocket
-> NetworkSession
-> Socket
-> (IO ())
-> Iteratee ByteString m ()
writeOut port session sock tickle = loop
where
dbg s = debug $ "SimpleBackend.writeOut(" ++ show (_socket session)
++ "): " ++ s
loop = continue k
k EOF = yield () EOF
k (Chunks xs) = do
let s = S.concat xs
let n = S.length s
dbg $ "got chunk with " ++ show n ++ " bytes"
ee <- liftIO $ try $ timeoutSend s
case ee of
(Left (e::SomeException)) -> do
dbg $ "timeoutSend got error " ++ show e
throwError e
(Right _) -> do
let last10 = S.drop (n10) s
dbg $ "wrote " ++ show n ++ " bytes, last 10=" ++ show last10
loop
fd = fdSocket sock
#ifdef PORTABLE
timeoutSend = Listen.send port sock tickle
(threadWaitWrite $ fromIntegral fd) session
#else
timeoutSend = Listen.send port tickle
(threadWaitWrite $ fromIntegral fd) session
#endif