module Database.CQL.Protocol.Codec
    ( encodeByte
    , decodeByte
    , encodeSignedByte
    , decodeSignedByte
    , encodeShort
    , decodeShort
    , encodeSignedShort
    , decodeSignedShort
    , encodeInt
    , decodeInt
    , encodeString
    , decodeString
    , encodeLongString
    , decodeLongString
    , encodeBytes
    , decodeBytes
    , encodeShortBytes
    , decodeShortBytes
    , encodeUUID
    , decodeUUID
    , encodeList
    , decodeList
    , encodeMap
    , decodeMap
    , encodeMultiMap
    , decodeMultiMap
    , encodeSockAddr
    , decodeSockAddr
    , encodeConsistency
    , decodeConsistency
    , encodeOpCode
    , decodeOpCode
    , encodeColumnType
    , decodeColumnType
    , encodePagingState
    , decodePagingState
    , decodeKeyspace
    , decodeTable
    , decodeQueryId
    , putValue
    , getValue
    ) where
import Control.Applicative
import Control.Monad
import Data.Bits
import Data.ByteString (ByteString)
import Data.Decimal
import Data.Int
import Data.IP
import Data.List (unfoldr)
import Data.Text (Text)
import Data.UUID (UUID)
import Data.Word
import Data.Serialize hiding (decode, encode)
import Database.CQL.Protocol.Types
import Network.Socket (SockAddr (..), PortNumber (..))
import Prelude
import qualified Data.ByteString         as B
import qualified Data.ByteString.Lazy    as LB
import qualified Data.Text.Encoding      as T
import qualified Data.Text.Lazy          as LT
import qualified Data.Text.Lazy.Encoding as LT
import qualified Data.UUID               as UUID
encodeByte :: Putter Word8
encodeByte = put
decodeByte :: Get Word8
decodeByte = get
encodeSignedByte :: Putter Int8
encodeSignedByte = put
decodeSignedByte :: Get Int8
decodeSignedByte = get
encodeShort :: Putter Word16
encodeShort = put
decodeShort :: Get Word16
decodeShort = get
encodeSignedShort :: Putter Int16
encodeSignedShort = put
decodeSignedShort :: Get Int16
decodeSignedShort = get
encodeInt :: Putter Int32
encodeInt = put
decodeInt :: Get Int32
decodeInt = get
encodeString :: Putter Text
encodeString = encodeShortBytes . T.encodeUtf8
decodeString :: Get Text
decodeString = T.decodeUtf8 <$> decodeShortBytes
encodeLongString :: Putter LT.Text
encodeLongString = encodeBytes . LT.encodeUtf8
decodeLongString :: Get LT.Text
decodeLongString = do
    n <- get :: Get Int32
    LT.decodeUtf8 <$> getLazyByteString (fromIntegral n)
encodeBytes :: Putter LB.ByteString
encodeBytes bs = do
    put (fromIntegral (LB.length bs) :: Int32)
    putLazyByteString bs
decodeBytes :: Get (Maybe LB.ByteString)
decodeBytes = do
    n <- get :: Get Int32
    if n < 0
        then return Nothing
        else Just <$> getLazyByteString (fromIntegral n)
encodeShortBytes :: Putter ByteString
encodeShortBytes bs = do
    put (fromIntegral (B.length bs) :: Word16)
    putByteString bs
decodeShortBytes :: Get ByteString
decodeShortBytes = do
    n <- get :: Get Word16
    getByteString (fromIntegral n)
encodeUUID :: Putter UUID
encodeUUID = putLazyByteString . UUID.toByteString
decodeUUID :: Get UUID
decodeUUID = do
    uuid <- UUID.fromByteString <$> getLazyByteString 16
    maybe (fail "decode-uuid: invalid") return uuid
encodeList :: Putter [Text]
encodeList sl = do
    put (fromIntegral (length sl) :: Word16)
    mapM_ encodeString sl
decodeList :: Get [Text]
decodeList = do
    n <- get :: Get Word16
    replicateM (fromIntegral n) decodeString
encodeMap :: Putter [(Text, Text)]
encodeMap m = do
    put (fromIntegral (length m) :: Word16)
    forM_ m $ \(k, v) -> encodeString k >> encodeString v
decodeMap :: Get [(Text, Text)]
decodeMap = do
    n <- get :: Get Word16
    replicateM (fromIntegral n) ((,) <$> decodeString <*> decodeString)
encodeMultiMap :: Putter [(Text, [Text])]
encodeMultiMap mm = do
    put (fromIntegral (length mm) :: Word16)
    forM_ mm $ \(k, v) -> encodeString k >> encodeList v
decodeMultiMap :: Get [(Text, [Text])]
decodeMultiMap = do
    n <- get :: Get Word16
    replicateM (fromIntegral n) ((,) <$> decodeString <*> decodeList)
encodeSockAddr :: Putter SockAddr
encodeSockAddr (SockAddrInet p a) = do
    putWord8 4
    putWord32le a
    putWord32be (fromIntegral p)
encodeSockAddr (SockAddrInet6 p _ (a, b, c, d) _) = do
    putWord8 16
    putWord32host a
    putWord32host b
    putWord32host c
    putWord32host d
    putWord32be (fromIntegral p)
encodeSockAddr (SockAddrUnix _) = fail "encode-socket: unix address not allowed"
decodeSockAddr :: Get SockAddr
decodeSockAddr = do
    n <- getWord8
    case n of
        4  -> do
            i <- getIPv4
            p <- getPort
            return $ SockAddrInet p i
        16 -> do
            i <- getIPv6
            p <- getPort
            return $ SockAddrInet6 p 0 i 0
        _  -> fail $ "decode-socket: unknown: " ++ show n
  where
    getPort :: Get PortNumber
    getPort = fromIntegral <$> getWord32be
    getIPv4 :: Get Word32
    getIPv4 = getWord32le
    getIPv6 :: Get (Word32, Word32, Word32, Word32)
    getIPv6 = (,,,) <$> getWord32host <*> getWord32host <*> getWord32host <*> getWord32host
encodeConsistency :: Putter Consistency
encodeConsistency Any         = encodeShort 0x00
encodeConsistency One         = encodeShort 0x01
encodeConsistency Two         = encodeShort 0x02
encodeConsistency Three       = encodeShort 0x03
encodeConsistency Quorum      = encodeShort 0x04
encodeConsistency All         = encodeShort 0x05
encodeConsistency LocalQuorum = encodeShort 0x06
encodeConsistency EachQuorum  = encodeShort 0x07
encodeConsistency Serial      = encodeShort 0x08
encodeConsistency LocalSerial = encodeShort 0x09
encodeConsistency LocalOne    = encodeShort 0x0A
decodeConsistency :: Get Consistency
decodeConsistency = decodeShort >>= mapCode
      where
        mapCode 0x00 = return Any
        mapCode 0x01 = return One
        mapCode 0x02 = return Two
        mapCode 0x03 = return Three
        mapCode 0x04 = return Quorum
        mapCode 0x05 = return All
        mapCode 0x06 = return LocalQuorum
        mapCode 0x07 = return EachQuorum
        mapCode 0x08 = return Serial
        mapCode 0x09 = return LocalSerial
        mapCode 0x10 = return LocalOne
        mapCode code = fail $ "decode-consistency: unknown: " ++ show code
encodeOpCode :: Putter OpCode
encodeOpCode OcError         = encodeByte 0x00
encodeOpCode OcStartup       = encodeByte 0x01
encodeOpCode OcReady         = encodeByte 0x02
encodeOpCode OcAuthenticate  = encodeByte 0x03
encodeOpCode OcOptions       = encodeByte 0x05
encodeOpCode OcSupported     = encodeByte 0x06
encodeOpCode OcQuery         = encodeByte 0x07
encodeOpCode OcResult        = encodeByte 0x08
encodeOpCode OcPrepare       = encodeByte 0x09
encodeOpCode OcExecute       = encodeByte 0x0A
encodeOpCode OcRegister      = encodeByte 0x0B
encodeOpCode OcEvent         = encodeByte 0x0C
encodeOpCode OcBatch         = encodeByte 0x0D
encodeOpCode OcAuthChallenge = encodeByte 0x0E
encodeOpCode OcAuthResponse  = encodeByte 0x0F
encodeOpCode OcAuthSuccess   = encodeByte 0x10
decodeOpCode :: Get OpCode
decodeOpCode = decodeByte >>= mapCode
  where
    mapCode 0x00 = return OcError
    mapCode 0x01 = return OcStartup
    mapCode 0x02 = return OcReady
    mapCode 0x03 = return OcAuthenticate
    mapCode 0x05 = return OcOptions
    mapCode 0x06 = return OcSupported
    mapCode 0x07 = return OcQuery
    mapCode 0x08 = return OcResult
    mapCode 0x09 = return OcPrepare
    mapCode 0x0A = return OcExecute
    mapCode 0x0B = return OcRegister
    mapCode 0x0C = return OcEvent
    mapCode 0x0D = return OcBatch
    mapCode 0x0E = return OcAuthChallenge
    mapCode 0x0F = return OcAuthResponse
    mapCode 0x10 = return OcAuthSuccess
    mapCode word = fail $ "decode-opcode: unknown: " ++ show word
encodeColumnType :: Putter ColumnType
encodeColumnType (CustomColumn x)   = encodeShort 0x0000 >> encodeString x
encodeColumnType AsciiColumn        = encodeShort 0x0001
encodeColumnType BigIntColumn       = encodeShort 0x0002
encodeColumnType BlobColumn         = encodeShort 0x0003
encodeColumnType BooleanColumn      = encodeShort 0x0004
encodeColumnType CounterColumn      = encodeShort 0x0005
encodeColumnType DecimalColumn      = encodeShort 0x0006
encodeColumnType DoubleColumn       = encodeShort 0x0007
encodeColumnType FloatColumn        = encodeShort 0x0008
encodeColumnType IntColumn          = encodeShort 0x0009
encodeColumnType TextColumn         = encodeShort 0x000A
encodeColumnType TimestampColumn    = encodeShort 0x000B
encodeColumnType UuidColumn         = encodeShort 0x000C
encodeColumnType VarCharColumn      = encodeShort 0x000D
encodeColumnType VarIntColumn       = encodeShort 0x000E
encodeColumnType TimeUuidColumn     = encodeShort 0x000F
encodeColumnType InetColumn         = encodeShort 0x0010
encodeColumnType (MaybeColumn x)    = encodeColumnType x
encodeColumnType (ListColumn x)     = encodeShort 0x0020 >> encodeColumnType x
encodeColumnType (MapColumn  x y)   = encodeShort 0x0021 >> encodeColumnType x >> encodeColumnType y
encodeColumnType (SetColumn  x)     = encodeShort 0x0022 >> encodeColumnType x
encodeColumnType (TupleColumn xs)   = encodeShort 0x0031 >> mapM_ encodeColumnType xs
encodeColumnType (UdtColumn k n xs) = do
    encodeShort 0x0030
    encodeString (unKeyspace k)
    encodeString n
    encodeShort (fromIntegral (length xs))
    forM_ xs $ \(x, t) -> encodeString x >> encodeColumnType t
decodeColumnType :: Get ColumnType
decodeColumnType = decodeShort >>= toType
  where
    toType 0x0000 = CustomColumn <$> decodeString
    toType 0x0001 = return AsciiColumn
    toType 0x0002 = return BigIntColumn
    toType 0x0003 = return BlobColumn
    toType 0x0004 = return BooleanColumn
    toType 0x0005 = return CounterColumn
    toType 0x0006 = return DecimalColumn
    toType 0x0007 = return DoubleColumn
    toType 0x0008 = return FloatColumn
    toType 0x0009 = return IntColumn
    toType 0x000A = return TextColumn
    toType 0x000B = return TimestampColumn
    toType 0x000C = return UuidColumn
    toType 0x000D = return VarCharColumn
    toType 0x000E = return VarIntColumn
    toType 0x000F = return TimeUuidColumn
    toType 0x0010 = return InetColumn
    toType 0x0020 = ListColumn  <$> (decodeShort >>= toType)
    toType 0x0021 = MapColumn   <$> (decodeShort >>= toType) <*> (decodeShort >>= toType)
    toType 0x0022 = SetColumn   <$> (decodeShort >>= toType)
    toType 0x0030 = UdtColumn   <$> (Keyspace <$> decodeString) <*> decodeString <*> do
        n <- fromIntegral <$> decodeShort
        replicateM n ((,) <$> decodeString <*> (decodeShort >>= toType))
    toType 0x0031 = TupleColumn <$> do
        n <- fromIntegral <$> decodeShort
        replicateM n (decodeShort >>= toType)
    toType other  = fail $ "decode-type: unknown: " ++ show other
encodePagingState :: Putter PagingState
encodePagingState (PagingState s) = encodeBytes s
decodePagingState :: Get (Maybe PagingState)
decodePagingState = liftM PagingState <$> decodeBytes
putValue :: Version -> Putter Value
putValue V3 (CqlList x)        = toBytes 4 $ do
    encodeInt (fromIntegral (length x))
    mapM_ (toBytes 4 . putNative) x
putValue V2 (CqlList x)        = toBytes 4 $ do
    encodeShort (fromIntegral (length x))
    mapM_ (toBytes 2 . putNative) x
putValue V3 (CqlSet x)         = toBytes 4 $ do
    encodeInt (fromIntegral (length x))
    mapM_ (toBytes 4 . putNative) x
putValue V2 (CqlSet x)         = toBytes 4 $ do
    encodeShort (fromIntegral (length x))
    mapM_ (toBytes 2 . putNative) x
putValue V3 (CqlMap x)         = toBytes 4 $ do
    encodeInt (fromIntegral (length x))
    forM_ x $ \(k, v) -> toBytes 4 (putNative k) >> toBytes 4 (putNative v)
putValue V2 (CqlMap x)         = toBytes 4 $ do
    encodeShort (fromIntegral (length x))
    forM_ x $ \(k, v) -> toBytes 2 (putNative k) >> toBytes 2 (putNative v)
putValue V3 (CqlTuple x)       = mapM_ (toBytes 4 . putValue V3) x
putValue V3 (CqlUdt x)         = mapM_ (toBytes 4 . putValue V3 . snd) x
putValue _ (CqlMaybe Nothing)  = put (1 :: Int32)
putValue v (CqlMaybe (Just x)) = putValue v x
putValue _ value               = toBytes 4 $ putNative value
putNative :: Putter Value
putNative (CqlCustom x)    = putLazyByteString x
putNative (CqlBoolean x)   = putWord8 $ if x then 1 else 0
putNative (CqlInt x)       = put x
putNative (CqlBigInt x)    = put x
putNative (CqlFloat x)     = putFloat32be x
putNative (CqlDouble x)    = putFloat64be x
putNative (CqlText x)      = putByteString (T.encodeUtf8 x)
putNative (CqlUuid x)      = encodeUUID x
putNative (CqlTimeUuid x)  = encodeUUID x
putNative (CqlTimestamp x) = put x
putNative (CqlAscii x)     = putByteString (T.encodeUtf8 x)
putNative (CqlBlob x)      = putLazyByteString x
putNative (CqlCounter x)   = put x
putNative (CqlInet x)      = case x of
    IPv4 i -> putWord32le (toHostAddress i)
    IPv6 i -> do
        let (a, b, c, d) = toHostAddress6 i
        putWord32host a
        putWord32host b
        putWord32host c
        putWord32host d
putNative (CqlVarInt x)    = integer2bytes x
putNative (CqlDecimal x)   = do
    put (fromIntegral (decimalPlaces x) :: Int32)
    integer2bytes (decimalMantissa x)
putNative v@(CqlList  _)   = fail $ "putNative: collection type: " ++ show v
putNative v@(CqlSet   _)   = fail $ "putNative: collection type: " ++ show v
putNative v@(CqlMap   _)   = fail $ "putNative: collection type: " ++ show v
putNative v@(CqlMaybe _)   = fail $ "putNative: collection type: " ++ show v
putNative v@(CqlTuple _)   = fail $ "putNative: tuple type: " ++ show v
putNative v@(CqlUdt   _)   = fail $ "putNative: UDT: " ++ show v
getValue :: Version -> ColumnType -> Get Value
getValue V3 (ListColumn t)    = CqlList <$> (getList $ do
    len <- decodeInt
    replicateM (fromIntegral len) (withBytes 4 (getNative t)))
getValue V2 (ListColumn t)    = CqlList <$> (getList $ do
    len <- decodeShort
    replicateM (fromIntegral len) (withBytes 2 (getNative t)))
getValue V3 (SetColumn t)     = CqlSet <$> (getList $ do
    len <- decodeInt
    replicateM (fromIntegral len) (withBytes 4 (getNative t)))
getValue V2 (SetColumn t)     = CqlSet <$> (getList $ do
    len <- decodeShort
    replicateM (fromIntegral len) (withBytes 2 (getNative t)))
getValue V3 (MapColumn t u)   = CqlMap <$> (getList $ do
    len <- decodeInt
    replicateM (fromIntegral len)
               ((,) <$> withBytes 4 (getNative t) <*> withBytes 4 (getNative u)))
getValue V2 (MapColumn t u)   = CqlMap <$> (getList $ do
    len <- decodeShort
    replicateM (fromIntegral len)
               ((,) <$> withBytes 2 (getNative t) <*> withBytes 2 (getNative u)))
getValue V3 (TupleColumn t)   = CqlTuple <$> mapM (getValue V3) t
getValue V3 (UdtColumn _ _ x) = CqlUdt <$> do
    let (n, t) = unzip x
    zip n <$> mapM (getValue V3) t
getValue v (MaybeColumn t)    = do
    n <- lookAhead (get :: Get Int32)
    if n < 0
        then uncheckedSkip 4 >> return (CqlMaybe Nothing)
        else CqlMaybe . Just <$> getValue v t
getValue _ colType            = withBytes 4 $ getNative colType
getNative :: ColumnType -> Get Value
getNative (CustomColumn _) = CqlCustom <$> remainingBytesLazy
getNative BooleanColumn    = CqlBoolean . (/= 0) <$> getWord8
getNative IntColumn        = CqlInt <$> get
getNative BigIntColumn     = CqlBigInt <$> get
getNative FloatColumn      = CqlFloat  <$> getFloat32be
getNative DoubleColumn     = CqlDouble <$> getFloat64be
getNative TextColumn       = CqlText . T.decodeUtf8 <$> remainingBytes
getNative VarCharColumn    = CqlText . T.decodeUtf8 <$> remainingBytes
getNative AsciiColumn      = CqlAscii . T.decodeUtf8 <$> remainingBytes
getNative BlobColumn       = CqlBlob <$> remainingBytesLazy
getNative UuidColumn       = CqlUuid <$> decodeUUID
getNative TimeUuidColumn   = CqlTimeUuid <$> decodeUUID
getNative TimestampColumn  = CqlTimestamp <$> get
getNative CounterColumn    = CqlCounter <$> get
getNative InetColumn       = CqlInet <$> do
    len <- remaining
    case len of
        4  -> IPv4 . fromHostAddress <$> getWord32le
        16 -> do
            a <- (,,,) <$> getWord32host <*> getWord32host <*> getWord32host <*> getWord32host
            return $ IPv6 (fromHostAddress6 a)
        n  -> fail $ "getNative: invalid Inet length: " ++ show n
getNative VarIntColumn     = CqlVarInt <$> bytes2integer
getNative DecimalColumn    = do
    x <- get :: Get Int32
    y <- bytes2integer
    return (CqlDecimal (Decimal (fromIntegral x) y))
getNative c@(ListColumn  _)   = fail $ "getNative: collection type: " ++ show c
getNative c@(SetColumn   _)   = fail $ "getNative: collection type: " ++ show c
getNative c@(MapColumn _ _)   = fail $ "getNative: collection type: " ++ show c
getNative c@(MaybeColumn _)   = fail $ "getNative: collection type: " ++ show c
getNative c@(TupleColumn _)   = fail $ "getNative: tuple type: " ++ show c
getNative c@(UdtColumn _ _ _) = fail $ "getNative: udt: " ++ show c
getList :: Get [a] -> Get [a]
getList m = do
    n <- lookAhead (get :: Get Int32)
    if n < 0 then uncheckedSkip 4 >> return []
             else withBytes 4 m
withBytes :: Int -> Get a -> Get a
withBytes s p = do
    n <- case s of
        2 -> fromIntegral <$> (get :: Get Word16)
        4 -> fromIntegral <$> (get :: Get Int32)
        _ -> fail $ "withBytes: invalid size: " ++ show s
    when (n < 0) $
        fail "withBytes: null"
    b <- getBytes n
    case runGet p b of
        Left  e -> fail $ "withBytes: " ++ e
        Right x -> return x
remainingBytes :: Get ByteString
remainingBytes = remaining >>= getByteString . fromIntegral
remainingBytesLazy :: Get LB.ByteString
remainingBytesLazy = remaining >>= getLazyByteString . fromIntegral
toBytes :: Int -> Put -> Put
toBytes s p = do
    let bytes = runPut p
    case s of
        2 -> put (fromIntegral (B.length bytes) :: Word16)
        _ -> put (fromIntegral (B.length bytes) :: Int32)
    putByteString bytes
integer2bytes :: Putter Integer
integer2bytes n = do
    put sign
    put (unroll (abs n))
  where
    sign = fromIntegral (signum n) :: Word8
    unroll :: Integer -> [Word8]
    unroll = unfoldr step
      where
        step 0 = Nothing
        step i = Just (fromIntegral i, i `shiftR` 8)
bytes2integer :: Get Integer
bytes2integer = do
    sign  <- get
    bytes <- get
    let v = roll bytes
    return $! if sign == (1 :: Word8) then v else  v
  where
    roll :: [Word8] -> Integer
    roll = foldr unstep 0
      where
        unstep b a = a `shiftL` 8 .|. fromIntegral b
decodeKeyspace :: Get Keyspace
decodeKeyspace = Keyspace <$> decodeString
decodeTable :: Get Table
decodeTable = Table <$> decodeString
decodeQueryId :: Get (QueryId k a b)
decodeQueryId = QueryId <$> decodeShortBytes