{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE PatternSynonyms #-}

module Network.Unexceptional.Bytes
  ( send
  , sendInterruptible
  , receive
  , receiveInterruptible
  ) where

import Control.Applicative ((<|>))
import Control.Concurrent.STM (STM, TVar)
import Control.Exception (throwIO)
import Control.Monad (when, (<=<))
import Data.Bytes.Types (Bytes (Bytes), MutableBytes (MutableBytes))
import Data.Functor (($>))
import Data.Primitive (ByteArray, MutableByteArray)
import Foreign.C.Error (Errno)
import Foreign.C.Error.Pattern (pattern EAGAIN, pattern EWOULDBLOCK)
import GHC.Conc (threadWaitWrite, threadWaitWriteSTM)
import GHC.Exts (RealWorld)
import Network.Socket (Socket)
import System.Posix.Types (Fd (Fd))

import qualified Control.Concurrent.STM as STM
import qualified Data.Bytes.Types
import qualified Data.Primitive as PM
import qualified Linux.Socket as X
import qualified Network.Socket as S
import qualified Network.Unexceptional.MutableBytes as MB
import qualified Network.Unexceptional.Types as Types
import qualified Posix.Socket as X

{- | Send the entire byte sequence. This call POSIX @send@ in a loop
until all of the bytes have been sent.

If this is passed the empty byte sequence, it doesn't actually call
POSIX @send()@. It just returns that it succeeded.
-}
send ::
  Socket ->
  Bytes ->
  IO (Either Errno ())
send :: Socket -> Bytes -> IO (Either Errno ())
send Socket
s Bytes {ByteArray
array :: ByteArray
$sel:array:Bytes :: Bytes -> ByteArray
array, Int
offset :: Int
$sel:offset:Bytes :: Bytes -> Int
offset, $sel:length:Bytes :: Bytes -> Int
length = Int
len} = case Int
len of
  Int
0 -> Either Errno () -> IO (Either Errno ())
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (() -> Either Errno ()
forall a b. b -> Either a b
Right ())
  Int
_ -> Socket -> (CInt -> IO (Either Errno ())) -> IO (Either Errno ())
forall r. Socket -> (CInt -> IO r) -> IO r
S.withFdSocket Socket
s ((CInt -> IO (Either Errno ())) -> IO (Either Errno ()))
-> (CInt -> IO (Either Errno ())) -> IO (Either Errno ())
forall a b. (a -> b) -> a -> b
$ \CInt
fd ->
    -- We attempt the first send without testing if the socket is in
    -- ready for writes. This is because it is uncommon for the transmit
    -- buffer to already be full.
    Fd -> ByteArray -> Int -> Int -> IO (Either Errno ())
sendLoop (CInt -> Fd
Fd CInt
fd) ByteArray
array Int
offset Int
len

-- does not wait for file descriptor to be ready
sendLoop :: Fd -> ByteArray -> Int -> Int -> IO (Either Errno ())
sendLoop :: Fd -> ByteArray -> Int -> Int -> IO (Either Errno ())
sendLoop !Fd
fd !ByteArray
arr !Int
off !Int
len =
  Fd
-> ByteArray
-> Int
-> CSize
-> MessageFlags 'Send
-> IO (Either Errno CSize)
X.uninterruptibleSendByteArray Fd
fd ByteArray
arr Int
off (Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
len) (MessageFlags 'Send
X.noSignal MessageFlags 'Send -> MessageFlags 'Send -> MessageFlags 'Send
forall a. Semigroup a => a -> a -> a
<> MessageFlags 'Send
forall (m :: Message). MessageFlags m
X.dontWait) IO (Either Errno CSize)
-> (Either Errno CSize -> IO (Either Errno ()))
-> IO (Either Errno ())
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Left Errno
e ->
      if Errno
e Errno -> Errno -> Bool
forall a. Eq a => a -> a -> Bool
== Errno
EAGAIN Bool -> Bool -> Bool
|| Errno
e Errno -> Errno -> Bool
forall a. Eq a => a -> a -> Bool
== Errno
EWOULDBLOCK
        then do
          Fd -> IO ()
threadWaitWrite Fd
fd
          Fd -> ByteArray -> Int -> Int -> IO (Either Errno ())
sendLoop Fd
fd ByteArray
arr Int
off Int
len
        else Either Errno () -> IO (Either Errno ())
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Errno -> Either Errno ()
forall a b. a -> Either a b
Left Errno
e)
    Right CSize
sentSzC ->
      let sentSz :: Int
sentSz = CSize -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral CSize
sentSzC :: Int
       in case Int -> Int -> Ordering
forall a. Ord a => a -> a -> Ordering
compare Int
sentSz Int
len of
            Ordering
EQ -> Either Errno () -> IO (Either Errno ())
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (() -> Either Errno ()
forall a b. b -> Either a b
Right ())
            Ordering
LT -> Fd -> ByteArray -> Int -> Int -> IO (Either Errno ())
sendLoop Fd
fd ByteArray
arr (Int
off Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
sentSz) (Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
sentSz)
            Ordering
GT -> String -> IO (Either Errno ())
forall a. String -> IO a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Network.Unexceptional.Bytes.sendAll: send claimed to send too many bytes"

sendInterruptibleLoop :: TVar Bool -> Fd -> ByteArray -> Int -> Int -> IO (Either Errno ())
sendInterruptibleLoop :: TVar Bool -> Fd -> ByteArray -> Int -> Int -> IO (Either Errno ())
sendInterruptibleLoop !TVar Bool
interrupt !Fd
fd !ByteArray
arr !Int
off !Int
len =
  Fd
-> ByteArray
-> Int
-> CSize
-> MessageFlags 'Send
-> IO (Either Errno CSize)
X.uninterruptibleSendByteArray Fd
fd ByteArray
arr Int
off (Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
len) (MessageFlags 'Send
X.noSignal MessageFlags 'Send -> MessageFlags 'Send -> MessageFlags 'Send
forall a. Semigroup a => a -> a -> a
<> MessageFlags 'Send
forall (m :: Message). MessageFlags m
X.dontWait) IO (Either Errno CSize)
-> (Either Errno CSize -> IO (Either Errno ()))
-> IO (Either Errno ())
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Left Errno
e ->
      if Errno
e Errno -> Errno -> Bool
forall a. Eq a => a -> a -> Bool
== Errno
EAGAIN Bool -> Bool -> Bool
|| Errno
e Errno -> Errno -> Bool
forall a. Eq a => a -> a -> Bool
== Errno
EWOULDBLOCK
        then
          TVar Bool -> Fd -> IO Outcome
waitUntilWriteable TVar Bool
interrupt Fd
fd IO Outcome
-> (Outcome -> IO (Either Errno ())) -> IO (Either Errno ())
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
            Outcome
Ready -> TVar Bool -> Fd -> ByteArray -> Int -> Int -> IO (Either Errno ())
sendInterruptibleLoop TVar Bool
interrupt Fd
fd ByteArray
arr Int
off Int
len
            Outcome
Interrupted -> Either Errno () -> IO (Either Errno ())
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Errno -> Either Errno ()
forall a b. a -> Either a b
Left Errno
EAGAIN)
        else Either Errno () -> IO (Either Errno ())
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Errno -> Either Errno ()
forall a b. a -> Either a b
Left Errno
e)
    Right CSize
sentSzC ->
      let sentSz :: Int
sentSz = CSize -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral CSize
sentSzC :: Int
       in case Int -> Int -> Ordering
forall a. Ord a => a -> a -> Ordering
compare Int
sentSz Int
len of
            Ordering
EQ -> Either Errno () -> IO (Either Errno ())
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (() -> Either Errno ()
forall a b. b -> Either a b
Right ())
            Ordering
LT -> TVar Bool -> Fd -> ByteArray -> Int -> Int -> IO (Either Errno ())
sendInterruptibleLoop TVar Bool
interrupt Fd
fd ByteArray
arr (Int
off Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
sentSz) (Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
sentSz)
            Ordering
GT -> String -> IO (Either Errno ())
forall a. String -> IO a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Network.Unexceptional.Bytes.sendAll: send claimed to send too many bytes"

-- | Variant of 'send' that fails with @EAGAIN@ if the interrupt ever becomes true.
sendInterruptible ::
  TVar Bool ->
  Socket ->
  Bytes ->
  IO (Either Errno ())
sendInterruptible :: TVar Bool -> Socket -> Bytes -> IO (Either Errno ())
sendInterruptible !TVar Bool
interrupt Socket
s Bytes {ByteArray
$sel:array:Bytes :: Bytes -> ByteArray
array :: ByteArray
array, Int
$sel:offset:Bytes :: Bytes -> Int
offset :: Int
offset, $sel:length:Bytes :: Bytes -> Int
length = Int
len} = case Int
len of
  Int
0 -> Either Errno () -> IO (Either Errno ())
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (() -> Either Errno ()
forall a b. b -> Either a b
Right ())
  Int
_ ->
    TVar Bool -> IO Bool
forall a. TVar a -> IO a
STM.readTVarIO TVar Bool
interrupt IO Bool -> (Bool -> IO (Either Errno ())) -> IO (Either Errno ())
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
      Bool
True -> Either Errno () -> IO (Either Errno ())
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Errno -> Either Errno ()
forall a b. a -> Either a b
Left Errno
EAGAIN)
      Bool
False -> Socket -> (CInt -> IO (Either Errno ())) -> IO (Either Errno ())
forall r. Socket -> (CInt -> IO r) -> IO r
S.withFdSocket Socket
s ((CInt -> IO (Either Errno ())) -> IO (Either Errno ()))
-> (CInt -> IO (Either Errno ())) -> IO (Either Errno ())
forall a b. (a -> b) -> a -> b
$ \CInt
fd ->
        -- We attempt the first send without testing if the socket is in
        -- ready for writes. This is because it is uncommon for the transmit
        -- buffer to already be full.
        TVar Bool -> Fd -> ByteArray -> Int -> Int -> IO (Either Errno ())
sendInterruptibleLoop TVar Bool
interrupt (CInt -> Fd
Fd CInt
fd) ByteArray
array Int
offset Int
len

data Outcome = Ready | Interrupted

checkFinished :: TVar Bool -> STM ()
checkFinished :: TVar Bool -> STM ()
checkFinished = Bool -> STM ()
STM.check (Bool -> STM ()) -> (TVar Bool -> STM Bool) -> TVar Bool -> STM ()
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< TVar Bool -> STM Bool
forall a. TVar a -> STM a
STM.readTVar

waitUntilWriteable :: TVar Bool -> Fd -> IO Outcome
waitUntilWriteable :: TVar Bool -> Fd -> IO Outcome
waitUntilWriteable !TVar Bool
interrupt !Fd
fd = do
  (STM ()
isReadyAction, IO ()
deregister) <- Fd -> IO (STM (), IO ())
threadWaitWriteSTM Fd
fd
  Outcome
outcome <- STM Outcome -> IO Outcome
forall a. STM a -> IO a
STM.atomically (STM Outcome -> IO Outcome) -> STM Outcome -> IO Outcome
forall a b. (a -> b) -> a -> b
$ (STM ()
isReadyAction STM () -> Outcome -> STM Outcome
forall (f :: * -> *) a b. Functor f => f a -> b -> f b
$> Outcome
Ready) STM Outcome -> STM Outcome -> STM Outcome
forall a. STM a -> STM a -> STM a
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> (TVar Bool -> STM ()
checkFinished TVar Bool
interrupt STM () -> Outcome -> STM Outcome
forall (f :: * -> *) a b. Functor f => f a -> b -> f b
$> Outcome
Interrupted)
  IO ()
deregister
  Outcome -> IO Outcome
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Outcome
outcome

{- | If this returns zero bytes, it means that the peer has
performed an orderly shutdown.
-}
receive ::
  Socket ->
  -- | Maximum number of bytes to receive
  Int ->
  IO (Either Errno Bytes)
receive :: Socket -> Int -> IO (Either Errno Bytes)
receive Socket
s Int
n =
  if Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0
    then do
      MutableByteArray RealWorld
dst <- Int -> IO (MutableByteArray (PrimState IO))
forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
PM.newByteArray Int
n
      Socket -> MutableBytes RealWorld -> IO (Either Errno Int)
MB.receive Socket
s (MutableByteArray RealWorld -> Int -> Int -> MutableBytes RealWorld
forall s. MutableByteArray s -> Int -> Int -> MutableBytes s
MutableBytes MutableByteArray RealWorld
dst Int
0 Int
n) IO (Either Errno Int)
-> (Either Errno Int -> IO (Either Errno Bytes))
-> IO (Either Errno Bytes)
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= MutableByteArray RealWorld
-> Int -> Either Errno Int -> IO (Either Errno Bytes)
handleRececeptionResult MutableByteArray RealWorld
dst Int
n
    else NonpositiveReceptionSize -> IO (Either Errno Bytes)
forall e a. Exception e => e -> IO a
throwIO NonpositiveReceptionSize
Types.NonpositiveReceptionSize

{- | If this returns zero bytes, it means that the peer has
performed an orderly shutdown.
-}
receiveInterruptible ::
  -- | Interrupt
  TVar Bool ->
  Socket ->
  -- | Maximum number of bytes to receive
  Int ->
  IO (Either Errno Bytes)
receiveInterruptible :: TVar Bool -> Socket -> Int -> IO (Either Errno Bytes)
receiveInterruptible !TVar Bool
interrupt Socket
s Int
n =
  if Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0
    then do
      MutableByteArray RealWorld
dst <- Int -> IO (MutableByteArray (PrimState IO))
forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
PM.newByteArray Int
n
      TVar Bool
-> Socket -> MutableBytes RealWorld -> IO (Either Errno Int)
MB.receiveInterruptible TVar Bool
interrupt Socket
s (MutableByteArray RealWorld -> Int -> Int -> MutableBytes RealWorld
forall s. MutableByteArray s -> Int -> Int -> MutableBytes s
MutableBytes MutableByteArray RealWorld
dst Int
0 Int
n) IO (Either Errno Int)
-> (Either Errno Int -> IO (Either Errno Bytes))
-> IO (Either Errno Bytes)
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= MutableByteArray RealWorld
-> Int -> Either Errno Int -> IO (Either Errno Bytes)
handleRececeptionResult MutableByteArray RealWorld
dst Int
n
    else NonpositiveReceptionSize -> IO (Either Errno Bytes)
forall e a. Exception e => e -> IO a
throwIO NonpositiveReceptionSize
Types.NonpositiveReceptionSize

handleRececeptionResult :: MutableByteArray RealWorld -> Int -> Either Errno Int -> IO (Either Errno Bytes)
handleRececeptionResult :: MutableByteArray RealWorld
-> Int -> Either Errno Int -> IO (Either Errno Bytes)
handleRececeptionResult !MutableByteArray RealWorld
dst !Int
n Either Errno Int
x = case Either Errno Int
x of
  Left Errno
e -> Either Errno Bytes -> IO (Either Errno Bytes)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Errno -> Either Errno Bytes
forall a b. a -> Either a b
Left Errno
e)
  Right Int
m -> do
    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
m Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
n) (MutableByteArray (PrimState IO) -> Int -> IO ()
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> Int -> m ()
PM.shrinkMutableByteArray MutableByteArray RealWorld
MutableByteArray (PrimState IO)
dst Int
m)
    ByteArray
dst' <- MutableByteArray (PrimState IO) -> IO ByteArray
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> m ByteArray
PM.unsafeFreezeByteArray MutableByteArray RealWorld
MutableByteArray (PrimState IO)
dst
    Either Errno Bytes -> IO (Either Errno Bytes)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Bytes -> Either Errno Bytes
forall a b. b -> Either a b
Right (ByteArray -> Int -> Int -> Bytes
Bytes ByteArray
dst' Int
0 Int
m))