{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}

module Torch.Data.Dataset where

import qualified Control.Foldl as L
import Lens.Family (view)
import Pipes (ListT (Select), Pipe, Producer, enumerate, (>->))
import Pipes.Group (chunksOf, folds)
import Torch.Data.StreamedPipeline

-- | This type is actually not very useful.
-- | It would actually be better to define a transform
-- | on top of another dataset, since then we can do this in parallel
data CollatedDataset m dataset batch collatedBatch = CollatedDataset
  { forall (m :: * -> *) dataset batch collatedBatch.
CollatedDataset m dataset batch collatedBatch -> dataset
set :: dataset,
    forall (m :: * -> *) dataset batch collatedBatch.
CollatedDataset m dataset batch collatedBatch -> Int
chunkSize :: Int,
    forall (m :: * -> *) dataset batch collatedBatch.
CollatedDataset m dataset batch collatedBatch
-> Pipe [batch] collatedBatch m ()
collateFn :: Pipe [batch] collatedBatch m ()
  }

instance Datastream m seed dataset batch => Datastream m seed (CollatedDataset m dataset batch collatedBatch) collatedBatch where
  streamSamples :: CollatedDataset m dataset batch collatedBatch
-> seed -> ListT m collatedBatch
streamSamples CollatedDataset {dataset
Int
Pipe [batch] collatedBatch m ()
set :: forall (m :: * -> *) dataset batch collatedBatch.
CollatedDataset m dataset batch collatedBatch -> dataset
chunkSize :: forall (m :: * -> *) dataset batch collatedBatch.
CollatedDataset m dataset batch collatedBatch -> Int
collateFn :: forall (m :: * -> *) dataset batch collatedBatch.
CollatedDataset m dataset batch collatedBatch
-> Pipe [batch] collatedBatch m ()
set :: dataset
chunkSize :: Int
collateFn :: Pipe [batch] collatedBatch m ()
..} =
    Producer collatedBatch m () -> ListT m collatedBatch
forall (m :: * -> *) a. Producer a m () -> ListT m a
Select
      (Producer collatedBatch m () -> ListT m collatedBatch)
-> (seed -> Producer collatedBatch m ())
-> seed
-> ListT m collatedBatch
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Proxy X () () [batch] m ()
-> Pipe [batch] collatedBatch m () -> Producer collatedBatch m ()
forall (m :: * -> *) a' a b r c' c.
Functor m =>
Proxy a' a () b m r -> Proxy () b c' c m r -> Proxy a' a c' c m r
>-> Pipe [batch] collatedBatch m ()
collateFn)
      (Proxy X () () [batch] m () -> Producer collatedBatch m ())
-> (seed -> Proxy X () () [batch] m ())
-> seed
-> Producer collatedBatch m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall x.
 (x -> batch -> x)
 -> x
 -> (x -> [batch])
 -> FreeT (Producer batch m) m ()
 -> Proxy X () () [batch] m ())
-> Fold batch [batch]
-> FreeT (Producer batch m) m ()
-> Proxy X () () [batch] m ()
forall a b r.
(forall x. (x -> a -> x) -> x -> (x -> b) -> r) -> Fold a b -> r
L.purely (x -> batch -> x)
-> x
-> (x -> [batch])
-> FreeT (Producer batch m) m ()
-> Proxy X () () [batch] m ()
forall x.
(x -> batch -> x)
-> x
-> (x -> [batch])
-> FreeT (Producer batch m) m ()
-> Proxy X () () [batch] m ()
forall (m :: * -> *) x a b r.
Monad m =>
(x -> a -> x)
-> x -> (x -> b) -> FreeT (Producer a m) m r -> Producer b m r
folds Fold batch [batch]
forall a. Fold a [a]
L.list
      (FreeT (Producer batch m) m () -> Proxy X () () [batch] m ())
-> (seed -> FreeT (Producer batch m) m ())
-> seed
-> Proxy X () () [batch] m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FoldLike
  (FreeT (Producer batch m) m ())
  (Producer batch m ())
  (Producer Any m ())
  (FreeT (Producer batch m) m ())
  (FreeT (Producer Any m) m ())
-> Producer batch m () -> FreeT (Producer batch m) m ()
forall a s t b. FoldLike a s t a b -> s -> a
view (Int
-> Lens
     (Producer batch m ())
     (Producer Any m ())
     (FreeT (Producer batch m) m ())
     (FreeT (Producer Any m) m ())
forall (m :: * -> *) a' x a.
Monad m =>
Int
-> Lens
     (Producer a' m x)
     (Producer a m x)
     (FreeT (Producer a' m) m x)
     (FreeT (Producer a m) m x)
chunksOf Int
chunkSize)
      (Producer batch m () -> FreeT (Producer batch m) m ())
-> (seed -> Producer batch m ())
-> seed
-> FreeT (Producer batch m) m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ListT m batch -> Producer batch m ()
forall (m :: * -> *) a. ListT m a -> Producer a m ()
enumerate
      (ListT m batch -> Producer batch m ())
-> (seed -> ListT m batch) -> seed -> Producer batch m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. dataset -> seed -> ListT m batch
forall (m :: * -> *) seed dataset sample.
Datastream m seed dataset sample =>
dataset -> seed -> ListT m sample
streamSamples dataset
set