{-# language BangPatterns #-}
{-# language DuplicateRecordFields #-}
{-# language PatternSynonyms #-}
{-# language LambdaCase #-}
{-# language NamedFieldPuns #-}

module Network.Unexceptional.ByteArray
  ( receiveExactly
  , receiveFromInterruptible
  ) where

import Control.Exception (throwIO)
import Control.Monad (when)
import Data.Primitive (ByteArray)
import Data.Bytes.Types (Bytes(Bytes),MutableBytes(MutableBytes))
import Foreign.C.Error (Errno)
import Foreign.C.Error.Pattern (pattern EWOULDBLOCK,pattern EAGAIN)
import GHC.Conc (threadWaitWrite)
import Network.Socket (Socket)
import System.Posix.Types (Fd(Fd))
import Control.Concurrent.STM (TVar)

import qualified Network.Unexceptional.Types as Types
import qualified Posix.Socket as X
import qualified Linux.Socket as X
import qualified Data.Bytes.Types
import qualified Network.Socket as S
import qualified Data.Primitive as PM
import qualified Network.Unexceptional.MutableBytes as MB

-- | Blocks until an exact number of bytes has been received.
receiveExactly ::
     Socket
  -> Int -- ^ Exact number of bytes to receive, must be greater than zero
  -> IO (Either Errno ByteArray)
receiveExactly :: Socket -> Int -> IO (Either Errno ByteArray)
receiveExactly !Socket
s !Int
n = do
  MutableByteArray RealWorld
dst <- Int -> IO (MutableByteArray (PrimState IO))
forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
PM.newByteArray Int
n
  Socket -> MutableBytes RealWorld -> IO (Either Errno ())
MB.receiveExactly Socket
s (MutableByteArray RealWorld -> Int -> Int -> MutableBytes RealWorld
forall s. MutableByteArray s -> Int -> Int -> MutableBytes s
MutableBytes MutableByteArray RealWorld
dst Int
0 Int
n) IO (Either Errno ())
-> (Either Errno () -> IO (Either Errno ByteArray))
-> IO (Either Errno ByteArray)
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Left Errno
e -> Either Errno ByteArray -> IO (Either Errno ByteArray)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Errno -> Either Errno ByteArray
forall a b. a -> Either a b
Left Errno
e)
    Right ()
_ -> do
      ByteArray
dst' <- MutableByteArray (PrimState IO) -> IO ByteArray
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> m ByteArray
PM.unsafeFreezeByteArray MutableByteArray RealWorld
MutableByteArray (PrimState IO)
dst
      Either Errno ByteArray -> IO (Either Errno ByteArray)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ByteArray -> Either Errno ByteArray
forall a b. b -> Either a b
Right ByteArray
dst')

receiveFromInterruptible ::
     TVar Bool
  -> Socket
  -> Int -- ^ Maximum number of bytes to receive.
  -> IO (Either Errno (ByteArray, S.SockAddr))
receiveFromInterruptible :: TVar Bool
-> Socket -> Int -> IO (Either Errno (ByteArray, SockAddr))
receiveFromInterruptible !TVar Bool
intr Socket
s !Int
n = do
  MutableByteArray RealWorld
dst <- Int -> IO (MutableByteArray (PrimState IO))
forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
PM.newByteArray Int
n
  TVar Bool
-> Socket
-> MutableBytes RealWorld
-> IO (Either Errno (Int, SockAddr))
MB.receiveFromInterruptible TVar Bool
intr Socket
s (MutableByteArray RealWorld -> Int -> Int -> MutableBytes RealWorld
forall s. MutableByteArray s -> Int -> Int -> MutableBytes s
MutableBytes MutableByteArray RealWorld
dst Int
0 Int
n) IO (Either Errno (Int, SockAddr))
-> (Either Errno (Int, SockAddr)
    -> IO (Either Errno (ByteArray, SockAddr)))
-> IO (Either Errno (ByteArray, SockAddr))
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Left Errno
err -> Either Errno (ByteArray, SockAddr)
-> IO (Either Errno (ByteArray, SockAddr))
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Errno -> Either Errno (ByteArray, SockAddr)
forall a b. a -> Either a b
Left Errno
err)
    Right (Int
sz,SockAddr
sockAddr) -> do
      MutableByteArray (PrimState IO) -> Int -> IO ()
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> Int -> m ()
PM.shrinkMutableByteArray MutableByteArray RealWorld
MutableByteArray (PrimState IO)
dst Int
sz
      ByteArray
dst' <- MutableByteArray (PrimState IO) -> IO ByteArray
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> m ByteArray
PM.unsafeFreezeByteArray MutableByteArray RealWorld
MutableByteArray (PrimState IO)
dst
      Either Errno (ByteArray, SockAddr)
-> IO (Either Errno (ByteArray, SockAddr))
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((ByteArray, SockAddr) -> Either Errno (ByteArray, SockAddr)
forall a b. b -> Either a b
Right (ByteArray
dst', SockAddr
sockAddr))