-- | Compatibility layer for network package, including newtype 'PortID'
{-# LANGUAGE CPP, OverloadedStrings #-}

module Database.MongoDB.Internal.Network (Host(..), PortID(..), N.HostName, connectTo, 
                                          lookupReplicaSetName, lookupSeedList) where

#if !MIN_VERSION_network(2, 9, 0)

import qualified Network as N
import System.IO (Handle)

#else

import Control.Exception (bracketOnError)
import Network.BSD as BSD
import qualified Network.Socket as N
import System.IO (Handle, IOMode(ReadWriteMode))

#endif

import Data.ByteString.Char8 (pack, unpack)
import Data.List (dropWhileEnd)
import Data.Maybe (fromMaybe)
import Data.Text (Text)
import Network.DNS.Lookup (lookupSRV, lookupTXT)
import Network.DNS.Resolver (defaultResolvConf, makeResolvSeed, withResolver)
import Network.HTTP.Types.URI (parseQueryText)


-- | Wraps network's 'PortNumber'
-- Used to ease compatibility between older and newer network versions.
data PortID = PortNumber N.PortNumber
#if !defined(mingw32_HOST_OS) && !defined(cygwin32_HOST_OS) && !defined(_WIN32)
            | UnixSocket String
#endif
            deriving (PortID -> PortID -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: PortID -> PortID -> Bool
$c/= :: PortID -> PortID -> Bool
== :: PortID -> PortID -> Bool
$c== :: PortID -> PortID -> Bool
Eq, Eq PortID
PortID -> PortID -> Bool
PortID -> PortID -> Ordering
PortID -> PortID -> PortID
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: PortID -> PortID -> PortID
$cmin :: PortID -> PortID -> PortID
max :: PortID -> PortID -> PortID
$cmax :: PortID -> PortID -> PortID
>= :: PortID -> PortID -> Bool
$c>= :: PortID -> PortID -> Bool
> :: PortID -> PortID -> Bool
$c> :: PortID -> PortID -> Bool
<= :: PortID -> PortID -> Bool
$c<= :: PortID -> PortID -> Bool
< :: PortID -> PortID -> Bool
$c< :: PortID -> PortID -> Bool
compare :: PortID -> PortID -> Ordering
$ccompare :: PortID -> PortID -> Ordering
Ord, Int -> PortID -> ShowS
[PortID] -> ShowS
PortID -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [PortID] -> ShowS
$cshowList :: [PortID] -> ShowS
show :: PortID -> [Char]
$cshow :: PortID -> [Char]
showsPrec :: Int -> PortID -> ShowS
$cshowsPrec :: Int -> PortID -> ShowS
Show)


#if !MIN_VERSION_network(2, 9, 0)

-- Unwrap our newtype and use network's PortID and connectTo
connectTo :: N.HostName         -- Hostname
          -> PortID             -- Port Identifier
          -> IO Handle          -- Connected Socket
connectTo hostname (PortNumber port) = N.connectTo hostname (N.PortNumber port)

#if !defined(mingw32_HOST_OS) && !defined(cygwin32_HOST_OS) && !defined(_WIN32)
connectTo _ (UnixSocket path) = N.connectTo "" (N.UnixSocket path)
#endif

#else

-- Copied implementation from network 2.8's 'connectTo', but using our 'PortID' newtype.
-- https://github.com/haskell/network/blob/e73f0b96c9da924fe83f3c73488f7e69f712755f/Network.hs#L120-L129
connectTo :: N.HostName         -- Hostname
          -> PortID             -- Port Identifier
          -> IO Handle          -- Connected Socket
connectTo :: [Char] -> PortID -> IO Handle
connectTo [Char]
hostname (PortNumber PortNumber
port) = do
    ProtocolNumber
proto <- [Char] -> IO ProtocolNumber
BSD.getProtocolNumber [Char]
"tcp"
    forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracketOnError
        (Family -> SocketType -> ProtocolNumber -> IO Socket
N.socket Family
N.AF_INET SocketType
N.Stream ProtocolNumber
proto)
        Socket -> IO ()
N.close  -- only done if there's an error
        (\Socket
sock -> do
          HostEntry
he <- [Char] -> IO HostEntry
BSD.getHostByName [Char]
hostname
          Socket -> SockAddr -> IO ()
N.connect Socket
sock (PortNumber -> HostAddress -> SockAddr
N.SockAddrInet PortNumber
port (HostEntry -> HostAddress
hostAddress HostEntry
he))
          Socket -> IOMode -> IO Handle
N.socketToHandle Socket
sock IOMode
ReadWriteMode
        )

#if !defined(mingw32_HOST_OS) && !defined(cygwin32_HOST_OS) && !defined(_WIN32)
connectTo [Char]
_ (UnixSocket [Char]
path) = do
    forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracketOnError
        (Family -> SocketType -> ProtocolNumber -> IO Socket
N.socket Family
N.AF_UNIX SocketType
N.Stream ProtocolNumber
0)
        Socket -> IO ()
N.close
        (\Socket
sock -> do
          Socket -> SockAddr -> IO ()
N.connect Socket
sock ([Char] -> SockAddr
N.SockAddrUnix [Char]
path)
          Socket -> IOMode -> IO Handle
N.socketToHandle Socket
sock IOMode
ReadWriteMode
        )
#endif

#endif

-- * Host

data Host = Host N.HostName PortID  deriving (Int -> Host -> ShowS
[Host] -> ShowS
Host -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [Host] -> ShowS
$cshowList :: [Host] -> ShowS
show :: Host -> [Char]
$cshow :: Host -> [Char]
showsPrec :: Int -> Host -> ShowS
$cshowsPrec :: Int -> Host -> ShowS
Show, Host -> Host -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Host -> Host -> Bool
$c/= :: Host -> Host -> Bool
== :: Host -> Host -> Bool
$c== :: Host -> Host -> Bool
Eq, Eq Host
Host -> Host -> Bool
Host -> Host -> Ordering
Host -> Host -> Host
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: Host -> Host -> Host
$cmin :: Host -> Host -> Host
max :: Host -> Host -> Host
$cmax :: Host -> Host -> Host
>= :: Host -> Host -> Bool
$c>= :: Host -> Host -> Bool
> :: Host -> Host -> Bool
$c> :: Host -> Host -> Bool
<= :: Host -> Host -> Bool
$c<= :: Host -> Host -> Bool
< :: Host -> Host -> Bool
$c< :: Host -> Host -> Bool
compare :: Host -> Host -> Ordering
$ccompare :: Host -> Host -> Ordering
Ord)

lookupReplicaSetName :: N.HostName -> IO (Maybe Text)
-- ^ Retrieves the replica set name from the TXT DNS record for the given hostname
lookupReplicaSetName :: [Char] -> IO (Maybe Text)
lookupReplicaSetName [Char]
hostname = do 
  ResolvSeed
rs <- ResolvConf -> IO ResolvSeed
makeResolvSeed ResolvConf
defaultResolvConf
  Either DNSError [ByteString]
res <- forall a. ResolvSeed -> (Resolver -> IO a) -> IO a
withResolver ResolvSeed
rs forall a b. (a -> b) -> a -> b
$ \Resolver
resolver -> Resolver -> ByteString -> IO (Either DNSError [ByteString])
lookupTXT Resolver
resolver ([Char] -> ByteString
pack [Char]
hostname)
  case Either DNSError [ByteString]
res of 
    Left DNSError
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing 
    Right [] -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing 
    Right (ByteString
x:[ByteString]
_) ->
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a -> a
fromMaybe (forall a. Maybe a
Nothing :: Maybe Text) (forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup Text
"replicaSet" forall a b. (a -> b) -> a -> b
$ ByteString -> QueryText
parseQueryText ByteString
x)

lookupSeedList :: N.HostName -> IO [Host]
-- ^ Retrieves the replica set seed list from the SRV DNS record for the given hostname
lookupSeedList :: [Char] -> IO [Host]
lookupSeedList [Char]
hostname = do 
  ResolvSeed
rs <- ResolvConf -> IO ResolvSeed
makeResolvSeed ResolvConf
defaultResolvConf
  Either DNSError [(Word16, Word16, Word16, ByteString)]
res <- forall a. ResolvSeed -> (Resolver -> IO a) -> IO a
withResolver ResolvSeed
rs forall a b. (a -> b) -> a -> b
$ \Resolver
resolver -> Resolver
-> ByteString
-> IO (Either DNSError [(Word16, Word16, Word16, ByteString)])
lookupSRV Resolver
resolver forall a b. (a -> b) -> a -> b
$ [Char] -> ByteString
pack forall a b. (a -> b) -> a -> b
$ [Char]
"_mongodb._tcp." forall a. [a] -> [a] -> [a]
++ [Char]
hostname
  case Either DNSError [(Word16, Word16, Word16, ByteString)]
res of 
    Left DNSError
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure []
    Right [(Word16, Word16, Word16, ByteString)]
srv -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (\(Word16
_, Word16
_, Word16
por, ByteString
tar) -> 
      let tar' :: [Char]
tar' = forall a. (a -> Bool) -> [a] -> [a]
dropWhileEnd (forall a. Eq a => a -> a -> Bool
==Char
'.') (ByteString -> [Char]
unpack ByteString
tar) 
      in [Char] -> PortID -> Host
Host [Char]
tar' (PortNumber -> PortID
PortNumber forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ Word16
por)) [(Word16, Word16, Word16, ByteString)]
srv