{-# LANGUAGE OverloadedStrings #-}

-- | Simple functions to run TCP clients and servers.
module Network.Run.TCP (
    runTCPClient,
    runTCPServer,

    -- * Generalized API
    runTCPServerWithSocket,
    openServerSocket,
) where

import Control.Concurrent (forkFinally)
import qualified Control.Exception as E
import Control.Monad (forever, void)
import Network.Socket

import Network.Run.Core

-- | Running a TCP client with a connected socket.
runTCPClient :: HostName -> ServiceName -> (Socket -> IO a) -> IO a
runTCPClient :: forall a. HostName -> HostName -> (Socket -> IO a) -> IO a
runTCPClient HostName
host HostName
port Socket -> IO a
client = IO a -> IO a
forall a. IO a -> IO a
withSocketsDo (IO a -> IO a) -> IO a -> IO a
forall a b. (a -> b) -> a -> b
$ do
    AddrInfo
addr <- SocketType
-> Maybe HostName -> HostName -> [AddrInfoFlag] -> IO AddrInfo
resolve SocketType
Stream (HostName -> Maybe HostName
forall a. a -> Maybe a
Just HostName
host) HostName
port [AddrInfoFlag
AI_ADDRCONFIG]
    IO Socket -> (Socket -> IO ()) -> (Socket -> IO a) -> IO a
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
E.bracket (AddrInfo -> IO Socket
open AddrInfo
addr) Socket -> IO ()
gclose Socket -> IO a
client
  where
    open :: AddrInfo -> IO Socket
open AddrInfo
addr = IO Socket
-> (Socket -> IO ()) -> (Socket -> IO Socket) -> IO Socket
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
E.bracketOnError (AddrInfo -> IO Socket
openClientSocket AddrInfo
addr) Socket -> IO ()
close Socket -> IO Socket
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return

-- | Running a TCP server with an accepted socket and its peer name.
runTCPServer :: Maybe HostName -> ServiceName -> (Socket -> IO a) -> IO a
runTCPServer :: forall a. Maybe HostName -> HostName -> (Socket -> IO a) -> IO a
runTCPServer = (AddrInfo -> IO Socket)
-> Maybe HostName -> HostName -> (Socket -> IO a) -> IO a
forall a.
(AddrInfo -> IO Socket)
-> Maybe HostName -> HostName -> (Socket -> IO a) -> IO a
runTCPServerWithSocket AddrInfo -> IO Socket
openServerSocket

----------------------------------------------------------------
-- Generalized API

-- | Generalization of 'runTCPServer'
runTCPServerWithSocket
    :: (AddrInfo -> IO Socket)
    -- ^ Initialize socket.
    --
    -- This function is called while exceptions are masked.
    --
    -- The default (used by 'runTCPServer') is 'openServerSocket'.
    -> Maybe HostName
    -> ServiceName
    -> (Socket -> IO a)
    -- ^ Called for each incoming connection, in a new thread
    -> IO a
runTCPServerWithSocket :: forall a.
(AddrInfo -> IO Socket)
-> Maybe HostName -> HostName -> (Socket -> IO a) -> IO a
runTCPServerWithSocket AddrInfo -> IO Socket
initSocket Maybe HostName
mhost HostName
port Socket -> IO a
server = IO a -> IO a
forall a. IO a -> IO a
withSocketsDo (IO a -> IO a) -> IO a -> IO a
forall a b. (a -> b) -> a -> b
$ do
    AddrInfo
addr <- SocketType
-> Maybe HostName -> HostName -> [AddrInfoFlag] -> IO AddrInfo
resolve SocketType
Stream Maybe HostName
mhost HostName
port [AddrInfoFlag
AI_PASSIVE]
    IO Socket -> (Socket -> IO ()) -> (Socket -> IO a) -> IO a
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
E.bracket (AddrInfo -> IO Socket
open AddrInfo
addr) Socket -> IO ()
close Socket -> IO a
forall {b}. Socket -> IO b
loop
  where
    open :: AddrInfo -> IO Socket
open AddrInfo
addr = IO Socket
-> (Socket -> IO ()) -> (Socket -> IO Socket) -> IO Socket
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
E.bracketOnError (AddrInfo -> IO Socket
initSocket AddrInfo
addr) Socket -> IO ()
close ((Socket -> IO Socket) -> IO Socket)
-> (Socket -> IO Socket) -> IO Socket
forall a b. (a -> b) -> a -> b
$ \Socket
sock -> do
        Socket -> Int -> IO ()
listen Socket
sock Int
1024
        Socket -> IO Socket
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Socket
sock
    loop :: Socket -> IO b
loop Socket
sock = IO () -> IO b
forall (f :: * -> *) a b. Applicative f => f a -> f b
forever (IO () -> IO b) -> IO () -> IO b
forall a b. (a -> b) -> a -> b
$
        IO (Socket, SockAddr)
-> ((Socket, SockAddr) -> IO ())
-> ((Socket, SockAddr) -> IO ())
-> IO ()
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
E.bracketOnError (Socket -> IO (Socket, SockAddr)
accept Socket
sock) (Socket -> IO ()
close (Socket -> IO ())
-> ((Socket, SockAddr) -> Socket) -> (Socket, SockAddr) -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Socket, SockAddr) -> Socket
forall a b. (a, b) -> a
fst) (((Socket, SockAddr) -> IO ()) -> IO ())
-> ((Socket, SockAddr) -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$
            \(Socket
conn, SockAddr
_peer) ->
                IO ThreadId -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO ThreadId -> IO ()) -> IO ThreadId -> IO ()
forall a b. (a -> b) -> a -> b
$ IO a -> (Either SomeException a -> IO ()) -> IO ThreadId
forall a. IO a -> (Either SomeException a -> IO ()) -> IO ThreadId
forkFinally (Socket -> IO a
server Socket
conn) (IO () -> Either SomeException a -> IO ()
forall a b. a -> b -> a
const (IO () -> Either SomeException a -> IO ())
-> IO () -> Either SomeException a -> IO ()
forall a b. (a -> b) -> a -> b
$ Socket -> IO ()
gclose Socket
conn)