module Network.Transport.TCP.Internal
( ControlHeader(..)
, encodeControlHeader
, decodeControlHeader
, ConnectionRequestResponse(..)
, encodeConnectionRequestResponse
, decodeConnectionRequestResponse
, forkServer
, recvWithLength
, recvExact
, recvWord32
, encodeWord32
, tryCloseSocket
, tryShutdownSocketBoth
, resolveSockAddr
, EndPointId
, encodeEndPointAddress
, decodeEndPointAddress
, randomEndPointAddress
, ProtocolVersion
, currentProtocolVersion
) where
#if ! MIN_VERSION_base(4,6,0)
import Prelude hiding (catch)
#endif
import Network.Transport.Internal
( decodeWord32
, encodeWord32
, void
, tryIO
, forkIOWithUnmask
)
import Network.Transport ( EndPointAddress(..) )
#ifdef USE_MOCK_NETWORK
import qualified Network.Transport.TCP.Mock.Socket as N
#else
import qualified Network.Socket as N
#endif
( HostName
, NameInfoFlag(NI_NUMERICHOST)
, ServiceName
, Socket
, SocketType(Stream)
, SocketOption(ReuseAddr)
, getAddrInfo
, defaultHints
, socket
, bind
, listen
, addrFamily
, addrAddress
, defaultProtocol
, setSocketOption
, accept
, close
, socketPort
, shutdown
, ShutdownCmd(ShutdownBoth)
, SockAddr(..)
, getNameInfo
)
#ifdef USE_MOCK_NETWORK
import qualified Network.Transport.TCP.Mock.Socket.ByteString as NBS (recv)
#else
import qualified Network.Socket.ByteString as NBS (recv)
#endif
import Data.Word (Word32, Word64)
import Control.Monad (forever, when)
import Control.Exception (SomeException, catch, bracketOnError, throwIO, mask_)
import Control.Concurrent (ThreadId, forkIO)
import Control.Concurrent.MVar
( MVar
, newEmptyMVar
, putMVar
, readMVar
)
import Control.Monad (forever, when)
import Control.Exception
( SomeException
, catch
, bracketOnError
, throwIO
, mask_
, mask
, finally
, onException
)
import Control.Applicative ((<$>), (<*>))
import Data.Word (Word32)
import Data.ByteString (ByteString)
import qualified Data.ByteString as BS (length, concat, null)
import Data.ByteString.Lazy.Internal (smallChunkSize)
import Data.ByteString.Lazy (toStrict)
import qualified Data.ByteString.Char8 as BSC (unpack, pack)
import Data.ByteString.Lazy.Builder (word64BE, toLazyByteString)
import Data.Monoid ((<>))
import qualified Data.UUID as UUID
import qualified Data.UUID.V4 as UUID
type EndPointId = Word32
type ProtocolVersion = Word32
currentProtocolVersion :: ProtocolVersion
currentProtocolVersion = 0x00000000
data ControlHeader =
CreatedNewConnection
| CloseConnection
| CloseSocket
| CloseEndPoint
| ProbeSocket
| ProbeSocketAck
deriving (Show)
decodeControlHeader :: Word32 -> Maybe ControlHeader
decodeControlHeader w32 = case w32 of
0 -> Just CreatedNewConnection
1 -> Just CloseConnection
2 -> Just CloseSocket
3 -> Just CloseEndPoint
4 -> Just ProbeSocket
5 -> Just ProbeSocketAck
_ -> Nothing
encodeControlHeader :: ControlHeader -> Word32
encodeControlHeader ch = case ch of
CreatedNewConnection -> 0
CloseConnection -> 1
CloseSocket -> 2
CloseEndPoint -> 3
ProbeSocket -> 4
ProbeSocketAck -> 5
data ConnectionRequestResponse =
ConnectionRequestUnsupportedVersion
| ConnectionRequestAccepted
| ConnectionRequestInvalid
| ConnectionRequestCrossed
| ConnectionRequestHostMismatch
deriving (Show)
decodeConnectionRequestResponse :: Word32 -> Maybe ConnectionRequestResponse
decodeConnectionRequestResponse w32 = case w32 of
0xFFFFFFFF -> Just ConnectionRequestUnsupportedVersion
0x00000000 -> Just ConnectionRequestAccepted
0x00000001 -> Just ConnectionRequestInvalid
0x00000002 -> Just ConnectionRequestCrossed
0x00000003 -> Just ConnectionRequestHostMismatch
_ -> Nothing
encodeConnectionRequestResponse :: ConnectionRequestResponse -> Word32
encodeConnectionRequestResponse crr = case crr of
ConnectionRequestUnsupportedVersion -> 0xFFFFFFFF
ConnectionRequestAccepted -> 0x00000000
ConnectionRequestInvalid -> 0x00000001
ConnectionRequestCrossed -> 0x00000002
ConnectionRequestHostMismatch -> 0x00000003
randomEndPointAddress :: IO EndPointAddress
randomEndPointAddress = do
uuid <- UUID.nextRandom
return $ EndPointAddress (UUID.toASCIIBytes uuid)
forkServer :: N.HostName
-> N.ServiceName
-> Int
-> Bool
-> (SomeException -> IO ())
-> (SomeException -> IO ())
-> (IO () -> (N.Socket, N.SockAddr) -> IO ())
-> IO (N.ServiceName, ThreadId)
forkServer host port backlog reuseAddr errorHandler terminationHandler requestHandler = do
addr:_ <- N.getAddrInfo (Just N.defaultHints) (Just host) (Just port)
bracketOnError (N.socket (N.addrFamily addr) N.Stream N.defaultProtocol)
tryCloseSocket $ \sock -> do
when reuseAddr $ N.setSocketOption sock N.ReuseAddr 1
N.bind sock (N.addrAddress addr)
N.listen sock backlog
let release :: ((N.Socket, N.SockAddr), MVar ()) -> IO ()
release ((sock, _), socketClosed) =
N.close sock `finally` putMVar socketClosed ()
let act restore (sock, sockAddr) = do
socketClosed <- newEmptyMVar
void $ forkIO $ restore $ do
requestHandler (readMVar socketClosed) (sock, sockAddr)
`finally`
release ((sock, sockAddr), socketClosed)
let acceptRequest :: IO ()
acceptRequest = mask $ \restore -> do
(sock, sockAddr) <- N.accept sock
let handler :: SomeException -> IO ()
handler _ = N.close sock
catch (act restore (sock, sockAddr)) handler
(,) <$> fmap show (N.socketPort sock) <*>
(mask_ $ forkIOWithUnmask $ \unmask ->
catch (unmask (forever (catch acceptRequest errorHandler))) $ \ex -> do
tryCloseSocket sock
terminationHandler ex)
recvWithLength :: Word32 -> N.Socket -> IO [ByteString]
recvWithLength limit sock = do
len <- recvWord32 sock
when (len > limit) $
throwIO (userError "recvWithLength: limit exceeded")
recvExact sock len
recvWord32 :: N.Socket -> IO Word32
recvWord32 = fmap (decodeWord32 . BS.concat) . flip recvExact 4
tryCloseSocket :: N.Socket -> IO ()
tryCloseSocket sock = void . tryIO $
N.close sock
tryShutdownSocketBoth :: N.Socket -> IO ()
tryShutdownSocketBoth sock = void . tryIO $
N.shutdown sock N.ShutdownBoth
recvExact :: N.Socket
-> Word32
-> IO [ByteString]
recvExact sock len = go [] len
where
go :: [ByteString] -> Word32 -> IO [ByteString]
go acc 0 = return (reverse acc)
go acc l = do
bs <- NBS.recv sock (fromIntegral l `min` smallChunkSize)
if BS.null bs
then throwIO (userError "recvExact: Socket closed")
else go (bs : acc) (l - fromIntegral (BS.length bs))
resolveSockAddr :: N.SockAddr -> IO (Maybe (N.HostName, N.HostName, N.ServiceName))
resolveSockAddr sockAddr = case sockAddr of
N.SockAddrInet port host -> do
(mResolvedHost, mResolvedPort) <- N.getNameInfo [] True False sockAddr
case (mResolvedHost, mResolvedPort) of
(Just resolvedHost, Nothing) -> do
(Just numericHost, _) <- N.getNameInfo [N.NI_NUMERICHOST] True False sockAddr
return $ Just (numericHost, resolvedHost, show port)
_ -> error $ concat [
"decodeSockAddr: unexpected resolution "
, show sockAddr
, " -> "
, show mResolvedHost
, ", "
, show mResolvedPort
]
_ -> return Nothing
encodeEndPointAddress :: N.HostName
-> N.ServiceName
-> EndPointId
-> EndPointAddress
encodeEndPointAddress host port ix = EndPointAddress . BSC.pack $
host ++ ":" ++ port ++ ":" ++ show ix
decodeEndPointAddress :: EndPointAddress
-> Maybe (N.HostName, N.ServiceName, EndPointId)
decodeEndPointAddress (EndPointAddress bs) =
case splitMaxFromEnd (== ':') 2 $ BSC.unpack bs of
[host, port, endPointIdStr] ->
case reads endPointIdStr of
[(endPointId, "")] -> Just (host, port, endPointId)
_ -> Nothing
_ ->
Nothing
splitMaxFromEnd :: (a -> Bool) -> Int -> [a] -> [[a]]
splitMaxFromEnd p = \n -> go [[]] n . reverse
where
go accs _ [] = accs
go ([] : accs) 0 xs = reverse xs : accs
go (acc : accs) n (x:xs) =
if p x then go ([] : acc : accs) (n - 1) xs
else go ((x : acc) : accs) n xs
go _ _ _ = error "Bug in splitMaxFromEnd"