module TimerWheel.Internal.AlarmBuckets
  ( AlarmBuckets,
    AlarmId,
    insert,
    delete,
    deleteExpiredAt,
    timestampToIndex,
  )
where

import Data.Atomics qualified as Atomics
import Data.Primitive.Array (MutableArray)
import Data.Primitive.Array qualified as Array
import GHC.Base (RealWorld)
import TimerWheel.Internal.Alarm (Alarm (..))
import TimerWheel.Internal.Bucket (Bucket)
import TimerWheel.Internal.Bucket qualified as Bucket
import TimerWheel.Internal.Nanoseconds (Nanoseconds)
import TimerWheel.Internal.Prelude
import TimerWheel.Internal.Timestamp (Timestamp)
import TimerWheel.Internal.Timestamp qualified as Timestamp

type AlarmBuckets =
  MutableArray RealWorld (Bucket Alarm)

type AlarmId =
  Int

insert :: AlarmBuckets -> Nanoseconds -> AlarmId -> Timestamp -> Alarm -> IO ()
insert :: AlarmBuckets -> Nanoseconds -> Int -> Timestamp -> Alarm -> IO ()
insert AlarmBuckets
buckets Nanoseconds
resolution Int
alarmId Timestamp
timestamp Alarm
alarm = do
  Ticket (Bucket Alarm)
ticket <- AlarmBuckets -> Int -> IO (Ticket (Bucket Alarm))
forall a. MutableArray RealWorld a -> Int -> IO (Ticket a)
Atomics.readArrayElem AlarmBuckets
buckets Int
index
  Ticket (Bucket Alarm) -> IO ()
loop Ticket (Bucket Alarm)
ticket
  where
    loop :: Atomics.Ticket (Bucket Alarm) -> IO ()
    loop :: Ticket (Bucket Alarm) -> IO ()
loop Ticket (Bucket Alarm)
ticket = do
      (Bool
success, Ticket (Bucket Alarm)
ticket1) <-
        AlarmBuckets
-> Int
-> Ticket (Bucket Alarm)
-> Bucket Alarm
-> IO (Bool, Ticket (Bucket Alarm))
forall a.
MutableArray RealWorld a
-> Int -> Ticket a -> a -> IO (Bool, Ticket a)
Atomics.casArrayElem
          AlarmBuckets
buckets
          Int
index
          Ticket (Bucket Alarm)
ticket
          (Int -> Timestamp -> Alarm -> Bucket Alarm -> Bucket Alarm
forall a. Int -> Timestamp -> a -> Bucket a -> Bucket a
Bucket.insert Int
alarmId Timestamp
timestamp Alarm
alarm (Ticket (Bucket Alarm) -> Bucket Alarm
forall a. Ticket a -> a
Atomics.peekTicket Ticket (Bucket Alarm)
ticket))
      if Bool
success then () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure () else Ticket (Bucket Alarm) -> IO ()
loop Ticket (Bucket Alarm)
ticket1

    index :: Int
    index :: Int
index =
      AlarmBuckets -> Nanoseconds -> Timestamp -> Int
timestampToIndex AlarmBuckets
buckets Nanoseconds
resolution Timestamp
timestamp

delete :: AlarmBuckets -> Nanoseconds -> AlarmId -> Timestamp -> IO Bool
delete :: AlarmBuckets -> Nanoseconds -> Int -> Timestamp -> IO Bool
delete AlarmBuckets
buckets Nanoseconds
resolution Int
alarmId Timestamp
timestamp = do
  Ticket (Bucket Alarm)
ticket <- AlarmBuckets -> Int -> IO (Ticket (Bucket Alarm))
forall a. MutableArray RealWorld a -> Int -> IO (Ticket a)
Atomics.readArrayElem AlarmBuckets
buckets Int
index
  Ticket (Bucket Alarm) -> IO Bool
loop Ticket (Bucket Alarm)
ticket
  where
    loop :: Atomics.Ticket (Bucket Alarm) -> IO Bool
    loop :: Ticket (Bucket Alarm) -> IO Bool
loop Ticket (Bucket Alarm)
ticket =
      case Int -> Bucket Alarm -> Maybe (Bucket Alarm)
forall v. Int -> Bucket v -> Maybe (Bucket v)
Bucket.deleteExpectingHit Int
alarmId (Ticket (Bucket Alarm) -> Bucket Alarm
forall a. Ticket a -> a
Atomics.peekTicket Ticket (Bucket Alarm)
ticket) of
        Maybe (Bucket Alarm)
Nothing -> Bool -> IO Bool
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False
        Just Bucket Alarm
bucket -> do
          (Bool
success, Ticket (Bucket Alarm)
ticket1) <- AlarmBuckets
-> Int
-> Ticket (Bucket Alarm)
-> Bucket Alarm
-> IO (Bool, Ticket (Bucket Alarm))
forall a.
MutableArray RealWorld a
-> Int -> Ticket a -> a -> IO (Bool, Ticket a)
Atomics.casArrayElem AlarmBuckets
buckets Int
index Ticket (Bucket Alarm)
ticket Bucket Alarm
bucket
          if Bool
success then Bool -> IO Bool
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
True else Ticket (Bucket Alarm) -> IO Bool
loop Ticket (Bucket Alarm)
ticket1

    index :: Int
    index :: Int
index =
      AlarmBuckets -> Nanoseconds -> Timestamp -> Int
timestampToIndex AlarmBuckets
buckets Nanoseconds
resolution Timestamp
timestamp

deleteExpiredAt :: AlarmBuckets -> Int -> Timestamp -> IO (Bucket Alarm)
deleteExpiredAt :: AlarmBuckets -> Int -> Timestamp -> IO (Bucket Alarm)
deleteExpiredAt AlarmBuckets
buckets Int
index Timestamp
now = do
  Ticket (Bucket Alarm)
ticket <- AlarmBuckets -> Int -> IO (Ticket (Bucket Alarm))
forall a. MutableArray RealWorld a -> Int -> IO (Ticket a)
Atomics.readArrayElem AlarmBuckets
buckets Int
index
  Ticket (Bucket Alarm) -> IO (Bucket Alarm)
loop Ticket (Bucket Alarm)
ticket
  where
    loop :: Atomics.Ticket (Bucket Alarm) -> IO (Bucket Alarm)
    loop :: Ticket (Bucket Alarm) -> IO (Bucket Alarm)
loop Ticket (Bucket Alarm)
ticket = do
      let Bucket.Pair Bucket Alarm
expired Bucket Alarm
bucket1 = Timestamp -> Bucket Alarm -> Pair (Bucket Alarm) (Bucket Alarm)
forall a. Timestamp -> Bucket a -> Pair (Bucket a) (Bucket a)
Bucket.partition Timestamp
now (Ticket (Bucket Alarm) -> Bucket Alarm
forall a. Ticket a -> a
Atomics.peekTicket Ticket (Bucket Alarm)
ticket)
      if Bucket Alarm -> Bool
forall a. Bucket a -> Bool
Bucket.isEmpty Bucket Alarm
expired
        then Bucket Alarm -> IO (Bucket Alarm)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bucket Alarm
forall a. Bucket a
Bucket.empty
        else do
          (Bool
success, Ticket (Bucket Alarm)
ticket1) <- AlarmBuckets
-> Int
-> Ticket (Bucket Alarm)
-> Bucket Alarm
-> IO (Bool, Ticket (Bucket Alarm))
forall a.
MutableArray RealWorld a
-> Int -> Ticket a -> a -> IO (Bool, Ticket a)
Atomics.casArrayElem AlarmBuckets
buckets Int
index Ticket (Bucket Alarm)
ticket Bucket Alarm
bucket1
          if Bool
success then Bucket Alarm -> IO (Bucket Alarm)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bucket Alarm
expired else Ticket (Bucket Alarm) -> IO (Bucket Alarm)
loop Ticket (Bucket Alarm)
ticket1

-- `timestampToIndex buckets resolution timestamp` figures out which index `timestamp` corresponds to in `buckets`,
-- where each bucket corresponds to `resolution` nanoseconds.
--
-- For example, consider a three-element `buckets` with resolution `1000000000`.
--
--   +--------------------------------------+
--   | 1000000000 | 1000000000 | 1000000000 |
--   +--------------------------------------+
--
-- Some timestamp like `1053298012387` gets binned to one of the three indices 0, 1, or 2, with quick and easy maffs:
--
--   1. Figure out which index the timestamp corresponds to, if there were infinitely many:
--
--        1053298012387 `div` 1000000000 = 1053
--
--   2. Wrap around per the actual length of the array:
--
--        1053 `rem` 3 = 0
timestampToIndex :: AlarmBuckets -> Nanoseconds -> Timestamp -> Int
timestampToIndex :: AlarmBuckets -> Nanoseconds -> Timestamp -> Int
timestampToIndex AlarmBuckets
buckets Nanoseconds
resolution Timestamp
timestamp =
  -- This downcast is safe because there are at most `maxBound :: Int` buckets (not that anyone would ever have that
  -- many...)
  forall a b. (Integral a, Num b) => a -> b
fromIntegral @Word64 @Int
    (Nanoseconds -> Timestamp -> Word64
Timestamp.epoch Nanoseconds
resolution Timestamp
timestamp Word64 -> Word64 -> Word64
forall a. Integral a => a -> a -> a
`rem` forall a b. (Integral a, Num b) => a -> b
fromIntegral @Int @Word64 (AlarmBuckets -> Int
forall s a. MutableArray s a -> Int
Array.sizeofMutableArray AlarmBuckets
buckets))