{-# Language TypeFamilies #-}

{-# LANGUAGE FlexibleInstances #-}

module MXNet.Core.IO.DataIter.Conduit (

    ConduitData,

    Dataset(..),

    imageRecordIter, mnistIter, csvIter, libSVMIter

) where



import Data.IORef

import Data.Conduit

import qualified Data.Conduit.Combinators as C

import qualified Data.Conduit.List as CL

import Control.Monad.IO.Class

import MXNet.Core.Base

import MXNet.Core.Base.NDArray (NDArray(..))

import MXNet.Core.Base.Internal

import qualified MXNet.Core.IO.Internal as I



import MXNet.NN.Types (TrainM)

import MXNet.NN.DataIter.Class



newtype ConduitData m a = ConduitData { getConduit :: ConduitM () a m () }



imageRecordIter :: (MatchKVList kvs I.ImageRecordIter_Args, ShowKV kvs, DType a, MonadIO m) => 

                   HMap kvs -> ConduitData m (NDArray a, NDArray a)

imageRecordIter = makeIter I.imageRecordIter



mnistIter :: (MatchKVList kvs I.MNISTIter_Args, ShowKV kvs, DType a, MonadIO m) => 

             HMap kvs -> ConduitData m (NDArray a, NDArray a)

mnistIter = makeIter I.mNISTIter



csvIter :: (MatchKVList kvs I.CSVIter_Args, ShowKV kvs, DType a, MonadIO m) => 

             HMap kvs -> ConduitData m (NDArray a, NDArray a)

csvIter = makeIter I.cSVIter



libSVMIter :: (MatchKVList kvs I.LibSVMIter_Args, ShowKV kvs, DType a, MonadIO m) => 

              HMap kvs -> ConduitData m (NDArray a, NDArray a)

libSVMIter = makeIter I.libSVMIter



makeIter creator args = ConduitData $ do

    iter <- liftIO (creator args)

    let loop = do valid <- liftIO $ checked $ mxDataIterNext iter

                  if valid == 0

                  then liftIO (checked $ mxDataIterFree iter)

                  else do

                      yieldM $ liftIO $ do 

                          dat <- checked $ mxDataIterGetData  iter

                          lbl <- checked $ mxDataIterGetLabel iter

                          return (NDArray dat, NDArray lbl)

                      loop

    loop



type instance DatasetConstraint (ConduitData m1) m2 = m1 ~ m2



instance Monad m => Dataset (ConduitData m) where

    fromListD = ConduitData . CL.sourceList 

    zipD (ConduitData d1) (ConduitData d2) = ConduitData $ getZipSource $ (,) <$> ZipSource d1 <*> ZipSource d2

    sizeD (ConduitData dat) = runConduit (dat .| C.length)

    forEachD (ConduitData dat) proc = sourceToList $ dat .| CL.mapM proc