{-# language BangPatterns #-}
{-# language DuplicateRecordFields #-}
{-# language PatternSynonyms #-}
{-# language LambdaCase #-}
{-# language NamedFieldPuns #-}

module Network.Unexceptional.ByteArray
  ( receiveExactly
  ) where

import Control.Exception (throwIO)
import Control.Monad (when)
import Data.Primitive (ByteArray)
import Data.Bytes.Types (Bytes(Bytes),MutableBytes(MutableBytes))
import Foreign.C.Error (Errno)
import Foreign.C.Error.Pattern (pattern EWOULDBLOCK,pattern EAGAIN)
import GHC.Conc (threadWaitWrite)
import Network.Socket (Socket)
import System.Posix.Types (Fd(Fd))

import qualified Network.Unexceptional.Types as Types
import qualified Posix.Socket as X
import qualified Linux.Socket as X
import qualified Data.Bytes.Types
import qualified Network.Socket as S
import qualified Data.Primitive as PM
import qualified Network.Unexceptional.MutableBytes as MB

-- | Blocks until an exact number of bytes has been received.
receiveExactly :: 
     Socket
  -> Int -- ^ Exact number of bytes to receive, must be greater than zero
  -> IO (Either Errno ByteArray)
receiveExactly :: Socket -> Int -> IO (Either Errno ByteArray)
receiveExactly Socket
s Int
n = if Int
n forall a. Ord a => a -> a -> Bool
> Int
0
  then do
    MutableByteArray RealWorld
dst <- forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
PM.newByteArray Int
n
    let loop :: Int -> Int -> IO (Either Errno ByteArray)
loop !Int
ix !Int
remaining = case Int
remaining of
          Int
0 -> do
            ByteArray
dst' <- forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> m ByteArray
PM.unsafeFreezeByteArray MutableByteArray RealWorld
dst 
            forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a b. b -> Either a b
Right ByteArray
dst')
          Int
_ -> Socket -> MutableBytes RealWorld -> IO (Either Errno Int)
MB.receive Socket
s (forall s. MutableByteArray s -> Int -> Int -> MutableBytes s
MutableBytes MutableByteArray RealWorld
dst Int
ix Int
remaining) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
            Left Errno
e -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a b. a -> Either a b
Left Errno
e)
            Right Int
k -> Int -> Int -> IO (Either Errno ByteArray)
loop (Int
ix forall a. Num a => a -> a -> a
+ Int
k) (Int
remaining forall a. Num a => a -> a -> a
- Int
k)
    Int -> Int -> IO (Either Errno ByteArray)
loop Int
0 Int
n
  else forall e a. Exception e => e -> IO a
throwIO NonpositiveReceptionSize
Types.NonpositiveReceptionSize