{-# LANGUAGE
DataKinds
, RankNTypes
, NamedFieldPuns
, FlexibleContexts
, ScopedTypeVariables
#-}
module Control.Concurrent.Threaded.Hash where
import Control.Concurrent.Threaded (ThreadedInternal (..))
import Control.Concurrent.Async (Async, async, cancel)
import Control.Concurrent.Chan.Scope (Scope (Read, Write))
import Control.Concurrent.Chan.Extra (ChanScoped (readOnly, allowReading, writeOnly, allowWriting))
import Control.Concurrent.STM (atomically)
import Control.Concurrent.STM.TChan.Typed (TChanRW, newTChanRW, readTChanRW, writeTChanRW)
import Control.Concurrent.STM.TMapMVar.Hash (TMapMVar, newTMapMVar, tryObserve, insertForce, tryLookup)
import Control.Monad (forever)
import Control.Monad.IO.Class (MonadIO (liftIO))
import Control.Monad.Trans.Control.Aligned (MonadBaseControl (liftBaseWith))
import Data.Singleton.Class (Extractable (runSingleton))
import Data.Hashable (Hashable)
threaded :: forall m stM k input output
. Hashable k
=> Eq k
=> Show k
=> MonadIO m
=> MonadBaseControl IO m stM
=> Extractable stM
=>
TChanRW 'Write (k, input)
->
(TChanRW 'Read input -> TChanRW 'Write output -> m ())
-> m (Async (), TChanRW 'Read (k, output))
threaded incoming process = do
( threads :: TMapMVar k (ThreadedInternal incoming outgoing)
) <- liftIO (atomically newTMapMVar)
outgoing <- liftIO (atomically (readOnly <$> newTChanRW))
threadRunner <- liftBaseWith $ \runInBase -> fmap (fmap runSingleton) $ async $ runInBase $ forever $ do
(k, input) <- liftIO (atomically (readTChanRW (allowReading incoming)))
mThread <- liftIO (atomically (tryObserve threads k))
case mThread of
Nothing -> do
threadInput' <- liftIO $ atomically $ do
i <- newTChanRW
writeTChanRW i input
pure i
let threadInput = readOnly threadInput'
threadOutput' <- liftIO (atomically newTChanRW)
let threadOutput = writeOnly threadOutput'
outputRelay <- liftIO $ async $ forever $ do
output <- atomically (readTChanRW threadOutput')
atomically (writeTChanRW (allowWriting outgoing) (k, output))
thread <- liftBaseWith $ \runInBase' -> fmap (fmap runSingleton) $ async $ runInBase' $ do
process threadInput threadOutput
mThread' <- liftIO (atomically (tryLookup threads k))
case mThread' of
Nothing -> error ("Thread's facilities don't exist: " ++ show k)
Just ThreadedInternal{thread = thread'} -> liftIO $ do
cancel outputRelay
cancel thread'
liftIO $ atomically $
insertForce threads k ThreadedInternal{thread,outputRelay,threadInput,threadOutput}
Just ThreadedInternal{threadInput} ->
liftIO (atomically (writeTChanRW (allowWriting threadInput) input))
pure (threadRunner, outgoing)