module Potoki.Core.Transform.Concurrency
where

import Potoki.Core.Prelude hiding (take, takeWhile, filter)
import Potoki.Core.Transform.Instances ()
import Potoki.Core.Transform.Basic
import Potoki.Core.Types
import qualified Potoki.Core.Fetch as A
import qualified Acquire.Acquire as M


bufferizeFlushing :: Int -> Transform input [input]
bufferizeFlushing maxSize =
  Transform $ \ (A.Fetch fetchIO) -> liftIO $ do
    buffer <- newTBQueueIO (fromIntegral maxSize)
    activeVar <- newTVarIO True

    forkIO $ let
      loop = do
        fetchingResult <- fetchIO
        case fetchingResult of
          Just !element -> do
            atomically $ writeTBQueue buffer element
            loop
          Nothing -> atomically $ writeTVar activeVar False
      in loop

    return $ Fetch $ atomically $ do
      batch <- flushTBQueue buffer
      if null batch
        then do
          active <- readTVar activeVar
          if active
            then retry
            else return Nothing
        else return (Just batch)

{-# INLINE bufferize #-}
bufferize :: NFData element => Int -> Transform element element
bufferize size =
  Transform $ \ (A.Fetch fetchIO) -> liftIO $ do
    buffer <- newTBQueueIO (fromIntegral size)
    activeVar <- newTVarIO True

    forkIO $ let
      loop = do
        fetchingResult <- fetchIO
        case fetchingResult of
          Just element -> do
            forcedElement <- evaluate (force element)
            atomically $ writeTBQueue buffer forcedElement
            loop
          Nothing -> atomically $ writeTVar activeVar False
      in loop

    return $ Fetch $ let
      readBuffer = Just <$> readTBQueue buffer
      terminate = do
        active <- readTVar activeVar
        if active
          then empty
          else return Nothing
      in atomically (readBuffer <|> terminate)

{-|
Identity Transform, which ensures that the inputs are fetched synchronously.

Useful for concurrent transforms.
-}
{-# INLINABLE sync #-}
sync :: Transform a a
sync =
  Transform $ \ (A.Fetch fetchIO) -> liftIO $ do
    activeVar <- newMVar True
    return $ A.Fetch $ do
      active <- takeMVar activeVar
      if active
        then fetchIO >>= \ case
          Just !element -> do
            putMVar activeVar True
            return (Just element)
          Nothing -> do
            putMVar activeVar False
            return Nothing
        else do
          putMVar activeVar False
          return Nothing

{-|
Execute the transform on the specified amount of threads.
The order of the outputs produced is indiscriminate.
-}
{-# INLINABLE concurrently #-}
concurrently :: NFData output => Int -> Transform input output -> Transform input output
concurrently workersAmount transform =
  if workersAmount == 1
    then transform
    else
      sync >>>
      unsafeConcurrently workersAmount transform

{-# INLINE unsafeConcurrently #-}
unsafeConcurrently :: NFData output => Int -> Transform input output -> Transform input output
unsafeConcurrently workersAmount (Transform syncTransformIO) =
  Transform $ \ fetchIO -> liftIO $ do
    chan <- newTBQueueIO (fromIntegral (workersAmount * 2))
    workersCounter <- newTVarIO workersAmount

    replicateM_ workersAmount $ forkIO $ do
      (A.Fetch fetchIO, finalize) <- case syncTransformIO fetchIO of M.Acquire io -> io
      let
        loop = do
          fetchResult <- fetchIO
          case fetchResult of
            Just result -> do
              forcedResult <- evaluate (force result)
              atomically (writeTBQueue chan forcedResult)
              loop
            Nothing -> atomically (modifyTVar' workersCounter pred)
        in loop *> finalize

    return $ A.Fetch $ let
      readChan = Just <$> readTBQueue chan
      terminate = do
        workersActive <- readTVar workersCounter
        if workersActive > 0
          then empty
          else return Nothing
      in atomically (readChan <|> terminate)

concurrentlyInOrder :: NFData b => Int -> Transform a b -> Transform a b
concurrentlyInOrder concurrency (Transform transform) = Transform $ \ (Fetch fetchA) -> liftIO $ do
  inputQueue <- newTBQueueIO (fromIntegral concurrency)
  outputSlotQueue <- newTQueueIO
  liveWorkersVar <- newTVarIO concurrency

  forkIO $ let
    loop = do
      fetchAResult <- fetchA
      case fetchAResult of
        Just a -> do
          atomically $ writeTBQueue inputQueue (Just a)
          loop
        Nothing -> atomically $ replicateM_ concurrency $ writeTBQueue inputQueue Nothing
    in loop

  replicateM_ concurrency $ forkIO $ do
    outputQueue <- newTQueueIO
    needsSwitchVar <- newTVarIO False

    let
      localizedFetchA = Fetch $ atomically $ do
        needsSwitch <- readTVar needsSwitchVar
        if needsSwitch
          then writeTQueue outputQueue Nothing
          else writeTVar needsSwitchVar True
        writeTQueue outputSlotQueue outputQueue
        readTBQueue inputQueue

      in do
        (Fetch fetchB, finalize) <- case transform localizedFetchA of M.Acquire io -> io
        let
          loop = do
            fetchBResult <- fetchB
            case fetchBResult of
              Just b -> do
                forcedB <- evaluate (force b)
                atomically $ writeTQueue outputQueue (Just forcedB)
                loop
              Nothing -> do
                atomically $ do
                  writeTQueue outputQueue Nothing
                  modifyTVar' liveWorkersVar pred
                finalize
            in loop

  return $ Fetch $ atomically $ fix $ \ loop -> mplus
    (do
      outputQueue <- peekTQueue outputSlotQueue
      bIfAny <- readTQueue outputQueue
      case bIfAny of
        Just b -> return (Just b)
        Nothing -> do
          readTQueue outputSlotQueue
          loop)
    (do
      liveWorkers <- readTVar liveWorkersVar
      guard (liveWorkers <= 0)
      return Nothing)

{-|
A transform, which fetches the inputs asynchronously on the specified number of threads.
-}
async :: NFData input => Int -> Transform input input
async workersAmount =
  Transform $ \ (A.Fetch fetchIO) -> liftIO $ do
    chan <- atomically newEmptyTMVar
    workersCounter <- atomically (newTVar workersAmount)

    replicateM_ workersAmount $ forkIO $ let
      loop = do
        fetchResult <- fetchIO
        case fetchResult of
          Just input -> atomically (putTMVar chan input) *> loop
          Nothing -> atomically (modifyTVar' workersCounter pred)
      in loop

    return $ A.Fetch $ let
      readChan = Just <$> takeTMVar chan
      terminate = do
        workersActive <- readTVar workersCounter
        if workersActive > 0
          then empty
          else return Nothing
      in atomically (readChan <|> terminate)

concurrentlyWithBatching :: (NFData a, NFData b) => Int -> Int -> Transform a b -> Transform a b
concurrentlyWithBatching batching concurrency transform =
  batch @Vector batching >>> bufferize concurrency >>>
  unsafeConcurrently concurrency (vector >>> transform >>> batch @Vector batching) >>>
  vector

concurrentlyInOrderWithBatching :: (NFData b) => Int -> Int -> Transform a b -> Transform a b
concurrentlyInOrderWithBatching batching concurrency transform =
  batch @Vector batching >>>
  concurrentlyInOrder concurrency (vector >>> transform >>> batch @Vector batching) >>>
  vector