{-# LANGUAGE RecordWildCards #-}
module Numeric.SGD.DataSet
(
DataSet (..)
, shuffle
, loadData
, randomSample
, withVect
, withDisk
) where
import Control.Monad (forM_)
import qualified Control.Monad.State.Strict as S
import System.IO.Temp (withTempDirectory)
import System.IO.Unsafe (unsafeInterleaveIO)
import System.FilePath ((</>))
import qualified System.Random as R
import System.Random.Shuffle (shuffleM)
import Data.Binary (Binary, encodeFile, decode)
import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as BL
import qualified Data.Vector as V
import qualified Data.Map.Strict as M
data DataSet elem = DataSet
{ size :: Int
, elemAt :: Int -> IO elem
}
loadData :: DataSet a -> IO [a]
loadData DataSet{..} = lazyMapM elemAt [0 .. size - 1]
shuffle :: DataSet a -> IO (DataSet a)
shuffle DataSet{..} = do
let ixs = [0 .. size - 1]
ixs' <- shuffleM ixs
let m = M.fromList (zip ixs ixs')
return $ DataSet
{ size = size
, elemAt = elemAt . (m M.!)
}
randomSample :: Int -> DataSet a -> IO [a]
randomSample k dataSet
| k <= 0 = return []
| otherwise = do
ix <- R.randomRIO (0, size dataSet - 1)
x <- elemAt dataSet ix
(x:) <$> randomSample (k-1) dataSet
withVect :: [a] -> (DataSet a -> IO b) -> IO b
withVect xs handler =
handler dataset
where
v = V.fromList xs
dataset = DataSet
{ size = V.length v
, elemAt = \k -> return (v V.! k) }
withDisk :: Binary a => [a] -> (DataSet a -> IO b) -> IO b
withDisk xs handler = withTempDirectory "." ".sgd" $ \tmpDir -> do
n <- flip S.execStateT 0 $ forM_ (zip xs [0 :: Int ..]) $ \(x, ix) -> do
S.lift $ encodeFile (tmpDir </> show ix) x
S.modify (+1)
let at ix = do
cs <- B.readFile (tmpDir </> show ix)
return . decode $ BL.fromChunks [cs]
handler $ DataSet {size = n, elemAt = at}
lazySequence :: [IO a] -> IO [a]
lazySequence (mx:mxs) = do
x <- mx
xs <- unsafeInterleaveIO (lazySequence mxs)
return (x : xs)
lazySequence [] = return []
lazyMapM :: (a -> IO b) -> [a] -> IO [b]
lazyMapM f = lazySequence . map f