{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE CPP #-}

-- | This module provides convenience functions for interfacing raw tcp.
--
-- Please use 'E.bracket' or its friends to ensure exception safety.
--
-- This module is intended to be imported @qualified@, e.g.:
--
-- @
-- import           "Data.Connection"
-- import qualified "System.IO.Streams.TCP" as TCP
-- @
--
module System.IO.Streams.TCP
  ( TCPConnection
    -- * client
  , connect
  , connectSocket
  , socketToConnection
  , defaultChunkSize
    -- * server
  , bindAndListen
  , bindAndListenWith
  , accept
  , acceptWith
  ) where

import qualified Control.Exception         as E
import           Control.Monad
import           Data.Connection
import qualified Data.ByteString           as B
import qualified Data.ByteString.Lazy.Internal as L
import qualified Network.Socket            as N
import qualified Network.Socket.ByteString as NB
import qualified Network.Socket.ByteString.Lazy as NL
import qualified System.IO.Streams         as S
import           Foreign.Storable   (sizeOf)

addrAny :: N.HostAddress
#if MIN_VERSION_network(2,7,0)
addrAny :: HostAddress
addrAny = (Word8, Word8, Word8, Word8) -> HostAddress
N.tupleToHostAddress (Word8
0,Word8
0,Word8
0,Word8
0)
#else
addrAny = N.iNADDR_ANY
#endif

-- | Type alias for tcp connection.
--
-- Normally you shouldn't use 'N.Socket' in 'connExtraInfo' directly, this field is
-- intend for used with 'N.setSocketOption' if you need to.
--
type TCPConnection = Connection (N.Socket, N.SockAddr)

-- | The chunk size used for I\/O, less the memory management overhead.
--
-- Currently set to 32k.
--
defaultChunkSize :: Int
defaultChunkSize :: Int
defaultChunkSize = Int
32 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
chunkOverhead
  where
    k :: Int
k = Int
1024
    chunkOverhead :: Int
chunkOverhead = Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int -> Int
forall a. Storable a => a -> Int
sizeOf (Int
forall a. HasCallStack => a
undefined :: Int)

-- | Initiating an raw TCP connection to the given @('HostName', 'PortNumber')@ combination.
--
-- It use 'N.getAddrInfo' to resolve host/service name
-- with 'N.AI_ADDRCONFIG', 'N.AI_NUMERICSERV' hint set, so it should be able to
-- resolve both numeric IPv4/IPv6 hostname and domain name.
--
-- `TCP_NODELAY` are enabled by default. you can use 'N.setSocketOption' to adjust.
--
connectSocket :: N.HostName             -- ^ hostname to connect to
              -> N.PortNumber           -- ^ port number to connect to
              -> IO (N.Socket, N.SockAddr)
connectSocket :: HostName -> PortNumber -> IO (Socket, SockAddr)
connectSocket HostName
host PortNumber
port = do
    (Family
family, SocketType
socketType, ProtocolNumber
protocol, SockAddr
addr) <- HostName
-> PortNumber -> IO (Family, SocketType, ProtocolNumber, SockAddr)
forall {a}.
Show a =>
HostName -> a -> IO (Family, SocketType, ProtocolNumber, SockAddr)
resolveAddrInfo HostName
host PortNumber
port
    IO Socket
-> (Socket -> IO ())
-> (Socket -> IO (Socket, SockAddr))
-> IO (Socket, SockAddr)
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
E.bracketOnError (Family -> SocketType -> ProtocolNumber -> IO Socket
N.socket Family
family SocketType
socketType ProtocolNumber
protocol)
                     Socket -> IO ()
N.close
                     (\Socket
sock -> do Socket -> SockAddr -> IO ()
N.connect Socket
sock SockAddr
addr
                                  Socket -> SocketOption -> Int -> IO ()
N.setSocketOption Socket
sock SocketOption
N.NoDelay Int
1
                                  (Socket, SockAddr) -> IO (Socket, SockAddr)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Socket
sock, SockAddr
addr)
                     )
  where
    resolveAddrInfo :: HostName -> a -> IO (Family, SocketType, ProtocolNumber, SockAddr)
resolveAddrInfo HostName
host' a
port' = do
        -- Partial function here OK, network will throw an exception rather than
        -- return the empty list here.
        (AddrInfo
addrInfo:[AddrInfo]
_) <- Maybe AddrInfo -> Maybe HostName -> Maybe HostName -> IO [AddrInfo]
N.getAddrInfo (AddrInfo -> Maybe AddrInfo
forall a. a -> Maybe a
Just AddrInfo
hints) (HostName -> Maybe HostName
forall a. a -> Maybe a
Just HostName
host') (HostName -> Maybe HostName
forall a. a -> Maybe a
Just (HostName -> Maybe HostName) -> HostName -> Maybe HostName
forall a b. (a -> b) -> a -> b
$ a -> HostName
forall a. Show a => a -> HostName
show a
port')
        let family :: Family
family     = AddrInfo -> Family
N.addrFamily AddrInfo
addrInfo
        let socketType :: SocketType
socketType = AddrInfo -> SocketType
N.addrSocketType AddrInfo
addrInfo
        let protocol :: ProtocolNumber
protocol   = AddrInfo -> ProtocolNumber
N.addrProtocol AddrInfo
addrInfo
        let addr :: SockAddr
addr    = AddrInfo -> SockAddr
N.addrAddress AddrInfo
addrInfo
        (Family, SocketType, ProtocolNumber, SockAddr)
-> IO (Family, SocketType, ProtocolNumber, SockAddr)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Family
family, SocketType
socketType, ProtocolNumber
protocol, SockAddr
addr)
      where
        hints :: AddrInfo
hints = AddrInfo
N.defaultHints {
                N.addrFlags      = [N.AI_ADDRCONFIG, N.AI_NUMERICSERV]
            ,   N.addrSocketType = N.Stream
            }
    {-# INLINABLE resolveAddrInfo #-}

-- | Make a 'Connection' from a 'Socket' with given buffer size.
--
socketToConnection
    :: Int                      -- ^ receive buffer size
    -> (N.Socket, N.SockAddr)   -- ^ socket address pair
    -> IO TCPConnection
socketToConnection :: Int -> (Socket, SockAddr) -> IO TCPConnection
socketToConnection Int
bufsiz (Socket
sock, SockAddr
addr) = do
    InputStream ByteString
is <- IO (Maybe ByteString) -> IO (InputStream ByteString)
forall a. IO (Maybe a) -> IO (InputStream a)
S.makeInputStream (IO (Maybe ByteString) -> IO (InputStream ByteString))
-> IO (Maybe ByteString) -> IO (InputStream ByteString)
forall a b. (a -> b) -> a -> b
$ do
        ByteString
s <- Socket -> Int -> IO ByteString
NB.recv Socket
sock Int
bufsiz
        Maybe ByteString -> IO (Maybe ByteString)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe ByteString -> IO (Maybe ByteString))
-> Maybe ByteString -> IO (Maybe ByteString)
forall a b. (a -> b) -> a -> b
$! if ByteString -> Bool
B.null ByteString
s then Maybe ByteString
forall a. Maybe a
Nothing else ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just ByteString
s
    TCPConnection -> IO TCPConnection
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (InputStream ByteString
-> (ByteString -> IO ())
-> IO ()
-> (Socket, SockAddr)
-> TCPConnection
forall a.
InputStream ByteString
-> (ByteString -> IO ()) -> IO () -> a -> Connection a
Connection InputStream ByteString
is (Socket -> ByteString -> IO ()
send' Socket
sock) (Socket -> IO ()
N.close Socket
sock) (Socket
sock, SockAddr
addr))
  where
    send' :: Socket -> ByteString -> IO ()
send' Socket
_    (ByteString
L.Empty) = () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    send' Socket
sock' (L.Chunk ByteString
bs ByteString
L.Empty) = Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (ByteString -> Bool
B.null ByteString
bs) (Socket -> ByteString -> IO ()
NB.sendAll Socket
sock' ByteString
bs)
    send' Socket
sock' ByteString
lbs = Socket -> ByteString -> IO ()
NL.sendAll Socket
sock' ByteString
lbs

-- | Connect to server using 'defaultChunkSize'.
--
connect :: N.HostName             -- ^ hostname to connect to
        -> N.PortNumber           -- ^ port number to connect to
        -> IO TCPConnection
connect :: HostName -> PortNumber -> IO TCPConnection
connect HostName
host PortNumber
port = HostName -> PortNumber -> IO (Socket, SockAddr)
connectSocket HostName
host PortNumber
port IO (Socket, SockAddr)
-> ((Socket, SockAddr) -> IO TCPConnection) -> IO TCPConnection
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Int -> (Socket, SockAddr) -> IO TCPConnection
socketToConnection Int
defaultChunkSize

-- | Bind and listen on port with a limit on connection count.
--
-- This function will set @SO_REUSEADDR@, @TCP_NODELAY@ before binding.
--
bindAndListen :: Int                 -- ^ connection limit
              -> N.PortNumber        -- ^ port number
              -> IO N.Socket
bindAndListen :: Int -> PortNumber -> IO Socket
bindAndListen = (Socket -> IO ()) -> Int -> PortNumber -> IO Socket
bindAndListenWith ((Socket -> IO ()) -> Int -> PortNumber -> IO Socket)
-> (Socket -> IO ()) -> Int -> PortNumber -> IO Socket
forall a b. (a -> b) -> a -> b
$ \ Socket
sock -> do
    Socket -> SocketOption -> Int -> IO ()
N.setSocketOption Socket
sock SocketOption
N.ReuseAddr Int
1
    Socket -> SocketOption -> Int -> IO ()
N.setSocketOption Socket
sock SocketOption
N.NoDelay Int
1

-- | Bind and listen on port with a limit on connection count.
--
-- Note: The following socket options are inherited by a connected TCP socket from the listening socket:
--
-- @
-- SO_DEBUG
-- SO_DONTROUTE
-- SO_KEEPALIVE
-- SO_LINGER
-- SO_OOBINLINE
-- SO_RCVBUF
-- SO_RCVLOWAT
-- SO_SNDBUF
-- SO_SNDLOWAT
-- TCP_MAXSEG
-- TCP_NODELAY
-- @
--
bindAndListenWith :: (N.Socket -> IO ()) -- ^ set socket options before binding
                  -> Int                 -- ^ connection limit
                  -> N.PortNumber        -- ^ port number
                  -> IO N.Socket
bindAndListenWith :: (Socket -> IO ()) -> Int -> PortNumber -> IO Socket
bindAndListenWith Socket -> IO ()
f Int
maxc PortNumber
port =
    IO Socket
-> (Socket -> IO ()) -> (Socket -> IO Socket) -> IO Socket
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
E.bracketOnError (Family -> SocketType -> ProtocolNumber -> IO Socket
N.socket Family
N.AF_INET SocketType
N.Stream ProtocolNumber
0)
                     Socket -> IO ()
N.close
                     (\Socket
sock -> do Socket -> IO ()
f Socket
sock
                                  Socket -> SockAddr -> IO ()
N.bind Socket
sock (PortNumber -> HostAddress -> SockAddr
N.SockAddrInet PortNumber
port HostAddress
addrAny)
                                  Socket -> Int -> IO ()
N.listen Socket
sock Int
maxc
                                  Socket -> IO Socket
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Socket
sock
                     )

-- | Accept a connection with 'defaultChunkSize'.
--
accept :: N.Socket -> IO TCPConnection
accept :: Socket -> IO TCPConnection
accept = ((Socket, SockAddr) -> IO TCPConnection)
-> Socket -> IO TCPConnection
acceptWith (Int -> (Socket, SockAddr) -> IO TCPConnection
socketToConnection Int
defaultChunkSize)

-- | Accept a connection with user customization.
--
acceptWith :: ((N.Socket, N.SockAddr) -> IO TCPConnection) -- ^ set socket options, adjust receive buffer, etc.
           -> N.Socket
           -> IO TCPConnection
acceptWith :: ((Socket, SockAddr) -> IO TCPConnection)
-> Socket -> IO TCPConnection
acceptWith (Socket, SockAddr) -> IO TCPConnection
f = (Socket, SockAddr) -> IO TCPConnection
f ((Socket, SockAddr) -> IO TCPConnection)
-> (Socket -> IO (Socket, SockAddr)) -> Socket -> IO TCPConnection
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< Socket -> IO (Socket, SockAddr)
N.accept