module Data.Conduit.Vector ( sourceVector, sourceMVector, consumeVector, consumeMVector, takeVector, takeMVector, thawConduit, freezeConduit ) where import Control.Monad.Primitive import Control.Monad.ST import Data.Conduit import qualified Data.Conduit.List as L import Data.Conduit.Util import qualified Data.Vector.Generic as V import qualified Data.Vector.Generic.Mutable as M import qualified Data.Vector.Fusion.Stream as S import qualified Data.Vector.Fusion.Stream.Monadic as SM -- | Use an immutable vector as a source. sourceVector :: (Monad m, V.Vector v a) => v a -> Source m a sourceVector vec = sourceState (V.stream vec) f where f stream | S.null stream = return StateClosed | otherwise = return $ StateOpen (S.tail stream) (S.head stream) -- | Use a mutable vector as a source in the ST or IO monad. sourceMVector :: (PrimMonad m, M.MVector v a) => v (PrimState m) a -> Source m a sourceMVector vec = sourceState (M.mstream vec) f where f stream = do isNull <- SM.null stream if isNull then return StateClosed else do x <- SM.head stream return $ StateOpen (SM.tail stream) x -- | Consumes all values from the stream and return as an immutable vector. -- Due to the way it operates, it requires the ST monad at the minimum, -- although it can also operate IO. This is due to its dependency on -- a mutable vector. consumeVector :: (PrimMonad m, V.Vector v a) => Sink a m (v a) consumeVector = sinkState (Nothing, 0) push close where push (v, index) x = do v' <- case v of Nothing -> M.new 10 Just vec -> return vec let len = M.length v' v'' <- if index >= len then M.grow v' len else return v' M.write v'' index x return $ StateProcessing (Just v'', index + 1) close (Nothing, index) = return $ V.fromList [] close (Just v, index) = V.unsafeFreeze $ M.take index v -- | Consumes the first n values from a source and returns as an immutable -- vector. takeVector :: (PrimMonad m, V.Vector v a) => Int -> Sink a m (v a) takeVector n = sinkState (Nothing, 0) push close where push (v, index) x = do v' <- case v of Nothing -> M.new n Just vec -> return vec if index >= n then do v'' <- V.unsafeFreeze v' return $ StateDone Nothing v'' else do M.write v' index x return $ StateProcessing (Just v', index + 1) close (Nothing, index) = return $ V.fromList [] close (Just v, index) = V.unsafeFreeze v -- | Consumes all values from the stream and returns as a mutable vector. consumeMVector :: (PrimMonad m, M.MVector v a) => Sink a m (v (PrimState m) a) consumeMVector = sinkState (Nothing, 0) push close where push (v, index) x = do v' <- case v of Nothing -> M.new 10 Just vec -> return vec let len = M.length v' v'' <- if index >= len then M.grow v' len else return v' M.write v'' index x return $ StateProcessing (Just v'', index + 1) close (Nothing, index) = M.new 0 close (Just v, index) = return $ M.take index v -- | Consumes the first n values from the stream and returns as a -- mutable vector. takeMVector :: (PrimMonad m, M.MVector v a) => Int -> Sink a m (v (PrimState m) a) takeMVector n = sinkState (Nothing, 0) push close where push (v, index) x = do v' <- case v of Nothing -> M.new n Just vec -> return vec if index >= n then return $ StateDone Nothing v' else do M.write v' index x return $ StateProcessing (Just v', index + 1) close (Nothing, index) = M.new 0 close (Just v, index) = return v -- | Conduit which thaws immutable vectors into mutable vectors thawConduit :: (PrimMonad m, V.Vector v a) => Conduit (v a) m (V.Mutable v (PrimState m) a) thawConduit = L.mapM V.unsafeThaw -- | Conduit which freezes mutable vectors into immutable vectors freezeConduit :: (PrimMonad m, V.Vector v a) => Conduit (V.Mutable v (PrimState m) a) m (v a) freezeConduit = L.mapM V.unsafeFreeze