{- Inflate implementation for Haskell Copyright 2004, 2007 Ian Lynagh Licence: Your choice of GPL version 2 or 3 clause BSD. This module provides a Haskell implementation of the inflate function, as described by RFC 1951. -} module Codec.Compression.Deflate.Inflate (Octets, inflate) where import Codec.Compression.LazyStateT import Codec.Compression.UnsafeInterleave import Codec.Compression.Utils -- import Control.Monad import Control.Monad.State -- import Data.Bits import Data.List import Data.IORef -- import Data.Word import Data.ByteString.Lazy (ByteString) import qualified Data.ByteString.Lazy as BS import Data.ByteString.Base (fromForeignPtr) import Foreign type Octet = Word8 -- The basic inut/output type type Octets = ByteString -- We use lazy bytestrings rather than [Word8] -- for efficiency type Code = Word16 -- A generic code type Dist = Code -- A distance code (1-32768) type LitLen = Code -- A literal/length code (3-258) type Length = Word8 -- Number of bits needed to identify a code type Table = InfM Code -- A Huffman table type Tables = (Table, Table) -- lit/len and dist Huffman tables data St = St { num_bits :: !Word8, -- number of remaining input bits bits :: !Word, -- remaining input bits (< 8) octets :: !Octets, -- remaining input octets history :: !(Ptr Octet), -- last 32768 output words loc :: !Word16, -- where in history we are var :: !(IORef Octets) -- where to put trailing chars } type InfM a = LazyStateT St IO a extract_InfM :: IORef Octets -> Octets -> InfM a -> IO a extract_InfM ref os m = do arr <- mallocArray 32768 let init_state = St { num_bits = 0, bits = 0, octets = os, history = arr, loc = 0, var = ref } evalLazyStateT m init_state align_8_bits :: InfM () align_8_bits = do s <- get put $ s { bits = 0, num_bits = 0 } -- n at most 65535 get_octets :: Word16 -> InfM Octets get_octets n = do s <- get let os = octets s n' = fromIntegral n if BS.length os < n' then error "get_octets: Insufficient remaining" else case BS.splitAt n' os of (pref, suff) -> do put $ s { octets = suff } return pref -- XXX Should we mask on return instead of on store? -- i at most 16 get_w16 :: Word8 -> InfM Word16 get_w16 0 = return 0 get_w16 i = do s <- get let n = num_bits s if i == n then do put $ s { num_bits = 0, bits = 0 } return $ fromIntegral $ bits s else if i < n then do let bs = bits s i' = fromIntegral i mask = (1 `shiftL` i') - 1 put $ s { num_bits = n - i, bits = bs `shiftR` i' } return $ fromIntegral $ (bs .&. mask) -- XXX Could inline from here down else do let os = octets s bs = fromIntegral $ BS.head os let new_bs = bs `shiftL` fromIntegral n put $ s { num_bits = num_bits s + 8, bits = bits s .|. new_bs, octets = BS.tail os } get_w16 i -- i at most 8 get_w8 :: Word8 -> InfM Word8 get_w8 i = do w <- get_w16 i return (fromIntegral w) get_bit :: InfM Bool get_bit = do s <- get let n = num_bits s if n > 0 then do let bs = bits s put $ s { num_bits = n - 1, bits = bs `shiftR` 1 } return $ testBit bs 0 else do let os = octets s bs = fromIntegral $ BS.head os put $ s { num_bits = 7, bits = bs `shiftR` 1, octets = BS.tail os } return $ testBit bs 0 {- We have 2 ways to provide more output. We can either write a single octet out or repeat a given number of bits a given distance back in the history. -} output :: Octet -> InfM () output w = do s <- get let l = loc s lift $ pokeElemOff (history s) (fromIntegral l) w put $ s { loc = (l + 1) `mod` 32768 } -- len `elem` [3..258] -- dist `elem` [1..32768] repeat_w32s :: Word16 -> Word16 -> InfM Octets repeat_w32s len dist = do s <- get let l = loc s h = history s start_index = fromIntegral ((l - dist) `mod` 32768) len' = fromIntegral len -- XXX This should be roughly a moveArray f !0 !_ !_ = return () f num 32768 to = f num 0 to f num from 32768 = f num from 0 f num from to = do peekElemOff h from >>= pokeElemOff h to f (num - 1) (from + 1) (to + 1) put $ s { loc = (l + len) `mod` 32768 } lift $ f len start_index (fromIntegral l) fp <- lift $ mallocForeignPtrArray len' lift $ withForeignPtr fp $ \p -> if (start_index + len') <= 32768 then copyArray p (h `advancePtr` start_index) len' else do let len1 = 32768 - start_index len2 = len' - len1 copyArray p (h `advancePtr` start_index) len1 copyArray (p `advancePtr` len1) h len2 return $ BS.fromChunks [fromForeignPtr fp len'] {- The hardcore stuff! To inflate an octet stream we use inflate_blocks to do the hard work. It in turn looks at the first 3 bits to decide whether to just output an uncompressed segment or pass off the work to inflate_tables and inflate_codes. -} inflate :: IORef Octets -> Octets -> IO Octets inflate ref os = extract_InfM ref os (inflate_blocks False) -- Bool is true if we have seen the "last" block marker inflate_blocks :: Bool -> InfM Octets inflate_blocks True = do align_8_bits -- redundant as we only look at octets s <- get liftIO $ writeIORef (var s) (octets s) return BS.empty inflate_blocks False = do w <- get_w16 3 -- XXX Could be a more efficient type let is_last = testBit w 0 case w `shiftR` 1 of 0 -> do align_8_bits len <- get_w16 16 nlen <- get_w16 16 -- check nlen = 1s complement of len unless (len + nlen == -1) $ error "inflate_blocks: Mismatched lengths" ws <- get_octets len mapM_ output $ BS.unpack ws -- XXX efficiency ws_tail <- unsafeInterleave $ inflate_blocks is_last return (ws `myAppend` ws_tail) 1 -> inflate_codes is_last inflate_trees_fixed 2 -> do tables <- inflate_tables inflate_codes is_last tables 3 -> error "inflate_blocks: case 11 reserved" _ -> error "inflate_blocks: can't happen" inflate_tables :: InfM Tables inflate_tables = do hlit <- get_w16 5 hdist <- get_w16 5 hclen <- get_w8 4 let f i = do w <- get_w8 3 return (w, i) order = [16,17,18,0,8,7,9,6,10,5,11,4,12,3,13,2,14,1,15] llc_bs <- mapM f $ genericTake (hclen + 4) order let tab = make_table llc_bs lit_dist_lengths <- make_lit_dist_lengths tab (258 + hlit + hdist) (error "inflate_tables dummy") -- XXX Use Exactly variant? let (lit_lengths, dist_lengths) = genericSplitAt (257 + hlit) lit_dist_lengths lit_table = make_table (zip lit_lengths [0..]) dist_table = make_table (zip dist_lengths [0..]) return (lit_table, dist_table) {- make_lit_dist_lengths reads n (at most ~350) dist and length code lengths. -} make_lit_dist_lengths :: Table -> Word16 -> Length -> InfM [Length] make_lit_dist_lengths _ n _ | n < 0 = error "make_lit_dist_lengths n < 0" make_lit_dist_lengths _ 0 _ = return [] make_lit_dist_lengths tab n last_thing = do c <- tab (ls, n', last_thing') <- meta_code n c last_thing ws <- make_lit_dist_lengths tab n' last_thing' return (ls ++ ws) meta_code :: Word16 -> Code -> Length -> InfM ([Length], Word16, Length) meta_code n i _ | i < 16 = let i' = fromIntegral i in return ([i'], n - 1, i') meta_code n 16 last_thing = do w <- get_w16 2 let l = 3 + w return (genericReplicate l last_thing, n - l, last_thing) meta_code n 17 _ = do w <- get_w16 3 let l = 3 + w return (genericReplicate l 0, n - l, 0) meta_code n 18 _ = do w <- get_w16 7 let l = 11 + w return (genericReplicate l 0, n - l, 0) meta_code _ i _ = error $ "meta_code: " ++ show i inflate_codes :: Bool -> Tables -> InfM Octets inflate_codes seen_last tabs@(tab_litlen, tab_dist) = do i <- tab_litlen; if i == 256 then inflate_blocks seen_last else do pref <- if i < 256 then do let i' = fromIntegral i output i' return $ BS.singleton i' else case lookup i litlens of Nothing -> error "do_code_litlen" -- num_extra_bits `elem` [0..5] Just (base, num_extra_bits) -> do extra <- get_w16 num_extra_bits -- l `elem` [3..258] let l = base + extra -- dist `elem` [1..32768] dist <- dist_code tab_dist repeat_w32s l dist o <- unsafeInterleave $ inflate_codes seen_last tabs return (pref `myAppend` o) litlens :: [(Code, (LitLen, Word8))] litlens = zip [257..285] $ mk_bases 3 litlen_counts ++ [(258, 0)] where litlen_counts = [(8,0),(4,1),(4,2),(4,3),(4,4),(4,5)] dist_code :: Table -> InfM Dist dist_code tab = do code <- tab case lookup code dists of Nothing -> error "dist_code" -- num_extra_bits `elem` [0..13] Just (base, num_extra_bits) -> do extra <- get_w16 num_extra_bits return (base + extra) dists :: [(Code, (Dist, Word8))] dists = zip [0..29] $ mk_bases 1 dist_counts where dist_counts = (4,0):map ((,) 2) [1..13] mk_bases :: Word16 -> [(Int, Word16)] -> [(Word16, Word8)] mk_bases base counts = snd $ mapAccumL next_base base incs where next_base current bs = (current + 2^bs, (current, fromIntegral bs)) incs = concat $ map (uncurry replicate) counts -- The fixed tables. inflate_trees_fixed :: Tables inflate_trees_fixed = (make_table $ [(8, c) | c <- [0..143]] ++ [(9, c) | c <- [144..255]] ++ [(7, c) | c <- [256..279]] ++ [(8, c) | c <- [280..287]], make_table [(5, c) | c <- [0..29]]) {- The Huffman Tree As the name suggests, the obvious way to store Huffman trees is in a tree datastructure. Externally we want to view them as functions though, so we wrap the tree with \verb!get_code! which takes a list of bits and returns the corresponding code and the remaining bits. To make a tree from a list of length code pairs is a simple recursive process. -} data Tree = Branch Tree Tree | Leaf Code | Null make_table :: [(Length, Code)] -> Table make_table lcs = case make_tree 0 $ sort $ filter ((/= 0) . fst) lcs of (tree, []) -> get_code tree _ -> error $ "make_table: Left-over lcs from" get_code :: Tree -> InfM Code get_code (Branch zero_tree one_tree) = do b <- get_bit if b then get_code one_tree else get_code zero_tree get_code (Leaf w) = return w get_code Null = error "get_code Null" make_tree :: Length -> [(Length, Code)] -> (Tree, [(Length, Code)]) make_tree _ [] = (Null, []) make_tree i lcs@((l, c):lcs') | i == l = (Leaf c, lcs') | i < l = let (zero_tree, lcs_z) = make_tree (i+1) lcs (one_tree, lcs_o) = make_tree (i+1) lcs_z in (Branch zero_tree one_tree, lcs_o) | otherwise = error "make_tree: can't happen"