{-# language BangPatterns #-} {-# language RankNTypes #-} {-# language DuplicateRecordFields #-} {-# language LambdaCase #-} {-# language NamedFieldPuns #-} {-# language MagicHash #-} module Socket.Stream.IPv4 ( -- * Types Listener , Connection , Endpoint(..) -- * Bracketed , withListener , withAccepted , withConnection , forkAccepted , forkAcceptedUnmasked -- * Communicate , sendByteArray , sendByteArraySlice , sendMutableByteArray , sendMutableByteArraySlice , receiveByteArray , receiveBoundedByteArray , receiveMutableByteArray -- * Exceptions , SocketException(..) , Context(..) , Reason(..) ) where import Control.Concurrent (ThreadId,threadWaitWrite,threadWaitRead) import Control.Concurrent (forkIO,forkIOWithUnmask) import Control.Exception (mask,onException) import Data.Bifunctor (bimap) import Data.Primitive (ByteArray,MutableByteArray(..)) import Data.Word (Word16) import Foreign.C.Error (Errno(..),eAGAIN,eWOULDBLOCK,eINPROGRESS) import Foreign.C.Types (CInt,CSize) import GHC.Exts (RealWorld,Int(I#),shrinkMutableByteArray#) import Socket (SocketException(..),Context(..),Reason(..)) import Socket.Debug (debug) import Socket.IPv4 (Endpoint(..)) import System.Posix.Types (Fd) import Net.Types (IPv4(..)) import qualified Control.Monad.Primitive as PM import qualified Data.Primitive as PM import qualified Linux.Socket as L import qualified Posix.Socket as S -- | A socket that listens for incomming connections. newtype Listener = Listener Fd -- | A connection-oriented stream socket. newtype Connection = Connection Fd withListener :: Endpoint -> (Listener -> Word16 -> IO a) -> IO (Either SocketException a) withListener endpoint@Endpoint{port = specifiedPort} f = mask $ \restore -> do debug ("withSocket: opening listener " ++ show endpoint) e1 <- S.uninterruptibleSocket S.internet (L.applySocketFlags (L.closeOnExec <> L.nonblocking) S.stream) S.defaultProtocol debug ("withSocket: opened listener " ++ show endpoint) case e1 of Left err -> pure (Left (errorCode Open err)) Right fd -> do e2 <- S.uninterruptibleBind fd (S.encodeSocketAddressInternet (endpointToSocketAddressInternet endpoint)) debug ("withSocket: requested binding for listener " ++ show endpoint) case e2 of Left err -> do _ <- S.uninterruptibleClose fd pure (Left (errorCode Bind err)) Right _ -> S.uninterruptibleListen fd 16 >>= \case -- We hardcode the listen backlog to 16. The author is unfamiliar -- with use cases where gains are realized from tuning this parameter. -- Open an issue if this causes problems for anyone. Left err -> do _ <- S.uninterruptibleClose fd debug "withSocket: listen failed with error code" pure (Left (errorCode Listen err)) Right _ -> do -- The getsockname is copied from code in Socket.Datagram.IPv4.Undestined. -- Consider factoring this out. eactualPort <- if specifiedPort == 0 then S.uninterruptibleGetSocketName fd S.sizeofSocketAddressInternet >>= \case Left err -> do _ <- S.uninterruptibleClose fd pure (Left (errorCode GetName err)) Right (sockAddrRequiredSz,sockAddr) -> if sockAddrRequiredSz == S.sizeofSocketAddressInternet then case S.decodeSocketAddressInternet sockAddr of Just S.SocketAddressInternet{port = actualPort} -> do let cleanActualPort = S.networkToHostShort actualPort debug ("withSocket: successfully bound listener " ++ show endpoint ++ " and got port " ++ show cleanActualPort) pure (Right cleanActualPort) Nothing -> do _ <- S.uninterruptibleClose fd pure (Left (exception GetName SocketAddressFamily)) else do _ <- S.uninterruptibleClose fd pure (Left (exception GetName SocketAddressSize)) else pure (Right specifiedPort) case eactualPort of Left err -> pure (Left err) Right actualPort -> do a <- onException (restore (f (Listener fd) actualPort)) (S.uninterruptibleClose fd) S.uninterruptibleClose fd >>= \case Left err -> pure (Left (errorCode Close err)) Right _ -> pure (Right a) -- | Accept a connection on the listener and run the supplied callback -- on it. This closes the connection when the callback finishes or if -- an exception is thrown. Since this function blocks the thread until -- the callback finishes, it is only suitable for stream socket clients -- that handle one connection at a time. The variant 'forkAcceptedUnmasked' -- is preferrable for servers that need to handle connections concurrently -- (most use cases). withAccepted :: Listener -> (Connection -> Endpoint -> IO a) -> IO (Either SocketException a) withAccepted lst cb = internalAccepted ( \restore action -> do action restore ) lst cb internalAccepted :: ((forall x. IO x -> IO x) -> ((IO a -> IO b) -> IO (Either SocketException b)) -> IO (Either SocketException c)) -> Listener -> (Connection -> Endpoint -> IO a) -> IO (Either SocketException c) internalAccepted wrap (Listener !lst) f = do threadWaitRead lst mask $ \restore -> do S.uninterruptibleAccept lst S.sizeofSocketAddressInternet >>= \case Left err -> pure (Left (errorCode Accept err)) Right (sockAddrRequiredSz,sockAddr,acpt) -> if sockAddrRequiredSz == S.sizeofSocketAddressInternet then case S.decodeSocketAddressInternet sockAddr of Just sockAddrInet -> do let acceptedEndpoint = socketAddressInternetToEndpoint sockAddrInet debug ("withAccepted: successfully accepted connection from " ++ show acceptedEndpoint) wrap restore $ \restore' -> do a <- onException (restore' (f (Connection acpt) acceptedEndpoint)) (S.uninterruptibleClose acpt) gracefulClose acpt a Nothing -> do _ <- S.uninterruptibleClose acpt pure (Left (exception GetName SocketAddressFamily)) else do _ <- S.uninterruptibleClose acpt pure (Left (exception GetName SocketAddressSize)) gracefulClose :: Fd -> a -> IO (Either SocketException a) gracefulClose fd a = S.uninterruptibleShutdown fd S.write >>= \case Left err -> do _ <- S.uninterruptibleClose fd pure (Left (errorCode Shutdown err)) Right _ -> do buf <- PM.newByteArray 1 S.uninterruptibleReceiveMutableByteArray fd buf 0 1 mempty >>= \case Left err1 -> if err1 == eWOULDBLOCK || err1 == eAGAIN then do threadWaitRead fd S.uninterruptibleReceiveMutableByteArray fd buf 0 1 mempty >>= \case Left err -> do _ <- S.uninterruptibleClose fd pure (Left (errorCode Shutdown err)) Right sz -> if sz == 0 then fmap (bimap (errorCode Close) (const a)) (S.uninterruptibleClose fd) else do debug ("Socket.Stream.IPv4.gracefulClose: remote not shutdown A") _ <- S.uninterruptibleClose fd pure (Left (exception Shutdown RemoteNotShutdown)) else do _ <- S.uninterruptibleClose fd -- Is this the right error context? It's a call -- to recv, but it happens while shutting down -- the socket. pure (Left (errorCode Shutdown err1)) Right sz -> if sz == 0 then fmap (bimap (errorCode Close) (const a)) (S.uninterruptibleClose fd) else do debug ("Socket.Stream.IPv4.gracefulClose: remote not shutdown B") _ <- S.uninterruptibleClose fd pure (Left (exception Shutdown RemoteNotShutdown)) -- | Accept a connection on the listener and run the supplied callback in -- a new thread. Prefer 'forkAcceptedUnmasked' unless the masking state -- needs to be preserved for the callback. Such a situation seems unlikely -- to the author. forkAccepted :: Listener -> (Either SocketException a -> IO ()) -> (Connection -> Endpoint -> IO a) -> IO (Either SocketException ThreadId) forkAccepted lst consumeException cb = internalAccepted ( \restore action -> do tid <- forkIO $ do x <- action restore restore (consumeException x) pure (Right tid) ) lst cb -- | Accept a connection on the listener and run the supplied callback in -- a new thread. The masking state is set to @Unmasked@ when running the -- callback. forkAcceptedUnmasked :: Listener -> (Either SocketException a -> IO ()) -> (Connection -> Endpoint -> IO a) -> IO (Either SocketException ThreadId) forkAcceptedUnmasked lst consumeException cb = internalAccepted ( \_ action -> do tid <- forkIOWithUnmask $ \unmask -> do x <- action unmask unmask (consumeException x) pure (Right tid) ) lst cb -- | Establish a connection to a server. withConnection :: Endpoint -- ^ Remote endpoint -> (Connection -> IO a) -- ^ Callback to consume connection -> IO (Either SocketException a) withConnection !remote f = mask $ \restore -> do debug ("withSocket: opening connection " ++ show remote) e1 <- S.uninterruptibleSocket S.internet (L.applySocketFlags (L.closeOnExec <> L.nonblocking) S.stream) S.defaultProtocol debug ("withSocket: opened connection " ++ show remote) case e1 of Left err1 -> pure (Left (errorCode Open err1)) Right fd -> do let sockAddr = id $ S.encodeSocketAddressInternet $ endpointToSocketAddressInternet $ remote merr <- S.uninterruptibleConnect fd sockAddr >>= \case Left err2 -> if err2 == eINPROGRESS then do threadWaitWrite fd pure Nothing else pure (Just (errorCode Connect err2)) Right _ -> pure Nothing case merr of Just err -> do _ <- S.uninterruptibleClose fd pure (Left err) Nothing -> do e <- S.uninterruptibleGetSocketOption fd S.levelSocket S.optionError (intToCInt (PM.sizeOf (undefined :: CInt))) case e of Left err -> do _ <- S.uninterruptibleClose fd pure (Left (errorCode Option err)) Right (sz,S.OptionValue val) -> if sz == intToCInt (PM.sizeOf (undefined :: CInt)) then let err = PM.indexByteArray val 0 :: CInt in if err == 0 then do a <- onException (restore (f (Connection fd))) (S.uninterruptibleClose fd) gracefulClose fd a else do _ <- S.uninterruptibleClose fd pure (Left (errorCode Connect (Errno err))) else do _ <- S.uninterruptibleClose fd pure (Left (exception Option OptionValueSize)) sendByteArray :: Connection -- ^ Connection -> ByteArray -- ^ Buffer (will be sliced) -> IO (Either SocketException ()) sendByteArray conn arr = sendByteArraySlice conn arr 0 (PM.sizeofByteArray arr) sendByteArraySlice :: Connection -- ^ Connection -> ByteArray -- ^ Buffer (will be sliced) -> Int -- ^ Offset into payload -> Int -- ^ Lenth of slice into buffer -> IO (Either SocketException ()) sendByteArraySlice !conn !payload !off0 !len0 = go off0 len0 where go !off !len = if len > 0 then internalSend conn payload off len >>= \case Left e -> pure (Left e) Right sz' -> do let sz = csizeToInt sz' go (off + sz) (len - sz) else pure (Right ()) sendMutableByteArray :: Connection -- ^ Connection -> MutableByteArray RealWorld -- ^ Buffer (will be sliced) -> IO (Either SocketException ()) sendMutableByteArray conn arr = sendMutableByteArraySlice conn arr 0 =<< PM.getSizeofMutableByteArray arr sendMutableByteArraySlice :: Connection -- ^ Connection -> MutableByteArray RealWorld -- ^ Buffer (will be sliced) -> Int -- ^ Offset into payload -> Int -- ^ Lenth of slice into buffer -> IO (Either SocketException ()) sendMutableByteArraySlice !conn !payload !off0 !len0 = go off0 len0 where go !off !len = if len > 0 then internalSendMutable conn payload off len >>= \case Left e -> pure (Left e) Right sz' -> do let sz = csizeToInt sz' go (off + sz) (len - sz) else pure (Right ()) -- The length must be greater than zero. internalSendMutable :: Connection -- ^ Connection -> MutableByteArray RealWorld -- ^ Buffer (will be sliced) -> Int -- ^ Offset into payload -> Int -- ^ Length of slice into buffer -> IO (Either SocketException CSize) internalSendMutable (Connection !s) !payload !off !len = do e1 <- S.uninterruptibleSendMutableByteArray s payload (intToCInt off) (intToCSize len) mempty case e1 of Left err1 -> if err1 == eWOULDBLOCK || err1 == eAGAIN then do threadWaitWrite s e2 <- S.uninterruptibleSendMutableByteArray s payload (intToCInt off) (intToCSize len) mempty case e2 of Left err2 -> pure (Left (errorCode Send err2)) Right sz -> pure (Right sz) else pure (Left (errorCode Send err1)) Right sz -> pure (Right sz) -- The length must be greater than zero. internalSend :: Connection -- ^ Connection -> ByteArray -- ^ Buffer (will be sliced) -> Int -- ^ Offset into payload -> Int -- ^ Length of slice into buffer -> IO (Either SocketException CSize) internalSend (Connection !s) !payload !off !len = do debug ("send: about to send chunk on stream socket, offset " ++ show off ++ " and length " ++ show len) e1 <- S.uninterruptibleSendByteArray s payload (intToCInt off) (intToCSize len) mempty debug "send: just sent chunk on stream socket" case e1 of Left err1 -> if err1 == eWOULDBLOCK || err1 == eAGAIN then do debug "send: waiting to for write ready on stream socket" threadWaitWrite s e2 <- S.uninterruptibleSendByteArray s payload (intToCInt off) (intToCSize len) mempty case e2 of Left err2 -> do debug "send: encountered error after sending chunk on stream socket" pure (Left (errorCode Send err2)) Right sz -> pure (Right sz) else pure (Left (errorCode Send err1)) Right sz -> pure (Right sz) -- The maximum number of bytes to receive must be greater than zero. -- The operating system guarantees us that the returned actual number -- of bytes is less than or equal to the requested number of bytes. -- This function does not validate that the result size is greater -- than zero. Functions calling this must perform that check. This -- also does not trim the buffer. The caller must do that if it is -- necessary. internalReceiveMaximally :: Connection -- ^ Connection -> Int -- ^ Maximum number of bytes to receive -> MutableByteArray RealWorld -- ^ Receive buffer -> Int -- ^ Offset into buffer -> IO (Either SocketException Int) internalReceiveMaximally (Connection !fd) !maxSz !buf !off = do debug "receive: stream socket about to wait" threadWaitRead fd debug ("receive: stream socket is now readable, receiving up to " ++ show maxSz ++ " bytes at offset " ++ show off) e <- S.uninterruptibleReceiveMutableByteArray fd buf (intToCInt off) (intToCSize maxSz) mempty debug "receive: finished reading from stream socket" case e of Left err -> pure (Left (errorCode Receive err)) Right recvSz -> pure (Right (csizeToInt recvSz)) -- | Receive exactly the given number of bytes. If the remote application -- shuts down its end of the connection before sending the required -- number of bytes, this returns -- @'Left' ('SocketException' 'Receive' 'RemoteShutdown')@. receiveByteArray :: Connection -- ^ Connection -> Int -- ^ Number of bytes to receive -> IO (Either SocketException ByteArray) receiveByteArray !conn0 !total = do marr <- PM.newByteArray total go conn0 marr 0 total where go !conn !marr !off !remaining = case compare remaining 0 of GT -> internalReceiveMaximally conn remaining marr off >>= \case Left err -> pure (Left err) Right sz -> if sz /= 0 then go conn marr (off + sz) (remaining - sz) else pure (Left (exception Receive RemoteShutdown)) EQ -> do arr <- PM.unsafeFreezeByteArray marr pure (Right arr) LT -> pure (Left (exception Receive NegativeBytesRequested)) -- | Receive a number of bytes exactly equal to the size of the mutable -- byte array. If the remote application shuts down its end of the -- connection before sending the required number of bytes, this returns -- @'Left' ('SocketException' 'Receive' 'RemoteShutdown')@. receiveMutableByteArray :: Connection -> MutableByteArray RealWorld -> IO (Either SocketException ()) receiveMutableByteArray !conn0 !marr0 = do total <- PM.getSizeofMutableByteArray marr0 go conn0 marr0 0 total where go !conn !marr !off !remaining = if remaining > 0 then internalReceiveMaximally conn remaining marr off >>= \case Left err -> pure (Left err) Right sz -> if sz /= 0 then go conn marr (off + sz) (remaining - sz) else pure (Left (exception Receive RemoteShutdown)) else pure (Right ()) -- | Receive up to the given number of bytes. If the remote application -- shuts down its end of the connection instead of sending any bytes, -- this returns -- @'Left' ('SocketException' 'Receive' 'RemoteShutdown')@. receiveBoundedByteArray :: Connection -- ^ Connection -> Int -- ^ Maximum number of bytes to receive -> IO (Either SocketException ByteArray) receiveBoundedByteArray !conn !total | total > 0 = do m <- PM.newByteArray total internalReceiveMaximally conn total m 0 >>= \case Left err -> pure (Left err) Right sz -> if sz /= 0 then do shrinkMutableByteArray m sz fmap Right (PM.unsafeFreezeByteArray m) else pure (Left (exception Receive RemoteShutdown)) | total == 0 = pure (Right mempty) | otherwise = pure (Left (exception Receive NegativeBytesRequested)) endpointToSocketAddressInternet :: Endpoint -> S.SocketAddressInternet endpointToSocketAddressInternet (Endpoint {address, port}) = S.SocketAddressInternet { port = S.hostToNetworkShort port , address = S.hostToNetworkLong (getIPv4 address) } socketAddressInternetToEndpoint :: S.SocketAddressInternet -> Endpoint socketAddressInternetToEndpoint (S.SocketAddressInternet {address,port}) = Endpoint { address = IPv4 (S.networkToHostLong address) , port = S.networkToHostShort port } errorCode :: Context -> Errno -> SocketException errorCode func (Errno x) = SocketException func (ErrorCode x) exception :: Context -> Reason -> SocketException exception func reason = SocketException func reason intToCInt :: Int -> CInt intToCInt = fromIntegral intToCSize :: Int -> CSize intToCSize = fromIntegral csizeToInt :: CSize -> Int csizeToInt = fromIntegral shrinkMutableByteArray :: MutableByteArray RealWorld -> Int -> IO () shrinkMutableByteArray (MutableByteArray arr) (I# sz) = PM.primitive_ (shrinkMutableByteArray# arr sz)