module Network.Transport.TCP.Internal
( ControlHeader(..)
, encodeControlHeader
, decodeControlHeader
, ConnectionRequestResponse(..)
, encodeConnectionRequestResponse
, decodeConnectionRequestResponse
, forkServer
, recvWithLength
, recvExact
, recvWord32
, encodeWord32
, tryCloseSocket
, tryShutdownSocketBoth
, resolveSockAddr
, EndPointId
, encodeEndPointAddress
, decodeEndPointAddress
, randomEndPointAddress
, ProtocolVersion
, currentProtocolVersion
) where
#if ! MIN_VERSION_base(4,6,0)
import Prelude hiding (catch)
#endif
import Network.Transport.Internal
( decodeWord32
, encodeWord32
, void
, tryIO
, forkIOWithUnmask
)
import Network.Transport ( EndPointAddress(..) )
#ifdef USE_MOCK_NETWORK
import qualified Network.Transport.TCP.Mock.Socket as N
#else
import qualified Network.Socket as N
#endif
( HostName
, NameInfoFlag(NI_NUMERICHOST)
, ServiceName
, Socket
, SocketType(Stream)
, SocketOption(ReuseAddr)
, getAddrInfo
, defaultHints
, socket
, bind
, listen
, addrFamily
, addrAddress
, defaultProtocol
, setSocketOption
, accept
, close
, socketPort
, shutdown
, ShutdownCmd(ShutdownBoth)
, SockAddr(..)
, getNameInfo
)
#ifdef USE_MOCK_NETWORK
import qualified Network.Transport.TCP.Mock.Socket.ByteString as NBS (recv)
#else
import qualified Network.Socket.ByteString as NBS (recv)
#endif
import Data.Word (Word32, Word64)
import Control.Monad (forever, when)
import Control.Exception (SomeException, catch, bracketOnError, throwIO, mask_)
import Control.Concurrent (ThreadId, forkIO)
import Control.Concurrent.MVar
( MVar
, newEmptyMVar
, putMVar
, readMVar
)
import Control.Monad (forever, when)
import Control.Exception
( SomeException
, catch
, bracketOnError
, throwIO
, mask_
, mask
, finally
, onException
)
import Control.Applicative ((<$>), (<*>))
import Data.Word (Word32)
import Data.ByteString (ByteString)
import qualified Data.ByteString as BS (length, concat, null)
import Data.ByteString.Lazy.Internal (smallChunkSize)
import Data.ByteString.Lazy (toStrict)
import qualified Data.ByteString.Char8 as BSC (unpack, pack)
import Data.ByteString.Builder (word64BE, toLazyByteString)
import Data.Monoid ((<>))
import qualified Data.UUID as UUID
import qualified Data.UUID.V4 as UUID
type EndPointId = Word32
type ProtocolVersion = Word32
currentProtocolVersion :: ProtocolVersion
currentProtocolVersion :: EndPointId
currentProtocolVersion = EndPointId
0x00000000
data =
CreatedNewConnection
| CloseConnection
| CloseSocket
| CloseEndPoint
| ProbeSocket
| ProbeSocketAck
deriving (Int -> ControlHeader -> ShowS
[ControlHeader] -> ShowS
ControlHeader -> HostName
forall a.
(Int -> a -> ShowS) -> (a -> HostName) -> ([a] -> ShowS) -> Show a
showList :: [ControlHeader] -> ShowS
$cshowList :: [ControlHeader] -> ShowS
show :: ControlHeader -> HostName
$cshow :: ControlHeader -> HostName
showsPrec :: Int -> ControlHeader -> ShowS
$cshowsPrec :: Int -> ControlHeader -> ShowS
Show)
decodeControlHeader :: Word32 -> Maybe ControlHeader
EndPointId
w32 = case EndPointId
w32 of
EndPointId
0 -> forall a. a -> Maybe a
Just ControlHeader
CreatedNewConnection
EndPointId
1 -> forall a. a -> Maybe a
Just ControlHeader
CloseConnection
EndPointId
2 -> forall a. a -> Maybe a
Just ControlHeader
CloseSocket
EndPointId
3 -> forall a. a -> Maybe a
Just ControlHeader
CloseEndPoint
EndPointId
4 -> forall a. a -> Maybe a
Just ControlHeader
ProbeSocket
EndPointId
5 -> forall a. a -> Maybe a
Just ControlHeader
ProbeSocketAck
EndPointId
_ -> forall a. Maybe a
Nothing
encodeControlHeader :: ControlHeader -> Word32
ControlHeader
ch = case ControlHeader
ch of
ControlHeader
CreatedNewConnection -> EndPointId
0
ControlHeader
CloseConnection -> EndPointId
1
ControlHeader
CloseSocket -> EndPointId
2
ControlHeader
CloseEndPoint -> EndPointId
3
ControlHeader
ProbeSocket -> EndPointId
4
ControlHeader
ProbeSocketAck -> EndPointId
5
data ConnectionRequestResponse =
ConnectionRequestUnsupportedVersion
| ConnectionRequestAccepted
| ConnectionRequestInvalid
| ConnectionRequestCrossed
| ConnectionRequestHostMismatch
deriving (Int -> ConnectionRequestResponse -> ShowS
[ConnectionRequestResponse] -> ShowS
ConnectionRequestResponse -> HostName
forall a.
(Int -> a -> ShowS) -> (a -> HostName) -> ([a] -> ShowS) -> Show a
showList :: [ConnectionRequestResponse] -> ShowS
$cshowList :: [ConnectionRequestResponse] -> ShowS
show :: ConnectionRequestResponse -> HostName
$cshow :: ConnectionRequestResponse -> HostName
showsPrec :: Int -> ConnectionRequestResponse -> ShowS
$cshowsPrec :: Int -> ConnectionRequestResponse -> ShowS
Show)
decodeConnectionRequestResponse :: Word32 -> Maybe ConnectionRequestResponse
decodeConnectionRequestResponse :: EndPointId -> Maybe ConnectionRequestResponse
decodeConnectionRequestResponse EndPointId
w32 = case EndPointId
w32 of
EndPointId
0xFFFFFFFF -> forall a. a -> Maybe a
Just ConnectionRequestResponse
ConnectionRequestUnsupportedVersion
EndPointId
0x00000000 -> forall a. a -> Maybe a
Just ConnectionRequestResponse
ConnectionRequestAccepted
EndPointId
0x00000001 -> forall a. a -> Maybe a
Just ConnectionRequestResponse
ConnectionRequestInvalid
EndPointId
0x00000002 -> forall a. a -> Maybe a
Just ConnectionRequestResponse
ConnectionRequestCrossed
EndPointId
0x00000003 -> forall a. a -> Maybe a
Just ConnectionRequestResponse
ConnectionRequestHostMismatch
EndPointId
_ -> forall a. Maybe a
Nothing
encodeConnectionRequestResponse :: ConnectionRequestResponse -> Word32
encodeConnectionRequestResponse :: ConnectionRequestResponse -> EndPointId
encodeConnectionRequestResponse ConnectionRequestResponse
crr = case ConnectionRequestResponse
crr of
ConnectionRequestResponse
ConnectionRequestUnsupportedVersion -> EndPointId
0xFFFFFFFF
ConnectionRequestResponse
ConnectionRequestAccepted -> EndPointId
0x00000000
ConnectionRequestResponse
ConnectionRequestInvalid -> EndPointId
0x00000001
ConnectionRequestResponse
ConnectionRequestCrossed -> EndPointId
0x00000002
ConnectionRequestResponse
ConnectionRequestHostMismatch -> EndPointId
0x00000003
randomEndPointAddress :: IO EndPointAddress
randomEndPointAddress :: IO EndPointAddress
randomEndPointAddress = do
UUID
uuid <- IO UUID
UUID.nextRandom
forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ ByteString -> EndPointAddress
EndPointAddress (UUID -> ByteString
UUID.toASCIIBytes UUID
uuid)
forkServer :: N.HostName
-> N.ServiceName
-> Int
-> Bool
-> (SomeException -> IO ())
-> (SomeException -> IO ())
-> (IO () -> (N.Socket, N.SockAddr) -> IO ())
-> IO (N.ServiceName, ThreadId)
forkServer :: HostName
-> HostName
-> Int
-> Bool
-> (SomeException -> IO ())
-> (SomeException -> IO ())
-> (IO () -> (Socket, SockAddr) -> IO ())
-> IO (HostName, ThreadId)
forkServer HostName
host HostName
port Int
backlog Bool
reuseAddr SomeException -> IO ()
errorHandler SomeException -> IO ()
terminationHandler IO () -> (Socket, SockAddr) -> IO ()
requestHandler = do
AddrInfo
addr:[AddrInfo]
_ <- Maybe AddrInfo -> Maybe HostName -> Maybe HostName -> IO [AddrInfo]
N.getAddrInfo (forall a. a -> Maybe a
Just AddrInfo
N.defaultHints) (forall a. a -> Maybe a
Just HostName
host) (forall a. a -> Maybe a
Just HostName
port)
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracketOnError (Family -> SocketType -> ProtocolNumber -> IO Socket
N.socket (AddrInfo -> Family
N.addrFamily AddrInfo
addr) SocketType
N.Stream ProtocolNumber
N.defaultProtocol)
Socket -> IO ()
tryCloseSocket forall a b. (a -> b) -> a -> b
$ \Socket
sock -> do
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
reuseAddr forall a b. (a -> b) -> a -> b
$ Socket -> SocketOption -> Int -> IO ()
N.setSocketOption Socket
sock SocketOption
N.ReuseAddr Int
1
Socket -> SockAddr -> IO ()
N.bind Socket
sock (AddrInfo -> SockAddr
N.addrAddress AddrInfo
addr)
Socket -> Int -> IO ()
N.listen Socket
sock Int
backlog
let release :: ((N.Socket, N.SockAddr), MVar ()) -> IO ()
release :: ((Socket, SockAddr), MVar ()) -> IO ()
release ((Socket
sock, SockAddr
_), MVar ()
socketClosed) =
Socket -> IO ()
N.close Socket
sock forall a b. IO a -> IO b -> IO a
`finally` forall a. MVar a -> a -> IO ()
putMVar MVar ()
socketClosed ()
let act :: (IO () -> IO ()) -> (Socket, SockAddr) -> IO ()
act IO () -> IO ()
restore (Socket
sock, SockAddr
sockAddr) = do
MVar ()
socketClosed <- forall a. IO (MVar a)
newEmptyMVar
forall (m :: * -> *) a. Monad m => m a -> m ()
void forall a b. (a -> b) -> a -> b
$ IO () -> IO ThreadId
forkIO forall a b. (a -> b) -> a -> b
$ IO () -> IO ()
restore forall a b. (a -> b) -> a -> b
$ do
IO () -> (Socket, SockAddr) -> IO ()
requestHandler (forall a. MVar a -> IO a
readMVar MVar ()
socketClosed) (Socket
sock, SockAddr
sockAddr)
forall a b. IO a -> IO b -> IO a
`finally`
((Socket, SockAddr), MVar ()) -> IO ()
release ((Socket
sock, SockAddr
sockAddr), MVar ()
socketClosed)
let acceptRequest :: IO ()
acceptRequest :: IO ()
acceptRequest = forall b. ((forall a. IO a -> IO a) -> IO b) -> IO b
mask forall a b. (a -> b) -> a -> b
$ \forall a. IO a -> IO a
restore -> do
(Socket
sock, SockAddr
sockAddr) <- Socket -> IO (Socket, SockAddr)
N.accept Socket
sock
let handler :: SomeException -> IO ()
handler :: SomeException -> IO ()
handler SomeException
_ = Socket -> IO ()
N.close Socket
sock
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
catch ((IO () -> IO ()) -> (Socket, SockAddr) -> IO ()
act forall a. IO a -> IO a
restore (Socket
sock, SockAddr
sockAddr)) SomeException -> IO ()
handler
(,) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. Show a => a -> HostName
show (Socket -> IO PortNumber
N.socketPort Socket
sock) forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*>
(forall a. IO a -> IO a
mask_ forall a b. (a -> b) -> a -> b
$ ((forall a. IO a -> IO a) -> IO ()) -> IO ThreadId
forkIOWithUnmask forall a b. (a -> b) -> a -> b
$ \forall a. IO a -> IO a
unmask ->
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
catch (forall a. IO a -> IO a
unmask (forall (f :: * -> *) a b. Applicative f => f a -> f b
forever (forall e a. Exception e => IO a -> (e -> IO a) -> IO a
catch IO ()
acceptRequest SomeException -> IO ()
errorHandler))) forall a b. (a -> b) -> a -> b
$ \SomeException
ex -> do
Socket -> IO ()
tryCloseSocket Socket
sock
SomeException -> IO ()
terminationHandler SomeException
ex)
recvWithLength :: Word32 -> N.Socket -> IO [ByteString]
recvWithLength :: EndPointId -> Socket -> IO [ByteString]
recvWithLength EndPointId
limit Socket
sock = do
EndPointId
len <- Socket -> IO EndPointId
recvWord32 Socket
sock
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (EndPointId
len forall a. Ord a => a -> a -> Bool
> EndPointId
limit) forall a b. (a -> b) -> a -> b
$
forall e a. Exception e => e -> IO a
throwIO (HostName -> IOError
userError HostName
"recvWithLength: limit exceeded")
Socket -> EndPointId -> IO [ByteString]
recvExact Socket
sock EndPointId
len
recvWord32 :: N.Socket -> IO Word32
recvWord32 :: Socket -> IO EndPointId
recvWord32 = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (ByteString -> EndPointId
decodeWord32 forall b c a. (b -> c) -> (a -> b) -> a -> c
. [ByteString] -> ByteString
BS.concat) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b c. (a -> b -> c) -> b -> a -> c
flip Socket -> EndPointId -> IO [ByteString]
recvExact EndPointId
4
tryCloseSocket :: N.Socket -> IO ()
tryCloseSocket :: Socket -> IO ()
tryCloseSocket Socket
sock = forall (m :: * -> *) a. Monad m => m a -> m ()
void forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. MonadIO m => IO a -> m (Either IOError a)
tryIO forall a b. (a -> b) -> a -> b
$
Socket -> IO ()
N.close Socket
sock
tryShutdownSocketBoth :: N.Socket -> IO ()
tryShutdownSocketBoth :: Socket -> IO ()
tryShutdownSocketBoth Socket
sock = forall (m :: * -> *) a. Monad m => m a -> m ()
void forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. MonadIO m => IO a -> m (Either IOError a)
tryIO forall a b. (a -> b) -> a -> b
$
Socket -> ShutdownCmd -> IO ()
N.shutdown Socket
sock ShutdownCmd
N.ShutdownBoth
recvExact :: N.Socket
-> Word32
-> IO [ByteString]
recvExact :: Socket -> EndPointId -> IO [ByteString]
recvExact Socket
sock EndPointId
len = [ByteString] -> EndPointId -> IO [ByteString]
go [] EndPointId
len
where
go :: [ByteString] -> Word32 -> IO [ByteString]
go :: [ByteString] -> EndPointId -> IO [ByteString]
go [ByteString]
acc EndPointId
0 = forall (m :: * -> *) a. Monad m => a -> m a
return (forall a. [a] -> [a]
reverse [ByteString]
acc)
go [ByteString]
acc EndPointId
l = do
ByteString
bs <- Socket -> Int -> IO ByteString
NBS.recv Socket
sock (forall a b. (Integral a, Num b) => a -> b
fromIntegral EndPointId
l forall a. Ord a => a -> a -> a
`min` Int
smallChunkSize)
if ByteString -> Bool
BS.null ByteString
bs
then forall e a. Exception e => e -> IO a
throwIO (HostName -> IOError
userError HostName
"recvExact: Socket closed")
else [ByteString] -> EndPointId -> IO [ByteString]
go (ByteString
bs forall a. a -> [a] -> [a]
: [ByteString]
acc) (EndPointId
l forall a. Num a => a -> a -> a
- forall a b. (Integral a, Num b) => a -> b
fromIntegral (ByteString -> Int
BS.length ByteString
bs))
resolveSockAddr :: N.SockAddr -> IO (Maybe (N.HostName, N.HostName, N.ServiceName))
resolveSockAddr :: SockAddr -> IO (Maybe (HostName, HostName, HostName))
resolveSockAddr SockAddr
sockAddr = case SockAddr
sockAddr of
N.SockAddrInet PortNumber
port EndPointId
host -> do
(Maybe HostName
mResolvedHost, Maybe HostName
mResolvedPort) <- [NameInfoFlag]
-> Bool -> Bool -> SockAddr -> IO (Maybe HostName, Maybe HostName)
N.getNameInfo [] Bool
True Bool
False SockAddr
sockAddr
case (Maybe HostName
mResolvedHost, Maybe HostName
mResolvedPort) of
(Just HostName
resolvedHost, Maybe HostName
Nothing) -> do
(Just HostName
numericHost, Maybe HostName
_) <- [NameInfoFlag]
-> Bool -> Bool -> SockAddr -> IO (Maybe HostName, Maybe HostName)
N.getNameInfo [NameInfoFlag
N.NI_NUMERICHOST] Bool
True Bool
False SockAddr
sockAddr
forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just (HostName
numericHost, HostName
resolvedHost, forall a. Show a => a -> HostName
show PortNumber
port)
(Maybe HostName, Maybe HostName)
_ -> forall a. HasCallStack => HostName -> a
error forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [
HostName
"decodeSockAddr: unexpected resolution "
, forall a. Show a => a -> HostName
show SockAddr
sockAddr
, HostName
" -> "
, forall a. Show a => a -> HostName
show Maybe HostName
mResolvedHost
, HostName
", "
, forall a. Show a => a -> HostName
show Maybe HostName
mResolvedPort
]
SockAddr
_ -> forall (m :: * -> *) a. Monad m => a -> m a
return forall a. Maybe a
Nothing
encodeEndPointAddress :: N.HostName
-> N.ServiceName
-> EndPointId
-> EndPointAddress
encodeEndPointAddress :: HostName -> HostName -> EndPointId -> EndPointAddress
encodeEndPointAddress HostName
host HostName
port EndPointId
ix = ByteString -> EndPointAddress
EndPointAddress forall b c a. (b -> c) -> (a -> b) -> a -> c
. HostName -> ByteString
BSC.pack forall a b. (a -> b) -> a -> b
$
HostName
host forall a. [a] -> [a] -> [a]
++ HostName
":" forall a. [a] -> [a] -> [a]
++ HostName
port forall a. [a] -> [a] -> [a]
++ HostName
":" forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> HostName
show EndPointId
ix
decodeEndPointAddress :: EndPointAddress
-> Maybe (N.HostName, N.ServiceName, EndPointId)
decodeEndPointAddress :: EndPointAddress -> Maybe (HostName, HostName, EndPointId)
decodeEndPointAddress (EndPointAddress ByteString
bs) =
case forall a. (a -> Bool) -> Int -> [a] -> [[a]]
splitMaxFromEnd (forall a. Eq a => a -> a -> Bool
== Char
':') Int
2 forall a b. (a -> b) -> a -> b
$ ByteString -> HostName
BSC.unpack ByteString
bs of
[HostName
host, HostName
port, HostName
endPointIdStr] ->
case forall a. Read a => ReadS a
reads HostName
endPointIdStr of
[(EndPointId
endPointId, HostName
"")] -> forall a. a -> Maybe a
Just (HostName
host, HostName
port, EndPointId
endPointId)
[(EndPointId, HostName)]
_ -> forall a. Maybe a
Nothing
[HostName]
_ ->
forall a. Maybe a
Nothing
splitMaxFromEnd :: (a -> Bool) -> Int -> [a] -> [[a]]
splitMaxFromEnd :: forall a. (a -> Bool) -> Int -> [a] -> [[a]]
splitMaxFromEnd a -> Bool
p = \Int
n -> forall {t}. (Eq t, Num t) => [[a]] -> t -> [a] -> [[a]]
go [[]] Int
n forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. [a] -> [a]
reverse
where
go :: [[a]] -> t -> [a] -> [[a]]
go [[a]]
accs t
_ [] = [[a]]
accs
go ([] : [[a]]
accs) t
0 [a]
xs = forall a. [a] -> [a]
reverse [a]
xs forall a. a -> [a] -> [a]
: [[a]]
accs
go ([a]
acc : [[a]]
accs) t
n (a
x:[a]
xs) =
if a -> Bool
p a
x then [[a]] -> t -> [a] -> [[a]]
go ([] forall a. a -> [a] -> [a]
: [a]
acc forall a. a -> [a] -> [a]
: [[a]]
accs) (t
n forall a. Num a => a -> a -> a
- t
1) [a]
xs
else [[a]] -> t -> [a] -> [[a]]
go ((a
x forall a. a -> [a] -> [a]
: [a]
acc) forall a. a -> [a] -> [a]
: [[a]]
accs) t
n [a]
xs
go [[a]]
_ t
_ [a]
_ = forall a. HasCallStack => HostName -> a
error HostName
"Bug in splitMaxFromEnd"