{-# LANGUAGE
    DataKinds
  , RankNTypes
  , NamedFieldPuns
  , FlexibleContexts
  , ScopedTypeVariables
  #-}

module Control.Concurrent.Threaded.Hash where

import Control.Concurrent.Threaded (ThreadedInternal (..))
import Control.Concurrent.Async (Async, async, cancel)
import Control.Concurrent.Chan.Scope (Scope (Read, Write))
import Control.Concurrent.Chan.Extra (ChanScoped (readOnly, allowReading, writeOnly, allowWriting))
import Control.Concurrent.STM (atomically)
import Control.Concurrent.STM.TChan.Typed (TChanRW, newTChanRW, readTChanRW, writeTChanRW)
import Control.Concurrent.STM.TMapMVar.Hash (TMapMVar, newTMapMVar, tryObserve, insertForce, tryLookup)
import Control.Monad (forever)
import Control.Monad.IO.Class (MonadIO (liftIO))
import Control.Monad.Trans.Control.Aligned (MonadBaseControl (liftBaseWith))
import Data.Singleton.Class (Extractable (runSingleton))
import Data.Hashable (Hashable)


-- | Segregates concurrently operating threads by some key type @k@. Returns the
-- thread that processes all other threads (this function is non-blocking), and the
-- channel that dispenses the outputs from each thread.
threaded :: forall m stM k input output
          . Hashable k
         => Eq k
         => Show k
         => MonadIO m
         => MonadBaseControl IO m stM
         => Extractable stM
         => -- | Incoming messages, identified by thread @k@
            TChanRW 'Write (k, input)
         -> -- | Process to spark in a new thread. When @m ()@ returns, the thread is considered \"dead\",
            -- and is internally cleaned up.
            (TChanRW 'Read input -> TChanRW 'Write output -> m ())
         -> m (Async (), TChanRW 'Read (k, output))
threaded incoming process = do
  ( threads :: TMapMVar k (ThreadedInternal incoming outgoing)
    ) <- liftIO (atomically newTMapMVar)
  outgoing <- liftIO (atomically (readOnly <$> newTChanRW))

  -- the main function that organizes the execution and plumbing of the threads
  threadRunner <- liftBaseWith $ \runInBase -> fmap (fmap runSingleton) $ async $ runInBase $ forever $ do
    (k, input) <- liftIO (atomically (readTChanRW (allowReading incoming)))

    mThread <- liftIO (atomically (tryObserve threads k))
    case mThread of
      Nothing -> do
        -- thread-specific channels
        threadInput' <- liftIO $ atomically $ do
          i <- newTChanRW
          -- initial input
          writeTChanRW i input
          pure i
        let threadInput = readOnly threadInput'
        threadOutput' <- liftIO (atomically newTChanRW)
        let threadOutput = writeOnly threadOutput'


        -- relays the process's output to the whole output - can't be done in lock-step, must let one finish
        -- before writing
        outputRelay <- liftIO $ async $ forever $ do
          output <- atomically (readTChanRW threadOutput')
          atomically (writeTChanRW (allowWriting outgoing) (k, output))

        -- main thread
        thread <- liftBaseWith $ \runInBase' -> fmap (fmap runSingleton) $ async $ runInBase' $ do
          process threadInput threadOutput
          -- thread finished processing
          mThread' <- liftIO (atomically (tryLookup threads k))
          case mThread' of
            Nothing -> error ("Thread's facilities don't exist: " ++ show k)
            Just ThreadedInternal{thread = thread'} -> liftIO $ do
              -- kill the thread's supervisors
              cancel outputRelay
              cancel thread'

        -- store threads and channels
        liftIO $ atomically $
          insertForce threads k ThreadedInternal{thread,outputRelay,threadInput,threadOutput}


      -- thread is still processing - relay input to its channel
      Just ThreadedInternal{threadInput} ->
        liftIO (atomically (writeTChanRW (allowWriting threadInput) input))

  pure (threadRunner, outgoing)