module Network.WebSockets.ExtendedServer
(
ServerProgram,
server,
safeSendText,
safeSend,
StdOutMutex,
ConnectionSendMutex
) where
import Network.WebSockets.Connection
import Network.WebSockets.Types
import qualified Data.Text as T (Text, pack)
import qualified Network.Socket as S (Socket(..), withSocketsDo)
import Control.Monad (forever, when)
import Control.Exception (handle, SomeException(..))
import Control.Concurrent (forkIO, threadDelay)
import Control.Concurrent.MVar (MVar, newMVar, takeMVar, putMVar)
import Control.Concurrent.Event (Event, wait, set, new, isSet)
import Data.List ((\\))
import Data.Maybe (isNothing, fromJust)
import GHC.IO.Exception (IOException(..))
type ConnectionIdList = (ConnectionId, [ConnectionId])
type ConnectionId = Int
type StdOutMutex = Mutex [[Char]]
type ConnectionSendMutex = Mutex ()
type Mutex a = MVar a
type ClosedThread = Event
type CloseThread = Event
type ConnectionList = [(ConnectionId, Connection, ConnectionSendMutex)]
type ServerProgram = Connection -> StdOutMutex -> ConnectionSendMutex -> IO ()
safeStartCloseThread :: IO Event
safeStartCloseThread = new
safeStartClosedThread :: IO Event
safeStartClosedThread = new
safe :: Mutex a -> IO b -> IO b
safe mu f = do
a <- takeMVar mu
b <- f
putMVar mu a
return b
safePutStr :: StdOutMutex -> String -> IO ()
safePutStr mu str = do
strList <- takeMVar mu
putMVar mu (strList ++ [str])
safeGetLine :: StdOutMutex -> IO String
safeGetLine mu = do
c <- getChar
str <- safe mu $ do
str' <- getLine
return str'
return (c:str)
safeStartStdOutMutex :: IO StdOutMutex
safeStartStdOutMutex = newMVar []
safeSendText :: ConnectionSendMutex -> Connection -> T.Text -> IO ()
safeSendText mu conn = (safeSend mu conn).(Text).(toLazyByteString)
safeSend :: ConnectionSendMutex -> Connection -> DataMessage -> IO ()
safeSend mu conn msg = safe mu $ sendDataMessage conn msg
safeStartMutex :: IO (Mutex ())
safeStartMutex = newMVar ()
safeAddConnection :: MVar ConnectionList -> ConnectionId -> Connection -> ConnectionSendMutex -> IO ()
safeAddConnection connListM connId conn mu = do
connList <- takeMVar connListM
putMVar connListM ((connId, conn, mu):connList)
safeRemoveConnection :: MVar ConnectionList -> ConnectionId -> IO ()
safeRemoveConnection connListM connId = do
connList <- takeMVar connListM
let
connList' = filter (\(connId', _, _) -> not $ connId'== connId) connList
putMVar connListM connList'
safeGetConnection :: MVar ConnectionList -> ConnectionId -> IO (Maybe Connection)
safeGetConnection connListM cid = do
connList <- takeMVar connListM
putMVar connListM connList
return $ getConnection connList cid
getConnection :: ConnectionList -> ConnectionId -> Maybe Connection
getConnection [] _ = Nothing
getConnection ((cid, conn, _):cs) cid' | cid == cid' = Just conn
| otherwise = getConnection cs cid'
safeGetConnSendMutex :: MVar ConnectionList -> ConnectionId -> IO (Maybe ConnectionSendMutex)
safeGetConnSendMutex connListM connId = do
connList <- takeMVar connListM
putMVar connListM connList
return $ getConnSendMutex connList connId
getConnSendMutex :: ConnectionList -> ConnectionId -> Maybe ConnectionSendMutex
getConnSendMutex [] _ = Nothing
getConnSendMutex ((connId', _, connSendM):cs) connId | connId' == connId = Just connSendM
| otherwise = getConnSendMutex cs connId
safeStartConnectionList :: IO (MVar ConnectionList)
safeStartConnectionList = newMVar []
takeConnId :: ConnectionIdList -> (ConnectionId, ConnectionIdList)
takeConnId (next, []) = (next, (next + 1, []))
takeConnId (next, i:is) = (i, (next, is))
returnConnId :: ConnectionId -> ConnectionIdList -> ConnectionIdList
returnConnId i (next, is) = (next, i:is)
safeTakeConnId :: MVar ConnectionIdList -> IO ConnectionId
safeTakeConnId connIdListM = do
connIdList <- takeMVar connIdListM
let
(i, connIdList') = takeConnId connIdList
putMVar connIdListM connIdList'
return i
safeReturnConnId :: ConnectionId -> MVar ConnectionIdList -> IO ()
safeReturnConnId i connIdListM = do
connIdList <- takeMVar connIdListM
let
connIdList' = returnConnId i connIdList
putMVar connIdListM connIdList'
return ()
safeStartConnId :: IO (MVar ConnectionIdList)
safeStartConnId = newMVar (0, [])
server :: String -> Int -> ServerProgram -> IO ()
server host port prog = S.withSocketsDo $ do
sock <- makeSocket host port
connListM <- safeStartConnectionList
stdoutM <- safeStartStdOutMutex
connIdListM <- safeStartConnId
closePrinter <- safeStartCloseThread
closedPrinter <- safeStartClosedThread
closedTextInt <- safeStartClosedThread
safePutStr stdoutM ""
_ <- forkIO (printInterface stdoutM closePrinter closedPrinter)
_ <- forkIO (textInterface connListM stdoutM closePrinter closedTextInt sock)
handleMainThreadExceptions $ forever $ do
penCon <- makePendingConnection sock
conn <- acceptRequest penCon
connId <- safeTakeConnId connIdListM
connSendM <- safeStartMutex
safeAddConnection connListM connId conn connSendM
safePutStr stdoutM ("Connection established: " ++ (show connId))
_ <- forkIO (handleConnectionThreadExceptions connListM connIdListM stdoutM connId $ prog conn stdoutM connSendM)
return ()
waitForConnectionsCIL connIdListM
wait closedTextInt
wait closedPrinter
waitForConnectionsCIL :: MVar ConnectionIdList -> IO ()
waitForConnectionsCIL connIdListM = do
(maxId, list) <- takeMVar connIdListM
let
list' = buildList (maxId 1)
diff = list' \\ list
allDone = null diff
putMVar connIdListM (maxId, list)
threadDelay 100000
when (not allDone) (waitForConnectionsCIL connIdListM)
waitForConnectionsCL :: MVar ConnectionList -> IO ()
waitForConnectionsCL connListM = do
connList <- takeMVar connListM
putMVar connListM connList
threadDelay 100000
when (not $ null connList) (waitForConnectionsCL connListM)
buildList :: Int -> [Int]
buildList i | i >= 0 = i:(buildList (i 1))
| i < 0 = []
buildList _ = []
printInterface :: StdOutMutex -> CloseThread -> ClosedThread -> IO ()
printInterface stdoutM closeThread closedThread = do
printStep stdoutM
threadDelay 10000
close <- isSet closeThread
when (not close) (printInterface stdoutM closeThread closedThread)
closed <- isSet closedThread
when (not closed) $ do
printStep stdoutM
set closedThread
printStep :: StdOutMutex -> IO ()
printStep stdoutM = do
strList <- takeMVar stdoutM
strList' <- flushPrintStr strList
putMVar stdoutM strList'
flushPrintStr :: [[Char]] -> IO [[Char]]
flushPrintStr [] = return []
flushPrintStr (str:strList) = do
putStrLn str
flushPrintStr strList
textInterface :: MVar ConnectionList -> StdOutMutex -> CloseThread -> ClosedThread -> S.Socket -> IO ()
textInterface connListM stdoutM closePrinter closedThread sock = do
line <- safeGetLine stdoutM
doQuit <- doCommand connListM stdoutM closePrinter closedThread sock (words line)
when (not doQuit) (textInterface connListM stdoutM closePrinter closedThread sock)
doCommand :: MVar ConnectionList -> StdOutMutex -> CloseThread -> ClosedThread -> S.Socket -> [String] -> IO Bool
doCommand _ _ _ _ _ [] = return False
doCommand connListM stdoutM closePrinter closedThread sock (c:_) | c == "quit" = do
quit connListM stdoutM closePrinter closedThread sock
return True
| otherwise = return False
quit :: MVar ConnectionList -> StdOutMutex -> CloseThread -> ClosedThread -> S.Socket -> IO ()
quit connListM stdoutM closePrinter closedThread sock = do
safePutStr stdoutM "Closing..."
connList <- takeMVar connListM
safePutStr stdoutM "Waiting for Server to close..."
closeSocket sock
safePutStr stdoutM "Server is Closed!"
closeAllConnections connList "Server is closing"
putMVar connListM connList
safePutStr stdoutM "Waiting for connections to close..."
waitForConnectionsCL connListM
safePutStr stdoutM "All Connections are closed!"
safePutStr stdoutM "All done!"
set closePrinter
set closedThread
closeAllConnections :: ConnectionList -> String -> IO ()
closeAllConnections [] _ = return ()
closeAllConnections ((_, conn, mu):cs) str = do
closeConnection mu conn str
closeAllConnections cs str
closeConnection :: ConnectionSendMutex -> Connection -> String -> IO ()
closeConnection connMu conn str = do
safe connMu $ sendClose conn (T.pack str)
handleMainThreadExceptions :: IO () -> IO ()
handleMainThreadExceptions = handleIOError
handleIOError :: IO () -> IO ()
handleIOError = handle (ioErrorHandler)
ioErrorHandler :: IOException -> IO ()
ioErrorHandler (IOError {}) = return ()
handleConnectionThreadExceptions :: MVar ConnectionList -> MVar ConnectionIdList -> StdOutMutex -> ConnectionId -> IO () -> IO ()
handleConnectionThreadExceptions connListM connIdListM stdoutM connId io = handleConnExcept $ handleOtherExceptions stdoutM connListM connId $ handleConnExcept io
where
handleConnExcept = (handleConnectionException connListM connIdListM stdoutM connId)
handleOtherExceptions :: StdOutMutex -> MVar ConnectionList -> ConnectionId -> IO () -> IO ()
handleOtherExceptions stdoutM connListM connId = handle (connectionOtherExceptionHandler stdoutM connListM connId)
connectionOtherExceptionHandler :: StdOutMutex -> MVar ConnectionList -> ConnectionId -> SomeException -> IO ()
connectionOtherExceptionHandler stdoutM connListM cid e = do
connM <- safeGetConnection connListM cid
connSendMM <- safeGetConnSendMutex connListM cid
let
hasConn = not $ isNothing connM
hasConnSendM = not $ isNothing connSendMM
conn = fromJust connM
connSendM = fromJust connSendMM
safePutStr stdoutM ("Connection [" ++ (show cid) ++ "] has had an exception: " ++ (show e))
when (hasConn && hasConnSendM) $ do
closeConnection connSendM conn "Connection has caused a server-side exception"
safePutStr stdoutM ("Send close message...")
safePutStr stdoutM ("Reading from the connection for close event")
forever $ do
msg <- receiveData conn :: IO T.Text
safePutStr stdoutM (show msg)
handleConnectionException :: MVar ConnectionList -> MVar ConnectionIdList -> StdOutMutex -> ConnectionId -> IO () -> IO ()
handleConnectionException connListM connIdListM stdoutM connId = handle (connectionExceptionHandler connListM connIdListM stdoutM connId)
connectionExceptionHandler :: MVar ConnectionList -> MVar ConnectionIdList -> StdOutMutex -> ConnectionId -> ConnectionException -> IO ()
connectionExceptionHandler connListM connIdListM stdoutM connId (CloseRequest code msg) = do
safeRemoveConnection connListM connId
safeReturnConnId connId connIdListM
safePutStr stdoutM ("Connection[" ++ (show connId) ++ "] closed: " ++ (show code) ++ " " ++ (show (fromLazyByteString msg :: T.Text)) )
connectionExceptionHandler connListM connIdListM stdoutM connId (ConnectionClosed) = do
safeRemoveConnection connListM connId
safeReturnConnId connId connIdListM
safePutStr stdoutM ("Connection[" ++ (show connId) ++ "] closed unexpectedly: " ++ (show connId))