{-# LANGUAGE MultiWayIf #-}
module Codec.Compression.Zlib.Deflate(
         inflate
       , computeCodeValues
       )
 where

import Codec.Compression.Zlib.HuffmanTree
import Codec.Compression.Zlib.Monad
import Control.Monad
import Data.Bits
import Data.ByteString.Lazy(ByteString)
import qualified Data.ByteString.Lazy as BS
import Data.Int
import Data.List
import Data.Map.Strict(Map)
import qualified Data.Map.Strict as Map
import Data.Word
import MonadLib(raise)

inflate :: DeflateM ByteString
inflate =
  do isFinal <- inflateBlock
     if isFinal
        then checkChecksum >> finalOutput
        else inflate
 where
  shiftAdd x y = (x `shiftL` 8) .|. fromIntegral y
  --
  checkChecksum =
    do advanceToByte
       rest     <- readRest
       ourAdler <- finalAdler
       let theirAdler = BS.foldl shiftAdd 0 rest
       if | BS.length rest < 4     -> raise (ChecksumError "checksum missing")
          | BS.length rest > 4     -> raise (FormatError "Ends in middle of file")
          | theirAdler /= ourAdler -> raise (ChecksumError "checksum mismatch")
          | otherwise              -> return ()


inflateBlock :: DeflateM Bool
inflateBlock =
  do bfinal <- nextBit
     btype  <- nextBits 2
     case btype :: Word8 of
       0 -> -- no compression
         do advanceToByte
            len  <- nextWord16
            nlen <- nextWord16
            unless (len == complement nlen) $
              raise (FormatError "Len/nlen mismatch in uncompressed block.")
            emitBlock =<< nextBlock len
            return bfinal
       1 -> -- compressed with fixed Huffman codes
         do flt <- fixedLitTree
            fdt <- fixedDistanceTree
            runInflate flt fdt
            return bfinal
       2 -> -- compressed with dynamic Huffman codes
         do hlit  <- (257+) `fmap` nextBits 5
            hdist <- (1+)   `fmap` nextBits 5
            hclen <- (4+)   `fmap` nextBits 4
            codeLens <- replicateM hclen (nextBits 3)
            let codeLens' = zip codeLengthOrder codeLens
            codeTree <- computeHuffmanTree codeLens'
            lens <- getCodeLengths codeTree 0 (hlit + hdist) 0 Map.empty
            -- We do this as a big chunk and then split it up because the spec
            -- allows repeat codes to cross the hlit / hdist boundary. So now we
            -- need to pull off the hdist items.
            let (litlens, offdistlens) =
                    Map.partitionWithKey (\ k _ -> k < hlit) lens
                distlens = Map.mapKeys (\ k -> k - hlit) offdistlens
            litTree  <- computeHuffmanTree (Map.toList litlens)
            distTree <- computeHuffmanTree (Map.toList distlens)
            runInflate litTree distTree
            return bfinal
       _ -> -- reserved / error
         raise (FormatError ("Unacceptable BTYPE: " ++ show btype))
 where
  runInflate :: HuffmanTree Int -> HuffmanTree Int -> DeflateM ()
  runInflate litTree distTree =
    do code <- nextCode litTree
       if | code <  256 -> do emitByte (fromIntegral code)
                              runInflate litTree distTree
          | code == 256 -> return ()
          | code > 256  -> do len      <- getLength code
                              distCode <- nextCode distTree
                              dist     <- getDistance distCode
                              emitPastChunk dist len
                              runInflate litTree distTree

-- -----------------------------------------------------------------------------

getCodeLengths :: HuffmanTree Int ->
                  Int -> Int -> Int ->
                  Map Int Int ->
                  DeflateM (Map Int Int)
getCodeLengths tree n maxl prev acc
  | n >= maxl   = return acc
  | otherwise =
    do code <- nextCode tree
       if | code <= 15 ->
                getCodeLengths tree (n+1) maxl code (Map.insert n code acc)
          | code == 16 -> -- copy the previous code length 3 - 6 times
             do num <- (3+) `fmap` nextBits 2
                getCodeLengths tree (n+num) maxl prev (addNTimes n num prev acc)
          | code == 17 -> -- repeat a code length of 0 for 3 - 10 times
             do num <- (3+) `fmap` nextBits 3
                getCodeLengths tree (n+num) maxl 0    (addNTimes n num 0 acc)
          | code == 18 -> -- repeat a code length of 0 for 11 - 138 times
             do num <- (11+) `fmap` nextBits 7
                getCodeLengths tree (n+num) maxl 0    (addNTimes n num 0 acc)
 where
  addNTimes idx count val old =
    let idxs = take count [idx..]
        vals = replicate count val
    in Map.union old (Map.fromList (zip idxs vals))

-- -----------------------------------------------------------------------------

getLength :: Int -> DeflateM Int64
getLength c =
  case Map.lookup c getLengthMap of
    Nothing -> raise (DecompressionError ("getLength for bad code: "++show c))
    Just m  -> m

getLengthMap :: Map Int (DeflateM Int64)
getLengthMap = Map.fromList [
    (257, return 3)
  , (258, return 4)
  , (259, return 5)
  , (260, return 6)
  , (261, return 7)
  , (262, return 8)
  , (263, return 9)
  , (264, return 10)
  , (265, (+ 11)  `fmap` nextBits 1)
  , (266, (+ 13)  `fmap` nextBits 1)
  , (267, (+ 15)  `fmap` nextBits 1)
  , (268, (+ 17)  `fmap` nextBits 1)
  , (269, (+ 19)  `fmap` nextBits 2)
  , (270, (+ 23)  `fmap` nextBits 2)
  , (271, (+ 27)  `fmap` nextBits 2)
  , (272, (+ 31)  `fmap` nextBits 2)
  , (273, (+ 35)  `fmap` nextBits 3)
  , (274, (+ 43)  `fmap` nextBits 3)
  , (275, (+ 51)  `fmap` nextBits 3)
  , (276, (+ 59)  `fmap` nextBits 3)
  , (277, (+ 67)  `fmap` nextBits 4)
  , (278, (+ 83)  `fmap` nextBits 4)
  , (279, (+ 99)  `fmap` nextBits 4)
  , (280, (+ 115) `fmap` nextBits 4)
  , (281, (+ 131) `fmap` nextBits 5)
  , (282, (+ 163) `fmap` nextBits 5)
  , (283, (+ 195) `fmap` nextBits 5)
  , (284, (+ 227) `fmap` nextBits 5)
  , (285, return 258)
  ]

getDistance :: Int -> DeflateM Int
getDistance c =
  case Map.lookup c getDistanceMap of
    Nothing -> raise (DecompressionError ("getDistance for bad code: "++show c))
    Just m  -> m

getDistanceMap :: Map Int (DeflateM Int)
getDistanceMap = Map.fromList [
    (0,  return 1)
  , (1,  return 2)
  , (2,  return 3)
  , (3,  return 4)
  , (4,  (+ 5)     `fmap` nextBits 1)
  , (5,  (+ 7)     `fmap` nextBits 1)
  , (6,  (+ 9)     `fmap` nextBits 2)
  , (7,  (+ 13)    `fmap` nextBits 2)
  , (8,  (+ 17)    `fmap` nextBits 3)
  , (9,  (+ 25)    `fmap` nextBits 3)
  , (10, (+ 33)    `fmap` nextBits 4)
  , (11, (+ 49)    `fmap` nextBits 4)
  , (12, (+ 65)    `fmap` nextBits 5)
  , (13, (+ 97)    `fmap` nextBits 5)
  , (14, (+ 129)   `fmap` nextBits 6)
  , (15, (+ 193)   `fmap` nextBits 6)
  , (16, (+ 257)   `fmap` nextBits 7)
  , (17, (+ 385)   `fmap` nextBits 7)
  , (18, (+ 513)   `fmap` nextBits 8)
  , (19, (+ 769)   `fmap` nextBits 8)
  , (20, (+ 1025)  `fmap` nextBits 9)
  , (21, (+ 1537)  `fmap` nextBits 9)
  , (22, (+ 2049)  `fmap` nextBits 10)
  , (23, (+ 3073)  `fmap` nextBits 10)
  , (24, (+ 4097)  `fmap` nextBits 11)
  , (25, (+ 6145)  `fmap` nextBits 11)
  , (26, (+ 8193)  `fmap` nextBits 12)
  , (27, (+ 12289) `fmap` nextBits 12)
  , (28, (+ 16385) `fmap` nextBits 13)
  , (29, (+ 24577) `fmap` nextBits 13)
  ]

-- -----------------------------------------------------------------------------

fixedLitTree :: DeflateM (HuffmanTree Int)
fixedLitTree = computeHuffmanTree
  ([(x, 8) | x <- [0   .. 143]] ++
   [(x, 9) | x <- [144 .. 255]] ++
   [(x, 7) | x <- [256 .. 279]] ++
   [(x, 8) | x <- [280 .. 287]])

fixedDistanceTree :: DeflateM (HuffmanTree Int)
fixedDistanceTree = computeHuffmanTree [(x,5) | x <- [0..31]]

-- -----------------------------------------------------------------------------

computeHuffmanTree :: [(Int, Int)] -> DeflateM (HuffmanTree Int)
computeHuffmanTree initialData =
  case createHuffmanTree (computeCodeValues initialData) of
    Left  err -> raise (HuffmanTreeError err)
    Right x   -> return x

computeCodeValues :: Ord a => [(a, Int)] -> [(a, Int, Int)]
computeCodeValues vals = Map.foldrWithKey (\ v (l, c) a -> (v,l,c):a) [] codes
 where
  valsNo0s = filter (\ (_, b) -> (b /= 0)) vals
  valsSort = sortBy (\ (a,_) (b,_) -> compare a b) valsNo0s
  blCount  = foldr (\ (_,k) m -> Map.insertWith (+) k 1 m) Map.empty valsNo0s
  nextcode = step2 0 1 (Map.insert 0 0 Map.empty)
  lenTree  = Map.fromList valsSort
  codeTree = step3 (map fst valsSort) nextcode Map.empty
  maxBits  = maximum (map snd valsSort)
  codes    = Map.intersectionWith (,) lenTree codeTree
  --
  step2 code bits nc
    | bits > maxBits = nc
    | otherwise =
      let prevCount = Map.findWithDefault 0 (bits - 1) blCount
          code' = (code + prevCount) `shiftL` 1
      in step2 code' (bits + 1) (Map.insert bits code' nc) 
  --
  step3 [] _ ct = ct
  step3 (n:rest) nc ct =
    let len        = Map.findWithDefault 0 n lenTree
        Just ncLen = Map.lookup len nc
        ct'        = Map.insert n ncLen ct
        nc'        = Map.insert len (ncLen + 1) nc
    in if len == 0
          then step3 rest nc  ct
          else step3 rest nc' ct'

codeLengthOrder :: [Int]
codeLengthOrder =
  [16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15]