{-# language BangPatterns #-} {-# language DataKinds #-} {-# language MagicHash #-} {-# language ScopedTypeVariables #-} {-# language UnboxedTuples #-} {-# language UnliftedFFITypes #-} module Linux.Socket ( -- * Functions uninterruptibleReceiveMultipleMessageA , uninterruptibleReceiveMultipleMessageB , uninterruptibleReceiveMultipleMessageC , uninterruptibleReceiveMultipleMessageD , uninterruptibleAccept4 -- * Types , SocketFlags(..) -- * Option Names , LST.headerInclude -- * Message Flags , LST.dontWait , LST.truncate , LST.controlTruncate -- * Socket Flags , LST.closeOnExec , LST.nonblocking -- * Twiddle , applySocketFlags -- * UDP Header , LST.sizeofUdpHeader , LST.pokeUdpHeaderSourcePort , LST.pokeUdpHeaderDestinationPort , LST.pokeUdpHeaderLength , LST.pokeUdpHeaderChecksum -- * IPv4 Header , LST.sizeofIpHeader , LST.pokeIpHeaderVersionIhl , LST.pokeIpHeaderTypeOfService , LST.pokeIpHeaderTotalLength , LST.pokeIpHeaderIdentifier , LST.pokeIpHeaderFragmentOffset , LST.pokeIpHeaderTimeToLive , LST.pokeIpHeaderProtocol , LST.pokeIpHeaderChecksum , LST.pokeIpHeaderSourceAddress , LST.pokeIpHeaderDestinationAddress ) where import Prelude hiding (truncate) import Control.Monad (when) import Data.Bits ((.|.)) import Data.Primitive.Addr (Addr(..),plusAddr,nullAddr) import Data.Primitive (MutableByteArray(..),ByteArray(..),MutablePrimArray(..)) import Data.Primitive.Unlifted.Array (MutableUnliftedArray(..),UnliftedArray) import Data.Word (Word8) import Foreign.C.Error (Errno,getErrno) import Foreign.C.Types (CInt(..),CSize(..),CUInt(..)) import GHC.Exts (Ptr(..),RealWorld,MutableByteArray#,Addr#,MutableArrayArray#,Int(I#)) import GHC.Exts (shrinkMutableByteArray#,touch#,nullAddr#) import GHC.IO (IO(..)) import Linux.Socket.Types (SocketFlags(..)) import Posix.Socket (Type(..),MessageFlags(..),Message(Receive),SocketAddress(..)) import System.Posix.Types (Fd(..),CSsize(..)) import qualified Data.Primitive as PM import qualified Data.Primitive.Unlifted.Array as PM import qualified Control.Monad.Primitive as PM import qualified Posix.Socket as S import qualified Linux.Socket.Types as LST foreign import ccall unsafe "sys/socket.h recvmmsg" c_unsafe_addr_recvmmsg :: Fd -> Addr# -- This addr is an array of msghdr -> CUInt -- Length of msghdr array -> MessageFlags 'Receive -> Addr# -- Timeout -> IO CSsize foreign import ccall unsafe "sys/socket.h accept4" c_unsafe_accept4 :: Fd -> MutableByteArray# RealWorld -- SocketAddress -> MutableByteArray# RealWorld -- Ptr CInt -> SocketFlags -> IO Fd foreign import ccall unsafe "HaskellPosix.h recvmmsg_sockaddr_in" c_unsafe_recvmmsg_sockaddr_in :: Fd -> MutableByteArray# RealWorld -- lengths -> MutableByteArray# RealWorld -- sockaddrs -> MutableArrayArray# RealWorld -- buffers -> CUInt -- Length of msghdr array -> MessageFlags 'Receive -> IO CInt foreign import ccall unsafe "HaskellPosix.h recvmmsg_sockaddr_discard" c_unsafe_recvmmsg_sockaddr_discard :: Fd -> MutableByteArray# RealWorld -- lengths -> MutableArrayArray# RealWorld -- buffers -> CUInt -- Length of msghdr array -> MessageFlags 'Receive -> IO CInt -- | Linux extends the @type@ argument of -- to allow -- setting two socket flags on socket creation: @SOCK_CLOEXEC@ and -- @SOCK_NONBLOCK@. It is advisable to set @SOCK_CLOEXEC@ on when -- opening a socket on linux. For example, we may open a TCP Internet -- socket with: -- -- > uninterruptibleSocket internet (applySocketFlags closeOnExec stream) defaultProtocol -- -- To additionally open the socket in nonblocking mode -- (e.g. with @SOCK_NONBLOCK@): -- -- > uninterruptibleSocket internet (applySocketFlags (closeOnExec <> nonblocking) stream) defaultProtocol -- applySocketFlags :: SocketFlags -> Type -> Type applySocketFlags (SocketFlags s) (Type t) = Type (s .|. t) -- | Receive multiple messages. This does not provide the socket -- addresses or the control messages. It does not use any of the -- input-scattering that @recvmmsg@ offers, meaning that a single -- datagram is never split across noncontiguous memory. It supplies -- @NULL@ for the timeout argument. All of the messages must have the -- same maximum size. All resulting byte arrays have been explicitly -- pinned. In addition to bytearrays corresponding to each datagram, -- this also provides the maximum @msg_len@ that @recvmmsg@ wrote -- back out. This is provided so that users of @MSG_TRUNC@ can detect -- when bytes were dropped from the end of a message (although it does -- let the user figure out which message had bytes dropped). uninterruptibleReceiveMultipleMessageA :: Fd -- ^ Socket -> CSize -- ^ Maximum bytes per message -> CUInt -- ^ Maximum number of messages -> MessageFlags 'Receive -- ^ Flags -> IO (Either Errno (CUInt,UnliftedArray ByteArray)) uninterruptibleReceiveMultipleMessageA !s !msgSize !msgCount !flags = do bufs <- PM.unsafeNewUnliftedArray (cuintToInt msgCount) mmsghdrsBuf <- PM.newPinnedByteArray (cuintToInt msgCount * cintToInt LST.sizeofMultipleMessageHeader) iovecsBuf <- PM.newPinnedByteArray (cuintToInt msgCount * cintToInt S.sizeofIOVector) let !mmsghdrsAddr@(Addr mmsghdrsAddr#) = ptrToAddr (PM.mutableByteArrayContents mmsghdrsBuf) let iovecsAddr = ptrToAddr (PM.mutableByteArrayContents iovecsBuf) initializeMultipleMessageHeadersWithoutSockAddr bufs iovecsAddr mmsghdrsAddr msgSize msgCount r <- c_unsafe_addr_recvmmsg s mmsghdrsAddr# msgCount flags nullAddr# if r > (-1) then do (_,maxMsgSz,frozenBufs) <- shrinkAndFreezeMessages msgSize 0 (cssizeToInt r) bufs mmsghdrsAddr touchMutableUnliftedArray bufs touchMutableByteArray iovecsBuf touchMutableByteArray mmsghdrsBuf pure (Right (maxMsgSz,frozenBufs)) else do touchMutableUnliftedArray bufs touchMutableByteArray iovecsBuf touchMutableByteArray mmsghdrsBuf fmap Left getErrno -- | Receive multiple messages. This is similar to -- @uninterruptibleReceiveMultipleMessageA@. However, it also -- provides the @sockaddr@s of the remote endpoints. These are -- written in contiguous memory to a bytearray of length -- @max_num_msgs * expected_sockaddr_sz@. The @sockaddr@s must -- all be expected to be of the same length. This function -- provides a @sockaddr@ size check that is non-zero when any -- @sockaddr@ had a length other than the expected length. -- This can be used to detect if the @sockaddr@ array has one or -- more corrupt @sockaddr@s in it. All byte arrays returned by -- this function are pinned. -- -- The values in the returned tuple are: -- -- * Error-checking number for @sockaddr@ size. Non-zero indicates -- that at least one @sockaddr@ required a number of bytes other -- than the expected number. -- * Pinned bytearray with all of the @sockaddr@s in it as a -- array of structures. -- * The size of the largest message received. If @MSG_TRUNC@ is used -- this lets the caller know if one or more messages were truncated. -- * The message data of each message. -- -- The @sockaddr@s bytearray and the unlifted array of messages are -- guaranteed to have the same number of elements. uninterruptibleReceiveMultipleMessageB :: Fd -- ^ Socket -> CInt -- ^ Expected @sockaddr@ size -> CSize -- ^ Maximum bytes per message -> CUInt -- ^ Maximum number of messages -> MessageFlags 'Receive -- ^ Flags -> IO (Either Errno (CInt,ByteArray,CUInt,UnliftedArray ByteArray)) uninterruptibleReceiveMultipleMessageB !s !expSockAddrSize !msgSize !msgCount !flags = do bufs <- PM.unsafeNewUnliftedArray (cuintToInt msgCount) mmsghdrsBuf <- PM.newPinnedByteArray (cuintToInt msgCount * cintToInt LST.sizeofMultipleMessageHeader) iovecsBuf <- PM.newPinnedByteArray (cuintToInt msgCount * cintToInt S.sizeofIOVector) sockaddrsBuf <- PM.newPinnedByteArray (cuintToInt msgCount * cintToInt expSockAddrSize) -- Linux does not require zeroing out sockaddr_in before using it, -- so we leave sockaddrsBuf alone after initialization. let sockaddrsAddr = ptrToAddr (PM.mutableByteArrayContents sockaddrsBuf) let !mmsghdrsAddr@(Addr mmsghdrsAddr#) = ptrToAddr (PM.mutableByteArrayContents mmsghdrsBuf) let iovecsAddr = ptrToAddr (PM.mutableByteArrayContents iovecsBuf) initializeMultipleMessageHeadersWithSockAddr bufs iovecsAddr mmsghdrsAddr sockaddrsAddr expSockAddrSize msgSize msgCount r <- c_unsafe_addr_recvmmsg s mmsghdrsAddr# msgCount flags nullAddr# if r > (-1) then do (validation,maxMsgSz,frozenBufs) <- shrinkAndFreezeMessages msgSize expSockAddrSize (cssizeToInt r) bufs mmsghdrsAddr shrinkMutableByteArray sockaddrsBuf (cssizeToInt r * cintToInt expSockAddrSize) sockaddrs <- PM.unsafeFreezeByteArray sockaddrsBuf touchMutableByteArray iovecsBuf touchMutableByteArray mmsghdrsBuf touchMutableByteArray sockaddrsBuf pure (Right (validation,sockaddrs,maxMsgSz,frozenBufs)) else do touchMutableUnliftedArray bufs touchMutableByteArray iovecsBuf touchMutableByteArray mmsghdrsBuf touchMutableByteArray sockaddrsBuf fmap Left getErrno -- | All three buffer arguments need to have the same length (in elements, not bytes). uninterruptibleReceiveMultipleMessageC :: Fd -- ^ Socket -> MutablePrimArray RealWorld CInt -- ^ Buffer for payload lengths -> MutablePrimArray RealWorld S.SocketAddressInternet -- ^ Buffer for @sockaddr_in@s -> MutableUnliftedArray RealWorld (MutableByteArray RealWorld) -- ^ Buffers for payloads -> CUInt -- ^ Maximum number of datagrams to receive, length of buffers -> MessageFlags 'Receive -- ^ Flags -> IO (Either Errno CInt) uninterruptibleReceiveMultipleMessageC !s (MutablePrimArray lens) (MutablePrimArray addrs) (PM.MutableUnliftedArray payloads) !msgCount !flags = c_unsafe_recvmmsg_sockaddr_in s lens addrs payloads msgCount flags >>= errorsFromInt -- | All three buffer arguments need to have the same length (in elements, not bytes). -- This discards the source addresses. uninterruptibleReceiveMultipleMessageD :: Fd -- ^ Socket -> MutablePrimArray RealWorld CInt -- ^ Buffer for payload lengths -> MutableUnliftedArray RealWorld (MutableByteArray RealWorld) -- ^ Buffers for payloads -> CUInt -- ^ Maximum number of datagrams to receive, length of buffers -> MessageFlags 'Receive -- ^ Flags -> IO (Either Errno CInt) uninterruptibleReceiveMultipleMessageD !s (MutablePrimArray lens) (PM.MutableUnliftedArray payloads) !msgCount !flags = c_unsafe_recvmmsg_sockaddr_discard s lens payloads msgCount flags >>= errorsFromInt -- This sets up an array of mmsghdr. Each msghdr has msg_iov set to -- be an array of iovec with a single element. initializeMultipleMessageHeadersWithoutSockAddr :: MutableUnliftedArray RealWorld (MutableByteArray RealWorld) -- buffers -> Addr -- array of iovec -> Addr -- array of message headers -> CSize -- message size -> CUInt -- message count -> IO () initializeMultipleMessageHeadersWithoutSockAddr bufs iovecsAddr mmsgHdrsAddr msgSize msgCount = let go !ix !iovecAddr !mmsgHdrAddr = if ix < cuintToInt msgCount then do pokeMultipleMessageHeader mmsgHdrAddr nullAddr 0 iovecAddr 1 nullAddr 0 mempty 0 initializeIOVector bufs iovecAddr msgSize ix go (ix + 1) (plusAddr iovecAddr (cintToInt S.sizeofIOVector)) (plusAddr mmsgHdrAddr (cintToInt LST.sizeofMultipleMessageHeader)) else pure () in go 0 iovecsAddr mmsgHdrsAddr -- This sets up an array of mmsghdr. Each msghdr has msg_iov set to -- be an array of iovec with a single element. One giant buffer with -- space for all of the @sockaddr@s is used. initializeMultipleMessageHeadersWithSockAddr :: MutableUnliftedArray RealWorld (MutableByteArray RealWorld) -- buffers -> Addr -- array of iovec -> Addr -- array of message headers -> Addr -- array of sockaddrs -> CInt -- expected sockaddr size -> CSize -- message size -> CUInt -- message count -> IO () initializeMultipleMessageHeadersWithSockAddr bufs iovecsAddr0 mmsgHdrsAddr0 sockaddrsAddr0 sockaddrSize msgSize msgCount = let go !ix !iovecAddr !mmsgHdrAddr !sockaddrAddr = if ix < cuintToInt msgCount then do pokeMultipleMessageHeader mmsgHdrAddr sockaddrAddr sockaddrSize iovecAddr 1 nullAddr 0 mempty 0 initializeIOVector bufs iovecAddr msgSize ix go (ix + 1) (plusAddr iovecAddr (cintToInt S.sizeofIOVector)) (plusAddr mmsgHdrAddr (cintToInt LST.sizeofMultipleMessageHeader)) (plusAddr sockaddrAddr (cintToInt sockaddrSize)) else pure () in go 0 iovecsAddr0 mmsgHdrsAddr0 sockaddrsAddr0 ptrToAddr :: Ptr Word8 -> Addr ptrToAddr (Ptr x) = Addr x -- Initialize a single iovec. We write the pinned byte array into -- both the iov_base field and into an unlifted array. initializeIOVector :: MutableUnliftedArray RealWorld (MutableByteArray RealWorld) -> Addr -> CSize -> Int -> IO () initializeIOVector bufs iovecAddr msgSize ix = do buf <- PM.newPinnedByteArray (csizeToInt msgSize) PM.writeUnliftedArray bufs ix buf S.pokeIOVectorBase iovecAddr (ptrToAddr (PM.mutableByteArrayContents buf)) S.pokeIOVectorLength iovecAddr msgSize -- Freeze a slice of the mutable byte arrays inside the unlifted array, -- shrinking the byte arrays before doing so. shrinkAndFreezeMessages :: CSize -- Full size of each buffer -> CInt -- Expected sockaddr size -> Int -- Actual number of received messages -> MutableUnliftedArray RealWorld (MutableByteArray RealWorld) -> Addr -- Array of mmsghdr -> IO (CInt,CUInt,UnliftedArray ByteArray) shrinkAndFreezeMessages !bufSize !expSockAddrSize !n !bufs !mmsghdr0 = do r <- PM.unsafeNewUnliftedArray n go r 0 0 0 mmsghdr0 where go !r !validation !ix !maxMsgSz !mmsghdr = if ix < n then do sz <- LST.peekMultipleMessageHeaderLength mmsghdr sockaddrSz <- LST.peekMultipleMessageHeaderNameLength mmsghdr buf <- PM.readUnliftedArray bufs ix when (cuintToInt sz < csizeToInt bufSize) (shrinkMutableByteArray buf (cuintToInt sz)) PM.writeUnliftedArray r ix =<< PM.unsafeFreezeByteArray buf go r (validation .|. (sockaddrSz - expSockAddrSize)) (ix + 1) (max maxMsgSz sz) (plusAddr mmsghdr (cintToInt LST.sizeofMultipleMessageHeader)) else do a <- PM.unsafeFreezeUnliftedArray r pure (validation,maxMsgSz,a) pokeMultipleMessageHeader :: Addr -> Addr -> CInt -> Addr -> CSize -> Addr -> CSize -> MessageFlags 'Receive -> CUInt -> IO () pokeMultipleMessageHeader mmsgHdrAddr a b c d e f g len = do LST.pokeMultipleMessageHeaderName mmsgHdrAddr a LST.pokeMultipleMessageHeaderNameLength mmsgHdrAddr b LST.pokeMultipleMessageHeaderIOVector mmsgHdrAddr c LST.pokeMultipleMessageHeaderIOVectorLength mmsgHdrAddr d LST.pokeMultipleMessageHeaderControl mmsgHdrAddr e LST.pokeMultipleMessageHeaderControlLength mmsgHdrAddr f LST.pokeMultipleMessageHeaderFlags mmsgHdrAddr g LST.pokeMultipleMessageHeaderLength mmsgHdrAddr len shrinkMutableByteArray :: MutableByteArray RealWorld -> Int -> IO () shrinkMutableByteArray (MutableByteArray arr) (I# sz) = PM.primitive_ (shrinkMutableByteArray# arr sz) -- | Variant of 'Posix.Socket.uninterruptibleAccept' that allows setting -- flags on the newly-accepted connection. uninterruptibleAccept4 :: Fd -- ^ Listening socket -> CInt -- ^ Maximum socket address size -> SocketFlags -- ^ Set non-blocking and close-on-exec without extra syscall -> IO (Either Errno (CInt,SocketAddress,Fd)) -- ^ Peer information and connected socket {-# inline uninterruptibleAccept4 #-} uninterruptibleAccept4 !sock !maxSz !flags = do sockAddrBuf@(MutableByteArray sockAddrBuf#) <- PM.newByteArray (cintToInt maxSz) lenBuf@(MutableByteArray lenBuf#) <- PM.newByteArray (PM.sizeOf (undefined :: CInt)) PM.writeByteArray lenBuf 0 maxSz r <- c_unsafe_accept4 sock sockAddrBuf# lenBuf# flags if r > (-1) then do (sz :: CInt) <- PM.readByteArray lenBuf 0 if sz < maxSz then shrinkMutableByteArray sockAddrBuf (cintToInt sz) else pure () sockAddr <- PM.unsafeFreezeByteArray sockAddrBuf pure (Right (sz,SocketAddress sockAddr,r)) else fmap Left getErrno cintToInt :: CInt -> Int cintToInt = fromIntegral cuintToInt :: CUInt -> Int cuintToInt = fromIntegral csizeToInt :: CSize -> Int csizeToInt = fromIntegral cssizeToInt :: CSsize -> Int cssizeToInt = fromIntegral errorsFromInt :: CInt -> IO (Either Errno CInt) {-# inline errorsFromInt #-} errorsFromInt r = if r > (-1) then pure (Right r) else fmap Left getErrno touchMutableUnliftedArray :: MutableUnliftedArray RealWorld a -> IO () touchMutableUnliftedArray (MutableUnliftedArray x) = touchMutableUnliftedArray# x touchMutableByteArray :: MutableByteArray RealWorld -> IO () touchMutableByteArray (MutableByteArray x) = touchMutableByteArray# x touchMutableUnliftedArray# :: MutableArrayArray# RealWorld -> IO () touchMutableUnliftedArray# x = IO $ \s -> case touch# x s of s' -> (# s', () #) touchMutableByteArray# :: MutableByteArray# RealWorld -> IO () touchMutableByteArray# x = IO $ \s -> case touch# x s of s' -> (# s', () #)