{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE PartialTypeSignatures #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeOperators #-}
module Torch.Data.Pipeline
(
Dataset (..),
DatasetOptions (..),
datasetOpts,
Sample (..),
streamFromMap,
)
where
import Control.Concurrent.Async.Lifted
import Control.Concurrent.STM hiding (atomically)
import Control.Monad
import Control.Monad.Base (MonadBase)
import Control.Monad.Cont (ContT)
import Control.Monad.Trans.Control (MonadBaseControl (..))
import Data.IntMap (IntMap)
import qualified Data.IntMap as I
import Data.Set
import Pipes
import Pipes.Concurrent hiding (atomically)
import System.Random
import Torch.Data.Internal
class (Ord k) => Dataset m dataset k sample | dataset -> m, dataset -> sample, dataset -> k where
getItem :: dataset -> k -> m sample
keys :: dataset -> Set k
data DatasetOptions = DatasetOptions
{
DatasetOptions -> Int
dataBufferSize :: Int,
DatasetOptions -> Int
numWorkers :: Int,
DatasetOptions -> Sample
shuffle :: Sample
}
datasetOpts :: Int -> DatasetOptions
datasetOpts :: Int -> DatasetOptions
datasetOpts Int
numWorkers =
DatasetOptions
{ dataBufferSize :: Int
dataBufferSize = Int
numWorkers,
numWorkers :: Int
numWorkers = Int
numWorkers,
shuffle :: Sample
shuffle = Sample
Sequential
}
data Sample where
Sequential :: Sample
Shuffle :: RandomGen g => g -> Sample
streamFromMap ::
forall m dataset k sample r.
(Dataset m dataset k sample, MonadIO m, MonadBaseControl IO m) =>
DatasetOptions ->
dataset ->
ContT r m (ListT m sample, Sample)
streamFromMap :: forall (m :: * -> *) dataset k sample r.
(Dataset m dataset k sample, MonadIO m, MonadBaseControl IO m) =>
DatasetOptions -> dataset -> ContT r m (ListT m sample, Sample)
streamFromMap DatasetOptions {Int
Sample
dataBufferSize :: DatasetOptions -> Int
numWorkers :: DatasetOptions -> Int
shuffle :: DatasetOptions -> Sample
dataBufferSize :: Int
numWorkers :: Int
shuffle :: Sample
..} dataset
dataset = do
(Output (k, TVar (Maybe sample))
keyOutput, Input (k, TVar (Maybe sample))
keyInput, STM ()
seal) <- IO
(Output (k, TVar (Maybe sample)), Input (k, TVar (Maybe sample)),
STM ())
-> ContT
r
m
(Output (k, TVar (Maybe sample)), Input (k, TVar (Maybe sample)),
STM ())
forall a. IO a -> ContT r m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO
(Output (k, TVar (Maybe sample)), Input (k, TVar (Maybe sample)),
STM ())
-> ContT
r
m
(Output (k, TVar (Maybe sample)), Input (k, TVar (Maybe sample)),
STM ()))
-> IO
(Output (k, TVar (Maybe sample)), Input (k, TVar (Maybe sample)),
STM ())
-> ContT
r
m
(Output (k, TVar (Maybe sample)), Input (k, TVar (Maybe sample)),
STM ())
forall a b. (a -> b) -> a -> b
$ Buffer (k, TVar (Maybe sample))
-> IO
(Output (k, TVar (Maybe sample)), Input (k, TVar (Maybe sample)),
STM ())
forall a. Buffer a -> IO (Output a, Input a, STM ())
spawn' Buffer (k, TVar (Maybe sample))
forall a. Buffer a
unbounded
let retrieveSet :: ContT r m [(k, TVar (Maybe sample))]
retrieveSet = IO [(k, TVar (Maybe sample))]
-> ContT r m [(k, TVar (Maybe sample))]
forall a. IO a -> ContT r m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO [(k, TVar (Maybe sample))]
-> ContT r m [(k, TVar (Maybe sample))])
-> IO [(k, TVar (Maybe sample))]
-> ContT r m [(k, TVar (Maybe sample))]
forall a b. (a -> b) -> a -> b
$ Set k -> IO [(k, TVar (Maybe sample))]
forall (m :: * -> *) k sample.
MonadIO m =>
Set k -> m [(k, TVar (Maybe sample))]
keyTVarSet (Set k -> IO [(k, TVar (Maybe sample))])
-> Set k -> IO [(k, TVar (Maybe sample))]
forall a b. (a -> b) -> a -> b
$ dataset -> Set k
forall {k} (m :: k -> *) dataset k (sample :: k).
Dataset m dataset k sample =>
dataset -> Set k
keys dataset
dataset
([(k, TVar (Maybe sample))]
keyTVarSet, Sample
updatedSample) <- case Sample
shuffle of
Sample
Sequential -> (,Sample
Sequential) ([(k, TVar (Maybe sample))]
-> ([(k, TVar (Maybe sample))], Sample))
-> ContT r m [(k, TVar (Maybe sample))]
-> ContT r m ([(k, TVar (Maybe sample))], Sample)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ContT r m [(k, TVar (Maybe sample))]
retrieveSet
Shuffle g
g -> (g -> Sample)
-> ([(k, TVar (Maybe sample))], g)
-> ([(k, TVar (Maybe sample))], Sample)
forall a b.
(a -> b)
-> ([(k, TVar (Maybe sample))], a)
-> ([(k, TVar (Maybe sample))], b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap g -> Sample
forall g. RandomGen g => g -> Sample
Shuffle (([(k, TVar (Maybe sample))], g)
-> ([(k, TVar (Maybe sample))], Sample))
-> ([(k, TVar (Maybe sample))] -> ([(k, TVar (Maybe sample))], g))
-> [(k, TVar (Maybe sample))]
-> ([(k, TVar (Maybe sample))], Sample)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. g -> [(k, TVar (Maybe sample))] -> ([(k, TVar (Maybe sample))], g)
forall g a. RandomGen g => g -> [a] -> ([a], g)
fisherYates g
g ([(k, TVar (Maybe sample))]
-> ([(k, TVar (Maybe sample))], Sample))
-> ContT r m [(k, TVar (Maybe sample))]
-> ContT r m ([(k, TVar (Maybe sample))], Sample)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ContT r m [(k, TVar (Maybe sample))]
retrieveSet
Output (k, TVar (Maybe sample))
-> [(k, TVar (Maybe sample))] -> ContT r m ()
forall (m :: * -> *) k sample.
MonadBase IO m =>
Output (k, TVar (Maybe sample))
-> [(k, TVar (Maybe sample))] -> m ()
keyQueue Output (k, TVar (Maybe sample))
keyOutput [(k, TVar (Maybe sample))]
keyTVarSet
IO () -> ContT r m ()
forall a. IO a -> ContT r m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ContT r m ()) -> IO () -> ContT r m ()
forall a b. (a -> b) -> a -> b
$ STM () -> IO ()
forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically STM ()
seal
let workers :: m ()
workers = Int -> dataset -> Input (k, TVar (Maybe sample)) -> m ()
forall (m :: * -> *) dataset k sample.
(Dataset m dataset k sample, MonadIO m, MonadBaseControl IO m) =>
Int -> dataset -> Input (k, TVar (Maybe sample)) -> m ()
runWorkers Int
numWorkers dataset
dataset Input (k, TVar (Maybe sample))
keyInput
datastream :: Output sample -> m ()
datastream = [(k, TVar (Maybe sample))] -> Output sample -> m ()
forall (m :: * -> *) k sample.
(MonadBase IO m, MonadIO m) =>
[(k, TVar (Maybe sample))] -> Output sample -> m ()
awaitNextItem [(k, TVar (Maybe sample))]
keyTVarSet
ListT m sample
listT <- Int -> (Output sample -> m ()) -> ContT r m (ListT m sample)
forall a (m :: * -> *) b.
MonadBaseControl IO m =>
Int -> (Output a -> m ()) -> ContT b m (ListT m a)
runWithBuffer Int
dataBufferSize ((Output sample -> m ()) -> ContT r m (ListT m sample))
-> (Output sample -> m ()) -> ContT r m (ListT m sample)
forall a b. (a -> b) -> a -> b
$ \Output sample
output -> m () -> m () -> m ()
forall (m :: * -> *) a b.
MonadBaseControl IO m =>
m a -> m b -> m ()
concurrently_ m ()
workers (Output sample -> m ()
datastream Output sample
output)
(ListT m sample, Sample) -> ContT r m (ListT m sample, Sample)
forall a. a -> ContT r m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ListT m sample
listT, Sample
updatedSample)
runWorkers ::
(Dataset m dataset k sample, MonadIO m, MonadBaseControl IO m) =>
Int ->
dataset ->
Input (k, TVar (Maybe sample)) ->
m ()
runWorkers :: forall (m :: * -> *) dataset k sample.
(Dataset m dataset k sample, MonadIO m, MonadBaseControl IO m) =>
Int -> dataset -> Input (k, TVar (Maybe sample)) -> m ()
runWorkers Int
numWorkers dataset
dataset Input (k, TVar (Maybe sample))
keyInput = Int -> m () -> m ()
forall (m :: * -> *) a. MonadBaseControl IO m => Int -> m a -> m ()
replicateConcurrently_ Int
numWorkers (Effect m () -> m ()
forall (m :: * -> *) r. Monad m => Effect m r -> m r
runEffect (Effect m () -> m ()) -> Effect m () -> m ()
forall a b. (a -> b) -> a -> b
$ Input (k, TVar (Maybe sample))
-> Producer' (k, TVar (Maybe sample)) m ()
forall (m :: * -> *) a.
MonadBase IO m =>
Input a -> Producer' a m ()
fromInput' Input (k, TVar (Maybe sample))
keyInput Proxy X () () (k, TVar (Maybe sample)) m ()
-> Proxy () (k, TVar (Maybe sample)) () X m () -> Effect m ()
forall (m :: * -> *) a' a b r c' c.
Functor m =>
Proxy a' a () b m r -> Proxy () b c' c m r -> Proxy a' a c' c m r
>-> Proxy () (k, TVar (Maybe sample)) () X m ()
runWorker)
where
runWorker :: Proxy () (k, TVar (Maybe sample)) () X m ()
runWorker = Proxy () (k, TVar (Maybe sample)) () X m ()
-> Proxy () (k, TVar (Maybe sample)) () X m ()
forall (f :: * -> *) a b. Applicative f => f a -> f b
forever (Proxy () (k, TVar (Maybe sample)) () X m ()
-> Proxy () (k, TVar (Maybe sample)) () X m ())
-> Proxy () (k, TVar (Maybe sample)) () X m ()
-> Proxy () (k, TVar (Maybe sample)) () X m ()
forall a b. (a -> b) -> a -> b
$ do
(k
key, TVar (Maybe sample)
tvar) <- Proxy () (k, TVar (Maybe sample)) () X m (k, TVar (Maybe sample))
Consumer' (k, TVar (Maybe sample)) m (k, TVar (Maybe sample))
forall (m :: * -> *) a. Functor m => Consumer' a m a
await
sample
item <- m sample -> Proxy () (k, TVar (Maybe sample)) () X m sample
forall (m :: * -> *) a.
Monad m =>
m a -> Proxy () (k, TVar (Maybe sample)) () X m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m sample -> Proxy () (k, TVar (Maybe sample)) () X m sample)
-> m sample -> Proxy () (k, TVar (Maybe sample)) () X m sample
forall a b. (a -> b) -> a -> b
$ dataset -> k -> m sample
forall {k} (m :: k -> *) dataset k (sample :: k).
Dataset m dataset k sample =>
dataset -> k -> m sample
getItem dataset
dataset k
key
STM () -> Proxy () (k, TVar (Maybe sample)) () X m ()
forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically (STM () -> Proxy () (k, TVar (Maybe sample)) () X m ())
-> STM () -> Proxy () (k, TVar (Maybe sample)) () X m ()
forall a b. (a -> b) -> a -> b
$ TVar (Maybe sample) -> Maybe sample -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar TVar (Maybe sample)
tvar (sample -> Maybe sample
forall a. a -> Maybe a
Just sample
item)
awaitNextItem ::
(MonadBase IO m, MonadIO m) =>
[(k, TVar (Maybe sample))] ->
Output sample ->
m ()
awaitNextItem :: forall (m :: * -> *) k sample.
(MonadBase IO m, MonadIO m) =>
[(k, TVar (Maybe sample))] -> Output sample -> m ()
awaitNextItem [(k, TVar (Maybe sample))]
tvars Output sample
output = Effect m () -> m ()
forall (m :: * -> *) r. Monad m => Effect m r -> m r
runEffect (Effect m () -> m ()) -> Effect m () -> m ()
forall a b. (a -> b) -> a -> b
$ [(k, TVar (Maybe sample))]
-> Proxy X () () (k, TVar (Maybe sample)) m ()
forall (m :: * -> *) (f :: * -> *) a x' x.
(Functor m, Foldable f) =>
f a -> Proxy x' x () a m ()
each [(k, TVar (Maybe sample))]
tvars Proxy X () () (k, TVar (Maybe sample)) m ()
-> Proxy () (k, TVar (Maybe sample)) () sample m ()
-> Proxy X () () sample m ()
forall (m :: * -> *) a' a b r c' c.
Functor m =>
Proxy a' a () b m r -> Proxy () b c' c m r -> Proxy a' a c' c m r
>-> Proxy () (k, TVar (Maybe sample)) () sample m ()
forall {a} {y} {b}. Proxy () (a, TVar (Maybe y)) () y m b
readNextItem Proxy X () () sample m ()
-> Proxy () sample () X m () -> Effect m ()
forall (m :: * -> *) a' a b r c' c.
Functor m =>
Proxy a' a () b m r -> Proxy () b c' c m r -> Proxy a' a c' c m r
>-> Output sample -> Consumer' sample m ()
forall (m :: * -> *) a.
MonadBase IO m =>
Output a -> Consumer' a m ()
toOutput' Output sample
output
where
readNextItem :: Proxy () (a, TVar (Maybe y)) () y m b
readNextItem = Proxy () (a, TVar (Maybe y)) () y m ()
-> Proxy () (a, TVar (Maybe y)) () y m b
forall (f :: * -> *) a b. Applicative f => f a -> f b
forever (Proxy () (a, TVar (Maybe y)) () y m ()
-> Proxy () (a, TVar (Maybe y)) () y m b)
-> Proxy () (a, TVar (Maybe y)) () y m ()
-> Proxy () (a, TVar (Maybe y)) () y m b
forall a b. (a -> b) -> a -> b
$ do
(a
_, TVar (Maybe y)
tvar) <- Proxy () (a, TVar (Maybe y)) () y m (a, TVar (Maybe y))
Consumer' (a, TVar (Maybe y)) m (a, TVar (Maybe y))
forall (m :: * -> *) a. Functor m => Consumer' a m a
await
y
item <- STM y -> Proxy () (a, TVar (Maybe y)) () y m y
forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically (STM y -> Proxy () (a, TVar (Maybe y)) () y m y)
-> STM y -> Proxy () (a, TVar (Maybe y)) () y m y
forall a b. (a -> b) -> a -> b
$ do
Maybe y
val <- TVar (Maybe y) -> STM (Maybe y)
forall a. TVar a -> STM a
readTVar TVar (Maybe y)
tvar
case Maybe y
val of
Maybe y
Nothing -> STM y
forall a. STM a
retry
Just y
item -> TVar (Maybe y) -> Maybe y -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar TVar (Maybe y)
tvar Maybe y
forall a. Maybe a
Nothing STM () -> STM y -> STM y
forall a b. STM a -> STM b -> STM b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> y -> STM y
forall a. a -> STM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure y
item
y -> Proxy () (a, TVar (Maybe y)) () y m ()
forall (m :: * -> *) a x' x. Functor m => a -> Proxy x' x () a m ()
yield y
item
keyTVarSet :: MonadIO m => Set k -> m [(k, TVar (Maybe sample))]
keyTVarSet :: forall (m :: * -> *) k sample.
MonadIO m =>
Set k -> m [(k, TVar (Maybe sample))]
keyTVarSet = STM [(k, TVar (Maybe sample))] -> m [(k, TVar (Maybe sample))]
forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically (STM [(k, TVar (Maybe sample))] -> m [(k, TVar (Maybe sample))])
-> (Set k -> STM [(k, TVar (Maybe sample))])
-> Set k
-> m [(k, TVar (Maybe sample))]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (k -> STM (k, TVar (Maybe sample)))
-> [k] -> STM [(k, TVar (Maybe sample))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\k
k -> (,) k
k (TVar (Maybe sample) -> (k, TVar (Maybe sample)))
-> STM (TVar (Maybe sample)) -> STM (k, TVar (Maybe sample))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe sample -> STM (TVar (Maybe sample))
forall a. a -> STM (TVar a)
newTVar Maybe sample
forall a. Maybe a
Nothing) ([k] -> STM [(k, TVar (Maybe sample))])
-> (Set k -> [k]) -> Set k -> STM [(k, TVar (Maybe sample))]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Set k -> [k]
forall a. Set a -> [a]
toList
keyQueue :: MonadBase IO m => Output (k, TVar (Maybe sample)) -> [(k, TVar (Maybe sample))] -> m ()
keyQueue :: forall (m :: * -> *) k sample.
MonadBase IO m =>
Output (k, TVar (Maybe sample))
-> [(k, TVar (Maybe sample))] -> m ()
keyQueue Output (k, TVar (Maybe sample))
keyOutput [(k, TVar (Maybe sample))]
keyTVarSet = Effect m () -> m ()
forall (m :: * -> *) r. Monad m => Effect m r -> m r
runEffect (Effect m () -> m ()) -> Effect m () -> m ()
forall a b. (a -> b) -> a -> b
$ [(k, TVar (Maybe sample))]
-> Proxy X () () (k, TVar (Maybe sample)) m ()
forall (m :: * -> *) (f :: * -> *) a x' x.
(Functor m, Foldable f) =>
f a -> Proxy x' x () a m ()
each [(k, TVar (Maybe sample))]
keyTVarSet Proxy X () () (k, TVar (Maybe sample)) m ()
-> Proxy () (k, TVar (Maybe sample)) () X m () -> Effect m ()
forall (m :: * -> *) a' a b r c' c.
Functor m =>
Proxy a' a () b m r -> Proxy () b c' c m r -> Proxy a' a c' c m r
>-> Output (k, TVar (Maybe sample))
-> Consumer' (k, TVar (Maybe sample)) m ()
forall (m :: * -> *) a.
MonadBase IO m =>
Output a -> Consumer' a m ()
toOutput' Output (k, TVar (Maybe sample))
keyOutput
fisherYatesStep :: RandomGen g => (IntMap a, g) -> (Int, a) -> (IntMap a, g)
fisherYatesStep :: forall g a.
RandomGen g =>
(IntMap a, g) -> (Int, a) -> (IntMap a, g)
fisherYatesStep (IntMap a
m, g
gen) (Int
i, a
x) = ((Int -> a -> IntMap a -> IntMap a
forall a. Int -> a -> IntMap a -> IntMap a
I.insert Int
j a
x (IntMap a -> IntMap a)
-> (IntMap a -> IntMap a) -> IntMap a -> IntMap a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> a -> IntMap a -> IntMap a
forall a. Int -> a -> IntMap a -> IntMap a
I.insert Int
i (IntMap a
m IntMap a -> Int -> a
forall a. IntMap a -> Int -> a
I.! Int
j)) IntMap a
m, g
gen')
where
(Int
j, g
gen') = (Int, Int) -> g -> (Int, g)
forall g. RandomGen g => (Int, Int) -> g -> (Int, g)
forall a g. (Random a, RandomGen g) => (a, a) -> g -> (a, g)
randomR (Int
0, Int
i) g
gen
fisherYates :: RandomGen g => g -> [a] -> ([a], g)
fisherYates :: forall g a. RandomGen g => g -> [a] -> ([a], g)
fisherYates g
gen [] = ([], g
gen)
fisherYates g
gen [a]
l =
(IntMap a, g) -> ([a], g)
forall {a} {b}. (IntMap a, b) -> ([a], b)
toElems ((IntMap a, g) -> ([a], g)) -> (IntMap a, g) -> ([a], g)
forall a b. (a -> b) -> a -> b
$ ((IntMap a, g) -> (Int, a) -> (IntMap a, g))
-> (IntMap a, g) -> [(Int, a)] -> (IntMap a, g)
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
Prelude.foldl (IntMap a, g) -> (Int, a) -> (IntMap a, g)
forall g a.
RandomGen g =>
(IntMap a, g) -> (Int, a) -> (IntMap a, g)
fisherYatesStep (a -> g -> (IntMap a, g)
forall {a} {b}. a -> b -> (IntMap a, b)
initial ([a] -> a
forall a. HasCallStack => [a] -> a
head [a]
l) g
gen) ([a] -> [(Int, a)]
forall {b}. [b] -> [(Int, b)]
numerate ([a] -> [a]
forall a. HasCallStack => [a] -> [a]
tail [a]
l))
where
toElems :: (IntMap a, b) -> ([a], b)
toElems (IntMap a
x, b
y) = (IntMap a -> [a]
forall a. IntMap a -> [a]
I.elems IntMap a
x, b
y)
numerate :: [b] -> [(Int, b)]
numerate = [Int] -> [b] -> [(Int, b)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
1 ..]
initial :: a -> b -> (IntMap a, b)
initial a
x b
gen = (Int -> a -> IntMap a
forall a. Int -> a -> IntMap a
I.singleton Int
0 a
x, b
gen)