{-# Language MultiParamTypeClasses, FunctionalDependencies #-}

{-# LANGUAGE FlexibleInstances #-}

module MXNet.NN.DataIter.Class where



import GHC.Exts (Constraint)



-- | Constraints on Dataset and the monad where the operation shall be ran.

type family DatasetConstraint (d :: * -> *) (m :: * -> *) :: Constraint



-- | Abstract Dataset type class

class Dataset (d :: * -> *) where

    -- | Create Dataset from `[]`.

    -- note that depending on the instance, it may or may not work with infinitive list.

    fromListD   :: [e] -> d e

    -- | Zip two Datasets

    zipD        :: d e1 -> d e2 -> d (e1, e2)

    -- | Get number of elements

    sizeD       :: (DatasetConstraint d m, Monad m) => d e -> m Int

    -- | Apply a function on each element of Dataset

    forEachD    :: (DatasetConstraint d m, Monad m) => d e -> (e -> m a) -> m [a]



    -- | Apply a function on each element of Dataset together with the element's index. 

    -- Note that the default implmentation assumes the Dataset can be created from a infinitive list.

    forEachD_i  :: (DatasetConstraint d m, Monad m) => d e -> ((Int, e) -> m a) -> m [a]

    forEachD_i  dat = forEachD (zipD (fromListD [1..]) dat)



    -- | Apply a function on each element of Dataset together with the total number of elements and the element's index.

    forEachD_ni :: (DatasetConstraint d m, Monad m) => d e -> (((Int, Int), e) -> m a) -> m [a]

    forEachD_ni dat proc = do 

        n <- sizeD dat

        forEachD ((fromListD (replicate n n) `zipD` fromListD [1..n]) `zipD` dat) proc