{-# LANGUAGE BangPatterns #-}

module Instrument.Counter
    ( Counter
    , newCounter
    , readCounter
    , resetCounter
    , add
    , increment
    ) where

-------------------------------------------------------------------------------
import           Control.Monad
import           Data.IORef
-------------------------------------------------------------------------------

newtype Counter = Counter { Counter -> IORef Int
unCounter :: IORef Int }

-------------------------------------------------------------------------------
newCounter :: IO Counter
newCounter :: IO Counter
newCounter = IORef Int -> Counter
Counter (IORef Int -> Counter) -> IO (IORef Int) -> IO Counter
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
`liftM` Int -> IO (IORef Int)
forall a. a -> IO (IORef a)
newIORef Int
0


-------------------------------------------------------------------------------
readCounter :: Counter -> IO Int
readCounter :: Counter -> IO Int
readCounter (Counter IORef Int
i) = IORef Int -> IO Int
forall a. IORef a -> IO a
readIORef IORef Int
i


-------------------------------------------------------------------------------
-- | Reset the counter while reading it
resetCounter :: Counter -> IO Int
resetCounter :: Counter -> IO Int
resetCounter (Counter IORef Int
i) = IORef Int -> (Int -> (Int, Int)) -> IO Int
forall a b. IORef a -> (a -> (a, b)) -> IO b
atomicModifyIORef IORef Int
i Int -> (Int, Int)
forall a b. Num a => b -> (a, b)
f
    where f :: b -> (a, b)
f b
i' = (a
0, b
i')

-------------------------------------------------------------------------------
increment :: Counter -> IO ()
increment :: Counter -> IO ()
increment = Int -> Counter -> IO ()
add Int
1


-------------------------------------------------------------------------------
add :: Int -> Counter -> IO ()
add :: Int -> Counter -> IO ()
add Int
x Counter
c = IORef Int -> (Int -> (Int, ())) -> IO ()
forall a b. IORef a -> (a -> (a, b)) -> IO b
atomicModifyIORef (Counter -> IORef Int
unCounter Counter
c) Int -> (Int, ())
f
    where
      f :: Int -> (Int, ())
f !Int
i = (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
x, ())