{-# language BangPatterns #-}
{-# language DuplicateRecordFields #-}
{-# language PatternSynonyms #-}
{-# language LambdaCase #-}
{-# language NamedFieldPuns #-}

module Network.Unexceptional
  ( accept_
  , socket
  , connect
  , connectInterruptible
  ) where

import Control.Exception (mask_)
import Control.Monad ((<=<))
import Control.Applicative ((<|>))
import Data.Functor (($>))
import Network.Socket (Socket,SockAddr,mkSocket,withFdSocket,SocketOption(SoError),getSocketOption)
import Network.Socket.Address (SocketAddress,pokeSocketAddress,sizeOfSocketAddress)
import Foreign.Marshal.Alloc (allocaBytes)
import Foreign.C.Error (Errno(Errno))
import GHC.Conc (threadWaitRead,threadWaitWrite,threadWaitWriteSTM)
import Foreign.C.Error.Pattern (pattern EWOULDBLOCK,pattern EAGAIN,pattern EINPROGRESS,pattern EINTR)
import System.Posix.Types (Fd(Fd))
import Foreign.Ptr (castPtr,nullPtr)
import GHC.Exts (Ptr)
import Control.Concurrent.STM (STM,TVar)

import qualified Control.Concurrent.STM as STM
import qualified Linux.Socket as X
import qualified Posix.Socket as X
import qualified Network.Socket as N

-- | Accept a connection. See the documentation in @network@ for @accept@.
--
-- Note: This may leak a file descriptor if an asynchronous exception is
-- received while this function is running.
accept_ :: Socket -> IO (Either Errno Socket)
accept_ :: Socket -> IO (Either Errno Socket)
accept_ Socket
listing_sock = Socket
-> (CInt -> IO (Either Errno Socket)) -> IO (Either Errno Socket)
forall r. Socket -> (CInt -> IO r) -> IO r
withFdSocket Socket
listing_sock ((CInt -> IO (Either Errno Socket)) -> IO (Either Errno Socket))
-> (CInt -> IO (Either Errno Socket)) -> IO (Either Errno Socket)
forall a b. (a -> b) -> a -> b
$ \CInt
listing_fd -> do
  let acceptLoop :: IO (Either Errno Socket)
acceptLoop = do
        Fd -> IO ()
threadWaitRead (CInt -> Fd
Fd CInt
listing_fd)
        Fd -> SocketFlags -> IO (Either Errno Fd)
X.uninterruptibleAccept4_ (CInt -> Fd
Fd CInt
listing_fd) (SocketFlags
X.nonblocking SocketFlags -> SocketFlags -> SocketFlags
forall a. Semigroup a => a -> a -> a
<> SocketFlags
X.closeOnExec) IO (Either Errno Fd)
-> (Either Errno Fd -> IO (Either Errno Socket))
-> IO (Either Errno Socket)
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
          Left Errno
e -> if Errno
e Errno -> Errno -> Bool
forall a. Eq a => a -> a -> Bool
== Errno
EAGAIN Bool -> Bool -> Bool
|| Errno
e Errno -> Errno -> Bool
forall a. Eq a => a -> a -> Bool
== Errno
EWOULDBLOCK
            then IO (Either Errno Socket)
acceptLoop
            else Either Errno Socket -> IO (Either Errno Socket)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Errno -> Either Errno Socket
forall a b. a -> Either a b
Left Errno
e)
          Right (Fd CInt
fd) -> (Socket -> Either Errno Socket)
-> IO Socket -> IO (Either Errno Socket)
forall a b. (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Socket -> Either Errno Socket
forall a b. b -> Either a b
Right (CInt -> IO Socket
mkSocket CInt
fd)
  IO (Either Errno Socket)
acceptLoop

-- | Connect to a socket address. See the documentation in @network@
-- for @connect@.
connect :: Socket -> SockAddr -> IO (Either Errno ())
connect :: Socket -> SockAddr -> IO (Either Errno ())
connect Socket
s SockAddr
sa = SockAddr
-> (Ptr SockAddr -> Int -> IO (Either Errno ()))
-> IO (Either Errno ())
forall sa a.
SocketAddress sa =>
sa -> (Ptr sa -> Int -> IO a) -> IO a
withSocketAddress SockAddr
sa ((Ptr SockAddr -> Int -> IO (Either Errno ()))
 -> IO (Either Errno ()))
-> (Ptr SockAddr -> Int -> IO (Either Errno ()))
-> IO (Either Errno ())
forall a b. (a -> b) -> a -> b
$ \Ptr SockAddr
p_sa Int
sz -> Socket -> (CInt -> IO (Either Errno ())) -> IO (Either Errno ())
forall r. Socket -> (CInt -> IO r) -> IO r
withFdSocket Socket
s ((CInt -> IO (Either Errno ())) -> IO (Either Errno ()))
-> (CInt -> IO (Either Errno ())) -> IO (Either Errno ())
forall a b. (a -> b) -> a -> b
$ \CInt
fd ->
  let loop :: IO (Either Errno ())
loop = do
        Either Errno ()
r <- Fd -> Ptr SockAddr -> Int -> IO (Either Errno ())
forall a. Fd -> Ptr a -> Int -> IO (Either Errno ())
X.uninterruptibleConnectPtr (CInt -> Fd
Fd CInt
fd) Ptr SockAddr
p_sa Int
sz
        case Either Errno ()
r of
          Right ()
_ -> Either Errno () -> IO (Either Errno ())
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (() -> Either Errno ()
forall a b. b -> Either a b
Right ())
          Left Errno
err -> case Errno
err of
            Errno
EINTR -> IO (Either Errno ())
loop
            Errno
EINPROGRESS -> do
              Fd -> IO ()
threadWaitWrite (CInt -> Fd
Fd CInt
fd)
              Int
errB <- Socket -> SocketOption -> IO Int
getSocketOption Socket
s SocketOption
SoError
              case Int
errB of
                Int
0 -> Either Errno () -> IO (Either Errno ())
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (() -> Either Errno ()
forall a b. b -> Either a b
Right ())
                Int
_ -> Either Errno () -> IO (Either Errno ())
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Errno -> Either Errno ()
forall a b. a -> Either a b
Left (CInt -> Errno
Errno (Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
errB)))
            Errno
_ -> Either Errno () -> IO (Either Errno ())
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Errno -> Either Errno ()
forall a b. a -> Either a b
Left Errno
err)
   in IO (Either Errno ())
loop

-- | Variant of 'connect' that can be interrupted by setting the interrupt
-- variable to @True@. If interrupted in this way, this function returns
-- @EAGAIN@. For example, to attempt to connect for no more than 1 second:
--
-- > interrupt <- Control.Concurrent.STM.registerDelay 1_000_000
-- > connectInterruptible interrupt sock sockAddr
connectInterruptible :: TVar Bool -> Socket -> SockAddr -> IO (Either Errno ())
connectInterruptible :: TVar Bool -> Socket -> SockAddr -> IO (Either Errno ())
connectInterruptible !TVar Bool
interrupt Socket
s SockAddr
sa = SockAddr
-> (Ptr SockAddr -> Int -> IO (Either Errno ()))
-> IO (Either Errno ())
forall sa a.
SocketAddress sa =>
sa -> (Ptr sa -> Int -> IO a) -> IO a
withSocketAddress SockAddr
sa ((Ptr SockAddr -> Int -> IO (Either Errno ()))
 -> IO (Either Errno ()))
-> (Ptr SockAddr -> Int -> IO (Either Errno ()))
-> IO (Either Errno ())
forall a b. (a -> b) -> a -> b
$ \Ptr SockAddr
p_sa Int
sz -> Socket -> (CInt -> IO (Either Errno ())) -> IO (Either Errno ())
forall r. Socket -> (CInt -> IO r) -> IO r
withFdSocket Socket
s ((CInt -> IO (Either Errno ())) -> IO (Either Errno ()))
-> (CInt -> IO (Either Errno ())) -> IO (Either Errno ())
forall a b. (a -> b) -> a -> b
$ \CInt
fd ->
  let loop :: IO (Either Errno ())
loop = do
        Either Errno ()
r <- Fd -> Ptr SockAddr -> Int -> IO (Either Errno ())
forall a. Fd -> Ptr a -> Int -> IO (Either Errno ())
X.uninterruptibleConnectPtr (CInt -> Fd
Fd CInt
fd) Ptr SockAddr
p_sa Int
sz
        case Either Errno ()
r of
          Right ()
_ -> Either Errno () -> IO (Either Errno ())
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (() -> Either Errno ()
forall a b. b -> Either a b
Right ())
          Left Errno
err -> case Errno
err of
            Errno
EINTR -> IO (Either Errno ())
loop
            Errno
EINPROGRESS -> TVar Bool -> Fd -> IO Outcome
waitUntilWriteable TVar Bool
interrupt (CInt -> Fd
Fd CInt
fd) IO Outcome
-> (Outcome -> IO (Either Errno ())) -> IO (Either Errno ())
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
              Outcome
Interrupted -> Either Errno () -> IO (Either Errno ())
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Errno -> Either Errno ()
forall a b. a -> Either a b
Left Errno
EAGAIN)
              Outcome
Ready -> do
                Int
errB <- Socket -> SocketOption -> IO Int
getSocketOption Socket
s SocketOption
SoError
                case Int
errB of
                  Int
0 -> Either Errno () -> IO (Either Errno ())
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (() -> Either Errno ()
forall a b. b -> Either a b
Right ())
                  Int
_ -> Either Errno () -> IO (Either Errno ())
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Errno -> Either Errno ()
forall a b. a -> Either a b
Left (CInt -> Errno
Errno (Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
errB)))
            Errno
_ -> Either Errno () -> IO (Either Errno ())
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Errno -> Either Errno ()
forall a b. a -> Either a b
Left Errno
err)
   in IO (Either Errno ())
loop

-- Copied this from the network library. TODO: See if network can
-- just export this.
withSocketAddress :: SocketAddress sa => sa -> (Ptr sa -> Int -> IO a) -> IO a
withSocketAddress :: forall sa a.
SocketAddress sa =>
sa -> (Ptr sa -> Int -> IO a) -> IO a
withSocketAddress sa
addr Ptr sa -> Int -> IO a
f = do
  let sz :: Int
sz = sa -> Int
forall sa. SocketAddress sa => sa -> Int
sizeOfSocketAddress sa
addr
  if Int
sz Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0
    then Ptr sa -> Int -> IO a
f Ptr sa
forall a. Ptr a
nullPtr Int
0
    else Int -> (Ptr Any -> IO a) -> IO a
forall a b. Int -> (Ptr a -> IO b) -> IO b
allocaBytes Int
sz ((Ptr Any -> IO a) -> IO a) -> (Ptr Any -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ \Ptr Any
p -> Ptr Any -> sa -> IO ()
forall a. Ptr a -> sa -> IO ()
forall sa a. SocketAddress sa => Ptr a -> sa -> IO ()
pokeSocketAddress Ptr Any
p sa
addr IO () -> IO a -> IO a
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Ptr sa -> Int -> IO a
f (Ptr Any -> Ptr sa
forall a b. Ptr a -> Ptr b
castPtr Ptr Any
p) Int
sz

data Outcome = Ready | Interrupted

checkFinished :: TVar Bool -> STM ()
checkFinished :: TVar Bool -> STM ()
checkFinished = Bool -> STM ()
STM.check (Bool -> STM ()) -> (TVar Bool -> STM Bool) -> TVar Bool -> STM ()
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< TVar Bool -> STM Bool
forall a. TVar a -> STM a
STM.readTVar

waitUntilWriteable :: TVar Bool -> Fd -> IO Outcome
waitUntilWriteable :: TVar Bool -> Fd -> IO Outcome
waitUntilWriteable !TVar Bool
interrupt !Fd
fd = do
  (STM ()
isReadyAction,IO ()
deregister) <- Fd -> IO (STM (), IO ())
threadWaitWriteSTM Fd
fd
  Outcome
outcome <- STM Outcome -> IO Outcome
forall a. STM a -> IO a
STM.atomically (STM Outcome -> IO Outcome) -> STM Outcome -> IO Outcome
forall a b. (a -> b) -> a -> b
$ (STM ()
isReadyAction STM () -> Outcome -> STM Outcome
forall (f :: * -> *) a b. Functor f => f a -> b -> f b
$> Outcome
Ready) STM Outcome -> STM Outcome -> STM Outcome
forall a. STM a -> STM a -> STM a
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> (TVar Bool -> STM ()
checkFinished TVar Bool
interrupt STM () -> Outcome -> STM Outcome
forall (f :: * -> *) a b. Functor f => f a -> b -> f b
$> Outcome
Interrupted)
  IO ()
deregister
  Outcome -> IO Outcome
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Outcome
outcome

-- | Create a socket. See the documentation in @network@ for @socket@.
--
-- There is no interruptible variant of this function because it cannot
-- block. (It does not actually perform network activity.)
socket ::
     N.Family         -- Family Name (usually AF_INET)
  -> N.SocketType     -- Socket Type (usually Stream)
  -> N.ProtocolNumber -- Protocol Number (getProtocolByName to find value)
  -> IO (Either Errno Socket)      -- Unconnected Socket
socket :: Family -> SocketType -> CInt -> IO (Either Errno Socket)
socket !Family
fam !SocketType
stype !CInt
protocol = case SocketType
stype of
  SocketType
N.Stream -> Type -> IO (Either Errno Socket)
finish Type
X.stream
  SocketType
N.Datagram -> Type -> IO (Either Errno Socket)
finish Type
X.datagram
  SocketType
_ -> String -> IO (Either Errno Socket)
forall a. String -> IO a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Network.Unexceptional.socket: Currently only supports stream and datagram types"
  where
  finish :: Type -> IO (Either Errno Socket)
finish !Type
sockTy = IO (Either Errno Socket) -> IO (Either Errno Socket)
forall a. IO a -> IO a
mask_ (IO (Either Errno Socket) -> IO (Either Errno Socket))
-> IO (Either Errno Socket) -> IO (Either Errno Socket)
forall a b. (a -> b) -> a -> b
$ do
    Family -> Type -> Protocol -> IO (Either Errno Fd)
X.uninterruptibleSocket (CInt -> Family
X.Family (Family -> CInt
N.packFamily Family
fam)) (SocketFlags -> Type -> Type
X.applySocketFlags (SocketFlags
X.closeOnExec SocketFlags -> SocketFlags -> SocketFlags
forall a. Semigroup a => a -> a -> a
<> SocketFlags
X.nonblocking) Type
sockTy) (CInt -> Protocol
X.Protocol CInt
protocol) IO (Either Errno Fd)
-> (Either Errno Fd -> IO (Either Errno Socket))
-> IO (Either Errno Socket)
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
      Left Errno
err -> Either Errno Socket -> IO (Either Errno Socket)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Errno -> Either Errno Socket
forall a b. a -> Either a b
Left Errno
err)
      Right (Fd CInt
fd) -> do
        Socket
s <- CInt -> IO Socket
mkSocket CInt
fd
        Either Errno Socket -> IO (Either Errno Socket)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Socket -> Either Errno Socket
forall a b. b -> Either a b
Right Socket
s)