module Network.DnsCache
(
DnsCache,
withDnsCache,
resolveA,
resolveAAAA,
resolveMX,
MassResult(..),
MassType(..),
massLookup,
massLookupReport
)
where
import qualified Control.Exception as E
import qualified Data.Map as M
import Control.Concurrent
import Control.ContStuff
import Data.Char
import Data.IP
import Data.List
import Data.Map (Map)
import Data.Ord
import Data.Time.Clock
import Network.DNS hiding (lookup)
data CacheCmd res
= CacheGet Domain (Waiter res)
| CacheQuit (IO ())
| CacheRequest Domain (Bool -> IO ())
| CacheSetNotFound Domain
| CacheSetResolved Domain [res]
data DnsCache =
DnsCache {
dnsAVar :: MVar (CacheCmd IPv4),
dnsAaaaVar :: MVar (CacheCmd IPv6),
dnsMxVar :: MVar (CacheCmd Domain),
dnsResVar :: MVar ResolverCmd
}
data Entry res
= NotFound
| Resolving [Waiter res]
| Resolved UTCTime [res]
data MassResult =
MassResult {
massDomain :: Domain,
massA :: [IPv4],
massAAAA :: [IPv6],
massMX :: [Domain]
}
deriving Show
data MassType = MassA | MassAAAA | MassMX | MassAll
data ResolverCmd where
Resolve ::
Domain ->
MVar (CacheCmd res) ->
(Resolver -> Domain -> IO (Maybe [res])) ->
Waiter res ->
ResolverCmd
ResolverQuit :: IO () -> ResolverCmd
type Waiter res = Maybe [res] -> IO ()
dnsCache :: NominalDiffTime -> MVar (CacheCmd res) -> IO ()
dnsCache timeout cmdVar =
evalStateT M.empty . forever $ do
cmd <- io $ takeMVar cmdVar
case cmd of
CacheGet dom waiter -> do
mCache <- getField (M.lookup dom)
case mCache of
Nothing -> io $ waiter Nothing
Just NotFound -> io $ waiter Nothing
Just (Resolved _ res) -> io $ waiter (Just res)
Just (Resolving ws) ->
modify (M.insert dom (Resolving (waiter:ws)))
CacheQuit c -> io c >> abort ()
CacheRequest dom c -> do
mCache <- getField (M.lookup dom)
curTime <- io getCurrentTime
case mCache of
Just (Resolving _) -> io $ c False
Just (Resolved time _) | diffUTCTime curTime time < timeout ->
io $ c False
_ -> do
modify (M.insert dom (Resolving []))
io $ c True
CacheSetNotFound dom -> do
mCache <- getField (M.lookup dom)
case mCache of
Just (Resolving cs) -> io $ mapM_ ($ Nothing) cs
_ -> return ()
modify (M.insert dom NotFound)
CacheSetResolved dom res -> do
mCache <- getField (M.lookup dom)
case mCache of
Just (Resolving cs) -> io $ mapM_ ($ Just res) cs
_ -> return ()
curTime <- io getCurrentTime
modify (M.insert dom (Resolved curTime res))
massLookup :: DnsCache -> MassType -> [Domain] -> IO (Map Domain MassResult)
massLookup dns mtype domains = do
let numDomains = length domains
resultVar <- newEmptyMVar
massResolver dns mtype domains resultVar
execStateT M.empty . replicateM_ numDomains $ do
m@(MassResult dom _ _ _) <- io $ takeMVar resultVar
modify (M.insert dom m)
massLookupReport ::
(Base m ~ IO, LiftBase m, Monad m) =>
DnsCache -> MassType -> [Domain] -> (MassResult -> m ()) -> m ()
massLookupReport dns mtype domains rep = do
let numDomains = length domains
resultVar <- io newEmptyMVar
io $ massResolver dns mtype domains resultVar
replicateM_ numDomains $ io (takeMVar resultVar) >>= rep
massResolver :: DnsCache -> MassType -> [Domain] -> MVar MassResult -> IO ()
massResolver dns mtype domains resultVar = do
forM_ domains $ \domain ->
forkIO $ do
case mtype of
MassA -> do
resVar <- newEmptyMVar
forkIO $ resolveA dns domain >>= putMVar resVar . maybe [] id
res <- takeMVar resVar
putMVar resultVar (MassResult domain res [] [])
MassAAAA -> do
resVar <- newEmptyMVar
forkIO $ resolveAAAA dns domain >>= putMVar resVar . maybe [] id
res <- takeMVar resVar
putMVar resultVar (MassResult domain [] res [])
MassMX -> do
resVar <- newEmptyMVar
forkIO $ resolveMX dns domain >>= putMVar resVar . maybe [] id
res <- takeMVar resVar
putMVar resultVar (MassResult domain [] [] res)
MassAll -> do
aVar <- newEmptyMVar
aaaaVar <- newEmptyMVar
mxVar <- newEmptyMVar
forkIO $ resolveA dns domain >>= putMVar aVar . maybe [] id
forkIO $ resolveAAAA dns domain >>= putMVar aaaaVar . maybe [] id
forkIO $ resolveMX dns domain >>= putMVar mxVar . maybe [] id
a <- takeMVar aVar
aaaa <- takeMVar aaaaVar
mx <- takeMVar mxVar
putMVar resultVar (MassResult domain a aaaa mx)
resolveA :: DnsCache -> Domain -> IO (Maybe [IPv4])
resolveA cfg dom = do
let cacheVar = dnsAVar cfg
let resVar = dnsResVar cfg
answerVar <- newEmptyMVar
putMVar resVar (Resolve dom cacheVar lookupA (putMVar answerVar))
takeMVar answerVar
resolveAAAA :: DnsCache -> Domain -> IO (Maybe [IPv6])
resolveAAAA cfg dom = do
let cacheVar = dnsAaaaVar cfg
let resVar = dnsResVar cfg
answerVar <- newEmptyMVar
putMVar resVar (Resolve dom cacheVar lookupAAAA (putMVar answerVar))
takeMVar answerVar
resolveMX :: DnsCache -> Domain -> IO (Maybe [Domain])
resolveMX cfg dom = do
let cacheVar = dnsMxVar cfg
let resVar = dnsResVar cfg
let mxLookup r d = fmap (map fst . sortBy (comparing snd)) <$> lookupMX r d
answerVar <- newEmptyMVar
putMVar resVar (Resolve dom cacheVar mxLookup (putMVar answerVar))
takeMVar answerVar
resolver :: MVar ResolverCmd -> IO ()
resolver cmdVar = do
seed <- makeResolvSeed defaultResolvConf
withResolver seed $ \resolver ->
evalContT . forever $ do
cmd <- io $ takeMVar cmdVar
case cmd of
Resolve dom' cacheVar dnsLookup waiter ->
io $ do
let dom = map toLower dom'
mayResVar <- newEmptyMVar
putMVar cacheVar (CacheRequest dom (putMVar mayResVar))
mayResolve <- takeMVar mayResVar
if mayResolve
then do
mRes <- dnsLookup resolver dom
case mRes of
Nothing ->
putMVar cacheVar (CacheSetNotFound dom) >>
waiter Nothing
Just res ->
putMVar cacheVar (CacheSetResolved dom res) >>
waiter (Just res)
else do
resVar <- newEmptyMVar
putMVar cacheVar (CacheGet dom (putMVar resVar))
takeMVar resVar >>= waiter
ResolverQuit c -> io c >> abort ()
withDnsCache :: Int -> NominalDiffTime -> (DnsCache -> IO a) -> IO a
withDnsCache numThreads timeout comp = do
resolverVar <- newEmptyMVar
aVar <- newEmptyMVar
aaaaVar <- newEmptyMVar
mxVar <- newEmptyMVar
replicateM_ numThreads . forkIO $ resolver resolverVar
forkIO $ dnsCache timeout aVar
forkIO $ dnsCache timeout aaaaVar
forkIO $ dnsCache timeout mxVar
let cfg = DnsCache {
dnsAVar = aVar,
dnsAaaaVar = aaaaVar,
dnsMxVar = mxVar,
dnsResVar = resolverVar
}
E.finally (comp cfg) $ do
quitVar <- newEmptyMVar
replicateM_ numThreads $ do
putMVar resolverVar (ResolverQuit $ putMVar quitVar ())
takeMVar quitVar
putMVar aVar (CacheQuit (putMVar quitVar ()))
putMVar aaaaVar (CacheQuit (putMVar quitVar ()))
putMVar mxVar (CacheQuit (putMVar quitVar ()))
replicateM_ 3 $ takeMVar quitVar