{-# LANGUAGE CPP #-}

module Lambdabot.Util.Network (
    connectTo',
) where

import Network.Socket
import Network.BSD
import System.IO
import Control.Exception

-- |This is essentially a reimplementation of the former Network.connectTo
--  function, except that we don't do the service name lookup.

-- Code originally from the network package.
connectTo' :: HostName -> PortNumber -> IO Handle
connectTo' :: HostName -> PortNumber -> IO Handle
connectTo' HostName
host PortNumber
port = do
    ProtocolNumber
proto <- HostName -> IO ProtocolNumber
getProtocolNumber HostName
"tcp"
    let hints :: AddrInfo
hints = AddrInfo
defaultHints { addrFlags :: [AddrInfoFlag]
addrFlags = [AddrInfoFlag
AI_ADDRCONFIG]
                             , addrProtocol :: ProtocolNumber
addrProtocol = ProtocolNumber
proto
                             , addrSocketType :: SocketType
addrSocketType = SocketType
Stream }
    [AddrInfo]
addrs <- Maybe AddrInfo -> Maybe HostName -> Maybe HostName -> IO [AddrInfo]
getAddrInfo (AddrInfo -> Maybe AddrInfo
forall a. a -> Maybe a
Just AddrInfo
hints) (HostName -> Maybe HostName
forall a. a -> Maybe a
Just HostName
host) (HostName -> Maybe HostName
forall a. a -> Maybe a
Just (PortNumber -> HostName
forall a. Show a => a -> HostName
show PortNumber
port))
    [IO Handle] -> IO Handle
forall a. [IO a] -> IO a
firstSuccessful ([IO Handle] -> IO Handle) -> [IO Handle] -> IO Handle
forall a b. (a -> b) -> a -> b
$ (AddrInfo -> IO Handle) -> [AddrInfo] -> [IO Handle]
forall a b. (a -> b) -> [a] -> [b]
map AddrInfo -> IO Handle
tryToConnect [AddrInfo]
addrs
  where
    tryToConnect :: AddrInfo -> IO Handle
tryToConnect AddrInfo
addr =
      IO Socket
-> (Socket -> IO ()) -> (Socket -> IO Handle) -> IO Handle
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracketOnError
          (Family -> SocketType -> ProtocolNumber -> IO Socket
socket (AddrInfo -> Family
addrFamily AddrInfo
addr) (AddrInfo -> SocketType
addrSocketType AddrInfo
addr) (AddrInfo -> ProtocolNumber
addrProtocol AddrInfo
addr))
          (Socket -> IO ()
close)  -- only done if there's an error
          (\Socket
sock -> do
            Socket -> SockAddr -> IO ()
connect Socket
sock (AddrInfo -> SockAddr
addrAddress AddrInfo
addr)
            Socket -> IOMode -> IO Handle
socketToHandle Socket
sock IOMode
ReadWriteMode
          )
    firstSuccessful :: [IO a] -> IO a
firstSuccessful = [IOException] -> [IO a] -> IO a
forall a. [IOException] -> [IO a] -> IO a
go []
      where
        go :: [IOException] -> [IO a] -> IO a
        go :: [IOException] -> [IO a] -> IO a
go []      [] = IOException -> IO a
forall a. IOException -> IO a
ioError (IOException -> IO a)
-> (HostName -> IOException) -> HostName -> IO a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HostName -> IOException
userError (HostName -> IO a) -> HostName -> IO a
forall a b. (a -> b) -> a -> b
$ HostName
"host name `" HostName -> HostName -> HostName
forall a. [a] -> [a] -> [a]
++ HostName -> HostName
forall a. Show a => a -> HostName
show HostName
host HostName -> HostName -> HostName
forall a. [a] -> [a] -> [a]
++
                        HostName
"` could not be resolved"
        go l :: [IOException]
l@(IOException
_:[IOException]
_) [] = IOException -> IO a
forall a. IOException -> IO a
ioError (IOException -> IO a)
-> (HostName -> IOException) -> HostName -> IO a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HostName -> IOException
userError (HostName -> IO a) -> HostName -> IO a
forall a b. (a -> b) -> a -> b
$ HostName
"could not connect to host `" HostName -> HostName -> HostName
forall a. [a] -> [a] -> [a]
++
                        HostName -> HostName
forall a. Show a => a -> HostName
show HostName
host
        go [IOException]
acc     (IO a
act:[IO a]
followingActs) = do
            Either IOException a
er <- IO a -> IO (Either IOException a)
forall e a. Exception e => IO a -> IO (Either e a)
try IO a
act
            case Either IOException a
er of
                Left IOException
err -> [IOException] -> [IO a] -> IO a
forall a. [IOException] -> [IO a] -> IO a
go (IOException
errIOException -> [IOException] -> [IOException]
forall a. a -> [a] -> [a]
:[IOException]
acc) [IO a]
followingActs
                Right a
r  -> a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return a
r