module Network.WebSockets.ExtendedServer ( ServerProgram, server, safeSendText, safeSend, StdOutMutex, ConnectionSendMutex ) where import Network.WebSockets.Connection import Network.WebSockets.Types import qualified Data.Text as T (Text, pack) import qualified Network.Socket as S (Socket(..), withSocketsDo) import Control.Monad (forever, when) import Control.Exception (handle, SomeException(..)) import Control.Concurrent (forkIO, threadDelay) import Control.Concurrent.MVar (MVar, newMVar, takeMVar, putMVar) import Control.Concurrent.Event (Event, wait, set, new, isSet) import Data.List ((\\)) import Data.Maybe (isNothing, fromJust) import GHC.IO.Exception (IOException(..)) type ConnectionIdList = (ConnectionId, [ConnectionId]) type ConnectionId = Int type StdOutMutex = Mutex [[Char]] type ConnectionSendMutex = Mutex () type Mutex a = MVar a type ClosedThread = Event type CloseThread = Event type ConnectionList = [(ConnectionId, Connection, ConnectionSendMutex)] type ServerProgram = Connection -> StdOutMutex -> ConnectionSendMutex -> IO () -- Utility functions -- safeStartCloseThread :: IO Event safeStartCloseThread = new safeStartClosedThread :: IO Event safeStartClosedThread = new -- Safe function -- safe :: Mutex a -> IO b -> IO b safe mu f = do a <- takeMVar mu b <- f putMVar mu a return b -- Print to stdOut functions -- safePutStr :: StdOutMutex -> String -> IO () safePutStr mu str = do strList <- takeMVar mu putMVar mu (strList ++ [str]) -- Doesn't work correctly on Windows? safeGetLine :: StdOutMutex -> IO String safeGetLine mu = do c <- getChar str <- safe mu $ do str' <- getLine return str' return (c:str) safeStartStdOutMutex :: IO StdOutMutex safeStartStdOutMutex = newMVar [] -- Send data over the socket -- safeSendText :: ConnectionSendMutex -> Connection -> T.Text -> IO () safeSendText mu conn = (safeSend mu conn).(Text).(toLazyByteString) safeSend :: ConnectionSendMutex -> Connection -> DataMessage -> IO () safeSend mu conn msg = safe mu $ sendDataMessage conn msg safeStartMutex :: IO (Mutex ()) safeStartMutex = newMVar () -- Connection list functions -- safeAddConnection :: MVar ConnectionList -> ConnectionId -> Connection -> ConnectionSendMutex -> IO () safeAddConnection connListM connId conn mu = do connList <- takeMVar connListM putMVar connListM ((connId, conn, mu):connList) safeRemoveConnection :: MVar ConnectionList -> ConnectionId -> IO () safeRemoveConnection connListM connId = do connList <- takeMVar connListM let connList' = filter (\(connId', _, _) -> not $ connId'== connId) connList putMVar connListM connList' safeGetConnection :: MVar ConnectionList -> ConnectionId -> IO (Maybe Connection) safeGetConnection connListM cid = do connList <- takeMVar connListM putMVar connListM connList return $ getConnection connList cid getConnection :: ConnectionList -> ConnectionId -> Maybe Connection getConnection [] _ = Nothing getConnection ((cid, conn, _):cs) cid' | cid == cid' = Just conn | otherwise = getConnection cs cid' safeGetConnSendMutex :: MVar ConnectionList -> ConnectionId -> IO (Maybe ConnectionSendMutex) safeGetConnSendMutex connListM connId = do connList <- takeMVar connListM putMVar connListM connList return $ getConnSendMutex connList connId getConnSendMutex :: ConnectionList -> ConnectionId -> Maybe ConnectionSendMutex getConnSendMutex [] _ = Nothing getConnSendMutex ((connId', _, connSendM):cs) connId | connId' == connId = Just connSendM | otherwise = getConnSendMutex cs connId safeStartConnectionList :: IO (MVar ConnectionList) safeStartConnectionList = newMVar [] -- Connection Id List functions -- takeConnId :: ConnectionIdList -> (ConnectionId, ConnectionIdList) takeConnId (next, []) = (next, (next + 1, [])) takeConnId (next, i:is) = (i, (next, is)) returnConnId :: ConnectionId -> ConnectionIdList -> ConnectionIdList returnConnId i (next, is) = (next, i:is) safeTakeConnId :: MVar ConnectionIdList -> IO ConnectionId safeTakeConnId connIdListM = do connIdList <- takeMVar connIdListM let (i, connIdList') = takeConnId connIdList putMVar connIdListM connIdList' return i safeReturnConnId :: ConnectionId -> MVar ConnectionIdList -> IO () safeReturnConnId i connIdListM = do connIdList <- takeMVar connIdListM let connIdList' = returnConnId i connIdList putMVar connIdListM connIdList' return () safeStartConnId :: IO (MVar ConnectionIdList) safeStartConnId = newMVar (0, []) -- Server function -- server :: String -> Int -> ServerProgram -> IO () server host port prog = S.withSocketsDo $ do sock <- makeSocket host port connListM <- safeStartConnectionList stdoutM <- safeStartStdOutMutex connIdListM <- safeStartConnId closePrinter <- safeStartCloseThread closedPrinter <- safeStartClosedThread closedTextInt <- safeStartClosedThread safePutStr stdoutM "" _ <- forkIO (printInterface stdoutM closePrinter closedPrinter) _ <- forkIO (textInterface connListM stdoutM closePrinter closedTextInt sock) handleMainThreadExceptions $ forever $ do penCon <- makePendingConnection sock conn <- acceptRequest penCon connId <- safeTakeConnId connIdListM connSendM <- safeStartMutex safeAddConnection connListM connId conn connSendM safePutStr stdoutM ("Connection established: " ++ (show connId)) _ <- forkIO (handleConnectionThreadExceptions connListM connIdListM stdoutM connId $ prog conn stdoutM connSendM) return () waitForConnectionsCIL connIdListM wait closedTextInt wait closedPrinter waitForConnectionsCIL :: MVar ConnectionIdList -> IO () waitForConnectionsCIL connIdListM = do (maxId, list) <- takeMVar connIdListM let list' = buildList (maxId - 1) diff = list' \\ list allDone = null diff putMVar connIdListM (maxId, list) threadDelay 100000 when (not allDone) (waitForConnectionsCIL connIdListM) waitForConnectionsCL :: MVar ConnectionList -> IO () waitForConnectionsCL connListM = do connList <- takeMVar connListM putMVar connListM connList threadDelay 100000 when (not $ null connList) (waitForConnectionsCL connListM) buildList :: Int -> [Int] buildList i | i >= 0 = i:(buildList (i - 1)) | i < 0 = [] buildList _ = [] -- Printer Interface -- printInterface :: StdOutMutex -> CloseThread -> ClosedThread -> IO () printInterface stdoutM closeThread closedThread = do printStep stdoutM threadDelay 10000 close <- isSet closeThread when (not close) (printInterface stdoutM closeThread closedThread) closed <- isSet closedThread when (not closed) $ do printStep stdoutM set closedThread printStep :: StdOutMutex -> IO () printStep stdoutM = do strList <- takeMVar stdoutM strList' <- flushPrintStr strList putMVar stdoutM strList' flushPrintStr :: [[Char]] -> IO [[Char]] flushPrintStr [] = return [] flushPrintStr (str:strList) = do putStrLn str flushPrintStr strList -- Textual User Interface -- textInterface :: MVar ConnectionList -> StdOutMutex -> CloseThread -> ClosedThread -> S.Socket -> IO () textInterface connListM stdoutM closePrinter closedThread sock = do line <- safeGetLine stdoutM doQuit <- doCommand connListM stdoutM closePrinter closedThread sock (words line) when (not doQuit) (textInterface connListM stdoutM closePrinter closedThread sock) doCommand :: MVar ConnectionList -> StdOutMutex -> CloseThread -> ClosedThread -> S.Socket -> [String] -> IO Bool doCommand _ _ _ _ _ [] = return False doCommand connListM stdoutM closePrinter closedThread sock (c:_) | c == "quit" = do quit connListM stdoutM closePrinter closedThread sock return True | otherwise = return False quit :: MVar ConnectionList -> StdOutMutex -> CloseThread -> ClosedThread -> S.Socket -> IO () quit connListM stdoutM closePrinter closedThread sock = do safePutStr stdoutM "Closing..." connList <- takeMVar connListM safePutStr stdoutM "Waiting for Server to close..." closeSocket sock safePutStr stdoutM "Server is Closed!" closeAllConnections connList "Server is closing" putMVar connListM connList safePutStr stdoutM "Waiting for connections to close..." waitForConnectionsCL connListM safePutStr stdoutM "All Connections are closed!" safePutStr stdoutM "All done!" set closePrinter set closedThread -- Close connection(s) -- closeAllConnections :: ConnectionList -> String -> IO () closeAllConnections [] _ = return () closeAllConnections ((_, conn, mu):cs) str = do closeConnection mu conn str closeAllConnections cs str closeConnection :: ConnectionSendMutex -> Connection -> String -> IO () closeConnection connMu conn str = do safe connMu $ sendClose conn (T.pack str) -- Main Thread Exception handlers -- handleMainThreadExceptions :: IO () -> IO () handleMainThreadExceptions = handleIOError handleIOError :: IO () -> IO () handleIOError = handle (ioErrorHandler) ioErrorHandler :: IOException -> IO () ioErrorHandler (IOError {}) = return () -- Connection Thread Exception handlers -- handleConnectionThreadExceptions :: MVar ConnectionList -> MVar ConnectionIdList -> StdOutMutex -> ConnectionId -> IO () -> IO () handleConnectionThreadExceptions connListM connIdListM stdoutM connId io = handleConnExcept $ handleOtherExceptions stdoutM connListM connId $ handleConnExcept io where handleConnExcept = (handleConnectionException connListM connIdListM stdoutM connId) handleOtherExceptions :: StdOutMutex -> MVar ConnectionList -> ConnectionId -> IO () -> IO () handleOtherExceptions stdoutM connListM connId = handle (connectionOtherExceptionHandler stdoutM connListM connId) connectionOtherExceptionHandler :: StdOutMutex -> MVar ConnectionList -> ConnectionId -> SomeException -> IO () connectionOtherExceptionHandler stdoutM connListM cid e = do connM <- safeGetConnection connListM cid connSendMM <- safeGetConnSendMutex connListM cid let hasConn = not $ isNothing connM hasConnSendM = not $ isNothing connSendMM conn = fromJust connM connSendM = fromJust connSendMM safePutStr stdoutM ("Connection [" ++ (show cid) ++ "] has had an exception: " ++ (show e)) when (hasConn && hasConnSendM) $ do closeConnection connSendM conn "Connection has caused a server-side exception" safePutStr stdoutM ("Send close message...") safePutStr stdoutM ("Reading from the connection for close event") forever $ do msg <- receiveData conn :: IO T.Text safePutStr stdoutM (show msg) handleConnectionException :: MVar ConnectionList -> MVar ConnectionIdList -> StdOutMutex -> ConnectionId -> IO () -> IO () handleConnectionException connListM connIdListM stdoutM connId = handle (connectionExceptionHandler connListM connIdListM stdoutM connId) connectionExceptionHandler :: MVar ConnectionList -> MVar ConnectionIdList -> StdOutMutex -> ConnectionId -> ConnectionException -> IO () connectionExceptionHandler connListM connIdListM stdoutM connId (CloseRequest code msg) = do safeRemoveConnection connListM connId safeReturnConnId connId connIdListM safePutStr stdoutM ("Connection[" ++ (show connId) ++ "] closed: " ++ (show code) ++ " " ++ (show (fromLazyByteString msg :: T.Text)) ) connectionExceptionHandler connListM connIdListM stdoutM connId (ConnectionClosed) = do safeRemoveConnection connListM connId safeReturnConnId connId connIdListM safePutStr stdoutM ("Connection[" ++ (show connId) ++ "] closed unexpectedly: " ++ (show connId))