module Network.Metaverse.Packets where

import Control.Monad

import Data.Word
import Data.Bits

import Data.Binary
import Data.Binary.Put
import Data.Binary.Get

import Network.Metaverse.PacketTypes
import Network.Metaverse.Utils

import qualified Data.ByteString      as B
import qualified Data.ByteString.Lazy as L

type SequenceNum = Word32

data Packet = Packet {
    packetZerocoded  :: Bool,
    packetReliable   :: Bool,
    packetRetransmit :: Bool,
    packetSequence   :: SequenceNum,
    packetExtra      :: B.ByteString,
    packetBody       :: PacketBody,
    packetAcks       :: [SequenceNum]
    }
    deriving Show

serialize :: Packet -> B.ByteString
serialize (Packet zcode reliable retrans seq extra body acks) =
    let putter  = do
            let mask i b = if b then bit i else 0
            let nacks = length acks
            let flags = mask 4 (nacks > 0)
                    .|. mask 5 retrans
                    .|. mask 6 reliable
                    .|. mask 7 zcode
            putWord8 flags
            putWord32be seq
            putWord8 (fromIntegral (B.length extra))
            putByteString extra

            if zcode
               then putLazyByteString (zeroencode (encode body))
               else put body

            mapM_ putWord32be acks
            when (nacks > 0) (putWord8 (fromIntegral nacks))

    in  B.concat $ L.toChunks $ runPut putter

deserialize :: B.ByteString -> Packet
deserialize fullMsg = 
    -- Unfortunately, the encoding of the packets makes it impossible to
    -- just use Data.Binary.  We need to preprocess the header using
    -- something that lets us read from both sides.

    let -- First, read the flags from the first byte of the header.
        flags    = B.head fullMsg
        hasAcks  = testBit flags 4
        retrans  = testBit flags 5
        reliable = testBit flags 6
        zcode    = testBit flags 7

        -- Next, if there are appended acks, peel them off and read them.
        (withoutAcks, acks) = if hasAcks
            then let msg1                        = B.init fullMsg
                     nacks                       = B.last fullMsg
                     (result, appended)          = B.splitAt (B.length msg1 - 4 * fromIntegral nacks) msg1
                     ackGetter                   = replicateM (fromIntegral nacks) getWord32be
                     acks                        = runGet ackGetter (L.fromChunks [ appended ])
                 in  (result, acks)
            else (fullMsg, [])

        -- Now take off the header, so we're left with only the body (which
        -- may or may not be zerocoded.
        headerGetter = do _        <- getWord8    -- Header flags (already seen)
                          seq      <- getWord32be
                          extralen <- getWord8
                          extra    <- getBytes (fromIntegral extralen)
                          body     <- getRemainingLazyByteString
                          return (seq, extra, body)
        (seq, extra, encodedBody) = runGet headerGetter (L.fromChunks [ withoutAcks ])

        -- Un-zerocode the body if needed.  Zerocoding is indicated by a flag
        -- in the message headers.
        decodedBody = if zcode then zerodecode encodedBody else encodedBody

        -- Parse the body into a PacketBody.  The "decode" here means something
        -- different, hence the funny-sounding code.
        body = decode (decodedBody)

    in  Packet zcode reliable retrans seq extra body acks

zerodecode :: L.ByteString -> L.ByteString
zerodecode r | L.length r <= 1  = r
             | x == 0           = let Just (n, r') = L.uncons xs
                                  in  L.append (L.replicate (fromIntegral n) 0) (zerodecode r')
             | otherwise        = L.cons x (zerodecode xs)
    where Just (x, xs) = L.uncons r

zeroencode :: L.ByteString -> L.ByteString
zeroencode r | L.null r   = r
             | L.null pfx = L.cons x (zeroencode xs)
             | otherwise  = L.append (zeros (L.length pfx)) (zeroencode rest)
    where (pfx, rest)         = L.span (== 0) r
          Just (x, xs)        = L.uncons r
          zeros n | n > 255   = L.append (L.pack [0, 255]) (zeros (n - 255))
                  | n > 0     = L.pack [0, fromIntegral n]
                  | otherwise = L.empty