{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE CPP #-}
module Torch.Data.Loaders.RGBVector
( Normalize(..)
, file2rgb
, rgb2list
, assertList
) where
import Data.Proxy
import Data.Vector (Vector)
import Control.Concurrent (threadDelay)
import Control.Monad
import Control.Monad.Trans.Class
import Control.Monad.Trans.Except
import Control.Exception.Safe
import GHC.Conc (getNumProcessors)
import GHC.TypeLits (KnownNat)
import Numeric.Dimensions
import System.Random.MWC (GenIO)
import System.Random.MWC.Distributions (uniformShuffle)
import System.Directory (listDirectory, doesDirectoryExist)
import System.FilePath ((</>), takeExtension)
import Control.Concurrent
import Control.Monad.Primitive
import qualified Data.Vector as V
import Data.Vector.Mutable (MVector)
import qualified Data.Vector.Mutable as M
#ifdef USE_GD
import qualified Graphics.GD as GD
#else
import qualified Codec.Picture as JP
#endif
import Torch.Data.Loaders.Logging
type HsReal = Double
type MRGBVector s = MVector s (MVector s (MVector s HsReal))
type RGBVector = Vector (Vector (Vector HsReal))
modulename = "Torch.Data.Loaders.RGBVector"
data Normalize
= ZeroToOne
| NegOneToOne
| NoNormalize
deriving (Eq, Ord, Show, Enum, Bounded)
rgb2list
:: forall h w . (All KnownDim '[h, w], All KnownNat '[h, w])
=> Proxy '(h, w)
-> Normalize
-> FilePath
-> ExceptT String IO [[[HsReal]]]
rgb2list hwp donorm fp = do
pxs <- file2rgb hwp fp
ExceptT $ do
vec <- mkRGBVec
fillFrom pxs $ \chw px -> do
let pxfin = prep px
writePx vec chw pxfin
lst <- freezeList vec
pure $ Right lst
where
prep w =
case donorm of
NoNormalize -> w
ZeroToOne -> w / 255
NegOneToOne -> (w / 255) * 2 - 1
(height, width) = reifyHW hwp
mkRGBVec :: PrimMonad m => m (MRGBVector (PrimState m))
mkRGBVec = M.replicateM 3 (M.replicateM height (M.unsafeNew width))
writePx
:: PrimMonad m
=> MRGBVector (PrimState m)
-> (Int, Int, Int)
-> HsReal
-> m ()
writePx channels (c, h, w) px
= M.unsafeRead channels c
>>= \rows -> M.unsafeRead rows h
>>= \cols -> M.unsafeWrite cols w px
readPx
:: PrimMonad m
=> MRGBVector (PrimState m)
-> (Int, Int, Int)
-> m HsReal
readPx channels (c, h, w)
= M.unsafeRead channels c
>>= \rows -> M.unsafeRead rows h
>>= \cols -> M.unsafeRead cols w
freezeList
:: PrimMonad m => MRGBVector (PrimState m) -> m [[[HsReal]]]
freezeList mvecs = do
readN mvecs 3 $ \mframe ->
readN mframe height $ \mrow ->
readN mrow width pure
readNfreeze :: PrimMonad m => MVector (PrimState m) a -> Int -> (a -> m b) -> m (Vector b)
readNfreeze mvec n op =
V.fromListN n <$> readN mvec n op
readN :: PrimMonad m => MVector (PrimState m) a -> Int -> (a -> m b) -> m [b]
readN mvec n op = mapM (M.read mvec >=> op) [0..n-1]
fillFrom :: (Num y, PrimMonad m) => [((Int, Int), (Int, Int, Int))] -> ((Int, Int, Int) -> y -> m ()) -> m ()
fillFrom pxs filler =
forM_ pxs $ \((h, w), (r, g, b)) ->
forM_ (zip [0..] [r,g,b]) $ \(c, px) ->
filler (c, h, w) (fromIntegral px)
file2rgb
:: forall h w hw rgb
. (All KnownDim '[h, w], All KnownNat '[h, w])
=> hw ~ (Int, Int)
=> rgb ~ (Int, Int, Int)
=> Proxy '(h, w)
-> FilePath
-> ExceptT String IO [(hw, rgb)]
file2rgb hwp fp = do
#ifdef USE_GD
im <- lift $ GD.loadPngFile fp
forM [(h, w) | h <- [0.. height - 1], w <- [0.. width - 1]] $ \(h, w) -> do
(r,g,b,_) <- lift $ GD.toRGBA <$> GD.getPixel (h,w) im
#else
im <- JP.convertRGB8 <$> ExceptT (JP.readPng fp)
forM [(h, w) | h <- [0.. height - 1], w <- [0.. width - 1]] $ \(h, w) -> do
let JP.PixelRGB8 r g b = JP.pixelAt im h w
#endif
pure ((h, w), (fromIntegral r, fromIntegral g, fromIntegral b))
where
(height, width) = reifyHW hwp
assertPixels :: [((Int, Int), (Int, Int, Int))] -> IO ()
assertPixels pxs = do
if all ((\(r, g, b) -> all (==0) [r, g, b]). snd) pxs
then throwString $ mkError modulename "IMAGE ALL ZEROS!"
else
if all ((\(r, g, b) -> any (\x -> x < 0 || x > 255) [r, g, b]). snd) pxs
then throwString $ mkError modulename "IMAGE OUT OF PIXEL BOUNDS!"
else pure ()
assertList :: String -> [HsReal] -> IO ()
assertList hdr rs = do
let
oob = filter (\x -> x < -0.1 || x > 255.1) rs
if not (null oob)
then throwString $ show ( length oob, length rs, mkError hdr "OOB found!")
else
if all (== 0) rs
then throwString $ mkError hdr "all-zeros found!"
else pure ()
reifyHW
:: forall h w
. (All KnownDim '[h, w], All KnownNat '[h, w])
=> Proxy '(h, w)
-> (Int, Int)
reifyHW _ = (fromIntegral (dimVal (dim :: Dim h)), fromIntegral (dimVal (dim :: Dim w)))