{-# LANGUAGE MultiWayIf #-} module Codec.Compression.Zlib.Deflate( inflate , computeCodeValues ) where import Codec.Compression.Zlib.HuffmanTree import Codec.Compression.Zlib.Monad import Control.Monad import Data.Bits import Data.ByteString.Lazy(ByteString) import qualified Data.ByteString.Lazy as BS import Data.Int import Data.List import Data.Map.Strict(Map) import qualified Data.Map.Strict as Map import Data.Word inflate :: DeflateM (Maybe ByteString) inflate = do isFinal <- inflateBlock if isFinal then do advanceToByte rest <- readRest ourAdler <- finalAdler result <- finalOutput let theirAdler = BS.foldl shiftAdd 0 rest if | BS.length rest /= 4 -> return Nothing | theirAdler /= ourAdler -> return Nothing | otherwise -> return (Just result) else inflate where shiftAdd x y = (x `shiftL` 8) .|. fromIntegral y inflateBlock :: DeflateM Bool inflateBlock = do bfinal <- nextBit btype <- nextBits 2 case btype :: Word8 of 0 -> -- no compression do advanceToByte len <- nextWord16 nlen <- nextWord16 unless (len == complement nlen) $ fail "Len/nlen mismatch in uncompressed block." emitBlock =<< nextBlock len return bfinal 1 -> -- compressed with fixed Huffman codes do runInflate fixedLitTree fixedDistanceTree return bfinal 2 -> -- compressed with dynamic Huffman codes do hlit <- (257+) `fmap` nextBits 5 hdist <- (1+) `fmap` nextBits 5 hclen <- (4+) `fmap` nextBits 4 codeLens <- replicateM hclen (nextBits 3) let codeLens' = zip codeLengthOrder codeLens codeTree = computeHuffmanTree codeLens' lens <- getCodeLengths codeTree 0 (hlit + hdist) 0 Map.empty -- We do this as a big chunk and then split it up because the spec -- allows repeat codes to cross the hlit / hdist boundary. So now we -- need to pull off the hdist items. let (litlens, offdistlens) = Map.partitionWithKey (\ k _ -> k < hlit) lens distlens = Map.mapKeys (\ k -> k - hlit) offdistlens litTree = computeHuffmanTree (Map.toList litlens) distTree = computeHuffmanTree (Map.toList distlens) runInflate litTree distTree return bfinal _ -> -- reserved / error error ("Unacceptable BTYPE: " ++ show btype) where runInflate :: HuffmanTree Int -> HuffmanTree Int -> DeflateM () runInflate litTree distTree = do code <- nextCode litTree if | code < 256 -> do emitByte (fromIntegral code) runInflate litTree distTree | code == 256 -> return () | code > 256 -> do len <- getLength code distCode <- nextCode distTree dist <- getDistance distCode emitPastChunk dist len runInflate litTree distTree -- ----------------------------------------------------------------------------- getCodeLengths :: HuffmanTree Int -> Int -> Int -> Int -> Map Int Int -> DeflateM (Map Int Int) getCodeLengths tree n maxl prev acc | n >= maxl = return acc | otherwise = do code <- nextCode tree if | code <= 15 -> getCodeLengths tree (n+1) maxl code (Map.insert n code acc) | code == 16 -> -- copy the previous code length 3 - 6 times do num <- (3+) `fmap` nextBits 2 getCodeLengths tree (n+num) maxl prev (addNTimes n num prev acc) | code == 17 -> -- repeat a code length of 0 for 3 - 10 times do num <- (3+) `fmap` nextBits 3 getCodeLengths tree (n+num) maxl 0 (addNTimes n num 0 acc) | code == 18 -> -- repeat a code length of 0 for 11 - 138 times do num <- (11+) `fmap` nextBits 7 getCodeLengths tree (n+num) maxl 0 (addNTimes n num 0 acc) where addNTimes idx count val old = let idxs = take count [idx..] vals = replicate count val in Map.union old (Map.fromList (zip idxs vals)) -- ----------------------------------------------------------------------------- getLength :: Int -> DeflateM Int64 getLength c = case Map.lookup c getLengthMap of Nothing -> error ("getLength for bad code: " ++ show c) Just m -> m getLengthMap :: Map Int (DeflateM Int64) getLengthMap = Map.fromList [ (257, return 3) , (258, return 4) , (259, return 5) , (260, return 6) , (261, return 7) , (262, return 8) , (263, return 9) , (264, return 10) , (265, (+ 11) `fmap` nextBits 1) , (266, (+ 13) `fmap` nextBits 1) , (267, (+ 15) `fmap` nextBits 1) , (268, (+ 17) `fmap` nextBits 1) , (269, (+ 19) `fmap` nextBits 2) , (270, (+ 23) `fmap` nextBits 2) , (271, (+ 27) `fmap` nextBits 2) , (272, (+ 31) `fmap` nextBits 2) , (273, (+ 35) `fmap` nextBits 3) , (274, (+ 43) `fmap` nextBits 3) , (275, (+ 51) `fmap` nextBits 3) , (276, (+ 59) `fmap` nextBits 3) , (277, (+ 67) `fmap` nextBits 4) , (278, (+ 83) `fmap` nextBits 4) , (279, (+ 99) `fmap` nextBits 4) , (280, (+ 115) `fmap` nextBits 4) , (281, (+ 131) `fmap` nextBits 5) , (282, (+ 163) `fmap` nextBits 5) , (283, (+ 195) `fmap` nextBits 5) , (284, (+ 227) `fmap` nextBits 5) , (285, return 258) ] getDistance :: Int -> DeflateM Int getDistance c = case Map.lookup c getDistanceMap of Nothing -> error ("getDistance for bad code: " ++ show c) Just m -> m getDistanceMap :: Map Int (DeflateM Int) getDistanceMap = Map.fromList [ (0, return 1) , (1, return 2) , (2, return 3) , (3, return 4) , (4, (+ 5) `fmap` nextBits 1) , (5, (+ 7) `fmap` nextBits 1) , (6, (+ 9) `fmap` nextBits 2) , (7, (+ 13) `fmap` nextBits 2) , (8, (+ 17) `fmap` nextBits 3) , (9, (+ 25) `fmap` nextBits 3) , (10, (+ 33) `fmap` nextBits 4) , (11, (+ 49) `fmap` nextBits 4) , (12, (+ 65) `fmap` nextBits 5) , (13, (+ 97) `fmap` nextBits 5) , (14, (+ 129) `fmap` nextBits 6) , (15, (+ 193) `fmap` nextBits 6) , (16, (+ 257) `fmap` nextBits 7) , (17, (+ 385) `fmap` nextBits 7) , (18, (+ 513) `fmap` nextBits 8) , (19, (+ 769) `fmap` nextBits 8) , (20, (+ 1025) `fmap` nextBits 9) , (21, (+ 1537) `fmap` nextBits 9) , (22, (+ 2049) `fmap` nextBits 10) , (23, (+ 3073) `fmap` nextBits 10) , (24, (+ 4097) `fmap` nextBits 11) , (25, (+ 6145) `fmap` nextBits 11) , (26, (+ 8193) `fmap` nextBits 12) , (27, (+ 12289) `fmap` nextBits 12) , (28, (+ 16385) `fmap` nextBits 13) , (29, (+ 24577) `fmap` nextBits 13) ] -- ----------------------------------------------------------------------------- fixedLitTree :: HuffmanTree Int fixedLitTree = computeHuffmanTree ([(x, 8) | x <- [0 .. 143]] ++ [(x, 9) | x <- [144 .. 255]] ++ [(x, 7) | x <- [256 .. 279]] ++ [(x, 8) | x <- [280 .. 287]]) fixedDistanceTree :: HuffmanTree Int fixedDistanceTree = computeHuffmanTree [(x,5) | x <- [0..31]] -- ----------------------------------------------------------------------------- computeHuffmanTree :: [(Int, Int)] -> HuffmanTree Int computeHuffmanTree = createHuffmanTree . computeCodeValues computeCodeValues :: Ord a => [(a, Int)] -> [(a, Int, Int)] computeCodeValues vals = Map.foldrWithKey (\ v (l, c) a -> (v,l,c):a) [] codes where valsNo0s = filter (\ (_, b) -> (b /= 0)) vals valsSort = sortBy (\ (a,_) (b,_) -> compare a b) valsNo0s blCount = foldr (\ (_,k) m -> Map.insertWith (+) k 1 m) Map.empty valsNo0s nextcode = step2 0 1 (Map.insert 0 0 Map.empty) lenTree = Map.fromList valsSort codeTree = step3 (map fst valsSort) nextcode Map.empty maxBits = maximum (map snd valsSort) codes = Map.intersectionWith (,) lenTree codeTree -- step2 code bits nc | bits > maxBits = nc | otherwise = let prevCount = Map.findWithDefault 0 (bits - 1) blCount code' = (code + prevCount) `shiftL` 1 in step2 code' (bits + 1) (Map.insert bits code' nc) -- step3 [] _ ct = ct step3 (n:rest) nc ct = let len = Map.findWithDefault 0 n lenTree Just ncLen = Map.lookup len nc ct' = Map.insert n ncLen ct nc' = Map.insert len (ncLen + 1) nc in if len == 0 then step3 rest nc ct else step3 rest nc' ct' codeLengthOrder :: [Int] codeLengthOrder = [16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15]