-------------------------------------------------------------------------------
-- |
-- Module    :  Numeric.Datasets.CIFAR10
-- License   :  BSD-3-Clause
-- Stability :  experimental
-- Portability: non-portable
--
-- The binary version contains the files data_batch_1.bin, data_batch_2.bin,
-- ..., data_batch_5.bin, as well as test_batch.bin. Each of these files is
-- formatted as follows:
--
--     <1 x label><3072 x pixel>
--     ...
--     <1 x label><3072 x pixel>
--
-- In other words, the first byte is the label of the first image, which is a
-- number in the range 0-9. The next 3072 bytes are the values of the pixels of
-- the image. The first 1024 bytes are the red channel values, the next 1024
-- the green, and the final 1024 the blue. The values are stored in row-major
-- order, so the first 32 bytes are the red channel values of the first row of
-- the image.
-------------------------------------------------------------------------------
{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}
#if MIN_VERSION_GLASGOW_HASKELL(8,2,0,0)
{-# LANGUAGE DerivingStrategies #-}
#endif
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
#if MIN_VERSION_JuicyPixels(3,3,0)
#else
{-# LANGUAGE UndecidableInstances #-}
#endif
module Numeric.Datasets.CIFAR10
  ( Label(..)
  , CIFARImage(..), height, width, image, label
  , cifarURL
  , cifar10
  , parseCifar
  ) where

import Codec.Picture (Image, PixelRGB8(PixelRGB8), Pixel8, writePixel)
import Codec.Picture.Types (newMutableImage, freezeImage)
import Control.DeepSeq
import Control.Exception (throw)
import Control.Monad.ST (runST)
import Data.List (zipWith4)
import GHC.Generics (Generic)
import Network.HTTP.Req (Url, (/:), https, Scheme(..))
import qualified Codec.Archive.Tar as Tar
import qualified Codec.Compression.GZip as GZip
import qualified Data.Attoparsec.ByteString.Lazy as Atto
import qualified Data.ByteString as BS
import qualified Data.ByteString.Lazy as BL

import Numeric.Datasets

#if MIN_VERSION_JuicyPixels(3,3,0)
#else
import Foreign.Storable (Storable)
import qualified Codec.Picture as Compat
#endif
-- ========================================================================= --

-- | labels of CIFAR-10 dataset. Enum corresponds to binary-based uint8 label.
data Label
  = Airplane
  | Automobile
  | Bird
  | Cat
  | Deer
  | Dog
  | Frog
  | Horse
  | Ship
  | Truck
  deriving (Show, Eq, Generic, Bounded, Enum, Read, NFData)

#if MIN_VERSION_JuicyPixels(3,3,0)
#else
instance (Eq (Compat.PixelBaseComponent a), Storable (Compat.PixelBaseComponent a))
    => Eq (Image a) where
  a == b = Compat.imageWidth  a == Compat.imageWidth  b &&
           Compat.imageHeight a == Compat.imageHeight b &&
           Compat.imageData   a == Compat.imageData   b
#endif

-- | Data representation of a CIFAR image is a 32x32 RGB image
newtype CIFARImage = CIFARImage { getXY :: (Image PixelRGB8, Label) }
#if MIN_VERSION_GLASGOW_HASKELL(8,2,0,0)
  deriving newtype (Eq, NFData)
#else
  deriving (Eq, Generic, NFData)
#endif

instance Show CIFARImage where
  show im = "CIFARImage{Height: 32, Width: 32, Pixel: RGB8, Label: " ++ show (label im) ++ "}"

-- | height of 'CIFARImage'
height :: Int
height = 32

-- | width of 'CIFARImage'
width :: Int
width = 32

-- | extract the JuicyPixel representation from a CIFAR datapoint
image :: CIFARImage -> Image PixelRGB8
image = fst . getXY

-- | extract the label from a CIFAR datapoint
label :: CIFARImage -> Label
label = snd . getXY

-- | Source URL for cifar-10 and cifar-100
cifarURL :: Url 'Https
cifarURL = https "www.cs.toronto.edu" /: "~kriz"

-------------------------------------------------------------------------------
tempdir :: Maybe FilePath
tempdir = Nothing

-- | Define a dataset from a source for a CSV file
cifar10 :: Dataset CIFARImage
cifar10 = Dataset
  (URL $ cifarURL /: "cifar-10-binary.tar.gz")
  tempdir
  (Just unzipCifar)
  (Parsable parseCifar)

-- cifar10Sha256 = "c4a38c50a1bc5f3a1c5537f2155ab9d68f9f25eb1ed8d9ddda3db29a59bca1dd"

-- | parser for a cifar binary
parseCifar :: Atto.Parser CIFARImage
parseCifar = do
  lbl :: Label <- toEnum . fromIntegral <$> Atto.anyWord8
  rs :: [Pixel8] <- BS.unpack <$> Atto.take 1024
  gs :: [Pixel8] <- BS.unpack <$> Atto.take 1024
  bs :: [Pixel8] <- BS.unpack <$> Atto.take 1024
  let ipixels = zipWith4 (\ix r g b -> (ix, PixelRGB8 r g b)) ixs rs gs bs
  pure $ CIFARImage (newImage ipixels, lbl)
  where
    newImage :: [((Int, Int), PixelRGB8)] -> Image PixelRGB8
    newImage ipixels = runST $ do
      mim <- newMutableImage height width
      mapM_ (\((x, y), rgb) -> writePixel mim x y rgb) ipixels
      freezeImage mim

    ixs :: [(Int, Int)]
    ixs = concat $ zipWith (\(row::Int) cols -> (row,) <$> cols) [0..] (replicate height [0..width - 1])

-- | how to unpack the tarball
--
-- FIXME: this should be in MonadThrow
unzipCifar :: BL.ByteString -> BL.ByteString
unzipCifar zipbs = do
  either (throw . fst) (BL.concat) $ Tar.foldlEntries go [] entries
  where
    entries :: Tar.Entries Tar.FormatError
    entries = Tar.read $ GZip.decompress zipbs

    go :: [BL.ByteString] -> Tar.Entry -> [BL.ByteString]
    go agg entry =
      case Tar.entryContent entry of
        Tar.NormalFile ps fs ->
          -- Each file is exactly 30730000 bytes long. All other files are metadata. See https://www.cs.toronto.edu/~kriz/cifar.html
          if fs == 30730000
          then ps:agg
          else agg
        _ -> agg