{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE DeriveDataTypeable #-}
module Network.DNS.Transport (
Resolver(..)
, resolve
) where
import Control.Concurrent.Async (async, waitAnyCancel)
import Control.Exception as E
import qualified Data.ByteString.Char8 as BS
import qualified Data.List.NonEmpty as NE
import Network.Socket (AddrInfo(..), SockAddr(..), Family(AF_INET, AF_INET6), Socket, SocketType(Stream), close, socket, connect, defaultProtocol)
import System.IO.Error (annotateIOError)
import System.Timeout (timeout)
import Network.DNS.IO
import Network.DNS.Imports
import Network.DNS.Types
import Network.DNS.Types.Internal
checkResp :: Question -> Identifier -> DNSMessage -> Bool
checkResp q seqno = isNothing . checkRespM q seqno
checkRespM :: Question -> Identifier -> DNSMessage -> Maybe DNSError
checkRespM q seqno resp
| identifier (header resp) /= seqno = Just SequenceNumberMismatch
| [q] /= question resp = Just QuestionMismatch
| otherwise = Nothing
data TCPFallback = TCPFallback deriving (Show, Typeable)
instance Exception TCPFallback
type Rslv0 = QueryControls -> (Socket -> IO DNSMessage)
-> IO (Either DNSError DNSMessage)
type Rslv1 = Question
-> Int
-> Int
-> Rslv0
type TcpRslv = AddrInfo
-> Question
-> Int
-> QueryControls
-> IO DNSMessage
type UdpRslv = Int
-> (Socket -> IO DNSMessage)
-> TcpRslv
resolve :: Domain -> TYPE -> Resolver -> Rslv0
resolve dom typ rlv qctls rcv
| isIllegal dom = return $ Left IllegalDomain
| typ == AXFR = return $ Left InvalidAXFRLookup
| onlyOne = resolveOne (head nss) (head gens) q tm retry ctls rcv
| concurrent = resolveConcurrent nss gens q tm retry ctls rcv
| otherwise = resolveSequential nss gens q tm retry ctls rcv
where
q = case BS.last dom of
'.' -> Question dom typ
_ -> Question (dom <> ".") typ
gens = NE.toList $ genIds rlv
seed = resolvseed rlv
nss = NE.toList $ nameservers seed
onlyOne = length nss == 1
ctls = qctls <> resolvQueryControls (resolvconf $ resolvseed rlv)
conf = resolvconf seed
concurrent = resolvConcurrent conf
tm = resolvTimeout conf
retry = resolvRetry conf
resolveSequential :: [AddrInfo] -> [IO Identifier] -> Rslv1
resolveSequential nss gs q tm retry ctls rcv = loop nss gs
where
loop [ai] [gen] = resolveOne ai gen q tm retry ctls rcv
loop (ai:ais) (gen:gens) = do
eres <- resolveOne ai gen q tm retry ctls rcv
case eres of
Left _ -> loop ais gens
res -> return res
loop _ _ = error "resolveSequential:loop"
resolveConcurrent :: [AddrInfo] -> [IO Identifier] -> Rslv1
resolveConcurrent nss gens q tm retry ctls rcv = do
asyncs <- mapM mkAsync $ zip nss gens
snd <$> waitAnyCancel asyncs
where
mkAsync (ai,gen) = async $ resolveOne ai gen q tm retry ctls rcv
resolveOne :: AddrInfo -> IO Identifier -> Rslv1
resolveOne ai gen q tm retry ctls rcv =
E.try $ udpTcpLookup gen retry rcv ai q tm ctls
udpTcpLookup :: IO Identifier -> UdpRslv
udpTcpLookup gen retry rcv ai q tm ctls = do
ident <- gen
udpLookup ident retry rcv ai q tm ctls `E.catch`
\TCPFallback -> tcpLookup gen ai q tm ctls
ioErrorToDNSError :: AddrInfo -> String -> IOError -> IO DNSMessage
ioErrorToDNSError ai tag ioe = throwIO $ NetworkFailure aioe
where
aioe = annotateIOError ioe (show ai) Nothing $ Just tag
udpOpen :: AddrInfo -> IO Socket
udpOpen ai = do
sock <- socket (addrFamily ai) (addrSocketType ai) (addrProtocol ai)
connect sock (addrAddress ai)
return sock
udpLookup :: Identifier -> UdpRslv
udpLookup ident retry rcv ai q tm ctls = do
let qry = encodeQuestion ident q ctls
E.handle (ioErrorToDNSError ai "UDP") $
bracket (udpOpen ai) close (loop qry ctls 0 RetryLimitExceeded)
where
loop qry lctls cnt err sock
| cnt == retry = E.throwIO err
| otherwise = do
mres <- timeout tm (send sock qry >> getAns sock)
case mres of
Nothing -> loop qry lctls (cnt + 1) RetryLimitExceeded sock
Just res -> do
let fl = flags $ header res
tc = trunCation fl
rc = rcode fl
eh = ednsHeader res
cs = ednsEnabled FlagClear <> lctls
if tc then E.throwIO TCPFallback
else if rc == FormatErr && eh == NoEDNS && cs /= lctls
then let qry' = encodeQuestion ident q cs
in loop qry' cs cnt RetryLimitExceeded sock
else return res
getAns sock = do
mres <- rcv sock
if checkResp q ident mres
then return mres
else getAns sock
tcpOpen :: SockAddr -> IO Socket
tcpOpen peer = case peer of
SockAddrInet{} -> socket AF_INET Stream defaultProtocol
SockAddrInet6{} -> socket AF_INET6 Stream defaultProtocol
_ -> E.throwIO ServerFailure
tcpLookup :: IO Identifier -> TcpRslv
tcpLookup gen ai q tm ctls =
E.handle (ioErrorToDNSError ai "TCP") $ do
res <- bracket (tcpOpen addr) close (perform ctls)
let rc = rcode $ flags $ header res
eh = ednsHeader res
cs = ednsEnabled FlagClear <> ctls
if rc == FormatErr && eh == NoEDNS && cs /= ctls
then bracket (tcpOpen addr) close (perform cs)
else return res
where
addr = addrAddress ai
perform cs vc = do
ident <- gen
let qry = encodeQuestion ident q cs
mres <- timeout tm $ do
connect vc addr
sendVC vc qry
receiveVC vc
case mres of
Nothing -> E.throwIO TimeoutExpired
Just res -> maybe (return res) E.throwIO (checkRespM q ident res)
badLength :: Domain -> Bool
badLength dom
| BS.null dom = True
| BS.last dom == '.' = BS.length dom > 254
| otherwise = BS.length dom > 253
isIllegal :: Domain -> Bool
isIllegal dom
| badLength dom = True
| '.' `BS.notElem` dom = True
| ':' `BS.elem` dom = True
| '/' `BS.elem` dom = True
| any (\x -> BS.length x > 63)
(BS.split '.' dom) = True
| otherwise = False