module SocketsAndPipes.Serve.Setup ( withSocketOnPort ) where

import SocketsAndPipes.Serve.Sockets
    ( PortNumber, Socket, PassiveSocket (..), closePassiveSocket )

import SocketsAndPipes.Serve.Exceptions
    ( BindFailed (..), AddrTried (..),
      overException, firstSuccessOrAllExceptions )

import Control.Monad ( (>=>), when )
import Data.Function ( on )

import qualified Control.Exception.Safe as Exception
import qualified Data.List              as List
import qualified Network.Socket         as Socket

withSocketOnPort :: PortNumber -> (PassiveSocket -> IO a) -> IO a
withSocketOnPort :: PortNumber -> (PassiveSocket -> IO a) -> IO a
withSocketOnPort PortNumber
port = IO PassiveSocket
-> (PassiveSocket -> IO ()) -> (PassiveSocket -> IO a) -> IO a
forall (m :: * -> *) a b c.
MonadMask m =>
m a -> (a -> m b) -> (a -> m c) -> m c
Exception.bracket (PortNumber -> IO PassiveSocket
bindToPort PortNumber
port) PassiveSocket -> IO ()
closePassiveSocket

bindToPort :: PortNumber -> IO PassiveSocket
bindToPort :: PortNumber -> IO PassiveSocket
bindToPort = PortNumber -> IO [AddrInfo]
addrsForPort (PortNumber -> IO [AddrInfo])
-> ([AddrInfo] -> IO PassiveSocket)
-> PortNumber
-> IO PassiveSocket
forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> [AddrInfo] -> IO PassiveSocket
chooseAddrAndBind

addrsForPort :: PortNumber -> IO [Socket.AddrInfo]
addrsForPort :: PortNumber -> IO [AddrInfo]
addrsForPort PortNumber
port = Maybe AddrInfo -> Maybe HostName -> Maybe HostName -> IO [AddrInfo]
Socket.getAddrInfo Maybe AddrInfo
hints Maybe HostName
hostName Maybe HostName
serviceName
  where
    hints :: Maybe AddrInfo
hints       = AddrInfo -> Maybe AddrInfo
forall a. a -> Maybe a
Just AddrInfo
serverAddrHints   :: Maybe Socket.AddrInfo
    hostName :: Maybe HostName
hostName    = Maybe HostName
forall a. Maybe a
Nothing                :: Maybe Socket.HostName
    serviceName :: Maybe HostName
serviceName = HostName -> Maybe HostName
forall a. a -> Maybe a
Just (PortNumber -> HostName
forall a. Show a => a -> HostName
show PortNumber
port)       :: Maybe Socket.ServiceName
{- ^
    The first thing we have to do when starting a server is figure
    out exactly what network address to listen on.

    We've been given a port number, but that's only half the story;
    a network address actually include a lot more obnoxious details
    in addition to the port number.

    'addrsForPort' uses the 'S.getAddrInfo' function from the network
    library to find a list of possible addresses for us to choose from.
-}

serverAddrHints :: Socket.AddrInfo
serverAddrHints :: AddrInfo
serverAddrHints =
    AddrInfo
Socket.defaultHints{ SocketType
addrSocketType :: SocketType
addrSocketType :: SocketType
Socket.addrSocketType, [AddrInfoFlag]
addrFlags :: [AddrInfoFlag]
addrFlags :: [AddrInfoFlag]
Socket.addrFlags }
  where
    addrSocketType :: SocketType
addrSocketType = SocketType
Socket.Stream :: Socket.SocketType
        {- A "stream" socket uses TCP to make sure all the
           packets arrive in the right order. -}
    addrFlags :: [AddrInfoFlag]
addrFlags = [AddrInfoFlag
Socket.AI_PASSIVE] :: [Socket.AddrInfoFlag]
        {- A "passive" socket is a socket that will be
           used to listen for incoming connections. -}

chooseAddrAndBind :: [Socket.AddrInfo] -> IO PassiveSocket
chooseAddrAndBind :: [AddrInfo] -> IO PassiveSocket
chooseAddrAndBind =
    ([AddrTried] -> BindFailed)
-> [IO PassiveSocket] -> IO PassiveSocket
forall e1 e2 a.
(Exception e1, Exception e2) =>
([e1] -> e2) -> [IO a] -> IO a
firstSuccessOrAllExceptions [AddrTried] -> BindFailed
BindFailed ([IO PassiveSocket] -> IO PassiveSocket)
-> ([AddrInfo] -> [IO PassiveSocket])
-> [AddrInfo]
-> IO PassiveSocket
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (AddrInfo -> IO PassiveSocket) -> [AddrInfo] -> [IO PassiveSocket]
forall a b. (a -> b) -> [a] -> [b]
map AddrInfo -> IO PassiveSocket
bindToAddr
    ([AddrInfo] -> [IO PassiveSocket])
-> ([AddrInfo] -> [AddrInfo]) -> [AddrInfo] -> [IO PassiveSocket]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (AddrInfo -> AddrInfo -> Ordering) -> [AddrInfo] -> [AddrInfo]
forall a. (a -> a -> Ordering) -> [a] -> [a]
List.sortBy (Int -> Int -> Ordering
forall a. Ord a => a -> a -> Ordering
compare (Int -> Int -> Ordering)
-> (AddrInfo -> Int) -> AddrInfo -> AddrInfo -> Ordering
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` AddrInfo -> Int
addrPreference)

addrPreference :: Socket.AddrInfo -> Int
addrPreference :: AddrInfo -> Int
addrPreference AddrInfo
addr =
    case AddrInfo -> Family
Socket.addrFamily AddrInfo
addr of
        Family
Socket.AF_INET6 -> Int
1 {- IPv6 is best, because these addresses can
                                accept both IPv4 and IPv6 connections. -}
        Family
Socket.AF_INET  -> Int
2 {- IPv4 is next best, if IPv6 is not supported. -}
        Family
_               -> Int
3 {- Other addressing schemes are unfamiliar. -}
{- ^
    Assigns a ranking to each address, indicating our relative preference.
    A lesser number indicates a more preferable address.
-}

bindToAddr :: Socket.AddrInfo -> IO PassiveSocket
bindToAddr :: AddrInfo -> IO PassiveSocket
bindToAddr AddrInfo
addr =
    (SomeException -> AddrTried)
-> IO PassiveSocket -> IO PassiveSocket
forall e1 e2 a.
(Exception e1, Exception e2) =>
(e1 -> e2) -> IO a -> IO a
overException (AddrInfo -> SomeException -> AddrTried
AddrTried AddrInfo
addr) (IO PassiveSocket -> IO PassiveSocket)
-> IO PassiveSocket -> IO PassiveSocket
forall a b. (a -> b) -> a -> b
$
        IO Socket
-> (Socket -> IO ())
-> (Socket -> IO PassiveSocket)
-> IO PassiveSocket
forall (m :: * -> *) a b c.
MonadMask m =>
m a -> (a -> m b) -> (a -> m c) -> m c
Exception.bracketOnError (AddrInfo -> IO Socket
Socket.openSocket AddrInfo
addr) Socket -> IO ()
Socket.close ((Socket -> IO PassiveSocket) -> IO PassiveSocket)
-> (Socket -> IO PassiveSocket) -> IO PassiveSocket
forall a b. (a -> b) -> a -> b
$ \Socket
s ->
            AddrInfo -> Socket -> IO ()
initServerSocket AddrInfo
addr Socket
s IO () -> IO PassiveSocket -> IO PassiveSocket
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> PassiveSocket -> IO PassiveSocket
forall (m :: * -> *) a. Monad m => a -> m a
return (Socket -> PassiveSocket
PassiveSocket Socket
s)

initServerSocket :: Socket.AddrInfo -> Socket -> IO ()
initServerSocket :: AddrInfo -> Socket -> IO ()
initServerSocket AddrInfo
addr Socket
s =
  do
    Socket -> IO ()
setReuseAddr Socket
s       -- Disable some safety to permit fast restarts.
    Socket -> IO ()
setKeepAlive Socket
s       -- Send empty packets to keep connections alive.
    Socket -> IO ()
setNoDelay Socket
s         -- Send bytes immediately without buffering.
    AddrInfo -> Socket -> IO ()
allowIPv4and6 AddrInfo
addr Socket
s -- If it's an IPv6 address, enable IPv4 also.
    AddrInfo -> Socket -> IO ()
bind AddrInfo
addr Socket
s          -- Assign the address to the socket.
    Socket -> IO ()
listen Socket
s             -- Announce willingness to receive connections.

bind :: Socket.AddrInfo -> Socket -> IO ()
bind :: AddrInfo -> Socket -> IO ()
bind AddrInfo
addr Socket
s = Socket -> SockAddr -> IO ()
Socket.bind Socket
s (AddrInfo -> SockAddr
Socket.addrAddress AddrInfo
addr)

allowIPv4and6 :: Socket.AddrInfo -> Socket -> IO ()
allowIPv4and6 :: AddrInfo -> Socket -> IO ()
allowIPv4and6 AddrInfo
addr Socket
s =
    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (AddrInfo -> Family
Socket.addrFamily AddrInfo
addr Family -> Family -> Bool
forall a. Eq a => a -> a -> Bool
== Family
Socket.AF_INET6) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
        Socket -> SocketOption -> Int -> IO ()
Socket.setSocketOption Socket
s SocketOption
Socket.IPv6Only Int
0

setReuseAddr :: Socket -> IO ()
setReuseAddr :: Socket -> IO ()
setReuseAddr Socket
s = Socket -> SocketOption -> Int -> IO ()
Socket.setSocketOption Socket
s SocketOption
Socket.ReuseAddr Int
1
{- ^
    By default, the operating system will not let us restart our server and
    bind to the same address immediately, because the new process will
    receive any TCP packets that were in flight during the restart, which
    is typically undesirable.

    Overriding the default behavior like this is not really safe!
    But it lets us restart our server quickly :)
-}

setKeepAlive :: Socket -> IO ()
setKeepAlive :: Socket -> IO ()
setKeepAlive Socket
s = Socket -> SocketOption -> Int -> IO ()
Socket.setSocketOption Socket
s SocketOption
Socket.KeepAlive Int
1
{- ^
    This enables a nice TCP feature: if there is a long period of time
    with no activity on the socket, the OS will occasionally send an
    empty packet. This has two benefits:

      1. It lets the peer know that we're still here; otherwise the peer will
         close the connection, assuming that we've abandoned the conversation.

      2. If lets us know whether the peer is still there. If we don't receive
         an acknowledgement of the empty packet, we can close the connection.
-}

setNoDelay :: Socket -> IO ()
setNoDelay :: Socket -> IO ()
setNoDelay Socket
s = Socket -> SocketOption -> Int -> IO ()
Socket.setSocketOption Socket
s SocketOption
Socket.NoDelay Int
1
{- ^

    Since it's more efficient to transmit a few large packets than many
    small packets, the OS doesn't always send your bytes right away when
    you write to a socket; By default, it make some effort to group
    together small writes into larger packets.

    The downside of this optimization is that it means sometimes we don't
    immediately see the effect of writing to a socket. For experimental
    and demonstration purposes, this can be quite undesirable, so we use
    this setting to disable the feature.
-}

listen :: Socket -> IO ()
listen :: Socket -> IO ()
listen Socket
s = Socket -> Int -> IO ()
Socket.listen Socket
s Int
listenBacklog
{- ^
    Informs the operating system that this socket will be used to
    accept incoming connection requests.

    Such as socket is called a "passive socket".
-}

listenBacklog :: Int
listenBacklog :: Int
listenBacklog = Int
1024
{- ^
    The 'S.accept' function pulls sockets from a queue maintained by
    the operating system. This is the size we are requesting for that queue.
    (The OS might not actually give us a queue as big as we ask for.)
-}