{-|
Module      : Control.Concurrent.Throttle
Description : Throttling async mechanism
Copyright   : (c) CNRS, 2024-Present
License     : AGPL + CECILL v3
Maintainer  : team@gargantext.org
Stability   : experimental
Portability : POSIX

-}

{-# LANGUAGE ScopedTypeVariables #-}

module Control.Concurrent.Throttle
  ( throttle )
where

import Control.Concurrent (threadDelay)
import Control.Concurrent.Async qualified as Async
import Control.Concurrent.STM (atomically)
import Control.Concurrent.STM.TChan qualified as TChan
import Control.Concurrent.STM.TVar qualified as TVar
import Control.Monad (forever)
import Data.Map.Strict qualified as Map
import Data.Maybe (isNothing)
import Data.Time.Clock.POSIX (getPOSIXTime)


-- TODO Add a ThrottleHash typeclass which converts 'a' to 'id'?


{-| Throttling with given interval. Here, throttling means: perform
action only as frequently as allowed and other calls are DROPPED. This
is in contrast to things like Conduit throttling, where actions are
just SLOWED DOWN. We use this for asynchronous notifications and, if
messages are the same, we can just drop them safely. Our input is the
provided 'TChan.TChan'.

This function should be spawned as a thread.

We provide separate 'id' and 'a'. 'id' is used for uniquely
identifying groups of throttled messages, while 'a' are actual
messages that are sent to 'action' callback.
-}
throttle :: (Ord id, Eq id, Show id) => Int -> TChan.TChan (id, a) -> (a -> IO ()) -> IO ()
throttle :: forall id a.
(Ord id, Eq id, Show id) =>
Int -> TChan (id, a) -> (a -> IO ()) -> IO ()
throttle Int
delay TChan (id, a)
tchan a -> IO ()
action = do
  TVar (Map id (a, Int))
smap <- Map id (a, Int) -> IO (TVar (Map id (a, Int)))
forall a. a -> IO (TVar a)
TVar.newTVarIO Map id (a, Int)
forall k a. Map k a
Map.empty :: IO (TVar.TVar (Map.Map id (a, Int)))
  
  IO Any -> (Async Any -> IO ()) -> IO ()
forall a b. IO a -> (Async a -> IO b) -> IO b
Async.withAsync (TVar (Map id (a, Int)) -> IO Any
forall {k} {b}. Ord k => TVar (Map k (a, Int)) -> IO b
mapCleaner TVar (Map id (a, Int))
smap) ((Async Any -> IO ()) -> IO ()) -> (Async Any -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Async Any
_ -> IO () -> IO ()
forall (f :: * -> *) a b. Applicative f => f a -> f b
forever (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
    (id
msgId, a
msg) <- STM (id, a) -> IO (id, a)
forall a. STM a -> IO a
atomically (STM (id, a) -> IO (id, a)) -> STM (id, a) -> IO (id, a)
forall a b. (a -> b) -> a -> b
$ TChan (id, a) -> STM (id, a)
forall a. TChan a -> STM a
TChan.readTChan TChan (id, a)
tchan

    Int
now <- IO Int
unixTime

    STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ TVar (Map id (a, Int))
-> (Map id (a, Int) -> Map id (a, Int)) -> STM ()
forall a. TVar a -> (a -> a) -> STM ()
TVar.modifyTVar TVar (Map id (a, Int))
smap (id -> (a, Int) -> Map id (a, Int) -> Map id (a, Int)
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert id
msgId (a
msg,  Int
now))

  where
    -- | This thread just clears outdated map elements at regular intervals
    mapCleaner :: TVar (Map k (a, Int)) -> IO b
mapCleaner TVar (Map k (a, Int))
smap = IO () -> IO b
forall (f :: * -> *) a b. Applicative f => f a -> f b
forever (IO () -> IO b) -> IO () -> IO b
forall a b. (a -> b) -> a -> b
$ do

      -- https://stackoverflow.com/questions/42843882/how-do-you-get-a-millisecond-precision-unix-timestamp-in-haskell
      Int
now <- IO Int
unixTime

      Map k (a, Int)
m <- TVar (Map k (a, Int)) -> IO (Map k (a, Int))
forall a. TVar a -> IO a
TVar.readTVarIO TVar (Map k (a, Int))
smap
      -- let (_needToWait, canRun) = Map.partition (\(_, t) -> now - t < delay) m
      let canRun :: Map k (a, Int)
canRun = ((a, Int) -> Bool) -> Map k (a, Int) -> Map k (a, Int)
forall a k. (a -> Bool) -> Map k a -> Map k a
Map.filter (\(a
_, Int
t) -> Int
now Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
t Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
delay) Map k (a, Int)
m
      -- putStrLn $ "[mapCleaner] m " <> show (Map.mapWithKey (\k (_, t) -> (k, now - t)) m)
      -- putStrLn $ "[mapCleaner] canRun " <> show (Map.keys canRun)
      ((a, Int) -> IO ()) -> Map k (a, Int) -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (\(a
msg, Int
_) -> a -> IO ()
action a
msg) Map k (a, Int)
canRun

      -- OK so this is a bit tricky. STM guarantees atomic read above
      -- and 'smap' could have been modified while we ran 'mapM_'. The
      -- only way to modify 'smap' is to add new items.
      -- * an item in 'canRun' was added: so we called the throttled
      --   function and it's been added in the meantime into the queue.
      --   In this case we have to compare the time again with 'now'
      -- * an item not in 'canRun' was added
      STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$
        TVar (Map k (a, Int))
-> (Map k (a, Int) -> Map k (a, Int)) -> STM ()
forall a. TVar a -> (a -> a) -> STM ()
TVar.modifyTVar TVar (Map k (a, Int))
smap ((Map k (a, Int) -> Map k (a, Int)) -> STM ())
-> (Map k (a, Int) -> Map k (a, Int)) -> STM ()
forall a b. (a -> b) -> a -> b
$
          (k -> (a, Int) -> Bool) -> Map k (a, Int) -> Map k (a, Int)
forall k a. (k -> a -> Bool) -> Map k a -> Map k a
Map.filterWithKey (\k
k (a
_, Int
t) -> Maybe (a, Int) -> Bool
forall a. Maybe a -> Bool
isNothing (k -> Map k (a, Int) -> Maybe (a, Int)
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup k
k Map k (a, Int)
canRun) Bool -> Bool -> Bool
|| (Int
now Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
t Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0 Bool -> Bool -> Bool
&& Int
now Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
t Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
delay))

      Int -> IO ()
threadDelay (Int
delay Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2)


-- | Get Unix timestamp, with millisecond resolution
unixTime :: IO Int
unixTime :: IO Int
unixTime = (POSIXTime -> Int
forall b. Integral b => POSIXTime -> b
forall a b. (RealFrac a, Integral b) => a -> b
round (POSIXTime -> Int) -> (POSIXTime -> POSIXTime) -> POSIXTime -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (POSIXTime -> POSIXTime -> POSIXTime
forall a. Num a => a -> a -> a
* POSIXTime
1000000)) (POSIXTime -> Int) -> IO POSIXTime -> IO Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO POSIXTime
getPOSIXTime