{-# language BangPatterns #-}
{-# language MagicHash #-}
{-# language UnboxedTuples #-}
{-# language ScopedTypeVariables #-}

module Data.Primitive.Unlifted.Atomic
  ( casUnliftedArray
  ) where

import Control.Monad.Primitive (PrimMonad,PrimState,primitive)
import Data.Primitive.Unlifted.Array (MutableUnliftedArray,MutableUnliftedArray_(..))
import Data.Primitive.Unlifted.Array.Primops (MutableUnliftedArray#(..))
import Data.Primitive.Unlifted.Class (PrimUnlifted,toUnlifted#,fromUnlifted#)
import GHC.Exts (Int(I#))
import GHC.Exts (casArray#,isTrue#,(==#))

-- | Given an array, an offset, the expected old value,
-- and the new value, perform an atomic compare and swap i.e. write
-- the new value if the current value matches the provided old value.
-- Returns the value of the element before the operation. Implies a
-- full memory barrier.
--
-- Some unlifted types, in particular the ones that correspond to mutable
-- resources, have good guarantees about pointer equality. With these
-- types, this function is much easier to reason about than @casArray@.
casUnliftedArray :: forall m a. (PrimMonad m, PrimUnlifted a)
  => MutableUnliftedArray (PrimState m) a -- ^ array
  -> Int -- ^ index
  -> a -- ^ expected old value
  -> a -- ^ new value
  -> m (Bool,a)
{-# INLINE casUnliftedArray #-}
casUnliftedArray :: forall (m :: * -> *) a.
(PrimMonad m, PrimUnlifted a) =>
MutableUnliftedArray (PrimState m) a
-> Int -> a -> a -> m (Bool, a)
casUnliftedArray (MutableUnliftedArray (MutableUnliftedArray# MutableArray# (PrimState m) (Unlifted a)
arr#)) (I# Int#
i#) a
old a
new =
  -- All of this unsafeCoercing is really nasty business. This will go away
  -- once https://github.com/ghc-proposals/ghc-proposals/pull/203 happens.
  -- Also, this is unsound if the result is immidiately consumed by
  -- the FFI.
  (State# (PrimState m) -> (# State# (PrimState m), (Bool, a) #))
-> m (Bool, a)
forall a.
(State# (PrimState m) -> (# State# (PrimState m), a #)) -> m a
forall (m :: * -> *) a.
PrimMonad m =>
(State# (PrimState m) -> (# State# (PrimState m), a #)) -> m a
primitive ((State# (PrimState m) -> (# State# (PrimState m), (Bool, a) #))
 -> m (Bool, a))
-> (State# (PrimState m) -> (# State# (PrimState m), (Bool, a) #))
-> m (Bool, a)
forall a b. (a -> b) -> a -> b
$ \State# (PrimState m)
s0 ->
    let !uold :: Unlifted a
uold = a -> Unlifted a
forall a. PrimUnlifted a => a -> Unlifted a
toUnlifted# a
old
        !unew :: Unlifted a
unew = a -> Unlifted a
forall a. PrimUnlifted a => a -> Unlifted a
toUnlifted# a
new
     in case MutableArray# (PrimState m) (Unlifted a)
-> Int#
-> Unlifted a
-> Unlifted a
-> State# (PrimState m)
-> (# State# (PrimState m), Int#, Unlifted a #)
forall d a.
MutableArray# d a
-> Int# -> a -> a -> State# d -> (# State# d, Int#, a #)
casArray# MutableArray# (PrimState m) (Unlifted a)
arr# Int#
i# Unlifted a
uold Unlifted a
unew State# (PrimState m)
s0 of
          (# State# (PrimState m)
s1, Int#
n, Unlifted a
ur #) -> (# State# (PrimState m)
s1, (Int# -> Bool
isTrue# (Int#
n Int# -> Int# -> Int#
==# Int#
0# ),Unlifted a -> a
forall a. PrimUnlifted a => Unlifted a -> a
fromUnlifted# Unlifted a
ur) #)