{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeApplications #-}
module Torch.Data.Loaders.Internal where

-- import Prelude hiding (print, putStrLn)
-- import qualified Prelude as P (print, putStrLn)
-- import GHC.Int
import Data.Proxy
import Data.Vector (Vector)
-- import qualified Data.List as List ((!!))
-- import Control.Concurrent (threadDelay)
import Control.Monad (filterM)
-- import Control.Monad.Trans.Class
import Control.Monad.Trans.Except
-- import Control.Exception.Safe
-- import Control.DeepSeq
-- import GHC.Conc (getNumProcessors)
import GHC.TypeLits (KnownNat)
-- import Numeric.Dimensions
import System.Random.MWC (GenIO)
import System.Random.MWC.Distributions (uniformShuffle)
import System.Directory (listDirectory, doesDirectoryExist)
import System.FilePath ((</>), takeExtension)
-- import Control.Concurrent
--
-- import Control.Monad.Primitive
import qualified Data.Vector as V
-- import Data.Vector.Mutable (MVector)
-- import qualified Data.Vector.Mutable as M
--
#ifdef CUDA
import Torch.Cuda.Double
import qualified Torch.Cuda.Long as Long
import qualified Torch.Cuda.Double.Dynamic as Dynamic
import qualified Torch.Double.Dynamic as CPU
#else
import Torch.Double
import qualified Torch.Long as Long
import qualified Torch.Double.Storage as Storage
import qualified Torch.Double.Dynamic as Dynamic
#endif

import Torch.Data.Loaders.RGBVector
import Data.List

-- -- | asyncronously map across a pool with a maximum level of concurrency
-- mapPool :: Traversable t => Int -> (a -> IO b) -> t a -> IO (t b)
-- mapPool mx fn xs = do
--   sem <- MSem.new mx
--   Async.mapConcurrently (MSem.with sem . fn) xs

-- | load an RGB PNG image into a Torch tensor
rgb2torch
  :: forall h w . (All KnownDim '[h, w], All KnownNat '[h, w])
  => Normalize
  -> FilePath
  -> ExceptT String IO (Tensor '[3, h, w])
rgb2torch n f = rgb2list (Proxy @ '(h, w)) n f >>= cuboid

-- | Given a folder with subfolders of category images, return a uniform-randomly
-- shuffled list of absolute filepaths with the corresponding category.
shuffleCatFolders
  :: forall c
  .  GenIO                        -- ^ generator for shuffle
  -> (FilePath -> Maybe c)        -- ^ how to convert a subfolder into a category
  -> FilePath                     -- ^ absolute path of the dataset
  -> IO (Vector (c, FilePath))    -- ^ shuffled list
shuffleCatFolders g cast path = do
  cats <- filterM (doesDirectoryExist . (path </>)) =<< listDirectory path
  imgfiles <- sequence $ catContents <$> cats
  uniformShuffle (V.concat imgfiles) g
 where
  catContents :: FilePath -> IO (Vector (c, FilePath))
  catContents catFP =
    case cast catFP of
      Nothing -> pure mempty
      Just c ->
        let
          fdr = path </> catFP
          asPair img = (c, fdr </> img)
        in
          V.fromList . fmap asPair . filter isImage
          <$> listDirectory fdr

-- | verifies that an absolute filepath is an image
isImage :: FilePath -> Bool
isImage = (== ".png") . takeExtension