{-# language BangPatterns #-}
{-# language DuplicateRecordFields #-}
{-# language MagicHash #-}
{-# language PatternSynonyms #-}
{-# language LambdaCase #-}
{-# language NamedFieldPuns #-}

module Network.Unexceptional.ByteString
  ( send
  , receive
  ) where

import Control.Monad (when)
import Data.ByteString.Internal (ByteString(BS))
import Data.Bytes.Types (MutableBytes(MutableBytes))
import Data.Primitive (ByteArray(ByteArray))
import Data.Primitive.Addr (Addr(Addr),plusAddr)
import Foreign.C.Error (Errno)
import Foreign.C.Error.Pattern (pattern EWOULDBLOCK,pattern EAGAIN)
import GHC.Conc (threadWaitWrite)
import GHC.ForeignPtr (ForeignPtr(ForeignPtr),ForeignPtrContents(PlainPtr))
import Network.Socket (Socket)
import System.Posix.Types (Fd(Fd))

import qualified Data.ByteString.Unsafe as ByteString
import qualified Data.Primitive as PM
import qualified GHC.Exts as Exts
import qualified Linux.Socket as X
import qualified Network.Socket as S
import qualified Network.Unexceptional.MutableBytes as MB
import qualified Posix.Socket as X

-- | Send the entire byte sequence. This call POSIX @send@ in a loop
-- until all of the bytes have been sent.
send ::
     Socket
  -> ByteString
  -> IO (Either Errno ())
send :: Socket -> ByteString -> IO (Either Errno ())
send Socket
s !ByteString
b =
  forall r. Socket -> (CInt -> IO r) -> IO r
S.withFdSocket Socket
s forall a b. (a -> b) -> a -> b
$ \CInt
fd -> forall a. ByteString -> (CStringLen -> IO a) -> IO a
ByteString.unsafeUseAsCStringLen ByteString
b forall a b. (a -> b) -> a -> b
$ \(PM.Ptr Addr#
ptr,Int
len) ->
  -- We attempt the first send without testing if the socket is in
  -- ready for writes. This is because it is uncommon for the transmit
  -- buffer to already be full.
  Fd -> Addr -> Int -> IO (Either Errno ())
sendLoop (CInt -> Fd
Fd CInt
fd) (Addr# -> Addr
Addr Addr#
ptr) Int
len

-- does not wait for file descriptor to be ready
sendLoop :: Fd -> Addr -> Int -> IO (Either Errno ())
sendLoop :: Fd -> Addr -> Int -> IO (Either Errno ())
sendLoop !Fd
fd !Addr
addr !Int
len =
  Fd
-> Addr -> CSize -> MessageFlags 'Send -> IO (Either Errno CSize)
X.uninterruptibleSend Fd
fd Addr
addr (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
len) (MessageFlags 'Send
X.noSignal forall a. Semigroup a => a -> a -> a
<> forall (m :: Message). MessageFlags m
X.dontWait) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Left Errno
e -> if Errno
e forall a. Eq a => a -> a -> Bool
== Errno
EAGAIN Bool -> Bool -> Bool
|| Errno
e forall a. Eq a => a -> a -> Bool
== Errno
EWOULDBLOCK
      then do
        Fd -> IO ()
threadWaitWrite Fd
fd
        Fd -> Addr -> Int -> IO (Either Errno ())
sendLoop Fd
fd Addr
addr Int
len
      else forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a b. a -> Either a b
Left Errno
e)
    Right CSize
sentSzC ->
      let sentSz :: Int
sentSz = forall a b. (Integral a, Num b) => a -> b
fromIntegral CSize
sentSzC :: Int
       in case forall a. Ord a => a -> a -> Ordering
compare Int
sentSz Int
len of
            Ordering
EQ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a b. b -> Either a b
Right ())
            Ordering
LT -> Fd -> Addr -> Int -> IO (Either Errno ())
sendLoop Fd
fd (Addr -> Int -> Addr
plusAddr Addr
addr Int
sentSz) (Int
len forall a. Num a => a -> a -> a
- Int
sentSz)
            Ordering
GT -> forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Network.Unexceptional.ByteString.sendLoop: send claimed to send too many bytes"

-- | If this returns zero bytes, it means that the peer has
-- performed an orderly shutdown.
receive :: 
     Socket
  -> Int -- ^ Maximum number of bytes to receive
  -> IO (Either Errno ByteString)
receive :: Socket -> Int -> IO (Either Errno ByteString)
receive Socket
s Int
n = do
  MutableByteArray RealWorld
dst <- forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
PM.newPinnedByteArray Int
n
  Socket -> MutableBytes RealWorld -> IO (Either Errno Int)
MB.receive Socket
s (forall s. MutableByteArray s -> Int -> Int -> MutableBytes s
MutableBytes MutableByteArray RealWorld
dst Int
0 Int
n) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Left Errno
e -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a b. a -> Either a b
Left Errno
e)
    Right Int
m -> do
      forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
m forall a. Ord a => a -> a -> Bool
< Int
n) (forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> Int -> m ()
PM.shrinkMutableByteArray MutableByteArray RealWorld
dst Int
m)
      ByteArray ByteArray#
dst# <- forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> m ByteArray
PM.unsafeFreezeByteArray MutableByteArray RealWorld
dst 
      forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a b. b -> Either a b
Right (ForeignPtr Word8 -> Int -> ByteString
BS (forall a. Addr# -> ForeignPtrContents -> ForeignPtr a
ForeignPtr (ByteArray# -> Addr#
Exts.byteArrayContents# ByteArray#
dst#) (MutableByteArray# RealWorld -> ForeignPtrContents
PlainPtr (unsafeCoerce# :: forall a b. a -> b
Exts.unsafeCoerce# ByteArray#
dst#))) Int
m))