module Network.ONCRPC.RecordMarking
( sendRecord
, RecordState(RecordStart)
, recordDone
, recordRemaining
, recvRecord
) where
import Control.Monad (unless)
import Data.Bits (Bits, finiteBitSize, bit, clearBit, setBit, testBit)
import qualified Data.ByteString as BS
import qualified Data.ByteString.Lazy as BSL
import Data.Word (Word32)
import qualified Network.Socket as Net
import qualified Network.Socket.All as NetAll
import qualified Network.Socket.ByteString as NetBS
import qualified Network.Socket.ByteString.Lazy as NetBSL
type FragmentHeader = Word32
fragmentHeaderBit :: Int
fragmentHeaderBit = pred $ finiteBitSize (0 :: FragmentHeader)
maxFragmentSize :: (Bits i, Integral i) => i
maxFragmentSize = pred $ bit fragmentHeaderBit
unFragmentHeader :: Integral i => FragmentHeader -> (Bool, i)
unFragmentHeader w =
(testBit w' fragmentHeaderBit, fromIntegral $ clearBit w' fragmentHeaderBit)
where w' = Net.ntohl w
mkFragmentHeader :: Integral i => Bool -> i -> FragmentHeader
mkFragmentHeader l n = Net.htonl $ sb l $ fromIntegral n where
sb True x = setBit x fragmentHeaderBit
sb False x = x
sendRecord :: Net.Socket -> BSL.ByteString -> IO ()
sendRecord sock b = do
NetAll.sendStorable sock $ mkFragmentHeader l (BSL.length h)
NetBSL.sendAll sock h
unless l $ sendRecord sock t
where
(h, t) = BSL.splitAt maxFragmentSize b
l = BSL.null t
data RecordState
= RecordStart
| RecordHeader
| RecordFragment
{ _fragmentLast :: !Bool
, _fragmentLength :: !Int
}
deriving (Eq, Show)
recordDone :: RecordState -> Bool
recordDone RecordStart = True
recordDone _ = False
recordRemaining :: RecordState -> Maybe Int
recordRemaining RecordStart = Just 0
recordRemaining (RecordFragment True n) = Just n
recordRemaining _ = Nothing
recvRecord :: Net.Socket -> RecordState -> IO (BS.ByteString, RecordState)
recvRecord sock (RecordFragment e n) = do
b <- NetBS.recv sock n
let l = BS.length b
return (b, if l < n
then RecordFragment e (n l)
else if e
then RecordStart
else RecordHeader)
recvRecord sock s =
maybe (return (BS.empty, s)) (recvRecord sock . uncurry RecordFragment . unFragmentHeader) =<< NetAll.recvStorable sock