{-# language BangPatterns #-}
{-# language DuplicateRecordFields #-}
{-# language PatternSynonyms #-}
{-# language LambdaCase #-}
{-# language NamedFieldPuns #-}

module Network.Unexceptional.MutableBytes
  ( receive
  , receiveInterruptible
  ) where

import Control.Applicative ((<|>))
import Control.Concurrent.STM (STM,TVar)
import Control.Exception (throwIO)
import Control.Monad ((<=<))
import Data.Bytes.Types (MutableBytes(MutableBytes))
import Data.Functor (($>))
import Data.Primitive (MutableByteArray)
import Foreign.C.Error (Errno)
import Foreign.C.Error.Pattern (pattern EWOULDBLOCK,pattern EAGAIN)
import GHC.Conc (threadWaitRead,threadWaitReadSTM)
import GHC.Exts (RealWorld)
import Network.Socket (Socket)
import System.Posix.Types (Fd(Fd))

import qualified Control.Concurrent.STM as STM
import qualified Data.Bytes.Types
import qualified Linux.Socket as X
import qualified Network.Socket as S
import qualified Network.Unexceptional.Types as Types
import qualified Posix.Socket as X

-- | Receive bytes from a socket. Receives at most N bytes, where N
-- is the size of the buffer. Returns the number of bytes that were
-- actually received.
receive ::
     Socket
  -> MutableBytes RealWorld -- ^ Slice of a buffer
  -> IO (Either Errno Int)
receive :: Socket -> MutableBytes RealWorld -> IO (Either Errno Int)
receive Socket
s MutableBytes{MutableByteArray RealWorld
$sel:array:MutableBytes :: forall s. MutableBytes s -> MutableByteArray s
array :: MutableByteArray RealWorld
array,Int
$sel:offset:MutableBytes :: forall s. MutableBytes s -> Int
offset :: Int
offset,$sel:length:MutableBytes :: forall s. MutableBytes s -> Int
length=Int
len} =
  if Int
len forall a. Ord a => a -> a -> Bool
> Int
0
    then forall r. Socket -> (CInt -> IO r) -> IO r
S.withFdSocket Socket
s forall a b. (a -> b) -> a -> b
$ \CInt
fd ->
      -- We attempt the first receive without testing if the socket is
      -- ready for reads.
      Fd
-> MutableByteArray RealWorld
-> Int
-> Int
-> IO (Either Errno Int)
receiveLoop (CInt -> Fd
Fd CInt
fd) MutableByteArray RealWorld
array Int
offset Int
len
    else forall e a. Exception e => e -> IO a
throwIO NonpositiveReceptionSize
Types.NonpositiveReceptionSize

receiveInterruptible ::
     TVar Bool -- ^ Interrupt
  -> Socket
  -> MutableBytes RealWorld -- ^ Slice of a buffer
  -> IO (Either Errno Int)
receiveInterruptible :: TVar Bool
-> Socket -> MutableBytes RealWorld -> IO (Either Errno Int)
receiveInterruptible !TVar Bool
interrupt Socket
s MutableBytes{MutableByteArray RealWorld
array :: MutableByteArray RealWorld
$sel:array:MutableBytes :: forall s. MutableBytes s -> MutableByteArray s
array,Int
offset :: Int
$sel:offset:MutableBytes :: forall s. MutableBytes s -> Int
offset,$sel:length:MutableBytes :: forall s. MutableBytes s -> Int
length=Int
len} =
  if Int
len forall a. Ord a => a -> a -> Bool
> Int
0
    then forall r. Socket -> (CInt -> IO r) -> IO r
S.withFdSocket Socket
s forall a b. (a -> b) -> a -> b
$ \CInt
fd ->
      -- We attempt the first receive without testing if the socket is
      -- ready for reads.
      TVar Bool
-> Fd
-> MutableByteArray RealWorld
-> Int
-> Int
-> IO (Either Errno Int)
receiveInterruptibleLoop TVar Bool
interrupt (CInt -> Fd
Fd CInt
fd) MutableByteArray RealWorld
array Int
offset Int
len
    else forall e a. Exception e => e -> IO a
throwIO NonpositiveReceptionSize
Types.NonpositiveReceptionSize

-- Does not wait for file descriptor to be ready. Only performs
-- a single successful recv syscall
receiveLoop :: Fd -> MutableByteArray RealWorld -> Int -> Int -> IO (Either Errno Int)
receiveLoop :: Fd
-> MutableByteArray RealWorld
-> Int
-> Int
-> IO (Either Errno Int)
receiveLoop !Fd
fd !MutableByteArray RealWorld
arr !Int
off !Int
len =
  Fd
-> MutableByteArray RealWorld
-> Int
-> CSize
-> MessageFlags 'Receive
-> IO (Either Errno CSize)
X.uninterruptibleReceiveMutableByteArray Fd
fd MutableByteArray RealWorld
arr Int
off (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
len) forall (m :: Message). MessageFlags m
X.dontWait forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Left Errno
e -> if Errno
e forall a. Eq a => a -> a -> Bool
== Errno
EAGAIN Bool -> Bool -> Bool
|| Errno
e forall a. Eq a => a -> a -> Bool
== Errno
EWOULDBLOCK
      then do
        Fd -> IO ()
threadWaitRead Fd
fd
        Fd
-> MutableByteArray RealWorld
-> Int
-> Int
-> IO (Either Errno Int)
receiveLoop Fd
fd MutableByteArray RealWorld
arr Int
off Int
len
      else forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a b. a -> Either a b
Left Errno
e)
    Right CSize
recvSzC ->
      let recvSz :: Int
recvSz = forall a b. (Integral a, Num b) => a -> b
fromIntegral CSize
recvSzC :: Int
       in case forall a. Ord a => a -> a -> Ordering
compare Int
recvSz Int
len of
            Ordering
GT -> forall e a. Exception e => e -> IO a
throwIO ReceivedTooManyBytes
Types.ReceivedTooManyBytes
            Ordering
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a b. b -> Either a b
Right Int
recvSz)

-- Does not wait for file descriptor to be ready. Only performs
-- a single successful recv syscall
receiveInterruptibleLoop :: TVar Bool -> Fd -> MutableByteArray RealWorld -> Int -> Int -> IO (Either Errno Int)
receiveInterruptibleLoop :: TVar Bool
-> Fd
-> MutableByteArray RealWorld
-> Int
-> Int
-> IO (Either Errno Int)
receiveInterruptibleLoop !TVar Bool
interrupt !Fd
fd !MutableByteArray RealWorld
arr !Int
off !Int
len =
  Fd
-> MutableByteArray RealWorld
-> Int
-> CSize
-> MessageFlags 'Receive
-> IO (Either Errno CSize)
X.uninterruptibleReceiveMutableByteArray Fd
fd MutableByteArray RealWorld
arr Int
off (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
len) forall (m :: Message). MessageFlags m
X.dontWait forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Left Errno
e -> if Errno
e forall a. Eq a => a -> a -> Bool
== Errno
EAGAIN Bool -> Bool -> Bool
|| Errno
e forall a. Eq a => a -> a -> Bool
== Errno
EWOULDBLOCK
      then TVar Bool -> Fd -> IO Outcome
waitUntilReadable TVar Bool
interrupt Fd
fd forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
        Outcome
Ready -> TVar Bool
-> Fd
-> MutableByteArray RealWorld
-> Int
-> Int
-> IO (Either Errno Int)
receiveInterruptibleLoop TVar Bool
interrupt Fd
fd MutableByteArray RealWorld
arr Int
off Int
len
        Outcome
Interrupted -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a b. a -> Either a b
Left Errno
EAGAIN)
      else forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a b. a -> Either a b
Left Errno
e)
    Right CSize
recvSzC ->
      let recvSz :: Int
recvSz = forall a b. (Integral a, Num b) => a -> b
fromIntegral CSize
recvSzC :: Int
       in case forall a. Ord a => a -> a -> Ordering
compare Int
recvSz Int
len of
            Ordering
GT -> forall e a. Exception e => e -> IO a
throwIO ReceivedTooManyBytes
Types.ReceivedTooManyBytes
            Ordering
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a b. b -> Either a b
Right Int
recvSz)

checkFinished :: TVar Bool -> STM ()
checkFinished :: TVar Bool -> STM ()
checkFinished = Bool -> STM ()
STM.check forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< forall a. TVar a -> STM a
STM.readTVar

data Outcome = Ready | Interrupted

waitUntilReadable :: TVar Bool -> Fd -> IO Outcome
waitUntilReadable :: TVar Bool -> Fd -> IO Outcome
waitUntilReadable !TVar Bool
interrupt !Fd
fd = do
  (STM ()
isReadyAction,IO ()
deregister) <- Fd -> IO (STM (), IO ())
threadWaitReadSTM Fd
fd
  Outcome
outcome <- forall a. STM a -> IO a
STM.atomically forall a b. (a -> b) -> a -> b
$ (STM ()
isReadyAction forall (f :: * -> *) a b. Functor f => f a -> b -> f b
$> Outcome
Ready) forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> (TVar Bool -> STM ()
checkFinished TVar Bool
interrupt forall (f :: * -> *) a b. Functor f => f a -> b -> f b
$> Outcome
Interrupted)
  IO ()
deregister
  forall (f :: * -> *) a. Applicative f => a -> f a
pure Outcome
outcome