module TensorFlow.Records.Conduit
(
encodeTFRecords
, decodeTFRecords
, sinkTFRecords
, sourceTFRecords
) where
import Control.Monad.Catch (MonadThrow)
import Control.Monad.Trans.Resource (MonadResource)
import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as BL
import Data.Conduit ((=$=), Conduit, Consumer, Producer)
import Data.Conduit.Binary (sinkFile, sourceFile)
import Data.Conduit.Cereal (conduitGet2, conduitPut)
import TensorFlow.Records (getTFRecord, putTFRecord)
decodeTFRecords :: MonadThrow m => Conduit B.ByteString m BL.ByteString
decodeTFRecords = conduitGet2 getTFRecord
sourceTFRecords :: (MonadResource m, MonadThrow m) => FilePath -> Producer m BL.ByteString
sourceTFRecords path = sourceFile path =$= decodeTFRecords
encodeTFRecords :: Monad m => Conduit BL.ByteString m B.ByteString
encodeTFRecords = conduitPut putTFRecord
sinkTFRecords :: (MonadResource m) => FilePath -> Consumer BL.ByteString m ()
sinkTFRecords path = encodeTFRecords =$= sinkFile path