{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StaticPointers #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
module Language.Java.Batching
( Batchable(..)
, BatchReify(..)
, BatchReflect(..)
, ArrayBatch
) where
import Control.Distributed.Closure.TH
import Control.Exception (bracket)
import Control.Monad (forM_, foldM)
import qualified Data.ByteString as BS
import qualified Data.ByteString.Unsafe as BS
import Data.Int
import Data.Singletons (SingI, Proxy(..))
import qualified Data.Text as Text
import qualified Data.Text.Foreign as Text
import qualified Data.Vector as V
import qualified Data.Vector.Mutable as VM
import qualified Data.Vector.Storable as VS
import Data.Word
import Foreign.C.Types (CChar)
import Foreign.ForeignPtr (newForeignPtr_, withForeignPtr)
import Foreign.JNI
import Foreign.Ptr
import Foreign.Storable
import Language.Java
import Language.Java.Inline
imports "io.tweag.jvm.batching.*"
class (Interpretation a, SingI (Batch a)) => Batchable (a :: k) where
type family Batch a :: JType
class Batchable a => BatchReify a where
newBatchWriter
:: proxy a
-> IO (J ('Iface "io.tweag.jvm.batching.BatchWriter"
<> [Interp a, Batch a]
)
)
default newBatchWriter
:: (Batch a ~ 'Array (Interp a))
=> proxy a
-> IO (J ('Iface "io.tweag.jvm.batching.BatchWriter"
<> [Interp a, Batch a]
)
)
newBatchWriter proxy a
_ = J ('Iface "io.tweag.jvm.batching.BatchWriter")
-> J ('Iface "io.tweag.jvm.batching.BatchWriter"
<> '[Interp a, 'Array (Interp a)])
forall (a :: JType) (g :: [JType]). J a -> J (a <> g)
generic (J ('Iface "io.tweag.jvm.batching.BatchWriter")
-> J ('Iface "io.tweag.jvm.batching.BatchWriter"
<> '[Interp a, 'Array (Interp a)]))
-> IO (J ('Iface "io.tweag.jvm.batching.BatchWriter"))
-> IO
(J ('Iface "io.tweag.jvm.batching.BatchWriter"
<> '[Interp a, 'Array (Interp a)]))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [java| new BatchWriters.ObjectBatchWriter() |]
reifyBatch :: J (Batch a) -> Int32 -> IO (V.Vector a)
default reifyBatch
:: (Reify a, Batch a ~ 'Array (Interp a))
=> J (Batch a) -> Int32 -> IO (V.Vector a)
reifyBatch J (Batch a)
jxs Int32
size =
Int -> (Int -> IO a) -> IO (Vector a)
forall (m :: * -> *) a.
Monad m =>
Int -> (Int -> m a) -> m (Vector a)
V.generateM (Int32 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int32
size) ((Int -> IO a) -> IO (Vector a)) -> (Int -> IO a) -> IO (Vector a)
forall a b. (a -> b) -> a -> b
$ \Int
i ->
JArray (Interp a) -> Int32 -> IO (J Any)
forall (a :: JType) o.
(IsReferenceType a, Coercible o (J a)) =>
JArray a -> Int32 -> IO o
getObjectArrayElement JArray (Interp a)
J (Batch a)
jxs (Int -> Int32
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
i) IO (J Any) -> (J Any -> IO a) -> IO a
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= J (Interp a) -> IO a
forall a. Reify a => J (Interp a) -> IO a
reify (J (Interp a) -> IO a) -> (J Any -> J (Interp a)) -> J Any -> IO a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. J Any -> J (Interp a)
forall (a :: JType) (b :: JType). J a -> J b
unsafeCast
reifyPrimitiveBatch
:: Storable a
=> (J ('Array ty) -> IO (Ptr a))
-> (J ('Array ty) -> Ptr a -> IO ())
-> J ('Array ty) -> Int32 -> IO (V.Vector a)
reifyPrimitiveBatch :: (J ('Array ty) -> IO (Ptr a))
-> (J ('Array ty) -> Ptr a -> IO ())
-> J ('Array ty)
-> Int32
-> IO (Vector a)
reifyPrimitiveBatch J ('Array ty) -> IO (Ptr a)
getArrayElements J ('Array ty) -> Ptr a -> IO ()
releaseArrayElements J ('Array ty)
jxs Int32
size = do
IO (Ptr a)
-> (Ptr a -> IO ()) -> (Ptr a -> IO (Vector a)) -> IO (Vector a)
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket (J ('Array ty) -> IO (Ptr a)
getArrayElements J ('Array ty)
jxs) (J ('Array ty) -> Ptr a -> IO ()
releaseArrayElements J ('Array ty)
jxs)
((Ptr a -> IO (Vector a)) -> IO (Vector a))
-> (Ptr a -> IO (Vector a)) -> IO (Vector a)
forall a b. (a -> b) -> a -> b
$ Int -> (Int -> IO a) -> IO (Vector a)
forall (m :: * -> *) a.
Monad m =>
Int -> (Int -> m a) -> m (Vector a)
V.generateM (Int32 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int32
size) ((Int -> IO a) -> IO (Vector a))
-> (Ptr a -> Int -> IO a) -> Ptr a -> IO (Vector a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ptr a -> Int -> IO a
forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff
type ArrayBatch ty =
'Class "io.tweag.jvm.batching.Tuple2" <>
'[ ty
, 'Array ('Prim "int")
]
reifyArrayBatch
:: forall a b ty.
(Int32 -> J ty -> IO a)
-> (Int -> Int -> a -> IO b)
-> J (ArrayBatch ty)
-> Int32
-> IO (V.Vector b)
reifyArrayBatch :: (Int32 -> J ty -> IO a)
-> (Int -> Int -> a -> IO b)
-> J (ArrayBatch ty)
-> Int32
-> IO (Vector b)
reifyArrayBatch Int32 -> J ty -> IO a
reifyB Int -> Int -> a -> IO b
slice J (ArrayBatch ty)
batch0 Int32
batchSize = do
let batch :: J ('Class "io.tweag.jvm.batching.Tuple2")
batch = J (ArrayBatch ty) -> J ('Class "io.tweag.jvm.batching.Tuple2")
forall (a :: JType) (g :: [JType]). J (a <> g) -> J a
unsafeUngeneric J (ArrayBatch ty)
batch0
Vector Int32
arrayEnds <- J ('Class "io.tweag.jvm.batching.Tuple2") -> IO (Vector Int32)
reifyArrayOffsets J ('Class "io.tweag.jvm.batching.Tuple2")
batch
a
arrayValues <- Vector Int32 -> J ('Class "io.tweag.jvm.batching.Tuple2") -> IO a
reifyArrayValues Vector Int32
arrayEnds J ('Class "io.tweag.jvm.batching.Tuple2")
batch
Vector Int32 -> a -> IO (Vector b)
reifySlices Vector Int32
arrayEnds a
arrayValues
where
fromObject :: JObject -> J x
fromObject :: JObject -> J x
fromObject = JObject -> J x
forall (a :: JType) (b :: JType). J a -> J b
unsafeCast
reifyArrayOffsets
:: J ('Class "io.tweag.jvm.batching.Tuple2") -> IO (VS.Vector Int32)
reifyArrayOffsets :: J ('Class "io.tweag.jvm.batching.Tuple2") -> IO (Vector Int32)
reifyArrayOffsets J ('Class "io.tweag.jvm.batching.Tuple2")
batch = [java| (int[])$batch._2 |] IO (J ('Array ('Prim "int")))
-> (J ('Array ('Prim "int")) -> IO (Vector Int32))
-> IO (Vector Int32)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= J ('Array ('Prim "int")) -> IO (Vector Int32)
forall a. Reify a => J (Interp a) -> IO a
reify
reifyArrayValues :: Vector Int32 -> J ('Class "io.tweag.jvm.batching.Tuple2") -> IO a
reifyArrayValues Vector Int32
arrayEnds J ('Class "io.tweag.jvm.batching.Tuple2")
batch = do
let count :: Int32
count = if Vector Int32 -> Bool
forall a. Storable a => Vector a -> Bool
VS.null Vector Int32
arrayEnds then Int32
0 else Vector Int32 -> Int32
forall a. Storable a => Vector a -> a
VS.last Vector Int32
arrayEnds
[java| $batch._1 |] IO JObject -> (JObject -> IO a) -> IO a
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Int32 -> J ty -> IO a
reifyB Int32
count (J ty -> IO a) -> (JObject -> J ty) -> JObject -> IO a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. JObject -> J ty
forall (x :: JType). JObject -> J x
fromObject
reifySlices :: Vector Int32 -> a -> IO (Vector b)
reifySlices Vector Int32
arrayEnds a
arrayValues = do
IOVector b
result <- Int -> IO (MVector (PrimState IO) b)
forall (m :: * -> *) a.
PrimMonad m =>
Int -> m (MVector (PrimState m) a)
VM.new (Int32 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int32
batchSize)
Int32
_ <- (Int32 -> Int -> IO Int32) -> Int32 -> [Int] -> IO Int32
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM
(IOVector b -> Vector Int32 -> a -> Int32 -> Int -> IO Int32
writeSliceToVector IOVector b
result Vector Int32
arrayEnds a
arrayValues)
Int32
0
[Int
0 .. Int32 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int32
batchSize Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]
MVector (PrimState IO) b -> IO (Vector b)
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> m (Vector a)
V.unsafeFreeze IOVector b
MVector (PrimState IO) b
result
writeSliceToVector
:: VM.IOVector b
-> VS.Vector Int32
-> a
-> Int32
-> Int
-> IO Int32
writeSliceToVector :: IOVector b -> Vector Int32 -> a -> Int32 -> Int -> IO Int32
writeSliceToVector IOVector b
output Vector Int32
arrayEnds a
arrayValues Int32
offset Int
i = do
Int -> Int -> a -> IO b
slice (Int32 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int32
offset)
(Int32 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int32 -> Int) -> Int32 -> Int
forall a b. (a -> b) -> a -> b
$ Vector Int32
arrayEnds Vector Int32 -> Int -> Int32
forall a. Storable a => Vector a -> Int -> a
VS.! Int
i Int32 -> Int32 -> Int32
forall a. Num a => a -> a -> a
- Int32
offset) a
arrayValues
IO b -> (b -> IO ()) -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= MVector (PrimState IO) b -> Int -> b -> IO ()
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> a -> m ()
VM.unsafeWrite IOVector b
MVector (PrimState IO) b
output Int
i
Int32 -> IO Int32
forall (m :: * -> *) a. Monad m => a -> m a
return (Int32 -> IO Int32) -> Int32 -> IO Int32
forall a b. (a -> b) -> a -> b
$ Vector Int32
arrayEnds Vector Int32 -> Int -> Int32
forall a. Storable a => Vector a -> Int -> a
VS.! Int
i
class Batchable a => BatchReflect a where
newBatchReader
:: proxy a
-> IO (J ('Iface "io.tweag.jvm.batching.BatchReader"
<> [Batch a, Interp a]
)
)
default newBatchReader
:: (Batch a ~ 'Array (Interp a))
=> proxy a
-> IO (J ('Iface "io.tweag.jvm.batching.BatchReader"
<> [Batch a, Interp a]
)
)
newBatchReader proxy a
_ =
J ('Iface "io.tweag.jvm.batching.BatchReader")
-> J ('Iface "io.tweag.jvm.batching.BatchReader"
<> '[ 'Array (Interp a), Interp a])
forall (a :: JType) (g :: [JType]). J a -> J (a <> g)
generic (J ('Iface "io.tweag.jvm.batching.BatchReader")
-> J ('Iface "io.tweag.jvm.batching.BatchReader"
<> '[ 'Array (Interp a), Interp a]))
-> IO (J ('Iface "io.tweag.jvm.batching.BatchReader"))
-> IO
(J ('Iface "io.tweag.jvm.batching.BatchReader"
<> '[ 'Array (Interp a), Interp a]))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [java| new BatchReaders.ObjectBatchReader() |]
reflectBatch :: V.Vector a -> IO (J (Batch a))
default reflectBatch
:: (Reflect a, Batch a ~ 'Array (Interp a))
=> V.Vector a -> IO (J (Batch a))
reflectBatch Vector a
v = do
J ('Array (Interp a))
jxs <- Int32 -> IO (J ('Array (Interp a)))
forall (ty :: JType). SingI ty => Int32 -> IO (J ('Array ty))
newArray (Int32 -> IO (J ('Array (Interp a))))
-> Int32 -> IO (J ('Array (Interp a)))
forall a b. (a -> b) -> a -> b
$ Int -> Int32
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Vector a -> Int
forall a. Vector a -> Int
V.length Vector a
v)
[Int] -> (Int -> IO ()) -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0 .. Vector a -> Int
forall a. Vector a -> Int
V.length Vector a
v Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] ((Int -> IO ()) -> IO ()) -> (Int -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Int
i ->
IO (J (Interp a)) -> (J (Interp a) -> IO ()) -> IO ()
forall (m :: * -> *) o (ty :: JType) a.
(MonadMask m, MonadIO m, Coercible o (J ty)) =>
m o -> (o -> m a) -> m a
withLocalRef (a -> IO (J (Interp a))
forall a. Reflect a => a -> IO (J (Interp a))
reflect (Vector a
v Vector a -> Int -> a
forall a. Vector a -> Int -> a
V.! Int
i))
(J ('Array (Interp a)) -> Int32 -> J (Interp a) -> IO ()
forall (a :: JType) o.
(IsReferenceType a, Coercible o (J a)) =>
JArray a -> Int32 -> o -> IO ()
setObjectArrayElement J ('Array (Interp a))
jxs (Int -> Int32
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
i))
J ('Array (Interp a)) -> IO (J ('Array (Interp a)))
forall (m :: * -> *) a. Monad m => a -> m a
return J ('Array (Interp a))
jxs
reflectPrimitiveBatch
:: forall a ty. (Storable a, IsPrimitiveType ty)
=> (J ('Array ty) -> Int32 -> Int32 -> Ptr a -> IO ())
-> V.Vector a -> IO (J ('Array ty))
reflectPrimitiveBatch :: (J ('Array ty) -> Int32 -> Int32 -> Ptr a -> IO ())
-> Vector a -> IO (J ('Array ty))
reflectPrimitiveBatch J ('Array ty) -> Int32 -> Int32 -> Ptr a -> IO ()
setArrayRegion Vector a
v = do
let (ForeignPtr a
fptr, Int
offset, Int
len) = Vector a -> (ForeignPtr a, Int, Int)
forall a. Storable a => Vector a -> (ForeignPtr a, Int, Int)
VS.unsafeToForeignPtr (Vector a -> Vector a
forall (v :: * -> *) a (w :: * -> *).
(Vector v a, Vector w a) =>
v a -> w a
V.convert Vector a
v)
ForeignPtr a -> (Ptr a -> IO (J ('Array ty))) -> IO (J ('Array ty))
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
fptr ((Ptr a -> IO (J ('Array ty))) -> IO (J ('Array ty)))
-> (Ptr a -> IO (J ('Array ty))) -> IO (J ('Array ty))
forall a b. (a -> b) -> a -> b
$ \Ptr a
ptr -> do
J ('Array ty)
jxs <- Int32 -> IO (J ('Array ty))
forall (ty :: JType). SingI ty => Int32 -> IO (J ('Array ty))
newArray (Int -> Int32
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
len)
let aOffset :: Int
aOffset = Int
offset Int -> Int -> Int
forall a. Num a => a -> a -> a
* a -> Int
forall a. Storable a => a -> Int
sizeOf (a
forall a. HasCallStack => a
undefined :: a)
J ('Array ty) -> Int32 -> Int32 -> Ptr a -> IO ()
setArrayRegion J ('Array ty)
jxs Int32
0 (Int -> Int32
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
len) (Ptr a -> Int -> Ptr a
forall a b. Ptr a -> Int -> Ptr b
plusPtr Ptr a
ptr Int
aOffset)
J ('Array ty) -> IO (J ('Array ty))
forall (m :: * -> *) a. Monad m => a -> m a
return J ('Array ty)
jxs
reflectArrayBatch
:: forall a b ty.
(b -> IO (J ty))
-> (a -> Int)
-> ([a] -> IO b)
-> V.Vector a
-> IO (J (ArrayBatch ty))
reflectArrayBatch :: (b -> IO (J ty))
-> (a -> Int)
-> ([a] -> IO b)
-> Vector a
-> IO (J (ArrayBatch ty))
reflectArrayBatch b -> IO (J ty)
reflectB a -> Int
getLength [a] -> IO b
concatenate Vector a
vecs = do
let ends :: Vector Int32
ends = Vector Int32 -> Vector Int32
forall (v :: * -> *) a (w :: * -> *).
(Vector v a, Vector w a) =>
v a -> w a
V.convert (Vector Int32 -> Vector Int32) -> Vector Int32 -> Vector Int32
forall a b. (a -> b) -> a -> b
$ (Int32 -> Int32 -> Int32) -> Int32 -> Vector Int32 -> Vector Int32
forall a b. (a -> b -> a) -> a -> Vector b -> Vector a
V.postscanl' Int32 -> Int32 -> Int32
forall a. Num a => a -> a -> a
(+) Int32
0 (Vector Int32 -> Vector Int32) -> Vector Int32 -> Vector Int32
forall a b. (a -> b) -> a -> b
$
(a -> Int32) -> Vector a -> Vector Int32
forall a b. (a -> b) -> Vector a -> Vector b
V.map (Int -> Int32
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Int32) -> (a -> Int) -> a -> Int32
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Int
getLength) Vector a
vecs
:: VS.Vector Int32
b
bigvec <- [a] -> IO b
concatenate ([a] -> IO b) -> [a] -> IO b
forall a b. (a -> b) -> a -> b
$ Vector a -> [a]
forall a. Vector a -> [a]
V.toList Vector a
vecs
J ty
jvec <- b -> IO (J ty)
reflectB b
bigvec
J ('Array ('Prim "int"))
jends <- Vector Int32 -> IO (J (Interp (Vector Int32)))
forall a. Reflect a => a -> IO (J (Interp a))
reflect Vector Int32
ends
J ('Class "io.tweag.jvm.batching.Tuple2") -> J (ArrayBatch ty)
forall (a :: JType) (g :: [JType]). J a -> J (a <> g)
generic (J ('Class "io.tweag.jvm.batching.Tuple2") -> J (ArrayBatch ty))
-> IO (J ('Class "io.tweag.jvm.batching.Tuple2"))
-> IO (J (ArrayBatch ty))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> JObject
-> JObject -> IO (J ('Class "io.tweag.jvm.batching.Tuple2"))
forall a f (sym :: Symbol).
(Ty a ~ 'Class sym, Coercible a (J ('Class sym)), Coercible a,
VariadicIO f a) =>
f
Language.Java.new (J ty -> JObject
forall (a :: JType). J a -> JObject
upcast J ty
jvec) (J ('Array ('Prim "int")) -> JObject
forall (a :: JType). J a -> JObject
upcast J ('Array ('Prim "int"))
jends)
withStatic [d|
instance Batchable Bool where
type Batch Bool = 'Array ('Prim "boolean")
instance BatchReify Bool where
newBatchWriter _ = [java| new BatchWriters.BooleanBatchWriter() |]
reifyBatch jxs size = do
let toBool w = if w == 0 then False else True
bracket (getBooleanArrayElements jxs)
(releaseBooleanArrayElements jxs) $
\arr -> V.generateM (fromIntegral size)
((toBool <$>) . peekElemOff arr)
instance Batchable CChar where
type Batch CChar = 'Array ('Prim "byte")
instance BatchReify CChar where
newBatchWriter _ = [java| new BatchWriters.ByteBatchWriter() |]
reifyBatch =
reifyPrimitiveBatch getByteArrayElements releaseByteArrayElements
instance Batchable Word16 where
type Batch Word16 = 'Array ('Prim "char")
instance BatchReify Word16 where
newBatchWriter _ = [java| new BatchWriters.CharacterBatchWriter() |]
reifyBatch =
reifyPrimitiveBatch getCharArrayElements releaseCharArrayElements
instance Batchable Int16 where
type Batch Int16 = 'Array ('Prim "short")
instance BatchReify Int16 where
newBatchWriter _ = [java| new BatchWriters.ShortBatchWriter() |]
reifyBatch =
reifyPrimitiveBatch getShortArrayElements releaseShortArrayElements
instance Batchable Int32 where
type Batch Int32 = 'Array ('Prim "int")
instance BatchReify Int32 where
newBatchWriter _ = [java| new BatchWriters.IntegerBatchWriter() |]
reifyBatch =
reifyPrimitiveBatch getIntArrayElements releaseIntArrayElements
instance Batchable Int64 where
type Batch Int64 = 'Array ('Prim "long")
instance BatchReify Int64 where
newBatchWriter _ = [java| new BatchWriters.LongBatchWriter() |]
reifyBatch =
reifyPrimitiveBatch getLongArrayElements releaseLongArrayElements
instance Batchable Float where
type Batch Float = 'Array ('Prim "float")
instance BatchReify Float where
newBatchWriter _ = [java| new BatchWriters.FloatBatchWriter() |]
reifyBatch =
reifyPrimitiveBatch getFloatArrayElements releaseFloatArrayElements
instance Batchable Double where
type Batch Double = 'Array ('Prim "double")
instance BatchReify Double where
newBatchWriter _ = [java| new BatchWriters.DoubleBatchWriter() |]
reifyBatch =
reifyPrimitiveBatch getDoubleArrayElements releaseDoubleArrayElements
instance BatchReflect Bool where
newBatchReader _ = [java| new BatchReaders.BooleanBatchReader() |]
reflectBatch = reflectPrimitiveBatch setBooleanArrayRegion
. V.map (\w -> if w then 1 else 0)
instance BatchReflect CChar where
newBatchReader _ = [java| new BatchReaders.ByteBatchReader() |]
reflectBatch = reflectPrimitiveBatch setByteArrayRegion
instance BatchReflect Word16 where
newBatchReader _ = [java| new BatchReaders.CharacterBatchReader() |]
reflectBatch = reflectPrimitiveBatch setCharArrayRegion
instance BatchReflect Int16 where
newBatchReader _ = [java| new BatchReaders.ShortBatchReader() |]
reflectBatch = reflectPrimitiveBatch setShortArrayRegion
instance BatchReflect Int32 where
newBatchReader _ = [java| new BatchReaders.IntegerBatchReader() |]
reflectBatch = reflectPrimitiveBatch setIntArrayRegion
instance BatchReflect Int64 where
newBatchReader _ = [java| new BatchReaders.LongBatchReader() |]
reflectBatch = reflectPrimitiveBatch setLongArrayRegion
instance BatchReflect Float where
newBatchReader _ = [java| new BatchReaders.FloatBatchReader() |]
reflectBatch = reflectPrimitiveBatch setFloatArrayRegion
instance BatchReflect Double where
newBatchReader _ = [java| new BatchReaders.DoubleBatchReader() |]
reflectBatch = reflectPrimitiveBatch setDoubleArrayRegion
#if ! (__GLASGOW_HASKELL__ == 800 && __GLASGOW_HASKELL_PATCHLEVEL1__ == 1)
instance Batchable BS.ByteString where
type Batch BS.ByteString
= 'Class "io.tweag.jvm.batching.Tuple2" <>
'[ 'Array ('Prim "byte")
, 'Array ('Prim "int")
]
instance Batchable (VS.Vector Word16) where
type Batch (VS.Vector Word16)
= 'Class "io.tweag.jvm.batching.Tuple2" <>
'[ 'Array ('Prim "char")
, 'Array ('Prim "int")
]
instance Batchable (VS.Vector Int16) where
type Batch (VS.Vector Int16)
= 'Class "io.tweag.jvm.batching.Tuple2" <>
'[ 'Array ('Prim "short")
, 'Array ('Prim "int")
]
instance Batchable (VS.Vector Int32) where
type Batch (VS.Vector Int32)
= 'Class "io.tweag.jvm.batching.Tuple2" <>
'[ 'Array ('Prim "int")
, 'Array ('Prim "int")
]
instance Batchable (VS.Vector Int64) where
type Batch (VS.Vector Int64)
= 'Class "io.tweag.jvm.batching.Tuple2" <>
'[ 'Array ('Prim "long")
, 'Array ('Prim "int")
]
instance Batchable (VS.Vector Float) where
type Batch (VS.Vector Float)
= 'Class "io.tweag.jvm.batching.Tuple2" <>
'[ 'Array ('Prim "float")
, 'Array ('Prim "int")
]
instance Batchable (VS.Vector Double) where
type Batch (VS.Vector Double)
= 'Class "io.tweag.jvm.batching.Tuple2" <>
'[ 'Array ('Prim "double")
, 'Array ('Prim "int")
]
instance Batchable Text.Text where
type Batch Text.Text
= 'Class "io.tweag.jvm.batching.Tuple2" <>
'[ 'Array ('Prim "char")
, 'Array ('Prim "int")
]
instance BatchReify BS.ByteString where
newBatchWriter _ = [java| new BatchWriters.ByteArrayBatchWriter() |]
reifyBatch = reifyArrayBatch (const reify) bsUnsafeSlice
where
bsUnsafeSlice :: Int -> Int -> BS.ByteString -> IO BS.ByteString
bsUnsafeSlice offset sz = return . BS.unsafeTake sz . BS.unsafeDrop offset
instance BatchReify (VS.Vector Word16) where
newBatchWriter _ = [java| new BatchWriters.CharArrayBatchWriter() |]
reifyBatch =
reifyArrayBatch (const reify) (fmap (fmap return) . VS.unsafeSlice)
instance BatchReify (VS.Vector Int16) where
newBatchWriter _ = [java| new BatchWriters.ShortArrayBatchWriter() |]
reifyBatch =
reifyArrayBatch (const reify) (fmap (fmap return) . VS.unsafeSlice)
instance BatchReify (VS.Vector Int32) where
newBatchWriter _ = [java| new BatchWriters.IntArrayBatchWriter() |]
reifyBatch =
reifyArrayBatch (const reify) (fmap (fmap return) . VS.unsafeSlice)
instance BatchReify (VS.Vector Int64) where
newBatchWriter _ = [java| new BatchWriters.LongArrayBatchWriter() |]
reifyBatch =
reifyArrayBatch (const reify) (fmap (fmap return) . VS.unsafeSlice)
instance BatchReify (VS.Vector Float) where
newBatchWriter _ = [java| new BatchWriters.FloatArrayBatchWriter() |]
reifyBatch =
reifyArrayBatch (const reify) (fmap (fmap return) . VS.unsafeSlice)
instance BatchReify (VS.Vector Double) where
newBatchWriter _ = [java| new BatchWriters.DoubleArrayBatchWriter() |]
reifyBatch =
reifyArrayBatch (const reify) (fmap (fmap return) . VS.unsafeSlice)
instance BatchReify Text.Text where
newBatchWriter _ = [java| new BatchWriters.StringArrayBatchWriter() |]
reifyBatch = reifyArrayBatch (const reify) $ \o n vs ->
(VS.unsafeWith (VS.unsafeSlice o n vs) $ \ptr ->
Text.fromPtr ptr (fromIntegral n)
)
instance BatchReflect BS.ByteString where
newBatchReader _ = [java| new BatchReaders.ByteArrayBatchReader() |]
reflectBatch = reflectArrayBatch reflect BS.length (return . BS.concat)
instance BatchReflect (VS.Vector Word16) where
newBatchReader _ = [java| new BatchReaders.CharArrayBatchReader() |]
reflectBatch = reflectArrayBatch reflect VS.length (return . VS.concat)
instance BatchReflect (VS.Vector Int16) where
newBatchReader _ = [java| new BatchReaders.ShortArrayBatchReader() |]
reflectBatch = reflectArrayBatch reflect VS.length (return . VS.concat)
instance BatchReflect (VS.Vector Int32) where
newBatchReader _ = [java| new BatchReaders.IntArrayBatchReader() |]
reflectBatch = reflectArrayBatch reflect VS.length (return . VS.concat)
instance BatchReflect (VS.Vector Int64) where
newBatchReader _ = [java| new BatchReaders.LongArrayBatchReader() |]
reflectBatch = reflectArrayBatch reflect VS.length (return . VS.concat)
instance BatchReflect (VS.Vector Float) where
newBatchReader _ = [java| new BatchReaders.FloatArrayBatchReader() |]
reflectBatch = reflectArrayBatch reflect VS.length (return . VS.concat)
instance BatchReflect (VS.Vector Double) where
newBatchReader _ = [java| new BatchReaders.DoubleArrayBatchReader() |]
reflectBatch = reflectArrayBatch reflect VS.length (return . VS.concat)
instance BatchReflect Text.Text where
newBatchReader _ = [java| new BatchReaders.StringArrayBatchReader() |]
reflectBatch = reflectArrayBatch reflect Text.length $ \ts ->
Text.useAsPtr (Text.concat ts) $ \ptr len ->
(`VS.unsafeFromForeignPtr0` fromIntegral len)
<$> newForeignPtr_ ptr
#endif
instance Interpretation a => Interpretation (V.Vector a) where
type Interp (V.Vector a) = 'Array (Interp a)
instance (Interpretation a, BatchReify a)
=> Reify (V.Vector a) where
reify jv = do
_batcher <- unsafeUngeneric <$> newBatchWriter (Proxy :: Proxy a)
n <- getArrayLength jv
let _jvo = arrayUpcast jv
batch <- [java| {
$_batcher.start($n);
for(int i=0;i<$_jvo.length;i++)
$_batcher.set(i, $_jvo[i]);
return $_batcher.getBatch();
} |]
reifyBatch (fromObject batch) n
where
fromObject :: JObject -> J x
fromObject = unsafeCast
instance (Interpretation a, BatchReflect a)
=> Reflect (V.Vector a) where
reflect v = do
_batch <- upcast <$> reflectBatch v
_batchReader <-
unsafeUngeneric <$> newBatchReader (Proxy :: Proxy a)
jv <- [java| {
$_batchReader.setBatch($_batch);
return $_batchReader.getSize();
} |] >>= newArray :: IO (J ('Array (Interp a)))
let _jvo = arrayUpcast jv
() <- [java| {
for(int i=0;i<$_jvo.length;i++)
$_jvo[i] = $_batchReader.get(i);
} |]
return jv
instance Batchable a => Batchable (V.Vector a) where
type Batch (V.Vector a) = ArrayBatch (Batch a)
instance BatchReify a => BatchReify (V.Vector a) where
newBatchWriter _ = do
_b <- unsafeUngeneric <$> newBatchWriter (Proxy :: Proxy a)
generic <$> [java| new BatchWriters.ObjectArrayBatchWriter($_b) |]
reifyBatch =
reifyArrayBatch (flip reifyBatch) (fmap (fmap return) . V.unsafeSlice)
instance BatchReflect a => BatchReflect (V.Vector a) where
newBatchReader _ = do
_b <- unsafeUngeneric <$> newBatchReader (Proxy :: Proxy a)
generic <$> [java| new BatchReaders.ObjectArrayBatchReader($_b) |]
reflectBatch =
reflectArrayBatch reflectBatch V.length (return . V.concat)
|]