{-# LANGUAGE RecordWildCards #-}


-- | Provides the `DataSet` type which abstracts over the actual (IO-based)
-- representation of the training dataset.


module Numeric.SGD.DataSet
(
-- * Dataset
  DataSet (..)
, shuffle
-- * Reading
, loadData
, randomSample
-- * Construction
, withVect
, withDisk
-- , withData
) where


import           Control.Monad (forM_)
import qualified Control.Monad.State.Strict as S

import           System.IO.Temp (withTempDirectory)
import           System.IO.Unsafe (unsafeInterleaveIO)
import           System.FilePath ((</>))
import qualified System.Random as R
import           System.Random.Shuffle (shuffleM)

import           Data.Binary (Binary, encodeFile, decode)
import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as BL
import qualified Data.Vector as V
import qualified Data.Map.Strict as M


------------------------------- 
-- Type
-------------------------------


-- | Dataset stored on a disk
data DataSet elem = DataSet
  { size :: Int
    -- ^ The size of the dataset; the individual indices are
    -- [0, 1, ..., size - 1]
  , elemAt :: Int -> IO elem
    -- ^ Get the dataset element with the given identifier
  }


-------------------------------------------
-- Reading
-------------------------------------------


-- | Lazily load the entire dataset from a disk.
loadData :: DataSet a -> IO [a]
loadData DataSet{..} = lazyMapM elemAt [0 .. size - 1]


-- -- | A dataset sample of the given size.
-- sample :: R.RandomGen g => g -> Int -> DataSet a -> IO ([a], g)
-- sample g 0 _       = return ([], g)
-- sample g n dataset = do
--     (xs, g') <- sample g (n-1) dataset
--     let (i, g'') = R.next g'
--     x <- dataset `elemAt` (i `mod` size dataset)
--     return (x:xs, g'')


-- | Shuffle the dataset.
shuffle :: DataSet a -> IO (DataSet a)
shuffle DataSet{..} = do
  let ixs = [0 .. size - 1]
  ixs' <- shuffleM ixs
  let m = M.fromList (zip ixs ixs')
  return $ DataSet
    { size = size
    , elemAt = elemAt . (m M.!)
    }


-- | Random dataset sample with a specified number of elements (loaded eagerly)
randomSample :: Int -> DataSet a -> IO [a]
randomSample k dataSet
  | k <= 0 = return []
  | otherwise = do
      ix <- R.randomRIO (0, size dataSet - 1)
      x <- elemAt dataSet ix
      (x:) <$> randomSample (k-1) dataSet


-------------------------------------------
-- Construction
-------------------------------------------


-- | Construct dataset from a list of elements, store it as a vector, and run
-- the given handler.
withVect :: [a] -> (DataSet a -> IO b) -> IO b
withVect xs handler =
    handler dataset
  where
    v = V.fromList xs
    dataset = DataSet
        { size      = V.length v
        , elemAt    = \k -> return (v V.! k) }


-- | Construct dataset from a list of elements, store it on a disk and run the
-- given handler.  Training elements must have the `Binary` instance for this
-- function to work.
withDisk :: Binary a => [a] -> (DataSet a -> IO b) -> IO b
withDisk xs handler = withTempDirectory "." ".sgd" $ \tmpDir -> do
    -- We use state monad to compute the number of dataset elements. 
    n <- flip S.execStateT 0 $ forM_ (zip xs [0 :: Int ..]) $ \(x, ix) -> do
        S.lift $ encodeFile (tmpDir </> show ix) x
        S.modify (+1)

    -- Avoid decodeFile laziness when using some older versions of the binary
    -- library (as of year 2019, this could be probably simplified)
    let at ix = do
          cs <- B.readFile (tmpDir </> show ix)
          return . decode $ BL.fromChunks [cs]

    handler $ DataSet {size = n, elemAt = at}


-------------------------------------------
-- Lazy IO Utils
-------------------------------------------


-- | Lazily evaluate each action in the sequence from left to right,
-- and collect the results.
lazySequence :: [IO a] -> IO [a]
lazySequence (mx:mxs) = do
    x   <- mx
    xs  <- unsafeInterleaveIO (lazySequence mxs)
    return (x : xs)
lazySequence [] = return []


-- | `lazyMapM` f is equivalent to `lazySequence` . `map` f.
lazyMapM :: (a -> IO b) -> [a] -> IO [b]
lazyMapM f = lazySequence . map f