module Data.Vector.Generic.Mutable (
MVector(..),
length, overlaps, new, newWith, read, write, swap, clear, set, copy, grow,
slice, take, drop, init, tail,
unsafeSlice, unsafeInit, unsafeTail,
unsafeNew, unsafeNewWith, unsafeRead, unsafeWrite, unsafeSwap,
unsafeCopy, unsafeGrow,
unstream, transform, unstreamR, transformR,
unsafeAccum, accum, unsafeUpdate, update, reverse,
unstablePartition, unstablePartitionStream, partitionStream
) where
import qualified Data.Vector.Fusion.Stream as Stream
import Data.Vector.Fusion.Stream ( Stream, MStream )
import qualified Data.Vector.Fusion.Stream.Monadic as MStream
import Data.Vector.Fusion.Stream.Size
import Control.Monad.Primitive ( PrimMonad, PrimState )
import GHC.Float (
double2Int, int2Double
)
import Prelude hiding ( length, reverse, map, read,
take, drop, init, tail )
#include "vector.h"
class MVector v a where
basicLength :: v s a -> Int
basicUnsafeSlice :: Int
-> Int
-> v s a
-> v s a
basicOverlaps :: v s a -> v s a -> Bool
basicUnsafeNew :: PrimMonad m => Int -> m (v (PrimState m) a)
basicUnsafeNewWith :: PrimMonad m => Int -> a -> m (v (PrimState m) a)
basicUnsafeRead :: PrimMonad m => v (PrimState m) a -> Int -> m a
basicUnsafeWrite :: PrimMonad m => v (PrimState m) a -> Int -> a -> m ()
basicClear :: PrimMonad m => v (PrimState m) a -> m ()
basicSet :: PrimMonad m => v (PrimState m) a -> a -> m ()
basicUnsafeCopy :: PrimMonad m => v (PrimState m) a
-> v (PrimState m) a
-> m ()
basicUnsafeGrow :: PrimMonad m => v (PrimState m) a -> Int
-> m (v (PrimState m) a)
basicUnsafeNewWith n x
= do
v <- basicUnsafeNew n
basicSet v x
return v
basicClear _ = return ()
basicSet v x = do_set 0
where
n = basicLength v
do_set i | i < n = do
basicUnsafeWrite v i x
do_set (i+1)
| otherwise = return ()
basicUnsafeCopy dst src = do_copy 0
where
n = basicLength src
do_copy i | i < n = do
x <- basicUnsafeRead src i
basicUnsafeWrite dst i x
do_copy (i+1)
| otherwise = return ()
basicUnsafeGrow v by
= do
v' <- basicUnsafeNew (n+by)
basicUnsafeCopy (basicUnsafeSlice 0 n v') v
return v'
where
n = basicLength v
overlaps :: MVector v a => v s a -> v s a -> Bool
overlaps = basicOverlaps
unsafeAppend1 :: (PrimMonad m, MVector v a)
=> v (PrimState m) a -> Int -> a -> m (v (PrimState m) a)
unsafeAppend1 v i x
| i < length v = do
unsafeWrite v i x
return v
| otherwise = do
v' <- enlarge v
INTERNAL_CHECK(checkIndex) "unsafeAppend1" i (length v')
$ unsafeWrite v' i x
return v'
unsafePrepend1 :: (PrimMonad m, MVector v a)
=> v (PrimState m) a -> Int -> a -> m (v (PrimState m) a, Int)
unsafePrepend1 v i x
| i /= 0 = do
let i' = i1
unsafeWrite v i' x
return (v, i')
| otherwise = do
(v', i) <- enlargeFront v
let i' = i1
INTERNAL_CHECK(checkIndex) "unsafePrepend1" i' (length v')
$ unsafeWrite v' i' x
return (v', i')
mstream :: (PrimMonad m, MVector v a) => v (PrimState m) a -> MStream m a
mstream v = v `seq` (MStream.unfoldrM get 0 `MStream.sized` Exact n)
where
n = length v
get i | i < n = do x <- unsafeRead v i
return $ Just (x, i+1)
| otherwise = return $ Nothing
munstream :: (PrimMonad m, MVector v a)
=> v (PrimState m) a -> MStream m a -> m (v (PrimState m) a)
munstream v s = v `seq` do
n' <- MStream.foldM put 0 s
return $ unsafeSlice 0 n' v
where
put i x = do
INTERNAL_CHECK(checkIndex) "munstream" i (length v)
$ unsafeWrite v i x
return (i+1)
transform :: (PrimMonad m, MVector v a)
=> (MStream m a -> MStream m a) -> v (PrimState m) a -> m (v (PrimState m) a)
transform f v = munstream v (f (mstream v))
mrstream :: (PrimMonad m, MVector v a) => v (PrimState m) a -> MStream m a
mrstream v = v `seq` (MStream.unfoldrM get n `MStream.sized` Exact n)
where
n = length v
get i | j >= 0 = do x <- unsafeRead v j
return $ Just (x,j)
| otherwise = return Nothing
where
j = i1
munstreamR :: (PrimMonad m, MVector v a)
=> v (PrimState m) a -> MStream m a -> m (v (PrimState m) a)
munstreamR v s = v `seq` do
i <- MStream.foldM put n s
return $ unsafeSlice i (ni) v
where
n = length v
put i x = do
unsafeWrite v j x
return j
where
j = i1
transformR :: (PrimMonad m, MVector v a)
=> (MStream m a -> MStream m a) -> v (PrimState m) a -> m (v (PrimState m) a)
transformR f v = munstreamR v (f (mrstream v))
unstream :: (PrimMonad m, MVector v a) => Stream a -> m (v (PrimState m) a)
unstream s = case upperBound (Stream.size s) of
Just n -> unstreamMax s n
Nothing -> unstreamUnknown s
unstreamMax
:: (PrimMonad m, MVector v a) => Stream a -> Int -> m (v (PrimState m) a)
unstreamMax s n
= do
v <- INTERNAL_CHECK(checkLength) "unstreamMax" n
$ unsafeNew n
let put i x = do
INTERNAL_CHECK(checkIndex) "unstreamMax" i n
$ unsafeWrite v i x
return (i+1)
n' <- Stream.foldM' put 0 s
return $ INTERNAL_CHECK(checkSlice) "unstreamMax" 0 n' n
$ unsafeSlice 0 n' v
unstreamUnknown
:: (PrimMonad m, MVector v a) => Stream a -> m (v (PrimState m) a)
unstreamUnknown s
= do
v <- unsafeNew 0
(v', n) <- Stream.foldM put (v, 0) s
return $ INTERNAL_CHECK(checkSlice) "unstreamUnknown" 0 n (length v')
$ unsafeSlice 0 n v'
where
put (v,i) x = do
v' <- unsafeAppend1 v i x
return (v',i+1)
unstreamR :: (PrimMonad m, MVector v a) => Stream a -> m (v (PrimState m) a)
unstreamR s = case upperBound (Stream.size s) of
Just n -> unstreamRMax s n
Nothing -> unstreamRUnknown s
unstreamRMax
:: (PrimMonad m, MVector v a) => Stream a -> Int -> m (v (PrimState m) a)
unstreamRMax s n
= do
v <- INTERNAL_CHECK(checkLength) "unstreamRMax" n
$ unsafeNew n
let put i x = do
let i' = i1
INTERNAL_CHECK(checkIndex) "unstreamRMax" i' n
$ unsafeWrite v i' x
return i'
i <- Stream.foldM' put n s
return $ INTERNAL_CHECK(checkSlice) "unstreamRMax" i (ni) n
$ unsafeSlice i (ni) v
unstreamRUnknown
:: (PrimMonad m, MVector v a) => Stream a -> m (v (PrimState m) a)
unstreamRUnknown s
= do
v <- unsafeNew 0
(v', i) <- Stream.foldM put (v, 0) s
let n = length v'
return $ INTERNAL_CHECK(checkSlice) "unstreamRUnknown" i (ni) n
$ unsafeSlice i (ni) v'
where
put (v,i) x = unsafePrepend1 v i x
length :: MVector v a => v s a -> Int
length = basicLength
null :: MVector v a => v s a -> Bool
null v = length v == 0
new :: (PrimMonad m, MVector v a) => Int -> m (v (PrimState m) a)
new n = BOUNDS_CHECK(checkLength) "new" n
$ unsafeNew n
newWith :: (PrimMonad m, MVector v a) => Int -> a -> m (v (PrimState m) a)
newWith n x = BOUNDS_CHECK(checkLength) "newWith" n
$ unsafeNewWith n x
unsafeNew :: (PrimMonad m, MVector v a) => Int -> m (v (PrimState m) a)
unsafeNew n = UNSAFE_CHECK(checkLength) "unsafeNew" n
$ basicUnsafeNew n
unsafeNewWith :: (PrimMonad m, MVector v a) => Int -> a -> m (v (PrimState m) a)
unsafeNewWith n x = UNSAFE_CHECK(checkLength) "unsafeNewWith" n
$ basicUnsafeNewWith n x
grow :: (PrimMonad m, MVector v a)
=> v (PrimState m) a -> Int -> m (v (PrimState m) a)
grow v by = BOUNDS_CHECK(checkLength) "grow" by
$ unsafeGrow v by
growFront :: (PrimMonad m, MVector v a)
=> v (PrimState m) a -> Int -> m (v (PrimState m) a)
growFront v by = BOUNDS_CHECK(checkLength) "growFront" by
$ unsafeGrowFront v by
enlarge_delta v = max (length v) 1
enlarge :: (PrimMonad m, MVector v a)
=> v (PrimState m) a -> m (v (PrimState m) a)
enlarge v = unsafeGrow v (enlarge_delta v)
enlargeFront :: (PrimMonad m, MVector v a)
=> v (PrimState m) a -> m (v (PrimState m) a, Int)
enlargeFront v = do
v' <- unsafeGrowFront v by
return (v', by)
where
by = enlarge_delta v
unsafeGrow :: (PrimMonad m, MVector v a)
=> v (PrimState m) a -> Int -> m (v (PrimState m) a)
unsafeGrow v n = UNSAFE_CHECK(checkLength) "unsafeGrow" n
$ basicUnsafeGrow v n
unsafeGrowFront :: (PrimMonad m, MVector v a)
=> v (PrimState m) a -> Int -> m (v (PrimState m) a)
unsafeGrowFront v by = UNSAFE_CHECK(checkLength) "unsafeGrowFront" by
$ do
let n = length v
v' <- basicUnsafeNew (by+n)
basicUnsafeCopy (basicUnsafeSlice by n v') v
return v'
read :: (PrimMonad m, MVector v a) => v (PrimState m) a -> Int -> m a
read v i = BOUNDS_CHECK(checkIndex) "read" i (length v)
$ unsafeRead v i
write :: (PrimMonad m, MVector v a) => v (PrimState m) a -> Int -> a -> m ()
write v i x = BOUNDS_CHECK(checkIndex) "write" i (length v)
$ unsafeWrite v i x
swap :: (PrimMonad m, MVector v a) => v (PrimState m) a -> Int -> Int -> m ()
swap v i j = BOUNDS_CHECK(checkIndex) "swap" i (length v)
$ BOUNDS_CHECK(checkIndex) "swap" j (length v)
$ unsafeSwap v i j
exchange :: (PrimMonad m, MVector v a) => v (PrimState m) a -> Int -> a -> m a
exchange v i x = BOUNDS_CHECK(checkIndex) "exchange" i (length v)
$ unsafeExchange v i x
unsafeRead :: (PrimMonad m, MVector v a) => v (PrimState m) a -> Int -> m a
unsafeRead v i = UNSAFE_CHECK(checkIndex) "unsafeRead" i (length v)
$ basicUnsafeRead v i
unsafeWrite :: (PrimMonad m, MVector v a)
=> v (PrimState m) a -> Int -> a -> m ()
unsafeWrite v i x = UNSAFE_CHECK(checkIndex) "unsafeWrite" i (length v)
$ basicUnsafeWrite v i x
unsafeSwap :: (PrimMonad m, MVector v a)
=> v (PrimState m) a -> Int -> Int -> m ()
unsafeSwap v i j = UNSAFE_CHECK(checkIndex) "unsafeSwap" i (length v)
$ UNSAFE_CHECK(checkIndex) "unsafeSwap" j (length v)
$ do
x <- unsafeRead v i
y <- unsafeRead v j
unsafeWrite v i y
unsafeWrite v j x
unsafeExchange :: (PrimMonad m, MVector v a)
=> v (PrimState m) a -> Int -> a -> m a
unsafeExchange v i x = UNSAFE_CHECK(checkIndex) "unsafeExchange" i (length v)
$ do
y <- unsafeRead v i
unsafeWrite v i x
return y
clear :: (PrimMonad m, MVector v a) => v (PrimState m) a -> m ()
clear = basicClear
set :: (PrimMonad m, MVector v a) => v (PrimState m) a -> a -> m ()
set = basicSet
copy :: (PrimMonad m, MVector v a)
=> v (PrimState m) a -> v (PrimState m) a -> m ()
copy dst src = BOUNDS_CHECK(check) "copy" "overlapping vectors"
(not (dst `overlaps` src))
$ BOUNDS_CHECK(check) "copy" "length mismatch"
(length dst == length src)
$ unsafeCopy dst src
unsafeCopy :: (PrimMonad m, MVector v a) => v (PrimState m) a
-> v (PrimState m) a
-> m ()
unsafeCopy dst src = UNSAFE_CHECK(check) "unsafeCopy" "length mismatch"
(length dst == length src)
$ UNSAFE_CHECK(check) "unsafeCopy" "overlapping vectors"
(not (dst `overlaps` src))
$ basicUnsafeCopy dst src
slice :: MVector v a => Int -> Int -> v s a -> v s a
slice i n v = BOUNDS_CHECK(checkSlice) "slice" i n (length v)
$ unsafeSlice i n v
take :: MVector v a => Int -> v s a -> v s a
take n v = unsafeSlice 0 (min (max n 0) (length v)) v
drop :: MVector v a => Int -> v s a -> v s a
drop n v = unsafeSlice (min m n') (max 0 (m n')) v
where
n' = max n 0
m = length v
init :: MVector v a => v s a -> v s a
init v = slice 0 (length v 1) v
tail :: MVector v a => v s a -> v s a
tail v = slice 1 (length v 1) v
unsafeSlice :: MVector v a => Int
-> Int
-> v s a
-> v s a
unsafeSlice i n v = UNSAFE_CHECK(checkSlice) "unsafeSlice" i n (length v)
$ basicUnsafeSlice i n v
unsafeInit :: MVector v a => v s a -> v s a
unsafeInit v = unsafeSlice 0 (length v 1) v
unsafeTail :: MVector v a => v s a -> v s a
unsafeTail v = unsafeSlice 1 (length v 1) v
unsafeTake :: MVector v a => Int -> v s a -> v s a
unsafeTake n v = unsafeSlice 0 n v
unsafeDrop :: MVector v a => Int -> v s a -> v s a
unsafeDrop n v = unsafeSlice n (length v n) v
accum :: (PrimMonad m, MVector v a)
=> (a -> b -> a) -> v (PrimState m) a -> Stream (Int, b) -> m ()
accum f !v s = Stream.mapM_ upd s
where
upd (i,b) = do
a <- BOUNDS_CHECK(checkIndex) "accum" i (length v)
$ unsafeRead v i
unsafeWrite v i (f a b)
update :: (PrimMonad m, MVector v a)
=> v (PrimState m) a -> Stream (Int, a) -> m ()
update !v s = Stream.mapM_ upd s
where
upd (i,b) = BOUNDS_CHECK(checkIndex) "update" i (length v)
$ unsafeWrite v i b
unsafeAccum :: (PrimMonad m, MVector v a)
=> (a -> b -> a) -> v (PrimState m) a -> Stream (Int, b) -> m ()
unsafeAccum f !v s = Stream.mapM_ upd s
where
upd (i,b) = do
a <- UNSAFE_CHECK(checkIndex) "accum" i (length v)
$ unsafeRead v i
unsafeWrite v i (f a b)
unsafeUpdate :: (PrimMonad m, MVector v a)
=> v (PrimState m) a -> Stream (Int, a) -> m ()
unsafeUpdate !v s = Stream.mapM_ upd s
where
upd (i,b) = UNSAFE_CHECK(checkIndex) "accum" i (length v)
$ unsafeWrite v i b
reverse :: (PrimMonad m, MVector v a) => v (PrimState m) a -> m ()
reverse !v = reverse_loop 0 (length v 1)
where
reverse_loop i j | i < j = do
unsafeSwap v i j
reverse_loop (i + 1) (j 1)
reverse_loop _ _ = return ()
unstablePartition :: forall m v a. (PrimMonad m, MVector v a)
=> (a -> Bool) -> v (PrimState m) a -> m Int
unstablePartition f !v = from_left 0 (length v)
where
from_left :: Int -> Int -> m Int
from_left i j
| i == j = return i
| otherwise = do
x <- unsafeRead v i
if f x
then from_left (i+1) j
else from_right i (j1)
from_right :: Int -> Int -> m Int
from_right i j
| i == j = return i
| otherwise = do
x <- unsafeRead v j
if f x
then do
y <- unsafeRead v i
unsafeWrite v i x
unsafeWrite v j y
from_left (i+1) j
else from_right i (j1)
unstablePartitionStream :: (PrimMonad m, MVector v a)
=> (a -> Bool) -> Stream a -> m (v (PrimState m) a, v (PrimState m) a)
unstablePartitionStream f s
= case upperBound (Stream.size s) of
Just n -> unstablePartitionMax f s n
Nothing -> partitionUnknown f s
unstablePartitionMax :: (PrimMonad m, MVector v a)
=> (a -> Bool) -> Stream a -> Int
-> m (v (PrimState m) a, v (PrimState m) a)
unstablePartitionMax f s n
= do
v <- INTERNAL_CHECK(checkLength) "unstablePartitionMax" n
$ unsafeNew n
let
put (i, j) x
| f x = do
unsafeWrite v i x
return (i+1, j)
| otherwise = do
unsafeWrite v (j1) x
return (i, j1)
(i,j) <- Stream.foldM' put (0, n) s
return (unsafeSlice 0 i v, unsafeSlice j (nj) v)
partitionStream :: (PrimMonad m, MVector v a)
=> (a -> Bool) -> Stream a -> m (v (PrimState m) a, v (PrimState m) a)
partitionStream f s
= case upperBound (Stream.size s) of
Just n -> partitionMax f s n
Nothing -> partitionUnknown f s
partitionMax :: (PrimMonad m, MVector v a)
=> (a -> Bool) -> Stream a -> Int -> m (v (PrimState m) a, v (PrimState m) a)
partitionMax f s n
= do
v <- INTERNAL_CHECK(checkLength) "unstablePartitionMax" n
$ unsafeNew n
let
put (i,j) x
| f x = do
unsafeWrite v i x
return (i+1,j)
| otherwise = let j' = j1 in
do
unsafeWrite v j' x
return (i,j')
(i,j) <- Stream.foldM' put (0,n) s
INTERNAL_CHECK(check) "partitionMax" "invalid indices" (i <= j)
$ return ()
let l = unsafeSlice 0 i v
r = unsafeSlice j (nj) v
reverse r
return (l,r)
partitionUnknown :: (PrimMonad m, MVector v a)
=> (a -> Bool) -> Stream a -> m (v (PrimState m) a, v (PrimState m) a)
partitionUnknown f s
= do
v1 <- unsafeNew 0
v2 <- unsafeNew 0
(v1', n1, v2', n2) <- Stream.foldM' put (v1, 0, v2, 0) s
INTERNAL_CHECK(checkSlice) "partitionUnknown" 0 n1 (length v1')
$ INTERNAL_CHECK(checkSlice) "partitionUnknown" 0 n2 (length v2')
$ return (unsafeSlice 0 n1 v1', unsafeSlice 0 n2 v2')
where
put (v1, i1, v2, i2) x
| f x = do
v1' <- unsafeAppend1 v1 i1 x
return (v1', i1+1, v2, i2)
| otherwise = do
v2' <- unsafeAppend1 v2 i2 x
return (v1, i1, v2', i2+1)