#include "phases.h"
module Data.Vector.Generic.Mutable (
MVectorPure(..), MVector(..),
slice, new, newWith, read, write, copy, grow,
unstream, transform,
accum, update, reverse
) where
import qualified Data.Vector.Fusion.Stream as Stream
import Data.Vector.Fusion.Stream ( Stream, MStream )
import qualified Data.Vector.Fusion.Stream.Monadic as MStream
import Data.Vector.Fusion.Stream.Size
import Control.Exception ( assert )
import GHC.Float (
double2Int, int2Double
)
import Prelude hiding ( length, reverse, map, read )
gROWTH_FACTOR :: Double
gROWTH_FACTOR = 1.5
class MVectorPure v a where
length :: v a -> Int
unsafeSlice :: v a -> Int
-> Int
-> v a
overlaps :: v a -> v a -> Bool
class (Monad m, MVectorPure v a) => MVector v m a where
unsafeNew :: Int -> m (v a)
unsafeNewWith :: Int -> a -> m (v a)
unsafeRead :: v a -> Int -> m a
unsafeWrite :: v a -> Int -> a -> m ()
clear :: v a -> m ()
set :: v a -> a -> m ()
unsafeCopy :: v a
-> v a
-> m ()
unsafeGrow :: v a -> Int -> m (v a)
unsafeNewWith n x = do
v <- unsafeNew n
set v x
return v
set v x = do_set 0
where
n = length v
do_set i | i < n = do
unsafeWrite v i x
do_set (i+1)
| otherwise = return ()
unsafeCopy dst src = do_copy 0
where
n = length src
do_copy i | i < n = do
x <- unsafeRead src i
unsafeWrite dst i x
do_copy (i+1)
| otherwise = return ()
unsafeGrow v by = do
v' <- unsafeNew (n+by)
unsafeCopy (unsafeSlice v' 0 n) v
return v'
where
n = length v
inBounds :: MVectorPure v a => v a -> Int -> Bool
inBounds v i = i >= 0 && i < length v
slice :: MVectorPure v a => v a -> Int -> Int -> v a
slice v i n = assert (i >=0 && n >= 0 && i+n <= length v)
$ unsafeSlice v i n
new :: MVector v m a => Int -> m (v a)
new n = assert (n >= 0) $ unsafeNew n
newWith :: MVector v m a => Int -> a -> m (v a)
newWith n x = assert (n >= 0) $ unsafeNewWith n x
read :: MVector v m a => v a -> Int -> m a
read v i = assert (inBounds v i) $ unsafeRead v i
write :: MVector v m a => v a -> Int -> a -> m ()
write v i x = assert (inBounds v i) $ unsafeWrite v i x
copy :: MVector v m a => v a -> v a -> m ()
copy dst src = assert (not (dst `overlaps` src) && length dst == length src)
$ unsafeCopy dst src
grow :: MVector v m a => v a -> Int -> m (v a)
grow v by = assert (by >= 0)
$ unsafeGrow v by
mstream :: MVector v m a => v a -> MStream m a
mstream v = v `seq` (MStream.unfoldrM get 0 `MStream.sized` Exact n)
where
n = length v
get i | i < n = do x <- unsafeRead v i
return $ Just (x, i+1)
| otherwise = return $ Nothing
munstream :: MVector v m a => v a -> MStream m a -> m (v a)
munstream v s = v `seq` do
n' <- MStream.foldM put 0 s
return $ slice v 0 n'
where
put i x = do { write v i x; return (i+1) }
transform :: MVector v m a => (MStream m a -> MStream m a) -> v a -> m (v a)
transform f v = munstream v (f (mstream v))
unstream :: MVector v m a => Stream a -> m (v a)
unstream s = case upperBound (Stream.size s) of
Just n -> unstreamMax s n
Nothing -> unstreamUnknown s
unstreamMax :: MVector v m a => Stream a -> Int -> m (v a)
unstreamMax s n
= do
v <- new n
let put i x = do { write v i x; return (i+1) }
n' <- Stream.foldM' put 0 s
return $ slice v 0 n'
unstreamUnknown :: MVector v m a => Stream a -> m (v a)
unstreamUnknown s
= do
v <- new 0
(v', n) <- Stream.foldM put (v, 0) s
return $ slice v' 0 n
where
put (v, i) x = do
v' <- enlarge v i
unsafeWrite v' i x
return (v', i+1)
enlarge v i | i < length v = return v
| otherwise = unsafeGrow v
. max 1
. double2Int
$ int2Double (length v) * gROWTH_FACTOR
accum :: MVector v m a => (a -> b -> a) -> v a -> Stream (Int, b) -> m ()
accum f !v s = Stream.mapM_ upd s
where
upd (i,b) = do
a <- read v i
write v i (f a b)
update :: MVector v m a => v a -> Stream (Int, a) -> m ()
update = accum (const id)
reverse :: MVector v m a => v a -> m ()
reverse !v = reverse_loop 0 (length v 1)
where
reverse_loop i j | i < j = do
x <- unsafeRead v i
y <- unsafeRead v j
unsafeWrite v i y
unsafeWrite v j x
reverse_loop (i + 1) (j 1)
reverse_loop _ _ = return ()