{-# LANGUAGE BangPatterns #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE OverloadedStrings #-} -- | DNS cache to resolve domains concurrently. module Network.DNS.Cache ( DNSCacheConf(..) , DNSCache , withDNSCache -- * Looking up , lookup , lookupCache -- * Resolving , Result(..) , resolve , resolveCache -- * Waiting , wait ) where import Control.Applicative ((<$>)) import Control.Concurrent.Lifted (threadDelay, fork, killThread) import Control.Concurrent.Async (async, waitAnyCancel) import Control.Exception.Lifted (bracket) import Control.Monad (forever) import Control.Monad.IO.Class import Control.Monad.Trans.Control import qualified Data.ByteString.Char8 as BS import qualified Data.ByteString.Short as B import Data.IP (toHostAddress) import Data.Time (getCurrentTime, addUTCTime, NominalDiffTime) import Network.DNS hiding (lookup) import Network.DNS.Cache.Cache import qualified Network.DNS.Cache.Sync as S import Network.DNS.Cache.Types import Network.DNS.Cache.Utils import Network.DNS.Cache.Value import Prelude hiding (lookup) ---------------------------------------------------------------- -- | Configuration for DNS cache. data DNSCacheConf = DNSCacheConf { -- | A list of resolvers (cache DNS servers). -- A domain is resolved by the resolvers concurrently. -- The first reply is used regardless of success/failure at this moment resolvConfs :: [ResolvConf] -- | Capability of how many domains can be resolved concurrently , maxConcurrency :: Int -- | The minimum bound of cache duration for success replies in seconds. , minTTL :: NominalDiffTime -- | The maximum bound of cache duration for success replies in seconds. , maxTTL :: NominalDiffTime -- | The cache duration for failure replies in seconds. , negativeTTL :: NominalDiffTime } -- | An abstract data for DNS cache. -- Cached domains are expired every 10 seconds according to their TTL. data DNSCache = DNSCache { cacheSeeds :: [ResolvSeed] , cacheNofServers :: !Int , cacheRef :: CacheRef , cacheActiveRef :: S.ActiveRef , cacheConcVar :: S.ConcVar , cacheConcLimit :: Int , cacheMinTTL :: NominalDiffTime , cacheMaxTTL :: NominalDiffTime , cacheNegTTL :: NominalDiffTime } ---------------------------------------------------------------- -- | A basic function to create DNS cache. -- Domains should be resolved in the function of the second argument. withDNSCache :: (MonadBaseControl IO m, MonadIO m) => DNSCacheConf -> (DNSCache -> m a) -> m a withDNSCache conf func = do seeds <- mapM (liftIO . makeResolvSeed) (resolvConfs conf) let n = length seeds cacheref <- liftIO newCacheRef activeref <- liftIO S.newActiveRef lvar <- liftIO S.newConcVar let cache = DNSCache seeds n cacheref activeref lvar maxcon minttl maxttl negttl bracket (fork $ liftIO $ prune cacheref) killThread (const $ func cache) where maxcon = maxConcurrency conf minttl = minTTL conf maxttl = maxTTL conf negttl = negativeTTL conf ---------------------------------------------------------------- lookupPSQ :: DNSCache -> Domain -> IO (Key, Maybe (Prio, Entry)) lookupPSQ cache dom = do !mx <- lookupCacheRef key cacheref return (key,mx) where cacheref = cacheRef cache !key = B.toShort dom ---------------------------------------------------------------- -- | Lookup 'Domain' only in the cache. lookupCache :: DNSCache -> Domain -> IO (Maybe HostAddress) lookupCache cache dom = do mx <- resolveCache cache dom case mx of Nothing -> return Nothing Just ev -> return (fromEither ev) ---------------------------------------------------------------- -- | Lookup 'Domain' in the cache. -- If not exist, queries are sent to DNS servers and -- resolved IP addresses are cached. lookup :: DNSCache -> Domain -> IO (Maybe HostAddress) lookup cache dom = fromEither <$> resolve cache dom ---------------------------------------------------------------- -- | Lookup 'Domain' only in the cache. resolveCache :: DNSCache -> Domain -> IO (Maybe (Either DNSError Result)) resolveCache _ dom | isIPAddr dom = Just . Right . Numeric <$> return (tov4 dom) where tov4 = toHostAddress . read . BS.unpack resolveCache cache dom = do (_, mx) <- lookupPSQ cache dom case mx of Nothing -> return Nothing Just (_, Right v) -> Just . Right . Hit <$> rotate v Just (_, Left e) -> Just . Left <$> return e -- | Lookup 'Domain' in the cache. -- If not exist, queries are sent to DNS servers and -- resolved IP addresses are cached. resolve :: DNSCache -> Domain -> IO (Either DNSError Result) resolve _ dom | isIPAddr dom = return $ Right $ Numeric $ tov4 dom where tov4 = toHostAddress . read . BS.unpack resolve cache dom = do (key,mx) <- lookupPSQ cache dom case mx of Just (_,ev) -> case ev of Left e -> Left <$> return e Right v -> Right . Hit <$> rotate v Nothing -> do -- If this domain is being resolved by another thread -- let's wait. ma <- S.lookupActiveRef key activeref case ma of Just avar -> S.listen avar Nothing -> do avar <- S.newActiveVar S.insertActiveRef key avar activeref x <- sendQuery cache dom !res <- case x of Left err -> insertNegative cache key err Right [] -> insertNegative cache key UnexpectedRDATA Right addrs -> insertPositive cache key addrs S.deleteActiveRef key activeref S.tell avar (toHit res) return res where activeref = cacheActiveRef cache toHit (Right (Resolved addr)) = Right (Hit addr) toHit x = x insertPositive :: DNSCache -> Key -> [(HostAddress, TTL)] -> IO (Either DNSError Result) insertPositive _ _ [] = error "insertPositive" insertPositive cache key addrs@((addr,ttl):_) = do !ent <- positiveEntry $ map fst addrs !tim <- addUTCTime lifeTime <$> getCurrentTime insertCacheRef key tim ent cacheref return $! Right $ Resolved addr where minttl = cacheMinTTL cache maxttl = cacheMaxTTL cache !lifeTime = minttl `max` (maxttl `min` fromIntegral ttl) cacheref = cacheRef cache insertNegative :: DNSCache -> Key -> DNSError -> IO (Either DNSError Result) insertNegative cache key err = do !tim <- addUTCTime lifeTime <$> getCurrentTime insertCacheRef key tim (Left err) cacheref return $ Left err where lifeTime = cacheNegTTL cache cacheref = cacheRef cache ---------------------------------------------------------------- sendQuery :: DNSCache -> Domain -> IO (Either DNSError [(HostAddress,TTL)]) sendQuery cache dom = bracket setup teardown body where setup = waitIncrease cache teardown _ = decrease cache body _ = concResolv cache dom waitIncrease :: DNSCache -> IO () waitIncrease cache = S.waitIncrease lvar lim where lvar = cacheConcVar cache lim = cacheConcLimit cache decrease :: DNSCache -> IO () decrease cache = S.decrease lvar where lvar = cacheConcVar cache concResolv :: DNSCache -> Domain -> IO (Either DNSError [(HostAddress,TTL)]) concResolv cache dom = withResolvers seeds $ \resolvers -> do eans <- resolv n resolvers dom return $ case eans of Left err -> Left err Right ans -> fromDNSFormat ans getHostAddressandTTL where n = cacheNofServers cache seeds = cacheSeeds cache isA r = rrtype r == A unTag (RD_A ip) = ip unTag _ = error "unTag" toAddr = toHostAddress . unTag . rdata hostAddressandTTL r = (toAddr r, rrttl r) getHostAddressandTTL = map hostAddressandTTL . filter isA . answer resolv :: Int -> [Resolver] -> Domain -> IO (Either DNSError DNSFormat) resolv 1 resolvers dom = lookupRaw (head resolvers) dom A resolv _ resolvers dom = do asyncs <- mapM async actions snd <$> waitAnyCancel asyncs where actions = map (\res -> lookupRaw res dom A) resolvers ---------------------------------------------------------------- -- | Wait until the predicate in the second argument is satisfied. -- The predicate are given the number of the current resolving domains. -- -- For instance, if you ensure that no resolvings are going on: -- -- > wait cache (== 0) -- -- If you want to ensure that capability of concurrent resolving is not full: -- -- > wait cache (< maxCon) -- -- where 'maxCon' represents 'maxConcurrency' in 'DNSCacheConf'. wait :: DNSCache -> (Int -> Bool) -> IO () wait cache cond = S.wait lvar cond where lvar = cacheConcVar cache prune :: CacheRef -> IO () prune cacheref = forever $ do threadDelay 10000000 tim <- getCurrentTime pruneCacheRef tim cacheref