{-# language
    BangPatterns
  , DeriveAnyClass
  , DeriveGeneric
  , DerivingStrategies
  , LambdaCase
  , MagicHash
  , ScopedTypeVariables
  , UnboxedTuples
  #-}

-- | This module provides many concurrent monoidal folds
--   for commutative monoids.
--
--   Some notes (applies to all folds):
--
--     1. This module is intended to be imported qualified
--        to avoid name clashing.
--
--     2. Accumulation is strict.
--
--     3. Exceptions that occur will accumulate into a 'CmfException'
--        and be re-thrown.
module Cmf
  ( -- * Folds
    foldMap
  , foldMapWithKey

    -- * Exception Type
  , CmfException(..)
  ) where

import Control.Concurrent
import Control.Exception
import Control.Monad (void)
import Data.Foldable (foldlM)
import Data.Monoid (Sum(..))
import GHC.Conc (ThreadId(..))
import GHC.Exts (fork#)
import GHC.Generics (Generic)
import GHC.IO (IO(..))

import qualified Data.Map as Map

import Prelude hiding (foldMap)

-- | An exception to be re-thrown by a fold in this
--   module. It is just an accumulation of all the
--   exceptions that occurred among the running
--   threads.
newtype CmfException = CmfException [SomeException]
  deriving stock (Show, Generic)
  deriving anyclass (Exception)

-- | A concurrent monoidal fold over some 'Foldable'.
--
-- This operation may fail with:
--
--   * 'CmfException' if any of the threads throws an exception.
foldMap :: forall t m a. (Foldable t, Monoid m)
  => (a -> IO m)
  -> t a
  -> IO m
foldMap f xs = do
  var <- newEmptyMVar
  total <- foldlM
    (\ !n a -> do
      void $ fork $ try (f a) >>= putMVar var
      pure (n + 1)
    ) 0 xs
  internal total var
{-# inlineable foldMap #-}

-- | A concurrent monoidal fold (with keys) over a 'Map.Map'.
--
-- This operation may fail with:
--
--   * 'CmfException' if any of the threads throws an exception.

foldMapWithKey :: (Monoid m)
  => (k -> a -> IO m)
  -> Map.Map k a
  -> IO m
foldMapWithKey f mp = do
  var <- newEmptyMVar
  Sum total <- Map.foldMapWithKey
    (\k a -> do
      void $ fork $ try (f k a) >>= putMVar var
      pure (Sum 1)
    ) mp
  internal total var
{-# inlineable foldMapWithKey #-}

-- fork, but don't catch exceptions
-- avoids nested catch#
fork :: IO () -> IO ()
fork action = IO $ \s -> case fork# action s of
  (# s1, _ #) -> (# s1, () #)

-- worker
internal :: forall m. (Monoid m)
  => Int -- total number of threads to spawn
  -> MVar (Either SomeException m)
  -> IO m
internal total var = do
  let go2 :: Int -> [SomeException] -> IO (Either [SomeException] m)
      go2 !n !es = if n < total
        then takeMVar var >>= \case
          Left e  -> go2 (n + 1) (e:es)
          Right _ -> go2 (n + 1) es
        else pure (Left es)
  let go :: Int -> m -> IO (Either [SomeException] m)
      go !n !m = if n < total
        then takeMVar var >>= \case
          Left r -> go2 (n + 1) [r]
          Right m' -> go (n + 1) (m <> m')
        else pure (Right m)
  r <- go 0 mempty
  case r of
    Left errs -> throwIO $ CmfException errs
    Right m -> pure m