{-# LANGUAGE DeriveFoldable #-}
{-# language ScopedTypeVariables #-}
{-# options_ghc -Wno-unused-imports #-}
{-# options_ghc -Wno-unused-matches #-}
{-|

Streaming (de)serialization and encode-decode functions for the IDX format used in the MNIST handwritten digit recognition dataset [1].

Both sparse and dense decoders are provided. In either case, the range of the data is the same as the raw data (one unsigned byte per pixel).


== Links

1) http://yann.lecun.com/exdb/mnist/

-}
module Data.IDX.Conduit (
  -- * Source
  -- ** Labels
  sourceIdxLabels,
  mnistLabels,
  -- ** Data
  -- *** Dense
  sourceIdx,
  -- *** Sparse
  sourceIdxSparse,
  -- * Sink
  -- ** Data
  -- *** Dense
  sinkIdx,
  -- *** Sparse
  sinkIdxSparse,
  -- * Types
  Sparse,
  sBufSize, sNzComponents,
  -- * Debug
  readHeader
                        )where

import Control.Monad (when)
import Control.Monad.IO.Class (MonadIO(..))
import Data.Either (isRight)
import Data.Foldable (Foldable(..), traverse_, for_)
import Data.Int (Int8, Int16, Int32)
import Data.Word (Word8, Word16, Word32)
import Data.Void (Void)
import GHC.IO.Handle (Handle, hSeek, SeekMode(..), hClose)
import System.IO (IOMode(..), withBinaryFile, openBinaryFile)
-- binary
import Data.Binary (Binary(..), Get, getWord8, putWord8, encode, decode, decodeOrFail)
import Data.Binary.Get (runGetOrFail)
-- bytestring
import qualified Data.ByteString as BS (ByteString)
import qualified Data.ByteString.Lazy as LBS (ByteString, hGet, readFile, toStrict, map)
import qualified Data.ByteString.Lazy.Internal as LBS (unpackBytes, packBytes)
-- conduit
import Conduit (MonadResource, runResourceT, (.|), runConduitRes)
import qualified Data.Conduit as C (ConduitT, runConduit, bracketP, yield)
import qualified Data.Conduit.Combinators as C (sinkFile, map, takeExactly, print, takeExactlyE)
-- containers
import Data.Sequence (Seq, (|>))
import qualified Data.Sequence as SQ (fromList)
-- vector
import qualified Data.Vector as V (Vector, replicateM, length, forM_, head, tail)
import qualified Data.Vector.Unboxed as VU (Unbox, Vector, length, fromList, toList, foldl, (!))


-- | Outputs dense data buffers in the 0-255 range
--
-- In the case of MNIST dataset, 0 corresponds to the background of the image.
sourceIdx :: MonadResource m =>
             FilePath -- ^ filepath of uncompressed IDX data file
          -> Maybe Int -- ^ optional maximum number of entries to retrieve
          -> C.ConduitT () (VU.Vector Word8) m ()
sourceIdx :: FilePath -> Maybe Int -> ConduitT () (Vector Word8) m ()
sourceIdx = (Int -> ByteString -> Vector Word8)
-> FilePath -> Maybe Int -> ConduitT () (Vector Word8) m ()
forall (m :: * -> *) o i.
MonadResource m =>
(Int -> ByteString -> o)
-> FilePath -> Maybe Int -> ConduitT i o m ()
sourceIDX_ (\ Int
_ ByteString
bs -> [Word8] -> Vector Word8
forall a. Unbox a => [a] -> Vector a
VU.fromList ([Word8] -> Vector Word8) -> [Word8] -> Vector Word8
forall a b. (a -> b) -> a -> b
$ ByteString -> [Word8]
components ByteString
bs)


-- | Outputs sparse data buffers (i.e without zero components)
--
-- This incurs at least one additional data copy of each vector, but the resulting vectors take up less space.
sourceIdxSparse :: MonadResource m =>
                   FilePath -- ^ filepath of uncompressed IDX data file
                -> Maybe Int -- ^ optional maximum number of entries to retrieve
                -> C.ConduitT () (Sparse Word8) m ()
sourceIdxSparse :: FilePath -> Maybe Int -> ConduitT () (Sparse Word8) m ()
sourceIdxSparse = (Int -> ByteString -> Sparse Word8)
-> FilePath -> Maybe Int -> ConduitT () (Sparse Word8) m ()
forall (m :: * -> *) o i.
MonadResource m =>
(Int -> ByteString -> o)
-> FilePath -> Maybe Int -> ConduitT i o m ()
sourceIDX_ (\Int
n ByteString
bs -> Int -> Vector (Int, Word8) -> Sparse Word8
forall a. Int -> Vector (Int, a) -> Sparse a
Sparse Int
n ([Word8] -> Vector (Int, Word8)
forall (t :: * -> *). Foldable t => t Word8 -> Vector (Int, Word8)
sparsify ([Word8] -> Vector (Int, Word8)) -> [Word8] -> Vector (Int, Word8)
forall a b. (a -> b) -> a -> b
$ ByteString -> [Word8]
components ByteString
bs))

-- | Parser for the labels, can be plugged in as an argument to 'sourceIdxLabels'
mnistLabels :: LBS.ByteString
            -> Either String Int
mnistLabels :: ByteString -> Either FilePath Int
mnistLabels ByteString
l
  | [Int] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
xs Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1 = Int -> Either FilePath Int
forall a b. b -> Either a b
Right ([Int] -> Int
forall a. [a] -> a
head [Int]
xs)
  | Bool
otherwise = FilePath -> Either FilePath Int
forall a b. a -> Either a b
Left FilePath
"MNIST labels are the 0-9 digits"
  where xs :: [Int]
xs = Word8 -> Int
forall a. Enum a => a -> Int
fromEnum (Word8 -> Int) -> [Word8] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
`map` (ByteString -> [Word8]
LBS.unpackBytes ByteString
l)

-- | Outputs the labels corresponding to the data
sourceIdxLabels :: MonadResource m =>
                   (LBS.ByteString -> Either e o) -- ^ parser for the labels, where the bytestring buffer contains exactly one unsigned byte
                -> FilePath -- ^ filepath of uncompressed IDX labels file
                -> Maybe Int -- ^ optional maximum number of entries to retrieve
                -> C.ConduitT () (Either e o) m r
sourceIdxLabels :: (ByteString -> Either e o)
-> FilePath -> Maybe Int -> ConduitT () (Either e o) m r
sourceIdxLabels ByteString -> Either e o
buildf FilePath
fp Maybe Int
mmax = FilePath
-> (Handle -> ConduitT () (Either e o) m r)
-> ConduitT () (Either e o) m r
forall (m :: * -> *) i o r.
MonadResource m =>
FilePath -> (Handle -> ConduitT i o m r) -> ConduitT i o m r
withReadHdl FilePath
fp ((Handle -> ConduitT () (Either e o) m r)
 -> ConduitT () (Either e o) m r)
-> (Handle -> ConduitT () (Either e o) m r)
-> ConduitT () (Either e o) m r
forall a b. (a -> b) -> a -> b
$ \Handle
handle -> do
  ByteString
hlbs <- IO ByteString -> ConduitT () (Either e o) m ByteString
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO ByteString -> ConduitT () (Either e o) m ByteString)
-> IO ByteString -> ConduitT () (Either e o) m ByteString
forall a b. (a -> b) -> a -> b
$ Handle -> Int -> IO ByteString
LBS.hGet Handle
handle Int
4
  case ByteString -> Either FilePath IDXMagic
forall b. Binary b => ByteString -> Either FilePath b
decodeE ByteString
hlbs of
    Left FilePath
e -> FilePath -> ConduitT () (Either e o) m r
forall a. HasCallStack => FilePath -> a
error FilePath
e
    Right magic :: IDXMagic
magic@IDXMagic{} -> do
      ByteString
nitbs <- IO ByteString -> ConduitT () (Either e o) m ByteString
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO ByteString -> ConduitT () (Either e o) m ByteString)
-> IO ByteString -> ConduitT () (Either e o) m ByteString
forall a b. (a -> b) -> a -> b
$ Handle -> Int -> IO ByteString
LBS.hGet Handle
handle Int
4 -- number of items is 32 bit (4 byte)
      case ByteString -> Either FilePath Int
forall b. Binary b => ByteString -> Either FilePath b
decodeE ByteString
nitbs of
        Left FilePath
e -> FilePath -> ConduitT () (Either e o) m r
forall a. HasCallStack => FilePath -> a
error FilePath
e
        Right (Int
ndata :: Int) -> do
          let bufsize :: Int
bufsize = Int
1
              go :: Int -> ConduitT i (Either e o) m b
go Int
i = do
                let n :: Int
n = case Maybe Int
mmax of
                      Maybe Int
Nothing -> Int
n
                      Just Int
mi -> Int
mi
                Bool
-> ConduitT i (Either e o) m () -> ConduitT i (Either e o) m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
n Int
ndata) (ConduitT i (Either e o) m () -> ConduitT i (Either e o) m ())
-> ConduitT i (Either e o) m () -> ConduitT i (Either e o) m ()
forall a b. (a -> b) -> a -> b
$ do
                  ByteString
b <- IO ByteString -> ConduitT i (Either e o) m ByteString
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO ByteString -> ConduitT i (Either e o) m ByteString)
-> IO ByteString -> ConduitT i (Either e o) m ByteString
forall a b. (a -> b) -> a -> b
$ Handle -> Int -> IO ByteString
LBS.hGet Handle
handle Int
bufsize
                  IO () -> ConduitT i (Either e o) m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ConduitT i (Either e o) m ())
-> IO () -> ConduitT i (Either e o) m ()
forall a b. (a -> b) -> a -> b
$ Handle -> SeekMode -> Integer -> IO ()
hSeek Handle
handle SeekMode
RelativeSeek (Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
bufsize)
                  Either e o -> ConduitT i (Either e o) m ()
forall (m :: * -> *) o i. Monad m => o -> ConduitT i o m ()
C.yield (Either e o -> ConduitT i (Either e o) m ())
-> Either e o -> ConduitT i (Either e o) m ()
forall a b. (a -> b) -> a -> b
$ ByteString -> Either e o
buildf ByteString
b
                Int -> ConduitT i (Either e o) m b
go (Int -> Int
forall a. Enum a => a -> a
succ Int
i)
          Int -> ConduitT () (Either e o) m r
forall (m :: * -> *) i b.
MonadIO m =>
Int -> ConduitT i (Either e o) m b
go Int
0

decodeE :: Binary b => LBS.ByteString -> Either String b
decodeE :: ByteString -> Either FilePath b
decodeE ByteString
l = case ByteString
-> Either
     (ByteString, ByteOffset, FilePath) (ByteString, ByteOffset, b)
forall a.
Binary a =>
ByteString
-> Either
     (ByteString, ByteOffset, FilePath) (ByteString, ByteOffset, a)
decodeOrFail ByteString
l of
    Left (ByteString
_, ByteOffset
_, FilePath
e) -> FilePath -> Either FilePath b
forall a b. a -> Either a b
Left FilePath
e
    Right (ByteString
_, ByteOffset
_, b
x) -> b -> Either FilePath b
forall a b. b -> Either a b
Right b
x


{-# WARNING sinkIdx "this produces an incomplete header for some reason, causing the decoder to chop the data items at the wrong length. Do not use until https://github.com/ocramz/mnist-idx-conduit/issues/1 is resolved." #-}
-- | Write a dataset to disk
--
-- Contents are written as unsigned bytes, so make sure 8 bit data comes in without losses
sinkIdx :: (MonadResource m, Foldable t) =>
           FilePath -- ^ file to write
        -> Int -- ^ number of data items that will be written
        -> t Word32 -- ^ data dimension sizes
        -> C.ConduitT (VU.Vector Word8) Void m ()
sinkIdx :: FilePath -> Int -> t Word32 -> ConduitT (Vector Word8) Void m ()
sinkIdx = (Vector Word8 -> ByteString)
-> FilePath -> Int -> t Word32 -> ConduitT (Vector Word8) Void m ()
forall (m :: * -> *) (t :: * -> *) i.
(MonadResource m, Foldable t) =>
(i -> ByteString)
-> FilePath -> Int -> t Word32 -> ConduitT i Void m ()
sinkIDX_ (ByteString -> ByteString
LBS.toStrict (ByteString -> ByteString)
-> (Vector Word8 -> ByteString) -> Vector Word8 -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Word8] -> ByteString
fromComponents ([Word8] -> ByteString)
-> (Vector Word8 -> [Word8]) -> Vector Word8 -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vector Word8 -> [Word8]
forall a. Unbox a => Vector a -> [a]
VU.toList)

{-# WARNING sinkIdxSparse "this produces an incomplete header for some reason, causing the decoder to chop the data items at the wrong length. Do not use until https://github.com/ocramz/mnist-idx-conduit/issues/1 is resolved." #-}
-- | Write a sparse dataset to disk
--
-- Contents are written as unsigned bytes, so make sure 8 bit data comes in without losses
sinkIdxSparse :: (Foldable t, MonadResource m) =>
                 FilePath -- ^ file to write
              -> Int -- ^ number of data items that will be written
              -> t Word32 -- ^ data dimension sizes
              -> C.ConduitT (Sparse Word8) Void m ()
sinkIdxSparse :: FilePath -> Int -> t Word32 -> ConduitT (Sparse Word8) Void m ()
sinkIdxSparse = (Sparse Word8 -> ByteString)
-> FilePath -> Int -> t Word32 -> ConduitT (Sparse Word8) Void m ()
forall (m :: * -> *) (t :: * -> *) i.
(MonadResource m, Foldable t) =>
(i -> ByteString)
-> FilePath -> Int -> t Word32 -> ConduitT i Void m ()
sinkIDX_ (\(Sparse Int
n Vector (Int, Word8)
vu) -> ByteString -> ByteString
LBS.toStrict (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ [Word8] -> ByteString
fromComponents ([Word8] -> ByteString) -> [Word8] -> ByteString
forall a b. (a -> b) -> a -> b
$ Int -> Vector (Int, Word8) -> [Word8]
densify Int
n Vector (Int, Word8)
vu)

{-# WARNING sinkIDX_ "this produces an incomplete header for some reason, causing the decoder to chop the data items at the wrong length. Do not use until https://github.com/ocramz/mnist-idx-conduit/issues/1 is resolved." #-}
sinkIDX_ :: (MonadResource m, Foldable t) =>
            (i -> BS.ByteString)
         -> FilePath
         -> Int -- ^ number of data items that will be written
         -> t Word32 -- ^ data dimension sizes
         -> C.ConduitT i Void m ()
sinkIDX_ :: (i -> ByteString)
-> FilePath -> Int -> t Word32 -> ConduitT i Void m ()
sinkIDX_ i -> ByteString
buildf FilePath
fp Int
ndata t Word32
ds = ConduitT i ByteString m ()
src ConduitT i ByteString m ()
-> ConduitM ByteString Void m () -> ConduitT i Void m ()
forall (m :: * -> *) a b c r.
Monad m =>
ConduitM a b m () -> ConduitM b c m r -> ConduitM a c m r
.|
                              FilePath -> ConduitM ByteString Void m ()
forall (m :: * -> *) o.
MonadResource m =>
FilePath -> ConduitT ByteString o m ()
C.sinkFile FilePath
fp
  where
    ndims :: Int
ndims = t Word32 -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length t Word32
ds
    magicbs :: ByteString
magicbs = IDXMagic -> ByteString
forall b. Binary b => b -> ByteString
encodeBS (IDXContentType -> Int -> IDXMagic
IDXMagic IDXContentType
IDXUnsignedByte Int
ndims)
    ndatabs :: ByteString
ndatabs = Word32 -> ByteString
forall b. Binary b => b -> ByteString
encodeBS (Int -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
ndata :: Word32)
    src :: ConduitT i ByteString m ()
src = do
      ByteString -> ConduitT i ByteString m ()
forall (m :: * -> *) o i. Monad m => o -> ConduitT i o m ()
C.yield ByteString
magicbs -- magic number
      ByteString -> ConduitT i ByteString m ()
forall (m :: * -> *) o i. Monad m => o -> ConduitT i o m ()
C.yield ByteString
ndatabs -- number of data items
      t Word32
-> (Word32 -> ConduitT i ByteString m ())
-> ConduitT i ByteString m ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
t a -> (a -> f b) -> f ()
for_ t Word32
ds ((Word32 -> ConduitT i ByteString m ())
 -> ConduitT i ByteString m ())
-> (Word32 -> ConduitT i ByteString m ())
-> ConduitT i ByteString m ()
forall a b. (a -> b) -> a -> b
$ \Word32
d -> do -- data dimension sizes
        let
          d32 :: Word32
          d32 :: Word32
d32 = Word32 -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word32
d
        ByteString -> ConduitT i ByteString m ()
forall (m :: * -> *) o i. Monad m => o -> ConduitT i o m ()
C.yield (Word32 -> ByteString
forall b. Binary b => b -> ByteString
encodeBS Word32
d32)
      Int -> ConduitT i ByteString m () -> ConduitT i ByteString m ()
forall (m :: * -> *) a b r.
Monad m =>
Int -> ConduitT a b m r -> ConduitT a b m r
C.takeExactly Int
ndata (ConduitT i ByteString m () -> ConduitT i ByteString m ())
-> ConduitT i ByteString m () -> ConduitT i ByteString m ()
forall a b. (a -> b) -> a -> b
$ (i -> ByteString) -> ConduitT i ByteString m ()
forall (m :: * -> *) a b. Monad m => (a -> b) -> ConduitT a b m ()
C.map i -> ByteString
buildf


encodeBS :: (Binary b) => b -> BS.ByteString
encodeBS :: b -> ByteString
encodeBS = ByteString -> ByteString
LBS.toStrict (ByteString -> ByteString) -> (b -> ByteString) -> b -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. b -> ByteString
forall a. Binary a => a -> ByteString
encode

sourceIDX_ :: MonadResource m =>
              (Int -> LBS.ByteString -> o)
           -> FilePath -- ^ filepath of uncompressed IDX data file
           -> Maybe Int -- ^ optional maximum number of entries to retrieve
           -> C.ConduitT i o m ()
sourceIDX_ :: (Int -> ByteString -> o)
-> FilePath -> Maybe Int -> ConduitT i o m ()
sourceIDX_ Int -> ByteString -> o
buildf FilePath
fp Maybe Int
mmax = FilePath -> (Handle -> ConduitT i o m ()) -> ConduitT i o m ()
forall (m :: * -> *) i o r.
MonadResource m =>
FilePath -> (Handle -> ConduitT i o m r) -> ConduitT i o m r
withReadHdl FilePath
fp ((Handle -> ConduitT i o m ()) -> ConduitT i o m ())
-> (Handle -> ConduitT i o m ()) -> ConduitT i o m ()
forall a b. (a -> b) -> a -> b
$ \Handle
handle -> do
  ByteString
hlbs <- IO ByteString -> ConduitT i o m ByteString
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO ByteString -> ConduitT i o m ByteString)
-> IO ByteString -> ConduitT i o m ByteString
forall a b. (a -> b) -> a -> b
$ Handle -> Int -> IO ByteString
LBS.hGet Handle
handle Int
4
  case ByteString
-> Either
     (ByteString, ByteOffset, FilePath)
     (ByteString, ByteOffset, IDXMagic)
forall a.
Binary a =>
ByteString
-> Either
     (ByteString, ByteOffset, FilePath) (ByteString, ByteOffset, a)
decodeOrFail ByteString
hlbs of
    Left (ByteString
_, ByteOffset
_, FilePath
e) -> FilePath -> ConduitT i o m ()
forall a. HasCallStack => FilePath -> a
error FilePath
e
    Right (ByteString
_, ByteOffset
_, IDXMagic IDXContentType
_ Int
ndims) -> do
      let
        bytesDimsVec :: Int
bytesDimsVec = Int
4 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
ndims -- each dim is a 32 bit (4 byte) int
      ByteString
dvlbs <- IO ByteString -> ConduitT i o m ByteString
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO ByteString -> ConduitT i o m ByteString)
-> IO ByteString -> ConduitT i o m ByteString
forall a b. (a -> b) -> a -> b
$ Handle -> Int -> IO ByteString
LBS.hGet Handle
handle Int
bytesDimsVec
      case Int -> ByteString -> Either FilePath (Vector Int)
forall a. Num a => Int -> ByteString -> Either FilePath (Vector a)
getDims Int
ndims ByteString
dvlbs of
        Left FilePath
e -> FilePath -> ConduitT i o m ()
forall a. HasCallStack => FilePath -> a
error FilePath
e
        Right Vector Int
vv -> do
          let
            ndata :: Int
ndata = Vector Int -> Int
forall a. Vector a -> a
V.head Vector Int
vv
            bufsize :: Int
bufsize = Vector Int -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product (Vector Int -> Int) -> Vector Int -> Int
forall a b. (a -> b) -> a -> b
$ Vector Int -> Vector Int
forall a. Vector a -> Vector a
V.tail Vector Int
vv
            go :: Int -> ConduitT i o m ()
go Int
i = do
              let n :: Int
n = case Maybe Int
mmax of
                    Maybe Int
Nothing -> Int
ndata
                    Just Int
m  -> Int
m
              Bool -> ConduitT i o m () -> ConduitT i o m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
n Int
ndata) (ConduitT i o m () -> ConduitT i o m ())
-> ConduitT i o m () -> ConduitT i o m ()
forall a b. (a -> b) -> a -> b
$ do
                ByteString
b <- IO ByteString -> ConduitT i o m ByteString
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO ByteString -> ConduitT i o m ByteString)
-> IO ByteString -> ConduitT i o m ByteString
forall a b. (a -> b) -> a -> b
$ Handle -> Int -> IO ByteString
LBS.hGet Handle
handle Int
bufsize
                IO () -> ConduitT i o m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ConduitT i o m ()) -> IO () -> ConduitT i o m ()
forall a b. (a -> b) -> a -> b
$ Handle -> SeekMode -> Integer -> IO ()
hSeek Handle
handle SeekMode
RelativeSeek (Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
bufsize)
                o -> ConduitT i o m ()
forall (m :: * -> *) o i. Monad m => o -> ConduitT i o m ()
C.yield (o -> ConduitT i o m ()) -> o -> ConduitT i o m ()
forall a b. (a -> b) -> a -> b
$ Int -> ByteString -> o
buildf Int
bufsize ByteString
b
                Int -> ConduitT i o m ()
go (Int -> Int
forall a. Enum a => a -> a
succ Int
i)
          Int -> ConduitT i o m ()
forall (m :: * -> *) i. MonadIO m => Int -> ConduitT i o m ()
go Int
0

sparsify :: (Foldable t) => t Word8 -> VU.Vector (Int, Word8)
sparsify :: t Word8 -> Vector (Int, Word8)
sparsify t Word8
xs = [(Int, Word8)] -> Vector (Int, Word8)
forall a. Unbox a => [a] -> Vector a
VU.fromList ([(Int, Word8)] -> Vector (Int, Word8))
-> [(Int, Word8)] -> Vector (Int, Word8)
forall a b. (a -> b) -> a -> b
$ Seq (Int, Word8) -> [(Int, Word8)]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList (Seq (Int, Word8) -> [(Int, Word8)])
-> Seq (Int, Word8) -> [(Int, Word8)]
forall a b. (a -> b) -> a -> b
$ (Int, Seq (Int, Word8)) -> Seq (Int, Word8)
forall a b. (a, b) -> b
snd ((Int, Seq (Int, Word8)) -> Seq (Int, Word8))
-> (Int, Seq (Int, Word8)) -> Seq (Int, Word8)
forall a b. (a -> b) -> a -> b
$ ((Int, Seq (Int, Word8)) -> Word8 -> (Int, Seq (Int, Word8)))
-> (Int, Seq (Int, Word8)) -> t Word8 -> (Int, Seq (Int, Word8))
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (Int, Seq (Int, Word8)) -> Word8 -> (Int, Seq (Int, Word8))
forall a a.
(Enum a, Enum a) =>
(a, Seq (a, a)) -> a -> (a, Seq (a, a))
ins (Int
0, Seq (Int, Word8)
forall a. Monoid a => a
mempty) t Word8
xs
  where
    ins :: (a, Seq (a, a)) -> a -> (a, Seq (a, a))
ins (a
i, Seq (a, a)
acc) a
x =
      let x' :: Int
x' = a -> Int
forall a. Enum a => a -> Int
fromEnum a
x
      in if Int
x' Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
0
      then (a -> a
forall a. Enum a => a -> a
succ a
i, Seq (a, a)
acc Seq (a, a) -> (a, a) -> Seq (a, a)
forall a. Seq a -> a -> Seq a
|> (a
i, a
x))
      else (a -> a
forall a. Enum a => a -> a
succ a
i, Seq (a, a)
acc)

densify :: Int -> VU.Vector (Int, Word8) -> [Word8]
densify :: Int -> Vector (Int, Word8) -> [Word8]
densify Int
n Vector (Int, Word8)
vu = Seq Word8 -> [Word8]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList (Seq Word8 -> [Word8]) -> Seq Word8 -> [Word8]
forall a b. (a -> b) -> a -> b
$ (Int, Seq Word8) -> Seq Word8
forall a b. (a, b) -> b
snd ((Int, Seq Word8) -> Seq Word8) -> (Int, Seq Word8) -> Seq Word8
forall a b. (a -> b) -> a -> b
$ ((Int, Seq Word8) -> Int -> (Int, Seq Word8))
-> (Int, Seq Word8) -> [Int] -> (Int, Seq Word8)
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (Int, Seq Word8) -> Int -> (Int, Seq Word8)
ins (Int
0, Seq Word8
forall a. Monoid a => a
mempty) [Int
0 .. Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]
  where
    nnz :: Int
nnz = Vector (Int, Word8) -> Int
forall a. Unbox a => Vector a -> Int
VU.length Vector (Int, Word8)
vu
    ins :: (Int, Seq Word8) -> Int -> (Int, Seq Word8)
ins (Int
inz, Seq Word8
acc) Int
i
      | Int
inz Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
nnz =
        let (Int
iv, Word8
x) = Vector (Int, Word8)
vu Vector (Int, Word8) -> Int -> (Int, Word8)
forall a. Unbox a => Vector a -> Int -> a
VU.! Int
inz
        in case Int
i Int -> Int -> Ordering
forall a. Ord a => a -> a -> Ordering
`compare` Int
iv of
          Ordering
EQ -> (Int -> Int
forall a. Enum a => a -> a
succ Int
inz, Seq Word8
acc Seq Word8 -> Word8 -> Seq Word8
forall a. Seq a -> a -> Seq a
|> Word8
x)
          Ordering
_ -> (Int
inz, Seq Word8
acc Seq Word8 -> Word8 -> Seq Word8
forall a. Seq a -> a -> Seq a
|> Word8
0)
      | Bool
otherwise = (Int
inz, Seq Word8
acc Seq Word8 -> Word8 -> Seq Word8
forall a. Seq a -> a -> Seq a
|> Word8
0)


components :: LBS.ByteString -> [Word8]
components :: ByteString -> [Word8]
components = ByteString -> [Word8]
LBS.unpackBytes

fromComponents :: [Word8] -> LBS.ByteString
fromComponents :: [Word8] -> ByteString
fromComponents = [Word8] -> ByteString
LBS.packBytes

-- | Sparse buffer (containing only nonzero entries)
data Sparse a = Sparse {
  Sparse a -> Int
sBufSize :: !Int -- ^ total number of entries in the _dense_ buffer, i.e. including the zeros
  , Sparse a -> Vector (Int, a)
sNzComponents :: VU.Vector (Int, a) -- ^ nonzero components, together with the linear index into their dense counterpart
  }  deriving (Sparse a -> Sparse a -> Bool
(Sparse a -> Sparse a -> Bool)
-> (Sparse a -> Sparse a -> Bool) -> Eq (Sparse a)
forall a. (Unbox a, Eq a) => Sparse a -> Sparse a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Sparse a -> Sparse a -> Bool
$c/= :: forall a. (Unbox a, Eq a) => Sparse a -> Sparse a -> Bool
== :: Sparse a -> Sparse a -> Bool
$c== :: forall a. (Unbox a, Eq a) => Sparse a -> Sparse a -> Bool
Eq, Int -> Sparse a -> ShowS
[Sparse a] -> ShowS
Sparse a -> FilePath
(Int -> Sparse a -> ShowS)
-> (Sparse a -> FilePath)
-> ([Sparse a] -> ShowS)
-> Show (Sparse a)
forall a. (Show a, Unbox a) => Int -> Sparse a -> ShowS
forall a. (Show a, Unbox a) => [Sparse a] -> ShowS
forall a. (Show a, Unbox a) => Sparse a -> FilePath
forall a.
(Int -> a -> ShowS) -> (a -> FilePath) -> ([a] -> ShowS) -> Show a
showList :: [Sparse a] -> ShowS
$cshowList :: forall a. (Show a, Unbox a) => [Sparse a] -> ShowS
show :: Sparse a -> FilePath
$cshow :: forall a. (Show a, Unbox a) => Sparse a -> FilePath
showsPrec :: Int -> Sparse a -> ShowS
$cshowsPrec :: forall a. (Show a, Unbox a) => Int -> Sparse a -> ShowS
Show)

getDims :: Num a =>
           Int -- ^ number of dimensions
        -> LBS.ByteString -> Either String (V.Vector a)
getDims :: Int -> ByteString -> Either FilePath (Vector a)
getDims Int
n ByteString
lbs = case Get (Vector a)
-> ByteString
-> Either
     (ByteString, ByteOffset, FilePath)
     (ByteString, ByteOffset, Vector a)
forall a.
Get a
-> ByteString
-> Either
     (ByteString, ByteOffset, FilePath) (ByteString, ByteOffset, a)
runGetOrFail Get (Vector a)
gg ByteString
lbs of
  Left (ByteString
_, ByteOffset
_, FilePath
e) -> FilePath -> Either FilePath (Vector a)
forall a b. a -> Either a b
Left FilePath
e
  Right (ByteString
_, ByteOffset
_, Vector a
x) -> Vector a -> Either FilePath (Vector a)
forall a b. b -> Either a b
Right Vector a
x
  where
    gg :: Get (Vector a)
gg = Int -> Get a -> Get (Vector a)
forall (m :: * -> *) a. Monad m => Int -> m a -> m (Vector a)
V.replicateM Int
n (Int32 -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int32 -> a) -> Get Int32 -> Get a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Get Int32
getInt32)

withReadHdl :: MonadResource m =>
               FilePath
            -> (Handle -> C.ConduitT i o m r) -- ^ read from the handle
            -> C.ConduitT i o m r
withReadHdl :: FilePath -> (Handle -> ConduitT i o m r) -> ConduitT i o m r
withReadHdl FilePath
fp = IO Handle
-> (Handle -> IO ())
-> (Handle -> ConduitT i o m r)
-> ConduitT i o m r
forall (m :: * -> *) a i o r.
MonadResource m =>
IO a -> (a -> IO ()) -> (a -> ConduitT i o m r) -> ConduitT i o m r
C.bracketP (FilePath -> IOMode -> IO Handle
openBinaryFile FilePath
fp IOMode
ReadMode) Handle -> IO ()
hClose

withReadHdl_ :: FilePath -> (Handle -> IO r) -> IO r
withReadHdl_ :: FilePath -> (Handle -> IO r) -> IO r
withReadHdl_ FilePath
fp = FilePath -> IOMode -> (Handle -> IO r) -> IO r
forall r. FilePath -> IOMode -> (Handle -> IO r) -> IO r
withBinaryFile FilePath
fp IOMode
ReadMode

-- | Decode the header of an IDX data file and print out its contents
readHeader :: FilePath -- ^ path of IDX file
           -> IO (IDXMagic, Int32, V.Vector Int32) -- ^ "magic number", number of data items, list of dimension sizes of each data item
readHeader :: FilePath -> IO (IDXMagic, Int32, Vector Int32)
readHeader FilePath
fp = FilePath
-> (Handle -> IO (IDXMagic, Int32, Vector Int32))
-> IO (IDXMagic, Int32, Vector Int32)
forall r. FilePath -> (Handle -> IO r) -> IO r
withReadHdl_ FilePath
fp ((Handle -> IO (IDXMagic, Int32, Vector Int32))
 -> IO (IDXMagic, Int32, Vector Int32))
-> (Handle -> IO (IDXMagic, Int32, Vector Int32))
-> IO (IDXMagic, Int32, Vector Int32)
forall a b. (a -> b) -> a -> b
$ \Handle
handle -> do
  ByteString
hlbs <- IO ByteString -> IO ByteString
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO ByteString -> IO ByteString) -> IO ByteString -> IO ByteString
forall a b. (a -> b) -> a -> b
$ Handle -> Int -> IO ByteString
LBS.hGet Handle
handle Int
4
  case ByteString
-> Either
     (ByteString, ByteOffset, FilePath)
     (ByteString, ByteOffset, IDXMagic)
forall a.
Binary a =>
ByteString
-> Either
     (ByteString, ByteOffset, FilePath) (ByteString, ByteOffset, a)
decodeOrFail ByteString
hlbs of
    Left (ByteString
_, ByteOffset
_, FilePath
e) -> FilePath -> IO (IDXMagic, Int32, Vector Int32)
forall a. HasCallStack => FilePath -> a
error FilePath
e
    Right (ByteString
_, ByteOffset
_, mg :: IDXMagic
mg@(IDXMagic IDXContentType
_ Int
ndims)) -> do
      let
        bytesDimsVec :: Int
bytesDimsVec = Int
4 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
ndims -- each dim is a 32 bit (4 byte) int
      ByteString
dvlbs <- IO ByteString -> IO ByteString
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO ByteString -> IO ByteString) -> IO ByteString -> IO ByteString
forall a b. (a -> b) -> a -> b
$ Handle -> Int -> IO ByteString
LBS.hGet Handle
handle Int
bytesDimsVec
      case Int -> ByteString -> Either FilePath (Vector Int32)
forall a. Num a => Int -> ByteString -> Either FilePath (Vector a)
getDims Int
ndims ByteString
dvlbs of
        Left FilePath
e -> FilePath -> IO (IDXMagic, Int32, Vector Int32)
forall a. HasCallStack => FilePath -> a
error FilePath
e
        Right Vector Int32
vv -> do
          let
            ndata :: Int32
ndata = Vector Int32 -> Int32
forall a. Vector a -> a
V.head Vector Int32
vv
            bufsizes :: Vector Int32
bufsizes = Vector Int32 -> Vector Int32
forall a. Vector a -> Vector a
V.tail Vector Int32
vv
          (IDXMagic, Int32, Vector Int32)
-> IO (IDXMagic, Int32, Vector Int32)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IDXMagic
mg, Int32
ndata, Vector Int32
bufsizes)


-- | "magic number" starting the file header for the IDX format
--
-- as per http://yann.lecun.com/exdb/mnist/
--
-- 32 bit (4 bytes) header ("magic number")
data IDXMagic = IDXMagic {
  IDXMagic -> IDXContentType
idxType :: IDXContentType
  , IDXMagic -> Int
idxNumDims :: Int
                         } deriving (Int -> IDXMagic -> ShowS
[IDXMagic] -> ShowS
IDXMagic -> FilePath
(Int -> IDXMagic -> ShowS)
-> (IDXMagic -> FilePath) -> ([IDXMagic] -> ShowS) -> Show IDXMagic
forall a.
(Int -> a -> ShowS) -> (a -> FilePath) -> ([a] -> ShowS) -> Show a
showList :: [IDXMagic] -> ShowS
$cshowList :: [IDXMagic] -> ShowS
show :: IDXMagic -> FilePath
$cshow :: IDXMagic -> FilePath
showsPrec :: Int -> IDXMagic -> ShowS
$cshowsPrec :: Int -> IDXMagic -> ShowS
Show)

instance Binary IDXMagic where
  get :: Get IDXMagic
get = do
    -- first 2 bytes are 0
    Word8
_ <- Get Word8
getWord8 Get Word8 -> Get Word8 -> Get Word8
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Get Word8
getWord8
    -- third byte encodes the type of data
    IDXContentType
ty <- Get IDXContentType
forall t. Binary t => Get t
get :: Get IDXContentType
    -- fourth byte encode the number of dimensions
    Int
nDims <- Word8 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word8 -> Int) -> Get Word8 -> Get Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Get Word8
getWord8
    IDXMagic -> Get IDXMagic
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IDXMagic -> Get IDXMagic) -> IDXMagic -> Get IDXMagic
forall a b. (a -> b) -> a -> b
$ IDXContentType -> Int -> IDXMagic
IDXMagic IDXContentType
ty Int
nDims
  put :: IDXMagic -> Put
put IDXMagic
d = do
    -- first 2 bytes are 0
    Word8 -> Put
putWord8 Word8
0 Put -> Put -> Put
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Word8 -> Put
putWord8 Word8
0
    -- third byte encodes the type of data
    IDXContentType -> Put
forall t. Binary t => t -> Put
put (IDXContentType -> Put) -> IDXContentType -> Put
forall a b. (a -> b) -> a -> b
$ IDXMagic -> IDXContentType
idxType IDXMagic
d
    -- fourth byte encode the number of dimensions
    Word8 -> Put
forall t. Binary t => t -> Put
put (Word8 -> Put) -> Word8 -> Put
forall a b. (a -> b) -> a -> b
$ (Int -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral (IDXMagic -> Int
idxNumDims IDXMagic
d) :: Word8)


-- | A type to describe the content, according to IDX spec
data IDXContentType =
   IDXUnsignedByte
   | IDXSignedByte
   | IDXShort
   | IDXInt
   | IDXFloat
   | IDXDouble
   deriving Int -> IDXContentType -> ShowS
[IDXContentType] -> ShowS
IDXContentType -> FilePath
(Int -> IDXContentType -> ShowS)
-> (IDXContentType -> FilePath)
-> ([IDXContentType] -> ShowS)
-> Show IDXContentType
forall a.
(Int -> a -> ShowS) -> (a -> FilePath) -> ([a] -> ShowS) -> Show a
showList :: [IDXContentType] -> ShowS
$cshowList :: [IDXContentType] -> ShowS
show :: IDXContentType -> FilePath
$cshow :: IDXContentType -> FilePath
showsPrec :: Int -> IDXContentType -> ShowS
$cshowsPrec :: Int -> IDXContentType -> ShowS
Show

instance Binary IDXContentType where
    get :: Get IDXContentType
get = do
      Word8
w <- Get Word8
getWord8
      case Word8
w of
        Word8
0x08 -> IDXContentType -> Get IDXContentType
forall (m :: * -> *) a. Monad m => a -> m a
return IDXContentType
IDXUnsignedByte
        Word8
0x09 -> IDXContentType -> Get IDXContentType
forall (m :: * -> *) a. Monad m => a -> m a
return IDXContentType
IDXSignedByte
        Word8
0x0B -> IDXContentType -> Get IDXContentType
forall (m :: * -> *) a. Monad m => a -> m a
return IDXContentType
IDXShort
        Word8
0x0C -> IDXContentType -> Get IDXContentType
forall (m :: * -> *) a. Monad m => a -> m a
return IDXContentType
IDXInt
        Word8
0x0D -> IDXContentType -> Get IDXContentType
forall (m :: * -> *) a. Monad m => a -> m a
return IDXContentType
IDXFloat
        Word8
0x0E -> IDXContentType -> Get IDXContentType
forall (m :: * -> *) a. Monad m => a -> m a
return IDXContentType
IDXDouble
        Word8
_ -> FilePath -> Get IDXContentType
forall (m :: * -> *) a. MonadFail m => FilePath -> m a
fail (FilePath -> Get IDXContentType) -> FilePath -> Get IDXContentType
forall a b. (a -> b) -> a -> b
$ FilePath
"Unrecognized IDX content type: " FilePath -> ShowS
forall a. [a] -> [a] -> [a]
++ (Word8 -> FilePath
forall a. Show a => a -> FilePath
show Word8
w)

    put :: IDXContentType -> Put
put IDXContentType
IDXUnsignedByte = Word8 -> Put
putWord8 Word8
0x08
    put IDXContentType
IDXSignedByte   = Word8 -> Put
putWord8 Word8
0x09
    put IDXContentType
IDXShort        = Word8 -> Put
putWord8 Word8
0x0B
    put IDXContentType
IDXInt          = Word8 -> Put
putWord8 Word8
0x0C
    put IDXContentType
IDXFloat        = Word8 -> Put
putWord8 Word8
0x0D
    put IDXContentType
IDXDouble       = Word8 -> Put
putWord8 Word8
0x0E

-- Data.Binary uses big-endian format
-- getInt8 :: Get Int8
-- getInt8 = get

-- getInt16 :: Get Int16
-- getInt16 = get

getInt32 :: Get Int32
getInt32 :: Get Int32
getInt32 = Get Int32
forall t. Binary t => Get t
get

-- getFloat :: Get Float
-- getFloat = get

-- getDouble :: Get Double
-- getDouble = get