{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PatternSynonyms #-}
module Network.Unexceptional
( accept_
, socket
, connect
, connectInterruptible
) where
import Control.Applicative ((<|>))
import Control.Concurrent.STM (STM, TVar)
import Control.Exception (mask_)
import Control.Monad ((<=<))
import Data.Functor (($>))
import Foreign.C.Error (Errno (Errno))
import Foreign.C.Error.Pattern (pattern EAGAIN, pattern EINPROGRESS, pattern EINTR, pattern EWOULDBLOCK)
import Foreign.Marshal.Alloc (allocaBytes)
import Foreign.Ptr (castPtr, nullPtr)
import GHC.Conc (threadWaitRead, threadWaitWrite, threadWaitWriteSTM)
import GHC.Exts (Ptr)
import Network.Socket (SockAddr, Socket, SocketOption (SoError), getSocketOption, mkSocket, withFdSocket)
import Network.Socket.Address (SocketAddress, pokeSocketAddress, sizeOfSocketAddress)
import System.Posix.Types (Fd (Fd))
import qualified Control.Concurrent.STM as STM
import qualified Linux.Socket as X
import qualified Network.Socket as N
import qualified Posix.Socket as X
accept_ :: Socket -> IO (Either Errno Socket)
accept_ :: Socket -> IO (Either Errno Socket)
accept_ Socket
listing_sock = Socket
-> (CInt -> IO (Either Errno Socket)) -> IO (Either Errno Socket)
forall r. Socket -> (CInt -> IO r) -> IO r
withFdSocket Socket
listing_sock ((CInt -> IO (Either Errno Socket)) -> IO (Either Errno Socket))
-> (CInt -> IO (Either Errno Socket)) -> IO (Either Errno Socket)
forall a b. (a -> b) -> a -> b
$ \CInt
listing_fd -> do
let acceptLoop :: IO (Either Errno Socket)
acceptLoop = do
Fd -> IO ()
threadWaitRead (CInt -> Fd
Fd CInt
listing_fd)
Fd -> SocketFlags -> IO (Either Errno Fd)
X.uninterruptibleAccept4_ (CInt -> Fd
Fd CInt
listing_fd) (SocketFlags
X.nonblocking SocketFlags -> SocketFlags -> SocketFlags
forall a. Semigroup a => a -> a -> a
<> SocketFlags
X.closeOnExec) IO (Either Errno Fd)
-> (Either Errno Fd -> IO (Either Errno Socket))
-> IO (Either Errno Socket)
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Left Errno
e ->
if Errno
e Errno -> Errno -> Bool
forall a. Eq a => a -> a -> Bool
== Errno
EAGAIN Bool -> Bool -> Bool
|| Errno
e Errno -> Errno -> Bool
forall a. Eq a => a -> a -> Bool
== Errno
EWOULDBLOCK
then IO (Either Errno Socket)
acceptLoop
else Either Errno Socket -> IO (Either Errno Socket)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Errno -> Either Errno Socket
forall a b. a -> Either a b
Left Errno
e)
Right (Fd CInt
fd) -> (Socket -> Either Errno Socket)
-> IO Socket -> IO (Either Errno Socket)
forall a b. (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Socket -> Either Errno Socket
forall a b. b -> Either a b
Right (CInt -> IO Socket
mkSocket CInt
fd)
IO (Either Errno Socket)
acceptLoop
connect :: Socket -> SockAddr -> IO (Either Errno ())
connect :: Socket -> SockAddr -> IO (Either Errno ())
connect Socket
s SockAddr
sa = SockAddr
-> (Ptr SockAddr -> Int -> IO (Either Errno ()))
-> IO (Either Errno ())
forall sa a.
SocketAddress sa =>
sa -> (Ptr sa -> Int -> IO a) -> IO a
withSocketAddress SockAddr
sa ((Ptr SockAddr -> Int -> IO (Either Errno ()))
-> IO (Either Errno ()))
-> (Ptr SockAddr -> Int -> IO (Either Errno ()))
-> IO (Either Errno ())
forall a b. (a -> b) -> a -> b
$ \Ptr SockAddr
p_sa Int
sz -> Socket -> (CInt -> IO (Either Errno ())) -> IO (Either Errno ())
forall r. Socket -> (CInt -> IO r) -> IO r
withFdSocket Socket
s ((CInt -> IO (Either Errno ())) -> IO (Either Errno ()))
-> (CInt -> IO (Either Errno ())) -> IO (Either Errno ())
forall a b. (a -> b) -> a -> b
$ \CInt
fd ->
let loop :: IO (Either Errno ())
loop = do
Either Errno ()
r <- Fd -> Ptr SockAddr -> Int -> IO (Either Errno ())
forall a. Fd -> Ptr a -> Int -> IO (Either Errno ())
X.uninterruptibleConnectPtr (CInt -> Fd
Fd CInt
fd) Ptr SockAddr
p_sa Int
sz
case Either Errno ()
r of
Right ()
_ -> Either Errno () -> IO (Either Errno ())
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (() -> Either Errno ()
forall a b. b -> Either a b
Right ())
Left Errno
err -> case Errno
err of
Errno
EINTR -> IO (Either Errno ())
loop
Errno
EINPROGRESS -> do
Fd -> IO ()
threadWaitWrite (CInt -> Fd
Fd CInt
fd)
Int
errB <- Socket -> SocketOption -> IO Int
getSocketOption Socket
s SocketOption
SoError
case Int
errB of
Int
0 -> Either Errno () -> IO (Either Errno ())
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (() -> Either Errno ()
forall a b. b -> Either a b
Right ())
Int
_ -> Either Errno () -> IO (Either Errno ())
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Errno -> Either Errno ()
forall a b. a -> Either a b
Left (CInt -> Errno
Errno (Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
errB)))
Errno
_ -> Either Errno () -> IO (Either Errno ())
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Errno -> Either Errno ()
forall a b. a -> Either a b
Left Errno
err)
in IO (Either Errno ())
loop
connectInterruptible :: TVar Bool -> Socket -> SockAddr -> IO (Either Errno ())
connectInterruptible :: TVar Bool -> Socket -> SockAddr -> IO (Either Errno ())
connectInterruptible !TVar Bool
interrupt Socket
s SockAddr
sa = SockAddr
-> (Ptr SockAddr -> Int -> IO (Either Errno ()))
-> IO (Either Errno ())
forall sa a.
SocketAddress sa =>
sa -> (Ptr sa -> Int -> IO a) -> IO a
withSocketAddress SockAddr
sa ((Ptr SockAddr -> Int -> IO (Either Errno ()))
-> IO (Either Errno ()))
-> (Ptr SockAddr -> Int -> IO (Either Errno ()))
-> IO (Either Errno ())
forall a b. (a -> b) -> a -> b
$ \Ptr SockAddr
p_sa Int
sz -> Socket -> (CInt -> IO (Either Errno ())) -> IO (Either Errno ())
forall r. Socket -> (CInt -> IO r) -> IO r
withFdSocket Socket
s ((CInt -> IO (Either Errno ())) -> IO (Either Errno ()))
-> (CInt -> IO (Either Errno ())) -> IO (Either Errno ())
forall a b. (a -> b) -> a -> b
$ \CInt
fd ->
let loop :: IO (Either Errno ())
loop = do
Either Errno ()
r <- Fd -> Ptr SockAddr -> Int -> IO (Either Errno ())
forall a. Fd -> Ptr a -> Int -> IO (Either Errno ())
X.uninterruptibleConnectPtr (CInt -> Fd
Fd CInt
fd) Ptr SockAddr
p_sa Int
sz
case Either Errno ()
r of
Right ()
_ -> Either Errno () -> IO (Either Errno ())
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (() -> Either Errno ()
forall a b. b -> Either a b
Right ())
Left Errno
err -> case Errno
err of
Errno
EINTR -> IO (Either Errno ())
loop
Errno
EINPROGRESS ->
TVar Bool -> Fd -> IO Outcome
waitUntilWriteable TVar Bool
interrupt (CInt -> Fd
Fd CInt
fd) IO Outcome
-> (Outcome -> IO (Either Errno ())) -> IO (Either Errno ())
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Outcome
Interrupted -> Either Errno () -> IO (Either Errno ())
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Errno -> Either Errno ()
forall a b. a -> Either a b
Left Errno
EAGAIN)
Outcome
Ready -> do
Int
errB <- Socket -> SocketOption -> IO Int
getSocketOption Socket
s SocketOption
SoError
case Int
errB of
Int
0 -> Either Errno () -> IO (Either Errno ())
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (() -> Either Errno ()
forall a b. b -> Either a b
Right ())
Int
_ -> Either Errno () -> IO (Either Errno ())
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Errno -> Either Errno ()
forall a b. a -> Either a b
Left (CInt -> Errno
Errno (Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
errB)))
Errno
_ -> Either Errno () -> IO (Either Errno ())
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Errno -> Either Errno ()
forall a b. a -> Either a b
Left Errno
err)
in IO (Either Errno ())
loop
withSocketAddress :: (SocketAddress sa) => sa -> (Ptr sa -> Int -> IO a) -> IO a
withSocketAddress :: forall sa a.
SocketAddress sa =>
sa -> (Ptr sa -> Int -> IO a) -> IO a
withSocketAddress sa
addr Ptr sa -> Int -> IO a
f = do
let sz :: Int
sz = sa -> Int
forall sa. SocketAddress sa => sa -> Int
sizeOfSocketAddress sa
addr
if Int
sz Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0
then Ptr sa -> Int -> IO a
f Ptr sa
forall a. Ptr a
nullPtr Int
0
else Int -> (Ptr Any -> IO a) -> IO a
forall a b. Int -> (Ptr a -> IO b) -> IO b
allocaBytes Int
sz ((Ptr Any -> IO a) -> IO a) -> (Ptr Any -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ \Ptr Any
p -> Ptr Any -> sa -> IO ()
forall a. Ptr a -> sa -> IO ()
forall sa a. SocketAddress sa => Ptr a -> sa -> IO ()
pokeSocketAddress Ptr Any
p sa
addr IO () -> IO a -> IO a
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Ptr sa -> Int -> IO a
f (Ptr Any -> Ptr sa
forall a b. Ptr a -> Ptr b
castPtr Ptr Any
p) Int
sz
data Outcome = Ready | Interrupted
checkFinished :: TVar Bool -> STM ()
checkFinished :: TVar Bool -> STM ()
checkFinished = Bool -> STM ()
STM.check (Bool -> STM ()) -> (TVar Bool -> STM Bool) -> TVar Bool -> STM ()
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< TVar Bool -> STM Bool
forall a. TVar a -> STM a
STM.readTVar
waitUntilWriteable :: TVar Bool -> Fd -> IO Outcome
waitUntilWriteable :: TVar Bool -> Fd -> IO Outcome
waitUntilWriteable !TVar Bool
interrupt !Fd
fd = do
(STM ()
isReadyAction, IO ()
deregister) <- Fd -> IO (STM (), IO ())
threadWaitWriteSTM Fd
fd
Outcome
outcome <- STM Outcome -> IO Outcome
forall a. STM a -> IO a
STM.atomically (STM Outcome -> IO Outcome) -> STM Outcome -> IO Outcome
forall a b. (a -> b) -> a -> b
$ (STM ()
isReadyAction STM () -> Outcome -> STM Outcome
forall (f :: * -> *) a b. Functor f => f a -> b -> f b
$> Outcome
Ready) STM Outcome -> STM Outcome -> STM Outcome
forall a. STM a -> STM a -> STM a
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> (TVar Bool -> STM ()
checkFinished TVar Bool
interrupt STM () -> Outcome -> STM Outcome
forall (f :: * -> *) a b. Functor f => f a -> b -> f b
$> Outcome
Interrupted)
IO ()
deregister
Outcome -> IO Outcome
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Outcome
outcome
socket ::
N.Family ->
N.SocketType ->
N.ProtocolNumber ->
IO (Either Errno Socket)
socket :: Family -> SocketType -> CInt -> IO (Either Errno Socket)
socket !Family
fam !SocketType
stype !CInt
protocol = case SocketType
stype of
SocketType
N.Stream -> Type -> IO (Either Errno Socket)
finish Type
X.stream
SocketType
N.Datagram -> Type -> IO (Either Errno Socket)
finish Type
X.datagram
SocketType
_ -> String -> IO (Either Errno Socket)
forall a. String -> IO a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Network.Unexceptional.socket: Currently only supports stream and datagram types"
where
finish :: Type -> IO (Either Errno Socket)
finish !Type
sockTy = IO (Either Errno Socket) -> IO (Either Errno Socket)
forall a. IO a -> IO a
mask_ (IO (Either Errno Socket) -> IO (Either Errno Socket))
-> IO (Either Errno Socket) -> IO (Either Errno Socket)
forall a b. (a -> b) -> a -> b
$ do
Family -> Type -> Protocol -> IO (Either Errno Fd)
X.uninterruptibleSocket (CInt -> Family
X.Family (Family -> CInt
N.packFamily Family
fam)) (SocketFlags -> Type -> Type
X.applySocketFlags (SocketFlags
X.closeOnExec SocketFlags -> SocketFlags -> SocketFlags
forall a. Semigroup a => a -> a -> a
<> SocketFlags
X.nonblocking) Type
sockTy) (CInt -> Protocol
X.Protocol CInt
protocol) IO (Either Errno Fd)
-> (Either Errno Fd -> IO (Either Errno Socket))
-> IO (Either Errno Socket)
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Left Errno
err -> Either Errno Socket -> IO (Either Errno Socket)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Errno -> Either Errno Socket
forall a b. a -> Either a b
Left Errno
err)
Right (Fd CInt
fd) -> do
Socket
s <- CInt -> IO Socket
mkSocket CInt
fd
Either Errno Socket -> IO (Either Errno Socket)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Socket -> Either Errno Socket
forall a b. b -> Either a b
Right Socket
s)