{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE CPP #-} module Torch.Data.Loaders.RGBVector ( Normalize(..) , file2rgb , rgb2list , assertList ) where import Data.Proxy import Data.Vector (Vector) import Control.Concurrent (threadDelay) import Control.Monad -- (forM_, filterM) import Control.Monad.Trans.Class import Control.Monad.Trans.Except import Control.Exception.Safe -- import Control.DeepSeq import GHC.Conc (getNumProcessors) import GHC.TypeLits (KnownNat) import Numeric.Dimensions import System.Random.MWC (GenIO) import System.Random.MWC.Distributions (uniformShuffle) import System.Directory (listDirectory, doesDirectoryExist) import System.FilePath ((), takeExtension) import Control.Concurrent import Control.Monad.Primitive import qualified Data.Vector as V import Data.Vector.Mutable (MVector) import qualified Data.Vector.Mutable as M #ifdef USE_GD import qualified Graphics.GD as GD #else import qualified Codec.Picture as JP #endif import Torch.Data.Loaders.Logging type HsReal = Double type MRGBVector s = MVector s (MVector s (MVector s HsReal)) type RGBVector = Vector (Vector (Vector HsReal)) modulename = "Torch.Data.Loaders.RGBVector" data Normalize = ZeroToOne | NegOneToOne | NoNormalize deriving (Eq, Ord, Show, Enum, Bounded) -- | load an RGB PNG image into a Torch tensor rgb2list :: forall h w . (All KnownDim '[h, w], All KnownNat '[h, w]) => Proxy '(h, w) -> Normalize -> FilePath -> ExceptT String IO [[[HsReal]]] rgb2list hwp donorm fp = do pxs <- file2rgb hwp fp -- lift $ assertPixels pxs ExceptT $ do vec <- mkRGBVec -- threadDelay 1000 fillFrom pxs $ \chw px -> do let pxfin = prep px writePx vec chw pxfin lst <- freezeList vec -- assertList modulename (concat (concat lst)) pure $ Right lst where prep w = case donorm of NoNormalize -> w ZeroToOne -> w / 255 NegOneToOne -> (w / 255) * 2 - 1 (height, width) = reifyHW hwp mkRGBVec :: PrimMonad m => m (MRGBVector (PrimState m)) mkRGBVec = M.replicateM 3 (M.replicateM height (M.unsafeNew width)) writePx :: PrimMonad m => MRGBVector (PrimState m) -> (Int, Int, Int) -> HsReal -> m () writePx channels (c, h, w) px = M.unsafeRead channels c >>= \rows -> M.unsafeRead rows h >>= \cols -> M.unsafeWrite cols w px readPx :: PrimMonad m => MRGBVector (PrimState m) -> (Int, Int, Int) -> m HsReal readPx channels (c, h, w) = M.unsafeRead channels c >>= \rows -> M.unsafeRead rows h >>= \cols -> M.unsafeRead cols w freezeList :: PrimMonad m => MRGBVector (PrimState m) -> m [[[HsReal]]] freezeList mvecs = do readN mvecs 3 $ \mframe -> readN mframe height $ \mrow -> readN mrow width pure readNfreeze :: PrimMonad m => MVector (PrimState m) a -> Int -> (a -> m b) -> m (Vector b) readNfreeze mvec n op = V.fromListN n <$> readN mvec n op readN :: PrimMonad m => MVector (PrimState m) a -> Int -> (a -> m b) -> m [b] readN mvec n op = mapM (M.read mvec >=> op) [0..n-1] fillFrom :: (Num y, PrimMonad m) => [((Int, Int), (Int, Int, Int))] -> ((Int, Int, Int) -> y -> m ()) -> m () fillFrom pxs filler = forM_ pxs $ \((h, w), (r, g, b)) -> forM_ (zip [0..] [r,g,b]) $ \(c, px) -> filler (c, h, w) (fromIntegral px) file2rgb :: forall h w hw rgb . (All KnownDim '[h, w], All KnownNat '[h, w]) => hw ~ (Int, Int) => rgb ~ (Int, Int, Int) => Proxy '(h, w) -> FilePath -> ExceptT String IO [(hw, rgb)] file2rgb hwp fp = do #ifdef USE_GD im <- lift $ GD.loadPngFile fp forM [(h, w) | h <- [0.. height - 1], w <- [0.. width - 1]] $ \(h, w) -> do (r,g,b,_) <- lift $ GD.toRGBA <$> GD.getPixel (h,w) im #else im <- JP.convertRGB8 <$> ExceptT (JP.readPng fp) forM [(h, w) | h <- [0.. height - 1], w <- [0.. width - 1]] $ \(h, w) -> do let JP.PixelRGB8 r g b = JP.pixelAt im h w #endif -- lift $ print (r, g, b) pure ((h, w), (fromIntegral r, fromIntegral g, fromIntegral b)) where (height, width) = reifyHW hwp assertPixels :: [((Int, Int), (Int, Int, Int))] -> IO () assertPixels pxs = do if all ((\(r, g, b) -> all (==0) [r, g, b]). snd) pxs then throwString $ mkError modulename "IMAGE ALL ZEROS!" else if all ((\(r, g, b) -> any (\x -> x < 0 || x > 255) [r, g, b]). snd) pxs then throwString $ mkError modulename "IMAGE OUT OF PIXEL BOUNDS!" else pure () assertList :: String -> [HsReal] -> IO () assertList hdr rs = do let oob = filter (\x -> x < -0.1 || x > 255.1) rs if not (null oob) then throwString $ show ({-oob,-} length oob, length rs, mkError hdr "OOB found!") else if all (== 0) rs then throwString $ mkError hdr "all-zeros found!" else pure () reifyHW :: forall h w . (All KnownDim '[h, w], All KnownNat '[h, w]) => Proxy '(h, w) -> (Int, Int) reifyHW _ = (fromIntegral (dimVal (dim :: Dim h)), fromIntegral (dimVal (dim :: Dim w)))