{-| This module provides a variant of the vector, equipped with @push_back@ operation.
IOVector here are supposed to be used in single thread situation.
-}
module Data.Vector.Mutable.PushBack where

import Prelude hiding (length, read)
import Control.Monad
import Data.IORef
import qualified Data.Vector as V
import qualified Data.Vector.Mutable as VM
import qualified Data.Vector.Unboxed as VU
import qualified Data.Vector.Unboxed.Mutable as VUM
import System.IO.Unsafe

-- | @IOVector@ consists of (1) pointer to the underlying vector (2) length
-- While 'Data.Vector' has the underlying array itself, this type only has the pointer.
-- This means read/write should be slower than the original vector.
data IOVector a = IOVector !(IORef (VM.IOVector a)) !(VUM.IOVector Int)

-- Allocate (p + 10)-element vector, which might be more efficient than allocating just a small size of vector, like 1-element or 2-element.
new :: Int -> IO (IOVector a)
new :: Int -> IO (IOVector a)
new p :: Int
p = Int -> IO (IOVector a)
forall a. Int -> IO (IOVector a)
new' (Int
p Int -> Int -> Int
forall a. Num a => a -> a -> a
+ 10)
  where new' :: Int -> IO (IOVector a)
new' p :: Int
p = IORef (IOVector a) -> IOVector Int -> IOVector a
forall a. IORef (IOVector a) -> IOVector Int -> IOVector a
IOVector (IORef (IOVector a) -> IOVector Int -> IOVector a)
-> IO (IORef (IOVector a)) -> IO (IOVector Int -> IOVector a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (IOVector a -> IO (IORef (IOVector a))
forall a. a -> IO (IORef a)
newIORef (IOVector a -> IO (IORef (IOVector a)))
-> IO (IOVector a) -> IO (IORef (IOVector a))
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Int -> IO (MVector (PrimState IO) a)
forall (m :: * -> *) a.
PrimMonad m =>
Int -> m (MVector (PrimState m) a)
VM.new Int
p) IO (IOVector Int -> IOVector a)
-> IO (IOVector Int) -> IO (IOVector a)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (Int -> Int -> IO (MVector (PrimState IO) Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
VUM.replicate 1 0)

read :: IOVector a -> Int -> IO a
read :: IOVector a -> Int -> IO a
read (IOVector vref :: IORef (IOVector a)
vref _) k :: Int
k = IORef (IOVector a) -> IO (IOVector a)
forall a. IORef a -> IO a
readIORef IORef (IOVector a)
vref IO (IOVector a) -> (IOVector a -> IO a) -> IO a
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \vec :: IOVector a
vec -> MVector (PrimState IO) a -> Int -> IO a
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> m a
VM.read IOVector a
MVector (PrimState IO) a
vec Int
k

-- | Get the position of the last cell in the @IOVector@. This operation is not safe because of the 'unsafePerformIO'.
safeLength :: IOVector a -> IO Int
safeLength :: IOVector a -> IO Int
safeLength (IOVector _ uvec :: IOVector Int
uvec) = MVector (PrimState IO) Int -> Int -> IO Int
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
VUM.read IOVector Int
MVector (PrimState IO) Int
uvec 0

length :: IOVector a -> Int
length :: IOVector a -> Int
length pvec :: IOVector a
pvec = IO Int -> Int
forall a. IO a -> a
unsafePerformIO (IO Int -> Int) -> IO Int -> Int
forall a b. (a -> b) -> a -> b
$ IOVector a -> IO Int
forall a. IOVector a -> IO Int
safeLength IOVector a
pvec

safeCapacity :: IOVector a -> IO Int
safeCapacity :: IOVector a -> IO Int
safeCapacity (IOVector vref :: IORef (IOVector a)
vref _) = (IOVector a -> Int) -> IO (IOVector a) -> IO Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap IOVector a -> Int
forall s a. MVector s a -> Int
VM.length (IO (IOVector a) -> IO Int) -> IO (IOVector a) -> IO Int
forall a b. (a -> b) -> a -> b
$ IORef (IOVector a) -> IO (IOVector a)
forall a. IORef a -> IO a
readIORef IORef (IOVector a)
vref

-- | Get the capacity of the @IOVector@. This operation is not safe because of the 'unsafePerformIO'.
capacity :: IOVector a -> Int
capacity :: IOVector a -> Int
capacity pvec :: IOVector a
pvec = IO Int -> Int
forall a. IO a -> a
unsafePerformIO (IO Int -> Int) -> IO Int -> Int
forall a b. (a -> b) -> a -> b
$ IOVector a -> IO Int
forall a. IOVector a -> IO Int
safeCapacity IOVector a
pvec

write :: IOVector a -> Int -> a -> IO ()
write :: IOVector a -> Int -> a -> IO ()
write (IOVector vref :: IORef (IOVector a)
vref _) i :: Int
i v :: a
v = do
  IOVector a
vec <- IORef (IOVector a) -> IO (IOVector a)
forall a. IORef a -> IO a
readIORef IORef (IOVector a)
vref
  MVector (PrimState IO) a -> Int -> a -> IO ()
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> a -> m ()
VM.write IOVector a
MVector (PrimState IO) a
vec Int
i a
v

-- | /O(n)/ Insert a value into any place. This is a slow operation.
insert
  :: IOVector a  -- ^ The vector should have positive (non-zero) length
  -> Int
  -> a
  -> IO ()
insert :: IOVector a -> Int -> a -> IO ()
insert pvec :: IOVector a
pvec i :: Int
i v :: a
v = do
  Int
len <- IOVector a -> IO Int
forall a. IOVector a -> IO Int
safeLength IOVector a
pvec

  IOVector a -> Int -> IO a
forall a. IOVector a -> Int -> IO a
read IOVector a
pvec (Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
- 1) IO a -> (a -> IO ()) -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= IOVector a -> a -> IO ()
forall a. IOVector a -> a -> IO ()
push IOVector a
pvec
  [Int] -> (Int -> IO ()) -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Int] -> [Int]
forall a. [a] -> [a]
reverse [Int
i .. Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
- 2]) ((Int -> IO ()) -> IO ()) -> (Int -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \j :: Int
j -> IOVector a -> Int -> IO a
forall a. IOVector a -> Int -> IO a
read IOVector a
pvec Int
j IO a -> (a -> IO ()) -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= IOVector a -> Int -> a -> IO ()
forall a. IOVector a -> Int -> a -> IO ()
write IOVector a
pvec (Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
+ 1)
  IOVector a -> Int -> a -> IO ()
forall a. IOVector a -> Int -> a -> IO ()
write IOVector a
pvec Int
i a
v

-- | /O(n)/ This is a slow operation. This also throws an exception if the specified index does not exist.
delete :: IOVector a -> Int -> IO ()
delete :: IOVector a -> Int -> IO ()
delete pvec :: IOVector a
pvec@(IOVector _ uvec :: IOVector Int
uvec) i :: Int
i = do
  Int
len <- IOVector a -> IO Int
forall a. IOVector a -> IO Int
safeLength IOVector a
pvec
  [Int] -> (Int -> IO ()) -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ 1 .. Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
- 1] ((Int -> IO ()) -> IO ()) -> (Int -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \j :: Int
j -> IOVector a -> Int -> IO a
forall a. IOVector a -> Int -> IO a
read IOVector a
pvec Int
j IO a -> (a -> IO ()) -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= IOVector a -> Int -> a -> IO ()
forall a. IOVector a -> Int -> a -> IO ()
write IOVector a
pvec (Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
- 1)
  MVector (PrimState IO) Int -> (Int -> Int) -> Int -> IO ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> (a -> a) -> Int -> m ()
VUM.modify IOVector Int
MVector (PrimState IO) Int
uvec (\x :: Int
x -> Int
x Int -> Int -> Int
forall a. Num a => a -> a -> a
- 1) 0

push :: IOVector a -> a -> IO ()
push :: IOVector a -> a -> IO ()
push pvec :: IOVector a
pvec@(IOVector vref :: IORef (IOVector a)
vref uvec :: IOVector Int
uvec) v :: a
v = do
  IOVector a
vec <- IORef (IOVector a) -> IO (IOVector a)
forall a. IORef a -> IO a
readIORef IORef (IOVector a)
vref
  Int
len <- IOVector a -> IO Int
forall a. IOVector a -> IO Int
safeLength IOVector a
pvec
  Int
cap <- IOVector a -> IO Int
forall a. IOVector a -> IO Int
safeCapacity IOVector a
pvec
  Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
len Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
cap) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
    IOVector a
vec' <- MVector (PrimState IO) a -> Int -> IO (MVector (PrimState IO) a)
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> m (MVector (PrimState m) a)
VM.grow IOVector a
MVector (PrimState IO) a
vec Int
cap
    IORef (IOVector a) -> IOVector a -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef (IOVector a)
vref IOVector a
vec'

  IOVector a -> Int -> a -> IO ()
forall a. IOVector a -> Int -> a -> IO ()
write      IOVector a
pvec Int
len   a
v
  MVector (PrimState IO) Int -> (Int -> Int) -> Int -> IO ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> (a -> a) -> Int -> m ()
VUM.modify IOVector Int
MVector (PrimState IO) Int
uvec (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ 1) 0

fromList :: [a] -> IO (IOVector a)
fromList :: [a] -> IO (IOVector a)
fromList xs :: [a]
xs = do
  MVector RealWorld a
vec  <- Vector a -> IO (MVector RealWorld a)
forall (m :: * -> *) a.
PrimMonad m =>
Vector a -> m (MVector (PrimState m) a)
V.thaw (Vector a -> IO (MVector RealWorld a))
-> Vector a -> IO (MVector RealWorld a)
forall a b. (a -> b) -> a -> b
$ [a] -> Vector a
forall a. [a] -> Vector a
V.fromList [a]
xs
  MVector RealWorld a
vec' <- MVector (PrimState IO) a -> Int -> IO (MVector (PrimState IO) a)
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> m (MVector (PrimState m) a)
VM.grow MVector RealWorld a
MVector (PrimState IO) a
vec ((MVector RealWorld a -> Int
forall s a. MVector s a -> Int
VM.length MVector RealWorld a
vec Int -> Int -> Int
forall a. Num a => a -> a -> a
+ 5) Int -> Int -> Int
forall a. Num a => a -> a -> a
* 2)
  IORef (MVector RealWorld a)
vref <- MVector RealWorld a -> IO (IORef (MVector RealWorld a))
forall a. a -> IO (IORef a)
newIORef MVector RealWorld a
vec'
  IOVector Int
uvec <- Vector Int -> IO (IOVector Int)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
Vector a -> m (MVector (PrimState m) a)
VU.thaw (Vector Int -> IO (IOVector Int))
-> Vector Int -> IO (IOVector Int)
forall a b. (a -> b) -> a -> b
$ [Int] -> Vector Int
forall a. Unbox a => [a] -> Vector a
VU.fromList [MVector RealWorld a -> Int
forall s a. MVector s a -> Int
VM.length MVector RealWorld a
vec]
  IOVector a -> IO (IOVector a)
forall (m :: * -> *) a. Monad m => a -> m a
return (IOVector a -> IO (IOVector a)) -> IOVector a -> IO (IOVector a)
forall a b. (a -> b) -> a -> b
$ IORef (MVector RealWorld a) -> IOVector Int -> IOVector a
forall a. IORef (IOVector a) -> IOVector Int -> IOVector a
IOVector IORef (MVector RealWorld a)
vref IOVector Int
uvec

asIOVector :: IOVector a -> IO (VM.IOVector a)
asIOVector :: IOVector a -> IO (IOVector a)
asIOVector pvec :: IOVector a
pvec@(IOVector vref :: IORef (IOVector a)
vref _) = do
  Int
len <- IOVector a -> IO Int
forall a. IOVector a -> IO Int
safeLength IOVector a
pvec
  IORef (IOVector a) -> IO (IOVector a)
forall a. IORef a -> IO a
readIORef IORef (IOVector a)
vref IO (IOVector a)
-> (IOVector a -> IO (IOVector a)) -> IO (IOVector a)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \vec :: IOVector a
vec -> IOVector a -> IO (IOVector a)
forall (m :: * -> *) a. Monad m => a -> m a
return (Int -> Int -> IOVector a -> IOVector a
forall s a. Int -> Int -> MVector s a -> MVector s a
VM.slice 0 Int
len IOVector a
vec)

asUnsafeIOVector :: IOVector a -> VM.IOVector a
asUnsafeIOVector :: IOVector a -> IOVector a
asUnsafeIOVector pvec :: IOVector a
pvec = IO (IOVector a) -> IOVector a
forall a. IO a -> a
unsafePerformIO (IO (IOVector a) -> IOVector a) -> IO (IOVector a) -> IOVector a
forall a b. (a -> b) -> a -> b
$ IOVector a -> IO (IOVector a)
forall a. IOVector a -> IO (IOVector a)
asIOVector IOVector a
pvec