{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}
#if MIN_VERSION_GLASGOW_HASKELL(8,2,0,0)
{-# LANGUAGE DerivingStrategies #-}
#endif
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
#if MIN_VERSION_JuicyPixels(3,3,0)
#else
{-# LANGUAGE UndecidableInstances #-}
#endif
module Numeric.Datasets.CIFAR10
( Label(..)
, CIFARImage(..), height, width, image, label
, cifarURL
, cifar10
, parseCifar
) where
import Codec.Picture (Image, PixelRGB8(PixelRGB8), Pixel8, writePixel)
import Codec.Picture.Types (newMutableImage, freezeImage)
import Control.DeepSeq
import Control.Exception (throw)
import Control.Monad.ST (runST)
import Data.List (zipWith4)
import GHC.Generics (Generic)
import Network.HTTP.Req (Url, (/:), https, Scheme(..))
import qualified Codec.Archive.Tar as Tar
import qualified Codec.Compression.GZip as GZip
import qualified Data.Attoparsec.ByteString.Lazy as Atto
import qualified Data.ByteString as BS
import qualified Data.ByteString.Lazy as BL
import Numeric.Datasets
#if MIN_VERSION_JuicyPixels(3,3,0)
#else
import Foreign.Storable (Storable)
import qualified Codec.Picture as Compat
#endif
data Label
= Airplane
| Automobile
| Bird
| Cat
| Deer
| Dog
| Frog
| Horse
| Ship
| Truck
deriving (Show, Eq, Generic, Bounded, Enum, Read, NFData)
#if MIN_VERSION_JuicyPixels(3,3,0)
#else
instance (Eq (Compat.PixelBaseComponent a), Storable (Compat.PixelBaseComponent a))
=> Eq (Image a) where
a == b = Compat.imageWidth a == Compat.imageWidth b &&
Compat.imageHeight a == Compat.imageHeight b &&
Compat.imageData a == Compat.imageData b
#endif
newtype CIFARImage = CIFARImage { getXY :: (Image PixelRGB8, Label) }
#if MIN_VERSION_GLASGOW_HASKELL(8,2,0,0)
deriving newtype (Eq, NFData)
#else
deriving (Eq, Generic, NFData)
#endif
instance Show CIFARImage where
show im = "CIFARImage{Height: 32, Width: 32, Pixel: RGB8, Label: " ++ show (label im) ++ "}"
height :: Int
height = 32
width :: Int
width = 32
image :: CIFARImage -> Image PixelRGB8
image = fst . getXY
label :: CIFARImage -> Label
label = snd . getXY
cifarURL :: Url 'Https
cifarURL = https "www.cs.toronto.edu" /: "~kriz"
tempdir :: Maybe FilePath
tempdir = Nothing
cifar10 :: Dataset CIFARImage
cifar10 = Dataset
(URL $ cifarURL /: "cifar-10-binary.tar.gz")
tempdir
(Just unzipCifar)
(Parsable parseCifar)
parseCifar :: Atto.Parser CIFARImage
parseCifar = do
lbl :: Label <- toEnum . fromIntegral <$> Atto.anyWord8
rs :: [Pixel8] <- BS.unpack <$> Atto.take 1024
gs :: [Pixel8] <- BS.unpack <$> Atto.take 1024
bs :: [Pixel8] <- BS.unpack <$> Atto.take 1024
let ipixels = zipWith4 (\ix r g b -> (ix, PixelRGB8 r g b)) ixs rs gs bs
pure $ CIFARImage (newImage ipixels, lbl)
where
newImage :: [((Int, Int), PixelRGB8)] -> Image PixelRGB8
newImage ipixels = runST $ do
mim <- newMutableImage height width
mapM_ (\((x, y), rgb) -> writePixel mim x y rgb) ipixels
freezeImage mim
ixs :: [(Int, Int)]
ixs = concat $ zipWith (\(row::Int) cols -> (row,) <$> cols) [0..] (replicate height [0..width - 1])
unzipCifar :: BL.ByteString -> BL.ByteString
unzipCifar zipbs = do
either (throw . fst) (BL.concat) $ Tar.foldlEntries go [] entries
where
entries :: Tar.Entries Tar.FormatError
entries = Tar.read $ GZip.decompress zipbs
go :: [BL.ByteString] -> Tar.Entry -> [BL.ByteString]
go agg entry =
case Tar.entryContent entry of
Tar.NormalFile ps fs ->
if fs == 30730000
then ps:agg
else agg
_ -> agg