module Codec.Compression.Lzo.File ( decompressFile
                                  , compressFile
                                  ) where

import           Codec.Compression.Lzo.Block
import           Control.Monad               (unless, when)
import           Data.Binary.Get             (Get, getByteString, getWord16be, getWord32be, getWord8, lookAhead, runGetOrFail, skip)
import           Data.Binary.Put             (Put, putByteString, putLazyByteString, putWord16be, putWord32be, putWord8, runPut)
import           Data.Bits                   (Bits, (.&.), (.|.))
import qualified Data.ByteString             as BS
import qualified Data.ByteString.Lazy        as BSL
import           Data.Digest.Adler32         (adler32)
import           Data.Digest.CRC32           (crc32)
import           Data.Semigroup              ((<>))
import           Data.Word                   (Word32)

-- see: https://github.com/ir193/python-lzo/
-- see https://github.com/libarchive/libarchive/blob/3649ed23c6b4392d692580c03b10a611e3eaaa32/libarchive/archive_read_support_filter_lzop.c
lzopMagic :: BS.ByteString
lzopMagic = BS.pack [0x89, 0x4c, 0x5a, 0x4f, 0x00, 0x0d, 0x0a, 0x1a, 0x0a]

hasFlag :: (Num a, Bits a) => a -> a -> Bool
hasFlag x flag = (x .&. flag) /= 0

type LzoReadHeader = Word32

type LzoBlock = Maybe BS.ByteString

getMagic :: Get ()
getMagic = do
    inp <- getByteString 9
    unless (inp == lzopMagic) $
        fail "Invalid lzop magic bytes, perhaps it is not an lzop file?"

adler32cFlag :: Word32
adler32cFlag = 0x0002

adler32dFlag :: Word32
adler32dFlag = 0x0001

crc32dFlag :: Word32
crc32dFlag = 0x0100

crc32cFlag :: Word32
crc32cFlag = 0x0200

failChecksum :: Show a => a -> Word32 -> Get b
failChecksum expected actual =
    fail ("Checksum does not match; expected " ++ show expected ++ ", found " ++ show actual)

putLzoBlock :: LzoBlock -> Put
putLzoBlock Nothing =
    putWord32be 0
putLzoBlock (Just b) =
       putWord32be dst
    <> putWord32be (min src dst)
    <> putWord32be dAdler
    <> putByteString
        (if dst <= src then b else compressed)

    where compressed = compress b
          -- uncompressed length
          dst = fromIntegral (BS.length b)
          src = fromIntegral (BS.length compressed)
          dAdler = adler32 b

getLzoBlock :: Word32 -- ^ Flags
            -> Get LzoBlock
getLzoBlock ff = {-# SCC "getLzoBlock" #-} do
    -- uncompressed length
    dst <- getWord32be
    if dst == 0
        then pure Nothing
        else Just <$> do
            src <- getWord32be
            when (dst > 64 * 1024 * 1024) $
                fail "Uncompressed data longer than max block size"
            dAdler <- mGet
                (hasFlag ff adler32dFlag)
                getWord32be
            dCrc <- mGet
                (hasFlag ff crc32dFlag)
                getWord32be
            cAdler <- if hasFlag ff adler32cFlag
                then if src <= dst then Just <$> getWord32be else pure dAdler
                else pure Nothing
            cCrc <- if hasFlag ff crc32cFlag
                then if src <= dst then Just <$> getWord32be else pure dCrc
                else pure Nothing
            srcData <- getByteString (fromIntegral src)
            when (hasFlag ff adler32cFlag) $ do
                let actual = adler32 srcData
                unless (Just actual == cAdler) $
                    failChecksum cAdler actual
            when (hasFlag ff crc32cFlag) $ do
                let actual = crc32 srcData
                unless (Just actual == cCrc) $
                    failChecksum cCrc actual
            let decData = if src < dst
                then decompress srcData (fromIntegral dst)
                else srcData
            when (hasFlag ff adler32dFlag) $ do
                let actual = {-# SCC "adler32d" #-} adler32 decData
                unless (Just actual == dAdler) $
                    failChecksum dAdler actual
            when (hasFlag ff crc32dFlag) $ do
                let actual = crc32 decData
                unless (Just actual == dCrc) $
                    failChecksum dCrc actual
            pure decData

mGet :: Bool -> Get a -> Get (Maybe a)
mGet True dec = Just <$> dec
mGet False _  = pure Nothing

unixFlag :: Word32
unixFlag = 0x03000000

-- https://github.com/libarchive/libarchive/blob/3649ed23c6b4392d692580c03b10a611e3eaaa32/libarchive/archive_write_add_filter_lzop.c#L104
preLzoHeader :: Put
preLzoHeader =
       putWord16be 0x1030 -- lzop version
    <> putWord16be 0x940 -- version
    <> putWord16be 0x940 -- just for safety, min version
    <> putWord8 1 -- method
    <> putWord8 2 -- compression level
    <> putWord32be (unixFlag .|. adler32dFlag) -- flags
    <> putWord32be 0x81a4 -- from here: https://github.com/libarchive/libarchive/blob/3649ed23c6b4392d692580c03b10a611e3eaaa32/libarchive/archive_write_add_filter_lzop.c#L123
    <> putWord32be 0 -- mtime low (ignored)
    <> putWord32be 0 -- mtime high (ignored)
    <> putWord8 0 -- filename length

putLzoHeader :: Put
putLzoHeader =
       putByteString lzopMagic
    <> putLazyByteString headerBS
    <> putWord32be chk
    where headerBS = runPut preLzoHeader
          chk = adler32 headerBS

getLzoHeader :: Get LzoReadHeader
getLzoHeader = do
    headerBytes <- lookAhead (getByteString 25)
    v <- getWord16be
    unless
        (v >= 0x940)
        (fail "lzo format version too low")
    skip 4
    m <- getWord8
    unless (m `elem` [1..3]) $
        fail ("Unsupported or invalid method: " ++ show m)
    skip 1
    fl <- getWord32be
    when
        (hasFlag fl 0x0800)
        (fail "Filters not supported.")
    skip 12
    filenameLength <- getWord8
    fn <- getByteString (fromIntegral filenameLength)
    chk <- getWord32be
    let actual = adler32 $ headerBytes <> fn
    unless (chk == actual) $
        failChecksum chk actual
    when (hasFlag fl 0x0040) $
        fail "Extra data not supported."
    pure fl

putChunks :: [BS.ByteString] -> Put
putChunks bs =
       putLzoHeader
    <> foldMap putLzoBlock (Just <$> bs)
    <> putLzoBlock Nothing

compressFile :: BSL.ByteString -> BSL.ByteString
compressFile = runPut . putChunks . BSL.toChunks

getFile :: BSL.ByteString -> [BS.ByteString]
getFile bsl =
    let (rest, _, header) =
            asE $ runGetOrFail (getMagic *> getLzoHeader) bsl
                in loop header rest

    where loop ff bs =
            let (rest, _, res) = asE $ runGetOrFail (getLzoBlock ff) bs in
                case res of
                    Nothing -> []
                    Just x  -> x : loop ff rest

          asE = either (error.show) id


decompressFile :: BSL.ByteString -> BSL.ByteString
decompressFile = BSL.fromChunks . getFile