{-# LANGUAGE CPP #-}

#include "HsNetDef.h"

module Network.Socket.Syscall where

import Foreign.Marshal.Utils (with)
import qualified Control.Exception as E
# if defined(mingw32_HOST_OS)
import System.IO.Error (catchIOError)
#endif

#if defined(mingw32_HOST_OS)
import Foreign (FunPtr)
import GHC.Conc (asyncDoProc)
#else
import Foreign.C.Error (getErrno, eINTR, eINPROGRESS)
import GHC.Conc (threadWaitWrite)
#endif

#ifdef HAVE_ADVANCED_SOCKET_FLAGS
import Network.Socket.Cbits
#else
import Network.Socket.Fcntl
#endif

import Network.Socket.Imports
import Network.Socket.Internal
import Network.Socket.Options
import Network.Socket.Types

-- ----------------------------------------------------------------------------
-- On Windows, our sockets are not put in non-blocking mode (non-blocking
-- is not supported for regular file descriptors on Windows, and it would
-- be a pain to support it only for sockets).  So there are two cases:
--
--  - the threaded RTS uses safe calls for socket operations to get
--    non-blocking I/O, just like the rest of the I/O library
--
--  - with the non-threaded RTS, only some operations on sockets will be
--    non-blocking.  Reads and writes go through the normal async I/O
--    system.  accept() uses asyncDoProc so is non-blocking.  A handful
--    of others (recvFrom, sendFd, recvFd) will block all threads - if this
--    is a problem, -threaded is the workaround.
--

-----------------------------------------------------------------------------
-- Connection Functions

-- In the following connection and binding primitives.  The names of
-- the equivalent C functions have been preserved where possible. It
-- should be noted that some of these names used in the C library,
-- \tr{bind} in particular, have a different meaning to many Haskell
-- programmers and have thus been renamed by appending the prefix
-- Socket.

-- | Create a new socket using the given address family, socket type
-- and protocol number.  The address family is usually 'AF_INET',
-- 'AF_INET6', or 'AF_UNIX'.  The socket type is usually 'Stream' or
-- 'Datagram'.  The protocol number is usually 'defaultProtocol'.
-- If 'AF_INET6' is used and the socket type is 'Stream' or 'Datagram',
-- the 'IPv6Only' socket option is set to 0 so that both IPv4 and IPv6
-- can be handled with one socket.
--
-- >>> import Network.Socket
-- >>> let hints = defaultHints { addrFlags = [AI_NUMERICHOST, AI_NUMERICSERV], addrSocketType = Stream }
-- >>> addr:_ <- getAddrInfo (Just hints) (Just "127.0.0.1") (Just "5000")
-- >>> sock <- socket (addrFamily addr) (addrSocketType addr) (addrProtocol addr)
-- >>> Network.Socket.bind sock (addrAddress addr)
-- >>> getSocketName sock
-- 127.0.0.1:5000
socket :: Family         -- Family Name (usually AF_INET)
       -> SocketType     -- Socket Type (usually Stream)
       -> ProtocolNumber -- Protocol Number (getProtocolByName to find value)
       -> IO Socket      -- Unconnected Socket
socket :: Family -> SocketType -> CInt -> IO Socket
socket Family
family SocketType
stype CInt
protocol = IO CInt -> (CInt -> IO CInt) -> (CInt -> IO Socket) -> IO Socket
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
E.bracketOnError IO CInt
create CInt -> IO CInt
c_close ((CInt -> IO Socket) -> IO Socket)
-> (CInt -> IO Socket) -> IO Socket
forall a b. (a -> b) -> a -> b
$ \CInt
fd -> do
    -- Let's ensure that the socket (file descriptor) is closed even on
    -- asynchronous exceptions.
    CInt -> IO ()
forall {m :: * -> *} {p}. Monad m => p -> m ()
setNonBlock CInt
fd
    Socket
s <- CInt -> IO Socket
mkSocket CInt
fd
    -- This socket is not managed by the IO manager yet.
    -- So, we don't have to call "close" which uses "closeFdWith".
    Socket -> IO ()
unsetIPv6Only Socket
s
    Socket -> IO Socket
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Socket
s
  where
    create :: IO CInt
create = do
        let c_stype :: CInt
c_stype = CInt -> CInt
modifyFlag (CInt -> CInt) -> CInt -> CInt
forall a b. (a -> b) -> a -> b
$ SocketType -> CInt
packSocketType SocketType
stype
        String -> IO CInt -> IO CInt
forall a. (Eq a, Num a) => String -> IO a -> IO a
throwSocketErrorIfMinus1Retry String
"Network.Socket.socket" (IO CInt -> IO CInt) -> IO CInt -> IO CInt
forall a b. (a -> b) -> a -> b
$
            CInt -> CInt -> CInt -> IO CInt
c_socket (Family -> CInt
packFamily Family
family) CInt
c_stype CInt
protocol

#ifdef HAVE_ADVANCED_SOCKET_FLAGS
    modifyFlag :: CInt -> CInt
modifyFlag CInt
c_stype = CInt
c_stype CInt -> CInt -> CInt
forall a. Bits a => a -> a -> a
.|. CInt
sockNonBlock
#else
    modifyFlag c_stype = c_stype
#endif

#ifdef HAVE_ADVANCED_SOCKET_FLAGS
    setNonBlock :: p -> m ()
setNonBlock p
_ = () -> m ()
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
#else
    setNonBlock fd = setNonBlockIfNeeded fd
#endif

#if HAVE_DECL_IPV6_V6ONLY
    unsetIPv6Only :: Socket -> IO ()
unsetIPv6Only Socket
s = Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Family
family Family -> Family -> Bool
forall a. Eq a => a -> a -> Bool
== Family
AF_INET6 Bool -> Bool -> Bool
&& SocketType
stype SocketType -> [SocketType] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [SocketType
Stream, SocketType
Datagram]) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
# if defined(mingw32_HOST_OS)
      -- The IPv6Only option is only supported on Windows Vista and later,
      -- so trying to change it might throw an error.
      setSocketOption s IPv6Only 0 `catchIOError` \_ -> return ()
# elif defined(openbsd_HOST_OS)
      -- don't change IPv6Only
      return ()
# else
      -- The default value of the IPv6Only option is platform specific,
      -- so we explicitly set it to 0 to provide a common default.
      Socket -> SocketOption -> Int -> IO ()
setSocketOption Socket
s SocketOption
IPv6Only Int
0
# endif
#else
    unsetIPv6Only _ = return ()
#endif

-----------------------------------------------------------------------------
-- Binding a socket

-- | Bind the socket to an address. The socket must not already be
-- bound.  The 'Family' passed to @bind@ must be the
-- same as that passed to 'socket'.  If the special port number
-- 'defaultPort' is passed then the system assigns the next available
-- use port.
bind :: SocketAddress sa => Socket -> sa -> IO ()
bind :: forall sa. SocketAddress sa => Socket -> sa -> IO ()
bind Socket
s sa
sa = sa -> (Ptr sa -> Int -> IO ()) -> IO ()
forall sa a.
SocketAddress sa =>
sa -> (Ptr sa -> Int -> IO a) -> IO a
withSocketAddress sa
sa ((Ptr sa -> Int -> IO ()) -> IO ())
-> (Ptr sa -> Int -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr sa
p_sa Int
siz -> IO CInt -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO CInt -> IO ()) -> IO CInt -> IO ()
forall a b. (a -> b) -> a -> b
$ Socket -> (CInt -> IO CInt) -> IO CInt
forall r. Socket -> (CInt -> IO r) -> IO r
withFdSocket Socket
s ((CInt -> IO CInt) -> IO CInt) -> (CInt -> IO CInt) -> IO CInt
forall a b. (a -> b) -> a -> b
$ \CInt
fd -> do
  let sz :: CInt
sz = Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
siz
  String -> IO CInt -> IO CInt
forall a. (Eq a, Num a) => String -> IO a -> IO a
throwSocketErrorIfMinus1Retry String
"Network.Socket.bind" (IO CInt -> IO CInt) -> IO CInt -> IO CInt
forall a b. (a -> b) -> a -> b
$ CInt -> Ptr sa -> CInt -> IO CInt
forall sa. CInt -> Ptr sa -> CInt -> IO CInt
c_bind CInt
fd Ptr sa
p_sa CInt
sz

-----------------------------------------------------------------------------
-- Connecting a socket

-- | Connect to a remote socket at address.
connect :: SocketAddress sa => Socket -> sa -> IO ()
connect :: forall sa. SocketAddress sa => Socket -> sa -> IO ()
connect Socket
s sa
sa = IO () -> IO ()
forall a. IO a -> IO a
withSocketsDo (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ sa -> (Ptr sa -> Int -> IO ()) -> IO ()
forall sa a.
SocketAddress sa =>
sa -> (Ptr sa -> Int -> IO a) -> IO a
withSocketAddress sa
sa ((Ptr sa -> Int -> IO ()) -> IO ())
-> (Ptr sa -> Int -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr sa
p_sa Int
sz ->
    Socket -> Ptr sa -> CInt -> IO ()
forall sa. SocketAddress sa => Socket -> Ptr sa -> CInt -> IO ()
connectLoop Socket
s Ptr sa
p_sa (Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
sz)

connectLoop :: SocketAddress sa => Socket -> Ptr sa -> CInt -> IO ()
connectLoop :: forall sa. SocketAddress sa => Socket -> Ptr sa -> CInt -> IO ()
connectLoop Socket
s Ptr sa
p_sa CInt
sz = Socket -> (CInt -> IO ()) -> IO ()
forall r. Socket -> (CInt -> IO r) -> IO r
withFdSocket Socket
s ((CInt -> IO ()) -> IO ()) -> (CInt -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \CInt
fd -> CInt -> IO ()
loop CInt
fd
  where
    errLoc :: String
errLoc = String
"Network.Socket.connect: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Socket -> String
forall a. Show a => a -> String
show Socket
s
    loop :: CInt -> IO ()
loop CInt
fd = do
       CInt
r <- CInt -> Ptr sa -> CInt -> IO CInt
forall sa. CInt -> Ptr sa -> CInt -> IO CInt
c_connect CInt
fd Ptr sa
p_sa CInt
sz
       Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (CInt
r CInt -> CInt -> Bool
forall a. Eq a => a -> a -> Bool
== -CInt
1) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
#if defined(mingw32_HOST_OS)
           throwSocketError errLoc
#else
           Errno
err <- IO Errno
getErrno
           case () of
             ()
_ | Errno
err Errno -> Errno -> Bool
forall a. Eq a => a -> a -> Bool
== Errno
eINTR       -> CInt -> IO ()
loop CInt
fd
             ()
_ | Errno
err Errno -> Errno -> Bool
forall a. Eq a => a -> a -> Bool
== Errno
eINPROGRESS -> IO ()
connectBlocked
--           _ | err == eAGAIN      -> connectBlocked
             ()
_otherwise             -> String -> IO ()
forall a. String -> IO a
throwSocketError String
errLoc

    connectBlocked :: IO ()
connectBlocked = do
       Socket -> (CInt -> IO ()) -> IO ()
forall r. Socket -> (CInt -> IO r) -> IO r
withFdSocket Socket
s ((CInt -> IO ()) -> IO ()) -> (CInt -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ Fd -> IO ()
threadWaitWrite (Fd -> IO ()) -> (CInt -> Fd) -> CInt -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CInt -> Fd
forall a b. (Integral a, Num b) => a -> b
fromIntegral
       Int
err <- Socket -> SocketOption -> IO Int
getSocketOption Socket
s SocketOption
SoError
       Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
err Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
0) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ String -> CInt -> IO ()
forall a. String -> CInt -> IO a
throwSocketErrorCode String
errLoc (Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
err)
#endif

-----------------------------------------------------------------------------
-- Listen

-- | Listen for connections made to the socket.  The second argument
-- specifies the maximum number of queued connections and should be at
-- least 1; the maximum value is system-dependent (usually 5).
listen :: Socket -> Int -> IO ()
listen :: Socket -> Int -> IO ()
listen Socket
s Int
backlog = Socket -> (CInt -> IO ()) -> IO ()
forall r. Socket -> (CInt -> IO r) -> IO r
withFdSocket Socket
s ((CInt -> IO ()) -> IO ()) -> (CInt -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \CInt
fd -> do
    String -> IO CInt -> IO ()
forall a. (Eq a, Num a) => String -> IO a -> IO ()
throwSocketErrorIfMinus1Retry_ String
"Network.Socket.listen" (IO CInt -> IO ()) -> IO CInt -> IO ()
forall a b. (a -> b) -> a -> b
$
        CInt -> CInt -> IO CInt
c_listen CInt
fd (CInt -> IO CInt) -> CInt -> IO CInt
forall a b. (a -> b) -> a -> b
$ Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
backlog

-----------------------------------------------------------------------------
-- Accept
--
-- A call to `accept' only returns when data is available on the given
-- socket, unless the socket has been set to non-blocking.  It will
-- return a new socket which should be used to read the incoming data and
-- should then be closed. Using the socket returned by `accept' allows
-- incoming requests to be queued on the original socket.

-- | Accept a connection.  The socket must be bound to an address and
-- listening for connections.  The return value is a pair @(conn,
-- address)@ where @conn@ is a new socket object usable to send and
-- receive data on the connection, and @address@ is the address bound
-- to the socket on the other end of the connection.
-- On Unix, FD_CLOEXEC is set to the new 'Socket'.
accept :: SocketAddress sa => Socket -> IO (Socket, sa)
accept :: forall sa. SocketAddress sa => Socket -> IO (Socket, sa)
accept Socket
listing_sock = (Ptr sa -> Int -> IO (Socket, sa)) -> IO (Socket, sa)
forall sa a. SocketAddress sa => (Ptr sa -> Int -> IO a) -> IO a
withNewSocketAddress ((Ptr sa -> Int -> IO (Socket, sa)) -> IO (Socket, sa))
-> (Ptr sa -> Int -> IO (Socket, sa)) -> IO (Socket, sa)
forall a b. (a -> b) -> a -> b
$ \Ptr sa
new_sa Int
sz ->
    Socket -> (CInt -> IO (Socket, sa)) -> IO (Socket, sa)
forall r. Socket -> (CInt -> IO r) -> IO r
withFdSocket Socket
listing_sock ((CInt -> IO (Socket, sa)) -> IO (Socket, sa))
-> (CInt -> IO (Socket, sa)) -> IO (Socket, sa)
forall a b. (a -> b) -> a -> b
$ \CInt
listing_fd -> do
 Socket
new_sock <- IO CInt -> (CInt -> IO CInt) -> (CInt -> IO Socket) -> IO Socket
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
E.bracketOnError (CInt -> Ptr sa -> Int -> IO CInt
forall {a} {sa}. Integral a => CInt -> Ptr sa -> a -> IO CInt
callAccept CInt
listing_fd Ptr sa
new_sa Int
sz) CInt -> IO CInt
c_close CInt -> IO Socket
mkSocket
 sa
new_addr <- Ptr sa -> IO sa
forall sa. SocketAddress sa => Ptr sa -> IO sa
peekSocketAddress Ptr sa
new_sa
 (Socket, sa) -> IO (Socket, sa)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Socket
new_sock, sa
new_addr)
  where
#if defined(mingw32_HOST_OS)
     callAccept fd sa sz
       | threaded  = with (fromIntegral sz) $ \ ptr_len ->
                       throwSocketErrorIfMinus1Retry "Network.Socket.accept" $
                         c_accept_safe fd sa ptr_len
       | otherwise = do
             paramData <- c_newAcceptParams fd (fromIntegral sz) sa
             rc        <- asyncDoProc c_acceptDoProc paramData
             new_fd    <- c_acceptNewSock paramData
             c_free paramData
             when (rc /= 0) $
               throwSocketErrorCode "Network.Socket.accept" (fromIntegral rc)
             return new_fd
#else
     callAccept :: CInt -> Ptr sa -> a -> IO CInt
callAccept CInt
fd Ptr sa
sa a
sz = CInt -> (Ptr CInt -> IO CInt) -> IO CInt
forall a b. Storable a => a -> (Ptr a -> IO b) -> IO b
with (a -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
sz) ((Ptr CInt -> IO CInt) -> IO CInt)
-> (Ptr CInt -> IO CInt) -> IO CInt
forall a b. (a -> b) -> a -> b
$ \ Ptr CInt
ptr_len -> do
# ifdef HAVE_ADVANCED_SOCKET_FLAGS
       Socket -> String -> IO CInt -> IO CInt
forall a. (Eq a, Num a) => Socket -> String -> IO a -> IO a
throwSocketErrorWaitRead Socket
listing_sock String
"Network.Socket.accept"
                        (CInt -> Ptr sa -> Ptr CInt -> CInt -> IO CInt
forall sa. CInt -> Ptr sa -> Ptr CInt -> CInt -> IO CInt
c_accept4 CInt
fd Ptr sa
sa Ptr CInt
ptr_len (CInt
sockNonBlock CInt -> CInt -> CInt
forall a. Bits a => a -> a -> a
.|. CInt
sockCloexec))
# else
       new_fd <- throwSocketErrorWaitRead listing_sock "Network.Socket.accept"
                        (c_accept fd sa ptr_len)
       setNonBlockIfNeeded new_fd
       setCloseOnExecIfNeeded new_fd
       return new_fd
# endif /* HAVE_ADVANCED_SOCKET_FLAGS */
#endif

foreign import CALLCONV unsafe "socket"
  c_socket :: CInt -> CInt -> CInt -> IO CInt
foreign import CALLCONV unsafe "bind"
  c_bind :: CInt -> Ptr sa -> CInt{-CSockLen???-} -> IO CInt
foreign import CALLCONV SAFE_ON_WIN "connect"
  c_connect :: CInt -> Ptr sa -> CInt{-CSockLen???-} -> IO CInt
foreign import CALLCONV unsafe "listen"
  c_listen :: CInt -> CInt -> IO CInt

#ifdef HAVE_ADVANCED_SOCKET_FLAGS
foreign import CALLCONV unsafe "accept4"
  c_accept4 :: CInt -> Ptr sa -> Ptr CInt{-CSockLen???-} -> CInt -> IO CInt
#else
foreign import CALLCONV unsafe "accept"
  c_accept :: CInt -> Ptr sa -> Ptr CInt{-CSockLen???-} -> IO CInt
#endif

#if defined(mingw32_HOST_OS)
foreign import CALLCONV safe "accept"
  c_accept_safe :: CInt -> Ptr sa -> Ptr CInt{-CSockLen???-} -> IO CInt
foreign import ccall unsafe "rtsSupportsBoundThreads"
  threaded :: Bool
foreign import ccall unsafe "HsNet.h acceptNewSock"
  c_acceptNewSock :: Ptr () -> IO CInt
foreign import ccall unsafe "HsNet.h newAcceptParams"
  c_newAcceptParams :: CInt -> CInt -> Ptr a -> IO (Ptr ())
foreign import ccall unsafe "HsNet.h &acceptDoProc"
  c_acceptDoProc :: FunPtr (Ptr () -> IO Int)
foreign import ccall unsafe "free"
  c_free:: Ptr a -> IO ()
#endif