module Network.WebSockets.ExtendedServer 
            ) where
import Network.WebSockets.Connection
import Network.WebSockets.Types

import qualified Data.Text          as T  (Text, pack)
import qualified Network.Socket     as S  (Socket(..), withSocketsDo)
import Control.Monad                      (forever, when)
import Control.Exception                  (handle, SomeException(..))
import Control.Concurrent                 (forkIO, threadDelay)
import Control.Concurrent.MVar            (MVar, newMVar, takeMVar, putMVar)
import Control.Concurrent.Event           (Event, wait, set, new, isSet)
import Data.List                          ((\\))
import Data.Maybe                         (isNothing, fromJust)
import GHC.IO.Exception                   (IOException(..))

type ConnectionIdList = (ConnectionId, [ConnectionId])
type ConnectionId = Int
type StdOutMutex = Mutex [[Char]]
type ConnectionSendMutex = Mutex ()
type Mutex a = MVar a
type ClosedThread = Event
type CloseThread = Event
type ConnectionList = [(ConnectionId, Connection, ConnectionSendMutex)]

type ServerProgram = Connection -> StdOutMutex -> ConnectionSendMutex -> IO ()

-- Utility functions
safeStartCloseThread :: IO Event
safeStartCloseThread = new               
safeStartClosedThread :: IO Event
safeStartClosedThread = new
-- Safe function

safe :: Mutex a -> IO b -> IO b
safe mu f = do
                a <- takeMVar mu
                b <- f
                putMVar mu a
                return b

-- Print to stdOut functions
safePutStr :: StdOutMutex -> String -> IO ()
safePutStr mu str = do
                        strList <- takeMVar mu
                        putMVar mu (strList ++ [str])

-- Doesn't work correctly on Windows?                        
safeGetLine :: StdOutMutex -> IO String
safeGetLine mu = do
                    c <- getChar
                    str <- safe mu $ do
                                        str' <- getLine
                                        return str'
                    return (c:str)                        

safeStartStdOutMutex :: IO StdOutMutex                        
safeStartStdOutMutex = newMVar []                        

-- Send data over the socket
safeSendText :: ConnectionSendMutex -> Connection -> T.Text -> IO ()
safeSendText mu conn = (safeSend mu conn).(Text).(toLazyByteString)

safeSend :: ConnectionSendMutex -> Connection -> DataMessage -> IO ()
safeSend mu conn msg = safe mu $ sendDataMessage conn msg
safeStartMutex :: IO (Mutex ())
safeStartMutex = newMVar ()                        

-- Connection list functions
safeAddConnection :: MVar ConnectionList -> ConnectionId -> Connection -> ConnectionSendMutex -> IO ()
safeAddConnection connListM connId conn mu = do
                                        connList <- takeMVar connListM        
                                        putMVar connListM ((connId, conn, mu):connList)
safeRemoveConnection :: MVar ConnectionList -> ConnectionId -> IO ()
safeRemoveConnection connListM connId = do
                                        connList <- takeMVar connListM
                                            connList' = filter (\(connId', _, _) -> not $ connId'== connId) connList
                                        putMVar connListM connList'

safeGetConnection :: MVar ConnectionList -> ConnectionId -> IO (Maybe Connection)
safeGetConnection connListM cid = do
                                    connList <- takeMVar connListM
                                    putMVar connListM connList
                                    return $ getConnection connList cid

getConnection :: ConnectionList -> ConnectionId -> Maybe Connection
getConnection [] _ = Nothing
getConnection ((cid, conn, _):cs) cid' | cid == cid' = Just conn
                                                | otherwise = getConnection cs cid'

safeGetConnSendMutex :: MVar ConnectionList -> ConnectionId -> IO (Maybe ConnectionSendMutex)
safeGetConnSendMutex connListM connId = do
                                        connList <- takeMVar connListM
                                        putMVar connListM connList
                                        return $ getConnSendMutex connList connId
getConnSendMutex :: ConnectionList -> ConnectionId -> Maybe ConnectionSendMutex
getConnSendMutex [] _ = Nothing
getConnSendMutex ((connId', _, connSendM):cs) connId | connId' == connId = Just connSendM
                                                     | otherwise         = getConnSendMutex cs connId
safeStartConnectionList :: IO (MVar ConnectionList)
safeStartConnectionList = newMVar []                                      

-- Connection Id List functions

takeConnId :: ConnectionIdList -> (ConnectionId, ConnectionIdList)
takeConnId (next, []) = (next, (next + 1, []))
takeConnId (next, i:is) = (i, (next, is))

returnConnId :: ConnectionId -> ConnectionIdList -> ConnectionIdList
returnConnId i (next, is) = (next, i:is)

safeTakeConnId :: MVar ConnectionIdList -> IO ConnectionId
safeTakeConnId connIdListM = do
                                connIdList <- takeMVar connIdListM
                                    (i, connIdList') = takeConnId connIdList
                                putMVar connIdListM connIdList'
                                return i
safeReturnConnId :: ConnectionId -> MVar ConnectionIdList -> IO ()
safeReturnConnId i connIdListM = do
                                    connIdList <- takeMVar connIdListM
                                        connIdList' = returnConnId i connIdList
                                    putMVar connIdListM connIdList'
                                    return ()
safeStartConnId :: IO (MVar ConnectionIdList)
safeStartConnId = newMVar (0, [])                                    
-- Server function
server :: String -> Int -> ServerProgram -> IO ()
server host port prog = S.withSocketsDo $ do
                            sock          <- makeSocket host port
                            connListM     <- safeStartConnectionList
                            stdoutM       <- safeStartStdOutMutex
                            connIdListM   <- safeStartConnId
                            closePrinter  <- safeStartCloseThread
                            closedPrinter <- safeStartClosedThread
                            closedTextInt <- safeStartClosedThread
                            safePutStr stdoutM ""
                            _ <- forkIO (printInterface stdoutM closePrinter closedPrinter)
                            _ <- forkIO (textInterface connListM stdoutM closePrinter closedTextInt sock)
                            handleMainThreadExceptions $ forever $ do
                                penCon    <- makePendingConnection sock
                                conn      <- acceptRequest penCon
                                connId    <- safeTakeConnId connIdListM
                                connSendM <- safeStartMutex
                                safeAddConnection connListM connId conn connSendM
                                safePutStr stdoutM ("Connection established: " ++ (show connId))
                                _ <- forkIO (handleConnectionThreadExceptions connListM connIdListM stdoutM connId $ prog conn stdoutM connSendM)
                                return ()
                            waitForConnectionsCIL connIdListM
                            wait closedTextInt
                            wait closedPrinter

waitForConnectionsCIL :: MVar ConnectionIdList -> IO ()
waitForConnectionsCIL connIdListM = do
                                    (maxId, list) <- takeMVar connIdListM
                                        list' = buildList (maxId - 1)
                                        diff = list' \\ list
                                        allDone = null diff
                                    putMVar connIdListM (maxId, list)
                                    threadDelay 100000
                                    when (not allDone) (waitForConnectionsCIL connIdListM)

waitForConnectionsCL :: MVar ConnectionList -> IO ()
waitForConnectionsCL connListM = do
                                connList <- takeMVar connListM
                                putMVar connListM connList
                                threadDelay 100000
                                when (not $ null connList) (waitForConnectionsCL connListM)
buildList :: Int -> [Int]
buildList i | i >= 0 = i:(buildList (i - 1))
            | i < 0 = []
buildList _ = []            
-- Printer Interface
printInterface :: StdOutMutex -> CloseThread -> ClosedThread -> IO ()
printInterface stdoutM closeThread closedThread = do
                                        printStep stdoutM
                                        threadDelay 10000
                                        close  <- isSet closeThread
                                        when (not close) (printInterface stdoutM closeThread closedThread)
                                        closed <- isSet closedThread
                                        when (not closed) $ do
                                                                printStep stdoutM
                                                                set closedThread

printStep :: StdOutMutex -> IO ()
printStep stdoutM = do
                    strList <- takeMVar stdoutM
                    strList' <- flushPrintStr strList
                    putMVar stdoutM strList'    
flushPrintStr :: [[Char]] -> IO [[Char]]
flushPrintStr []            = return []
flushPrintStr (str:strList) = do
                            putStrLn str
                            flushPrintStr strList
-- Textual User Interface
textInterface :: MVar ConnectionList -> StdOutMutex -> CloseThread -> ClosedThread -> S.Socket -> IO ()
textInterface connListM stdoutM closePrinter closedThread sock = do
                                                                    line <- safeGetLine stdoutM
                                                                    doQuit <- doCommand connListM stdoutM closePrinter closedThread sock (words line)
                                                                    when (not doQuit) (textInterface connListM stdoutM closePrinter closedThread sock)
doCommand :: MVar ConnectionList -> StdOutMutex -> CloseThread -> ClosedThread -> S.Socket -> [String] -> IO Bool
doCommand _ _ _ _ _ []                                                          = return False
doCommand connListM stdoutM closePrinter closedThread sock (c:_) | c == "quit" = do
                                                                                    quit connListM stdoutM closePrinter closedThread sock  
                                                                                    return True
                                                                  | otherwise   = return False
quit :: MVar ConnectionList -> StdOutMutex -> CloseThread -> ClosedThread -> S.Socket -> IO ()
quit connListM stdoutM closePrinter closedThread sock = do
                                                            safePutStr stdoutM "Closing..."
                                                            connList <- takeMVar connListM
                                                            safePutStr stdoutM "Waiting for Server to close..."
                                                            closeSocket sock
                                                            safePutStr stdoutM "Server is Closed!"
                                                            closeAllConnections connList "Server is closing"
                                                            putMVar connListM connList
                                                            safePutStr stdoutM "Waiting for connections to close..."
                                                            waitForConnectionsCL connListM
                                                            safePutStr stdoutM "All Connections are closed!"
                                                            safePutStr stdoutM "All done!"
                                                            set closePrinter
                                                            set closedThread
-- Close connection(s)                                                               
closeAllConnections :: ConnectionList -> String -> IO ()
closeAllConnections [] _                   = return ()
closeAllConnections ((_, conn, mu):cs) str = do
                                            closeConnection mu conn str
                                            closeAllConnections cs str

closeConnection :: ConnectionSendMutex -> Connection -> String -> IO ()
closeConnection connMu conn str = do
                                safe connMu $ sendClose conn (T.pack str)
-- Main Thread Exception handlers
handleMainThreadExceptions :: IO () -> IO ()
handleMainThreadExceptions = handleIOError

handleIOError :: IO () -> IO ()
handleIOError = handle (ioErrorHandler)

ioErrorHandler :: IOException -> IO ()
ioErrorHandler (IOError {}) = return ()

-- Connection Thread Exception handlers
handleConnectionThreadExceptions :: MVar ConnectionList -> MVar ConnectionIdList -> StdOutMutex -> ConnectionId -> IO () -> IO ()
handleConnectionThreadExceptions connListM connIdListM stdoutM connId io = handleConnExcept $ handleOtherExceptions stdoutM connListM connId $ handleConnExcept io
                                                                                handleConnExcept = (handleConnectionException connListM connIdListM stdoutM connId)

handleOtherExceptions :: StdOutMutex -> MVar ConnectionList -> ConnectionId -> IO () -> IO ()
handleOtherExceptions stdoutM connListM connId = handle (connectionOtherExceptionHandler stdoutM connListM connId)

connectionOtherExceptionHandler :: StdOutMutex -> MVar ConnectionList -> ConnectionId -> SomeException -> IO ()
connectionOtherExceptionHandler stdoutM connListM cid e = do
                                                            connM      <- safeGetConnection connListM cid
                                                            connSendMM <- safeGetConnSendMutex connListM cid
                                                                hasConn      = not $ isNothing connM
                                                                hasConnSendM = not $ isNothing connSendMM
                                                                conn         = fromJust connM
                                                                connSendM    = fromJust connSendMM
                                                            safePutStr stdoutM ("Connection [" ++ (show cid) ++ "] has had an exception: " ++ (show e))
                                                            when (hasConn && hasConnSendM) $ do
                                                                                                closeConnection connSendM conn "Connection has caused a server-side exception"
                                                                                                safePutStr stdoutM ("Send close message...")
                                                                                                safePutStr stdoutM ("Reading from the connection for close event")
                                                                                                forever $ do
                                                                                                            msg <- receiveData conn :: IO T.Text
                                                                                                            safePutStr stdoutM (show msg)

handleConnectionException :: MVar ConnectionList -> MVar ConnectionIdList -> StdOutMutex -> ConnectionId -> IO () -> IO ()
handleConnectionException connListM connIdListM stdoutM connId = handle (connectionExceptionHandler connListM connIdListM stdoutM connId)

connectionExceptionHandler :: MVar ConnectionList -> MVar ConnectionIdList -> StdOutMutex -> ConnectionId -> ConnectionException -> IO ()
connectionExceptionHandler connListM connIdListM stdoutM connId (CloseRequest code msg) = do
                                                                            safeRemoveConnection connListM connId
                                                                            safeReturnConnId connId connIdListM
                                                                            safePutStr stdoutM ("Connection[" ++ (show connId) ++ "] closed: " ++ (show code) ++ " " ++ (show (fromLazyByteString msg :: T.Text)) )
connectionExceptionHandler connListM connIdListM stdoutM connId (ConnectionClosed)      = do
                                                                            safeRemoveConnection connListM connId
                                                                            safeReturnConnId connId connIdListM
                                                                            safePutStr stdoutM ("Connection[" ++ (show connId) ++ "] closed unexpectedly: " ++ (show connId))