{-# language MagicHash #-}
{-# language UnboxedTuples #-}

module Data.Primitive.Array.Atomic
  ( casArray
  ) where

import Control.Monad.Primitive (PrimMonad,PrimState,primitive)
import Data.Primitive (MutableArray(..))
import GHC.Exts (Int(I#),casArray#,isTrue#,(==#))

-- | Given an array, an offset in Int units, the expected old value,
-- and the new value, perform an atomic compare and swap i.e. write
-- the new value if the current value matches the provided old value.
-- Returns the value of the element before the operation. Implies a
-- full memory barrier.
--
-- Note that lifted values in GHC have limited guarantees concerning
-- pointer equality. In particular, data constructor applications of
-- single-constructor data types may be mangled by GHC Core optimizations.
-- Users of this function are expected to understand how to make
-- pointer equality survive GHC's optimization passes.
casArray :: PrimMonad m
  => MutableArray (PrimState m) a -- ^ prim array
  -> Int -- ^ index
  -> a -- ^ expected old value
  -> a -- ^ new value
  -> m (Bool,a)
{-# INLINE casArray #-}
casArray :: forall (m :: * -> *) a.
PrimMonad m =>
MutableArray (PrimState m) a -> Int -> a -> a -> m (Bool, a)
casArray (MutableArray MutableArray# (PrimState m) a
arr#) (I# Int#
i#) a
old a
new =
  (State# (PrimState m) -> (# State# (PrimState m), (Bool, a) #))
-> m (Bool, a)
forall a.
(State# (PrimState m) -> (# State# (PrimState m), a #)) -> m a
forall (m :: * -> *) a.
PrimMonad m =>
(State# (PrimState m) -> (# State# (PrimState m), a #)) -> m a
primitive ((State# (PrimState m) -> (# State# (PrimState m), (Bool, a) #))
 -> m (Bool, a))
-> (State# (PrimState m) -> (# State# (PrimState m), (Bool, a) #))
-> m (Bool, a)
forall a b. (a -> b) -> a -> b
$ \State# (PrimState m)
s0 -> case MutableArray# (PrimState m) a
-> Int#
-> a
-> a
-> State# (PrimState m)
-> (# State# (PrimState m), Int#, a #)
forall d a.
MutableArray# d a
-> Int# -> a -> a -> State# d -> (# State# d, Int#, a #)
casArray# MutableArray# (PrimState m) a
arr# Int#
i# a
old a
new State# (PrimState m)
s0 of
    (# State# (PrimState m)
s1, Int#
n, a
r #) -> (# State# (PrimState m)
s1, (Int# -> Bool
isTrue# (Int#
n Int# -> Int# -> Int#
==# Int#
0# ),a
r) #)