module Network.C10kServer (
C10kConfig(..)
, C10kServer, runC10kServer
, C10kServerH, runC10kServerH
) where
import Control.Concurrent
import Control.Exception
import Control.Monad
import Data.IORef
import Network.BSD
import Network.Socket
import qualified Network.TCPInfo as T (accept)
import Network.TCPInfo hiding (accept)
import Prelude hiding (catch)
import System.Exit
import System.IO
import System.Posix.Process
import System.Posix.Signals
import System.Posix.Types
import System.Posix.User
type C10kServer = Socket -> IO ()
type C10kServerH = Handle -> TCPInfo -> IO ()
data C10kConfig = C10kConfig {
initHook :: IO ()
, exitHook :: String -> IO ()
, parentStartedHook :: IO ()
, startedHook :: IO ()
, sleepTimer :: Int
, preforkProcessNumber :: Int
, threadNumberPerProcess :: Int
, portName :: ServiceName
, ipAddr :: Maybe HostName
, pidFile :: FilePath
, user :: String
, group :: String
}
runC10kServer :: C10kServer -> C10kConfig -> IO ()
runC10kServer srv cnf = runC10kServer' (dispatchS srv) cnf `catch` errorHandle cnf
runC10kServerH :: C10kServerH -> C10kConfig -> IO ()
runC10kServerH srv cnf = runC10kServer' (dispatchH srv) cnf `catch` errorHandle cnf
errorHandle :: C10kConfig -> SomeException -> IO ()
errorHandle cnf e = do
exitHook cnf (show e)
exitFailure
runC10kServer' :: (Socket -> Dispatch) -> C10kConfig -> IO ()
runC10kServer' sDispatch cnf = do
initHook cnf `catch` ignore
if onlyOne
then stay sDispatch cnf
else prefork sDispatch cnf
where
onlyOne = preforkProcessNumber cnf == 1
stay :: (Socket -> Dispatch) -> C10kConfig -> IO ()
stay sDispatch cnf = do
parentStartedHook cnf `catch` ignore
s <- listenTo addr port
writePidFile cnf
setGroupUser cnf
runServer (sDispatch s) cnf
sClose s
where
port = portName cnf
addr = ipAddr cnf
prefork :: (Socket -> Dispatch) -> C10kConfig -> IO ()
prefork sDispatch cnf = do
cids <- doPrefork sDispatch cnf
parentStartedHook cnf `catch` ignore
handleSignal cids
pause
terminateChildren cids
where
handleSignal cids = do
ignoreSigChild
mapM_ (terminator cids) [sigTERM,sigINT]
pause = blockSignals reservedSignals >> awaitSignal Nothing >> yield
initHandler func sig = installHandler sig func Nothing
ignoreSigChild = initHandler Ignore sigCHLD
terminator cids = initHandler (Catch (terminateChildren cids))
terminateChildren cids = do
ignoreSigChild
mapM_ terminateChild cids
terminateChild cid = signalProcess sigTERM cid `catch` ignore
doPrefork :: (Socket -> Dispatch) -> C10kConfig -> IO [ProcessID]
doPrefork sDispatch cnf = do
s <- listenTo addr port
writePidFile cnf
setGroupUser cnf
cids <- fork (sDispatch s)
sClose s
return cids
where
port = portName cnf
addr = ipAddr cnf
n = preforkProcessNumber cnf
fork dispatch = replicateM n . forkProcess $ runServer dispatch cnf
setGroupUser :: C10kConfig -> IO ()
setGroupUser cnf = do
uid <- getRealUserID
when (uid == 0) $ do
getGroupEntryForName (group cnf) >>= setGroupID . groupID
getUserEntryForName (user cnf) >>= setUserID . userID
writePidFile :: C10kConfig -> IO ()
writePidFile cnf = do
pid <- getProcessID
writeFile (pidFile cnf) $ show pid ++ "\n"
runServer :: Dispatch -> C10kConfig -> IO ()
runServer dispatch cnf = do
startedHook cnf
ref <- newIORef 0 :: IO (IORef Int)
dispatchOrSleep ref dispatch cnf
dispatchOrSleep :: IORef Int -> Dispatch -> C10kConfig -> IO ()
dispatchOrSleep ref dispatch cnf = do
n <- howMany
if n > threadNumberPerProcess cnf
then sleep (sleepTimer cnf * microseconds)
else dispatch increase decrease
dispatchOrSleep ref dispatch cnf
where
howMany = readIORef ref
increase = atomicModifyIORef ref (\n -> (n+1,()))
decrease = atomicModifyIORef ref (\n -> (n1,()))
sleep = threadDelay
type Dispatch = IO () -> IO () -> IO ()
dispatchS :: C10kServer -> Socket -> Dispatch
dispatchS srv sock inc dec = do
(connsock,_) <- accept sock
inc
forkIO $ srv connsock `finally` (dec >> sClose connsock)
return ()
dispatchH :: C10kServerH -> Socket -> Dispatch
dispatchH srv sock inc dec = do
(hdl,tcpi) <- T.accept sock
inc
forkIO $ srv hdl tcpi `finally` (dec >> hClose hdl)
return ()
ignore :: SomeException -> IO ()
ignore _ = return ()
microseconds :: Int
microseconds = 1000000
listenTo :: Maybe HostName -> ServiceName -> IO Socket
listenTo maddr serv = do
proto <- getProtocolNumber "tcp"
let hints = defaultHints {
addrFlags = [ AI_ADDRCONFIG, AI_NUMERICHOST
, AI_NUMERICSERV, AI_PASSIVE
]
, addrSocketType = Stream
, addrProtocol = proto
}
addrs <- getAddrInfo (Just hints) maddr (Just serv)
let addrs' = filter (\x -> addrFamily x == AF_INET6) addrs
addr = if null addrs' then head addrs else head addrs'
sock <- socket (addrFamily addr) (addrSocketType addr) (addrProtocol addr)
setSocketOption sock ReuseAddr 1
bindSocket sock (addrAddress addr)
listen sock maxListenQueue
return sock