{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RecordWildCards #-}

module Torch.Data.Utils
  ( pmap,
    pmap',
    pmapGroup,
    bufferedCollate,
    collate,
    enumerateData,
    CachedDataset,
    cache,
  )
where

import qualified Control.Foldl as L
import Control.Monad.Cont
import Control.Monad.Trans.Control
import Data.Kind (Type)
import Data.IntMap.Strict (IntMap)
import qualified Data.IntMap.Strict as M
import qualified Data.Set as S
import Lens.Family
import Pipes
import Pipes.Concurrent
import Pipes.Group
import qualified Pipes.Prelude as P
import Torch.Data.Internal
import Torch.Data.Pipeline

-- | Run a map function in parallel over the given stream.
pmap :: (MonadIO m, MonadBaseControl IO m) => Buffer b -> (a -> b) -> ListT m a -> ContT r m (ListT m b)
pmap :: forall (m :: * -> *) b a r.
(MonadIO m, MonadBaseControl IO m) =>
Buffer b -> (a -> b) -> ListT m a -> ContT r m (ListT m b)
pmap Buffer b
buffer a -> b
f ListT m a
prod = ((ListT m b -> m r) -> m r) -> ContT r m (ListT m b)
forall {k} (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((ListT m b -> m r) -> m r) -> ContT r m (ListT m b))
-> ((ListT m b -> m r) -> m r) -> ContT r m (ListT m b)
forall a b. (a -> b) -> a -> b
$ \ListT m b -> m r
cont ->
  ((), r) -> r
forall a b. (a, b) -> b
snd
    (((), r) -> r) -> m ((), r) -> m r
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Buffer b -> (Output b -> m ()) -> (Input b -> m r) -> m ((), r)
forall (m :: * -> *) a l r.
MonadBaseControl IO m =>
Buffer a -> (Output a -> m l) -> (Input a -> m r) -> m (l, r)
withBufferLifted
      Buffer b
buffer
      (\Output b
output -> Effect m () -> m ()
forall (m :: * -> *) r. Monad m => Effect m r -> m r
runEffect (Effect m () -> m ()) -> Effect m () -> m ()
forall a b. (a -> b) -> a -> b
$ ListT m a -> Producer a m ()
forall (m :: * -> *) a. ListT m a -> Producer a m ()
enumerate ListT m a
prod Producer a m () -> Proxy () a () b m () -> Proxy X () () b m ()
forall (m :: * -> *) a' a b r c' c.
Functor m =>
Proxy a' a () b m r -> Proxy () b c' c m r -> Proxy a' a c' c m r
>-> (a -> b) -> Proxy () a () b m ()
forall (m :: * -> *) a b r. Functor m => (a -> b) -> Pipe a b m r
P.map a -> b
f Proxy X () () b m () -> Proxy () b () X m () -> Effect m ()
forall (m :: * -> *) a' a b r c' c.
Functor m =>
Proxy a' a () b m r -> Proxy () b c' c m r -> Proxy a' a c' c m r
>-> Output b -> Consumer' b m ()
forall (m :: * -> *) a.
MonadBase IO m =>
Output a -> Consumer' a m ()
toOutput' Output b
output)
      (\Input b
input -> ListT m b -> m r
cont (ListT m b -> m r) -> ListT m b -> m r
forall a b. (a -> b) -> a -> b
$ Proxy X () () b m () -> ListT m b
forall (m :: * -> *) a. Producer a m () -> ListT m a
Select (Proxy X () () b m () -> ListT m b)
-> Proxy X () () b m () -> ListT m b
forall a b. (a -> b) -> a -> b
$ Input b -> Producer' b m ()
forall (m :: * -> *) a.
MonadBase IO m =>
Input a -> Producer' a m ()
fromInput' Input b
input)

-- | Run a pipe in parallel over the given stream.
pmap' :: (MonadIO m, MonadBaseControl IO m) => Buffer b -> Pipe a b m () -> ListT m a -> ContT r m (ListT m b)
pmap' :: forall (m :: * -> *) b a r.
(MonadIO m, MonadBaseControl IO m) =>
Buffer b -> Pipe a b m () -> ListT m a -> ContT r m (ListT m b)
pmap' Buffer b
buffer Pipe a b m ()
f ListT m a
prod = ((ListT m b -> m r) -> m r) -> ContT r m (ListT m b)
forall {k} (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((ListT m b -> m r) -> m r) -> ContT r m (ListT m b))
-> ((ListT m b -> m r) -> m r) -> ContT r m (ListT m b)
forall a b. (a -> b) -> a -> b
$ \ListT m b -> m r
cont ->
  ((), r) -> r
forall a b. (a, b) -> b
snd
    (((), r) -> r) -> m ((), r) -> m r
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Buffer b -> (Output b -> m ()) -> (Input b -> m r) -> m ((), r)
forall (m :: * -> *) a l r.
MonadBaseControl IO m =>
Buffer a -> (Output a -> m l) -> (Input a -> m r) -> m (l, r)
withBufferLifted
      Buffer b
buffer
      (\Output b
output -> Effect m () -> m ()
forall (m :: * -> *) r. Monad m => Effect m r -> m r
runEffect (Effect m () -> m ()) -> Effect m () -> m ()
forall a b. (a -> b) -> a -> b
$ ListT m a -> Producer a m ()
forall (m :: * -> *) a. ListT m a -> Producer a m ()
enumerate ListT m a
prod Producer a m () -> Pipe a b m () -> Proxy X () () b m ()
forall (m :: * -> *) a' a b r c' c.
Functor m =>
Proxy a' a () b m r -> Proxy () b c' c m r -> Proxy a' a c' c m r
>-> Pipe a b m ()
f Proxy X () () b m () -> Proxy () b () X m () -> Effect m ()
forall (m :: * -> *) a' a b r c' c.
Functor m =>
Proxy a' a () b m r -> Proxy () b c' c m r -> Proxy a' a c' c m r
>-> Output b -> Consumer' b m ()
forall (m :: * -> *) a.
MonadBase IO m =>
Output a -> Consumer' a m ()
toOutput' Output b
output)
      (\Input b
input -> ListT m b -> m r
cont (ListT m b -> m r) -> ListT m b -> m r
forall a b. (a -> b) -> a -> b
$ Proxy X () () b m () -> ListT m b
forall (m :: * -> *) a. Producer a m () -> ListT m a
Select (Proxy X () () b m () -> ListT m b)
-> Proxy X () () b m () -> ListT m b
forall a b. (a -> b) -> a -> b
$ Input b -> Producer' b m ()
forall (m :: * -> *) a.
MonadBase IO m =>
Input a -> Producer' a m ()
fromInput' Input b
input)

-- | Map a ListT transform over the given the stream in parallel. This should be useful
-- for using functions which groups elements of a stream and yields them downstream.
pmapGroup :: (MonadIO m, MonadBaseControl IO m) => Buffer b -> (ListT m a -> ListT m b) -> ListT m a -> ContT r m (ListT m b)
pmapGroup :: forall (m :: * -> *) b a r.
(MonadIO m, MonadBaseControl IO m) =>
Buffer b
-> (ListT m a -> ListT m b) -> ListT m a -> ContT r m (ListT m b)
pmapGroup Buffer b
buffer ListT m a -> ListT m b
f ListT m a
prod = ((ListT m b -> m r) -> m r) -> ContT r m (ListT m b)
forall {k} (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((ListT m b -> m r) -> m r) -> ContT r m (ListT m b))
-> ((ListT m b -> m r) -> m r) -> ContT r m (ListT m b)
forall a b. (a -> b) -> a -> b
$ \ListT m b -> m r
cont ->
  ((), r) -> r
forall a b. (a, b) -> b
snd
    (((), r) -> r) -> m ((), r) -> m r
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Buffer b -> (Output b -> m ()) -> (Input b -> m r) -> m ((), r)
forall (m :: * -> *) a l r.
MonadBaseControl IO m =>
Buffer a -> (Output a -> m l) -> (Input a -> m r) -> m (l, r)
withBufferLifted
      Buffer b
buffer
      (\Output b
output -> Effect m () -> m ()
forall (m :: * -> *) r. Monad m => Effect m r -> m r
runEffect (Effect m () -> m ()) -> Effect m () -> m ()
forall a b. (a -> b) -> a -> b
$ ListT m b -> Producer b m ()
forall (m :: * -> *) a. ListT m a -> Producer a m ()
enumerate (ListT m a -> ListT m b
f ListT m a
prod) Producer b m () -> Proxy () b () X m () -> Effect m ()
forall (m :: * -> *) a' a b r c' c.
Functor m =>
Proxy a' a () b m r -> Proxy () b c' c m r -> Proxy a' a c' c m r
>-> Output b -> Consumer' b m ()
forall (m :: * -> *) a.
MonadBase IO m =>
Output a -> Consumer' a m ()
toOutput' Output b
output)
      (\Input b
input -> ListT m b -> m r
cont (ListT m b -> m r) -> ListT m b -> m r
forall a b. (a -> b) -> a -> b
$ Producer b m () -> ListT m b
forall (m :: * -> *) a. Producer a m () -> ListT m a
Select (Producer b m () -> ListT m b) -> Producer b m () -> ListT m b
forall a b. (a -> b) -> a -> b
$ Input b -> Producer' b m ()
forall (m :: * -> *) a.
MonadBase IO m =>
Input a -> Producer' a m ()
fromInput' Input b
input)

-- | Enumerate the given stream, zipping each element with an index.
enumerateData :: Monad m => ListT m a -> Producer (a, Int) m ()
enumerateData :: forall (m :: * -> *) a.
Monad m =>
ListT m a -> Producer (a, Int) m ()
enumerateData ListT m a
input = Producer a m () -> Producer Int m () -> Proxy X () () (a, Int) m ()
forall (m :: * -> *) a r b x' x.
Monad m =>
Producer a m r -> Producer b m r -> Proxy x' x () (a, b) m r
P.zip (ListT m a -> Producer a m ()
forall (m :: * -> *) a. ListT m a -> Producer a m ()
enumerate ListT m a
input) ([Int] -> Producer Int m ()
forall (m :: * -> *) (f :: * -> *) a x' x.
(Functor m, Foldable f) =>
f a -> Proxy x' x () a m ()
each [Int
0 ..])

-- | Run a given batching function in parallel. See 'collate' for how the
-- given samples are batched.
bufferedCollate :: (MonadIO m, MonadBaseControl IO m) => Buffer batch -> Int -> ([sample] -> Maybe batch) -> ListT m sample -> ContT r m (ListT m batch)
bufferedCollate :: forall (m :: * -> *) batch sample r.
(MonadIO m, MonadBaseControl IO m) =>
Buffer batch
-> Int
-> ([sample] -> Maybe batch)
-> ListT m sample
-> ContT r m (ListT m batch)
bufferedCollate Buffer batch
buffer Int
batchSize [sample] -> Maybe batch
collateFn = Buffer batch
-> (ListT m sample -> ListT m batch)
-> ListT m sample
-> ContT r m (ListT m batch)
forall (m :: * -> *) b a r.
(MonadIO m, MonadBaseControl IO m) =>
Buffer b
-> (ListT m a -> ListT m b) -> ListT m a -> ContT r m (ListT m b)
pmapGroup Buffer batch
buffer (Int -> ([sample] -> Maybe batch) -> ListT m sample -> ListT m batch
forall (m :: * -> *) sample batch.
Monad m =>
Int -> ([sample] -> Maybe batch) -> ListT m sample -> ListT m batch
collate Int
batchSize [sample] -> Maybe batch
collateFn)

-- | Run a batching function with integer batch size over the given stream. The elements of the stream are
-- split into lists of the given batch size and are collated with the given function. Only Just values are yielded
-- downstream. If the last chunk of samples is less than the given batch size then the batching function will be passed a list
-- of length less than batch size.
collate :: Monad m => Int -> ([sample] -> Maybe batch) -> ListT m sample -> ListT m batch
collate :: forall (m :: * -> *) sample batch.
Monad m =>
Int -> ([sample] -> Maybe batch) -> ListT m sample -> ListT m batch
collate Int
batchSize [sample] -> Maybe batch
collateFn = Producer batch m () -> ListT m batch
forall (m :: * -> *) a. Producer a m () -> ListT m a
Select (Producer batch m () -> ListT m batch)
-> (ListT m sample -> Producer batch m ())
-> ListT m sample
-> ListT m batch
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Proxy X () () [sample] m ()
-> Proxy () [sample] () batch m () -> Producer batch m ()
forall (m :: * -> *) a' a b r c' c.
Functor m =>
Proxy a' a () b m r -> Proxy () b c' c m r -> Proxy a' a c' c m r
>-> ([sample] -> Maybe batch) -> Proxy () [sample] () batch m ()
forall (m :: * -> *) (t :: * -> *) a b r.
(Functor m, Foldable t) =>
(a -> t b) -> Pipe a b m r
P.mapFoldable [sample] -> Maybe batch
collateFn) (Proxy X () () [sample] m () -> Producer batch m ())
-> (ListT m sample -> Proxy X () () [sample] m ())
-> ListT m sample
-> Producer batch m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall x.
 (x -> sample -> x)
 -> x
 -> (x -> [sample])
 -> FreeT (Producer sample m) m ()
 -> Proxy X () () [sample] m ())
-> Fold sample [sample]
-> FreeT (Producer sample m) m ()
-> Proxy X () () [sample] m ()
forall a b r.
(forall x. (x -> a -> x) -> x -> (x -> b) -> r) -> Fold a b -> r
L.purely (x -> sample -> x)
-> x
-> (x -> [sample])
-> FreeT (Producer sample m) m ()
-> Proxy X () () [sample] m ()
forall x.
(x -> sample -> x)
-> x
-> (x -> [sample])
-> FreeT (Producer sample m) m ()
-> Proxy X () () [sample] m ()
forall (m :: * -> *) x a b r.
Monad m =>
(x -> a -> x)
-> x -> (x -> b) -> FreeT (Producer a m) m r -> Producer b m r
folds Fold sample [sample]
forall a. Fold a [a]
L.list (FreeT (Producer sample m) m () -> Proxy X () () [sample] m ())
-> (ListT m sample -> FreeT (Producer sample m) m ())
-> ListT m sample
-> Proxy X () () [sample] m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FoldLike
  (FreeT (Producer sample m) m ())
  (Producer sample m ())
  (Producer Any m ())
  (FreeT (Producer sample m) m ())
  (FreeT (Producer Any m) m ())
-> Producer sample m () -> FreeT (Producer sample m) m ()
forall a s t b. FoldLike a s t a b -> s -> a
view (Int
-> Lens
     (Producer sample m ())
     (Producer Any m ())
     (FreeT (Producer sample m) m ())
     (FreeT (Producer Any m) m ())
forall (m :: * -> *) a' x a.
Monad m =>
Int
-> Lens
     (Producer a' m x)
     (Producer a m x)
     (FreeT (Producer a' m) m x)
     (FreeT (Producer a m) m x)
chunksOf Int
batchSize) (Producer sample m () -> FreeT (Producer sample m) m ())
-> (ListT m sample -> Producer sample m ())
-> ListT m sample
-> FreeT (Producer sample m) m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ListT m sample -> Producer sample m ()
forall (m :: * -> *) a. ListT m a -> Producer a m ()
enumerate

-- | An In-Memory cached dataset. See the 'cache' function for
-- how to create a cached dataset.
newtype CachedDataset (m :: Type -> Type) sample = CachedDataset {forall (m :: * -> *) sample.
CachedDataset m sample -> IntMap sample
cached :: IntMap sample}

-- | Enumerate a given stream and store it as a 'CachedDataset'. This function should
-- be used after a time consuming preprocessing pipeline and used in subsequent epochs
-- to avoid repeating the preprocessing pipeline.
cache :: Monad m => ListT m sample -> m (CachedDataset m sample)
cache :: forall (m :: * -> *) sample.
Monad m =>
ListT m sample -> m (CachedDataset m sample)
cache ListT m sample
datastream = ((IntMap sample, Int) -> sample -> (IntMap sample, Int))
-> (IntMap sample, Int)
-> ((IntMap sample, Int) -> CachedDataset m sample)
-> Producer sample m ()
-> m (CachedDataset m sample)
forall (m :: * -> *) x a b.
Monad m =>
(x -> a -> x) -> x -> (x -> b) -> Producer a m () -> m b
P.fold (IntMap sample, Int) -> sample -> (IntMap sample, Int)
forall {a}. (IntMap a, Int) -> a -> (IntMap a, Int)
step (IntMap sample, Int)
forall {a}. (IntMap a, Int)
begin (IntMap sample, Int) -> CachedDataset m sample
forall {sample} {b} {m :: * -> *}.
(IntMap sample, b) -> CachedDataset m sample
done (Producer sample m () -> m (CachedDataset m sample))
-> (ListT m sample -> Producer sample m ())
-> ListT m sample
-> m (CachedDataset m sample)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ListT m sample -> Producer sample m ()
forall (m :: * -> *) a. ListT m a -> Producer a m ()
enumerate (ListT m sample -> m (CachedDataset m sample))
-> ListT m sample -> m (CachedDataset m sample)
forall a b. (a -> b) -> a -> b
$ ListT m sample
datastream
  where
    step :: (IntMap a, Int) -> a -> (IntMap a, Int)
step (IntMap a
cacheMap, Int
ix) a
sample = (Int -> a -> IntMap a -> IntMap a
forall a. Int -> a -> IntMap a -> IntMap a
M.insert Int
ix a
sample IntMap a
cacheMap, Int
ix Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
    begin :: (IntMap a, Int)
begin = (IntMap a
forall a. IntMap a
M.empty, Int
0)
    done :: (IntMap sample, b) -> CachedDataset m sample
done = IntMap sample -> CachedDataset m sample
forall (m :: * -> *) sample.
IntMap sample -> CachedDataset m sample
CachedDataset (IntMap sample -> CachedDataset m sample)
-> ((IntMap sample, b) -> IntMap sample)
-> (IntMap sample, b)
-> CachedDataset m sample
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (IntMap sample, b) -> IntMap sample
forall a b. (a, b) -> a
fst

instance Applicative m => Dataset m (CachedDataset m sample) Int sample where
  getItem :: CachedDataset m sample -> Int -> m sample
getItem CachedDataset {IntMap sample
cached :: forall (m :: * -> *) sample.
CachedDataset m sample -> IntMap sample
cached :: IntMap sample
..} Int
key = sample -> m sample
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (sample -> m sample) -> sample -> m sample
forall a b. (a -> b) -> a -> b
$ IntMap sample
cached IntMap sample -> Int -> sample
forall a. IntMap a -> Int -> a
M.! Int
key
  keys :: CachedDataset m sample -> Set Int
keys CachedDataset {IntMap sample
cached :: forall (m :: * -> *) sample.
CachedDataset m sample -> IntMap sample
cached :: IntMap sample
..} = [Int] -> Set Int
forall a. Eq a => [a] -> Set a
S.fromAscList [Int
0 .. IntMap sample -> Int
forall a. IntMap a -> Int
M.size IntMap sample
cached]