{-# language BangPatterns #-} {-# language DataKinds #-} {-# language MagicHash #-} {-# language ScopedTypeVariables #-} {-# language UnboxedTuples #-} {-# language UnliftedFFITypes #-} module Linux.Socket ( -- * Functions uninterruptibleReceiveMultipleMessageA , uninterruptibleReceiveMultipleMessageB -- * Types , SocketFlags(..) -- * Option Names , LST.headerInclude -- * Message Flags , LST.dontWait , LST.truncate , LST.controlTruncate -- * Socket Flags , LST.closeOnExec , LST.nonblocking -- * Twiddle , applySocketFlags -- * UDP Header , LST.sizeofUdpHeader , LST.pokeUdpHeaderSourcePort , LST.pokeUdpHeaderDestinationPort , LST.pokeUdpHeaderLength , LST.pokeUdpHeaderChecksum -- * IPv4 Header , LST.sizeofIpHeader , LST.pokeIpHeaderVersionIhl , LST.pokeIpHeaderTypeOfService , LST.pokeIpHeaderTotalLength , LST.pokeIpHeaderIdentifier , LST.pokeIpHeaderFragmentOffset , LST.pokeIpHeaderTimeToLive , LST.pokeIpHeaderProtocol , LST.pokeIpHeaderChecksum , LST.pokeIpHeaderSourceAddress , LST.pokeIpHeaderDestinationAddress ) where import Prelude hiding (truncate) import Control.Monad (when) import Data.Bits ((.|.)) import Data.Primitive (MutableByteArray(..),Addr(..),ByteArray(..)) import Data.Primitive (MutableUnliftedArray(..),UnliftedArray) import Foreign.C.Error (Errno,getErrno) import Foreign.C.Types (CInt(..),CSize(..),CUInt(..)) import GHC.Exts (Ptr,RealWorld,ByteArray#,MutableByteArray#,Addr#,MutableArrayArray#,Int(I#)) import GHC.Exts (shrinkMutableByteArray#,touch#,nullAddr#) import GHC.IO (IO(..)) import Linux.Socket.Types (SocketFlags(..)) import Posix.Socket (Type(..),MessageFlags(..),Message(Receive)) import System.Posix.Types (Fd(..),CSsize(..)) import qualified Data.Primitive as PM import qualified Control.Monad.Primitive as PM import qualified Posix.Socket as S import qualified Linux.Socket.Types as LST foreign import ccall unsafe "sys/socket.h recvmmsg" c_unsafe_addr_recvmmsg :: Fd -> Addr# -- This addr is an array of msghdr -> CUInt -- Length of msghdr array -> MessageFlags 'Receive -> Addr# -- Timeout -> IO CSsize -- | Linux extends the @type@ argument of -- to allow -- setting two socket flags on socket creation: @SOCK_CLOEXEC@ and -- @SOCK_NONBLOCK@. It is advisable to set @SOCK_CLOEXEC@ on when -- opening a socket on linux. For example, we may open a TCP Internet -- socket with: -- -- > uninterruptibleSocket internet (applySocketFlags closeOnExec stream) defaultProtocol -- -- To additionally open the socket in nonblocking mode -- (e.g. with @SOCK_NONBLOCK@): -- -- > uninterruptibleSocket internet (applySocketFlags (closeOnExec <> nonblocking) stream) defaultProtocol -- applySocketFlags :: SocketFlags -> Type -> Type applySocketFlags (SocketFlags s) (Type t) = Type (s .|. t) -- | Receive multiple messages. This does not provide the socket -- addresses or the control messages. It does not use any of the -- input-scattering that @recvmmsg@ offers, meaning that a single -- datagram is never split across noncontiguous memory. It supplies -- @NULL@ for the timeout argument. All of the messages must have the -- same maximum size. All resulting byte arrays have been explicitly -- pinned. In addition to bytearrays corresponding to each datagram, -- this also provides the maximum @msg_len@ that @recvmmsg@ wrote -- back out. This is provided so that users of @MSG_TRUNC@ can detect -- when bytes were dropped from the end of a message (although it does -- let the user figure out which message had bytes dropped). uninterruptibleReceiveMultipleMessageA :: Fd -- ^ Socket -> CSize -- ^ Maximum bytes per message -> CUInt -- ^ Maximum number of messages -> MessageFlags 'Receive -- ^ Flags -> IO (Either Errno (CUInt,UnliftedArray ByteArray)) uninterruptibleReceiveMultipleMessageA !s !msgSize !msgCount !flags = do bufs <- PM.unsafeNewUnliftedArray (cuintToInt msgCount) mmsghdrsBuf <- PM.newPinnedByteArray (cuintToInt msgCount * cintToInt LST.sizeofMultipleMessageHeader) iovecsBuf <- PM.newPinnedByteArray (cuintToInt msgCount * cintToInt S.sizeofIOVector) let !mmsghdrsAddr@(Addr mmsghdrsAddr#) = PM.mutableByteArrayContents mmsghdrsBuf let iovecsAddr = PM.mutableByteArrayContents iovecsBuf initializeMultipleMessageHeadersWithoutSockAddr bufs iovecsAddr mmsghdrsAddr msgSize msgCount r <- c_unsafe_addr_recvmmsg s mmsghdrsAddr# msgCount flags nullAddr# if r > (-1) then do (_,maxMsgSz,frozenBufs) <- shrinkAndFreezeMessages msgSize 0 (cssizeToInt r) bufs mmsghdrsAddr touchMutableUnliftedArray bufs touchMutableByteArray iovecsBuf touchMutableByteArray mmsghdrsBuf pure (Right (maxMsgSz,frozenBufs)) else do touchMutableUnliftedArray bufs touchMutableByteArray iovecsBuf touchMutableByteArray mmsghdrsBuf fmap Left getErrno -- | Receive multiple messages. This is similar to -- @uninterruptibleReceiveMultipleMessageA@. However, it also -- provides the @sockaddr@s of the remote endpoints. These are -- written in contiguous memory to a bytearray of length -- @max_num_msgs * expected_sockaddr_sz@. The @sockaddr@s must -- all be expected to be of the same length. This function -- provides a @sockaddr@ size check that is non-zero when any -- @sockaddr@ had a length other than the expected length. -- This can be used to detect if the @sockaddr@ array has one or -- more corrupt @sockaddr@s in it. All byte arrays returned by -- this function are pinned. -- -- The values in the returned tuple are: -- -- * Error-checking number for @sockaddr@ size. Non-zero indicates -- that at least one @sockaddr@ required a number of bytes other -- than the expected number. -- * Pinned bytearray with all of the @sockaddr@s in it as a -- array of structures. -- * The size of the largest message received. If @MSG_TRUNC@ is used -- this lets the caller know if one or more messages were truncated. -- * The message data of each message. -- -- The @sockaddr@s bytearray and the unlifted array of messages are -- guaranteed to have the same number of elements. uninterruptibleReceiveMultipleMessageB :: Fd -- ^ Socket -> CInt -- ^ Expected @sockaddr@ size -> CSize -- ^ Maximum bytes per message -> CUInt -- ^ Maximum number of messages -> MessageFlags 'Receive -- ^ Flags -> IO (Either Errno (CInt,ByteArray,CUInt,UnliftedArray ByteArray)) uninterruptibleReceiveMultipleMessageB !s !expSockAddrSize !msgSize !msgCount !flags = do bufs <- PM.unsafeNewUnliftedArray (cuintToInt msgCount) mmsghdrsBuf <- PM.newPinnedByteArray (cuintToInt msgCount * cintToInt LST.sizeofMultipleMessageHeader) iovecsBuf <- PM.newPinnedByteArray (cuintToInt msgCount * cintToInt S.sizeofIOVector) sockaddrsBuf <- PM.newPinnedByteArray (cuintToInt msgCount * cintToInt expSockAddrSize) -- Linux does not require zeroing out sockaddr_in before using it, -- so we leave sockaddrsBuf alone after initialization. let sockaddrsAddr = PM.mutableByteArrayContents sockaddrsBuf let !mmsghdrsAddr@(Addr mmsghdrsAddr#) = PM.mutableByteArrayContents mmsghdrsBuf let iovecsAddr = PM.mutableByteArrayContents iovecsBuf initializeMultipleMessageHeadersWithSockAddr bufs iovecsAddr mmsghdrsAddr sockaddrsAddr expSockAddrSize msgSize msgCount r <- c_unsafe_addr_recvmmsg s mmsghdrsAddr# msgCount flags nullAddr# if r > (-1) then do (validation,maxMsgSz,frozenBufs) <- shrinkAndFreezeMessages msgSize expSockAddrSize (cssizeToInt r) bufs mmsghdrsAddr shrinkMutableByteArray sockaddrsBuf (cssizeToInt r * cintToInt expSockAddrSize) sockaddrs <- PM.unsafeFreezeByteArray sockaddrsBuf touchMutableByteArray iovecsBuf touchMutableByteArray mmsghdrsBuf touchMutableByteArray sockaddrsBuf pure (Right (validation,sockaddrs,maxMsgSz,frozenBufs)) else do touchMutableUnliftedArray bufs touchMutableByteArray iovecsBuf touchMutableByteArray mmsghdrsBuf touchMutableByteArray sockaddrsBuf fmap Left getErrno -- This sets up an array of mmsghdr. Each msghdr has msg_iov set to -- be an array of iovec with a single element. initializeMultipleMessageHeadersWithoutSockAddr :: MutableUnliftedArray RealWorld (MutableByteArray RealWorld) -- buffers -> Addr -- array of iovec -> Addr -- array of message headers -> CSize -- message size -> CUInt -- message count -> IO () initializeMultipleMessageHeadersWithoutSockAddr bufs iovecsAddr mmsgHdrsAddr msgSize msgCount = let go !ix !iovecAddr !mmsgHdrAddr = if ix < cuintToInt msgCount then do pokeMultipleMessageHeader mmsgHdrAddr PM.nullAddr 0 iovecAddr 1 PM.nullAddr 0 mempty 0 initializeIOVector bufs iovecAddr msgSize ix go (ix + 1) (PM.plusAddr iovecAddr (cintToInt S.sizeofIOVector)) (PM.plusAddr mmsgHdrAddr (cintToInt LST.sizeofMultipleMessageHeader)) else pure () in go 0 iovecsAddr mmsgHdrsAddr -- This sets up an array of mmsghdr. Each msghdr has msg_iov set to -- be an array of iovec with a single element. One giant buffer with -- space for all of the @sockaddr@s is used. initializeMultipleMessageHeadersWithSockAddr :: MutableUnliftedArray RealWorld (MutableByteArray RealWorld) -- buffers -> Addr -- array of iovec -> Addr -- array of message headers -> Addr -- array of sockaddrs -> CInt -- expected sockaddr size -> CSize -- message size -> CUInt -- message count -> IO () initializeMultipleMessageHeadersWithSockAddr bufs iovecsAddr0 mmsgHdrsAddr0 sockaddrsAddr0 sockaddrSize msgSize msgCount = let go !ix !iovecAddr !mmsgHdrAddr !sockaddrAddr = if ix < cuintToInt msgCount then do pokeMultipleMessageHeader mmsgHdrAddr sockaddrAddr sockaddrSize iovecAddr 1 PM.nullAddr 0 mempty 0 initializeIOVector bufs iovecAddr msgSize ix go (ix + 1) (PM.plusAddr iovecAddr (cintToInt S.sizeofIOVector)) (PM.plusAddr mmsgHdrAddr (cintToInt LST.sizeofMultipleMessageHeader)) (PM.plusAddr sockaddrAddr (cintToInt sockaddrSize)) else pure () in go 0 iovecsAddr0 mmsgHdrsAddr0 sockaddrsAddr0 -- Initialize a single iovec. We write the pinned byte array into -- both the iov_base field and into an unlifted array. initializeIOVector :: MutableUnliftedArray RealWorld (MutableByteArray RealWorld) -> Addr -> CSize -> Int -> IO () initializeIOVector bufs iovecAddr msgSize ix = do buf <- PM.newPinnedByteArray (csizeToInt msgSize) PM.writeUnliftedArray bufs ix buf S.pokeIOVectorBase iovecAddr (PM.mutableByteArrayContents buf) S.pokeIOVectorLength iovecAddr msgSize -- Freeze a slice of the mutable byte arrays inside the unlifted array, -- shrinking the byte arrays before doing so. shrinkAndFreezeMessages :: CSize -- Full size of each buffer -> CInt -- Expected sockaddr size -> Int -- Actual number of received messages -> MutableUnliftedArray RealWorld (MutableByteArray RealWorld) -> Addr -- Array of mmsghdr -> IO (CInt,CUInt,UnliftedArray ByteArray) shrinkAndFreezeMessages !bufSize !expSockAddrSize !n !bufs !mmsghdr0 = do r <- PM.unsafeNewUnliftedArray n go r 0 0 0 mmsghdr0 where go !r !validation !ix !maxMsgSz !mmsghdr = if ix < n then do sz <- LST.peekMultipleMessageHeaderLength mmsghdr sockaddrSz <- LST.peekMultipleMessageHeaderNameLength mmsghdr buf <- PM.readUnliftedArray bufs ix when (cuintToInt sz < csizeToInt bufSize) (shrinkMutableByteArray buf (cuintToInt sz)) PM.writeUnliftedArray r ix =<< PM.unsafeFreezeByteArray buf go r (validation .|. (sockaddrSz - expSockAddrSize)) (ix + 1) (max maxMsgSz sz) (PM.plusAddr mmsghdr (cintToInt LST.sizeofMultipleMessageHeader)) else do a <- PM.unsafeFreezeUnliftedArray r pure (validation,maxMsgSz,a) pokeMultipleMessageHeader :: Addr -> Addr -> CInt -> Addr -> CSize -> Addr -> CSize -> MessageFlags 'Receive -> CUInt -> IO () pokeMultipleMessageHeader mmsgHdrAddr a b c d e f g len = do LST.pokeMultipleMessageHeaderName mmsgHdrAddr a LST.pokeMultipleMessageHeaderNameLength mmsgHdrAddr b LST.pokeMultipleMessageHeaderIOVector mmsgHdrAddr c LST.pokeMultipleMessageHeaderIOVectorLength mmsgHdrAddr d LST.pokeMultipleMessageHeaderControl mmsgHdrAddr e LST.pokeMultipleMessageHeaderControlLength mmsgHdrAddr f LST.pokeMultipleMessageHeaderFlags mmsgHdrAddr g LST.pokeMultipleMessageHeaderLength mmsgHdrAddr len shrinkMutableByteArray :: MutableByteArray RealWorld -> Int -> IO () shrinkMutableByteArray (MutableByteArray arr) (I# sz) = PM.primitive_ (shrinkMutableByteArray# arr sz) cintToInt :: CInt -> Int cintToInt = fromIntegral cuintToInt :: CUInt -> Int cuintToInt = fromIntegral csizeToInt :: CSize -> Int csizeToInt = fromIntegral cssizeToInt :: CSsize -> Int cssizeToInt = fromIntegral touchMutableUnliftedArray :: MutableUnliftedArray RealWorld a -> IO () touchMutableUnliftedArray (MutableUnliftedArray x) = touchMutableUnliftedArray# x touchMutableByteArray :: MutableByteArray RealWorld -> IO () touchMutableByteArray (MutableByteArray x) = touchMutableByteArray# x touchMutableUnliftedArray# :: MutableArrayArray# RealWorld -> IO () touchMutableUnliftedArray# x = IO $ \s -> case touch# x s of s' -> (# s', () #) touchMutableByteArray# :: MutableByteArray# RealWorld -> IO () touchMutableByteArray# x = IO $ \s -> case touch# x s of s' -> (# s', () #)