{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}

{-# OPTIONS_GHC -Wall #-}
{- | Create a circular buffer over shared memory that can be accessed from
   separate processes.

   This module assumes that exactly one WriteBuffer and one ReadBuffer will be
   used to access the same shared memory object.  The ReadBuffer and
   WriteBuffer may exist in separate processes.

   to use this library, in one process

   > bracket (createBuffer "aBuffer" "aSemaphore" 256 0o600) (removeBuffer)
   >         (\buffer -> doSomethingWith buffer)

   and in the other

   > bracket (openBuffer "aBuffer" "aSemaphore" 256 0o600) (closeBuffer)
   >         (\buffer -> doSomethingWith buffer)

   The buffer may be opened from either the reader or writer end, but you
   should ensure that the buffer is created before it is opened.

   As the underlying objects (shm and named posix semaphores) exist in the file
   system, failing to call removeBuffer will leave stale objects in the
   filesystem.

   Opening multiple ReadBuffers or WriteBuffers with the same names (whether in
   one process or several) results in undefined behavior.
 -}
module System.Posix.CircularBuffer (
  WriteBuffer
, ReadBuffer
, Shared (..)
-- * normal interface
, putBuffer
, getBuffer
-- ** batch operations
, putBufferList
, getAvailable
, sizeOfInt
) where

import System.Posix.SharedBuffer

import Control.Concurrent.MVar
import Control.Exception (try)
import Control.Monad
import Data.Bits
import Foreign.ForeignPtr
import Foreign.Ptr
import Foreign.Marshal.Array (advancePtr)
import Foreign.Storable
import System.Posix.Semaphore.Unsafe
import System.Posix (FileMode)

import Debug.Trace (traceEventIO)

-- we could use Ptr's instead of ForeignPtr's, but then we'd need to free them
-- in the case of an exception, which would require exporting more from this
-- module.
data WriteBuffer a = WB CircularBuffer (MVar Int)
data ReadBuffer a  = RB CircularBuffer (ForeignPtr Int)

-- | Functions for creating/opening/closing/removing shared buffers.
class Shared b where
    createBuffer :: String -> String -> Int -> FileMode -> IO b
    openBuffer   :: String -> String -> Int -> FileMode -> IO b
    closeBuffer  :: b -> IO ()
    removeBuffer :: b -> IO ()
    unlinkBuffer :: b -> IO ()

instance Storable a => Shared (WriteBuffer a) where
    createBuffer = openSharedBuffer makeWB
                                    OpenSemFlags{semCreate = True, semExclusive = True}
                                    ShmOpenFlags{shmCreate = True
                                                ,shmReadWrite = True
                                                ,shmExclusive = True
                                                ,shmTrunc = False
                                                }
                                    writeProtection
                                    (sizeOf (undefined :: a))
    openBuffer = openSharedBuffer makeWB
                                  OpenSemFlags{semCreate = False, semExclusive = False}
                                  ShmOpenFlags{shmCreate = False
                                              ,shmReadWrite = True
                                              ,shmExclusive = False
                                              ,shmTrunc = False
                                              }
                                  writeProtection
                                  (sizeOf (undefined :: a))
    closeBuffer (WB cb _)  = closeBuffer cb
    removeBuffer (WB cb _) = removeBuffer cb
    unlinkBuffer (WB cb _) = unlinkBuffer cb

instance Storable a => Shared (ReadBuffer a) where
    createBuffer = openSharedBuffer makeRB
                                    OpenSemFlags{semCreate = True, semExclusive = True}
                                    ShmOpenFlags{shmCreate = True
                                                ,shmReadWrite = False
                                                ,shmExclusive = True
                                                ,shmTrunc = False
                                                }
                                    [ProtRead]
                                    (sizeOf (undefined :: a))
    openBuffer   = openSharedBuffer makeRB
                                    OpenSemFlags{semCreate = False, semExclusive = False}
                                    ShmOpenFlags{shmCreate = False
                                                ,shmReadWrite = False
                                                ,shmExclusive = False
                                                ,shmTrunc = False
                                                }
                                    [ProtRead]
                                    (sizeOf (undefined :: a))
    closeBuffer (RB cb _)  = closeBuffer cb
    removeBuffer (RB cb _) = removeBuffer cb
    unlinkBuffer (RB cb _) = unlinkBuffer cb

makeRB :: CircularBuffer -> MVar Int -> ForeignPtr Int -> ReadBuffer a
makeRB buf _ fp = RB buf fp

makeWB :: CircularBuffer -> MVar Int -> ForeignPtr Int -> WriteBuffer a
makeWB buf mv _ = WB buf mv

-- | open an existing shared memory buffer and semaphore.
openSharedBuffer :: (CircularBuffer -> MVar Int -> ForeignPtr Int -> b)
                 -> OpenSemFlags
                 -> ShmOpenFlags
                 -> [Protection]
                 -> Int
                 -> String
                 -> String
                 -> Int
                 -> FileMode
                 -> IO b
openSharedBuffer maker semFlags shmFlags prot bitwidth shmName cbSemName reqCbSize mode = do
    let bufsz = fromIntegral $ bitwidth*cbSize
        cbSize = 2^(ceiling (logBase 2 (fromIntegral $ 1+reqCbSize) :: Double) :: Int)
        -- the buffer is effectively 1 element smaller than specified
        -- (due to a race condition in the reader, see 'readSeqBlocking')
        -- so make it 1 larger so that the full requested size is always
        -- available.
    cbBuf <- openSBuffer shmName bufsz shmFlags prot mode
    cbSem <- semOpen cbSemName semFlags mode 0
    seqref <- newMVar 0
    seqptr <- mallocForeignPtr
    withForeignPtr seqptr $ flip poke 0
    return $ maker (CircularBuffer{cbBuf,cbSize,cbSem,cbSemName}) seqref seqptr

-- | Write a value to the writer end.
--
-- This function is thread-safe.
putBuffer :: Storable a => WriteBuffer a -> a -> IO ()
putBuffer (WB cb seqvar) val = modifyMVar_ seqvar $ \seqnum -> do
    writeSeqBlocking cb seqnum val
    return $! seqnum+1
{-# INLINEABLE putBuffer #-}

-- | Write a list of values to the writer end.
--
-- This function is thread-safe.
putBufferList :: Storable a => WriteBuffer a -> [a] -> IO ()
putBufferList (WB cb seqvar) vals = modifyMVar_ seqvar $ \seqnum -> do
    cnt <- writeSeqList cb seqnum vals
    return $! seqnum+cnt
{-# INLINE putBufferList #-}

-- | read the next value from the reader end.
--
-- This function is *NOT* thread-safe.
getBuffer :: Storable a => ReadBuffer a -> IO a
getBuffer (RB cb seqvar) = withForeignPtr seqvar $ \seqPtr -> do
    seqnum <- peek seqPtr
    val <- readSeqBlocking cb seqnum
    poke seqPtr $ seqnum+1
    return val
{-# INLINEABLE getBuffer #-}

-- | read all currently available values from the reader end.
--
-- This function is *NOT* thread-safe.
getAvailable :: Storable a => ReadBuffer a -> IO [a]
getAvailable (RB cb seqvar) = withForeignPtr seqvar $ \seqPtr -> do
    seqnum <- peek seqPtr
    val <- readSeqReady cb seqnum
    poke seqPtr $ seqnum+length val
    return val
{-# INLINEABLE getAvailable #-}

------------------------------------------------------------------
-- circular buffer interface

-- intended use:
-- a single producer and single consumer in separate processes
--
-- invariants:
--   cbSem contains the number of items available to be read in buffer.  That
--   is, the number of items that are immediately available to the reader.
data CircularBuffer = CircularBuffer
    { cbSize :: {-# UNPACK #-} !Int
    , cbBuf  :: {-# UNPACK #-} !SharedBuffer
    , cbSem  :: {-# UNPACK #-} !Semaphore
    , cbSemName :: String
    }

instance Shared CircularBuffer where
    createBuffer = error "can't create a CircularBuffer directly"
    openBuffer = error "can't open a CircularBuffer directly"
    closeBuffer = closeSharedBuffer . cbBuf
    removeBuffer cb = do
        removeSharedBuffer (cbBuf cb)
        void (try (semUnlink (cbSemName cb)) :: IO (Either IOError ()))
    unlinkBuffer cb =
        void (try $ unlinkSharedBuffer (cbBuf cb) >> semUnlink (cbSemName cb) :: IO (Either IOError ()))

-- Wait until data is available, then read the value at a particular sequence number.
--
-- logically, this should proceed as:
-- - first wait until a value is available (cbSem > 0)
-- - read the value
-- - lock the semaphore (decrement it).
--
-- but we actually wait and lock the semaphore in one operation, then read the
-- value.  The writer side must take care to not overwrite the end of the
-- buffer, as it will see the semaphore decrement before the value is actually
-- read.
readSeqBlocking :: Storable a => CircularBuffer -> Int -> IO a
readSeqBlocking cb = \rseq -> do
    waitAndLock (cbSem cb)
    readSeq cb rseq
{-# INLINEABLE readSeqBlocking #-}

-- read currently available data starting from the given sequence number.
readSeqReady :: Storable a => CircularBuffer -> Int -> IO [a]
readSeqReady cb@CircularBuffer{..} rseq = do
    curReady <- unsafeSemGetValue cbSem
    vals <- readSeqs cb rseq curReady
    replicateM_ curReady (unsafeSemLock cbSem)
    -- unsafeSemLock is ok if there are no other readers on this
    -- semaphore.
    return vals
{-# INLINEABLE readSeqReady #-}

-- Write data to the buffer at the current position.  May block if the reader
-- is far behind.
--
-- first check for space: if the buffer is full (cbSem == bufsize-1), wait until
-- cbSem decreases.  The buffer size is reduced by 1 because the reader
-- decrements the semaphore before reading the value.
--
-- This is only safe if all writes are serialized at a higher level.
--
-- When space is available, write the value
--
-- finally increment cbSem to indicate another value is ready
writeSeqBlocking :: Storable a => CircularBuffer -> Int -> a -> IO ()
writeSeqBlocking cb@CircularBuffer{..} writePos val = do
    -- first check if the buffer is full, and if so, wait until space is
    -- available.  Currently just spinning on this, because we expect it to be
    -- a rare occurrence.
    let waitUntilAvailable = do
            curLag <- unsafeSemGetValue cbSem
            when (curLag >= cbSize-1) $ do
                traceEventIO "writeSeqBlocking: waitUntilAvailable spinning"
                waitUntilAvailable
    waitUntilAvailable
    writeSeq cb writePos val
    unsafeSemPost cbSem
{-# INLINEABLE writeSeqBlocking #-}

-- | write a list of items, starting from the given sequence number.
-- Attempt to write everything in one batch.
-- returns the number of items written (which should be the length of the list)
writeSeqList :: Storable a => CircularBuffer -> Int -> [a] -> IO Int
writeSeqList cb@CircularBuffer{..} writePos = go 0
  where
    go !len [] = return len
    go !prevWritten xs = do
        currentLag <- unsafeSemGetValue cbSem
        let numReady = cbSize-currentLag
            (toWrite,toWait) = splitAt numReady xs
            offset = writePos+prevWritten
        numWritten <- foldM (\(!ix) x -> writeSeq cb (offset+ix) x >> return (ix+1)) 0 toWrite
        replicateM_ numWritten (unsafeSemPost cbSem)
        go (prevWritten+numWritten) toWait
{-# INLINE writeSeqList #-}

------------------------------------------------------------------
-- low-level interface
--
-- these functions deal *only* with reading from/writing to the buffer.  They
-- do not synchronize results.

-- precondition: must have a valid sequence number.
-- this is not checked.
readSeq :: Storable a => CircularBuffer -> Int -> IO a
readSeq CircularBuffer{..} rseq = do
    let cbPtr  = castPtr $ sbPtr $ cbBuf
        offset = rseq .&. (cbSize-1)
    peek (cbPtr `advancePtr` offset)
-- ghc inlines these already

-- precondition: must have a valid sequence number.
-- this is not checked.
readSeqs :: Storable a => CircularBuffer -> Int -> Int -> IO [a]
readSeqs CircularBuffer{..} rseq count
  | count <= 0 = return []
  | otherwise  = go (count-1) []
  where
    cbPtr = castPtr $ sbPtr cbBuf
    go 0 acc = do
        let offset = rseq .&. (cbSize-1)
        x <- peek (cbPtr `advancePtr` offset)
        return (x:acc)
    go n acc = do
        let offset = (n+rseq) .&. (cbSize-1)
        x <- peek (cbPtr `advancePtr` offset)
        go (n-1) (x:acc)
{-# INLINE readSeqs #-}

-- write to a position in the buffer.  Doesn't validate anything.
writeSeq :: Storable a => CircularBuffer -> Int -> a -> IO ()
writeSeq CircularBuffer{..} wseq val = do
    let cbPtr  = castPtr $ sbPtr $ cbBuf
        offset = wseq .&. (cbSize-1)
    poke (cbPtr `advancePtr` offset) val

-- Does the same as `semWait`, but cheaper if the semaphore
-- is immediately available.
waitAndLock :: Semaphore -> IO ()
waitAndLock sem = do
    gotLock <- unsafeSemTryWait sem
    when (not gotLock) $ do
        gotLock' <- semTimedWait 10 0 sem
        when (not gotLock') $ waitAndLock sem

------------------------------------------------------------------
-- size of an int, as a compile-time constant.
sizeOfInt :: Int
sizeOfInt = $(let sz = sizeOf (0::Int) in [| sz |])