{-# Language ScopedTypeVariables, MultiWayIf, TypeFamilies #-}
module Data.Bond.Internal.FastBinaryProto (
        FastBinaryProto(..)
    ) where

import Data.Bond.Proto
import Data.Bond.Types
import Data.Bond.Internal.BinaryUtils
import Data.Bond.Internal.BondedUtils
import Data.Bond.Internal.Cast
import Data.Bond.Internal.Protocol
import Data.Bond.Internal.ProtoUtils
import Data.Bond.Internal.SchemaUtils
import Data.Bond.Internal.TaggedProtocol

import Data.Bond.Schema.BondDataType
import Data.Bond.Schema.ProtocolType

import Control.Applicative
import Control.Monad
import Control.Monad.Error
import Data.List
import Data.Maybe
import Data.Proxy
import Prelude          -- ghc 7.10 workaround for Control.Applicative

import qualified Data.Binary.Get as B
import qualified Data.Binary.Put as B
import qualified Data.ByteString as BS
import qualified Data.ByteString.Lazy as BL
import qualified Data.HashSet as H
import qualified Data.Map as M
import qualified Data.Set as S
import qualified Data.Vector as V

-- |A binary, tagged protocol similar to 'CompactBinaryProto' but optimized for deserialization speed rather than payload compactness.
data FastBinaryProto = FastBinaryProto

instance TaggedProtocol FastBinaryProto where
    getFieldHeader = do
        t <- BondDataType . fromIntegral <$> getWord8
        n <- if t == bT_STOP || t == bT_STOP_BASE then return 0 else getWord16le
        return (t, Ordinal n)
    getListHeader = do
        t <- BondDataType . fromIntegral <$> getWord8
        n <- getVarInt
        return (t, n)
    getTaggedStruct = getTaggedData
    putFieldHeader t (Ordinal n) = do
        putTag t
        putWord16le n
    putListHeader t n = do
        putTag t
        putVarInt n
    putTaggedStruct s = putTaggedData s >> putTag bT_STOP
    skipStruct =
        let loop = do
                (td, _) <- getFieldHeader
                if | td == bT_STOP -> return ()
                   | td == bT_STOP_BASE -> loop
                   | otherwise -> skipType td >> loop
         in loop
    skipRestOfStruct = skipType bT_STRUCT
    skipType t =
        if | t == bT_BOOL -> skip 1
           | t == bT_UINT8 -> skip 1
           | t == bT_UINT16 -> skip 2
           | t == bT_UINT32 -> skip 4
           | t == bT_UINT64 -> skip 8
           | t == bT_FLOAT -> skip 4
           | t == bT_DOUBLE -> skip 8
           | t == bT_STRING -> getVarInt >>= skip
           | t == bT_STRUCT ->
                let loop = do
                        td <- BondDataType . fromIntegral <$> getWord8
                        if | td == bT_STOP -> return ()
                           | td == bT_STOP_BASE -> loop
                           | otherwise -> skip 2 >> skipType td >> loop
                 in loop
           | t == bT_LIST -> do
                td <- BondDataType . fromIntegral <$> getWord8
                n <- getVarInt
                replicateM_ n (skipType td)
           | t == bT_SET -> skipType bT_LIST
           | t == bT_MAP -> do
                tk <- BondDataType . fromIntegral <$> getWord8
                tv <- BondDataType . fromIntegral <$> getWord8
                n <- getVarInt
                replicateM_ n $ skipType tk >> skipType tv
           | t == bT_INT8 -> skip 1
           | t == bT_INT16 -> skip 2
           | t == bT_INT32 -> skip 4
           | t == bT_INT64 -> skip 8
           | t == bT_WSTRING -> do
                n <- getVarInt
                skip $ n * 2
           | otherwise -> fail $ "Invalid type to skip " ++ bondTypeName t

instance BondProto FastBinaryProto where
    bondRead = binaryDecode
    bondWrite = binaryEncode
    bondReadWithSchema = readTaggedWithSchema
    bondWriteWithSchema = writeTaggedWithSchema
    protoSig _ = protoHeader fAST_PROTOCOL 1

instance BondTaggedProto FastBinaryProto where
    bondReadTagged = readTagged
    bondWriteTagged = writeTagged

instance Protocol FastBinaryProto where
    type ReaderM FastBinaryProto = B.Get
    type WriterM FastBinaryProto = ErrorT String B.PutM

    bondGetStruct = getStruct TopLevelStruct
    bondGetBaseStruct = getStruct BaseStruct

    bondGetBool = do
        v <- getWord8
        return $ v /= 0
    bondGetUInt8 = getWord8
    bondGetUInt16 = getWord16le
    bondGetUInt32 = getWord32le
    bondGetUInt64 = getWord64le
    bondGetInt8 = fromIntegral <$> getWord8
    bondGetInt16 = fromIntegral <$> getWord16le
    bondGetInt32 = fromIntegral <$> getWord32le
    bondGetInt64 = fromIntegral <$> getWord64le
    bondGetFloat = wordToFloat <$> getWord32le
    bondGetDouble = wordToDouble <$> getWord64le
    bondGetString = do
        n <- getVarInt
        Utf8 <$> getByteString n
    bondGetWString = do
        n <- getVarInt
        Utf16 <$> getByteString (n * 2)
    bondGetBlob = do
        (t, n) <- getListHeader
        unless (t == bT_INT8) $ fail $ "invalid element tag " ++ bondTypeName t ++ " in blob field"
        Blob <$> getByteString n
    bondGetDefNothing = Just <$> bondGet
    bondGetList = getList
    bondGetHashSet = H.fromList <$> bondGetList
    bondGetSet = S.fromList <$> bondGetList
    bondGetMap = getMap
    bondGetVector = getVector
    bondGetNullable = do
        v <- bondGetList
        case v of
            [] -> return Nothing
            [x] -> return (Just x)
            _ -> fail $ "list of length " ++ show (length v) ++ " where nullable expected"
    bondGetBonded = do
        size <- lookAhead $ do
            start <- bytesRead
            skipType bT_STRUCT
            stop <- bytesRead
            return (stop - start)
        bs <- getLazyByteString (fromIntegral size)
        return $ BondedStream $ BL.append (protoHeader fAST_PROTOCOL 1) bs

    bondPutStruct = putStruct TopLevelStruct
    bondPutBaseStruct = putBaseStruct
    bondPutField = putField
    bondPutDefNothingField p n Nothing = unless (isOptionalField p n) $ fail "can't write nothing to non-optional field"
    bondPutDefNothingField p n (Just v) = putField p n v

    bondPutBool True = putWord8 1
    bondPutBool False = putWord8 0
    bondPutUInt8 = putWord8
    bondPutUInt16 = putWord16le
    bondPutUInt32 = putWord32le
    bondPutUInt64 = putWord64le
    bondPutInt8 = putWord8 . fromIntegral
    bondPutInt16 = putWord16le . fromIntegral
    bondPutInt32 = putWord32le . fromIntegral
    bondPutInt64 = putWord64le . fromIntegral
    bondPutFloat = putWord32le . floatToWord
    bondPutDouble = putWord64le . doubleToWord
    bondPutString (Utf8 s) = do
        putVarInt $ BS.length s
        putByteString s
    bondPutWString (Utf16 s) = do
        putVarInt $ BS.length s `div` 2
        putByteString s
    bondPutList = putList
    bondPutNullable = bondPutList . maybeToList
    bondPutHashSet = putHashSet
    bondPutSet = putSet
    bondPutMap = putMap
    bondPutVector = putVector
    bondPutBlob (Blob b) = do
        putTag bT_INT8
        putVarInt $ BS.length b
        putByteString b
    bondPutBonded (BondedObject a) = bondPut a
    bondPutBonded s = do
        BondedStream stream <- case bondRecode FastBinaryProto s of
            Left msg -> throwError $ "Bonded recode error: " ++ msg
            Right v -> return v
        putLazyByteString (BL.drop 4 stream)

getList :: forall a. BondType a => BondGet FastBinaryProto [a]
getList = do
    let et = getWireType (Proxy :: Proxy a)
    (t, n) <- getListHeader
    unless (t == et) $ fail $ "invalid element tag " ++ bondTypeName t ++ " in list field, " ++ bondTypeName et ++ " expected"
    replicateM n bondGet

getVector :: forall a. BondType a => BondGet FastBinaryProto (Vector a)
getVector = do
    let et = getWireType (Proxy :: Proxy a)
    (t, n) <- getListHeader
    unless (t == et) $ fail $ "invalid element tag " ++ bondTypeName t ++ " in list field, " ++ bondTypeName et ++ " expected"
    V.replicateM n bondGet

getMap :: forall k v. (Ord k, BondType k, BondType v) => BondGet FastBinaryProto (Map k v)
getMap = do
    let etk = getWireType (Proxy :: Proxy k)
    let etv = getWireType (Proxy :: Proxy v)
    tk <- BondDataType . fromIntegral <$> getWord8
    tv <- BondDataType . fromIntegral <$> getWord8
    unless (tk == etk) $ fail $ "invalid element tag " ++ bondTypeName tk ++ " in list field, " ++ bondTypeName etk ++ " expected"
    unless (tv == etv) $ fail $ "invalid element tag " ++ bondTypeName tv ++ " in list field, " ++ bondTypeName etv ++ " expected"
    n <- getVarInt
    fmap M.fromList $ replicateM n $ do
        k <- bondGet
        v <- bondGet
        return (k, v)

putList :: forall a. BondType a => [a] -> BondPut FastBinaryProto
putList xs = do
    putListHeader (getWireType (Proxy :: Proxy a)) (length xs)
    mapM_ bondPut xs

putHashSet :: forall a. BondType a => HashSet a -> BondPut FastBinaryProto
putHashSet xs = do
    putListHeader (getWireType (Proxy :: Proxy a)) (H.size xs)
    mapM_ bondPut $ H.toList xs

putSet :: forall a. BondType a => Set a -> BondPut FastBinaryProto
putSet xs = do
    putListHeader (getWireType (Proxy :: Proxy a)) (S.size xs)
    mapM_ bondPut $ S.toList xs

putMap :: forall k v. (BondType k, BondType v) => Map k v -> BondPut FastBinaryProto
putMap m = do
    putTag $ getWireType (Proxy :: Proxy k)
    putTag $ getWireType (Proxy :: Proxy v)
    putVarInt $ M.size m
    forM_ (M.toList m) $ \(k, v) -> do
        bondPut k
        bondPut v

putVector :: forall a. BondType a => Vector a -> BondPut FastBinaryProto
putVector xs = do
    putListHeader (getWireType (Proxy :: Proxy a)) (V.length xs)
    V.mapM_ bondPut xs