{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StaticPointers #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}

{-# OPTIONS_GHC -fno-warn-orphans #-}

-- | This module provides composable batched marshalling.
--
-- === Batching
--
-- Calls to Java methods via JNI are slow in general. Marshalling an array of
-- primitive values can be as slow as marshalling a single value.
--
-- Because of this, reifying an iterator or a container is best done by
-- accumulating multiple elements on the java side before passing them to the
-- Haskell side. And conversely, when reflecting an iterator or container,
-- multiple Haskell values are put together before marshalling to the Java
-- side.
--
-- Some Haskell values can be batched trivially into arrays of primitive values.
-- 'Int32' can be batched in a java @int[]@, 'Double' can be batched in a java
-- @double[]@, etc. However, other types like @Tuple2 Int32 Double@ would
-- require more primitive arrays. Values of type @Tuple2 Int32 Double@ are
-- batched in a pair of java arrays of type @int[]@ and @double[]@.
--
-- > data Tuple2 a b = Tuple2 a b
--
-- More generally, the design aims to provide composable batchers. If one knows
-- how to batch types @a@ and @b@, one can also batch @Tuple2 a b@, @[a]@,
-- @Vector a@, etc.
--
-- A reference to a batch of values in Java has the type @J (Batch a)@, where
-- @a@ is the Haskell type of the elements in the batch. e.g.
--
-- > type instance Batch Int32 = 'Array ('Prim "int")
-- > type instance Batch Double = 'Array ('Prim "double")
-- > type instance Batch (Tuple2 a b) =
-- >                 'Class "scala.Tuple2" <> '[Batch a, Batch b]
--
-- When defining batching for a new type, one needs to tell how batches are
-- represented in Java by adding a type instance to the type family @Batch@.
-- In addition, procedures for adding and extracting values from the batch
-- need to be specified on both the Haskell and the Java side.
--
-- On the Java side, batches are built using the interface
-- @io.tweag.jvm.batching.BatchWriter@. On the Haskell side, these
-- batches are read using @reifyBatch@.
--
-- > class ( ... ) => BatchReify a where
-- >   newBatchWriter
-- >     :: proxy a
-- >     -> IO (J ('Iface "io.tweag.jvm.batching.BatchWriter"
-- >                  <> [Interp a, Batch a]
-- >              )
-- >           )
-- >   reifyBatch :: J (Batch a) -> Int32 -> IO (V.Vector a)
--
-- @newBatchWriter@ produces a java object implementing the @BatchWriter@
-- interface, and @reifyBatch@ allows to read a batch created in this fashion.
--
-- Conversely, batches can be read on the Java side using the interface
-- @io.tweag.jvm.batching.BatchReader@. And on the Haskell side, these
-- batches can be created with @reflectBatch@.
--
-- > class ( ... ) => BatchReflect a where
-- >  newBatchReader
-- >    :: proxy a
-- >    -> IO (J ('Iface "io.tweag.jvm.batching.BatchReader"
-- >                 <> [Batch a, Interp a]
-- >              )
-- >          )
-- >  reflectBatch :: V.Vector a -> IO (J (Batch a))
--
-- @newBatchReader@ produces a java object implementing the @BatchReader@
-- interface, and @reflectBatch@ allows to create these batches from vectors of
-- Haskell values.
--
-- The methods of @BatchReify@ and @BatchReflect@ offer default
-- implementations which marshal elements in the batch one at a time. Taking
-- advantage of batching requires defining the methods explicitly. The default
-- implementations are useful for cases where speed is not important, for
-- instance when the iterators to reflect or reify contain a single element or
-- just very few.
--
-- 'Vector's and 'ByteString's are batched with the follow scheme.
--
-- > type instance Batch BS.ByteString
-- >   = 'Class "io.tweag.jvm.batching.Tuple2" <>
-- >        '[ 'Array ('Prim "byte")
-- >         , 'Array ('Prim "int")
-- >         ]
--
-- We use two arrays. One of the arrays contains the result of appending all of
-- the 'ByteString's in the batch. The other array contains the offset of each
-- vector in the resulting array. See 'ArrayBatch'.
--
module Language.Java.Batching
  ( Batchable(..)
  , BatchReify(..)
  , BatchReflect(..)
    -- * Array batching
  , ArrayBatch
  ) where

import Control.Distributed.Closure.TH
import Control.Exception (bracket)
import Control.Monad (forM_, foldM)
import qualified Data.ByteString             as BS
import qualified Data.ByteString.Unsafe      as BS
import Data.Int
import Data.Singletons (SingI, Proxy(..))
import qualified Data.Text as Text
import qualified Data.Text.Foreign as Text
import qualified Data.Vector as V
import qualified Data.Vector.Mutable as VM
import qualified Data.Vector.Storable as VS
import Data.Word
import Foreign.C.Types (CChar)
import Foreign.ForeignPtr (newForeignPtr_, withForeignPtr)
import Foreign.JNI
import Foreign.Ptr
import Foreign.Storable
import Language.Java
import Language.Java.Inline

imports "io.tweag.jvm.batching.*"

-- | A class of types whose values can be marshaled in batches.
class (Interpretation a, SingI (Batch a)) => Batchable (a :: k) where
  -- | The type of java batches for reifying and reflecting values of type @a@.
  type family Batch a :: JType

-- | A class for batching reification of values.
--
-- It has a method to create a batcher that creates batches in Java, and
-- another method that refies a batch into a vector of haskell values.
--
-- The type of the batch used to appear as a class parameter but we run into
-- https://ghc.haskell.org/trac/ghc/ticket/13582
--
class Batchable a => BatchReify a where
  -- | Produces a batcher that aggregates elements of type @ty@ (such as @int@)
  -- and produces collections of type @Batch a@ (such as @int[]@).
  newBatchWriter
    :: proxy a
    -> IO (J ('Iface "io.tweag.jvm.batching.BatchWriter"
                 <> [Interp a, Batch a]
             )
          )

  -- The default implementation makes calls to the JVM for each element in the
  -- batch.
  default newBatchWriter
    :: (Batch a ~ 'Array (Interp a))
    => proxy a
    -> IO (J ('Iface "io.tweag.jvm.batching.BatchWriter"
                 <> [Interp a, Batch a]
             )
          )
  newBatchWriter proxy a
_ = J ('Iface "io.tweag.jvm.batching.BatchWriter")
-> J ('Iface "io.tweag.jvm.batching.BatchWriter"
      <> '[Interp a, 'Array (Interp a)])
forall (a :: JType) (g :: [JType]). J a -> J (a <> g)
generic (J ('Iface "io.tweag.jvm.batching.BatchWriter")
 -> J ('Iface "io.tweag.jvm.batching.BatchWriter"
       <> '[Interp a, 'Array (Interp a)]))
-> IO (J ('Iface "io.tweag.jvm.batching.BatchWriter"))
-> IO
     (J ('Iface "io.tweag.jvm.batching.BatchWriter"
         <> '[Interp a, 'Array (Interp a)]))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [java| new BatchWriters.ObjectBatchWriter() |]

  -- | Reifies the values in a batch of type @Batch a@.
  -- Gets the batch and the amount of elements it contains.
  reifyBatch :: J (Batch a) -> Int32 -> IO (V.Vector a)

  -- The default implementation makes calls to the JVM for each element in the
  -- batch.
  default reifyBatch
    :: (Reify a, Batch a ~ 'Array (Interp a))
    => J (Batch a) -> Int32 -> IO (V.Vector a)
  reifyBatch J (Batch a)
jxs Int32
size =
      Int -> (Int -> IO a) -> IO (Vector a)
forall (m :: * -> *) a.
Monad m =>
Int -> (Int -> m a) -> m (Vector a)
V.generateM (Int32 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int32
size) ((Int -> IO a) -> IO (Vector a)) -> (Int -> IO a) -> IO (Vector a)
forall a b. (a -> b) -> a -> b
$ \Int
i ->
      JArray (Interp a) -> Int32 -> IO (J Any)
forall (a :: JType) o.
(IsReferenceType a, Coercible o (J a)) =>
JArray a -> Int32 -> IO o
getObjectArrayElement JArray (Interp a)
J (Batch a)
jxs (Int -> Int32
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
i) IO (J Any) -> (J Any -> IO a) -> IO a
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= J (Interp a) -> IO a
forall a. Reify a => J (Interp a) -> IO a
reify (J (Interp a) -> IO a) -> (J Any -> J (Interp a)) -> J Any -> IO a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. J Any -> J (Interp a)
forall (a :: JType) (b :: JType). J a -> J b
unsafeCast

-- | Helper for reifying batches of primitive types
reifyPrimitiveBatch
  :: Storable a
  => (J ('Array ty) -> IO (Ptr a))
  -> (J ('Array ty) -> Ptr a -> IO ())
  -> J ('Array ty) -> Int32 -> IO (V.Vector a)
reifyPrimitiveBatch :: (J ('Array ty) -> IO (Ptr a))
-> (J ('Array ty) -> Ptr a -> IO ())
-> J ('Array ty)
-> Int32
-> IO (Vector a)
reifyPrimitiveBatch J ('Array ty) -> IO (Ptr a)
getArrayElements J ('Array ty) -> Ptr a -> IO ()
releaseArrayElements J ('Array ty)
jxs Int32
size = do
    IO (Ptr a)
-> (Ptr a -> IO ()) -> (Ptr a -> IO (Vector a)) -> IO (Vector a)
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket (J ('Array ty) -> IO (Ptr a)
getArrayElements J ('Array ty)
jxs) (J ('Array ty) -> Ptr a -> IO ()
releaseArrayElements J ('Array ty)
jxs)
      ((Ptr a -> IO (Vector a)) -> IO (Vector a))
-> (Ptr a -> IO (Vector a)) -> IO (Vector a)
forall a b. (a -> b) -> a -> b
$ Int -> (Int -> IO a) -> IO (Vector a)
forall (m :: * -> *) a.
Monad m =>
Int -> (Int -> m a) -> m (Vector a)
V.generateM (Int32 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int32
size) ((Int -> IO a) -> IO (Vector a))
-> (Ptr a -> Int -> IO a) -> Ptr a -> IO (Vector a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ptr a -> Int -> IO a
forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff

-- | Batches of arrays of variable length
--
-- The first component is an array or batch B containing the elements
-- of all the arrays in the batch. The second component is an array of
-- offsets F. The ith position in the offset array is the first position
-- in B after the ith array of the batch.
--
-- Thus, the first array of the batch can be found in B between the
-- indices 0 and F[0], the second array of the batch is between the
-- indices F[0] and F[1], and so on.
--
type ArrayBatch ty =
    'Class "io.tweag.jvm.batching.Tuple2" <>
       '[ ty
        , 'Array ('Prim "int")
        ]

-- | Helper for reifying batches of vectors
--
-- Arrays are batched with two arrays. One of the arrays contains the result
-- of appending all of the vectors in the batch. The other array contains the
-- offset of each vector in the resulting array.
--
reifyArrayBatch
  :: forall a b ty.
     (Int32 -> J ty -> IO a) -- ^ reify the array/batch of values
                             -- (takes the amount of elements in the array)
  -> (Int -> Int -> a -> IO b) -- ^ slice at a given offset of given length of some array a
  -> J (ArrayBatch ty)
  -> Int32
  -> IO (V.Vector b)
reifyArrayBatch :: (Int32 -> J ty -> IO a)
-> (Int -> Int -> a -> IO b)
-> J (ArrayBatch ty)
-> Int32
-> IO (Vector b)
reifyArrayBatch Int32 -> J ty -> IO a
reifyB Int -> Int -> a -> IO b
slice J (ArrayBatch ty)
batch0 Int32
batchSize = do
    let batch :: J ('Class "io.tweag.jvm.batching.Tuple2")
batch = J (ArrayBatch ty) -> J ('Class "io.tweag.jvm.batching.Tuple2")
forall (a :: JType) (g :: [JType]). J (a <> g) -> J a
unsafeUngeneric J (ArrayBatch ty)
batch0
    Vector Int32
arrayEnds <- J ('Class "io.tweag.jvm.batching.Tuple2") -> IO (Vector Int32)
reifyArrayOffsets J ('Class "io.tweag.jvm.batching.Tuple2")
batch
    a
arrayValues <- Vector Int32 -> J ('Class "io.tweag.jvm.batching.Tuple2") -> IO a
reifyArrayValues Vector Int32
arrayEnds J ('Class "io.tweag.jvm.batching.Tuple2")
batch
    Vector Int32 -> a -> IO (Vector b)
reifySlices Vector Int32
arrayEnds a
arrayValues
  where
    fromObject :: JObject -> J x
    fromObject :: JObject -> J x
fromObject = JObject -> J x
forall (a :: JType) (b :: JType). J a -> J b
unsafeCast

    reifyArrayOffsets
      :: J ('Class "io.tweag.jvm.batching.Tuple2") -> IO (VS.Vector Int32)
    reifyArrayOffsets :: J ('Class "io.tweag.jvm.batching.Tuple2") -> IO (Vector Int32)
reifyArrayOffsets J ('Class "io.tweag.jvm.batching.Tuple2")
batch = [java| (int[])$batch._2 |] IO (J ('Array ('Prim "int")))
-> (J ('Array ('Prim "int")) -> IO (Vector Int32))
-> IO (Vector Int32)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= J ('Array ('Prim "int")) -> IO (Vector Int32)
forall a. Reify a => J (Interp a) -> IO a
reify

    reifyArrayValues :: Vector Int32 -> J ('Class "io.tweag.jvm.batching.Tuple2") -> IO a
reifyArrayValues Vector Int32
arrayEnds J ('Class "io.tweag.jvm.batching.Tuple2")
batch = do
      let count :: Int32
count = if Vector Int32 -> Bool
forall a. Storable a => Vector a -> Bool
VS.null Vector Int32
arrayEnds then Int32
0 else Vector Int32 -> Int32
forall a. Storable a => Vector a -> a
VS.last Vector Int32
arrayEnds
      [java| $batch._1 |] IO JObject -> (JObject -> IO a) -> IO a
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Int32 -> J ty -> IO a
reifyB Int32
count (J ty -> IO a) -> (JObject -> J ty) -> JObject -> IO a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. JObject -> J ty
forall (x :: JType). JObject -> J x
fromObject

    reifySlices :: Vector Int32 -> a -> IO (Vector b)
reifySlices Vector Int32
arrayEnds a
arrayValues = do
      IOVector b
result <- Int -> IO (MVector (PrimState IO) b)
forall (m :: * -> *) a.
PrimMonad m =>
Int -> m (MVector (PrimState m) a)
VM.new (Int32 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int32
batchSize)
      Int32
_ <- (Int32 -> Int -> IO Int32) -> Int32 -> [Int] -> IO Int32
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM
           (IOVector b -> Vector Int32 -> a -> Int32 -> Int -> IO Int32
writeSliceToVector IOVector b
result Vector Int32
arrayEnds a
arrayValues)
           Int32
0
           [Int
0 .. Int32 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int32
batchSize Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]
      MVector (PrimState IO) b -> IO (Vector b)
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> m (Vector a)
V.unsafeFreeze IOVector b
MVector (PrimState IO) b
result

    writeSliceToVector
      :: VM.IOVector b   -- ^ output vector to write to
      -> VS.Vector Int32 -- ^ ends[i] holds the offset of the (i+1)th slice
      -> a               -- ^ input vector to read slices from
      -> Int32           -- ^ offset of the slice to read from the input vector
      -> Int             -- ^ index of the position to write in the output vector
      -> IO Int32        -- ^ offset of the next slice to read
    writeSliceToVector :: IOVector b -> Vector Int32 -> a -> Int32 -> Int -> IO Int32
writeSliceToVector IOVector b
output Vector Int32
arrayEnds a
arrayValues Int32
offset Int
i = do
        Int -> Int -> a -> IO b
slice (Int32 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int32
offset)
              (Int32 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int32 -> Int) -> Int32 -> Int
forall a b. (a -> b) -> a -> b
$ Vector Int32
arrayEnds Vector Int32 -> Int -> Int32
forall a. Storable a => Vector a -> Int -> a
VS.! Int
i Int32 -> Int32 -> Int32
forall a. Num a => a -> a -> a
- Int32
offset) a
arrayValues
          IO b -> (b -> IO ()) -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= MVector (PrimState IO) b -> Int -> b -> IO ()
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> a -> m ()
VM.unsafeWrite IOVector b
MVector (PrimState IO) b
output Int
i
        Int32 -> IO Int32
forall (m :: * -> *) a. Monad m => a -> m a
return (Int32 -> IO Int32) -> Int32 -> IO Int32
forall a b. (a -> b) -> a -> b
$ Vector Int32
arrayEnds Vector Int32 -> Int -> Int32
forall a. Storable a => Vector a -> Int -> a
VS.! Int
i

-- | A class for batching reflection of values.
--
-- It has a method to create a batch reader that reads batches in Java, and
-- another method that reflects a vector of haskell values into a batch.
--
-- We considered having the type of the batch appear as a class parameter but
-- we run into https://ghc.haskell.org/trac/ghc/ticket/13582
--
class Batchable a => BatchReflect a where
  -- | Produces a batch reader that receives collections of type @ty1@
  -- (such as @int[]@) and produces values of type @ty2@ (such as @int@).
  newBatchReader
    :: proxy a
    -> IO (J ('Iface "io.tweag.jvm.batching.BatchReader"
                 <> [Batch a, Interp a]
             )
          )

  -- The default implementation makes calls to the JVM for each element in the
  -- batch.
  default newBatchReader
    :: (Batch a ~ 'Array (Interp a))
    => proxy a
    -> IO (J ('Iface "io.tweag.jvm.batching.BatchReader"
                       <> [Batch a, Interp a]
             )
          )
  newBatchReader proxy a
_ =
      J ('Iface "io.tweag.jvm.batching.BatchReader")
-> J ('Iface "io.tweag.jvm.batching.BatchReader"
      <> '[ 'Array (Interp a), Interp a])
forall (a :: JType) (g :: [JType]). J a -> J (a <> g)
generic (J ('Iface "io.tweag.jvm.batching.BatchReader")
 -> J ('Iface "io.tweag.jvm.batching.BatchReader"
       <> '[ 'Array (Interp a), Interp a]))
-> IO (J ('Iface "io.tweag.jvm.batching.BatchReader"))
-> IO
     (J ('Iface "io.tweag.jvm.batching.BatchReader"
         <> '[ 'Array (Interp a), Interp a]))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [java| new BatchReaders.ObjectBatchReader() |]

  -- | Reflects the values in a vector to a batch of type @ty@.
  reflectBatch :: V.Vector a -> IO (J (Batch a))
  -- The default implementation makes calls to the JVM for each element in the
  -- batch.
  default reflectBatch
    :: (Reflect a, Batch a ~ 'Array (Interp a))
    => V.Vector a -> IO (J (Batch a))
  reflectBatch Vector a
v = do
      J ('Array (Interp a))
jxs <- Int32 -> IO (J ('Array (Interp a)))
forall (ty :: JType). SingI ty => Int32 -> IO (J ('Array ty))
newArray (Int32 -> IO (J ('Array (Interp a))))
-> Int32 -> IO (J ('Array (Interp a)))
forall a b. (a -> b) -> a -> b
$ Int -> Int32
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Vector a -> Int
forall a. Vector a -> Int
V.length Vector a
v)
      [Int] -> (Int -> IO ()) -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0 .. Vector a -> Int
forall a. Vector a -> Int
V.length Vector a
v Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] ((Int -> IO ()) -> IO ()) -> (Int -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Int
i ->
        IO (J (Interp a)) -> (J (Interp a) -> IO ()) -> IO ()
forall (m :: * -> *) o (ty :: JType) a.
(MonadMask m, MonadIO m, Coercible o (J ty)) =>
m o -> (o -> m a) -> m a
withLocalRef (a -> IO (J (Interp a))
forall a. Reflect a => a -> IO (J (Interp a))
reflect (Vector a
v Vector a -> Int -> a
forall a. Vector a -> Int -> a
V.! Int
i))
                     (J ('Array (Interp a)) -> Int32 -> J (Interp a) -> IO ()
forall (a :: JType) o.
(IsReferenceType a, Coercible o (J a)) =>
JArray a -> Int32 -> o -> IO ()
setObjectArrayElement J ('Array (Interp a))
jxs (Int -> Int32
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
i))
      J ('Array (Interp a)) -> IO (J ('Array (Interp a)))
forall (m :: * -> *) a. Monad m => a -> m a
return J ('Array (Interp a))
jxs

-- | Helper for reflecting batches of primitive types
reflectPrimitiveBatch
  :: forall a ty. (Storable a, IsPrimitiveType ty)
  => (J ('Array ty) -> Int32 -> Int32 -> Ptr a -> IO ())
  -> V.Vector a -> IO (J ('Array ty))
reflectPrimitiveBatch :: (J ('Array ty) -> Int32 -> Int32 -> Ptr a -> IO ())
-> Vector a -> IO (J ('Array ty))
reflectPrimitiveBatch J ('Array ty) -> Int32 -> Int32 -> Ptr a -> IO ()
setArrayRegion Vector a
v = do
    let (ForeignPtr a
fptr, Int
offset, Int
len) = Vector a -> (ForeignPtr a, Int, Int)
forall a. Storable a => Vector a -> (ForeignPtr a, Int, Int)
VS.unsafeToForeignPtr (Vector a -> Vector a
forall (v :: * -> *) a (w :: * -> *).
(Vector v a, Vector w a) =>
v a -> w a
V.convert Vector a
v)
    ForeignPtr a -> (Ptr a -> IO (J ('Array ty))) -> IO (J ('Array ty))
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
fptr ((Ptr a -> IO (J ('Array ty))) -> IO (J ('Array ty)))
-> (Ptr a -> IO (J ('Array ty))) -> IO (J ('Array ty))
forall a b. (a -> b) -> a -> b
$ \Ptr a
ptr -> do
      J ('Array ty)
jxs <- Int32 -> IO (J ('Array ty))
forall (ty :: JType). SingI ty => Int32 -> IO (J ('Array ty))
newArray (Int -> Int32
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
len)
      let aOffset :: Int
aOffset = Int
offset Int -> Int -> Int
forall a. Num a => a -> a -> a
* a -> Int
forall a. Storable a => a -> Int
sizeOf (a
forall a. HasCallStack => a
undefined :: a)
      J ('Array ty) -> Int32 -> Int32 -> Ptr a -> IO ()
setArrayRegion J ('Array ty)
jxs Int32
0 (Int -> Int32
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
len) (Ptr a -> Int -> Ptr a
forall a b. Ptr a -> Int -> Ptr b
plusPtr Ptr a
ptr Int
aOffset)
      J ('Array ty) -> IO (J ('Array ty))
forall (m :: * -> *) a. Monad m => a -> m a
return J ('Array ty)
jxs

-- | Helper for reflecting batches of vectors
--
-- The vector type is a, and vectors are manipulated exclusively with the
-- polymorphic functions given as arguments.
--
-- Vectors are batched with two arrays. One of the arrays contains the result
-- of appending all of the vectors in the batch. The other array contains the
-- offset of each vector in the resulting array.
--
reflectArrayBatch
  :: forall a b ty.
     (b -> IO (J ty))
  -> (a -> Int) -- ^ get length
  -> ([a] -> IO b) -- ^ concat
  -> V.Vector a
  -> IO (J (ArrayBatch ty))
reflectArrayBatch :: (b -> IO (J ty))
-> (a -> Int)
-> ([a] -> IO b)
-> Vector a
-> IO (J (ArrayBatch ty))
reflectArrayBatch b -> IO (J ty)
reflectB a -> Int
getLength [a] -> IO b
concatenate Vector a
vecs = do
    let ends :: Vector Int32
ends = Vector Int32 -> Vector Int32
forall (v :: * -> *) a (w :: * -> *).
(Vector v a, Vector w a) =>
v a -> w a
V.convert (Vector Int32 -> Vector Int32) -> Vector Int32 -> Vector Int32
forall a b. (a -> b) -> a -> b
$ (Int32 -> Int32 -> Int32) -> Int32 -> Vector Int32 -> Vector Int32
forall a b. (a -> b -> a) -> a -> Vector b -> Vector a
V.postscanl' Int32 -> Int32 -> Int32
forall a. Num a => a -> a -> a
(+) Int32
0 (Vector Int32 -> Vector Int32) -> Vector Int32 -> Vector Int32
forall a b. (a -> b) -> a -> b
$
                (a -> Int32) -> Vector a -> Vector Int32
forall a b. (a -> b) -> Vector a -> Vector b
V.map (Int -> Int32
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Int32) -> (a -> Int) -> a -> Int32
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Int
getLength) Vector a
vecs
                  :: VS.Vector Int32
    b
bigvec <- [a] -> IO b
concatenate ([a] -> IO b) -> [a] -> IO b
forall a b. (a -> b) -> a -> b
$ Vector a -> [a]
forall a. Vector a -> [a]
V.toList Vector a
vecs
    J ty
jvec <- b -> IO (J ty)
reflectB b
bigvec
    J ('Array ('Prim "int"))
jends <- Vector Int32 -> IO (J (Interp (Vector Int32)))
forall a. Reflect a => a -> IO (J (Interp a))
reflect Vector Int32
ends
    J ('Class "io.tweag.jvm.batching.Tuple2") -> J (ArrayBatch ty)
forall (a :: JType) (g :: [JType]). J a -> J (a <> g)
generic (J ('Class "io.tweag.jvm.batching.Tuple2") -> J (ArrayBatch ty))
-> IO (J ('Class "io.tweag.jvm.batching.Tuple2"))
-> IO (J (ArrayBatch ty))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> JObject
-> JObject -> IO (J ('Class "io.tweag.jvm.batching.Tuple2"))
forall a f (sym :: Symbol).
(Ty a ~ 'Class sym, Coercible a (J ('Class sym)), Coercible a,
 VariadicIO f a) =>
f
Language.Java.new (J ty -> JObject
forall (a :: JType). J a -> JObject
upcast J ty
jvec) (J ('Array ('Prim "int")) -> JObject
forall (a :: JType). J a -> JObject
upcast J ('Array ('Prim "int"))
jends)

withStatic [d|
  instance Batchable Bool where
    type Batch Bool = 'Array ('Prim "boolean")

  instance BatchReify Bool where
    newBatchWriter _ = [java| new BatchWriters.BooleanBatchWriter() |]
    reifyBatch jxs size = do
        let toBool w = if w == 0 then False else True
        bracket (getBooleanArrayElements jxs)
                (releaseBooleanArrayElements jxs) $
          \arr -> V.generateM (fromIntegral size)
                              ((toBool <$>) . peekElemOff arr)

  instance Batchable CChar where
    type Batch CChar = 'Array ('Prim "byte")

  instance BatchReify CChar where
    newBatchWriter _ = [java| new BatchWriters.ByteBatchWriter() |]
    reifyBatch =
      reifyPrimitiveBatch getByteArrayElements releaseByteArrayElements

  instance Batchable Word16 where
    type Batch Word16 = 'Array ('Prim "char")

  instance BatchReify Word16 where
    newBatchWriter _ = [java| new BatchWriters.CharacterBatchWriter() |]
    reifyBatch =
      reifyPrimitiveBatch getCharArrayElements releaseCharArrayElements

  instance Batchable Int16 where
    type Batch Int16 = 'Array ('Prim "short")

  instance BatchReify Int16 where
    newBatchWriter _ = [java| new BatchWriters.ShortBatchWriter() |]
    reifyBatch =
      reifyPrimitiveBatch getShortArrayElements releaseShortArrayElements

  instance Batchable Int32 where
    type Batch Int32 = 'Array ('Prim "int")

  instance BatchReify Int32 where
    newBatchWriter _ = [java| new BatchWriters.IntegerBatchWriter() |]
    reifyBatch =
      reifyPrimitiveBatch getIntArrayElements releaseIntArrayElements

  instance Batchable Int64 where
    type Batch Int64 = 'Array ('Prim "long")

  instance BatchReify Int64 where
    newBatchWriter _ = [java| new BatchWriters.LongBatchWriter() |]
    reifyBatch =
      reifyPrimitiveBatch getLongArrayElements releaseLongArrayElements

  instance Batchable Float where
    type Batch Float = 'Array ('Prim "float")

  instance BatchReify Float where
    newBatchWriter _ = [java| new BatchWriters.FloatBatchWriter() |]
    reifyBatch =
      reifyPrimitiveBatch getFloatArrayElements releaseFloatArrayElements

  instance Batchable Double where
    type Batch Double = 'Array ('Prim "double")

  instance BatchReify Double where
    newBatchWriter _ = [java| new BatchWriters.DoubleBatchWriter() |]
    reifyBatch =
      reifyPrimitiveBatch getDoubleArrayElements releaseDoubleArrayElements

  instance BatchReflect Bool where
    newBatchReader _ = [java| new BatchReaders.BooleanBatchReader() |]
    reflectBatch = reflectPrimitiveBatch setBooleanArrayRegion
                 . V.map (\w -> if w then 1 else 0)

  instance BatchReflect CChar where
    newBatchReader _ = [java| new BatchReaders.ByteBatchReader() |]
    reflectBatch = reflectPrimitiveBatch setByteArrayRegion

  instance BatchReflect Word16 where
    newBatchReader _ = [java| new BatchReaders.CharacterBatchReader() |]
    reflectBatch = reflectPrimitiveBatch setCharArrayRegion

  instance BatchReflect Int16 where
    newBatchReader _ = [java| new BatchReaders.ShortBatchReader() |]
    reflectBatch = reflectPrimitiveBatch setShortArrayRegion

  instance BatchReflect Int32 where
    newBatchReader _ = [java| new BatchReaders.IntegerBatchReader() |]
    reflectBatch = reflectPrimitiveBatch setIntArrayRegion

  instance BatchReflect Int64 where
    newBatchReader _ = [java| new BatchReaders.LongBatchReader() |]
    reflectBatch = reflectPrimitiveBatch setLongArrayRegion

  instance BatchReflect Float where
    newBatchReader _ = [java| new BatchReaders.FloatBatchReader() |]
    reflectBatch = reflectPrimitiveBatch setFloatArrayRegion

  instance BatchReflect Double where
    newBatchReader _ = [java| new BatchReaders.DoubleBatchReader() |]
    reflectBatch = reflectPrimitiveBatch setDoubleArrayRegion

#if ! (__GLASGOW_HASKELL__ == 800 && __GLASGOW_HASKELL_PATCHLEVEL1__ == 1)
  instance Batchable BS.ByteString where
    type Batch BS.ByteString
      = 'Class "io.tweag.jvm.batching.Tuple2" <>
         '[ 'Array ('Prim "byte")
          , 'Array ('Prim "int")
          ]

  instance Batchable (VS.Vector Word16) where
    type Batch (VS.Vector Word16)
      = 'Class "io.tweag.jvm.batching.Tuple2" <>
         '[ 'Array ('Prim "char")
          , 'Array ('Prim "int")
          ]

  instance Batchable (VS.Vector Int16) where
    type Batch (VS.Vector Int16)
      = 'Class "io.tweag.jvm.batching.Tuple2" <>
         '[ 'Array ('Prim "short")
          , 'Array ('Prim "int")
          ]

  instance Batchable (VS.Vector Int32) where
    type Batch (VS.Vector Int32)
      = 'Class "io.tweag.jvm.batching.Tuple2" <>
         '[ 'Array ('Prim "int")
          , 'Array ('Prim "int")
          ]

  instance Batchable (VS.Vector Int64) where
    type Batch (VS.Vector Int64)
      = 'Class "io.tweag.jvm.batching.Tuple2" <>
         '[ 'Array ('Prim "long")
          , 'Array ('Prim "int")
          ]

  instance Batchable (VS.Vector Float) where
    type Batch (VS.Vector Float)
      = 'Class "io.tweag.jvm.batching.Tuple2" <>
         '[ 'Array ('Prim "float")
          , 'Array ('Prim "int")
          ]

  instance Batchable (VS.Vector Double) where
    type Batch (VS.Vector Double)
      = 'Class "io.tweag.jvm.batching.Tuple2" <>
         '[ 'Array ('Prim "double")
          , 'Array ('Prim "int")
          ]

  instance Batchable Text.Text where
    type Batch Text.Text
      = 'Class "io.tweag.jvm.batching.Tuple2" <>
         '[ 'Array ('Prim "char")
          , 'Array ('Prim "int")
          ]

  instance BatchReify BS.ByteString where
    newBatchWriter _ = [java| new BatchWriters.ByteArrayBatchWriter() |]
    reifyBatch = reifyArrayBatch (const reify) bsUnsafeSlice
      where
        bsUnsafeSlice :: Int -> Int -> BS.ByteString -> IO BS.ByteString
        bsUnsafeSlice offset sz = return . BS.unsafeTake sz . BS.unsafeDrop offset

  instance BatchReify (VS.Vector Word16) where
    newBatchWriter _ = [java| new BatchWriters.CharArrayBatchWriter() |]
    reifyBatch =
        reifyArrayBatch (const reify) (fmap (fmap return) . VS.unsafeSlice)

  instance BatchReify (VS.Vector Int16) where
    newBatchWriter _ = [java| new BatchWriters.ShortArrayBatchWriter() |]
    reifyBatch =
        reifyArrayBatch (const reify) (fmap (fmap return) . VS.unsafeSlice)

  instance BatchReify (VS.Vector Int32) where
    newBatchWriter _ = [java| new BatchWriters.IntArrayBatchWriter() |]
    reifyBatch =
        reifyArrayBatch (const reify) (fmap (fmap return) . VS.unsafeSlice)

  instance BatchReify (VS.Vector Int64) where
    newBatchWriter _ = [java| new BatchWriters.LongArrayBatchWriter() |]
    reifyBatch =
        reifyArrayBatch (const reify) (fmap (fmap return) . VS.unsafeSlice)

  instance BatchReify (VS.Vector Float) where
    newBatchWriter _ = [java| new BatchWriters.FloatArrayBatchWriter() |]
    reifyBatch =
        reifyArrayBatch (const reify) (fmap (fmap return) . VS.unsafeSlice)

  instance BatchReify (VS.Vector Double) where
    newBatchWriter _ = [java| new BatchWriters.DoubleArrayBatchWriter() |]
    reifyBatch =
        reifyArrayBatch (const reify) (fmap (fmap return) . VS.unsafeSlice)

  instance BatchReify Text.Text where
    newBatchWriter _ = [java| new BatchWriters.StringArrayBatchWriter() |]
    reifyBatch = reifyArrayBatch (const reify) $ \o n vs ->
                  (VS.unsafeWith (VS.unsafeSlice o n vs) $ \ptr ->
                      Text.fromPtr ptr (fromIntegral n)
                  )

  instance BatchReflect BS.ByteString where
    newBatchReader _ = [java| new BatchReaders.ByteArrayBatchReader() |]
    reflectBatch = reflectArrayBatch reflect BS.length (return . BS.concat)

  instance BatchReflect (VS.Vector Word16) where
    newBatchReader _ = [java| new BatchReaders.CharArrayBatchReader() |]
    reflectBatch = reflectArrayBatch reflect VS.length (return . VS.concat)

  instance BatchReflect (VS.Vector Int16) where
    newBatchReader _ = [java| new BatchReaders.ShortArrayBatchReader() |]
    reflectBatch = reflectArrayBatch reflect VS.length (return . VS.concat)

  instance BatchReflect (VS.Vector Int32) where
    newBatchReader _ = [java| new BatchReaders.IntArrayBatchReader() |]
    reflectBatch = reflectArrayBatch reflect VS.length (return . VS.concat)

  instance BatchReflect (VS.Vector Int64) where
    newBatchReader _ = [java| new BatchReaders.LongArrayBatchReader() |]
    reflectBatch = reflectArrayBatch reflect VS.length (return . VS.concat)

  instance BatchReflect (VS.Vector Float) where
    newBatchReader _ = [java| new BatchReaders.FloatArrayBatchReader() |]
    reflectBatch = reflectArrayBatch reflect VS.length (return . VS.concat)

  instance BatchReflect (VS.Vector Double) where
    newBatchReader _ = [java| new BatchReaders.DoubleArrayBatchReader() |]
    reflectBatch = reflectArrayBatch reflect VS.length (return . VS.concat)

  instance BatchReflect Text.Text where
    newBatchReader _ = [java| new BatchReaders.StringArrayBatchReader() |]
    reflectBatch = reflectArrayBatch reflect Text.length $ \ts ->
                     Text.useAsPtr (Text.concat ts) $ \ptr len ->
                       (`VS.unsafeFromForeignPtr0` fromIntegral len)
                         <$> newForeignPtr_ ptr
#endif

  instance Interpretation a => Interpretation (V.Vector a) where
    type Interp (V.Vector a) = 'Array (Interp a)

  -- TODO: Fix GHC so it doesn't complain that variables used exclusively in
  -- quasiquotes are unused. Thus we can stop prepending variable names with
  -- '_'.

  instance (Interpretation a, BatchReify a)
           => Reify (V.Vector a) where
    reify jv = do
        _batcher <- unsafeUngeneric <$> newBatchWriter (Proxy :: Proxy a)
        n <- getArrayLength jv
        let _jvo = arrayUpcast jv
        batch <- [java| {
          $_batcher.start($n);
          for(int i=0;i<$_jvo.length;i++)
             $_batcher.set(i, $_jvo[i]);
          return $_batcher.getBatch();
          } |]
        reifyBatch (fromObject batch) n
      where
        fromObject :: JObject -> J x
        fromObject = unsafeCast

  instance (Interpretation a, BatchReflect a)
           => Reflect (V.Vector a) where
    reflect v = do
        _batch <- upcast <$> reflectBatch v
        _batchReader <-
          unsafeUngeneric <$> newBatchReader (Proxy :: Proxy a)
        jv <- [java| {
          $_batchReader.setBatch($_batch);
          return $_batchReader.getSize();
          } |] >>= newArray :: IO (J ('Array (Interp a)))
        let _jvo = arrayUpcast jv
        () <- [java| {
          for(int i=0;i<$_jvo.length;i++)
            $_jvo[i] = $_batchReader.get(i);
          } |]
        return jv

  instance Batchable a => Batchable (V.Vector a) where
    type Batch (V.Vector a) = ArrayBatch (Batch a)

  instance BatchReify a => BatchReify (V.Vector a) where
    newBatchWriter _ = do
        _b <- unsafeUngeneric <$> newBatchWriter (Proxy :: Proxy a)
        generic <$> [java| new BatchWriters.ObjectArrayBatchWriter($_b) |]
    reifyBatch =
        reifyArrayBatch (flip reifyBatch) (fmap (fmap return) . V.unsafeSlice)

  instance BatchReflect a => BatchReflect (V.Vector a) where
    newBatchReader _ = do
        _b <- unsafeUngeneric <$> newBatchReader (Proxy :: Proxy a)
        generic <$> [java| new BatchReaders.ObjectArrayBatchReader($_b) |]
    reflectBatch =
        reflectArrayBatch reflectBatch V.length (return . V.concat)
 |]