{-# Language TypeFamilies #-} {-# Language FlexibleInstances #-} module MXNet.Core.IO.DataIter.Streaming ( StreamData, Dataset(..), imageRecordIter, mnistIter, csvIter, libSVMIter ) where import Data.IORef import Streaming import Streaming.Prelude (Of(..), yield, length_, toList_) import qualified Streaming.Prelude as S import MXNet.Core.Base import MXNet.Core.Base.NDArray (NDArray(..)) import MXNet.Core.Base.Internal import qualified MXNet.Core.IO.Internal as I import MXNet.NN.Types (TrainM) import MXNet.NN.DataIter.Class newtype StreamData m a = StreamData { getStream :: Stream (Of a) m ()} imageRecordIter :: (MatchKVList kvs I.ImageRecordIter_Args, ShowKV kvs, DType a, MonadIO m) => HMap kvs -> StreamData m (NDArray a, NDArray a) imageRecordIter = makeIter I.imageRecordIter mnistIter :: (MatchKVList kvs I.MNISTIter_Args, ShowKV kvs, DType a, MonadIO m) => HMap kvs -> StreamData m (NDArray a, NDArray a) mnistIter = makeIter I.mNISTIter csvIter :: (MatchKVList kvs I.CSVIter_Args, ShowKV kvs, DType a, MonadIO m) => HMap kvs -> StreamData m (NDArray a, NDArray a) csvIter = makeIter I.cSVIter libSVMIter :: (MatchKVList kvs I.LibSVMIter_Args, ShowKV kvs, DType a, MonadIO m) => HMap kvs -> StreamData m (NDArray a, NDArray a) libSVMIter = makeIter I.libSVMIter makeIter creator args = StreamData $ do cnt <- liftIO (newIORef 0) iter <- liftIO (creator args) let loop = do valid <- liftIO $ do modifyIORef cnt (+1) checked $ mxDataIterNext iter if valid == 0 then liftIO (checked $ mxDataIterFree iter) else do item <- liftIO $ do dat <- checked $ mxDataIterGetData iter lbl <- checked $ mxDataIterGetLabel iter return (NDArray dat, NDArray lbl) yield item loop loop type instance DatasetConstraint (StreamData m1) m2 = m1 ~ m2 instance Monad m => Dataset (StreamData m) where fromListD = StreamData . S.each zipD s1 s2 = StreamData $ S.zip (getStream s1) (getStream s2) sizeD = length_ . getStream forEachD dat proc = toList_ $ void $ S.mapM proc (getStream dat)