module Codec.Picture.Jpg.FastDct( referenceDct, fastDctLibJpeg ) where import Control.Applicative( (<$>) ) import Data.Int( Int16, Int32 ) import Data.Bits( Bits, shiftR, shiftL ) import Control.Monad.ST( ST ) import qualified Data.Vector.Storable.Mutable as M import Codec.Picture.Jpg.Types import Control.Monad( forM, forM_ ) {-# INLINE (.>>.) #-} {-# INLINE (.<<.) #-} (.>>.), (.<<.) :: (Bits a) => a -> Int -> a (.>>.) = shiftR (.<<.) = shiftL -- | Reference implementation of the DCT, directly implementing the formula -- of ITU-81. It's slow as hell, perform to many operations, but is accurate -- and a good reference point. referenceDct :: MutableMacroBlock s Int32 -> MutableMacroBlock s Int16 -> ST s (MutableMacroBlock s Int32) referenceDct workData block = forM_ [(u, v) | u <- [0 :: Int .. 7], v <- [0..7]] (\(u,v) -> do val <- at (u,v) (workData .<-. (v * 8 + u)) . truncate $ (1 / 4) * c u * c v * val) >> return workData where -- at :: (Int, Int) -> ST s Float at (u,v) = sum <$> (forM [(x,y) | x <- [0..7], y <- [0..7 :: Int]] $ \(x,y) -> do sample <- fromIntegral <$> (block .!!!. (y * 8 + x)) return $ sample * cos ((2 * fromIntegral x + 1) * fromIntegral u * (pi :: Float)/ 16) * cos ((2 * fromIntegral y + 1) * fromIntegral v * pi / 16)) c 0 = 1 / sqrt 2 c _ = 1 pASS1_BITS, cONST_BITS :: Int cONST_BITS = 13 pASS1_BITS = 2 fIX_0_298631336, fIX_0_390180644, fIX_0_541196100, fIX_0_765366865, fIX_0_899976223, fIX_1_175875602, fIX_1_501321110, fIX_1_847759065, fIX_1_961570560, fIX_2_053119869, fIX_2_562915447, fIX_3_072711026 :: Int32 fIX_0_298631336 =(2446) -- FIX(0.298631336) */ fIX_0_390180644 =(3196) -- FIX(0.390180644) */ fIX_0_541196100 =(4433) -- FIX(0.541196100) */ fIX_0_765366865 =(6270) -- FIX(0.765366865) */ fIX_0_899976223 =(7373) -- FIX(0.899976223) */ fIX_1_175875602 =(9633) -- FIX(1.175875602) */ fIX_1_501321110 =(12299) -- FIX(1.501321110) */ fIX_1_847759065 =(15137) -- FIX(1.847759065) */ fIX_1_961570560 =(16069) -- FIX(1.961570560) */ fIX_2_053119869 =(16819) -- FIX(2.053119869) */ fIX_2_562915447 =(20995) -- FIX(2.562915447) */ fIX_3_072711026 =(25172) -- FIX(3.072711026) */ cENTERJSAMPLE :: Int32 cENTERJSAMPLE = 128 -- | Fast DCT extracted from libjpeg fastDctLibJpeg :: MutableMacroBlock s Int32 -> MutableMacroBlock s Int16 -> ST s (MutableMacroBlock s Int32) fastDctLibJpeg workData sample_block = do firstPass workData 0 secondPass workData 7 {-_ <- mutate (\_ a -> a `quot` 8) workData-} return workData where -- Pass 1: process rows. -- Note results are scaled up by sqrt(8) compared to a true DCT; -- furthermore, we scale the results by 2**PASS1_BITS. firstPass _ 8 = return () firstPass dataBlock i = do let baseIdx = i * 8 readAt idx = fromIntegral <$> sample_block .!!!. (baseIdx + idx) mult = (*) writeAt idx n = (dataBlock .<-. (baseIdx + idx)) n writeAtPos idx n = (dataBlock .<-. (baseIdx + idx)) (n .>>. (cONST_BITS - pASS1_BITS)) blk0 <- readAt 0 blk1 <- readAt 1 blk2 <- readAt 2 blk3 <- readAt 3 blk4 <- readAt 4 blk5 <- readAt 5 blk6 <- readAt 6 blk7 <- readAt 7 let tmp0 = blk0 + blk7 tmp1 = blk1 + blk6 tmp2 = blk2 + blk5 tmp3 = blk3 + blk4 tmp10 = tmp0 + tmp3 tmp12 = tmp0 - tmp3 tmp11 = tmp1 + tmp2 tmp13 = tmp1 - tmp2 tmp0' = blk0 - blk7 tmp1' = blk1 - blk6 tmp2' = blk2 - blk5 tmp3' = blk3 - blk4 -- Stage 4 and output writeAt 0 $ (tmp10 + tmp11 - 8 * cENTERJSAMPLE) .<<. pASS1_BITS writeAt 4 $ (tmp10 - tmp11) .<<. pASS1_BITS let z1 = mult (tmp12 + tmp13) fIX_0_541196100 + (1 .<<. (cONST_BITS - pASS1_BITS - 1)) writeAtPos 2 $ z1 + mult tmp12 fIX_0_765366865 writeAtPos 6 $ z1 - mult tmp13 fIX_1_847759065 let tmp10' = tmp0' + tmp3' tmp11' = tmp1' + tmp2' tmp12' = tmp0' + tmp2' tmp13' = tmp1' + tmp3' z1' = mult (tmp12' + tmp13') fIX_1_175875602 -- c3 */ -- Add fudge factor here for final descale. */ + (1 .<<. (cONST_BITS - pASS1_BITS-1)) tmp0'' = mult tmp0' fIX_1_501321110 tmp1'' = mult tmp1' fIX_3_072711026 tmp2'' = mult tmp2' fIX_2_053119869 tmp3'' = mult tmp3' fIX_0_298631336 tmp10'' = mult tmp10' (- fIX_0_899976223) tmp11'' = mult tmp11' (- fIX_2_562915447) tmp12'' = mult tmp12' (- fIX_0_390180644) + z1' tmp13'' = mult tmp13' (- fIX_1_961570560) + z1' writeAtPos 1 $ tmp0'' + tmp10'' + tmp12'' writeAtPos 3 $ tmp1'' + tmp11'' + tmp13'' writeAtPos 5 $ tmp2'' + tmp11'' + tmp12'' writeAtPos 7 $ tmp3'' + tmp10'' + tmp13'' firstPass dataBlock $ i + 1 -- Pass 2: process columns. -- We remove the PASS1_BITS scaling, but leave the results scaled up -- by an overall factor of 8. secondPass :: M.STVector s Int32 -> Int -> ST s () secondPass _ (-1) = return () secondPass block i = do let readAt idx = block .!!!. ((7 - i) + idx * 8) mult = (*) writeAt idx n = (block .<-. (8 * idx + (7 - i))) n writeAtPos idx n = (block .<-. (8 * idx + (7 - i))) $ n .>>. (cONST_BITS + pASS1_BITS + 3) blk0 <- readAt 0 blk1 <- readAt 1 blk2 <- readAt 2 blk3 <- readAt 3 blk4 <- readAt 4 blk5 <- readAt 5 blk6 <- readAt 6 blk7 <- readAt 7 let tmp0 = blk0 + blk7 tmp1 = blk1 + blk6 tmp2 = blk2 + blk5 tmp3 = blk3 + blk4 -- Add fudge factor here for final descale. */ tmp10 = tmp0 + tmp3 + (1 .<<. (pASS1_BITS-1)) tmp12 = tmp0 - tmp3 tmp11 = tmp1 + tmp2 tmp13 = tmp1 - tmp2 tmp0' = blk0 - blk7 tmp1' = blk1 - blk6 tmp2' = blk2 - blk5 tmp3' = blk3 - blk4 writeAt 0 $ (tmp10 + tmp11) .>>. (pASS1_BITS + 3) writeAt 4 $ (tmp10 - tmp11) .>>. (pASS1_BITS + 3) let z1 = mult (tmp12 + tmp13) fIX_0_541196100 + (1 .<<. (cONST_BITS + pASS1_BITS - 1)) writeAtPos 2 $ z1 + mult tmp12 fIX_0_765366865 writeAtPos 6 $ z1 - mult tmp13 fIX_1_847759065 let tmp10' = tmp0' + tmp3' tmp11' = tmp1' + tmp2' tmp12' = tmp0' + tmp2' tmp13' = tmp1' + tmp3' z1' = mult (tmp12' + tmp13') fIX_1_175875602 -- Add fudge factor here for final descale. */ + 1 .<<. (cONST_BITS+pASS1_BITS-1); tmp0'' = mult tmp0' fIX_1_501321110 tmp1'' = mult tmp1' fIX_3_072711026 tmp2'' = mult tmp2' fIX_2_053119869 tmp3'' = mult tmp3' fIX_0_298631336 tmp10'' = mult tmp10' (- fIX_0_899976223) tmp11'' = mult tmp11' (- fIX_2_562915447) tmp12'' = mult tmp12' (- fIX_0_390180644) + z1' tmp13'' = mult tmp13' (- fIX_1_961570560) + z1' writeAtPos 1 $ tmp0'' + tmp10'' + tmp12'' writeAtPos 3 $ tmp1'' + tmp11'' + tmp13'' writeAtPos 5 $ tmp2'' + tmp11'' + tmp12'' writeAtPos 7 $ tmp3'' + tmp10'' + tmp13'' secondPass block (i - 1)