{-# language BangPatterns #-}
{-# language DuplicateRecordFields #-}
{-# language PatternSynonyms #-}
{-# language LambdaCase #-}
{-# language NamedFieldPuns #-}
module Network.Unexceptional.ByteArray
( receiveExactly
) where
import Control.Exception (throwIO)
import Control.Monad (when)
import Data.Primitive (ByteArray)
import Data.Bytes.Types (Bytes(Bytes),MutableBytes(MutableBytes))
import Foreign.C.Error (Errno)
import Foreign.C.Error.Pattern (pattern EWOULDBLOCK,pattern EAGAIN)
import GHC.Conc (threadWaitWrite)
import Network.Socket (Socket)
import System.Posix.Types (Fd(Fd))
import qualified Network.Unexceptional.Types as Types
import qualified Posix.Socket as X
import qualified Linux.Socket as X
import qualified Data.Bytes.Types
import qualified Network.Socket as S
import qualified Data.Primitive as PM
import qualified Network.Unexceptional.MutableBytes as MB
receiveExactly ::
Socket
-> Int
-> IO (Either Errno ByteArray)
receiveExactly :: Socket -> Int -> IO (Either Errno ByteArray)
receiveExactly Socket
s Int
n = if Int
n forall a. Ord a => a -> a -> Bool
> Int
0
then do
MutableByteArray RealWorld
dst <- forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
PM.newByteArray Int
n
let loop :: Int -> Int -> IO (Either Errno ByteArray)
loop !Int
ix !Int
remaining = case Int
remaining of
Int
0 -> do
ByteArray
dst' <- forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> m ByteArray
PM.unsafeFreezeByteArray MutableByteArray RealWorld
dst
forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a b. b -> Either a b
Right ByteArray
dst')
Int
_ -> Socket -> MutableBytes RealWorld -> IO (Either Errno Int)
MB.receive Socket
s (forall s. MutableByteArray s -> Int -> Int -> MutableBytes s
MutableBytes MutableByteArray RealWorld
dst Int
ix Int
remaining) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Left Errno
e -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a b. a -> Either a b
Left Errno
e)
Right Int
k -> Int -> Int -> IO (Either Errno ByteArray)
loop (Int
ix forall a. Num a => a -> a -> a
+ Int
k) (Int
remaining forall a. Num a => a -> a -> a
- Int
k)
Int -> Int -> IO (Either Errno ByteArray)
loop Int
0 Int
n
else forall e a. Exception e => e -> IO a
throwIO NonpositiveReceptionSize
Types.NonpositiveReceptionSize