{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
module Torch.Data.CsvDatastream
( BufferSize,
NamedColumns (..),
CsvDatastream' (..),
CsvDatastream,
CsvDatastreamNamed,
csvDatastream,
tsvDatastream,
FromField (..),
FromRecord (..),
FromNamedRecord (..),
)
where
import qualified Control.Foldl as L
import Control.Monad
import Control.Monad.ST
import Data.Array.ST
import Data.Char (ord)
import Data.Csv (DecodeOptions (decDelimiter))
import Data.STRef
import Data.Vector (Vector)
import qualified Data.Vector as V
import Lens.Family (view)
import Pipes
import qualified Pipes.ByteString as B
import Pipes.Csv
import Pipes.Group (chunksOf, folds)
import qualified Pipes.Prelude as P
import qualified Pipes.Safe as Safe
import qualified Pipes.Safe.Prelude as Safe
import System.IO (IOMode (ReadMode))
import System.Random
import Torch.Data.StreamedPipeline
data NamedColumns = Unnamed | Named
type BufferSize = Int
data CsvDatastream' batches (named :: NamedColumns) = CsvDatastream'
{
forall batches (named :: NamedColumns).
CsvDatastream' batches named -> FilePath
filePath :: FilePath,
forall batches (named :: NamedColumns).
CsvDatastream' batches named -> Word8
delimiter :: !B.Word8,
:: HasHeader,
forall batches (named :: NamedColumns).
CsvDatastream' batches named -> Int
batchSize :: Int,
forall batches (named :: NamedColumns).
CsvDatastream' batches named -> Maybe Int
bufferedShuffle :: Maybe BufferSize,
forall batches (named :: NamedColumns).
CsvDatastream' batches named -> Bool
dropLast :: Bool
}
type CsvDatastream batches = CsvDatastream' batches Unnamed
type CsvDatastreamNamed batches = CsvDatastream' batches Named
tsvDatastream :: forall (isNamed :: NamedColumns) batches. FilePath -> CsvDatastream' batches isNamed
tsvDatastream :: forall (isNamed :: NamedColumns) batches.
FilePath -> CsvDatastream' batches isNamed
tsvDatastream FilePath
filePath = (FilePath -> CsvDatastream' Any Any
forall (isNamed :: NamedColumns) batches.
FilePath -> CsvDatastream' batches isNamed
csvDatastream FilePath
filePath) {delimiter = fromIntegral $ ord '\t'}
csvDatastream :: forall (isNamed :: NamedColumns) batches. FilePath -> CsvDatastream' batches isNamed
csvDatastream :: forall (isNamed :: NamedColumns) batches.
FilePath -> CsvDatastream' batches isNamed
csvDatastream FilePath
filePath =
CsvDatastream'
{ filePath :: FilePath
filePath = FilePath
filePath,
delimiter :: Word8
delimiter = Int -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Word8) -> Int -> Word8
forall a b. (a -> b) -> a -> b
$ Char -> Int
ord Char
',',
hasHeader :: HasHeader
hasHeader = HasHeader
NoHeader,
batchSize :: Int
batchSize = Int
1,
bufferedShuffle :: Maybe Int
bufferedShuffle = Maybe Int
forall a. Maybe a
Nothing,
dropLast :: Bool
dropLast = Bool
True
}
instance
( MonadBaseControl IO m,
Safe.MonadSafe m,
FromRecord batch
) =>
Datastream m () (CsvDatastream batch) (Vector batch)
where
streamSamples :: CsvDatastream batch -> () -> ListT m (Vector batch)
streamSamples csv :: CsvDatastream batch
csv@CsvDatastream' {Bool
Int
FilePath
Maybe Int
Word8
HasHeader
filePath :: forall batches (named :: NamedColumns).
CsvDatastream' batches named -> FilePath
delimiter :: forall batches (named :: NamedColumns).
CsvDatastream' batches named -> Word8
hasHeader :: forall batches (named :: NamedColumns).
CsvDatastream' batches named -> HasHeader
batchSize :: forall batches (named :: NamedColumns).
CsvDatastream' batches named -> Int
bufferedShuffle :: forall batches (named :: NamedColumns).
CsvDatastream' batches named -> Maybe Int
dropLast :: forall batches (named :: NamedColumns).
CsvDatastream' batches named -> Bool
filePath :: FilePath
delimiter :: Word8
hasHeader :: HasHeader
batchSize :: Int
bufferedShuffle :: Maybe Int
dropLast :: Bool
..} ()
_ = CsvDatastream batch
-> (Proxy X () () ByteString m ()
-> Proxy X () () (Either FilePath batch) m ())
-> ListT m (Vector batch)
forall {f :: * -> *} {m :: * -> *} {m :: * -> *} {batches}
{named :: NamedColumns} {x'} {x} {a}.
(Foldable f, MonadSafe m, MonadIO m, MonadBase IO m) =>
CsvDatastream' batches named
-> (Proxy x' x () ByteString m () -> Proxy X () () (f a) m ())
-> ListT m (Vector a)
readCsv CsvDatastream batch
csv (DecodeOptions
-> HasHeader
-> Proxy X () () ByteString m ()
-> Proxy X () () (Either FilePath batch) m ()
forall (m :: * -> *) a.
(Monad m, FromRecord a) =>
DecodeOptions
-> HasHeader
-> Producer ByteString m ()
-> Producer (Either FilePath a) m ()
decodeWith (DecodeOptions
defaultDecodeOptions {decDelimiter = delimiter}) HasHeader
hasHeader)
instance
( MonadBaseControl IO m,
Safe.MonadSafe m,
FromNamedRecord batch
) =>
Datastream m () (CsvDatastreamNamed batch) (Vector batch)
where
streamSamples :: CsvDatastreamNamed batch -> () -> ListT m (Vector batch)
streamSamples csv :: CsvDatastreamNamed batch
csv@CsvDatastream' {Bool
Int
FilePath
Maybe Int
Word8
HasHeader
filePath :: forall batches (named :: NamedColumns).
CsvDatastream' batches named -> FilePath
delimiter :: forall batches (named :: NamedColumns).
CsvDatastream' batches named -> Word8
hasHeader :: forall batches (named :: NamedColumns).
CsvDatastream' batches named -> HasHeader
batchSize :: forall batches (named :: NamedColumns).
CsvDatastream' batches named -> Int
bufferedShuffle :: forall batches (named :: NamedColumns).
CsvDatastream' batches named -> Maybe Int
dropLast :: forall batches (named :: NamedColumns).
CsvDatastream' batches named -> Bool
filePath :: FilePath
delimiter :: Word8
hasHeader :: HasHeader
batchSize :: Int
bufferedShuffle :: Maybe Int
dropLast :: Bool
..} ()
_ = CsvDatastreamNamed batch
-> (Proxy X () () ByteString m ()
-> Proxy X () () (Either FilePath batch) m ())
-> ListT m (Vector batch)
forall {f :: * -> *} {m :: * -> *} {m :: * -> *} {batches}
{named :: NamedColumns} {x'} {x} {a}.
(Foldable f, MonadSafe m, MonadIO m, MonadBase IO m) =>
CsvDatastream' batches named
-> (Proxy x' x () ByteString m () -> Proxy X () () (f a) m ())
-> ListT m (Vector a)
readCsv CsvDatastreamNamed batch
csv (DecodeOptions
-> Proxy X () () ByteString m ()
-> Proxy X () () (Either FilePath batch) m ()
forall (m :: * -> *) a.
(Monad m, FromNamedRecord a) =>
DecodeOptions
-> Producer ByteString m () -> Producer (Either FilePath a) m ()
decodeByNameWith (DecodeOptions
defaultDecodeOptions {decDelimiter = delimiter}))
readCsv :: CsvDatastream' batches named
-> (Proxy x' x () ByteString m () -> Proxy X () () (f a) m ())
-> ListT m (Vector a)
readCsv CsvDatastream' {Bool
Int
FilePath
Maybe Int
Word8
HasHeader
filePath :: forall batches (named :: NamedColumns).
CsvDatastream' batches named -> FilePath
delimiter :: forall batches (named :: NamedColumns).
CsvDatastream' batches named -> Word8
hasHeader :: forall batches (named :: NamedColumns).
CsvDatastream' batches named -> HasHeader
batchSize :: forall batches (named :: NamedColumns).
CsvDatastream' batches named -> Int
bufferedShuffle :: forall batches (named :: NamedColumns).
CsvDatastream' batches named -> Maybe Int
dropLast :: forall batches (named :: NamedColumns).
CsvDatastream' batches named -> Bool
filePath :: FilePath
delimiter :: Word8
hasHeader :: HasHeader
batchSize :: Int
bufferedShuffle :: Maybe Int
dropLast :: Bool
..} Proxy x' x () ByteString m () -> Proxy X () () (f a) m ()
decode = Producer (Vector a) m () -> ListT m (Vector a)
forall (m :: * -> *) a. Producer a m () -> ListT m a
Select (Producer (Vector a) m () -> ListT m (Vector a))
-> Producer (Vector a) m () -> ListT m (Vector a)
forall a b. (a -> b) -> a -> b
$
FilePath
-> IOMode
-> (Handle -> Producer (Vector a) m ())
-> Producer (Vector a) m ()
forall (m :: * -> *) r.
MonadSafe m =>
FilePath -> IOMode -> (Handle -> m r) -> m r
Safe.withFile FilePath
filePath IOMode
ReadMode ((Handle -> Producer (Vector a) m ()) -> Producer (Vector a) m ())
-> (Handle -> Producer (Vector a) m ()) -> Producer (Vector a) m ()
forall a b. (a -> b) -> a -> b
$ \Handle
fh ->
if Bool
dropLast
then Handle -> Producer (Vector a) m ()
streamRecords Handle
fh Producer (Vector a) m ()
-> Proxy () (Vector a) () (Vector a) m ()
-> Producer (Vector a) m ()
forall (m :: * -> *) a' a b r c' c.
Functor m =>
Proxy a' a () b m r -> Proxy () b c' c m r -> Proxy a' a c' c m r
>-> (Vector a -> Bool) -> Proxy () (Vector a) () (Vector a) m ()
forall (m :: * -> *) a r. Functor m => (a -> Bool) -> Pipe a a m r
P.filter (\Vector a
v -> Vector a -> Int
forall a. Vector a -> Int
V.length Vector a
v Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
batchSize)
else Handle -> Producer (Vector a) m ()
streamRecords Handle
fh
where
streamRecords :: Handle -> Producer (Vector a) m ()
streamRecords Handle
fh = case Maybe Int
bufferedShuffle of
Maybe Int
Nothing -> (forall x.
(x -> a -> x)
-> x
-> (x -> Vector a)
-> FreeT (Producer a m) m ()
-> Producer (Vector a) m ())
-> Fold a (Vector a)
-> FreeT (Producer a m) m ()
-> Producer (Vector a) m ()
forall a b r.
(forall x. (x -> a -> x) -> x -> (x -> b) -> r) -> Fold a b -> r
L.purely (x -> a -> x)
-> x
-> (x -> Vector a)
-> FreeT (Producer a m) m ()
-> Producer (Vector a) m ()
forall x.
(x -> a -> x)
-> x
-> (x -> Vector a)
-> FreeT (Producer a m) m ()
-> Producer (Vector a) m ()
forall (m :: * -> *) x a b r.
Monad m =>
(x -> a -> x)
-> x -> (x -> b) -> FreeT (Producer a m) m r -> Producer b m r
folds Fold a (Vector a)
forall (v :: * -> *) a. Vector v a => Fold a (v a)
L.vector (FreeT (Producer a m) m () -> Producer (Vector a) m ())
-> FreeT (Producer a m) m () -> Producer (Vector a) m ()
forall a b. (a -> b) -> a -> b
$ FoldLike
(FreeT (Producer a m) m ())
(Producer a m ())
(Producer Any m ())
(FreeT (Producer a m) m ())
(FreeT (Producer Any m) m ())
-> Producer a m () -> FreeT (Producer a m) m ()
forall a s t b. FoldLike a s t a b -> s -> a
view (Int
-> Lens
(Producer a m ())
(Producer Any m ())
(FreeT (Producer a m) m ())
(FreeT (Producer Any m) m ())
forall (m :: * -> *) a' x a.
Monad m =>
Int
-> Lens
(Producer a' m x)
(Producer a m x)
(FreeT (Producer a' m) m x)
(FreeT (Producer a m) m x)
chunksOf Int
batchSize) (Producer a m () -> FreeT (Producer a m) m ())
-> Producer a m () -> FreeT (Producer a m) m ()
forall a b. (a -> b) -> a -> b
$ Proxy x' x () ByteString m () -> Proxy X () () (f a) m ()
decode (Handle -> Proxy x' x () ByteString m ()
forall {m :: * -> *} {x'} {x}.
MonadIO m =>
Handle -> Proxy x' x () ByteString m ()
produceLine Handle
fh) Proxy X () () (f a) m ()
-> Proxy () (f a) () a m () -> Producer a m ()
forall (m :: * -> *) a' a b r c' c.
Functor m =>
Proxy a' a () b m r -> Proxy () b c' c m r -> Proxy a' a c' c m r
>-> Proxy () (f a) () a m ()
forall (m :: * -> *) (f :: * -> *) a r.
(Functor m, Foldable f) =>
Pipe (f a) a m r
P.concat
Just Int
bufferSize ->
(forall x.
(x -> a -> x)
-> x
-> (x -> Vector a)
-> FreeT (Producer a m) m ()
-> Producer (Vector a) m ())
-> Fold a (Vector a)
-> FreeT (Producer a m) m ()
-> Producer (Vector a) m ()
forall a b r.
(forall x. (x -> a -> x) -> x -> (x -> b) -> r) -> Fold a b -> r
L.purely (x -> a -> x)
-> x
-> (x -> Vector a)
-> FreeT (Producer a m) m ()
-> Producer (Vector a) m ()
forall x.
(x -> a -> x)
-> x
-> (x -> Vector a)
-> FreeT (Producer a m) m ()
-> Producer (Vector a) m ()
forall (m :: * -> *) x a b r.
Monad m =>
(x -> a -> x)
-> x -> (x -> b) -> FreeT (Producer a m) m r -> Producer b m r
folds Fold a (Vector a)
forall (v :: * -> *) a. Vector v a => Fold a (v a)
L.vector (FreeT (Producer a m) m () -> Producer (Vector a) m ())
-> FreeT (Producer a m) m () -> Producer (Vector a) m ()
forall a b. (a -> b) -> a -> b
$
FoldLike
(FreeT (Producer a m) m ())
(Producer a m ())
(Producer Any m ())
(FreeT (Producer a m) m ())
(FreeT (Producer Any m) m ())
-> Producer a m () -> FreeT (Producer a m) m ()
forall a s t b. FoldLike a s t a b -> s -> a
view (Int
-> Lens
(Producer a m ())
(Producer Any m ())
(FreeT (Producer a m) m ())
(FreeT (Producer Any m) m ())
forall (m :: * -> *) a' x a.
Monad m =>
Int
-> Lens
(Producer a' m x)
(Producer a m x)
(FreeT (Producer a' m) m x)
(FreeT (Producer a m) m x)
chunksOf Int
batchSize) (Producer a m () -> FreeT (Producer a m) m ())
-> Producer a m () -> FreeT (Producer a m) m ()
forall a b. (a -> b) -> a -> b
$
((forall x.
(x -> a -> x)
-> x
-> (x -> [a])
-> FreeT (Producer a m) m ()
-> Proxy X () () [a] m ())
-> Fold a [a]
-> FreeT (Producer a m) m ()
-> Proxy X () () [a] m ()
forall a b r.
(forall x. (x -> a -> x) -> x -> (x -> b) -> r) -> Fold a b -> r
L.purely (x -> a -> x)
-> x
-> (x -> [a])
-> FreeT (Producer a m) m ()
-> Proxy X () () [a] m ()
forall x.
(x -> a -> x)
-> x
-> (x -> [a])
-> FreeT (Producer a m) m ()
-> Proxy X () () [a] m ()
forall (m :: * -> *) x a b r.
Monad m =>
(x -> a -> x)
-> x -> (x -> b) -> FreeT (Producer a m) m r -> Producer b m r
folds Fold a [a]
forall a. Fold a [a]
L.list (FreeT (Producer a m) m () -> Proxy X () () [a] m ())
-> FreeT (Producer a m) m () -> Proxy X () () [a] m ()
forall a b. (a -> b) -> a -> b
$ FoldLike
(FreeT (Producer a m) m ())
(Producer a m ())
(Producer Any m ())
(FreeT (Producer a m) m ())
(FreeT (Producer Any m) m ())
-> Producer a m () -> FreeT (Producer a m) m ()
forall a s t b. FoldLike a s t a b -> s -> a
view (Int
-> Lens
(Producer a m ())
(Producer Any m ())
(FreeT (Producer a m) m ())
(FreeT (Producer Any m) m ())
forall (m :: * -> *) a' x a.
Monad m =>
Int
-> Lens
(Producer a' m x)
(Producer a m x)
(FreeT (Producer a' m) m x)
(FreeT (Producer a m) m x)
chunksOf Int
bufferSize) (Producer a m () -> FreeT (Producer a m) m ())
-> Producer a m () -> FreeT (Producer a m) m ()
forall a b. (a -> b) -> a -> b
$ Proxy x' x () ByteString m () -> Proxy X () () (f a) m ()
decode (Handle -> Proxy x' x () ByteString m ()
forall {m :: * -> *} {x'} {x}.
MonadIO m =>
Handle -> Proxy x' x () ByteString m ()
produceLine Handle
fh) Proxy X () () (f a) m ()
-> Proxy () (f a) () a m () -> Producer a m ()
forall (m :: * -> *) a' a b r c' c.
Functor m =>
Proxy a' a () b m r -> Proxy () b c' c m r -> Proxy a' a c' c m r
>-> Proxy () (f a) () a m ()
forall (m :: * -> *) (f :: * -> *) a r.
(Functor m, Foldable f) =>
Pipe (f a) a m r
P.concat) Proxy X () () [a] m () -> Proxy () [a] () a m () -> Producer a m ()
forall (m :: * -> *) a' a b r c' c.
Functor m =>
Proxy a' a () b m r -> Proxy () b c' c m r -> Proxy a' a c' c m r
>-> Proxy () [a] () a m ()
forall {y}. Proxy () [y] () y m ()
shuffleRecords
produceLine :: Handle -> Proxy x' x () ByteString m ()
produceLine Handle
fh = Int -> Handle -> Producer' ByteString m ()
forall (m :: * -> *).
MonadIO m =>
Int -> Handle -> Producer' ByteString m ()
B.hGetSome Int
1000 Handle
fh
shuffleRecords :: Proxy () [y] () y m ()
shuffleRecords = do
[y]
chunks <- Proxy () [y] () y m [y]
Consumer' [y] m [y]
forall (m :: * -> *) a. Functor m => Consumer' a m a
await
StdGen
std <- IO StdGen -> Proxy () [y] () y m StdGen
forall α. IO α -> Proxy () [y] () y m α
forall (b :: * -> *) (m :: * -> *) α. MonadBase b m => b α -> m α
Torch.Data.StreamedPipeline.liftBase IO StdGen
forall (m :: * -> *). MonadIO m => m StdGen
newStdGen
(y -> Proxy () [y] () y m ()) -> [y] -> Proxy () [y] () y m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ y -> Proxy () [y] () y m ()
forall (m :: * -> *) a x' x. Functor m => a -> Proxy x' x () a m ()
yield ([y] -> Proxy () [y] () y m ()) -> [y] -> Proxy () [y] () y m ()
forall a b. (a -> b) -> a -> b
$ ([y], StdGen) -> [y]
forall a b. (a, b) -> a
fst (([y], StdGen) -> [y]) -> ([y], StdGen) -> [y]
forall a b. (a -> b) -> a -> b
$ [y] -> StdGen -> ([y], StdGen)
forall a. [a] -> StdGen -> ([a], StdGen)
shuffle' [y]
chunks StdGen
std
shuffle' :: [a] -> StdGen -> ([a], StdGen)
shuffle' :: forall a. [a] -> StdGen -> ([a], StdGen)
shuffle' [a]
xs StdGen
gen =
(forall s. ST s ([a], StdGen)) -> ([a], StdGen)
forall a. (forall s. ST s a) -> a
runST
( do
STRef s StdGen
g <- StdGen -> ST s (STRef s StdGen)
forall a s. a -> ST s (STRef s a)
newSTRef StdGen
gen
let randomRST :: (Int, Int) -> ST s Int
randomRST (Int, Int)
lohi = do
(Int
a, StdGen
s') <- (StdGen -> (Int, StdGen)) -> ST s StdGen -> ST s (Int, StdGen)
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM ((Int, Int) -> StdGen -> (Int, StdGen)
forall g. RandomGen g => (Int, Int) -> g -> (Int, g)
forall a g. (Random a, RandomGen g) => (a, a) -> g -> (a, g)
randomR (Int, Int)
lohi) (STRef s StdGen -> ST s StdGen
forall s a. STRef s a -> ST s a
readSTRef STRef s StdGen
g)
STRef s StdGen -> StdGen -> ST s ()
forall s a. STRef s a -> a -> ST s ()
writeSTRef STRef s StdGen
g StdGen
s'
Int -> ST s Int
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return Int
a
STArray s Int a
ar <- Int -> [a] -> ST s (STArray s Int a)
forall a s. Int -> [a] -> ST s (STArray s Int a)
newArray Int
n [a]
xs
[a]
xs' <- [Int] -> (Int -> ST s a) -> ST s [a]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Int
1 .. Int
n] ((Int -> ST s a) -> ST s [a]) -> (Int -> ST s a) -> ST s [a]
forall a b. (a -> b) -> a -> b
$ \Int
i -> do
Int
j <- (Int, Int) -> ST s Int
randomRST (Int
i, Int
n)
a
vi <- STArray s Int a -> Int -> ST s a
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray STArray s Int a
ar Int
i
a
vj <- STArray s Int a -> Int -> ST s a
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray STArray s Int a
ar Int
j
STArray s Int a -> Int -> a -> ST s ()
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
writeArray STArray s Int a
ar Int
j a
vi
a -> ST s a
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return a
vj
StdGen
gen' <- STRef s StdGen -> ST s StdGen
forall s a. STRef s a -> ST s a
readSTRef STRef s StdGen
g
([a], StdGen) -> ST s ([a], StdGen)
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return ([a]
xs', StdGen
gen')
)
where
n :: Int
n = [a] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
Prelude.length [a]
xs
newArray :: Int -> [a] -> ST s (STArray s Int a)
newArray :: forall a s. Int -> [a] -> ST s (STArray s Int a)
newArray Int
n [a]
xs = (Int, Int) -> [a] -> ST s (STArray s Int a)
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
(i, i) -> [e] -> m (a i e)
newListArray (Int
1, Int
n) [a]
xs