{-# LANGUAGE CPP #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# OPTIONS_HADDOCK not-home #-}
-----------------------------------------------------------------------------
-- |
-- Copyright   :  (C) 2015 Edward Kmett, Paul Khuong and Ted Cooper
-- License     :  BSD-style (see the file LICENSE)
-- Maintainer  :  Edward Kmett <ekmett@gmail.com>,
--                Ted Cooper <anthezium@gmail.com>
-- Stability   :  experimental
-- Portability :  non-portable
--
-- QSBR-based RCU
-----------------------------------------------------------------------------
module Control.Concurrent.RCU.GC.Internal
  ( SRef(..)
  , RCUThread(..)
  , RCU(..)
  , runRCU
  , runOnRCU
  , ReadingRCU(..)
  , WritingRCU(..)
  , RCUState(..)
#if BENCHMARKS
  , unRCU
  , runWritingRCU
  , runReadingRCU
  , writeSRefIO
#endif
  ) where

import Control.Applicative
import Control.Concurrent
import Control.Concurrent.RCU.Class
import Control.Monad
import Control.Monad.IO.Class
import Control.Monad.Primitive
import Control.Parallel
import Data.Atomics
import Data.IORef
import Data.List
import Data.Primitive
import Prelude hiding (Read(..))
import System.Mem


--------------------------------------------------------------------------------
-- * Shared References
--------------------------------------------------------------------------------

-- | Shared references
newtype SRef s a = SRef { unSRef :: IORef a }
  deriving Eq

newSRefIO :: a -> IO (IORef a)
newSRefIO = newIORef
{-# INLINE newSRefIO #-}

readSRefIO :: IORef a -> IO a
readSRefIO = readIORef
{-# INLINE readSRefIO #-}

writeSRefIO :: IORef a -> a ->  IO ()
writeSRefIO r a = do a `pseq` writeBarrier
                     writeIORef r a
{-# INLINE writeSRefIO #-}

--------------------------------------------------------------------------------
-- * Shared state
--------------------------------------------------------------------------------

-- | Counter for causal ordering.
newtype Counter = Counter (MutableByteArray RealWorld)

instance Eq Counter where
  Counter m == Counter n = sameMutableByteArray m n

newCounter :: Int -> IO Counter
newCounter w = do
  b <- newByteArray 8
  writeByteArray b 0 w
  return (Counter b)
{-# INLINE newCounter #-}

readCounter :: Counter -> IO Int
readCounter (Counter c) = readByteArray c 0
{-# INLINE readCounter #-}

writeCounter :: Counter -> Int -> IO ()
writeCounter (Counter c) w = writeByteArray c 0 w
{-# INLINE writeCounter #-}

incCounter :: Counter -> IO Int
incCounter (Counter c) = do
  x <- fetchAddIntArray c 0 1
  return $! x + 1
{-# INLINE incCounter #-}

newtype Version = Version (IORef ())

newVersion :: IO Version
newVersion = Version <$> newIORef ()

-- | State for an RCU computation.
data RCUState = RCUState
  { -- | Global state
    rcuStateGlobalCounter   :: {-# UNPACK #-} !Counter
  , rcuStateGlobalVersion   :: {-# UNPACK #-} !(IORef Version)
  , rcuStateThreadCountersV :: {-# UNPACK #-} !(MVar [Counter])
  , rcuStateWriterLockV     :: {-# UNPACK #-} !(MVar ())
    -- | Thread state
  , rcuStateMyCounter       :: {-# UNPACK #-} !Counter  -- each thread's state gets its own counter
  , rcuStatePinned          ::                !(Maybe Int)
  }

--------------------------------------------------------------------------------
-- * Read-Side Critical Sections
--------------------------------------------------------------------------------

-- | This is the basic read-side critical section for an RCU computation
newtype ReadingRCU s a = ReadingRCU { runReadingRCU :: RCUState -> IO a }
  deriving Functor

instance Applicative (ReadingRCU s) where
  pure a = ReadingRCU $ \ _ -> pure a
  ReadingRCU mf <*> ReadingRCU ma = ReadingRCU $ \ s -> mf s <*> ma s

instance Monad (ReadingRCU s) where
  return a = ReadingRCU $ \ _ -> pure a
  ReadingRCU m >>= f = ReadingRCU $ \ s -> do
    a <- m s
    runReadingRCU (f a) s
  fail s = ReadingRCU $ \ _ -> fail s

instance Alternative (ReadingRCU s) where
  empty = ReadingRCU $ \ _ -> empty
  ReadingRCU ma <|> ReadingRCU mb = ReadingRCU $ \s -> ma s <|> mb s

instance MonadPlus (ReadingRCU s) where
  mzero = ReadingRCU $ \ _ -> mzero
  ReadingRCU ma `mplus` ReadingRCU mb = ReadingRCU $ \s -> ma s `mplus` mb s

instance MonadNew (SRef s) (ReadingRCU s) where
  newSRef a = ReadingRCU $ \_ -> SRef <$> newSRefIO a

instance MonadReading (SRef s) (ReadingRCU s) where
  readSRef (SRef r) = ReadingRCU $ \ _ -> readSRefIO r
  {-# INLINE readSRef #-}

--------------------------------------------------------------------------------
-- * Write-Side Critical Sections
--------------------------------------------------------------------------------

-- | This is the basic write-side critical section for an RCU computation
newtype WritingRCU s a = WritingRCU { runWritingRCU :: RCUState -> IO a }
  deriving Functor

instance Applicative (WritingRCU s) where
  pure a = WritingRCU $ \ _ -> pure a
  WritingRCU mf <*> WritingRCU ma = WritingRCU $ \ s -> mf s <*> ma s

instance Monad (WritingRCU s) where
  return a = WritingRCU $ \ _ -> pure a
  WritingRCU m >>= f = WritingRCU $ \ s -> do
    a <- m s
    runWritingRCU (f a) s
  fail s = WritingRCU $ \ _ -> fail s

instance Alternative (WritingRCU s) where
  empty = WritingRCU $ \ _ -> empty
  WritingRCU ma <|> WritingRCU mb = WritingRCU $ \s -> ma s <|> mb s

instance MonadPlus (WritingRCU s) where
  mzero = WritingRCU $ \ _ -> mzero
  WritingRCU ma `mplus` WritingRCU mb = WritingRCU $ \s -> ma s `mplus` mb s

instance MonadNew (SRef s) (WritingRCU s) where
  newSRef a = WritingRCU $ \_ -> SRef <$> newSRefIO a

instance MonadReading (SRef s) (WritingRCU s) where
  readSRef (SRef r) = WritingRCU $ \ _ -> readSRefIO r
  {-# INLINE readSRef #-}

instance MonadWriting (SRef s) (WritingRCU s) where
  writeSRef (SRef r) a = WritingRCU $ \ _ -> writeSRefIO r a
  {-# INLINE writeSRef #-}
  synchronize = WritingRCU synchronizeIO

synchronizeIO :: RCUState -> IO ()
synchronizeIO s = do
  withMVar (rcuStateThreadCountersV s) $ \ threadCounters -> do
    gc' <- incCounter (rcuStateGlobalCounter s)
    writeCounter (rcuStateMyCounter s) gc'
    let waitForThreads i xxs@(x:xs)
          | i > 2000 = return True
          | otherwise = do
            tc <- readCounter x
            if tc == gc' then waitForThreads (i + 1) xs
            else do
              threadDelay 1
              waitForThreads (i + 1) xxs
        waitForThreads _ [] = return False
    bad <- waitForThreads (0 :: Int) threadCounters
    when bad $ do
      -- slow path
      m <- newEmptyMVar
      stuff s m
      performMinorGC
      sitAndSpin m
  storeLoadBarrier

stuff :: RCUState -> MVar ()  -> IO ()
stuff s m = do
  Version v <- readIORef (rcuStateGlobalVersion s)
  v' <- newVersion
  atomicWriteIORef (rcuStateGlobalVersion s) v'
  _ <- mkWeakIORef v $ putMVar m ()
  return ()
{-# NOINLINE stuff #-}

-- This is awful. It should just takeMVar
sitAndSpin :: MVar () -> IO ()
sitAndSpin m = tryTakeMVar m >>= \case
  Just () -> return ()
  Nothing -> do
    performMajorGC
    sitAndSpin m

--------------------------------------------------------------------------------
-- * RCU Context
--------------------------------------------------------------------------------

-- | This is an RCU computation. It can use 'forking' and 'joining' to form
-- new threads, and then you can use 'reading' and 'writing' to run classic
-- read-side and write-side RCU computations. Writers are
-- serialized using an MVar, readers are able to proceed while writers are
-- updating.
newtype RCU s a = RCU { unRCU :: RCUState -> IO a }
  deriving Functor

instance Applicative (RCU s) where
  pure = return
  (<*>) = ap

instance Monad (RCU s) where
  return a = RCU $ \ _ -> return a
  RCU m >>= f = RCU $ \s -> do
    a <- m s
    unRCU (f a) s

instance MonadNew (SRef s) (RCU s) where
  newSRef a = RCU $ \_ -> SRef <$> newSRefIO a

-- | This is a basic 'RCU' thread. It may be embellished when running in a more
-- exotic context.
data RCUThread s a = RCUThread
  { rcuThreadId  :: {-# UNPACK #-} !ThreadId
  , rcuThreadVar :: {-# UNPACK #-} !(MVar a)
  }

instance MonadRCU (SRef s) (RCU s) where
  type Reading (RCU s) = ReadingRCU s
  type Writing (RCU s) = WritingRCU s
  type Thread (RCU s) = RCUThread s
  forking (RCU m) = RCU $ \ s -> do
    result <- newEmptyMVar
    gc <- readCounter (rcuStateGlobalCounter s)
    threadCounter <- newCounter gc
    modifyMVar_ (rcuStateThreadCountersV s) $ return . (threadCounter :)
    tid <- forkIO $ do
      x <- m $ s { rcuStateMyCounter = threadCounter }
      putMVar result x
      modifyMVar_ (rcuStateThreadCountersV s) $ return . delete threadCounter
    return (RCUThread tid result)
  {-# INLINE forking #-}

  joining (RCUThread _ m) = RCU $ \ _ -> readMVar m
  {-# INLINE joining #-}

  reading (ReadingRCU m) = RCU $ \ s -> do
    v <- readIORef (rcuStateGlobalVersion s)
    x <- m s
    touch v
    writeCounter (rcuStateMyCounter s) =<< readCounter (rcuStateGlobalCounter s)
    return x
  {-# INLINE reading #-}

  writing (WritingRCU m) = RCU $ \ s -> do
    -- Acquire the writer-serializing lock.
    takeMVar (rcuStateWriterLockV s)
    x <- m s
    synchronizeIO s
    putMVar (rcuStateWriterLockV s) ()
    return x
  {-# INLINE writing #-}

instance MonadIO (RCU s) where
  liftIO m = RCU $ \ _ -> m
  {-# INLINE liftIO #-}

-- | Run an RCU computation.
runRCU :: (forall s. RCU s a) -> IO a
runRCU m = do
  v <- newVersion
  unRCU m =<< RCUState <$> newCounter 0
                       <*> newIORef v
                       <*> newMVar []
                       <*> newMVar ()
                       <*> newCounter 0
                       <*> pure Nothing
{-# INLINE runRCU #-}

-- | Run an RCU computation in a thread pinned to a particular core.
runOnRCU :: Int -> (forall s. RCU s a) -> IO a
runOnRCU i m = do
  v <- newVersion
  unRCU m =<< RCUState <$> newCounter 0
                       <*> newIORef v
                       <*> newMVar []
                       <*> newMVar ()
                       <*> newCounter 0
                       <*> pure (Just i)
{-# INLINE runOnRCU #-}