module Simulation.Aivika.Distributed.Optimistic.Internal.TimeServer
       (TimeServerParams(..),
        TimeServerEnv(..),
        TimeServerStrategy(..),
        defaultTimeServerParams,
        defaultTimeServerEnv,
        timeServer,
        timeServerWithEnv,
        curryTimeServer) where
import qualified Data.Map as M
import qualified Data.Set as S
import Data.Maybe
import Data.IORef
import Data.Typeable
import Data.Binary
import Data.Time.Clock
import GHC.Generics
import Control.Monad
import Control.Monad.Trans
import Control.Exception
import qualified Control.Monad.Catch as C
import Control.Concurrent
import qualified Control.Distributed.Process as DP
import Simulation.Aivika.Distributed.Optimistic.Internal.Priority
import Simulation.Aivika.Distributed.Optimistic.Internal.Message
import Simulation.Aivika.Distributed.Optimistic.State
data TimeServerParams =
  TimeServerParams { tsLoggingPriority :: Priority,
                     
                     tsName :: String,
                     
                     tsReceiveTimeout :: Int,
                     
                     tsTimeSyncTimeout :: Int,
                     
                     tsTimeSyncDelay :: Int,
                     
                     tsProcessMonitoringEnabled :: Bool,
                     
                     tsProcessMonitoringDelay :: Int,
                     
                     tsProcessReconnectingEnabled :: Bool,
                     
                     tsProcessReconnectingDelay :: Int,
                     
                     tsSimulationMonitoringInterval :: Int,
                     
                     tsSimulationMonitoringTimeout :: Int,
                     
                     tsStrategy :: TimeServerStrategy
                     
                   } deriving (Eq, Ord, Show, Typeable, Generic)
instance Binary TimeServerParams
data TimeServerEnv =
  TimeServerEnv { tsSimulationMonitoringAction :: Maybe (TimeServerState -> DP.Process ())
                  
                }
data TimeServerStrategy = WaitIndefinitelyForLogicalProcess
                          
                        | TerminateDueToLogicalProcessTimeout Int
                          
                          
                          
                        | UnregisterLogicalProcessDueToTimeout Int
                          
                          
                          
                          
                          
                        deriving (Eq, Ord, Show, Typeable, Generic)
instance Binary TimeServerStrategy
data TimeServer =
  TimeServer { tsParams :: TimeServerParams,
               
               tsInitQuorum :: Int,
               
               tsInInit :: IORef Bool,
               
               tsTerminating :: IORef Bool,
               
               tsTerminated :: IORef Bool,
               
               tsProcesses :: IORef (M.Map DP.ProcessId LogicalProcessInfo),
               
               tsProcessesInFind :: IORef (S.Set DP.ProcessId),
               
               tsGlobalTime :: IORef (Maybe Double),
               
               tsGlobalTimeTimestamp :: IORef (Maybe UTCTime),
               
               tsLogicalProcessValidationTimestamp :: IORef UTCTime
               
             }
data LogicalProcessInfo =
  LogicalProcessInfo { lpId :: DP.ProcessId,
                       
                       lpLocalTime :: IORef (Maybe Double),
                       
                       lpTimestamp :: IORef UTCTime,
                       
                       lpMonitorRef :: Maybe DP.MonitorRef
                       
                     }
defaultTimeServerParams :: TimeServerParams
defaultTimeServerParams =
  TimeServerParams { tsLoggingPriority = WARNING,
                     tsName = "Time Server",
                     tsReceiveTimeout = 100000,
                     tsTimeSyncTimeout = 60000000,
                     tsTimeSyncDelay = 1000000,
                     tsProcessMonitoringEnabled = False,
                     tsProcessMonitoringDelay = 3000000,
                     tsProcessReconnectingEnabled = False,
                     tsProcessReconnectingDelay = 5000000,
                     tsSimulationMonitoringInterval = 30000000,
                     tsSimulationMonitoringTimeout = 100000,
                     tsStrategy = TerminateDueToLogicalProcessTimeout 300000000
                   }
defaultTimeServerEnv :: TimeServerEnv
defaultTimeServerEnv =
  TimeServerEnv { tsSimulationMonitoringAction = Nothing }
newTimeServer :: Int -> TimeServerParams -> IO TimeServer
newTimeServer n ps =
  do f  <- newIORef True
     ft <- newIORef False
     fe <- newIORef False
     m  <- newIORef M.empty
     s  <- newIORef S.empty
     t0 <- newIORef Nothing
     t' <- newIORef Nothing
     t2 <- getCurrentTime >>= newIORef
     return TimeServer { tsParams = ps,
                         tsInitQuorum = n,
                         tsInInit = f,
                         tsTerminating = ft,
                         tsTerminated = fe,
                         tsProcesses = m,
                         tsProcessesInFind = s,
                         tsGlobalTime = t0,
                         tsGlobalTimeTimestamp = t',
                         tsLogicalProcessValidationTimestamp = t2
                       }
processTimeServerMessage :: TimeServer -> TimeServerMessage -> DP.Process ()
processTimeServerMessage server (RegisterLogicalProcessMessage pid) =
  join $ liftIO $
  do m <- readIORef (tsProcesses server)
     case M.lookup pid m of
       Just x ->
         return $
         logTimeServer server WARNING $
         "Time Server: already registered process identifier " ++ show pid
       Nothing  ->
         do t <- newIORef Nothing
            utc <- getCurrentTime >>= newIORef
            modifyIORef (tsProcesses server) $
              M.insert pid LogicalProcessInfo { lpId = pid, lpLocalTime = t, lpTimestamp = utc, lpMonitorRef = Nothing }
            return $
              do when (tsProcessMonitoringEnabled $ tsParams server) $
                   do logTimeServer server INFO $
                        "Time Server: monitoring the process by identifier " ++ show pid
                      r <- DP.monitor pid
                      liftIO $
                        modifyIORef (tsProcesses server) $
                        M.update (\x -> Just x { lpMonitorRef = Just r }) pid
                 serverId <- DP.getSelfPid
                 DP.send pid (RegisterLogicalProcessAcknowledgementMessage serverId)
                 tryStartTimeServer server
processTimeServerMessage server (UnregisterLogicalProcessMessage pid) =
  join $ liftIO $
  do m <- readIORef (tsProcesses server)
     case M.lookup pid m of
       Nothing ->
         return $
         logTimeServer server WARNING $
         "Time Server: unknown process identifier " ++ show pid
       Just x  ->
         do modifyIORef (tsProcesses server) $
              M.delete pid
            modifyIORef (tsProcessesInFind server) $
              S.delete pid
            return $
              do when (tsProcessMonitoringEnabled $ tsParams server) $
                   case lpMonitorRef x of
                     Nothing -> return ()
                     Just r  ->
                       do logTimeServer server INFO $
                            "Time Server: unmonitoring the process by identifier " ++ show pid
                          DP.unmonitor r
                 serverId <- DP.getSelfPid
                 DP.send pid (UnregisterLogicalProcessAcknowledgementMessage serverId)
                 tryProvideTimeServerGlobalTime server
                 tryTerminateTimeServer server
processTimeServerMessage server (TerminateTimeServerMessage pid) =
  join $ liftIO $
  do m <- readIORef (tsProcesses server)
     case M.lookup pid m of
       Nothing ->
         return $
         logTimeServer server WARNING $
         "Time Server: unknown process identifier " ++ show pid
       Just x  ->
         do modifyIORef (tsProcesses server) $
              M.delete pid
            modifyIORef (tsProcessesInFind server) $
              S.delete pid
            return $
              do when (tsProcessMonitoringEnabled $ tsParams server) $
                   case lpMonitorRef x of
                     Nothing -> return ()
                     Just r  ->
                       do logTimeServer server INFO $
                            "Time Server: unmonitoring the process by identifier " ++ show pid
                          DP.unmonitor r
                 serverId <- DP.getSelfPid
                 DP.send pid (TerminateTimeServerAcknowledgementMessage serverId)
                 startTerminatingTimeServer server
processTimeServerMessage server (RequestGlobalTimeMessage pid) =
  tryComputeTimeServerGlobalTime server
processTimeServerMessage server (LocalTimeMessage pid t') =
  join $ liftIO $
  do m <- readIORef (tsProcesses server)
     case M.lookup pid m of
       Nothing ->
         return $
         do logTimeServer server WARNING $
              "Time Server: unknown process identifier " ++ show pid
            processTimeServerMessage server (RegisterLogicalProcessMessage pid)
            processTimeServerMessage server (LocalTimeMessage pid t')
       Just x  ->
         do utc <- getCurrentTime
            writeIORef (lpLocalTime x) (Just t')
            writeIORef (lpTimestamp x) utc
            modifyIORef (tsProcessesInFind server) $
              S.delete pid
            return $
              tryProvideTimeServerGlobalTime server
processTimeServerMessage server (ComputeLocalTimeAcknowledgementMessage pid) =
  join $ liftIO $
  do m <- readIORef (tsProcesses server)
     case M.lookup pid m of
       Nothing ->
         return $
         do logTimeServer server WARNING $
              "Time Server: unknown process identifier " ++ show pid
            processTimeServerMessage server (RegisterLogicalProcessMessage pid)
            processTimeServerMessage server (ComputeLocalTimeAcknowledgementMessage pid)
       Just x  ->
         do utc <- getCurrentTime
            writeIORef (lpTimestamp x) utc
            return $
              return ()
processTimeServerMessage server (ProvideTimeServerStateMessage pid) =
  do let ps   = tsParams server
         name = tsName ps
     serverId <- DP.getSelfPid
     t <- liftIO $ readIORef (tsGlobalTime server)
     m <- liftIO $ readIORef (tsProcesses server)
     let msg = TimeServerState { tsStateId = serverId,
                                 tsStateName = name,
                                 tsStateGlobalVirtualTime = t,
                                 tsStateLogicalProcesses = M.keys m }
     DP.send pid msg
processTimeServerMessage server (ReMonitorTimeServerMessage pids) =
  do forM_ pids $ \pid ->
       do 
          logTimeServer server NOTICE $ "Time Server: re-monitoring " ++ show pid
          
          DP.monitor pid
          
          logTimeServer server NOTICE $ "Time Server: started re-monitoring " ++ show pid
          
     resetComputingTimeServerGlobalTime server
(.>=.) :: Maybe Double -> Maybe Double -> Bool
(.>=.) (Just x) (Just y) = x >= y
(.>=.) _ _ = False
(.>.) :: Maybe Double -> Maybe Double -> Bool
(.>.) (Just x) (Just y) = x > y
(.>.) _ _ = False
tryStartTimeServer :: TimeServer -> DP.Process ()
tryStartTimeServer server =
  join $ liftIO $
  do f <- readIORef (tsInInit server)
     if not f
       then return $
            return ()
       else do m <- readIORef (tsProcesses server)
               if M.size m < tsInitQuorum server
                 then return $
                      return ()
                 else do writeIORef (tsInInit server) False
                         return $
                           do logTimeServer server INFO $
                                "Time Server: starting"
                              tryComputeTimeServerGlobalTime server
  
tryComputeTimeServerGlobalTime :: TimeServer -> DP.Process ()
tryComputeTimeServerGlobalTime server =
  join $ liftIO $
  do f <- readIORef (tsInInit server)
     if f
       then return $
            return ()
       else do s <- readIORef (tsProcessesInFind server)
               if S.size s > 0
                 then return $
                      return ()
                 else return $
                      computeTimeServerGlobalTime server
resetComputingTimeServerGlobalTime :: TimeServer -> DP.Process ()
resetComputingTimeServerGlobalTime server =
  do logTimeServer server NOTICE $
       "Time Server: reset computing the global time"
     liftIO $
       do utc <- getCurrentTime
          writeIORef (tsProcessesInFind server) S.empty
          writeIORef (tsGlobalTimeTimestamp server) (Just utc)
tryProvideTimeServerGlobalTime :: TimeServer -> DP.Process ()
tryProvideTimeServerGlobalTime server =
  join $ liftIO $
  do f <- readIORef (tsInInit server)
     if f
       then return $
            return ()
       else do s <- readIORef (tsProcessesInFind server)
               if S.size s > 0
                 then return $
                      return ()
                 else return $
                      provideTimeServerGlobalTime server
computeTimeServerGlobalTime :: TimeServer -> DP.Process ()
computeTimeServerGlobalTime server =
  do logTimeServer server DEBUG $
       "Time Server: computing the global time..."
     zs <- liftIO $ fmap M.assocs $ readIORef (tsProcesses server)
     forM_ zs $ \(pid, x) ->
       liftIO $
       modifyIORef (tsProcessesInFind server) $
       S.insert pid
     forM_ zs $ \(pid, x) ->
       DP.send pid ComputeLocalTimeMessage
provideTimeServerGlobalTime :: TimeServer -> DP.Process ()
provideTimeServerGlobalTime server =
  do t0 <- liftIO $ timeServerGlobalTime server
     logTimeServer server INFO $
       "Time Server: providing the global time = " ++ show t0
     case t0 of
       Nothing -> return ()
       Just t0 ->
         do t' <- liftIO $ readIORef (tsGlobalTime server)
            when (t' .>. Just t0) $
              logTimeServer server NOTICE
              "Time Server: the global time has decreased"
            timestamp <- liftIO getCurrentTime
            liftIO $ writeIORef (tsGlobalTime server) (Just t0)
            liftIO $ writeIORef (tsGlobalTimeTimestamp server) (Just timestamp)
            zs <- liftIO $ fmap M.assocs $ readIORef (tsProcesses server)
            forM_ zs $ \(pid, x) ->
              DP.send pid (GlobalTimeMessage t0)
timeServerGlobalTime :: TimeServer -> IO (Maybe Double)
timeServerGlobalTime server =
  do zs <- fmap M.assocs $ readIORef (tsProcesses server)
     case zs of
       [] -> return Nothing
       ((pid, x) : zs') ->
         do t <- readIORef (lpLocalTime x)
            loop zs t
              where loop [] acc = return acc
                    loop ((pid, x) : zs') acc =
                      do t <- readIORef (lpLocalTime x)
                         case t of
                           Nothing ->
                             loop zs' Nothing
                           Just _  ->
                             loop zs' (liftM2 min t acc)
minTimestampLogicalProcess :: TimeServer -> IO (Maybe LogicalProcessInfo)
minTimestampLogicalProcess server =
  do zs <- fmap M.assocs $ readIORef (tsProcesses server)
     case zs of
       [] -> return Nothing
       ((pid, x) : zs') -> loop zs x
         where loop [] acc = return (Just acc)
               loop ((pid, x) : zs') acc =
                 do t0 <- readIORef (lpTimestamp acc)
                    t  <- readIORef (lpTimestamp x)
                    if t0 <= t
                      then loop zs' acc
                      else loop zs' x
filterLogicalProcesses :: TimeServer -> [DP.ProcessId] -> IO [DP.ProcessId]
filterLogicalProcesses server pids =
  do xs <- readIORef (tsProcesses server)
     return $ filter (\pid -> M.member pid xs) pids
startTerminatingTimeServer :: TimeServer -> DP.Process ()
startTerminatingTimeServer server =
  do logTimeServer server INFO "Time Server: start terminating..."
     liftIO $
       writeIORef (tsTerminating server) True
     tryTerminateTimeServer server
tryTerminateTimeServer :: TimeServer -> DP.Process ()
tryTerminateTimeServer server =
  do f <- liftIO $ readIORef (tsTerminating server)
     when f $
       do m <- liftIO $ readIORef (tsProcesses server)
          when (M.null m) $
            do logTimeServer server INFO "Time Server: terminate"
               DP.terminate
secondsToMicroseconds :: Double -> Int
secondsToMicroseconds x = fromInteger $ toInteger $ round (1000000 * x)
data InternalTimeServerMessage = InternalTimeServerMessage TimeServerMessage
                                 
                               | InternalProcessMonitorNotification DP.ProcessMonitorNotification
                                 
                               | InternalKeepAliveMessage KeepAliveMessage
                                 
handleTimeServerException :: TimeServer -> SomeException -> DP.Process ()
handleTimeServerException server e =
  do 
     logTimeServer server ERROR $ "Exception occured: " ++ show e
     
     C.throwM e
timeServer :: Int -> TimeServerParams -> DP.Process ()
timeServer n ps = timeServerWithEnv n ps defaultTimeServerEnv
timeServerWithEnv :: Int -> TimeServerParams -> TimeServerEnv -> DP.Process ()
timeServerWithEnv n ps env =
  do server <- liftIO $ newTimeServer n ps
     logTimeServer server INFO "Time Server: starting..."
     let loop utc0 =
           do let f1 :: TimeServerMessage -> DP.Process InternalTimeServerMessage
                  f1 x = return (InternalTimeServerMessage x)
                  f2 :: DP.ProcessMonitorNotification -> DP.Process InternalTimeServerMessage
                  f2 x = return (InternalProcessMonitorNotification x)
                  f3 :: KeepAliveMessage -> DP.Process InternalTimeServerMessage
                  f3 x = return (InternalKeepAliveMessage x)
              a <- DP.receiveTimeout (tsReceiveTimeout ps) [DP.match f1, DP.match f2, DP.match f3]
              case a of
                Nothing -> return ()
                Just (InternalTimeServerMessage m) ->
                  do 
                     logTimeServer server DEBUG $
                       "Time Server: " ++ show m
                     
                     processTimeServerMessage server m
                Just (InternalProcessMonitorNotification m) ->
                  handleProcessMonitorNotification m server
                Just (InternalKeepAliveMessage m) ->
                  do 
                     logTimeServer server DEBUG $
                       "Time Server: " ++ show m
                     
                     return ()
              utc <- liftIO getCurrentTime
              validation <- liftIO $ readIORef (tsLogicalProcessValidationTimestamp server)
              timestamp <- liftIO $ readIORef (tsGlobalTimeTimestamp server)
              when (timeSyncTimeoutExceeded server validation utc) $
                validateLogicalProcesses server utc
              case timestamp of
                Just x | timeSyncTimeoutExceeded server x utc ->
                  resetComputingTimeServerGlobalTime server
                _ -> return ()
              if timeSyncDelayExceeded server utc0 utc
                then do tryComputeTimeServerGlobalTime server
                        loop utc
                else loop utc0
         loop' utc0 =
           C.finally
           (loop utc0)
           (liftIO $
            atomicWriteIORef (tsTerminated server) True)
     case tsSimulationMonitoringAction env of
       Nothing  -> return ()
       Just act ->
         do serverId  <- DP.getSelfPid
            monitorId <-
              DP.spawnLocal $
              let loop =
                    do f <- liftIO $ readIORef (tsTerminated server)
                       unless f $
                         do x <- DP.expectTimeout (tsSimulationMonitoringTimeout ps)
                            case x of
                              Nothing -> return ()
                              Just st -> act st
                            loop
              in C.catch loop (handleTimeServerException server)
            DP.spawnLocal $
              let loop =
                    do f <- liftIO $ readIORef (tsTerminated server)
                       unless f $
                         do liftIO $
                              threadDelay (tsSimulationMonitoringInterval ps)
                            DP.send serverId (ProvideTimeServerStateMessage monitorId)
                            loop
              in C.catch loop (handleTimeServerException server)
            return ()
     C.catch (liftIO getCurrentTime >>= loop') (handleTimeServerException server) 
handleProcessMonitorNotification :: DP.ProcessMonitorNotification -> TimeServer -> DP.Process ()
handleProcessMonitorNotification m@(DP.ProcessMonitorNotification _ pid0 reason) server =
  do let ps = tsParams server
         recv m@(DP.ProcessMonitorNotification _ _ _) = 
           do 
              logTimeServer server WARNING $
                "Time Server: received a process monitor notification " ++ show m
              
              return m
     recv m
     when (tsProcessReconnectingEnabled ps && reason == DP.DiedDisconnect) $
       do liftIO $
            threadDelay (tsProcessReconnectingDelay ps)
          let pred m@(DP.ProcessMonitorNotification _ _ reason) = reason == DP.DiedDisconnect
              loop :: [DP.ProcessId] -> DP.Process [DP.ProcessId]
              loop acc =
                do y <- DP.receiveTimeout 0 [DP.matchIf pred recv]
                   case y of
                     Nothing -> return $ reverse acc
                     Just m@(DP.ProcessMonitorNotification _ pid _) -> loop (pid : acc)
          pids <- loop [pid0] >>= (liftIO . filterLogicalProcesses server)
          
          logTimeServer server NOTICE "Begin reconnecting..."
          
          forM_ pids $ \pid ->
            do 
               logTimeServer server NOTICE $
                 "Time Server: reconnecting to " ++ show pid
               
               DP.reconnect pid
          serverId <- DP.getSelfPid
          DP.spawnLocal $
            let action =
                  do liftIO $
                       threadDelay (tsProcessMonitoringDelay ps)
                     
                     logTimeServer server NOTICE $ "Time Server: proceed to the re-monitoring"
                     
                     DP.send serverId (ReMonitorTimeServerMessage pids)
            in C.catch action (handleTimeServerException server)
          return ()
timeSyncDelayExceeded :: TimeServer -> UTCTime -> UTCTime -> Bool
timeSyncDelayExceeded server utc0 utc =
  let dt = fromRational $ toRational (diffUTCTime utc utc0)
  in secondsToMicroseconds dt > (tsTimeSyncDelay $ tsParams server)
timeSyncTimeoutExceeded :: TimeServer -> UTCTime -> UTCTime -> Bool
timeSyncTimeoutExceeded server utc0 utc =
  let dt = fromRational $ toRational (diffUTCTime utc utc0)
  in secondsToMicroseconds dt > (tsTimeSyncTimeout $ tsParams server)
diffLogicalProcessTimestamp :: UTCTime -> LogicalProcessInfo -> IO Int
diffLogicalProcessTimestamp utc lp =
  do utc0 <- readIORef (lpTimestamp lp)
     let dt = fromRational $ toRational (diffUTCTime utc utc0)
     return $ secondsToMicroseconds dt
validateLogicalProcesses :: TimeServer -> UTCTime -> DP.Process ()
validateLogicalProcesses server utc =
  do logTimeServer server NOTICE $
       "Time Server: validating the logical processes"
     liftIO $
       writeIORef (tsLogicalProcessValidationTimestamp server) utc
     case tsStrategy (tsParams server) of
       WaitIndefinitelyForLogicalProcess ->
         return ()
       TerminateDueToLogicalProcessTimeout timeout ->
         do x <- liftIO $ minTimestampLogicalProcess server
            case x of
              Just lp ->
                do diff <- liftIO $ diffLogicalProcessTimestamp utc lp
                   when (diff > timeout) $
                     do logTimeServer server WARNING $
                          "Time Server: terminating due to the exceeded logical process timeout"
                        DP.terminate
              Nothing ->
                return ()
       UnregisterLogicalProcessDueToTimeout timeout ->
         do x <- liftIO $ minTimestampLogicalProcess server
            case x of
              Just lp ->
                do diff <- liftIO $ diffLogicalProcessTimestamp utc lp
                   when (diff > timeout) $
                     do logTimeServer server WARNING $
                          "Time Server: unregistering the logical process due to the exceeded timeout"
                        processTimeServerMessage server (UnregisterLogicalProcessMessage $ lpId lp)
              Nothing ->
                return ()
curryTimeServer :: (Int, TimeServerParams) -> DP.Process ()
curryTimeServer (n, ps) = timeServer n ps
logTimeServer :: TimeServer -> Priority -> String -> DP.Process ()
logTimeServer server p message =
  when (tsLoggingPriority (tsParams server) <= p) $
  DP.say $
  embracePriority p ++ " " ++ message