{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeApplications #-}
module Torch.Data.Loaders.Internal where

-- import Prelude hiding (print, putStrLn)
-- import qualified Prelude as P (print, putStrLn)
-- import GHC.Int
import Data.Proxy
import Data.Vector (Vector)
-- import qualified Data.List as List ((!!))
-- import Control.Concurrent (threadDelay)
import Control.Monad (filterM)
-- import Control.Monad.Trans.Class
import Control.Monad.Trans.Except
-- import Control.Exception.Safe
-- import Control.DeepSeq
-- import GHC.Conc (getNumProcessors)
import GHC.TypeLits (KnownNat)
-- import Numeric.Dimensions
import System.Random.MWC (GenIO)
import System.Random.MWC.Distributions (uniformShuffle)
import System.Directory (listDirectory, doesDirectoryExist)
import System.FilePath ((</>), takeExtension)
-- import Control.Concurrent
-- import Control.Monad.Primitive
import qualified Data.Vector as V
-- import Data.Vector.Mutable (MVector)
-- import qualified Data.Vector.Mutable as M
#ifdef CUDA
import Torch.Cuda.Double
import qualified Torch.Cuda.Long as Long
import qualified Torch.Cuda.Double.Dynamic as Dynamic
import qualified Torch.Double.Dynamic as CPU
import Torch.Double
import qualified Torch.Long as Long
import qualified Torch.Double.Storage as Storage
import qualified Torch.Double.Dynamic as Dynamic

import Torch.Data.Loaders.RGBVector
import Data.List

-- -- | asyncronously map across a pool with a maximum level of concurrency
-- mapPool :: Traversable t => Int -> (a -> IO b) -> t a -> IO (t b)
-- mapPool mx fn xs = do
--   sem <- MSem.new mx
--   Async.mapConcurrently (MSem.with sem . fn) xs

-- | load an RGB PNG image into a Torch tensor
  :: forall h w . (All KnownDim '[h, w], All KnownNat '[h, w])
  => Normalize
  -> FilePath
  -> ExceptT String IO (Tensor '[3, h, w])
rgb2torch n f = rgb2list (Proxy @ '(h, w)) n f >>= cuboid

-- | Given a folder with subfolders of category images, return a uniform-randomly
-- shuffled list of absolute filepaths with the corresponding category.
  :: forall c
  .  GenIO                        -- ^ generator for shuffle
  -> (FilePath -> Maybe c)        -- ^ how to convert a subfolder into a category
  -> FilePath                     -- ^ absolute path of the dataset
  -> IO (Vector (c, FilePath))    -- ^ shuffled list
shuffleCatFolders g cast path = do
  cats <- filterM (doesDirectoryExist . (path </>)) =<< listDirectory path
  imgfiles <- sequence $ catContents <$> cats
  uniformShuffle (V.concat imgfiles) g
  catContents :: FilePath -> IO (Vector (c, FilePath))
  catContents catFP =
    case cast catFP of
      Nothing -> pure mempty
      Just c ->
          fdr = path </> catFP
          asPair img = (c, fdr </> img)
          V.fromList . fmap asPair . filter isImage
          <$> listDirectory fdr

-- | verifies that an absolute filepath is an image
isImage :: FilePath -> Bool
isImage = (== ".png") . takeExtension