{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE UnboxedTuples #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE MultiWayIf #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE RankNTypes #-}
module Socket.Stream.IPv4
(
Listener
, Connection(..)
, Peer(..)
, withListener
, withAccepted
, withConnection
, forkAccepted
, forkAcceptedUnmasked
, interruptibleForkAcceptedUnmasked
, SendException(..)
, ReceiveException(..)
, ConnectException(..)
, SocketException(..)
, AcceptException(..)
, CloseException(..)
, Interruptibility(..)
, Family(..)
, Version(..)
, listen
, unlisten
, unlisten_
, connect
, disconnect
, disconnect_
, accept
, interruptibleAccept
) where
import Control.Concurrent (ThreadId)
import Control.Concurrent (forkIO, forkIOWithUnmask)
import Control.Exception (mask, mask_, onException, throwIO)
import Control.Monad.STM (atomically)
import Control.Concurrent.STM (TVar,modifyTVar')
import Data.Word (Word16)
import Foreign.C.Error (Errno(..), eAGAIN, eINPROGRESS, eWOULDBLOCK, eNOTCONN)
import Foreign.C.Error (eADDRINUSE,eHOSTUNREACH)
import Foreign.C.Error (eNFILE,eMFILE,eACCES,ePERM,eCONNABORTED)
import Foreign.C.Error (eTIMEDOUT,eADDRNOTAVAIL,eNETUNREACH,eCONNREFUSED)
import Foreign.C.Types (CInt)
import Net.Types (IPv4(..))
import Socket (Interruptibility(..))
import Socket (SocketUnrecoverableException(..),Family(Internet),Version(V4))
import Socket (cgetsockname,cclose)
import Socket.Error (die)
import Socket.Debug (debug)
import Socket.IPv4 (Peer(..),describeEndpoint)
import Socket.Stream (ConnectException(..),SocketException(..),AcceptException(..))
import Socket.Stream (SendException(..),ReceiveException(..),CloseException(..))
import Socket.Stream (Connection(..))
import System.Posix.Types(Fd)
import qualified Control.Concurrent.STM as STM
import qualified Data.Primitive as PM
import qualified Foreign.C.Error.Describe as D
import qualified Linux.Socket as L
import qualified Posix.Socket as S
import qualified Socket as SCK
import qualified Socket.EventManager as EM
newtype Listener = Listener Fd
listen :: Peer -> IO (Either SocketException (Listener, Word16))
listen endpoint@Peer{port = specifiedPort} = do
debug ("listen: opening listen " ++ describeEndpoint endpoint)
e1 <- S.uninterruptibleSocket S.internet
(L.applySocketFlags (L.closeOnExec <> L.nonblocking) S.stream)
S.defaultProtocol
debug ("listen: opened listen " ++ describeEndpoint endpoint)
case e1 of
Left err -> handleSocketListenException SCK.functionWithListener err
Right fd -> do
e2 <- S.uninterruptibleBind fd
(S.encodeSocketAddressInternet (endpointToSocketAddressInternet endpoint))
debug ("listen: requested binding for listen " ++ describeEndpoint endpoint)
case e2 of
Left err -> do
_ <- S.uninterruptibleClose fd
handleBindListenException specifiedPort err
Right _ -> S.uninterruptibleListen fd 16 >>= \case
Left err -> do
_ <- S.uninterruptibleClose fd
debug "listen: listen failed with error code"
handleBindListenException specifiedPort err
Right _ -> do
actualPort <- if specifiedPort == 0
then S.uninterruptibleGetSocketName fd S.sizeofSocketAddressInternet >>= \case
Left err -> throwIO $ SocketUnrecoverableException
moduleSocketStreamIPv4
functionWithListener
[cgetsockname,describeEndpoint endpoint,describeErrorCode err]
Right (sockAddrRequiredSz,sockAddr) -> if sockAddrRequiredSz == S.sizeofSocketAddressInternet
then case S.decodeSocketAddressInternet sockAddr of
Just S.SocketAddressInternet{port = actualPort} -> do
let cleanActualPort = S.networkToHostShort actualPort
debug ("listen: successfully bound listen " ++ describeEndpoint endpoint ++ " and got port " ++ show cleanActualPort)
pure cleanActualPort
Nothing -> do
_ <- S.uninterruptibleClose fd
throwIO $ SocketUnrecoverableException
moduleSocketStreamIPv4
functionWithListener
[cgetsockname,"non-internet socket family"]
else do
_ <- S.uninterruptibleClose fd
throwIO $ SocketUnrecoverableException
moduleSocketStreamIPv4
functionWithListener
[cgetsockname,describeEndpoint endpoint,"socket address size"]
else pure specifiedPort
let !mngr = EM.manager
debug ("listen: registering fd " ++ show fd)
EM.register mngr fd
pure (Right (Listener fd, actualPort))
unlisten :: Listener -> IO ()
unlisten (Listener fd) = S.uninterruptibleClose fd >>= \case
Left err -> throwIO $ SocketUnrecoverableException
moduleSocketStreamIPv4
functionWithListener
[cclose,describeErrorCode err]
Right _ -> pure ()
unlisten_ :: Listener -> IO ()
unlisten_ (Listener fd) = S.uninterruptibleErrorlessClose fd
withListener ::
Peer
-> (Listener -> Word16 -> IO a)
-> IO (Either SocketException a)
withListener endpoint f = mask $ \restore -> do
listen endpoint >>= \case
Left err -> pure (Left err)
Right (sck, actualPort) -> do
a <- onException
(restore (f sck actualPort))
(unlisten_ sck)
unlisten sck
pure (Right a)
accept :: Listener -> IO (Either (AcceptException 'Uninterruptible) (Connection,Peer))
accept (Listener !fd) = do
debug ("accept: about to create manager, fd=" ++ show fd)
let !mngr = EM.manager
debug ("accept: about to get reader, fd=" ++ show fd)
!tv <- EM.reader mngr fd
let go !oldToken = do
debug ("accept: calling waitlessAccept for " ++ show fd)
waitlessAccept fd >>= \case
Left merr -> case merr of
Nothing -> EM.unreadyAndWait oldToken tv >>= go
Just err -> pure (Left err)
Right r@(Connection conn,_) -> do
debug ("accept: waitlessAccept succeeded for " ++ show fd)
EM.register mngr conn
pure (Right r)
go =<< STM.readTVarIO tv
interruptibleAccept ::
TVar Bool
-> Listener
-> IO (Either (AcceptException 'Interruptible) (Connection,Peer))
interruptibleAccept !abandon (Listener fd) = do
let !mngr = EM.manager
tv <- EM.reader mngr fd
token <- EM.interruptibleWait abandon tv
if EM.isInterrupt token
then pure (Left AcceptInterrupted)
else waitlessAccept fd >>= \case
Left merr -> case merr of
Nothing -> do
EM.unready token tv
interruptibleAccept abandon (Listener fd)
Just err -> pure (Left err)
Right r@(Connection conn,_) -> do
EM.register mngr conn
pure (Right r)
interruptibleAcceptCounting ::
TVar Int
-> TVar Bool
-> Listener
-> IO (Either (AcceptException 'Interruptible) (Connection,Peer))
interruptibleAcceptCounting !counter !abandon (Listener !fd) = do
let !mngr = EM.manager
tv <- EM.reader mngr fd
token <- EM.interruptibleWait abandon tv
if EM.isInterrupt token
then pure (Left AcceptInterrupted)
else waitlessAccept fd >>= \case
Left merr -> case merr of
Nothing -> do
EM.unready token tv
interruptibleAcceptCounting counter abandon (Listener fd)
Just err -> pure (Left err)
Right r@(Connection conn,_) -> do
EM.register mngr conn
pure (Right r)
waitlessAccept :: Fd -> IO (Either (Maybe (AcceptException i)) (Connection,Peer))
waitlessAccept lstn = do
L.uninterruptibleAccept4 lstn S.sizeofSocketAddressInternet (L.closeOnExec <> L.nonblocking) >>= \case
Left err -> handleAcceptException err
Right (sockAddrRequiredSz,sockAddr,acpt) -> if sockAddrRequiredSz == S.sizeofSocketAddressInternet
then case S.decodeSocketAddressInternet sockAddr of
Just sockAddrInet -> do
let !acceptedEndpoint = socketAddressInternetToEndpoint sockAddrInet
debug ("internalAccepted: successfully accepted connection from " ++ show acceptedEndpoint)
pure (Right (Connection acpt, acceptedEndpoint))
Nothing -> do
_ <- S.uninterruptibleClose acpt
throwIO $ SocketUnrecoverableException
moduleSocketStreamIPv4
SCK.functionWithAccepted
[SCK.cgetsockname,SCK.nonInternetSocketFamily]
else do
_ <- S.uninterruptibleClose acpt
throwIO $ SocketUnrecoverableException
moduleSocketStreamIPv4
SCK.functionWithAccepted
[SCK.cgetsockname,SCK.socketAddressSize]
gracefulCloseA :: Fd -> IO (Either CloseException ())
gracefulCloseA fd = do
let !mngr = EM.manager
!tv <- EM.reader mngr fd
token0 <- STM.readTVarIO tv
S.uninterruptibleShutdown fd S.write >>= \case
Left err -> if err == eNOTCONN
then gracefulCloseB tv token0 fd
else do
_ <- S.uninterruptibleClose fd
throwIO $ SocketUnrecoverableException
moduleSocketStreamIPv4
SCK.functionGracefulClose
[SCK.cshutdown,describeErrorCode err]
Right _ -> gracefulCloseB tv token0 fd
gracefulCloseB :: TVar EM.Token -> EM.Token -> Fd -> IO (Either CloseException ())
gracefulCloseB !tv !token0 !fd = do
!buf <- PM.newByteArray 1
S.uninterruptibleReceiveMutableByteArray fd buf 0 1 S.peek >>= \case
Left err1 -> if err1 == eWOULDBLOCK || err1 == eAGAIN
then do
token1 <- EM.persistentUnreadyAndWait token0 tv
gracefulCloseB tv token1 fd
else do
_ <- S.uninterruptibleClose fd
throwIO $ SocketUnrecoverableException
moduleSocketStreamIPv4
SCK.functionGracefulClose
[SCK.crecv,describeErrorCode err1]
Right sz -> if sz == 0
then S.uninterruptibleClose fd >>= \case
Left err -> throwIO $ SocketUnrecoverableException
moduleSocketStreamIPv4
SCK.functionGracefulClose
[SCK.cclose,describeErrorCode err]
Right _ -> pure (Right ())
else do
debug ("Socket.Stream.IPv4.gracefulClose: remote not shutdown B")
_ <- S.uninterruptibleClose fd
pure (Left ClosePeerContinuedSending)
withAccepted ::
Listener
-> (Either CloseException () -> a -> IO b)
-> (Connection -> Peer -> IO a)
-> IO (Either (AcceptException 'Uninterruptible) b)
withAccepted lstn@(Listener lstnFd) consumeException cb = do
debug ("withAccepted: fd " ++ show lstnFd)
r <- mask $ \restore -> do
accept lstn >>= \case
Left e -> pure (Left e)
Right (conn, endpoint) -> do
a <- onException (restore (cb conn endpoint)) (disconnect_ conn)
e <- disconnect conn
pure (Right (e,a))
case r of
Left e -> pure (Left e)
Right (e,a) -> fmap Right (consumeException e a)
forkAccepted ::
Listener
-> (Either CloseException () -> a -> IO ())
-> (Connection -> Peer -> IO a)
-> IO (Either (AcceptException 'Uninterruptible) ThreadId)
forkAccepted lstn consumeException cb =
mask $ \restore -> accept lstn >>= \case
Left e -> pure (Left e)
Right (conn, endpoint) -> fmap Right $ forkIO $ do
a <- onException (restore (cb conn endpoint)) (disconnect_ conn)
e <- disconnect conn
restore (consumeException e a)
forkAcceptedUnmasked ::
Listener
-> (Either CloseException () -> a -> IO ())
-> (Connection -> Peer -> IO a)
-> IO (Either (AcceptException 'Uninterruptible) ThreadId)
forkAcceptedUnmasked lstn consumeException cb =
mask_ $ accept lstn >>= \case
Left e -> pure (Left e)
Right (conn, endpoint) -> fmap Right $ forkIOWithUnmask $ \unmask -> do
a <- onException (unmask (cb conn endpoint)) (disconnect_ conn)
e <- disconnect conn
unmask (consumeException e a)
interruptibleForkAcceptedUnmasked ::
TVar Int
-> TVar Bool
-> Listener
-> (Either CloseException () -> a -> IO ())
-> (Connection -> Peer -> IO a)
-> IO (Either (AcceptException 'Interruptible) ThreadId)
interruptibleForkAcceptedUnmasked !counter !abandon !lstn consumeException cb =
mask_ $ interruptibleAcceptCounting counter abandon lstn >>= \case
Left e -> do
case e of
AcceptInterrupted -> pure ()
_ -> atomically (modifyTVar' counter (subtract 1))
pure (Left e)
Right (conn, endpoint) -> fmap Right $ forkIOWithUnmask $ \unmask -> do
a <- onException
(unmask (cb conn endpoint))
(disconnect_ conn *> atomically (modifyTVar' counter (subtract 1)))
e <- disconnect conn
r <- unmask (consumeException e a)
atomically (modifyTVar' counter (subtract 1))
pure r
connect ::
Peer
-> IO (Either (ConnectException ('Internet 'V4) 'Uninterruptible) Connection)
connect !remote = do
beforeEstablishment remote >>= \case
Left err -> pure (Left err)
Right (fd,sockAddr) -> do
let !mngr = EM.manager
EM.register mngr fd
tv <- EM.writer mngr fd
debug ("connect: about to connect, fd=" ++ show fd)
token0 <- STM.readTVarIO tv
S.uninterruptibleConnect fd sockAddr >>= \case
Left err2 -> if err2 == eINPROGRESS
then do
debug ("connect: EINPROGRESS, fd=" ++ show fd)
token1 <- EM.unreadyAndWait token0 tv
afterEstablishment tv token1 fd
else do
debug ("connect: failed, fd=" ++ show fd)
S.uninterruptibleErrorlessClose fd
handleConnectException SCK.functionWithConnection err2
Right _ -> do
debug ("connect: succeeded immidiately, fd=" ++ show fd)
afterEstablishment tv token0 fd
beforeEstablishment :: Peer -> IO (Either (ConnectException ('Internet 'V4) i) (Fd,S.SocketAddress))
{-# INLINE beforeEstablishment #-}
beforeEstablishment !remote = do
debug ("beforeEstablishment: opening connection " ++ show remote)
e1 <- S.uninterruptibleSocket S.internet
(L.applySocketFlags (L.closeOnExec <> L.nonblocking) S.stream)
S.defaultProtocol
debug ("beforeEstablishment: opened connection " ++ show remote)
case e1 of
Left err -> handleSocketConnectException SCK.functionWithConnection err
Right fd -> do
let sockAddr = id
$ S.encodeSocketAddressInternet
$ endpointToSocketAddressInternet
$ remote
pure (Right (fd,sockAddr))
afterEstablishment ::
TVar EM.Token
-> EM.Token
-> Fd
-> IO (Either (ConnectException ('Internet 'V4) i) Connection)
afterEstablishment !tv !oldToken !fd = do
debug ("afterEstablishment: finished waiting, fd=" ++ show fd)
e <- S.uninterruptibleGetSocketOption fd
S.levelSocket S.optionError (intToCInt (PM.sizeOf (undefined :: CInt)))
case e of
Left err -> do
S.uninterruptibleErrorlessClose fd
throwIO $ SocketUnrecoverableException
moduleSocketStreamIPv4
functionWithListener
[SCK.cgetsockopt,describeErrorCode err]
Right (sz,S.OptionValue val) -> if sz == intToCInt (PM.sizeOf (undefined :: CInt))
then
let err = PM.indexByteArray val 0 :: CInt in
if | err == 0 -> do
debug ("afterEstablishment: connection established, fd=" ++ show fd)
pure (Right (Connection fd))
| Errno err == eAGAIN || Errno err == eWOULDBLOCK -> do
debug ("afterEstablishment: not ready yet, unreadying token and waiting, fd=" ++ show fd)
EM.unready oldToken tv
newToken <- EM.wait tv
afterEstablishment tv newToken fd
| otherwise -> do
S.uninterruptibleErrorlessClose fd
handleConnectException SCK.functionWithConnection (Errno err)
else do
S.uninterruptibleErrorlessClose fd
throwIO $ SocketUnrecoverableException
moduleSocketStreamIPv4
functionWithListener
[SCK.cgetsockopt,connectErrorOptionValueSize]
disconnect :: Connection -> IO (Either CloseException ())
disconnect (Connection fd) = gracefulCloseA fd
disconnect_ :: Connection -> IO ()
disconnect_ (Connection fd) = S.uninterruptibleErrorlessClose fd
withConnection ::
Peer
-> (Either CloseException () -> a -> IO b)
-> (Connection -> IO a)
-> IO (Either (ConnectException ('Internet 'V4) 'Uninterruptible) b)
withConnection !remote g f = mask $ \restore -> do
connect remote >>= \case
Left err -> pure (Left err)
Right conn -> do
a <- onException (restore (f conn)) (disconnect_ conn)
m <- disconnect conn
b <- g m a
pure (Right b)
endpointToSocketAddressInternet :: Peer -> S.SocketAddressInternet
endpointToSocketAddressInternet (Peer {address, port}) = S.SocketAddressInternet
{ port = S.hostToNetworkShort port
, address = S.hostToNetworkLong (getIPv4 address)
}
socketAddressInternetToEndpoint :: S.SocketAddressInternet -> Peer
socketAddressInternetToEndpoint (S.SocketAddressInternet {address,port}) = Peer
{ address = IPv4 (S.networkToHostLong address)
, port = S.networkToHostShort port
}
intToCInt :: Int -> CInt
intToCInt = fromIntegral
moduleSocketStreamIPv4 :: String
moduleSocketStreamIPv4 = "Socket.Stream.IPv4"
functionWithListener :: String
functionWithListener = "withListener"
describeErrorCode :: Errno -> String
describeErrorCode err@(Errno e) = "error code " ++ D.string err ++ " (" ++ show e ++ ")"
handleConnectException :: String -> Errno -> IO (Either (ConnectException ('Internet 'V4) i) a)
handleConnectException func e
| e == eACCES = pure (Left ConnectFirewalled)
| e == ePERM = pure (Left ConnectFirewalled)
| e == eNETUNREACH = pure (Left ConnectNetworkUnreachable)
| e == eHOSTUNREACH = pure (Left ConnectHostUnreachable)
| e == eCONNREFUSED = pure (Left ConnectRefused)
| e == eADDRNOTAVAIL = pure (Left ConnectEphemeralPortsExhausted)
| e == eTIMEDOUT = pure (Left ConnectTimeout)
| otherwise = throwIO $ SocketUnrecoverableException
moduleSocketStreamIPv4
func
[describeErrorCode e]
handleSocketConnectException ::
String
-> Errno
-> IO (Either (ConnectException ('Internet 'V4) i) a)
handleSocketConnectException func e
| e == eMFILE = pure (Left ConnectFileDescriptorLimit)
| e == eNFILE = pure (Left ConnectFileDescriptorLimit)
| otherwise = throwIO $ SocketUnrecoverableException
moduleSocketStreamIPv4
func
[describeErrorCode e]
handleSocketListenException :: String -> Errno -> IO (Either SocketException a)
handleSocketListenException func e
| e == eMFILE = pure (Left SocketFileDescriptorLimit)
| e == eNFILE = pure (Left SocketFileDescriptorLimit)
| otherwise = throwIO $ SocketUnrecoverableException
moduleSocketStreamIPv4
func
[describeErrorCode e]
handleBindListenException ::
Word16
-> Errno
-> IO (Either SocketException a)
handleBindListenException !thePort !e
| e == eACCES = pure (Left SocketPermissionDenied)
| e == eADDRINUSE = if thePort == 0
then pure (Left SocketEphemeralPortsExhausted)
else pure (Left SocketAddressInUse)
| otherwise = die
("Socket.Stream.IPv4.bindListen: " ++ describeErrorCode e)
handleAcceptException :: Errno -> IO (Either (Maybe (AcceptException i)) a)
handleAcceptException e
| e == eAGAIN = pure (Left Nothing)
| e == eWOULDBLOCK = pure (Left Nothing)
| e == eCONNABORTED = pure (Left (Just AcceptConnectionAborted))
| e == eMFILE = pure (Left (Just AcceptFileDescriptorLimit))
| e == eNFILE = pure (Left (Just AcceptFileDescriptorLimit))
| e == ePERM = pure (Left (Just AcceptFirewalled))
| otherwise = die ("Socket.Stream.IPv4.accept: " ++ describeErrorCode e)
connectErrorOptionValueSize :: String
connectErrorOptionValueSize = "incorrectly sized value of SO_ERROR option"