-- | Compatibility layer for network package, including newtype 'PortID'
{-# LANGUAGE CPP, OverloadedStrings #-}

module Database.MongoDB.Internal.Network (Host(..), PortID(..), N.HostName, connectTo, 
                                          lookupReplicaSetName, lookupSeedList) where

#if !MIN_VERSION_network(2, 9, 0)

import qualified Network as N
import System.IO (Handle)

#else

import Control.Exception (bracketOnError)
import Network.BSD as BSD
import qualified Network.Socket as N
import System.IO (Handle, IOMode(ReadWriteMode))

#endif

import Data.ByteString.Char8 (pack, unpack)
import Data.List (dropWhileEnd)
import Data.Maybe (fromMaybe)
import Data.Text (Text)
import Network.DNS.Lookup (lookupSRV, lookupTXT)
import Network.DNS.Resolver (defaultResolvConf, makeResolvSeed, withResolver)
import Network.HTTP.Types.URI (parseQueryText)


-- | Wraps network's 'PortNumber'
-- Used to ease compatibility between older and newer network versions.
data PortID = PortNumber N.PortNumber
#if !defined(mingw32_HOST_OS) && !defined(cygwin32_HOST_OS) && !defined(_WIN32)
            | UnixSocket String
#endif
            deriving (PortID -> PortID -> Bool
(PortID -> PortID -> Bool)
-> (PortID -> PortID -> Bool) -> Eq PortID
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: PortID -> PortID -> Bool
== :: PortID -> PortID -> Bool
$c/= :: PortID -> PortID -> Bool
/= :: PortID -> PortID -> Bool
Eq, Eq PortID
Eq PortID =>
(PortID -> PortID -> Ordering)
-> (PortID -> PortID -> Bool)
-> (PortID -> PortID -> Bool)
-> (PortID -> PortID -> Bool)
-> (PortID -> PortID -> Bool)
-> (PortID -> PortID -> PortID)
-> (PortID -> PortID -> PortID)
-> Ord PortID
PortID -> PortID -> Bool
PortID -> PortID -> Ordering
PortID -> PortID -> PortID
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: PortID -> PortID -> Ordering
compare :: PortID -> PortID -> Ordering
$c< :: PortID -> PortID -> Bool
< :: PortID -> PortID -> Bool
$c<= :: PortID -> PortID -> Bool
<= :: PortID -> PortID -> Bool
$c> :: PortID -> PortID -> Bool
> :: PortID -> PortID -> Bool
$c>= :: PortID -> PortID -> Bool
>= :: PortID -> PortID -> Bool
$cmax :: PortID -> PortID -> PortID
max :: PortID -> PortID -> PortID
$cmin :: PortID -> PortID -> PortID
min :: PortID -> PortID -> PortID
Ord, Int -> PortID -> ShowS
[PortID] -> ShowS
PortID -> [Char]
(Int -> PortID -> ShowS)
-> (PortID -> [Char]) -> ([PortID] -> ShowS) -> Show PortID
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> PortID -> ShowS
showsPrec :: Int -> PortID -> ShowS
$cshow :: PortID -> [Char]
show :: PortID -> [Char]
$cshowList :: [PortID] -> ShowS
showList :: [PortID] -> ShowS
Show)


#if !MIN_VERSION_network(2, 9, 0)

-- Unwrap our newtype and use network's PortID and connectTo
connectTo :: N.HostName         -- Hostname
          -> PortID             -- Port Identifier
          -> IO Handle          -- Connected Socket
connectTo hostname (PortNumber port) = N.connectTo hostname (N.PortNumber port)

#if !defined(mingw32_HOST_OS) && !defined(cygwin32_HOST_OS) && !defined(_WIN32)
connectTo _ (UnixSocket path) = N.connectTo "" (N.UnixSocket path)
#endif

#else

-- Copied implementation from network 2.8's 'connectTo', but using our 'PortID' newtype.
-- https://github.com/haskell/network/blob/e73f0b96c9da924fe83f3c73488f7e69f712755f/Network.hs#L120-L129
connectTo :: N.HostName         -- Hostname
          -> PortID             -- Port Identifier
          -> IO Handle          -- Connected Socket
connectTo :: [Char] -> PortID -> IO Handle
connectTo [Char]
hostname (PortNumber PortNumber
port) = do
    ProtocolNumber
proto <- [Char] -> IO ProtocolNumber
BSD.getProtocolNumber [Char]
"tcp"
    IO Socket
-> (Socket -> IO ()) -> (Socket -> IO Handle) -> IO Handle
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracketOnError
        (Family -> SocketType -> ProtocolNumber -> IO Socket
N.socket Family
N.AF_INET SocketType
N.Stream ProtocolNumber
proto)
        Socket -> IO ()
N.close  -- only done if there's an error
        (\Socket
sock -> do
          HostEntry
he <- [Char] -> IO HostEntry
BSD.getHostByName [Char]
hostname
          Socket -> SockAddr -> IO ()
N.connect Socket
sock (PortNumber -> HostAddress -> SockAddr
N.SockAddrInet PortNumber
port (HostEntry -> HostAddress
hostAddress HostEntry
he))
          Socket -> IOMode -> IO Handle
N.socketToHandle Socket
sock IOMode
ReadWriteMode
        )

#if !defined(mingw32_HOST_OS) && !defined(cygwin32_HOST_OS) && !defined(_WIN32)
connectTo [Char]
_ (UnixSocket [Char]
path) = do
    IO Socket
-> (Socket -> IO ()) -> (Socket -> IO Handle) -> IO Handle
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracketOnError
        (Family -> SocketType -> ProtocolNumber -> IO Socket
N.socket Family
N.AF_UNIX SocketType
N.Stream ProtocolNumber
0)
        Socket -> IO ()
N.close
        (\Socket
sock -> do
          Socket -> SockAddr -> IO ()
N.connect Socket
sock ([Char] -> SockAddr
N.SockAddrUnix [Char]
path)
          Socket -> IOMode -> IO Handle
N.socketToHandle Socket
sock IOMode
ReadWriteMode
        )
#endif

#endif

-- * Host

data Host = Host N.HostName PortID  deriving (Int -> Host -> ShowS
[Host] -> ShowS
Host -> [Char]
(Int -> Host -> ShowS)
-> (Host -> [Char]) -> ([Host] -> ShowS) -> Show Host
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Host -> ShowS
showsPrec :: Int -> Host -> ShowS
$cshow :: Host -> [Char]
show :: Host -> [Char]
$cshowList :: [Host] -> ShowS
showList :: [Host] -> ShowS
Show, Host -> Host -> Bool
(Host -> Host -> Bool) -> (Host -> Host -> Bool) -> Eq Host
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Host -> Host -> Bool
== :: Host -> Host -> Bool
$c/= :: Host -> Host -> Bool
/= :: Host -> Host -> Bool
Eq, Eq Host
Eq Host =>
(Host -> Host -> Ordering)
-> (Host -> Host -> Bool)
-> (Host -> Host -> Bool)
-> (Host -> Host -> Bool)
-> (Host -> Host -> Bool)
-> (Host -> Host -> Host)
-> (Host -> Host -> Host)
-> Ord Host
Host -> Host -> Bool
Host -> Host -> Ordering
Host -> Host -> Host
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: Host -> Host -> Ordering
compare :: Host -> Host -> Ordering
$c< :: Host -> Host -> Bool
< :: Host -> Host -> Bool
$c<= :: Host -> Host -> Bool
<= :: Host -> Host -> Bool
$c> :: Host -> Host -> Bool
> :: Host -> Host -> Bool
$c>= :: Host -> Host -> Bool
>= :: Host -> Host -> Bool
$cmax :: Host -> Host -> Host
max :: Host -> Host -> Host
$cmin :: Host -> Host -> Host
min :: Host -> Host -> Host
Ord)

lookupReplicaSetName :: N.HostName -> IO (Maybe Text)
-- ^ Retrieves the replica set name from the TXT DNS record for the given hostname
lookupReplicaSetName :: [Char] -> IO (Maybe Text)
lookupReplicaSetName [Char]
hostname = do 
  ResolvSeed
rs <- ResolvConf -> IO ResolvSeed
makeResolvSeed ResolvConf
defaultResolvConf
  Either DNSError [ByteString]
res <- ResolvSeed
-> (Resolver -> IO (Either DNSError [ByteString]))
-> IO (Either DNSError [ByteString])
forall a. ResolvSeed -> (Resolver -> IO a) -> IO a
withResolver ResolvSeed
rs ((Resolver -> IO (Either DNSError [ByteString]))
 -> IO (Either DNSError [ByteString]))
-> (Resolver -> IO (Either DNSError [ByteString]))
-> IO (Either DNSError [ByteString])
forall a b. (a -> b) -> a -> b
$ \Resolver
resolver -> Resolver -> ByteString -> IO (Either DNSError [ByteString])
lookupTXT Resolver
resolver ([Char] -> ByteString
pack [Char]
hostname)
  case Either DNSError [ByteString]
res of 
    Left DNSError
_ -> Maybe Text -> IO (Maybe Text)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe Text
forall a. Maybe a
Nothing 
    Right [] -> Maybe Text -> IO (Maybe Text)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe Text
forall a. Maybe a
Nothing 
    Right (ByteString
x:[ByteString]
_) ->
      Maybe Text -> IO (Maybe Text)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe Text -> IO (Maybe Text)) -> Maybe Text -> IO (Maybe Text)
forall a b. (a -> b) -> a -> b
$ Maybe Text -> Maybe (Maybe Text) -> Maybe Text
forall a. a -> Maybe a -> a
fromMaybe (Maybe Text
forall a. Maybe a
Nothing :: Maybe Text) (Text -> [(Text, Maybe Text)] -> Maybe (Maybe Text)
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup Text
"replicaSet" ([(Text, Maybe Text)] -> Maybe (Maybe Text))
-> [(Text, Maybe Text)] -> Maybe (Maybe Text)
forall a b. (a -> b) -> a -> b
$ ByteString -> [(Text, Maybe Text)]
parseQueryText ByteString
x)

lookupSeedList :: N.HostName -> IO [Host]
-- ^ Retrieves the replica set seed list from the SRV DNS record for the given hostname
lookupSeedList :: [Char] -> IO [Host]
lookupSeedList [Char]
hostname = do 
  ResolvSeed
rs <- ResolvConf -> IO ResolvSeed
makeResolvSeed ResolvConf
defaultResolvConf
  Either DNSError [(Word16, Word16, Word16, ByteString)]
res <- ResolvSeed
-> (Resolver
    -> IO (Either DNSError [(Word16, Word16, Word16, ByteString)]))
-> IO (Either DNSError [(Word16, Word16, Word16, ByteString)])
forall a. ResolvSeed -> (Resolver -> IO a) -> IO a
withResolver ResolvSeed
rs ((Resolver
  -> IO (Either DNSError [(Word16, Word16, Word16, ByteString)]))
 -> IO (Either DNSError [(Word16, Word16, Word16, ByteString)]))
-> (Resolver
    -> IO (Either DNSError [(Word16, Word16, Word16, ByteString)]))
-> IO (Either DNSError [(Word16, Word16, Word16, ByteString)])
forall a b. (a -> b) -> a -> b
$ \Resolver
resolver -> Resolver
-> ByteString
-> IO (Either DNSError [(Word16, Word16, Word16, ByteString)])
lookupSRV Resolver
resolver (ByteString
 -> IO (Either DNSError [(Word16, Word16, Word16, ByteString)]))
-> ByteString
-> IO (Either DNSError [(Word16, Word16, Word16, ByteString)])
forall a b. (a -> b) -> a -> b
$ [Char] -> ByteString
pack ([Char] -> ByteString) -> [Char] -> ByteString
forall a b. (a -> b) -> a -> b
$ [Char]
"_mongodb._tcp." [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
hostname
  case Either DNSError [(Word16, Word16, Word16, ByteString)]
res of 
    Left DNSError
_ -> [Host] -> IO [Host]
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure []
    Right [(Word16, Word16, Word16, ByteString)]
srv -> [Host] -> IO [Host]
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Host] -> IO [Host]) -> [Host] -> IO [Host]
forall a b. (a -> b) -> a -> b
$ ((Word16, Word16, Word16, ByteString) -> Host)
-> [(Word16, Word16, Word16, ByteString)] -> [Host]
forall a b. (a -> b) -> [a] -> [b]
map (\(Word16
_, Word16
_, Word16
por, ByteString
tar) -> 
      let tar' :: [Char]
tar' = (Char -> Bool) -> ShowS
forall a. (a -> Bool) -> [a] -> [a]
dropWhileEnd (Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
==Char
'.') (ByteString -> [Char]
unpack ByteString
tar) 
      in [Char] -> PortID -> Host
Host [Char]
tar' (PortNumber -> PortID
PortNumber (PortNumber -> PortID)
-> (Word16 -> PortNumber) -> Word16 -> PortID
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word16 -> PortNumber
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word16 -> PortID) -> Word16 -> PortID
forall a b. (a -> b) -> a -> b
$ Word16
por)) [(Word16, Word16, Word16, ByteString)]
srv