-- | Compatibility layer for network package, including newtype 'PortID'
{-# LANGUAGE CPP, GeneralizedNewtypeDeriving, OverloadedStrings #-}

module Database.MongoDB.Internal.Network (Host(..), PortID(..), N.HostName, connectTo,
                                          lookupReplicaSetName, lookupSeedList) where


#if !MIN_VERSION_network(2, 9, 0)

import qualified Network as N
import System.IO (Handle)

#else

import Control.Exception (bracketOnError)
import Network.BSD as BSD
import qualified Network.Socket as N
import System.IO (Handle, IOMode(ReadWriteMode))

#endif

import Data.ByteString.Char8 (pack, unpack)
import Data.List (dropWhileEnd, lookup)
import Data.Maybe (fromMaybe)
import Data.Text (Text)
import Network.DNS.Lookup (lookupSRV, lookupTXT)
import Network.DNS.Resolver (defaultResolvConf, makeResolvSeed, withResolver)
import Network.HTTP.Types.URI (parseQueryText)


-- | Wraps network's 'PortNumber'
-- Used to ease compatibility between older and newer network versions.
data PortID = PortNumber N.PortNumber
#if !defined(mingw32_HOST_OS) && !defined(cygwin32_HOST_OS) && !defined(_WIN32)
            | UnixSocket String
#endif
            deriving (Eq, Ord, Show)


#if !MIN_VERSION_network(2, 9, 0)

-- Unwrap our newtype and use network's PortID and connectTo
connectTo :: N.HostName         -- Hostname
          -> PortID             -- Port Identifier
          -> IO Handle          -- Connected Socket
connectTo hostname (PortNumber port) = N.connectTo hostname (N.PortNumber port)

#if !defined(mingw32_HOST_OS) && !defined(cygwin32_HOST_OS) && !defined(_WIN32)
connectTo _ (UnixSocket path) = N.connectTo "" (N.UnixSocket path)
#endif

#else

-- Copied implementation from network 2.8's 'connectTo', but using our 'PortID' newtype.
-- https://github.com/haskell/network/blob/e73f0b96c9da924fe83f3c73488f7e69f712755f/Network.hs#L120-L129
connectTo :: N.HostName         -- Hostname
          -> PortID             -- Port Identifier
          -> IO Handle          -- Connected Socket
connectTo hostname (PortNumber port) = do
    proto <- BSD.getProtocolNumber "tcp"
    bracketOnError
        (N.socket N.AF_INET N.Stream proto)
        (N.close)  -- only done if there's an error
        (\sock -> do
          he <- BSD.getHostByName hostname
          N.connect sock (N.SockAddrInet port (hostAddress he))
          N.socketToHandle sock ReadWriteMode
        )

#if !defined(mingw32_HOST_OS) && !defined(cygwin32_HOST_OS) && !defined(_WIN32)
connectTo _ (UnixSocket path) = do
    bracketOnError
        (N.socket N.AF_UNIX N.Stream 0)
        (N.close)
        (\sock -> do
          N.connect sock (N.SockAddrUnix path)
          N.socketToHandle sock ReadWriteMode
        )
#endif

#endif

-- * Host

data Host = Host N.HostName PortID  deriving (Show, Eq, Ord)

lookupReplicaSetName :: N.HostName -> IO (Maybe Text)
-- ^ Retrieves the replica set name from the TXT DNS record for the given hostname
lookupReplicaSetName hostname = do
  rs <- makeResolvSeed defaultResolvConf
  res <- withResolver rs $ \resolver -> lookupTXT resolver (pack hostname)
  case res of
    Left _ -> pure Nothing
    Right [] -> pure Nothing
    Right (x:_) ->
      pure $ fromMaybe (Nothing :: Maybe Text) (lookup "replicaSet" $ parseQueryText x)

lookupSeedList :: N.HostName -> IO [Host]
-- ^ Retrieves the replica set seed list from the SRV DNS record for the given hostname
lookupSeedList hostname = do
  rs <- makeResolvSeed defaultResolvConf
  res <- withResolver rs $ \resolver -> lookupSRV resolver $ pack $ "_mongodb._tcp." ++ hostname
  case res of
    Left _ -> pure []
    Right srv -> pure $ map (\(_, _, por, tar) ->
      let tar' = dropWhileEnd (=='.') (unpack tar)
      in Host tar' (PortNumber . fromIntegral $ por)) srv