{-# LANGUAGE DataKinds #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE NoStarIsType #-} module Torch.Typed.Vision where import qualified Codec.Compression.GZip as GZip import Control.Monad (forM_) import qualified Data.ByteString as BS import qualified Data.ByteString.Internal as BSI import Foreign.Marshal.Utils (copyBytes) import qualified Data.ByteString.Lazy as BS.Lazy import Data.Kind import qualified Foreign.ForeignPtr as F import qualified Foreign.Ptr as F import GHC.Exts (IsList (fromList)) import GHC.TypeLits import System.IO.Unsafe import qualified Torch.DType as D import Torch.Data.Pipeline import qualified Torch.Device as D import Torch.Internal.Cast import qualified Torch.Internal.Managed.TensorFactories as LibTorch import qualified Torch.Tensor as D import qualified Torch.TensorOptions as D import Torch.Typed.Auxiliary import Torch.Typed.Functional import Torch.Typed.Tensor data MNIST (m :: Type -> Type) (device :: (D.DeviceType, Nat)) (batchSize :: Nat) = MNIST {forall (m :: Type -> Type) (device :: (DeviceType, Nat)) (batchSize :: Nat). MNIST m device batchSize -> MnistData mnistData :: MnistData} instance (KnownNat batchSize, KnownDevice device, Applicative m) => Dataset m (MNIST m device batchSize) Int (Tensor device 'D.Float '[batchSize, 784], Tensor device 'D.Int64 '[batchSize]) where getItem :: MNIST m device batchSize -> Int -> m (Tensor device 'Float '[batchSize, 784], Tensor device 'Int64 '[batchSize]) getItem MNIST {MnistData mnistData :: forall (m :: Type -> Type) (device :: (DeviceType, Nat)) (batchSize :: Nat). MNIST m device batchSize -> MnistData mnistData :: MnistData ..} Int ix = let batchSize :: Int batchSize = forall (n :: Nat). KnownNat n => Int natValI @batchSize indexes :: [Int] indexes = [Int ix Int -> Int -> Int forall a. Num a => a -> a -> a * Int batchSize .. (Int ix Int -> Int -> Int forall a. Num a => a -> a -> a + Int 1) Int -> Int -> Int forall a. Num a => a -> a -> a * Int batchSize Int -> Int -> Int forall a. Num a => a -> a -> a - Int 1] imgs :: CPUTensor 'Float '[batchSize, DataDim] imgs = forall (n :: Nat). KnownNat n => MnistData -> [Int] -> CPUTensor 'Float '[n, DataDim] getImages @batchSize MnistData mnistData [Int] indexes labels :: CPUTensor 'Int64 '[batchSize] labels = forall (n :: Nat). KnownNat n => MnistData -> [Int] -> CPUTensor 'Int64 '[n] getLabels @batchSize MnistData mnistData [Int] indexes in (Tensor device 'Float '[batchSize, 784], Tensor device 'Int64 '[batchSize]) -> m (Tensor device 'Float '[batchSize, 784], Tensor device 'Int64 '[batchSize]) forall a. a -> m a forall (f :: Type -> Type) a. Applicative f => a -> f a pure (forall (device' :: (DeviceType, Nat)) (device :: (DeviceType, Nat)) (dtype :: DType) (shape :: [Nat]) t t'. (KnownDevice device', IsUnnamed t device dtype shape, Unnamed t', t' ~ ReplaceDevice'' t device') => t -> t' toDevice @device CPUTensor 'Float '[batchSize, 784] CPUTensor 'Float '[batchSize, DataDim] imgs, forall (device' :: (DeviceType, Nat)) (device :: (DeviceType, Nat)) (dtype :: DType) (shape :: [Nat]) t t'. (KnownDevice device', IsUnnamed t device dtype shape, Unnamed t', t' ~ ReplaceDevice'' t device') => t -> t' toDevice @device CPUTensor 'Int64 '[batchSize] labels) keys :: MNIST m device batchSize -> Set Int keys MNIST {MnistData mnistData :: forall (m :: Type -> Type) (device :: (DeviceType, Nat)) (batchSize :: Nat). MNIST m device batchSize -> MnistData mnistData :: MnistData ..} = [Item (Set Int)] -> Set Int forall l. IsList l => [Item l] -> l fromList [Int Item (Set Int) 0 .. MnistData -> Int Torch.Typed.Vision.length MnistData mnistData Int -> Int -> Int forall a. Integral a => a -> a -> a `Prelude.div` (forall (n :: Nat). KnownNat n => Int natValI @batchSize) Int -> Int -> Int forall a. Num a => a -> a -> a - Int 1] data MnistData = MnistData { MnistData -> ByteString images :: BS.ByteString, MnistData -> ByteString labels :: BS.ByteString } type Rows = 28 type Cols = 28 type DataDim = Rows * Cols type ClassDim = 10 getLabels :: forall n. KnownNat n => MnistData -> [Int] -> CPUTensor 'D.Int64 '[n] getLabels :: forall (n :: Nat). KnownNat n => MnistData -> [Int] -> CPUTensor 'Int64 '[n] getLabels MnistData mnist [Int] imageIdxs = Tensor -> Tensor '( 'CPU, 0) 'Int64 '[n] forall (device :: (DeviceType, Nat)) (dtype :: DType) (shape :: [Nat]). Tensor -> Tensor device dtype shape UnsafeMkTensor (Tensor -> Tensor '( 'CPU, 0) 'Int64 '[n]) -> ([Int] -> Tensor) -> [Int] -> Tensor '( 'CPU, 0) 'Int64 '[n] forall b c a. (b -> c) -> (a -> b) -> a -> c . [Int] -> Tensor forall a. TensorLike a => a -> Tensor D.asTensor ([Int] -> Tensor) -> ([Int] -> [Int]) -> [Int] -> Tensor forall b c a. (b -> c) -> (a -> b) -> a -> c . (Int -> Int) -> [Int] -> [Int] forall a b. (a -> b) -> [a] -> [b] map (MnistData -> Int -> Int getLabel MnistData mnist) ([Int] -> [Int]) -> ([Int] -> [Int]) -> [Int] -> [Int] forall b c a. (b -> c) -> (a -> b) -> a -> c . Int -> [Int] -> [Int] forall a. Int -> [a] -> [a] take (forall (n :: Nat). KnownNat n => Int natValI @n) ([Int] -> Tensor '( 'CPU, 0) 'Int64 '[n]) -> [Int] -> Tensor '( 'CPU, 0) 'Int64 '[n] forall a b. (a -> b) -> a -> b $ [Int] imageIdxs getLabel :: MnistData -> Int -> Int getLabel :: MnistData -> Int -> Int getLabel MnistData mnist Int imageIdx = Word8 -> Int forall a b. (Integral a, Num b) => a -> b fromIntegral (Word8 -> Int) -> Word8 -> Int forall a b. (a -> b) -> a -> b $ HasCallStack => ByteString -> Int -> Word8 ByteString -> Int -> Word8 BS.index (MnistData -> ByteString labels MnistData mnist) (Int -> Int forall a b. (Integral a, Num b) => a -> b fromIntegral Int imageIdx Int -> Int -> Int forall a. Num a => a -> a -> a + Int 8) getImage :: MnistData -> Int -> CPUTensor 'D.Float '[DataDim] getImage :: MnistData -> Int -> CPUTensor 'Float '[DataDim] getImage MnistData mnist Int imageIdx = let imageBS :: [Float] imageBS = [ Word8 -> Float forall a b. (Integral a, Num b) => a -> b fromIntegral (Word8 -> Float) -> Word8 -> Float forall a b. (a -> b) -> a -> b $ HasCallStack => ByteString -> Int -> Word8 ByteString -> Int -> Word8 BS.index (MnistData -> ByteString images MnistData mnist) (Int -> Int forall a b. (Integral a, Num b) => a -> b fromIntegral Int imageIdx Int -> Int -> Int forall a. Num a => a -> a -> a * Int 28 Int -> Integer -> Int forall a b. (Num a, Integral b) => a -> b -> a ^ Integer 2 Int -> Int -> Int forall a. Num a => a -> a -> a + Int 16 Int -> Int -> Int forall a. Num a => a -> a -> a + Int r) | Int r <- [Int 0 .. Int 28 Int -> Integer -> Int forall a b. (Num a, Integral b) => a -> b -> a ^ Integer 2 Int -> Int -> Int forall a. Num a => a -> a -> a - Int 1] ] :: [Float] (CPUTensor 'Float '[DataDim] tensor :: CPUTensor 'D.Float '[DataDim]) = Tensor -> Tensor '( 'CPU, 0) 'Float '[784] forall (device :: (DeviceType, Nat)) (dtype :: DType) (shape :: [Nat]). Tensor -> Tensor device dtype shape UnsafeMkTensor (Tensor -> Tensor '( 'CPU, 0) 'Float '[784]) -> Tensor -> Tensor '( 'CPU, 0) 'Float '[784] forall a b. (a -> b) -> a -> b $ [Float] -> Tensor forall a. TensorLike a => a -> Tensor D.asTensor [Float] imageBS in CPUTensor 'Float '[DataDim] tensor getImages' :: forall n. KnownNat n => MnistData -> [Int] -> CPUTensor 'D.Float '[n, DataDim] getImages' :: forall (n :: Nat). KnownNat n => MnistData -> [Int] -> CPUTensor 'Float '[n, DataDim] getImages' MnistData mnist [Int] imageIdxs = Tensor -> Tensor '( 'CPU, 0) 'Float '[n, DataDim] forall (device :: (DeviceType, Nat)) (dtype :: DType) (shape :: [Nat]). Tensor -> Tensor device dtype shape UnsafeMkTensor (Tensor -> Tensor '( 'CPU, 0) 'Float '[n, DataDim]) -> Tensor -> Tensor '( 'CPU, 0) 'Float '[n, DataDim] forall a b. (a -> b) -> a -> b $ [[Float]] -> Tensor forall a. TensorLike a => a -> Tensor D.asTensor ([[Float]] -> Tensor) -> [[Float]] -> Tensor forall a b. (a -> b) -> a -> b $ (Int -> [Float]) -> [Int] -> [[Float]] forall a b. (a -> b) -> [a] -> [b] map Int -> [Float] image ([Int] -> [[Float]]) -> [Int] -> [[Float]] forall a b. (a -> b) -> a -> b $ Int -> [Int] -> [Int] forall a. Int -> [a] -> [a] take (forall (n :: Nat). KnownNat n => Int natValI @n) [Int] imageIdxs where image :: Int -> [Float] image Int idx = [ Word8 -> Float forall a b. (Integral a, Num b) => a -> b fromIntegral (Word8 -> Float) -> Word8 -> Float forall a b. (a -> b) -> a -> b $ HasCallStack => ByteString -> Int -> Word8 ByteString -> Int -> Word8 BS.index (MnistData -> ByteString images MnistData mnist) (Int -> Int forall a b. (Integral a, Num b) => a -> b fromIntegral Int idx Int -> Int -> Int forall a. Num a => a -> a -> a * Int 28 Int -> Integer -> Int forall a b. (Num a, Integral b) => a -> b -> a ^ Integer 2 Int -> Int -> Int forall a. Num a => a -> a -> a + Int 16 Int -> Int -> Int forall a. Num a => a -> a -> a + Int r) | Int r <- [Int 0 .. Int 28 Int -> Integer -> Int forall a b. (Num a, Integral b) => a -> b -> a ^ Integer 2 Int -> Int -> Int forall a. Num a => a -> a -> a - Int 1] ] :: [Float] getImages :: forall n. KnownNat n => MnistData -> [Int] -> CPUTensor 'D.Float '[n, DataDim] getImages :: forall (n :: Nat). KnownNat n => MnistData -> [Int] -> CPUTensor 'Float '[n, DataDim] getImages MnistData mnist [Int] imageIdxs = Tensor -> Tensor '( 'CPU, 0) 'Float '[n, DataDim] forall (device :: (DeviceType, Nat)) (dtype :: DType) (shape :: [Nat]). Tensor -> Tensor device dtype shape UnsafeMkTensor (Tensor -> Tensor '( 'CPU, 0) 'Float '[n, DataDim]) -> Tensor -> Tensor '( 'CPU, 0) 'Float '[n, DataDim] forall a b. (a -> b) -> a -> b $ IO Tensor -> Tensor forall a. IO a -> a unsafePerformIO (IO Tensor -> Tensor) -> IO Tensor -> Tensor forall a b. (a -> b) -> a -> b $ do let (BSI.PS ForeignPtr Word8 fptr Int off Int len) = MnistData -> ByteString images MnistData mnist Tensor t <- ((ForeignPtr IntArray -> ForeignPtr TensorOptions -> IO (ForeignPtr Tensor)) -> [Int] -> TensorOptions -> IO Tensor forall a ca x1 cx1 y cy. (Castable a ca, Castable x1 cx1, Castable y cy) => (ca -> cx1 -> IO cy) -> a -> x1 -> IO y cast2 ForeignPtr IntArray -> ForeignPtr TensorOptions -> IO (ForeignPtr Tensor) LibTorch.empty_lo :: [Int] -> D.TensorOptions -> IO D.Tensor) [forall (n :: Nat). KnownNat n => Int natValI @n, forall (n :: Nat). KnownNat n => Int natValI @DataDim] (DType -> TensorOptions -> TensorOptions D.withDType DType D.UInt8 TensorOptions D.defaultOpts) Tensor -> (Ptr () -> IO ()) -> IO () forall a. Tensor -> (Ptr () -> IO a) -> IO a D.withTensor Tensor t ((Ptr () -> IO ()) -> IO ()) -> (Ptr () -> IO ()) -> IO () forall a b. (a -> b) -> a -> b $ \Ptr () ptr1 -> do ForeignPtr Word8 -> (Ptr Word8 -> IO ()) -> IO () forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b F.withForeignPtr ForeignPtr Word8 fptr ((Ptr Word8 -> IO ()) -> IO ()) -> (Ptr Word8 -> IO ()) -> IO () forall a b. (a -> b) -> a -> b $ \Ptr Word8 ptr2 -> do [(Int, Int)] -> ((Int, Int) -> IO ()) -> IO () forall (t :: Type -> Type) (m :: Type -> Type) a b. (Foldable t, Monad m) => t a -> (a -> m b) -> m () forM_ ([Int] -> [Int] -> [(Int, Int)] forall a b. [a] -> [b] -> [(a, b)] zip [Int 0 .. ((forall (n :: Nat). KnownNat n => Int natValI @n) Int -> Int -> Int forall a. Num a => a -> a -> a -Int 1)] [Int] imageIdxs) (((Int, Int) -> IO ()) -> IO ()) -> ((Int, Int) -> IO ()) -> IO () forall a b. (a -> b) -> a -> b $ \(Int i, Int idx) -> do Ptr Any -> Ptr Any -> Int -> IO () forall a. Ptr a -> Ptr a -> Int -> IO () copyBytes (Ptr () -> Int -> Ptr Any forall a b. Ptr a -> Int -> Ptr b F.plusPtr Ptr () ptr1 ((forall (n :: Nat). KnownNat n => Int natValI @DataDim) Int -> Int -> Int forall a. Num a => a -> a -> a * Int i)) (Ptr Word8 -> Int -> Ptr Any forall a b. Ptr a -> Int -> Ptr b F.plusPtr Ptr Word8 ptr2 (Int off Int -> Int -> Int forall a. Num a => a -> a -> a + Int 16 Int -> Int -> Int forall a. Num a => a -> a -> a + (forall (n :: Nat). KnownNat n => Int natValI @DataDim) Int -> Int -> Int forall a. Num a => a -> a -> a * Int idx)) (forall (n :: Nat). KnownNat n => Int natValI @DataDim) Tensor -> IO Tensor forall a. a -> IO a forall (m :: Type -> Type) a. Monad m => a -> m a return (Tensor -> IO Tensor) -> Tensor -> IO Tensor forall a b. (a -> b) -> a -> b $ DType -> Tensor -> Tensor forall a. HasTypes a Tensor => DType -> a -> a D.toType DType D.Float Tensor t length :: MnistData -> Int length :: MnistData -> Int length MnistData mnist = Int -> Int forall a b. (Integral a, Num b) => a -> b fromIntegral (Int -> Int) -> Int -> Int forall a b. (a -> b) -> a -> b $ ByteString -> Int BS.length (MnistData -> ByteString labels MnistData mnist) Int -> Int -> Int forall a. Num a => a -> a -> a - Int 8 decompressFile :: String -> String -> IO BS.ByteString decompressFile :: FilePath -> FilePath -> IO ByteString decompressFile FilePath path FilePath file = ByteString -> ByteString decompress' (ByteString -> ByteString) -> IO ByteString -> IO ByteString forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b <$> FilePath -> IO ByteString BS.readFile (FilePath path FilePath -> FilePath -> FilePath forall a. Semigroup a => a -> a -> a <> FilePath "/" FilePath -> FilePath -> FilePath forall a. Semigroup a => a -> a -> a <> FilePath file) where decompress' :: ByteString -> ByteString decompress' = [ByteString] -> ByteString BS.concat ([ByteString] -> ByteString) -> (ByteString -> [ByteString]) -> ByteString -> ByteString forall b c a. (b -> c) -> (a -> b) -> a -> c . LazyByteString -> [ByteString] BS.Lazy.toChunks (LazyByteString -> [ByteString]) -> (ByteString -> LazyByteString) -> ByteString -> [ByteString] forall b c a. (b -> c) -> (a -> b) -> a -> c . LazyByteString -> LazyByteString GZip.decompress (LazyByteString -> LazyByteString) -> (ByteString -> LazyByteString) -> ByteString -> LazyByteString forall b c a. (b -> c) -> (a -> b) -> a -> c . ByteString -> LazyByteString BS.Lazy.fromStrict initMnist :: String -> IO (MnistData, MnistData) initMnist :: FilePath -> IO (MnistData, MnistData) initMnist FilePath path = do ByteString imagesBS <- FilePath -> FilePath -> IO ByteString decompressFile FilePath path FilePath "train-images-idx3-ubyte.gz" ByteString labelsBS <- FilePath -> FilePath -> IO ByteString decompressFile FilePath path FilePath "train-labels-idx1-ubyte.gz" ByteString testImagesBS <- FilePath -> FilePath -> IO ByteString decompressFile FilePath path FilePath "t10k-images-idx3-ubyte.gz" ByteString testLabelsBS <- FilePath -> FilePath -> IO ByteString decompressFile FilePath path FilePath "t10k-labels-idx1-ubyte.gz" (MnistData, MnistData) -> IO (MnistData, MnistData) forall a. a -> IO a forall (m :: Type -> Type) a. Monad m => a -> m a return (ByteString -> ByteString -> MnistData MnistData ByteString imagesBS ByteString labelsBS, ByteString -> ByteString -> MnistData MnistData ByteString testImagesBS ByteString testLabelsBS)