-----------------------------------------------------------------------------
-- |
-- Module      :  Data.Vector.Storable.Buffer
-- Copyright   :  (c) A.V.H. McPhail 2011
-- License     :  BSD3
--
-- Maintainer  :  Vivian McPhail <haskell.vivian.mcphail@gmail.com>
-- Stability   :  provisional
--
-- A buffer that can be used as a vector
-----------------------------------------------------------------------------

module Data.Vector.Storable.Buffer (
    Buffer,
    newBuffer,
    pushNextElement,
    toVector,
    mapBufferM, mapBufferM_,
) where

import Data.IORef

import qualified Data.Vector.Storable         as V
import qualified Data.Vector.Storable.Mutable as M

import Foreign hiding(new)

-------------------------------------------------------------------

data Buffer a = B { end :: {-# UNPACK #-} !(IORef Int)    -- ^ next position to fill
                  , dat :: {-# UNPACK #-} !(M.IOVector a) -- ^ the data
                  }

-- | create a new buffer
newBuffer :: Storable a 
            => Int       -- ^ Size
          -> IO (Buffer a)
newBuffer n = do
  v <- M.new n
  o <- newIORef 0
  return $ B o v
{-# INLINE newBuffer #-}

-- | add the next element to the buffer
pushNextElement :: Storable a => Buffer a -> a -> IO ()
pushNextElement b@(B o v) e = do
  let n = M.length v
  i <- readIORef o
  M.unsafeWrite v i e
  if i == (n-1)
     then writeIORef o 0
     else writeIORef o (i+1)
{-# INLINE pushNextElement #-}

-- | convert to a vector
toVector :: Storable a => Buffer a -> V.Vector a
toVector (B o v) = unsafePerformIO $ do
   let n = M.length v                     
   w <- M.new n
   i <- readIORef o
   M.unsafeWith v $ \p ->
       M.unsafeWith w $ \q -> do
         copyArray q (p `advancePtr` i) (n-i)
         if i /= 0
            then copyArray (q `advancePtr` (n-i)) p i
            else return ()
   V.unsafeFreeze w
{-# INLINE toVector #-}

-- | monadic map over a buffer
mapBufferM :: (Storable a, Storable b) => (a -> IO b) -> Buffer a -> IO (V.Vector b)
mapBufferM f (B o v) = do
  let n = M.length v
  w <- M.new n
  i <- readIORef o
  go w 0 i n
  V.unsafeFreeze w
     where go w' i' o' n' 
              | i' + 1 == n' = do
                         x <- M.unsafeRead v (if i'+o' >= n' then i'+o'-n' else i'+o')
                         y <- f x
                         M.unsafeWrite w' i' y
              | otherwise         = do
                         x <- M.unsafeRead v (if i'+o' >= n' then i'+o'-n' else i'+o')
                         y <- f x
                         M.unsafeWrite w' i' y
                         let i'' = if i' + 1 == n' then 0 else i' + 1
                         go w' i'' o' n'
{-# INLINE mapBufferM #-}

-- | monadic map over a buffer
mapBufferM_ :: (Storable a) => (a -> IO b) -> Buffer a -> IO ()
mapBufferM_ f (B o v) = do
  let n = M.length v
  i <- readIORef o
  go 0 i n
     where go i' o' n' 
              | i' + 1 == n' = do
                         x <- M.unsafeRead v (if i'+o' >= n' then i'+o'-n' else i'+o')
                         _ <- f x
                         return ()
              | otherwise         = do
                         x <- M.unsafeRead v (if i'+o' >= n' then i'+o'-n' else i'+o')
                         _ <- f x
                         let i'' = if i' + 1 == n' then 0 else i' + 1
                         go i'' o' n'
{-# INLINE mapBufferM_ #-}