-- Copyright 2016 Google Inc. All Rights Reserved.
--
-- Use of this source code is governed by a BSD-style
-- license that can be found in the LICENSE file or at
-- https://developers.google.com/open-source/licenses/bsd

{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE LambdaCase #-}

-- | Utility functions for parsing and encoding individual types.
module Data.ProtoLens.Encoding.Bytes(
    -- * Running encodings
    Parser,
    Builder,
    runParser,
    isolate,
    runBuilder,
    -- * Bytestrings
    getBytes,
    putBytes,
    -- * Integral types
    getVarInt,
    getVarIntH,
    putVarInt,
    getFixed32,
    getFixed64,
    putFixed32,
    putFixed64,
    -- * Floating-point types
    wordToFloat,
    wordToDouble,
    floatToWord,
    doubleToWord,
    -- * Signed types
    signedInt32ToWord,
    wordToSignedInt32,
    signedInt64ToWord,
    wordToSignedInt64,
    -- * Other utilities
    atEnd,
    runEither,
    (<?>),
    foldMapBuilder,
    ) where

import Control.Monad.IO.Class (liftIO)
import Control.Monad.Trans.Except (throwE, ExceptT)
import Data.Bits
import Data.ByteString (ByteString)
import Data.ByteString.Builder as Builder
import qualified Data.ByteString.Builder.Internal as Internal
import qualified Data.ByteString.Lazy as L
import Data.Int (Int32, Int64)
#if !MIN_VERSION_base(4,11,0)
import Data.Semigroup ((<>))
#endif
import qualified Data.Vector.Generic as V
import Data.Word (Word8, Word32, Word64)
import Foreign.Marshal (malloc, free)
import Foreign.Storable (peek)
import System.IO (Handle, hGetBuf)
#if MIN_VERSION_base(4,11,0)
import qualified GHC.Float as Float
#else
import Foreign.Ptr (castPtr)
import Foreign.Marshal.Alloc (alloca)
import Foreign.Storable (Storable, poke)
import System.IO.Unsafe (unsafePerformIO)
#endif

import Data.ProtoLens.Encoding.Parser

-- | Constructs a strict 'ByteString' from the given 'Builder'.
runBuilder :: Builder -> ByteString
runBuilder :: Builder -> ByteString
runBuilder = ByteString -> ByteString
L.toStrict forall b c a. (b -> c) -> (a -> b) -> a -> c
. Builder -> ByteString
Builder.toLazyByteString

-- | Emit a given @ByteString@.
putBytes :: ByteString -> Builder
putBytes :: ByteString -> Builder
putBytes = ByteString -> Builder
Builder.byteString

-- VarInts are inherently unsigned; there are different ways of encoding
-- negative numbers for int32/64 and sint32/64.
getVarInt :: Parser Word64
getVarInt :: Parser Word64
getVarInt = Word64 -> Word64 -> Parser Word64
loopStart Word64
0 Word64
1
  where
    loopStart :: Word64 -> Word64 -> Parser Word64
loopStart !Word64
n !Word64
s = Parser Word8
getWord8 forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (m :: * -> *).
Monad m =>
(Word64 -> Word64 -> m Word64)
-> Word64 -> Word64 -> Word8 -> m Word64
getVarIntLoopFinish Word64 -> Word64 -> Parser Word64
loopStart Word64
n Word64
s

-- Same as getVarInt but reads from a Handle
getVarIntH :: Handle -> ExceptT String IO Word64
getVarIntH :: Handle -> ExceptT String IO Word64
getVarIntH Handle
h = do
    Ptr Word8
buf <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a. Storable a => IO (Ptr a)
malloc
    let loopStart :: Word64 -> Word64 -> ExceptT String m Word64
loopStart !Word64
n !Word64
s =
          (forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall a. Handle -> Ptr a -> Int -> IO Int
hGetBuf Handle
h Ptr Word8
buf Int
1) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>=
          \case
            Int
1 -> (forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall a. Storable a => Ptr a -> IO a
peek Ptr Word8
buf) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>=
                 forall (m :: * -> *).
Monad m =>
(Word64 -> Word64 -> m Word64)
-> Word64 -> Word64 -> Word8 -> m Word64
getVarIntLoopFinish Word64 -> Word64 -> ExceptT String m Word64
loopStart Word64
n Word64
s
            Int
_ -> forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE String
"Unexpected end of file"
    Word64
res <- forall {m :: * -> *}.
MonadIO m =>
Word64 -> Word64 -> ExceptT String m Word64
loopStart Word64
0 Word64
1
    forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall a. Ptr a -> IO ()
free Ptr Word8
buf
    forall (m :: * -> *) a. Monad m => a -> m a
return Word64
res

getVarIntLoopFinish
  :: (Monad m)
  => (Word64 -> Word64 -> m Word64) -- "loop start" callback
  -> Word64
  -> Word64
  -> Word8
  -> m Word64
getVarIntLoopFinish :: forall (m :: * -> *).
Monad m =>
(Word64 -> Word64 -> m Word64)
-> Word64 -> Word64 -> Word8 -> m Word64
getVarIntLoopFinish Word64 -> Word64 -> m Word64
ls !Word64
n !Word64
s !Word8
b = do
    let n' :: Word64
n' = Word64 -> Word64 -> Word8 -> Word64
decodeVarIntStep Word64
n Word64
s Word8
b
    if Word8 -> Bool
testMsb Word8
b
      then Word64 -> Word64 -> m Word64
ls Word64
n' (Word64
128forall a. Num a => a -> a -> a
*Word64
s)
      else forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$! Word64
n'

-- n -- result of previous step; s -- 128^{step index}; b -- step byte
decodeVarIntStep :: Word64 -> Word64 -> Word8 -> Word64
decodeVarIntStep :: Word64 -> Word64 -> Word8 -> Word64
decodeVarIntStep Word64
n Word64
s Word8
b = Word64
n forall a. Num a => a -> a -> a
+ Word64
s forall a. Num a => a -> a -> a
* forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word8
b forall a. Bits a => a -> a -> a
.&. Word8
127)

testMsb :: Word8 -> Bool
testMsb :: Word8 -> Bool
testMsb Word8
b = (Word8
b forall a. Bits a => a -> a -> a
.&. Word8
128) forall a. Eq a => a -> a -> Bool
/= Word8
0

putVarInt :: Word64 -> Builder
putVarInt :: Word64 -> Builder
putVarInt Word64
n
    | Word64
n forall a. Ord a => a -> a -> Bool
< Word64
128 = Word8 -> Builder
Builder.word8 (forall a b. (Integral a, Num b) => a -> b
fromIntegral Word64
n)
    | Bool
otherwise = Word8 -> Builder
Builder.word8 (forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ Word64
n forall a. Bits a => a -> a -> a
.|. Word64
128)
                      forall a. Semigroup a => a -> a -> a
<> Word64 -> Builder
putVarInt (Word64
n forall a. Bits a => a -> Int -> a
`shiftR` Int
7)

getFixed32 :: Parser Word32
getFixed32 :: Parser Word32
getFixed32 = Parser Word32
getWord32le

getFixed64 :: Parser Word64
getFixed64 :: Parser Word64
getFixed64 = do
    Word32
x <- Parser Word32
getFixed32
    Word32
y <- Parser Word32
getFixed32
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a b. (Integral a, Num b) => a -> b
fromIntegral Word32
y forall a. Bits a => a -> Int -> a
`shiftL` Int
32 forall a. Num a => a -> a -> a
+ forall a b. (Integral a, Num b) => a -> b
fromIntegral Word32
x

-- Note: putFixed32 and putFixed32 have added BangPatterns over the
-- standard Builders.
-- This works better when they're composed with other functions.
-- For example, consider `putFixed32 . floatToWord`.
-- Since `putFixed32` may return a continuation, it doesn't automatically
-- force the result of `floatToWord`, so the resulting Word32 must be kept
-- lazily.  The extra strictness means that the Word32 will be evaluated
-- outside of the continuation, and GHC can pass it around unboxed.

putFixed32 :: Word32 -> Builder
putFixed32 :: Word32 -> Builder
putFixed32 !Word32
x = Word32 -> Builder
word32LE Word32
x

putFixed64 :: Word64 -> Builder
putFixed64 :: Word64 -> Builder
putFixed64 !Word64
x = Word64 -> Builder
word64LE Word64
x

#if MIN_VERSION_base(4,11,0)
wordToDouble :: Word64 -> Double
wordToDouble :: Word64 -> Double
wordToDouble = Word64 -> Double
Float.castWord64ToDouble

wordToFloat :: Word32 -> Float
wordToFloat :: Word32 -> Float
wordToFloat = Word32 -> Float
Float.castWord32ToFloat

doubleToWord :: Double -> Word64
doubleToWord :: Double -> Word64
doubleToWord = Double -> Word64
Float.castDoubleToWord64

floatToWord :: Float -> Word32
floatToWord :: Float -> Word32
floatToWord = Float -> Word32
Float.castFloatToWord32

#else
-- WARNING: SUPER UNSAFE!
-- Helper function purely for converting between Word32/Word64 and
-- Float/Double.  Note that ideally we could just use unsafeCoerce, but this
-- breaks with -O2 since it violates some assumptions in Core.  As a result,
-- poking the FFI turns out to be a more reliable way to do these casts.
-- For more information see:
-- https://ghc.haskell.org/trac/ghc/ticket/2209
-- https://ghc.haskell.org/trac/ghc/ticket/4092
{-# INLINE cast #-}
cast :: (Storable a, Storable b) => a -> b
cast x = unsafePerformIO $ alloca $ \p -> do
            poke p x
            peek $ castPtr p

wordToDouble :: Word64 -> Double
wordToDouble = cast

wordToFloat :: Word32 -> Float
wordToFloat = cast

doubleToWord :: Double -> Word64
doubleToWord = cast

floatToWord :: Float -> Word32
floatToWord = cast
#endif

signedInt32ToWord :: Int32 -> Word32
signedInt32ToWord :: Int32 -> Word32
signedInt32ToWord Int32
n = forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ forall a. Bits a => a -> Int -> a
shiftL Int32
n Int
1 forall a. Bits a => a -> a -> a
`xor` forall a. Bits a => a -> Int -> a
shiftR Int32
n Int
31

wordToSignedInt32 :: Word32 -> Int32
wordToSignedInt32 :: Word32 -> Int32
wordToSignedInt32 Word32
n
    = forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall a. Bits a => a -> Int -> a
shiftR Word32
n Int
1) forall a. Bits a => a -> a -> a
`xor` forall a. Num a => a -> a
negate (forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ Word32
n forall a. Bits a => a -> a -> a
.&. Word32
1)

signedInt64ToWord :: Int64 -> Word64
signedInt64ToWord :: Int64 -> Word64
signedInt64ToWord Int64
n = forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ forall a. Bits a => a -> Int -> a
shiftL Int64
n Int
1 forall a. Bits a => a -> a -> a
`xor` forall a. Bits a => a -> Int -> a
shiftR Int64
n Int
63

wordToSignedInt64 :: Word64 -> Int64
wordToSignedInt64 :: Word64 -> Int64
wordToSignedInt64 Word64
n
    = forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall a. Bits a => a -> Int -> a
shiftR Word64
n Int
1) forall a. Bits a => a -> a -> a
`xor` forall a. Num a => a -> a
negate (forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ Word64
n forall a. Bits a => a -> a -> a
.&. Word64
1)

runEither :: Either String a -> Parser a
runEither :: forall a. Either String a -> Parser a
runEither = forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either forall (m :: * -> *) a. MonadFail m => String -> m a
fail forall (m :: * -> *) a. Monad m => a -> m a
return

-- | Loop over the elements of a vector and concatenate the resulting
-- @Builder@s.
--
-- This function has been hand-tuned to perform better than a naive
-- implementation using, e.g., Vector.foldr or a manual loop.
foldMapBuilder :: V.Vector v a => (a -> Builder) -> v a -> Builder
foldMapBuilder :: forall (v :: * -> *) a.
Vector v a =>
(a -> Builder) -> v a -> Builder
foldMapBuilder a -> Builder
f = \v a
v0 -> (forall r. BuildStep r -> BuildStep r) -> Builder
Internal.builder (forall {v :: * -> *} {a}.
Vector v a =>
v a
-> (BufferRange -> IO (BuildSignal a))
-> BufferRange
-> IO (BuildSignal a)
loop v a
v0)
    -- Place v0 on the right-hand side so that GHC actually inlines
    -- this function.
  where
    -- Fully-saturate the inner loop (rather than currying away `cont`
    -- and `bs`) to avoid GHC creating an intermediate continuation.
    loop :: v a
-> (BufferRange -> IO (BuildSignal a))
-> BufferRange
-> IO (BuildSignal a)
loop v a
v BufferRange -> IO (BuildSignal a)
cont BufferRange
bs
        | forall (v :: * -> *) a. Vector v a => v a -> Bool
V.null v a
v = BufferRange -> IO (BuildSignal a)
cont BufferRange
bs
        | Bool
otherwise = let
            !x :: a
x = forall (v :: * -> *) a. Vector v a => v a -> a
V.unsafeHead v a
v
            -- lts-8.24 (ghc-8.0) doesn't inline unsafeTail well.
            -- We can remove the following bang when we bump the lower bound:
            !xs :: v a
xs = forall (v :: * -> *) a. Vector v a => v a -> v a
V.unsafeTail v a
v
            in forall a. Builder -> BuildStep a -> BuildStep a
Internal.runBuilderWith
                        (a -> Builder
f a
x)
                        (v a
-> (BufferRange -> IO (BuildSignal a))
-> BufferRange
-> IO (BuildSignal a)
loop v a
xs BufferRange -> IO (BuildSignal a)
cont) BufferRange
bs
{-# INLINE foldMapBuilder #-}