{-# LANGUAGE CPP #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DeriveAnyClass #-}
module Torch.Data.Loaders.Cifar10
( default_cifar_path
, Mode(..)
, mode_path
, testLength
, trainLength
, Category(..)
, I.rgb2torch
, cifar10set
, defaultCifar10set
) where
import System.FilePath ((</>))
import Data.Proxy (Proxy(..))
import GHC.Generics (Generic)
import Text.Read (readMaybe)
import Control.DeepSeq (NFData)
import Data.Vector (Vector)
import System.Random.MWC (GenIO, createSystemRandom)
import Data.Hashable
#ifdef CUDA
import Torch.Cuda.Double
#else
import Torch.Double
#endif
import qualified Data.Char as Char
import qualified Torch.Data.Loaders.Internal as I
default_cifar_path :: FilePath
default_cifar_path = "/mnt/lake/datasets/cifar-10"
data Mode = Test | Train
deriving (Eq, Enum, Ord, Show, Bounded)
testLength :: Proxy 'Test -> Proxy 1000
testLength _ = Proxy
trainLength :: Proxy 'Train -> Proxy 5000
trainLength _ = Proxy
data Category
= Airplane
| Automobile
| Bird
| Cat
| Deer
| Dog
| Frog
| Horse
| Ship
| Truck
deriving (Eq, Enum, Ord, Show, Bounded, Generic, NFData, Read, Hashable)
mode_path :: FilePath -> Mode -> FilePath
mode_path cifarpath m = cifarpath </> (Char.toLower <$> show m)
cifar10set :: GenIO -> FilePath -> Mode -> IO (Vector (Category, FilePath))
cifar10set g p m = I.shuffleCatFolders g cast (mode_path p m)
where
cast :: FilePath -> Maybe Category
cast fp =
case filter (not . (`elem` ("/\\"::String))) fp of
h:tl -> readMaybe (Char.toUpper h : map Char.toLower tl)
_ -> Nothing
defaultCifar10set :: Mode -> IO (Vector (Category, FilePath))
defaultCifar10set m =
createSystemRandom >>= \g -> cifar10set g default_cifar_path m