{-# LANGUAGE BangPatterns #-} {-# LANGUAGE OverloadedStrings #-} -- | DNS cache to resolve domains concurrently. module Network.DNS.Cache ( DNSCacheConf(..) , DNSCache , withDNSCache -- * Looking up , lookup , lookupCache -- * Resolving , Result(..) , resolve , resolveCache -- * Waiting , wait ) where import Control.Applicative ((<$>)) import Control.Concurrent (threadDelay, forkIO, killThread) import Control.Concurrent.Async (async, waitAnyCancel) import Control.Exception (bracket) import Control.Monad (forever, void) import qualified Data.ByteString.Char8 as BS import Data.IP (toHostAddress) import Data.Time (getCurrentTime, addUTCTime, NominalDiffTime) import Network.DNS hiding (lookup) import Network.DNS.Cache.Cache import qualified Network.DNS.Cache.Sync as S import Network.DNS.Cache.Types import Network.DNS.Cache.Utils import Network.DNS.Cache.Value import Prelude hiding (lookup) ---------------------------------------------------------------- -- | Configuration for DNS cache. data DNSCacheConf = DNSCacheConf { -- | A list of resolvers (cache DNS servers). -- A domain is resolved by the resolvers concurrently. -- The first reply is used regardless of success/failure at this moment resolvConfs :: [ResolvConf] -- | Capability of how many domains can be resolved concurrently , maxConcurrency :: Int -- | The minimum bound of cache duration for success replies in seconds. , minTTL :: NominalDiffTime -- | The maximum bound of cache duration for success replies in seconds. , maxTTL :: NominalDiffTime -- | The cache duration for failure replies in seconds. , negativeTTL :: NominalDiffTime } -- | An abstract data for DNS cache. -- Cached domains are expired every 10 seconds according to their TTL. data DNSCache = DNSCache { cacheSeeds :: [ResolvSeed] , cacheNofServers :: !Int , cacheRef :: CacheRef , cacheActiveRef :: S.ActiveRef , cacheConcVar :: S.ConcVar , cacheConcLimit :: Int , cacheMinTTL :: NominalDiffTime , cacheMaxTTL :: NominalDiffTime , cacheNegTTL :: NominalDiffTime } ---------------------------------------------------------------- -- | A basic function to create DNS cache. -- Domains should be resolved in the function of the second argument. withDNSCache :: DNSCacheConf -> (DNSCache -> IO a) -> IO a withDNSCache conf func = do seeds <- mapM makeResolvSeed (resolvConfs conf) let n = length seeds cacheref <- newCacheRef activeref <- S.newActiveRef lvar <- S.newConcVar let cache = DNSCache seeds n cacheref activeref lvar maxcon minttl maxttl negttl bracket (forkIO $ prune cacheref) killThread (const $ func cache) where maxcon = maxConcurrency conf minttl = minTTL conf maxttl = maxTTL conf negttl = negativeTTL conf ---------------------------------------------------------------- lookupPSQ :: DNSCache -> Domain -> IO (Key, Maybe (Prio, Entry)) lookupPSQ cache dom = do !mx <- lookupCacheRef key cacheref return (key,mx) where cacheref = cacheRef cache !key = newKey dom ---------------------------------------------------------------- -- | Lookup 'Domain' only in the cache. lookupCache :: DNSCache -> Domain -> IO (Maybe HostAddress) lookupCache cache dom = do mx <- resolveCache cache dom case mx of Nothing -> return Nothing Just ev -> return (fromEither ev) ---------------------------------------------------------------- -- | Lookup 'Domain' in the cache. -- If not exist, queries are sent to DNS servers and -- resolved IP addresses are cached. lookup :: DNSCache -> Domain -> IO (Maybe HostAddress) lookup cache dom = fromEither <$> resolve cache dom ---------------------------------------------------------------- -- | Lookup 'Domain' only in the cache. resolveCache :: DNSCache -> Domain -> IO (Maybe (Either DNSError Result)) resolveCache _ dom | isIPAddr dom = Just . Right . Numeric <$> return (tov4 dom) where tov4 = toHostAddress . read . BS.unpack resolveCache cache dom = do (_, mx) <- lookupPSQ cache dom case mx of Nothing -> return Nothing Just (_, Right v) -> Just . Right . Hit <$> rotate v Just (_, Left e) -> Just . Left <$> return e -- | Lookup 'Domain' in the cache. -- If not exist, queries are sent to DNS servers and -- resolved IP addresses are cached. resolve :: DNSCache -> Domain -> IO (Either DNSError Result) resolve _ dom | isIPAddr dom = return $ Right $ Numeric $ tov4 dom where tov4 = toHostAddress . read . BS.unpack resolve cache dom = do (key,mx) <- lookupPSQ cache dom case mx of Just (_,ev) -> case ev of Left e -> Left <$> return e Right v -> Right . Hit <$> rotate v Nothing -> do -- If this domain is being resolved by another thread -- let's wait. ma <- S.lookupActiveRef key activeref case ma of Just avar -> S.listen avar Nothing -> do avar <- S.newActiveVar S.insertActiveRef key avar activeref x <- sendQuery cache dom !res <- case x of Left err -> insertNegative cache key err Right [] -> insertNegative cache key UnexpectedRDATA Right addrs -> insertPositive cache key addrs S.deleteActiveRef key activeref S.tell avar (toHit res) return res where activeref = cacheActiveRef cache toHit (Right (Resolved addr)) = Right (Hit addr) toHit x = x insertPositive :: DNSCache -> Key -> [(HostAddress, TTL)] -> IO (Either DNSError Result) insertPositive _ _ [] = error "insertPositive" insertPositive cache key addrs@((addr,ttl):_) = do !ent <- positiveEntry $ map fst addrs !tim <- addUTCTime lifeTime <$> getCurrentTime insertCacheRef key tim ent cacheref return $! Right $ Resolved addr where minttl = cacheMinTTL cache maxttl = cacheMaxTTL cache !lifeTime = minttl `max` (maxttl `min` fromIntegral ttl) cacheref = cacheRef cache insertNegative :: DNSCache -> Key -> DNSError -> IO (Either DNSError Result) insertNegative cache key err = do !tim <- addUTCTime lifeTime <$> getCurrentTime insertCacheRef key tim (Left err) cacheref return $ Left err where lifeTime = cacheNegTTL cache cacheref = cacheRef cache ---------------------------------------------------------------- sendQuery :: DNSCache -> Domain -> IO (Either DNSError [(HostAddress,TTL)]) sendQuery cache dom = bracket setup teardown body where setup = waitIncrease cache teardown _ = decrease cache body _ = concResolv cache dom waitIncrease :: DNSCache -> IO () waitIncrease cache = S.waitIncrease lvar lim where lvar = cacheConcVar cache lim = cacheConcLimit cache decrease :: DNSCache -> IO () decrease cache = S.decrease lvar where lvar = cacheConcVar cache concResolv :: DNSCache -> Domain -> IO (Either DNSError [(HostAddress,TTL)]) concResolv cache dom = withResolvers seeds $ \resolvers -> do eans <- resolv n resolvers dom return $ case eans of Left err -> Left err Right ans -> fromDNSFormat ans getHostAddressandTTL where n = cacheNofServers cache seeds = cacheSeeds cache isA r = rrtype r == A unTag (RD_A ip) = ip unTag _ = error "unTag" toAddr = toHostAddress . unTag . rdata hostAddressandTTL r = (toAddr r, rrttl r) getHostAddressandTTL = map hostAddressandTTL . filter isA . answer resolv :: Int -> [Resolver] -> Domain -> IO (Either DNSError DNSFormat) resolv 1 resolvers dom = lookupRaw (head resolvers) dom A resolv _ resolvers dom = do asyncs <- mapM async actions snd <$> waitAnyCancel asyncs where actions = map (\res -> lookupRaw res dom A) resolvers ---------------------------------------------------------------- -- | Wait until the predicate in the second argument is satisfied. -- The predicate are given the number of the current resolving domains. -- -- For instance, if you ensure that no resolvings are going on: -- -- > wait cache (== 0) -- -- If you want to ensure that capability of concurrent resolving is not full: -- -- > wait cache (< maxCon) -- -- where 'maxCon' represents 'maxConcurrency' in 'DNSCacheConf'. wait :: DNSCache -> (Int -> Bool) -> IO () wait cache cond = S.wait lvar cond where lvar = cacheConcVar cache prune :: CacheRef -> IO () prune cacheref = forever $ do threadDelay 10000000 tim <- getCurrentTime pruneCacheRef tim cacheref