module Network.Transport.TCP.Internal
( ControlHeader(..)
, encodeControlHeader
, decodeControlHeader
, ConnectionRequestResponse(..)
, encodeConnectionRequestResponse
, decodeConnectionRequestResponse
, forkServer
, recvWithLength
, recvExact
, recvWord32
, encodeWord32
, tryCloseSocket
, tryShutdownSocketBoth
, decodeSockAddr
, EndPointId
, encodeEndPointAddress
, decodeEndPointAddress
, ProtocolVersion
, currentProtocolVersion
) where
#if ! MIN_VERSION_base(4,6,0)
import Prelude hiding (catch)
#endif
import Network.Transport.Internal
( decodeWord32
, encodeWord32
, void
, tryIO
, forkIOWithUnmask
)
import Network.Transport ( EndPointAddress(..) )
#ifdef USE_MOCK_NETWORK
import qualified Network.Transport.TCP.Mock.Socket as N
#else
import qualified Network.Socket as N
#endif
( HostName
, ServiceName
, Socket
, SocketType(Stream)
, SocketOption(ReuseAddr)
, getAddrInfo
, defaultHints
, socket
, bindSocket
, listen
, addrFamily
, addrAddress
, defaultProtocol
, setSocketOption
, accept
, sClose
, socketPort
, shutdown
, ShutdownCmd(ShutdownBoth)
, SockAddr(..)
, inet_ntoa
)
#ifdef USE_MOCK_NETWORK
import qualified Network.Transport.TCP.Mock.Socket.ByteString as NBS (recv)
#else
import qualified Network.Socket.ByteString as NBS (recv)
#endif
import Data.Word (Word32)
import Control.Monad (forever, when)
import Control.Exception (SomeException, catch, bracketOnError, throwIO, mask_)
import Control.Concurrent (ThreadId, forkIO)
import Control.Concurrent.MVar
( MVar
, newEmptyMVar
, putMVar
, readMVar
)
import Control.Monad (forever, when)
import Control.Exception
( SomeException
, catch
, bracketOnError
, throwIO
, mask_
, mask
, finally
, onException
)
import Control.Applicative ((<$>), (<*>))
import Data.Word (Word32)
import Data.ByteString (ByteString)
import qualified Data.ByteString as BS (length, concat, null)
import Data.ByteString.Lazy.Internal (smallChunkSize)
import qualified Data.ByteString.Char8 as BSC (unpack, pack)
type EndPointId = Word32
type ProtocolVersion = Word32
currentProtocolVersion :: ProtocolVersion
currentProtocolVersion = 0x00000000
data ControlHeader =
CreatedNewConnection
| CloseConnection
| CloseSocket
| CloseEndPoint
| ProbeSocket
| ProbeSocketAck
deriving (Show)
decodeControlHeader :: Word32 -> Maybe ControlHeader
decodeControlHeader w32 = case w32 of
0 -> Just CreatedNewConnection
1 -> Just CloseConnection
2 -> Just CloseSocket
3 -> Just CloseEndPoint
4 -> Just ProbeSocket
5 -> Just ProbeSocketAck
_ -> Nothing
encodeControlHeader :: ControlHeader -> Word32
encodeControlHeader ch = case ch of
CreatedNewConnection -> 0
CloseConnection -> 1
CloseSocket -> 2
CloseEndPoint -> 3
ProbeSocket -> 4
ProbeSocketAck -> 5
data ConnectionRequestResponse =
ConnectionRequestUnsupportedVersion
| ConnectionRequestAccepted
| ConnectionRequestInvalid
| ConnectionRequestCrossed
| ConnectionRequestHostMismatch
deriving (Show)
decodeConnectionRequestResponse :: Word32 -> Maybe ConnectionRequestResponse
decodeConnectionRequestResponse w32 = case w32 of
0xFFFFFFFF -> Just ConnectionRequestUnsupportedVersion
0x00000000 -> Just ConnectionRequestAccepted
0x00000001 -> Just ConnectionRequestInvalid
0x00000002 -> Just ConnectionRequestCrossed
0x00000003 -> Just ConnectionRequestHostMismatch
_ -> Nothing
encodeConnectionRequestResponse :: ConnectionRequestResponse -> Word32
encodeConnectionRequestResponse crr = case crr of
ConnectionRequestUnsupportedVersion -> 0xFFFFFFFF
ConnectionRequestAccepted -> 0x00000000
ConnectionRequestInvalid -> 0x00000001
ConnectionRequestCrossed -> 0x00000002
ConnectionRequestHostMismatch -> 0x00000003
forkServer :: N.HostName
-> N.ServiceName
-> Int
-> Bool
-> (SomeException -> IO ())
-> (SomeException -> IO ())
-> (IO () -> (N.Socket, N.SockAddr) -> IO ())
-> IO (N.ServiceName, ThreadId)
forkServer host port backlog reuseAddr errorHandler terminationHandler requestHandler = do
addr:_ <- N.getAddrInfo (Just N.defaultHints) (Just host) (Just port)
bracketOnError (N.socket (N.addrFamily addr) N.Stream N.defaultProtocol)
tryCloseSocket $ \sock -> do
when reuseAddr $ N.setSocketOption sock N.ReuseAddr 1
N.bindSocket sock (N.addrAddress addr)
N.listen sock backlog
let release :: ((N.Socket, N.SockAddr), MVar ()) -> IO ()
release ((sock, _), socketClosed) =
N.sClose sock `finally` putMVar socketClosed ()
let act restore (sock, sockAddr) = do
socketClosed <- newEmptyMVar
void $ forkIO $ restore $ do
requestHandler (readMVar socketClosed) (sock, sockAddr)
`finally`
release ((sock, sockAddr), socketClosed)
let acceptRequest :: IO ()
acceptRequest = mask $ \restore -> do
(sock, sockAddr) <- N.accept sock
let handler :: SomeException -> IO ()
handler _ = N.sClose sock
catch (act restore (sock, sockAddr)) handler
(,) <$> fmap show (N.socketPort sock) <*>
(mask_ $ forkIOWithUnmask $ \unmask ->
catch (unmask (forever (catch acceptRequest errorHandler))) $ \ex -> do
tryCloseSocket sock
terminationHandler ex)
recvWithLength :: Word32 -> N.Socket -> IO [ByteString]
recvWithLength limit sock = do
len <- recvWord32 sock
when (len > limit) $
throwIO (userError "recvWithLength: limit exceeded")
recvExact sock len
recvWord32 :: N.Socket -> IO Word32
recvWord32 = fmap (decodeWord32 . BS.concat) . flip recvExact 4
tryCloseSocket :: N.Socket -> IO ()
tryCloseSocket sock = void . tryIO $
N.sClose sock
tryShutdownSocketBoth :: N.Socket -> IO ()
tryShutdownSocketBoth sock = void . tryIO $
N.shutdown sock N.ShutdownBoth
recvExact :: N.Socket
-> Word32
-> IO [ByteString]
recvExact sock len = go [] len
where
go :: [ByteString] -> Word32 -> IO [ByteString]
go acc 0 = return (reverse acc)
go acc l = do
bs <- NBS.recv sock (fromIntegral l `min` smallChunkSize)
if BS.null bs
then throwIO (userError "recvExact: Socket closed")
else go (bs : acc) (l fromIntegral (BS.length bs))
decodeSockAddr :: N.SockAddr -> IO (Maybe (N.HostName, N.ServiceName))
decodeSockAddr sockAddr = case sockAddr of
N.SockAddrInet port host -> do
hostString <- N.inet_ntoa host
return $ Just (hostString, show port)
_ -> return Nothing
encodeEndPointAddress :: N.HostName
-> N.ServiceName
-> EndPointId
-> EndPointAddress
encodeEndPointAddress host port ix = EndPointAddress . BSC.pack $
host ++ ":" ++ port ++ ":" ++ show ix
decodeEndPointAddress :: EndPointAddress
-> Maybe (N.HostName, N.ServiceName, EndPointId)
decodeEndPointAddress (EndPointAddress bs) =
case splitMaxFromEnd (== ':') 2 $ BSC.unpack bs of
[host, port, endPointIdStr] ->
case reads endPointIdStr of
[(endPointId, "")] -> Just (host, port, endPointId)
_ -> Nothing
_ ->
Nothing
splitMaxFromEnd :: (a -> Bool) -> Int -> [a] -> [[a]]
splitMaxFromEnd p = \n -> go [[]] n . reverse
where
go accs _ [] = accs
go ([] : accs) 0 xs = reverse xs : accs
go (acc : accs) n (x:xs) =
if p x then go ([] : acc : accs) (n 1) xs
else go ((x : acc) : accs) n xs
go _ _ _ = error "Bug in splitMaxFromEnd"