-- | A DNS resolver. This code acts like the resolver library from libc, except -- that it can work asynchronously, and can return much more information. -- -- At the moment, the interface is very much undecided but currently looks -- like this: -- -- > import qualified Network.DNS.Client as DNS -- > -- > DNS.resolve DNS.A "somedomain.com" -- > Right [(2008-02-01 00:27:14.861098 UTC, DNS.RRA [2466498203])] -- -- The first element of the tuple is the time when the information expires. -- The second depends on the record type requested (A, in this case) and -- A records contain IP address, so that's a HostAddress in there. -- -- This module parses @/etc/resolv.conf@ for it's configuration. It needs a -- recursive server to do the hard work. If you're lacking a recursive -- server, you can setup dnscache (from djbdns) locally and point at that. module Network.DNS.Client ( module Network.DNS.Types , resolve , resolveAsync , DNSError(..) ) where import Data.Word import Data.List (nub) import Data.Maybe (fromMaybe) import Data.Time import Control.Monad (when) import Control.Timeout import Control.Concurrent (forkIO) import System.IO.Unsafe (unsafePerformIO) import System.Random (mkStdGen, random, Random, randomR) import Control.Concurrent.STM import qualified Data.Binary.Get as G import qualified Data.Map as Map import Network.Socket hiding (sendTo, recvFrom) import Network.Socket.ByteString import qualified Data.Binary.Put as P import qualified Data.ByteString as B import qualified Data.ByteString.Lazy as BL import Network.DNS.Common import Network.DNS.ResolveConfParse import Network.DNS.Types queryHeader :: Word16 -> Header queryHeader id = Header id False QUERY False False True False NoError 1 0 0 0 -- | This is the type of errors from the library. Either it's one of the first -- two errors (which are generated from within this code), or it's an error -- directly from the DNS server. data DNSError = Timeout -- ^ the DNS server didn't answer | AnswerNotIncluded -- ^ this is returned when the DNS server returned a valid -- answer, but the answer didn't include the information we -- were looking for. Firstly, this isn't a recursive -- resolver, so if you point it at a non-recursive server -- you'll get this for nearly every query as the server will -- just be telling us the location of the roots. -- -- This can also occur when you ask for a resource which -- doesn't exist - like a AAAA record from www.google.com | DNSError ResponseCode -- ^ errors from the DNS server deriving (Show, Eq) data InflightRequest = InflightRequest { infRequest :: B.ByteString -- ^ the question section , infCallback :: (Either DNSError ([String], ([Entry], [Entry], [Entry])) -> IO ()) , infCurrent :: [String] -- ^ the current name being queried , infSearch :: [[String]] -- ^ alternative names to try , infAttempts :: Int -- ^ number of times transmitted , infTimeout :: TimeoutTag -- ^ the timeout , infType :: DNSType -- ^ the type of the query -- | In the case that we need to always send this request to -- a specific nameserver (e.g. it's a probe packet), this -- is set to a non-Nothing value. , infSpecificNameserver :: Maybe Nameserver } data Nameserver = Nameserver { nsAddress :: Word32 -- ^ ip address , nsUp :: TVar Bool -- ^ do we believe that this server is up? , nsInflight :: TVar (Map.Map Word16 InflightRequest) , nsTimeouts :: TVar Int -- ^ how many requests have timed out in a row } data ResolverConfig = ResolverConfig { rcNameservers :: [Nameserver] -- | This is a list of lists of labels. If a requested domain name is short -- enougth (see ndots) it's tried with these suffixes first , rcSearchPath :: [[String]] -- | If a name has this many dots, we try an initial absolute query , rcNdots :: Int -- | Number of attempts per request , rcAttempts :: Int -- | Number of seconds to wait for the nameserver , rcTimeout :: Int -- | This is the next nameserver to try , rcRobin :: TVar Int -- | This is an infinite list of random seeds , rcSeeds :: TVar [Int] } -- | This is the max number of requests which can be outstanding against -- any one server at a time. Since, to generate ids, we roll until we find -- a good id setting this > 60000 is dangerous. If it hit 2**16 we would -- livelock maxInflightPerServer :: Int maxInflightPerServer = 2048 -- | This is a socket that is bound, locally, to a port and all outgoing -- packets are written to it. It's also the socket that we listen on. globalSocket :: Socket globalSocket = unsafePerformIO $ do s <- socket AF_INET Datagram 0 bindSocket s $ SockAddrInet (PortNum 0) iNADDR_ANY return s -- | Lookup a reply from the network lookupReply :: ResolverConfig -> Word16 -- ^ The id of the reply -> Word32 -- ^ The IP address of the source -> IO (Maybe (InflightRequest, Nameserver)) lookupReply config id addr = do case filter (\x -> nsAddress x == addr) $ rcNameservers config of [] -> return Nothing (ns:_) -> do minf <- atomically (do inflight <- readTVar $ nsInflight ns let (minf, inflight') = Map.updateLookupWithKey (const $ const Nothing) id inflight writeTVar (nsInflight ns) inflight' -- cancel the timeout if we found an inflight case minf of Just inf -> cancelTimeout $ infTimeout inf _ -> return False return minf) return $ minf >>= (\inf -> return (inf, ns)) -- | This function never returns an it's expected that it runs in it's own -- thread, forever reading from the network. readerThread :: Socket -> IO () readerThread socket = do (bytes, SockAddrInet _ addr) <- recvFrom socket 1500 case parsePacket bytes of Left _ -> readerThread socket Right (Packet header _ ans nses additional) -> do if (headIsResponse header == False) || (headIsTruncated header) then readerThread socket else do config <- getResolveConfig minfns <- lookupReply config (headId header) addr case minfns of Nothing -> readerThread socket Just (inf, ns) -> do case headResponseCode header of ServerError -> handleTransientError config inf $ DNSError ServerError NoError -> handleReply config inf ns (ans, nses, additional) x -> handleFailure config inf $ DNSError x readerThread socket -- | There can only ever really be a single config in action at any one time, -- because there is only a single @readerThread@. One could imaging having a -- socket bound to get packets only from a single nameserver, and that would -- work so long as the sets of nameservers didn't overlap. However, I don't -- think that's a common requirement, so I don't support it. globalConfig :: TVar (Maybe ResolverConfig) globalConfig = unsafePerformIO $ newTVarIO Nothing -- | Build a ResolverConfig by parsing @/etc/resolv.conf@ resolverConfigFromResolvConf :: IO ResolverConfig resolverConfigFromResolvConf = do Right resolvconf <- parseResolveConf "/etc/resolv.conf" robin <- atomically $ newTVar 0 let toNameserver ip = do up <- atomically $ newTVar True inflight <- atomically $ newTVar Map.empty timeouts <- atomically $ newTVar 0 return $ Nameserver ip up inflight timeouts -- For the list of seeds we have a lazy IO on /dev/urandom going through a -- Get monad, which parses 4 byte lumps as Ints urandom <- BL.readFile "/dev/urandom" let urandomParser :: G.Get [Int] urandomParser = do v <- G.getWord32be rest <- urandomParser return $ fromIntegral v : rest seeds = G.runGet urandomParser urandom tseeds <- atomically $ newTVar seeds ns <- mapM toNameserver $ nub $ resolveNameservers resolvconf return $ ResolverConfig ns (resolveSearch resolvconf) (fromMaybe 1 (resolveNdots resolvconf)) (fromMaybe 2 (resolveAttempts resolvconf)) (fromMaybe 5 (resolveTimeout resolvconf)) robin tseeds -- | Pick a nameserver from a config. We try to round robin the nameservers, -- avoiding the down nameservers. If all nameservers are down, we just pick -- one anyway. selectNameserver :: ResolverConfig -> STM Nameserver selectNameserver config = do robin <- readTVar $ rcRobin config writeTVar (rcRobin config) $ (robin + 1) `mod` (length (rcNameservers config)) let servers = (drop robin $ rcNameservers config) ++ (take robin $ rcNameservers config) servers' <- mapM (\x -> readTVar (nsUp x) >>= \up -> return (up, x)) servers case filter fst servers' of [] -> return $ head servers -- all down, just pick one x:_ -> return $ snd x instance Random Word16 where random g = (fromIntegral result, g') where result :: Int (result, g') = randomR (0, 65535) g randomR (lo, hi) g = (fromIntegral result, g') where result :: Int (result, g') = randomR (fromIntegral lo, fromIntegral hi) g submit4 :: ResolverConfig -> InflightRequest -> (IO () -> STM TimeoutTag) -> STM (Word16, Word32) submit4 config inf mtag = do -- We cannot reuse a id number if we already have an inflight request to the -- same nameserver with the same id. Thus, once we have picked a nameserver -- we need to generate a stream of numbers until we find one which isn't -- currently in use. So we generate a strong random seed and use it with the -- standard Haskell PRNG to generate the stream of ids in the STM monad ns <- case infSpecificNameserver inf of Nothing -> selectNameserver config Just ns -> return ns inflight <- readTVar $ nsInflight ns when (Map.size inflight > maxInflightPerServer) retry seeds <- readTVar (rcSeeds config) let seed = head seeds writeTVar (rcSeeds config) $ tail seeds let prng = mkStdGen seed f prng = r where r = if Map.member candidate inflight then f prng' else candidate (candidate, prng') = random prng id = f prng addr = nsAddress ns tag <- mtag $ handleTimeout config id addr let inf' = inf { infTimeout = tag, infAttempts = infAttempts inf + 1 } writeTVar (nsInflight ns) $ Map.insert id inf' inflight return (id, addr) -- | This is the maximum number of timeouts, in a row, that can happen from a -- single nameserver before we consider that server to be down. When we mark -- a server as down we start sending probes to it. Once it responds to one of -- those probes with anything save a Timeout, we'll mark it good again. maxTimeoutsPerServer :: Int maxTimeoutsPerServer = 5 -- | This is the callback from a probe request. If it came back with a -- reasonable reply we mark the nameserver as up probeCallback :: Nameserver -> Either DNSError ([String], ([Entry], [Entry], [Entry])) -> IO () probeCallback ns result = case result of Left Timeout -> (addTimeout 60 $ probeNameserver ns) >> return () _ -> atomically (writeTVar (nsUp ns) True) -- | This is a timeout callback. The timeout is called when we mark a -- nameserver as down. This timeout is to start a probe. If the probe returns -- any kind of DNS packet, we call the nameserver up again. probeNameserver :: Nameserver -> IO () probeNameserver ns = do config <- getResolveConfig let labels = ["www", "google", "com"] Just name = serialiseDNSName labels req = B.concat $ BL.toChunks $ P.runPut $ serialiseQuestion name A inf = InflightRequest req (probeCallback ns) labels [labels] 1 undefined A $ Just ns transmit config inf handleTimeout :: ResolverConfig -- ^ the config when the request started -> Word16 -- ^ the id of the request which has timed out -> Word32 -- ^ the IP address of the nameserver -> IO () handleTimeout config id addr = do -- There has to be a nameserver with the correct IP address because -- we chose it from this same config minfns <- lookupReply config id addr case minfns of Nothing -> return () -- we raced the network and lost Just (inf, ns) -> do -- First, deal with the nameserver. If it's been handling a lot of failures -- recently we might want to consider it down timeouts <- atomically $ do timeouts <- readTVar $ nsTimeouts ns writeTVar (nsTimeouts ns) $ timeouts + 1 return timeouts when (timeouts > maxTimeoutsPerServer) $ do mtimeout <- addTimeoutAtomic 60 atomically $ do upflag <- readTVar $ nsUp ns when (upflag == True) $ do writeTVar (nsUp ns) False mtimeout $ probeNameserver ns return () handleTransientError config inf Timeout handleReply :: ResolverConfig -> InflightRequest -> Nameserver -> ([Entry], [Entry], [Entry]) -> IO () handleReply _ inf ns answers = do -- we got a good reply from a nameserver, mark as up and good atomically $ do writeTVar (nsTimeouts ns) 0 writeTVar (nsUp ns) True (infCallback inf) $ Right (infCurrent inf, answers) transmit :: ResolverConfig -> InflightRequest -> IO () transmit config inf = do mtag <- addTimeoutAtomic $ fromIntegral $ rcTimeout config (id, addr) <- atomically $ submit4 config inf mtag let header = B.concat $ BL.toChunks $ P.runPut $ serialiseHeader $ queryHeader id query = header `B.append` infRequest inf sendTo globalSocket query $ SockAddrInet 53 addr -- FIXME: check return value? return () -- | Increment the attempts, pop the next name from the search list and -- transmit the given inflight sendInflight :: ResolverConfig -> InflightRequest -> IO () sendInflight config inf = if null $ infSearch inf then (infCallback inf) $ Left $ DNSError NXDomain else do let inf' = inf { infSearch = tail $ infSearch inf, infAttempts = 1 } target = head $ infSearch inf case serialiseDNSName target of Nothing -> sendInflight config inf' Just name -> do transmit config $ inf' { infRequest = B.concat $ BL.toChunks $ P.runPut $ serialiseQuestion name $ infType inf , infCurrent = target } -- | This is for when we get a timeout or ServerError - there might not -- be anything wrong with the query so we retry it. handleTransientError :: ResolverConfig -> InflightRequest -> DNSError -> IO () handleTransientError config inf result = do if infAttempts inf > rcAttempts config then handleFailure config inf result else transmit config $ inf { infAttempts = 1 + infAttempts inf } -- | This is for when we have a harder error - either multiple timeouts or -- NXDomain etc from the server. The timeout must have been canceled by here handleFailure :: ResolverConfig -> InflightRequest -> DNSError -> IO () handleFailure config inf result = do if not $ null $ infSearch inf then sendInflight config inf else (infCallback inf) $ Left result type DNSDB = Map.Map ([String], DNSType) [(UTCTime, RR)] answersToDB :: UTCTime -> ([Entry], [Entry], [Entry]) -> DNSDB answersToDB currentTime (ans, nses, additional) = Map.unionsWith (++) $ map toMap [ans, nses, additional] where toMap :: [Entry] -> Map.Map ([String], DNSType) [(UTCTime, RR)] toMap = Map.fromListWith (++) . map (\(host, secs, rr) -> ((host, rrToType rr), [(t secs, rr)])) t :: Word32 -> UTCTime t = (flip addUTCTime) currentTime . fromIntegral dbGet :: [String] -> DNSType -> DNSDB -> [(UTCTime, RR)] dbGet host ty db = find host (0 :: Int) where -- this limits the CNAME depth find _ 16 = [] find host n = case Map.lookup (host, ty) db of Just x -> x Nothing -> case Map.lookup (host, CNAME) db of Nothing -> [] Just ((_, RRCNAME host'):_) -> find host' (n + 1) -- | This has to turn the information from the server into something -- useful for the caller. At the moment this is pretty stupid, just -- grab A records from the answer section parseQuery :: DNSType -> (Either DNSError [(UTCTime, RR)] -> IO ()) -> Either DNSError ([String], ([Entry], [Entry], [Entry])) -> IO () parseQuery ty cb e = case e of Left x -> cb $ Left x Right (host, answers) -> do currentTime <- getCurrentTime let db = answersToDB currentTime answers case dbGet host ty db of [] -> cb $ Left AnswerNotIncluded x -> cb $ Right x getResolveConfig :: IO ResolverConfig getResolveConfig = do config <- atomically $ readTVar globalConfig case config of Nothing -> do config <- resolverConfigFromResolvConf (set, config') <- atomically $ do mconfig <- readTVar globalConfig case mconfig of Nothing -> do writeTVar globalConfig $ Just config return (True, config) Just config'' -> return (False, config'') when set (forkIO (readerThread globalSocket) >> return ()) return config' Just config' -> return config' -- | Lookup some information from DNS resolve :: DNSType -- ^ the type of DNS information requested -> String -- ^ the domain to query -> IO (Either DNSError [(UTCTime, RR)]) -- ^ The RR values here will -- always be of the correct -- type for the requested -- DNSType resolve ty hostname = do var <- atomically $ newEmptyTMVar resolveAsync ty hostname (atomically . putTMVar var) atomically $ takeTMVar var -- | This is the same as resolve, below, put you get the answer asynchronously. -- Blocking the thread which makes the callback in this case is bad - it'll -- block the DNS network reading thread. resolveAsync :: DNSType -> String -> (Either DNSError [(UTCTime, RR)] -> IO ()) -> IO () resolveAsync ty host cb = do config <- getResolveConfig let labels = splitDNSName host wrappedCb = parseQuery ty cb inf = InflightRequest undefined wrappedCb labels [] 1 undefined ty Nothing -- get the list of names that we'll try to resolve, in order. -- if we are searching, that could be a list of length > 1 let names = if last host /= '.' && length labels - 1 < rcNdots config && length (rcSearchPath config) > 0 then (map ((++) labels) $ rcSearchPath config) ++ [labels] else [labels] sendInflight config $ inf { infSearch = names }