-- | Implement Erlang style message passing concurrency.
--
-- This handles the 'MessagePassing' and 'Process' effects, using
-- 'STM.TQueue's and 'forkIO'.
--
-- This aims to be a pragmatic implementation, so even logging is
-- supported.
--
-- At the core is a /main process/ that enters 'runMainProcess'
-- and creates all of the internal state stored in 'STM.TVar's
-- to manage processes with message queues.
--
-- The 'Eff' handler for 'Process' and 'MessagePassing' use
-- are implemented and available through 'spawn'.
--
-- 'spawn' uses 'forkFinally' and 'STM.TQueue's and tries to catch
-- most exceptions.
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE GADTs #-}
module Control.Eff.Concurrent.Dispatcher
  ( runMainProcess
  , defaultMain
  , spawn
  , DispatcherError(..)
  , DispatcherIO
  , HasDispatcherIO
  , ProcIO
  )
where

import           GHC.Stack
import           Data.Maybe
import           Data.Kind
import qualified Control.Exception             as Exc
import           Control.Concurrent            as Concurrent
import           Control.Concurrent.STM        as STM
import           Control.Eff
import           Control.Eff.Concurrent.MessagePassing
import           Control.Eff.ExceptionExtra
import           Control.Eff.Lift
import           Control.Eff.Log
import           Control.Eff.Reader.Strict     as Reader
import           Control.Lens
import           Control.Monad                  ( when
                                                , void
                                                )
import qualified Control.Monad.State           as Mtl
import           Data.Dynamic
import           Data.Typeable                  ( typeRep )
import           Data.Map                       ( Map )
import qualified Data.Map                      as Map

-- | Information about a process, needed to implement 'MessagePassing' and
-- 'Process' handlers. The message queue is backed by a 'STM.TQueue'.
data ProcessInfo =
                 ProcessInfo { _processId       :: ProcessId
                             , _messageQ        :: STM.TQueue Dynamic
                             , _exitOnShutdown  :: Bool
                             }

makeLenses ''ProcessInfo

instance Show ProcessInfo where
  show p =
    "ProcessInfo: " ++ show (p ^. processId) ++ " trapExit: "
                      ++ show (not (p^.exitOnShutdown))

-- | Contains all 'ProcessInfo' elements, as well as the state needed to
-- implement inter process communication. It contains also a 'LogChannel' to
-- which the logs of all processes are forwarded to.
data Dispatcher =
               Dispatcher { _nextPid :: ProcessId
                          , _processTable :: Map ProcessId ProcessInfo
                          , _threadIdTable :: Map ProcessId ThreadId
                          , _schedulerShuttingDown :: Bool
                          , _logChannel :: LogChannel String
                          }

makeLenses ''Dispatcher

-- | A newtype wrapper around an 'STM.TVar' holding a 'Dispatcher' state.
-- This is needed by 'spawn' and provided by 'runDispatcher'.
newtype DispatcherVar = DispatcherVar { fromDispatcherVar :: STM.TVar Dispatcher }
  deriving Typeable

-- | A sum-type with errors that can occur when dispatching messages.
data DispatcherError =
   UnhandledMessageReceived Dynamic ProcessId
   -- ^ A process message queue contained a bad message and the 'Dynamic' value
   -- could not be converted to the expected value using 'fromDynamic'.
  | ProcessNotFound ProcessId
    -- ^ No 'ProcessInfo' was found for a 'ProcessId' during internal
    -- processing. NOTE: This is **ONLY** caused by internal errors, probably by
    -- an incorrect 'MessagePassing' handler in this module. **Sending a message
    -- to a process ALWAYS succeeds!** Even if the process does not exist.
  | ProcessException String ProcessId
    -- ^ A process called 'raiseError'.
  | DispatcherShuttingDown
    -- ^ An action was not performed while the dispatcher was exiting.
  | LowLevelIOException Exc.SomeException
    -- ^ 'Control.Exception.SomeException' was caught while dispatching
    -- messages.
  deriving (Typeable, Show)

instance Exc.Exception DispatcherError

-- | An alias for the constraints for the effects essential to this dispatcher
-- implementation, i.e. these effects allow 'spawn'ing new 'Process'es.
-- @see DispatcherIO
type HasDispatcherIO r = ( HasCallStack
                        , SetMember Lift (Lift IO) r
                        , Member (Exc DispatcherError) r
                        , Member (Logs String) r
                        , Member (Reader DispatcherVar) r)

-- | The concrete list of 'Eff'ects for this scheduler implementation.
-- @see HasDispatcherIO
type DispatcherIO =
              '[ Exc DispatcherError
               , Reader DispatcherVar
               , Logs String
               , Lift IO
               ]

-- | The concrete list of 'Eff'ects that provide 'MessagePassing' and
-- 'Process'es ontop of 'DispatcherIO'
type ProcIO = ConsProcIO DispatcherIO

-- | /Cons/ 'ProcIO' onto a list of effects.
type ConsProcIO r = MessagePassing ': Process ': r

instance MonadLog String (Eff ProcIO) where
  logMessageFree = logMessageFreeEff

-- | This is the main entry point to running a message passing concurrency
-- application. This function takes a 'ProcIO' effect and a 'LogChannel' for
-- concurrent logging.
runMainProcess :: Eff ProcIO a -> LogChannel String -> IO a
runMainProcess e logC = withDispatcher
  (dispatchMessages
    (\cleanup -> do
      mt <- lift myThreadId
      mp <- self
      logMessage (show mp ++ " main process started in thread " ++ show mt)
      res <- try e
      case res of
        Left ex ->
          do
              logMessage
                (  show mp
                ++ " main process exception: "
                ++ ((show :: DispatcherError -> String) ex)
                )
              lift (runCleanUpAction cleanup)
            >> throwError ex
        Right rres -> do
          logMessage (show mp ++ " main process exited")
          lift (runCleanUpAction cleanup)
          return rres
    )
  )
  where
    withDispatcher :: Eff DispatcherIO a -> IO a
    withDispatcher mainProcessAction = do
      myTId <- myThreadId
      Exc.bracket
        (newTVarIO (Dispatcher myPid Map.empty Map.empty False logC))
        (tearDownDispatcher myTId)
        (\sch -> runLift
          (forwardLogsToChannel
            logC
            (runReader (runErrorRethrowIO mainProcessAction) (DispatcherVar sch))
          )
        )
     where
      myPid = 1
      tearDownDispatcher myTId v = do
        logChannelPutIO logC (show myTId ++" begin dispatcher tear down")
        sch <-
          (atomically
            (do
              sch <- readTVar v
              let sch' = sch & schedulerShuttingDown .~ True
              writeTVar v sch'
              return sch
            )
          )
        logChannelPutIO logC (show myTId ++ " killing threads: " ++
                                   show (sch ^.. threadIdTable. traversed))
        imapM_ (killProcThread myTId) (sch ^. threadIdTable)
        atomically
          (do
              dispatcher <- readTVar v
              let allThreadsDead = dispatcher^.threadIdTable.to Map.null
                                   && dispatcher^.processTable.to Map.null
              STM.check allThreadsDead)
        logChannelPutIO logC "all threads dead"


      killProcThread myTId pid tid = when
        (myTId /= tid)
        (  logChannelPutIO logC ("killing thread " ++ show pid)
        >> killThread tid
        )


-- | Start the message passing concurrency system then execute a 'ProcIO' effect.
-- All logging is sent to standard output.
defaultMain :: Eff ProcIO a -> IO a
defaultMain c =
  runLoggingT
    (logChannelBracket
      (Just "~~~~~~ main process started")
      (Just "====== main process exited")
      (runMainProcess c))
    (print :: String -> IO ())


runChildProcess
  :: DispatcherVar
  -> (CleanUpAction -> Eff ProcIO a)
  -> IO (Either DispatcherError a)
runChildProcess s procAction = do
  l <- (view logChannel) <$> atomically (readTVar (fromDispatcherVar s))
  runLift
    (forwardLogsToChannel
      l
      (runReader (runError (dispatchMessages procAction)) s)
    )

getLogChannel :: HasDispatcherIO r => Eff r (LogChannel String)
getLogChannel = do
  s <- getDispatcherTVar
  lift ((view logChannel) <$> atomically (readTVar s))


overProcessInfo
  :: HasDispatcherIO r
  => ProcessId
  -> Mtl.StateT ProcessInfo STM.STM a
  -> Eff r a
overProcessInfo pid stAction = liftEither =<< overDispatcher
  (do
    res <- use (processTable . at pid)
    case res of
      Nothing    -> return (Left (ProcessNotFound pid))
      Just pinfo -> do
        (x, pinfoOut) <- Mtl.lift (Mtl.runStateT stAction pinfo)
        processTable . at pid . _Just .= pinfoOut
        return (Right x)
  )

-- ** MessagePassing execution

spawn :: HasDispatcherIO r => Eff ProcIO () -> Eff r ProcessId
spawn mfa = do
  schedulerVar <- ask
  pidVar       <- lift newEmptyTMVarIO
  cleanupVar   <- lift newEmptyTMVarIO
  lc           <- getLogChannel
  void
    (lift
      (Concurrent.forkFinally
        (runChildProcess
          schedulerVar
          (\cleanUpAction -> do
            lift (atomically (STM.putTMVar cleanupVar cleanUpAction))
            pid <- self
            lift (atomically (STM.putTMVar pidVar (Just pid)))
            catchError
              mfa
              ( logMessage
              . ("process exception: " ++)
              . (show :: DispatcherError -> String)
              )
          )
        )
        (\eres -> do
          mcleanup <- atomically (STM.tryTakeTMVar cleanupVar)
          void (atomically (tryPutTMVar pidVar Nothing))
          case mcleanup of
            Nothing -> return ()
            Just ca -> do
              runCleanUpAction ca
              mt <- myThreadId
              case eres of
                Left se -> logChannelPutIO
                  lc
                  ("thread " ++ show mt ++ " killed by exception: " ++ show se)
                Right _ ->
                  logChannelPutIO lc ("thread " ++ show mt ++ " exited")
        )
      )
    )
  mPid <- lift (atomically (STM.takeTMVar pidVar))
  maybe (throwError DispatcherShuttingDown) return mPid

newtype CleanUpAction = CleanUpAction { runCleanUpAction :: IO () }

dispatchMessages
  :: forall r a
   . (HasDispatcherIO r, HasCallStack)
  => (CleanUpAction -> Eff (ConsProcIO r) a)
  -> Eff r a
dispatchMessages processAction = withMessageQueue
  (\cleanUpAction pinfo ->
     try
     (handle_relay
      return
      (goProc (pinfo ^. processId))
       (handle_relay return (go (pinfo ^. processId))
        (processAction cleanUpAction)
       ))
     >>=
     either
     (\(e :: DispatcherError) ->
        do logMsg (show (pinfo^.processId)
                    ++ " cleanup on exception: "
                    ++ show e)
           lift (runCleanUpAction cleanUpAction)
           throwError e)
     return
  )
 where
  go
    :: forall v
     . HasCallStack
    => ProcessId
    -> MessagePassing v
    -> (v -> Eff (Process ': r) a)
    -> Eff (Process ': r) a
  go _pid (SendMessage toPid reqIn) k = do
    psVar <- getDispatcherTVar
    liftRethrow
        LowLevelIOException
        (atomically
          (do
            p <- readTVar psVar
            let mto = p ^. processTable . at toPid
            case mto of
              Just toProc ->
                let dReq = toDyn reqIn
                in  do
                      writeTQueue (toProc ^. messageQ) dReq
                      return True
              Nothing -> return False
          )
        )
      >>= k
  go pid (ReceiveMessage onMsg) k = do
    catchError
      (do
        mDynMsg <- overProcessInfo pid
          (do
            mq <- use messageQ
            Mtl.lift (readTQueue mq))
        case fromDynamic mDynMsg of
          Just req -> let result = onMsg req in k (Message result)
          nix@Nothing ->
            let
              msg =
                "unexpected message: " ++ show mDynMsg ++ " expected: " ++ show
                  (typeRep nix)
            in  do
                  isExitOnShutdown <- overProcessInfo pid (use exitOnShutdown)
                  if isExitOnShutdown
                    then throwError (UnhandledMessageReceived mDynMsg pid)
                    else k (ProcessControlMessage msg)
      )
      (\(se :: DispatcherError) -> do
        isExitOnShutdown <- overProcessInfo pid (use exitOnShutdown)
        if isExitOnShutdown
          then throwError se
          else k (ProcessControlMessage (show se))
      )

  goProc
    :: forall v x
     . HasCallStack
    => ProcessId
    -> Process v
    -> (v -> Eff r x)
    -> Eff r x
  goProc pid SelfPid k = k pid
  goProc pid (TrapExit s) k =
    overProcessInfo pid (exitOnShutdown .= (not s)) >>= k
  goProc pid GetTrapExit k =
    overProcessInfo pid (use exitOnShutdown) >>= k . not
  goProc pid (RaiseError msg) _k = do
    logMsg (show pid ++ " error raised: " ++ msg)
    throwError (ProcessException msg pid)

withMessageQueue
  :: HasDispatcherIO r => (CleanUpAction -> ProcessInfo -> Eff r a) -> Eff r a
withMessageQueue m = do
  mpinfo <- createQueue
  lc     <- getLogChannel
  case mpinfo of
    Just pinfo -> do
      cleanUpAction <-
        getDispatcherTVar >>= return . CleanUpAction . destroyQueue
          lc
          (pinfo ^. processId)
      m cleanUpAction pinfo
    Nothing -> throwError DispatcherShuttingDown
 where
  createQueue = do
    myTId <- lift myThreadId
    overDispatcher
      (do
        abortNow <- use schedulerShuttingDown
        if abortNow
          then return Nothing
          else do
            pid     <- nextPid <<+= 1
            channel <- Mtl.lift newTQueue
            let pinfo = ProcessInfo pid channel True
            threadIdTable . at pid .= Just myTId
            processTable . at pid .= Just pinfo
            return (Just pinfo)
      )
  destroyQueue lc pid psVar = do
    didWork <- Exc.try
      (overDispatcherIO
        psVar
        (do
          abortNow <- use schedulerShuttingDown
          if abortNow
            then return (Nothing, False)
            else do
              os <- processTable . at pid <<.= Nothing
              ot <- threadIdTable . at pid <<.= Nothing
              return (os, isJust os || isJust ot)
        )
      )
    let getCause =
          Exc.try @Exc.SomeException
              (overDispatcherIO psVar (preuse (processTable . at pid)))
            >>= either
                  (return . (show pid ++) . show)
                  (return . (maybe (show pid) show))

    case didWork of
      Right (pinfo, True) ->
        logChannelPutIO lc ("destroying queue: " ++ show pinfo)
      Right (pinfo, False) ->
        logChannelPutIO lc ("queue already destroyed: " ++ show pinfo)
      Left (e :: Exc.SomeException) ->
        getCause
          >>= logChannelPutIO lc
          .   (("failed to destroy queue: " ++ show e ++ " ") ++)


overDispatcher
  :: HasDispatcherIO r => Mtl.StateT Dispatcher STM.STM a -> Eff r a
overDispatcher stAction = do
  psVar <- getDispatcherTVar
  liftRethrow LowLevelIOException (overDispatcherIO psVar stAction)

overDispatcherIO
  :: STM.TVar Dispatcher -> Mtl.StateT Dispatcher STM.STM a -> IO a
overDispatcherIO psVar stAction = STM.atomically
  (do
    ps                   <- STM.readTVar psVar
    (result, psModified) <- Mtl.runStateT stAction ps
    STM.writeTVar psVar psModified
    return result
  )

getDispatcherTVar :: HasDispatcherIO r => Eff r (TVar Dispatcher)
getDispatcherTVar = fromDispatcherVar <$> ask

getDispatcher :: HasDispatcherIO r => Eff r Dispatcher
getDispatcher = do
  processesVar <- getDispatcherTVar
  lift (atomically (readTVar processesVar))