module Control.Distributed.Process.Management.Internal.Trace.Tracer
  ( 
    traceController
    
  , defaultTracer
  , systemLoggerTracer
  , logfileTracer
  , eventLogTracer
  ) where
import Control.Applicative
import Control.Concurrent.Chan (writeChan)
import Control.Concurrent.MVar
  ( MVar
  , putMVar
  )
import Control.Distributed.Process.Internal.CQueue
  ( CQueue
  )
import Control.Distributed.Process.Internal.Primitives
  ( die
  , receiveWait
  , forward
  , sendChan
  , match
  , matchAny
  , matchIf
  , handleMessage
  , matchUnknown
  )
import Control.Distributed.Process.Management.Internal.Types
  ( MxEvent(..)
  , Addressable(..)
  )
import Control.Distributed.Process.Management.Internal.Trace.Types
  ( SetTrace(..)
  , TraceSubject(..)
  , TraceFlags(..)
  , TraceOk(..)
  , defaultTraceFlags
  )
import Control.Distributed.Process.Management.Internal.Trace.Primitives
  ( traceOn )
import Control.Distributed.Process.Internal.Types
  ( LocalNode(..)
  , NCMsg(..)
  , ProcessId
  , Process
  , LocalProcess(..)
  , Identifier(..)
  , ProcessSignal(NamedSend)
  , Message
  , SendPort
  , forever'
  , nullProcessId
  , createUnencodedMessage
  )
import Control.Monad.IO.Class (liftIO)
import Control.Monad.Reader (ask)
import Control.Monad.Catch
  ( catch
  , finally
  )
import Data.Set (Set)
import qualified Data.Set as Set
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Maybe (fromMaybe)
import Data.Time.Clock (getCurrentTime)
import Data.Time.Format (formatTime)
import Debug.Trace (traceEventIO)
import Prelude
#if ! MIN_VERSION_base(4,6,0)
import Prelude hiding (catch)
#endif
import System.Environment (getEnv)
import System.IO
  ( Handle
  , IOMode(AppendMode)
  , BufferMode(..)
  , openFile
  , hClose
  , hPutStrLn
  , hSetBuffering
  )
#if MIN_VERSION_time(1,5,0)
import Data.Time.Format (defaultTimeLocale)
#else
import System.Locale (defaultTimeLocale)
#endif
import System.Mem.Weak
  ( Weak
  )
data TracerState =
  TracerST
  {
    client   :: !(Maybe ProcessId)
  , flags    :: !TraceFlags
  , regNames :: !(Map ProcessId (Set String))
  }
defaultTracer :: Process ()
defaultTracer =
  catch (checkEnv "DISTRIBUTED_PROCESS_TRACE_FILE" >>= logfileTracer)
        (\(_ :: IOError) -> defaultTracerAux)
defaultTracerAux :: Process ()
defaultTracerAux =
  catch (checkEnv "DISTRIBUTED_PROCESS_TRACE_CONSOLE" >> systemLoggerTracer)
        (\(_ :: IOError) -> defaultEventLogTracer)
defaultEventLogTracer :: Process ()
defaultEventLogTracer =
  catch (checkEnv "DISTRIBUTED_PROCESS_TRACE_EVENTLOG" >> eventLogTracer)
        (\(_ :: IOError) -> nullTracer)
checkEnv :: String -> Process String
checkEnv s = liftIO $ getEnv s
nullTracer :: Process ()
nullTracer =
  forever' $ receiveWait [ matchUnknown (return ()) ]
systemLoggerTracer :: Process ()
systemLoggerTracer = do
  node <- processNode <$> ask
  let tr = sendTraceLog node
  forever' $ receiveWait [ matchAny (\m -> handleMessage m tr) ]
  where
    sendTraceLog :: LocalNode -> MxEvent -> Process ()
    sendTraceLog node ev = do
      now <- liftIO $ getCurrentTime
      msg <- return $ (formatTime defaultTimeLocale "%c" now, buildTxt ev)
      emptyPid <- return $ (nullProcessId (localNodeId node))
      traceMsg <- return $ NCMsg {
                             ctrlMsgSender = ProcessIdentifier (emptyPid)
                           , ctrlMsgSignal = (NamedSend "trace.logger"
                                                 (createUnencodedMessage msg))
                           }
      liftIO $ writeChan (localCtrlChan node) traceMsg
    buildTxt :: MxEvent -> String
    buildTxt (MxLog msg) = msg
    buildTxt ev          = show ev
eventLogTracer :: Process ()
eventLogTracer =
  
  
  
  forever' $ receiveWait [ matchAny (\m -> handleMessage m writeTrace) ]
  where
    writeTrace :: MxEvent -> Process ()
    writeTrace ev = liftIO $ traceEventIO (show ev)
logfileTracer :: FilePath -> Process ()
logfileTracer p = do
  
  h <- liftIO $ openFile p AppendMode
  liftIO $ hSetBuffering h LineBuffering
  logger h `finally` (liftIO $ hClose h)
  where
    logger :: Handle -> Process ()
    logger h' = forever' $ do
      receiveWait [
          matchIf (\ev -> case ev of
                            MxTraceDisable      -> True
                            (MxTraceTakeover _) -> True
                            _                   -> False)
                  (\_ -> (liftIO $ hClose h') >> die "trace stopped")
        , matchAny (\ev -> handleMessage ev (writeTrace h'))
        ]
    writeTrace :: Handle -> MxEvent -> Process ()
    writeTrace h ev = do
      liftIO $ do
        now <- getCurrentTime
        hPutStrLn h $ (formatTime defaultTimeLocale "%c - " now) ++ (show ev)
traceController :: MVar ((Weak (CQueue Message))) -> Process ()
traceController mv = do
    
    
    weakQueue <- processWeakQ <$> ask
    liftIO $ putMVar mv weakQueue
    initState <- initialState
    traceLoop initState { client = Nothing }
  where
    traceLoop :: TracerState -> Process ()
    traceLoop st = do
      let client' = client st
      
      
      
      st' <- receiveWait [
          match (\(setResp, set :: SetTrace) -> do
                  
                  
                  
                  
                  case set of
                    (TraceEnable pid) -> do
                      
                      sendTraceMsg client' (createUnencodedMessage (MxTraceTakeover pid))
                      sendOk setResp
                      return st { client = (Just pid) }
                    TraceDisable -> do
                      sendTraceMsg client' (createUnencodedMessage MxTraceDisable)
                      sendOk setResp
                      return st { client = Nothing })
        , match (\(confResp, flags') ->
                  sendOk confResp >> applyTraceFlags flags' st)
        , match (\chGetFlags -> sendChan chGetFlags (flags st) >> return st)
        , match (\chGetCurrent -> sendChan chGetCurrent (client st) >> return st)
          
        , matchAny (\ev ->
            handleMessage ev (handleTrace st ev) >>= return . fromMaybe st)
        ]
      traceLoop st'
    sendOk :: Maybe (SendPort TraceOk) -> Process ()
    sendOk Nothing   = return ()
    sendOk (Just sp) = sendChan sp TraceOk
    initialState :: Process TracerState
    initialState = do
      flags' <- checkEnvFlags
      return $ TracerST { client   = Nothing
                        , flags    = flags'
                        , regNames = Map.empty
                        }
    checkEnvFlags :: Process TraceFlags
    checkEnvFlags =
      catch (checkEnv "DISTRIBUTED_PROCESS_TRACE_FLAGS" >>= return . parseFlags)
            (\(_ :: IOError) -> return defaultTraceFlags)
    parseFlags :: String -> TraceFlags
    parseFlags s = parseFlags' s defaultTraceFlags
      where parseFlags' :: String -> TraceFlags -> TraceFlags
            parseFlags' [] parsedFlags = parsedFlags
            parseFlags' (x:xs) parsedFlags
              | x == 'p'  = parseFlags' xs parsedFlags { traceSpawned = traceOn }
              | x == 'n'  = parseFlags' xs parsedFlags { traceRegistered = traceOn }
              | x == 'u'  = parseFlags' xs parsedFlags { traceUnregistered = traceOn }
              | x == 'd'  = parseFlags' xs parsedFlags { traceDied = traceOn }
              | x == 's'  = parseFlags' xs parsedFlags { traceSend = traceOn }
              | x == 'r'  = parseFlags' xs parsedFlags { traceRecv = traceOn }
              | x == 'l'  = parseFlags' xs parsedFlags { traceNodes = True }
              | otherwise = parseFlags' xs parsedFlags
applyTraceFlags :: TraceFlags -> TracerState -> Process TracerState
applyTraceFlags flags' state = return state { flags = flags' }
handleTrace :: TracerState -> Message -> MxEvent -> Process TracerState
handleTrace st msg ev@(MxRegistered p n) =
  let regNames' =
        Map.insertWith (\_ ns -> Set.insert n ns) p
                       (Set.singleton n)
                       (regNames st)
  in do
    traceEv ev msg (traceRegistered (flags st)) st
    return st { regNames = regNames' }
handleTrace st msg ev@(MxUnRegistered p n) =
  let f ns = case ns of
               Nothing  -> Nothing
               Just ns' -> Just (Set.delete n ns')
      regNames' = Map.alter f p (regNames st)
  in do
    traceEv ev msg (traceUnregistered (flags st)) st
    return st { regNames = regNames' }
handleTrace st msg ev@(MxSpawned  _)   = do
  traceEv ev msg (traceSpawned (flags st)) st >> return st
handleTrace st msg ev@(MxProcessDied _ _)     = do
  traceEv ev msg (traceDied (flags st)) st >> return st
handleTrace st msg ev@(MxSent _ _ _)   =
  traceEv ev msg (traceSend (flags st)) st >> return st
handleTrace st msg ev@(MxReceived _ _) =
  traceEv ev msg (traceRecv (flags st)) st >> return st
handleTrace st msg ev = do
  case ev of
    (MxNodeDied _ _) ->
      case (traceNodes (flags st)) of
        True  -> sendTrace st ev msg
        False -> return ()
    (MxUser _) -> sendTrace st ev msg
    (MxLog _)  -> sendTrace st ev msg
    _ ->
      case (traceConnections (flags st)) of
        True  -> sendTrace st ev msg
        False -> return ()
  return st
traceEv :: MxEvent
        -> Message
        -> Maybe TraceSubject
        -> TracerState
        -> Process ()
traceEv _  _   Nothing                  _  = return ()
traceEv ev msg (Just TraceAll)          st = sendTrace st ev msg
traceEv ev msg (Just (TraceProcs pids)) st = do
  node <- processNode <$> ask
  let p = case resolveToPid ev of
            Nothing  -> (nullProcessId (localNodeId node))
            Just pid -> pid
  case (Set.member p pids) of
    True  -> sendTrace st ev msg
    False -> return ()
traceEv ev msg (Just (TraceNames names)) st = do
  
  
  node <- processNode <$> ask
  let p = case resolveToPid ev of
            Nothing  -> (nullProcessId (localNodeId node))
            Just pid -> pid
  case (Map.lookup p (regNames st)) of
    Nothing -> return ()
    Just ns -> if (Set.null (Set.intersection ns names))
                 then return ()
                 else sendTrace st ev msg
sendTrace :: TracerState -> MxEvent -> Message -> Process ()
sendTrace st ev msg = do
  let c = client st
  if c == (resolveToPid ev)  
     then return ()
     else sendTraceMsg c msg
sendTraceMsg :: Maybe ProcessId -> Message -> Process ()
sendTraceMsg Nothing  _   = return ()
sendTraceMsg (Just p) msg = (flip forward) p msg