{-|
Module      : Data.Conduit.Algorithms.Storable
Copyright   : 2018 Luis Pedro Coelho
License     : MIT
Maintainer  : luis@luispedro.org

Read/write Storable vectors
-}
{-# LANGUAGE FlexibleContexts, ScopedTypeVariables #-}
module Data.Conduit.Algorithms.Storable
    ( writeStorableV
    , readStorableV
    ) where

import qualified Data.ByteString as B
import qualified Data.ByteString.Unsafe as BU
import qualified Data.Vector.Storable as VS
import qualified Data.Vector.Storable.Mutable as VSM
import Control.Monad.IO.Class

import Foreign.Ptr
import Foreign.Marshal.Utils
import Foreign.Storable
import Control.Monad (when)

import qualified Data.Conduit.List as CL
import qualified Data.Conduit.Combinators as CC
import qualified Data.Conduit as C
import           Data.Conduit ((.|))

-- | write a Storable vector
--
-- This uses the same format as in-memory
--
-- See 'readStorableV'
writeStorableV :: forall m a. (MonadIO m, Monad m, Storable a) => C.ConduitT (VS.Vector a) B.ByteString m ()
writeStorableV :: ConduitT (Vector a) ByteString m ()
writeStorableV = (Vector a -> m ByteString) -> ConduitT (Vector a) ByteString m ()
forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> ConduitT a b m ()
CL.mapM (IO ByteString -> m ByteString
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO(IO ByteString -> m ByteString)
-> (Vector a -> IO ByteString) -> Vector a -> m ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Storable a => Vector a -> IO ByteString
Vector a -> IO ByteString
encodeStorable')
    where
        encodeStorable' :: Storable a => VS.Vector a -> IO B.ByteString
        encodeStorable' :: Vector a -> IO ByteString
encodeStorable' Vector a
v' = Vector a -> (Ptr a -> IO ByteString) -> IO ByteString
forall a b. Storable a => Vector a -> (Ptr a -> IO b) -> IO b
VS.unsafeWith Vector a
v' ((Ptr a -> IO ByteString) -> IO ByteString)
-> (Ptr a -> IO ByteString) -> IO ByteString
forall a b. (a -> b) -> a -> b
$ \Ptr a
p ->
                                    CStringLen -> IO ByteString
B.packCStringLen (Ptr a -> Ptr CChar
forall a b. Ptr a -> Ptr b
castPtr Ptr a
p, Vector a -> Int
forall a. Storable a => Vector a -> Int
VS.length Vector a
v' Int -> Int -> Int
forall a. Num a => a -> a -> a
* (a -> Int
forall a. Storable a => a -> Int
sizeOf (a
forall a. HasCallStack => a
undefined :: a)))


-- | read a Storable vector
--
-- This expects the same format as the in-memory vector.
--
-- This will break up the incoming data into vectors of the given size. The
-- last vector may be smaller if there is not enough data. Any unconsumed Bytes
-- will be leftover for the next conduit in the pipeline.
--
-- See 'writeStorableV'
readStorableV :: forall m a. (MonadIO m, Storable a) => Int -> C.ConduitM B.ByteString (VS.Vector a) m ()
readStorableV :: Int -> ConduitM ByteString (Vector a) m ()
readStorableV Int
nelems = Index ByteString -> ConduitT ByteString ByteString m ()
forall (m :: * -> *) seq.
(Monad m, IsSequence seq) =>
Index seq -> ConduitT seq seq m ()
CC.chunksOfE Int
Index ByteString
blockBytes ConduitT ByteString ByteString m ()
-> ConduitM ByteString (Vector a) m ()
-> ConduitM ByteString (Vector a) m ()
forall (m :: * -> *) a b c r.
Monad m =>
ConduitM a b m () -> ConduitM b c m r -> ConduitM a c m r
.| ConduitM ByteString (Vector a) m ()
MonadIO m => ConduitM ByteString (Vector a) m ()
parseBlocks
    where
        blockBytes :: Int
blockBytes = Int
nelems Int -> Int -> Int
forall a. Num a => a -> a -> a
* (a -> Int
forall a. Storable a => a -> Int
sizeOf a
a')
        a' :: a
        a' :: a
a' = a
forall a. HasCallStack => a
undefined


        parseBlocks :: MonadIO m => C.ConduitT B.ByteString (VS.Vector a) m ()
        parseBlocks :: ConduitM ByteString (Vector a) m ()
parseBlocks = (ByteString -> ConduitM ByteString (Vector a) m ())
-> ConduitM ByteString (Vector a) m ()
forall (m :: * -> *) i o r.
Monad m =>
(i -> ConduitT i o m r) -> ConduitT i o m ()
C.awaitForever ((ByteString -> ConduitM ByteString (Vector a) m ())
 -> ConduitM ByteString (Vector a) m ())
-> (ByteString -> ConduitM ByteString (Vector a) m ())
-> ConduitM ByteString (Vector a) m ()
forall a b. (a -> b) -> a -> b
$ \ByteString
bs -> do
            let (Int
n,Int
rest) = ByteString -> Int
B.length ByteString
bs Int -> Int -> (Int, Int)
forall a. Integral a => a -> a -> (a, a)
`divMod` a -> Int
forall a. Storable a => a -> Int
sizeOf a
a'
            Vector a
r <- IO (Vector a) -> ConduitT ByteString (Vector a) m (Vector a)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Vector a) -> ConduitT ByteString (Vector a) m (Vector a))
-> IO (Vector a) -> ConduitT ByteString (Vector a) m (Vector a)
forall a b. (a -> b) -> a -> b
$ do
                IOVector a
v <- Int -> IO (MVector (PrimState IO) a)
forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
Int -> m (MVector (PrimState m) a)
VSM.new Int
n
                ByteString -> (CStringLen -> IO ()) -> IO ()
forall a. ByteString -> (CStringLen -> IO a) -> IO a
BU.unsafeUseAsCStringLen ByteString
bs ((CStringLen -> IO ()) -> IO ()) -> (CStringLen -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \(Ptr CChar
p, Int
_) ->
                    IOVector a -> (Ptr a -> IO ()) -> IO ()
forall a b. Storable a => IOVector a -> (Ptr a -> IO b) -> IO b
VSM.unsafeWith IOVector a
v ((Ptr a -> IO ()) -> IO ()) -> (Ptr a -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr a
vp ->
                        Ptr CChar -> Ptr CChar -> Int -> IO ()
forall a. Ptr a -> Ptr a -> Int -> IO ()
moveBytes (Ptr a -> Ptr CChar
forall a b. Ptr a -> Ptr b
castPtr Ptr a
vp) Ptr CChar
p (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
* a -> Int
forall a. Storable a => a -> Int
sizeOf a
a')
                MVector (PrimState IO) a -> IO (Vector a)
forall a (m :: * -> *).
(Storable a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
VS.unsafeFreeze IOVector a
MVector (PrimState IO) a
v
            Vector a -> ConduitM ByteString (Vector a) m ()
forall (m :: * -> *) o i. Monad m => o -> ConduitT i o m ()
C.yield Vector a
r
            Bool
-> ConduitM ByteString (Vector a) m ()
-> ConduitM ByteString (Vector a) m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
rest Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0) (ConduitM ByteString (Vector a) m ()
 -> ConduitM ByteString (Vector a) m ())
-> ConduitM ByteString (Vector a) m ()
-> ConduitM ByteString (Vector a) m ()
forall a b. (a -> b) -> a -> b
$ do
                ByteString -> ConduitM ByteString (Vector a) m ()
forall i o (m :: * -> *). i -> ConduitT i o m ()
C.leftover (Int -> ByteString -> ByteString
B.drop (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
* a -> Int
forall a. Storable a => a -> Int
sizeOf a
a') ByteString
bs)