{-# LANGUAGE CApiFFI #-}
module Network.ONCRPC.RecordMarking
  ( sendRecord
  , RecordState(RecordStart)
  , recordDone
  , recordRemaining
  , recvRecord
  ) where

import           Control.Monad (unless)
import           Data.Bits (Bits, finiteBitSize, bit, clearBit, setBit, testBit)
import qualified Data.ByteString as BS
import qualified Data.ByteString.Lazy as BSL
import           Data.Word (Word32)
import qualified Network.Socket as Net
import qualified Network.Socket.All as NetAll
import qualified Network.Socket.ByteString as NetBS
import qualified Network.Socket.ByteString.Lazy as NetBSL

foreign import capi unsafe "arpa/inet.h htonl" htonl :: Word32 -> Word32
foreign import capi unsafe "arpa/inet.h ntohl" ntohl :: Word32 -> Word32

-- |A raw RPC record fragment header, stored in network byte order.
type FragmentHeader = Word32

fragmentHeaderBit :: Int
fragmentHeaderBit :: Int
fragmentHeaderBit = forall a. Enum a => a -> a
pred forall a b. (a -> b) -> a -> b
$ forall b. FiniteBits b => b -> Int
finiteBitSize (FragmentHeader
0 :: FragmentHeader)

maxFragmentSize :: (Bits i, Integral i) => i
maxFragmentSize :: forall i. (Bits i, Integral i) => i
maxFragmentSize = forall a. Enum a => a -> a
pred forall a b. (a -> b) -> a -> b
$ forall a. Bits a => Int -> a
bit Int
fragmentHeaderBit

unFragmentHeader :: Integral i => FragmentHeader -> (Bool, i)
unFragmentHeader :: forall i. Integral i => FragmentHeader -> (Bool, i)
unFragmentHeader FragmentHeader
w =
  (forall a. Bits a => a -> Int -> Bool
testBit FragmentHeader
w' Int
fragmentHeaderBit, forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ forall a. Bits a => a -> Int -> a
clearBit FragmentHeader
w' Int
fragmentHeaderBit)
  where w' :: FragmentHeader
w' = FragmentHeader -> FragmentHeader
ntohl FragmentHeader
w

mkFragmentHeader :: Integral i => Bool -> i -> FragmentHeader
mkFragmentHeader :: forall i. Integral i => Bool -> i -> FragmentHeader
mkFragmentHeader Bool
l i
n = FragmentHeader -> FragmentHeader
htonl forall a b. (a -> b) -> a -> b
$ forall {a}. Bits a => Bool -> a -> a
sb Bool
l forall a b. (a -> b) -> a -> b
$ forall a b. (Integral a, Num b) => a -> b
fromIntegral i
n where
  sb :: Bool -> a -> a
sb Bool
True a
x = forall a. Bits a => a -> Int -> a
setBit a
x Int
fragmentHeaderBit
  sb Bool
False a
x = a
x

sendRecord :: Net.Socket -> BSL.ByteString -> IO ()
sendRecord :: Socket -> ByteString -> IO ()
sendRecord Socket
sock ByteString
b = do
  forall a. Storable a => Socket -> a -> IO ()
NetAll.sendStorable Socket
sock forall a b. (a -> b) -> a -> b
$ forall i. Integral i => Bool -> i -> FragmentHeader
mkFragmentHeader Bool
l (ByteString -> Int64
BSL.length ByteString
h)
  Socket -> ByteString -> IO ()
NetBSL.sendAll Socket
sock ByteString
h
  forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
l forall a b. (a -> b) -> a -> b
$ Socket -> ByteString -> IO ()
sendRecord Socket
sock ByteString
t
  where
  (ByteString
h, ByteString
t) = Int64 -> ByteString -> (ByteString, ByteString)
BSL.splitAt forall i. (Bits i, Integral i) => i
maxFragmentSize ByteString
b
  l :: Bool
l = ByteString -> Bool
BSL.null ByteString
t

data RecordState
  = RecordStart
  | RecordHeader
  | RecordFragment
    { RecordState -> Bool
_fragmentLast :: !Bool
    , RecordState -> Int
_fragmentLength :: !Int
    }
  deriving (RecordState -> RecordState -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: RecordState -> RecordState -> Bool
$c/= :: RecordState -> RecordState -> Bool
== :: RecordState -> RecordState -> Bool
$c== :: RecordState -> RecordState -> Bool
Eq, Int -> RecordState -> ShowS
[RecordState] -> ShowS
RecordState -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [RecordState] -> ShowS
$cshowList :: [RecordState] -> ShowS
show :: RecordState -> String
$cshow :: RecordState -> String
showsPrec :: Int -> RecordState -> ShowS
$cshowsPrec :: Int -> RecordState -> ShowS
Show)

-- |Is the current record complete?
recordDone :: RecordState -> Bool
recordDone :: RecordState -> Bool
recordDone RecordState
RecordStart = Bool
True
recordDone RecordState
_ = Bool
False

-- |How many bytes are left in this record, if known?
recordRemaining :: RecordState -> Maybe Int
recordRemaining :: RecordState -> Maybe Int
recordRemaining RecordState
RecordStart = forall a. a -> Maybe a
Just Int
0
recordRemaining (RecordFragment Bool
True Int
n) = forall a. a -> Maybe a
Just Int
n
recordRemaining RecordState
_ = forall a. Maybe a
Nothing

-- |Receive the next block of a record
recvRecord :: Net.Socket -> RecordState -> IO (BS.ByteString, RecordState)
recvRecord :: Socket -> RecordState -> IO (ByteString, RecordState)
recvRecord Socket
sock (RecordFragment Bool
e Int
n) = do
  ByteString
b <- Socket -> Int -> IO ByteString
NetBS.recv Socket
sock Int
n
  let l :: Int
l = ByteString -> Int
BS.length ByteString
b
  forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString
b, if Int
l forall a. Ord a => a -> a -> Bool
< Int
n
    then Bool -> Int -> RecordState
RecordFragment Bool
e (Int
n forall a. Num a => a -> a -> a
- Int
l)
    else if Bool
e
      then RecordState
RecordStart
      else RecordState
RecordHeader)
recvRecord Socket
sock RecordState
s =
  forall b a. b -> (a -> b) -> Maybe a -> b
maybe (forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString
BS.empty, RecordState
s)) (Socket -> RecordState -> IO (ByteString, RecordState)
recvRecord Socket
sock forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Bool -> Int -> RecordState
RecordFragment forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall i. Integral i => FragmentHeader -> (Bool, i)
unFragmentHeader) forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a. Storable a => Socket -> IO (Maybe a)
NetAll.recvStorable Socket
sock