{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
module Torch.Data.Dataset where
import qualified Control.Foldl as L
import Lens.Family (view)
import Pipes (ListT (Select), Pipe, Producer, enumerate, (>->))
import Pipes.Group (chunksOf, folds)
import Torch.Data.StreamedPipeline
data CollatedDataset m dataset batch collatedBatch = CollatedDataset
{ forall (m :: * -> *) dataset batch collatedBatch.
CollatedDataset m dataset batch collatedBatch -> dataset
set :: dataset,
forall (m :: * -> *) dataset batch collatedBatch.
CollatedDataset m dataset batch collatedBatch -> Int
chunkSize :: Int,
forall (m :: * -> *) dataset batch collatedBatch.
CollatedDataset m dataset batch collatedBatch
-> Pipe [batch] collatedBatch m ()
collateFn :: Pipe [batch] collatedBatch m ()
}
instance Datastream m seed dataset batch => Datastream m seed (CollatedDataset m dataset batch collatedBatch) collatedBatch where
streamSamples :: CollatedDataset m dataset batch collatedBatch
-> seed -> ListT m collatedBatch
streamSamples CollatedDataset {dataset
Int
Pipe [batch] collatedBatch m ()
set :: forall (m :: * -> *) dataset batch collatedBatch.
CollatedDataset m dataset batch collatedBatch -> dataset
chunkSize :: forall (m :: * -> *) dataset batch collatedBatch.
CollatedDataset m dataset batch collatedBatch -> Int
collateFn :: forall (m :: * -> *) dataset batch collatedBatch.
CollatedDataset m dataset batch collatedBatch
-> Pipe [batch] collatedBatch m ()
set :: dataset
chunkSize :: Int
collateFn :: Pipe [batch] collatedBatch m ()
..} =
Producer collatedBatch m () -> ListT m collatedBatch
forall (m :: * -> *) a. Producer a m () -> ListT m a
Select
(Producer collatedBatch m () -> ListT m collatedBatch)
-> (seed -> Producer collatedBatch m ())
-> seed
-> ListT m collatedBatch
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Proxy X () () [batch] m ()
-> Pipe [batch] collatedBatch m () -> Producer collatedBatch m ()
forall (m :: * -> *) a' a b r c' c.
Functor m =>
Proxy a' a () b m r -> Proxy () b c' c m r -> Proxy a' a c' c m r
>-> Pipe [batch] collatedBatch m ()
collateFn)
(Proxy X () () [batch] m () -> Producer collatedBatch m ())
-> (seed -> Proxy X () () [batch] m ())
-> seed
-> Producer collatedBatch m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall x.
(x -> batch -> x)
-> x
-> (x -> [batch])
-> FreeT (Producer batch m) m ()
-> Proxy X () () [batch] m ())
-> Fold batch [batch]
-> FreeT (Producer batch m) m ()
-> Proxy X () () [batch] m ()
forall a b r.
(forall x. (x -> a -> x) -> x -> (x -> b) -> r) -> Fold a b -> r
L.purely (x -> batch -> x)
-> x
-> (x -> [batch])
-> FreeT (Producer batch m) m ()
-> Proxy X () () [batch] m ()
forall x.
(x -> batch -> x)
-> x
-> (x -> [batch])
-> FreeT (Producer batch m) m ()
-> Proxy X () () [batch] m ()
forall (m :: * -> *) x a b r.
Monad m =>
(x -> a -> x)
-> x -> (x -> b) -> FreeT (Producer a m) m r -> Producer b m r
folds Fold batch [batch]
forall a. Fold a [a]
L.list
(FreeT (Producer batch m) m () -> Proxy X () () [batch] m ())
-> (seed -> FreeT (Producer batch m) m ())
-> seed
-> Proxy X () () [batch] m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FoldLike
(FreeT (Producer batch m) m ())
(Producer batch m ())
(Producer Any m ())
(FreeT (Producer batch m) m ())
(FreeT (Producer Any m) m ())
-> Producer batch m () -> FreeT (Producer batch m) m ()
forall a s t b. FoldLike a s t a b -> s -> a
view (Int
-> Lens
(Producer batch m ())
(Producer Any m ())
(FreeT (Producer batch m) m ())
(FreeT (Producer Any m) m ())
forall (m :: * -> *) a' x a.
Monad m =>
Int
-> Lens
(Producer a' m x)
(Producer a m x)
(FreeT (Producer a' m) m x)
(FreeT (Producer a m) m x)
chunksOf Int
chunkSize)
(Producer batch m () -> FreeT (Producer batch m) m ())
-> (seed -> Producer batch m ())
-> seed
-> FreeT (Producer batch m) m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ListT m batch -> Producer batch m ()
forall (m :: * -> *) a. ListT m a -> Producer a m ()
enumerate
(ListT m batch -> Producer batch m ())
-> (seed -> ListT m batch) -> seed -> Producer batch m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. dataset -> seed -> ListT m batch
forall (m :: * -> *) seed dataset sample.
Datastream m seed dataset sample =>
dataset -> seed -> ListT m sample
streamSamples dataset
set