{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-|
Module: Data.Store.Streaming
Description: A thin streaming layer that uses 'Store' for serialisation.

For efficiency reasons, 'Store' does not provide facilities for
incrementally consuming input.  In order to avoid partial input, this
module introduces 'Message's that wrap values of instances of 'Store'.

In addition to the serialisation of a value, the serialised message
also contains the length of the serialisation.  This way, instead of
consuming input incrementally, more input can be demanded before
serialisation is attempted in the first place.

Each message starts with a fixed magic number, in order to detect
(randomly) invalid data.

-}
module Data.Store.Streaming
       ( -- * 'Message's to stream data using 'Store' for serialisation.
         Message (..)
         -- * Encoding 'Message's
       , encodeMessage
         -- * Decoding 'Message's
       , PeekMessage
       , FillByteBuffer
       , peekMessage
       , decodeMessage
       , peekMessageBS
       , decodeMessageBS
#ifndef mingw32_HOST_OS
       , ReadMoreData(..)
       , peekMessageFd
       , decodeMessageFd
#endif
         -- * Conduits for encoding and decoding
       , conduitEncode
       , conduitDecode
       ) where

import           Control.Exception (throwIO)
import           Control.Monad (unless)
import           Control.Monad.Fail (MonadFail)
import           Control.Monad.IO.Class
import           Control.Monad.Trans.Resource (MonadResource)
import           Data.ByteString (ByteString)
import qualified Data.Conduit as C
import qualified Data.Conduit.List as C
import           Data.Store
import           Data.Store.Core (decodeIOWithFromPtr, unsafeEncodeWith)
import           Data.Store.Internal (getSize)
import qualified Data.Text as T
import           Data.Word
import           Foreign.Ptr
import           Prelude
import           System.IO.ByteBuffer (ByteBuffer)
import qualified System.IO.ByteBuffer as BB
import           Control.Monad.Trans.Free.Church (FT, iterTM, wrap)
import           Control.Monad.Trans.Maybe (MaybeT(MaybeT), runMaybeT)
import           Control.Monad.Trans.Class (lift)
import           System.Posix.Types (Fd(..))
import           GHC.Conc (threadWaitRead)
import           Data.Store.Streaming.Internal

-- | If @a@ is an instance of 'Store', @Message a@ can be serialised
-- and deserialised in a streaming fashion.
newtype Message a = Message { forall a. Message a -> a
fromMessage :: a } deriving (Message a -> Message a -> Bool
forall a. Eq a => Message a -> Message a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Message a -> Message a -> Bool
$c/= :: forall a. Eq a => Message a -> Message a -> Bool
== :: Message a -> Message a -> Bool
$c== :: forall a. Eq a => Message a -> Message a -> Bool
Eq, Int -> Message a -> ShowS
forall a. Show a => Int -> Message a -> ShowS
forall a. Show a => [Message a] -> ShowS
forall a. Show a => Message a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Message a] -> ShowS
$cshowList :: forall a. Show a => [Message a] -> ShowS
show :: Message a -> String
$cshow :: forall a. Show a => Message a -> String
showsPrec :: Int -> Message a -> ShowS
$cshowsPrec :: forall a. Show a => Int -> Message a -> ShowS
Show)

-- | Encode a 'Message' to a 'ByteString'.
encodeMessage :: Store a => Message a -> ByteString
encodeMessage :: forall a. Store a => Message a -> ByteString
encodeMessage (Message a
x) =
    Poke () -> Int -> ByteString
unsafeEncodeWith Poke ()
pokeFunc Int
totalLength
  where
    bodyLength :: Int
bodyLength = forall a. Store a => a -> Int
getSize a
x
    totalLength :: Int
totalLength = Int
headerLength forall a. Num a => a -> a -> a
+ Int
bodyLength
    pokeFunc :: Poke ()
pokeFunc = do
        forall a. Store a => a -> Poke ()
poke Word64
messageMagic
        forall a. Store a => a -> Poke ()
poke Int
bodyLength
        forall a. Store a => a -> Poke ()
poke a
x

-- | The result of peeking at the next message can either be a
-- successfully deserialised object, or a request for more input.
type PeekMessage i m a = FT ((->) i) m a

needMoreInput :: PeekMessage i m i
needMoreInput :: forall i (m :: * -> *). PeekMessage i m i
needMoreInput = forall (f :: * -> *) (m :: * -> *) a.
MonadFree f m =>
f (m a) -> m a
wrap forall (m :: * -> *) a. Monad m => a -> m a
return

-- | Given some sort of input, fills the 'ByteBuffer' with it.
--
-- The 'Int' is how many bytes we'd like: this is useful when the filling
-- function is 'fillFromFd', where we can specify a max size.
type FillByteBuffer i m = ByteBuffer -> Int -> i -> m ()

-- | Decode a value, given a 'Ptr' and the number of bytes that make
-- up the encoded message.
decodeFromPtr :: (MonadIO m, Store a) => Ptr Word8 -> Int -> m a
decodeFromPtr :: forall (m :: * -> *) a.
(MonadIO m, Store a) =>
Ptr Word8 -> Int -> m a
decodeFromPtr Ptr Word8
ptr Int
n = forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall a. Peek a -> Ptr Word8 -> Int -> IO a
decodeIOWithFromPtr forall a. Store a => Peek a
peek Ptr Word8
ptr Int
n

peekSized :: (MonadIO m, Store a) => FillByteBuffer i m -> ByteBuffer -> Int -> PeekMessage i m a
peekSized :: forall (m :: * -> *) a i.
(MonadIO m, Store a) =>
FillByteBuffer i m -> ByteBuffer -> Int -> PeekMessage i m a
peekSized FillByteBuffer i m
fill ByteBuffer
bb Int
n = FT ((->) i) m a
go
  where
    go :: FT ((->) i) m a
go = do
      Either Int (Ptr Word8)
mbPtr <- forall (m :: * -> *).
MonadIO m =>
ByteBuffer -> Int -> m (Either Int (Ptr Word8))
BB.unsafeConsume ByteBuffer
bb Int
n
      case Either Int (Ptr Word8)
mbPtr of
        Left Int
needed -> do
          i
inp <- forall i (m :: * -> *). PeekMessage i m i
needMoreInput
          forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (FillByteBuffer i m
fill ByteBuffer
bb Int
needed i
inp)
          FT ((->) i) m a
go
        Right Ptr Word8
ptr -> forall (m :: * -> *) a.
(MonadIO m, Store a) =>
Ptr Word8 -> Int -> m a
decodeFromPtr Ptr Word8
ptr Int
n

-- | Read and check the magic number from a 'ByteBuffer'
peekMessageMagic :: MonadIO m => FillByteBuffer i m -> ByteBuffer -> PeekMessage i m ()
peekMessageMagic :: forall (m :: * -> *) i.
MonadIO m =>
FillByteBuffer i m -> ByteBuffer -> PeekMessage i m ()
peekMessageMagic FillByteBuffer i m
fill ByteBuffer
bb =
    forall (m :: * -> *) a i.
(MonadIO m, Store a) =>
FillByteBuffer i m -> ByteBuffer -> Int -> PeekMessage i m a
peekSized FillByteBuffer i m
fill ByteBuffer
bb Int
magicLength forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
      Word64
mm | Word64
mm forall a. Eq a => a -> a -> Bool
== Word64
messageMagic -> forall (m :: * -> *) a. Monad m => a -> m a
return ()
      Word64
mm -> forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall e a. Exception e => e -> IO a
throwIO forall a b. (a -> b) -> a -> b
$ Int -> Text -> PeekException
PeekException Int
0 forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Text
T.pack forall a b. (a -> b) -> a -> b
$
          String
"Wrong message magic, " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show Word64
mm

-- | Decode a 'SizeTag' from a 'ByteBuffer'.
peekMessageSizeTag :: MonadIO m => FillByteBuffer i m -> ByteBuffer -> PeekMessage i m SizeTag
peekMessageSizeTag :: forall (m :: * -> *) i.
MonadIO m =>
FillByteBuffer i m -> ByteBuffer -> PeekMessage i m Int
peekMessageSizeTag FillByteBuffer i m
fill ByteBuffer
bb = forall (m :: * -> *) a i.
(MonadIO m, Store a) =>
FillByteBuffer i m -> ByteBuffer -> Int -> PeekMessage i m a
peekSized FillByteBuffer i m
fill ByteBuffer
bb Int
sizeTagLength

-- | Decode some object from a 'ByteBuffer', by first reading its
-- header, and then the actual data.
peekMessage :: (MonadIO m, Store a) => FillByteBuffer i m -> ByteBuffer -> PeekMessage i m (Message a)
peekMessage :: forall (m :: * -> *) a i.
(MonadIO m, Store a) =>
FillByteBuffer i m -> ByteBuffer -> PeekMessage i m (Message a)
peekMessage FillByteBuffer i m
fill ByteBuffer
bb =
  forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. a -> Message a
Message forall a b. (a -> b) -> a -> b
$ do
    forall (m :: * -> *) i.
MonadIO m =>
FillByteBuffer i m -> ByteBuffer -> PeekMessage i m ()
peekMessageMagic FillByteBuffer i m
fill ByteBuffer
bb
    forall (m :: * -> *) i.
MonadIO m =>
FillByteBuffer i m -> ByteBuffer -> PeekMessage i m Int
peekMessageSizeTag FillByteBuffer i m
fill ByteBuffer
bb forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (m :: * -> *) a i.
(MonadIO m, Store a) =>
FillByteBuffer i m -> ByteBuffer -> Int -> PeekMessage i m a
peekSized FillByteBuffer i m
fill ByteBuffer
bb

-- | Decode a 'Message' from a 'ByteBuffer' and an action that can get
-- additional inputs to refill the buffer when necessary.
--
-- The only conditions under which this function will give 'Nothing',
-- is when the 'ByteBuffer' contains zero bytes, and refilling yields
-- 'Nothing'.  If there is some data available, but not enough to
-- decode the whole 'Message', a 'PeekException' will be thrown.
decodeMessage :: (Store a, MonadIO m) => FillByteBuffer i m -> ByteBuffer -> m (Maybe i) -> m (Maybe (Message a))
decodeMessage :: forall a (m :: * -> *) i.
(Store a, MonadIO m) =>
FillByteBuffer i m
-> ByteBuffer -> m (Maybe i) -> m (Maybe (Message a))
decodeMessage FillByteBuffer i m
fill ByteBuffer
bb m (Maybe i)
getInp =
  forall {a}. FT ((->) i) m a -> m (Maybe a)
maybeDecode (forall (m :: * -> *) i.
MonadIO m =>
FillByteBuffer i m -> ByteBuffer -> PeekMessage i m ()
peekMessageMagic FillByteBuffer i m
fill ByteBuffer
bb) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Just () -> forall {a}. FT ((->) i) m a -> m (Maybe a)
maybeDecode (forall (m :: * -> *) i.
MonadIO m =>
FillByteBuffer i m -> ByteBuffer -> PeekMessage i m Int
peekMessageSizeTag FillByteBuffer i m
fill ByteBuffer
bb forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (m :: * -> *) a i.
(MonadIO m, Store a) =>
FillByteBuffer i m -> ByteBuffer -> Int -> PeekMessage i m a
peekSized FillByteBuffer i m
fill ByteBuffer
bb) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
      Just a
x -> forall (m :: * -> *) a. Monad m => a -> m a
return (forall a. a -> Maybe a
Just (forall a. a -> Message a
Message a
x))
      Maybe a
Nothing -> do
        -- We have already read the message magic, so a failure to
        -- read the whole message means we have an incomplete message.
        Int
available <- forall (m :: * -> *). MonadIO m => ByteBuffer -> m Int
BB.availableBytes ByteBuffer
bb
        forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall e a. Exception e => e -> IO a
throwIO forall a b. (a -> b) -> a -> b
$ Int -> Text -> PeekException
PeekException Int
available forall a b. (a -> b) -> a -> b
$ String -> Text
T.pack
          String
"Data.Store.Streaming.decodeMessage: could not get enough bytes to decode message"
    Maybe ()
Nothing -> do
      Int
available <- forall (m :: * -> *). MonadIO m => ByteBuffer -> m Int
BB.availableBytes ByteBuffer
bb
      -- At this point, we have not consumed anything yet, so if bb is
      -- empty, there simply was no message to read.
      forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Int
available forall a. Eq a => a -> a -> Bool
== Int
0) forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall e a. Exception e => e -> IO a
throwIO forall a b. (a -> b) -> a -> b
$ Int -> Text -> PeekException
PeekException Int
available forall a b. (a -> b) -> a -> b
$ String -> Text
T.pack
        String
"Data.Store.Streaming.decodeMessage: could not get enough bytes to decode message"
      forall (m :: * -> *) a. Monad m => a -> m a
return forall a. Maybe a
Nothing
  where
    maybeDecode :: FT ((->) i) m a -> m (Maybe a)
maybeDecode FT ((->) i) m a
m = forall (m :: * -> *) a. MaybeT m a -> m (Maybe a)
runMaybeT (forall (f :: * -> *) (m :: * -> *) (t :: (* -> *) -> * -> *) a.
(Functor f, Monad m, MonadTrans t, Monad (t m)) =>
(f (t m a) -> t m a) -> FT f m a -> t m a
iterTM (\i -> MaybeT m a
consumeInp -> i -> MaybeT m a
consumeInp forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *) a. m (Maybe a) -> MaybeT m a
MaybeT m (Maybe i)
getInp) FT ((->) i) m a
m)

-- | Decode some 'Message' from a 'ByteBuffer', by first reading its
-- header, and then the actual 'Message'.
peekMessageBS :: (MonadIO m, Store a) => ByteBuffer -> PeekMessage ByteString m (Message a)
peekMessageBS :: forall (m :: * -> *) a.
(MonadIO m, Store a) =>
ByteBuffer -> PeekMessage ByteString m (Message a)
peekMessageBS = forall (m :: * -> *) a i.
(MonadIO m, Store a) =>
FillByteBuffer i m -> ByteBuffer -> PeekMessage i m (Message a)
peekMessage (\ByteBuffer
bb Int
_ ByteString
bs -> forall (m :: * -> *). MonadIO m => ByteBuffer -> ByteString -> m ()
BB.copyByteString ByteBuffer
bb ByteString
bs)

decodeMessageBS :: (MonadIO m, Store a)
            => ByteBuffer -> m (Maybe ByteString) -> m (Maybe (Message a))
decodeMessageBS :: forall (m :: * -> *) a.
(MonadIO m, Store a) =>
ByteBuffer -> m (Maybe ByteString) -> m (Maybe (Message a))
decodeMessageBS = forall a (m :: * -> *) i.
(Store a, MonadIO m) =>
FillByteBuffer i m
-> ByteBuffer -> m (Maybe i) -> m (Maybe (Message a))
decodeMessage (\ByteBuffer
bb Int
_ ByteString
bs -> forall (m :: * -> *). MonadIO m => ByteBuffer -> ByteString -> m ()
BB.copyByteString ByteBuffer
bb ByteString
bs)

#ifndef mingw32_HOST_OS

-- | We use this type as a more descriptive unit to signal that more input
-- should be read from the Fd.
--
-- This data-type is only available on POSIX systems (essentially, non-windows)
data ReadMoreData = ReadMoreData
  deriving (ReadMoreData -> ReadMoreData -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ReadMoreData -> ReadMoreData -> Bool
$c/= :: ReadMoreData -> ReadMoreData -> Bool
== :: ReadMoreData -> ReadMoreData -> Bool
$c== :: ReadMoreData -> ReadMoreData -> Bool
Eq, Int -> ReadMoreData -> ShowS
[ReadMoreData] -> ShowS
ReadMoreData -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ReadMoreData] -> ShowS
$cshowList :: [ReadMoreData] -> ShowS
show :: ReadMoreData -> String
$cshow :: ReadMoreData -> String
showsPrec :: Int -> ReadMoreData -> ShowS
$cshowsPrec :: Int -> ReadMoreData -> ShowS
Show)

-- | Peeks a message from a _non blocking_ 'Fd'.
--
-- This function is only available on POSIX systems (essentially, non-windows)
peekMessageFd :: (MonadIO m, MonadFail m, Store a) => ByteBuffer -> Fd -> PeekMessage ReadMoreData m (Message a)
peekMessageFd :: forall (m :: * -> *) a.
(MonadIO m, MonadFail m, Store a) =>
ByteBuffer -> Fd -> PeekMessage ReadMoreData m (Message a)
peekMessageFd ByteBuffer
bb Fd
fd =
  forall (m :: * -> *) a i.
(MonadIO m, Store a) =>
FillByteBuffer i m -> ByteBuffer -> PeekMessage i m (Message a)
peekMessage (\ByteBuffer
bb_ Int
needed ReadMoreData
ReadMoreData -> do Int
_ <- forall (m :: * -> *).
(MonadIO m, MonadFail m) =>
ByteBuffer -> Fd -> Int -> m Int
BB.fillFromFd ByteBuffer
bb_ Fd
fd Int
needed; forall (m :: * -> *) a. Monad m => a -> m a
return ()) ByteBuffer
bb

-- | Decodes all the message using 'registerFd' to find out when a 'Socket' is
-- ready for reading.
--
-- This function is only available on POSIX systems (essentially, non-windows)
decodeMessageFd :: (MonadIO m, MonadFail m, Store a) => ByteBuffer -> Fd -> m (Message a)
decodeMessageFd :: forall (m :: * -> *) a.
(MonadIO m, MonadFail m, Store a) =>
ByteBuffer -> Fd -> m (Message a)
decodeMessageFd ByteBuffer
bb Fd
fd = do
  Maybe (Message a)
mbMsg <- forall a (m :: * -> *) i.
(Store a, MonadIO m) =>
FillByteBuffer i m
-> ByteBuffer -> m (Maybe i) -> m (Maybe (Message a))
decodeMessage
    (\ByteBuffer
bb_ Int
needed ReadMoreData
ReadMoreData -> do Int
_ <- forall (m :: * -> *).
(MonadIO m, MonadFail m) =>
ByteBuffer -> Fd -> Int -> m Int
BB.fillFromFd ByteBuffer
bb_ Fd
fd Int
needed; forall (m :: * -> *) a. Monad m => a -> m a
return ()) ByteBuffer
bb
    (forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (Fd -> IO ()
threadWaitRead Fd
fd) forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall (m :: * -> *) a. Monad m => a -> m a
return (forall a. a -> Maybe a
Just ReadMoreData
ReadMoreData))
  case Maybe (Message a)
mbMsg of
    Just Message a
msg -> forall (m :: * -> *) a. Monad m => a -> m a
return Message a
msg
    Maybe (Message a)
Nothing -> forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"decodeMessageFd: impossible: got Nothing")

#endif

-- | Conduit for encoding 'Message's to 'ByteString's.
conduitEncode
  :: (Monad m, Store a)
  => C.Conduit (Message a) m ByteString
  -- ^ NOTE: ignore the conduit deprecation warning. Otherwise
  -- incompatible with old conduit versions
conduitEncode :: forall (m :: * -> *) a.
(Monad m, Store a) =>
Conduit (Message a) m ByteString
conduitEncode = forall (m :: * -> *) a b. Monad m => (a -> b) -> ConduitT a b m ()
C.map forall a. Store a => Message a -> ByteString
encodeMessage

-- | Conduit for decoding 'Message's from 'ByteString's.
conduitDecode :: (MonadResource m, Store a)
              => Maybe Int
              -- ^ Initial length of the 'ByteBuffer' used for
              -- buffering the incoming 'ByteString's.  If 'Nothing',
              -- use the default value of 4MB.
              -> C.Conduit ByteString m (Message a)
              -- ^ NOTE: ignore the conduit deprecation
              -- warning. Otherwise incompatible with old conduit
              -- versions.
conduitDecode :: forall (m :: * -> *) a.
(MonadResource m, Store a) =>
Maybe Int -> Conduit ByteString m (Message a)
conduitDecode Maybe Int
bufSize =
    forall (m :: * -> *) a i o r.
MonadResource m =>
IO a -> (a -> IO ()) -> (a -> ConduitT i o m r) -> ConduitT i o m r
C.bracketP
      (forall (m :: * -> *). MonadIO m => Maybe Int -> m ByteBuffer
BB.new Maybe Int
bufSize)
      forall (m :: * -> *). MonadIO m => ByteBuffer -> m ()
BB.free
      forall {m :: * -> *} {a}.
(MonadIO m, Store a) =>
ByteBuffer -> ConduitT ByteString (Message a) m ()
go
  where
    go :: ByteBuffer -> ConduitT ByteString (Message a) m ()
go ByteBuffer
buffer = do
        Maybe (Message a)
mmessage <- forall (m :: * -> *) a.
(MonadIO m, Store a) =>
ByteBuffer -> m (Maybe ByteString) -> m (Maybe (Message a))
decodeMessageBS ByteBuffer
buffer forall (m :: * -> *) i o. Monad m => ConduitT i o m (Maybe i)
C.await
        case Maybe (Message a)
mmessage of
            Maybe (Message a)
Nothing -> forall (m :: * -> *) a. Monad m => a -> m a
return ()
            Just Message a
message -> forall (m :: * -> *) o i. Monad m => o -> ConduitT i o m ()
C.yield Message a
message forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> ByteBuffer -> ConduitT ByteString (Message a) m ()
go ByteBuffer
buffer