module Network.ONCRPC.RecordMarking ( sendRecord , RecordState(RecordStart) , recordDone , recordRemaining , recvRecord ) where import Control.Monad (unless) import Data.Bits (Bits, finiteBitSize, bit, clearBit, setBit, testBit) import qualified Data.ByteString as BS import qualified Data.ByteString.Lazy as BSL import Data.Word (Word32) import qualified Network.Socket as Net import qualified Network.Socket.All as NetAll import qualified Network.Socket.ByteString as NetBS import qualified Network.Socket.ByteString.Lazy as NetBSL -- |A raw RPC record fragment header, stored in network byte order. type FragmentHeader = Word32 fragmentHeaderBit :: Int fragmentHeaderBit = pred $ finiteBitSize (0 :: FragmentHeader) maxFragmentSize :: (Bits i, Integral i) => i maxFragmentSize = pred $ bit fragmentHeaderBit unFragmentHeader :: Integral i => FragmentHeader -> (Bool, i) unFragmentHeader w = (testBit w' fragmentHeaderBit, fromIntegral $ clearBit w' fragmentHeaderBit) where w' = Net.ntohl w mkFragmentHeader :: Integral i => Bool -> i -> FragmentHeader mkFragmentHeader l n = Net.htonl $ sb l $ fromIntegral n where sb True x = setBit x fragmentHeaderBit sb False x = x sendRecord :: Net.Socket -> BSL.ByteString -> IO () sendRecord sock b = do NetAll.sendStorable sock $ mkFragmentHeader l (BSL.length h) NetBSL.sendAll sock h unless l $ sendRecord sock t where (h, t) = BSL.splitAt maxFragmentSize b l = BSL.null t data RecordState = RecordStart | RecordHeader | RecordFragment { _fragmentLast :: !Bool , _fragmentLength :: !Int } deriving (Eq, Show) -- |Is the current record complete? recordDone :: RecordState -> Bool recordDone RecordStart = True recordDone _ = False -- |How many bytes are left in this record, if known? recordRemaining :: RecordState -> Maybe Int recordRemaining RecordStart = Just 0 recordRemaining (RecordFragment True n) = Just n recordRemaining _ = Nothing -- |Receive the next block of a record recvRecord :: Net.Socket -> RecordState -> IO (BS.ByteString, RecordState) recvRecord sock (RecordFragment e n) = do b <- NetBS.recv sock n let l = BS.length b return (b, if l < n then RecordFragment e (n - l) else if e then RecordStart else RecordHeader) recvRecord sock s = maybe (return (BS.empty, s)) (recvRecord sock . uncurry RecordFragment . unFragmentHeader) =<< NetAll.recvStorable sock