{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE PatternSynonyms #-}

module Network.Unexceptional.ByteString
  ( send
  , sendInterruptible
  , receive
  , receiveExactly
  , receiveExactlyInterruptible
  ) where

import Control.Applicative ((<|>))
import Control.Concurrent.STM (STM, TVar)
import Control.Monad (when, (<=<))
import Data.ByteString.Internal (ByteString (BS))
import Data.Bytes.Types (MutableBytes (MutableBytes))
import Data.Functor (($>))
import Data.Primitive (ByteArray (ByteArray))
import Data.Primitive.Addr (Addr (Addr), plusAddr)
import Foreign.C.Error (Errno)
import Foreign.C.Error.Pattern (pattern EAGAIN, pattern EWOULDBLOCK)
import GHC.Conc (threadWaitWrite, threadWaitWriteSTM)
import GHC.ForeignPtr (ForeignPtr (ForeignPtr), ForeignPtrContents (PlainPtr))
import Network.Socket (Socket)
import System.Posix.Types (Fd (Fd))

import qualified Control.Concurrent.STM as STM
import qualified Data.ByteString.Unsafe as ByteString
import qualified Data.Primitive as PM
import qualified GHC.Exts as Exts
import qualified Linux.Socket as X
import qualified Network.Socket as S
import qualified Network.Unexceptional.MutableBytes as MB
import qualified Posix.Socket as X

{- | Send the entire byte sequence. This call POSIX @send@ in a loop
until all of the bytes have been sent.
-}
send ::
  Socket ->
  ByteString ->
  IO (Either Errno ())
send :: Socket -> ByteString -> IO (Either Errno ())
send Socket
s !ByteString
b =
  Socket -> (CInt -> IO (Either Errno ())) -> IO (Either Errno ())
forall r. Socket -> (CInt -> IO r) -> IO r
S.withFdSocket Socket
s ((CInt -> IO (Either Errno ())) -> IO (Either Errno ()))
-> (CInt -> IO (Either Errno ())) -> IO (Either Errno ())
forall a b. (a -> b) -> a -> b
$ \CInt
fd -> ByteString
-> (CStringLen -> IO (Either Errno ())) -> IO (Either Errno ())
forall a. ByteString -> (CStringLen -> IO a) -> IO a
ByteString.unsafeUseAsCStringLen ByteString
b ((CStringLen -> IO (Either Errno ())) -> IO (Either Errno ()))
-> (CStringLen -> IO (Either Errno ())) -> IO (Either Errno ())
forall a b. (a -> b) -> a -> b
$ \(PM.Ptr Addr#
ptr, Int
len) ->
    -- We attempt the first send without testing if the socket is in
    -- ready for writes. This is because it is uncommon for the transmit
    -- buffer to already be full.
    Fd -> Addr -> Int -> IO (Either Errno ())
sendLoop (CInt -> Fd
Fd CInt
fd) (Addr# -> Addr
Addr Addr#
ptr) Int
len

-- does not wait for file descriptor to be ready
sendLoop :: Fd -> Addr -> Int -> IO (Either Errno ())
sendLoop :: Fd -> Addr -> Int -> IO (Either Errno ())
sendLoop !Fd
fd !Addr
addr !Int
len =
  Fd
-> Addr -> CSize -> MessageFlags 'Send -> IO (Either Errno CSize)
X.uninterruptibleSend Fd
fd Addr
addr (Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
len) (MessageFlags 'Send
X.noSignal MessageFlags 'Send -> MessageFlags 'Send -> MessageFlags 'Send
forall a. Semigroup a => a -> a -> a
<> MessageFlags 'Send
forall (m :: Message). MessageFlags m
X.dontWait) IO (Either Errno CSize)
-> (Either Errno CSize -> IO (Either Errno ()))
-> IO (Either Errno ())
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Left Errno
e ->
      if Errno
e Errno -> Errno -> Bool
forall a. Eq a => a -> a -> Bool
== Errno
EAGAIN Bool -> Bool -> Bool
|| Errno
e Errno -> Errno -> Bool
forall a. Eq a => a -> a -> Bool
== Errno
EWOULDBLOCK
        then do
          Fd -> IO ()
threadWaitWrite Fd
fd
          Fd -> Addr -> Int -> IO (Either Errno ())
sendLoop Fd
fd Addr
addr Int
len
        else Either Errno () -> IO (Either Errno ())
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Errno -> Either Errno ()
forall a b. a -> Either a b
Left Errno
e)
    Right CSize
sentSzC ->
      let sentSz :: Int
sentSz = CSize -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral CSize
sentSzC :: Int
       in case Int -> Int -> Ordering
forall a. Ord a => a -> a -> Ordering
compare Int
sentSz Int
len of
            Ordering
EQ -> Either Errno () -> IO (Either Errno ())
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (() -> Either Errno ()
forall a b. b -> Either a b
Right ())
            Ordering
LT -> Fd -> Addr -> Int -> IO (Either Errno ())
sendLoop Fd
fd (Addr -> Int -> Addr
plusAddr Addr
addr Int
sentSz) (Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
sentSz)
            Ordering
GT -> String -> IO (Either Errno ())
forall a. String -> IO a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Network.Unexceptional.ByteString.sendLoop: send claimed to send too many bytes"

{- | Send the entire byte sequence. This call POSIX @send@ in a loop
until all of the bytes have been sent.
-}
sendInterruptible ::
  TVar Bool ->
  Socket ->
  ByteString ->
  IO (Either Errno ())
sendInterruptible :: TVar Bool -> Socket -> ByteString -> IO (Either Errno ())
sendInterruptible !TVar Bool
interrupt Socket
s !ByteString
b =
  Socket -> (CInt -> IO (Either Errno ())) -> IO (Either Errno ())
forall r. Socket -> (CInt -> IO r) -> IO r
S.withFdSocket Socket
s ((CInt -> IO (Either Errno ())) -> IO (Either Errno ()))
-> (CInt -> IO (Either Errno ())) -> IO (Either Errno ())
forall a b. (a -> b) -> a -> b
$ \CInt
fd -> ByteString
-> (CStringLen -> IO (Either Errno ())) -> IO (Either Errno ())
forall a. ByteString -> (CStringLen -> IO a) -> IO a
ByteString.unsafeUseAsCStringLen ByteString
b ((CStringLen -> IO (Either Errno ())) -> IO (Either Errno ()))
-> (CStringLen -> IO (Either Errno ())) -> IO (Either Errno ())
forall a b. (a -> b) -> a -> b
$ \(PM.Ptr Addr#
ptr, Int
len) ->
    -- We attempt the first send without testing if the socket is in
    -- ready for writes. This is because it is uncommon for the transmit
    -- buffer to already be full.
    TVar Bool -> Fd -> Addr -> Int -> IO (Either Errno ())
sendInterruptibleLoop TVar Bool
interrupt (CInt -> Fd
Fd CInt
fd) (Addr# -> Addr
Addr Addr#
ptr) Int
len

-- does not wait for file descriptor to be ready
sendInterruptibleLoop :: TVar Bool -> Fd -> Addr -> Int -> IO (Either Errno ())
sendInterruptibleLoop :: TVar Bool -> Fd -> Addr -> Int -> IO (Either Errno ())
sendInterruptibleLoop !TVar Bool
interrupt !Fd
fd !Addr
addr !Int
len =
  Fd
-> Addr -> CSize -> MessageFlags 'Send -> IO (Either Errno CSize)
X.uninterruptibleSend Fd
fd Addr
addr (Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
len) (MessageFlags 'Send
X.noSignal MessageFlags 'Send -> MessageFlags 'Send -> MessageFlags 'Send
forall a. Semigroup a => a -> a -> a
<> MessageFlags 'Send
forall (m :: Message). MessageFlags m
X.dontWait) IO (Either Errno CSize)
-> (Either Errno CSize -> IO (Either Errno ()))
-> IO (Either Errno ())
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Left Errno
e ->
      if Errno
e Errno -> Errno -> Bool
forall a. Eq a => a -> a -> Bool
== Errno
EAGAIN Bool -> Bool -> Bool
|| Errno
e Errno -> Errno -> Bool
forall a. Eq a => a -> a -> Bool
== Errno
EWOULDBLOCK
        then
          TVar Bool -> Fd -> IO Outcome
waitUntilWriteable TVar Bool
interrupt Fd
fd IO Outcome
-> (Outcome -> IO (Either Errno ())) -> IO (Either Errno ())
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
            Outcome
Ready -> TVar Bool -> Fd -> Addr -> Int -> IO (Either Errno ())
sendInterruptibleLoop TVar Bool
interrupt Fd
fd Addr
addr Int
len
            Outcome
Interrupted -> Either Errno () -> IO (Either Errno ())
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Errno -> Either Errno ()
forall a b. a -> Either a b
Left Errno
EAGAIN)
        else Either Errno () -> IO (Either Errno ())
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Errno -> Either Errno ()
forall a b. a -> Either a b
Left Errno
e)
    Right CSize
sentSzC ->
      let sentSz :: Int
sentSz = CSize -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral CSize
sentSzC :: Int
       in case Int -> Int -> Ordering
forall a. Ord a => a -> a -> Ordering
compare Int
sentSz Int
len of
            Ordering
EQ -> Either Errno () -> IO (Either Errno ())
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (() -> Either Errno ()
forall a b. b -> Either a b
Right ())
            Ordering
LT -> TVar Bool -> Fd -> Addr -> Int -> IO (Either Errno ())
sendInterruptibleLoop TVar Bool
interrupt Fd
fd (Addr -> Int -> Addr
plusAddr Addr
addr Int
sentSz) (Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
sentSz)
            Ordering
GT -> String -> IO (Either Errno ())
forall a. String -> IO a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Network.Unexceptional.ByteString.sendInterruptibleLoop: send claimed to send too many bytes"

{- | If this returns zero bytes, it means that the peer has
performed an orderly shutdown.
-}
receive ::
  Socket ->
  -- | Maximum number of bytes to receive
  Int ->
  IO (Either Errno ByteString)
receive :: Socket -> Int -> IO (Either Errno ByteString)
receive Socket
s Int
n = do
  MutableByteArray RealWorld
dst <- Int -> IO (MutableByteArray (PrimState IO))
forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
PM.newPinnedByteArray Int
n
  Socket -> MutableBytes RealWorld -> IO (Either Errno Int)
MB.receive Socket
s (MutableByteArray RealWorld -> Int -> Int -> MutableBytes RealWorld
forall s. MutableByteArray s -> Int -> Int -> MutableBytes s
MutableBytes MutableByteArray RealWorld
dst Int
0 Int
n) IO (Either Errno Int)
-> (Either Errno Int -> IO (Either Errno ByteString))
-> IO (Either Errno ByteString)
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Left Errno
e -> Either Errno ByteString -> IO (Either Errno ByteString)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Errno -> Either Errno ByteString
forall a b. a -> Either a b
Left Errno
e)
    Right Int
m -> do
      Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
m Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
n) (MutableByteArray (PrimState IO) -> Int -> IO ()
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> Int -> m ()
PM.shrinkMutableByteArray MutableByteArray RealWorld
MutableByteArray (PrimState IO)
dst Int
m)
      ByteArray ByteArray#
dst# <- MutableByteArray (PrimState IO) -> IO ByteArray
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> m ByteArray
PM.unsafeFreezeByteArray MutableByteArray RealWorld
MutableByteArray (PrimState IO)
dst
      Either Errno ByteString -> IO (Either Errno ByteString)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ByteString -> Either Errno ByteString
forall a b. b -> Either a b
Right (ForeignPtr Word8 -> Int -> ByteString
BS (Addr# -> ForeignPtrContents -> ForeignPtr Word8
forall a. Addr# -> ForeignPtrContents -> ForeignPtr a
ForeignPtr (ByteArray# -> Addr#
Exts.byteArrayContents# ByteArray#
dst#) (MutableByteArray# RealWorld -> ForeignPtrContents
PlainPtr (ByteArray# -> MutableByteArray# RealWorld
forall a b. a -> b
Exts.unsafeCoerce# ByteArray#
dst#))) Int
m))

-- | Blocks until an exact number of bytes has been received.
receiveExactly ::
  Socket ->
  -- | Exact number of bytes to receive, must be greater than zero
  Int ->
  IO (Either Errno ByteString)
receiveExactly :: Socket -> Int -> IO (Either Errno ByteString)
receiveExactly !Socket
s !Int
n = do
  MutableByteArray RealWorld
dst <- Int -> IO (MutableByteArray (PrimState IO))
forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
PM.newPinnedByteArray Int
n
  Socket -> MutableBytes RealWorld -> IO (Either Errno ())
MB.receiveExactly Socket
s (MutableByteArray RealWorld -> Int -> Int -> MutableBytes RealWorld
forall s. MutableByteArray s -> Int -> Int -> MutableBytes s
MutableBytes MutableByteArray RealWorld
dst Int
0 Int
n) IO (Either Errno ())
-> (Either Errno () -> IO (Either Errno ByteString))
-> IO (Either Errno ByteString)
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Left Errno
e -> Either Errno ByteString -> IO (Either Errno ByteString)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Errno -> Either Errno ByteString
forall a b. a -> Either a b
Left Errno
e)
    Right ()
_ -> do
      ByteArray ByteArray#
dst# <- MutableByteArray (PrimState IO) -> IO ByteArray
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> m ByteArray
PM.unsafeFreezeByteArray MutableByteArray RealWorld
MutableByteArray (PrimState IO)
dst
      Either Errno ByteString -> IO (Either Errno ByteString)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ByteString -> Either Errno ByteString
forall a b. b -> Either a b
Right (ForeignPtr Word8 -> Int -> ByteString
BS (Addr# -> ForeignPtrContents -> ForeignPtr Word8
forall a. Addr# -> ForeignPtrContents -> ForeignPtr a
ForeignPtr (ByteArray# -> Addr#
Exts.byteArrayContents# ByteArray#
dst#) (MutableByteArray# RealWorld -> ForeignPtrContents
PlainPtr (ByteArray# -> MutableByteArray# RealWorld
forall a b. a -> b
Exts.unsafeCoerce# ByteArray#
dst#))) Int
n))

-- | Blocks until an exact number of bytes has been received.
receiveExactlyInterruptible ::
  TVar Bool ->
  Socket ->
  -- | Exact number of bytes to receive, must be greater than zero
  Int ->
  IO (Either Errno ByteString)
receiveExactlyInterruptible :: TVar Bool -> Socket -> Int -> IO (Either Errno ByteString)
receiveExactlyInterruptible !TVar Bool
intr !Socket
s !Int
n = do
  MutableByteArray RealWorld
dst <- Int -> IO (MutableByteArray (PrimState IO))
forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
PM.newPinnedByteArray Int
n
  TVar Bool
-> Socket -> MutableBytes RealWorld -> IO (Either Errno ())
MB.receiveExactlyInterruptible TVar Bool
intr Socket
s (MutableByteArray RealWorld -> Int -> Int -> MutableBytes RealWorld
forall s. MutableByteArray s -> Int -> Int -> MutableBytes s
MutableBytes MutableByteArray RealWorld
dst Int
0 Int
n) IO (Either Errno ())
-> (Either Errno () -> IO (Either Errno ByteString))
-> IO (Either Errno ByteString)
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Left Errno
e -> Either Errno ByteString -> IO (Either Errno ByteString)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Errno -> Either Errno ByteString
forall a b. a -> Either a b
Left Errno
e)
    Right ()
_ -> do
      ByteArray ByteArray#
dst# <- MutableByteArray (PrimState IO) -> IO ByteArray
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> m ByteArray
PM.unsafeFreezeByteArray MutableByteArray RealWorld
MutableByteArray (PrimState IO)
dst
      Either Errno ByteString -> IO (Either Errno ByteString)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ByteString -> Either Errno ByteString
forall a b. b -> Either a b
Right (ForeignPtr Word8 -> Int -> ByteString
BS (Addr# -> ForeignPtrContents -> ForeignPtr Word8
forall a. Addr# -> ForeignPtrContents -> ForeignPtr a
ForeignPtr (ByteArray# -> Addr#
Exts.byteArrayContents# ByteArray#
dst#) (MutableByteArray# RealWorld -> ForeignPtrContents
PlainPtr (ByteArray# -> MutableByteArray# RealWorld
forall a b. a -> b
Exts.unsafeCoerce# ByteArray#
dst#))) Int
n))

waitUntilWriteable :: TVar Bool -> Fd -> IO Outcome
waitUntilWriteable :: TVar Bool -> Fd -> IO Outcome
waitUntilWriteable !TVar Bool
interrupt !Fd
fd = do
  (STM ()
isReadyAction, IO ()
deregister) <- Fd -> IO (STM (), IO ())
threadWaitWriteSTM Fd
fd
  Outcome
outcome <- STM Outcome -> IO Outcome
forall a. STM a -> IO a
STM.atomically (STM Outcome -> IO Outcome) -> STM Outcome -> IO Outcome
forall a b. (a -> b) -> a -> b
$ (STM ()
isReadyAction STM () -> Outcome -> STM Outcome
forall (f :: * -> *) a b. Functor f => f a -> b -> f b
$> Outcome
Ready) STM Outcome -> STM Outcome -> STM Outcome
forall a. STM a -> STM a -> STM a
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> (TVar Bool -> STM ()
checkFinished TVar Bool
interrupt STM () -> Outcome -> STM Outcome
forall (f :: * -> *) a b. Functor f => f a -> b -> f b
$> Outcome
Interrupted)
  IO ()
deregister
  Outcome -> IO Outcome
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Outcome
outcome

data Outcome = Ready | Interrupted

checkFinished :: TVar Bool -> STM ()
checkFinished :: TVar Bool -> STM ()
checkFinished = Bool -> STM ()
STM.check (Bool -> STM ()) -> (TVar Bool -> STM Bool) -> TVar Bool -> STM ()
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< TVar Bool -> STM Bool
forall a. TVar a -> STM a
STM.readTVar