{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE PatternSynonyms #-}
module Network.Unexceptional.ByteString
( send
, sendInterruptible
, receive
, receiveExactly
, receiveExactlyInterruptible
) where
import Control.Applicative ((<|>))
import Control.Concurrent.STM (STM, TVar)
import Control.Monad (when, (<=<))
import Data.ByteString.Internal (ByteString (BS))
import Data.Bytes.Types (MutableBytes (MutableBytes))
import Data.Functor (($>))
import Data.Primitive (ByteArray (ByteArray))
import Data.Primitive.Addr (Addr (Addr), plusAddr)
import Foreign.C.Error (Errno)
import Foreign.C.Error.Pattern (pattern EAGAIN, pattern EWOULDBLOCK)
import GHC.Conc (threadWaitWrite, threadWaitWriteSTM)
import GHC.ForeignPtr (ForeignPtr (ForeignPtr), ForeignPtrContents (PlainPtr))
import Network.Socket (Socket)
import System.Posix.Types (Fd (Fd))
import qualified Control.Concurrent.STM as STM
import qualified Data.ByteString.Unsafe as ByteString
import qualified Data.Primitive as PM
import qualified GHC.Exts as Exts
import qualified Linux.Socket as X
import qualified Network.Socket as S
import qualified Network.Unexceptional.MutableBytes as MB
import qualified Posix.Socket as X
send ::
Socket ->
ByteString ->
IO (Either Errno ())
send :: Socket -> ByteString -> IO (Either Errno ())
send Socket
s !ByteString
b =
Socket -> (CInt -> IO (Either Errno ())) -> IO (Either Errno ())
forall r. Socket -> (CInt -> IO r) -> IO r
S.withFdSocket Socket
s ((CInt -> IO (Either Errno ())) -> IO (Either Errno ()))
-> (CInt -> IO (Either Errno ())) -> IO (Either Errno ())
forall a b. (a -> b) -> a -> b
$ \CInt
fd -> ByteString
-> (CStringLen -> IO (Either Errno ())) -> IO (Either Errno ())
forall a. ByteString -> (CStringLen -> IO a) -> IO a
ByteString.unsafeUseAsCStringLen ByteString
b ((CStringLen -> IO (Either Errno ())) -> IO (Either Errno ()))
-> (CStringLen -> IO (Either Errno ())) -> IO (Either Errno ())
forall a b. (a -> b) -> a -> b
$ \(PM.Ptr Addr#
ptr, Int
len) ->
Fd -> Addr -> Int -> IO (Either Errno ())
sendLoop (CInt -> Fd
Fd CInt
fd) (Addr# -> Addr
Addr Addr#
ptr) Int
len
sendLoop :: Fd -> Addr -> Int -> IO (Either Errno ())
sendLoop :: Fd -> Addr -> Int -> IO (Either Errno ())
sendLoop !Fd
fd !Addr
addr !Int
len =
Fd
-> Addr -> CSize -> MessageFlags 'Send -> IO (Either Errno CSize)
X.uninterruptibleSend Fd
fd Addr
addr (Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
len) (MessageFlags 'Send
X.noSignal MessageFlags 'Send -> MessageFlags 'Send -> MessageFlags 'Send
forall a. Semigroup a => a -> a -> a
<> MessageFlags 'Send
forall (m :: Message). MessageFlags m
X.dontWait) IO (Either Errno CSize)
-> (Either Errno CSize -> IO (Either Errno ()))
-> IO (Either Errno ())
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Left Errno
e ->
if Errno
e Errno -> Errno -> Bool
forall a. Eq a => a -> a -> Bool
== Errno
EAGAIN Bool -> Bool -> Bool
|| Errno
e Errno -> Errno -> Bool
forall a. Eq a => a -> a -> Bool
== Errno
EWOULDBLOCK
then do
Fd -> IO ()
threadWaitWrite Fd
fd
Fd -> Addr -> Int -> IO (Either Errno ())
sendLoop Fd
fd Addr
addr Int
len
else Either Errno () -> IO (Either Errno ())
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Errno -> Either Errno ()
forall a b. a -> Either a b
Left Errno
e)
Right CSize
sentSzC ->
let sentSz :: Int
sentSz = CSize -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral CSize
sentSzC :: Int
in case Int -> Int -> Ordering
forall a. Ord a => a -> a -> Ordering
compare Int
sentSz Int
len of
Ordering
EQ -> Either Errno () -> IO (Either Errno ())
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (() -> Either Errno ()
forall a b. b -> Either a b
Right ())
Ordering
LT -> Fd -> Addr -> Int -> IO (Either Errno ())
sendLoop Fd
fd (Addr -> Int -> Addr
plusAddr Addr
addr Int
sentSz) (Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
sentSz)
Ordering
GT -> String -> IO (Either Errno ())
forall a. String -> IO a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Network.Unexceptional.ByteString.sendLoop: send claimed to send too many bytes"
sendInterruptible ::
TVar Bool ->
Socket ->
ByteString ->
IO (Either Errno ())
sendInterruptible :: TVar Bool -> Socket -> ByteString -> IO (Either Errno ())
sendInterruptible !TVar Bool
interrupt Socket
s !ByteString
b =
Socket -> (CInt -> IO (Either Errno ())) -> IO (Either Errno ())
forall r. Socket -> (CInt -> IO r) -> IO r
S.withFdSocket Socket
s ((CInt -> IO (Either Errno ())) -> IO (Either Errno ()))
-> (CInt -> IO (Either Errno ())) -> IO (Either Errno ())
forall a b. (a -> b) -> a -> b
$ \CInt
fd -> ByteString
-> (CStringLen -> IO (Either Errno ())) -> IO (Either Errno ())
forall a. ByteString -> (CStringLen -> IO a) -> IO a
ByteString.unsafeUseAsCStringLen ByteString
b ((CStringLen -> IO (Either Errno ())) -> IO (Either Errno ()))
-> (CStringLen -> IO (Either Errno ())) -> IO (Either Errno ())
forall a b. (a -> b) -> a -> b
$ \(PM.Ptr Addr#
ptr, Int
len) ->
TVar Bool -> Fd -> Addr -> Int -> IO (Either Errno ())
sendInterruptibleLoop TVar Bool
interrupt (CInt -> Fd
Fd CInt
fd) (Addr# -> Addr
Addr Addr#
ptr) Int
len
sendInterruptibleLoop :: TVar Bool -> Fd -> Addr -> Int -> IO (Either Errno ())
sendInterruptibleLoop :: TVar Bool -> Fd -> Addr -> Int -> IO (Either Errno ())
sendInterruptibleLoop !TVar Bool
interrupt !Fd
fd !Addr
addr !Int
len =
Fd
-> Addr -> CSize -> MessageFlags 'Send -> IO (Either Errno CSize)
X.uninterruptibleSend Fd
fd Addr
addr (Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
len) (MessageFlags 'Send
X.noSignal MessageFlags 'Send -> MessageFlags 'Send -> MessageFlags 'Send
forall a. Semigroup a => a -> a -> a
<> MessageFlags 'Send
forall (m :: Message). MessageFlags m
X.dontWait) IO (Either Errno CSize)
-> (Either Errno CSize -> IO (Either Errno ()))
-> IO (Either Errno ())
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Left Errno
e ->
if Errno
e Errno -> Errno -> Bool
forall a. Eq a => a -> a -> Bool
== Errno
EAGAIN Bool -> Bool -> Bool
|| Errno
e Errno -> Errno -> Bool
forall a. Eq a => a -> a -> Bool
== Errno
EWOULDBLOCK
then
TVar Bool -> Fd -> IO Outcome
waitUntilWriteable TVar Bool
interrupt Fd
fd IO Outcome
-> (Outcome -> IO (Either Errno ())) -> IO (Either Errno ())
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Outcome
Ready -> TVar Bool -> Fd -> Addr -> Int -> IO (Either Errno ())
sendInterruptibleLoop TVar Bool
interrupt Fd
fd Addr
addr Int
len
Outcome
Interrupted -> Either Errno () -> IO (Either Errno ())
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Errno -> Either Errno ()
forall a b. a -> Either a b
Left Errno
EAGAIN)
else Either Errno () -> IO (Either Errno ())
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Errno -> Either Errno ()
forall a b. a -> Either a b
Left Errno
e)
Right CSize
sentSzC ->
let sentSz :: Int
sentSz = CSize -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral CSize
sentSzC :: Int
in case Int -> Int -> Ordering
forall a. Ord a => a -> a -> Ordering
compare Int
sentSz Int
len of
Ordering
EQ -> Either Errno () -> IO (Either Errno ())
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (() -> Either Errno ()
forall a b. b -> Either a b
Right ())
Ordering
LT -> TVar Bool -> Fd -> Addr -> Int -> IO (Either Errno ())
sendInterruptibleLoop TVar Bool
interrupt Fd
fd (Addr -> Int -> Addr
plusAddr Addr
addr Int
sentSz) (Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
sentSz)
Ordering
GT -> String -> IO (Either Errno ())
forall a. String -> IO a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Network.Unexceptional.ByteString.sendInterruptibleLoop: send claimed to send too many bytes"
receive ::
Socket ->
Int ->
IO (Either Errno ByteString)
receive :: Socket -> Int -> IO (Either Errno ByteString)
receive Socket
s Int
n = do
MutableByteArray RealWorld
dst <- Int -> IO (MutableByteArray (PrimState IO))
forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
PM.newPinnedByteArray Int
n
Socket -> MutableBytes RealWorld -> IO (Either Errno Int)
MB.receive Socket
s (MutableByteArray RealWorld -> Int -> Int -> MutableBytes RealWorld
forall s. MutableByteArray s -> Int -> Int -> MutableBytes s
MutableBytes MutableByteArray RealWorld
dst Int
0 Int
n) IO (Either Errno Int)
-> (Either Errno Int -> IO (Either Errno ByteString))
-> IO (Either Errno ByteString)
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Left Errno
e -> Either Errno ByteString -> IO (Either Errno ByteString)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Errno -> Either Errno ByteString
forall a b. a -> Either a b
Left Errno
e)
Right Int
m -> do
Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
m Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
n) (MutableByteArray (PrimState IO) -> Int -> IO ()
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> Int -> m ()
PM.shrinkMutableByteArray MutableByteArray RealWorld
MutableByteArray (PrimState IO)
dst Int
m)
ByteArray ByteArray#
dst# <- MutableByteArray (PrimState IO) -> IO ByteArray
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> m ByteArray
PM.unsafeFreezeByteArray MutableByteArray RealWorld
MutableByteArray (PrimState IO)
dst
Either Errno ByteString -> IO (Either Errno ByteString)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ByteString -> Either Errno ByteString
forall a b. b -> Either a b
Right (ForeignPtr Word8 -> Int -> ByteString
BS (Addr# -> ForeignPtrContents -> ForeignPtr Word8
forall a. Addr# -> ForeignPtrContents -> ForeignPtr a
ForeignPtr (ByteArray# -> Addr#
Exts.byteArrayContents# ByteArray#
dst#) (MutableByteArray# RealWorld -> ForeignPtrContents
PlainPtr (ByteArray# -> MutableByteArray# RealWorld
forall a b. a -> b
Exts.unsafeCoerce# ByteArray#
dst#))) Int
m))
receiveExactly ::
Socket ->
Int ->
IO (Either Errno ByteString)
receiveExactly :: Socket -> Int -> IO (Either Errno ByteString)
receiveExactly !Socket
s !Int
n = do
MutableByteArray RealWorld
dst <- Int -> IO (MutableByteArray (PrimState IO))
forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
PM.newPinnedByteArray Int
n
Socket -> MutableBytes RealWorld -> IO (Either Errno ())
MB.receiveExactly Socket
s (MutableByteArray RealWorld -> Int -> Int -> MutableBytes RealWorld
forall s. MutableByteArray s -> Int -> Int -> MutableBytes s
MutableBytes MutableByteArray RealWorld
dst Int
0 Int
n) IO (Either Errno ())
-> (Either Errno () -> IO (Either Errno ByteString))
-> IO (Either Errno ByteString)
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Left Errno
e -> Either Errno ByteString -> IO (Either Errno ByteString)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Errno -> Either Errno ByteString
forall a b. a -> Either a b
Left Errno
e)
Right ()
_ -> do
ByteArray ByteArray#
dst# <- MutableByteArray (PrimState IO) -> IO ByteArray
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> m ByteArray
PM.unsafeFreezeByteArray MutableByteArray RealWorld
MutableByteArray (PrimState IO)
dst
Either Errno ByteString -> IO (Either Errno ByteString)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ByteString -> Either Errno ByteString
forall a b. b -> Either a b
Right (ForeignPtr Word8 -> Int -> ByteString
BS (Addr# -> ForeignPtrContents -> ForeignPtr Word8
forall a. Addr# -> ForeignPtrContents -> ForeignPtr a
ForeignPtr (ByteArray# -> Addr#
Exts.byteArrayContents# ByteArray#
dst#) (MutableByteArray# RealWorld -> ForeignPtrContents
PlainPtr (ByteArray# -> MutableByteArray# RealWorld
forall a b. a -> b
Exts.unsafeCoerce# ByteArray#
dst#))) Int
n))
receiveExactlyInterruptible ::
TVar Bool ->
Socket ->
Int ->
IO (Either Errno ByteString)
receiveExactlyInterruptible :: TVar Bool -> Socket -> Int -> IO (Either Errno ByteString)
receiveExactlyInterruptible !TVar Bool
intr !Socket
s !Int
n = do
MutableByteArray RealWorld
dst <- Int -> IO (MutableByteArray (PrimState IO))
forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
PM.newPinnedByteArray Int
n
TVar Bool
-> Socket -> MutableBytes RealWorld -> IO (Either Errno ())
MB.receiveExactlyInterruptible TVar Bool
intr Socket
s (MutableByteArray RealWorld -> Int -> Int -> MutableBytes RealWorld
forall s. MutableByteArray s -> Int -> Int -> MutableBytes s
MutableBytes MutableByteArray RealWorld
dst Int
0 Int
n) IO (Either Errno ())
-> (Either Errno () -> IO (Either Errno ByteString))
-> IO (Either Errno ByteString)
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Left Errno
e -> Either Errno ByteString -> IO (Either Errno ByteString)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Errno -> Either Errno ByteString
forall a b. a -> Either a b
Left Errno
e)
Right ()
_ -> do
ByteArray ByteArray#
dst# <- MutableByteArray (PrimState IO) -> IO ByteArray
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> m ByteArray
PM.unsafeFreezeByteArray MutableByteArray RealWorld
MutableByteArray (PrimState IO)
dst
Either Errno ByteString -> IO (Either Errno ByteString)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ByteString -> Either Errno ByteString
forall a b. b -> Either a b
Right (ForeignPtr Word8 -> Int -> ByteString
BS (Addr# -> ForeignPtrContents -> ForeignPtr Word8
forall a. Addr# -> ForeignPtrContents -> ForeignPtr a
ForeignPtr (ByteArray# -> Addr#
Exts.byteArrayContents# ByteArray#
dst#) (MutableByteArray# RealWorld -> ForeignPtrContents
PlainPtr (ByteArray# -> MutableByteArray# RealWorld
forall a b. a -> b
Exts.unsafeCoerce# ByteArray#
dst#))) Int
n))
waitUntilWriteable :: TVar Bool -> Fd -> IO Outcome
waitUntilWriteable :: TVar Bool -> Fd -> IO Outcome
waitUntilWriteable !TVar Bool
interrupt !Fd
fd = do
(STM ()
isReadyAction, IO ()
deregister) <- Fd -> IO (STM (), IO ())
threadWaitWriteSTM Fd
fd
Outcome
outcome <- STM Outcome -> IO Outcome
forall a. STM a -> IO a
STM.atomically (STM Outcome -> IO Outcome) -> STM Outcome -> IO Outcome
forall a b. (a -> b) -> a -> b
$ (STM ()
isReadyAction STM () -> Outcome -> STM Outcome
forall (f :: * -> *) a b. Functor f => f a -> b -> f b
$> Outcome
Ready) STM Outcome -> STM Outcome -> STM Outcome
forall a. STM a -> STM a -> STM a
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> (TVar Bool -> STM ()
checkFinished TVar Bool
interrupt STM () -> Outcome -> STM Outcome
forall (f :: * -> *) a b. Functor f => f a -> b -> f b
$> Outcome
Interrupted)
IO ()
deregister
Outcome -> IO Outcome
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Outcome
outcome
data Outcome = Ready | Interrupted
checkFinished :: TVar Bool -> STM ()
checkFinished :: TVar Bool -> STM ()
checkFinished = Bool -> STM ()
STM.check (Bool -> STM ()) -> (TVar Bool -> STM Bool) -> TVar Bool -> STM ()
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< TVar Bool -> STM Bool
forall a. TVar a -> STM a
STM.readTVar