module Data.Atomics.Counter.Fat (
      AtomicCounter()
    , newCounter
    , incrCounter
    , readCounter
    ) where

-- An atomic counter padded with 64-bytes (an x86 cache line) on either side to
-- try to avoid false sharing.

import Data.Primitive.MachDeps(sIZEOF_INT)
import Control.Monad.Primitive(RealWorld)
import Data.Primitive.ByteArray
import Data.Atomics(fetchAddIntArray)
import Control.Exception(assert)

newtype AtomicCounter = AtomicCounter (MutableByteArray RealWorld)

sIZEOF_CACHELINE :: Int
{-# INLINE sIZEOF_CACHELINE #-}
sIZEOF_CACHELINE :: Int
sIZEOF_CACHELINE   = Int
64

newCounter :: Int -> IO AtomicCounter
{-# INLINE newCounter #-}
newCounter :: Int -> IO AtomicCounter
newCounter Int
n = do
    MutableByteArray RealWorld
arr <- Int -> Int -> IO (MutableByteArray (PrimState IO))
forall (m :: * -> *).
PrimMonad m =>
Int -> Int -> m (MutableByteArray (PrimState m))
newAlignedPinnedByteArray 
                Int
sIZEOF_CACHELINE
                Int
sIZEOF_CACHELINE
    MutableByteArray (PrimState IO) -> Int -> Int -> IO ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
writeByteArray MutableByteArray RealWorld
MutableByteArray (PrimState IO)
arr Int
0 Int
n
    -- out of principle:
    Bool -> IO AtomicCounter -> IO AtomicCounter
forall a. (?callStack::CallStack) => Bool -> a -> a
assert (Int
sIZEOF_INT Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
sIZEOF_CACHELINE) (IO AtomicCounter -> IO AtomicCounter)
-> IO AtomicCounter -> IO AtomicCounter
forall a b. (a -> b) -> a -> b
$
      AtomicCounter -> IO AtomicCounter
forall (m :: * -> *) a. Monad m => a -> m a
return (MutableByteArray RealWorld -> AtomicCounter
AtomicCounter MutableByteArray RealWorld
arr)

incrCounter :: Int -> AtomicCounter -> IO Int
{-# INLINE incrCounter #-}
incrCounter :: Int -> AtomicCounter -> IO Int
incrCounter Int
incr (AtomicCounter MutableByteArray RealWorld
arr) =
    MutableByteArray RealWorld -> Int -> Int -> IO Int
fetchAddIntArray MutableByteArray RealWorld
arr Int
0 Int
incr

readCounter :: AtomicCounter -> IO Int
{-# INLINE readCounter #-}
readCounter :: AtomicCounter -> IO Int
readCounter (AtomicCounter MutableByteArray RealWorld
arr) = 
    MutableByteArray (PrimState IO) -> Int -> IO Int
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> m a
readByteArray MutableByteArray RealWorld
MutableByteArray (PrimState IO)
arr Int
0