module TensorFlow.Records
(
putTFRecord
, getTFRecord
, getTFRecords
, getTFRecordLength
, getTFRecordData
, putTFRecordLength
, putTFRecordData
) where
import Control.Exception (evaluate)
import Control.Monad (when)
import Data.ByteString.Unsafe (unsafePackCStringLen)
import qualified Data.ByteString.Builder as B (Builder)
import Data.ByteString.Builder.Extra (runBuilder, Next(..))
import qualified Data.ByteString.Lazy as BL
import Data.Serialize.Get
( Get
, getBytes
, getWord32le
, getWord64le
, getLazyByteString
, isEmpty
, lookAhead
)
import Data.Serialize
( Put
, execPut
, putLazyByteString
, putWord32le
, putWord64le
)
import Data.Word (Word8, Word64)
import Foreign.Marshal.Alloc (allocaBytes)
import Foreign.Ptr (Ptr, castPtr)
import System.IO.Unsafe (unsafePerformIO)
import TensorFlow.CRC32C (crc32cLBSMasked, crc32cUpdate, crc32cMask)
getTFRecord :: Get BL.ByteString
getTFRecord = getTFRecordLength >>= getTFRecordData
getTFRecords :: Get [BL.ByteString]
getTFRecords = do
e <- isEmpty
if e then return [] else (:) <$> getTFRecord <*> getTFRecords
getCheckMaskedCRC32C :: BL.ByteString -> Get ()
getCheckMaskedCRC32C bs = do
wireCRC <- getWord32le
let maskedCRC = crc32cLBSMasked bs
when (maskedCRC /= wireCRC) $ fail $
"getCheckMaskedCRC32C: CRC mismatch, computed: " ++ show maskedCRC ++
", expected: " ++ show wireCRC
getTFRecordLength :: Get Word64
getTFRecordLength = do
buf <- lookAhead (getBytes 8)
getWord64le <* getCheckMaskedCRC32C (BL.fromStrict buf)
getTFRecordData :: Word64 -> Get BL.ByteString
getTFRecordData len = if len > 0x7fffffffffffffff
then fail "getTFRecordData: Record size overflows Int64"
else do
bs <- getLazyByteString (fromIntegral len)
getCheckMaskedCRC32C bs
return bs
putMaskedCRC32C :: BL.ByteString -> Put
putMaskedCRC32C = putWord32le . crc32cLBSMasked
unsafeWithFixedWidthBuilder :: Int -> B.Builder -> (Ptr Word8 -> IO r) -> IO r
unsafeWithFixedWidthBuilder n b act = allocaBytes n $ \ptr -> do
(_, signal) <- runBuilder b ptr n
case signal of
Done -> act ptr
More _ _ -> error "unsafeWithFixedWidthBuilder: Builder returned More."
Chunk _ _ -> error "unsafeWithFixedWidthBuilder: Builder returned Chunk."
putTFRecordLength :: Word64 -> Put
putTFRecordLength x =
let put = putWord64le x
len = 8
crc = crc32cMask $ unsafePerformIO $
unsafeWithFixedWidthBuilder len (execPut put) $ \ptr -> do
str <- unsafePackCStringLen (castPtr ptr, len)
evaluate $ crc32cUpdate 0 str
in put *> putWord32le crc
putTFRecordData :: BL.ByteString -> Put
putTFRecordData bs = putLazyByteString bs *> putMaskedCRC32C bs
putTFRecord :: BL.ByteString -> Put
putTFRecord bs =
putTFRecordLength (fromIntegral $ BL.length bs) *> putTFRecordData bs