{-# Language MultiParamTypeClasses #-} {-# LANGUAGE FlexibleInstances #-} module MXNet.NN.DataIter.LazyVec where import Prelude hiding (zip) import Data.Vector (Vector) import qualified Data.Vector as V import qualified Data.Vector.Mutable as VM import Data.IORef import Control.Monad import Control.Monad.IO.Class import Control.Exception.Base (assert) import MXNet.NN.DataIter.Class data Lazy a = Direct a | Make (() -> IO a) instance Functor Lazy where fmap f (Direct a) = Direct (f a) fmap f (Make g) = Make (g >=> return . f) force :: Lazy a -> IO a force (Direct a) = return a force (Make f) = f () data LVec a = LVec { size :: Int, unLVec :: Lazy (Vector a)} fromVec :: Vector a -> LVec a fromVec v = LVec (V.length v) (Direct v) toVec :: LVec a -> IO (Vector a) toVec = force . unLVec batch :: Int -> LVec a -> IO (LVec (Vector a)) batch chunksize vec = do pos <- newIORef 0 return $ case unLVec vec of Direct v -> makeChunk pos v Make f -> LVec new_vec_size . Make $ f >=> (toVec . makeChunk pos) where total = size vec (quotient, remainder) = divMod total chunksize new_vec_size = if remainder > 0 then quotient + 1 else quotient makeChunk cur_pos vector = assert (V.length vector == total) $ LVec new_vec_size . Make $ \_ -> do vec' <- VM.new new_vec_size forM_ [0..new_vec_size-1] $ \ i -> do j <- readIORef cur_pos if j + chunksize >= total then do let rst = total - j rnd = chunksize - rst VM.write vec' i (V.slice j rst vector V.++ V.slice 0 rnd vector) writeIORef cur_pos rnd else do VM.write vec' i (V.slice j chunksize vector) writeIORef cur_pos (j+chunksize) V.freeze vec' zip :: LVec a -> LVec b -> LVec (a,b) zip (LVec n1 (Direct a)) (LVec n2 (Direct b)) = LVec (min n1 n2) (Direct (V.zip a b)) zip (LVec n1 (Direct a)) (LVec n2 (Make f)) = LVec (min n1 n2) (Make (f >=> return . V.zip a)) zip (LVec n1 (Make f)) (LVec n2 (Direct b)) = LVec (min n1 n2) (Make (f >=> return . flip V.zip b)) zip (LVec n1 (Make f)) (LVec n2 (Make g)) = LVec (min n1 n2) (Make (\_ -> liftM2 V.zip (f ()) (g ()))) map :: (a -> IO b) -> LVec a -> LVec b map f v = case fmap (V.mapM f) (unLVec v) of Direct a -> LVec (size v) $ Make (\_ -> a) Make m -> LVec (size v) $ Make (join . m) type instance DatasetConstraint LVec m = MonadIO m instance Dataset LVec where fromListD = fromVec . V.fromList zipD = zip sizeD dat = return $ size dat forEachD dat proc = do vec <- liftIO $ toVec dat ret <- V.mapM proc vec return $ V.toList ret -- LVec does not support infinite stream, so we override the -- default implementations forEachD_i dat = forEachD (zipD (fromListD [1..size dat]) dat)