{-# LANGUAGE ConstraintKinds #-} 
module Network.SimpleServer(CmdHandler,
                            ConnectionHandler,
                            DisconnectHandler,
                            ClientConn(cid, lookup, modify),
                            Server(),
                            new, 
                            addCommand,
                            start, 
                            stop, 
                            respond, 
                            broadcast, 
                            disconnectClient,
                            clientList) where

import Control.Concurrent hiding(modifyMVar)
import qualified Control.Concurrent.Lock as Lock
import Control.Concurrent.MVar hiding(modifyMVar)
import Control.Concurrent.Thread.Delay
import Control.Exception
import Control.Monad
import qualified Data.ByteString.Char8 as ByteS
import Data.Either
import Data.Foldable(toList)
import qualified Data.HashTable.IO as HT
import Data.IORef
import Data.Maybe
import Data.Time.Clock
import qualified Data.Sequence as Seq
import qualified Network as Net
import qualified Network.Socket as Net(close)
import System.IO(Handle, hSetBuffering, BufferMode(NoBuffering))

-- |A CmdHandler is used to handle a command in the form of a list of strings
type CmdHandler = [String] -> Server -> ClientConn -> IO ()

-- |A ConnectionHandler is called each time a client connects to the server.
type ConnectionHandler = Server -> ClientConn -> IO ()

-- |A DisconnectHandler is called each time a client is disconnected from the server.
type DisconnectHandler = Server -> ClientConn -> IO ()

{-|
Describes a Clients connection and provides an interface for
storing data associated with the client. Each client will be given
a unique cid and are Eq if their cid's are Eq.

A ClientConn comes packaged with two functions for storing additional
information in Strings. `lookup` and `modify`. The lookup function
takes a key and returns the current value of the key or the empty
string if it has never been set. The modify function
takes a key and value and updates it such that the next call to
lookup with that key will return the value provided.
-}
data ClientConn = ClientConn { -- | The Unique ID for this client
                               cid       :: Integer,
                               -- | A lookup function for this client
                               lookup    :: (String -> IO String),
                               -- | A modify function for this client
                               modify    :: (String -> String -> IO ()),
                               chandle   :: Handle,
                               host      :: Net.HostName,
                               pid       :: Net.PortNumber,
                               msgList  :: List String,
                               dead      :: MVar Bool,
                               timestamp :: TimeStamp,
                               tid       :: MVar (ThreadId, ThreadId),
                               lock      :: MVar Lock.Lock}

instance Eq ClientConn where
  (==) c0 c1 = (cid c0) == (cid c1)

-- |A Generic Server
data Server = Server { port       :: Net.PortID, 
                       socket     :: IORef (Maybe Net.Socket), 
                       clients    :: List ClientConn,
                       cmdList   :: List Message,
                       lastclean  :: TimeStamp,
                       timeout    :: NominalDiffTime,
                       serverLock :: MVar Lock.Lock,
                       cmdTable   :: CmdTable,
                       nextID     :: MVar Integer,
                       cHandler   :: ConnectionHandler,
                       dHandler   :: DisconnectHandler,
                       threads    :: MVar (ThreadId, ThreadId)}





{-|
Creates a new server that is not connected to anything.
If a client does not talk to a server for more than 60 seconds
it will be disconnected.
-}
new :: ConnectionHandler -> DisconnectHandler -> Int -> IO Server
new cHandler dHandler pid = do
  socket   <- newIORef Nothing
  clients  <- emptyList
  cmdList <- emptyList
  time <- getCurrentTime
  lastClean <- newMVar time
  lock <- Lock.new
  serverLock <- newMVar lock
  let allowed = 60
  cmdTable <- HT.new
  nextID <- newMVar 0
  threads <- newEmptyMVar
  return $ Server (Net.PortNumber $ fromIntegral pid) socket clients cmdList lastClean allowed serverLock cmdTable nextID cHandler dHandler threads

{-|
Given a server, a command, and a command handler, adds the command to the
server. If the command already exists, it will be overwritten
-}
addCommand :: Server -> String -> CmdHandler -> IO ()
addCommand server cmd handler = HT.insert (cmdTable server) cmd handler

{-|
Starts a server if it is currently not started. Otherwise, does nothing.
-}
start :: Server -> IO ()
start server = Net.withSocketsDo $ do
  maybeSocket <- readIORef $ socket server
  case maybeSocket of
    Nothing -> do 
      s <- try $ Net.listenOn (port server) :: IO (Either IOException Net.Socket)
      case s of 
        Left e -> debugLn' (serverLock server) $ "The server could not be started: " ++ (show e)
        Right s -> do
          writeIORef (socket server) (Just s)
          rt <- forkIO $ runServer server
          at <- forkIO $ acceptCon server s
          putMVar (threads server) (rt, at)
          return ()
    Just s -> return ()

{-|
Stops a server if it is running sending a disconnect message
to all clients. Otherwise, does nothing.
Any shutdown operations should be run before this is called. 
-}
stop :: Server -> IO ()
stop server = Net.withSocketsDo $ do
  maybeSocket <- readIORef $ socket server
  case maybeSocket of
    Nothing -> return ()
    Just s -> do
      clist <- takeAll $ clients server
      mapM_ (disconnect server) (toList clist)
      (rt, at) <- takeMVar (threads server)
      killThread rt
      killThread at
      Net.close s
      writeIORef (socket server) Nothing

{-|
Adds a response message to the queue.
-}
respond :: ClientConn -> String -> IO ()
respond client string = put (msgList client) string

{-|
Broadcasts a message to all clients on the server
-}
broadcast :: Server -> String -> IO ()
broadcast server string = do
  debugLn' (serverLock server) "Reading client list"
  q <- readAll (clients server)
  debugLn' (serverLock server) "Processing client list"
  mapM_ ((flip put string) . msgList) q
  debugLn' (serverLock server) "Message queued."

{-|
Disconnects the client if they are on this server. If
they are not on this server, the results are unspecified.
-}
disconnectClient :: Server -> ClientConn -> IO ()
disconnectClient server client = do
  swapMVar (dead client) True
  clean server

{-|
Returns a list of all clients that are currently connected to the server
-}
clientList :: Server -> IO [ClientConn]
clientList = readAll . clients


--------------------------------------
-- Helper Functions and Types Begin --
--------------------------------------

type List a = MVar (Seq.Seq a)
type TimeStamp = MVar UTCTime
type CmdTable = HT.BasicHashTable String CmdHandler
type UserTable = HT.BasicHashTable String String


data Message = Message { cmd    :: String,
                         client :: ClientConn } deriving Eq

{-|              
Creates a new client connection
-}
newConn :: Integer -> Handle -> Net.HostName -> Net.PortNumber -> IO ClientConn
newConn id handle host pid = do
  queue <- emptyList
  dead' <- newMVar False
  tid <- newEmptyMVar
  timestamp <- newEmptyMVar
  lock <- newEmptyMVar
  table <- HT.new
  l <- Lock.new
  let lookup = safeLookup l table
      modify = safeModify l table
  return $ ClientConn id lookup modify handle host pid queue dead' timestamp tid lock

safeLookup :: Lock.Lock -> UserTable -> (String -> IO String)
safeLookup lock usertable = (\key -> do
                                Lock.acquire lock
                                val <- HT.lookup usertable key
                                Lock.release lock
                                return $ case val of
                                    Nothing -> ""
                                    Just x -> x)

safeModify :: Lock.Lock -> UserTable -> (String -> String -> IO ())
safeModify lock usertable = (\key val -> do
                                Lock.acquire lock
                                HT.insert usertable key val
                                Lock.release lock)



{-| 
The main loop for the server. Checks to see if the
server has started, if it has it checks for any clients
who need to be disconnected then processes any commands in its queue.
Every 30 seconds, the server will check to see if a client has disconnected
or timed out and remove those clients from the client list.
After processing the loop continues until the server is stopped.
If there are no commands, it will wait approx. 1/10th of a second
before continuing.
-}
runServer :: Server -> IO ()
runServer server = Net.withSocketsDo $ do
  maybeSocket <- readIORef $ socket server
  case maybeSocket of
    Nothing -> return ()
    Just _ -> do
      checkClean server
      cmds <- takeAll (cmdList server)
      if (cmds == []) 
        then delay (1000*100) 
        else do
          debugLn' (serverLock server) "Processing Commands..."
          mapM_ (processCommand server) cmds
          debugLn' (serverLock server) "Done."
      runServer server

{-|
Using a servers command table, processess the a message. If
the message cannot be processed, a message is added to its response queue
-}
processCommand :: Server -> Message -> IO ()
processCommand server msg = do
  let commands = words (cmd msg)
  if commands == [] 
    then return ()
    else do
      maybeFunction <- HT.lookup (cmdTable server) (head commands)
      case maybeFunction of
        Nothing -> do
          debugLn' (serverLock server) $ "Could not process command: " ++ (cmd msg)
          put (response_queue msg) ("Invalid command: " ++ (cmd msg))
        Just f -> f commands server (client msg)
      where response_queue = msgList . client
{-|
Checks for dead or timed out clients and removes them.
-}
checkClean :: Server -> IO ()
checkClean server = do
  time <- getCurrentTime
  last <- readMVar (lastclean server)
  let passed = diffUTCTime time last
      allowed = timeout server
  if (passed > allowed) 
    then do
      swapMVar (lastclean server) time
      clean server
    else return ()

clean :: Server -> IO ()
clean server = do
  let allowed = timeout server
  clist <- takeMVar (clients server)
  (newCList, removed) <- filterM' (timedout server allowed) clist
  putMVar (clients server) (Seq.fromList newCList)
  mapM_ (disconnect server) removed


-- Helper function for filtering Seq with a pred of a -> IO Bool
filterM' :: Monad m => (a -> m Bool) -> Seq.Seq a -> m ([a],[a])
filterM' pred seq = do
  ls <- filterM pred (toList seq)
  ls' <- filterM not' (toList seq)
  return (ls, ls')
  where not' a = do
          val <- pred a
          return $ not val

-- Helper that checks if a client is timed out or marked dead. If they are
-- they are sent a disconnect message and removed from the  
-- client list
timedout :: Server -> NominalDiffTime -> ClientConn -> IO Bool
timedout server allowed client = do
  time <- getCurrentTime
  last <- readMVar (timestamp client)
  dead' <- readMVar (dead client)
  let passed = diffUTCTime time last
  return $ not $ (passed > allowed) || (dead' == True)

-- Sends a disconnect message to a client, marks it dead
-- and kills any associated threads
disconnect :: Server -> ClientConn -> IO ()
disconnect server client = do
  (dHandler server) server client
  flush server client
  (wio,rio) <- readMVar $ tid client
  killThread wio
  killThread rio
  swapMVar (dead client) True
  return ()

flush :: Server -> ClientConn -> IO ()
flush server client = do
  messages <- takeAll $ msgList client
  mapM_ (hPutStrLn (chandle client)) messages

{-|
Accepts a connection and adds it to the clients list
-}
acceptCon :: Server -> Net.Socket -> IO ()
acceptCon server sock = do
  (handle, host, pid) <- Net.accept sock
  hSetBuffering handle NoBuffering
  id <- takeMVar (nextID server)
  putMVar (nextID server) (id+1)
  conn <- newConn id handle host pid
  time <- getCurrentTime
  putMVar (timestamp conn) time
  lock' <- readMVar $ serverLock server
  putMVar (lock conn) lock'
  put (clients server) conn
  wio <- forkIO $ writeClient conn
  rio <- forkIO $ readClient conn (cmdList server)
  putMVar (tid conn) (wio,rio)
  (cHandler server) server conn
  acceptCon server sock
  
{-
The main loop for receiving input from a client.
Read a whole string input, create a message, queue it for processing, repeat
-}
readClient :: ClientConn -> List Message -> IO ()
readClient client queue = do
  either <- try $ hGetLine (chandle client) :: IO (Either IOException String)
  case either of
    Left e -> do 
      swapMVar (dead client) True
      return ()
    Right val -> do
      time <- getCurrentTime
      swapMVar (timestamp client) time
      put queue (Message val client)
      readClient client queue
{-
The main loop for sending output to a client. All messages that are queued
to be sent, are sent. If the queue was empty, the thread sleeps for approx 1/10th
a second
-}
writeClient :: ClientConn -> IO () 
writeClient client = do
  queue <- takeAll (msgList client)
  if queue == [] 
    then do 
      delay (1000*100)
      writeClient client
    else do
      debugLn' (lock client) "Client List non-empty. Writing to client."
      either <- try $ mapM_ (hPutStrLn (chandle client)) queue :: IO (Either IOException ())
      case either of
        Left e -> do
          debugLn' (lock client) $ "Could not read from handle: " ++ (show e)
          swapMVar (dead client) True
          return ()
        Right _ -> writeClient client


-- Debugging putStrLn that takes a lock such that
-- the output is readable.
putStrLn' :: MVar Lock.Lock -> String -> IO ()
putStrLn' mvar string = do
  lock <- readMVar mvar
  Lock.acquire lock
  putStrLn string
  Lock.release lock
  
hPutStrLn :: Handle -> String -> IO ()
hPutStrLn handle string = ByteS.hPutStrLn handle (ByteS.pack string)

hGetLine :: Handle -> IO String
hGetLine handle = do
  line <- ByteS.hGetLine handle
  return $ ByteS.unpack line

debug = False

debugLn' :: MVar Lock.Lock -> String -> IO ()
debugLn' lock str = if debug then putStrLn' lock str else return ()

emptyList :: IO (List a)
emptyList = newMVar Seq.empty

takeAll :: List a -> IO [a]
takeAll queue = do
  q <- swapMVar queue Seq.empty
  return $ toList q

readAll :: List a -> IO [a]
readAll queue = do
  q <- readMVar queue
  return $ toList q

modifyMVar :: MVar a -> (a -> a) -> IO ()
modifyMVar mvar f = do
  el <- takeMVar mvar
  putMVar mvar (f el)

put :: List a -> a -> IO ()
put queue el = modifyMVar queue (Seq.|> el)