module Codec.Compression.Zlib.Deflate(
inflate
, computeCodeValues
)
where
import Codec.Compression.Zlib.HuffmanTree
import Codec.Compression.Zlib.Monad
import Control.Monad
import Data.Bits
import Data.ByteString.Lazy(ByteString)
import qualified Data.ByteString.Lazy as BS
import Data.Int
import Data.List
import Data.Map.Strict(Map)
import qualified Data.Map.Strict as Map
import Data.Word
inflate :: DeflateM (Maybe ByteString)
inflate =
do isFinal <- inflateBlock
if isFinal
then do advanceToByte
rest <- readRest
ourAdler <- finalAdler
result <- finalOutput
let theirAdler = BS.foldl shiftAdd 0 rest
if | BS.length rest /= 4 -> return Nothing
| theirAdler /= ourAdler -> return Nothing
| otherwise -> return (Just result)
else inflate
where shiftAdd x y = (x `shiftL` 8) .|. fromIntegral y
inflateBlock :: DeflateM Bool
inflateBlock =
do bfinal <- nextBit
btype <- nextBits 2
case btype :: Word8 of
0 ->
do advanceToByte
len <- nextWord16
nlen <- nextWord16
unless (len == complement nlen) $
fail "Len/nlen mismatch in uncompressed block."
emitBlock =<< nextBlock len
return bfinal
1 ->
do runInflate fixedLitTree fixedDistanceTree
return bfinal
2 ->
do hlit <- (257+) `fmap` nextBits 5
hdist <- (1+) `fmap` nextBits 5
hclen <- (4+) `fmap` nextBits 4
codeLens <- replicateM hclen (nextBits 3)
let codeLens' = zip codeLengthOrder codeLens
codeTree = computeHuffmanTree codeLens'
lens <- getCodeLengths codeTree 0 (hlit + hdist) 0 Map.empty
let (litlens, offdistlens) =
Map.partitionWithKey (\ k _ -> k < hlit) lens
distlens = Map.mapKeys (\ k -> k hlit) offdistlens
litTree = computeHuffmanTree (Map.toList litlens)
distTree = computeHuffmanTree (Map.toList distlens)
runInflate litTree distTree
return bfinal
_ ->
error ("Unacceptable BTYPE: " ++ show btype)
where
runInflate :: HuffmanTree Int -> HuffmanTree Int -> DeflateM ()
runInflate litTree distTree =
do code <- nextCode litTree
if | code < 256 -> do emitByte (fromIntegral code)
runInflate litTree distTree
| code == 256 -> return ()
| code > 256 -> do len <- getLength code
distCode <- nextCode distTree
dist <- getDistance distCode
emitPastChunk dist len
runInflate litTree distTree
getCodeLengths :: HuffmanTree Int ->
Int -> Int -> Int ->
Map Int Int ->
DeflateM (Map Int Int)
getCodeLengths tree n maxl prev acc
| n >= maxl = return acc
| otherwise =
do code <- nextCode tree
if | code <= 15 ->
getCodeLengths tree (n+1) maxl code (Map.insert n code acc)
| code == 16 ->
do num <- (3+) `fmap` nextBits 2
getCodeLengths tree (n+num) maxl prev (addNTimes n num prev acc)
| code == 17 ->
do num <- (3+) `fmap` nextBits 3
getCodeLengths tree (n+num) maxl 0 (addNTimes n num 0 acc)
| code == 18 ->
do num <- (11+) `fmap` nextBits 7
getCodeLengths tree (n+num) maxl 0 (addNTimes n num 0 acc)
where
addNTimes idx count val old =
let idxs = take count [idx..]
vals = replicate count val
in Map.union old (Map.fromList (zip idxs vals))
getLength :: Int -> DeflateM Int64
getLength c =
case Map.lookup c getLengthMap of
Nothing -> error ("getLength for bad code: " ++ show c)
Just m -> m
getLengthMap :: Map Int (DeflateM Int64)
getLengthMap = Map.fromList [
(257, return 3)
, (258, return 4)
, (259, return 5)
, (260, return 6)
, (261, return 7)
, (262, return 8)
, (263, return 9)
, (264, return 10)
, (265, (+ 11) `fmap` nextBits 1)
, (266, (+ 13) `fmap` nextBits 1)
, (267, (+ 15) `fmap` nextBits 1)
, (268, (+ 17) `fmap` nextBits 1)
, (269, (+ 19) `fmap` nextBits 2)
, (270, (+ 23) `fmap` nextBits 2)
, (271, (+ 27) `fmap` nextBits 2)
, (272, (+ 31) `fmap` nextBits 2)
, (273, (+ 35) `fmap` nextBits 3)
, (274, (+ 43) `fmap` nextBits 3)
, (275, (+ 51) `fmap` nextBits 3)
, (276, (+ 59) `fmap` nextBits 3)
, (277, (+ 67) `fmap` nextBits 4)
, (278, (+ 83) `fmap` nextBits 4)
, (279, (+ 99) `fmap` nextBits 4)
, (280, (+ 115) `fmap` nextBits 4)
, (281, (+ 131) `fmap` nextBits 5)
, (282, (+ 163) `fmap` nextBits 5)
, (283, (+ 195) `fmap` nextBits 5)
, (284, (+ 227) `fmap` nextBits 5)
, (285, return 258)
]
getDistance :: Int -> DeflateM Int
getDistance c =
case Map.lookup c getDistanceMap of
Nothing -> error ("getDistance for bad code: " ++ show c)
Just m -> m
getDistanceMap :: Map Int (DeflateM Int)
getDistanceMap = Map.fromList [
(0, return 1)
, (1, return 2)
, (2, return 3)
, (3, return 4)
, (4, (+ 5) `fmap` nextBits 1)
, (5, (+ 7) `fmap` nextBits 1)
, (6, (+ 9) `fmap` nextBits 2)
, (7, (+ 13) `fmap` nextBits 2)
, (8, (+ 17) `fmap` nextBits 3)
, (9, (+ 25) `fmap` nextBits 3)
, (10, (+ 33) `fmap` nextBits 4)
, (11, (+ 49) `fmap` nextBits 4)
, (12, (+ 65) `fmap` nextBits 5)
, (13, (+ 97) `fmap` nextBits 5)
, (14, (+ 129) `fmap` nextBits 6)
, (15, (+ 193) `fmap` nextBits 6)
, (16, (+ 257) `fmap` nextBits 7)
, (17, (+ 385) `fmap` nextBits 7)
, (18, (+ 513) `fmap` nextBits 8)
, (19, (+ 769) `fmap` nextBits 8)
, (20, (+ 1025) `fmap` nextBits 9)
, (21, (+ 1537) `fmap` nextBits 9)
, (22, (+ 2049) `fmap` nextBits 10)
, (23, (+ 3073) `fmap` nextBits 10)
, (24, (+ 4097) `fmap` nextBits 11)
, (25, (+ 6145) `fmap` nextBits 11)
, (26, (+ 8193) `fmap` nextBits 12)
, (27, (+ 12289) `fmap` nextBits 12)
, (28, (+ 16385) `fmap` nextBits 13)
, (29, (+ 24577) `fmap` nextBits 13)
]
fixedLitTree :: HuffmanTree Int
fixedLitTree = computeHuffmanTree
([(x, 8) | x <- [0 .. 143]] ++
[(x, 9) | x <- [144 .. 255]] ++
[(x, 7) | x <- [256 .. 279]] ++
[(x, 8) | x <- [280 .. 287]])
fixedDistanceTree :: HuffmanTree Int
fixedDistanceTree = computeHuffmanTree [(x,5) | x <- [0..31]]
computeHuffmanTree :: [(Int, Int)] -> HuffmanTree Int
computeHuffmanTree = createHuffmanTree . computeCodeValues
computeCodeValues :: Ord a => [(a, Int)] -> [(a, Int, Int)]
computeCodeValues vals = Map.foldrWithKey (\ v (l, c) a -> (v,l,c):a) [] codes
where
valsNo0s = filter (\ (_, b) -> (b /= 0)) vals
valsSort = sortBy (\ (a,_) (b,_) -> compare a b) valsNo0s
blCount = foldr (\ (_,k) m -> Map.insertWith (+) k 1 m) Map.empty valsNo0s
nextcode = step2 0 1 (Map.insert 0 0 Map.empty)
lenTree = Map.fromList valsSort
codeTree = step3 (map fst valsSort) nextcode Map.empty
maxBits = maximum (map snd valsSort)
codes = Map.intersectionWith (,) lenTree codeTree
step2 code bits nc
| bits > maxBits = nc
| otherwise =
let prevCount = Map.findWithDefault 0 (bits 1) blCount
code' = (code + prevCount) `shiftL` 1
in step2 code' (bits + 1) (Map.insert bits code' nc)
step3 [] _ ct = ct
step3 (n:rest) nc ct =
let len = Map.findWithDefault 0 n lenTree
Just ncLen = Map.lookup len nc
ct' = Map.insert n ncLen ct
nc' = Map.insert len (ncLen + 1) nc
in if len == 0
then step3 rest nc ct
else step3 rest nc' ct'
codeLengthOrder :: [Int]
codeLengthOrder =
[16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15]