------------------------------------------------------------------------------- -- | -- Module : Numeric.Dataloader -- Stability : experimental -- Portability: non-portable -- -- A Dataloader is an extension of a Dataset and is primarily intended for -- compute-intensive, batch loading interfaces. When used with ImageFolder -- representations of Datasets, it shuffles the order of files to be loaded -- and leverages the async library when possible. -- -- Concurrent loading primarily takes place in 'batchStream'. 'stream' exists -- primarily to provide a unified API with training that is not batch-oriented. ------------------------------------------------------------------------------- {-# LANGUAGE ScopedTypeVariables #-} module Numeric.Dataloader ( Dataloader(..) , uniformIxline , stream , batchStream ) where import Control.Monad ((>=>)) import Control.Monad.IO.Class (MonadIO, liftIO) import Data.Vector (Vector) import Streaming (Stream, Of(..)) import System.Random.MWC (GenIO) import qualified Data.Vector as V import qualified Streaming as S import qualified Streaming.Prelude as S import qualified System.Random.MWC.Distributions as MWC import Control.Exception.Safe (MonadThrow) import Streaming.Instances () import Control.Parallel.Strategies import Numeric.Datasets -- * Configuring data loaders -- | Options for a data loading functions. data Dataloader a b = Dataloader { batchSize :: Int -- ^ Batch size used with 'batchStream'. , shuffle :: Maybe (Vector Int) -- ^ Optional shuffle order (forces the dataset to be loaded in memory if it wasn't already). , dataset :: Dataset a -- ^ Dataset associated with the dataloader. , transform :: a -> b -- ^ Transformation associated with the dataloader which will be run in parallel. If using an -- ImageFolder, this is where you would transform image filepaths to an image (or other -- compute-optimized form). Additionally, this is where you should perform any -- static normalization. } -- | Generate a uniformly random index line from a dataset and a generator. uniformIxline :: Dataset a -> GenIO -> IO (Vector Int) uniformIxline ds gen = do len <- V.length <$> getDatavec ds MWC.uniformPermutation len gen -- * Data loading functions -- | Stream a dataset in-memory, applying a transformation function. stream :: (MonadThrow io, MonadIO io) => Dataloader a b -> Stream (Of b) io () stream dl = S.maps (\(a:>b) -> (transform dl a `using` rpar) :> b) (sourceStream dl) -- | Stream batches of a dataset, concurrently processing each element -- -- NOTE: Run with @-threaded -rtsopts@ to concurrently load data in-memory. batchStream :: (MonadThrow io, MonadIO io, NFData b) => Dataloader a b -> Stream (Of [b]) io () batchStream dl = S.mapsM (S.toList >=> liftIO . firstOfM go) $ S.chunksOf (batchSize dl) $ sourceStream dl where go as = fmap (transform dl) as `usingIO` parList rdeepseq -- * helper functions (not for export) -- | Stream a dataset in-memory sourceStream :: (MonadThrow io, MonadIO io) => Dataloader a b -> Stream (Of a) io () sourceStream loader = permute loader <$> getDatavec (dataset loader) >>= S.each where -- Use a dataloader's shuffle order to return a permuted vector (or return the -- identity vector). permute :: Dataloader a b -> Vector a -> Vector a permute loader va = maybe va (V.backpermute va) (shuffle loader) -- | Monadic, concrete version of Control.Arrow.first for @Of@ firstOfM :: Monad m => (a -> m b) -> Of a c -> m (Of b c) firstOfM fm (a:>c) = do b <- fm a pure (b:>c)