-- Copyright 2016 TensorFlow authors. -- -- Licensed under the Apache License, Version 2.0 (the "License"); -- you may not use this file except in compliance with the License. -- You may obtain a copy of the License at -- -- http://www.apache.org/licenses/LICENSE-2.0 -- -- Unless required by applicable law or agreed to in writing, software -- distributed under the License is distributed on an "AS IS" BASIS, -- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -- See the License for the specific language governing permissions and -- limitations under the License. -- | Encoder and decoder for the TensorFlow \"TFRecords\" format. {-# LANGUAGE Rank2Types #-} module TensorFlow.Records ( -- * Records putTFRecord , getTFRecord , getTFRecords -- * Implementation -- | These may be useful for encoding or decoding to types other than -- 'ByteString' that have their own Cereal codecs. , getTFRecordLength , getTFRecordData , putTFRecordLength , putTFRecordData ) where import Control.Exception (evaluate) import Control.Monad (when) import Data.ByteString.Unsafe (unsafePackCStringLen) import qualified Data.ByteString.Builder as B (Builder) import Data.ByteString.Builder.Extra (runBuilder, Next(..)) import qualified Data.ByteString.Lazy as BL import Data.Serialize.Get ( Get , getBytes , getWord32le , getWord64le , getLazyByteString , isEmpty , lookAhead ) import Data.Serialize ( Put , execPut , putLazyByteString , putWord32le , putWord64le ) import Data.Word (Word8, Word64) import Foreign.Marshal.Alloc (allocaBytes) import Foreign.Ptr (Ptr, castPtr) import System.IO.Unsafe (unsafePerformIO) import TensorFlow.CRC32C (crc32cLBSMasked, crc32cUpdate, crc32cMask) -- | Parse one TFRecord. getTFRecord :: Get BL.ByteString getTFRecord = getTFRecordLength >>= getTFRecordData -- | Parse many TFRecords as a list. Note you probably want streaming instead -- as provided by the tensorflow-records-conduit package. getTFRecords :: Get [BL.ByteString] getTFRecords = do e <- isEmpty if e then return [] else (:) <$> getTFRecord <*> getTFRecords getCheckMaskedCRC32C :: BL.ByteString -> Get () getCheckMaskedCRC32C bs = do wireCRC <- getWord32le let maskedCRC = crc32cLBSMasked bs when (maskedCRC /= wireCRC) $ fail $ "getCheckMaskedCRC32C: CRC mismatch, computed: " ++ show maskedCRC ++ ", expected: " ++ show wireCRC -- | Get a length and verify its checksum. getTFRecordLength :: Get Word64 getTFRecordLength = do buf <- lookAhead (getBytes 8) getWord64le <* getCheckMaskedCRC32C (BL.fromStrict buf) -- | Get a record payload and verify its checksum. getTFRecordData :: Word64 -> Get BL.ByteString getTFRecordData len = if len > 0x7fffffffffffffff then fail "getTFRecordData: Record size overflows Int64" else do bs <- getLazyByteString (fromIntegral len) getCheckMaskedCRC32C bs return bs putMaskedCRC32C :: BL.ByteString -> Put putMaskedCRC32C = putWord32le . crc32cLBSMasked -- Runs a Builder that's known to write a fixed number of bytes on an 'alloca' -- buffer, and runs the given IO action on the result. Raises exceptions if -- the Builder yields ByteString chunks or attempts to write more bytes than -- expected. unsafeWithFixedWidthBuilder :: Int -> B.Builder -> (Ptr Word8 -> IO r) -> IO r unsafeWithFixedWidthBuilder n b act = allocaBytes n $ \ptr -> do (_, signal) <- runBuilder b ptr n case signal of Done -> act ptr More _ _ -> error "unsafeWithFixedWidthBuilder: Builder returned More." Chunk _ _ -> error "unsafeWithFixedWidthBuilder: Builder returned Chunk." -- | Put a record length and its checksum. putTFRecordLength :: Word64 -> Put putTFRecordLength x = let put = putWord64le x len = 8 crc = crc32cMask $ unsafePerformIO $ -- Serialized Word64 is always 8 bytes, so we can go fast by using -- alloca. unsafeWithFixedWidthBuilder len (execPut put) $ \ptr -> do str <- unsafePackCStringLen (castPtr ptr, len) -- Force the result to ensure it's evaluated before freeing ptr. evaluate $ crc32cUpdate 0 str in put *> putWord32le crc -- | Put a record payload and its checksum. putTFRecordData :: BL.ByteString -> Put putTFRecordData bs = putLazyByteString bs *> putMaskedCRC32C bs -- | Put one TFRecord with the given contents. putTFRecord :: BL.ByteString -> Put putTFRecord bs = putTFRecordLength (fromIntegral $ BL.length bs) *> putTFRecordData bs