----------------------------------------------------------------------------
-- |
-- Module      :  Control.Concurrent.Counter.Lifted.IO
-- Copyright   :  (c) Sergey Vinokurov 2022
-- License     :  Apache-2.0 (see LICENSE)
-- Maintainer  :  serg.foo@gmail.com
--
-- Lifted 'Control.Concurrent.Counter.Lifted.Counter' specialized to
-- operate in the 'IO' monad.
----------------------------------------------------------------------------

{-# LANGUAGE TypeApplications #-}

module Control.Concurrent.Counter.Lifted.IO
  ( Counter

  -- * Create
  , new

  -- * Read/write
  , get
  , set
  , cas

  -- * Arithmetic operations
  , add
  , sub

  -- * Bitwise operations
  , and
  , or
  , xor
  , nand
  ) where

import Prelude hiding (and, or)

import Data.Coerce
import GHC.Exts (RealWorld)
import GHC.IO
import GHC.ST

import qualified Control.Concurrent.Counter.Lifted.ST as Lifted

-- | Memory location that supports select few atomic operations.
--
-- Isomorphic to @IORef Int@.
newtype Counter = Counter (Lifted.Counter RealWorld)

-- | Pointer equality
instance Eq Counter where
  == :: Counter -> Counter -> Bool
(==) = (Counter RealWorld -> Counter RealWorld -> Bool)
-> Counter -> Counter -> Bool
forall a b. Coercible a b => a -> b
coerce (forall a. Eq a => a -> a -> Bool
(==) @(Lifted.Counter RealWorld))

{-# INLINE new #-}
-- | Create new counter with initial value.
new :: Int -> IO Counter
new :: Int -> IO Counter
new = IO (Counter RealWorld) -> IO Counter
forall a b. Coercible a b => a -> b
coerce (IO (Counter RealWorld) -> IO Counter)
-> (Int -> IO (Counter RealWorld)) -> Int -> IO Counter
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ST RealWorld (Counter RealWorld) -> IO (Counter RealWorld)
forall a. ST RealWorld a -> IO a
stToIO (ST RealWorld (Counter RealWorld) -> IO (Counter RealWorld))
-> (Int -> ST RealWorld (Counter RealWorld))
-> Int
-> IO (Counter RealWorld)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> ST RealWorld (Counter RealWorld)
forall s. Int -> ST s (Counter s)
Lifted.new


{-# INLINE get #-}
-- | Atomically read the counter's value.
get :: Counter -> IO Int
get :: Counter -> IO Int
get = (Counter RealWorld -> ST RealWorld Int) -> Counter -> IO Int
forall a b. Coercible a b => a -> b
coerce Counter RealWorld -> ST RealWorld Int
forall s. Counter s -> ST s Int
Lifted.get

{-# INLINE set #-}
-- | Atomically assign new value to the counter.
set :: Counter -> Int -> IO ()
set :: Counter -> Int -> IO ()
set = (Counter RealWorld -> Int -> ST RealWorld ())
-> Counter -> Int -> IO ()
forall a b. Coercible a b => a -> b
coerce Counter RealWorld -> Int -> ST RealWorld ()
forall s. Counter s -> Int -> ST s ()
Lifted.set

{-# INLINE cas #-}
-- | Atomic compare and swap, i.e. write the new value if the current
-- value matches the provided old value. Returns the value of the
-- element before the operation
cas
  :: Counter
  -> Int -- ^ Expected old value
  -> Int -- ^ New value
  -> IO Int
cas :: Counter -> Int -> Int -> IO Int
cas = (Counter RealWorld -> Int -> Int -> ST RealWorld Int)
-> Counter -> Int -> Int -> IO Int
forall a b. Coercible a b => a -> b
coerce Counter RealWorld -> Int -> Int -> ST RealWorld Int
forall s. Counter s -> Int -> Int -> ST s Int
Lifted.cas

{-# INLINE add #-}
-- | Atomically add an amount to the counter and return its old value.
add :: Counter -> Int -> IO Int
add :: Counter -> Int -> IO Int
add = (Counter RealWorld -> Int -> ST RealWorld Int)
-> Counter -> Int -> IO Int
forall a b. Coercible a b => a -> b
coerce Counter RealWorld -> Int -> ST RealWorld Int
forall s. Counter s -> Int -> ST s Int
Lifted.add

{-# INLINE sub #-}
-- | Atomically subtract an amount from the counter and return its old value.
sub :: Counter -> Int -> IO Int
sub :: Counter -> Int -> IO Int
sub = (Counter RealWorld -> Int -> ST RealWorld Int)
-> Counter -> Int -> IO Int
forall a b. Coercible a b => a -> b
coerce Counter RealWorld -> Int -> ST RealWorld Int
forall s. Counter s -> Int -> ST s Int
Lifted.sub


{-# INLINE and #-}
-- | Atomically combine old value with a new one via bitwise and. Returns old counter value.
and :: Counter -> Int -> IO Int
and :: Counter -> Int -> IO Int
and = (Counter RealWorld -> Int -> ST RealWorld Int)
-> Counter -> Int -> IO Int
forall a b. Coercible a b => a -> b
coerce Counter RealWorld -> Int -> ST RealWorld Int
forall s. Counter s -> Int -> ST s Int
Lifted.and

{-# INLINE or #-}
-- | Atomically combine old value with a new one via bitwise or. Returns old counter value.
or :: Counter -> Int -> IO Int
or :: Counter -> Int -> IO Int
or = (Counter RealWorld -> Int -> ST RealWorld Int)
-> Counter -> Int -> IO Int
forall a b. Coercible a b => a -> b
coerce Counter RealWorld -> Int -> ST RealWorld Int
forall s. Counter s -> Int -> ST s Int
Lifted.or

{-# INLINE xor #-}
-- | Atomically combine old value with a new one via bitwise xor. Returns old counter value.
xor :: Counter -> Int -> IO Int
xor :: Counter -> Int -> IO Int
xor = (Counter RealWorld -> Int -> ST RealWorld Int)
-> Counter -> Int -> IO Int
forall a b. Coercible a b => a -> b
coerce Counter RealWorld -> Int -> ST RealWorld Int
forall s. Counter s -> Int -> ST s Int
Lifted.xor

{-# INLINE nand #-}
-- | Atomically combine old value with a new one via bitwise nand. Returns old counter value.
nand :: Counter -> Int -> IO Int
nand :: Counter -> Int -> IO Int
nand = (Counter RealWorld -> Int -> ST RealWorld Int)
-> Counter -> Int -> IO Int
forall a b. Coercible a b => a -> b
coerce Counter RealWorld -> Int -> ST RealWorld Int
forall s. Counter s -> Int -> ST s Int
Lifted.nand