{-# LANGUAGE
    DataKinds
  , RankNTypes
  , NamedFieldPuns
  , FlexibleContexts
  , ScopedTypeVariables
  #-}
module Control.Concurrent.Threaded where
import Control.Concurrent.Async (Async, async, cancel)
import Control.Concurrent.Chan.Scope (Scope (Read, Write))
import Control.Concurrent.Chan.Extra (ChanScoped (readOnly, allowReading, writeOnly, allowWriting))
import Control.Concurrent.STM (atomically)
import Control.Concurrent.STM.TChan.Typed (TChanRW, newTChanRW, readTChanRW, writeTChanRW)
import Control.Concurrent.STM.TMapMVar (TMapMVar, newTMapMVar, tryObserve, insertForce, tryLookup)
import Control.Monad (forever)
import Control.Monad.IO.Class (MonadIO (liftIO))
import Control.Monad.Trans.Control.Aligned (MonadBaseControl (liftBaseWith))
import Data.Singleton.Class (Extractable (runSingleton))
data ThreadedInternal incoming outgoing = ThreadedInternal
  { thread :: Async () 
  , outputRelay :: Async () 
  , threadInput :: TChanRW 'Read incoming 
  , threadOutput :: TChanRW 'Write outgoing 
  }
threaded :: forall m stM k input output
          . Ord k
         => Show k
         => MonadIO m
         => MonadBaseControl IO m stM
         => Extractable stM
         => 
            TChanRW 'Write (k, input)
         -> 
            
            (TChanRW 'Read input -> TChanRW 'Write output -> m ())
         -> m (Async (), TChanRW 'Read (k, output))
threaded incoming process = do
  ( threads :: TMapMVar k (ThreadedInternal incoming outgoing)
    ) <- liftIO (atomically newTMapMVar)
  outgoing <- liftIO (atomically (readOnly <$> newTChanRW))
  
  threadRunner <- liftBaseWith $ \runInBase -> fmap (fmap runSingleton) $ async $ runInBase $ forever $ do
    (k, input) <- liftIO (atomically (readTChanRW (allowReading incoming)))
    mThread <- liftIO (atomically (tryObserve threads k))
    case mThread of
      Nothing -> do
        
        threadInput' <- liftIO $ atomically $ do
          i <- newTChanRW
          
          writeTChanRW i input
          pure i
        let threadInput = readOnly threadInput'
        threadOutput' <- liftIO (atomically newTChanRW)
        let threadOutput = writeOnly threadOutput'
        
        
        outputRelay <- liftIO $ async $ forever $ do
          output <- atomically (readTChanRW threadOutput')
          atomically (writeTChanRW (allowWriting outgoing) (k, output))
        
        thread <- liftBaseWith $ \runInBase' -> fmap (fmap runSingleton) $ async $ runInBase' $ do
          process threadInput threadOutput
          
          mThread' <- liftIO (atomically (tryLookup threads k))
          case mThread' of
            Nothing -> error ("Thread's facilities don't exist: " ++ show k)
            Just ThreadedInternal{thread = thread'} -> liftIO $ do
              
              cancel outputRelay
              cancel thread'
        
        liftIO $ atomically $
          insertForce threads k ThreadedInternal{thread,outputRelay,threadInput,threadOutput}
      
      Just ThreadedInternal{threadInput} ->
        liftIO (atomically (writeTChanRW (allowWriting threadInput) input))
  pure (threadRunner, outgoing)