{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeApplications #-}
module Torch.Data.Loaders.Internal where
import Data.Proxy
import Data.Vector (Vector)
import Control.Monad (filterM)
import Control.Monad.Trans.Except
import GHC.TypeLits (KnownNat)
import System.Random.MWC (GenIO)
import System.Random.MWC.Distributions (uniformShuffle)
import System.Directory (listDirectory, doesDirectoryExist)
import System.FilePath ((</>), takeExtension)
import qualified Data.Vector as V
#ifdef CUDA
import Torch.Cuda.Double
import qualified Torch.Cuda.Long as Long
import qualified Torch.Cuda.Double.Dynamic as Dynamic
import qualified Torch.Double.Dynamic as CPU
#else
import Torch.Double
import qualified Torch.Long as Long
import qualified Torch.Double.Storage as Storage
import qualified Torch.Double.Dynamic as Dynamic
#endif
import Torch.Data.Loaders.RGBVector
import Data.List
rgb2torch
:: forall h w . (All KnownDim '[h, w], All KnownNat '[h, w])
=> Normalize
-> FilePath
-> ExceptT String IO (Tensor '[3, h, w])
rgb2torch n f = rgb2list (Proxy @ '(h, w)) n f >>= cuboid
shuffleCatFolders
:: forall c
. GenIO
-> (FilePath -> Maybe c)
-> FilePath
-> IO (Vector (c, FilePath))
shuffleCatFolders g cast path = do
cats <- filterM (doesDirectoryExist . (path </>)) =<< listDirectory path
imgfiles <- sequence $ catContents <$> cats
uniformShuffle (V.concat imgfiles) g
where
catContents :: FilePath -> IO (Vector (c, FilePath))
catContents catFP =
case cast catFP of
Nothing -> pure mempty
Just c ->
let
fdr = path </> catFP
asPair img = (c, fdr </> img)
in
V.fromList . fmap asPair . filter isImage
<$> listDirectory fdr
isImage :: FilePath -> Bool
isImage = (== ".png") . takeExtension