{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE OverloadedStrings #-}
module Network.DNS.Cache (
DNSCacheConf(..)
, DNSCache
, withDNSCache
, lookup
, lookupCache
, Result(..)
, resolve
, resolveCache
, wait
) where
import Control.Applicative ((<$>))
import Control.Concurrent.Lifted (threadDelay, fork, killThread)
import Control.Concurrent.Async (async, waitAnyCancel)
import Control.Exception.Lifted (bracket)
import Control.Monad (forever)
import Control.Monad.IO.Class
import Control.Monad.Trans.Control
import qualified Data.ByteString.Char8 as BS
import qualified Data.ByteString.Short as B
import Data.IP (toHostAddress)
import Data.Time (getCurrentTime, addUTCTime, NominalDiffTime)
import Network.DNS hiding (lookup)
import Network.DNS.Cache.Cache
import qualified Network.DNS.Cache.Sync as S
import Network.DNS.Cache.Types
import Network.DNS.Cache.Utils
import Network.DNS.Cache.Value
import Prelude hiding (lookup)
data DNSCacheConf = DNSCacheConf {
resolvConfs :: [ResolvConf]
, maxConcurrency :: Int
, minTTL :: NominalDiffTime
, maxTTL :: NominalDiffTime
, negativeTTL :: NominalDiffTime
}
data DNSCache = DNSCache {
cacheSeeds :: [ResolvSeed]
, cacheNofServers :: !Int
, cacheRef :: CacheRef
, cacheActiveRef :: S.ActiveRef
, cacheConcVar :: S.ConcVar
, cacheConcLimit :: Int
, cacheMinTTL :: NominalDiffTime
, cacheMaxTTL :: NominalDiffTime
, cacheNegTTL :: NominalDiffTime
}
withDNSCache :: (MonadBaseControl IO m, MonadIO m)
=> DNSCacheConf -> (DNSCache -> m a) -> m a
withDNSCache conf func = do
seeds <- mapM (liftIO . makeResolvSeed) (resolvConfs conf)
let n = length seeds
cacheref <- liftIO newCacheRef
activeref <- liftIO S.newActiveRef
lvar <- liftIO S.newConcVar
let cache = DNSCache seeds n cacheref activeref lvar maxcon minttl maxttl negttl
bracket (fork $ liftIO $ prune cacheref) killThread (const $ func cache)
where
maxcon = maxConcurrency conf
minttl = minTTL conf
maxttl = maxTTL conf
negttl = negativeTTL conf
lookupPSQ :: DNSCache -> Domain -> IO (Key, Maybe (Prio, Entry))
lookupPSQ cache dom = do
!mx <- lookupCacheRef key cacheref
return (key,mx)
where
cacheref = cacheRef cache
!key = B.toShort dom
lookupCache :: DNSCache -> Domain -> IO (Maybe HostAddress)
lookupCache cache dom = do
mx <- resolveCache cache dom
case mx of
Nothing -> return Nothing
Just ev -> return (fromEither ev)
lookup :: DNSCache -> Domain -> IO (Maybe HostAddress)
lookup cache dom = fromEither <$> resolve cache dom
resolveCache :: DNSCache -> Domain -> IO (Maybe (Either DNSError Result))
resolveCache _ dom
| isIPAddr dom = Just . Right . Numeric <$> return (tov4 dom)
where
tov4 = toHostAddress . read . BS.unpack
resolveCache cache dom = do
(_, mx) <- lookupPSQ cache dom
case mx of
Nothing -> return Nothing
Just (_, Right v) -> Just . Right . Hit <$> rotate v
Just (_, Left e) -> Just . Left <$> return e
resolve :: DNSCache -> Domain -> IO (Either DNSError Result)
resolve _ dom
| isIPAddr dom = return $ Right $ Numeric $ tov4 dom
where
tov4 = toHostAddress . read . BS.unpack
resolve cache dom = do
(key,mx) <- lookupPSQ cache dom
case mx of
Just (_,ev) -> case ev of
Left e -> Left <$> return e
Right v -> Right . Hit <$> rotate v
Nothing -> do
ma <- S.lookupActiveRef key activeref
case ma of
Just avar -> S.listen avar
Nothing -> do
avar <- S.newActiveVar
S.insertActiveRef key avar activeref
x <- sendQuery cache dom
!res <- case x of
Left err -> insertNegative cache key err
Right [] -> insertNegative cache key UnexpectedRDATA
Right addrs -> insertPositive cache key addrs
S.deleteActiveRef key activeref
S.tell avar (toHit res)
return res
where
activeref = cacheActiveRef cache
toHit (Right (Resolved addr)) = Right (Hit addr)
toHit x = x
insertPositive :: DNSCache -> Key -> [(HostAddress, TTL)]
-> IO (Either DNSError Result)
insertPositive _ _ [] = error "insertPositive"
insertPositive cache key addrs@((addr,ttl):_) = do
!ent <- positiveEntry $ map fst addrs
!tim <- addUTCTime lifeTime <$> getCurrentTime
insertCacheRef key tim ent cacheref
return $! Right $ Resolved addr
where
minttl = cacheMinTTL cache
maxttl = cacheMaxTTL cache
!lifeTime = minttl `max` (maxttl `min` fromIntegral ttl)
cacheref = cacheRef cache
insertNegative :: DNSCache -> Key -> DNSError -> IO (Either DNSError Result)
insertNegative cache key err = do
!tim <- addUTCTime lifeTime <$> getCurrentTime
insertCacheRef key tim (Left err) cacheref
return $ Left err
where
lifeTime = cacheNegTTL cache
cacheref = cacheRef cache
sendQuery :: DNSCache -> Domain -> IO (Either DNSError [(HostAddress,TTL)])
sendQuery cache dom = bracket setup teardown body
where
setup = waitIncrease cache
teardown _ = decrease cache
body _ = concResolv cache dom
waitIncrease :: DNSCache -> IO ()
waitIncrease cache = S.waitIncrease lvar lim
where
lvar = cacheConcVar cache
lim = cacheConcLimit cache
decrease :: DNSCache -> IO ()
decrease cache = S.decrease lvar
where
lvar = cacheConcVar cache
concResolv :: DNSCache -> Domain -> IO (Either DNSError [(HostAddress,TTL)])
concResolv cache dom = withResolvers seeds $ \resolvers -> do
eans <- resolv n resolvers dom
return $ case eans of
Left err -> Left err
Right ans -> fromDNSMessage ans getHostAddressandTTL
where
n = cacheNofServers cache
seeds = cacheSeeds cache
isA r = rrtype r == A
unTag (RD_A ip) = ip
unTag _ = error "unTag"
toAddr = toHostAddress . unTag . rdata
hostAddressandTTL r = (toAddr r, rrttl r)
getHostAddressandTTL = map hostAddressandTTL . filter isA . answer
resolv :: Int -> [Resolver] -> Domain -> IO (Either DNSError DNSMessage)
resolv 1 resolvers dom = lookupRaw (head resolvers) dom A
resolv _ resolvers dom = do
asyncs <- mapM async actions
snd <$> waitAnyCancel asyncs
where
actions = map (\res -> lookupRaw res dom A) resolvers
wait :: DNSCache -> (Int -> Bool) -> IO ()
wait cache cond = S.wait lvar cond
where
lvar = cacheConcVar cache
prune :: CacheRef -> IO ()
prune cacheref = forever $ do
threadDelay 10000000
tim <- getCurrentTime
pruneCacheRef tim cacheref