-- | -- Module: Network.DnsCache -- Copyright: (c) 2010 Ertugrul Soeylemez -- License: BSD3 -- Maintainer: Ertugrul Soeylemez -- Stability: experimental -- -- This module implements an asynchronous, caching DNS resolver. {-# LANGUAGE FlexibleInstances, GADTs, RankNTypes, TypeFamilies #-} module Network.DnsCache ( -- * DNS cache DnsCache, createDnsCache, freeDnsCache, withDnsCache, -- * Monadic interface DnsMonad(..), withDnsStateT, -- * DNS lookup resolveA, resolveAAAA, resolveMX, -- * DNS mass lookup MassResult(..), MassType(..), massLookup, massLookupReport, -- * Reexports Domain ) where import qualified Data.ByteString.Char8 as Bc import qualified Data.Map as M import Control.Concurrent hiding (forkIO) import Control.ContStuff import Data.Char import Data.IORef import Data.IP import Data.List import Data.Map (Map) import Data.Ord import Data.Time.Clock import Network.DNS hiding (lookup) import System.IO.Unsafe -- | Command to a DNS cache. data CacheCmd res = CacheGet Domain (Waiter res) | CacheQuit (IO ()) | CacheRequest Domain (Bool -> IO ()) | CacheSetNotFound Domain | CacheSetResolved Domain [res] -- | DNS cache configuration. data DnsCache = DnsCache { dnsAVar :: MVar (CacheCmd IPv4), dnsAaaaVar :: MVar (CacheCmd IPv6), dnsMxVar :: MVar (CacheCmd Domain), dnsResVar :: MVar ResolverCmd, dnsNumThreads :: Int } -- | DNS cache entry. data Entry res = NotFound | Resolving [Waiter res] | Resolved UTCTime [res] -- | Result of a mass lookup. data MassResult = MassResult { massDomain :: Domain, massA :: [IPv4], massAAAA :: [IPv6], massMX :: [Domain] } deriving Show -- | The type of resources to look up. data MassType = MassA | MassAAAA | MassMX | MassAll -- | DNS resolver command. data ResolverCmd where Resolve :: Domain -> MVar (CacheCmd res) -> (Resolver -> Domain -> IO (Maybe [res])) -> Waiter res -> ResolverCmd ResolverQuit :: IO () -> ResolverCmd -- | DNS resolver waiter. type Waiter res = Maybe [res] -> IO () -- | Monads, which contain a DNS cache. class MonadIO m => DnsMonad m where -- | Get the current DNS cache. getDnsCache :: m DnsCache instance MonadIO m => DnsMonad (StateT r DnsCache m) where getDnsCache = get instance DnsMonad IO where getDnsCache = readIORef globalCache {-# NOINLINE globalCache #-} globalCache :: IORef DnsCache globalCache = unsafePerformIO (createDnsCache 1 10 >>= newIORef) -- | DNS cache thread. dnsCache :: NominalDiffTime -> MVar (CacheCmd res) -> IO () dnsCache timeout cmdVar = evalStateT M.empty . forever $ do cmd <- liftIO $ takeMVar cmdVar case cmd of CacheGet dom waiter -> do mCache <- getField (M.lookup dom) case mCache of Nothing -> liftIO $ waiter Nothing Just NotFound -> liftIO $ waiter Nothing Just (Resolved _ res) -> liftIO $ waiter (Just res) Just (Resolving ws) -> modify (M.insert dom (Resolving (waiter:ws))) CacheQuit c -> liftIO c >> abort () CacheRequest dom c -> do mCache <- getField (M.lookup dom) curTime <- liftIO getCurrentTime case mCache of Just (Resolving _) -> liftIO $ c False Just (Resolved time _) | diffUTCTime curTime time < timeout -> liftIO $ c False _ -> do modify (M.insert dom (Resolving [])) liftIO $ c True CacheSetNotFound dom -> do mCache <- getField (M.lookup dom) case mCache of Just (Resolving cs) -> liftIO $ mapM_ ($ Nothing) cs _ -> return () modify (M.insert dom NotFound) CacheSetResolved dom res -> do mCache <- getField (M.lookup dom) case mCache of Just (Resolving cs) -> liftIO $ mapM_ ($ Just res) cs _ -> return () curTime <- liftIO getCurrentTime modify (M.insert dom (Resolved curTime res)) -- | Resolve an A record. resolveA :: DnsMonad m => Domain -> m (Maybe [IPv4]) resolveA dom = getDnsCache >>= resolveA_ dom -- | Resolve an A record. resolveA_ :: MonadIO m => Domain -> DnsCache -> m (Maybe [IPv4]) resolveA_ dom cfg = liftIO $ do let cacheVar = dnsAVar cfg let resVar = dnsResVar cfg answerVar <- newEmptyMVar putMVar resVar (Resolve dom cacheVar lookupA (putMVar answerVar)) takeMVar answerVar -- | Resolve an AAAA record. resolveAAAA :: DnsMonad m => Domain -> m (Maybe [IPv6]) resolveAAAA dom = getDnsCache >>= resolveAAAA_ dom -- | Resolve an AAAA record. resolveAAAA_ :: MonadIO m => Domain -> DnsCache -> m (Maybe [IPv6]) resolveAAAA_ dom cfg = liftIO $ do let cacheVar = dnsAaaaVar cfg let resVar = dnsResVar cfg answerVar <- newEmptyMVar putMVar resVar (Resolve dom cacheVar lookupAAAA (putMVar answerVar)) takeMVar answerVar -- | Resolve an MX record. resolveMX :: DnsMonad m => Domain -> m (Maybe [Domain]) resolveMX dom = getDnsCache >>= resolveMX_ dom -- | Resolve an MX record. resolveMX_ :: MonadIO m => Domain -> DnsCache -> m (Maybe [Domain]) resolveMX_ dom cfg = liftIO $ do let cacheVar = dnsMxVar cfg let resVar = dnsResVar cfg let mxLookup r d = fmap (map fst . sortBy (comparing snd)) <$> lookupMX r d answerVar <- newEmptyMVar putMVar resVar (Resolve dom cacheVar mxLookup (putMVar answerVar)) takeMVar answerVar -- | DNS resolver thread. resolver :: MVar ResolverCmd -> IO () resolver cmdVar = do seed <- makeResolvSeed defaultResolvConf withResolver seed $ \resolver -> evalContT . forever $ do cmd <- liftIO $ takeMVar cmdVar case cmd of Resolve dom' cacheVar dnsLookup waiter -> liftIO $ do let dom = Bc.map toLower dom' mayResVar <- newEmptyMVar putMVar cacheVar (CacheRequest dom (putMVar mayResVar)) mayResolve <- takeMVar mayResVar if mayResolve then do mRes <- dnsLookup resolver dom case mRes of Nothing -> putMVar cacheVar (CacheSetNotFound dom) >> waiter Nothing Just res -> putMVar cacheVar (CacheSetResolved dom res) >> waiter (Just res) else do resVar <- newEmptyMVar putMVar cacheVar (CacheGet dom (putMVar resVar)) takeMVar resVar >>= waiter ResolverQuit c -> liftIO c >> abort () -- ================== -- -- DNS cache creation -- -- ================== -- -- | Start a DNS cache with the given number of resolver threads and the -- given cache timeout. createDnsCache :: MonadIO m => Int -> NominalDiffTime -> m DnsCache createDnsCache numThreads timeout = do resolverVar <- liftIO newEmptyMVar aVar <- liftIO newEmptyMVar aaaaVar <- liftIO newEmptyMVar mxVar <- liftIO newEmptyMVar liftIO $ do replicateM_ numThreads . forkIO $ resolver resolverVar forkIO $ dnsCache timeout aVar forkIO $ dnsCache timeout aaaaVar forkIO $ dnsCache timeout mxVar let dns = DnsCache { dnsAVar = aVar, dnsAaaaVar = aaaaVar, dnsMxVar = mxVar, dnsResVar = resolverVar, dnsNumThreads = numThreads } return dns -- | Free existing DNS cache. freeDnsCache :: MonadIO m => DnsCache -> m () freeDnsCache dns = liftIO $ do let aVar = dnsAVar dns let aaaaVar = dnsAaaaVar dns let mxVar = dnsMxVar dns let resolverVar = dnsResVar dns let numThreads = dnsNumThreads dns quitVar <- newEmptyMVar replicateM_ numThreads $ do putMVar resolverVar (ResolverQuit $ putMVar quitVar ()) takeMVar quitVar putMVar aVar (CacheQuit (putMVar quitVar ())) putMVar aaaaVar (CacheQuit (putMVar quitVar ())) putMVar mxVar (CacheQuit (putMVar quitVar ())) replicateM_ 3 $ takeMVar quitVar -- | Convenient wrapper around 'createDnsCache' and 'freeDnsCache'. withDnsCache :: (HasExceptions m, MonadIO m) => Int -> NominalDiffTime -> (DnsCache -> m a) -> m a withDnsCache numThreads timeout comp = bracket (createDnsCache numThreads timeout) freeDnsCache comp -- ==================== -- -- Running computations -- -- ==================== -- -- | Run a 'StateT' computation with a DNS cache. withDnsStateT :: (Applicative m, HasExceptions m, MonadIO m) => Int -> NominalDiffTime -> StateT a DnsCache m a -> m a withDnsStateT numThreads timeout comp = withDnsCache numThreads timeout (\dns -> evalStateT dns comp) -- =================== -- -- DNS mass resolution -- -- =================== -- -- | Perform a mass lookup. massLookup :: DnsMonad m => MassType -> [Domain] -> m (Map Domain MassResult) massLookup mtype domains = do getDnsCache >>= massLookup_ mtype domains -- | Perform a mass lookup. massLookup_ :: MonadIO m => MassType -> [Domain] -> DnsCache -> m (Map Domain MassResult) massLookup_ mtype domains dns = liftIO $ do let numDomains = length domains resultVar <- newEmptyMVar massResolver dns mtype domains resultVar execStateT M.empty . replicateM_ numDomains $ do m@(MassResult dom _ _ _) <- liftIO $ takeMVar resultVar modify (M.insert dom m) -- | Perform a mass lookup with report function. massLookupReport :: DnsMonad m => MassType -> [Domain] -> (MassResult -> m ()) -> m () massLookupReport mtype domains rep = getDnsCache >>= massLookupReport_ mtype domains rep -- | Perform a mass lookup with report function. massLookupReport_ :: MonadIO m => MassType -> [Domain] -> (MassResult -> m ()) -> DnsCache -> m () massLookupReport_ mtype domains rep dns = do let numDomains = length domains resultVar <- liftIO newEmptyMVar liftIO $ massResolver dns mtype domains resultVar replicateM_ numDomains $ liftIO (takeMVar resultVar) >>= rep -- | Mass resolver. massResolver :: DnsCache -> MassType -> [Domain] -> MVar MassResult -> IO () massResolver dns mtype domains resultVar = do forM_ domains $ \domain -> forkIO $ do case mtype of MassA -> do resVar <- newEmptyMVar forkIO $ resolveA_ domain dns >>= putMVar resVar . maybe [] id res <- takeMVar resVar putMVar resultVar (MassResult domain res [] []) MassAAAA -> do resVar <- newEmptyMVar forkIO $ resolveAAAA_ domain dns >>= putMVar resVar . maybe [] id res <- takeMVar resVar putMVar resultVar (MassResult domain [] res []) MassMX -> do resVar <- newEmptyMVar forkIO $ resolveMX_ domain dns >>= putMVar resVar . maybe [] id res <- takeMVar resVar putMVar resultVar (MassResult domain [] [] res) MassAll -> do aVar <- newEmptyMVar aaaaVar <- newEmptyMVar mxVar <- newEmptyMVar forkIO $ resolveA_ domain dns >>= putMVar aVar . maybe [] id forkIO $ resolveAAAA_ domain dns >>= putMVar aaaaVar . maybe [] id forkIO $ resolveMX_ domain dns >>= putMVar mxVar . maybe [] id a <- takeMVar aVar aaaa <- takeMVar aaaaVar mx <- takeMVar mxVar putMVar resultVar (MassResult domain a aaaa mx)