{-# LANGUAGE CPP #-} {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE DeriveAnyClass #-} module Torch.Data.Loaders.Cifar10 ( default_cifar_path , Mode(..) , mode_path , testLength , trainLength , Category(..) , I.rgb2torch , cifar10set , defaultCifar10set ) where import System.FilePath (()) import Data.Proxy (Proxy(..)) import GHC.Generics (Generic) import Text.Read (readMaybe) import Control.DeepSeq (NFData) import Data.Vector (Vector) import System.Random.MWC (GenIO, createSystemRandom) import Data.Hashable #ifdef CUDA import Torch.Cuda.Double #else import Torch.Double #endif import qualified Data.Char as Char import qualified Torch.Data.Loaders.Internal as I -- This should be replaced with a download-aware cache. default_cifar_path :: FilePath default_cifar_path = "/mnt/lake/datasets/cifar-10" data Mode = Test | Train deriving (Eq, Enum, Ord, Show, Bounded) testLength :: Proxy 'Test -> Proxy 1000 testLength _ = Proxy trainLength :: Proxy 'Train -> Proxy 5000 trainLength _ = Proxy data Category = Airplane -- 0 | Automobile -- 2 | Bird -- 3 | Cat -- 4 | Deer -- 5 | Dog -- 6 | Frog -- 7 | Horse -- 8 | Ship -- 9 | Truck -- 10 deriving (Eq, Enum, Ord, Show, Bounded, Generic, NFData, Read, Hashable) mode_path :: FilePath -> Mode -> FilePath mode_path cifarpath m = cifarpath (Char.toLower <$> show m) cifar10set :: GenIO -> FilePath -> Mode -> IO (Vector (Category, FilePath)) cifar10set g p m = I.shuffleCatFolders g cast (mode_path p m) where cast :: FilePath -> Maybe Category cast fp = case filter (not . (`elem` ("/\\"::String))) fp of h:tl -> readMaybe (Char.toUpper h : map Char.toLower tl) _ -> Nothing defaultCifar10set :: Mode -> IO (Vector (Category, FilePath)) defaultCifar10set m = createSystemRandom >>= \g -> cifar10set g default_cifar_path m -- test :: Tensor '[1] -- test -- = evalBP -- (classNLLCriterion (Long.unsafeVector [2] :: Long.Tensor '[1])) -- (unsqueeze1d (dim :: Dim 0) $ unsafeVector [1,0,0] :: Tensor '[1, 3]) -- -- test2 :: Tensor '[1] -- test2 -- = evalBP -- ( _classNLLCriterion' -- (-100) False True -- -- (Long.unsafeMatrix [[0,1,0]] :: Long.Tensor '[1,3]) -- -- (Long.unsafeVector [0,1,0] :: Long.Tensor '[3]) -- (Long.unsafeVector [0,1,2] :: Long.Tensor '[3]) -- ) -- -- (unsafeVector [1,0,0] :: Tensor '[3]) -- -- (unsafeMatrix [[0,0,1]] :: Tensor '[1,3]) -- (unsafeMatrix -- [ [1,0,0] -- , [0,1,0] -- , [0.5,0.5,0.5] -- ] :: Tensor '[3,3]) -- test3 :: CPU.Tensor '[1] -- test3 -- = evalBP -- ( CPU._classNLLCriterion' -- (-100) False True -- -- (CPULong.unsafeMatrix [[0,1,0]] :: CPULong.Tensor '[1,3]) -- (CPULong.unsafeVector [0,8] :: CPULong.Tensor '[2]) -- -- (CPULong.unsafeVector [0,1,0] :: CPULong.Tensor '[3]) -- -- (CPULong.unsafeVector [0,1,2] :: CPULong.Tensor '[3]) -- ) -- -- (CPU.unsafeVector [1,0,0] :: CPU.Tensor '[3]) -- -- (CPU.unsafeMatrix [[0,0,1]] :: CPU.Tensor '[1,3]) -- (CPU.unsafeMatrix -- [ [1,0,0] -- -- , [0,1,0] -- , [0.5,0.5,0.5] -- ] :: CPU.Tensor '[2,3])