{-# LANGUAGE NamedFieldPuns #-}
{- |
  This module provides utilities for client-side load balancing.
-}
module Control.Concurrent.LoadDistribution (
  evenlyDistributed,
  LoadBalanced,
  withResource,
  map,
) where

import Prelude hiding (map)

import Control.Concurrent.STM (atomically)
import Control.Concurrent.STM.TVar (TVar, newTVar, readTVar, writeTVar,
  modifyTVar)
import Control.Exception (bracket)
import Data.PSQueue (PSQ, minView, adjust, Binding((:->)), alter, insert)
import Data.Set (Set, member)
import qualified Data.PSQueue as Q (empty)
import qualified Data.Set as Set (foldr, map)
import qualified System.Log.Logger as L (debugM)

{- |
  A load balancing technique that evenly load across multiple resources, where
  "even" is figured based on the current number of simultaneous threads using
  the resource.

  Some notes on implentation: The priority portion of the priority queue
  is a tuple of the form (p :: Int, r :: Int). The first element, @p@, stands
  for "processes", and indicates the number of simultaneous threads that are
  using the specific resource. This value is incremented and decremented as
  threads acquire and release the resource. The second element, @r@, stands for
  "round robin". Its purpose is to act as a tie breaker between resources that
  have the same @p@ value by giving priority to whichever resource has been
  chosen fewer times. It is always incremented and never decremented. Without
  the @r@ component, the deterministic nature of the priority queue will
  heavily favor resources that come first in the sorted list of resources.
-}
evenlyDistributed
  :: (Ord resource)
  => IO (Set resource)
    -- ^ An IO action that can be used to obtain a tagged list of
    --   resources.  This is an IO action because the set of resources
    --   is allowed to change dynamically.
  -> IO (LoadBalanced resource)
evenlyDistributed getResources = do
  psQueueT <- atomically (newTVar Q.empty)
  return LB {psQueueT, getResources}


{- |
  An object that keeps track of load balancing across threads.
-}
data LoadBalanced resource =
  LB {
    psQueueT :: TVar (PSQ resource (Int, Int)),
    getResources :: IO (Set resource)
  }


{- |
  Derives a new load balancer from an existing load balancer, allowing the
  resource type to be mapped from one type to another. This requires IO
  (as opposed to implementing this using @instance Functor@) because we
  have to create a new shared load balancing state.
-}
map :: (Ord b) => (a -> b) -> LoadBalanced a -> IO (LoadBalanced b)
map f lb = do
  psQueueT <- atomically (newTVar Q.empty)
  return LB {psQueueT, getResources = Set.map f <$> getResources lb}


{- |
  Execute some kind of IO Action with the least loaded resource. If, for
  whatever reason, there are no available resources, then `Nothing` is passed
  to the user-provided action function.
-}
withResource
  :: (Show resource, Ord resource)
  => LoadBalanced resource
  -> (Maybe resource -> IO a)
  -> IO a
withResource LB {psQueueT, getResources} action = do
  resources <- getResources
  invokeResource psQueueT resources action


{- |
  Invoke the resource.
-}
invokeResource :: (Ord resource, Show resource)
  => TVar (PSQ resource (Int, Int))
  -> Set resource
  -> (Maybe resource -> IO a)
  -> IO a
invokeResource psQueueT targets =
    bracket checkout checkin
  where
    checkout = do
      url <- atomically $ do
        queue <- updateTargets <$> readTVar psQueueT
        (newQueue, resource) <- getBest queue
        writeTVar psQueueT newQueue
        return resource
      debugM ("Choosing: " ++ show url)
      return url
      where
        getBest queue =
          case minView queue of
            Nothing -> return (queue, Nothing)
            Just (resource :-> (p, r), newQueue) ->
              if resource `member` targets
                then return (
                    insert resource (p + 1, r + 1) newQueue,
                    Just resource
                  )
                else getBest newQueue
          
    checkin Nothing = return ()
    checkin (Just resource) = atomically $
      modifyTVar
        psQueueT
        (adjust decrement resource)
      where
        decrement (p, r) = (p - 1, r)

    updateTargets queue =
      Set.foldr
        (alter insertMissing)
        queue
        targets
      where
        insertMissing Nothing = Just (0, 0)
        insertMissing (Just p) = Just p


{- |
  Shorthand logging.
-}
debugM :: String -> IO ()
debugM = L.debugM "load-balancing"