{-# language BangPatterns #-}
{-# language DataKinds #-}
{-# language DeriveAnyClass #-}
{-# language DerivingStrategies #-}
{-# language DuplicateRecordFields #-}
{-# language LambdaCase #-}
{-# language MagicHash #-}
{-# language NamedFieldPuns #-}
{-# language UnboxedTuples #-}
module Socket.Datagram.IPv4.Undestined
(
Socket(..)
, Endpoint(..)
, Message(..)
, withSocket
, send
, sendMutableByteArraySlice
, receiveByteArray
, receiveMutableByteArraySlice_
, receiveMany
, receiveManyUnless
, SocketException(..)
) where
import Control.Concurrent (threadWaitWrite,threadWaitRead)
import Control.Exception (throwIO,mask,onException)
import Data.Primitive (ByteArray,MutableByteArray(..))
import Data.Word (Word16)
import Foreign.C.Error (Errno(..),eWOULDBLOCK,eAGAIN,eACCES)
import Foreign.C.Types (CInt,CSize)
import GHC.Exts (Int(I#),RealWorld,shrinkMutableByteArray#,ByteArray#,touch#)
import GHC.IO (IO(..))
import Net.Types (IPv4(..))
import Socket (SocketException(..),SocketUnrecoverableException(..),Direction(..),Interruptibility(..))
import Socket (cgetsockname)
import Socket.Datagram (SendException(..),ReceiveException(..))
import Socket.Datagram.IPv4.Undestined.Internal (Message(..),Socket(..))
import Socket.Datagram.IPv4.Undestined.Multiple (receiveMany,receiveManyUnless)
import Socket.Debug (debug)
import Socket.IPv4 (Endpoint(..),describeEndpoint)
import qualified Socket as SCK
import qualified Control.Monad.Primitive as PM
import qualified Data.Primitive as PM
import qualified Linux.Socket as L
import qualified Posix.Socket as S
withSocket ::
Endpoint
-> (Socket -> Word16 -> IO a)
-> IO (Either SocketException a)
withSocket endpoint@Endpoint{port = specifiedPort} f = mask $ \restore -> do
debug ("withSocket: opening socket " ++ describeEndpoint endpoint)
e1 <- S.uninterruptibleSocket S.internet
(L.applySocketFlags (L.closeOnExec <> L.nonblocking) S.datagram)
S.defaultProtocol
debug ("withSocket: opened socket " ++ describeEndpoint endpoint)
case e1 of
Left err -> throwIO $ SocketUnrecoverableException
moduleSocketDatagramIPv4Undestined
functionWithSocket
["socket",describeEndpoint endpoint,describeErrorCode err]
Right fd -> do
e2 <- S.uninterruptibleBind fd
(S.encodeSocketAddressInternet (endpointToSocketAddressInternet endpoint))
debug ("withSocket: requested binding for " ++ describeEndpoint endpoint)
case e2 of
Left err -> do
S.uninterruptibleErrorlessClose fd
throwIO $ SocketUnrecoverableException
moduleSocketDatagramIPv4Undestined
functionWithSocket
["bind",describeEndpoint endpoint,describeErrorCode err]
Right _ -> do
eactualPort <- if specifiedPort == 0
then S.uninterruptibleGetSocketName fd S.sizeofSocketAddressInternet >>= \case
Left err -> do
S.uninterruptibleErrorlessClose fd
throwIO $ SocketUnrecoverableException
moduleSocketDatagramIPv4Undestined
functionWithSocket
["getsockname",describeEndpoint endpoint,describeErrorCode err]
Right (sockAddrRequiredSz,sockAddr) -> if sockAddrRequiredSz == S.sizeofSocketAddressInternet
then case S.decodeSocketAddressInternet sockAddr of
Just S.SocketAddressInternet{port = actualPort} -> do
let cleanPort = S.networkToHostShort actualPort
debug ("withSocket: successfully bound " ++ describeEndpoint endpoint ++ " and got port " ++ show cleanPort)
pure (Right cleanPort)
Nothing -> do
S.uninterruptibleErrorlessClose fd
throwIO $ SocketUnrecoverableException
moduleSocketDatagramIPv4Undestined
functionWithSocket
[cgetsockname,describeEndpoint endpoint,"non-internet socket family"]
else do
S.uninterruptibleErrorlessClose fd
throwIO $ SocketUnrecoverableException
moduleSocketDatagramIPv4Undestined
functionWithSocket
[cgetsockname,describeEndpoint endpoint,"socket address size"]
else pure (Right specifiedPort)
case eactualPort of
Left err -> pure (Left err)
Right actualPort -> do
a <- onException (restore (f (Socket fd) actualPort)) (S.uninterruptibleErrorlessClose fd)
S.uninterruptibleClose fd >>= \case
Left err -> throwIO $ SocketUnrecoverableException
moduleSocketDatagramIPv4Undestined
functionWithSocket
["close",describeEndpoint endpoint,describeErrorCode err]
Right _ -> pure (Right a)
send ::
Socket
-> Endpoint
-> ByteArray
-> Int
-> Int
-> IO (Either SocketException ())
send (Socket !s) !theRemote !thePayload !off !len = do
debug ("send: about to send to " ++ show theRemote)
e1 <- S.uninterruptibleSendToByteArray s thePayload
(intToCInt off)
(intToCSize len)
mempty
(S.encodeSocketAddressInternet (endpointToSocketAddressInternet theRemote))
debug ("send: just sent to " ++ show theRemote)
case e1 of
Left err1 -> if err1 == eWOULDBLOCK || err1 == eAGAIN
then do
debug ("send: waiting to for write ready to send to " ++ show theRemote)
threadWaitWrite s
e2 <- S.uninterruptibleSendToByteArray s thePayload
(intToCInt off)
(intToCSize len)
mempty
(S.encodeSocketAddressInternet (endpointToSocketAddressInternet theRemote))
case e2 of
Left err2 -> do
debug ("send: encountered error after sending")
throwIO $ SocketUnrecoverableException
moduleSocketDatagramIPv4Undestined
functionSend
[show theRemote,describeErrorCode err2]
Right sz -> if csizeToInt sz == len
then pure (Right ())
else pure (Left (SentMessageTruncated (csizeToInt sz)))
else throwIO $ SocketUnrecoverableException
moduleSocketDatagramIPv4Undestined
functionSend
[show theRemote,describeErrorCode err1]
Right sz -> if csizeToInt sz == len
then do
debug ("send: success")
pure (Right ())
else pure (Left (SentMessageTruncated (csizeToInt sz)))
sendMutableByteArraySlice ::
Socket
-> Endpoint
-> MutableByteArray RealWorld
-> Int
-> Int
-> IO (Either (SendException 'Uninterruptible) ())
sendMutableByteArraySlice (Socket !s) !theRemote !thePayload !off !len = do
debug ("send mutable: about to send to " ++ show theRemote)
e1 <- S.uninterruptibleSendToMutableByteArray s thePayload
(intToCInt off)
(intToCSize len)
mempty
(S.encodeSocketAddressInternet (endpointToSocketAddressInternet theRemote))
debug ("send mutable: just sent to " ++ show theRemote)
case e1 of
Left err1 -> if err1 == eWOULDBLOCK || err1 == eAGAIN
then do
debug ("send mutable: waiting to for write ready to send to " ++ show theRemote)
threadWaitWrite s
e2 <- S.uninterruptibleSendToMutableByteArray s thePayload
(intToCInt off)
(intToCSize len)
mempty
(S.encodeSocketAddressInternet (endpointToSocketAddressInternet theRemote))
case e2 of
Left err2 -> do
debug ("send mutable: encountered error after sending")
handleSendException functionSendMutableByteArray err2
Right sz -> if csizeToInt sz == len
then pure (Right ())
else pure (Left (SendTruncated (csizeToInt sz)))
else handleSendException functionSendMutableByteArray err1
Right sz -> if csizeToInt sz == len
then do
debug ("send mutable: success")
pure (Right ())
else pure (Left (SendTruncated (csizeToInt sz)))
receiveByteArray ::
Socket
-> Int
-> IO (Either (ReceiveException 'Uninterruptible) Message)
receiveByteArray (Socket !fd) !maxSz = do
debug "receive: about to wait"
threadWaitRead fd
debug "receive: socket is now readable"
marr <- PM.newByteArray maxSz
e <- S.uninterruptibleReceiveFromMutableByteArray fd marr 0
(intToCSize maxSz) (L.truncate) S.sizeofSocketAddressInternet
debug "receive: finished reading from socket"
case e of
Left err -> throwIO $ SocketUnrecoverableException
moduleSocketDatagramIPv4Undestined
functionReceive
[describeErrorCode err]
Right (sockAddrRequiredSz,sockAddr,recvSz) -> if csizeToInt recvSz <= maxSz
then if sockAddrRequiredSz == S.sizeofSocketAddressInternet
then case S.decodeSocketAddressInternet sockAddr of
Just sockAddrInet -> do
shrinkMutableByteArray marr (csizeToInt recvSz)
arr <- PM.unsafeFreezeByteArray marr
pure $ Right (Message (socketAddressInternetToEndpoint sockAddrInet) arr)
Nothing -> throwIO $ SocketUnrecoverableException
moduleSocketDatagramIPv4Undestined
functionReceive
[SCK.crecvfrom,SCK.nonInternetSocketFamily]
else throwIO $ SocketUnrecoverableException
moduleSocketDatagramIPv4Undestined
functionReceive
[SCK.crecvfrom,SCK.socketAddressSize]
else pure (Left (ReceiveTruncated (csizeToInt recvSz)))
receiveMutableByteArraySlice_ ::
Socket
-> MutableByteArray RealWorld
-> Int
-> Int
-> IO (Either SocketException Int)
receiveMutableByteArraySlice_ (Socket !fd) !buf !off !maxSz = do
threadWaitRead fd
e <- S.uninterruptibleReceiveFromMutableByteArray_ fd buf (intToCInt off) (intToCSize maxSz) (L.truncate)
case e of
Left err -> throwIO $ SocketUnrecoverableException
moduleSocketDatagramIPv4Undestined
functionReceiveMutableByteArray
[describeErrorCode err]
Right recvSz -> if csizeToInt recvSz <= maxSz
then pure (Right (csizeToInt recvSz))
else pure (Left (ReceivedMessageTruncated (csizeToInt recvSz)))
endpointToSocketAddressInternet :: Endpoint -> S.SocketAddressInternet
endpointToSocketAddressInternet (Endpoint {address, port}) = S.SocketAddressInternet
{ port = S.hostToNetworkShort port
, address = S.hostToNetworkLong (getIPv4 address)
}
socketAddressInternetToEndpoint :: S.SocketAddressInternet -> Endpoint
socketAddressInternetToEndpoint (S.SocketAddressInternet {address,port}) = Endpoint
{ address = IPv4 (S.networkToHostLong address)
, port = S.networkToHostShort port
}
errorCode :: Errno -> SocketException
errorCode (Errno x) = ErrorCode x
shrinkMutableByteArray :: MutableByteArray RealWorld -> Int -> IO ()
shrinkMutableByteArray (MutableByteArray arr) (I# sz) =
PM.primitive_ (shrinkMutableByteArray# arr sz)
touchByteArray :: ByteArray -> IO ()
touchByteArray (PM.ByteArray x) = touchByteArray# x
touchByteArray# :: ByteArray# -> IO ()
touchByteArray# x = IO $ \s -> case touch# x s of s' -> (# s', () #)
intToCInt :: Int -> CInt
intToCInt = fromIntegral
intToCSize :: Int -> CSize
intToCSize = fromIntegral
csizeToInt :: CSize -> Int
csizeToInt = fromIntegral
moduleSocketDatagramIPv4Undestined :: String
moduleSocketDatagramIPv4Undestined = "Socket.Datagram.IPv4.Undestined"
functionReceive :: String
functionReceive = "receive"
functionSend :: String
functionSend = "send"
functionSendMutableByteArray :: String
functionSendMutableByteArray = "sendMutableByteArray"
functionReceiveMutableByteArray :: String
functionReceiveMutableByteArray = "receiveMutableByteArray"
functionWithSocket :: String
functionWithSocket = "withSocket"
describeErrorCode :: Errno -> String
describeErrorCode (Errno e) = "error code " ++ show e
handleSendException :: String -> Errno -> IO (Either (SendException i) a)
{-# INLINE handleSendException #-}
handleSendException func e
| e == eACCES = pure (Left SendBroadcasted)
| otherwise = throwIO $ SocketUnrecoverableException
moduleSocketDatagramIPv4Undestined
func
[describeErrorCode e]