{-# language BangPatterns #-}
{-# language DuplicateRecordFields #-}
{-# language PatternSynonyms #-}
{-# language LambdaCase #-}
{-# language NamedFieldPuns #-}

module Network.Unexceptional.Bytes
  ( send
  , sendInterruptible
  , receive
  , receiveInterruptible
  ) where

import Control.Applicative ((<|>))
import Control.Concurrent.STM (STM,TVar)
import Control.Exception (throwIO)
import Control.Monad ((<=<))
import Control.Monad (when)
import Data.Bytes.Types (Bytes(Bytes),MutableBytes(MutableBytes))
import Data.Functor (($>))
import Data.Primitive (MutableByteArray,ByteArray)
import Foreign.C.Error (Errno)
import Foreign.C.Error.Pattern (pattern EWOULDBLOCK,pattern EAGAIN)
import GHC.Conc (threadWaitWrite,threadWaitWriteSTM)
import GHC.Exts (RealWorld)
import Network.Socket (Socket)
import System.Posix.Types (Fd(Fd))

import qualified Control.Concurrent.STM as STM
import qualified Data.Bytes.Types
import qualified Data.Primitive as PM
import qualified Linux.Socket as X
import qualified Network.Socket as S
import qualified Network.Unexceptional.MutableBytes as MB
import qualified Network.Unexceptional.Types as Types
import qualified Posix.Socket as X

-- | Send the entire byte sequence. This call POSIX @send@ in a loop
-- until all of the bytes have been sent.
--
-- If this is passed the empty byte sequence, it doesn't actually call
-- POSIX @send()@. It just returns that it succeeded.
send ::
     Socket
  -> Bytes
  -> IO (Either Errno ())
send :: Socket -> Bytes -> IO (Either Errno ())
send Socket
s Bytes{ByteArray
$sel:array:Bytes :: Bytes -> ByteArray
array :: ByteArray
array,Int
$sel:offset:Bytes :: Bytes -> Int
offset :: Int
offset,$sel:length:Bytes :: Bytes -> Int
length=Int
len} = case Int
len of
  Int
0 -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a b. b -> Either a b
Right ())
  Int
_ -> forall r. Socket -> (CInt -> IO r) -> IO r
S.withFdSocket Socket
s forall a b. (a -> b) -> a -> b
$ \CInt
fd ->
    -- We attempt the first send without testing if the socket is in
    -- ready for writes. This is because it is uncommon for the transmit
    -- buffer to already be full.
    Fd -> ByteArray -> Int -> Int -> IO (Either Errno ())
sendLoop (CInt -> Fd
Fd CInt
fd) ByteArray
array Int
offset Int
len

-- does not wait for file descriptor to be ready
sendLoop :: Fd -> ByteArray -> Int -> Int -> IO (Either Errno ())
sendLoop :: Fd -> ByteArray -> Int -> Int -> IO (Either Errno ())
sendLoop !Fd
fd !ByteArray
arr !Int
off !Int
len =
  Fd
-> ByteArray
-> Int
-> CSize
-> MessageFlags 'Send
-> IO (Either Errno CSize)
X.uninterruptibleSendByteArray Fd
fd ByteArray
arr Int
off (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
len) (MessageFlags 'Send
X.noSignal forall a. Semigroup a => a -> a -> a
<> forall (m :: Message). MessageFlags m
X.dontWait) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Left Errno
e -> if Errno
e forall a. Eq a => a -> a -> Bool
== Errno
EAGAIN Bool -> Bool -> Bool
|| Errno
e forall a. Eq a => a -> a -> Bool
== Errno
EWOULDBLOCK
      then do
        Fd -> IO ()
threadWaitWrite Fd
fd
        Fd -> ByteArray -> Int -> Int -> IO (Either Errno ())
sendLoop Fd
fd ByteArray
arr Int
off Int
len
      else forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a b. a -> Either a b
Left Errno
e)
    Right CSize
sentSzC ->
      let sentSz :: Int
sentSz = forall a b. (Integral a, Num b) => a -> b
fromIntegral CSize
sentSzC :: Int
       in case forall a. Ord a => a -> a -> Ordering
compare Int
sentSz Int
len of
            Ordering
EQ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a b. b -> Either a b
Right ())
            Ordering
LT -> Fd -> ByteArray -> Int -> Int -> IO (Either Errno ())
sendLoop Fd
fd ByteArray
arr (Int
off forall a. Num a => a -> a -> a
+ Int
sentSz) (Int
len forall a. Num a => a -> a -> a
- Int
sentSz)
            Ordering
GT -> forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Network.Unexceptional.Bytes.sendAll: send claimed to send too many bytes"

sendInterruptibleLoop :: TVar Bool -> Fd -> ByteArray -> Int -> Int -> IO (Either Errno ())
sendInterruptibleLoop :: TVar Bool -> Fd -> ByteArray -> Int -> Int -> IO (Either Errno ())
sendInterruptibleLoop !TVar Bool
interrupt !Fd
fd !ByteArray
arr !Int
off !Int
len =
  Fd
-> ByteArray
-> Int
-> CSize
-> MessageFlags 'Send
-> IO (Either Errno CSize)
X.uninterruptibleSendByteArray Fd
fd ByteArray
arr Int
off (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
len) (MessageFlags 'Send
X.noSignal forall a. Semigroup a => a -> a -> a
<> forall (m :: Message). MessageFlags m
X.dontWait) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Left Errno
e -> if Errno
e forall a. Eq a => a -> a -> Bool
== Errno
EAGAIN Bool -> Bool -> Bool
|| Errno
e forall a. Eq a => a -> a -> Bool
== Errno
EWOULDBLOCK
      then TVar Bool -> Fd -> IO Outcome
waitUntilWriteable TVar Bool
interrupt Fd
fd forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
        Outcome
Ready -> TVar Bool -> Fd -> ByteArray -> Int -> Int -> IO (Either Errno ())
sendInterruptibleLoop TVar Bool
interrupt Fd
fd ByteArray
arr Int
off Int
len
        Outcome
Interrupted -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a b. a -> Either a b
Left Errno
EAGAIN)
      else forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a b. a -> Either a b
Left Errno
e)
    Right CSize
sentSzC ->
      let sentSz :: Int
sentSz = forall a b. (Integral a, Num b) => a -> b
fromIntegral CSize
sentSzC :: Int
       in case forall a. Ord a => a -> a -> Ordering
compare Int
sentSz Int
len of
            Ordering
EQ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a b. b -> Either a b
Right ())
            Ordering
LT -> TVar Bool -> Fd -> ByteArray -> Int -> Int -> IO (Either Errno ())
sendInterruptibleLoop TVar Bool
interrupt Fd
fd ByteArray
arr (Int
off forall a. Num a => a -> a -> a
+ Int
sentSz) (Int
len forall a. Num a => a -> a -> a
- Int
sentSz)
            Ordering
GT -> forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Network.Unexceptional.Bytes.sendAll: send claimed to send too many bytes"

-- | Variant of 'send' that fails with @EAGAIN@ if the interrupt ever becomes true.
sendInterruptible ::
     TVar Bool
  -> Socket
  -> Bytes
  -> IO (Either Errno ())
sendInterruptible :: TVar Bool -> Socket -> Bytes -> IO (Either Errno ())
sendInterruptible !TVar Bool
interrupt Socket
s Bytes{ByteArray
array :: ByteArray
$sel:array:Bytes :: Bytes -> ByteArray
array,Int
offset :: Int
$sel:offset:Bytes :: Bytes -> Int
offset,$sel:length:Bytes :: Bytes -> Int
length=Int
len} = case Int
len of
  Int
0 -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a b. b -> Either a b
Right ())
  Int
_ -> forall a. TVar a -> IO a
STM.readTVarIO TVar Bool
interrupt forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Bool
True -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a b. a -> Either a b
Left Errno
EAGAIN)
    Bool
False -> forall r. Socket -> (CInt -> IO r) -> IO r
S.withFdSocket Socket
s forall a b. (a -> b) -> a -> b
$ \CInt
fd ->
      -- We attempt the first send without testing if the socket is in
      -- ready for writes. This is because it is uncommon for the transmit
      -- buffer to already be full.
      TVar Bool -> Fd -> ByteArray -> Int -> Int -> IO (Either Errno ())
sendInterruptibleLoop TVar Bool
interrupt (CInt -> Fd
Fd CInt
fd) ByteArray
array Int
offset Int
len

data Outcome = Ready | Interrupted

checkFinished :: TVar Bool -> STM ()
checkFinished :: TVar Bool -> STM ()
checkFinished = Bool -> STM ()
STM.check forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< forall a. TVar a -> STM a
STM.readTVar

waitUntilWriteable :: TVar Bool -> Fd -> IO Outcome
waitUntilWriteable :: TVar Bool -> Fd -> IO Outcome
waitUntilWriteable !TVar Bool
interrupt !Fd
fd = do
  (STM ()
isReadyAction,IO ()
deregister) <- Fd -> IO (STM (), IO ())
threadWaitWriteSTM Fd
fd
  Outcome
outcome <- forall a. STM a -> IO a
STM.atomically forall a b. (a -> b) -> a -> b
$ (STM ()
isReadyAction forall (f :: * -> *) a b. Functor f => f a -> b -> f b
$> Outcome
Ready) forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> (TVar Bool -> STM ()
checkFinished TVar Bool
interrupt forall (f :: * -> *) a b. Functor f => f a -> b -> f b
$> Outcome
Interrupted)
  IO ()
deregister
  forall (f :: * -> *) a. Applicative f => a -> f a
pure Outcome
outcome

-- | If this returns zero bytes, it means that the peer has
-- performed an orderly shutdown.
receive :: 
     Socket
  -> Int -- ^ Maximum number of bytes to receive
  -> IO (Either Errno Bytes)
receive :: Socket -> Int -> IO (Either Errno Bytes)
receive Socket
s Int
n = if Int
n forall a. Ord a => a -> a -> Bool
> Int
0
  then do
    MutableByteArray RealWorld
dst <- forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
PM.newByteArray Int
n
    Socket -> MutableBytes RealWorld -> IO (Either Errno Int)
MB.receive Socket
s (forall s. MutableByteArray s -> Int -> Int -> MutableBytes s
MutableBytes MutableByteArray RealWorld
dst Int
0 Int
n) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= MutableByteArray RealWorld
-> Int -> Either Errno Int -> IO (Either Errno Bytes)
handleRececeptionResult MutableByteArray RealWorld
dst Int
n
  else forall e a. Exception e => e -> IO a
throwIO NonpositiveReceptionSize
Types.NonpositiveReceptionSize

-- | If this returns zero bytes, it means that the peer has
-- performed an orderly shutdown.
receiveInterruptible :: 
     TVar Bool -- ^ Interrupt
  -> Socket
  -> Int -- ^ Maximum number of bytes to receive
  -> IO (Either Errno Bytes)
receiveInterruptible :: TVar Bool -> Socket -> Int -> IO (Either Errno Bytes)
receiveInterruptible !TVar Bool
interrupt Socket
s Int
n = if Int
n forall a. Ord a => a -> a -> Bool
> Int
0
  then do
    MutableByteArray RealWorld
dst <- forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
PM.newByteArray Int
n
    TVar Bool
-> Socket -> MutableBytes RealWorld -> IO (Either Errno Int)
MB.receiveInterruptible TVar Bool
interrupt Socket
s (forall s. MutableByteArray s -> Int -> Int -> MutableBytes s
MutableBytes MutableByteArray RealWorld
dst Int
0 Int
n) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= MutableByteArray RealWorld
-> Int -> Either Errno Int -> IO (Either Errno Bytes)
handleRececeptionResult MutableByteArray RealWorld
dst Int
n
  else forall e a. Exception e => e -> IO a
throwIO NonpositiveReceptionSize
Types.NonpositiveReceptionSize

handleRececeptionResult :: MutableByteArray RealWorld -> Int -> Either Errno Int -> IO (Either Errno Bytes)
handleRececeptionResult :: MutableByteArray RealWorld
-> Int -> Either Errno Int -> IO (Either Errno Bytes)
handleRececeptionResult !MutableByteArray RealWorld
dst !Int
n Either Errno Int
x = case Either Errno Int
x of
  Left Errno
e -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a b. a -> Either a b
Left Errno
e)
  Right Int
m -> do
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
m forall a. Ord a => a -> a -> Bool
< Int
n) (forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> Int -> m ()
PM.shrinkMutableByteArray MutableByteArray RealWorld
dst Int
m)
    ByteArray
dst' <- forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> m ByteArray
PM.unsafeFreezeByteArray MutableByteArray RealWorld
dst 
    forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a b. b -> Either a b
Right (ByteArray -> Int -> Int -> Bytes
Bytes ByteArray
dst' Int
0 Int
m))