{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PatternSynonyms #-}

module Network.Unexceptional
  ( accept_
  , socket
  , connect
  , connectInterruptible
  ) where

import Control.Applicative ((<|>))
import Control.Concurrent.STM (STM, TVar)
import Control.Exception (mask_)
import Control.Monad ((<=<))
import Data.Functor (($>))
import Foreign.C.Error (Errno (Errno))
import Foreign.C.Error.Pattern (pattern EAGAIN, pattern EINPROGRESS, pattern EINTR, pattern EWOULDBLOCK)
import Foreign.Marshal.Alloc (allocaBytes)
import Foreign.Ptr (castPtr, nullPtr)
import GHC.Conc (threadWaitRead, threadWaitWrite, threadWaitWriteSTM)
import GHC.Exts (Ptr)
import Network.Socket (SockAddr, Socket, SocketOption (SoError), getSocketOption, mkSocket, withFdSocket)
import Network.Socket.Address (SocketAddress, pokeSocketAddress, sizeOfSocketAddress)
import System.Posix.Types (Fd (Fd))

import qualified Control.Concurrent.STM as STM
import qualified Linux.Socket as X
import qualified Network.Socket as N
import qualified Posix.Socket as X

{- | Accept a connection. See the documentation in @network@ for @accept@.

Note: This may leak a file descriptor if an asynchronous exception is
received while this function is running.
-}
accept_ :: Socket -> IO (Either Errno Socket)
accept_ :: Socket -> IO (Either Errno Socket)
accept_ Socket
listing_sock = Socket
-> (CInt -> IO (Either Errno Socket)) -> IO (Either Errno Socket)
forall r. Socket -> (CInt -> IO r) -> IO r
withFdSocket Socket
listing_sock ((CInt -> IO (Either Errno Socket)) -> IO (Either Errno Socket))
-> (CInt -> IO (Either Errno Socket)) -> IO (Either Errno Socket)
forall a b. (a -> b) -> a -> b
$ \CInt
listing_fd -> do
  let acceptLoop :: IO (Either Errno Socket)
acceptLoop = do
        Fd -> IO ()
threadWaitRead (CInt -> Fd
Fd CInt
listing_fd)
        Fd -> SocketFlags -> IO (Either Errno Fd)
X.uninterruptibleAccept4_ (CInt -> Fd
Fd CInt
listing_fd) (SocketFlags
X.nonblocking SocketFlags -> SocketFlags -> SocketFlags
forall a. Semigroup a => a -> a -> a
<> SocketFlags
X.closeOnExec) IO (Either Errno Fd)
-> (Either Errno Fd -> IO (Either Errno Socket))
-> IO (Either Errno Socket)
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
          Left Errno
e ->
            if Errno
e Errno -> Errno -> Bool
forall a. Eq a => a -> a -> Bool
== Errno
EAGAIN Bool -> Bool -> Bool
|| Errno
e Errno -> Errno -> Bool
forall a. Eq a => a -> a -> Bool
== Errno
EWOULDBLOCK
              then IO (Either Errno Socket)
acceptLoop
              else Either Errno Socket -> IO (Either Errno Socket)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Errno -> Either Errno Socket
forall a b. a -> Either a b
Left Errno
e)
          Right (Fd CInt
fd) -> (Socket -> Either Errno Socket)
-> IO Socket -> IO (Either Errno Socket)
forall a b. (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Socket -> Either Errno Socket
forall a b. b -> Either a b
Right (CInt -> IO Socket
mkSocket CInt
fd)
  IO (Either Errno Socket)
acceptLoop

{- | Connect to a socket address. See the documentation in @network@
for @connect@.
-}
connect :: Socket -> SockAddr -> IO (Either Errno ())
connect :: Socket -> SockAddr -> IO (Either Errno ())
connect Socket
s SockAddr
sa = SockAddr
-> (Ptr SockAddr -> Int -> IO (Either Errno ()))
-> IO (Either Errno ())
forall sa a.
SocketAddress sa =>
sa -> (Ptr sa -> Int -> IO a) -> IO a
withSocketAddress SockAddr
sa ((Ptr SockAddr -> Int -> IO (Either Errno ()))
 -> IO (Either Errno ()))
-> (Ptr SockAddr -> Int -> IO (Either Errno ()))
-> IO (Either Errno ())
forall a b. (a -> b) -> a -> b
$ \Ptr SockAddr
p_sa Int
sz -> Socket -> (CInt -> IO (Either Errno ())) -> IO (Either Errno ())
forall r. Socket -> (CInt -> IO r) -> IO r
withFdSocket Socket
s ((CInt -> IO (Either Errno ())) -> IO (Either Errno ()))
-> (CInt -> IO (Either Errno ())) -> IO (Either Errno ())
forall a b. (a -> b) -> a -> b
$ \CInt
fd ->
  let loop :: IO (Either Errno ())
loop = do
        Either Errno ()
r <- Fd -> Ptr SockAddr -> Int -> IO (Either Errno ())
forall a. Fd -> Ptr a -> Int -> IO (Either Errno ())
X.uninterruptibleConnectPtr (CInt -> Fd
Fd CInt
fd) Ptr SockAddr
p_sa Int
sz
        case Either Errno ()
r of
          Right ()
_ -> Either Errno () -> IO (Either Errno ())
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (() -> Either Errno ()
forall a b. b -> Either a b
Right ())
          Left Errno
err -> case Errno
err of
            Errno
EINTR -> IO (Either Errno ())
loop
            Errno
EINPROGRESS -> do
              Fd -> IO ()
threadWaitWrite (CInt -> Fd
Fd CInt
fd)
              Int
errB <- Socket -> SocketOption -> IO Int
getSocketOption Socket
s SocketOption
SoError
              case Int
errB of
                Int
0 -> Either Errno () -> IO (Either Errno ())
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (() -> Either Errno ()
forall a b. b -> Either a b
Right ())
                Int
_ -> Either Errno () -> IO (Either Errno ())
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Errno -> Either Errno ()
forall a b. a -> Either a b
Left (CInt -> Errno
Errno (Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
errB)))
            Errno
_ -> Either Errno () -> IO (Either Errno ())
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Errno -> Either Errno ()
forall a b. a -> Either a b
Left Errno
err)
   in IO (Either Errno ())
loop

{- | Variant of 'connect' that can be interrupted by setting the interrupt
variable to @True@. If interrupted in this way, this function returns
@EAGAIN@. For example, to attempt to connect for no more than 1 second:

> interrupt <- Control.Concurrent.STM.registerDelay 1_000_000
> connectInterruptible interrupt sock sockAddr
-}
connectInterruptible :: TVar Bool -> Socket -> SockAddr -> IO (Either Errno ())
connectInterruptible :: TVar Bool -> Socket -> SockAddr -> IO (Either Errno ())
connectInterruptible !TVar Bool
interrupt Socket
s SockAddr
sa = SockAddr
-> (Ptr SockAddr -> Int -> IO (Either Errno ()))
-> IO (Either Errno ())
forall sa a.
SocketAddress sa =>
sa -> (Ptr sa -> Int -> IO a) -> IO a
withSocketAddress SockAddr
sa ((Ptr SockAddr -> Int -> IO (Either Errno ()))
 -> IO (Either Errno ()))
-> (Ptr SockAddr -> Int -> IO (Either Errno ()))
-> IO (Either Errno ())
forall a b. (a -> b) -> a -> b
$ \Ptr SockAddr
p_sa Int
sz -> Socket -> (CInt -> IO (Either Errno ())) -> IO (Either Errno ())
forall r. Socket -> (CInt -> IO r) -> IO r
withFdSocket Socket
s ((CInt -> IO (Either Errno ())) -> IO (Either Errno ()))
-> (CInt -> IO (Either Errno ())) -> IO (Either Errno ())
forall a b. (a -> b) -> a -> b
$ \CInt
fd ->
  let loop :: IO (Either Errno ())
loop = do
        Either Errno ()
r <- Fd -> Ptr SockAddr -> Int -> IO (Either Errno ())
forall a. Fd -> Ptr a -> Int -> IO (Either Errno ())
X.uninterruptibleConnectPtr (CInt -> Fd
Fd CInt
fd) Ptr SockAddr
p_sa Int
sz
        case Either Errno ()
r of
          Right ()
_ -> Either Errno () -> IO (Either Errno ())
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (() -> Either Errno ()
forall a b. b -> Either a b
Right ())
          Left Errno
err -> case Errno
err of
            Errno
EINTR -> IO (Either Errno ())
loop
            Errno
EINPROGRESS ->
              TVar Bool -> Fd -> IO Outcome
waitUntilWriteable TVar Bool
interrupt (CInt -> Fd
Fd CInt
fd) IO Outcome
-> (Outcome -> IO (Either Errno ())) -> IO (Either Errno ())
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
                Outcome
Interrupted -> Either Errno () -> IO (Either Errno ())
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Errno -> Either Errno ()
forall a b. a -> Either a b
Left Errno
EAGAIN)
                Outcome
Ready -> do
                  Int
errB <- Socket -> SocketOption -> IO Int
getSocketOption Socket
s SocketOption
SoError
                  case Int
errB of
                    Int
0 -> Either Errno () -> IO (Either Errno ())
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (() -> Either Errno ()
forall a b. b -> Either a b
Right ())
                    Int
_ -> Either Errno () -> IO (Either Errno ())
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Errno -> Either Errno ()
forall a b. a -> Either a b
Left (CInt -> Errno
Errno (Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
errB)))
            Errno
_ -> Either Errno () -> IO (Either Errno ())
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Errno -> Either Errno ()
forall a b. a -> Either a b
Left Errno
err)
   in IO (Either Errno ())
loop

-- Copied this from the network library. TODO: See if network can
-- just export this.
withSocketAddress :: (SocketAddress sa) => sa -> (Ptr sa -> Int -> IO a) -> IO a
withSocketAddress :: forall sa a.
SocketAddress sa =>
sa -> (Ptr sa -> Int -> IO a) -> IO a
withSocketAddress sa
addr Ptr sa -> Int -> IO a
f = do
  let sz :: Int
sz = sa -> Int
forall sa. SocketAddress sa => sa -> Int
sizeOfSocketAddress sa
addr
  if Int
sz Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0
    then Ptr sa -> Int -> IO a
f Ptr sa
forall a. Ptr a
nullPtr Int
0
    else Int -> (Ptr Any -> IO a) -> IO a
forall a b. Int -> (Ptr a -> IO b) -> IO b
allocaBytes Int
sz ((Ptr Any -> IO a) -> IO a) -> (Ptr Any -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ \Ptr Any
p -> Ptr Any -> sa -> IO ()
forall a. Ptr a -> sa -> IO ()
forall sa a. SocketAddress sa => Ptr a -> sa -> IO ()
pokeSocketAddress Ptr Any
p sa
addr IO () -> IO a -> IO a
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Ptr sa -> Int -> IO a
f (Ptr Any -> Ptr sa
forall a b. Ptr a -> Ptr b
castPtr Ptr Any
p) Int
sz

data Outcome = Ready | Interrupted

checkFinished :: TVar Bool -> STM ()
checkFinished :: TVar Bool -> STM ()
checkFinished = Bool -> STM ()
STM.check (Bool -> STM ()) -> (TVar Bool -> STM Bool) -> TVar Bool -> STM ()
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< TVar Bool -> STM Bool
forall a. TVar a -> STM a
STM.readTVar

waitUntilWriteable :: TVar Bool -> Fd -> IO Outcome
waitUntilWriteable :: TVar Bool -> Fd -> IO Outcome
waitUntilWriteable !TVar Bool
interrupt !Fd
fd = do
  (STM ()
isReadyAction, IO ()
deregister) <- Fd -> IO (STM (), IO ())
threadWaitWriteSTM Fd
fd
  Outcome
outcome <- STM Outcome -> IO Outcome
forall a. STM a -> IO a
STM.atomically (STM Outcome -> IO Outcome) -> STM Outcome -> IO Outcome
forall a b. (a -> b) -> a -> b
$ (STM ()
isReadyAction STM () -> Outcome -> STM Outcome
forall (f :: * -> *) a b. Functor f => f a -> b -> f b
$> Outcome
Ready) STM Outcome -> STM Outcome -> STM Outcome
forall a. STM a -> STM a -> STM a
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> (TVar Bool -> STM ()
checkFinished TVar Bool
interrupt STM () -> Outcome -> STM Outcome
forall (f :: * -> *) a b. Functor f => f a -> b -> f b
$> Outcome
Interrupted)
  IO ()
deregister
  Outcome -> IO Outcome
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Outcome
outcome

{- | Create a socket. See the documentation in @network@ for @socket@.

There is no interruptible variant of this function because it cannot
block. (It does not actually perform network activity.)
-}
socket ::
  N.Family -> -- Family Name (usually AF_INET)
  N.SocketType -> -- Socket Type (usually Stream)
  N.ProtocolNumber -> -- Protocol Number (getProtocolByName to find value)
  IO (Either Errno Socket) -- Unconnected Socket
socket :: Family -> SocketType -> CInt -> IO (Either Errno Socket)
socket !Family
fam !SocketType
stype !CInt
protocol = case SocketType
stype of
  SocketType
N.Stream -> Type -> IO (Either Errno Socket)
finish Type
X.stream
  SocketType
N.Datagram -> Type -> IO (Either Errno Socket)
finish Type
X.datagram
  SocketType
_ -> String -> IO (Either Errno Socket)
forall a. String -> IO a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Network.Unexceptional.socket: Currently only supports stream and datagram types"
 where
  finish :: Type -> IO (Either Errno Socket)
finish !Type
sockTy = IO (Either Errno Socket) -> IO (Either Errno Socket)
forall a. IO a -> IO a
mask_ (IO (Either Errno Socket) -> IO (Either Errno Socket))
-> IO (Either Errno Socket) -> IO (Either Errno Socket)
forall a b. (a -> b) -> a -> b
$ do
    Family -> Type -> Protocol -> IO (Either Errno Fd)
X.uninterruptibleSocket (CInt -> Family
X.Family (Family -> CInt
N.packFamily Family
fam)) (SocketFlags -> Type -> Type
X.applySocketFlags (SocketFlags
X.closeOnExec SocketFlags -> SocketFlags -> SocketFlags
forall a. Semigroup a => a -> a -> a
<> SocketFlags
X.nonblocking) Type
sockTy) (CInt -> Protocol
X.Protocol CInt
protocol) IO (Either Errno Fd)
-> (Either Errno Fd -> IO (Either Errno Socket))
-> IO (Either Errno Socket)
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
      Left Errno
err -> Either Errno Socket -> IO (Either Errno Socket)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Errno -> Either Errno Socket
forall a b. a -> Either a b
Left Errno
err)
      Right (Fd CInt
fd) -> do
        Socket
s <- CInt -> IO Socket
mkSocket CInt
fd
        Either Errno Socket -> IO (Either Errno Socket)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Socket -> Either Errno Socket
forall a b. b -> Either a b
Right Socket
s)