{-# Language MultiParamTypeClasses #-}
{-# LANGUAGE FlexibleInstances #-}
module MXNet.NN.DataIter.LazyVec where
import Prelude hiding (zip)
import Data.Vector (Vector)
import qualified Data.Vector as V
import qualified Data.Vector.Mutable as VM
import Data.IORef
import Control.Monad
import Control.Monad.IO.Class
import Control.Exception.Base (assert)
import MXNet.NN.DataIter.Class
data Lazy a = Direct a | Make (() -> IO a)
instance Functor Lazy where
fmap f (Direct a) = Direct (f a)
fmap f (Make g) = Make (g >=> return . f)
force :: Lazy a -> IO a
force (Direct a) = return a
force (Make f) = f ()
data LVec a = LVec { size :: Int, unLVec :: Lazy (Vector a)}
fromVec :: Vector a -> LVec a
fromVec v = LVec (V.length v) (Direct v)
toVec :: LVec a -> IO (Vector a)
toVec = force . unLVec
batch :: Int -> LVec a -> IO (LVec (Vector a))
batch chunksize vec = do
pos <- newIORef 0
return $ case unLVec vec of
Direct v -> makeChunk pos v
Make f -> LVec new_vec_size . Make $ f >=> (toVec . makeChunk pos)
where
total = size vec
(quotient, remainder) = divMod total chunksize
new_vec_size = if remainder > 0 then quotient + 1 else quotient
makeChunk cur_pos vector =
assert (V.length vector == total) $
LVec new_vec_size . Make $ \_ -> do
vec' <- VM.new new_vec_size
forM_ [0..new_vec_size-1] $ \ i -> do
j <- readIORef cur_pos
if j + chunksize >= total
then do
let rst = total - j
rnd = chunksize - rst
VM.write vec' i (V.slice j rst vector V.++ V.slice 0 rnd vector)
writeIORef cur_pos rnd
else do
VM.write vec' i (V.slice j chunksize vector)
writeIORef cur_pos (j+chunksize)
V.freeze vec'
zip :: LVec a -> LVec b -> LVec (a,b)
zip (LVec n1 (Direct a)) (LVec n2 (Direct b)) = LVec (min n1 n2) (Direct (V.zip a b))
zip (LVec n1 (Direct a)) (LVec n2 (Make f)) = LVec (min n1 n2) (Make (f >=> return . V.zip a))
zip (LVec n1 (Make f)) (LVec n2 (Direct b)) = LVec (min n1 n2) (Make (f >=> return . flip V.zip b))
zip (LVec n1 (Make f)) (LVec n2 (Make g)) = LVec (min n1 n2) (Make (\_ -> liftM2 V.zip (f ()) (g ())))
map :: (a -> IO b) -> LVec a -> LVec b
map f v = case fmap (V.mapM f) (unLVec v) of
Direct a -> LVec (size v) $ Make (\_ -> a)
Make m -> LVec (size v) $ Make (join . m)
type instance DatasetConstraint LVec m = MonadIO m
instance Dataset LVec where
fromListD = fromVec . V.fromList
zipD = zip
sizeD dat = return $ size dat
forEachD dat proc = do
vec <- liftIO $ toVec dat
ret <- V.mapM proc vec
return $ V.toList ret
forEachD_i dat = forEachD (zipD (fromListD [1..size dat]) dat)