-- Some code modified from the atomic-primops library; license included below.
-- Copyright (c)2012-2013, Ryan R. Newton
--
-- All rights reserved.
--
-- Redistribution and use in source and binary forms, with or without
-- modification, are permitted provided that the following conditions are met:
--
--     * Redistributions of source code must retain the above copyright
--       notice, this list of conditions and the following disclaimer.
--
--     * Redistributions in binary form must reproduce the above
--       copyright notice, this list of conditions and the following
--       disclaimer in the documentation and/or other materials provided
--       with the distribution.
--
--     * Neither the name of Ryan R. Newton nor the names of other
--       contributors may be used to endorse or promote products derived
--       from this software without specific prior written permission.
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE UnboxedTuples #-}

module Ki.Internal.Counter
  ( Counter,
    newCounter,
    incrCounter,
  )
where

import Data.Bits
import GHC.Base
import Ki.Internal.Prelude

-- | A thread-safe counter implemented with atomic fetch-and-add.
data Counter
  = Counter (MutableByteArray# RealWorld)

-- | Create a new counter initialized to 0.
newCounter :: IO Counter
newCounter :: IO Counter
newCounter =
  (State# RealWorld -> (# State# RealWorld, Counter #)) -> IO Counter
forall a. (State# RealWorld -> (# State# RealWorld, a #)) -> IO a
IO \State# RealWorld
s0# ->
    case Int#
-> State# RealWorld
-> (# State# RealWorld, MutableByteArray# RealWorld #)
forall d. Int# -> State# d -> (# State# d, MutableByteArray# d #)
newByteArray# Int#
size State# RealWorld
s0# of
      (# State# RealWorld
s1#, MutableByteArray# RealWorld
arr# #) ->
        case MutableByteArray# RealWorld
-> Int# -> Int# -> State# RealWorld -> State# RealWorld
forall d.
MutableByteArray# d -> Int# -> Int# -> State# d -> State# d
writeIntArray# MutableByteArray# RealWorld
arr# Int#
0# Int#
0# State# RealWorld
s1# of
          State# RealWorld
s2# -> (# State# RealWorld
s2#, MutableByteArray# RealWorld -> Counter
Counter MutableByteArray# RealWorld
arr# #)
  where
    !(I# Int#
size) =
      Int -> Int
forall b. FiniteBits b => b -> Int
finiteBitSize (Int
forall a. HasCallStack => a
undefined :: Int) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
8
{-# INLINE newCounter #-}

-- | Increment a counter and return the value prior to incrementing.
incrCounter :: Counter -> IO Int
incrCounter :: Counter -> IO Int
incrCounter (Counter MutableByteArray# RealWorld
arr#) =
  (State# RealWorld -> (# State# RealWorld, Int #)) -> IO Int
forall a. (State# RealWorld -> (# State# RealWorld, a #)) -> IO a
IO \State# RealWorld
s0# ->
    case MutableByteArray# RealWorld
-> Int# -> Int# -> State# RealWorld -> (# State# RealWorld, Int# #)
forall d.
MutableByteArray# d
-> Int# -> Int# -> State# d -> (# State# d, Int# #)
fetchAddIntArray# MutableByteArray# RealWorld
arr# Int#
0# Int#
1# State# RealWorld
s0# of
      (# State# RealWorld
s1#, Int#
n# #) -> (# State# RealWorld
s1#, Int# -> Int
I# Int#
n# #)
{-# INLINE incrCounter #-}