{-# LANGUAGE MultiParamTypeClasses, BangPatterns #-} -- | -- Module : Data.Vector.Generic.Mutable -- Copyright : (c) Roman Leshchinskiy 2008-2009 -- License : BSD-style -- -- Maintainer : Roman Leshchinskiy -- Stability : experimental -- Portability : non-portable -- -- Generic interface to mutable vectors -- #include "phases.h" module Data.Vector.Generic.Mutable ( MVectorPure(..), MVector(..), slice, new, newWith, read, write, copy, grow, unstream, transform, accum, update, reverse ) where import qualified Data.Vector.Fusion.Stream as Stream import Data.Vector.Fusion.Stream ( Stream, MStream ) import qualified Data.Vector.Fusion.Stream.Monadic as MStream import Data.Vector.Fusion.Stream.Size import Control.Exception ( assert ) import GHC.Float ( double2Int, int2Double ) import Prelude hiding ( length, reverse, map, read ) gROWTH_FACTOR :: Double gROWTH_FACTOR = 1.5 -- | Basic pure functions on mutable vectors class MVectorPure v a where -- | Length of the mutable vector length :: v a -> Int -- | Yield a part of the mutable vector without copying it. No range checks! unsafeSlice :: v a -> Int -- ^ starting index -> Int -- ^ length of the slice -> v a -- Check whether two vectors overlap. overlaps :: v a -> v a -> Bool -- | Class of mutable vectors. The type @m@ is the monad in which the mutable -- vector can be transformed and @a@ is the type of elements. -- class (Monad m, MVectorPure v a) => MVector v m a where -- | Create a mutable vector of the given length. Length is not checked! unsafeNew :: Int -> m (v a) -- | Create a mutable vector of the given length and fill it with an -- initial value. Length is not checked! unsafeNewWith :: Int -> a -> m (v a) -- | Yield the element at the given position. Index is not checked! unsafeRead :: v a -> Int -> m a -- | Replace the element at the given position. Index is not checked! unsafeWrite :: v a -> Int -> a -> m () -- | Clear all references to external objects clear :: v a -> m () -- | Write the value at each position. set :: v a -> a -> m () -- | Copy a vector. The two vectors may not overlap. This is not checked! unsafeCopy :: v a -- ^ target -> v a -- ^ source -> m () -- | Grow a vector by the given number of elements. The length is not -- checked! unsafeGrow :: v a -> Int -> m (v a) {-# INLINE unsafeNewWith #-} unsafeNewWith n x = do v <- unsafeNew n set v x return v {-# INLINE set #-} set v x = do_set 0 where n = length v do_set i | i < n = do unsafeWrite v i x do_set (i+1) | otherwise = return () {-# INLINE unsafeCopy #-} unsafeCopy dst src = do_copy 0 where n = length src do_copy i | i < n = do x <- unsafeRead src i unsafeWrite dst i x do_copy (i+1) | otherwise = return () {-# INLINE unsafeGrow #-} unsafeGrow v by = do v' <- unsafeNew (n+by) unsafeCopy (unsafeSlice v' 0 n) v return v' where n = length v -- | Test whether the index is valid for the vector inBounds :: MVectorPure v a => v a -> Int -> Bool {-# INLINE inBounds #-} inBounds v i = i >= 0 && i < length v -- | Yield a part of the mutable vector without copying it. Safer version of -- 'unsafeSlice'. slice :: MVectorPure v a => v a -> Int -> Int -> v a {-# INLINE slice #-} slice v i n = assert (i >=0 && n >= 0 && i+n <= length v) $ unsafeSlice v i n -- | Create a mutable vector of the given length. Safer version of -- 'unsafeNew'. new :: MVector v m a => Int -> m (v a) {-# INLINE new #-} new n = assert (n >= 0) $ unsafeNew n -- | Create a mutable vector of the given length and fill it with an -- initial value. Safer version of 'unsafeNewWith'. newWith :: MVector v m a => Int -> a -> m (v a) {-# INLINE newWith #-} newWith n x = assert (n >= 0) $ unsafeNewWith n x -- | Yield the element at the given position. Safer version of 'unsafeRead'. read :: MVector v m a => v a -> Int -> m a {-# INLINE read #-} read v i = assert (inBounds v i) $ unsafeRead v i -- | Replace the element at the given position. Safer version of -- 'unsafeWrite'. write :: MVector v m a => v a -> Int -> a -> m () {-# INLINE write #-} write v i x = assert (inBounds v i) $ unsafeWrite v i x -- | Copy a vector. The two vectors may not overlap. Safer version of -- 'unsafeCopy'. copy :: MVector v m a => v a -> v a -> m () {-# INLINE copy #-} copy dst src = assert (not (dst `overlaps` src) && length dst == length src) $ unsafeCopy dst src -- | Grow a vector by the given number of elements. Safer version of -- 'unsafeGrow'. grow :: MVector v m a => v a -> Int -> m (v a) {-# INLINE grow #-} grow v by = assert (by >= 0) $ unsafeGrow v by mstream :: MVector v m a => v a -> MStream m a {-# INLINE mstream #-} mstream v = v `seq` (MStream.unfoldrM get 0 `MStream.sized` Exact n) where n = length v {-# INLINE_INNER get #-} get i | i < n = do x <- unsafeRead v i return $ Just (x, i+1) | otherwise = return $ Nothing munstream :: MVector v m a => v a -> MStream m a -> m (v a) {-# INLINE munstream #-} munstream v s = v `seq` do n' <- MStream.foldM put 0 s return $ slice v 0 n' where {-# INLINE_INNER put #-} put i x = do { write v i x; return (i+1) } transform :: MVector v m a => (MStream m a -> MStream m a) -> v a -> m (v a) {-# INLINE_STREAM transform #-} transform f v = munstream v (f (mstream v)) -- | Create a new mutable vector and fill it with elements from the 'Stream'. -- The vector will grow logarithmically if the 'Size' hint of the 'Stream' is -- inexact. unstream :: MVector v m a => Stream a -> m (v a) {-# INLINE_STREAM unstream #-} unstream s = case upperBound (Stream.size s) of Just n -> unstreamMax s n Nothing -> unstreamUnknown s unstreamMax :: MVector v m a => Stream a -> Int -> m (v a) {-# INLINE unstreamMax #-} unstreamMax s n = do v <- new n let put i x = do { write v i x; return (i+1) } n' <- Stream.foldM' put 0 s return $ slice v 0 n' unstreamUnknown :: MVector v m a => Stream a -> m (v a) {-# INLINE unstreamUnknown #-} unstreamUnknown s = do v <- new 0 (v', n) <- Stream.foldM put (v, 0) s return $ slice v' 0 n where {-# INLINE_INNER put #-} put (v, i) x = do v' <- enlarge v i unsafeWrite v' i x return (v', i+1) {-# INLINE_INNER enlarge #-} enlarge v i | i < length v = return v | otherwise = unsafeGrow v . max 1 . double2Int $ int2Double (length v) * gROWTH_FACTOR accum :: MVector v m a => (a -> b -> a) -> v a -> Stream (Int, b) -> m () {-# INLINE accum #-} accum f !v s = Stream.mapM_ upd s where {-# INLINE_INNER upd #-} upd (i,b) = do a <- read v i write v i (f a b) update :: MVector v m a => v a -> Stream (Int, a) -> m () {-# INLINE update #-} update = accum (const id) reverse :: MVector v m a => v a -> m () {-# INLINE reverse #-} reverse !v = reverse_loop 0 (length v - 1) where reverse_loop i j | i < j = do x <- unsafeRead v i y <- unsafeRead v j unsafeWrite v i y unsafeWrite v j x reverse_loop (i + 1) (j - 1) reverse_loop _ _ = return ()