module Network.DNS.PollResolver where
import Control.Concurrent ( forkIO )
import Control.Concurrent.MVar
import Control.Monad ( when )
import Data.List ( sortBy )
import Foreign
import Foreign.C
import Network ( HostName )
import Network.IP.Address
import Network.DNS.ADNS
import System.Posix.Poll
import System.Posix.GetTimeOfDay
type Resolver = String -> RRType -> [QueryFlag] -> IO (MVar Answer)
initResolver :: [InitFlag] -> (Resolver -> IO a) -> IO a
initResolver flags f = do
adnsInit flags $ \dns -> do
fds <- mallocForeignPtrArray initSize
mst <- newMVar (RState dns fds initSize [])
f (resolve mst)
where
initSize = 32
resolveA :: Resolver -> HostName -> IO (Either Status [HostAddress])
resolveA resolver x = do
Answer rc _ _ _ rs <- resolver x A [] >>= takeMVar
if rc /= sOK
then return (Left rc)
else return (Right [ addr | RRA (RRAddr addr) <- rs ])
resolvePTR :: Resolver -> HostAddress -> IO (Either Status [HostName])
resolvePTR resolver x = do
Answer rc _ _ _ rs <- resolver (ha2ptr x) PTR [] >>= takeMVar
if rc /= sOK
then return (Left rc)
else return (Right [ addr | RRPTR addr <- rs ])
resolveMX :: Resolver -> HostName -> IO (Either Status [(HostName, HostAddress)])
resolveMX resolver x = do
Answer rc _ _ _ rs <- resolver x MX [] >>= takeMVar
if rc /= sOK
then return (Left rc)
else do
let cmp (RRMX p1 _) (RRMX p2 _) = compare p1 p2
cmp _ _= error $ showString "unexpected record in MX lookup: " (show rs)
rs' = sortBy cmp rs
as = [ (hn,a) | RRMX _ (RRHostAddr hn stat has) <- rs'
, stat == sOK && not (null has)
, RRAddr a <- has ]
return (Right as)
query :: (Resolver -> a -> IO (Either Status [b]))
-> (Resolver -> a -> IO (Maybe [b]))
query f dns x = fmap toMaybe (f dns x)
where
toMaybe (Left rc)
| rc == sNXDOMAIN = Just []
| otherwise = Nothing
toMaybe (Right r) = Just r
dummyDNS :: Resolver
dummyDNS host _ _ = newMVar
(Answer sSYSTEMFAIL Nothing (Just host) (1) [])
data ResolverState = RState
{ adns :: AdnsState
, pollfds :: ForeignPtr Pollfd
, capacity :: Int
, queries :: [(Query, MVar Answer)]
}
resolve :: MVar ResolverState -> Resolver
resolve mst r rt qfs = modifyMVar mst $ \st -> do
res <- newEmptyMVar
q <- adnsSubmit (adns st) r rt qfs
when (null (queries st))
(forkIO (resolveLoop mst) >> return ())
let st' = st { queries = (q,res):(queries st) }
return (st', res)
resolveLoop :: MVar ResolverState -> IO ()
resolveLoop mst = do
empty <- modifyMVar mst $ \(RState dns fds cap qs) -> do
res' <- mapM (checkQuery dns) qs
case [ x | Just x <- res' ] of
[] -> do adnsQueries dns >>= mapM_ adnsCancel
return ((RState dns fds cap []), True)
res -> return ((RState dns fds cap res), False)
when (not empty) (waitForIO >> resolveLoop mst)
where
checkQuery dns (q, mv) = do
res <- adnsCheck dns q
case res of
Just a -> putMVar mv a >> return Nothing
Nothing -> return (Just (q, mv))
waitForIO = do
(nfds,to) <- beforePoll
when (nfds > 0) (doPoll nfds to >> afterPoll nfds)
beforePoll = do
b4 <- modifyMVar mst $ \st ->
withForeignPtr (pollfds st) $ \fds ->
alloca $ \nfds ->
alloca $ \to ->
alloca $ \now -> do
poke nfds (toEnum (capacity st))
poke to (1)
getTimeOfDay now
rc <- adnsBeforePoll (adns st) fds nfds to now
n <- peek nfds
tv <- peek to
if rc == 0 then return (st, (Right (n,tv))) else
if (Errno rc) == eRANGE then return (st, (Left n)) else
fail ("adnsBeforePoll returned unknown value " ++ show rc)
case b4 of
Left n -> allocFds (fromEnum n) >> beforePoll
Right x -> return x
doPoll nfds to = do
fds' <- withMVar mst (return . pollfds)
rc <- withForeignPtr fds' $ \fds ->
poll fds (toEnum (fromEnum nfds)) to
when (rc < 0) (throwErrno "PollResolver.doPoll failed")
afterPoll nfds =
withMVar mst $ \st ->
alloca $ \now ->
withForeignPtr (pollfds st) $ \fds -> do
getTimeOfDay now
adnsAfterPoll (adns st) fds nfds now
allocFds n = modifyMVar_ mst $ \st ->
if n <= capacity st then return st else do
let sizes = iterate (*2) (capacity st)
(n':_) = dropWhile (<n) sizes
fds <- mallocForeignPtrArray n'
return st { pollfds = fds
, capacity = n'
}