-- | A DNS resolver. This code acts like the resolver library from libc, except
--   that it can work asynchronously, and can return much more information.
--
--   At the moment, the interface is very much undecided but currently looks
--   like this:
--
--   > import qualified Network.DNS.Client as DNS
--   >
--   > DNS.resolve DNS.A "somedomain.com"
--   > Right [(2008-02-01 00:27:14.861098 UTC, DNS.RRA [2466498203])]
--
--   The first element of the tuple is the time when the information expires.
--   The second depends on the record type requested (A, in this case) and
--   A records contain IP address, so that's a HostAddress in there.
--
--   This module parses @/etc/resolv.conf@ for it's configuration. It needs a
--   recursive server to do the hard work. If you're lacking a recursive
--   server, you can setup dnscache (from djbdns) locally and point at that.

module Network.DNS.Client
  ( module Network.DNS.Types
  , resolve
  , resolveAsync
  , DNSError(..)
  ) where

import Data.Word
import Data.List (nub)
import Data.Maybe (fromMaybe)
import Data.Time
import Control.Monad (when)
import Control.Timeout
import Control.Concurrent (forkIO)
import System.IO.Unsafe (unsafePerformIO)
import System.Random (mkStdGen, random, Random, randomR)
import Control.Concurrent.STM

import qualified Data.Binary.Get as G

import qualified Data.Map as Map

import Network.Socket hiding (sendTo, recvFrom)
import Network.Socket.ByteString

import qualified Data.Binary.Put as P
import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as BL

import Network.DNS.Common
import Network.DNS.ResolveConfParse
import Network.DNS.Types

queryHeader :: Word16 -> Header
queryHeader id = Header id False QUERY False False True False NoError 1 0 0 0

-- | This is the type of errors from the library. Either it's one of the first
--   two errors (which are generated from within this code), or it's an error
--   directly from the DNS server.
data DNSError = Timeout  -- ^ the DNS server didn't answer
               | AnswerNotIncluded
               -- ^ this is returned when the DNS server returned a valid
               --   answer, but the answer didn't include the information we
               --   were looking for. Firstly, this isn't a recursive
               --   resolver, so if you point it at a non-recursive server
               --   you'll get this for nearly every query as the server will
               --   just be telling us the location of the roots.
               --
               --   This can also occur when you ask for a resource which
               --   doesn't exist - like a AAAA record from www.google.com
               | DNSError ResponseCode  -- ^ errors from the DNS server
               deriving (Show, Eq)

data InflightRequest =
  InflightRequest { infRequest :: B.ByteString  -- ^ the question section
                  , infCallback :: (Either DNSError ([String], ([Entry], [Entry], [Entry])) -> IO ())
                  , infCurrent :: [String]  -- ^ the current name being queried
                  , infSearch :: [[String]]  -- ^ alternative names to try
                  , infAttempts :: Int  -- ^ number of times transmitted
                  , infTimeout :: TimeoutTag  -- ^ the timeout
                  , infType :: DNSType  -- ^ the type of the query
                  -- | In the case that we need to always send this request to
                  --   a specific nameserver (e.g. it's a probe packet), this
                  --   is set to a non-Nothing value.
                  , infSpecificNameserver :: Maybe Nameserver
                  }

data Nameserver = Nameserver { nsAddress :: Word32  -- ^ ip address
                             , nsUp :: TVar Bool  -- ^ do we believe that this server is up?
                             , nsInflight :: TVar (Map.Map Word16 InflightRequest)
                             , nsTimeouts :: TVar Int  -- ^ how many requests have timed out in a row
                             }

data ResolverConfig = ResolverConfig
  {   rcNameservers :: [Nameserver]
      -- | This is a list of lists of labels. If a requested domain name is short
      --   enougth (see ndots) it's tried with these suffixes first
    , rcSearchPath :: [[String]]
      -- | If a name has this many dots, we try an initial absolute query
    , rcNdots :: Int
      -- | Number of attempts per request
    , rcAttempts :: Int
      -- | Number of seconds to wait for the nameserver
    , rcTimeout :: Int
      -- | This is the next nameserver to try
    , rcRobin :: TVar Int
      -- | This is an infinite list of random seeds
    , rcSeeds :: TVar [Int]
    }

-- | This is the max number of requests which can be outstanding against
--   any one server at a time. Since, to generate ids, we roll until we find
--   a good id setting this > 60000 is dangerous. If it hit 2**16 we would
--   livelock
maxInflightPerServer :: Int
maxInflightPerServer = 2048

-- | This is a socket that is bound, locally, to a port and all outgoing
--   packets are written to it. It's also the socket that we listen on.
globalSocket :: Socket
globalSocket = unsafePerformIO $ do
  s <- socket AF_INET Datagram 0
  bindSocket s $ SockAddrInet (PortNum 0) iNADDR_ANY
  return s

-- | Lookup a reply from the network
lookupReply :: ResolverConfig
            -> Word16  -- ^ The id of the reply
            -> Word32  -- ^ The IP address of the source
            -> IO (Maybe (InflightRequest, Nameserver))
lookupReply config id addr = do
  case filter (\x -> nsAddress x == addr) $ rcNameservers config of
       [] -> return Nothing
       (ns:_) -> do
         minf <- atomically (do
           inflight <- readTVar $ nsInflight ns
           let (minf, inflight') = Map.updateLookupWithKey (const $ const Nothing) id inflight
           writeTVar (nsInflight ns) inflight'
           -- cancel the timeout if we found an inflight
           case minf of
                Just inf -> cancelTimeout $ infTimeout inf
                _ -> return False
           return minf)
         return $ minf >>= (\inf -> return (inf, ns))

-- | This function never returns an it's expected that it runs in it's own
--   thread, forever reading from the network.
readerThread :: Socket -> IO ()
readerThread socket = do
  (bytes, SockAddrInet _ addr) <- recvFrom socket 1500
  case parsePacket bytes of
       Left _ -> readerThread socket
       Right (Packet header _ ans nses additional) -> do
         if (headIsResponse header == False) || (headIsTruncated header)
            then readerThread socket
            else do config <- getResolveConfig
                    minfns <- lookupReply config (headId header) addr
                    case minfns of
                         Nothing -> readerThread socket
                         Just (inf, ns) -> do
                           case headResponseCode header of
                                ServerError -> handleTransientError config inf $ DNSError ServerError
                                NoError -> handleReply config inf ns (ans, nses, additional)
                                x -> handleFailure config inf $ DNSError x

  readerThread socket

-- | There can only ever really be a single config in action at any one time,
--   because there is only a single @readerThread@. One could imaging having a
--   socket bound to get packets only from a single nameserver, and that would
--   work so long as the sets of nameservers didn't overlap. However, I don't
--   think that's a common requirement, so I don't support it.
globalConfig :: TVar (Maybe ResolverConfig)
globalConfig = unsafePerformIO $ newTVarIO Nothing

-- | Build a ResolverConfig by parsing @/etc/resolv.conf@
resolverConfigFromResolvConf :: IO ResolverConfig
resolverConfigFromResolvConf = do
  Right resolvconf <- parseResolveConf "/etc/resolv.conf"
  robin <- atomically $ newTVar 0
  let toNameserver ip = do
        up <- atomically $ newTVar True
        inflight <- atomically $ newTVar Map.empty
        timeouts <- atomically $ newTVar 0
        return $ Nameserver ip up inflight timeouts

  -- For the list of seeds we have a lazy IO on /dev/urandom going through a
  -- Get monad, which parses 4 byte lumps as Ints
  urandom <- BL.readFile "/dev/urandom"
  let urandomParser :: G.Get [Int]
      urandomParser = do
        v <- G.getWord32be
        rest <- urandomParser
        return $ fromIntegral v : rest
      seeds = G.runGet urandomParser urandom

  tseeds <- atomically $ newTVar seeds

  ns <- mapM toNameserver $ nub $ resolveNameservers resolvconf
  return $ ResolverConfig ns
             (resolveSearch resolvconf)
             (fromMaybe 1 (resolveNdots resolvconf))
             (fromMaybe 2 (resolveAttempts resolvconf))
             (fromMaybe 5 (resolveTimeout resolvconf))
             robin
             tseeds

-- | Pick a nameserver from a config. We try to round robin the nameservers,
--   avoiding the down nameservers. If all nameservers are down, we just pick
--   one anyway.
selectNameserver :: ResolverConfig -> STM Nameserver
selectNameserver config = do
  robin <- readTVar $ rcRobin config
  writeTVar (rcRobin config) $ (robin + 1) `mod` (length (rcNameservers config))
  let servers = (drop robin $ rcNameservers config) ++
                (take robin $ rcNameservers config)
  servers' <- mapM (\x -> readTVar (nsUp x) >>= \up -> return (up, x)) servers
  case filter fst servers' of
       [] -> return $ head servers  -- all down, just pick one
       x:_ -> return $ snd x

instance Random Word16 where
  random g = (fromIntegral result, g') where
    result :: Int
    (result, g') = randomR (0, 65535) g
  randomR (lo, hi) g = (fromIntegral result, g') where
    result :: Int
    (result, g') = randomR (fromIntegral lo, fromIntegral hi) g

submit4 :: ResolverConfig
        -> InflightRequest
        -> (IO () -> STM TimeoutTag)
        -> STM (Word16, Word32)
submit4 config inf mtag = do
  -- We cannot reuse a id number if we already have an inflight request to the
  -- same nameserver with the same id. Thus, once we have picked a nameserver
  -- we need to generate a stream of numbers until we find one which isn't
  -- currently in use. So we generate a strong random seed and use it with the
  -- standard Haskell PRNG to generate the stream of ids in the STM monad
  ns <- case infSpecificNameserver inf of
             Nothing -> selectNameserver config
             Just ns -> return ns
  inflight <- readTVar $ nsInflight ns
  when (Map.size inflight > maxInflightPerServer) retry
  seeds <- readTVar (rcSeeds config)
  let seed = head seeds
  writeTVar (rcSeeds config) $ tail seeds
  let prng = mkStdGen seed
      f prng = r where
        r = if Map.member candidate inflight
               then f prng'
               else candidate
        (candidate, prng') = random prng
      id = f prng
      addr = nsAddress ns
  tag <- mtag $ handleTimeout config id addr
  let inf' = inf { infTimeout = tag, infAttempts = infAttempts inf + 1 }
  writeTVar (nsInflight ns) $ Map.insert id inf' inflight
  return (id, addr)

-- | This is the maximum number of timeouts, in a row, that can happen from a
--   single nameserver before we consider that server to be down. When we mark
--   a server as down we start sending probes to it. Once it responds to one of
--   those probes with anything save a Timeout, we'll mark it good again.
maxTimeoutsPerServer :: Int
maxTimeoutsPerServer = 5

-- | This is the callback from a probe request. If it came back with a
--   reasonable reply we mark the nameserver as up
probeCallback :: Nameserver
              -> Either DNSError ([String], ([Entry], [Entry], [Entry])) -> IO ()
probeCallback ns result =
  case result of
       Left Timeout -> (addTimeout 60 $ probeNameserver ns) >> return ()
       _ -> atomically (writeTVar (nsUp ns) True)

-- | This is a timeout callback. The timeout is called when we mark a
--   nameserver as down. This timeout is to start a probe. If the probe returns
--   any kind of DNS packet, we call the nameserver up again.
probeNameserver :: Nameserver -> IO ()
probeNameserver ns = do
  config <- getResolveConfig
  let labels = ["www", "google", "com"]
      Just name = serialiseDNSName labels
      req = B.concat $ BL.toChunks $ P.runPut $ serialiseQuestion name A
      inf = InflightRequest req (probeCallback ns) labels [labels] 1 undefined A $ Just ns
  transmit config inf

handleTimeout :: ResolverConfig  -- ^ the config when the request started
              -> Word16  -- ^ the id of the request which has timed out
              -> Word32  -- ^ the IP address of the nameserver
              -> IO ()
handleTimeout config id addr = do
  -- There has to be a nameserver with the correct IP address because
  -- we chose it from this same config
  minfns <- lookupReply config id addr
  case minfns of
       Nothing -> return ()  -- we raced the network and lost
       Just (inf, ns) -> do
         -- First, deal with the nameserver. If it's been handling a lot of failures
         -- recently we might want to consider it down
         timeouts <- atomically $ do
           timeouts <- readTVar $ nsTimeouts ns
           writeTVar (nsTimeouts ns) $ timeouts + 1
           return timeouts
         when (timeouts > maxTimeoutsPerServer) $ do
           mtimeout <- addTimeoutAtomic 60
           atomically $ do
             upflag <- readTVar $ nsUp ns
             when (upflag == True) $ do
               writeTVar (nsUp ns) False
               mtimeout $ probeNameserver ns
               return ()

         handleTransientError config inf Timeout

handleReply :: ResolverConfig -> InflightRequest -> Nameserver -> ([Entry], [Entry], [Entry]) -> IO ()
handleReply _ inf ns answers = do
  -- we got a good reply from a nameserver, mark as up and good
  atomically $ do
    writeTVar (nsTimeouts ns) 0
    writeTVar (nsUp ns) True

  (infCallback inf) $ Right (infCurrent inf, answers)

transmit :: ResolverConfig -> InflightRequest -> IO ()
transmit config inf = do
  mtag <- addTimeoutAtomic $ fromIntegral $ rcTimeout config
  (id, addr) <- atomically $ submit4 config inf mtag
  let header = B.concat $ BL.toChunks $ P.runPut $ serialiseHeader $ queryHeader id
      query = header `B.append` infRequest inf
  sendTo globalSocket query $ SockAddrInet 53 addr
  -- FIXME: check return value?
  return ()

-- | Increment the attempts, pop the next name from the search list and
--   transmit the given inflight
sendInflight :: ResolverConfig -> InflightRequest -> IO ()
sendInflight config inf =
  if null $ infSearch inf
     then (infCallback inf) $ Left $ DNSError NXDomain
     else do let inf' = inf { infSearch = tail $ infSearch inf, infAttempts = 1 }
                 target = head $ infSearch inf
             case serialiseDNSName target of
                  Nothing -> sendInflight config inf'
                  Just name -> do
                    transmit config $ inf' { infRequest = B.concat $ BL.toChunks $ P.runPut $ serialiseQuestion name $ infType inf
                                           , infCurrent = target }

-- | This is for when we get a timeout or ServerError - there might not
--   be anything wrong with the query so we retry it.
handleTransientError :: ResolverConfig -> InflightRequest -> DNSError -> IO ()
handleTransientError config inf result = do
  if infAttempts inf > rcAttempts config
     then handleFailure config inf result
     else transmit config $ inf { infAttempts = 1 + infAttempts inf }

-- | This is for when we have a harder error - either multiple timeouts or
--   NXDomain etc from the server. The timeout must have been canceled by here
handleFailure :: ResolverConfig -> InflightRequest -> DNSError -> IO ()
handleFailure config inf result = do
  if not $ null $ infSearch inf
     then sendInflight config inf
     else (infCallback inf) $ Left result

type DNSDB = Map.Map ([String], DNSType) [(UTCTime, RR)]

answersToDB :: UTCTime -> ([Entry], [Entry], [Entry]) -> DNSDB
answersToDB currentTime (ans, nses, additional) =
  Map.unionsWith (++) $ map toMap [ans, nses, additional] where
  toMap :: [Entry] -> Map.Map ([String], DNSType) [(UTCTime, RR)]
  toMap = Map.fromListWith (++) . map (\(host, secs, rr) -> ((host, rrToType rr), [(t secs, rr)]))
  t :: Word32 -> UTCTime
  t = (flip addUTCTime) currentTime . fromIntegral

dbGet :: [String] -> DNSType -> DNSDB -> [(UTCTime, RR)]
dbGet host ty db = find host (0 :: Int) where
  -- this limits the CNAME depth
  find _ 16 = []
  find host n =
    case Map.lookup (host, ty) db of
         Just x -> x
         Nothing -> case Map.lookup (host, CNAME) db of
                         Nothing -> []
                         Just ((_, RRCNAME host'):_) -> find host' (n + 1)

-- | This has to turn the information from the server into something
--   useful for the caller. At the moment this is pretty stupid, just
--   grab A records from the answer section
parseQuery :: DNSType
           -> (Either DNSError [(UTCTime, RR)] -> IO ())
           -> Either DNSError ([String], ([Entry], [Entry], [Entry]))
           -> IO ()
parseQuery ty cb e =
  case e of
       Left x -> cb $ Left x
       Right (host, answers) -> do
         currentTime <- getCurrentTime
         let db = answersToDB currentTime answers
         case dbGet host ty db of
              [] -> cb $ Left AnswerNotIncluded
              x -> cb $ Right x

getResolveConfig :: IO ResolverConfig
getResolveConfig = do
  config <- atomically $ readTVar globalConfig
  case config of
       Nothing -> do
         config <- resolverConfigFromResolvConf
         (set, config') <- atomically $ do
           mconfig <- readTVar globalConfig
           case mconfig of
                Nothing -> do
                  writeTVar globalConfig $ Just config
                  return (True, config)
                Just config'' -> return (False, config'')
         when set (forkIO (readerThread globalSocket) >> return ())
         return config'

       Just config' -> return config'

-- | Lookup some information from DNS
resolve :: DNSType  -- ^ the type of DNS information requested
        -> String  -- ^ the domain to query
        -> IO (Either DNSError [(UTCTime, RR)]) -- ^ The RR values here will
                                                --   always be of the correct
                                                --   type for the requested
                                                --   DNSType
resolve ty hostname = do
  var <- atomically $ newEmptyTMVar
  resolveAsync ty hostname (atomically . putTMVar var)
  atomically $ takeTMVar var

-- | This is the same as resolve, below, put you get the answer asynchronously.
--   Blocking the thread which makes the callback in this case is bad - it'll
--   block the DNS network reading thread.
resolveAsync :: DNSType
              -> String
              -> (Either DNSError [(UTCTime, RR)] -> IO ())
              -> IO ()
resolveAsync ty host cb = do
  config <- getResolveConfig
  let labels = splitDNSName host
      wrappedCb = parseQuery ty cb
      inf = InflightRequest undefined wrappedCb labels [] 1 undefined ty Nothing
  -- get the list of names that we'll try to resolve, in order.
  -- if we are searching, that could be a list of length > 1
  let names = if last host /= '.' && length labels - 1 < rcNdots config && length (rcSearchPath config) > 0
                 then (map ((++) labels) $ rcSearchPath config) ++ [labels]
                 else [labels]
  sendInflight config $ inf { infSearch = names }