{-# LANGUAGE
DataKinds
, RankNTypes
, NamedFieldPuns
, FlexibleContexts
, ScopedTypeVariables
#-}
module Control.Concurrent.Threaded where
import Control.Concurrent.Async (Async, async, cancel)
import Control.Concurrent.Chan.Scope (Scope (Read, Write))
import Control.Concurrent.Chan.Extra (ChanScoped (readOnly, allowReading, writeOnly, allowWriting))
import Control.Concurrent.STM (atomically)
import Control.Concurrent.STM.TChan.Typed (TChanRW, newTChanRW, readTChanRW, writeTChanRW)
import Control.Concurrent.STM.TMapMVar (TMapMVar, newTMapMVar, tryObserve, insertForce, tryLookup)
import Control.Monad (forever)
import Control.Monad.IO.Class (MonadIO (liftIO))
import Control.Monad.Trans.Control.Aligned (MonadBaseControl (liftBaseWith))
import Data.Singleton.Class (Extractable (runSingleton))
data ThreadedInternal incoming outgoing = ThreadedInternal
{ thread :: Async ()
, outputRelay :: Async ()
, threadInput :: TChanRW 'Read incoming
, threadOutput :: TChanRW 'Write outgoing
}
threaded :: forall m stM k input output
. Ord k
=> Show k
=> MonadIO m
=> MonadBaseControl IO m stM
=> Extractable stM
=>
TChanRW 'Write (k, input)
->
(TChanRW 'Read input -> TChanRW 'Write output -> m ())
-> m (Async (), TChanRW 'Read (k, output))
threaded incoming process = do
( threads :: TMapMVar k (ThreadedInternal incoming outgoing)
) <- liftIO (atomically newTMapMVar)
outgoing <- liftIO (atomically (readOnly <$> newTChanRW))
threadRunner <- liftBaseWith $ \runInBase -> fmap (fmap runSingleton) $ async $ runInBase $ forever $ do
(k, input) <- liftIO (atomically (readTChanRW (allowReading incoming)))
mThread <- liftIO (atomically (tryObserve threads k))
case mThread of
Nothing -> do
threadInput' <- liftIO $ atomically $ do
i <- newTChanRW
writeTChanRW i input
pure i
let threadInput = readOnly threadInput'
threadOutput' <- liftIO (atomically newTChanRW)
let threadOutput = writeOnly threadOutput'
outputRelay <- liftIO $ async $ forever $ do
output <- atomically (readTChanRW threadOutput')
atomically (writeTChanRW (allowWriting outgoing) (k, output))
thread <- liftBaseWith $ \runInBase' -> fmap (fmap runSingleton) $ async $ runInBase' $ do
process threadInput threadOutput
mThread' <- liftIO (atomically (tryLookup threads k))
case mThread' of
Nothing -> error ("Thread's facilities don't exist: " ++ show k)
Just ThreadedInternal{thread = thread'} -> liftIO $ do
cancel outputRelay
cancel thread'
liftIO $ atomically $
insertForce threads k ThreadedInternal{thread,outputRelay,threadInput,threadOutput}
Just ThreadedInternal{threadInput} ->
liftIO (atomically (writeTChanRW (allowWriting threadInput) input))
pure (threadRunner, outgoing)