{-# LANGUAGE TupleSections #-}

{-|
Module      : TreeCounter
Description : Wait-free Tree Counter.
License     : BSD3
Maintainer  : Julian Sutherland (julian.sutherland10@imperial.ac.uk)

A wait-free tree counter. Creates a binary tree of counters, with each leaf associated with a thread. Leaves can be split, creating a new leaf for the current thread and another that can be used by another thread. Each thread will act on different leaves, meaning the actions are wait-free. A read is performed on the counter by recursively traversing it and summing the value of the counters in the nodes and leaves of the tree.
-}

module Data.NonBlocking.WaitFree.TreeCounter(TreeCounter(), TreeCounterIO, TreeCounterSTM, newTreeCounter, splitTreeCounter, incTreeCounter, readTreeCounter) where

import Control.Concurrent.STM (STM())
import Control.Concurrent.STM.TVar (TVar())
import Control.Monad (join)
import Control.Monad.Ref (MonadAtomicRef, newRef, readRef, writeRef, atomicModifyRef)
import Data.IORef(IORef())

-- |TreeCounter inside the IO Monad.
type TreeCounterIO = TreeCounter IORef
-- |TreeCounter inside the STM Monad.
type TreeCounterSTM = TreeCounter TVar

-- |A wait-free concurrent Tree Counter, a binary tree of counters, with each leaf associated with a thread. Leaves can be split, creating a new leaf for the current thread and another that can be used by another thread. Increments are wait-free as long as each thread performs them on different instance of TreeCounter split from an initial instance using 'splitTreeCounter', prone to ABA problem otherwise.
data TreeCounter r = TreeCounter (r (r (CounterTree r), r (CounterTree r)))
data CounterTree r = Node Integer (r (CounterTree r)) (r (CounterTree r)) | Leaf (r Integer)

-- |Creates a new instance of the 'TreeCounter' data type, instanciated to the value of the input, with type in the 'Integral' class. 
{-# SPECIALIZE newTreeCounter :: (Integral a) => a -> IO TreeCounterIO #-}
{-# SPECIALIZE newTreeCounter :: (Integral a) => a -> STM TreeCounterSTM #-}
newTreeCounter :: (MonadAtomicRef r m, Integral a) => a -> m (TreeCounter r)
newTreeCounter n = newRef (toInteger n) >>= newRef . Leaf >>=  newRef . join (,) >>= return . TreeCounter

-- |Splits a 'TreeCounter' instance, updating it to a new leaf and creating a new one, allowing another thread to increment the counter in a wait-free manner.
{-# SPECIALIZE splitTreeCounter :: TreeCounterIO -> IO TreeCounterIO #-}
{-# SPECIALIZE splitTreeCounter :: TreeCounterSTM -> STM TreeCounterSTM #-}
splitTreeCounter :: (MonadAtomicRef r m) => TreeCounter r -> m (TreeCounter r)
splitTreeCounter (TreeCounter tupleRef) = do
  (leafRef, rootRef) <- readRef tupleRef
  (Leaf lCountRef)   <- readRef leafRef
  lCount   <- readRef lCountRef
  leftRef  <- newRef 0 >>= newRef . Leaf
  rightRef <- newRef 0 >>= newRef . Leaf
  writeRef leafRef (Node lCount leftRef rightRef)
  writeRef tupleRef (leftRef, rootRef)
  newRef (rightRef, rootRef) >>= return . TreeCounter

-- |Increments the 'TreeCounter' in an atomic manner as long as this thread is the only thread incrementing the counter from this instance 'TreeCounter'
{-# SPECIALIZE incTreeCounter :: TreeCounterIO -> IO () #-}
{-# SPECIALIZE incTreeCounter :: TreeCounterSTM -> STM () #-}
incTreeCounter :: (MonadAtomicRef r m) => TreeCounter r -> m ()
incTreeCounter (TreeCounter tupleRef) = do
  (leafRef, _) <- readRef tupleRef
  (Leaf lCountRef) <- readRef leafRef
  atomicModifyRef lCountRef ((,()) . (+1))

-- |Reads the total value of the binary tree of counters associated with this instance of 'TreeCounter'.
{-# SPECIALIZE readTreeCounter :: (Num a) => TreeCounterIO -> IO a #-}
{-# SPECIALIZE readTreeCounter :: (Num a) => TreeCounterSTM -> STM a #-}
readTreeCounter :: (MonadAtomicRef r m, Num a) => TreeCounter r -> m a
readTreeCounter (TreeCounter tupleRef) = readRef tupleRef >>= readRef . snd >>= sumTree >>= return . fromInteger

{-# SPECIALIZE sumTree :: CounterTree IORef -> IO Integer #-}
{-# SPECIALIZE sumTree :: CounterTree TVar -> STM Integer #-}
sumTree :: (MonadAtomicRef r m) => CounterTree r -> m Integer
sumTree (Leaf lCountRef) = readRef lCountRef
sumTree (Node nCount leftRef rightRef) = do
  lCount <- readRef leftRef >>= sumTree
  rCount <- readRef rightRef >>= sumTree
  return (nCount + lCount + rCount)