{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE LambdaCase #-}

module Network.Unexceptional.ByteArray
  ( receiveExactly
  , receiveFromInterruptible
  ) where

import Control.Concurrent.STM (TVar)
import Data.Bytes.Types (MutableBytes (MutableBytes))
import Data.Primitive (ByteArray)
import Foreign.C.Error (Errno)
import Network.Socket (Socket)

import qualified Data.Primitive as PM
import qualified Network.Socket as S
import qualified Network.Unexceptional.MutableBytes as MB

-- | Blocks until an exact number of bytes has been received.
receiveExactly ::
  Socket ->
  -- | Exact number of bytes to receive, must be greater than zero
  Int ->
  IO (Either Errno ByteArray)
receiveExactly :: Socket -> Int -> IO (Either Errno ByteArray)
receiveExactly !Socket
s !Int
n = do
  MutableByteArray RealWorld
dst <- Int -> IO (MutableByteArray (PrimState IO))
forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
PM.newByteArray Int
n
  Socket -> MutableBytes RealWorld -> IO (Either Errno ())
MB.receiveExactly Socket
s (MutableByteArray RealWorld -> Int -> Int -> MutableBytes RealWorld
forall s. MutableByteArray s -> Int -> Int -> MutableBytes s
MutableBytes MutableByteArray RealWorld
dst Int
0 Int
n) IO (Either Errno ())
-> (Either Errno () -> IO (Either Errno ByteArray))
-> IO (Either Errno ByteArray)
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Left Errno
e -> Either Errno ByteArray -> IO (Either Errno ByteArray)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Errno -> Either Errno ByteArray
forall a b. a -> Either a b
Left Errno
e)
    Right ()
_ -> do
      ByteArray
dst' <- MutableByteArray (PrimState IO) -> IO ByteArray
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> m ByteArray
PM.unsafeFreezeByteArray MutableByteArray RealWorld
MutableByteArray (PrimState IO)
dst
      Either Errno ByteArray -> IO (Either Errno ByteArray)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ByteArray -> Either Errno ByteArray
forall a b. b -> Either a b
Right ByteArray
dst')

receiveFromInterruptible ::
  TVar Bool ->
  Socket ->
  -- | Maximum number of bytes to receive.
  Int ->
  IO (Either Errno (ByteArray, S.SockAddr))
receiveFromInterruptible :: TVar Bool
-> Socket -> Int -> IO (Either Errno (ByteArray, SockAddr))
receiveFromInterruptible !TVar Bool
intr Socket
s !Int
n = do
  MutableByteArray RealWorld
dst <- Int -> IO (MutableByteArray (PrimState IO))
forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
PM.newByteArray Int
n
  TVar Bool
-> Socket
-> MutableBytes RealWorld
-> IO (Either Errno (Int, SockAddr))
MB.receiveFromInterruptible TVar Bool
intr Socket
s (MutableByteArray RealWorld -> Int -> Int -> MutableBytes RealWorld
forall s. MutableByteArray s -> Int -> Int -> MutableBytes s
MutableBytes MutableByteArray RealWorld
dst Int
0 Int
n) IO (Either Errno (Int, SockAddr))
-> (Either Errno (Int, SockAddr)
    -> IO (Either Errno (ByteArray, SockAddr)))
-> IO (Either Errno (ByteArray, SockAddr))
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Left Errno
err -> Either Errno (ByteArray, SockAddr)
-> IO (Either Errno (ByteArray, SockAddr))
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Errno -> Either Errno (ByteArray, SockAddr)
forall a b. a -> Either a b
Left Errno
err)
    Right (Int
sz, SockAddr
sockAddr) -> do
      MutableByteArray (PrimState IO) -> Int -> IO ()
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> Int -> m ()
PM.shrinkMutableByteArray MutableByteArray RealWorld
MutableByteArray (PrimState IO)
dst Int
sz
      ByteArray
dst' <- MutableByteArray (PrimState IO) -> IO ByteArray
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> m ByteArray
PM.unsafeFreezeByteArray MutableByteArray RealWorld
MutableByteArray (PrimState IO)
dst
      Either Errno (ByteArray, SockAddr)
-> IO (Either Errno (ByteArray, SockAddr))
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((ByteArray, SockAddr) -> Either Errno (ByteArray, SockAddr)
forall a b. b -> Either a b
Right (ByteArray
dst', SockAddr
sockAddr))